blob: 38e152af061762878188a6d3820a777436bfa63e [file] [log] [blame]
Damien Neilc37adef2019-04-01 13:49:56 -07001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8 "fmt"
9 "reflect"
10 "sort"
11
12 "google.golang.org/protobuf/internal/encoding/wire"
13 "google.golang.org/protobuf/internal/errors"
14 "google.golang.org/protobuf/proto"
15 pref "google.golang.org/protobuf/reflect/protoreflect"
16)
17
18var protoMessageType = reflect.TypeOf((*proto.Message)(nil)).Elem()
19
20func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (funcs pointerCoderFuncs) {
21 // TODO: Consider generating specialized map coders.
22 keyField := fd.MapKey()
23 valField := fd.MapValue()
24 keyWiretag := wire.EncodeTag(1, wireTypes[keyField.Kind()])
25 valWiretag := wire.EncodeTag(2, wireTypes[valField.Kind()])
26 keyFuncs := encoderFuncsForValue(keyField, ft.Key())
27 valFuncs := encoderFuncsForValue(valField, ft.Elem())
28
29 return pointerCoderFuncs{
30 size: func(p pointer, tagsize int, opts marshalOptions) int {
31 return sizeMap(p, tagsize, ft, keyFuncs, valFuncs, opts)
32 },
33 marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
34 return appendMap(b, p, wiretag, keyWiretag, valWiretag, ft, keyFuncs, valFuncs, opts)
35 },
36 }
37}
38
39const (
40 mapKeyTagSize = 1 // field 1, tag size 1.
41 mapValTagSize = 1 // field 2, tag size 2.
42)
43
44func sizeMap(p pointer, tagsize int, goType reflect.Type, keyFuncs, valFuncs ifaceCoderFuncs, opts marshalOptions) int {
45 m := p.AsValueOf(goType).Elem()
46 n := 0
47 if m.Len() == 0 {
48 return 0
49 }
50 iter := mapRange(m)
51 for iter.Next() {
52 ki := iter.Key().Interface()
53 vi := iter.Value().Interface()
54 size := keyFuncs.size(ki, mapKeyTagSize, opts) + valFuncs.size(vi, mapValTagSize, opts)
55 n += wire.SizeBytes(size) + tagsize
56 }
57 return n
58}
59
60func appendMap(b []byte, p pointer, wiretag, keyWiretag, valWiretag uint64, goType reflect.Type, keyFuncs, valFuncs ifaceCoderFuncs, opts marshalOptions) ([]byte, error) {
61 m := p.AsValueOf(goType).Elem()
62 var nerr errors.NonFatal
63 var err error
64
65 if m.Len() == 0 {
66 return b, nil
67 }
68
69 if opts.Deterministic() {
70 keys := m.MapKeys()
71 sort.Sort(mapKeys(keys))
72 for _, k := range keys {
73 b, err = appendMapElement(b, k, m.MapIndex(k), wiretag, keyWiretag, valWiretag, keyFuncs, valFuncs, opts)
74 if !nerr.Merge(err) {
75 return b, err
76 }
77 }
78 return b, nerr.E
79 }
80
81 iter := mapRange(m)
82 for iter.Next() {
83 b, err = appendMapElement(b, iter.Key(), iter.Value(), wiretag, keyWiretag, valWiretag, keyFuncs, valFuncs, opts)
84 if !nerr.Merge(err) {
85 return b, err
86 }
87 }
88 return b, nerr.E
89}
90
91func appendMapElement(b []byte, key, value reflect.Value, wiretag, keyWiretag, valWiretag uint64, keyFuncs, valFuncs ifaceCoderFuncs, opts marshalOptions) ([]byte, error) {
92 ki := key.Interface()
93 vi := value.Interface()
94 b = wire.AppendVarint(b, wiretag)
95 size := keyFuncs.size(ki, mapKeyTagSize, opts) + valFuncs.size(vi, mapValTagSize, opts)
96 b = wire.AppendVarint(b, uint64(size))
97 var nerr errors.NonFatal
98 b, err := keyFuncs.marshal(b, ki, keyWiretag, opts)
99 if !nerr.Merge(err) {
100 return b, err
101 }
102 b, err = valFuncs.marshal(b, vi, valWiretag, opts)
103 if !nerr.Merge(err) {
104 return b, err
105 }
106 return b, nerr.E
107}
108
109// mapKeys returns a sort.Interface to be used for sorting the map keys.
110// Map fields may have key types of non-float scalars, strings and enums.
111func mapKeys(vs []reflect.Value) sort.Interface {
112 s := mapKeySorter{vs: vs}
113
114 // Type specialization per https://developers.google.com/protocol-buffers/docs/proto#maps.
115 if len(vs) == 0 {
116 return s
117 }
118 switch vs[0].Kind() {
119 case reflect.Int32, reflect.Int64:
120 s.less = func(a, b reflect.Value) bool { return a.Int() < b.Int() }
121 case reflect.Uint32, reflect.Uint64:
122 s.less = func(a, b reflect.Value) bool { return a.Uint() < b.Uint() }
123 case reflect.Bool:
124 s.less = func(a, b reflect.Value) bool { return !a.Bool() && b.Bool() } // false < true
125 case reflect.String:
126 s.less = func(a, b reflect.Value) bool { return a.String() < b.String() }
127 default:
128 panic(fmt.Sprintf("unsupported map key type: %v", vs[0].Kind()))
129 }
130
131 return s
132}
133
134type mapKeySorter struct {
135 vs []reflect.Value
136 less func(a, b reflect.Value) bool
137}
138
139func (s mapKeySorter) Len() int { return len(s.vs) }
140func (s mapKeySorter) Swap(i, j int) { s.vs[i], s.vs[j] = s.vs[j], s.vs[i] }
141func (s mapKeySorter) Less(i, j int) bool {
142 return s.less(s.vs[i], s.vs[j])
143}