blob: 5a02c34ca590f1a93edb353d998855de7c9a8a94 [file] [log] [blame]
Damien Neilc37adef2019-04-01 13:49:56 -07001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8 "fmt"
9 "reflect"
10 "sort"
11
12 "google.golang.org/protobuf/internal/encoding/wire"
Damien Neilc37adef2019-04-01 13:49:56 -070013 "google.golang.org/protobuf/proto"
14 pref "google.golang.org/protobuf/reflect/protoreflect"
15)
16
17var protoMessageType = reflect.TypeOf((*proto.Message)(nil)).Elem()
18
19func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (funcs pointerCoderFuncs) {
20 // TODO: Consider generating specialized map coders.
21 keyField := fd.MapKey()
22 valField := fd.MapValue()
23 keyWiretag := wire.EncodeTag(1, wireTypes[keyField.Kind()])
24 valWiretag := wire.EncodeTag(2, wireTypes[valField.Kind()])
25 keyFuncs := encoderFuncsForValue(keyField, ft.Key())
26 valFuncs := encoderFuncsForValue(valField, ft.Elem())
27
Damien Neil5322bdb2019-04-09 15:57:05 -070028 funcs = pointerCoderFuncs{
Damien Neilc37adef2019-04-01 13:49:56 -070029 size: func(p pointer, tagsize int, opts marshalOptions) int {
30 return sizeMap(p, tagsize, ft, keyFuncs, valFuncs, opts)
31 },
32 marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
33 return appendMap(b, p, wiretag, keyWiretag, valWiretag, ft, keyFuncs, valFuncs, opts)
34 },
35 }
Damien Neil5322bdb2019-04-09 15:57:05 -070036 if valFuncs.isInit != nil {
37 funcs.isInit = func(p pointer) error {
38 return isInitMap(p, ft, valFuncs.isInit)
39 }
40 }
41 return funcs
Damien Neilc37adef2019-04-01 13:49:56 -070042}
43
44const (
45 mapKeyTagSize = 1 // field 1, tag size 1.
46 mapValTagSize = 1 // field 2, tag size 2.
47)
48
49func sizeMap(p pointer, tagsize int, goType reflect.Type, keyFuncs, valFuncs ifaceCoderFuncs, opts marshalOptions) int {
50 m := p.AsValueOf(goType).Elem()
51 n := 0
52 if m.Len() == 0 {
53 return 0
54 }
55 iter := mapRange(m)
56 for iter.Next() {
57 ki := iter.Key().Interface()
58 vi := iter.Value().Interface()
59 size := keyFuncs.size(ki, mapKeyTagSize, opts) + valFuncs.size(vi, mapValTagSize, opts)
60 n += wire.SizeBytes(size) + tagsize
61 }
62 return n
63}
64
65func appendMap(b []byte, p pointer, wiretag, keyWiretag, valWiretag uint64, goType reflect.Type, keyFuncs, valFuncs ifaceCoderFuncs, opts marshalOptions) ([]byte, error) {
66 m := p.AsValueOf(goType).Elem()
Damien Neilc37adef2019-04-01 13:49:56 -070067 var err error
68
69 if m.Len() == 0 {
70 return b, nil
71 }
72
73 if opts.Deterministic() {
74 keys := m.MapKeys()
75 sort.Sort(mapKeys(keys))
76 for _, k := range keys {
77 b, err = appendMapElement(b, k, m.MapIndex(k), wiretag, keyWiretag, valWiretag, keyFuncs, valFuncs, opts)
Damien Neil8c86fc52019-06-19 09:28:29 -070078 if err != nil {
Damien Neilc37adef2019-04-01 13:49:56 -070079 return b, err
80 }
81 }
Damien Neil8c86fc52019-06-19 09:28:29 -070082 return b, nil
Damien Neilc37adef2019-04-01 13:49:56 -070083 }
84
85 iter := mapRange(m)
86 for iter.Next() {
87 b, err = appendMapElement(b, iter.Key(), iter.Value(), wiretag, keyWiretag, valWiretag, keyFuncs, valFuncs, opts)
Damien Neil8c86fc52019-06-19 09:28:29 -070088 if err != nil {
Damien Neilc37adef2019-04-01 13:49:56 -070089 return b, err
90 }
91 }
Damien Neil8c86fc52019-06-19 09:28:29 -070092 return b, nil
Damien Neilc37adef2019-04-01 13:49:56 -070093}
94
95func appendMapElement(b []byte, key, value reflect.Value, wiretag, keyWiretag, valWiretag uint64, keyFuncs, valFuncs ifaceCoderFuncs, opts marshalOptions) ([]byte, error) {
96 ki := key.Interface()
97 vi := value.Interface()
98 b = wire.AppendVarint(b, wiretag)
99 size := keyFuncs.size(ki, mapKeyTagSize, opts) + valFuncs.size(vi, mapValTagSize, opts)
100 b = wire.AppendVarint(b, uint64(size))
Damien Neilc37adef2019-04-01 13:49:56 -0700101 b, err := keyFuncs.marshal(b, ki, keyWiretag, opts)
Damien Neil8c86fc52019-06-19 09:28:29 -0700102 if err != nil {
Damien Neilc37adef2019-04-01 13:49:56 -0700103 return b, err
104 }
105 b, err = valFuncs.marshal(b, vi, valWiretag, opts)
Damien Neil8c86fc52019-06-19 09:28:29 -0700106 if err != nil {
Damien Neilc37adef2019-04-01 13:49:56 -0700107 return b, err
108 }
Damien Neil8c86fc52019-06-19 09:28:29 -0700109 return b, nil
Damien Neilc37adef2019-04-01 13:49:56 -0700110}
111
Damien Neil5322bdb2019-04-09 15:57:05 -0700112func isInitMap(p pointer, goType reflect.Type, isInit func(interface{}) error) error {
113 m := p.AsValueOf(goType).Elem()
114 if m.Len() == 0 {
115 return nil
116 }
117 iter := mapRange(m)
118 for iter.Next() {
119 if err := isInit(iter.Value().Interface()); err != nil {
120 return err
121 }
122 }
123 return nil
124}
125
Damien Neilc37adef2019-04-01 13:49:56 -0700126// mapKeys returns a sort.Interface to be used for sorting the map keys.
127// Map fields may have key types of non-float scalars, strings and enums.
128func mapKeys(vs []reflect.Value) sort.Interface {
129 s := mapKeySorter{vs: vs}
130
131 // Type specialization per https://developers.google.com/protocol-buffers/docs/proto#maps.
132 if len(vs) == 0 {
133 return s
134 }
135 switch vs[0].Kind() {
136 case reflect.Int32, reflect.Int64:
137 s.less = func(a, b reflect.Value) bool { return a.Int() < b.Int() }
138 case reflect.Uint32, reflect.Uint64:
139 s.less = func(a, b reflect.Value) bool { return a.Uint() < b.Uint() }
140 case reflect.Bool:
141 s.less = func(a, b reflect.Value) bool { return !a.Bool() && b.Bool() } // false < true
142 case reflect.String:
143 s.less = func(a, b reflect.Value) bool { return a.String() < b.String() }
144 default:
145 panic(fmt.Sprintf("unsupported map key type: %v", vs[0].Kind()))
146 }
147
148 return s
149}
150
151type mapKeySorter struct {
152 vs []reflect.Value
153 less func(a, b reflect.Value) bool
154}
155
156func (s mapKeySorter) Len() int { return len(s.vs) }
157func (s mapKeySorter) Swap(i, j int) { s.vs[i], s.vs[j] = s.vs[j], s.vs[i] }
158func (s mapKeySorter) Less(i, j int) bool {
159 return s.less(s.vs[i], s.vs[j])
160}