internal/impl: add fast-path marshal implementation
This is a port of the v1 table marshaler, with some substantial
cleanup and refactoring.
Benchstat results from the protobuf reference benchmark data comparing the
v1 package with v2, with AllowPartial:true set for the new package. This
is not an apples-to-apples comparison, since v1 doesn't have a way to
disable required field checks. Required field checks in v2 package
currently go through reflection, which performs terribly; my initial
experimentation indicates that fast-path required field checks will
not add a large amount of cost; these results are incomplete but not
wholly inaccurate.
name old time/op new time/op delta
/dataset.google_message3_1.pb/Marshal-12 219ms ± 1% 232ms ± 1% +5.85% (p=0.004 n=6+5)
/dataset.google_message2.pb/Marshal-12 261µs ± 3% 248µs ± 1% -5.14% (p=0.002 n=6+6)
/dataset.google_message1_proto2.pb/Marshal-12 681ns ± 2% 637ns ± 3% -6.53% (p=0.002 n=6+6)
/dataset.google_message1_proto3.pb/Marshal-12 1.10µs ± 8% 0.99µs ± 3% -9.63% (p=0.002 n=6+6)
/dataset.google_message3_3.pb/Marshal-12 44.2ms ± 3% 35.2ms ± 1% -20.28% (p=0.004 n=6+5)
/dataset.google_message4.pb/Marshal-12 91.4ms ± 2% 94.9ms ± 2% +3.78% (p=0.002 n=6+6)
/dataset.google_message3_2.pb/Marshal-12 78.7ms ± 6% 80.8ms ± 4% ~ (p=0.310 n=6+6)
/dataset.google_message3_4.pb/Marshal-12 10.6ms ± 3% 10.6ms ± 8% ~ (p=0.662 n=5+6)
/dataset.google_message3_5.pb/Marshal-12 675ms ± 4% 510ms ± 2% -24.40% (p=0.002 n=6+6)
/dataset.google_message3_1.pb/Marshal 219ms ± 1% 236ms ± 7% +8.06% (p=0.004 n=5+6)
/dataset.google_message2.pb/Marshal 257µs ± 1% 250µs ± 3% ~ (p=0.052 n=5+6)
/dataset.google_message1_proto2.pb/Marshal 685ns ± 1% 628ns ± 1% -8.41% (p=0.008 n=5+5)
/dataset.google_message1_proto3.pb/Marshal 1.08µs ± 1% 0.98µs ± 2% -9.31% (p=0.004 n=5+6)
/dataset.google_message3_3.pb/Marshal 43.7ms ± 1% 35.1ms ± 1% -19.76% (p=0.002 n=6+6)
/dataset.google_message4.pb/Marshal 93.4ms ± 4% 94.9ms ± 2% ~ (p=0.180 n=6+6)
/dataset.google_message3_2.pb/Marshal 105ms ± 2% 98ms ± 7% -6.81% (p=0.009 n=5+6)
/dataset.google_message3_4.pb/Marshal 16.3ms ± 6% 15.7ms ± 3% -3.44% (p=0.041 n=6+6)
/dataset.google_message3_5.pb/Marshal 676ms ± 4% 504ms ± 2% -25.50% (p=0.004 n=6+5)
Change-Id: I72cc4597117f4cf5d236ef505777d49dd4a5f75d
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/171020
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/internal/impl/encode_map.go b/internal/impl/encode_map.go
new file mode 100644
index 0000000..38e152a
--- /dev/null
+++ b/internal/impl/encode_map.go
@@ -0,0 +1,143 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package impl
+
+import (
+ "fmt"
+ "reflect"
+ "sort"
+
+ "google.golang.org/protobuf/internal/encoding/wire"
+ "google.golang.org/protobuf/internal/errors"
+ "google.golang.org/protobuf/proto"
+ pref "google.golang.org/protobuf/reflect/protoreflect"
+)
+
+var protoMessageType = reflect.TypeOf((*proto.Message)(nil)).Elem()
+
+func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (funcs pointerCoderFuncs) {
+ // TODO: Consider generating specialized map coders.
+ keyField := fd.MapKey()
+ valField := fd.MapValue()
+ keyWiretag := wire.EncodeTag(1, wireTypes[keyField.Kind()])
+ valWiretag := wire.EncodeTag(2, wireTypes[valField.Kind()])
+ keyFuncs := encoderFuncsForValue(keyField, ft.Key())
+ valFuncs := encoderFuncsForValue(valField, ft.Elem())
+
+ return pointerCoderFuncs{
+ size: func(p pointer, tagsize int, opts marshalOptions) int {
+ return sizeMap(p, tagsize, ft, keyFuncs, valFuncs, opts)
+ },
+ marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
+ return appendMap(b, p, wiretag, keyWiretag, valWiretag, ft, keyFuncs, valFuncs, opts)
+ },
+ }
+}
+
+const (
+ mapKeyTagSize = 1 // field 1, tag size 1.
+ mapValTagSize = 1 // field 2, tag size 2.
+)
+
+func sizeMap(p pointer, tagsize int, goType reflect.Type, keyFuncs, valFuncs ifaceCoderFuncs, opts marshalOptions) int {
+ m := p.AsValueOf(goType).Elem()
+ n := 0
+ if m.Len() == 0 {
+ return 0
+ }
+ iter := mapRange(m)
+ for iter.Next() {
+ ki := iter.Key().Interface()
+ vi := iter.Value().Interface()
+ size := keyFuncs.size(ki, mapKeyTagSize, opts) + valFuncs.size(vi, mapValTagSize, opts)
+ n += wire.SizeBytes(size) + tagsize
+ }
+ return n
+}
+
+func appendMap(b []byte, p pointer, wiretag, keyWiretag, valWiretag uint64, goType reflect.Type, keyFuncs, valFuncs ifaceCoderFuncs, opts marshalOptions) ([]byte, error) {
+ m := p.AsValueOf(goType).Elem()
+ var nerr errors.NonFatal
+ var err error
+
+ if m.Len() == 0 {
+ return b, nil
+ }
+
+ if opts.Deterministic() {
+ keys := m.MapKeys()
+ sort.Sort(mapKeys(keys))
+ for _, k := range keys {
+ b, err = appendMapElement(b, k, m.MapIndex(k), wiretag, keyWiretag, valWiretag, keyFuncs, valFuncs, opts)
+ if !nerr.Merge(err) {
+ return b, err
+ }
+ }
+ return b, nerr.E
+ }
+
+ iter := mapRange(m)
+ for iter.Next() {
+ b, err = appendMapElement(b, iter.Key(), iter.Value(), wiretag, keyWiretag, valWiretag, keyFuncs, valFuncs, opts)
+ if !nerr.Merge(err) {
+ return b, err
+ }
+ }
+ return b, nerr.E
+}
+
+func appendMapElement(b []byte, key, value reflect.Value, wiretag, keyWiretag, valWiretag uint64, keyFuncs, valFuncs ifaceCoderFuncs, opts marshalOptions) ([]byte, error) {
+ ki := key.Interface()
+ vi := value.Interface()
+ b = wire.AppendVarint(b, wiretag)
+ size := keyFuncs.size(ki, mapKeyTagSize, opts) + valFuncs.size(vi, mapValTagSize, opts)
+ b = wire.AppendVarint(b, uint64(size))
+ var nerr errors.NonFatal
+ b, err := keyFuncs.marshal(b, ki, keyWiretag, opts)
+ if !nerr.Merge(err) {
+ return b, err
+ }
+ b, err = valFuncs.marshal(b, vi, valWiretag, opts)
+ if !nerr.Merge(err) {
+ return b, err
+ }
+ return b, nerr.E
+}
+
+// mapKeys returns a sort.Interface to be used for sorting the map keys.
+// Map fields may have key types of non-float scalars, strings and enums.
+func mapKeys(vs []reflect.Value) sort.Interface {
+ s := mapKeySorter{vs: vs}
+
+ // Type specialization per https://developers.google.com/protocol-buffers/docs/proto#maps.
+ if len(vs) == 0 {
+ return s
+ }
+ switch vs[0].Kind() {
+ case reflect.Int32, reflect.Int64:
+ s.less = func(a, b reflect.Value) bool { return a.Int() < b.Int() }
+ case reflect.Uint32, reflect.Uint64:
+ s.less = func(a, b reflect.Value) bool { return a.Uint() < b.Uint() }
+ case reflect.Bool:
+ s.less = func(a, b reflect.Value) bool { return !a.Bool() && b.Bool() } // false < true
+ case reflect.String:
+ s.less = func(a, b reflect.Value) bool { return a.String() < b.String() }
+ default:
+ panic(fmt.Sprintf("unsupported map key type: %v", vs[0].Kind()))
+ }
+
+ return s
+}
+
+type mapKeySorter struct {
+ vs []reflect.Value
+ less func(a, b reflect.Value) bool
+}
+
+func (s mapKeySorter) Len() int { return len(s.vs) }
+func (s mapKeySorter) Swap(i, j int) { s.vs[i], s.vs[j] = s.vs[j], s.vs[i] }
+func (s mapKeySorter) Less(i, j int) bool {
+ return s.less(s.vs[i], s.vs[j])
+}