blob: 4916704c80c0a73a91060ef6779603bd34adee79 [file] [log] [blame]
Damien Neilc37adef2019-04-01 13:49:56 -07001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8 "fmt"
9 "reflect"
10 "sort"
11
12 "google.golang.org/protobuf/internal/encoding/wire"
Damien Neilc37adef2019-04-01 13:49:56 -070013 pref "google.golang.org/protobuf/reflect/protoreflect"
14)
15
Damien Neile91877d2019-06-27 10:54:42 -070016type mapInfo struct {
17 goType reflect.Type
18 keyWiretag uint64
19 valWiretag uint64
20 keyFuncs ifaceCoderFuncs
21 valFuncs ifaceCoderFuncs
22 keyZero interface{}
23 valZero interface{}
24 newVal func() interface{}
25}
26
Damien Neilc37adef2019-04-01 13:49:56 -070027func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (funcs pointerCoderFuncs) {
28 // TODO: Consider generating specialized map coders.
29 keyField := fd.MapKey()
30 valField := fd.MapValue()
31 keyWiretag := wire.EncodeTag(1, wireTypes[keyField.Kind()])
32 valWiretag := wire.EncodeTag(2, wireTypes[valField.Kind()])
33 keyFuncs := encoderFuncsForValue(keyField, ft.Key())
34 valFuncs := encoderFuncsForValue(valField, ft.Elem())
35
Damien Neile91877d2019-06-27 10:54:42 -070036 mapi := &mapInfo{
37 goType: ft,
38 keyWiretag: keyWiretag,
39 valWiretag: valWiretag,
40 keyFuncs: keyFuncs,
41 valFuncs: valFuncs,
42 keyZero: reflect.Zero(ft.Key()).Interface(),
43 valZero: reflect.Zero(ft.Elem()).Interface(),
44 }
45 switch valField.Kind() {
46 case pref.GroupKind, pref.MessageKind:
47 mapi.newVal = func() interface{} {
48 return reflect.New(ft.Elem().Elem()).Interface()
49 }
50 }
51
Damien Neil5322bdb2019-04-09 15:57:05 -070052 funcs = pointerCoderFuncs{
Damien Neilc37adef2019-04-01 13:49:56 -070053 size: func(p pointer, tagsize int, opts marshalOptions) int {
54 return sizeMap(p, tagsize, ft, keyFuncs, valFuncs, opts)
55 },
56 marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
57 return appendMap(b, p, wiretag, keyWiretag, valWiretag, ft, keyFuncs, valFuncs, opts)
58 },
Damien Neile91877d2019-06-27 10:54:42 -070059 unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
60 return consumeMap(b, p, wtyp, mapi, opts)
61 },
Damien Neilc37adef2019-04-01 13:49:56 -070062 }
Damien Neil5322bdb2019-04-09 15:57:05 -070063 if valFuncs.isInit != nil {
64 funcs.isInit = func(p pointer) error {
65 return isInitMap(p, ft, valFuncs.isInit)
66 }
67 }
68 return funcs
Damien Neilc37adef2019-04-01 13:49:56 -070069}
70
71const (
72 mapKeyTagSize = 1 // field 1, tag size 1.
73 mapValTagSize = 1 // field 2, tag size 2.
74)
75
Damien Neile91877d2019-06-27 10:54:42 -070076func consumeMap(b []byte, p pointer, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (int, error) {
77 mp := p.AsValueOf(mapi.goType)
78 if mp.Elem().IsNil() {
79 mp.Elem().Set(reflect.MakeMap(mapi.goType))
80 }
81 m := mp.Elem()
82
83 if wtyp != wire.BytesType {
84 return 0, errUnknown
85 }
86 b, n := wire.ConsumeBytes(b)
87 if n < 0 {
88 return 0, wire.ParseError(n)
89 }
90 var (
91 key = mapi.keyZero
92 val = mapi.valZero
93 )
94 if mapi.newVal != nil {
95 val = mapi.newVal()
96 }
97 for len(b) > 0 {
98 num, wtyp, n := wire.ConsumeTag(b)
99 if n < 0 {
100 return 0, wire.ParseError(n)
101 }
102 b = b[n:]
103 err := errUnknown
104 switch num {
105 case 1:
106 var v interface{}
107 v, n, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
108 if err != nil {
109 break
110 }
111 key = v
112 case 2:
113 var v interface{}
114 v, n, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
115 if err != nil {
116 break
117 }
118 val = v
119 }
120 if err == errUnknown {
121 n = wire.ConsumeFieldValue(num, wtyp, b)
122 if n < 0 {
123 return 0, wire.ParseError(n)
124 }
125 } else if err != nil {
126 return 0, err
127 }
128 b = b[n:]
129 }
130 m.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(val))
131 return n, nil
132}
133
Damien Neilc37adef2019-04-01 13:49:56 -0700134func sizeMap(p pointer, tagsize int, goType reflect.Type, keyFuncs, valFuncs ifaceCoderFuncs, opts marshalOptions) int {
135 m := p.AsValueOf(goType).Elem()
136 n := 0
137 if m.Len() == 0 {
138 return 0
139 }
140 iter := mapRange(m)
141 for iter.Next() {
142 ki := iter.Key().Interface()
143 vi := iter.Value().Interface()
144 size := keyFuncs.size(ki, mapKeyTagSize, opts) + valFuncs.size(vi, mapValTagSize, opts)
145 n += wire.SizeBytes(size) + tagsize
146 }
147 return n
148}
149
150func appendMap(b []byte, p pointer, wiretag, keyWiretag, valWiretag uint64, goType reflect.Type, keyFuncs, valFuncs ifaceCoderFuncs, opts marshalOptions) ([]byte, error) {
151 m := p.AsValueOf(goType).Elem()
Damien Neilc37adef2019-04-01 13:49:56 -0700152 var err error
153
154 if m.Len() == 0 {
155 return b, nil
156 }
157
158 if opts.Deterministic() {
159 keys := m.MapKeys()
160 sort.Sort(mapKeys(keys))
161 for _, k := range keys {
162 b, err = appendMapElement(b, k, m.MapIndex(k), wiretag, keyWiretag, valWiretag, keyFuncs, valFuncs, opts)
Damien Neil8c86fc52019-06-19 09:28:29 -0700163 if err != nil {
Damien Neilc37adef2019-04-01 13:49:56 -0700164 return b, err
165 }
166 }
Damien Neil8c86fc52019-06-19 09:28:29 -0700167 return b, nil
Damien Neilc37adef2019-04-01 13:49:56 -0700168 }
169
170 iter := mapRange(m)
171 for iter.Next() {
172 b, err = appendMapElement(b, iter.Key(), iter.Value(), wiretag, keyWiretag, valWiretag, keyFuncs, valFuncs, opts)
Damien Neil8c86fc52019-06-19 09:28:29 -0700173 if err != nil {
Damien Neilc37adef2019-04-01 13:49:56 -0700174 return b, err
175 }
176 }
Damien Neil8c86fc52019-06-19 09:28:29 -0700177 return b, nil
Damien Neilc37adef2019-04-01 13:49:56 -0700178}
179
180func appendMapElement(b []byte, key, value reflect.Value, wiretag, keyWiretag, valWiretag uint64, keyFuncs, valFuncs ifaceCoderFuncs, opts marshalOptions) ([]byte, error) {
181 ki := key.Interface()
182 vi := value.Interface()
183 b = wire.AppendVarint(b, wiretag)
184 size := keyFuncs.size(ki, mapKeyTagSize, opts) + valFuncs.size(vi, mapValTagSize, opts)
185 b = wire.AppendVarint(b, uint64(size))
Damien Neilc37adef2019-04-01 13:49:56 -0700186 b, err := keyFuncs.marshal(b, ki, keyWiretag, opts)
Damien Neil8c86fc52019-06-19 09:28:29 -0700187 if err != nil {
Damien Neilc37adef2019-04-01 13:49:56 -0700188 return b, err
189 }
190 b, err = valFuncs.marshal(b, vi, valWiretag, opts)
Damien Neil8c86fc52019-06-19 09:28:29 -0700191 if err != nil {
Damien Neilc37adef2019-04-01 13:49:56 -0700192 return b, err
193 }
Damien Neil8c86fc52019-06-19 09:28:29 -0700194 return b, nil
Damien Neilc37adef2019-04-01 13:49:56 -0700195}
196
Damien Neil5322bdb2019-04-09 15:57:05 -0700197func isInitMap(p pointer, goType reflect.Type, isInit func(interface{}) error) error {
198 m := p.AsValueOf(goType).Elem()
199 if m.Len() == 0 {
200 return nil
201 }
202 iter := mapRange(m)
203 for iter.Next() {
204 if err := isInit(iter.Value().Interface()); err != nil {
205 return err
206 }
207 }
208 return nil
209}
210
Damien Neilc37adef2019-04-01 13:49:56 -0700211// mapKeys returns a sort.Interface to be used for sorting the map keys.
212// Map fields may have key types of non-float scalars, strings and enums.
213func mapKeys(vs []reflect.Value) sort.Interface {
214 s := mapKeySorter{vs: vs}
215
216 // Type specialization per https://developers.google.com/protocol-buffers/docs/proto#maps.
217 if len(vs) == 0 {
218 return s
219 }
220 switch vs[0].Kind() {
221 case reflect.Int32, reflect.Int64:
222 s.less = func(a, b reflect.Value) bool { return a.Int() < b.Int() }
223 case reflect.Uint32, reflect.Uint64:
224 s.less = func(a, b reflect.Value) bool { return a.Uint() < b.Uint() }
225 case reflect.Bool:
226 s.less = func(a, b reflect.Value) bool { return !a.Bool() && b.Bool() } // false < true
227 case reflect.String:
228 s.less = func(a, b reflect.Value) bool { return a.String() < b.String() }
229 default:
230 panic(fmt.Sprintf("unsupported map key type: %v", vs[0].Kind()))
231 }
232
233 return s
234}
235
236type mapKeySorter struct {
237 vs []reflect.Value
238 less func(a, b reflect.Value) bool
239}
240
241func (s mapKeySorter) Len() int { return len(s.vs) }
242func (s mapKeySorter) Swap(i, j int) { s.vs[i], s.vs[j] = s.vs[j], s.vs[i] }
243func (s mapKeySorter) Less(i, j int) bool {
244 return s.less(s.vs[i], s.vs[j])
245}