blob: 00d65113890733448eb3ea2d182114d0a6508ea9 [file] [log] [blame]
Damien Neilc37adef2019-04-01 13:49:56 -07001// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
Damien Neilc37adef2019-04-01 13:49:56 -07008 "reflect"
Damien Neilc37adef2019-04-01 13:49:56 -07009
10 "google.golang.org/protobuf/internal/encoding/wire"
Damien Neil68b81c32019-08-22 11:41:32 -070011 "google.golang.org/protobuf/internal/mapsort"
Damien Neilc37adef2019-04-01 13:49:56 -070012 pref "google.golang.org/protobuf/reflect/protoreflect"
13)
14
Damien Neile91877d2019-06-27 10:54:42 -070015type mapInfo struct {
16 goType reflect.Type
17 keyWiretag uint64
18 valWiretag uint64
Damien Neil68b81c32019-08-22 11:41:32 -070019 keyFuncs valueCoderFuncs
20 valFuncs valueCoderFuncs
21 keyZero pref.Value
22 keyKind pref.Kind
Damien Neile91877d2019-06-27 10:54:42 -070023}
24
Damien Neilc37adef2019-04-01 13:49:56 -070025func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (funcs pointerCoderFuncs) {
26 // TODO: Consider generating specialized map coders.
27 keyField := fd.MapKey()
28 valField := fd.MapValue()
29 keyWiretag := wire.EncodeTag(1, wireTypes[keyField.Kind()])
30 valWiretag := wire.EncodeTag(2, wireTypes[valField.Kind()])
31 keyFuncs := encoderFuncsForValue(keyField, ft.Key())
32 valFuncs := encoderFuncsForValue(valField, ft.Elem())
Damien Neil68b81c32019-08-22 11:41:32 -070033 conv := NewConverter(ft, fd)
Damien Neilc37adef2019-04-01 13:49:56 -070034
Damien Neile91877d2019-06-27 10:54:42 -070035 mapi := &mapInfo{
36 goType: ft,
37 keyWiretag: keyWiretag,
38 valWiretag: valWiretag,
39 keyFuncs: keyFuncs,
40 valFuncs: valFuncs,
Damien Neil68b81c32019-08-22 11:41:32 -070041 keyZero: keyField.Default(),
42 keyKind: keyField.Kind(),
Damien Neile91877d2019-06-27 10:54:42 -070043 }
44
Damien Neil5322bdb2019-04-09 15:57:05 -070045 funcs = pointerCoderFuncs{
Damien Neilc37adef2019-04-01 13:49:56 -070046 size: func(p pointer, tagsize int, opts marshalOptions) int {
Damien Neil68b81c32019-08-22 11:41:32 -070047 mapv := conv.PBValueOf(p.AsValueOf(ft).Elem()).Map()
48 return sizeMap(mapv, tagsize, mapi, opts)
Damien Neilc37adef2019-04-01 13:49:56 -070049 },
50 marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
Damien Neil68b81c32019-08-22 11:41:32 -070051 mapv := conv.PBValueOf(p.AsValueOf(ft).Elem()).Map()
52 return appendMap(b, mapv, wiretag, mapi, opts)
Damien Neilc37adef2019-04-01 13:49:56 -070053 },
Damien Neile91877d2019-06-27 10:54:42 -070054 unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
Damien Neil68b81c32019-08-22 11:41:32 -070055 mp := p.AsValueOf(ft)
56 if mp.Elem().IsNil() {
57 mp.Elem().Set(reflect.MakeMap(mapi.goType))
58 }
59 mapv := conv.PBValueOf(mp.Elem()).Map()
60 return consumeMap(b, mapv, wtyp, mapi, opts)
Damien Neile91877d2019-06-27 10:54:42 -070061 },
Damien Neilc37adef2019-04-01 13:49:56 -070062 }
Damien Neil5322bdb2019-04-09 15:57:05 -070063 if valFuncs.isInit != nil {
64 funcs.isInit = func(p pointer) error {
Damien Neil68b81c32019-08-22 11:41:32 -070065 mapv := conv.PBValueOf(p.AsValueOf(ft).Elem()).Map()
66 return isInitMap(mapv, mapi)
Damien Neil5322bdb2019-04-09 15:57:05 -070067 }
68 }
69 return funcs
Damien Neilc37adef2019-04-01 13:49:56 -070070}
71
72const (
73 mapKeyTagSize = 1 // field 1, tag size 1.
74 mapValTagSize = 1 // field 2, tag size 2.
75)
76
Damien Neil68b81c32019-08-22 11:41:32 -070077func sizeMap(mapv pref.Map, tagsize int, mapi *mapInfo, opts marshalOptions) int {
78 if mapv.Len() == 0 {
79 return 0
Damien Neile91877d2019-06-27 10:54:42 -070080 }
Damien Neil68b81c32019-08-22 11:41:32 -070081 n := 0
82 mapv.Range(func(key pref.MapKey, value pref.Value) bool {
83 n += tagsize + wire.SizeBytes(
84 mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)+
85 mapi.valFuncs.size(value, mapValTagSize, opts))
86 return true
87 })
88 return n
89}
Damien Neile91877d2019-06-27 10:54:42 -070090
Damien Neil68b81c32019-08-22 11:41:32 -070091func consumeMap(b []byte, mapv pref.Map, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (int, error) {
Damien Neile91877d2019-06-27 10:54:42 -070092 if wtyp != wire.BytesType {
93 return 0, errUnknown
94 }
95 b, n := wire.ConsumeBytes(b)
96 if n < 0 {
97 return 0, wire.ParseError(n)
98 }
99 var (
100 key = mapi.keyZero
Damien Neil68b81c32019-08-22 11:41:32 -0700101 val = mapv.NewValue()
Damien Neile91877d2019-06-27 10:54:42 -0700102 )
Damien Neile91877d2019-06-27 10:54:42 -0700103 for len(b) > 0 {
104 num, wtyp, n := wire.ConsumeTag(b)
105 if n < 0 {
106 return 0, wire.ParseError(n)
107 }
108 b = b[n:]
109 err := errUnknown
110 switch num {
111 case 1:
Damien Neil68b81c32019-08-22 11:41:32 -0700112 var v pref.Value
Damien Neile91877d2019-06-27 10:54:42 -0700113 v, n, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
114 if err != nil {
115 break
116 }
117 key = v
118 case 2:
Damien Neil68b81c32019-08-22 11:41:32 -0700119 var v pref.Value
Damien Neile91877d2019-06-27 10:54:42 -0700120 v, n, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
121 if err != nil {
122 break
123 }
124 val = v
125 }
126 if err == errUnknown {
127 n = wire.ConsumeFieldValue(num, wtyp, b)
128 if n < 0 {
129 return 0, wire.ParseError(n)
130 }
131 } else if err != nil {
132 return 0, err
133 }
134 b = b[n:]
135 }
Damien Neil68b81c32019-08-22 11:41:32 -0700136 mapv.Set(key.MapKey(), val)
Damien Neile91877d2019-06-27 10:54:42 -0700137 return n, nil
138}
139
Damien Neil68b81c32019-08-22 11:41:32 -0700140func appendMap(b []byte, mapv pref.Map, wiretag uint64, mapi *mapInfo, opts marshalOptions) ([]byte, error) {
141 if mapv.Len() == 0 {
142 return b, nil
Damien Neilc37adef2019-04-01 13:49:56 -0700143 }
Damien Neilc37adef2019-04-01 13:49:56 -0700144 var err error
Damien Neil68b81c32019-08-22 11:41:32 -0700145 fn := func(key pref.MapKey, value pref.Value) bool {
146 b = wire.AppendVarint(b, wiretag)
147 size := 0
148 size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
149 size += mapi.valFuncs.size(value, mapValTagSize, opts)
150 b = wire.AppendVarint(b, uint64(size))
151 b, err = mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
Damien Neil8c86fc52019-06-19 09:28:29 -0700152 if err != nil {
Damien Neil68b81c32019-08-22 11:41:32 -0700153 return false
Damien Neilc37adef2019-04-01 13:49:56 -0700154 }
Damien Neil68b81c32019-08-22 11:41:32 -0700155 b, err = mapi.valFuncs.marshal(b, value, mapi.valWiretag, opts)
156 if err != nil {
157 return false
Damien Neil5322bdb2019-04-09 15:57:05 -0700158 }
Damien Neil68b81c32019-08-22 11:41:32 -0700159 return true
Damien Neil5322bdb2019-04-09 15:57:05 -0700160 }
Damien Neil68b81c32019-08-22 11:41:32 -0700161 if opts.Deterministic() {
162 mapsort.Range(mapv, mapi.keyKind, fn)
163 } else {
164 mapv.Range(fn)
165 }
166 return b, err
Damien Neil5322bdb2019-04-09 15:57:05 -0700167}
168
Damien Neil68b81c32019-08-22 11:41:32 -0700169func isInitMap(mapv pref.Map, mapi *mapInfo) error {
170 var err error
171 mapv.Range(func(_ pref.MapKey, value pref.Value) bool {
172 err = mapi.valFuncs.isInit(value)
173 return err == nil
174 })
175 return err
Damien Neilc37adef2019-04-01 13:49:56 -0700176}