blob: 05d1ecd15a9afe72e0a6d6df992f7e7333c790b6 [file] [log] [blame]
Damien Neilc37adef2019-04-01 13:49:56 -07001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
Damien Neilf2427c02019-12-20 09:43:20 -08008 "errors"
Damien Neilc37adef2019-04-01 13:49:56 -07009 "reflect"
Damien Neil3e42b662019-12-17 11:39:17 -080010 "sort"
Damien Neilc37adef2019-04-01 13:49:56 -070011
12 "google.golang.org/protobuf/internal/encoding/wire"
Damien Neilc37adef2019-04-01 13:49:56 -070013 pref "google.golang.org/protobuf/reflect/protoreflect"
14)
15
Damien Neile91877d2019-06-27 10:54:42 -070016type mapInfo struct {
Damien Neil3e42b662019-12-17 11:39:17 -080017 goType reflect.Type
18 keyWiretag uint64
19 valWiretag uint64
20 keyFuncs valueCoderFuncs
21 valFuncs valueCoderFuncs
22 keyZero pref.Value
23 keyKind pref.Kind
24 valMessageInfo *MessageInfo
25 conv *mapConverter
Damien Neile91877d2019-06-27 10:54:42 -070026}
27
Damien Neilc37adef2019-04-01 13:49:56 -070028func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (funcs pointerCoderFuncs) {
29 // TODO: Consider generating specialized map coders.
30 keyField := fd.MapKey()
31 valField := fd.MapValue()
32 keyWiretag := wire.EncodeTag(1, wireTypes[keyField.Kind()])
33 valWiretag := wire.EncodeTag(2, wireTypes[valField.Kind()])
Damien Neil4b3a82f2019-09-04 19:07:00 -070034 keyFuncs := encoderFuncsForValue(keyField)
35 valFuncs := encoderFuncsForValue(valField)
Damien Neil3e42b662019-12-17 11:39:17 -080036 conv := newMapConverter(ft, fd)
Damien Neilc37adef2019-04-01 13:49:56 -070037
Damien Neile91877d2019-06-27 10:54:42 -070038 mapi := &mapInfo{
39 goType: ft,
40 keyWiretag: keyWiretag,
41 valWiretag: valWiretag,
42 keyFuncs: keyFuncs,
43 valFuncs: valFuncs,
Damien Neil68b81c32019-08-22 11:41:32 -070044 keyZero: keyField.Default(),
45 keyKind: keyField.Kind(),
Damien Neil3e42b662019-12-17 11:39:17 -080046 conv: conv,
47 }
48 if valField.Kind() == pref.MessageKind {
49 mapi.valMessageInfo = getMessageInfo(ft.Elem())
Damien Neile91877d2019-06-27 10:54:42 -070050 }
51
Damien Neil5322bdb2019-04-09 15:57:05 -070052 funcs = pointerCoderFuncs{
Damien Neilc37adef2019-04-01 13:49:56 -070053 size: func(p pointer, tagsize int, opts marshalOptions) int {
Damien Neil3e42b662019-12-17 11:39:17 -080054 return sizeMap(p.AsValueOf(ft).Elem(), tagsize, mapi, opts)
Damien Neilc37adef2019-04-01 13:49:56 -070055 },
56 marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
Damien Neil3e42b662019-12-17 11:39:17 -080057 return appendMap(b, p.AsValueOf(ft).Elem(), wiretag, mapi, opts)
Damien Neilc37adef2019-04-01 13:49:56 -070058 },
Damien Neile91877d2019-06-27 10:54:42 -070059 unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
Damien Neil68b81c32019-08-22 11:41:32 -070060 mp := p.AsValueOf(ft)
61 if mp.Elem().IsNil() {
62 mp.Elem().Set(reflect.MakeMap(mapi.goType))
63 }
Damien Neil3e42b662019-12-17 11:39:17 -080064 if mapi.valMessageInfo == nil {
65 return consumeMap(b, mp.Elem(), wtyp, mapi, opts)
66 } else {
67 return consumeMapOfMessage(b, mp.Elem(), wtyp, mapi, opts)
68 }
Damien Neile91877d2019-06-27 10:54:42 -070069 },
Damien Neilc37adef2019-04-01 13:49:56 -070070 }
Damien Neil5322bdb2019-04-09 15:57:05 -070071 if valFuncs.isInit != nil {
72 funcs.isInit = func(p pointer) error {
Damien Neil3e42b662019-12-17 11:39:17 -080073 return isInitMap(p.AsValueOf(ft).Elem(), mapi)
Damien Neil5322bdb2019-04-09 15:57:05 -070074 }
75 }
76 return funcs
Damien Neilc37adef2019-04-01 13:49:56 -070077}
78
79const (
80 mapKeyTagSize = 1 // field 1, tag size 1.
81 mapValTagSize = 1 // field 2, tag size 2.
82)
83
Damien Neil3e42b662019-12-17 11:39:17 -080084func sizeMap(mapv reflect.Value, tagsize int, mapi *mapInfo, opts marshalOptions) int {
Damien Neil68b81c32019-08-22 11:41:32 -070085 if mapv.Len() == 0 {
86 return 0
Damien Neile91877d2019-06-27 10:54:42 -070087 }
Damien Neil68b81c32019-08-22 11:41:32 -070088 n := 0
Damien Neil3e42b662019-12-17 11:39:17 -080089 iter := mapRange(mapv)
90 for iter.Next() {
91 key := mapi.conv.keyConv.PBValueOf(iter.Key()).MapKey()
92 keySize := mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
93 var valSize int
94 value := mapi.conv.valConv.PBValueOf(iter.Value())
95 if mapi.valMessageInfo == nil {
96 valSize = mapi.valFuncs.size(value, mapValTagSize, opts)
97 } else {
98 p := pointerOfValue(iter.Value())
99 valSize += mapValTagSize
100 valSize += wire.SizeBytes(mapi.valMessageInfo.sizePointer(p, opts))
101 }
102 n += tagsize + wire.SizeBytes(keySize+valSize)
103 }
Damien Neil68b81c32019-08-22 11:41:32 -0700104 return n
105}
Damien Neile91877d2019-06-27 10:54:42 -0700106
Damien Neil3e42b662019-12-17 11:39:17 -0800107func consumeMap(b []byte, mapv reflect.Value, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (int, error) {
Damien Neile91877d2019-06-27 10:54:42 -0700108 if wtyp != wire.BytesType {
109 return 0, errUnknown
110 }
111 b, n := wire.ConsumeBytes(b)
112 if n < 0 {
113 return 0, wire.ParseError(n)
114 }
115 var (
116 key = mapi.keyZero
Damien Neil3e42b662019-12-17 11:39:17 -0800117 val = mapi.conv.valConv.New()
Damien Neile91877d2019-06-27 10:54:42 -0700118 )
Damien Neile91877d2019-06-27 10:54:42 -0700119 for len(b) > 0 {
120 num, wtyp, n := wire.ConsumeTag(b)
121 if n < 0 {
122 return 0, wire.ParseError(n)
123 }
Damien Neilf2427c02019-12-20 09:43:20 -0800124 if num > wire.MaxValidNumber {
125 return 0, errors.New("invalid field number")
126 }
Damien Neile91877d2019-06-27 10:54:42 -0700127 b = b[n:]
128 err := errUnknown
129 switch num {
130 case 1:
Damien Neil68b81c32019-08-22 11:41:32 -0700131 var v pref.Value
Damien Neile91877d2019-06-27 10:54:42 -0700132 v, n, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
133 if err != nil {
134 break
135 }
136 key = v
137 case 2:
Damien Neil68b81c32019-08-22 11:41:32 -0700138 var v pref.Value
Damien Neile91877d2019-06-27 10:54:42 -0700139 v, n, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
140 if err != nil {
141 break
142 }
143 val = v
144 }
145 if err == errUnknown {
146 n = wire.ConsumeFieldValue(num, wtyp, b)
147 if n < 0 {
148 return 0, wire.ParseError(n)
149 }
150 } else if err != nil {
151 return 0, err
152 }
153 b = b[n:]
154 }
Damien Neil3e42b662019-12-17 11:39:17 -0800155 mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), mapi.conv.valConv.GoValueOf(val))
Damien Neile91877d2019-06-27 10:54:42 -0700156 return n, nil
157}
158
Damien Neil3e42b662019-12-17 11:39:17 -0800159func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (int, error) {
160 if wtyp != wire.BytesType {
161 return 0, errUnknown
162 }
163 b, n := wire.ConsumeBytes(b)
164 if n < 0 {
165 return 0, wire.ParseError(n)
166 }
167 var (
168 key = mapi.keyZero
169 val = reflect.New(mapi.valMessageInfo.GoReflectType.Elem())
170 )
171 for len(b) > 0 {
172 num, wtyp, n := wire.ConsumeTag(b)
173 if n < 0 {
174 return 0, wire.ParseError(n)
175 }
Damien Neilf2427c02019-12-20 09:43:20 -0800176 if num > wire.MaxValidNumber {
177 return 0, errors.New("invalid field number")
178 }
Damien Neil3e42b662019-12-17 11:39:17 -0800179 b = b[n:]
180 err := errUnknown
181 switch num {
182 case 1:
183 var v pref.Value
184 v, n, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
185 if err != nil {
186 break
187 }
188 key = v
189 case 2:
190 if wtyp != wire.BytesType {
191 break
192 }
Damien Neil7e690b52019-12-18 09:35:01 -0800193 var v []byte
194 v, n = wire.ConsumeBytes(b)
Damien Neil3e42b662019-12-17 11:39:17 -0800195 if n < 0 {
196 return 0, wire.ParseError(n)
197 }
Damien Neil7e690b52019-12-18 09:35:01 -0800198 _, err = mapi.valMessageInfo.unmarshalPointer(v, pointerOfValue(val), 0, opts)
Damien Neil3e42b662019-12-17 11:39:17 -0800199 }
200 if err == errUnknown {
201 n = wire.ConsumeFieldValue(num, wtyp, b)
202 if n < 0 {
203 return 0, wire.ParseError(n)
204 }
205 } else if err != nil {
206 return 0, err
207 }
208 b = b[n:]
209 }
210 mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), val)
211 return n, nil
212}
213
214func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, opts marshalOptions) ([]byte, error) {
215 if mapi.valMessageInfo == nil {
216 key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
217 val := mapi.conv.valConv.PBValueOf(valrv)
218 size := 0
219 size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
220 size += mapi.valFuncs.size(val, mapValTagSize, opts)
221 b = wire.AppendVarint(b, uint64(size))
222 b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
223 if err != nil {
224 return nil, err
225 }
226 return mapi.valFuncs.marshal(b, val, mapi.valWiretag, opts)
227 } else {
228 key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
229 val := pointerOfValue(valrv)
230 valSize := mapi.valMessageInfo.sizePointer(val, opts)
231 size := 0
232 size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
233 size += mapValTagSize + wire.SizeBytes(valSize)
234 b = wire.AppendVarint(b, uint64(size))
235 b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
236 if err != nil {
237 return nil, err
238 }
239 b = wire.AppendVarint(b, mapi.valWiretag)
240 b = wire.AppendVarint(b, uint64(valSize))
241 return mapi.valMessageInfo.marshalAppendPointer(b, val, opts)
242 }
243}
244
245func appendMap(b []byte, mapv reflect.Value, wiretag uint64, mapi *mapInfo, opts marshalOptions) ([]byte, error) {
Damien Neil68b81c32019-08-22 11:41:32 -0700246 if mapv.Len() == 0 {
247 return b, nil
Damien Neilc37adef2019-04-01 13:49:56 -0700248 }
Damien Neil68b81c32019-08-22 11:41:32 -0700249 if opts.Deterministic() {
Damien Neil3e42b662019-12-17 11:39:17 -0800250 return appendMapDeterministic(b, mapv, wiretag, mapi, opts)
Damien Neil68b81c32019-08-22 11:41:32 -0700251 }
Damien Neil3e42b662019-12-17 11:39:17 -0800252 iter := mapRange(mapv)
253 for iter.Next() {
254 var err error
255 b = wire.AppendVarint(b, wiretag)
256 b, err = appendMapItem(b, iter.Key(), iter.Value(), mapi, opts)
257 if err != nil {
258 return b, err
259 }
260 }
261 return b, nil
Damien Neil5322bdb2019-04-09 15:57:05 -0700262}
263
Damien Neil3e42b662019-12-17 11:39:17 -0800264func appendMapDeterministic(b []byte, mapv reflect.Value, wiretag uint64, mapi *mapInfo, opts marshalOptions) ([]byte, error) {
265 keys := mapv.MapKeys()
266 sort.Slice(keys, func(i, j int) bool {
267 switch keys[i].Kind() {
268 case reflect.Bool:
269 return !keys[i].Bool() && keys[j].Bool()
270 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
271 return keys[i].Int() < keys[j].Int()
272 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
273 return keys[i].Uint() < keys[j].Uint()
274 case reflect.Float32, reflect.Float64:
275 return keys[i].Float() < keys[j].Float()
276 case reflect.String:
277 return keys[i].String() < keys[j].String()
278 default:
279 panic("invalid kind: " + keys[i].Kind().String())
280 }
Damien Neil68b81c32019-08-22 11:41:32 -0700281 })
Damien Neil3e42b662019-12-17 11:39:17 -0800282 for _, key := range keys {
283 var err error
284 b = wire.AppendVarint(b, wiretag)
285 b, err = appendMapItem(b, key, mapv.MapIndex(key), mapi, opts)
286 if err != nil {
287 return b, err
288 }
289 }
290 return b, nil
291}
292
293func isInitMap(mapv reflect.Value, mapi *mapInfo) error {
294 if mi := mapi.valMessageInfo; mi != nil {
295 mi.init()
296 if !mi.needsInitCheck {
297 return nil
298 }
299 iter := mapRange(mapv)
300 for iter.Next() {
301 val := pointerOfValue(iter.Value())
302 if err := mi.isInitializedPointer(val); err != nil {
303 return err
304 }
305 }
306 } else {
307 iter := mapRange(mapv)
308 for iter.Next() {
309 val := mapi.conv.valConv.PBValueOf(iter.Value())
310 if err := mapi.valFuncs.isInit(val); err != nil {
311 return err
312 }
313 }
314 }
315 return nil
Damien Neilc37adef2019-04-01 13:49:56 -0700316}