blob: b69ee1aa10435cde33e23703903e5748f806c71f [file] [log] [blame]
Damien Neilc37adef2019-04-01 13:49:56 -07001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
Damien Neilf2427c02019-12-20 09:43:20 -08008 "errors"
Damien Neilc37adef2019-04-01 13:49:56 -07009 "reflect"
Damien Neil3e42b662019-12-17 11:39:17 -080010 "sort"
Damien Neilc37adef2019-04-01 13:49:56 -070011
12 "google.golang.org/protobuf/internal/encoding/wire"
Damien Neilc37adef2019-04-01 13:49:56 -070013 pref "google.golang.org/protobuf/reflect/protoreflect"
14)
15
Damien Neile91877d2019-06-27 10:54:42 -070016type mapInfo struct {
Damien Neil3e42b662019-12-17 11:39:17 -080017 goType reflect.Type
18 keyWiretag uint64
19 valWiretag uint64
20 keyFuncs valueCoderFuncs
21 valFuncs valueCoderFuncs
22 keyZero pref.Value
23 keyKind pref.Kind
24 valMessageInfo *MessageInfo
25 conv *mapConverter
Damien Neile91877d2019-06-27 10:54:42 -070026}
27
Damien Neilc37adef2019-04-01 13:49:56 -070028func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (funcs pointerCoderFuncs) {
29 // TODO: Consider generating specialized map coders.
30 keyField := fd.MapKey()
31 valField := fd.MapValue()
32 keyWiretag := wire.EncodeTag(1, wireTypes[keyField.Kind()])
33 valWiretag := wire.EncodeTag(2, wireTypes[valField.Kind()])
Damien Neil4b3a82f2019-09-04 19:07:00 -070034 keyFuncs := encoderFuncsForValue(keyField)
35 valFuncs := encoderFuncsForValue(valField)
Damien Neil3e42b662019-12-17 11:39:17 -080036 conv := newMapConverter(ft, fd)
Damien Neilc37adef2019-04-01 13:49:56 -070037
Damien Neile91877d2019-06-27 10:54:42 -070038 mapi := &mapInfo{
39 goType: ft,
40 keyWiretag: keyWiretag,
41 valWiretag: valWiretag,
42 keyFuncs: keyFuncs,
43 valFuncs: valFuncs,
Damien Neil68b81c32019-08-22 11:41:32 -070044 keyZero: keyField.Default(),
45 keyKind: keyField.Kind(),
Damien Neil3e42b662019-12-17 11:39:17 -080046 conv: conv,
47 }
48 if valField.Kind() == pref.MessageKind {
49 mapi.valMessageInfo = getMessageInfo(ft.Elem())
Damien Neile91877d2019-06-27 10:54:42 -070050 }
51
Damien Neil5322bdb2019-04-09 15:57:05 -070052 funcs = pointerCoderFuncs{
Damien Neilc37adef2019-04-01 13:49:56 -070053 size: func(p pointer, tagsize int, opts marshalOptions) int {
Damien Neil3e42b662019-12-17 11:39:17 -080054 return sizeMap(p.AsValueOf(ft).Elem(), tagsize, mapi, opts)
Damien Neilc37adef2019-04-01 13:49:56 -070055 },
56 marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
Damien Neil3e42b662019-12-17 11:39:17 -080057 return appendMap(b, p.AsValueOf(ft).Elem(), wiretag, mapi, opts)
Damien Neilc37adef2019-04-01 13:49:56 -070058 },
Damien Neilf0831e82020-01-21 14:25:12 -080059 unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (unmarshalOutput, error) {
Damien Neil68b81c32019-08-22 11:41:32 -070060 mp := p.AsValueOf(ft)
61 if mp.Elem().IsNil() {
62 mp.Elem().Set(reflect.MakeMap(mapi.goType))
63 }
Damien Neil3e42b662019-12-17 11:39:17 -080064 if mapi.valMessageInfo == nil {
65 return consumeMap(b, mp.Elem(), wtyp, mapi, opts)
66 } else {
67 return consumeMapOfMessage(b, mp.Elem(), wtyp, mapi, opts)
68 }
Damien Neile91877d2019-06-27 10:54:42 -070069 },
Damien Neilc37adef2019-04-01 13:49:56 -070070 }
Damien Neil5322bdb2019-04-09 15:57:05 -070071 if valFuncs.isInit != nil {
72 funcs.isInit = func(p pointer) error {
Damien Neil3e42b662019-12-17 11:39:17 -080073 return isInitMap(p.AsValueOf(ft).Elem(), mapi)
Damien Neil5322bdb2019-04-09 15:57:05 -070074 }
75 }
76 return funcs
Damien Neilc37adef2019-04-01 13:49:56 -070077}
78
79const (
80 mapKeyTagSize = 1 // field 1, tag size 1.
81 mapValTagSize = 1 // field 2, tag size 2.
82)
83
Damien Neil3e42b662019-12-17 11:39:17 -080084func sizeMap(mapv reflect.Value, tagsize int, mapi *mapInfo, opts marshalOptions) int {
Damien Neil68b81c32019-08-22 11:41:32 -070085 if mapv.Len() == 0 {
86 return 0
Damien Neile91877d2019-06-27 10:54:42 -070087 }
Damien Neil68b81c32019-08-22 11:41:32 -070088 n := 0
Damien Neil3e42b662019-12-17 11:39:17 -080089 iter := mapRange(mapv)
90 for iter.Next() {
91 key := mapi.conv.keyConv.PBValueOf(iter.Key()).MapKey()
92 keySize := mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
93 var valSize int
94 value := mapi.conv.valConv.PBValueOf(iter.Value())
95 if mapi.valMessageInfo == nil {
96 valSize = mapi.valFuncs.size(value, mapValTagSize, opts)
97 } else {
98 p := pointerOfValue(iter.Value())
99 valSize += mapValTagSize
100 valSize += wire.SizeBytes(mapi.valMessageInfo.sizePointer(p, opts))
101 }
102 n += tagsize + wire.SizeBytes(keySize+valSize)
103 }
Damien Neil68b81c32019-08-22 11:41:32 -0700104 return n
105}
Damien Neile91877d2019-06-27 10:54:42 -0700106
Damien Neilf0831e82020-01-21 14:25:12 -0800107func consumeMap(b []byte, mapv reflect.Value, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
Damien Neile91877d2019-06-27 10:54:42 -0700108 if wtyp != wire.BytesType {
Damien Neilf0831e82020-01-21 14:25:12 -0800109 return out, errUnknown
Damien Neile91877d2019-06-27 10:54:42 -0700110 }
111 b, n := wire.ConsumeBytes(b)
112 if n < 0 {
Damien Neilf0831e82020-01-21 14:25:12 -0800113 return out, wire.ParseError(n)
Damien Neile91877d2019-06-27 10:54:42 -0700114 }
115 var (
116 key = mapi.keyZero
Damien Neil3e42b662019-12-17 11:39:17 -0800117 val = mapi.conv.valConv.New()
Damien Neile91877d2019-06-27 10:54:42 -0700118 )
Damien Neile91877d2019-06-27 10:54:42 -0700119 for len(b) > 0 {
120 num, wtyp, n := wire.ConsumeTag(b)
121 if n < 0 {
Damien Neilf0831e82020-01-21 14:25:12 -0800122 return out, wire.ParseError(n)
Damien Neile91877d2019-06-27 10:54:42 -0700123 }
Damien Neilf2427c02019-12-20 09:43:20 -0800124 if num > wire.MaxValidNumber {
Damien Neilf0831e82020-01-21 14:25:12 -0800125 return out, errors.New("invalid field number")
Damien Neilf2427c02019-12-20 09:43:20 -0800126 }
Damien Neile91877d2019-06-27 10:54:42 -0700127 b = b[n:]
128 err := errUnknown
129 switch num {
130 case 1:
Damien Neil68b81c32019-08-22 11:41:32 -0700131 var v pref.Value
Damien Neilf0831e82020-01-21 14:25:12 -0800132 var o unmarshalOutput
133 v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
Damien Neile91877d2019-06-27 10:54:42 -0700134 if err != nil {
135 break
136 }
137 key = v
Damien Neilf0831e82020-01-21 14:25:12 -0800138 n = o.n
Damien Neile91877d2019-06-27 10:54:42 -0700139 case 2:
Damien Neil68b81c32019-08-22 11:41:32 -0700140 var v pref.Value
Damien Neilf0831e82020-01-21 14:25:12 -0800141 var o unmarshalOutput
142 v, o, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
Damien Neile91877d2019-06-27 10:54:42 -0700143 if err != nil {
144 break
145 }
146 val = v
Damien Neilf0831e82020-01-21 14:25:12 -0800147 n = o.n
Damien Neile91877d2019-06-27 10:54:42 -0700148 }
149 if err == errUnknown {
150 n = wire.ConsumeFieldValue(num, wtyp, b)
151 if n < 0 {
Damien Neilf0831e82020-01-21 14:25:12 -0800152 return out, wire.ParseError(n)
Damien Neile91877d2019-06-27 10:54:42 -0700153 }
154 } else if err != nil {
Damien Neilf0831e82020-01-21 14:25:12 -0800155 return out, err
Damien Neile91877d2019-06-27 10:54:42 -0700156 }
157 b = b[n:]
158 }
Damien Neil3e42b662019-12-17 11:39:17 -0800159 mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), mapi.conv.valConv.GoValueOf(val))
Damien Neilf0831e82020-01-21 14:25:12 -0800160 out.n = n
161 return out, nil
Damien Neile91877d2019-06-27 10:54:42 -0700162}
163
Damien Neilf0831e82020-01-21 14:25:12 -0800164func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
Damien Neil3e42b662019-12-17 11:39:17 -0800165 if wtyp != wire.BytesType {
Damien Neilf0831e82020-01-21 14:25:12 -0800166 return out, errUnknown
Damien Neil3e42b662019-12-17 11:39:17 -0800167 }
168 b, n := wire.ConsumeBytes(b)
169 if n < 0 {
Damien Neilf0831e82020-01-21 14:25:12 -0800170 return out, wire.ParseError(n)
Damien Neil3e42b662019-12-17 11:39:17 -0800171 }
172 var (
173 key = mapi.keyZero
174 val = reflect.New(mapi.valMessageInfo.GoReflectType.Elem())
175 )
176 for len(b) > 0 {
177 num, wtyp, n := wire.ConsumeTag(b)
178 if n < 0 {
Damien Neilf0831e82020-01-21 14:25:12 -0800179 return out, wire.ParseError(n)
Damien Neil3e42b662019-12-17 11:39:17 -0800180 }
Damien Neilf2427c02019-12-20 09:43:20 -0800181 if num > wire.MaxValidNumber {
Damien Neilf0831e82020-01-21 14:25:12 -0800182 return out, errors.New("invalid field number")
Damien Neilf2427c02019-12-20 09:43:20 -0800183 }
Damien Neil3e42b662019-12-17 11:39:17 -0800184 b = b[n:]
185 err := errUnknown
186 switch num {
187 case 1:
188 var v pref.Value
Damien Neilf0831e82020-01-21 14:25:12 -0800189 var o unmarshalOutput
190 v, o, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
Damien Neil3e42b662019-12-17 11:39:17 -0800191 if err != nil {
192 break
193 }
194 key = v
Damien Neilf0831e82020-01-21 14:25:12 -0800195 n = o.n
Damien Neil3e42b662019-12-17 11:39:17 -0800196 case 2:
197 if wtyp != wire.BytesType {
198 break
199 }
Damien Neil7e690b52019-12-18 09:35:01 -0800200 var v []byte
201 v, n = wire.ConsumeBytes(b)
Damien Neil3e42b662019-12-17 11:39:17 -0800202 if n < 0 {
Damien Neilf0831e82020-01-21 14:25:12 -0800203 return out, wire.ParseError(n)
Damien Neil3e42b662019-12-17 11:39:17 -0800204 }
Damien Neil7e690b52019-12-18 09:35:01 -0800205 _, err = mapi.valMessageInfo.unmarshalPointer(v, pointerOfValue(val), 0, opts)
Damien Neil3e42b662019-12-17 11:39:17 -0800206 }
207 if err == errUnknown {
208 n = wire.ConsumeFieldValue(num, wtyp, b)
209 if n < 0 {
Damien Neilf0831e82020-01-21 14:25:12 -0800210 return out, wire.ParseError(n)
Damien Neil3e42b662019-12-17 11:39:17 -0800211 }
212 } else if err != nil {
Damien Neilf0831e82020-01-21 14:25:12 -0800213 return out, err
Damien Neil3e42b662019-12-17 11:39:17 -0800214 }
215 b = b[n:]
216 }
217 mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), val)
Damien Neilf0831e82020-01-21 14:25:12 -0800218 out.n = n
219 return out, nil
Damien Neil3e42b662019-12-17 11:39:17 -0800220}
221
222func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, opts marshalOptions) ([]byte, error) {
223 if mapi.valMessageInfo == nil {
224 key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
225 val := mapi.conv.valConv.PBValueOf(valrv)
226 size := 0
227 size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
228 size += mapi.valFuncs.size(val, mapValTagSize, opts)
229 b = wire.AppendVarint(b, uint64(size))
230 b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
231 if err != nil {
232 return nil, err
233 }
234 return mapi.valFuncs.marshal(b, val, mapi.valWiretag, opts)
235 } else {
236 key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
237 val := pointerOfValue(valrv)
238 valSize := mapi.valMessageInfo.sizePointer(val, opts)
239 size := 0
240 size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
241 size += mapValTagSize + wire.SizeBytes(valSize)
242 b = wire.AppendVarint(b, uint64(size))
243 b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
244 if err != nil {
245 return nil, err
246 }
247 b = wire.AppendVarint(b, mapi.valWiretag)
248 b = wire.AppendVarint(b, uint64(valSize))
249 return mapi.valMessageInfo.marshalAppendPointer(b, val, opts)
250 }
251}
252
253func appendMap(b []byte, mapv reflect.Value, wiretag uint64, mapi *mapInfo, opts marshalOptions) ([]byte, error) {
Damien Neil68b81c32019-08-22 11:41:32 -0700254 if mapv.Len() == 0 {
255 return b, nil
Damien Neilc37adef2019-04-01 13:49:56 -0700256 }
Damien Neil68b81c32019-08-22 11:41:32 -0700257 if opts.Deterministic() {
Damien Neil3e42b662019-12-17 11:39:17 -0800258 return appendMapDeterministic(b, mapv, wiretag, mapi, opts)
Damien Neil68b81c32019-08-22 11:41:32 -0700259 }
Damien Neil3e42b662019-12-17 11:39:17 -0800260 iter := mapRange(mapv)
261 for iter.Next() {
262 var err error
263 b = wire.AppendVarint(b, wiretag)
264 b, err = appendMapItem(b, iter.Key(), iter.Value(), mapi, opts)
265 if err != nil {
266 return b, err
267 }
268 }
269 return b, nil
Damien Neil5322bdb2019-04-09 15:57:05 -0700270}
271
Damien Neil3e42b662019-12-17 11:39:17 -0800272func appendMapDeterministic(b []byte, mapv reflect.Value, wiretag uint64, mapi *mapInfo, opts marshalOptions) ([]byte, error) {
273 keys := mapv.MapKeys()
274 sort.Slice(keys, func(i, j int) bool {
275 switch keys[i].Kind() {
276 case reflect.Bool:
277 return !keys[i].Bool() && keys[j].Bool()
278 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
279 return keys[i].Int() < keys[j].Int()
280 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
281 return keys[i].Uint() < keys[j].Uint()
282 case reflect.Float32, reflect.Float64:
283 return keys[i].Float() < keys[j].Float()
284 case reflect.String:
285 return keys[i].String() < keys[j].String()
286 default:
287 panic("invalid kind: " + keys[i].Kind().String())
288 }
Damien Neil68b81c32019-08-22 11:41:32 -0700289 })
Damien Neil3e42b662019-12-17 11:39:17 -0800290 for _, key := range keys {
291 var err error
292 b = wire.AppendVarint(b, wiretag)
293 b, err = appendMapItem(b, key, mapv.MapIndex(key), mapi, opts)
294 if err != nil {
295 return b, err
296 }
297 }
298 return b, nil
299}
300
301func isInitMap(mapv reflect.Value, mapi *mapInfo) error {
302 if mi := mapi.valMessageInfo; mi != nil {
303 mi.init()
304 if !mi.needsInitCheck {
305 return nil
306 }
307 iter := mapRange(mapv)
308 for iter.Next() {
309 val := pointerOfValue(iter.Value())
310 if err := mi.isInitializedPointer(val); err != nil {
311 return err
312 }
313 }
314 } else {
315 iter := mapRange(mapv)
316 for iter.Next() {
317 val := mapi.conv.valConv.PBValueOf(iter.Value())
318 if err := mapi.valFuncs.isInit(val); err != nil {
319 return err
320 }
321 }
322 }
323 return nil
Damien Neilc37adef2019-04-01 13:49:56 -0700324}