blob: bb6d47d2adb08f3e22c5fde44c792af57b5fa5e1 [file] [log] [blame]
Damien Neilb0c26f12019-12-16 09:37:59 -08001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8 "fmt"
9 "math"
10 "math/bits"
11 "reflect"
12 "unicode/utf8"
13
14 "google.golang.org/protobuf/internal/encoding/wire"
Damien Neil0bf97b72020-01-24 09:00:33 -080015 "google.golang.org/protobuf/internal/flags"
Damien Neilb0c26f12019-12-16 09:37:59 -080016 "google.golang.org/protobuf/internal/strs"
17 pref "google.golang.org/protobuf/reflect/protoreflect"
18 preg "google.golang.org/protobuf/reflect/protoregistry"
19 piface "google.golang.org/protobuf/runtime/protoiface"
20)
21
22// ValidationStatus is the result of validating the wire-format encoding of a message.
23type ValidationStatus int
24
25const (
26 // ValidationUnknown indicates that unmarshaling the message might succeed or fail.
27 // The validator was unable to render a judgement.
28 //
29 // The only causes of this status are an aberrant message type appearing somewhere
30 // in the message or a failure in the extension resolver.
31 ValidationUnknown ValidationStatus = iota + 1
32
33 // ValidationInvalid indicates that unmarshaling the message will fail.
34 ValidationInvalid
35
36 // ValidationValidInitialized indicates that unmarshaling the message will succeed
37 // and IsInitialized on the result will report success.
38 ValidationValidInitialized
39
40 // ValidationValidMaybeUninitalized indicates unmarshaling the message will succeed,
41 // but the output of IsInitialized on the result is unknown.
42 //
43 // This status may be returned for an initialized message when a message value
44 // is split across multiple fields.
45 ValidationValidMaybeUninitalized
46)
47
48func (v ValidationStatus) String() string {
49 switch v {
50 case ValidationUnknown:
51 return "ValidationUnknown"
52 case ValidationInvalid:
53 return "ValidationInvalid"
54 case ValidationValidInitialized:
55 return "ValidationValidInitialized"
56 case ValidationValidMaybeUninitalized:
57 return "ValidationValidMaybeUninitalized"
58 default:
59 return fmt.Sprintf("ValidationStatus(%d)", int(v))
60 }
61}
62
63// Validate determines whether the contents of the buffer are a valid wire encoding
64// of the message type.
65//
66// This function is exposed for testing.
67func Validate(b []byte, mt pref.MessageType, opts piface.UnmarshalOptions) ValidationStatus {
68 mi, ok := mt.(*MessageInfo)
69 if !ok {
70 return ValidationUnknown
71 }
72 return mi.validate(b, 0, newUnmarshalOptions(opts))
73}
74
75type validationInfo struct {
76 mi *MessageInfo
77 typ validationType
78 keyType, valType validationType
79
Damien Neil170b2bf2020-01-24 16:42:42 -080080 // For non-required fields, requiredBit is 0.
Damien Neilb0c26f12019-12-16 09:37:59 -080081 //
Damien Neil170b2bf2020-01-24 16:42:42 -080082 // For required fields, requiredBit's nth bit is set, where n is a
83 // unique index in the range [0, MessageInfo.numRequiredFields).
84 //
85 // If there are more than 64 required fields, requiredBit is 0.
86 requiredBit uint64
Damien Neilb0c26f12019-12-16 09:37:59 -080087}
88
89type validationType uint8
90
91const (
92 validationTypeOther validationType = iota
93 validationTypeMessage
94 validationTypeGroup
95 validationTypeMap
96 validationTypeRepeatedVarint
97 validationTypeRepeatedFixed32
98 validationTypeRepeatedFixed64
99 validationTypeVarint
100 validationTypeFixed32
101 validationTypeFixed64
102 validationTypeBytes
103 validationTypeUTF8String
104)
105
106func newFieldValidationInfo(mi *MessageInfo, si structInfo, fd pref.FieldDescriptor, ft reflect.Type) validationInfo {
107 var vi validationInfo
108 switch {
109 case fd.ContainingOneof() != nil:
110 switch fd.Kind() {
111 case pref.MessageKind:
112 vi.typ = validationTypeMessage
113 if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
114 vi.mi = getMessageInfo(ot.Field(0).Type)
115 }
116 case pref.GroupKind:
117 vi.typ = validationTypeGroup
118 if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
119 vi.mi = getMessageInfo(ot.Field(0).Type)
120 }
121 case pref.StringKind:
122 if strs.EnforceUTF8(fd) {
123 vi.typ = validationTypeUTF8String
124 }
125 }
126 default:
127 vi = newValidationInfo(fd, ft)
128 }
129 if fd.Cardinality() == pref.Required {
130 // Avoid overflow. The required field check is done with a 64-bit mask, with
131 // any message containing more than 64 required fields always reported as
132 // potentially uninitialized, so it is not important to get a precise count
133 // of the required fields past 64.
134 if mi.numRequiredFields < math.MaxUint8 {
135 mi.numRequiredFields++
Damien Neil170b2bf2020-01-24 16:42:42 -0800136 vi.requiredBit = 1 << (mi.numRequiredFields - 1)
Damien Neilb0c26f12019-12-16 09:37:59 -0800137 }
138 }
139 return vi
140}
141
142func newValidationInfo(fd pref.FieldDescriptor, ft reflect.Type) validationInfo {
143 var vi validationInfo
144 switch {
145 case fd.IsList():
146 switch fd.Kind() {
147 case pref.MessageKind:
148 vi.typ = validationTypeMessage
149 if ft.Kind() == reflect.Slice {
150 vi.mi = getMessageInfo(ft.Elem())
151 }
152 case pref.GroupKind:
153 vi.typ = validationTypeGroup
154 if ft.Kind() == reflect.Slice {
155 vi.mi = getMessageInfo(ft.Elem())
156 }
157 case pref.StringKind:
158 vi.typ = validationTypeBytes
159 if strs.EnforceUTF8(fd) {
160 vi.typ = validationTypeUTF8String
161 }
162 default:
163 switch wireTypes[fd.Kind()] {
164 case wire.VarintType:
165 vi.typ = validationTypeRepeatedVarint
166 case wire.Fixed32Type:
167 vi.typ = validationTypeRepeatedFixed32
168 case wire.Fixed64Type:
169 vi.typ = validationTypeRepeatedFixed64
170 }
171 }
172 case fd.IsMap():
173 vi.typ = validationTypeMap
174 switch fd.MapKey().Kind() {
175 case pref.StringKind:
176 if strs.EnforceUTF8(fd) {
177 vi.keyType = validationTypeUTF8String
178 }
179 }
180 switch fd.MapValue().Kind() {
181 case pref.MessageKind:
182 vi.valType = validationTypeMessage
183 if ft.Kind() == reflect.Map {
184 vi.mi = getMessageInfo(ft.Elem())
185 }
186 case pref.StringKind:
187 if strs.EnforceUTF8(fd) {
188 vi.valType = validationTypeUTF8String
189 }
190 }
191 default:
192 switch fd.Kind() {
193 case pref.MessageKind:
194 vi.typ = validationTypeMessage
195 if !fd.IsWeak() {
196 vi.mi = getMessageInfo(ft)
197 }
198 case pref.GroupKind:
199 vi.typ = validationTypeGroup
200 vi.mi = getMessageInfo(ft)
201 case pref.StringKind:
202 vi.typ = validationTypeBytes
203 if strs.EnforceUTF8(fd) {
204 vi.typ = validationTypeUTF8String
205 }
206 default:
207 switch wireTypes[fd.Kind()] {
208 case wire.VarintType:
209 vi.typ = validationTypeVarint
210 case wire.Fixed32Type:
211 vi.typ = validationTypeFixed32
212 case wire.Fixed64Type:
213 vi.typ = validationTypeFixed64
Damien Neil6635e7d2020-01-15 15:08:57 -0800214 case wire.BytesType:
215 vi.typ = validationTypeBytes
Damien Neilb0c26f12019-12-16 09:37:59 -0800216 }
217 }
218 }
219 return vi
220}
221
222func (mi *MessageInfo) validate(b []byte, groupTag wire.Number, opts unmarshalOptions) (result ValidationStatus) {
223 type validationState struct {
224 typ validationType
225 keyType, valType validationType
226 endGroup wire.Number
227 mi *MessageInfo
228 tail []byte
229 requiredMask uint64
230 }
231
232 // Pre-allocate some slots to avoid repeated slice reallocation.
233 states := make([]validationState, 0, 16)
234 states = append(states, validationState{
235 typ: validationTypeMessage,
236 mi: mi,
237 })
238 if groupTag > 0 {
239 states[0].typ = validationTypeGroup
240 states[0].endGroup = groupTag
241 }
242 initialized := true
243State:
244 for len(states) > 0 {
245 st := &states[len(states)-1]
246 if st.mi != nil {
247 st.mi.init()
Damien Neil0bf97b72020-01-24 09:00:33 -0800248 if flags.ProtoLegacy && st.mi.isMessageSet {
249 return ValidationUnknown
250 }
Damien Neilb0c26f12019-12-16 09:37:59 -0800251 }
Damien Neilb0c26f12019-12-16 09:37:59 -0800252 for len(b) > 0 {
Damien Neil5d828832020-01-28 08:06:12 -0800253 // Parse the tag (field number and wire type).
254 var tag uint64
255 if b[0] < 0x80 {
256 tag = uint64(b[0])
257 b = b[1:]
258 } else if len(b) >= 2 && b[1] < 128 {
259 tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
260 b = b[2:]
261 } else {
262 var n int
263 tag, n = wire.ConsumeVarint(b)
264 if n < 0 {
265 return ValidationInvalid
266 }
267 b = b[n:]
Damien Neilb0c26f12019-12-16 09:37:59 -0800268 }
Damien Neil5d828832020-01-28 08:06:12 -0800269 var num wire.Number
270 if n := tag >> 3; n < uint64(wire.MinValidNumber) || n > uint64(wire.MaxValidNumber) {
Damien Neilb0c26f12019-12-16 09:37:59 -0800271 return ValidationInvalid
Damien Neil5d828832020-01-28 08:06:12 -0800272 } else {
273 num = wire.Number(n)
Damien Neilb0c26f12019-12-16 09:37:59 -0800274 }
Damien Neil5d828832020-01-28 08:06:12 -0800275 wtyp := wire.Type(tag & 7)
276
Damien Neilb0c26f12019-12-16 09:37:59 -0800277 if wtyp == wire.EndGroupType {
278 if st.endGroup == num {
279 goto PopState
280 }
281 return ValidationInvalid
282 }
283 var vi validationInfo
284 switch st.typ {
285 case validationTypeMap:
286 switch num {
287 case 1:
288 vi.typ = st.keyType
289 case 2:
290 vi.typ = st.valType
291 vi.mi = st.mi
Damien Neil170b2bf2020-01-24 16:42:42 -0800292 vi.requiredBit = 1
Damien Neilb0c26f12019-12-16 09:37:59 -0800293 }
294 default:
295 var f *coderFieldInfo
296 if int(num) < len(st.mi.denseCoderFields) {
297 f = st.mi.denseCoderFields[num]
298 } else {
299 f = st.mi.coderFields[num]
300 }
301 if f != nil {
302 vi = f.validation
303 if vi.typ == validationTypeMessage && vi.mi == nil {
304 // Probable weak field.
305 //
306 // TODO: Consider storing the results of this lookup somewhere
307 // rather than recomputing it on every validation.
308 fd := st.mi.Desc.Fields().ByNumber(num)
309 if fd == nil || !fd.IsWeak() {
310 break
311 }
312 messageName := fd.Message().FullName()
313 messageType, err := preg.GlobalTypes.FindMessageByName(messageName)
314 switch err {
315 case nil:
316 vi.mi, _ = messageType.(*MessageInfo)
317 case preg.NotFound:
318 vi.typ = validationTypeBytes
319 default:
320 return ValidationUnknown
321 }
322 }
323 break
324 }
325 // Possible extension field.
326 //
327 // TODO: We should return ValidationUnknown when:
328 // 1. The resolver is not frozen. (More extensions may be added to it.)
329 // 2. The resolver returns preg.NotFound.
330 // In this case, a type added to the resolver in the future could cause
331 // unmarshaling to begin failing. Supporting this requires some way to
332 // determine if the resolver is frozen.
333 xt, err := opts.Resolver().FindExtensionByNumber(st.mi.Desc.FullName(), num)
334 if err != nil && err != preg.NotFound {
335 return ValidationUnknown
336 }
337 if err == nil {
338 vi = getExtensionFieldInfo(xt).validation
339 }
340 }
Damien Neil170b2bf2020-01-24 16:42:42 -0800341 if vi.requiredBit != 0 {
Damien Neilb0c26f12019-12-16 09:37:59 -0800342 // Check that the field has a compatible wire type.
343 // We only need to consider non-repeated field types,
344 // since repeated fields (and maps) can never be required.
345 ok := false
346 switch vi.typ {
347 case validationTypeVarint:
348 ok = wtyp == wire.VarintType
349 case validationTypeFixed32:
350 ok = wtyp == wire.Fixed32Type
351 case validationTypeFixed64:
352 ok = wtyp == wire.Fixed64Type
353 case validationTypeBytes, validationTypeUTF8String, validationTypeMessage, validationTypeGroup:
354 ok = wtyp == wire.BytesType
355 }
356 if ok {
Damien Neil170b2bf2020-01-24 16:42:42 -0800357 st.requiredMask |= vi.requiredBit
Damien Neilb0c26f12019-12-16 09:37:59 -0800358 }
359 }
Damien Neil8fa11b12020-01-28 08:31:04 -0800360
361 switch wtyp {
362 case wire.VarintType:
363 if len(b) >= 10 {
364 switch {
365 case b[0] < 0x80:
366 b = b[1:]
367 case b[1] < 0x80:
368 b = b[2:]
369 case b[2] < 0x80:
370 b = b[3:]
371 case b[3] < 0x80:
372 b = b[4:]
373 case b[4] < 0x80:
374 b = b[5:]
375 case b[5] < 0x80:
376 b = b[6:]
377 case b[6] < 0x80:
378 b = b[7:]
379 case b[7] < 0x80:
380 b = b[8:]
381 case b[8] < 0x80:
382 b = b[9:]
383 case b[9] < 0x80:
384 b = b[10:]
385 default:
386 return ValidationInvalid
387 }
388 } else {
389 switch {
390 case len(b) > 0 && b[0] < 0x80:
391 b = b[1:]
392 case len(b) > 1 && b[1] < 0x80:
393 b = b[2:]
394 case len(b) > 2 && b[2] < 0x80:
395 b = b[3:]
396 case len(b) > 3 && b[3] < 0x80:
397 b = b[4:]
398 case len(b) > 4 && b[4] < 0x80:
399 b = b[5:]
400 case len(b) > 5 && b[5] < 0x80:
401 b = b[6:]
402 case len(b) > 6 && b[6] < 0x80:
403 b = b[7:]
404 case len(b) > 7 && b[7] < 0x80:
405 b = b[8:]
406 case len(b) > 8 && b[8] < 0x80:
407 b = b[9:]
408 case len(b) > 9 && b[9] < 0x80:
409 b = b[10:]
410 default:
411 return ValidationInvalid
412 }
Damien Neilb0c26f12019-12-16 09:37:59 -0800413 }
Damien Neilb0c26f12019-12-16 09:37:59 -0800414 continue State
Damien Neil8fa11b12020-01-28 08:31:04 -0800415 case wire.BytesType:
416 var size uint64
417 if b[0] < 0x80 {
418 size = uint64(b[0])
419 b = b[1:]
420 } else if len(b) >= 2 && b[1] < 128 {
421 size = uint64(b[0]&0x7f) + uint64(b[1])<<7
422 b = b[2:]
423 } else {
424 var n int
425 size, n = wire.ConsumeVarint(b)
Damien Neilb0c26f12019-12-16 09:37:59 -0800426 if n < 0 {
427 return ValidationInvalid
428 }
Damien Neil8fa11b12020-01-28 08:31:04 -0800429 b = b[n:]
Damien Neilb0c26f12019-12-16 09:37:59 -0800430 }
Damien Neil8fa11b12020-01-28 08:31:04 -0800431 if size > uint64(len(b)) {
Damien Neilb0c26f12019-12-16 09:37:59 -0800432 return ValidationInvalid
433 }
Damien Neil8fa11b12020-01-28 08:31:04 -0800434 v := b[:size]
435 b = b[size:]
436 switch vi.typ {
437 case validationTypeMessage, validationTypeMap:
438 if vi.mi == nil && vi.typ == validationTypeMessage {
439 return ValidationUnknown
440 }
441 states = append(states, validationState{
442 typ: vi.typ,
443 keyType: vi.keyType,
444 valType: vi.valType,
445 mi: vi.mi,
446 tail: b,
447 })
448 b = v
449 continue State
450 case validationTypeRepeatedVarint:
451 // Packed field.
452 for len(v) > 0 {
453 _, n := wire.ConsumeVarint(v)
454 if n < 0 {
455 return ValidationInvalid
456 }
457 v = v[n:]
458 }
459 case validationTypeRepeatedFixed32:
460 // Packed field.
461 if len(v)%4 != 0 {
462 return ValidationInvalid
463 }
464 case validationTypeRepeatedFixed64:
465 // Packed field.
466 if len(v)%8 != 0 {
467 return ValidationInvalid
468 }
469 case validationTypeUTF8String:
470 if !utf8.Valid(v) {
471 return ValidationInvalid
472 }
Damien Neilb0c26f12019-12-16 09:37:59 -0800473 }
Damien Neil8fa11b12020-01-28 08:31:04 -0800474 case wire.Fixed32Type:
475 if len(b) < 4 {
Damien Neilb0c26f12019-12-16 09:37:59 -0800476 return ValidationInvalid
477 }
Damien Neil8fa11b12020-01-28 08:31:04 -0800478 b = b[4:]
479 case wire.Fixed64Type:
480 if len(b) < 8 {
Damien Neilb0c26f12019-12-16 09:37:59 -0800481 return ValidationInvalid
482 }
Damien Neil8fa11b12020-01-28 08:31:04 -0800483 b = b[8:]
484 case wire.StartGroupType:
485 switch vi.typ {
486 case validationTypeGroup:
487 if vi.mi == nil {
488 return ValidationUnknown
489 }
490 states = append(states, validationState{
491 typ: validationTypeGroup,
492 mi: vi.mi,
493 endGroup: num,
494 })
495 continue State
496 default:
497 n := wire.ConsumeFieldValue(num, wtyp, b)
498 if n < 0 {
499 return ValidationInvalid
500 }
501 b = b[n:]
502 }
503 default:
Damien Neilb0c26f12019-12-16 09:37:59 -0800504 return ValidationInvalid
505 }
Damien Neilb0c26f12019-12-16 09:37:59 -0800506 }
507 if st.endGroup != 0 {
508 return ValidationInvalid
509 }
510 if len(b) != 0 {
511 return ValidationInvalid
512 }
513 b = st.tail
514 PopState:
Damien Neil54a0a042020-01-08 17:53:16 -0800515 numRequiredFields := 0
Damien Neilb0c26f12019-12-16 09:37:59 -0800516 switch st.typ {
517 case validationTypeMessage, validationTypeGroup:
Damien Neil54a0a042020-01-08 17:53:16 -0800518 numRequiredFields = int(st.mi.numRequiredFields)
519 case validationTypeMap:
520 // If this is a map field with a message value that contains
521 // required fields, require that the value be present.
522 if st.mi != nil && st.mi.numRequiredFields > 0 {
523 numRequiredFields = 1
Damien Neilb0c26f12019-12-16 09:37:59 -0800524 }
525 }
Damien Neil54a0a042020-01-08 17:53:16 -0800526 // If there are more than 64 required fields, this check will
527 // always fail and we will report that the message is potentially
528 // uninitialized.
529 if numRequiredFields > 0 && bits.OnesCount64(st.requiredMask) != numRequiredFields {
530 initialized = false
531 }
Damien Neilb0c26f12019-12-16 09:37:59 -0800532 states = states[:len(states)-1]
533 }
534 if !initialized {
535 return ValidationValidMaybeUninitalized
536 }
537 return ValidationValidInitialized
538}