internal/impl: add fast-path unmarshal
Benchmarks run with:
go test ./benchmarks/ -bench=Wire -benchtime=500ms -benchmem -count=8
Fast-path vs. parent commit:
name old time/op new time/op delta
Wire/Unmarshal/google_message1_proto2-12 1.35µs ± 2% 0.45µs ± 4% -67.01% (p=0.000 n=8+8)
Wire/Unmarshal/google_message1_proto3-12 1.07µs ± 1% 0.31µs ± 1% -71.04% (p=0.000 n=8+8)
Wire/Unmarshal/google_message2-12 691µs ± 2% 188µs ± 2% -72.78% (p=0.000 n=7+8)
name old allocs/op new allocs/op delta
Wire/Unmarshal/google_message1_proto2-12 60.0 ± 0% 25.0 ± 0% -58.33% (p=0.000 n=8+8)
Wire/Unmarshal/google_message1_proto3-12 42.0 ± 0% 7.0 ± 0% -83.33% (p=0.000 n=8+8)
Wire/Unmarshal/google_message2-12 28.6k ± 0% 8.5k ± 0% -70.34% (p=0.000 n=8+8)
Fast-path vs. -v1:
name old time/op new time/op delta
Wire/Unmarshal/google_message1_proto2-12 702ns ± 1% 445ns ± 4% -36.58% (p=0.000 n=8+8)
Wire/Unmarshal/google_message1_proto3-12 604ns ± 1% 311ns ± 1% -48.54% (p=0.000 n=8+8)
Wire/Unmarshal/google_message2-12 179µs ± 3% 188µs ± 2% +5.30% (p=0.000 n=7+8)
name old allocs/op new allocs/op delta
Wire/Unmarshal/google_message1_proto2-12 26.0 ± 0% 25.0 ± 0% -3.85% (p=0.000 n=8+8)
Wire/Unmarshal/google_message1_proto3-12 8.00 ± 0% 7.00 ± 0% -12.50% (p=0.000 n=8+8)
Wire/Unmarshal/google_message2-12 8.49k ± 0% 8.49k ± 0% -0.01% (p=0.000 n=8+8)
Change-Id: I6247ac3fd66a63d9acb902cbd192094ee3d151c3
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/185147
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/internal/impl/codec_field.go b/internal/impl/codec_field.go
index 15124c0..09389b0 100644
--- a/internal/impl/codec_field.go
+++ b/internal/impl/codec_field.go
@@ -5,7 +5,6 @@
package impl
import (
- "fmt"
"reflect"
"unicode/utf8"
@@ -19,61 +18,59 @@
func (errInvalidUTF8) Error() string { return "string field contains invalid UTF-8" }
func (errInvalidUTF8) InvalidUTF8() bool { return true }
-func makeOneofFieldCoder(fs reflect.StructField, od pref.OneofDescriptor, structFields map[pref.FieldNumber]reflect.StructField, otypes map[pref.FieldNumber]reflect.Type) pointerCoderFuncs {
- type oneofFieldInfo struct {
- wiretag uint64
- tagsize int
- funcs pointerCoderFuncs
- }
-
- oneofFieldInfos := make(map[reflect.Type]oneofFieldInfo)
- for i, fields := 0, od.Fields(); i < fields.Len(); i++ {
- fd := fields.Get(i)
- ot := otypes[fd.Number()]
- wiretag := wire.EncodeTag(fd.Number(), wireTypes[fd.Kind()])
- oneofFieldInfos[ot] = oneofFieldInfo{
- wiretag: wiretag,
- tagsize: wire.SizeVarint(wiretag),
- funcs: fieldCoder(fd, ot.Field(0).Type),
- }
- }
+func makeOneofFieldCoder(si structInfo, fd pref.FieldDescriptor) pointerCoderFuncs {
+ ot := si.oneofWrappersByNumber[fd.Number()]
+ funcs := fieldCoder(fd, ot.Field(0).Type)
+ fs := si.oneofsByName[fd.ContainingOneof().Name()]
ft := fs.Type
- getInfo := func(p pointer) (pointer, oneofFieldInfo) {
+ wiretag := wire.EncodeTag(fd.Number(), wireTypes[fd.Kind()])
+ tagsize := wire.SizeVarint(wiretag)
+ getInfo := func(p pointer) (pointer, bool) {
v := p.AsValueOf(ft).Elem()
if v.IsNil() {
- return pointer{}, oneofFieldInfo{}
+ return pointer{}, false
}
v = v.Elem() // interface -> *struct
- telem := v.Elem().Type()
- info, ok := oneofFieldInfos[telem]
- if !ok {
- panic(fmt.Errorf("invalid oneof type %v", telem))
+ if v.Elem().Type() != ot {
+ return pointer{}, false
}
- return pointerOfValue(v).Apply(zeroOffset), info
+ return pointerOfValue(v).Apply(zeroOffset), true
}
- return pointerCoderFuncs{
+ pcf := pointerCoderFuncs{
size: func(p pointer, _ int, opts marshalOptions) int {
- v, info := getInfo(p)
- if info.funcs.size == nil {
+ v, ok := getInfo(p)
+ if !ok {
return 0
}
- return info.funcs.size(v, info.tagsize, opts)
+ return funcs.size(v, tagsize, opts)
},
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
- v, info := getInfo(p)
- if info.funcs.marshal == nil {
+ v, ok := getInfo(p)
+ if !ok {
return b, nil
}
- return info.funcs.marshal(b, v, info.wiretag, opts)
+ return funcs.marshal(b, v, wiretag, opts)
},
- isInit: func(p pointer) error {
- v, info := getInfo(p)
- if info.funcs.isInit == nil {
- return nil
+ unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+ v := reflect.New(ot)
+ n, err := funcs.unmarshal(b, pointerOfValue(v).Apply(zeroOffset), wtyp, opts)
+ if err != nil {
+ return 0, err
}
- return info.funcs.isInit(v)
+ p.AsValueOf(ft).Elem().Set(v)
+ return n, nil
},
}
+ if funcs.isInit != nil {
+ pcf.isInit = func(p pointer) error {
+ v, ok := getInfo(p)
+ if !ok {
+ return nil
+ }
+ return funcs.isInit(v)
+ }
+ }
+ return pcf
}
func makeMessageFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
@@ -85,6 +82,9 @@
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
return appendMessageInfo(b, p, wiretag, fi, opts)
},
+ unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+ return consumeMessageInfo(b, p, fi, wtyp, opts)
+ },
isInit: func(p pointer) error {
return fi.isInitializedPointer(p.Elem())
},
@@ -99,6 +99,13 @@
m := asMessage(p.AsValueOf(ft).Elem())
return appendMessage(b, m, wiretag, opts)
},
+ unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+ mp := p.AsValueOf(ft).Elem()
+ if mp.IsNil() {
+ mp.Set(reflect.New(ft.Elem()))
+ }
+ return consumeMessage(b, asMessage(mp), wtyp, opts)
+ },
isInit: func(p pointer) error {
m := asMessage(p.AsValueOf(ft).Elem())
return proto.IsInitialized(m)
@@ -117,6 +124,23 @@
return mi.marshalAppendPointer(b, p.Elem(), opts)
}
+func consumeMessageInfo(b []byte, p pointer, mi *MessageInfo, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+ if wtyp != wire.BytesType {
+ return 0, errUnknown
+ }
+ v, n := wire.ConsumeBytes(b)
+ if n < 0 {
+ return 0, wire.ParseError(n)
+ }
+ if p.Elem().IsNil() {
+ p.SetPointer(pointerOfValue(reflect.New(mi.GoType.Elem())))
+ }
+ if _, err := mi.unmarshalPointer(v, p.Elem(), 0, opts); err != nil {
+ return 0, err
+ }
+ return n, nil
+}
+
func sizeMessage(m proto.Message, tagsize int, _ marshalOptions) int {
return wire.SizeBytes(proto.Size(m)) + tagsize
}
@@ -127,6 +151,20 @@
return opts.Options().MarshalAppend(b, m)
}
+func consumeMessage(b []byte, m proto.Message, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+ if wtyp != wire.BytesType {
+ return 0, errUnknown
+ }
+ v, n := wire.ConsumeBytes(b)
+ if n < 0 {
+ return 0, wire.ParseError(n)
+ }
+ if err := opts.Options().Unmarshal(v, m); err != nil {
+ return 0, err
+ }
+ return n, nil
+}
+
func sizeMessageIface(ival interface{}, tagsize int, opts marshalOptions) int {
m := Export{}.MessageOf(ival).Interface()
return sizeMessage(m, tagsize, opts)
@@ -137,18 +175,26 @@
return appendMessage(b, m, wiretag, opts)
}
+func consumeMessageIface(b []byte, ival interface{}, _ wire.Number, wtyp wire.Type, opts unmarshalOptions) (interface{}, int, error) {
+ m := Export{}.MessageOf(ival).Interface()
+ n, err := consumeMessage(b, m, wtyp, opts)
+ return ival, n, err
+}
+
func isInitMessageIface(ival interface{}) error {
m := Export{}.MessageOf(ival).Interface()
return proto.IsInitialized(m)
}
var coderMessageIface = ifaceCoderFuncs{
- size: sizeMessageIface,
- marshal: appendMessageIface,
- isInit: isInitMessageIface,
+ size: sizeMessageIface,
+ marshal: appendMessageIface,
+ unmarshal: consumeMessageIface,
+ isInit: isInitMessageIface,
}
func makeGroupFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
+ num := fd.Number()
if fi, ok := getMessageInfo(ft); ok {
return pointerCoderFuncs{
size: func(p pointer, tagsize int, opts marshalOptions) int {
@@ -157,6 +203,9 @@
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
return appendGroupType(b, p, wiretag, fi, opts)
},
+ unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+ return consumeGroupType(b, p, fi, num, wtyp, opts)
+ },
isInit: func(p pointer) error {
return fi.isInitializedPointer(p.Elem())
},
@@ -171,6 +220,13 @@
m := asMessage(p.AsValueOf(ft).Elem())
return appendGroup(b, m, wiretag, opts)
},
+ unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+ mp := p.AsValueOf(ft).Elem()
+ if mp.IsNil() {
+ mp.Set(reflect.New(ft.Elem()))
+ }
+ return consumeGroup(b, asMessage(mp), num, wtyp, opts)
+ },
isInit: func(p pointer) error {
m := asMessage(p.AsValueOf(ft).Elem())
return proto.IsInitialized(m)
@@ -190,6 +246,16 @@
return b, err
}
+func consumeGroupType(b []byte, p pointer, mi *MessageInfo, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+ if wtyp != wire.StartGroupType {
+ return 0, errUnknown
+ }
+ if p.Elem().IsNil() {
+ p.SetPointer(pointerOfValue(reflect.New(mi.GoType.Elem())))
+ }
+ return mi.unmarshalPointer(b, p.Elem(), num, opts)
+}
+
func sizeGroup(m proto.Message, tagsize int, _ marshalOptions) int {
return 2*tagsize + proto.Size(m)
}
@@ -201,30 +267,47 @@
return b, err
}
-func sizeGroupIface(ival interface{}, tagsize int, opts marshalOptions) int {
- m := Export{}.MessageOf(ival).Interface()
- return sizeGroup(m, tagsize, opts)
+func consumeGroup(b []byte, m proto.Message, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+ if wtyp != wire.StartGroupType {
+ return 0, errUnknown
+ }
+ b, n := wire.ConsumeGroup(num, b)
+ if n < 0 {
+ return 0, wire.ParseError(n)
+ }
+ return n, opts.Options().Unmarshal(b, m)
}
-func appendGroupIface(b []byte, ival interface{}, wiretag uint64, opts marshalOptions) ([]byte, error) {
- m := Export{}.MessageOf(ival).Interface()
- return appendGroup(b, m, wiretag, opts)
-}
-
-var coderGroupIface = ifaceCoderFuncs{
- size: sizeGroupIface,
- marshal: appendGroupIface,
- isInit: isInitMessageIface,
+func makeGroupValueCoder(fd pref.FieldDescriptor, ft reflect.Type) ifaceCoderFuncs {
+ return ifaceCoderFuncs{
+ size: func(ival interface{}, tagsize int, opts marshalOptions) int {
+ m := Export{}.MessageOf(ival).Interface()
+ return sizeGroup(m, tagsize, opts)
+ },
+ marshal: func(b []byte, ival interface{}, wiretag uint64, opts marshalOptions) ([]byte, error) {
+ m := Export{}.MessageOf(ival).Interface()
+ return appendGroup(b, m, wiretag, opts)
+ },
+ unmarshal: func(b []byte, ival interface{}, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (interface{}, int, error) {
+ m := Export{}.MessageOf(ival).Interface()
+ n, err := consumeGroup(b, m, num, wtyp, opts)
+ return ival, n, err
+ },
+ isInit: isInitMessageIface,
+ }
}
func makeMessageSliceFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
if fi, ok := getMessageInfo(ft); ok {
return pointerCoderFuncs{
+ size: func(p pointer, tagsize int, opts marshalOptions) int {
+ return sizeMessageSliceInfo(p, fi, tagsize, opts)
+ },
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
return appendMessageSliceInfo(b, p, wiretag, fi, opts)
},
- size: func(p pointer, tagsize int, opts marshalOptions) int {
- return sizeMessageSliceInfo(p, fi, tagsize, opts)
+ unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+ return consumeMessageSliceInfo(b, p, fi, wtyp, opts)
},
isInit: func(p pointer) error {
return isInitMessageSliceInfo(p, fi)
@@ -238,6 +321,9 @@
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
return appendMessageSlice(b, p, wiretag, ft, opts)
},
+ unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+ return consumeMessageSlice(b, p, ft, wtyp, opts)
+ },
isInit: func(p pointer) error {
return isInitMessageSlice(p, ft)
},
@@ -268,6 +354,23 @@
return b, nil
}
+func consumeMessageSliceInfo(b []byte, p pointer, mi *MessageInfo, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+ if wtyp != wire.BytesType {
+ return 0, errUnknown
+ }
+ v, n := wire.ConsumeBytes(b)
+ if n < 0 {
+ return 0, wire.ParseError(n)
+ }
+ m := reflect.New(mi.GoType.Elem()).Interface()
+ mp := pointerOfIface(m)
+ if _, err := mi.unmarshalPointer(v, mp, 0, opts); err != nil {
+ return 0, err
+ }
+ p.AppendPointerSlice(mp)
+ return n, nil
+}
+
func isInitMessageSliceInfo(p pointer, mi *MessageInfo) error {
s := p.PointerSlice()
for _, v := range s {
@@ -282,7 +385,7 @@
s := p.PointerSlice()
n := 0
for _, v := range s {
- m := Export{}.MessageOf(v.AsValueOf(goType.Elem()).Interface()).Interface()
+ m := asMessage(v.AsValueOf(goType.Elem()))
n += wire.SizeBytes(proto.Size(m)) + tagsize
}
return n
@@ -292,7 +395,7 @@
s := p.PointerSlice()
var err error
for _, v := range s {
- m := Export{}.MessageOf(v.AsValueOf(goType.Elem()).Interface()).Interface()
+ m := asMessage(v.AsValueOf(goType.Elem()))
b = wire.AppendVarint(b, wiretag)
siz := proto.Size(m)
b = wire.AppendVarint(b, uint64(siz))
@@ -304,10 +407,26 @@
return b, nil
}
+func consumeMessageSlice(b []byte, p pointer, goType reflect.Type, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+ if wtyp != wire.BytesType {
+ return 0, errUnknown
+ }
+ v, n := wire.ConsumeBytes(b)
+ if n < 0 {
+ return 0, wire.ParseError(n)
+ }
+ mp := reflect.New(goType.Elem())
+ if err := opts.Options().Unmarshal(v, asMessage(mp)); err != nil {
+ return 0, err
+ }
+ p.AppendPointerSlice(pointerOfValue(mp))
+ return n, nil
+}
+
func isInitMessageSlice(p pointer, goType reflect.Type) error {
s := p.PointerSlice()
for _, v := range s {
- m := Export{}.MessageOf(v.AsValueOf(goType.Elem()).Interface()).Interface()
+ m := asMessage(v.AsValueOf(goType.Elem()))
if err := proto.IsInitialized(m); err != nil {
return err
}
@@ -327,18 +446,26 @@
return appendMessageSlice(b, p, wiretag, reflect.TypeOf(ival).Elem().Elem(), opts)
}
+func consumeMessageSliceIface(b []byte, ival interface{}, _ wire.Number, wtyp wire.Type, opts unmarshalOptions) (interface{}, int, error) {
+ p := pointerOfIface(ival)
+ n, err := consumeMessageSlice(b, p, reflect.TypeOf(ival).Elem().Elem(), wtyp, opts)
+ return ival, n, err
+}
+
func isInitMessageSliceIface(ival interface{}) error {
p := pointerOfIface(ival)
return isInitMessageSlice(p, reflect.TypeOf(ival).Elem().Elem())
}
var coderMessageSliceIface = ifaceCoderFuncs{
- size: sizeMessageSliceIface,
- marshal: appendMessageSliceIface,
- isInit: isInitMessageSliceIface,
+ size: sizeMessageSliceIface,
+ marshal: appendMessageSliceIface,
+ unmarshal: consumeMessageSliceIface,
+ isInit: isInitMessageSliceIface,
}
func makeGroupSliceFieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
+ num := fd.Number()
if fi, ok := getMessageInfo(ft); ok {
return pointerCoderFuncs{
size: func(p pointer, tagsize int, opts marshalOptions) int {
@@ -347,6 +474,9 @@
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
return appendGroupSliceInfo(b, p, wiretag, fi, opts)
},
+ unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+ return consumeGroupSliceInfo(b, p, num, wtyp, fi, opts)
+ },
isInit: func(p pointer) error {
return isInitMessageSliceInfo(p, fi)
},
@@ -359,6 +489,9 @@
marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
return appendGroupSlice(b, p, wiretag, ft, opts)
},
+ unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
+ return consumeGroupSlice(b, p, num, wtyp, ft, opts)
+ },
isInit: func(p pointer) error {
return isInitMessageSlice(p, ft)
},
@@ -369,7 +502,7 @@
s := p.PointerSlice()
n := 0
for _, v := range s {
- m := Export{}.MessageOf(v.AsValueOf(messageType.Elem()).Interface()).Interface()
+ m := asMessage(v.AsValueOf(messageType.Elem()))
n += 2*tagsize + proto.Size(m)
}
return n
@@ -379,7 +512,7 @@
s := p.PointerSlice()
var err error
for _, v := range s {
- m := Export{}.MessageOf(v.AsValueOf(messageType.Elem()).Interface()).Interface()
+ m := asMessage(v.AsValueOf(messageType.Elem()))
b = wire.AppendVarint(b, wiretag) // start group
b, err = opts.Options().MarshalAppend(b, m)
if err != nil {
@@ -390,6 +523,22 @@
return b, nil
}
+func consumeGroupSlice(b []byte, p pointer, num wire.Number, wtyp wire.Type, goType reflect.Type, opts unmarshalOptions) (int, error) {
+ if wtyp != wire.StartGroupType {
+ return 0, errUnknown
+ }
+ b, n := wire.ConsumeGroup(num, b)
+ if n < 0 {
+ return 0, wire.ParseError(n)
+ }
+ mp := reflect.New(goType.Elem())
+ if err := opts.Options().Unmarshal(b, asMessage(mp)); err != nil {
+ return 0, err
+ }
+ p.AppendPointerSlice(pointerOfValue(mp))
+ return n, nil
+}
+
func sizeGroupSliceInfo(p pointer, mi *MessageInfo, tagsize int, opts marshalOptions) int {
s := p.PointerSlice()
n := 0
@@ -413,6 +562,20 @@
return b, nil
}
+func consumeGroupSliceInfo(b []byte, p pointer, num wire.Number, wtyp wire.Type, mi *MessageInfo, opts unmarshalOptions) (int, error) {
+ if wtyp != wire.StartGroupType {
+ return 0, errUnknown
+ }
+ m := reflect.New(mi.GoType.Elem()).Interface()
+ mp := pointerOfIface(m)
+ n, err := mi.unmarshalPointer(b, mp, num, opts)
+ if err != nil {
+ return 0, err
+ }
+ p.AppendPointerSlice(mp)
+ return n, nil
+}
+
func sizeGroupSliceIface(ival interface{}, tagsize int, opts marshalOptions) int {
p := pointerOfIface(ival)
return sizeGroupSlice(p, reflect.TypeOf(ival).Elem().Elem(), tagsize, opts)
@@ -423,10 +586,17 @@
return appendGroupSlice(b, p, wiretag, reflect.TypeOf(ival).Elem().Elem(), opts)
}
+func consumeGroupSliceIface(b []byte, ival interface{}, num wire.Number, wtyp wire.Type, opts unmarshalOptions) (interface{}, int, error) {
+ p := pointerOfIface(ival)
+ n, err := consumeGroupSlice(b, p, num, wtyp, reflect.TypeOf(ival).Elem().Elem(), opts)
+ return ival, n, err
+}
+
var coderGroupSliceIface = ifaceCoderFuncs{
- size: sizeGroupSliceIface,
- marshal: appendGroupSliceIface,
- isInit: isInitMessageSliceIface,
+ size: sizeGroupSliceIface,
+ marshal: appendGroupSliceIface,
+ unmarshal: consumeGroupSliceIface,
+ isInit: isInitMessageSliceIface,
}
// Enums
@@ -443,9 +613,23 @@
return b, nil
}
+func consumeEnumIface(b []byte, ival interface{}, _ wire.Number, wtyp wire.Type, _ unmarshalOptions) (interface{}, int, error) {
+ if wtyp != wire.VarintType {
+ return nil, 0, errUnknown
+ }
+ v, n := wire.ConsumeVarint(b)
+ if n < 0 {
+ return nil, 0, wire.ParseError(n)
+ }
+ rv := reflect.New(reflect.TypeOf(ival)).Elem()
+ rv.SetInt(int64(v))
+ return rv.Interface(), n, nil
+}
+
var coderEnumIface = ifaceCoderFuncs{
- size: sizeEnumIface,
- marshal: appendEnumIface,
+ size: sizeEnumIface,
+ marshal: appendEnumIface,
+ unmarshal: consumeEnumIface,
}
func sizeEnumSliceIface(ival interface{}, tagsize int, opts marshalOptions) (size int) {
@@ -471,9 +655,47 @@
return b, nil
}
+func consumeEnumSliceIface(b []byte, ival interface{}, _ wire.Number, wtyp wire.Type, opts unmarshalOptions) (interface{}, int, error) {
+ n, err := consumeEnumSliceReflect(b, reflect.ValueOf(ival), wtyp, opts)
+ return ival, n, err
+}
+
+func consumeEnumSliceReflect(b []byte, s reflect.Value, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
+ s = s.Elem() // *[]E -> []E
+ if wtyp == wire.BytesType {
+ b, n = wire.ConsumeBytes(b)
+ if n < 0 {
+ return 0, wire.ParseError(n)
+ }
+ for len(b) > 0 {
+ v, n := wire.ConsumeVarint(b)
+ if n < 0 {
+ return 0, wire.ParseError(n)
+ }
+ rv := reflect.New(s.Type().Elem()).Elem()
+ rv.SetInt(int64(v))
+ s.Set(reflect.Append(s, rv))
+ b = b[n:]
+ }
+ return n, nil
+ }
+ if wtyp != wire.VarintType {
+ return 0, errUnknown
+ }
+ v, n := wire.ConsumeVarint(b)
+ if n < 0 {
+ return 0, wire.ParseError(n)
+ }
+ rv := reflect.New(s.Type().Elem()).Elem()
+ rv.SetInt(int64(v))
+ s.Set(reflect.Append(s, rv))
+ return n, nil
+}
+
var coderEnumSliceIface = ifaceCoderFuncs{
- size: sizeEnumSliceIface,
- marshal: appendEnumSliceIface,
+ size: sizeEnumSliceIface,
+ marshal: appendEnumSliceIface,
+ unmarshal: consumeEnumSliceIface,
}
// Strings with UTF8 validation.
@@ -488,9 +710,25 @@
return b, nil
}
+func consumeStringValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
+ if wtyp != wire.BytesType {
+ return 0, errUnknown
+ }
+ v, n := wire.ConsumeString(b)
+ if n < 0 {
+ return 0, wire.ParseError(n)
+ }
+ if !utf8.ValidString(v) {
+ return 0, errInvalidUTF8{}
+ }
+ *p.String() = v
+ return n, nil
+}
+
var coderStringValidateUTF8 = pointerCoderFuncs{
- size: sizeString,
- marshal: appendStringValidateUTF8,
+ size: sizeString,
+ marshal: appendStringValidateUTF8,
+ unmarshal: consumeStringValidateUTF8,
}
func appendStringNoZeroValidateUTF8(b []byte, p pointer, wiretag uint64, _ marshalOptions) ([]byte, error) {
@@ -507,8 +745,9 @@
}
var coderStringNoZeroValidateUTF8 = pointerCoderFuncs{
- size: sizeStringNoZero,
- marshal: appendStringNoZeroValidateUTF8,
+ size: sizeStringNoZero,
+ marshal: appendStringNoZeroValidateUTF8,
+ unmarshal: consumeStringValidateUTF8,
}
func sizeStringSliceValidateUTF8(p pointer, tagsize int, _ marshalOptions) (size int) {
@@ -526,15 +765,32 @@
b = wire.AppendVarint(b, wiretag)
b = wire.AppendString(b, v)
if !utf8.ValidString(v) {
- err = errInvalidUTF8{}
+ return b, errInvalidUTF8{}
}
}
return b, err
}
+func consumeStringSliceValidateUTF8(b []byte, p pointer, wtyp wire.Type, _ unmarshalOptions) (n int, err error) {
+ if wtyp != wire.BytesType {
+ return 0, errUnknown
+ }
+ sp := p.StringSlice()
+ v, n := wire.ConsumeString(b)
+ if n < 0 {
+ return 0, wire.ParseError(n)
+ }
+ if !utf8.ValidString(v) {
+ return 0, errInvalidUTF8{}
+ }
+ *sp = append(*sp, v)
+ return n, nil
+}
+
var coderStringSliceValidateUTF8 = pointerCoderFuncs{
- size: sizeStringSliceValidateUTF8,
- marshal: appendStringSliceValidateUTF8,
+ size: sizeStringSliceValidateUTF8,
+ marshal: appendStringSliceValidateUTF8,
+ unmarshal: consumeStringSliceValidateUTF8,
}
func sizeStringIfaceValidateUTF8(ival interface{}, tagsize int, _ marshalOptions) int {
@@ -552,9 +808,24 @@
return b, nil
}
+func consumeStringIfaceValidateUTF8(b []byte, _ interface{}, _ wire.Number, wtyp wire.Type, _ unmarshalOptions) (interface{}, int, error) {
+ if wtyp != wire.BytesType {
+ return nil, 0, errUnknown
+ }
+ v, n := wire.ConsumeString(b)
+ if n < 0 {
+ return nil, 0, wire.ParseError(n)
+ }
+ if !utf8.ValidString(v) {
+ return nil, 0, errInvalidUTF8{}
+ }
+ return v, n, nil
+}
+
var coderStringIfaceValidateUTF8 = ifaceCoderFuncs{
- size: sizeStringIfaceValidateUTF8,
- marshal: appendStringIfaceValidateUTF8,
+ size: sizeStringIfaceValidateUTF8,
+ marshal: appendStringIfaceValidateUTF8,
+ unmarshal: consumeStringIfaceValidateUTF8,
}
func asMessage(v reflect.Value) pref.ProtoMessage {