blob: 35a67c25bfcef5784b22c6f4bb794cf6882b0a7f [file] [log] [blame]
Damien Neilc37adef2019-04-01 13:49:56 -07001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
Damien Neilf2427c02019-12-20 09:43:20 -08008 "errors"
Damien Neilc37adef2019-04-01 13:49:56 -07009 "reflect"
Damien Neil3e42b662019-12-17 11:39:17 -080010 "sort"
Damien Neilc37adef2019-04-01 13:49:56 -070011
Joe Tsaicd108d02020-02-14 18:08:02 -080012 "google.golang.org/protobuf/encoding/protowire"
Damien Neilc37adef2019-04-01 13:49:56 -070013 pref "google.golang.org/protobuf/reflect/protoreflect"
14)
15
Damien Neile91877d2019-06-27 10:54:42 -070016type mapInfo struct {
Damien Neil316febd2020-02-09 12:26:50 -080017 goType reflect.Type
18 keyWiretag uint64
19 valWiretag uint64
20 keyFuncs valueCoderFuncs
21 valFuncs valueCoderFuncs
22 keyZero pref.Value
23 keyKind pref.Kind
24 conv *mapConverter
Damien Neile91877d2019-06-27 10:54:42 -070025}
26
Damien Neil316febd2020-02-09 12:26:50 -080027func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (valueMessage *MessageInfo, funcs pointerCoderFuncs) {
Damien Neilc37adef2019-04-01 13:49:56 -070028 // TODO: Consider generating specialized map coders.
29 keyField := fd.MapKey()
30 valField := fd.MapValue()
Joe Tsaicd108d02020-02-14 18:08:02 -080031 keyWiretag := protowire.EncodeTag(1, wireTypes[keyField.Kind()])
32 valWiretag := protowire.EncodeTag(2, wireTypes[valField.Kind()])
Damien Neil4b3a82f2019-09-04 19:07:00 -070033 keyFuncs := encoderFuncsForValue(keyField)
34 valFuncs := encoderFuncsForValue(valField)
Damien Neil3e42b662019-12-17 11:39:17 -080035 conv := newMapConverter(ft, fd)
Damien Neilc37adef2019-04-01 13:49:56 -070036
Damien Neile91877d2019-06-27 10:54:42 -070037 mapi := &mapInfo{
38 goType: ft,
39 keyWiretag: keyWiretag,
40 valWiretag: valWiretag,
41 keyFuncs: keyFuncs,
42 valFuncs: valFuncs,
Damien Neil68b81c32019-08-22 11:41:32 -070043 keyZero: keyField.Default(),
44 keyKind: keyField.Kind(),
Damien Neil3e42b662019-12-17 11:39:17 -080045 conv: conv,
46 }
47 if valField.Kind() == pref.MessageKind {
Damien Neil316febd2020-02-09 12:26:50 -080048 valueMessage = getMessageInfo(ft.Elem())
Damien Neile91877d2019-06-27 10:54:42 -070049 }
50
Damien Neil5322bdb2019-04-09 15:57:05 -070051 funcs = pointerCoderFuncs{
Damien Neil316febd2020-02-09 12:26:50 -080052 size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int {
53 return sizeMap(p.AsValueOf(ft).Elem(), mapi, f, opts)
Damien Neilc37adef2019-04-01 13:49:56 -070054 },
Damien Neil316febd2020-02-09 12:26:50 -080055 marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
56 return appendMap(b, p.AsValueOf(ft).Elem(), mapi, f, opts)
Damien Neilc37adef2019-04-01 13:49:56 -070057 },
Joe Tsaicd108d02020-02-14 18:08:02 -080058 unmarshal: func(b []byte, p pointer, wtyp protowire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) {
Damien Neil68b81c32019-08-22 11:41:32 -070059 mp := p.AsValueOf(ft)
60 if mp.Elem().IsNil() {
61 mp.Elem().Set(reflect.MakeMap(mapi.goType))
62 }
Damien Neil316febd2020-02-09 12:26:50 -080063 if f.mi == nil {
64 return consumeMap(b, mp.Elem(), wtyp, mapi, f, opts)
Damien Neil3e42b662019-12-17 11:39:17 -080065 } else {
Damien Neil316febd2020-02-09 12:26:50 -080066 return consumeMapOfMessage(b, mp.Elem(), wtyp, mapi, f, opts)
Damien Neil3e42b662019-12-17 11:39:17 -080067 }
Damien Neile91877d2019-06-27 10:54:42 -070068 },
Damien Neilc37adef2019-04-01 13:49:56 -070069 }
Damien Neile8e88752020-02-11 11:25:16 -080070 switch valField.Kind() {
71 case pref.MessageKind:
72 funcs.merge = mergeMapOfMessage
73 case pref.BytesKind:
74 funcs.merge = mergeMapOfBytes
75 default:
76 funcs.merge = mergeMap
77 }
Damien Neil5322bdb2019-04-09 15:57:05 -070078 if valFuncs.isInit != nil {
Damien Neil316febd2020-02-09 12:26:50 -080079 funcs.isInit = func(p pointer, f *coderFieldInfo) error {
80 return isInitMap(p.AsValueOf(ft).Elem(), mapi, f)
Damien Neil5322bdb2019-04-09 15:57:05 -070081 }
82 }
Damien Neil316febd2020-02-09 12:26:50 -080083 return valueMessage, funcs
Damien Neilc37adef2019-04-01 13:49:56 -070084}
85
86const (
87 mapKeyTagSize = 1 // field 1, tag size 1.
88 mapValTagSize = 1 // field 2, tag size 2.
89)
90
Damien Neil316febd2020-02-09 12:26:50 -080091func sizeMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) int {
Damien Neil68b81c32019-08-22 11:41:32 -070092 if mapv.Len() == 0 {
93 return 0
Damien Neile91877d2019-06-27 10:54:42 -070094 }
Damien Neil68b81c32019-08-22 11:41:32 -070095 n := 0
Damien Neil3e42b662019-12-17 11:39:17 -080096 iter := mapRange(mapv)
97 for iter.Next() {
98 key := mapi.conv.keyConv.PBValueOf(iter.Key()).MapKey()
99 keySize := mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
100 var valSize int
101 value := mapi.conv.valConv.PBValueOf(iter.Value())
Damien Neil316febd2020-02-09 12:26:50 -0800102 if f.mi == nil {
Damien Neil3e42b662019-12-17 11:39:17 -0800103 valSize = mapi.valFuncs.size(value, mapValTagSize, opts)
104 } else {
105 p := pointerOfValue(iter.Value())
106 valSize += mapValTagSize
Joe Tsaicd108d02020-02-14 18:08:02 -0800107 valSize += protowire.SizeBytes(f.mi.sizePointer(p, opts))
Damien Neil3e42b662019-12-17 11:39:17 -0800108 }
Joe Tsaicd108d02020-02-14 18:08:02 -0800109 n += f.tagsize + protowire.SizeBytes(keySize+valSize)
Damien Neil3e42b662019-12-17 11:39:17 -0800110 }
Damien Neil68b81c32019-08-22 11:41:32 -0700111 return n
112}
Damien Neile91877d2019-06-27 10:54:42 -0700113
Joe Tsaicd108d02020-02-14 18:08:02 -0800114func consumeMap(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
115 if wtyp != protowire.BytesType {
Damien Neilf0831e82020-01-21 14:25:12 -0800116 return out, errUnknown
Damien Neile91877d2019-06-27 10:54:42 -0700117 }
Joe Tsaicd108d02020-02-14 18:08:02 -0800118 b, n := protowire.ConsumeBytes(b)
Damien Neile91877d2019-06-27 10:54:42 -0700119 if n < 0 {
Joe Tsaicd108d02020-02-14 18:08:02 -0800120 return out, protowire.ParseError(n)
Damien Neile91877d2019-06-27 10:54:42 -0700121 }
122 var (
123 key = mapi.keyZero
Damien Neil3e42b662019-12-17 11:39:17 -0800124 val = mapi.conv.valConv.New()
Damien Neile91877d2019-06-27 10:54:42 -0700125 )
Damien Neile91877d2019-06-27 10:54:42 -0700126 for len(b) > 0 {
Joe Tsaicd108d02020-02-14 18:08:02 -0800127 num, wtyp, n := protowire.ConsumeTag(b)
Damien Neile91877d2019-06-27 10:54:42 -0700128 if n < 0 {
Joe Tsaicd108d02020-02-14 18:08:02 -0800129 return out, protowire.ParseError(n)
Damien Neile91877d2019-06-27 10:54:42 -0700130 }
Joe Tsaicd108d02020-02-14 18:08:02 -0800131 if num > protowire.MaxValidNumber {
Damien Neilf0831e82020-01-21 14:25:12 -0800132 return out, errors.New("invalid field number")
Damien Neilf2427c02019-12-20 09:43:20 -0800133 }
Damien Neile91877d2019-06-27 10:54:42 -0700134 b = b[n:]
135 err := errUnknown
136 switch num {
137 case 1:
Damien Neil68b81c32019-08-22 11:41:32 -0700138 var v pref.Value
Damien Neilf0831e82020-01-21 14:25:12 -0800139 var o unmarshalOutput
140 v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
Damien Neile91877d2019-06-27 10:54:42 -0700141 if err != nil {
142 break
143 }
144 key = v
Damien Neilf0831e82020-01-21 14:25:12 -0800145 n = o.n
Damien Neile91877d2019-06-27 10:54:42 -0700146 case 2:
Damien Neil68b81c32019-08-22 11:41:32 -0700147 var v pref.Value
Damien Neilf0831e82020-01-21 14:25:12 -0800148 var o unmarshalOutput
149 v, o, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
Damien Neile91877d2019-06-27 10:54:42 -0700150 if err != nil {
151 break
152 }
153 val = v
Damien Neilf0831e82020-01-21 14:25:12 -0800154 n = o.n
Damien Neile91877d2019-06-27 10:54:42 -0700155 }
156 if err == errUnknown {
Joe Tsaicd108d02020-02-14 18:08:02 -0800157 n = protowire.ConsumeFieldValue(num, wtyp, b)
Damien Neile91877d2019-06-27 10:54:42 -0700158 if n < 0 {
Joe Tsaicd108d02020-02-14 18:08:02 -0800159 return out, protowire.ParseError(n)
Damien Neile91877d2019-06-27 10:54:42 -0700160 }
161 } else if err != nil {
Damien Neilf0831e82020-01-21 14:25:12 -0800162 return out, err
Damien Neile91877d2019-06-27 10:54:42 -0700163 }
164 b = b[n:]
165 }
Damien Neil3e42b662019-12-17 11:39:17 -0800166 mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), mapi.conv.valConv.GoValueOf(val))
Damien Neilf0831e82020-01-21 14:25:12 -0800167 out.n = n
168 return out, nil
Damien Neile91877d2019-06-27 10:54:42 -0700169}
170
Joe Tsaicd108d02020-02-14 18:08:02 -0800171func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp protowire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
172 if wtyp != protowire.BytesType {
Damien Neilf0831e82020-01-21 14:25:12 -0800173 return out, errUnknown
Damien Neil3e42b662019-12-17 11:39:17 -0800174 }
Joe Tsaicd108d02020-02-14 18:08:02 -0800175 b, n := protowire.ConsumeBytes(b)
Damien Neil3e42b662019-12-17 11:39:17 -0800176 if n < 0 {
Joe Tsaicd108d02020-02-14 18:08:02 -0800177 return out, protowire.ParseError(n)
Damien Neil3e42b662019-12-17 11:39:17 -0800178 }
179 var (
180 key = mapi.keyZero
Damien Neil316febd2020-02-09 12:26:50 -0800181 val = reflect.New(f.mi.GoReflectType.Elem())
Damien Neil3e42b662019-12-17 11:39:17 -0800182 )
183 for len(b) > 0 {
Joe Tsaicd108d02020-02-14 18:08:02 -0800184 num, wtyp, n := protowire.ConsumeTag(b)
Damien Neil3e42b662019-12-17 11:39:17 -0800185 if n < 0 {
Joe Tsaicd108d02020-02-14 18:08:02 -0800186 return out, protowire.ParseError(n)
Damien Neil3e42b662019-12-17 11:39:17 -0800187 }
Joe Tsaicd108d02020-02-14 18:08:02 -0800188 if num > protowire.MaxValidNumber {
Damien Neilf0831e82020-01-21 14:25:12 -0800189 return out, errors.New("invalid field number")
Damien Neilf2427c02019-12-20 09:43:20 -0800190 }
Damien Neil3e42b662019-12-17 11:39:17 -0800191 b = b[n:]
192 err := errUnknown
193 switch num {
194 case 1:
195 var v pref.Value
Damien Neilf0831e82020-01-21 14:25:12 -0800196 var o unmarshalOutput
197 v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
Damien Neil3e42b662019-12-17 11:39:17 -0800198 if err != nil {
199 break
200 }
201 key = v
Damien Neilf0831e82020-01-21 14:25:12 -0800202 n = o.n
Damien Neil3e42b662019-12-17 11:39:17 -0800203 case 2:
Joe Tsaicd108d02020-02-14 18:08:02 -0800204 if wtyp != protowire.BytesType {
Damien Neil3e42b662019-12-17 11:39:17 -0800205 break
206 }
Damien Neil7e690b52019-12-18 09:35:01 -0800207 var v []byte
Joe Tsaicd108d02020-02-14 18:08:02 -0800208 v, n = protowire.ConsumeBytes(b)
Damien Neil3e42b662019-12-17 11:39:17 -0800209 if n < 0 {
Joe Tsaicd108d02020-02-14 18:08:02 -0800210 return out, protowire.ParseError(n)
Damien Neil3e42b662019-12-17 11:39:17 -0800211 }
Damien Neilc600d6c2020-01-21 15:00:33 -0800212 var o unmarshalOutput
Damien Neil316febd2020-02-09 12:26:50 -0800213 o, err = f.mi.unmarshalPointer(v, pointerOfValue(val), 0, opts)
Damien Neilc600d6c2020-01-21 15:00:33 -0800214 if o.initialized {
215 // Consider this map item initialized so long as we see
216 // an initialized value.
217 out.initialized = true
218 }
Damien Neil3e42b662019-12-17 11:39:17 -0800219 }
220 if err == errUnknown {
Joe Tsaicd108d02020-02-14 18:08:02 -0800221 n = protowire.ConsumeFieldValue(num, wtyp, b)
Damien Neil3e42b662019-12-17 11:39:17 -0800222 if n < 0 {
Joe Tsaicd108d02020-02-14 18:08:02 -0800223 return out, protowire.ParseError(n)
Damien Neil3e42b662019-12-17 11:39:17 -0800224 }
225 } else if err != nil {
Damien Neilf0831e82020-01-21 14:25:12 -0800226 return out, err
Damien Neil3e42b662019-12-17 11:39:17 -0800227 }
228 b = b[n:]
229 }
230 mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), val)
Damien Neilf0831e82020-01-21 14:25:12 -0800231 out.n = n
232 return out, nil
Damien Neil3e42b662019-12-17 11:39:17 -0800233}
234
Damien Neil316febd2020-02-09 12:26:50 -0800235func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
236 if f.mi == nil {
Damien Neil3e42b662019-12-17 11:39:17 -0800237 key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
238 val := mapi.conv.valConv.PBValueOf(valrv)
239 size := 0
240 size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
241 size += mapi.valFuncs.size(val, mapValTagSize, opts)
Joe Tsaicd108d02020-02-14 18:08:02 -0800242 b = protowire.AppendVarint(b, uint64(size))
Damien Neil3e42b662019-12-17 11:39:17 -0800243 b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
244 if err != nil {
245 return nil, err
246 }
247 return mapi.valFuncs.marshal(b, val, mapi.valWiretag, opts)
248 } else {
249 key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
250 val := pointerOfValue(valrv)
Damien Neil316febd2020-02-09 12:26:50 -0800251 valSize := f.mi.sizePointer(val, opts)
Damien Neil3e42b662019-12-17 11:39:17 -0800252 size := 0
253 size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
Joe Tsaicd108d02020-02-14 18:08:02 -0800254 size += mapValTagSize + protowire.SizeBytes(valSize)
255 b = protowire.AppendVarint(b, uint64(size))
Damien Neil3e42b662019-12-17 11:39:17 -0800256 b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
257 if err != nil {
258 return nil, err
259 }
Joe Tsaicd108d02020-02-14 18:08:02 -0800260 b = protowire.AppendVarint(b, mapi.valWiretag)
261 b = protowire.AppendVarint(b, uint64(valSize))
Damien Neil316febd2020-02-09 12:26:50 -0800262 return f.mi.marshalAppendPointer(b, val, opts)
Damien Neil3e42b662019-12-17 11:39:17 -0800263 }
264}
265
Damien Neil316febd2020-02-09 12:26:50 -0800266func appendMap(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
Damien Neil68b81c32019-08-22 11:41:32 -0700267 if mapv.Len() == 0 {
268 return b, nil
Damien Neilc37adef2019-04-01 13:49:56 -0700269 }
Damien Neil68b81c32019-08-22 11:41:32 -0700270 if opts.Deterministic() {
Damien Neil316febd2020-02-09 12:26:50 -0800271 return appendMapDeterministic(b, mapv, mapi, f, opts)
Damien Neil68b81c32019-08-22 11:41:32 -0700272 }
Damien Neil3e42b662019-12-17 11:39:17 -0800273 iter := mapRange(mapv)
274 for iter.Next() {
275 var err error
Joe Tsaicd108d02020-02-14 18:08:02 -0800276 b = protowire.AppendVarint(b, f.wiretag)
Damien Neil316febd2020-02-09 12:26:50 -0800277 b, err = appendMapItem(b, iter.Key(), iter.Value(), mapi, f, opts)
Damien Neil3e42b662019-12-17 11:39:17 -0800278 if err != nil {
279 return b, err
280 }
281 }
282 return b, nil
Damien Neil5322bdb2019-04-09 15:57:05 -0700283}
284
Damien Neil316febd2020-02-09 12:26:50 -0800285func appendMapDeterministic(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
Damien Neil3e42b662019-12-17 11:39:17 -0800286 keys := mapv.MapKeys()
287 sort.Slice(keys, func(i, j int) bool {
288 switch keys[i].Kind() {
289 case reflect.Bool:
290 return !keys[i].Bool() && keys[j].Bool()
291 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
292 return keys[i].Int() < keys[j].Int()
293 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
294 return keys[i].Uint() < keys[j].Uint()
295 case reflect.Float32, reflect.Float64:
296 return keys[i].Float() < keys[j].Float()
297 case reflect.String:
298 return keys[i].String() < keys[j].String()
299 default:
300 panic("invalid kind: " + keys[i].Kind().String())
301 }
Damien Neil68b81c32019-08-22 11:41:32 -0700302 })
Damien Neil3e42b662019-12-17 11:39:17 -0800303 for _, key := range keys {
304 var err error
Joe Tsaicd108d02020-02-14 18:08:02 -0800305 b = protowire.AppendVarint(b, f.wiretag)
Damien Neil316febd2020-02-09 12:26:50 -0800306 b, err = appendMapItem(b, key, mapv.MapIndex(key), mapi, f, opts)
Damien Neil3e42b662019-12-17 11:39:17 -0800307 if err != nil {
308 return b, err
309 }
310 }
311 return b, nil
312}
313
Damien Neil316febd2020-02-09 12:26:50 -0800314func isInitMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo) error {
315 if mi := f.mi; mi != nil {
Damien Neil3e42b662019-12-17 11:39:17 -0800316 mi.init()
317 if !mi.needsInitCheck {
318 return nil
319 }
320 iter := mapRange(mapv)
321 for iter.Next() {
322 val := pointerOfValue(iter.Value())
Joe Tsaif26a9e72020-02-20 10:05:37 -0800323 if err := mi.checkInitializedPointer(val); err != nil {
Damien Neil3e42b662019-12-17 11:39:17 -0800324 return err
325 }
326 }
327 } else {
328 iter := mapRange(mapv)
329 for iter.Next() {
330 val := mapi.conv.valConv.PBValueOf(iter.Value())
331 if err := mapi.valFuncs.isInit(val); err != nil {
332 return err
333 }
334 }
335 }
336 return nil
Damien Neilc37adef2019-04-01 13:49:56 -0700337}
Damien Neile8e88752020-02-11 11:25:16 -0800338
339func mergeMap(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
340 dstm := dst.AsValueOf(f.ft).Elem()
341 srcm := src.AsValueOf(f.ft).Elem()
342 if srcm.Len() == 0 {
343 return
344 }
345 if dstm.IsNil() {
346 dstm.Set(reflect.MakeMap(f.ft))
347 }
348 iter := mapRange(srcm)
349 for iter.Next() {
350 dstm.SetMapIndex(iter.Key(), iter.Value())
351 }
352}
353
354func mergeMapOfBytes(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
355 dstm := dst.AsValueOf(f.ft).Elem()
356 srcm := src.AsValueOf(f.ft).Elem()
357 if srcm.Len() == 0 {
358 return
359 }
360 if dstm.IsNil() {
361 dstm.Set(reflect.MakeMap(f.ft))
362 }
363 iter := mapRange(srcm)
364 for iter.Next() {
365 dstm.SetMapIndex(iter.Key(), reflect.ValueOf(append(emptyBuf[:], iter.Value().Bytes()...)))
366 }
367}
368
369func mergeMapOfMessage(dst, src pointer, f *coderFieldInfo, opts mergeOptions) {
370 dstm := dst.AsValueOf(f.ft).Elem()
371 srcm := src.AsValueOf(f.ft).Elem()
372 if srcm.Len() == 0 {
373 return
374 }
375 if dstm.IsNil() {
376 dstm.Set(reflect.MakeMap(f.ft))
377 }
378 iter := mapRange(srcm)
379 for iter.Next() {
380 val := reflect.New(f.ft.Elem().Elem())
381 if f.mi != nil {
382 f.mi.mergePointer(pointerOfValue(val), pointerOfValue(iter.Value()), opts)
383 } else {
384 opts.Merge(asMessage(val), asMessage(iter.Value()))
385 }
386 dstm.SetMapIndex(iter.Key(), val)
387 }
388}