internal/impl: store extension values as Values
Change the storage type of ExtensionField from interface{} to
protoreflect.Value.
Replace the codec functions operating on interface{}s with ones
operating on Values.
Values are potentially more efficient, since they can represent
non-pointer types without allocation. This also reduces the number of
types used to represent field values.
Additionally, this change lays groundwork for changing the
user-visible representation of repeated extension fields from
*[]T to []T. The storage type for extension fields must support mutation
(thus *[]T currently); changing the storage type to a Value permits this
without the need to introduce yet another view on field values.
Change-Id: Ida336be14112bb940f655236eb58df21bf312525
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/192218
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/internal/impl/codec_map.go b/internal/impl/codec_map.go
index 4916704..00d6511 100644
--- a/internal/impl/codec_map.go
+++ b/internal/impl/codec_map.go
@@ -5,11 +5,10 @@
package impl
import (
- "fmt"
"reflect"
- "sort"
"google.golang.org/protobuf/internal/encoding/wire"
+ "google.golang.org/protobuf/internal/mapsort"
pref "google.golang.org/protobuf/reflect/protoreflect"
)
@@ -17,11 +16,10 @@
goType reflect.Type
keyWiretag uint64
valWiretag uint64
- keyFuncs ifaceCoderFuncs
- valFuncs ifaceCoderFuncs
- keyZero interface{}
- valZero interface{}
- newVal func() interface{}
+ keyFuncs valueCoderFuncs
+ valFuncs valueCoderFuncs
+ keyZero pref.Value
+ keyKind pref.Kind
}
func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (funcs pointerCoderFuncs) {
@@ -32,6 +30,7 @@
valWiretag := wire.EncodeTag(2, wireTypes[valField.Kind()])
keyFuncs := encoderFuncsForValue(keyField, ft.Key())
valFuncs := encoderFuncsForValue(valField, ft.Elem())
+ conv := NewConverter(ft, fd)
mapi := &mapInfo{
goType: ft,
@@ -39,30 +38,32 @@
valWiretag: valWiretag,
keyFuncs: keyFuncs,
valFuncs: valFuncs,
- keyZero: reflect.Zero(ft.Key()).Interface(),
- valZero: reflect.Zero(ft.Elem()).Interface(),
- }
- switch valField.Kind() {
- case pref.GroupKind, pref.MessageKind:
- mapi.newVal = func() interface{} {
- return reflect.New(ft.Elem().Elem()).Interface()
- }
+ keyZero: keyField.Default(),
+ keyKind: keyField.Kind(),
}
funcs = pointerCoderFuncs{
size: func(p pointer, tagsize int, opts marshalOptions) int {
- return sizeMap(p, tagsize, ft, keyFuncs, valFuncs, opts)
+ mapv := conv.PBValueOf(p.AsValueOf(ft).Elem()).Map()
+ return sizeMap(mapv, tagsize, mapi, opts)
},
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
- return appendMap(b, p, wiretag, keyWiretag, valWiretag, ft, keyFuncs, valFuncs, opts)
+ mapv := conv.PBValueOf(p.AsValueOf(ft).Elem()).Map()
+ return appendMap(b, mapv, wiretag, mapi, opts)
},
unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
- return consumeMap(b, p, wtyp, mapi, opts)
+ mp := p.AsValueOf(ft)
+ if mp.Elem().IsNil() {
+ mp.Elem().Set(reflect.MakeMap(mapi.goType))
+ }
+ mapv := conv.PBValueOf(mp.Elem()).Map()
+ return consumeMap(b, mapv, wtyp, mapi, opts)
},
}
if valFuncs.isInit != nil {
funcs.isInit = func(p pointer) error {
- return isInitMap(p, ft, valFuncs.isInit)
+ mapv := conv.PBValueOf(p.AsValueOf(ft).Elem()).Map()
+ return isInitMap(mapv, mapi)
}
}
return funcs
@@ -73,13 +74,21 @@
mapValTagSize = 1 // field 2, tag size 2.
)
-func consumeMap(b []byte, p pointer, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (int, error) {
- mp := p.AsValueOf(mapi.goType)
- if mp.Elem().IsNil() {
- mp.Elem().Set(reflect.MakeMap(mapi.goType))
+func sizeMap(mapv pref.Map, tagsize int, mapi *mapInfo, opts marshalOptions) int {
+ if mapv.Len() == 0 {
+ return 0
}
- m := mp.Elem()
+ n := 0
+ mapv.Range(func(key pref.MapKey, value pref.Value) bool {
+ n += tagsize + wire.SizeBytes(
+ mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)+
+ mapi.valFuncs.size(value, mapValTagSize, opts))
+ return true
+ })
+ return n
+}
+func consumeMap(b []byte, mapv pref.Map, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (int, error) {
if wtyp != wire.BytesType {
return 0, errUnknown
}
@@ -89,11 +98,8 @@
}
var (
key = mapi.keyZero
- val = mapi.valZero
+ val = mapv.NewValue()
)
- if mapi.newVal != nil {
- val = mapi.newVal()
- }
for len(b) > 0 {
num, wtyp, n := wire.ConsumeTag(b)
if n < 0 {
@@ -103,14 +109,14 @@
err := errUnknown
switch num {
case 1:
- var v interface{}
+ var v pref.Value
v, n, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
if err != nil {
break
}
key = v
case 2:
- var v interface{}
+ var v pref.Value
v, n, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
if err != nil {
break
@@ -127,119 +133,44 @@
}
b = b[n:]
}
- m.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(val))
+ mapv.Set(key.MapKey(), val)
return n, nil
}
-func sizeMap(p pointer, tagsize int, goType reflect.Type, keyFuncs, valFuncs ifaceCoderFuncs, opts marshalOptions) int {
- m := p.AsValueOf(goType).Elem()
- n := 0
- if m.Len() == 0 {
- return 0
+func appendMap(b []byte, mapv pref.Map, wiretag uint64, mapi *mapInfo, opts marshalOptions) ([]byte, error) {
+ if mapv.Len() == 0 {
+ return b, nil
}
- iter := mapRange(m)
- for iter.Next() {
- ki := iter.Key().Interface()
- vi := iter.Value().Interface()
- size := keyFuncs.size(ki, mapKeyTagSize, opts) + valFuncs.size(vi, mapValTagSize, opts)
- n += wire.SizeBytes(size) + tagsize
- }
- return n
-}
-
-func appendMap(b []byte, p pointer, wiretag, keyWiretag, valWiretag uint64, goType reflect.Type, keyFuncs, valFuncs ifaceCoderFuncs, opts marshalOptions) ([]byte, error) {
- m := p.AsValueOf(goType).Elem()
var err error
-
- if m.Len() == 0 {
- return b, nil
- }
-
- if opts.Deterministic() {
- keys := m.MapKeys()
- sort.Sort(mapKeys(keys))
- for _, k := range keys {
- b, err = appendMapElement(b, k, m.MapIndex(k), wiretag, keyWiretag, valWiretag, keyFuncs, valFuncs, opts)
- if err != nil {
- return b, err
- }
- }
- return b, nil
- }
-
- iter := mapRange(m)
- for iter.Next() {
- b, err = appendMapElement(b, iter.Key(), iter.Value(), wiretag, keyWiretag, valWiretag, keyFuncs, valFuncs, opts)
+ fn := func(key pref.MapKey, value pref.Value) bool {
+ b = wire.AppendVarint(b, wiretag)
+ size := 0
+ size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
+ size += mapi.valFuncs.size(value, mapValTagSize, opts)
+ b = wire.AppendVarint(b, uint64(size))
+ b, err = mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
if err != nil {
- return b, err
+ return false
}
- }
- return b, nil
-}
-
-func appendMapElement(b []byte, key, value reflect.Value, wiretag, keyWiretag, valWiretag uint64, keyFuncs, valFuncs ifaceCoderFuncs, opts marshalOptions) ([]byte, error) {
- ki := key.Interface()
- vi := value.Interface()
- b = wire.AppendVarint(b, wiretag)
- size := keyFuncs.size(ki, mapKeyTagSize, opts) + valFuncs.size(vi, mapValTagSize, opts)
- b = wire.AppendVarint(b, uint64(size))
- b, err := keyFuncs.marshal(b, ki, keyWiretag, opts)
- if err != nil {
- return b, err
- }
- b, err = valFuncs.marshal(b, vi, valWiretag, opts)
- if err != nil {
- return b, err
- }
- return b, nil
-}
-
-func isInitMap(p pointer, goType reflect.Type, isInit func(interface{}) error) error {
- m := p.AsValueOf(goType).Elem()
- if m.Len() == 0 {
- return nil
- }
- iter := mapRange(m)
- for iter.Next() {
- if err := isInit(iter.Value().Interface()); err != nil {
- return err
+ b, err = mapi.valFuncs.marshal(b, value, mapi.valWiretag, opts)
+ if err != nil {
+ return false
}
+ return true
}
- return nil
+ if opts.Deterministic() {
+ mapsort.Range(mapv, mapi.keyKind, fn)
+ } else {
+ mapv.Range(fn)
+ }
+ return b, err
}
-// mapKeys returns a sort.Interface to be used for sorting the map keys.
-// Map fields may have key types of non-float scalars, strings and enums.
-func mapKeys(vs []reflect.Value) sort.Interface {
- s := mapKeySorter{vs: vs}
-
- // Type specialization per https://developers.google.com/protocol-buffers/docs/proto#maps.
- if len(vs) == 0 {
- return s
- }
- switch vs[0].Kind() {
- case reflect.Int32, reflect.Int64:
- s.less = func(a, b reflect.Value) bool { return a.Int() < b.Int() }
- case reflect.Uint32, reflect.Uint64:
- s.less = func(a, b reflect.Value) bool { return a.Uint() < b.Uint() }
- case reflect.Bool:
- s.less = func(a, b reflect.Value) bool { return !a.Bool() && b.Bool() } // false < true
- case reflect.String:
- s.less = func(a, b reflect.Value) bool { return a.String() < b.String() }
- default:
- panic(fmt.Sprintf("unsupported map key type: %v", vs[0].Kind()))
- }
-
- return s
-}
-
-type mapKeySorter struct {
- vs []reflect.Value
- less func(a, b reflect.Value) bool
-}
-
-func (s mapKeySorter) Len() int { return len(s.vs) }
-func (s mapKeySorter) Swap(i, j int) { s.vs[i], s.vs[j] = s.vs[j], s.vs[i] }
-func (s mapKeySorter) Less(i, j int) bool {
- return s.less(s.vs[i], s.vs[j])
+func isInitMap(mapv pref.Map, mapi *mapInfo) error {
+ var err error
+ mapv.Range(func(_ pref.MapKey, value pref.Value) bool {
+ err = mapi.valFuncs.isInit(value)
+ return err == nil
+ })
+ return err
}