blob: 5f7d9e28bfbb45638cbdfa4a646b952e232871fe [file] [log] [blame]
Damien Neilc37adef2019-04-01 13:49:56 -07001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
Damien Neilc37adef2019-04-01 13:49:56 -07008 "reflect"
Damien Neil3e42b662019-12-17 11:39:17 -08009 "sort"
Damien Neilc37adef2019-04-01 13:49:56 -070010
11 "google.golang.org/protobuf/internal/encoding/wire"
Damien Neilc37adef2019-04-01 13:49:56 -070012 pref "google.golang.org/protobuf/reflect/protoreflect"
13)
14
Damien Neile91877d2019-06-27 10:54:42 -070015type mapInfo struct {
Damien Neil3e42b662019-12-17 11:39:17 -080016 goType reflect.Type
17 keyWiretag uint64
18 valWiretag uint64
19 keyFuncs valueCoderFuncs
20 valFuncs valueCoderFuncs
21 keyZero pref.Value
22 keyKind pref.Kind
23 valMessageInfo *MessageInfo
24 conv *mapConverter
Damien Neile91877d2019-06-27 10:54:42 -070025}
26
Damien Neilc37adef2019-04-01 13:49:56 -070027func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (funcs pointerCoderFuncs) {
28 // TODO: Consider generating specialized map coders.
29 keyField := fd.MapKey()
30 valField := fd.MapValue()
31 keyWiretag := wire.EncodeTag(1, wireTypes[keyField.Kind()])
32 valWiretag := wire.EncodeTag(2, wireTypes[valField.Kind()])
Damien Neil4b3a82f2019-09-04 19:07:00 -070033 keyFuncs := encoderFuncsForValue(keyField)
34 valFuncs := encoderFuncsForValue(valField)
Damien Neil3e42b662019-12-17 11:39:17 -080035 conv := newMapConverter(ft, fd)
Damien Neilc37adef2019-04-01 13:49:56 -070036
Damien Neile91877d2019-06-27 10:54:42 -070037 mapi := &mapInfo{
38 goType: ft,
39 keyWiretag: keyWiretag,
40 valWiretag: valWiretag,
41 keyFuncs: keyFuncs,
42 valFuncs: valFuncs,
Damien Neil68b81c32019-08-22 11:41:32 -070043 keyZero: keyField.Default(),
44 keyKind: keyField.Kind(),
Damien Neil3e42b662019-12-17 11:39:17 -080045 conv: conv,
46 }
47 if valField.Kind() == pref.MessageKind {
48 mapi.valMessageInfo = getMessageInfo(ft.Elem())
Damien Neile91877d2019-06-27 10:54:42 -070049 }
50
Damien Neil5322bdb2019-04-09 15:57:05 -070051 funcs = pointerCoderFuncs{
Damien Neilc37adef2019-04-01 13:49:56 -070052 size: func(p pointer, tagsize int, opts marshalOptions) int {
Damien Neil3e42b662019-12-17 11:39:17 -080053 return sizeMap(p.AsValueOf(ft).Elem(), tagsize, mapi, opts)
Damien Neilc37adef2019-04-01 13:49:56 -070054 },
55 marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
Damien Neil3e42b662019-12-17 11:39:17 -080056 return appendMap(b, p.AsValueOf(ft).Elem(), wiretag, mapi, opts)
Damien Neilc37adef2019-04-01 13:49:56 -070057 },
Damien Neile91877d2019-06-27 10:54:42 -070058 unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
Damien Neil68b81c32019-08-22 11:41:32 -070059 mp := p.AsValueOf(ft)
60 if mp.Elem().IsNil() {
61 mp.Elem().Set(reflect.MakeMap(mapi.goType))
62 }
Damien Neil3e42b662019-12-17 11:39:17 -080063 if mapi.valMessageInfo == nil {
64 return consumeMap(b, mp.Elem(), wtyp, mapi, opts)
65 } else {
66 return consumeMapOfMessage(b, mp.Elem(), wtyp, mapi, opts)
67 }
Damien Neile91877d2019-06-27 10:54:42 -070068 },
Damien Neilc37adef2019-04-01 13:49:56 -070069 }
Damien Neil5322bdb2019-04-09 15:57:05 -070070 if valFuncs.isInit != nil {
71 funcs.isInit = func(p pointer) error {
Damien Neil3e42b662019-12-17 11:39:17 -080072 return isInitMap(p.AsValueOf(ft).Elem(), mapi)
Damien Neil5322bdb2019-04-09 15:57:05 -070073 }
74 }
75 return funcs
Damien Neilc37adef2019-04-01 13:49:56 -070076}
77
78const (
79 mapKeyTagSize = 1 // field 1, tag size 1.
80 mapValTagSize = 1 // field 2, tag size 2.
81)
82
Damien Neil3e42b662019-12-17 11:39:17 -080083func sizeMap(mapv reflect.Value, tagsize int, mapi *mapInfo, opts marshalOptions) int {
Damien Neil68b81c32019-08-22 11:41:32 -070084 if mapv.Len() == 0 {
85 return 0
Damien Neile91877d2019-06-27 10:54:42 -070086 }
Damien Neil68b81c32019-08-22 11:41:32 -070087 n := 0
Damien Neil3e42b662019-12-17 11:39:17 -080088 iter := mapRange(mapv)
89 for iter.Next() {
90 key := mapi.conv.keyConv.PBValueOf(iter.Key()).MapKey()
91 keySize := mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
92 var valSize int
93 value := mapi.conv.valConv.PBValueOf(iter.Value())
94 if mapi.valMessageInfo == nil {
95 valSize = mapi.valFuncs.size(value, mapValTagSize, opts)
96 } else {
97 p := pointerOfValue(iter.Value())
98 valSize += mapValTagSize
99 valSize += wire.SizeBytes(mapi.valMessageInfo.sizePointer(p, opts))
100 }
101 n += tagsize + wire.SizeBytes(keySize+valSize)
102 }
Damien Neil68b81c32019-08-22 11:41:32 -0700103 return n
104}
Damien Neile91877d2019-06-27 10:54:42 -0700105
Damien Neil3e42b662019-12-17 11:39:17 -0800106func consumeMap(b []byte, mapv reflect.Value, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (int, error) {
Damien Neile91877d2019-06-27 10:54:42 -0700107 if wtyp != wire.BytesType {
108 return 0, errUnknown
109 }
110 b, n := wire.ConsumeBytes(b)
111 if n < 0 {
112 return 0, wire.ParseError(n)
113 }
114 var (
115 key = mapi.keyZero
Damien Neil3e42b662019-12-17 11:39:17 -0800116 val = mapi.conv.valConv.New()
Damien Neile91877d2019-06-27 10:54:42 -0700117 )
Damien Neile91877d2019-06-27 10:54:42 -0700118 for len(b) > 0 {
119 num, wtyp, n := wire.ConsumeTag(b)
120 if n < 0 {
121 return 0, wire.ParseError(n)
122 }
123 b = b[n:]
124 err := errUnknown
125 switch num {
126 case 1:
Damien Neil68b81c32019-08-22 11:41:32 -0700127 var v pref.Value
Damien Neile91877d2019-06-27 10:54:42 -0700128 v, n, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
129 if err != nil {
130 break
131 }
132 key = v
133 case 2:
Damien Neil68b81c32019-08-22 11:41:32 -0700134 var v pref.Value
Damien Neile91877d2019-06-27 10:54:42 -0700135 v, n, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
136 if err != nil {
137 break
138 }
139 val = v
140 }
141 if err == errUnknown {
142 n = wire.ConsumeFieldValue(num, wtyp, b)
143 if n < 0 {
144 return 0, wire.ParseError(n)
145 }
146 } else if err != nil {
147 return 0, err
148 }
149 b = b[n:]
150 }
Damien Neil3e42b662019-12-17 11:39:17 -0800151 mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), mapi.conv.valConv.GoValueOf(val))
Damien Neile91877d2019-06-27 10:54:42 -0700152 return n, nil
153}
154
Damien Neil3e42b662019-12-17 11:39:17 -0800155func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (int, error) {
156 if wtyp != wire.BytesType {
157 return 0, errUnknown
158 }
159 b, n := wire.ConsumeBytes(b)
160 if n < 0 {
161 return 0, wire.ParseError(n)
162 }
163 var (
164 key = mapi.keyZero
165 val = reflect.New(mapi.valMessageInfo.GoReflectType.Elem())
166 )
167 for len(b) > 0 {
168 num, wtyp, n := wire.ConsumeTag(b)
169 if n < 0 {
170 return 0, wire.ParseError(n)
171 }
172 b = b[n:]
173 err := errUnknown
174 switch num {
175 case 1:
176 var v pref.Value
177 v, n, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
178 if err != nil {
179 break
180 }
181 key = v
182 case 2:
183 if wtyp != wire.BytesType {
184 break
185 }
186 v, n := wire.ConsumeBytes(b)
187 if n < 0 {
188 return 0, wire.ParseError(n)
189 }
190 n, err = mapi.valMessageInfo.unmarshalPointer(v, pointerOfValue(val), 0, opts)
191 }
192 if err == errUnknown {
193 n = wire.ConsumeFieldValue(num, wtyp, b)
194 if n < 0 {
195 return 0, wire.ParseError(n)
196 }
197 } else if err != nil {
198 return 0, err
199 }
200 b = b[n:]
201 }
202 mapv.SetMapIndex(mapi.conv.keyConv.GoValueOf(key), val)
203 return n, nil
204}
205
206func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, opts marshalOptions) ([]byte, error) {
207 if mapi.valMessageInfo == nil {
208 key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
209 val := mapi.conv.valConv.PBValueOf(valrv)
210 size := 0
211 size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
212 size += mapi.valFuncs.size(val, mapValTagSize, opts)
213 b = wire.AppendVarint(b, uint64(size))
214 b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
215 if err != nil {
216 return nil, err
217 }
218 return mapi.valFuncs.marshal(b, val, mapi.valWiretag, opts)
219 } else {
220 key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
221 val := pointerOfValue(valrv)
222 valSize := mapi.valMessageInfo.sizePointer(val, opts)
223 size := 0
224 size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
225 size += mapValTagSize + wire.SizeBytes(valSize)
226 b = wire.AppendVarint(b, uint64(size))
227 b, err := mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
228 if err != nil {
229 return nil, err
230 }
231 b = wire.AppendVarint(b, mapi.valWiretag)
232 b = wire.AppendVarint(b, uint64(valSize))
233 return mapi.valMessageInfo.marshalAppendPointer(b, val, opts)
234 }
235}
236
237func appendMap(b []byte, mapv reflect.Value, wiretag uint64, mapi *mapInfo, opts marshalOptions) ([]byte, error) {
Damien Neil68b81c32019-08-22 11:41:32 -0700238 if mapv.Len() == 0 {
239 return b, nil
Damien Neilc37adef2019-04-01 13:49:56 -0700240 }
Damien Neil68b81c32019-08-22 11:41:32 -0700241 if opts.Deterministic() {
Damien Neil3e42b662019-12-17 11:39:17 -0800242 return appendMapDeterministic(b, mapv, wiretag, mapi, opts)
Damien Neil68b81c32019-08-22 11:41:32 -0700243 }
Damien Neil3e42b662019-12-17 11:39:17 -0800244 iter := mapRange(mapv)
245 for iter.Next() {
246 var err error
247 b = wire.AppendVarint(b, wiretag)
248 b, err = appendMapItem(b, iter.Key(), iter.Value(), mapi, opts)
249 if err != nil {
250 return b, err
251 }
252 }
253 return b, nil
Damien Neil5322bdb2019-04-09 15:57:05 -0700254}
255
Damien Neil3e42b662019-12-17 11:39:17 -0800256func appendMapDeterministic(b []byte, mapv reflect.Value, wiretag uint64, mapi *mapInfo, opts marshalOptions) ([]byte, error) {
257 keys := mapv.MapKeys()
258 sort.Slice(keys, func(i, j int) bool {
259 switch keys[i].Kind() {
260 case reflect.Bool:
261 return !keys[i].Bool() && keys[j].Bool()
262 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
263 return keys[i].Int() < keys[j].Int()
264 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
265 return keys[i].Uint() < keys[j].Uint()
266 case reflect.Float32, reflect.Float64:
267 return keys[i].Float() < keys[j].Float()
268 case reflect.String:
269 return keys[i].String() < keys[j].String()
270 default:
271 panic("invalid kind: " + keys[i].Kind().String())
272 }
Damien Neil68b81c32019-08-22 11:41:32 -0700273 })
Damien Neil3e42b662019-12-17 11:39:17 -0800274 for _, key := range keys {
275 var err error
276 b = wire.AppendVarint(b, wiretag)
277 b, err = appendMapItem(b, key, mapv.MapIndex(key), mapi, opts)
278 if err != nil {
279 return b, err
280 }
281 }
282 return b, nil
283}
284
285func isInitMap(mapv reflect.Value, mapi *mapInfo) error {
286 if mi := mapi.valMessageInfo; mi != nil {
287 mi.init()
288 if !mi.needsInitCheck {
289 return nil
290 }
291 iter := mapRange(mapv)
292 for iter.Next() {
293 val := pointerOfValue(iter.Value())
294 if err := mi.isInitializedPointer(val); err != nil {
295 return err
296 }
297 }
298 } else {
299 iter := mapRange(mapv)
300 for iter.Next() {
301 val := mapi.conv.valConv.PBValueOf(iter.Value())
302 if err := mapi.valFuncs.isInit(val); err != nil {
303 return err
304 }
305 }
306 }
307 return nil
Damien Neilc37adef2019-04-01 13:49:56 -0700308}