blob: a43efd1284839f46b452c2524715a3f001c16674 [file] [log] [blame]
Herbie Ongc96a79d2019-03-08 10:49:17 -08001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package jsonpb
6
7import (
8 "encoding/base64"
9 "fmt"
10 "math"
11 "strconv"
12 "strings"
13
14 "github.com/golang/protobuf/v2/internal/encoding/json"
15 "github.com/golang/protobuf/v2/internal/errors"
16 "github.com/golang/protobuf/v2/internal/set"
17 "github.com/golang/protobuf/v2/proto"
18 pref "github.com/golang/protobuf/v2/reflect/protoreflect"
19)
20
21// Unmarshal reads the given []byte into the given proto.Message.
22func Unmarshal(m proto.Message, b []byte) error {
23 return UnmarshalOptions{}.Unmarshal(m, b)
24}
25
26// UnmarshalOptions is a configurable JSON format parser.
27type UnmarshalOptions struct{}
28
29// Unmarshal reads the given []byte and populates the given proto.Message using
30// options in UnmarshalOptions object. It will clear the message first before
31// setting the fields. If it returns an error, the given message may be
32// partially set.
33func (o UnmarshalOptions) Unmarshal(m proto.Message, b []byte) error {
34 mr := m.ProtoReflect()
35 // TODO: Determine if we would like to have an option for merging or only
36 // have merging behavior. We should at least be consistent with textproto
37 // marshaling.
38 resetMessage(mr)
39
40 dec := decoder{json.NewDecoder(b)}
41 var nerr errors.NonFatal
42 if err := dec.unmarshalMessage(mr); !nerr.Merge(err) {
43 return err
44 }
45
46 // Check for EOF.
47 val, err := dec.Read()
48 if err != nil {
49 return err
50 }
51 if val.Type() != json.EOF {
52 return unexpectedJSONError{val}
53 }
54 return nerr.E
55}
56
57// resetMessage clears all fields of given protoreflect.Message.
58func resetMessage(m pref.Message) {
59 knownFields := m.KnownFields()
60 knownFields.Range(func(num pref.FieldNumber, _ pref.Value) bool {
61 knownFields.Clear(num)
62 return true
63 })
64 unknownFields := m.UnknownFields()
65 unknownFields.Range(func(num pref.FieldNumber, _ pref.RawFields) bool {
66 unknownFields.Set(num, nil)
67 return true
68 })
69 extTypes := knownFields.ExtensionTypes()
70 extTypes.Range(func(xt pref.ExtensionType) bool {
71 extTypes.Remove(xt)
72 return true
73 })
74}
75
76// unexpectedJSONError is an error that contains the unexpected json.Value. This
77// is used by decoder methods to provide callers the read json.Value that it
78// did not expect.
79// TODO: Consider moving this to internal/encoding/json for consistency with
80// errors that package returns.
81type unexpectedJSONError struct {
82 value json.Value
83}
84
85func (e unexpectedJSONError) Error() string {
86 return newError("unexpected value %s", e.value).Error()
87}
88
89// newError returns an error object. If one of the values passed in is of
90// json.Value type, it produces an error with position info.
91func newError(f string, x ...interface{}) error {
92 var hasValue bool
93 var line, column int
94 for i := 0; i < len(x); i++ {
95 if val, ok := x[i].(json.Value); ok {
96 line, column = val.Position()
97 hasValue = true
98 break
99 }
100 }
101 e := errors.New(f, x...)
102 if hasValue {
103 return errors.New("(line %d:%d): %v", line, column, e)
104 }
105 return e
106}
107
108// decoder decodes JSON into protoreflect values.
109type decoder struct {
110 *json.Decoder
111}
112
113// unmarshalMessage unmarshals a message into the given protoreflect.Message.
114func (d decoder) unmarshalMessage(m pref.Message) error {
115 var nerr errors.NonFatal
116 var reqNums set.Ints
117 var seenNums set.Ints
118
119 msgType := m.Type()
120 knownFields := m.KnownFields()
121 fieldDescs := msgType.Fields()
122
123 jval, err := d.Read()
124 if !nerr.Merge(err) {
125 return err
126 }
127 if jval.Type() != json.StartObject {
128 return unexpectedJSONError{jval}
129 }
130
131Loop:
132 for {
133 // Read field name.
134 jval, err := d.Read()
135 if !nerr.Merge(err) {
136 return err
137 }
138 switch jval.Type() {
139 default:
140 return unexpectedJSONError{jval}
141 case json.EndObject:
142 break Loop
143 case json.Name:
144 // Continue below.
145 }
146
147 name, err := jval.Name()
148 if !nerr.Merge(err) {
149 return err
150 }
151
152 // Get the FieldDescriptor based on the field name. The name can either
153 // be the JSON name for the field or the proto field name.
154 fd := fieldDescs.ByJSONName(name)
155 if fd == nil {
156 fd = fieldDescs.ByName(pref.Name(name))
157 }
158
159 if fd == nil {
160 // Field is unknown.
161 // TODO: Provide option to ignore unknown message fields.
162 return newError("%v contains unknown field %s", msgType.FullName(), jval)
163 }
164
165 // Do not allow duplicate fields.
166 num := uint64(fd.Number())
167 if seenNums.Has(num) {
168 return newError("%v contains repeated field %s", msgType.FullName(), jval)
169 }
170 seenNums.Set(num)
171
172 // No need to set values for JSON null.
173 if d.Peek() == json.Null {
174 d.Read()
175 continue
176 }
177
178 if cardinality := fd.Cardinality(); cardinality == pref.Repeated {
179 // Map or list fields have cardinality of repeated.
180 if err := d.unmarshalRepeated(fd, knownFields); !nerr.Merge(err) {
181 return errors.New("%v|%q: %v", fd.FullName(), name, err)
182 }
183 } else {
184 // Required or optional fields.
185 if err := d.unmarshalSingular(fd, knownFields); !nerr.Merge(err) {
186 return errors.New("%v|%q: %v", fd.FullName(), name, err)
187 }
188 if cardinality == pref.Required {
189 reqNums.Set(num)
190 }
191 }
192 }
193
194 // Check for any missing required fields.
195 allReqNums := msgType.RequiredNumbers()
196 if reqNums.Len() != allReqNums.Len() {
197 for i := 0; i < allReqNums.Len(); i++ {
198 if num := allReqNums.Get(i); !reqNums.Has(uint64(num)) {
199 nerr.AppendRequiredNotSet(string(fieldDescs.ByNumber(num).FullName()))
200 }
201 }
202 }
203
204 return nerr.E
205}
206
207// unmarshalSingular unmarshals to the non-repeated field specified by the given
208// FieldDescriptor.
209func (d decoder) unmarshalSingular(fd pref.FieldDescriptor, knownFields pref.KnownFields) error {
210 var val pref.Value
211 var err error
212 num := fd.Number()
213
214 switch fd.Kind() {
215 case pref.MessageKind, pref.GroupKind:
216 m := knownFields.NewMessage(num)
217 err = d.unmarshalMessage(m)
218 val = pref.ValueOf(m)
219 default:
220 val, err = d.unmarshalScalar(fd)
221 }
222
223 var nerr errors.NonFatal
224 if !nerr.Merge(err) {
225 return err
226 }
227 knownFields.Set(num, val)
228 return nerr.E
229}
230
231// unmarshalScalar unmarshals to a scalar/enum protoreflect.Value specified by
232// the given FieldDescriptor.
233func (d decoder) unmarshalScalar(fd pref.FieldDescriptor) (pref.Value, error) {
234 const b32 int = 32
235 const b64 int = 64
236
237 var nerr errors.NonFatal
238 jval, err := d.Read()
239 if !nerr.Merge(err) {
240 return pref.Value{}, err
241 }
242
243 kind := fd.Kind()
244 switch kind {
245 case pref.BoolKind:
246 return unmarshalBool(jval)
247
248 case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind:
249 return unmarshalInt(jval, b32)
250
251 case pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
252 return unmarshalInt(jval, b64)
253
254 case pref.Uint32Kind, pref.Fixed32Kind:
255 return unmarshalUint(jval, b32)
256
257 case pref.Uint64Kind, pref.Fixed64Kind:
258 return unmarshalUint(jval, b64)
259
260 case pref.FloatKind:
261 return unmarshalFloat(jval, b32)
262
263 case pref.DoubleKind:
264 return unmarshalFloat(jval, b64)
265
266 case pref.StringKind:
267 pval, err := unmarshalString(jval)
268 if !nerr.Merge(err) {
269 return pval, err
270 }
271 return pval, nerr.E
272
273 case pref.BytesKind:
274 return unmarshalBytes(jval)
275
276 case pref.EnumKind:
277 return unmarshalEnum(jval, fd)
278 }
279
280 panic(fmt.Sprintf("invalid scalar kind %v", kind))
281}
282
283func unmarshalBool(jval json.Value) (pref.Value, error) {
284 if jval.Type() != json.Bool {
285 return pref.Value{}, unexpectedJSONError{jval}
286 }
287 b, err := jval.Bool()
288 return pref.ValueOf(b), err
289}
290
291func unmarshalInt(jval json.Value, bitSize int) (pref.Value, error) {
292 switch jval.Type() {
293 case json.Number:
294 return getInt(jval, bitSize)
295
296 case json.String:
297 // Use another decoder to decode number from string.
298 dec := decoder{json.NewDecoder([]byte(jval.String()))}
299 var nerr errors.NonFatal
300 jval, err := dec.Read()
301 if !nerr.Merge(err) {
302 return pref.Value{}, err
303 }
304 return getInt(jval, bitSize)
305 }
306 return pref.Value{}, unexpectedJSONError{jval}
307}
308
309func getInt(jval json.Value, bitSize int) (pref.Value, error) {
310 n, err := jval.Int(bitSize)
311 if err != nil {
312 return pref.Value{}, err
313 }
314 if bitSize == 32 {
315 return pref.ValueOf(int32(n)), nil
316 }
317 return pref.ValueOf(n), nil
318}
319
320func unmarshalUint(jval json.Value, bitSize int) (pref.Value, error) {
321 switch jval.Type() {
322 case json.Number:
323 return getUint(jval, bitSize)
324
325 case json.String:
326 // Use another decoder to decode number from string.
327 dec := decoder{json.NewDecoder([]byte(jval.String()))}
328 var nerr errors.NonFatal
329 jval, err := dec.Read()
330 if !nerr.Merge(err) {
331 return pref.Value{}, err
332 }
333 return getUint(jval, bitSize)
334 }
335 return pref.Value{}, unexpectedJSONError{jval}
336}
337
338func getUint(jval json.Value, bitSize int) (pref.Value, error) {
339 n, err := jval.Uint(bitSize)
340 if err != nil {
341 return pref.Value{}, err
342 }
343 if bitSize == 32 {
344 return pref.ValueOf(uint32(n)), nil
345 }
346 return pref.ValueOf(n), nil
347}
348
349func unmarshalFloat(jval json.Value, bitSize int) (pref.Value, error) {
350 switch jval.Type() {
351 case json.Number:
352 return getFloat(jval, bitSize)
353
354 case json.String:
355 s := jval.String()
356 switch s {
357 case "NaN":
358 if bitSize == 32 {
359 return pref.ValueOf(float32(math.NaN())), nil
360 }
361 return pref.ValueOf(math.NaN()), nil
362 case "Infinity":
363 if bitSize == 32 {
364 return pref.ValueOf(float32(math.Inf(+1))), nil
365 }
366 return pref.ValueOf(math.Inf(+1)), nil
367 case "-Infinity":
368 if bitSize == 32 {
369 return pref.ValueOf(float32(math.Inf(-1))), nil
370 }
371 return pref.ValueOf(math.Inf(-1)), nil
372 }
373 // Use another decoder to decode number from string.
374 dec := decoder{json.NewDecoder([]byte(s))}
375 var nerr errors.NonFatal
376 jval, err := dec.Read()
377 if !nerr.Merge(err) {
378 return pref.Value{}, err
379 }
380 return getFloat(jval, bitSize)
381 }
382 return pref.Value{}, unexpectedJSONError{jval}
383}
384
385func getFloat(jval json.Value, bitSize int) (pref.Value, error) {
386 n, err := jval.Float(bitSize)
387 if err != nil {
388 return pref.Value{}, err
389 }
390 if bitSize == 32 {
391 return pref.ValueOf(float32(n)), nil
392 }
393 return pref.ValueOf(n), nil
394}
395
396func unmarshalString(jval json.Value) (pref.Value, error) {
397 if jval.Type() != json.String {
398 return pref.Value{}, unexpectedJSONError{jval}
399 }
400 return pref.ValueOf(jval.String()), nil
401}
402
403func unmarshalBytes(jval json.Value) (pref.Value, error) {
404 if jval.Type() != json.String {
405 return pref.Value{}, unexpectedJSONError{jval}
406 }
407
408 s := jval.String()
409 enc := base64.StdEncoding
410 if strings.ContainsAny(s, "-_") {
411 enc = base64.URLEncoding
412 }
413 if len(s)%4 != 0 {
414 enc = enc.WithPadding(base64.NoPadding)
415 }
416 b, err := enc.DecodeString(s)
417 if err != nil {
418 return pref.Value{}, err
419 }
420 return pref.ValueOf(b), nil
421}
422
423func unmarshalEnum(jval json.Value, fd pref.FieldDescriptor) (pref.Value, error) {
424 switch jval.Type() {
425 case json.String:
426 // Lookup EnumNumber based on name.
427 s := jval.String()
428 if enumVal := fd.EnumType().Values().ByName(pref.Name(s)); enumVal != nil {
429 return pref.ValueOf(enumVal.Number()), nil
430 }
431 return pref.Value{}, newError("invalid enum value %q", jval)
432
433 case json.Number:
434 n, err := jval.Int(32)
435 if err != nil {
436 return pref.Value{}, err
437 }
438 return pref.ValueOf(pref.EnumNumber(n)), nil
439 }
440
441 return pref.Value{}, unexpectedJSONError{jval}
442}
443
444// unmarshalRepeated unmarshals into a repeated field.
445func (d decoder) unmarshalRepeated(fd pref.FieldDescriptor, knownFields pref.KnownFields) error {
446 var nerr errors.NonFatal
447 num := fd.Number()
448 val := knownFields.Get(num)
449 if !fd.IsMap() {
450 if err := d.unmarshalList(fd, val.List()); !nerr.Merge(err) {
451 return err
452 }
453 } else {
454 if err := d.unmarshalMap(fd, val.Map()); !nerr.Merge(err) {
455 return err
456 }
457 }
458 return nerr.E
459}
460
461// unmarshalList unmarshals into given protoreflect.List.
462func (d decoder) unmarshalList(fd pref.FieldDescriptor, list pref.List) error {
463 var nerr errors.NonFatal
464 jval, err := d.Read()
465 if !nerr.Merge(err) {
466 return err
467 }
468 if jval.Type() != json.StartArray {
469 return unexpectedJSONError{jval}
470 }
471
472 switch fd.Kind() {
473 case pref.MessageKind, pref.GroupKind:
474 for {
475 m := list.NewMessage()
476 err := d.unmarshalMessage(m)
477 if !nerr.Merge(err) {
478 if e, ok := err.(unexpectedJSONError); ok {
479 if e.value.Type() == json.EndArray {
480 // Done with list.
481 return nerr.E
482 }
483 }
484 return err
485 }
486 list.Append(pref.ValueOf(m))
487 }
488 default:
489 for {
490 val, err := d.unmarshalScalar(fd)
491 if !nerr.Merge(err) {
492 if e, ok := err.(unexpectedJSONError); ok {
493 if e.value.Type() == json.EndArray {
494 // Done with list.
495 return nerr.E
496 }
497 }
498 return err
499 }
500 list.Append(val)
501 }
502 }
503 return nerr.E
504}
505
506// unmarshalMap unmarshals into given protoreflect.Map.
507func (d decoder) unmarshalMap(fd pref.FieldDescriptor, mmap pref.Map) error {
508 var nerr errors.NonFatal
509
510 jval, err := d.Read()
511 if !nerr.Merge(err) {
512 return err
513 }
514 if jval.Type() != json.StartObject {
515 return unexpectedJSONError{jval}
516 }
517
518 fields := fd.MessageType().Fields()
519 keyDesc := fields.ByNumber(1)
520 valDesc := fields.ByNumber(2)
521
522 // Determine ahead whether map entry is a scalar type or a message type in
523 // order to call the appropriate unmarshalMapValue func inside the for loop
524 // below.
525 unmarshalMapValue := func() (pref.Value, error) {
526 return d.unmarshalScalar(valDesc)
527 }
528 switch valDesc.Kind() {
529 case pref.MessageKind, pref.GroupKind:
530 unmarshalMapValue = func() (pref.Value, error) {
531 m := mmap.NewMessage()
532 if err := d.unmarshalMessage(m); err != nil {
533 return pref.Value{}, err
534 }
535 return pref.ValueOf(m), nil
536 }
537 }
538
539Loop:
540 for {
541 // Read field name.
542 jval, err := d.Read()
543 if !nerr.Merge(err) {
544 return err
545 }
546 switch jval.Type() {
547 default:
548 return unexpectedJSONError{jval}
549 case json.EndObject:
550 break Loop
551 case json.Name:
552 // Continue.
553 }
554
555 name, err := jval.Name()
556 if !nerr.Merge(err) {
557 return err
558 }
559
560 // Unmarshal field name.
561 pkey, err := unmarshalMapKey(name, keyDesc)
562 if !nerr.Merge(err) {
563 return err
564 }
565
566 // Check for duplicate field name.
567 if mmap.Has(pkey) {
568 return newError("duplicate map key %q", jval)
569 }
570
571 // Read and unmarshal field value.
572 pval, err := unmarshalMapValue()
573 if !nerr.Merge(err) {
574 return err
575 }
576
577 mmap.Set(pkey, pval)
578 }
579
580 return nerr.E
581}
582
583// unmarshalMapKey converts given string into a protoreflect.MapKey. A map key type is any
584// integral or string type.
585func unmarshalMapKey(name string, fd pref.FieldDescriptor) (pref.MapKey, error) {
586 const b32 = 32
587 const b64 = 64
588 const base10 = 10
589
590 kind := fd.Kind()
591 switch kind {
592 case pref.StringKind:
593 return pref.ValueOf(name).MapKey(), nil
594
595 case pref.BoolKind:
596 switch name {
597 case "true":
598 return pref.ValueOf(true).MapKey(), nil
599 case "false":
600 return pref.ValueOf(false).MapKey(), nil
601 }
602 return pref.MapKey{}, errors.New("invalid value for boolean key %q", name)
603
604 case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind:
605 n, err := strconv.ParseInt(name, base10, b32)
606 if err != nil {
607 return pref.MapKey{}, err
608 }
609 return pref.ValueOf(int32(n)).MapKey(), nil
610
611 case pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
612 n, err := strconv.ParseInt(name, base10, b64)
613 if err != nil {
614 return pref.MapKey{}, err
615 }
616 return pref.ValueOf(int64(n)).MapKey(), nil
617
618 case pref.Uint32Kind, pref.Fixed32Kind:
619 n, err := strconv.ParseUint(name, base10, b32)
620 if err != nil {
621 return pref.MapKey{}, err
622 }
623 return pref.ValueOf(uint32(n)).MapKey(), nil
624
625 case pref.Uint64Kind, pref.Fixed64Kind:
626 n, err := strconv.ParseUint(name, base10, b64)
627 if err != nil {
628 return pref.MapKey{}, err
629 }
630 return pref.ValueOf(uint64(n)).MapKey(), nil
631 }
632
633 panic(fmt.Sprintf("%s: invalid kind %s for map key", fd.FullName(), kind))
634}