blob: eab8ec0fc5b52b685779f2f18d40083a7899d1a8 [file] [log] [blame]
Damien Neilb0c26f12019-12-16 09:37:59 -08001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8 "fmt"
9 "math"
10 "math/bits"
11 "reflect"
12 "unicode/utf8"
13
14 "google.golang.org/protobuf/internal/encoding/wire"
Damien Neil0bf97b72020-01-24 09:00:33 -080015 "google.golang.org/protobuf/internal/flags"
Damien Neilb0c26f12019-12-16 09:37:59 -080016 "google.golang.org/protobuf/internal/strs"
17 pref "google.golang.org/protobuf/reflect/protoreflect"
18 preg "google.golang.org/protobuf/reflect/protoregistry"
19 piface "google.golang.org/protobuf/runtime/protoiface"
20)
21
22// ValidationStatus is the result of validating the wire-format encoding of a message.
23type ValidationStatus int
24
25const (
26 // ValidationUnknown indicates that unmarshaling the message might succeed or fail.
27 // The validator was unable to render a judgement.
28 //
29 // The only causes of this status are an aberrant message type appearing somewhere
30 // in the message or a failure in the extension resolver.
31 ValidationUnknown ValidationStatus = iota + 1
32
33 // ValidationInvalid indicates that unmarshaling the message will fail.
34 ValidationInvalid
35
36 // ValidationValidInitialized indicates that unmarshaling the message will succeed
37 // and IsInitialized on the result will report success.
38 ValidationValidInitialized
39
40 // ValidationValidMaybeUninitalized indicates unmarshaling the message will succeed,
41 // but the output of IsInitialized on the result is unknown.
42 //
43 // This status may be returned for an initialized message when a message value
44 // is split across multiple fields.
45 ValidationValidMaybeUninitalized
46)
47
48func (v ValidationStatus) String() string {
49 switch v {
50 case ValidationUnknown:
51 return "ValidationUnknown"
52 case ValidationInvalid:
53 return "ValidationInvalid"
54 case ValidationValidInitialized:
55 return "ValidationValidInitialized"
56 case ValidationValidMaybeUninitalized:
57 return "ValidationValidMaybeUninitalized"
58 default:
59 return fmt.Sprintf("ValidationStatus(%d)", int(v))
60 }
61}
62
63// Validate determines whether the contents of the buffer are a valid wire encoding
64// of the message type.
65//
66// This function is exposed for testing.
67func Validate(b []byte, mt pref.MessageType, opts piface.UnmarshalOptions) ValidationStatus {
68 mi, ok := mt.(*MessageInfo)
69 if !ok {
70 return ValidationUnknown
71 }
Damien Neil524c6062020-01-28 13:32:01 -080072 return mi.validate(b, 0, unmarshalOptions(opts))
Damien Neilb0c26f12019-12-16 09:37:59 -080073}
74
75type validationInfo struct {
76 mi *MessageInfo
77 typ validationType
78 keyType, valType validationType
79
Damien Neil170b2bf2020-01-24 16:42:42 -080080 // For non-required fields, requiredBit is 0.
Damien Neilb0c26f12019-12-16 09:37:59 -080081 //
Damien Neil170b2bf2020-01-24 16:42:42 -080082 // For required fields, requiredBit's nth bit is set, where n is a
83 // unique index in the range [0, MessageInfo.numRequiredFields).
84 //
85 // If there are more than 64 required fields, requiredBit is 0.
86 requiredBit uint64
Damien Neilb0c26f12019-12-16 09:37:59 -080087}
88
89type validationType uint8
90
91const (
92 validationTypeOther validationType = iota
93 validationTypeMessage
94 validationTypeGroup
95 validationTypeMap
96 validationTypeRepeatedVarint
97 validationTypeRepeatedFixed32
98 validationTypeRepeatedFixed64
99 validationTypeVarint
100 validationTypeFixed32
101 validationTypeFixed64
102 validationTypeBytes
103 validationTypeUTF8String
104)
105
106func newFieldValidationInfo(mi *MessageInfo, si structInfo, fd pref.FieldDescriptor, ft reflect.Type) validationInfo {
107 var vi validationInfo
108 switch {
109 case fd.ContainingOneof() != nil:
110 switch fd.Kind() {
111 case pref.MessageKind:
112 vi.typ = validationTypeMessage
113 if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
114 vi.mi = getMessageInfo(ot.Field(0).Type)
115 }
116 case pref.GroupKind:
117 vi.typ = validationTypeGroup
118 if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
119 vi.mi = getMessageInfo(ot.Field(0).Type)
120 }
121 case pref.StringKind:
122 if strs.EnforceUTF8(fd) {
123 vi.typ = validationTypeUTF8String
124 }
125 }
126 default:
127 vi = newValidationInfo(fd, ft)
128 }
129 if fd.Cardinality() == pref.Required {
130 // Avoid overflow. The required field check is done with a 64-bit mask, with
131 // any message containing more than 64 required fields always reported as
132 // potentially uninitialized, so it is not important to get a precise count
133 // of the required fields past 64.
134 if mi.numRequiredFields < math.MaxUint8 {
135 mi.numRequiredFields++
Damien Neil170b2bf2020-01-24 16:42:42 -0800136 vi.requiredBit = 1 << (mi.numRequiredFields - 1)
Damien Neilb0c26f12019-12-16 09:37:59 -0800137 }
138 }
139 return vi
140}
141
142func newValidationInfo(fd pref.FieldDescriptor, ft reflect.Type) validationInfo {
143 var vi validationInfo
144 switch {
145 case fd.IsList():
146 switch fd.Kind() {
147 case pref.MessageKind:
148 vi.typ = validationTypeMessage
149 if ft.Kind() == reflect.Slice {
150 vi.mi = getMessageInfo(ft.Elem())
151 }
152 case pref.GroupKind:
153 vi.typ = validationTypeGroup
154 if ft.Kind() == reflect.Slice {
155 vi.mi = getMessageInfo(ft.Elem())
156 }
157 case pref.StringKind:
158 vi.typ = validationTypeBytes
159 if strs.EnforceUTF8(fd) {
160 vi.typ = validationTypeUTF8String
161 }
162 default:
163 switch wireTypes[fd.Kind()] {
164 case wire.VarintType:
165 vi.typ = validationTypeRepeatedVarint
166 case wire.Fixed32Type:
167 vi.typ = validationTypeRepeatedFixed32
168 case wire.Fixed64Type:
169 vi.typ = validationTypeRepeatedFixed64
170 }
171 }
172 case fd.IsMap():
173 vi.typ = validationTypeMap
174 switch fd.MapKey().Kind() {
175 case pref.StringKind:
176 if strs.EnforceUTF8(fd) {
177 vi.keyType = validationTypeUTF8String
178 }
179 }
180 switch fd.MapValue().Kind() {
181 case pref.MessageKind:
182 vi.valType = validationTypeMessage
183 if ft.Kind() == reflect.Map {
184 vi.mi = getMessageInfo(ft.Elem())
185 }
186 case pref.StringKind:
187 if strs.EnforceUTF8(fd) {
188 vi.valType = validationTypeUTF8String
189 }
190 }
191 default:
192 switch fd.Kind() {
193 case pref.MessageKind:
194 vi.typ = validationTypeMessage
195 if !fd.IsWeak() {
196 vi.mi = getMessageInfo(ft)
197 }
198 case pref.GroupKind:
199 vi.typ = validationTypeGroup
200 vi.mi = getMessageInfo(ft)
201 case pref.StringKind:
202 vi.typ = validationTypeBytes
203 if strs.EnforceUTF8(fd) {
204 vi.typ = validationTypeUTF8String
205 }
206 default:
207 switch wireTypes[fd.Kind()] {
208 case wire.VarintType:
209 vi.typ = validationTypeVarint
210 case wire.Fixed32Type:
211 vi.typ = validationTypeFixed32
212 case wire.Fixed64Type:
213 vi.typ = validationTypeFixed64
Damien Neil6635e7d2020-01-15 15:08:57 -0800214 case wire.BytesType:
215 vi.typ = validationTypeBytes
Damien Neilb0c26f12019-12-16 09:37:59 -0800216 }
217 }
218 }
219 return vi
220}
221
222func (mi *MessageInfo) validate(b []byte, groupTag wire.Number, opts unmarshalOptions) (result ValidationStatus) {
Damien Neilcb0bfd02020-01-28 09:11:12 -0800223 mi.init()
Damien Neilb0c26f12019-12-16 09:37:59 -0800224 type validationState struct {
225 typ validationType
226 keyType, valType validationType
227 endGroup wire.Number
228 mi *MessageInfo
229 tail []byte
230 requiredMask uint64
231 }
232
233 // Pre-allocate some slots to avoid repeated slice reallocation.
234 states := make([]validationState, 0, 16)
235 states = append(states, validationState{
236 typ: validationTypeMessage,
237 mi: mi,
238 })
239 if groupTag > 0 {
240 states[0].typ = validationTypeGroup
241 states[0].endGroup = groupTag
242 }
243 initialized := true
244State:
245 for len(states) > 0 {
246 st := &states[len(states)-1]
247 if st.mi != nil {
Damien Neil0bf97b72020-01-24 09:00:33 -0800248 if flags.ProtoLegacy && st.mi.isMessageSet {
249 return ValidationUnknown
250 }
Damien Neilb0c26f12019-12-16 09:37:59 -0800251 }
Damien Neilb0c26f12019-12-16 09:37:59 -0800252 for len(b) > 0 {
Damien Neil5d828832020-01-28 08:06:12 -0800253 // Parse the tag (field number and wire type).
254 var tag uint64
255 if b[0] < 0x80 {
256 tag = uint64(b[0])
257 b = b[1:]
258 } else if len(b) >= 2 && b[1] < 128 {
259 tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
260 b = b[2:]
261 } else {
262 var n int
263 tag, n = wire.ConsumeVarint(b)
264 if n < 0 {
265 return ValidationInvalid
266 }
267 b = b[n:]
Damien Neilb0c26f12019-12-16 09:37:59 -0800268 }
Damien Neil5d828832020-01-28 08:06:12 -0800269 var num wire.Number
270 if n := tag >> 3; n < uint64(wire.MinValidNumber) || n > uint64(wire.MaxValidNumber) {
Damien Neilb0c26f12019-12-16 09:37:59 -0800271 return ValidationInvalid
Damien Neil5d828832020-01-28 08:06:12 -0800272 } else {
273 num = wire.Number(n)
Damien Neilb0c26f12019-12-16 09:37:59 -0800274 }
Damien Neil5d828832020-01-28 08:06:12 -0800275 wtyp := wire.Type(tag & 7)
276
Damien Neilb0c26f12019-12-16 09:37:59 -0800277 if wtyp == wire.EndGroupType {
278 if st.endGroup == num {
279 goto PopState
280 }
281 return ValidationInvalid
282 }
283 var vi validationInfo
284 switch st.typ {
285 case validationTypeMap:
286 switch num {
287 case 1:
288 vi.typ = st.keyType
289 case 2:
290 vi.typ = st.valType
291 vi.mi = st.mi
Damien Neil170b2bf2020-01-24 16:42:42 -0800292 vi.requiredBit = 1
Damien Neilb0c26f12019-12-16 09:37:59 -0800293 }
294 default:
295 var f *coderFieldInfo
296 if int(num) < len(st.mi.denseCoderFields) {
297 f = st.mi.denseCoderFields[num]
298 } else {
299 f = st.mi.coderFields[num]
300 }
301 if f != nil {
302 vi = f.validation
303 if vi.typ == validationTypeMessage && vi.mi == nil {
304 // Probable weak field.
305 //
306 // TODO: Consider storing the results of this lookup somewhere
307 // rather than recomputing it on every validation.
308 fd := st.mi.Desc.Fields().ByNumber(num)
309 if fd == nil || !fd.IsWeak() {
310 break
311 }
312 messageName := fd.Message().FullName()
313 messageType, err := preg.GlobalTypes.FindMessageByName(messageName)
314 switch err {
315 case nil:
316 vi.mi, _ = messageType.(*MessageInfo)
317 case preg.NotFound:
318 vi.typ = validationTypeBytes
319 default:
320 return ValidationUnknown
321 }
322 }
323 break
324 }
325 // Possible extension field.
326 //
327 // TODO: We should return ValidationUnknown when:
328 // 1. The resolver is not frozen. (More extensions may be added to it.)
329 // 2. The resolver returns preg.NotFound.
330 // In this case, a type added to the resolver in the future could cause
331 // unmarshaling to begin failing. Supporting this requires some way to
332 // determine if the resolver is frozen.
Damien Neil524c6062020-01-28 13:32:01 -0800333 xt, err := opts.Resolver.FindExtensionByNumber(st.mi.Desc.FullName(), num)
Damien Neilb0c26f12019-12-16 09:37:59 -0800334 if err != nil && err != preg.NotFound {
335 return ValidationUnknown
336 }
337 if err == nil {
338 vi = getExtensionFieldInfo(xt).validation
339 }
340 }
Damien Neil170b2bf2020-01-24 16:42:42 -0800341 if vi.requiredBit != 0 {
Damien Neilb0c26f12019-12-16 09:37:59 -0800342 // Check that the field has a compatible wire type.
343 // We only need to consider non-repeated field types,
344 // since repeated fields (and maps) can never be required.
345 ok := false
346 switch vi.typ {
347 case validationTypeVarint:
348 ok = wtyp == wire.VarintType
349 case validationTypeFixed32:
350 ok = wtyp == wire.Fixed32Type
351 case validationTypeFixed64:
352 ok = wtyp == wire.Fixed64Type
353 case validationTypeBytes, validationTypeUTF8String, validationTypeMessage, validationTypeGroup:
354 ok = wtyp == wire.BytesType
355 }
356 if ok {
Damien Neil170b2bf2020-01-24 16:42:42 -0800357 st.requiredMask |= vi.requiredBit
Damien Neilb0c26f12019-12-16 09:37:59 -0800358 }
359 }
Damien Neil8fa11b12020-01-28 08:31:04 -0800360
361 switch wtyp {
362 case wire.VarintType:
Damien Neil4d918162020-02-01 10:39:11 -0800363 if len(b) >= 9 {
Damien Neil8fa11b12020-01-28 08:31:04 -0800364 switch {
365 case b[0] < 0x80:
366 b = b[1:]
367 case b[1] < 0x80:
368 b = b[2:]
369 case b[2] < 0x80:
370 b = b[3:]
371 case b[3] < 0x80:
372 b = b[4:]
373 case b[4] < 0x80:
374 b = b[5:]
375 case b[5] < 0x80:
376 b = b[6:]
377 case b[6] < 0x80:
378 b = b[7:]
379 case b[7] < 0x80:
380 b = b[8:]
381 case b[8] < 0x80:
382 b = b[9:]
Damien Neil4d918162020-02-01 10:39:11 -0800383 case b[9] < 0x80 && b[9] < 2:
Damien Neil8fa11b12020-01-28 08:31:04 -0800384 b = b[10:]
385 default:
386 return ValidationInvalid
387 }
388 } else {
389 switch {
390 case len(b) > 0 && b[0] < 0x80:
391 b = b[1:]
392 case len(b) > 1 && b[1] < 0x80:
393 b = b[2:]
394 case len(b) > 2 && b[2] < 0x80:
395 b = b[3:]
396 case len(b) > 3 && b[3] < 0x80:
397 b = b[4:]
398 case len(b) > 4 && b[4] < 0x80:
399 b = b[5:]
400 case len(b) > 5 && b[5] < 0x80:
401 b = b[6:]
402 case len(b) > 6 && b[6] < 0x80:
403 b = b[7:]
404 case len(b) > 7 && b[7] < 0x80:
405 b = b[8:]
406 case len(b) > 8 && b[8] < 0x80:
407 b = b[9:]
Damien Neil4d918162020-02-01 10:39:11 -0800408 case len(b) > 9 && b[9] < 2:
Damien Neil8fa11b12020-01-28 08:31:04 -0800409 b = b[10:]
410 default:
411 return ValidationInvalid
412 }
Damien Neilb0c26f12019-12-16 09:37:59 -0800413 }
Damien Neilb0c26f12019-12-16 09:37:59 -0800414 continue State
Damien Neil8fa11b12020-01-28 08:31:04 -0800415 case wire.BytesType:
416 var size uint64
Damien Neil6f297792020-01-29 15:55:53 -0800417 if len(b) >= 1 && b[0] < 0x80 {
Damien Neil8fa11b12020-01-28 08:31:04 -0800418 size = uint64(b[0])
419 b = b[1:]
420 } else if len(b) >= 2 && b[1] < 128 {
421 size = uint64(b[0]&0x7f) + uint64(b[1])<<7
422 b = b[2:]
423 } else {
424 var n int
425 size, n = wire.ConsumeVarint(b)
Damien Neilb0c26f12019-12-16 09:37:59 -0800426 if n < 0 {
427 return ValidationInvalid
428 }
Damien Neil8fa11b12020-01-28 08:31:04 -0800429 b = b[n:]
Damien Neilb0c26f12019-12-16 09:37:59 -0800430 }
Damien Neil8fa11b12020-01-28 08:31:04 -0800431 if size > uint64(len(b)) {
Damien Neilb0c26f12019-12-16 09:37:59 -0800432 return ValidationInvalid
433 }
Damien Neil8fa11b12020-01-28 08:31:04 -0800434 v := b[:size]
435 b = b[size:]
436 switch vi.typ {
Damien Neilcb0bfd02020-01-28 09:11:12 -0800437 case validationTypeMessage:
438 if vi.mi == nil {
Damien Neil8fa11b12020-01-28 08:31:04 -0800439 return ValidationUnknown
440 }
Damien Neilcb0bfd02020-01-28 09:11:12 -0800441 vi.mi.init()
442 fallthrough
443 case validationTypeMap:
Damien Neil8fa11b12020-01-28 08:31:04 -0800444 states = append(states, validationState{
445 typ: vi.typ,
446 keyType: vi.keyType,
447 valType: vi.valType,
448 mi: vi.mi,
449 tail: b,
450 })
451 b = v
452 continue State
453 case validationTypeRepeatedVarint:
454 // Packed field.
455 for len(v) > 0 {
456 _, n := wire.ConsumeVarint(v)
457 if n < 0 {
458 return ValidationInvalid
459 }
460 v = v[n:]
461 }
462 case validationTypeRepeatedFixed32:
463 // Packed field.
464 if len(v)%4 != 0 {
465 return ValidationInvalid
466 }
467 case validationTypeRepeatedFixed64:
468 // Packed field.
469 if len(v)%8 != 0 {
470 return ValidationInvalid
471 }
472 case validationTypeUTF8String:
473 if !utf8.Valid(v) {
474 return ValidationInvalid
475 }
Damien Neilb0c26f12019-12-16 09:37:59 -0800476 }
Damien Neil8fa11b12020-01-28 08:31:04 -0800477 case wire.Fixed32Type:
478 if len(b) < 4 {
Damien Neilb0c26f12019-12-16 09:37:59 -0800479 return ValidationInvalid
480 }
Damien Neil8fa11b12020-01-28 08:31:04 -0800481 b = b[4:]
482 case wire.Fixed64Type:
483 if len(b) < 8 {
Damien Neilb0c26f12019-12-16 09:37:59 -0800484 return ValidationInvalid
485 }
Damien Neil8fa11b12020-01-28 08:31:04 -0800486 b = b[8:]
487 case wire.StartGroupType:
488 switch vi.typ {
489 case validationTypeGroup:
490 if vi.mi == nil {
491 return ValidationUnknown
492 }
Damien Neilcb0bfd02020-01-28 09:11:12 -0800493 vi.mi.init()
Damien Neil8fa11b12020-01-28 08:31:04 -0800494 states = append(states, validationState{
495 typ: validationTypeGroup,
496 mi: vi.mi,
497 endGroup: num,
498 })
499 continue State
500 default:
501 n := wire.ConsumeFieldValue(num, wtyp, b)
502 if n < 0 {
503 return ValidationInvalid
504 }
505 b = b[n:]
506 }
507 default:
Damien Neilb0c26f12019-12-16 09:37:59 -0800508 return ValidationInvalid
509 }
Damien Neilb0c26f12019-12-16 09:37:59 -0800510 }
511 if st.endGroup != 0 {
512 return ValidationInvalid
513 }
514 if len(b) != 0 {
515 return ValidationInvalid
516 }
517 b = st.tail
518 PopState:
Damien Neil54a0a042020-01-08 17:53:16 -0800519 numRequiredFields := 0
Damien Neilb0c26f12019-12-16 09:37:59 -0800520 switch st.typ {
521 case validationTypeMessage, validationTypeGroup:
Damien Neil54a0a042020-01-08 17:53:16 -0800522 numRequiredFields = int(st.mi.numRequiredFields)
523 case validationTypeMap:
524 // If this is a map field with a message value that contains
525 // required fields, require that the value be present.
526 if st.mi != nil && st.mi.numRequiredFields > 0 {
527 numRequiredFields = 1
Damien Neilb0c26f12019-12-16 09:37:59 -0800528 }
529 }
Damien Neil54a0a042020-01-08 17:53:16 -0800530 // If there are more than 64 required fields, this check will
531 // always fail and we will report that the message is potentially
532 // uninitialized.
533 if numRequiredFields > 0 && bits.OnesCount64(st.requiredMask) != numRequiredFields {
534 initialized = false
535 }
Damien Neilb0c26f12019-12-16 09:37:59 -0800536 states = states[:len(states)-1]
537 }
538 if !initialized {
539 return ValidationValidMaybeUninitalized
540 }
541 return ValidationValidInitialized
542}