blob: e6b3d233932f6b927f56f8e9ab614975ad5f2fb4 [file] [log] [blame]
Damien Neilb0c26f12019-12-16 09:37:59 -08001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8 "fmt"
9 "math"
10 "math/bits"
11 "reflect"
12 "unicode/utf8"
13
14 "google.golang.org/protobuf/internal/encoding/wire"
15 "google.golang.org/protobuf/internal/strs"
16 pref "google.golang.org/protobuf/reflect/protoreflect"
17 preg "google.golang.org/protobuf/reflect/protoregistry"
18 piface "google.golang.org/protobuf/runtime/protoiface"
19)
20
21// ValidationStatus is the result of validating the wire-format encoding of a message.
22type ValidationStatus int
23
24const (
25 // ValidationUnknown indicates that unmarshaling the message might succeed or fail.
26 // The validator was unable to render a judgement.
27 //
28 // The only causes of this status are an aberrant message type appearing somewhere
29 // in the message or a failure in the extension resolver.
30 ValidationUnknown ValidationStatus = iota + 1
31
32 // ValidationInvalid indicates that unmarshaling the message will fail.
33 ValidationInvalid
34
35 // ValidationValidInitialized indicates that unmarshaling the message will succeed
36 // and IsInitialized on the result will report success.
37 ValidationValidInitialized
38
39 // ValidationValidMaybeUninitalized indicates unmarshaling the message will succeed,
40 // but the output of IsInitialized on the result is unknown.
41 //
42 // This status may be returned for an initialized message when a message value
43 // is split across multiple fields.
44 ValidationValidMaybeUninitalized
45)
46
47func (v ValidationStatus) String() string {
48 switch v {
49 case ValidationUnknown:
50 return "ValidationUnknown"
51 case ValidationInvalid:
52 return "ValidationInvalid"
53 case ValidationValidInitialized:
54 return "ValidationValidInitialized"
55 case ValidationValidMaybeUninitalized:
56 return "ValidationValidMaybeUninitalized"
57 default:
58 return fmt.Sprintf("ValidationStatus(%d)", int(v))
59 }
60}
61
62// Validate determines whether the contents of the buffer are a valid wire encoding
63// of the message type.
64//
65// This function is exposed for testing.
66func Validate(b []byte, mt pref.MessageType, opts piface.UnmarshalOptions) ValidationStatus {
67 mi, ok := mt.(*MessageInfo)
68 if !ok {
69 return ValidationUnknown
70 }
71 return mi.validate(b, 0, newUnmarshalOptions(opts))
72}
73
74type validationInfo struct {
75 mi *MessageInfo
76 typ validationType
77 keyType, valType validationType
78
79 // For non-required fields, requiredIndex is 0.
80 //
81 // For required fields, requiredIndex is unique index in the range
82 // (0, MessageInfo.numRequiredFields].
83 requiredIndex uint8
84}
85
86type validationType uint8
87
88const (
89 validationTypeOther validationType = iota
90 validationTypeMessage
91 validationTypeGroup
92 validationTypeMap
93 validationTypeRepeatedVarint
94 validationTypeRepeatedFixed32
95 validationTypeRepeatedFixed64
96 validationTypeVarint
97 validationTypeFixed32
98 validationTypeFixed64
99 validationTypeBytes
100 validationTypeUTF8String
101)
102
103func newFieldValidationInfo(mi *MessageInfo, si structInfo, fd pref.FieldDescriptor, ft reflect.Type) validationInfo {
104 var vi validationInfo
105 switch {
106 case fd.ContainingOneof() != nil:
107 switch fd.Kind() {
108 case pref.MessageKind:
109 vi.typ = validationTypeMessage
110 if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
111 vi.mi = getMessageInfo(ot.Field(0).Type)
112 }
113 case pref.GroupKind:
114 vi.typ = validationTypeGroup
115 if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
116 vi.mi = getMessageInfo(ot.Field(0).Type)
117 }
118 case pref.StringKind:
119 if strs.EnforceUTF8(fd) {
120 vi.typ = validationTypeUTF8String
121 }
122 }
123 default:
124 vi = newValidationInfo(fd, ft)
125 }
126 if fd.Cardinality() == pref.Required {
127 // Avoid overflow. The required field check is done with a 64-bit mask, with
128 // any message containing more than 64 required fields always reported as
129 // potentially uninitialized, so it is not important to get a precise count
130 // of the required fields past 64.
131 if mi.numRequiredFields < math.MaxUint8 {
132 mi.numRequiredFields++
133 vi.requiredIndex = mi.numRequiredFields
134 }
135 }
136 return vi
137}
138
139func newValidationInfo(fd pref.FieldDescriptor, ft reflect.Type) validationInfo {
140 var vi validationInfo
141 switch {
142 case fd.IsList():
143 switch fd.Kind() {
144 case pref.MessageKind:
145 vi.typ = validationTypeMessage
146 if ft.Kind() == reflect.Slice {
147 vi.mi = getMessageInfo(ft.Elem())
148 }
149 case pref.GroupKind:
150 vi.typ = validationTypeGroup
151 if ft.Kind() == reflect.Slice {
152 vi.mi = getMessageInfo(ft.Elem())
153 }
154 case pref.StringKind:
155 vi.typ = validationTypeBytes
156 if strs.EnforceUTF8(fd) {
157 vi.typ = validationTypeUTF8String
158 }
159 default:
160 switch wireTypes[fd.Kind()] {
161 case wire.VarintType:
162 vi.typ = validationTypeRepeatedVarint
163 case wire.Fixed32Type:
164 vi.typ = validationTypeRepeatedFixed32
165 case wire.Fixed64Type:
166 vi.typ = validationTypeRepeatedFixed64
167 }
168 }
169 case fd.IsMap():
170 vi.typ = validationTypeMap
171 switch fd.MapKey().Kind() {
172 case pref.StringKind:
173 if strs.EnforceUTF8(fd) {
174 vi.keyType = validationTypeUTF8String
175 }
176 }
177 switch fd.MapValue().Kind() {
178 case pref.MessageKind:
179 vi.valType = validationTypeMessage
180 if ft.Kind() == reflect.Map {
181 vi.mi = getMessageInfo(ft.Elem())
182 }
183 case pref.StringKind:
184 if strs.EnforceUTF8(fd) {
185 vi.valType = validationTypeUTF8String
186 }
187 }
188 default:
189 switch fd.Kind() {
190 case pref.MessageKind:
191 vi.typ = validationTypeMessage
192 if !fd.IsWeak() {
193 vi.mi = getMessageInfo(ft)
194 }
195 case pref.GroupKind:
196 vi.typ = validationTypeGroup
197 vi.mi = getMessageInfo(ft)
198 case pref.StringKind:
199 vi.typ = validationTypeBytes
200 if strs.EnforceUTF8(fd) {
201 vi.typ = validationTypeUTF8String
202 }
203 default:
204 switch wireTypes[fd.Kind()] {
205 case wire.VarintType:
206 vi.typ = validationTypeVarint
207 case wire.Fixed32Type:
208 vi.typ = validationTypeFixed32
209 case wire.Fixed64Type:
210 vi.typ = validationTypeFixed64
211 }
212 }
213 }
214 return vi
215}
216
217func (mi *MessageInfo) validate(b []byte, groupTag wire.Number, opts unmarshalOptions) (result ValidationStatus) {
218 type validationState struct {
219 typ validationType
220 keyType, valType validationType
221 endGroup wire.Number
222 mi *MessageInfo
223 tail []byte
224 requiredMask uint64
225 }
226
227 // Pre-allocate some slots to avoid repeated slice reallocation.
228 states := make([]validationState, 0, 16)
229 states = append(states, validationState{
230 typ: validationTypeMessage,
231 mi: mi,
232 })
233 if groupTag > 0 {
234 states[0].typ = validationTypeGroup
235 states[0].endGroup = groupTag
236 }
237 initialized := true
238State:
239 for len(states) > 0 {
240 st := &states[len(states)-1]
241 if st.mi != nil {
242 st.mi.init()
243 }
244 Field:
245 for len(b) > 0 {
246 num, wtyp, n := wire.ConsumeTag(b)
247 if n < 0 {
248 return ValidationInvalid
249 }
250 b = b[n:]
251 if num > wire.MaxValidNumber {
252 return ValidationInvalid
253 }
254 if wtyp == wire.EndGroupType {
255 if st.endGroup == num {
256 goto PopState
257 }
258 return ValidationInvalid
259 }
260 var vi validationInfo
261 switch st.typ {
262 case validationTypeMap:
263 switch num {
264 case 1:
265 vi.typ = st.keyType
266 case 2:
267 vi.typ = st.valType
268 vi.mi = st.mi
269 }
270 default:
271 var f *coderFieldInfo
272 if int(num) < len(st.mi.denseCoderFields) {
273 f = st.mi.denseCoderFields[num]
274 } else {
275 f = st.mi.coderFields[num]
276 }
277 if f != nil {
278 vi = f.validation
279 if vi.typ == validationTypeMessage && vi.mi == nil {
280 // Probable weak field.
281 //
282 // TODO: Consider storing the results of this lookup somewhere
283 // rather than recomputing it on every validation.
284 fd := st.mi.Desc.Fields().ByNumber(num)
285 if fd == nil || !fd.IsWeak() {
286 break
287 }
288 messageName := fd.Message().FullName()
289 messageType, err := preg.GlobalTypes.FindMessageByName(messageName)
290 switch err {
291 case nil:
292 vi.mi, _ = messageType.(*MessageInfo)
293 case preg.NotFound:
294 vi.typ = validationTypeBytes
295 default:
296 return ValidationUnknown
297 }
298 }
299 break
300 }
301 // Possible extension field.
302 //
303 // TODO: We should return ValidationUnknown when:
304 // 1. The resolver is not frozen. (More extensions may be added to it.)
305 // 2. The resolver returns preg.NotFound.
306 // In this case, a type added to the resolver in the future could cause
307 // unmarshaling to begin failing. Supporting this requires some way to
308 // determine if the resolver is frozen.
309 xt, err := opts.Resolver().FindExtensionByNumber(st.mi.Desc.FullName(), num)
310 if err != nil && err != preg.NotFound {
311 return ValidationUnknown
312 }
313 if err == nil {
314 vi = getExtensionFieldInfo(xt).validation
315 }
316 }
317 if vi.requiredIndex > 0 {
318 // Check that the field has a compatible wire type.
319 // We only need to consider non-repeated field types,
320 // since repeated fields (and maps) can never be required.
321 ok := false
322 switch vi.typ {
323 case validationTypeVarint:
324 ok = wtyp == wire.VarintType
325 case validationTypeFixed32:
326 ok = wtyp == wire.Fixed32Type
327 case validationTypeFixed64:
328 ok = wtyp == wire.Fixed64Type
329 case validationTypeBytes, validationTypeUTF8String, validationTypeMessage, validationTypeGroup:
330 ok = wtyp == wire.BytesType
331 }
332 if ok {
333 st.requiredMask |= 1 << (vi.requiredIndex - 1)
334 }
335 }
336 switch vi.typ {
337 case validationTypeMessage, validationTypeMap:
338 if wtyp != wire.BytesType {
339 break
340 }
341 if vi.mi == nil && vi.typ == validationTypeMessage {
342 return ValidationUnknown
343 }
344 size, n := wire.ConsumeVarint(b)
345 if n < 0 {
346 return ValidationInvalid
347 }
348 b = b[n:]
349 if uint64(len(b)) < size {
350 return ValidationInvalid
351 }
352 states = append(states, validationState{
353 typ: vi.typ,
354 keyType: vi.keyType,
355 valType: vi.valType,
356 mi: vi.mi,
357 tail: b[size:],
358 })
359 b = b[:size]
360 continue State
361 case validationTypeGroup:
362 if wtyp != wire.StartGroupType {
363 break
364 }
365 if vi.mi == nil {
366 return ValidationUnknown
367 }
368 states = append(states, validationState{
369 typ: validationTypeGroup,
370 mi: vi.mi,
371 endGroup: num,
372 })
373 continue State
374 case validationTypeRepeatedVarint:
375 if wtyp != wire.BytesType {
376 break
377 }
378 // Packed field.
379 v, n := wire.ConsumeBytes(b)
380 if n < 0 {
381 return ValidationInvalid
382 }
383 b = b[n:]
384 for len(v) > 0 {
385 _, n := wire.ConsumeVarint(v)
386 if n < 0 {
387 return ValidationInvalid
388 }
389 v = v[n:]
390 }
391 continue Field
392 case validationTypeRepeatedFixed32:
393 if wtyp != wire.BytesType {
394 break
395 }
396 // Packed field.
397 v, n := wire.ConsumeBytes(b)
398 if n < 0 || len(v)%4 != 0 {
399 return ValidationInvalid
400 }
401 b = b[n:]
402 continue Field
403 case validationTypeRepeatedFixed64:
404 if wtyp != wire.BytesType {
405 break
406 }
407 // Packed field.
408 v, n := wire.ConsumeBytes(b)
409 if n < 0 || len(v)%8 != 0 {
410 return ValidationInvalid
411 }
412 b = b[n:]
413 continue Field
414 case validationTypeUTF8String:
415 if wtyp != wire.BytesType {
416 break
417 }
418 v, n := wire.ConsumeBytes(b)
419 if n < 0 || !utf8.Valid(v) {
420 return ValidationInvalid
421 }
422 b = b[n:]
423 continue Field
424 }
425 n = wire.ConsumeFieldValue(num, wtyp, b)
426 if n < 0 {
427 return ValidationInvalid
428 }
429 b = b[n:]
430 }
431 if st.endGroup != 0 {
432 return ValidationInvalid
433 }
434 if len(b) != 0 {
435 return ValidationInvalid
436 }
437 b = st.tail
438 PopState:
439 switch st.typ {
440 case validationTypeMessage, validationTypeGroup:
441 // If there are more than 64 required fields, this check will
442 // always fail and we will report that the message is potentially
443 // uninitialized.
444 if st.mi.numRequiredFields > 0 && bits.OnesCount64(st.requiredMask) != int(st.mi.numRequiredFields) {
445 initialized = false
446 }
447 }
448 states = states[:len(states)-1]
449 }
450 if !initialized {
451 return ValidationValidMaybeUninitalized
452 }
453 return ValidationValidInitialized
454}