blob: 40c6a7a10ca92fdeb02fa9abdbd0530be6764365 [file] [log] [blame]
Damien Neilb0c26f12019-12-16 09:37:59 -08001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8 "fmt"
9 "math"
10 "math/bits"
11 "reflect"
12 "unicode/utf8"
13
14 "google.golang.org/protobuf/internal/encoding/wire"
Damien Neil0bf97b72020-01-24 09:00:33 -080015 "google.golang.org/protobuf/internal/flags"
Damien Neilb0c26f12019-12-16 09:37:59 -080016 "google.golang.org/protobuf/internal/strs"
17 pref "google.golang.org/protobuf/reflect/protoreflect"
18 preg "google.golang.org/protobuf/reflect/protoregistry"
19 piface "google.golang.org/protobuf/runtime/protoiface"
20)
21
22// ValidationStatus is the result of validating the wire-format encoding of a message.
23type ValidationStatus int
24
25const (
26 // ValidationUnknown indicates that unmarshaling the message might succeed or fail.
27 // The validator was unable to render a judgement.
28 //
29 // The only causes of this status are an aberrant message type appearing somewhere
30 // in the message or a failure in the extension resolver.
31 ValidationUnknown ValidationStatus = iota + 1
32
33 // ValidationInvalid indicates that unmarshaling the message will fail.
34 ValidationInvalid
35
36 // ValidationValidInitialized indicates that unmarshaling the message will succeed
37 // and IsInitialized on the result will report success.
38 ValidationValidInitialized
39
40 // ValidationValidMaybeUninitalized indicates unmarshaling the message will succeed,
41 // but the output of IsInitialized on the result is unknown.
42 //
43 // This status may be returned for an initialized message when a message value
44 // is split across multiple fields.
45 ValidationValidMaybeUninitalized
46)
47
48func (v ValidationStatus) String() string {
49 switch v {
50 case ValidationUnknown:
51 return "ValidationUnknown"
52 case ValidationInvalid:
53 return "ValidationInvalid"
54 case ValidationValidInitialized:
55 return "ValidationValidInitialized"
56 case ValidationValidMaybeUninitalized:
57 return "ValidationValidMaybeUninitalized"
58 default:
59 return fmt.Sprintf("ValidationStatus(%d)", int(v))
60 }
61}
62
63// Validate determines whether the contents of the buffer are a valid wire encoding
64// of the message type.
65//
66// This function is exposed for testing.
67func Validate(b []byte, mt pref.MessageType, opts piface.UnmarshalOptions) ValidationStatus {
68 mi, ok := mt.(*MessageInfo)
69 if !ok {
70 return ValidationUnknown
71 }
72 return mi.validate(b, 0, newUnmarshalOptions(opts))
73}
74
75type validationInfo struct {
76 mi *MessageInfo
77 typ validationType
78 keyType, valType validationType
79
Damien Neil170b2bf2020-01-24 16:42:42 -080080 // For non-required fields, requiredBit is 0.
Damien Neilb0c26f12019-12-16 09:37:59 -080081 //
Damien Neil170b2bf2020-01-24 16:42:42 -080082 // For required fields, requiredBit's nth bit is set, where n is a
83 // unique index in the range [0, MessageInfo.numRequiredFields).
84 //
85 // If there are more than 64 required fields, requiredBit is 0.
86 requiredBit uint64
Damien Neilb0c26f12019-12-16 09:37:59 -080087}
88
89type validationType uint8
90
91const (
92 validationTypeOther validationType = iota
93 validationTypeMessage
94 validationTypeGroup
95 validationTypeMap
96 validationTypeRepeatedVarint
97 validationTypeRepeatedFixed32
98 validationTypeRepeatedFixed64
99 validationTypeVarint
100 validationTypeFixed32
101 validationTypeFixed64
102 validationTypeBytes
103 validationTypeUTF8String
104)
105
106func newFieldValidationInfo(mi *MessageInfo, si structInfo, fd pref.FieldDescriptor, ft reflect.Type) validationInfo {
107 var vi validationInfo
108 switch {
109 case fd.ContainingOneof() != nil:
110 switch fd.Kind() {
111 case pref.MessageKind:
112 vi.typ = validationTypeMessage
113 if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
114 vi.mi = getMessageInfo(ot.Field(0).Type)
115 }
116 case pref.GroupKind:
117 vi.typ = validationTypeGroup
118 if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
119 vi.mi = getMessageInfo(ot.Field(0).Type)
120 }
121 case pref.StringKind:
122 if strs.EnforceUTF8(fd) {
123 vi.typ = validationTypeUTF8String
124 }
125 }
126 default:
127 vi = newValidationInfo(fd, ft)
128 }
129 if fd.Cardinality() == pref.Required {
130 // Avoid overflow. The required field check is done with a 64-bit mask, with
131 // any message containing more than 64 required fields always reported as
132 // potentially uninitialized, so it is not important to get a precise count
133 // of the required fields past 64.
134 if mi.numRequiredFields < math.MaxUint8 {
135 mi.numRequiredFields++
Damien Neil170b2bf2020-01-24 16:42:42 -0800136 vi.requiredBit = 1 << (mi.numRequiredFields - 1)
Damien Neilb0c26f12019-12-16 09:37:59 -0800137 }
138 }
139 return vi
140}
141
142func newValidationInfo(fd pref.FieldDescriptor, ft reflect.Type) validationInfo {
143 var vi validationInfo
144 switch {
145 case fd.IsList():
146 switch fd.Kind() {
147 case pref.MessageKind:
148 vi.typ = validationTypeMessage
149 if ft.Kind() == reflect.Slice {
150 vi.mi = getMessageInfo(ft.Elem())
151 }
152 case pref.GroupKind:
153 vi.typ = validationTypeGroup
154 if ft.Kind() == reflect.Slice {
155 vi.mi = getMessageInfo(ft.Elem())
156 }
157 case pref.StringKind:
158 vi.typ = validationTypeBytes
159 if strs.EnforceUTF8(fd) {
160 vi.typ = validationTypeUTF8String
161 }
162 default:
163 switch wireTypes[fd.Kind()] {
164 case wire.VarintType:
165 vi.typ = validationTypeRepeatedVarint
166 case wire.Fixed32Type:
167 vi.typ = validationTypeRepeatedFixed32
168 case wire.Fixed64Type:
169 vi.typ = validationTypeRepeatedFixed64
170 }
171 }
172 case fd.IsMap():
173 vi.typ = validationTypeMap
174 switch fd.MapKey().Kind() {
175 case pref.StringKind:
176 if strs.EnforceUTF8(fd) {
177 vi.keyType = validationTypeUTF8String
178 }
179 }
180 switch fd.MapValue().Kind() {
181 case pref.MessageKind:
182 vi.valType = validationTypeMessage
183 if ft.Kind() == reflect.Map {
184 vi.mi = getMessageInfo(ft.Elem())
185 }
186 case pref.StringKind:
187 if strs.EnforceUTF8(fd) {
188 vi.valType = validationTypeUTF8String
189 }
190 }
191 default:
192 switch fd.Kind() {
193 case pref.MessageKind:
194 vi.typ = validationTypeMessage
195 if !fd.IsWeak() {
196 vi.mi = getMessageInfo(ft)
197 }
198 case pref.GroupKind:
199 vi.typ = validationTypeGroup
200 vi.mi = getMessageInfo(ft)
201 case pref.StringKind:
202 vi.typ = validationTypeBytes
203 if strs.EnforceUTF8(fd) {
204 vi.typ = validationTypeUTF8String
205 }
206 default:
207 switch wireTypes[fd.Kind()] {
208 case wire.VarintType:
209 vi.typ = validationTypeVarint
210 case wire.Fixed32Type:
211 vi.typ = validationTypeFixed32
212 case wire.Fixed64Type:
213 vi.typ = validationTypeFixed64
Damien Neil6635e7d2020-01-15 15:08:57 -0800214 case wire.BytesType:
215 vi.typ = validationTypeBytes
Damien Neilb0c26f12019-12-16 09:37:59 -0800216 }
217 }
218 }
219 return vi
220}
221
222func (mi *MessageInfo) validate(b []byte, groupTag wire.Number, opts unmarshalOptions) (result ValidationStatus) {
223 type validationState struct {
224 typ validationType
225 keyType, valType validationType
226 endGroup wire.Number
227 mi *MessageInfo
228 tail []byte
229 requiredMask uint64
230 }
231
232 // Pre-allocate some slots to avoid repeated slice reallocation.
233 states := make([]validationState, 0, 16)
234 states = append(states, validationState{
235 typ: validationTypeMessage,
236 mi: mi,
237 })
238 if groupTag > 0 {
239 states[0].typ = validationTypeGroup
240 states[0].endGroup = groupTag
241 }
242 initialized := true
243State:
244 for len(states) > 0 {
245 st := &states[len(states)-1]
246 if st.mi != nil {
247 st.mi.init()
Damien Neil0bf97b72020-01-24 09:00:33 -0800248 if flags.ProtoLegacy && st.mi.isMessageSet {
249 return ValidationUnknown
250 }
Damien Neilb0c26f12019-12-16 09:37:59 -0800251 }
252 Field:
253 for len(b) > 0 {
Damien Neil5d828832020-01-28 08:06:12 -0800254 // Parse the tag (field number and wire type).
255 var tag uint64
256 if b[0] < 0x80 {
257 tag = uint64(b[0])
258 b = b[1:]
259 } else if len(b) >= 2 && b[1] < 128 {
260 tag = uint64(b[0]&0x7f) + uint64(b[1])<<7
261 b = b[2:]
262 } else {
263 var n int
264 tag, n = wire.ConsumeVarint(b)
265 if n < 0 {
266 return ValidationInvalid
267 }
268 b = b[n:]
Damien Neilb0c26f12019-12-16 09:37:59 -0800269 }
Damien Neil5d828832020-01-28 08:06:12 -0800270 var num wire.Number
271 if n := tag >> 3; n < uint64(wire.MinValidNumber) || n > uint64(wire.MaxValidNumber) {
Damien Neilb0c26f12019-12-16 09:37:59 -0800272 return ValidationInvalid
Damien Neil5d828832020-01-28 08:06:12 -0800273 } else {
274 num = wire.Number(n)
Damien Neilb0c26f12019-12-16 09:37:59 -0800275 }
Damien Neil5d828832020-01-28 08:06:12 -0800276 wtyp := wire.Type(tag & 7)
277
Damien Neilb0c26f12019-12-16 09:37:59 -0800278 if wtyp == wire.EndGroupType {
279 if st.endGroup == num {
280 goto PopState
281 }
282 return ValidationInvalid
283 }
284 var vi validationInfo
285 switch st.typ {
286 case validationTypeMap:
287 switch num {
288 case 1:
289 vi.typ = st.keyType
290 case 2:
291 vi.typ = st.valType
292 vi.mi = st.mi
Damien Neil170b2bf2020-01-24 16:42:42 -0800293 vi.requiredBit = 1
Damien Neilb0c26f12019-12-16 09:37:59 -0800294 }
295 default:
296 var f *coderFieldInfo
297 if int(num) < len(st.mi.denseCoderFields) {
298 f = st.mi.denseCoderFields[num]
299 } else {
300 f = st.mi.coderFields[num]
301 }
302 if f != nil {
303 vi = f.validation
304 if vi.typ == validationTypeMessage && vi.mi == nil {
305 // Probable weak field.
306 //
307 // TODO: Consider storing the results of this lookup somewhere
308 // rather than recomputing it on every validation.
309 fd := st.mi.Desc.Fields().ByNumber(num)
310 if fd == nil || !fd.IsWeak() {
311 break
312 }
313 messageName := fd.Message().FullName()
314 messageType, err := preg.GlobalTypes.FindMessageByName(messageName)
315 switch err {
316 case nil:
317 vi.mi, _ = messageType.(*MessageInfo)
318 case preg.NotFound:
319 vi.typ = validationTypeBytes
320 default:
321 return ValidationUnknown
322 }
323 }
324 break
325 }
326 // Possible extension field.
327 //
328 // TODO: We should return ValidationUnknown when:
329 // 1. The resolver is not frozen. (More extensions may be added to it.)
330 // 2. The resolver returns preg.NotFound.
331 // In this case, a type added to the resolver in the future could cause
332 // unmarshaling to begin failing. Supporting this requires some way to
333 // determine if the resolver is frozen.
334 xt, err := opts.Resolver().FindExtensionByNumber(st.mi.Desc.FullName(), num)
335 if err != nil && err != preg.NotFound {
336 return ValidationUnknown
337 }
338 if err == nil {
339 vi = getExtensionFieldInfo(xt).validation
340 }
341 }
Damien Neil170b2bf2020-01-24 16:42:42 -0800342 if vi.requiredBit != 0 {
Damien Neilb0c26f12019-12-16 09:37:59 -0800343 // Check that the field has a compatible wire type.
344 // We only need to consider non-repeated field types,
345 // since repeated fields (and maps) can never be required.
346 ok := false
347 switch vi.typ {
348 case validationTypeVarint:
349 ok = wtyp == wire.VarintType
350 case validationTypeFixed32:
351 ok = wtyp == wire.Fixed32Type
352 case validationTypeFixed64:
353 ok = wtyp == wire.Fixed64Type
354 case validationTypeBytes, validationTypeUTF8String, validationTypeMessage, validationTypeGroup:
355 ok = wtyp == wire.BytesType
356 }
357 if ok {
Damien Neil170b2bf2020-01-24 16:42:42 -0800358 st.requiredMask |= vi.requiredBit
Damien Neilb0c26f12019-12-16 09:37:59 -0800359 }
360 }
361 switch vi.typ {
362 case validationTypeMessage, validationTypeMap:
363 if wtyp != wire.BytesType {
364 break
365 }
366 if vi.mi == nil && vi.typ == validationTypeMessage {
367 return ValidationUnknown
368 }
369 size, n := wire.ConsumeVarint(b)
370 if n < 0 {
371 return ValidationInvalid
372 }
373 b = b[n:]
374 if uint64(len(b)) < size {
375 return ValidationInvalid
376 }
377 states = append(states, validationState{
378 typ: vi.typ,
379 keyType: vi.keyType,
380 valType: vi.valType,
381 mi: vi.mi,
382 tail: b[size:],
383 })
384 b = b[:size]
385 continue State
386 case validationTypeGroup:
387 if wtyp != wire.StartGroupType {
388 break
389 }
390 if vi.mi == nil {
391 return ValidationUnknown
392 }
393 states = append(states, validationState{
394 typ: validationTypeGroup,
395 mi: vi.mi,
396 endGroup: num,
397 })
398 continue State
399 case validationTypeRepeatedVarint:
400 if wtyp != wire.BytesType {
401 break
402 }
403 // Packed field.
404 v, n := wire.ConsumeBytes(b)
405 if n < 0 {
406 return ValidationInvalid
407 }
408 b = b[n:]
409 for len(v) > 0 {
410 _, n := wire.ConsumeVarint(v)
411 if n < 0 {
412 return ValidationInvalid
413 }
414 v = v[n:]
415 }
416 continue Field
417 case validationTypeRepeatedFixed32:
418 if wtyp != wire.BytesType {
419 break
420 }
421 // Packed field.
422 v, n := wire.ConsumeBytes(b)
423 if n < 0 || len(v)%4 != 0 {
424 return ValidationInvalid
425 }
426 b = b[n:]
427 continue Field
428 case validationTypeRepeatedFixed64:
429 if wtyp != wire.BytesType {
430 break
431 }
432 // Packed field.
433 v, n := wire.ConsumeBytes(b)
434 if n < 0 || len(v)%8 != 0 {
435 return ValidationInvalid
436 }
437 b = b[n:]
438 continue Field
439 case validationTypeUTF8String:
440 if wtyp != wire.BytesType {
441 break
442 }
443 v, n := wire.ConsumeBytes(b)
444 if n < 0 || !utf8.Valid(v) {
445 return ValidationInvalid
446 }
447 b = b[n:]
448 continue Field
449 }
Damien Neil5d828832020-01-28 08:06:12 -0800450 n := wire.ConsumeFieldValue(num, wtyp, b)
Damien Neilb0c26f12019-12-16 09:37:59 -0800451 if n < 0 {
452 return ValidationInvalid
453 }
454 b = b[n:]
455 }
456 if st.endGroup != 0 {
457 return ValidationInvalid
458 }
459 if len(b) != 0 {
460 return ValidationInvalid
461 }
462 b = st.tail
463 PopState:
Damien Neil54a0a042020-01-08 17:53:16 -0800464 numRequiredFields := 0
Damien Neilb0c26f12019-12-16 09:37:59 -0800465 switch st.typ {
466 case validationTypeMessage, validationTypeGroup:
Damien Neil54a0a042020-01-08 17:53:16 -0800467 numRequiredFields = int(st.mi.numRequiredFields)
468 case validationTypeMap:
469 // If this is a map field with a message value that contains
470 // required fields, require that the value be present.
471 if st.mi != nil && st.mi.numRequiredFields > 0 {
472 numRequiredFields = 1
Damien Neilb0c26f12019-12-16 09:37:59 -0800473 }
474 }
Damien Neil54a0a042020-01-08 17:53:16 -0800475 // If there are more than 64 required fields, this check will
476 // always fail and we will report that the message is potentially
477 // uninitialized.
478 if numRequiredFields > 0 && bits.OnesCount64(st.requiredMask) != numRequiredFields {
479 initialized = false
480 }
Damien Neilb0c26f12019-12-16 09:37:59 -0800481 states = states[:len(states)-1]
482 }
483 if !initialized {
484 return ValidationValidMaybeUninitalized
485 }
486 return ValidationValidInitialized
487}