internal/impl: change unmarshal func return to unmarshalOptions
The fast-path unmarshal funcs return the number of bytes consumed.
Change these functions to return an unmarshalOutput struct instead, to
make it easier to add to the results. This is groundwork for allowing
the fast-path unmarshaler to indicate when the unmarshaled message is
known to be initialized.
Change-Id: Ia8c44731a88f5be969a55cd98ea26282f412c7ae
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/215720
Reviewed-by: Joe Tsai <joetsai@google.com>
diff --git a/internal/impl/codec_field.go b/internal/impl/codec_field.go
index 3f3957c..f1f0671 100644
--- a/internal/impl/codec_field.go
+++ b/internal/impl/codec_field.go
@@ -52,7 +52,7 @@
if funcs.isInit != nil {
needIsInit = true
}
- cf.funcs.unmarshal = func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+ cf.funcs.unmarshal = func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (unmarshalOutput, error) {
var vw reflect.Value // pointer to wrapper type
vi := p.AsValueOf(ft).Elem() // oneof field value of interface kind
if !vi.IsNil() && !vi.Elem().IsNil() && vi.Elem().Elem().Type() == ot {
@@ -60,12 +60,12 @@
} else {
vw = reflect.New(ot)
}
- n, err := funcs.unmarshal(b, pointerOfValue(vw).Apply(zeroOffset), wtyp, opts)
+ out, err := funcs.unmarshal(b, pointerOfValue(vw).Apply(zeroOffset), wtyp, opts)
if err != nil {
- return 0, err
+ return out, err
}
vi.Set(vw)
- return n, nil
+ return out, nil
}
}
getInfo := func(p pointer) (pointer, *oneofFieldInfo) {
@@ -139,13 +139,13 @@
}
return appendMessage(b, m, wiretag, opts)
},
- unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+ unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (unmarshalOutput, error) {
fs := p.WeakFields()
m, ok := fs.get(num)
if !ok {
lazyInit()
if messageType == nil {
- return 0, errUnknown
+ return unmarshalOutput{}, errUnknown
}
m = messageType.New().Interface()
fs.set(num, m)
@@ -171,7 +171,7 @@
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
return appendMessageInfo(b, p, wiretag, mi, opts)
},
- unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+ unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (unmarshalOutput, error) {
return consumeMessageInfo(b, p, mi, wtyp, opts)
},
}
@@ -191,7 +191,7 @@
m := asMessage(p.AsValueOf(ft).Elem())
return appendMessage(b, m, wiretag, opts)
},
- unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+ unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (unmarshalOutput, error) {
mp := p.AsValueOf(ft).Elem()
if mp.IsNil() {
mp.Set(reflect.New(ft.Elem()))
@@ -216,21 +216,22 @@
return mi.marshalAppendPointer(b, p.Elem(), opts)
}
-func consumeMessageInfo(b []byte, p pointer, mi *MessageInfo, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+func consumeMessageInfo(b []byte, p pointer, mi *MessageInfo, wtyp wire.Type, opts unmarshalOptions) (out unmarshalOutput, err error) {
if wtyp != wire.BytesType {
- return 0, errUnknown
+ return out, errUnknown
}
v, n := wire.ConsumeBytes(b)
if n < 0 {
- return 0, wire.ParseError(n)
+ return out, wire.ParseError(n)
}
if p.Elem().IsNil() {
p.SetPointer(pointerOfValue(reflect.New(mi.GoReflectType.Elem())))
}
if _, err := mi.unmarshalPointer(v, p.Elem(), 0, opts); err != nil {
- return 0, err
+ return out, err
}
- return n, nil
+ out.n = n
+ return out, nil
}
func sizeMessage(m proto.Message, tagsize int, _ marshalOptions) int {
@@ -243,18 +244,19 @@
return opts.Options().MarshalAppend(b, m)
}
-func consumeMessage(b []byte, m proto.Message, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+func consumeMessage(b []byte, m proto.Message, wtyp wire.Type, opts unmarshalOptions) (out unmarshalOutput, err error) {
if wtyp != wire.BytesType {
- return 0, errUnknown
+ return out, errUnknown
}
v, n := wire.ConsumeBytes(b)
if n < 0 {
- return 0, wire.ParseError(n)
+ return out, wire.ParseError(n)
}
if err := opts.Options().Unmarshal(v, m); err != nil {
- return 0, err
+ return out, err
}
- return n, nil
+ out.n = n
+ return out, nil
}
func sizeMessageValue(v pref.Value, tagsize int, opts marshalOptions) int {
@@ -267,10 +269,10 @@
return appendMessage(b, m, wiretag, opts)
}
-func consumeMessageValue(b []byte, v pref.Value, _ wire.Number, wtyp wire.Type, opts unmarshalOptions) (pref.Value, int, error) {
+func consumeMessageValue(b []byte, v pref.Value, _ wire.Number, wtyp wire.Type, opts unmarshalOptions) (pref.Value, unmarshalOutput, error) {
m := v.Message().Interface()
- n, err := consumeMessage(b, m, wtyp, opts)
- return v, n, err
+ out, err := consumeMessage(b, m, wtyp, opts)
+ return v, out, err
}
func isInitMessageValue(v pref.Value) error {
@@ -295,10 +297,10 @@
return appendGroup(b, m, wiretag, opts)
}
-func consumeGroupValue(b []byte, v pref.Value, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (pref.Value, int, error) {
+func consumeGroupValue(b []byte, v pref.Value, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (pref.Value, unmarshalOutput, error) {
m := v.Message().Interface()
- n, err := consumeGroup(b, m, num, wtyp, opts)
- return v, n, err
+ out, err := consumeGroup(b, m, num, wtyp, opts)
+ return v, out, err
}
var coderGroupValue = valueCoderFuncs{
@@ -318,7 +320,7 @@
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
return appendGroupType(b, p, wiretag, mi, opts)
},
- unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+ unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (unmarshalOutput, error) {
return consumeGroupType(b, p, mi, num, wtyp, opts)
},
}
@@ -338,7 +340,7 @@
m := asMessage(p.AsValueOf(ft).Elem())
return appendGroup(b, m, wiretag, opts)
},
- unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+ unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (unmarshalOutput, error) {
mp := p.AsValueOf(ft).Elem()
if mp.IsNil() {
mp.Set(reflect.New(ft.Elem()))
@@ -364,9 +366,9 @@
return b, err
}
-func consumeGroupType(b []byte, p pointer, mi *MessageInfo, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+func consumeGroupType(b []byte, p pointer, mi *MessageInfo, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (out unmarshalOutput, err error) {
if wtyp != wire.StartGroupType {
- return 0, errUnknown
+ return out, errUnknown
}
if p.Elem().IsNil() {
p.SetPointer(pointerOfValue(reflect.New(mi.GoReflectType.Elem())))
@@ -385,15 +387,16 @@
return b, err
}
-func consumeGroup(b []byte, m proto.Message, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+func consumeGroup(b []byte, m proto.Message, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (out unmarshalOutput, err error) {
if wtyp != wire.StartGroupType {
- return 0, errUnknown
+ return out, errUnknown
}
b, n := wire.ConsumeGroup(num, b)
if n < 0 {
- return 0, wire.ParseError(n)
+ return out, wire.ParseError(n)
}
- return n, opts.Options().Unmarshal(b, m)
+ out.n = n
+ return out, opts.Options().Unmarshal(b, m)
}
func makeMessageSliceFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
@@ -405,7 +408,7 @@
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
return appendMessageSliceInfo(b, p, wiretag, mi, opts)
},
- unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+ unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (unmarshalOutput, error) {
return consumeMessageSliceInfo(b, p, mi, wtyp, opts)
},
}
@@ -423,7 +426,7 @@
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
return appendMessageSlice(b, p, wiretag, ft, opts)
},
- unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+ unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (unmarshalOutput, error) {
return consumeMessageSlice(b, p, ft, wtyp, opts)
},
isInit: func(p pointer) error {
@@ -456,21 +459,22 @@
return b, nil
}
-func consumeMessageSliceInfo(b []byte, p pointer, mi *MessageInfo, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+func consumeMessageSliceInfo(b []byte, p pointer, mi *MessageInfo, wtyp wire.Type, opts unmarshalOptions) (out unmarshalOutput, err error) {
if wtyp != wire.BytesType {
- return 0, errUnknown
+ return out, errUnknown
}
v, n := wire.ConsumeBytes(b)
if n < 0 {
- return 0, wire.ParseError(n)
+ return out, wire.ParseError(n)
}
m := reflect.New(mi.GoReflectType.Elem()).Interface()
mp := pointerOfIface(m)
if _, err := mi.unmarshalPointer(v, mp, 0, opts); err != nil {
- return 0, err
+ return out, err
}
p.AppendPointerSlice(mp)
- return n, nil
+ out.n = n
+ return out, nil
}
func isInitMessageSliceInfo(p pointer, mi *MessageInfo) error {
@@ -509,20 +513,21 @@
return b, nil
}
-func consumeMessageSlice(b []byte, p pointer, goType reflect.Type, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+func consumeMessageSlice(b []byte, p pointer, goType reflect.Type, wtyp wire.Type, opts unmarshalOptions) (out unmarshalOutput, err error) {
if wtyp != wire.BytesType {
- return 0, errUnknown
+ return out, errUnknown
}
v, n := wire.ConsumeBytes(b)
if n < 0 {
- return 0, wire.ParseError(n)
+ return out, wire.ParseError(n)
}
mp := reflect.New(goType.Elem())
if err := opts.Options().Unmarshal(v, asMessage(mp)); err != nil {
- return 0, err
+ return out, err
}
p.AppendPointerSlice(pointerOfValue(mp))
- return n, nil
+ out.n = n
+ return out, nil
}
func isInitMessageSlice(p pointer, goType reflect.Type) error {
@@ -565,21 +570,22 @@
return b, nil
}
-func consumeMessageSliceValue(b []byte, listv pref.Value, _ wire.Number, wtyp wire.Type, opts unmarshalOptions) (pref.Value, int, error) {
+func consumeMessageSliceValue(b []byte, listv pref.Value, _ wire.Number, wtyp wire.Type, opts unmarshalOptions) (_ pref.Value, out unmarshalOutput, err error) {
list := listv.List()
if wtyp != wire.BytesType {
- return pref.Value{}, 0, errUnknown
+ return pref.Value{}, out, errUnknown
}
v, n := wire.ConsumeBytes(b)
if n < 0 {
- return pref.Value{}, 0, wire.ParseError(n)
+ return pref.Value{}, out, wire.ParseError(n)
}
m := list.NewElement()
if err := opts.Options().Unmarshal(v, m.Message().Interface()); err != nil {
- return pref.Value{}, 0, err
+ return pref.Value{}, out, err
}
list.Append(m)
- return listv, n, nil
+ out.n = n
+ return listv, out, nil
}
func isInitMessageSliceValue(listv pref.Value) error {
@@ -626,21 +632,22 @@
return b, nil
}
-func consumeGroupSliceValue(b []byte, listv pref.Value, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (pref.Value, int, error) {
+func consumeGroupSliceValue(b []byte, listv pref.Value, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (_ pref.Value, out unmarshalOutput, err error) {
list := listv.List()
if wtyp != wire.StartGroupType {
- return pref.Value{}, 0, errUnknown
+ return pref.Value{}, out, errUnknown
}
b, n := wire.ConsumeGroup(num, b)
if n < 0 {
- return pref.Value{}, 0, wire.ParseError(n)
+ return pref.Value{}, out, wire.ParseError(n)
}
m := list.NewElement()
if err := opts.Options().Unmarshal(b, m.Message().Interface()); err != nil {
- return pref.Value{}, 0, err
+ return pref.Value{}, out, err
}
list.Append(m)
- return listv, n, nil
+ out.n = n
+ return listv, out, nil
}
var coderGroupSliceValue = valueCoderFuncs{
@@ -660,7 +667,7 @@
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
return appendGroupSliceInfo(b, p, wiretag, mi, opts)
},
- unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+ unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (unmarshalOutput, error) {
return consumeGroupSliceInfo(b, p, num, wtyp, mi, opts)
},
}
@@ -678,7 +685,7 @@
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
return appendGroupSlice(b, p, wiretag, ft, opts)
},
- unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+ unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (unmarshalOutput, error) {
return consumeGroupSlice(b, p, num, wtyp, ft, opts)
},
isInit: func(p pointer) error {
@@ -712,20 +719,21 @@
return b, nil
}
-func consumeGroupSlice(b []byte, p pointer, num wire.Number, wtyp wire.Type, goType reflect.Type, opts unmarshalOptions) (int, error) {
+func consumeGroupSlice(b []byte, p pointer, num wire.Number, wtyp wire.Type, goType reflect.Type, opts unmarshalOptions) (out unmarshalOutput, err error) {
if wtyp != wire.StartGroupType {
- return 0, errUnknown
+ return out, errUnknown
}
b, n := wire.ConsumeGroup(num, b)
if n < 0 {
- return 0, wire.ParseError(n)
+ return out, wire.ParseError(n)
}
mp := reflect.New(goType.Elem())
if err := opts.Options().Unmarshal(b, asMessage(mp)); err != nil {
- return 0, err
+ return out, err
}
p.AppendPointerSlice(pointerOfValue(mp))
- return n, nil
+ out.n = n
+ return out, nil
}
func sizeGroupSliceInfo(p pointer, mi *MessageInfo, tagsize int, opts marshalOptions) int {
@@ -751,18 +759,18 @@
return b, nil
}
-func consumeGroupSliceInfo(b []byte, p pointer, num wire.Number, wtyp wire.Type, mi *MessageInfo, opts unmarshalOptions) (int, error) {
+func consumeGroupSliceInfo(b []byte, p pointer, num wire.Number, wtyp wire.Type, mi *MessageInfo, opts unmarshalOptions) (unmarshalOutput, error) {
if wtyp != wire.StartGroupType {
- return 0, errUnknown
+ return unmarshalOutput{}, errUnknown
}
m := reflect.New(mi.GoReflectType.Elem()).Interface()
mp := pointerOfIface(m)
- n, err := mi.unmarshalPointer(b, mp, num, opts)
+ out, err := mi.unmarshalPointer(b, mp, num, opts)
if err != nil {
- return 0, err
+ return out, err
}
p.AppendPointerSlice(mp)
- return n, nil
+ return out, nil
}
func asMessage(v reflect.Value) pref.ProtoMessage {