blob: 53a860b4120e4b7c29f0caae45ebeb6bc43a8ba3 [file] [log] [blame]
Damien Neilc37adef2019-04-01 13:49:56 -07001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8 "fmt"
9 "reflect"
10 "sort"
11
12 "google.golang.org/protobuf/internal/encoding/wire"
Damien Neilc37adef2019-04-01 13:49:56 -070013 "google.golang.org/protobuf/proto"
14 pref "google.golang.org/protobuf/reflect/protoreflect"
15)
16
17var protoMessageType = reflect.TypeOf((*proto.Message)(nil)).Elem()
18
Damien Neile91877d2019-06-27 10:54:42 -070019type mapInfo struct {
20 goType reflect.Type
21 keyWiretag uint64
22 valWiretag uint64
23 keyFuncs ifaceCoderFuncs
24 valFuncs ifaceCoderFuncs
25 keyZero interface{}
26 valZero interface{}
27 newVal func() interface{}
28}
29
Damien Neilc37adef2019-04-01 13:49:56 -070030func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (funcs pointerCoderFuncs) {
31 // TODO: Consider generating specialized map coders.
32 keyField := fd.MapKey()
33 valField := fd.MapValue()
34 keyWiretag := wire.EncodeTag(1, wireTypes[keyField.Kind()])
35 valWiretag := wire.EncodeTag(2, wireTypes[valField.Kind()])
36 keyFuncs := encoderFuncsForValue(keyField, ft.Key())
37 valFuncs := encoderFuncsForValue(valField, ft.Elem())
38
Damien Neile91877d2019-06-27 10:54:42 -070039 mapi := &mapInfo{
40 goType: ft,
41 keyWiretag: keyWiretag,
42 valWiretag: valWiretag,
43 keyFuncs: keyFuncs,
44 valFuncs: valFuncs,
45 keyZero: reflect.Zero(ft.Key()).Interface(),
46 valZero: reflect.Zero(ft.Elem()).Interface(),
47 }
48 switch valField.Kind() {
49 case pref.GroupKind, pref.MessageKind:
50 mapi.newVal = func() interface{} {
51 return reflect.New(ft.Elem().Elem()).Interface()
52 }
53 }
54
Damien Neil5322bdb2019-04-09 15:57:05 -070055 funcs = pointerCoderFuncs{
Damien Neilc37adef2019-04-01 13:49:56 -070056 size: func(p pointer, tagsize int, opts marshalOptions) int {
57 return sizeMap(p, tagsize, ft, keyFuncs, valFuncs, opts)
58 },
59 marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
60 return appendMap(b, p, wiretag, keyWiretag, valWiretag, ft, keyFuncs, valFuncs, opts)
61 },
Damien Neile91877d2019-06-27 10:54:42 -070062 unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
63 return consumeMap(b, p, wtyp, mapi, opts)
64 },
Damien Neilc37adef2019-04-01 13:49:56 -070065 }
Damien Neil5322bdb2019-04-09 15:57:05 -070066 if valFuncs.isInit != nil {
67 funcs.isInit = func(p pointer) error {
68 return isInitMap(p, ft, valFuncs.isInit)
69 }
70 }
71 return funcs
Damien Neilc37adef2019-04-01 13:49:56 -070072}
73
74const (
75 mapKeyTagSize = 1 // field 1, tag size 1.
76 mapValTagSize = 1 // field 2, tag size 2.
77)
78
Damien Neile91877d2019-06-27 10:54:42 -070079func consumeMap(b []byte, p pointer, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (int, error) {
80 mp := p.AsValueOf(mapi.goType)
81 if mp.Elem().IsNil() {
82 mp.Elem().Set(reflect.MakeMap(mapi.goType))
83 }
84 m := mp.Elem()
85
86 if wtyp != wire.BytesType {
87 return 0, errUnknown
88 }
89 b, n := wire.ConsumeBytes(b)
90 if n < 0 {
91 return 0, wire.ParseError(n)
92 }
93 var (
94 key = mapi.keyZero
95 val = mapi.valZero
96 )
97 if mapi.newVal != nil {
98 val = mapi.newVal()
99 }
100 for len(b) > 0 {
101 num, wtyp, n := wire.ConsumeTag(b)
102 if n < 0 {
103 return 0, wire.ParseError(n)
104 }
105 b = b[n:]
106 err := errUnknown
107 switch num {
108 case 1:
109 var v interface{}
110 v, n, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
111 if err != nil {
112 break
113 }
114 key = v
115 case 2:
116 var v interface{}
117 v, n, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
118 if err != nil {
119 break
120 }
121 val = v
122 }
123 if err == errUnknown {
124 n = wire.ConsumeFieldValue(num, wtyp, b)
125 if n < 0 {
126 return 0, wire.ParseError(n)
127 }
128 } else if err != nil {
129 return 0, err
130 }
131 b = b[n:]
132 }
133 m.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(val))
134 return n, nil
135}
136
Damien Neilc37adef2019-04-01 13:49:56 -0700137func sizeMap(p pointer, tagsize int, goType reflect.Type, keyFuncs, valFuncs ifaceCoderFuncs, opts marshalOptions) int {
138 m := p.AsValueOf(goType).Elem()
139 n := 0
140 if m.Len() == 0 {
141 return 0
142 }
143 iter := mapRange(m)
144 for iter.Next() {
145 ki := iter.Key().Interface()
146 vi := iter.Value().Interface()
147 size := keyFuncs.size(ki, mapKeyTagSize, opts) + valFuncs.size(vi, mapValTagSize, opts)
148 n += wire.SizeBytes(size) + tagsize
149 }
150 return n
151}
152
153func appendMap(b []byte, p pointer, wiretag, keyWiretag, valWiretag uint64, goType reflect.Type, keyFuncs, valFuncs ifaceCoderFuncs, opts marshalOptions) ([]byte, error) {
154 m := p.AsValueOf(goType).Elem()
Damien Neilc37adef2019-04-01 13:49:56 -0700155 var err error
156
157 if m.Len() == 0 {
158 return b, nil
159 }
160
161 if opts.Deterministic() {
162 keys := m.MapKeys()
163 sort.Sort(mapKeys(keys))
164 for _, k := range keys {
165 b, err = appendMapElement(b, k, m.MapIndex(k), wiretag, keyWiretag, valWiretag, keyFuncs, valFuncs, opts)
Damien Neil8c86fc52019-06-19 09:28:29 -0700166 if err != nil {
Damien Neilc37adef2019-04-01 13:49:56 -0700167 return b, err
168 }
169 }
Damien Neil8c86fc52019-06-19 09:28:29 -0700170 return b, nil
Damien Neilc37adef2019-04-01 13:49:56 -0700171 }
172
173 iter := mapRange(m)
174 for iter.Next() {
175 b, err = appendMapElement(b, iter.Key(), iter.Value(), wiretag, keyWiretag, valWiretag, keyFuncs, valFuncs, opts)
Damien Neil8c86fc52019-06-19 09:28:29 -0700176 if err != nil {
Damien Neilc37adef2019-04-01 13:49:56 -0700177 return b, err
178 }
179 }
Damien Neil8c86fc52019-06-19 09:28:29 -0700180 return b, nil
Damien Neilc37adef2019-04-01 13:49:56 -0700181}
182
183func appendMapElement(b []byte, key, value reflect.Value, wiretag, keyWiretag, valWiretag uint64, keyFuncs, valFuncs ifaceCoderFuncs, opts marshalOptions) ([]byte, error) {
184 ki := key.Interface()
185 vi := value.Interface()
186 b = wire.AppendVarint(b, wiretag)
187 size := keyFuncs.size(ki, mapKeyTagSize, opts) + valFuncs.size(vi, mapValTagSize, opts)
188 b = wire.AppendVarint(b, uint64(size))
Damien Neilc37adef2019-04-01 13:49:56 -0700189 b, err := keyFuncs.marshal(b, ki, keyWiretag, opts)
Damien Neil8c86fc52019-06-19 09:28:29 -0700190 if err != nil {
Damien Neilc37adef2019-04-01 13:49:56 -0700191 return b, err
192 }
193 b, err = valFuncs.marshal(b, vi, valWiretag, opts)
Damien Neil8c86fc52019-06-19 09:28:29 -0700194 if err != nil {
Damien Neilc37adef2019-04-01 13:49:56 -0700195 return b, err
196 }
Damien Neil8c86fc52019-06-19 09:28:29 -0700197 return b, nil
Damien Neilc37adef2019-04-01 13:49:56 -0700198}
199
Damien Neil5322bdb2019-04-09 15:57:05 -0700200func isInitMap(p pointer, goType reflect.Type, isInit func(interface{}) error) error {
201 m := p.AsValueOf(goType).Elem()
202 if m.Len() == 0 {
203 return nil
204 }
205 iter := mapRange(m)
206 for iter.Next() {
207 if err := isInit(iter.Value().Interface()); err != nil {
208 return err
209 }
210 }
211 return nil
212}
213
Damien Neilc37adef2019-04-01 13:49:56 -0700214// mapKeys returns a sort.Interface to be used for sorting the map keys.
215// Map fields may have key types of non-float scalars, strings and enums.
216func mapKeys(vs []reflect.Value) sort.Interface {
217 s := mapKeySorter{vs: vs}
218
219 // Type specialization per https://developers.google.com/protocol-buffers/docs/proto#maps.
220 if len(vs) == 0 {
221 return s
222 }
223 switch vs[0].Kind() {
224 case reflect.Int32, reflect.Int64:
225 s.less = func(a, b reflect.Value) bool { return a.Int() < b.Int() }
226 case reflect.Uint32, reflect.Uint64:
227 s.less = func(a, b reflect.Value) bool { return a.Uint() < b.Uint() }
228 case reflect.Bool:
229 s.less = func(a, b reflect.Value) bool { return !a.Bool() && b.Bool() } // false < true
230 case reflect.String:
231 s.less = func(a, b reflect.Value) bool { return a.String() < b.String() }
232 default:
233 panic(fmt.Sprintf("unsupported map key type: %v", vs[0].Kind()))
234 }
235
236 return s
237}
238
239type mapKeySorter struct {
240 vs []reflect.Value
241 less func(a, b reflect.Value) bool
242}
243
244func (s mapKeySorter) Len() int { return len(s.vs) }
245func (s mapKeySorter) Swap(i, j int) { s.vs[i], s.vs[j] = s.vs[j], s.vs[i] }
246func (s mapKeySorter) Less(i, j int) bool {
247 return s.less(s.vs[i], s.vs[j])
248}