internal/impl: change unmarshal func return to unmarshalOptions
The fast-path unmarshal funcs return the number of bytes consumed.
Change these functions to return an unmarshalOutput struct instead, to
make it easier to add to the results. This is groundwork for allowing
the fast-path unmarshaler to indicate when the unmarshaled message is
known to be initialized.
Change-Id: Ia8c44731a88f5be969a55cd98ea26282f412c7ae
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/215720
Reviewed-by: Joe Tsai <joetsai@google.com>
diff --git a/internal/impl/decode.go b/internal/impl/decode.go
index 4d3718f..84d31ac 100644
--- a/internal/impl/decode.go
+++ b/internal/impl/decode.go
@@ -57,6 +57,10 @@
func (o unmarshalOptions) DiscardUnknown() bool { return o.flags&unmarshalDiscardUnknown != 0 }
func (o unmarshalOptions) Resolver() preg.ExtensionTypeResolver { return o.resolver }
+type unmarshalOutput struct {
+ n int // number of bytes consumed
+}
+
// unmarshal is protoreflect.Methods.Unmarshal.
func (mi *MessageInfo) unmarshal(m pref.Message, in piface.UnmarshalInput) (piface.UnmarshalOutput, error) {
var p pointer
@@ -77,7 +81,7 @@
// This is a sentinel error which should never be visible to the user.
var errUnknown = errors.New("unknown")
-func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag wire.Number, opts unmarshalOptions) (int, error) {
+func (mi *MessageInfo) unmarshalPointer(b []byte, p pointer, groupTag wire.Number, opts unmarshalOptions) (out unmarshalOutput, err error) {
mi.init()
if flags.ProtoLegacy && mi.isMessageSet {
return unmarshalMessageSet(mi, b, p, opts)
@@ -89,18 +93,19 @@
// TODO: inline 1 and 2 byte variants?
num, wtyp, n := wire.ConsumeTag(b)
if n < 0 {
- return 0, wire.ParseError(n)
+ return out, wire.ParseError(n)
}
if num > wire.MaxValidNumber {
- return 0, errors.New("invalid field number")
+ return out, errors.New("invalid field number")
}
b = b[n:]
if wtyp == wire.EndGroupType {
if num != groupTag {
- return 0, errors.New("mismatching end group marker")
+ return out, errors.New("mismatching end group marker")
}
- return start - len(b), nil
+ out.n = start - len(b)
+ return out, nil
}
var f *coderFieldInfo
@@ -115,7 +120,9 @@
if f.funcs.unmarshal == nil {
break
}
- n, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, opts)
+ var o unmarshalOutput
+ o, err = f.funcs.unmarshal(b, p.Apply(f.offset), wtyp, opts)
+ n = o.n
default:
// Possible extension.
if exts == nil && mi.extensionOffset.IsValid() {
@@ -127,15 +134,17 @@
if exts == nil {
break
}
- n, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts)
+ var o unmarshalOutput
+ o, err = mi.unmarshalExtension(b, num, wtyp, *exts, opts)
+ n = o.n
}
if err != nil {
if err != errUnknown {
- return 0, err
+ return out, err
}
n = wire.ConsumeFieldValue(num, wtyp, b)
if n < 0 {
- return 0, wire.ParseError(n)
+ return out, wire.ParseError(n)
}
if mi.unknownOffset.IsValid() {
u := p.Apply(mi.unknownOffset).Bytes()
@@ -146,12 +155,13 @@
b = b[n:]
}
if groupTag != 0 {
- return 0, errors.New("missing end group marker")
+ return out, errors.New("missing end group marker")
}
- return start, nil
+ out.n = start
+ return out, nil
}
-func (mi *MessageInfo) unmarshalExtension(b []byte, num wire.Number, wtyp wire.Type, exts map[int32]ExtensionField, opts unmarshalOptions) (n int, err error) {
+func (mi *MessageInfo) unmarshalExtension(b []byte, num wire.Number, wtyp wire.Type, exts map[int32]ExtensionField, opts unmarshalOptions) (out unmarshalOutput, err error) {
x := exts[int32(num)]
xt := x.Type()
if xt == nil {
@@ -159,14 +169,14 @@
xt, err = opts.Resolver().FindExtensionByNumber(mi.Desc.FullName(), num)
if err != nil {
if err == preg.NotFound {
- return 0, errUnknown
+ return out, errUnknown
}
- return 0, err
+ return out, err
}
}
xi := getExtensionFieldInfo(xt)
if xi.funcs.unmarshal == nil {
- return 0, errUnknown
+ return out, errUnknown
}
ival := x.Value()
if !ival.IsValid() && xi.unmarshalNeedsValue {
@@ -175,11 +185,11 @@
// concrete type.
ival = xt.New()
}
- v, n, err := xi.funcs.unmarshal(b, ival, num, wtyp, opts)
+ v, out, err := xi.funcs.unmarshal(b, ival, num, wtyp, opts)
if err != nil {
- return 0, err
+ return out, err
}
x.Set(xt, v)
exts[int32(num)] = x
- return n, nil
+ return out, nil
}