blob: cb6a820f10b92e9cb20448c32f768b8d0c15a53c [file] [log] [blame]
Damien Neilb0c26f12019-12-16 09:37:59 -08001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8 "fmt"
9 "math"
10 "math/bits"
11 "reflect"
12 "unicode/utf8"
13
14 "google.golang.org/protobuf/internal/encoding/wire"
Damien Neil0bf97b72020-01-24 09:00:33 -080015 "google.golang.org/protobuf/internal/flags"
Damien Neilb0c26f12019-12-16 09:37:59 -080016 "google.golang.org/protobuf/internal/strs"
17 pref "google.golang.org/protobuf/reflect/protoreflect"
18 preg "google.golang.org/protobuf/reflect/protoregistry"
19 piface "google.golang.org/protobuf/runtime/protoiface"
20)
21
22// ValidationStatus is the result of validating the wire-format encoding of a message.
23type ValidationStatus int
24
25const (
26 // ValidationUnknown indicates that unmarshaling the message might succeed or fail.
27 // The validator was unable to render a judgement.
28 //
29 // The only causes of this status are an aberrant message type appearing somewhere
30 // in the message or a failure in the extension resolver.
31 ValidationUnknown ValidationStatus = iota + 1
32
33 // ValidationInvalid indicates that unmarshaling the message will fail.
34 ValidationInvalid
35
36 // ValidationValidInitialized indicates that unmarshaling the message will succeed
37 // and IsInitialized on the result will report success.
38 ValidationValidInitialized
39
40 // ValidationValidMaybeUninitalized indicates unmarshaling the message will succeed,
41 // but the output of IsInitialized on the result is unknown.
42 //
43 // This status may be returned for an initialized message when a message value
44 // is split across multiple fields.
45 ValidationValidMaybeUninitalized
46)
47
48func (v ValidationStatus) String() string {
49 switch v {
50 case ValidationUnknown:
51 return "ValidationUnknown"
52 case ValidationInvalid:
53 return "ValidationInvalid"
54 case ValidationValidInitialized:
55 return "ValidationValidInitialized"
56 case ValidationValidMaybeUninitalized:
57 return "ValidationValidMaybeUninitalized"
58 default:
59 return fmt.Sprintf("ValidationStatus(%d)", int(v))
60 }
61}
62
63// Validate determines whether the contents of the buffer are a valid wire encoding
64// of the message type.
65//
66// This function is exposed for testing.
67func Validate(b []byte, mt pref.MessageType, opts piface.UnmarshalOptions) ValidationStatus {
68 mi, ok := mt.(*MessageInfo)
69 if !ok {
70 return ValidationUnknown
71 }
72 return mi.validate(b, 0, newUnmarshalOptions(opts))
73}
74
75type validationInfo struct {
76 mi *MessageInfo
77 typ validationType
78 keyType, valType validationType
79
80 // For non-required fields, requiredIndex is 0.
81 //
82 // For required fields, requiredIndex is unique index in the range
83 // (0, MessageInfo.numRequiredFields].
84 requiredIndex uint8
85}
86
87type validationType uint8
88
89const (
90 validationTypeOther validationType = iota
91 validationTypeMessage
92 validationTypeGroup
93 validationTypeMap
94 validationTypeRepeatedVarint
95 validationTypeRepeatedFixed32
96 validationTypeRepeatedFixed64
97 validationTypeVarint
98 validationTypeFixed32
99 validationTypeFixed64
100 validationTypeBytes
101 validationTypeUTF8String
102)
103
104func newFieldValidationInfo(mi *MessageInfo, si structInfo, fd pref.FieldDescriptor, ft reflect.Type) validationInfo {
105 var vi validationInfo
106 switch {
107 case fd.ContainingOneof() != nil:
108 switch fd.Kind() {
109 case pref.MessageKind:
110 vi.typ = validationTypeMessage
111 if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
112 vi.mi = getMessageInfo(ot.Field(0).Type)
113 }
114 case pref.GroupKind:
115 vi.typ = validationTypeGroup
116 if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
117 vi.mi = getMessageInfo(ot.Field(0).Type)
118 }
119 case pref.StringKind:
120 if strs.EnforceUTF8(fd) {
121 vi.typ = validationTypeUTF8String
122 }
123 }
124 default:
125 vi = newValidationInfo(fd, ft)
126 }
127 if fd.Cardinality() == pref.Required {
128 // Avoid overflow. The required field check is done with a 64-bit mask, with
129 // any message containing more than 64 required fields always reported as
130 // potentially uninitialized, so it is not important to get a precise count
131 // of the required fields past 64.
132 if mi.numRequiredFields < math.MaxUint8 {
133 mi.numRequiredFields++
134 vi.requiredIndex = mi.numRequiredFields
135 }
136 }
137 return vi
138}
139
140func newValidationInfo(fd pref.FieldDescriptor, ft reflect.Type) validationInfo {
141 var vi validationInfo
142 switch {
143 case fd.IsList():
144 switch fd.Kind() {
145 case pref.MessageKind:
146 vi.typ = validationTypeMessage
147 if ft.Kind() == reflect.Slice {
148 vi.mi = getMessageInfo(ft.Elem())
149 }
150 case pref.GroupKind:
151 vi.typ = validationTypeGroup
152 if ft.Kind() == reflect.Slice {
153 vi.mi = getMessageInfo(ft.Elem())
154 }
155 case pref.StringKind:
156 vi.typ = validationTypeBytes
157 if strs.EnforceUTF8(fd) {
158 vi.typ = validationTypeUTF8String
159 }
160 default:
161 switch wireTypes[fd.Kind()] {
162 case wire.VarintType:
163 vi.typ = validationTypeRepeatedVarint
164 case wire.Fixed32Type:
165 vi.typ = validationTypeRepeatedFixed32
166 case wire.Fixed64Type:
167 vi.typ = validationTypeRepeatedFixed64
168 }
169 }
170 case fd.IsMap():
171 vi.typ = validationTypeMap
172 switch fd.MapKey().Kind() {
173 case pref.StringKind:
174 if strs.EnforceUTF8(fd) {
175 vi.keyType = validationTypeUTF8String
176 }
177 }
178 switch fd.MapValue().Kind() {
179 case pref.MessageKind:
180 vi.valType = validationTypeMessage
181 if ft.Kind() == reflect.Map {
182 vi.mi = getMessageInfo(ft.Elem())
183 }
184 case pref.StringKind:
185 if strs.EnforceUTF8(fd) {
186 vi.valType = validationTypeUTF8String
187 }
188 }
189 default:
190 switch fd.Kind() {
191 case pref.MessageKind:
192 vi.typ = validationTypeMessage
193 if !fd.IsWeak() {
194 vi.mi = getMessageInfo(ft)
195 }
196 case pref.GroupKind:
197 vi.typ = validationTypeGroup
198 vi.mi = getMessageInfo(ft)
199 case pref.StringKind:
200 vi.typ = validationTypeBytes
201 if strs.EnforceUTF8(fd) {
202 vi.typ = validationTypeUTF8String
203 }
204 default:
205 switch wireTypes[fd.Kind()] {
206 case wire.VarintType:
207 vi.typ = validationTypeVarint
208 case wire.Fixed32Type:
209 vi.typ = validationTypeFixed32
210 case wire.Fixed64Type:
211 vi.typ = validationTypeFixed64
Damien Neil6635e7d2020-01-15 15:08:57 -0800212 case wire.BytesType:
213 vi.typ = validationTypeBytes
Damien Neilb0c26f12019-12-16 09:37:59 -0800214 }
215 }
216 }
217 return vi
218}
219
220func (mi *MessageInfo) validate(b []byte, groupTag wire.Number, opts unmarshalOptions) (result ValidationStatus) {
221 type validationState struct {
222 typ validationType
223 keyType, valType validationType
224 endGroup wire.Number
225 mi *MessageInfo
226 tail []byte
227 requiredMask uint64
228 }
229
230 // Pre-allocate some slots to avoid repeated slice reallocation.
231 states := make([]validationState, 0, 16)
232 states = append(states, validationState{
233 typ: validationTypeMessage,
234 mi: mi,
235 })
236 if groupTag > 0 {
237 states[0].typ = validationTypeGroup
238 states[0].endGroup = groupTag
239 }
240 initialized := true
241State:
242 for len(states) > 0 {
243 st := &states[len(states)-1]
244 if st.mi != nil {
245 st.mi.init()
Damien Neil0bf97b72020-01-24 09:00:33 -0800246 if flags.ProtoLegacy && st.mi.isMessageSet {
247 return ValidationUnknown
248 }
Damien Neilb0c26f12019-12-16 09:37:59 -0800249 }
250 Field:
251 for len(b) > 0 {
252 num, wtyp, n := wire.ConsumeTag(b)
253 if n < 0 {
254 return ValidationInvalid
255 }
256 b = b[n:]
257 if num > wire.MaxValidNumber {
258 return ValidationInvalid
259 }
260 if wtyp == wire.EndGroupType {
261 if st.endGroup == num {
262 goto PopState
263 }
264 return ValidationInvalid
265 }
266 var vi validationInfo
267 switch st.typ {
268 case validationTypeMap:
269 switch num {
270 case 1:
271 vi.typ = st.keyType
272 case 2:
273 vi.typ = st.valType
274 vi.mi = st.mi
Damien Neil54a0a042020-01-08 17:53:16 -0800275 vi.requiredIndex = 1
Damien Neilb0c26f12019-12-16 09:37:59 -0800276 }
277 default:
278 var f *coderFieldInfo
279 if int(num) < len(st.mi.denseCoderFields) {
280 f = st.mi.denseCoderFields[num]
281 } else {
282 f = st.mi.coderFields[num]
283 }
284 if f != nil {
285 vi = f.validation
286 if vi.typ == validationTypeMessage && vi.mi == nil {
287 // Probable weak field.
288 //
289 // TODO: Consider storing the results of this lookup somewhere
290 // rather than recomputing it on every validation.
291 fd := st.mi.Desc.Fields().ByNumber(num)
292 if fd == nil || !fd.IsWeak() {
293 break
294 }
295 messageName := fd.Message().FullName()
296 messageType, err := preg.GlobalTypes.FindMessageByName(messageName)
297 switch err {
298 case nil:
299 vi.mi, _ = messageType.(*MessageInfo)
300 case preg.NotFound:
301 vi.typ = validationTypeBytes
302 default:
303 return ValidationUnknown
304 }
305 }
306 break
307 }
308 // Possible extension field.
309 //
310 // TODO: We should return ValidationUnknown when:
311 // 1. The resolver is not frozen. (More extensions may be added to it.)
312 // 2. The resolver returns preg.NotFound.
313 // In this case, a type added to the resolver in the future could cause
314 // unmarshaling to begin failing. Supporting this requires some way to
315 // determine if the resolver is frozen.
316 xt, err := opts.Resolver().FindExtensionByNumber(st.mi.Desc.FullName(), num)
317 if err != nil && err != preg.NotFound {
318 return ValidationUnknown
319 }
320 if err == nil {
321 vi = getExtensionFieldInfo(xt).validation
322 }
323 }
324 if vi.requiredIndex > 0 {
325 // Check that the field has a compatible wire type.
326 // We only need to consider non-repeated field types,
327 // since repeated fields (and maps) can never be required.
328 ok := false
329 switch vi.typ {
330 case validationTypeVarint:
331 ok = wtyp == wire.VarintType
332 case validationTypeFixed32:
333 ok = wtyp == wire.Fixed32Type
334 case validationTypeFixed64:
335 ok = wtyp == wire.Fixed64Type
336 case validationTypeBytes, validationTypeUTF8String, validationTypeMessage, validationTypeGroup:
337 ok = wtyp == wire.BytesType
338 }
339 if ok {
340 st.requiredMask |= 1 << (vi.requiredIndex - 1)
341 }
342 }
343 switch vi.typ {
344 case validationTypeMessage, validationTypeMap:
345 if wtyp != wire.BytesType {
346 break
347 }
348 if vi.mi == nil && vi.typ == validationTypeMessage {
349 return ValidationUnknown
350 }
351 size, n := wire.ConsumeVarint(b)
352 if n < 0 {
353 return ValidationInvalid
354 }
355 b = b[n:]
356 if uint64(len(b)) < size {
357 return ValidationInvalid
358 }
359 states = append(states, validationState{
360 typ: vi.typ,
361 keyType: vi.keyType,
362 valType: vi.valType,
363 mi: vi.mi,
364 tail: b[size:],
365 })
366 b = b[:size]
367 continue State
368 case validationTypeGroup:
369 if wtyp != wire.StartGroupType {
370 break
371 }
372 if vi.mi == nil {
373 return ValidationUnknown
374 }
375 states = append(states, validationState{
376 typ: validationTypeGroup,
377 mi: vi.mi,
378 endGroup: num,
379 })
380 continue State
381 case validationTypeRepeatedVarint:
382 if wtyp != wire.BytesType {
383 break
384 }
385 // Packed field.
386 v, n := wire.ConsumeBytes(b)
387 if n < 0 {
388 return ValidationInvalid
389 }
390 b = b[n:]
391 for len(v) > 0 {
392 _, n := wire.ConsumeVarint(v)
393 if n < 0 {
394 return ValidationInvalid
395 }
396 v = v[n:]
397 }
398 continue Field
399 case validationTypeRepeatedFixed32:
400 if wtyp != wire.BytesType {
401 break
402 }
403 // Packed field.
404 v, n := wire.ConsumeBytes(b)
405 if n < 0 || len(v)%4 != 0 {
406 return ValidationInvalid
407 }
408 b = b[n:]
409 continue Field
410 case validationTypeRepeatedFixed64:
411 if wtyp != wire.BytesType {
412 break
413 }
414 // Packed field.
415 v, n := wire.ConsumeBytes(b)
416 if n < 0 || len(v)%8 != 0 {
417 return ValidationInvalid
418 }
419 b = b[n:]
420 continue Field
421 case validationTypeUTF8String:
422 if wtyp != wire.BytesType {
423 break
424 }
425 v, n := wire.ConsumeBytes(b)
426 if n < 0 || !utf8.Valid(v) {
427 return ValidationInvalid
428 }
429 b = b[n:]
430 continue Field
431 }
432 n = wire.ConsumeFieldValue(num, wtyp, b)
433 if n < 0 {
434 return ValidationInvalid
435 }
436 b = b[n:]
437 }
438 if st.endGroup != 0 {
439 return ValidationInvalid
440 }
441 if len(b) != 0 {
442 return ValidationInvalid
443 }
444 b = st.tail
445 PopState:
Damien Neil54a0a042020-01-08 17:53:16 -0800446 numRequiredFields := 0
Damien Neilb0c26f12019-12-16 09:37:59 -0800447 switch st.typ {
448 case validationTypeMessage, validationTypeGroup:
Damien Neil54a0a042020-01-08 17:53:16 -0800449 numRequiredFields = int(st.mi.numRequiredFields)
450 case validationTypeMap:
451 // If this is a map field with a message value that contains
452 // required fields, require that the value be present.
453 if st.mi != nil && st.mi.numRequiredFields > 0 {
454 numRequiredFields = 1
Damien Neilb0c26f12019-12-16 09:37:59 -0800455 }
456 }
Damien Neil54a0a042020-01-08 17:53:16 -0800457 // If there are more than 64 required fields, this check will
458 // always fail and we will report that the message is potentially
459 // uninitialized.
460 if numRequiredFields > 0 && bits.OnesCount64(st.requiredMask) != numRequiredFields {
461 initialized = false
462 }
Damien Neilb0c26f12019-12-16 09:37:59 -0800463 states = states[:len(states)-1]
464 }
465 if !initialized {
466 return ValidationValidMaybeUninitalized
467 }
468 return ValidationValidInitialized
469}