blob: 093f0784ea6989f4fe707c0661e2152a8217e0ba [file] [log] [blame]
Damien Neilb0c26f12019-12-16 09:37:59 -08001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8 "fmt"
9 "math"
10 "math/bits"
11 "reflect"
12 "unicode/utf8"
13
14 "google.golang.org/protobuf/internal/encoding/wire"
Damien Neil0bf97b72020-01-24 09:00:33 -080015 "google.golang.org/protobuf/internal/flags"
Damien Neilb0c26f12019-12-16 09:37:59 -080016 "google.golang.org/protobuf/internal/strs"
17 pref "google.golang.org/protobuf/reflect/protoreflect"
18 preg "google.golang.org/protobuf/reflect/protoregistry"
19 piface "google.golang.org/protobuf/runtime/protoiface"
20)
21
22// ValidationStatus is the result of validating the wire-format encoding of a message.
23type ValidationStatus int
24
25const (
26 // ValidationUnknown indicates that unmarshaling the message might succeed or fail.
27 // The validator was unable to render a judgement.
28 //
29 // The only causes of this status are an aberrant message type appearing somewhere
30 // in the message or a failure in the extension resolver.
31 ValidationUnknown ValidationStatus = iota + 1
32
33 // ValidationInvalid indicates that unmarshaling the message will fail.
34 ValidationInvalid
35
36 // ValidationValidInitialized indicates that unmarshaling the message will succeed
37 // and IsInitialized on the result will report success.
38 ValidationValidInitialized
39
40 // ValidationValidMaybeUninitalized indicates unmarshaling the message will succeed,
41 // but the output of IsInitialized on the result is unknown.
42 //
43 // This status may be returned for an initialized message when a message value
44 // is split across multiple fields.
45 ValidationValidMaybeUninitalized
46)
47
48func (v ValidationStatus) String() string {
49 switch v {
50 case ValidationUnknown:
51 return "ValidationUnknown"
52 case ValidationInvalid:
53 return "ValidationInvalid"
54 case ValidationValidInitialized:
55 return "ValidationValidInitialized"
56 case ValidationValidMaybeUninitalized:
57 return "ValidationValidMaybeUninitalized"
58 default:
59 return fmt.Sprintf("ValidationStatus(%d)", int(v))
60 }
61}
62
63// Validate determines whether the contents of the buffer are a valid wire encoding
64// of the message type.
65//
66// This function is exposed for testing.
67func Validate(b []byte, mt pref.MessageType, opts piface.UnmarshalOptions) ValidationStatus {
68 mi, ok := mt.(*MessageInfo)
69 if !ok {
70 return ValidationUnknown
71 }
72 return mi.validate(b, 0, newUnmarshalOptions(opts))
73}
74
75type validationInfo struct {
76 mi *MessageInfo
77 typ validationType
78 keyType, valType validationType
79
Damien Neil170b2bf2020-01-24 16:42:42 -080080 // For non-required fields, requiredBit is 0.
Damien Neilb0c26f12019-12-16 09:37:59 -080081 //
Damien Neil170b2bf2020-01-24 16:42:42 -080082 // For required fields, requiredBit's nth bit is set, where n is a
83 // unique index in the range [0, MessageInfo.numRequiredFields).
84 //
85 // If there are more than 64 required fields, requiredBit is 0.
86 requiredBit uint64
Damien Neilb0c26f12019-12-16 09:37:59 -080087}
88
89type validationType uint8
90
91const (
92 validationTypeOther validationType = iota
93 validationTypeMessage
94 validationTypeGroup
95 validationTypeMap
96 validationTypeRepeatedVarint
97 validationTypeRepeatedFixed32
98 validationTypeRepeatedFixed64
99 validationTypeVarint
100 validationTypeFixed32
101 validationTypeFixed64
102 validationTypeBytes
103 validationTypeUTF8String
104)
105
106func newFieldValidationInfo(mi *MessageInfo, si structInfo, fd pref.FieldDescriptor, ft reflect.Type) validationInfo {
107 var vi validationInfo
108 switch {
109 case fd.ContainingOneof() != nil:
110 switch fd.Kind() {
111 case pref.MessageKind:
112 vi.typ = validationTypeMessage
113 if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
114 vi.mi = getMessageInfo(ot.Field(0).Type)
115 }
116 case pref.GroupKind:
117 vi.typ = validationTypeGroup
118 if ot, ok := si.oneofWrappersByNumber[fd.Number()]; ok {
119 vi.mi = getMessageInfo(ot.Field(0).Type)
120 }
121 case pref.StringKind:
122 if strs.EnforceUTF8(fd) {
123 vi.typ = validationTypeUTF8String
124 }
125 }
126 default:
127 vi = newValidationInfo(fd, ft)
128 }
129 if fd.Cardinality() == pref.Required {
130 // Avoid overflow. The required field check is done with a 64-bit mask, with
131 // any message containing more than 64 required fields always reported as
132 // potentially uninitialized, so it is not important to get a precise count
133 // of the required fields past 64.
134 if mi.numRequiredFields < math.MaxUint8 {
135 mi.numRequiredFields++
Damien Neil170b2bf2020-01-24 16:42:42 -0800136 vi.requiredBit = 1 << (mi.numRequiredFields - 1)
Damien Neilb0c26f12019-12-16 09:37:59 -0800137 }
138 }
139 return vi
140}
141
142func newValidationInfo(fd pref.FieldDescriptor, ft reflect.Type) validationInfo {
143 var vi validationInfo
144 switch {
145 case fd.IsList():
146 switch fd.Kind() {
147 case pref.MessageKind:
148 vi.typ = validationTypeMessage
149 if ft.Kind() == reflect.Slice {
150 vi.mi = getMessageInfo(ft.Elem())
151 }
152 case pref.GroupKind:
153 vi.typ = validationTypeGroup
154 if ft.Kind() == reflect.Slice {
155 vi.mi = getMessageInfo(ft.Elem())
156 }
157 case pref.StringKind:
158 vi.typ = validationTypeBytes
159 if strs.EnforceUTF8(fd) {
160 vi.typ = validationTypeUTF8String
161 }
162 default:
163 switch wireTypes[fd.Kind()] {
164 case wire.VarintType:
165 vi.typ = validationTypeRepeatedVarint
166 case wire.Fixed32Type:
167 vi.typ = validationTypeRepeatedFixed32
168 case wire.Fixed64Type:
169 vi.typ = validationTypeRepeatedFixed64
170 }
171 }
172 case fd.IsMap():
173 vi.typ = validationTypeMap
174 switch fd.MapKey().Kind() {
175 case pref.StringKind:
176 if strs.EnforceUTF8(fd) {
177 vi.keyType = validationTypeUTF8String
178 }
179 }
180 switch fd.MapValue().Kind() {
181 case pref.MessageKind:
182 vi.valType = validationTypeMessage
183 if ft.Kind() == reflect.Map {
184 vi.mi = getMessageInfo(ft.Elem())
185 }
186 case pref.StringKind:
187 if strs.EnforceUTF8(fd) {
188 vi.valType = validationTypeUTF8String
189 }
190 }
191 default:
192 switch fd.Kind() {
193 case pref.MessageKind:
194 vi.typ = validationTypeMessage
195 if !fd.IsWeak() {
196 vi.mi = getMessageInfo(ft)
197 }
198 case pref.GroupKind:
199 vi.typ = validationTypeGroup
200 vi.mi = getMessageInfo(ft)
201 case pref.StringKind:
202 vi.typ = validationTypeBytes
203 if strs.EnforceUTF8(fd) {
204 vi.typ = validationTypeUTF8String
205 }
206 default:
207 switch wireTypes[fd.Kind()] {
208 case wire.VarintType:
209 vi.typ = validationTypeVarint
210 case wire.Fixed32Type:
211 vi.typ = validationTypeFixed32
212 case wire.Fixed64Type:
213 vi.typ = validationTypeFixed64
Damien Neil6635e7d2020-01-15 15:08:57 -0800214 case wire.BytesType:
215 vi.typ = validationTypeBytes
Damien Neilb0c26f12019-12-16 09:37:59 -0800216 }
217 }
218 }
219 return vi
220}
221
222func (mi *MessageInfo) validate(b []byte, groupTag wire.Number, opts unmarshalOptions) (result ValidationStatus) {
223 type validationState struct {
224 typ validationType
225 keyType, valType validationType
226 endGroup wire.Number
227 mi *MessageInfo
228 tail []byte
229 requiredMask uint64
230 }
231
232 // Pre-allocate some slots to avoid repeated slice reallocation.
233 states := make([]validationState, 0, 16)
234 states = append(states, validationState{
235 typ: validationTypeMessage,
236 mi: mi,
237 })
238 if groupTag > 0 {
239 states[0].typ = validationTypeGroup
240 states[0].endGroup = groupTag
241 }
242 initialized := true
243State:
244 for len(states) > 0 {
245 st := &states[len(states)-1]
246 if st.mi != nil {
247 st.mi.init()
Damien Neil0bf97b72020-01-24 09:00:33 -0800248 if flags.ProtoLegacy && st.mi.isMessageSet {
249 return ValidationUnknown
250 }
Damien Neilb0c26f12019-12-16 09:37:59 -0800251 }
252 Field:
253 for len(b) > 0 {
254 num, wtyp, n := wire.ConsumeTag(b)
255 if n < 0 {
256 return ValidationInvalid
257 }
258 b = b[n:]
259 if num > wire.MaxValidNumber {
260 return ValidationInvalid
261 }
262 if wtyp == wire.EndGroupType {
263 if st.endGroup == num {
264 goto PopState
265 }
266 return ValidationInvalid
267 }
268 var vi validationInfo
269 switch st.typ {
270 case validationTypeMap:
271 switch num {
272 case 1:
273 vi.typ = st.keyType
274 case 2:
275 vi.typ = st.valType
276 vi.mi = st.mi
Damien Neil170b2bf2020-01-24 16:42:42 -0800277 vi.requiredBit = 1
Damien Neilb0c26f12019-12-16 09:37:59 -0800278 }
279 default:
280 var f *coderFieldInfo
281 if int(num) < len(st.mi.denseCoderFields) {
282 f = st.mi.denseCoderFields[num]
283 } else {
284 f = st.mi.coderFields[num]
285 }
286 if f != nil {
287 vi = f.validation
288 if vi.typ == validationTypeMessage && vi.mi == nil {
289 // Probable weak field.
290 //
291 // TODO: Consider storing the results of this lookup somewhere
292 // rather than recomputing it on every validation.
293 fd := st.mi.Desc.Fields().ByNumber(num)
294 if fd == nil || !fd.IsWeak() {
295 break
296 }
297 messageName := fd.Message().FullName()
298 messageType, err := preg.GlobalTypes.FindMessageByName(messageName)
299 switch err {
300 case nil:
301 vi.mi, _ = messageType.(*MessageInfo)
302 case preg.NotFound:
303 vi.typ = validationTypeBytes
304 default:
305 return ValidationUnknown
306 }
307 }
308 break
309 }
310 // Possible extension field.
311 //
312 // TODO: We should return ValidationUnknown when:
313 // 1. The resolver is not frozen. (More extensions may be added to it.)
314 // 2. The resolver returns preg.NotFound.
315 // In this case, a type added to the resolver in the future could cause
316 // unmarshaling to begin failing. Supporting this requires some way to
317 // determine if the resolver is frozen.
318 xt, err := opts.Resolver().FindExtensionByNumber(st.mi.Desc.FullName(), num)
319 if err != nil && err != preg.NotFound {
320 return ValidationUnknown
321 }
322 if err == nil {
323 vi = getExtensionFieldInfo(xt).validation
324 }
325 }
Damien Neil170b2bf2020-01-24 16:42:42 -0800326 if vi.requiredBit != 0 {
Damien Neilb0c26f12019-12-16 09:37:59 -0800327 // Check that the field has a compatible wire type.
328 // We only need to consider non-repeated field types,
329 // since repeated fields (and maps) can never be required.
330 ok := false
331 switch vi.typ {
332 case validationTypeVarint:
333 ok = wtyp == wire.VarintType
334 case validationTypeFixed32:
335 ok = wtyp == wire.Fixed32Type
336 case validationTypeFixed64:
337 ok = wtyp == wire.Fixed64Type
338 case validationTypeBytes, validationTypeUTF8String, validationTypeMessage, validationTypeGroup:
339 ok = wtyp == wire.BytesType
340 }
341 if ok {
Damien Neil170b2bf2020-01-24 16:42:42 -0800342 st.requiredMask |= vi.requiredBit
Damien Neilb0c26f12019-12-16 09:37:59 -0800343 }
344 }
345 switch vi.typ {
346 case validationTypeMessage, validationTypeMap:
347 if wtyp != wire.BytesType {
348 break
349 }
350 if vi.mi == nil && vi.typ == validationTypeMessage {
351 return ValidationUnknown
352 }
353 size, n := wire.ConsumeVarint(b)
354 if n < 0 {
355 return ValidationInvalid
356 }
357 b = b[n:]
358 if uint64(len(b)) < size {
359 return ValidationInvalid
360 }
361 states = append(states, validationState{
362 typ: vi.typ,
363 keyType: vi.keyType,
364 valType: vi.valType,
365 mi: vi.mi,
366 tail: b[size:],
367 })
368 b = b[:size]
369 continue State
370 case validationTypeGroup:
371 if wtyp != wire.StartGroupType {
372 break
373 }
374 if vi.mi == nil {
375 return ValidationUnknown
376 }
377 states = append(states, validationState{
378 typ: validationTypeGroup,
379 mi: vi.mi,
380 endGroup: num,
381 })
382 continue State
383 case validationTypeRepeatedVarint:
384 if wtyp != wire.BytesType {
385 break
386 }
387 // Packed field.
388 v, n := wire.ConsumeBytes(b)
389 if n < 0 {
390 return ValidationInvalid
391 }
392 b = b[n:]
393 for len(v) > 0 {
394 _, n := wire.ConsumeVarint(v)
395 if n < 0 {
396 return ValidationInvalid
397 }
398 v = v[n:]
399 }
400 continue Field
401 case validationTypeRepeatedFixed32:
402 if wtyp != wire.BytesType {
403 break
404 }
405 // Packed field.
406 v, n := wire.ConsumeBytes(b)
407 if n < 0 || len(v)%4 != 0 {
408 return ValidationInvalid
409 }
410 b = b[n:]
411 continue Field
412 case validationTypeRepeatedFixed64:
413 if wtyp != wire.BytesType {
414 break
415 }
416 // Packed field.
417 v, n := wire.ConsumeBytes(b)
418 if n < 0 || len(v)%8 != 0 {
419 return ValidationInvalid
420 }
421 b = b[n:]
422 continue Field
423 case validationTypeUTF8String:
424 if wtyp != wire.BytesType {
425 break
426 }
427 v, n := wire.ConsumeBytes(b)
428 if n < 0 || !utf8.Valid(v) {
429 return ValidationInvalid
430 }
431 b = b[n:]
432 continue Field
433 }
434 n = wire.ConsumeFieldValue(num, wtyp, b)
435 if n < 0 {
436 return ValidationInvalid
437 }
438 b = b[n:]
439 }
440 if st.endGroup != 0 {
441 return ValidationInvalid
442 }
443 if len(b) != 0 {
444 return ValidationInvalid
445 }
446 b = st.tail
447 PopState:
Damien Neil54a0a042020-01-08 17:53:16 -0800448 numRequiredFields := 0
Damien Neilb0c26f12019-12-16 09:37:59 -0800449 switch st.typ {
450 case validationTypeMessage, validationTypeGroup:
Damien Neil54a0a042020-01-08 17:53:16 -0800451 numRequiredFields = int(st.mi.numRequiredFields)
452 case validationTypeMap:
453 // If this is a map field with a message value that contains
454 // required fields, require that the value be present.
455 if st.mi != nil && st.mi.numRequiredFields > 0 {
456 numRequiredFields = 1
Damien Neilb0c26f12019-12-16 09:37:59 -0800457 }
458 }
Damien Neil54a0a042020-01-08 17:53:16 -0800459 // If there are more than 64 required fields, this check will
460 // always fail and we will report that the message is potentially
461 // uninitialized.
462 if numRequiredFields > 0 && bits.OnesCount64(st.requiredMask) != numRequiredFields {
463 initialized = false
464 }
Damien Neilb0c26f12019-12-16 09:37:59 -0800465 states = states[:len(states)-1]
466 }
467 if !initialized {
468 return ValidationValidMaybeUninitalized
469 }
470 return ValidationValidInitialized
471}