blob: 5f3c3f5a3a4227194c6783a75d2286f18721edd4 [file] [log] [blame]
Damien Neilb0c26f12019-12-16 09:37:59 -08001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8 "fmt"
9 "math"
10 "math/bits"
11 "reflect"
12 "unicode/utf8"
13
14 "google.golang.org/protobuf/internal/encoding/wire"
15 "google.golang.org/protobuf/internal/strs"
16 pref "google.golang.org/protobuf/reflect/protoreflect"
17 preg "google.golang.org/protobuf/reflect/protoregistry"
18 piface "google.golang.org/protobuf/runtime/protoiface"
19)
20
21// ValidationStatus is the result of validating the wire-format encoding of a message.
22type ValidationStatus int
23
24const (
25 // ValidationUnknown indicates that unmarshaling the message might succeed or fail.
26 // The validator was unable to render a judgement.
27 //
28 // The only causes of this status are an aberrant message type appearing somewhere
29 // in the message or a failure in the extension resolver.
30 ValidationUnknown ValidationStatus = iota + 1
31
32 // ValidationInvalid indicates that unmarshaling the message will fail.
33 ValidationInvalid
34
35 // ValidationValidInitialized indicates that unmarshaling the message will succeed
36 // and IsInitialized on the result will report success.
37 ValidationValidInitialized
38
39 // ValidationValidMaybeUninitalized indicates unmarshaling the message will succeed,
40 // but the output of IsInitialized on the result is unknown.
41 //
42 // This status may be returned for an initialized message when a message value
43 // is split across multiple fields.
44 ValidationValidMaybeUninitalized
45)
46
47func (v ValidationStatus) String() string {
48 switch v {
49 case ValidationUnknown:
50 return "ValidationUnknown"
51 case ValidationInvalid:
52 return "ValidationInvalid"
53 case ValidationValidInitialized:
54 return "ValidationValidInitialized"
55 case ValidationValidMaybeUninitalized:
56 return "ValidationValidMaybeUninitalized"
57 default:
58 return fmt.Sprintf("ValidationStatus(%d)", int(v))
59 }
60}
61
62// Validate determines whether the contents of the buffer are a valid wire encoding
63// of the message type.
64//
65// This function is exposed for testing.
66func Validate(b []byte, mt pref.MessageType, opts piface.UnmarshalOptions) ValidationStatus {
67 mi, ok := mt.(*MessageInfo)
68 if !ok {
69 return ValidationUnknown
70 }
71 return mi.validate(b, 0, newUnmarshalOptions(opts))
72}
73
74type validationInfo struct {
75 mi *MessageInfo
76 typ validationType
77 keyType, valType validationType
78
79 // For non-required fields, requiredIndex is 0.
80 //
81 // For required fields, requiredIndex is unique index in the range
82 // (0, MessageInfo.numRequiredFields].
83 requiredIndex uint8
84}
85
86type validationType uint8
87
88const (
89 validationTypeOther validationType = iota
90 validationTypeMessage
91 validationTypeGroup
92 validationTypeMap
93 validationTypeRepeatedVarint
94 validationTypeRepeatedFixed32
95 validationTypeRepeatedFixed64
96 validationTypeVarint
97 validationTypeFixed32
98 validationTypeFixed64
99 validationTypeBytes
100 validationTypeUTF8String
101)
102
103func newFieldValidationInfo(mi *MessageInfo, si structInfo, fd pref.FieldDescriptor, ft reflect.Type) validationInfo {
104 var vi validationInfo
105 switch {
106 case fd.ContainingOneof() != nil:
107 switch fd.Kind() {
108 case pref.MessageKind:
109 vi.typ = validationTypeMessage
110 if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
111 vi.mi = getMessageInfo(ot.Field(0).Type)
112 }
113 case pref.GroupKind:
114 vi.typ = validationTypeGroup
115 if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
116 vi.mi = getMessageInfo(ot.Field(0).Type)
117 }
118 case pref.StringKind:
119 if strs.EnforceUTF8(fd) {
120 vi.typ = validationTypeUTF8String
121 }
122 }
123 default:
124 vi = newValidationInfo(fd, ft)
125 }
126 if fd.Cardinality() == pref.Required {
127 // Avoid overflow. The required field check is done with a 64-bit mask, with
128 // any message containing more than 64 required fields always reported as
129 // potentially uninitialized, so it is not important to get a precise count
130 // of the required fields past 64.
131 if mi.numRequiredFields < math.MaxUint8 {
132 mi.numRequiredFields++
133 vi.requiredIndex = mi.numRequiredFields
134 }
135 }
136 return vi
137}
138
139func newValidationInfo(fd pref.FieldDescriptor, ft reflect.Type) validationInfo {
140 var vi validationInfo
141 switch {
142 case fd.IsList():
143 switch fd.Kind() {
144 case pref.MessageKind:
145 vi.typ = validationTypeMessage
146 if ft.Kind() == reflect.Slice {
147 vi.mi = getMessageInfo(ft.Elem())
148 }
149 case pref.GroupKind:
150 vi.typ = validationTypeGroup
151 if ft.Kind() == reflect.Slice {
152 vi.mi = getMessageInfo(ft.Elem())
153 }
154 case pref.StringKind:
155 vi.typ = validationTypeBytes
156 if strs.EnforceUTF8(fd) {
157 vi.typ = validationTypeUTF8String
158 }
159 default:
160 switch wireTypes[fd.Kind()] {
161 case wire.VarintType:
162 vi.typ = validationTypeRepeatedVarint
163 case wire.Fixed32Type:
164 vi.typ = validationTypeRepeatedFixed32
165 case wire.Fixed64Type:
166 vi.typ = validationTypeRepeatedFixed64
167 }
168 }
169 case fd.IsMap():
170 vi.typ = validationTypeMap
171 switch fd.MapKey().Kind() {
172 case pref.StringKind:
173 if strs.EnforceUTF8(fd) {
174 vi.keyType = validationTypeUTF8String
175 }
176 }
177 switch fd.MapValue().Kind() {
178 case pref.MessageKind:
179 vi.valType = validationTypeMessage
180 if ft.Kind() == reflect.Map {
181 vi.mi = getMessageInfo(ft.Elem())
182 }
183 case pref.StringKind:
184 if strs.EnforceUTF8(fd) {
185 vi.valType = validationTypeUTF8String
186 }
187 }
188 default:
189 switch fd.Kind() {
190 case pref.MessageKind:
191 vi.typ = validationTypeMessage
192 if !fd.IsWeak() {
193 vi.mi = getMessageInfo(ft)
194 }
195 case pref.GroupKind:
196 vi.typ = validationTypeGroup
197 vi.mi = getMessageInfo(ft)
198 case pref.StringKind:
199 vi.typ = validationTypeBytes
200 if strs.EnforceUTF8(fd) {
201 vi.typ = validationTypeUTF8String
202 }
203 default:
204 switch wireTypes[fd.Kind()] {
205 case wire.VarintType:
206 vi.typ = validationTypeVarint
207 case wire.Fixed32Type:
208 vi.typ = validationTypeFixed32
209 case wire.Fixed64Type:
210 vi.typ = validationTypeFixed64
211 }
212 }
213 }
214 return vi
215}
216
217func (mi *MessageInfo) validate(b []byte, groupTag wire.Number, opts unmarshalOptions) (result ValidationStatus) {
218 type validationState struct {
219 typ validationType
220 keyType, valType validationType
221 endGroup wire.Number
222 mi *MessageInfo
223 tail []byte
224 requiredMask uint64
225 }
226
227 // Pre-allocate some slots to avoid repeated slice reallocation.
228 states := make([]validationState, 0, 16)
229 states = append(states, validationState{
230 typ: validationTypeMessage,
231 mi: mi,
232 })
233 if groupTag > 0 {
234 states[0].typ = validationTypeGroup
235 states[0].endGroup = groupTag
236 }
237 initialized := true
238State:
239 for len(states) > 0 {
240 st := &states[len(states)-1]
241 if st.mi != nil {
242 st.mi.init()
243 }
244 Field:
245 for len(b) > 0 {
246 num, wtyp, n := wire.ConsumeTag(b)
247 if n < 0 {
248 return ValidationInvalid
249 }
250 b = b[n:]
251 if num > wire.MaxValidNumber {
252 return ValidationInvalid
253 }
254 if wtyp == wire.EndGroupType {
255 if st.endGroup == num {
256 goto PopState
257 }
258 return ValidationInvalid
259 }
260 var vi validationInfo
261 switch st.typ {
262 case validationTypeMap:
263 switch num {
264 case 1:
265 vi.typ = st.keyType
266 case 2:
267 vi.typ = st.valType
268 vi.mi = st.mi
Damien Neil54a0a042020-01-08 17:53:16 -0800269 vi.requiredIndex = 1
Damien Neilb0c26f12019-12-16 09:37:59 -0800270 }
271 default:
272 var f *coderFieldInfo
273 if int(num) < len(st.mi.denseCoderFields) {
274 f = st.mi.denseCoderFields[num]
275 } else {
276 f = st.mi.coderFields[num]
277 }
278 if f != nil {
279 vi = f.validation
280 if vi.typ == validationTypeMessage && vi.mi == nil {
281 // Probable weak field.
282 //
283 // TODO: Consider storing the results of this lookup somewhere
284 // rather than recomputing it on every validation.
285 fd := st.mi.Desc.Fields().ByNumber(num)
286 if fd == nil || !fd.IsWeak() {
287 break
288 }
289 messageName := fd.Message().FullName()
290 messageType, err := preg.GlobalTypes.FindMessageByName(messageName)
291 switch err {
292 case nil:
293 vi.mi, _ = messageType.(*MessageInfo)
294 case preg.NotFound:
295 vi.typ = validationTypeBytes
296 default:
297 return ValidationUnknown
298 }
299 }
300 break
301 }
302 // Possible extension field.
303 //
304 // TODO: We should return ValidationUnknown when:
305 // 1. The resolver is not frozen. (More extensions may be added to it.)
306 // 2. The resolver returns preg.NotFound.
307 // In this case, a type added to the resolver in the future could cause
308 // unmarshaling to begin failing. Supporting this requires some way to
309 // determine if the resolver is frozen.
310 xt, err := opts.Resolver().FindExtensionByNumber(st.mi.Desc.FullName(), num)
311 if err != nil && err != preg.NotFound {
312 return ValidationUnknown
313 }
314 if err == nil {
315 vi = getExtensionFieldInfo(xt).validation
316 }
317 }
318 if vi.requiredIndex > 0 {
319 // Check that the field has a compatible wire type.
320 // We only need to consider non-repeated field types,
321 // since repeated fields (and maps) can never be required.
322 ok := false
323 switch vi.typ {
324 case validationTypeVarint:
325 ok = wtyp == wire.VarintType
326 case validationTypeFixed32:
327 ok = wtyp == wire.Fixed32Type
328 case validationTypeFixed64:
329 ok = wtyp == wire.Fixed64Type
330 case validationTypeBytes, validationTypeUTF8String, validationTypeMessage, validationTypeGroup:
331 ok = wtyp == wire.BytesType
332 }
333 if ok {
334 st.requiredMask |= 1 << (vi.requiredIndex - 1)
335 }
336 }
337 switch vi.typ {
338 case validationTypeMessage, validationTypeMap:
339 if wtyp != wire.BytesType {
340 break
341 }
342 if vi.mi == nil && vi.typ == validationTypeMessage {
343 return ValidationUnknown
344 }
345 size, n := wire.ConsumeVarint(b)
346 if n < 0 {
347 return ValidationInvalid
348 }
349 b = b[n:]
350 if uint64(len(b)) < size {
351 return ValidationInvalid
352 }
353 states = append(states, validationState{
354 typ: vi.typ,
355 keyType: vi.keyType,
356 valType: vi.valType,
357 mi: vi.mi,
358 tail: b[size:],
359 })
360 b = b[:size]
361 continue State
362 case validationTypeGroup:
363 if wtyp != wire.StartGroupType {
364 break
365 }
366 if vi.mi == nil {
367 return ValidationUnknown
368 }
369 states = append(states, validationState{
370 typ: validationTypeGroup,
371 mi: vi.mi,
372 endGroup: num,
373 })
374 continue State
375 case validationTypeRepeatedVarint:
376 if wtyp != wire.BytesType {
377 break
378 }
379 // Packed field.
380 v, n := wire.ConsumeBytes(b)
381 if n < 0 {
382 return ValidationInvalid
383 }
384 b = b[n:]
385 for len(v) > 0 {
386 _, n := wire.ConsumeVarint(v)
387 if n < 0 {
388 return ValidationInvalid
389 }
390 v = v[n:]
391 }
392 continue Field
393 case validationTypeRepeatedFixed32:
394 if wtyp != wire.BytesType {
395 break
396 }
397 // Packed field.
398 v, n := wire.ConsumeBytes(b)
399 if n < 0 || len(v)%4 != 0 {
400 return ValidationInvalid
401 }
402 b = b[n:]
403 continue Field
404 case validationTypeRepeatedFixed64:
405 if wtyp != wire.BytesType {
406 break
407 }
408 // Packed field.
409 v, n := wire.ConsumeBytes(b)
410 if n < 0 || len(v)%8 != 0 {
411 return ValidationInvalid
412 }
413 b = b[n:]
414 continue Field
415 case validationTypeUTF8String:
416 if wtyp != wire.BytesType {
417 break
418 }
419 v, n := wire.ConsumeBytes(b)
420 if n < 0 || !utf8.Valid(v) {
421 return ValidationInvalid
422 }
423 b = b[n:]
424 continue Field
425 }
426 n = wire.ConsumeFieldValue(num, wtyp, b)
427 if n < 0 {
428 return ValidationInvalid
429 }
430 b = b[n:]
431 }
432 if st.endGroup != 0 {
433 return ValidationInvalid
434 }
435 if len(b) != 0 {
436 return ValidationInvalid
437 }
438 b = st.tail
439 PopState:
Damien Neil54a0a042020-01-08 17:53:16 -0800440 numRequiredFields := 0
Damien Neilb0c26f12019-12-16 09:37:59 -0800441 switch st.typ {
442 case validationTypeMessage, validationTypeGroup:
Damien Neil54a0a042020-01-08 17:53:16 -0800443 numRequiredFields = int(st.mi.numRequiredFields)
444 case validationTypeMap:
445 // If this is a map field with a message value that contains
446 // required fields, require that the value be present.
447 if st.mi != nil && st.mi.numRequiredFields > 0 {
448 numRequiredFields = 1
Damien Neilb0c26f12019-12-16 09:37:59 -0800449 }
450 }
Damien Neil54a0a042020-01-08 17:53:16 -0800451 // If there are more than 64 required fields, this check will
452 // always fail and we will report that the message is potentially
453 // uninitialized.
454 if numRequiredFields > 0 && bits.OnesCount64(st.requiredMask) != numRequiredFields {
455 initialized = false
456 }
Damien Neilb0c26f12019-12-16 09:37:59 -0800457 states = states[:len(states)-1]
458 }
459 if !initialized {
460 return ValidationValidMaybeUninitalized
461 }
462 return ValidationValidInitialized
463}