blob: 1d3406364ad95eb3736394ee6692d08894b7a002 [file] [log] [blame]
Damien Neilb0c26f12019-12-16 09:37:59 -08001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8 "fmt"
9 "math"
10 "math/bits"
11 "reflect"
12 "unicode/utf8"
13
14 "google.golang.org/protobuf/internal/encoding/wire"
15 "google.golang.org/protobuf/internal/strs"
16 pref "google.golang.org/protobuf/reflect/protoreflect"
17 preg "google.golang.org/protobuf/reflect/protoregistry"
18 piface "google.golang.org/protobuf/runtime/protoiface"
19)
20
21// ValidationStatus is the result of validating the wire-format encoding of a message.
22type ValidationStatus int
23
24const (
25 // ValidationUnknown indicates that unmarshaling the message might succeed or fail.
26 // The validator was unable to render a judgement.
27 //
28 // The only causes of this status are an aberrant message type appearing somewhere
29 // in the message or a failure in the extension resolver.
30 ValidationUnknown ValidationStatus = iota + 1
31
32 // ValidationInvalid indicates that unmarshaling the message will fail.
33 ValidationInvalid
34
35 // ValidationValidInitialized indicates that unmarshaling the message will succeed
36 // and IsInitialized on the result will report success.
37 ValidationValidInitialized
38
39 // ValidationValidMaybeUninitalized indicates unmarshaling the message will succeed,
40 // but the output of IsInitialized on the result is unknown.
41 //
42 // This status may be returned for an initialized message when a message value
43 // is split across multiple fields.
44 ValidationValidMaybeUninitalized
45)
46
47func (v ValidationStatus) String() string {
48 switch v {
49 case ValidationUnknown:
50 return "ValidationUnknown"
51 case ValidationInvalid:
52 return "ValidationInvalid"
53 case ValidationValidInitialized:
54 return "ValidationValidInitialized"
55 case ValidationValidMaybeUninitalized:
56 return "ValidationValidMaybeUninitalized"
57 default:
58 return fmt.Sprintf("ValidationStatus(%d)", int(v))
59 }
60}
61
62// Validate determines whether the contents of the buffer are a valid wire encoding
63// of the message type.
64//
65// This function is exposed for testing.
66func Validate(b []byte, mt pref.MessageType, opts piface.UnmarshalOptions) ValidationStatus {
67 mi, ok := mt.(*MessageInfo)
68 if !ok {
69 return ValidationUnknown
70 }
71 return mi.validate(b, 0, newUnmarshalOptions(opts))
72}
73
74type validationInfo struct {
75 mi *MessageInfo
76 typ validationType
77 keyType, valType validationType
78
79 // For non-required fields, requiredIndex is 0.
80 //
81 // For required fields, requiredIndex is unique index in the range
82 // (0, MessageInfo.numRequiredFields].
83 requiredIndex uint8
84}
85
86type validationType uint8
87
88const (
89 validationTypeOther validationType = iota
90 validationTypeMessage
91 validationTypeGroup
92 validationTypeMap
93 validationTypeRepeatedVarint
94 validationTypeRepeatedFixed32
95 validationTypeRepeatedFixed64
96 validationTypeVarint
97 validationTypeFixed32
98 validationTypeFixed64
99 validationTypeBytes
100 validationTypeUTF8String
101)
102
103func newFieldValidationInfo(mi *MessageInfo, si structInfo, fd pref.FieldDescriptor, ft reflect.Type) validationInfo {
104 var vi validationInfo
105 switch {
106 case fd.ContainingOneof() != nil:
107 switch fd.Kind() {
108 case pref.MessageKind:
109 vi.typ = validationTypeMessage
110 if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
111 vi.mi = getMessageInfo(ot.Field(0).Type)
112 }
113 case pref.GroupKind:
114 vi.typ = validationTypeGroup
115 if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
116 vi.mi = getMessageInfo(ot.Field(0).Type)
117 }
118 case pref.StringKind:
119 if strs.EnforceUTF8(fd) {
120 vi.typ = validationTypeUTF8String
121 }
122 }
123 default:
124 vi = newValidationInfo(fd, ft)
125 }
126 if fd.Cardinality() == pref.Required {
127 // Avoid overflow. The required field check is done with a 64-bit mask, with
128 // any message containing more than 64 required fields always reported as
129 // potentially uninitialized, so it is not important to get a precise count
130 // of the required fields past 64.
131 if mi.numRequiredFields < math.MaxUint8 {
132 mi.numRequiredFields++
133 vi.requiredIndex = mi.numRequiredFields
134 }
135 }
136 return vi
137}
138
139func newValidationInfo(fd pref.FieldDescriptor, ft reflect.Type) validationInfo {
140 var vi validationInfo
141 switch {
142 case fd.IsList():
143 switch fd.Kind() {
144 case pref.MessageKind:
145 vi.typ = validationTypeMessage
146 if ft.Kind() == reflect.Slice {
147 vi.mi = getMessageInfo(ft.Elem())
148 }
149 case pref.GroupKind:
150 vi.typ = validationTypeGroup
151 if ft.Kind() == reflect.Slice {
152 vi.mi = getMessageInfo(ft.Elem())
153 }
154 case pref.StringKind:
155 vi.typ = validationTypeBytes
156 if strs.EnforceUTF8(fd) {
157 vi.typ = validationTypeUTF8String
158 }
159 default:
160 switch wireTypes[fd.Kind()] {
161 case wire.VarintType:
162 vi.typ = validationTypeRepeatedVarint
163 case wire.Fixed32Type:
164 vi.typ = validationTypeRepeatedFixed32
165 case wire.Fixed64Type:
166 vi.typ = validationTypeRepeatedFixed64
167 }
168 }
169 case fd.IsMap():
170 vi.typ = validationTypeMap
171 switch fd.MapKey().Kind() {
172 case pref.StringKind:
173 if strs.EnforceUTF8(fd) {
174 vi.keyType = validationTypeUTF8String
175 }
176 }
177 switch fd.MapValue().Kind() {
178 case pref.MessageKind:
179 vi.valType = validationTypeMessage
180 if ft.Kind() == reflect.Map {
181 vi.mi = getMessageInfo(ft.Elem())
182 }
183 case pref.StringKind:
184 if strs.EnforceUTF8(fd) {
185 vi.valType = validationTypeUTF8String
186 }
187 }
188 default:
189 switch fd.Kind() {
190 case pref.MessageKind:
191 vi.typ = validationTypeMessage
192 if !fd.IsWeak() {
193 vi.mi = getMessageInfo(ft)
194 }
195 case pref.GroupKind:
196 vi.typ = validationTypeGroup
197 vi.mi = getMessageInfo(ft)
198 case pref.StringKind:
199 vi.typ = validationTypeBytes
200 if strs.EnforceUTF8(fd) {
201 vi.typ = validationTypeUTF8String
202 }
203 default:
204 switch wireTypes[fd.Kind()] {
205 case wire.VarintType:
206 vi.typ = validationTypeVarint
207 case wire.Fixed32Type:
208 vi.typ = validationTypeFixed32
209 case wire.Fixed64Type:
210 vi.typ = validationTypeFixed64
Damien Neil6635e7d2020-01-15 15:08:57 -0800211 case wire.BytesType:
212 vi.typ = validationTypeBytes
Damien Neilb0c26f12019-12-16 09:37:59 -0800213 }
214 }
215 }
216 return vi
217}
218
219func (mi *MessageInfo) validate(b []byte, groupTag wire.Number, opts unmarshalOptions) (result ValidationStatus) {
220 type validationState struct {
221 typ validationType
222 keyType, valType validationType
223 endGroup wire.Number
224 mi *MessageInfo
225 tail []byte
226 requiredMask uint64
227 }
228
229 // Pre-allocate some slots to avoid repeated slice reallocation.
230 states := make([]validationState, 0, 16)
231 states = append(states, validationState{
232 typ: validationTypeMessage,
233 mi: mi,
234 })
235 if groupTag > 0 {
236 states[0].typ = validationTypeGroup
237 states[0].endGroup = groupTag
238 }
239 initialized := true
240State:
241 for len(states) > 0 {
242 st := &states[len(states)-1]
243 if st.mi != nil {
244 st.mi.init()
245 }
246 Field:
247 for len(b) > 0 {
248 num, wtyp, n := wire.ConsumeTag(b)
249 if n < 0 {
250 return ValidationInvalid
251 }
252 b = b[n:]
253 if num > wire.MaxValidNumber {
254 return ValidationInvalid
255 }
256 if wtyp == wire.EndGroupType {
257 if st.endGroup == num {
258 goto PopState
259 }
260 return ValidationInvalid
261 }
262 var vi validationInfo
263 switch st.typ {
264 case validationTypeMap:
265 switch num {
266 case 1:
267 vi.typ = st.keyType
268 case 2:
269 vi.typ = st.valType
270 vi.mi = st.mi
Damien Neil54a0a042020-01-08 17:53:16 -0800271 vi.requiredIndex = 1
Damien Neilb0c26f12019-12-16 09:37:59 -0800272 }
273 default:
274 var f *coderFieldInfo
275 if int(num) < len(st.mi.denseCoderFields) {
276 f = st.mi.denseCoderFields[num]
277 } else {
278 f = st.mi.coderFields[num]
279 }
280 if f != nil {
281 vi = f.validation
282 if vi.typ == validationTypeMessage && vi.mi == nil {
283 // Probable weak field.
284 //
285 // TODO: Consider storing the results of this lookup somewhere
286 // rather than recomputing it on every validation.
287 fd := st.mi.Desc.Fields().ByNumber(num)
288 if fd == nil || !fd.IsWeak() {
289 break
290 }
291 messageName := fd.Message().FullName()
292 messageType, err := preg.GlobalTypes.FindMessageByName(messageName)
293 switch err {
294 case nil:
295 vi.mi, _ = messageType.(*MessageInfo)
296 case preg.NotFound:
297 vi.typ = validationTypeBytes
298 default:
299 return ValidationUnknown
300 }
301 }
302 break
303 }
304 // Possible extension field.
305 //
306 // TODO: We should return ValidationUnknown when:
307 // 1. The resolver is not frozen. (More extensions may be added to it.)
308 // 2. The resolver returns preg.NotFound.
309 // In this case, a type added to the resolver in the future could cause
310 // unmarshaling to begin failing. Supporting this requires some way to
311 // determine if the resolver is frozen.
312 xt, err := opts.Resolver().FindExtensionByNumber(st.mi.Desc.FullName(), num)
313 if err != nil && err != preg.NotFound {
314 return ValidationUnknown
315 }
316 if err == nil {
317 vi = getExtensionFieldInfo(xt).validation
318 }
319 }
320 if vi.requiredIndex > 0 {
321 // Check that the field has a compatible wire type.
322 // We only need to consider non-repeated field types,
323 // since repeated fields (and maps) can never be required.
324 ok := false
325 switch vi.typ {
326 case validationTypeVarint:
327 ok = wtyp == wire.VarintType
328 case validationTypeFixed32:
329 ok = wtyp == wire.Fixed32Type
330 case validationTypeFixed64:
331 ok = wtyp == wire.Fixed64Type
332 case validationTypeBytes, validationTypeUTF8String, validationTypeMessage, validationTypeGroup:
333 ok = wtyp == wire.BytesType
334 }
335 if ok {
336 st.requiredMask |= 1 << (vi.requiredIndex - 1)
337 }
338 }
339 switch vi.typ {
340 case validationTypeMessage, validationTypeMap:
341 if wtyp != wire.BytesType {
342 break
343 }
344 if vi.mi == nil && vi.typ == validationTypeMessage {
345 return ValidationUnknown
346 }
347 size, n := wire.ConsumeVarint(b)
348 if n < 0 {
349 return ValidationInvalid
350 }
351 b = b[n:]
352 if uint64(len(b)) < size {
353 return ValidationInvalid
354 }
355 states = append(states, validationState{
356 typ: vi.typ,
357 keyType: vi.keyType,
358 valType: vi.valType,
359 mi: vi.mi,
360 tail: b[size:],
361 })
362 b = b[:size]
363 continue State
364 case validationTypeGroup:
365 if wtyp != wire.StartGroupType {
366 break
367 }
368 if vi.mi == nil {
369 return ValidationUnknown
370 }
371 states = append(states, validationState{
372 typ: validationTypeGroup,
373 mi: vi.mi,
374 endGroup: num,
375 })
376 continue State
377 case validationTypeRepeatedVarint:
378 if wtyp != wire.BytesType {
379 break
380 }
381 // Packed field.
382 v, n := wire.ConsumeBytes(b)
383 if n < 0 {
384 return ValidationInvalid
385 }
386 b = b[n:]
387 for len(v) > 0 {
388 _, n := wire.ConsumeVarint(v)
389 if n < 0 {
390 return ValidationInvalid
391 }
392 v = v[n:]
393 }
394 continue Field
395 case validationTypeRepeatedFixed32:
396 if wtyp != wire.BytesType {
397 break
398 }
399 // Packed field.
400 v, n := wire.ConsumeBytes(b)
401 if n < 0 || len(v)%4 != 0 {
402 return ValidationInvalid
403 }
404 b = b[n:]
405 continue Field
406 case validationTypeRepeatedFixed64:
407 if wtyp != wire.BytesType {
408 break
409 }
410 // Packed field.
411 v, n := wire.ConsumeBytes(b)
412 if n < 0 || len(v)%8 != 0 {
413 return ValidationInvalid
414 }
415 b = b[n:]
416 continue Field
417 case validationTypeUTF8String:
418 if wtyp != wire.BytesType {
419 break
420 }
421 v, n := wire.ConsumeBytes(b)
422 if n < 0 || !utf8.Valid(v) {
423 return ValidationInvalid
424 }
425 b = b[n:]
426 continue Field
427 }
428 n = wire.ConsumeFieldValue(num, wtyp, b)
429 if n < 0 {
430 return ValidationInvalid
431 }
432 b = b[n:]
433 }
434 if st.endGroup != 0 {
435 return ValidationInvalid
436 }
437 if len(b) != 0 {
438 return ValidationInvalid
439 }
440 b = st.tail
441 PopState:
Damien Neil54a0a042020-01-08 17:53:16 -0800442 numRequiredFields := 0
Damien Neilb0c26f12019-12-16 09:37:59 -0800443 switch st.typ {
444 case validationTypeMessage, validationTypeGroup:
Damien Neil54a0a042020-01-08 17:53:16 -0800445 numRequiredFields = int(st.mi.numRequiredFields)
446 case validationTypeMap:
447 // If this is a map field with a message value that contains
448 // required fields, require that the value be present.
449 if st.mi != nil && st.mi.numRequiredFields > 0 {
450 numRequiredFields = 1
Damien Neilb0c26f12019-12-16 09:37:59 -0800451 }
452 }
Damien Neil54a0a042020-01-08 17:53:16 -0800453 // If there are more than 64 required fields, this check will
454 // always fail and we will report that the message is potentially
455 // uninitialized.
456 if numRequiredFields > 0 && bits.OnesCount64(st.requiredMask) != numRequiredFields {
457 initialized = false
458 }
Damien Neilb0c26f12019-12-16 09:37:59 -0800459 states = states[:len(states)-1]
460 }
461 if !initialized {
462 return ValidationValidMaybeUninitalized
463 }
464 return ValidationValidInitialized
465}