blob: 4af5f162e0763e46a861ec9dcfdd0f01a659c369 [file] [log] [blame]
Damien Neilba23aa52018-12-07 14:38:17 -08001// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style.
3// license that can be found in the LICENSE file.
4
5package proto
6
7import (
8 "errors"
9
10 "github.com/golang/protobuf/v2/internal/encoding/wire"
11 "github.com/golang/protobuf/v2/internal/pragma"
12 "github.com/golang/protobuf/v2/reflect/protoreflect"
13)
14
15// UnmarshalOptions configures the unmarshaler.
16//
17// Example usage:
18// err := UnmarshalOptions{DiscardUnknown: true}.Unmarshal(b, m)
19type UnmarshalOptions struct {
20 // If DiscardUnknown is set, unknown fields are ignored.
21 DiscardUnknown bool
22
23 pragma.NoUnkeyedLiterals
24}
25
26// Unmarshal parses the wire-format message in b and places the result in m.
27func Unmarshal(b []byte, m Message) error {
28 return UnmarshalOptions{}.Unmarshal(b, m)
29}
30
31// Unmarshal parses the wire-format message in b and places the result in m.
32func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
33 // TODO: Reset m?
34 return o.unmarshalMessage(b, m.ProtoReflect())
35}
36
37func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) error {
38 messageType := m.Type()
39 fieldTypes := messageType.Fields()
40 knownFields := m.KnownFields()
41 unknownFields := m.UnknownFields()
42 for len(b) > 0 {
43 // Parse the tag (field number and wire type).
44 num, wtyp, tagLen := wire.ConsumeTag(b)
45 if tagLen < 0 {
46 return wire.ParseError(tagLen)
47 }
48
49 // Parse the field value.
50 fieldType := fieldTypes.ByNumber(num)
51 var err error
52 var valLen int
53 switch {
54 case fieldType == nil:
55 err = errUnknown
56 case fieldType.Cardinality() != protoreflect.Repeated:
57 valLen, err = o.unmarshalScalarField(b[tagLen:], wtyp, num, knownFields, fieldType)
58 case !fieldType.IsMap():
59 valLen, err = o.unmarshalList(b[tagLen:], wtyp, num, knownFields.Get(num).List(), fieldType.Kind())
60 default:
61 valLen, err = o.unmarshalMap(b[tagLen:], wtyp, num, knownFields.Get(num).Map(), fieldType)
62 }
63 if err == errUnknown {
64 valLen = wire.ConsumeFieldValue(num, wtyp, b[tagLen:])
65 if valLen < 0 {
66 return wire.ParseError(valLen)
67 }
68 unknownFields.Set(num, append(unknownFields.Get(num), b[:tagLen+valLen]...))
69 } else if err != nil {
70 return err
71 }
72 b = b[tagLen+valLen:]
73 }
74 // TODO: required field checks
75 return nil
76}
77
78func (o UnmarshalOptions) unmarshalScalarField(b []byte, wtyp wire.Type, num wire.Number, knownFields protoreflect.KnownFields, field protoreflect.FieldDescriptor) (n int, err error) {
79 v, n, err := o.unmarshalScalar(b, wtyp, num, field.Kind())
80 if err != nil {
81 return 0, err
82 }
83 switch field.Kind() {
84 case protoreflect.GroupKind, protoreflect.MessageKind:
85 // Messages are merged with any existing message value,
86 // unless the message is part of a oneof.
87 //
88 // TODO: C++ merges into oneofs, while v1 does not.
89 // Evaluate which behavior to pick.
90 var m protoreflect.Message
91 if knownFields.Has(num) && field.OneofType() == nil {
92 m = knownFields.Get(num).Message()
93 } else {
94 m = knownFields.NewMessage(num).ProtoReflect()
95 knownFields.Set(num, protoreflect.ValueOf(m))
96 }
97 if err := o.unmarshalMessage(v.Bytes(), m); err != nil {
98 return 0, err
99 }
100 default:
101 // Non-message scalars replace the previous value.
102 knownFields.Set(num, v)
103 }
104 return n, nil
105}
106
107func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp wire.Type, num wire.Number, mapv protoreflect.Map, field protoreflect.FieldDescriptor) (n int, err error) {
108 if wtyp != wire.BytesType {
109 return 0, errUnknown
110 }
111 b, n = wire.ConsumeBytes(b)
112 if n < 0 {
113 return 0, wire.ParseError(n)
114 }
115 var (
116 keyField = field.MessageType().Fields().ByNumber(1)
117 valField = field.MessageType().Fields().ByNumber(2)
118 key protoreflect.Value
119 val protoreflect.Value
120 haveKey bool
121 haveVal bool
122 )
123 switch valField.Kind() {
124 case protoreflect.GroupKind, protoreflect.MessageKind:
125 val = protoreflect.ValueOf(mapv.NewMessage().ProtoReflect())
126 }
127 // Map entries are represented as a two-element message with fields
128 // containing the key and value.
129 for len(b) > 0 {
130 num, wtyp, n := wire.ConsumeTag(b)
131 if n < 0 {
132 return 0, wire.ParseError(n)
133 }
134 b = b[n:]
135 err = errUnknown
136 switch num {
137 case 1:
138 key, n, err = o.unmarshalScalar(b, wtyp, num, keyField.Kind())
139 if err != nil {
140 break
141 }
142 haveKey = true
143 case 2:
144 var v protoreflect.Value
145 v, n, err = o.unmarshalScalar(b, wtyp, num, valField.Kind())
146 if err != nil {
147 break
148 }
149 switch valField.Kind() {
150 case protoreflect.GroupKind, protoreflect.MessageKind:
151 if err := o.unmarshalMessage(v.Bytes(), val.Message()); err != nil {
152 return 0, err
153 }
154 default:
155 val = v
156 }
157 haveVal = true
158 }
159 if err == errUnknown {
160 n = wire.ConsumeFieldValue(num, wtyp, b)
161 if n < 0 {
162 return 0, wire.ParseError(n)
163 }
164 } else if err != nil {
165 return 0, err
166 }
167 b = b[n:]
168 }
169 // Every map entry should have entries for key and value, but this is not strictly required.
170 if !haveKey {
171 key = keyField.Default()
172 }
173 if !haveVal {
174 switch valField.Kind() {
175 case protoreflect.GroupKind, protoreflect.MessageKind:
176 // Trigger required field checks by unmarshaling an empty message.
177 if err := o.unmarshalMessage(nil, val.Message()); err != nil {
178 return 0, err
179 }
180 default:
181 val = valField.Default()
182 }
183 }
184 mapv.Set(key.MapKey(), val)
185 return n, nil
186}
187
188// errUnknown is used internally to indicate fields which should be added
189// to the unknown field set of a message. It is never returned from an exported
190// function.
191var errUnknown = errors.New("unknown")