internal/impl: add fast-path unmarshal
Benchmarks run with:
go test ./benchmarks/ -bench=Wire -benchtime=500ms -benchmem -count=8
Fast-path vs. parent commit:
name old time/op new time/op delta
Wire/Unmarshal/google_message1_proto2-12 1.35µs ± 2% 0.45µs ± 4% -67.01% (p=0.000 n=8+8)
Wire/Unmarshal/google_message1_proto3-12 1.07µs ± 1% 0.31µs ± 1% -71.04% (p=0.000 n=8+8)
Wire/Unmarshal/google_message2-12 691µs ± 2% 188µs ± 2% -72.78% (p=0.000 n=7+8)
name old allocs/op new allocs/op delta
Wire/Unmarshal/google_message1_proto2-12 60.0 ± 0% 25.0 ± 0% -58.33% (p=0.000 n=8+8)
Wire/Unmarshal/google_message1_proto3-12 42.0 ± 0% 7.0 ± 0% -83.33% (p=0.000 n=8+8)
Wire/Unmarshal/google_message2-12 28.6k ± 0% 8.5k ± 0% -70.34% (p=0.000 n=8+8)
Fast-path vs. -v1:
name old time/op new time/op delta
Wire/Unmarshal/google_message1_proto2-12 702ns ± 1% 445ns ± 4% -36.58% (p=0.000 n=8+8)
Wire/Unmarshal/google_message1_proto3-12 604ns ± 1% 311ns ± 1% -48.54% (p=0.000 n=8+8)
Wire/Unmarshal/google_message2-12 179µs ± 3% 188µs ± 2% +5.30% (p=0.000 n=7+8)
name old allocs/op new allocs/op delta
Wire/Unmarshal/google_message1_proto2-12 26.0 ± 0% 25.0 ± 0% -3.85% (p=0.000 n=8+8)
Wire/Unmarshal/google_message1_proto3-12 8.00 ± 0% 7.00 ± 0% -12.50% (p=0.000 n=8+8)
Wire/Unmarshal/google_message2-12 8.49k ± 0% 8.49k ± 0% -0.01% (p=0.000 n=8+8)
Change-Id: I6247ac3fd66a63d9acb902cbd192094ee3d151c3
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/185147
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/internal/impl/decode.go b/internal/impl/decode.go
new file mode 100644
index 0000000..5fdffe3
--- /dev/null
+++ b/internal/impl/decode.go
@@ -0,0 +1,162 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package impl
+
+import (
+ "google.golang.org/protobuf/internal/encoding/wire"
+ "google.golang.org/protobuf/internal/errors"
+ "google.golang.org/protobuf/proto"
+ pref "google.golang.org/protobuf/reflect/protoreflect"
+ preg "google.golang.org/protobuf/reflect/protoregistry"
+ piface "google.golang.org/protobuf/runtime/protoiface"
+)
+
+// unmarshalOptions is a more efficient representation of UnmarshalOptions.
+//
+// We don't preserve the AllowPartial flag, because fast-path (un)marshal
+// operations always allow partial messages.
+type unmarshalOptions struct {
+ flags unmarshalOptionFlags
+ resolver preg.ExtensionTypeResolver
+}
+
+type unmarshalOptionFlags uint8
+
+const (
+ unmarshalDiscardUnknown unmarshalOptionFlags = 1 << iota
+)
+
+func newUnmarshalOptions(opts piface.UnmarshalOptions) unmarshalOptions {
+ o := unmarshalOptions{
+ resolver: opts.Resolver,
+ }
+ if opts.DiscardUnknown {
+ o.flags |= unmarshalDiscardUnknown
+ }
+ return o
+}
+
+func (o unmarshalOptions) Options() proto.UnmarshalOptions {
+ return proto.UnmarshalOptions{
+ AllowPartial: true,
+ DiscardUnknown: o.DiscardUnknown(),
+ Resolver: o.Resolver(),
+ }
+}
+
+func (o unmarshalOptions) DiscardUnknown() bool { return o.flags&unmarshalDiscardUnknown != 0 }
+func (o unmarshalOptions) Resolver() preg.ExtensionTypeResolver { return o.resolver }
+
+// unmarshal is protoreflect.Methods.Unmarshal.
+func (mi *MessageInfo) unmarshal(b []byte, m pref.ProtoMessage, opts piface.UnmarshalOptions) error {
+ _, err := mi.unmarshalPointer(b, pointerOfIface(m), 0, newUnmarshalOptions(opts))
+ return err
+}
+
+// errUnknown is returned during unmarshaling to indicate a parse error that
+// should result in a field being placed in the unknown fields section (for example,
+// when the wire type doesn't match) as opposed to the entire unmarshal operation
+// failing (for example, when a field extends past the available input).
+//
+// This is a sentinel error which should never be visible to the user.
+var errUnknown = errors.New("unknown")
+
+func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag wire.Number, opts unmarshalOptions) (int, error) {
+ mi.init()
+ var exts *map[int32]ExtensionField
+ start := len(b)
+ for len(b) > 0 {
+ // Parse the tag (field number and wire type).
+ // TODO: inline 1 and 2 byte variants?
+ num, wtyp, n := wire.ConsumeTag(b)
+ if n < 0 {
+ return 0, wire.ParseError(n)
+ }
+ b = b[n:]
+
+ var f *coderFieldInfo
+ if int(num) < len(mi.denseCoderFields) {
+ f = mi.denseCoderFields[num]
+ } else {
+ f = mi.coderFields[num]
+ }
+ err := errUnknown
+ switch {
+ case f != nil:
+ if f.funcs.unmarshal == nil {
+ break
+ }
+ n, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, opts)
+ case num == groupTag && wtyp == wire.EndGroupType:
+ // End of group.
+ return start - len(b), nil
+ default:
+ // Possible extension.
+ if exts == nil && mi.extensionOffset.IsValid() {
+ exts = p.Apply(mi.extensionOffset).Extensions()
+ if *exts == nil {
+ *exts = make(map[int32]ExtensionField)
+ }
+ }
+ if exts == nil {
+ break
+ }
+ n, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts)
+ }
+ if err != nil {
+ if err != errUnknown {
+ return 0, err
+ }
+ n = wire.ConsumeFieldValue(num, wtyp, b)
+ if n < 0 {
+ return 0, wire.ParseError(n)
+ }
+ if mi.unknownOffset.IsValid() {
+ u := p.Apply(mi.unknownOffset).Bytes()
+ *u = wire.AppendTag(*u, num, wtyp)
+ *u = append(*u, b[:n]...)
+ }
+ }
+ b = b[n:]
+ }
+ if groupTag != 0 {
+ return 0, errors.New("missing end group marker")
+ }
+ return start, nil
+}
+
+func (mi *MessageInfo) unmarshalExtension(b []byte, num wire.Number, wtyp wire.Type, exts map[int32]ExtensionField, opts unmarshalOptions) (n int, err error) {
+ x := exts[int32(num)]
+ xt := x.GetType()
+ if xt == nil {
+ var err error
+ xt, err = opts.Resolver().FindExtensionByNumber(mi.PBType.FullName(), num)
+ if err != nil {
+ if err == preg.NotFound {
+ return 0, errUnknown
+ }
+ return 0, err
+ }
+ x.SetType(xt)
+ }
+ xi := mi.extensionFieldInfo(xt)
+ if xi.funcs.unmarshal == nil {
+ return 0, errUnknown
+ }
+ ival := x.GetValue()
+ if ival == nil && xi.unmarshalNeedsValue {
+ // Create a new message, list, or map value to fill in.
+ // For enums, create a prototype value to let the unmarshal func know the
+ // concrete type.
+ ival = xt.InterfaceOf(xt.New())
+ }
+ v, n, err := xi.funcs.unmarshal(b, ival, num, wtyp, opts)
+ if err != nil {
+ return 0, err
+ }
+ x.SetEagerValue(v)
+ exts[int32(num)] = x
+ return n, nil
+}