internal/impl: add fast-path unmarshal
Benchmarks run with:
go test ./benchmarks/ -bench=Wire -benchtime=500ms -benchmem -count=8
Fast-path vs. parent commit:
name old time/op new time/op delta
Wire/Unmarshal/google_message1_proto2-12 1.35µs ± 2% 0.45µs ± 4% -67.01% (p=0.000 n=8+8)
Wire/Unmarshal/google_message1_proto3-12 1.07µs ± 1% 0.31µs ± 1% -71.04% (p=0.000 n=8+8)
Wire/Unmarshal/google_message2-12 691µs ± 2% 188µs ± 2% -72.78% (p=0.000 n=7+8)
name old allocs/op new allocs/op delta
Wire/Unmarshal/google_message1_proto2-12 60.0 ± 0% 25.0 ± 0% -58.33% (p=0.000 n=8+8)
Wire/Unmarshal/google_message1_proto3-12 42.0 ± 0% 7.0 ± 0% -83.33% (p=0.000 n=8+8)
Wire/Unmarshal/google_message2-12 28.6k ± 0% 8.5k ± 0% -70.34% (p=0.000 n=8+8)
Fast-path vs. -v1:
name old time/op new time/op delta
Wire/Unmarshal/google_message1_proto2-12 702ns ± 1% 445ns ± 4% -36.58% (p=0.000 n=8+8)
Wire/Unmarshal/google_message1_proto3-12 604ns ± 1% 311ns ± 1% -48.54% (p=0.000 n=8+8)
Wire/Unmarshal/google_message2-12 179µs ± 3% 188µs ± 2% +5.30% (p=0.000 n=7+8)
name old allocs/op new allocs/op delta
Wire/Unmarshal/google_message1_proto2-12 26.0 ± 0% 25.0 ± 0% -3.85% (p=0.000 n=8+8)
Wire/Unmarshal/google_message1_proto3-12 8.00 ± 0% 7.00 ± 0% -12.50% (p=0.000 n=8+8)
Wire/Unmarshal/google_message2-12 8.49k ± 0% 8.49k ± 0% -0.01% (p=0.000 n=8+8)
Change-Id: I6247ac3fd66a63d9acb902cbd192094ee3d151c3
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/185147
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/internal/impl/codec_map.go b/internal/impl/codec_map.go
index 5a02c34..53a860b 100644
--- a/internal/impl/codec_map.go
+++ b/internal/impl/codec_map.go
@@ -16,6 +16,17 @@
var protoMessageType = reflect.TypeOf((*proto.Message)(nil)).Elem()
+type mapInfo struct {
+ goType reflect.Type
+ keyWiretag uint64
+ valWiretag uint64
+ keyFuncs ifaceCoderFuncs
+ valFuncs ifaceCoderFuncs
+ keyZero interface{}
+ valZero interface{}
+ newVal func() interface{}
+}
+
func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (funcs pointerCoderFuncs) {
// TODO: Consider generating specialized map coders.
keyField := fd.MapKey()
@@ -25,6 +36,22 @@
keyFuncs := encoderFuncsForValue(keyField, ft.Key())
valFuncs := encoderFuncsForValue(valField, ft.Elem())
+ mapi := &mapInfo{
+ goType: ft,
+ keyWiretag: keyWiretag,
+ valWiretag: valWiretag,
+ keyFuncs: keyFuncs,
+ valFuncs: valFuncs,
+ keyZero: reflect.Zero(ft.Key()).Interface(),
+ valZero: reflect.Zero(ft.Elem()).Interface(),
+ }
+ switch valField.Kind() {
+ case pref.GroupKind, pref.MessageKind:
+ mapi.newVal = func() interface{} {
+ return reflect.New(ft.Elem().Elem()).Interface()
+ }
+ }
+
funcs = pointerCoderFuncs{
size: func(p pointer, tagsize int, opts marshalOptions) int {
return sizeMap(p, tagsize, ft, keyFuncs, valFuncs, opts)
@@ -32,6 +59,9 @@
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
return appendMap(b, p, wiretag, keyWiretag, valWiretag, ft, keyFuncs, valFuncs, opts)
},
+ unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+ return consumeMap(b, p, wtyp, mapi, opts)
+ },
}
if valFuncs.isInit != nil {
funcs.isInit = func(p pointer) error {
@@ -46,6 +76,64 @@
mapValTagSize = 1 // field 2, tag size 2.
)
+func consumeMap(b []byte, p pointer, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (int, error) {
+ mp := p.AsValueOf(mapi.goType)
+ if mp.Elem().IsNil() {
+ mp.Elem().Set(reflect.MakeMap(mapi.goType))
+ }
+ m := mp.Elem()
+
+ if wtyp != wire.BytesType {
+ return 0, errUnknown
+ }
+ b, n := wire.ConsumeBytes(b)
+ if n < 0 {
+ return 0, wire.ParseError(n)
+ }
+ var (
+ key = mapi.keyZero
+ val = mapi.valZero
+ )
+ if mapi.newVal != nil {
+ val = mapi.newVal()
+ }
+ for len(b) > 0 {
+ num, wtyp, n := wire.ConsumeTag(b)
+ if n < 0 {
+ return 0, wire.ParseError(n)
+ }
+ b = b[n:]
+ err := errUnknown
+ switch num {
+ case 1:
+ var v interface{}
+ v, n, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
+ if err != nil {
+ break
+ }
+ key = v
+ case 2:
+ var v interface{}
+ v, n, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
+ if err != nil {
+ break
+ }
+ val = v
+ }
+ if err == errUnknown {
+ n = wire.ConsumeFieldValue(num, wtyp, b)
+ if n < 0 {
+ return 0, wire.ParseError(n)
+ }
+ } else if err != nil {
+ return 0, err
+ }
+ b = b[n:]
+ }
+ m.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(val))
+ return n, nil
+}
+
func sizeMap(p pointer, tagsize int, goType reflect.Type, keyFuncs, valFuncs ifaceCoderFuncs, opts marshalOptions) int {
m := p.AsValueOf(goType).Elem()
n := 0