blob: 94a1bc49a6d028371e13d02afa378a4a9e228932 [file] [log] [blame]
Damien Neilc37adef2019-04-01 13:49:56 -07001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
Damien Neilc37adef2019-04-01 13:49:56 -07008 "reflect"
Damien Neil3e42b662019-12-17 11:39:17 -08009 "sort"
Damien Neilc37adef2019-04-01 13:49:56 -070010
11 "google.golang.org/protobuf/internal/encoding/wire"
Damien Neilc37adef2019-04-01 13:49:56 -070012 pref "google.golang.org/protobuf/reflect/protoreflect"
13)
14
Damien Neile91877d2019-06-27 10:54:42 -070015type mapInfo struct {
Damien Neil3e42b662019-12-17 11:39:17 -080016 goType reflect.Type
17 keyWiretag uint64
18 valWiretag uint64
19 keyFuncs valueCoderFuncs
20 valFuncs valueCoderFuncs
21 keyZero pref.Value
22 keyKind pref.Kind
23 valMessageInfo *MessageInfo
24 conv *mapConverter
Damien Neile91877d2019-06-27 10:54:42 -070025}
26
Damien Neilc37adef2019-04-01 13:49:56 -070027func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (funcs pointerCoderFuncs) {
28 // TODO: Consider generating specialized map coders.
29 keyField := fd.MapKey()
30 valField := fd.MapValue()
31 keyWiretag := wire.EncodeTag(1, wireTypes[keyField.Kind()])
32 valWiretag := wire.EncodeTag(2, wireTypes[valField.Kind()])
Damien Neil4b3a82f2019-09-04 19:07:00 -070033 keyFuncs := encoderFuncsForValue(keyField)
34 valFuncs := encoderFuncsForValue(valField)
Damien Neil3e42b662019-12-17 11:39:17 -080035 conv := newMapConverter(ft, fd)
Damien Neilc37adef2019-04-01 13:49:56 -070036
Damien Neile91877d2019-06-27 10:54:42 -070037 mapi := &mapInfo{
38 goType: ft,
39 keyWiretag: keyWiretag,
40 valWiretag: valWiretag,
41 keyFuncs: keyFuncs,
42 valFuncs: valFuncs,
Damien Neil68b81c32019-08-22 11:41:32 -070043 keyZero: keyField.Default(),
44 keyKind: keyField.Kind(),
Damien Neil3e42b662019-12-17 11:39:17 -080045 conv: conv,
46 }
47 if valField.Kind() == pref.MessageKind {
48 mapi.valMessageInfo = getMessageInfo(ft.Elem())
Damien Neile91877d2019-06-27 10:54:42 -070049 }
50
Damien Neil5322bdb2019-04-09 15:57:05 -070051 funcs = pointerCoderFuncs{
Damien Neilc37adef2019-04-01 13:49:56 -070052 size: func(p pointer, tagsize int, opts marshalOptions) int {
Damien Neil3e42b662019-12-17 11:39:17 -080053 return sizeMap(p.AsValueOf(ft).Elem(), tagsize, mapi, opts)
Damien Neilc37adef2019-04-01 13:49:56 -070054 },
55 marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
Damien Neil3e42b662019-12-17 11:39:17 -080056 return appendMap(b, p.AsValueOf(ft).Elem(), wiretag, mapi, opts)
Damien Neilc37adef2019-04-01 13:49:56 -070057 },
Damien Neile91877d2019-06-27 10:54:42 -070058 unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
Damien Neil68b81c32019-08-22 11:41:32 -070059 mp := p.AsValueOf(ft)
60 if mp.Elem().IsNil() {
61 mp.Elem().Set(reflect.MakeMap(mapi.goType))
62 }
Damien Neil3e42b662019-12-17 11:39:17 -080063 if mapi.valMessageInfo == nil {
64 return consumeMap(b, mp.Elem(), wtyp, mapi, opts)
65 } else {
66 return consumeMapOfMessage(b, mp.Elem(), wtyp, mapi, opts)
67 }
Damien Neile91877d2019-06-27 10:54:42 -070068 },
Damien Neilc37adef2019-04-01 13:49:56 -070069 }
Damien Neil5322bdb2019-04-09 15:57:05 -070070 if valFuncs.isInit != nil {
71 funcs.isInit = func(p pointer) error {
Damien Neil3e42b662019-12-17 11:39:17 -080072 return isInitMap(p.AsValueOf(ft).Elem(), mapi)
Damien Neil5322bdb2019-04-09 15:57:05 -070073 }
74 }
75 return funcs
Damien Neilc37adef2019-04-01 13:49:56 -070076}
77
78const (
79 mapKeyTagSize = 1 // field 1, tag size 1.
80 mapValTagSize = 1 // field 2, tag size 2.
81)
82
Damien Neil3e42b662019-12-17 11:39:17 -080083func sizeMap(mapv reflect.Value, tagsize int, mapi *mapInfo, opts marshalOptions) int {
Damien Neil68b81c32019-08-22 11:41:32 -070084 if mapv.Len() == 0 {
85 return 0
Damien Neile91877d2019-06-27 10:54:42 -070086 }
Damien Neil68b81c32019-08-22 11:41:32 -070087 n := 0
Damien Neil3e42b662019-12-17 11:39:17 -080088 iter := mapRange(mapv)
89 for iter.Next() {
90 key := mapi.conv.keyConv.PBValueOf(iter.Key()).MapKey()
91 keySize := mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
92 var valSize int
93 value := mapi.conv.valConv.PBValueOf(iter.Value())
94 if mapi.valMessageInfo == nil {
95 valSize = mapi.valFuncs.size(value, mapValTagSize, opts)
96 } else {
97 p := pointerOfValue(iter.Value())
98 valSize += mapValTagSize
99 valSize += wire.SizeBytes(mapi.valMessageInfo.sizePointer(p, opts))
100 }
101 n += tagsize + wire.SizeBytes(keySize+valSize)
102 }
Damien Neil68b81c32019-08-22 11:41:32 -0700103 return n
104}
Damien Neile91877d2019-06-27 10:54:42 -0700105
Damien Neil3e42b662019-12-17 11:39:17 -0800106func consumeMap(b []byte, mapv reflect.Value, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (int, error) {
Damien Neile91877d2019-06-27 10:54:42 -0700107 if wtyp != wire.BytesType {
108 return 0, errUnknown
109 }
110 b, n := wire.ConsumeBytes(b)
111 if n < 0 {
112 return 0, wire.ParseError(n)
113 }
114 var (
115 key = mapi.keyZero
Damien Neil3e42b662019-12-17 11:39:17 -0800116 val = mapi.conv.valConv.New()
Damien Neile91877d2019-06-27 10:54:42 -0700117 )
Damien Neile91877d2019-06-27 10:54:42 -0700118 for len(b) > 0 {
119 num, wtyp, n := wire.ConsumeTag(b)
120 if n < 0 {
121 return 0, wire.ParseError(n)
122 }
123 b = b[n:]
124 err := errUnknown
125 switch num {
126 case 1:
Damien Neil68b81c32019-08-22 11:41:32 -0700127 var v pref.Value
Damien Neile91877d2019-06-27 10:54:42 -0700128 v, n, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
129 if err != nil {
130 break
131 }
132 key = v
133 case 2:
Damien Neil68b81c32019-08-22 11:41:32 -0700134 var v pref.Value
Damien Neile91877d2019-06-27 10:54:42 -0700135 v, n, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
136 if err != nil {
137 break
138 }
139 val = v
140 }
141 if err == errUnknown {
142 n = wire.ConsumeFieldValue(num, wtyp, b)
143 if n < 0 {
144 return 0, wire.ParseError(n)
145 }
146 } else if err != nil {
147 return 0, err
148 }
149 b = b[n:]
150 }
Damien Neil3e42b662019-12-17 11:39:17 -0800151 mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), mapi.conv.valConv.GoValueOf(val))
Damien Neile91877d2019-06-27 10:54:42 -0700152 return n, nil
153}
154
Damien Neil3e42b662019-12-17 11:39:17 -0800155func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (int, error) {
156 if wtyp != wire.BytesType {
157 return 0, errUnknown
158 }
159 b, n := wire.ConsumeBytes(b)
160 if n < 0 {
161 return 0, wire.ParseError(n)
162 }
163 var (
164 key = mapi.keyZero
165 val = reflect.New(mapi.valMessageInfo.GoReflectType.Elem())
166 )
167 for len(b) > 0 {
168 num, wtyp, n := wire.ConsumeTag(b)
169 if n < 0 {
170 return 0, wire.ParseError(n)
171 }
172 b = b[n:]
173 err := errUnknown
174 switch num {
175 case 1:
176 var v pref.Value
177 v, n, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
178 if err != nil {
179 break
180 }
181 key = v
182 case 2:
183 if wtyp != wire.BytesType {
184 break
185 }
Damien Neil7e690b52019-12-18 09:35:01 -0800186 var v []byte
187 v, n = wire.ConsumeBytes(b)
Damien Neil3e42b662019-12-17 11:39:17 -0800188 if n < 0 {
189 return 0, wire.ParseError(n)
190 }
Damien Neil7e690b52019-12-18 09:35:01 -0800191 _, err = mapi.valMessageInfo.unmarshalPointer(v, pointerOfValue(val), 0, opts)
Damien Neil3e42b662019-12-17 11:39:17 -0800192 }
193 if err == errUnknown {
194 n = wire.ConsumeFieldValue(num, wtyp, b)
195 if n < 0 {
196 return 0, wire.ParseError(n)
197 }
198 } else if err != nil {
199 return 0, err
200 }
201 b = b[n:]
202 }
203 mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), val)
204 return n, nil
205}
206
207func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, opts marshalOptions) ([]byte, error) {
208 if mapi.valMessageInfo == nil {
209 key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
210 val := mapi.conv.valConv.PBValueOf(valrv)
211 size := 0
212 size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
213 size += mapi.valFuncs.size(val, mapValTagSize, opts)
214 b = wire.AppendVarint(b, uint64(size))
215 b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
216 if err != nil {
217 return nil, err
218 }
219 return mapi.valFuncs.marshal(b, val, mapi.valWiretag, opts)
220 } else {
221 key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
222 val := pointerOfValue(valrv)
223 valSize := mapi.valMessageInfo.sizePointer(val, opts)
224 size := 0
225 size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
226 size += mapValTagSize + wire.SizeBytes(valSize)
227 b = wire.AppendVarint(b, uint64(size))
228 b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
229 if err != nil {
230 return nil, err
231 }
232 b = wire.AppendVarint(b, mapi.valWiretag)
233 b = wire.AppendVarint(b, uint64(valSize))
234 return mapi.valMessageInfo.marshalAppendPointer(b, val, opts)
235 }
236}
237
238func appendMap(b []byte, mapv reflect.Value, wiretag uint64, mapi *mapInfo, opts marshalOptions) ([]byte, error) {
Damien Neil68b81c32019-08-22 11:41:32 -0700239 if mapv.Len() == 0 {
240 return b, nil
Damien Neilc37adef2019-04-01 13:49:56 -0700241 }
Damien Neil68b81c32019-08-22 11:41:32 -0700242 if opts.Deterministic() {
Damien Neil3e42b662019-12-17 11:39:17 -0800243 return appendMapDeterministic(b, mapv, wiretag, mapi, opts)
Damien Neil68b81c32019-08-22 11:41:32 -0700244 }
Damien Neil3e42b662019-12-17 11:39:17 -0800245 iter := mapRange(mapv)
246 for iter.Next() {
247 var err error
248 b = wire.AppendVarint(b, wiretag)
249 b, err = appendMapItem(b, iter.Key(), iter.Value(), mapi, opts)
250 if err != nil {
251 return b, err
252 }
253 }
254 return b, nil
Damien Neil5322bdb2019-04-09 15:57:05 -0700255}
256
Damien Neil3e42b662019-12-17 11:39:17 -0800257func appendMapDeterministic(b []byte, mapv reflect.Value, wiretag uint64, mapi *mapInfo, opts marshalOptions) ([]byte, error) {
258 keys := mapv.MapKeys()
259 sort.Slice(keys, func(i, j int) bool {
260 switch keys[i].Kind() {
261 case reflect.Bool:
262 return !keys[i].Bool() && keys[j].Bool()
263 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
264 return keys[i].Int() < keys[j].Int()
265 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
266 return keys[i].Uint() < keys[j].Uint()
267 case reflect.Float32, reflect.Float64:
268 return keys[i].Float() < keys[j].Float()
269 case reflect.String:
270 return keys[i].String() < keys[j].String()
271 default:
272 panic("invalid kind: " + keys[i].Kind().String())
273 }
Damien Neil68b81c32019-08-22 11:41:32 -0700274 })
Damien Neil3e42b662019-12-17 11:39:17 -0800275 for _, key := range keys {
276 var err error
277 b = wire.AppendVarint(b, wiretag)
278 b, err = appendMapItem(b, key, mapv.MapIndex(key), mapi, opts)
279 if err != nil {
280 return b, err
281 }
282 }
283 return b, nil
284}
285
286func isInitMap(mapv reflect.Value, mapi *mapInfo) error {
287 if mi := mapi.valMessageInfo; mi != nil {
288 mi.init()
289 if !mi.needsInitCheck {
290 return nil
291 }
292 iter := mapRange(mapv)
293 for iter.Next() {
294 val := pointerOfValue(iter.Value())
295 if err := mi.isInitializedPointer(val); err != nil {
296 return err
297 }
298 }
299 } else {
300 iter := mapRange(mapv)
301 for iter.Next() {
302 val := mapi.conv.valConv.PBValueOf(iter.Value())
303 if err := mapi.valFuncs.isInit(val); err != nil {
304 return err
305 }
306 }
307 }
308 return nil
Damien Neilc37adef2019-04-01 13:49:56 -0700309}