blob: b3195839cbbbf868ebc410db9935d496433fe038 [file] [log] [blame]
Damien Neilc37adef2019-04-01 13:49:56 -07001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
Damien Neilf2427c02019-12-20 09:43:20 -08008 "errors"
Damien Neilc37adef2019-04-01 13:49:56 -07009 "reflect"
Damien Neil3e42b662019-12-17 11:39:17 -080010 "sort"
Damien Neilc37adef2019-04-01 13:49:56 -070011
12 "google.golang.org/protobuf/internal/encoding/wire"
Damien Neilc37adef2019-04-01 13:49:56 -070013 pref "google.golang.org/protobuf/reflect/protoreflect"
14)
15
Damien Neile91877d2019-06-27 10:54:42 -070016type mapInfo struct {
Damien Neil316febd2020-02-09 12:26:50 -080017 goType reflect.Type
18 keyWiretag uint64
19 valWiretag uint64
20 keyFuncs valueCoderFuncs
21 valFuncs valueCoderFuncs
22 keyZero pref.Value
23 keyKind pref.Kind
24 conv *mapConverter
Damien Neile91877d2019-06-27 10:54:42 -070025}
26
Damien Neil316febd2020-02-09 12:26:50 -080027func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (valueMessage *MessageInfo, funcs pointerCoderFuncs) {
Damien Neilc37adef2019-04-01 13:49:56 -070028 // TODO: Consider generating specialized map coders.
29 keyField := fd.MapKey()
30 valField := fd.MapValue()
31 keyWiretag := wire.EncodeTag(1, wireTypes[keyField.Kind()])
32 valWiretag := wire.EncodeTag(2, wireTypes[valField.Kind()])
Damien Neil4b3a82f2019-09-04 19:07:00 -070033 keyFuncs := encoderFuncsForValue(keyField)
34 valFuncs := encoderFuncsForValue(valField)
Damien Neil3e42b662019-12-17 11:39:17 -080035 conv := newMapConverter(ft, fd)
Damien Neilc37adef2019-04-01 13:49:56 -070036
Damien Neile91877d2019-06-27 10:54:42 -070037 mapi := &mapInfo{
38 goType: ft,
39 keyWiretag: keyWiretag,
40 valWiretag: valWiretag,
41 keyFuncs: keyFuncs,
42 valFuncs: valFuncs,
Damien Neil68b81c32019-08-22 11:41:32 -070043 keyZero: keyField.Default(),
44 keyKind: keyField.Kind(),
Damien Neil3e42b662019-12-17 11:39:17 -080045 conv: conv,
46 }
47 if valField.Kind() == pref.MessageKind {
Damien Neil316febd2020-02-09 12:26:50 -080048 valueMessage = getMessageInfo(ft.Elem())
Damien Neile91877d2019-06-27 10:54:42 -070049 }
50
Damien Neil5322bdb2019-04-09 15:57:05 -070051 funcs = pointerCoderFuncs{
Damien Neil316febd2020-02-09 12:26:50 -080052 size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int {
53 return sizeMap(p.AsValueOf(ft).Elem(), mapi, f, opts)
Damien Neilc37adef2019-04-01 13:49:56 -070054 },
Damien Neil316febd2020-02-09 12:26:50 -080055 marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
56 return appendMap(b, p.AsValueOf(ft).Elem(), mapi, f, opts)
Damien Neilc37adef2019-04-01 13:49:56 -070057 },
Damien Neil316febd2020-02-09 12:26:50 -080058 unmarshal: func(b []byte, p pointer, wtyp wire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) {
Damien Neil68b81c32019-08-22 11:41:32 -070059 mp := p.AsValueOf(ft)
60 if mp.Elem().IsNil() {
61 mp.Elem().Set(reflect.MakeMap(mapi.goType))
62 }
Damien Neil316febd2020-02-09 12:26:50 -080063 if f.mi == nil {
64 return consumeMap(b, mp.Elem(), wtyp, mapi, f, opts)
Damien Neil3e42b662019-12-17 11:39:17 -080065 } else {
Damien Neil316febd2020-02-09 12:26:50 -080066 return consumeMapOfMessage(b, mp.Elem(), wtyp, mapi, f, opts)
Damien Neil3e42b662019-12-17 11:39:17 -080067 }
Damien Neile91877d2019-06-27 10:54:42 -070068 },
Damien Neilc37adef2019-04-01 13:49:56 -070069 }
Damien Neil5322bdb2019-04-09 15:57:05 -070070 if valFuncs.isInit != nil {
Damien Neil316febd2020-02-09 12:26:50 -080071 funcs.isInit = func(p pointer, f *coderFieldInfo) error {
72 return isInitMap(p.AsValueOf(ft).Elem(), mapi, f)
Damien Neil5322bdb2019-04-09 15:57:05 -070073 }
74 }
Damien Neil316febd2020-02-09 12:26:50 -080075 return valueMessage, funcs
Damien Neilc37adef2019-04-01 13:49:56 -070076}
77
78const (
79 mapKeyTagSize = 1 // field 1, tag size 1.
80 mapValTagSize = 1 // field 2, tag size 2.
81)
82
Damien Neil316febd2020-02-09 12:26:50 -080083func sizeMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) int {
Damien Neil68b81c32019-08-22 11:41:32 -070084 if mapv.Len() == 0 {
85 return 0
Damien Neile91877d2019-06-27 10:54:42 -070086 }
Damien Neil68b81c32019-08-22 11:41:32 -070087 n := 0
Damien Neil3e42b662019-12-17 11:39:17 -080088 iter := mapRange(mapv)
89 for iter.Next() {
90 key := mapi.conv.keyConv.PBValueOf(iter.Key()).MapKey()
91 keySize := mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
92 var valSize int
93 value := mapi.conv.valConv.PBValueOf(iter.Value())
Damien Neil316febd2020-02-09 12:26:50 -080094 if f.mi == nil {
Damien Neil3e42b662019-12-17 11:39:17 -080095 valSize = mapi.valFuncs.size(value, mapValTagSize, opts)
96 } else {
97 p := pointerOfValue(iter.Value())
98 valSize += mapValTagSize
Damien Neil316febd2020-02-09 12:26:50 -080099 valSize += wire.SizeBytes(f.mi.sizePointer(p, opts))
Damien Neil3e42b662019-12-17 11:39:17 -0800100 }
Damien Neil316febd2020-02-09 12:26:50 -0800101 n += f.tagsize + wire.SizeBytes(keySize+valSize)
Damien Neil3e42b662019-12-17 11:39:17 -0800102 }
Damien Neil68b81c32019-08-22 11:41:32 -0700103 return n
104}
Damien Neile91877d2019-06-27 10:54:42 -0700105
Damien Neil316febd2020-02-09 12:26:50 -0800106func consumeMap(b []byte, mapv reflect.Value, wtyp wire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
Damien Neile91877d2019-06-27 10:54:42 -0700107 if wtyp != wire.BytesType {
Damien Neilf0831e82020-01-21 14:25:12 -0800108 return out, errUnknown
Damien Neile91877d2019-06-27 10:54:42 -0700109 }
110 b, n := wire.ConsumeBytes(b)
111 if n < 0 {
Damien Neilf0831e82020-01-21 14:25:12 -0800112 return out, wire.ParseError(n)
Damien Neile91877d2019-06-27 10:54:42 -0700113 }
114 var (
115 key = mapi.keyZero
Damien Neil3e42b662019-12-17 11:39:17 -0800116 val = mapi.conv.valConv.New()
Damien Neile91877d2019-06-27 10:54:42 -0700117 )
Damien Neile91877d2019-06-27 10:54:42 -0700118 for len(b) > 0 {
119 num, wtyp, n := wire.ConsumeTag(b)
120 if n < 0 {
Damien Neilf0831e82020-01-21 14:25:12 -0800121 return out, wire.ParseError(n)
Damien Neile91877d2019-06-27 10:54:42 -0700122 }
Damien Neilf2427c02019-12-20 09:43:20 -0800123 if num > wire.MaxValidNumber {
Damien Neilf0831e82020-01-21 14:25:12 -0800124 return out, errors.New("invalid field number")
Damien Neilf2427c02019-12-20 09:43:20 -0800125 }
Damien Neile91877d2019-06-27 10:54:42 -0700126 b = b[n:]
127 err := errUnknown
128 switch num {
129 case 1:
Damien Neil68b81c32019-08-22 11:41:32 -0700130 var v pref.Value
Damien Neilf0831e82020-01-21 14:25:12 -0800131 var o unmarshalOutput
132 v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
Damien Neile91877d2019-06-27 10:54:42 -0700133 if err != nil {
134 break
135 }
136 key = v
Damien Neilf0831e82020-01-21 14:25:12 -0800137 n = o.n
Damien Neile91877d2019-06-27 10:54:42 -0700138 case 2:
Damien Neil68b81c32019-08-22 11:41:32 -0700139 var v pref.Value
Damien Neilf0831e82020-01-21 14:25:12 -0800140 var o unmarshalOutput
141 v, o, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
Damien Neile91877d2019-06-27 10:54:42 -0700142 if err != nil {
143 break
144 }
145 val = v
Damien Neilf0831e82020-01-21 14:25:12 -0800146 n = o.n
Damien Neile91877d2019-06-27 10:54:42 -0700147 }
148 if err == errUnknown {
149 n = wire.ConsumeFieldValue(num, wtyp, b)
150 if n < 0 {
Damien Neilf0831e82020-01-21 14:25:12 -0800151 return out, wire.ParseError(n)
Damien Neile91877d2019-06-27 10:54:42 -0700152 }
153 } else if err != nil {
Damien Neilf0831e82020-01-21 14:25:12 -0800154 return out, err
Damien Neile91877d2019-06-27 10:54:42 -0700155 }
156 b = b[n:]
157 }
Damien Neil3e42b662019-12-17 11:39:17 -0800158 mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), mapi.conv.valConv.GoValueOf(val))
Damien Neilf0831e82020-01-21 14:25:12 -0800159 out.n = n
160 return out, nil
Damien Neile91877d2019-06-27 10:54:42 -0700161}
162
Damien Neil316febd2020-02-09 12:26:50 -0800163func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp wire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
Damien Neil3e42b662019-12-17 11:39:17 -0800164 if wtyp != wire.BytesType {
Damien Neilf0831e82020-01-21 14:25:12 -0800165 return out, errUnknown
Damien Neil3e42b662019-12-17 11:39:17 -0800166 }
167 b, n := wire.ConsumeBytes(b)
168 if n < 0 {
Damien Neilf0831e82020-01-21 14:25:12 -0800169 return out, wire.ParseError(n)
Damien Neil3e42b662019-12-17 11:39:17 -0800170 }
171 var (
172 key = mapi.keyZero
Damien Neil316febd2020-02-09 12:26:50 -0800173 val = reflect.New(f.mi.GoReflectType.Elem())
Damien Neil3e42b662019-12-17 11:39:17 -0800174 )
175 for len(b) > 0 {
176 num, wtyp, n := wire.ConsumeTag(b)
177 if n < 0 {
Damien Neilf0831e82020-01-21 14:25:12 -0800178 return out, wire.ParseError(n)
Damien Neil3e42b662019-12-17 11:39:17 -0800179 }
Damien Neilf2427c02019-12-20 09:43:20 -0800180 if num > wire.MaxValidNumber {
Damien Neilf0831e82020-01-21 14:25:12 -0800181 return out, errors.New("invalid field number")
Damien Neilf2427c02019-12-20 09:43:20 -0800182 }
Damien Neil3e42b662019-12-17 11:39:17 -0800183 b = b[n:]
184 err := errUnknown
185 switch num {
186 case 1:
187 var v pref.Value
Damien Neilf0831e82020-01-21 14:25:12 -0800188 var o unmarshalOutput
189 v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
Damien Neil3e42b662019-12-17 11:39:17 -0800190 if err != nil {
191 break
192 }
193 key = v
Damien Neilf0831e82020-01-21 14:25:12 -0800194 n = o.n
Damien Neil3e42b662019-12-17 11:39:17 -0800195 case 2:
196 if wtyp != wire.BytesType {
197 break
198 }
Damien Neil7e690b52019-12-18 09:35:01 -0800199 var v []byte
200 v, n = wire.ConsumeBytes(b)
Damien Neil3e42b662019-12-17 11:39:17 -0800201 if n < 0 {
Damien Neilf0831e82020-01-21 14:25:12 -0800202 return out, wire.ParseError(n)
Damien Neil3e42b662019-12-17 11:39:17 -0800203 }
Damien Neilc600d6c2020-01-21 15:00:33 -0800204 var o unmarshalOutput
Damien Neil316febd2020-02-09 12:26:50 -0800205 o, err = f.mi.unmarshalPointer(v, pointerOfValue(val), 0, opts)
Damien Neilc600d6c2020-01-21 15:00:33 -0800206 if o.initialized {
207 // Consider this map item initialized so long as we see
208 // an initialized value.
209 out.initialized = true
210 }
Damien Neil3e42b662019-12-17 11:39:17 -0800211 }
212 if err == errUnknown {
213 n = wire.ConsumeFieldValue(num, wtyp, b)
214 if n < 0 {
Damien Neilf0831e82020-01-21 14:25:12 -0800215 return out, wire.ParseError(n)
Damien Neil3e42b662019-12-17 11:39:17 -0800216 }
217 } else if err != nil {
Damien Neilf0831e82020-01-21 14:25:12 -0800218 return out, err
Damien Neil3e42b662019-12-17 11:39:17 -0800219 }
220 b = b[n:]
221 }
222 mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), val)
Damien Neilf0831e82020-01-21 14:25:12 -0800223 out.n = n
224 return out, nil
Damien Neil3e42b662019-12-17 11:39:17 -0800225}
226
Damien Neil316febd2020-02-09 12:26:50 -0800227func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
228 if f.mi == nil {
Damien Neil3e42b662019-12-17 11:39:17 -0800229 key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
230 val := mapi.conv.valConv.PBValueOf(valrv)
231 size := 0
232 size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
233 size += mapi.valFuncs.size(val, mapValTagSize, opts)
234 b = wire.AppendVarint(b, uint64(size))
235 b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
236 if err != nil {
237 return nil, err
238 }
239 return mapi.valFuncs.marshal(b, val, mapi.valWiretag, opts)
240 } else {
241 key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
242 val := pointerOfValue(valrv)
Damien Neil316febd2020-02-09 12:26:50 -0800243 valSize := f.mi.sizePointer(val, opts)
Damien Neil3e42b662019-12-17 11:39:17 -0800244 size := 0
245 size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
246 size += mapValTagSize + wire.SizeBytes(valSize)
247 b = wire.AppendVarint(b, uint64(size))
248 b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
249 if err != nil {
250 return nil, err
251 }
252 b = wire.AppendVarint(b, mapi.valWiretag)
253 b = wire.AppendVarint(b, uint64(valSize))
Damien Neil316febd2020-02-09 12:26:50 -0800254 return f.mi.marshalAppendPointer(b, val, opts)
Damien Neil3e42b662019-12-17 11:39:17 -0800255 }
256}
257
Damien Neil316febd2020-02-09 12:26:50 -0800258func appendMap(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
Damien Neil68b81c32019-08-22 11:41:32 -0700259 if mapv.Len() == 0 {
260 return b, nil
Damien Neilc37adef2019-04-01 13:49:56 -0700261 }
Damien Neil68b81c32019-08-22 11:41:32 -0700262 if opts.Deterministic() {
Damien Neil316febd2020-02-09 12:26:50 -0800263 return appendMapDeterministic(b, mapv, mapi, f, opts)
Damien Neil68b81c32019-08-22 11:41:32 -0700264 }
Damien Neil3e42b662019-12-17 11:39:17 -0800265 iter := mapRange(mapv)
266 for iter.Next() {
267 var err error
Damien Neil316febd2020-02-09 12:26:50 -0800268 b = wire.AppendVarint(b, f.wiretag)
269 b, err = appendMapItem(b, iter.Key(), iter.Value(), mapi, f, opts)
Damien Neil3e42b662019-12-17 11:39:17 -0800270 if err != nil {
271 return b, err
272 }
273 }
274 return b, nil
Damien Neil5322bdb2019-04-09 15:57:05 -0700275}
276
Damien Neil316febd2020-02-09 12:26:50 -0800277func appendMapDeterministic(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
Damien Neil3e42b662019-12-17 11:39:17 -0800278 keys := mapv.MapKeys()
279 sort.Slice(keys, func(i, j int) bool {
280 switch keys[i].Kind() {
281 case reflect.Bool:
282 return !keys[i].Bool() && keys[j].Bool()
283 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
284 return keys[i].Int() < keys[j].Int()
285 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
286 return keys[i].Uint() < keys[j].Uint()
287 case reflect.Float32, reflect.Float64:
288 return keys[i].Float() < keys[j].Float()
289 case reflect.String:
290 return keys[i].String() < keys[j].String()
291 default:
292 panic("invalid kind: " + keys[i].Kind().String())
293 }
Damien Neil68b81c32019-08-22 11:41:32 -0700294 })
Damien Neil3e42b662019-12-17 11:39:17 -0800295 for _, key := range keys {
296 var err error
Damien Neil316febd2020-02-09 12:26:50 -0800297 b = wire.AppendVarint(b, f.wiretag)
298 b, err = appendMapItem(b, key, mapv.MapIndex(key), mapi, f, opts)
Damien Neil3e42b662019-12-17 11:39:17 -0800299 if err != nil {
300 return b, err
301 }
302 }
303 return b, nil
304}
305
Damien Neil316febd2020-02-09 12:26:50 -0800306func isInitMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo) error {
307 if mi := f.mi; mi != nil {
Damien Neil3e42b662019-12-17 11:39:17 -0800308 mi.init()
309 if !mi.needsInitCheck {
310 return nil
311 }
312 iter := mapRange(mapv)
313 for iter.Next() {
314 val := pointerOfValue(iter.Value())
315 if err := mi.isInitializedPointer(val); err != nil {
316 return err
317 }
318 }
319 } else {
320 iter := mapRange(mapv)
321 for iter.Next() {
322 val := mapi.conv.valConv.PBValueOf(iter.Value())
323 if err := mapi.valFuncs.isInit(val); err != nil {
324 return err
325 }
326 }
327 }
328 return nil
Damien Neilc37adef2019-04-01 13:49:56 -0700329}