internal/impl: pass *coderFieldInfo into fast-path functions
Refactor the fast-path size, marshal, unmarshal, and isinit functions to
take the *coderFieldInfo for the field as input.
This replaces a number of closures capturing field-specific information
with functions taking that information as an explicit parameter.
Change-Id: I8cb39701265edb7b673f6f04a0152d5f4dbb4d5d
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/218937
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/internal/impl/codec_map.go b/internal/impl/codec_map.go
index c8c0925..b319583 100644
--- a/internal/impl/codec_map.go
+++ b/internal/impl/codec_map.go
@@ -14,18 +14,17 @@
)
type mapInfo struct {
- goType reflect.Type
- keyWiretag uint64
- valWiretag uint64
- keyFuncs valueCoderFuncs
- valFuncs valueCoderFuncs
- keyZero pref.Value
- keyKind pref.Kind
- valMessageInfo *MessageInfo
- conv *mapConverter
+ goType reflect.Type
+ keyWiretag uint64
+ valWiretag uint64
+ keyFuncs valueCoderFuncs
+ valFuncs valueCoderFuncs
+ keyZero pref.Value
+ keyKind pref.Kind
+ conv *mapConverter
}
-func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (funcs pointerCoderFuncs) {
+func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (valueMessage *MessageInfo, funcs pointerCoderFuncs) {
// TODO: Consider generating specialized map coders.
keyField := fd.MapKey()
valField := fd.MapValue()
@@ -46,34 +45,34 @@
conv: conv,
}
if valField.Kind() == pref.MessageKind {
- mapi.valMessageInfo = getMessageInfo(ft.Elem())
+ valueMessage = getMessageInfo(ft.Elem())
}
funcs = pointerCoderFuncs{
- size: func(p pointer, tagsize int, opts marshalOptions) int {
- return sizeMap(p.AsValueOf(ft).Elem(), tagsize, mapi, opts)
+ size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int {
+ return sizeMap(p.AsValueOf(ft).Elem(), mapi, f, opts)
},
- marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
- return appendMap(b, p.AsValueOf(ft).Elem(), wiretag, mapi, opts)
+ marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
+ return appendMap(b, p.AsValueOf(ft).Elem(), mapi, f, opts)
},
- unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (unmarshalOutput, error) {
+ unmarshal: func(b []byte, p pointer, wtyp wire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) {
mp := p.AsValueOf(ft)
if mp.Elem().IsNil() {
mp.Elem().Set(reflect.MakeMap(mapi.goType))
}
- if mapi.valMessageInfo == nil {
- return consumeMap(b, mp.Elem(), wtyp, mapi, opts)
+ if f.mi == nil {
+ return consumeMap(b, mp.Elem(), wtyp, mapi, f, opts)
} else {
- return consumeMapOfMessage(b, mp.Elem(), wtyp, mapi, opts)
+ return consumeMapOfMessage(b, mp.Elem(), wtyp, mapi, f, opts)
}
},
}
if valFuncs.isInit != nil {
- funcs.isInit = func(p pointer) error {
- return isInitMap(p.AsValueOf(ft).Elem(), mapi)
+ funcs.isInit = func(p pointer, f *coderFieldInfo) error {
+ return isInitMap(p.AsValueOf(ft).Elem(), mapi, f)
}
}
- return funcs
+ return valueMessage, funcs
}
const (
@@ -81,7 +80,7 @@
mapValTagSize = 1 // field 2, tag size 2.
)
-func sizeMap(mapv reflect.Value, tagsize int, mapi *mapInfo, opts marshalOptions) int {
+func sizeMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) int {
if mapv.Len() == 0 {
return 0
}
@@ -92,19 +91,19 @@
keySize := mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
var valSize int
value := mapi.conv.valConv.PBValueOf(iter.Value())
- if mapi.valMessageInfo == nil {
+ if f.mi == nil {
valSize = mapi.valFuncs.size(value, mapValTagSize, opts)
} else {
p := pointerOfValue(iter.Value())
valSize += mapValTagSize
- valSize += wire.SizeBytes(mapi.valMessageInfo.sizePointer(p, opts))
+ valSize += wire.SizeBytes(f.mi.sizePointer(p, opts))
}
- n += tagsize + wire.SizeBytes(keySize+valSize)
+ n += f.tagsize + wire.SizeBytes(keySize+valSize)
}
return n
}
-func consumeMap(b []byte, mapv reflect.Value, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
+func consumeMap(b []byte, mapv reflect.Value, wtyp wire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
if wtyp != wire.BytesType {
return out, errUnknown
}
@@ -161,7 +160,7 @@
return out, nil
}
-func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
+func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp wire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
if wtyp != wire.BytesType {
return out, errUnknown
}
@@ -171,7 +170,7 @@
}
var (
key = mapi.keyZero
- val = reflect.New(mapi.valMessageInfo.GoReflectType.Elem())
+ val = reflect.New(f.mi.GoReflectType.Elem())
)
for len(b) > 0 {
num, wtyp, n := wire.ConsumeTag(b)
@@ -203,7 +202,7 @@
return out, wire.ParseError(n)
}
var o unmarshalOutput
- o, err = mapi.valMessageInfo.unmarshalPointer(v, pointerOfValue(val), 0, opts)
+ o, err = f.mi.unmarshalPointer(v, pointerOfValue(val), 0, opts)
if o.initialized {
// Consider this map item initialized so long as we see
// an initialized value.
@@ -225,8 +224,8 @@
return out, nil
}
-func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, opts marshalOptions) ([]byte, error) {
- if mapi.valMessageInfo == nil {
+func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
+ if f.mi == nil {
key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
val := mapi.conv.valConv.PBValueOf(valrv)
size := 0
@@ -241,7 +240,7 @@
} else {
key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
val := pointerOfValue(valrv)
- valSize := mapi.valMessageInfo.sizePointer(val, opts)
+ valSize := f.mi.sizePointer(val, opts)
size := 0
size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
size += mapValTagSize + wire.SizeBytes(valSize)
@@ -252,22 +251,22 @@
}
b = wire.AppendVarint(b, mapi.valWiretag)
b = wire.AppendVarint(b, uint64(valSize))
- return mapi.valMessageInfo.marshalAppendPointer(b, val, opts)
+ return f.mi.marshalAppendPointer(b, val, opts)
}
}
-func appendMap(b []byte, mapv reflect.Value, wiretag uint64, mapi *mapInfo, opts marshalOptions) ([]byte, error) {
+func appendMap(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
if mapv.Len() == 0 {
return b, nil
}
if opts.Deterministic() {
- return appendMapDeterministic(b, mapv, wiretag, mapi, opts)
+ return appendMapDeterministic(b, mapv, mapi, f, opts)
}
iter := mapRange(mapv)
for iter.Next() {
var err error
- b = wire.AppendVarint(b, wiretag)
- b, err = appendMapItem(b, iter.Key(), iter.Value(), mapi, opts)
+ b = wire.AppendVarint(b, f.wiretag)
+ b, err = appendMapItem(b, iter.Key(), iter.Value(), mapi, f, opts)
if err != nil {
return b, err
}
@@ -275,7 +274,7 @@
return b, nil
}
-func appendMapDeterministic(b []byte, mapv reflect.Value, wiretag uint64, mapi *mapInfo, opts marshalOptions) ([]byte, error) {
+func appendMapDeterministic(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
keys := mapv.MapKeys()
sort.Slice(keys, func(i, j int) bool {
switch keys[i].Kind() {
@@ -295,8 +294,8 @@
})
for _, key := range keys {
var err error
- b = wire.AppendVarint(b, wiretag)
- b, err = appendMapItem(b, key, mapv.MapIndex(key), mapi, opts)
+ b = wire.AppendVarint(b, f.wiretag)
+ b, err = appendMapItem(b, key, mapv.MapIndex(key), mapi, f, opts)
if err != nil {
return b, err
}
@@ -304,8 +303,8 @@
return b, nil
}
-func isInitMap(mapv reflect.Value, mapi *mapInfo) error {
- if mi := mapi.valMessageInfo; mi != nil {
+func isInitMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo) error {
+ if mi := f.mi; mi != nil {
mi.init()
if !mi.needsInitCheck {
return nil