blob: d140b27dcc9747d37be9d28cfb7164c48f6a03bc [file] [log] [blame]
Damien Neilc37adef2019-04-01 13:49:56 -07001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8 "fmt"
9 "reflect"
10 "sort"
11
12 "google.golang.org/protobuf/internal/encoding/wire"
Damien Neilc37adef2019-04-01 13:49:56 -070013 "google.golang.org/protobuf/proto"
14 pref "google.golang.org/protobuf/reflect/protoreflect"
15)
16
17var protoMessageType = reflect.TypeOf((*proto.Message)(nil)).Elem()
18
19func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (funcs pointerCoderFuncs) {
20 // TODO: Consider generating specialized map coders.
21 keyField := fd.MapKey()
22 valField := fd.MapValue()
23 keyWiretag := wire.EncodeTag(1, wireTypes[keyField.Kind()])
24 valWiretag := wire.EncodeTag(2, wireTypes[valField.Kind()])
25 keyFuncs := encoderFuncsForValue(keyField, ft.Key())
26 valFuncs := encoderFuncsForValue(valField, ft.Elem())
27
28 return pointerCoderFuncs{
29 size: func(p pointer, tagsize int, opts marshalOptions) int {
30 return sizeMap(p, tagsize, ft, keyFuncs, valFuncs, opts)
31 },
32 marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
33 return appendMap(b, p, wiretag, keyWiretag, valWiretag, ft, keyFuncs, valFuncs, opts)
34 },
35 }
36}
37
38const (
39 mapKeyTagSize = 1 // field 1, tag size 1.
40 mapValTagSize = 1 // field 2, tag size 2.
41)
42
43func sizeMap(p pointer, tagsize int, goType reflect.Type, keyFuncs, valFuncs ifaceCoderFuncs, opts marshalOptions) int {
44 m := p.AsValueOf(goType).Elem()
45 n := 0
46 if m.Len() == 0 {
47 return 0
48 }
49 iter := mapRange(m)
50 for iter.Next() {
51 ki := iter.Key().Interface()
52 vi := iter.Value().Interface()
53 size := keyFuncs.size(ki, mapKeyTagSize, opts) + valFuncs.size(vi, mapValTagSize, opts)
54 n += wire.SizeBytes(size) + tagsize
55 }
56 return n
57}
58
59func appendMap(b []byte, p pointer, wiretag, keyWiretag, valWiretag uint64, goType reflect.Type, keyFuncs, valFuncs ifaceCoderFuncs, opts marshalOptions) ([]byte, error) {
60 m := p.AsValueOf(goType).Elem()
Damien Neilc37adef2019-04-01 13:49:56 -070061 var err error
62
63 if m.Len() == 0 {
64 return b, nil
65 }
66
67 if opts.Deterministic() {
68 keys := m.MapKeys()
69 sort.Sort(mapKeys(keys))
70 for _, k := range keys {
71 b, err = appendMapElement(b, k, m.MapIndex(k), wiretag, keyWiretag, valWiretag, keyFuncs, valFuncs, opts)
Damien Neil8c86fc52019-06-19 09:28:29 -070072 if err != nil {
Damien Neilc37adef2019-04-01 13:49:56 -070073 return b, err
74 }
75 }
Damien Neil8c86fc52019-06-19 09:28:29 -070076 return b, nil
Damien Neilc37adef2019-04-01 13:49:56 -070077 }
78
79 iter := mapRange(m)
80 for iter.Next() {
81 b, err = appendMapElement(b, iter.Key(), iter.Value(), wiretag, keyWiretag, valWiretag, keyFuncs, valFuncs, opts)
Damien Neil8c86fc52019-06-19 09:28:29 -070082 if err != nil {
Damien Neilc37adef2019-04-01 13:49:56 -070083 return b, err
84 }
85 }
Damien Neil8c86fc52019-06-19 09:28:29 -070086 return b, nil
Damien Neilc37adef2019-04-01 13:49:56 -070087}
88
89func appendMapElement(b []byte, key, value reflect.Value, wiretag, keyWiretag, valWiretag uint64, keyFuncs, valFuncs ifaceCoderFuncs, opts marshalOptions) ([]byte, error) {
90 ki := key.Interface()
91 vi := value.Interface()
92 b = wire.AppendVarint(b, wiretag)
93 size := keyFuncs.size(ki, mapKeyTagSize, opts) + valFuncs.size(vi, mapValTagSize, opts)
94 b = wire.AppendVarint(b, uint64(size))
Damien Neilc37adef2019-04-01 13:49:56 -070095 b, err := keyFuncs.marshal(b, ki, keyWiretag, opts)
Damien Neil8c86fc52019-06-19 09:28:29 -070096 if err != nil {
Damien Neilc37adef2019-04-01 13:49:56 -070097 return b, err
98 }
99 b, err = valFuncs.marshal(b, vi, valWiretag, opts)
Damien Neil8c86fc52019-06-19 09:28:29 -0700100 if err != nil {
Damien Neilc37adef2019-04-01 13:49:56 -0700101 return b, err
102 }
Damien Neil8c86fc52019-06-19 09:28:29 -0700103 return b, nil
Damien Neilc37adef2019-04-01 13:49:56 -0700104}
105
106// mapKeys returns a sort.Interface to be used for sorting the map keys.
107// Map fields may have key types of non-float scalars, strings and enums.
108func mapKeys(vs []reflect.Value) sort.Interface {
109 s := mapKeySorter{vs: vs}
110
111 // Type specialization per https://developers.google.com/protocol-buffers/docs/proto#maps.
112 if len(vs) == 0 {
113 return s
114 }
115 switch vs[0].Kind() {
116 case reflect.Int32, reflect.Int64:
117 s.less = func(a, b reflect.Value) bool { return a.Int() < b.Int() }
118 case reflect.Uint32, reflect.Uint64:
119 s.less = func(a, b reflect.Value) bool { return a.Uint() < b.Uint() }
120 case reflect.Bool:
121 s.less = func(a, b reflect.Value) bool { return !a.Bool() && b.Bool() } // false < true
122 case reflect.String:
123 s.less = func(a, b reflect.Value) bool { return a.String() < b.String() }
124 default:
125 panic(fmt.Sprintf("unsupported map key type: %v", vs[0].Kind()))
126 }
127
128 return s
129}
130
131type mapKeySorter struct {
132 vs []reflect.Value
133 less func(a, b reflect.Value) bool
134}
135
136func (s mapKeySorter) Len() int { return len(s.vs) }
137func (s mapKeySorter) Swap(i, j int) { s.vs[i], s.vs[j] = s.vs[j], s.vs[i] }
138func (s mapKeySorter) Less(i, j int) bool {
139 return s.less(s.vs[i], s.vs[j])
140}