blob: 4928ace736c38948e66d174420c7cde4b7a7b8e5 [file] [log] [blame]
Damien Neilba23aa52018-12-07 14:38:17 -08001// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style.
3// license that can be found in the LICENSE file.
4
5package proto
6
7import (
8 "errors"
9
10 "github.com/golang/protobuf/v2/internal/encoding/wire"
11 "github.com/golang/protobuf/v2/internal/pragma"
12 "github.com/golang/protobuf/v2/reflect/protoreflect"
Damien Neil0d3e8cc2019-04-01 13:31:55 -070013 "github.com/golang/protobuf/v2/runtime/protoiface"
Damien Neilba23aa52018-12-07 14:38:17 -080014)
15
16// UnmarshalOptions configures the unmarshaler.
17//
18// Example usage:
19// err := UnmarshalOptions{DiscardUnknown: true}.Unmarshal(b, m)
20type UnmarshalOptions struct {
21 // If DiscardUnknown is set, unknown fields are ignored.
22 DiscardUnknown bool
23
Damien Neil0d3e8cc2019-04-01 13:31:55 -070024 // Reflection forces use of the reflection-based decoder, even for
25 // messages which implement fast-path deserialization.
26 Reflection bool
27
Damien Neilba23aa52018-12-07 14:38:17 -080028 pragma.NoUnkeyedLiterals
29}
30
Damien Neil0d3e8cc2019-04-01 13:31:55 -070031var _ = protoiface.UnmarshalOptions(UnmarshalOptions{})
32
Damien Neilba23aa52018-12-07 14:38:17 -080033// Unmarshal parses the wire-format message in b and places the result in m.
34func Unmarshal(b []byte, m Message) error {
35 return UnmarshalOptions{}.Unmarshal(b, m)
36}
37
38// Unmarshal parses the wire-format message in b and places the result in m.
39func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
40 // TODO: Reset m?
Damien Neil0d3e8cc2019-04-01 13:31:55 -070041 if err := o.unmarshalMessageFast(b, m); err != errInternalNoFast {
42 return err
43 }
Damien Neilba23aa52018-12-07 14:38:17 -080044 return o.unmarshalMessage(b, m.ProtoReflect())
45}
46
Damien Neil0d3e8cc2019-04-01 13:31:55 -070047func (o UnmarshalOptions) unmarshalMessageFast(b []byte, m Message) error {
48 if o.Reflection {
49 return errInternalNoFast
50 }
51 methods := protoMethods(m)
52 if methods == nil || methods.Unmarshal == nil {
53 return errInternalNoFast
54 }
55 return methods.Unmarshal(b, m, protoiface.UnmarshalOptions(o))
56}
57
Damien Neilba23aa52018-12-07 14:38:17 -080058func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) error {
59 messageType := m.Type()
60 fieldTypes := messageType.Fields()
61 knownFields := m.KnownFields()
62 unknownFields := m.UnknownFields()
63 for len(b) > 0 {
64 // Parse the tag (field number and wire type).
65 num, wtyp, tagLen := wire.ConsumeTag(b)
66 if tagLen < 0 {
67 return wire.ParseError(tagLen)
68 }
69
70 // Parse the field value.
71 fieldType := fieldTypes.ByNumber(num)
Damien Neild068d302018-12-17 14:06:08 -080072 if fieldType == nil {
73 fieldType = knownFields.ExtensionTypes().ByNumber(num)
74 }
Damien Neilba23aa52018-12-07 14:38:17 -080075 var err error
76 var valLen int
77 switch {
78 case fieldType == nil:
79 err = errUnknown
80 case fieldType.Cardinality() != protoreflect.Repeated:
81 valLen, err = o.unmarshalScalarField(b[tagLen:], wtyp, num, knownFields, fieldType)
82 case !fieldType.IsMap():
83 valLen, err = o.unmarshalList(b[tagLen:], wtyp, num, knownFields.Get(num).List(), fieldType.Kind())
84 default:
85 valLen, err = o.unmarshalMap(b[tagLen:], wtyp, num, knownFields.Get(num).Map(), fieldType)
86 }
87 if err == errUnknown {
88 valLen = wire.ConsumeFieldValue(num, wtyp, b[tagLen:])
89 if valLen < 0 {
90 return wire.ParseError(valLen)
91 }
92 unknownFields.Set(num, append(unknownFields.Get(num), b[:tagLen+valLen]...))
93 } else if err != nil {
94 return err
95 }
96 b = b[tagLen+valLen:]
97 }
98 // TODO: required field checks
99 return nil
100}
101
102func (o UnmarshalOptions) unmarshalScalarField(b []byte, wtyp wire.Type, num wire.Number, knownFields protoreflect.KnownFields, field protoreflect.FieldDescriptor) (n int, err error) {
103 v, n, err := o.unmarshalScalar(b, wtyp, num, field.Kind())
104 if err != nil {
105 return 0, err
106 }
107 switch field.Kind() {
108 case protoreflect.GroupKind, protoreflect.MessageKind:
109 // Messages are merged with any existing message value,
110 // unless the message is part of a oneof.
111 //
112 // TODO: C++ merges into oneofs, while v1 does not.
113 // Evaluate which behavior to pick.
114 var m protoreflect.Message
115 if knownFields.Has(num) && field.OneofType() == nil {
116 m = knownFields.Get(num).Message()
117 } else {
Joe Tsai3bc7d6f2019-01-09 02:57:13 -0800118 m = knownFields.NewMessage(num)
Damien Neilba23aa52018-12-07 14:38:17 -0800119 knownFields.Set(num, protoreflect.ValueOf(m))
120 }
121 if err := o.unmarshalMessage(v.Bytes(), m); err != nil {
122 return 0, err
123 }
124 default:
125 // Non-message scalars replace the previous value.
126 knownFields.Set(num, v)
127 }
128 return n, nil
129}
130
131func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp wire.Type, num wire.Number, mapv protoreflect.Map, field protoreflect.FieldDescriptor) (n int, err error) {
132 if wtyp != wire.BytesType {
133 return 0, errUnknown
134 }
135 b, n = wire.ConsumeBytes(b)
136 if n < 0 {
137 return 0, wire.ParseError(n)
138 }
139 var (
140 keyField = field.MessageType().Fields().ByNumber(1)
141 valField = field.MessageType().Fields().ByNumber(2)
142 key protoreflect.Value
143 val protoreflect.Value
144 haveKey bool
145 haveVal bool
146 )
147 switch valField.Kind() {
148 case protoreflect.GroupKind, protoreflect.MessageKind:
Joe Tsai3bc7d6f2019-01-09 02:57:13 -0800149 val = protoreflect.ValueOf(mapv.NewMessage())
Damien Neilba23aa52018-12-07 14:38:17 -0800150 }
151 // Map entries are represented as a two-element message with fields
152 // containing the key and value.
153 for len(b) > 0 {
154 num, wtyp, n := wire.ConsumeTag(b)
155 if n < 0 {
156 return 0, wire.ParseError(n)
157 }
158 b = b[n:]
159 err = errUnknown
160 switch num {
161 case 1:
162 key, n, err = o.unmarshalScalar(b, wtyp, num, keyField.Kind())
163 if err != nil {
164 break
165 }
166 haveKey = true
167 case 2:
168 var v protoreflect.Value
169 v, n, err = o.unmarshalScalar(b, wtyp, num, valField.Kind())
170 if err != nil {
171 break
172 }
173 switch valField.Kind() {
174 case protoreflect.GroupKind, protoreflect.MessageKind:
175 if err := o.unmarshalMessage(v.Bytes(), val.Message()); err != nil {
176 return 0, err
177 }
178 default:
179 val = v
180 }
181 haveVal = true
182 }
183 if err == errUnknown {
184 n = wire.ConsumeFieldValue(num, wtyp, b)
185 if n < 0 {
186 return 0, wire.ParseError(n)
187 }
188 } else if err != nil {
189 return 0, err
190 }
191 b = b[n:]
192 }
193 // Every map entry should have entries for key and value, but this is not strictly required.
194 if !haveKey {
195 key = keyField.Default()
196 }
197 if !haveVal {
198 switch valField.Kind() {
199 case protoreflect.GroupKind, protoreflect.MessageKind:
200 // Trigger required field checks by unmarshaling an empty message.
201 if err := o.unmarshalMessage(nil, val.Message()); err != nil {
202 return 0, err
203 }
204 default:
205 val = valField.Default()
206 }
207 }
208 mapv.Set(key.MapKey(), val)
209 return n, nil
210}
211
212// errUnknown is used internally to indicate fields which should be added
213// to the unknown field set of a message. It is never returned from an exported
214// function.
215var errUnknown = errors.New("unknown")