blob: dc41ea699dbdae9745fd4dc2dcb8b8e67af43fc6 [file] [log] [blame]
Damien Neilba23aa52018-12-07 14:38:17 -08001// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style.
3// license that can be found in the LICENSE file.
4
5package proto
6
7import (
8 "errors"
9
10 "github.com/golang/protobuf/v2/internal/encoding/wire"
11 "github.com/golang/protobuf/v2/internal/pragma"
12 "github.com/golang/protobuf/v2/reflect/protoreflect"
13)
14
15// UnmarshalOptions configures the unmarshaler.
16//
17// Example usage:
18// err := UnmarshalOptions{DiscardUnknown: true}.Unmarshal(b, m)
19type UnmarshalOptions struct {
20 // If DiscardUnknown is set, unknown fields are ignored.
21 DiscardUnknown bool
22
23 pragma.NoUnkeyedLiterals
24}
25
26// Unmarshal parses the wire-format message in b and places the result in m.
27func Unmarshal(b []byte, m Message) error {
28 return UnmarshalOptions{}.Unmarshal(b, m)
29}
30
31// Unmarshal parses the wire-format message in b and places the result in m.
32func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
33 // TODO: Reset m?
34 return o.unmarshalMessage(b, m.ProtoReflect())
35}
36
37func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) error {
38 messageType := m.Type()
39 fieldTypes := messageType.Fields()
40 knownFields := m.KnownFields()
41 unknownFields := m.UnknownFields()
42 for len(b) > 0 {
43 // Parse the tag (field number and wire type).
44 num, wtyp, tagLen := wire.ConsumeTag(b)
45 if tagLen < 0 {
46 return wire.ParseError(tagLen)
47 }
48
49 // Parse the field value.
50 fieldType := fieldTypes.ByNumber(num)
Damien Neild068d302018-12-17 14:06:08 -080051 if fieldType == nil {
52 fieldType = knownFields.ExtensionTypes().ByNumber(num)
53 }
Damien Neilba23aa52018-12-07 14:38:17 -080054 var err error
55 var valLen int
56 switch {
57 case fieldType == nil:
58 err = errUnknown
59 case fieldType.Cardinality() != protoreflect.Repeated:
60 valLen, err = o.unmarshalScalarField(b[tagLen:], wtyp, num, knownFields, fieldType)
61 case !fieldType.IsMap():
62 valLen, err = o.unmarshalList(b[tagLen:], wtyp, num, knownFields.Get(num).List(), fieldType.Kind())
63 default:
64 valLen, err = o.unmarshalMap(b[tagLen:], wtyp, num, knownFields.Get(num).Map(), fieldType)
65 }
66 if err == errUnknown {
67 valLen = wire.ConsumeFieldValue(num, wtyp, b[tagLen:])
68 if valLen < 0 {
69 return wire.ParseError(valLen)
70 }
71 unknownFields.Set(num, append(unknownFields.Get(num), b[:tagLen+valLen]...))
72 } else if err != nil {
73 return err
74 }
75 b = b[tagLen+valLen:]
76 }
77 // TODO: required field checks
78 return nil
79}
80
81func (o UnmarshalOptions) unmarshalScalarField(b []byte, wtyp wire.Type, num wire.Number, knownFields protoreflect.KnownFields, field protoreflect.FieldDescriptor) (n int, err error) {
82 v, n, err := o.unmarshalScalar(b, wtyp, num, field.Kind())
83 if err != nil {
84 return 0, err
85 }
86 switch field.Kind() {
87 case protoreflect.GroupKind, protoreflect.MessageKind:
88 // Messages are merged with any existing message value,
89 // unless the message is part of a oneof.
90 //
91 // TODO: C++ merges into oneofs, while v1 does not.
92 // Evaluate which behavior to pick.
93 var m protoreflect.Message
94 if knownFields.Has(num) && field.OneofType() == nil {
95 m = knownFields.Get(num).Message()
96 } else {
97 m = knownFields.NewMessage(num).ProtoReflect()
98 knownFields.Set(num, protoreflect.ValueOf(m))
99 }
100 if err := o.unmarshalMessage(v.Bytes(), m); err != nil {
101 return 0, err
102 }
103 default:
104 // Non-message scalars replace the previous value.
105 knownFields.Set(num, v)
106 }
107 return n, nil
108}
109
110func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp wire.Type, num wire.Number, mapv protoreflect.Map, field protoreflect.FieldDescriptor) (n int, err error) {
111 if wtyp != wire.BytesType {
112 return 0, errUnknown
113 }
114 b, n = wire.ConsumeBytes(b)
115 if n < 0 {
116 return 0, wire.ParseError(n)
117 }
118 var (
119 keyField = field.MessageType().Fields().ByNumber(1)
120 valField = field.MessageType().Fields().ByNumber(2)
121 key protoreflect.Value
122 val protoreflect.Value
123 haveKey bool
124 haveVal bool
125 )
126 switch valField.Kind() {
127 case protoreflect.GroupKind, protoreflect.MessageKind:
128 val = protoreflect.ValueOf(mapv.NewMessage().ProtoReflect())
129 }
130 // Map entries are represented as a two-element message with fields
131 // containing the key and value.
132 for len(b) > 0 {
133 num, wtyp, n := wire.ConsumeTag(b)
134 if n < 0 {
135 return 0, wire.ParseError(n)
136 }
137 b = b[n:]
138 err = errUnknown
139 switch num {
140 case 1:
141 key, n, err = o.unmarshalScalar(b, wtyp, num, keyField.Kind())
142 if err != nil {
143 break
144 }
145 haveKey = true
146 case 2:
147 var v protoreflect.Value
148 v, n, err = o.unmarshalScalar(b, wtyp, num, valField.Kind())
149 if err != nil {
150 break
151 }
152 switch valField.Kind() {
153 case protoreflect.GroupKind, protoreflect.MessageKind:
154 if err := o.unmarshalMessage(v.Bytes(), val.Message()); err != nil {
155 return 0, err
156 }
157 default:
158 val = v
159 }
160 haveVal = true
161 }
162 if err == errUnknown {
163 n = wire.ConsumeFieldValue(num, wtyp, b)
164 if n < 0 {
165 return 0, wire.ParseError(n)
166 }
167 } else if err != nil {
168 return 0, err
169 }
170 b = b[n:]
171 }
172 // Every map entry should have entries for key and value, but this is not strictly required.
173 if !haveKey {
174 key = keyField.Default()
175 }
176 if !haveVal {
177 switch valField.Kind() {
178 case protoreflect.GroupKind, protoreflect.MessageKind:
179 // Trigger required field checks by unmarshaling an empty message.
180 if err := o.unmarshalMessage(nil, val.Message()); err != nil {
181 return 0, err
182 }
183 default:
184 val = valField.Default()
185 }
186 }
187 mapv.Set(key.MapKey(), val)
188 return n, nil
189}
190
191// errUnknown is used internally to indicate fields which should be added
192// to the unknown field set of a message. It is never returned from an exported
193// function.
194var errUnknown = errors.New("unknown")