goprotobuf: Change Size implementation to use the same code structure as Marshal (encode).
This is much faster (2x-4x), and makes zero allocations.
R=r
CC=golang-dev
https://codereview.appspot.com/14430057
diff --git a/proto/all_test.go b/proto/all_test.go
index 9ae811e..50a537b 100644
--- a/proto/all_test.go
+++ b/proto/all_test.go
@@ -1747,78 +1747,122 @@
Unmarshal(data, pb)
}
-func benchmarkMsg(bytes bool) *GoTest {
+// Benchmarks
+
+func testMsg() *GoTest {
pb := initGoTest(true)
- if bytes {
- buf := make([]byte, 4000)
- for i := range buf {
- buf[i] = byte(i)
- }
- pb.F_BytesDefaulted = buf
- } else {
- const N = 1000 // Internally the library starts much smaller.
- pb.F_Int32Repeated = make([]int32, N)
- pb.F_DoubleRepeated = make([]float64, N)
- for i := 0; i < N; i++ {
- pb.F_Int32Repeated[i] = int32(i)
- pb.F_DoubleRepeated[i] = float64(i)
- }
+ const N = 1000 // Internally the library starts much smaller.
+ pb.F_Int32Repeated = make([]int32, N)
+ pb.F_DoubleRepeated = make([]float64, N)
+ for i := 0; i < N; i++ {
+ pb.F_Int32Repeated[i] = int32(i)
+ pb.F_DoubleRepeated[i] = float64(i)
}
return pb
}
-func BenchmarkMarshal(b *testing.B) {
- pb := benchmarkMsg(false)
+func bytesMsg() *GoTest {
+ pb := initGoTest(true)
+ buf := make([]byte, 4000)
+ for i := range buf {
+ buf[i] = byte(i)
+ }
+ pb.F_BytesDefaulted = buf
+ return pb
+}
+
+func benchmarkMarshal(b *testing.B, pb Message, marshal func(Message) ([]byte, error)) {
+ d, _ := marshal(pb)
+ b.SetBytes(int64(len(d)))
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ marshal(pb)
+ }
+}
+
+func benchmarkBufferMarshal(b *testing.B, pb Message) {
p := NewBuffer(nil)
+ benchmarkMarshal(b, pb, func(pb0 Message) ([]byte, error) {
+ p.Reset()
+ err := p.Marshal(pb0)
+ return p.Bytes(), err
+ })
+}
+
+func benchmarkSize(b *testing.B, pb Message) {
+ benchmarkMarshal(b, pb, func(pb0 Message) ([]byte, error) {
+ Size(pb)
+ return nil, nil
+ })
+}
+
+func newOf(pb Message) Message {
+ in := reflect.ValueOf(pb)
+ if in.IsNil() {
+ return pb
+ }
+ return reflect.New(in.Type().Elem()).Interface().(Message)
+}
+
+func benchmarkUnmarshal(b *testing.B, pb Message, unmarshal func([]byte, Message) error) {
+ d, _ := Marshal(pb)
+ b.SetBytes(int64(len(d)))
+ pbd := newOf(pb)
b.ResetTimer()
for i := 0; i < b.N; i++ {
- p.Reset()
- p.Marshal(pb)
+ unmarshal(d, pbd)
}
- b.SetBytes(int64(len(p.Bytes())))
+}
+
+func benchmarkBufferUnmarshal(b *testing.B, pb Message) {
+ p := NewBuffer(nil)
+ benchmarkUnmarshal(b, pb, func(d []byte, pb0 Message) error {
+ p.SetBuf(d)
+ return p.Unmarshal(pb0)
+ })
+}
+
+// Benchmark{Marshal,BufferMarshal,Size,Unmarshal,BufferUnmarshal}{,Bytes}
+
+func BenchmarkMarshal(b *testing.B) {
+ benchmarkMarshal(b, testMsg(), Marshal)
+}
+
+func BenchmarkBufferMarshal(b *testing.B) {
+ benchmarkBufferMarshal(b, testMsg())
+}
+
+func BenchmarkSize(b *testing.B) {
+ benchmarkSize(b, testMsg())
}
func BenchmarkUnmarshal(b *testing.B) {
- pb := benchmarkMsg(false)
- p := NewBuffer(nil)
- p.Marshal(pb)
- b.SetBytes(int64(len(p.Bytes())))
- p2 := NewBuffer(nil)
- pbd := new(GoTest)
+ benchmarkUnmarshal(b, testMsg(), Unmarshal)
+}
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- p2.SetBuf(p.Bytes())
- p2.Unmarshal(pbd)
- }
+func BenchmarkBufferUnmarshal(b *testing.B) {
+ benchmarkBufferUnmarshal(b, testMsg())
}
func BenchmarkMarshalBytes(b *testing.B) {
- pb := benchmarkMsg(true)
- p := NewBuffer(nil)
+ benchmarkMarshal(b, bytesMsg(), Marshal)
+}
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- p.Reset()
- p.Marshal(pb)
- }
- b.SetBytes(int64(len(p.Bytes())))
+func BenchmarkBufferMarshalBytes(b *testing.B) {
+ benchmarkBufferMarshal(b, bytesMsg())
+}
+
+func BenchmarkSizeBytes(b *testing.B) {
+ benchmarkSize(b, bytesMsg())
}
func BenchmarkUnmarshalBytes(b *testing.B) {
- pb := benchmarkMsg(true)
- p := NewBuffer(nil)
- p.Marshal(pb)
- b.SetBytes(int64(len(p.Bytes())))
- p2 := NewBuffer(nil)
- pbd := new(GoTest)
+ benchmarkUnmarshal(b, bytesMsg(), Unmarshal)
+}
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- p2.SetBuf(p.Bytes())
- p2.Unmarshal(pbd)
- }
+func BenchmarkBufferUnmarshalBytes(b *testing.B) {
+ benchmarkBufferUnmarshal(b, bytesMsg())
}
func BenchmarkUnmarshalUnrecognizedFields(b *testing.B) {
diff --git a/proto/encode.go b/proto/encode.go
index 2a249c8..d757110 100644
--- a/proto/encode.go
+++ b/proto/encode.go
@@ -105,6 +105,17 @@
return nil
}
+func sizeVarint(x uint64) (n int) {
+ for {
+ n++
+ x >>= 7
+ if x == 0 {
+ break
+ }
+ }
+ return n
+}
+
// EncodeFixed64 writes a 64-bit integer to the Buffer.
// This is the format for the
// fixed64, sfixed64, and double protocol buffer types.
@@ -121,6 +132,10 @@
return nil
}
+func sizeFixed64(x uint64) int {
+ return 8
+}
+
// EncodeFixed32 writes a 32-bit integer to the Buffer.
// This is the format for the
// fixed32, sfixed32, and float protocol buffer types.
@@ -133,6 +148,10 @@
return nil
}
+func sizeFixed32(x uint64) int {
+ return 4
+}
+
// EncodeZigzag64 writes a zigzag-encoded 64-bit integer
// to the Buffer.
// This is the format used for the sint64 protocol buffer type.
@@ -141,6 +160,10 @@
return p.EncodeVarint(uint64((x << 1) ^ uint64((int64(x) >> 63))))
}
+func sizeZigzag64(x uint64) int {
+ return sizeVarint(uint64((x << 1) ^ uint64((int64(x) >> 63))))
+}
+
// EncodeZigzag32 writes a zigzag-encoded 32-bit integer
// to the Buffer.
// This is the format used for the sint32 protocol buffer type.
@@ -149,6 +172,10 @@
return p.EncodeVarint(uint64((uint32(x) << 1) ^ uint32((int32(x) >> 31))))
}
+func sizeZigzag32(x uint64) int {
+ return sizeVarint(uint64((uint32(x) << 1) ^ uint32((int32(x) >> 31))))
+}
+
// EncodeRawBytes writes a count-delimited byte buffer to the Buffer.
// This is the format used for the bytes protocol buffer
// type and for embedded messages.
@@ -158,6 +185,11 @@
return nil
}
+func sizeRawBytes(b []byte) int {
+ return sizeVarint(uint64(len(b))) +
+ len(b)
+}
+
// EncodeStringBytes writes an encoded string to the Buffer.
// This is the format used for the proto2 string type.
func (p *Buffer) EncodeStringBytes(s string) error {
@@ -166,6 +198,11 @@
return nil
}
+func sizeStringBytes(s string) int {
+ return sizeVarint(uint64(len(s))) +
+ len(s)
+}
+
// Marshaler is the interface representing objects that can marshal themselves.
type Marshaler interface {
Marshal() ([]byte, error)
@@ -216,6 +253,30 @@
return err
}
+// Size returns the encoded size of a protocol buffer.
+func Size(pb Message) (n int) {
+ // Can the object marshal itself? If so, Size is slow.
+ // TODO: add Size to Marshaler, or add a Sizer interface.
+ if m, ok := pb.(Marshaler); ok {
+ b, _ := m.Marshal()
+ return len(b)
+ }
+
+ t, base, err := getbase(pb)
+ if structPointer_IsNil(base) {
+ return 0
+ }
+ if err == nil {
+ n = size_struct(t.Elem(), GetProperties(t.Elem()), base)
+ }
+
+ if collectStats {
+ stats.Size++
+ }
+
+ return
+}
+
// Individual type encoders.
// Encode a bool.
@@ -233,6 +294,14 @@
return nil
}
+func size_bool(p *Properties, base structPointer) int {
+ v := *structPointer_Bool(base, p.field)
+ if v == nil {
+ return 0
+ }
+ return len(p.tagcode) + 1 // each bool takes exactly one byte
+}
+
// Encode an int32.
func (o *Buffer) enc_int32(p *Properties, base structPointer) error {
v := structPointer_Word32(base, p.field)
@@ -245,6 +314,17 @@
return nil
}
+func size_int32(p *Properties, base structPointer) (n int) {
+ v := structPointer_Word32(base, p.field)
+ if word32_IsNil(v) {
+ return 0
+ }
+ x := word32_Get(v)
+ n += len(p.tagcode)
+ n += p.valSize(uint64(x))
+ return
+}
+
// Encode an int64.
func (o *Buffer) enc_int64(p *Properties, base structPointer) error {
v := structPointer_Word64(base, p.field)
@@ -257,6 +337,17 @@
return nil
}
+func size_int64(p *Properties, base structPointer) (n int) {
+ v := structPointer_Word64(base, p.field)
+ if word64_IsNil(v) {
+ return 0
+ }
+ x := word64_Get(v)
+ n += len(p.tagcode)
+ n += p.valSize(x)
+ return
+}
+
// Encode a string.
func (o *Buffer) enc_string(p *Properties, base structPointer) error {
v := *structPointer_String(base, p.field)
@@ -269,6 +360,17 @@
return nil
}
+func size_string(p *Properties, base structPointer) (n int) {
+ v := *structPointer_String(base, p.field)
+ if v == nil {
+ return 0
+ }
+ x := *v
+ n += len(p.tagcode)
+ n += sizeStringBytes(x)
+ return
+}
+
// All protocol buffer fields are nillable, but be careful.
func isNil(v reflect.Value) bool {
switch v.Kind() {
@@ -317,6 +419,27 @@
return state.err
}
+func size_struct_message(p *Properties, base structPointer) int {
+ structp := structPointer_GetStructPointer(base, p.field)
+ if structPointer_IsNil(structp) {
+ return 0
+ }
+
+ // Can the object marshal itself?
+ if p.isMarshaler {
+ m := structPointer_Interface(structp, p.stype).(Marshaler)
+ data, _ := m.Marshal()
+ n0 := len(p.tagcode)
+ n1 := sizeRawBytes(data)
+ return n0 + n1
+ }
+
+ n0 := len(p.tagcode)
+ n1 := size_struct(p.stype, p.sprop, structp)
+ n2 := sizeVarint(uint64(n1)) // size of encoded length
+ return n0 + n1 + n2
+}
+
// Encode a group struct.
func (o *Buffer) enc_struct_group(p *Properties, base structPointer) error {
var state errorState
@@ -334,6 +457,18 @@
return state.err
}
+func size_struct_group(p *Properties, base structPointer) (n int) {
+ b := structPointer_GetStructPointer(base, p.field)
+ if structPointer_IsNil(b) {
+ return 0
+ }
+
+ n += sizeVarint(uint64((p.Tag << 3) | WireStartGroup))
+ n += size_struct(p.stype, p.sprop, b)
+ n += sizeVarint(uint64((p.Tag << 3) | WireEndGroup))
+ return
+}
+
// Encode a slice of bools ([]bool).
func (o *Buffer) enc_slice_bool(p *Properties, base structPointer) error {
s := *structPointer_BoolSlice(base, p.field)
@@ -352,6 +487,15 @@
return nil
}
+func size_slice_bool(p *Properties, base structPointer) int {
+ s := *structPointer_BoolSlice(base, p.field)
+ l := len(s)
+ if l == 0 {
+ return 0
+ }
+ return l * (len(p.tagcode) + 1) // each bool takes exactly one byte
+}
+
// Encode a slice of bools ([]bool) in packed format.
func (o *Buffer) enc_slice_packed_bool(p *Properties, base structPointer) error {
s := *structPointer_BoolSlice(base, p.field)
@@ -371,6 +515,18 @@
return nil
}
+func size_slice_packed_bool(p *Properties, base structPointer) (n int) {
+ s := *structPointer_BoolSlice(base, p.field)
+ l := len(s)
+ if l == 0 {
+ return 0
+ }
+ n += len(p.tagcode)
+ n += sizeVarint(uint64(l))
+ n += l // each bool takes exactly one byte
+ return
+}
+
// Encode a slice of bytes ([]byte).
func (o *Buffer) enc_slice_byte(p *Properties, base structPointer) error {
s := *structPointer_Bytes(base, p.field)
@@ -382,6 +538,16 @@
return nil
}
+func size_slice_byte(p *Properties, base structPointer) (n int) {
+ s := *structPointer_Bytes(base, p.field)
+ if s == nil {
+ return 0
+ }
+ n += len(p.tagcode)
+ n += sizeRawBytes(s)
+ return
+}
+
// Encode a slice of int32s ([]int32).
func (o *Buffer) enc_slice_int32(p *Properties, base structPointer) error {
s := structPointer_Word32Slice(base, p.field)
@@ -397,6 +563,20 @@
return nil
}
+func size_slice_int32(p *Properties, base structPointer) (n int) {
+ s := structPointer_Word32Slice(base, p.field)
+ l := s.Len()
+ if l == 0 {
+ return 0
+ }
+ for i := 0; i < l; i++ {
+ n += len(p.tagcode)
+ x := s.Index(i)
+ n += p.valSize(uint64(x))
+ }
+ return
+}
+
// Encode a slice of int32s ([]int32) in packed format.
func (o *Buffer) enc_slice_packed_int32(p *Properties, base structPointer) error {
s := structPointer_Word32Slice(base, p.field)
@@ -416,6 +596,23 @@
return nil
}
+func size_slice_packed_int32(p *Properties, base structPointer) (n int) {
+ s := structPointer_Word32Slice(base, p.field)
+ l := s.Len()
+ if l == 0 {
+ return 0
+ }
+ var bufSize int
+ for i := 0; i < l; i++ {
+ bufSize += p.valSize(uint64(s.Index(i)))
+ }
+
+ n += len(p.tagcode)
+ n += sizeVarint(uint64(bufSize))
+ n += bufSize
+ return
+}
+
// Encode a slice of int64s ([]int64).
func (o *Buffer) enc_slice_int64(p *Properties, base structPointer) error {
s := structPointer_Word64Slice(base, p.field)
@@ -430,6 +627,19 @@
return nil
}
+func size_slice_int64(p *Properties, base structPointer) (n int) {
+ s := structPointer_Word64Slice(base, p.field)
+ l := s.Len()
+ if l == 0 {
+ return 0
+ }
+ for i := 0; i < l; i++ {
+ n += len(p.tagcode)
+ n += p.valSize(s.Index(i))
+ }
+ return
+}
+
// Encode a slice of int64s ([]int64) in packed format.
func (o *Buffer) enc_slice_packed_int64(p *Properties, base structPointer) error {
s := structPointer_Word64Slice(base, p.field)
@@ -449,6 +659,23 @@
return nil
}
+func size_slice_packed_int64(p *Properties, base structPointer) (n int) {
+ s := structPointer_Word64Slice(base, p.field)
+ l := s.Len()
+ if l == 0 {
+ return 0
+ }
+ var bufSize int
+ for i := 0; i < l; i++ {
+ bufSize += p.valSize(s.Index(i))
+ }
+
+ n += len(p.tagcode)
+ n += sizeVarint(uint64(bufSize))
+ n += bufSize
+ return
+}
+
// Encode a slice of slice of bytes ([][]byte).
func (o *Buffer) enc_slice_slice_byte(p *Properties, base structPointer) error {
ss := *structPointer_BytesSlice(base, p.field)
@@ -458,24 +685,45 @@
}
for i := 0; i < l; i++ {
o.buf = append(o.buf, p.tagcode...)
- s := ss[i]
- o.EncodeRawBytes(s)
+ o.EncodeRawBytes(ss[i])
}
return nil
}
+func size_slice_slice_byte(p *Properties, base structPointer) (n int) {
+ ss := *structPointer_BytesSlice(base, p.field)
+ l := len(ss)
+ if l == 0 {
+ return 0
+ }
+ n += l * len(p.tagcode)
+ for i := 0; i < l; i++ {
+ n += sizeRawBytes(ss[i])
+ }
+ return
+}
+
// Encode a slice of strings ([]string).
func (o *Buffer) enc_slice_string(p *Properties, base structPointer) error {
ss := *structPointer_StringSlice(base, p.field)
l := len(ss)
for i := 0; i < l; i++ {
o.buf = append(o.buf, p.tagcode...)
- s := ss[i]
- o.EncodeStringBytes(s)
+ o.EncodeStringBytes(ss[i])
}
return nil
}
+func size_slice_string(p *Properties, base structPointer) (n int) {
+ ss := *structPointer_StringSlice(base, p.field)
+ l := len(ss)
+ n += l * len(p.tagcode)
+ for i := 0; i < l; i++ {
+ n += sizeStringBytes(ss[i])
+ }
+ return
+}
+
// Encode a slice of message structs ([]*struct).
func (o *Buffer) enc_slice_struct_message(p *Properties, base structPointer) error {
var state errorState
@@ -522,6 +770,32 @@
return state.err
}
+func size_slice_struct_message(p *Properties, base structPointer) (n int) {
+ s := structPointer_StructPointerSlice(base, p.field)
+ l := s.Len()
+ n += l * len(p.tagcode)
+ for i := 0; i < l; i++ {
+ structp := s.Index(i)
+ if structPointer_IsNil(structp) {
+ return // return the size up to this point
+ }
+
+ // Can the object marshal itself?
+ if p.isMarshaler {
+ m := structPointer_Interface(structp, p.stype).(Marshaler)
+ data, _ := m.Marshal()
+ n += len(p.tagcode)
+ n += sizeRawBytes(data)
+ continue
+ }
+
+ n0 := size_struct(p.stype, p.sprop, structp)
+ n1 := sizeVarint(uint64(n0)) // size of encoded length
+ n += n0 + n1
+ }
+ return
+}
+
// Encode a slice of group structs ([]*struct).
func (o *Buffer) enc_slice_struct_group(p *Properties, base structPointer) error {
var state errorState
@@ -550,6 +824,23 @@
return state.err
}
+func size_slice_struct_group(p *Properties, base structPointer) (n int) {
+ s := structPointer_StructPointerSlice(base, p.field)
+ l := s.Len()
+
+ n += l * sizeVarint(uint64((p.Tag<<3)|WireStartGroup))
+ n += l * sizeVarint(uint64((p.Tag<<3)|WireEndGroup))
+ for i := 0; i < l; i++ {
+ b := s.Index(i)
+ if structPointer_IsNil(b) {
+ return // return size up to this point
+ }
+
+ n += size_struct(p.stype, p.sprop, b)
+ }
+ return
+}
+
// Encode an extension map.
func (o *Buffer) enc_map(p *Properties, base structPointer) error {
v := *structPointer_ExtMap(base, p.field)
@@ -577,6 +868,11 @@
return nil
}
+func size_map(p *Properties, base structPointer) int {
+ v := *structPointer_ExtMap(base, p.field)
+ return sizeExtensionMap(v)
+}
+
// Encode a struct.
func (o *Buffer) enc_struct(t reflect.Type, prop *StructProperties, base structPointer) error {
var state errorState
@@ -610,6 +906,23 @@
return state.err
}
+func size_struct(t reflect.Type, prop *StructProperties, base structPointer) (n int) {
+ for _, i := range prop.order {
+ p := prop.Prop[i]
+ if p.size != nil {
+ n += p.size(p, base)
+ }
+ }
+
+ // Add unrecognized fields at the end.
+ if prop.unrecField.IsValid() {
+ v := *structPointer_Bytes(base, prop.unrecField)
+ n += len(v)
+ }
+
+ return
+}
+
// errorState maintains the first error that occurs and updates that error
// with additional context.
type errorState struct {
diff --git a/proto/extensions.go b/proto/extensions.go
index e730b68..50d72aa 100644
--- a/proto/extensions.go
+++ b/proto/extensions.go
@@ -183,6 +183,30 @@
return nil
}
+func sizeExtensionMap(m map[int32]Extension) (n int) {
+ for _, e := range m {
+ if e.value == nil || e.desc == nil {
+ // Extension is only in its encoded form.
+ n += len(e.enc)
+ continue
+ }
+
+ // We don't skip extensions that have an encoded form set,
+ // because the extension value may have been mutated after
+ // the last time this function was called.
+
+ et := reflect.TypeOf(e.desc.ExtensionType)
+ props := extensionProperties(e.desc)
+
+ // If e.value has type T, the encoder expects a *struct{ X T }.
+ // Pass a *T with a zero field and hope it all works out.
+ x := reflect.New(et)
+ x.Elem().Set(reflect.ValueOf(e.value))
+ n += props.size(props, toStructPointer(x))
+ }
+ return
+}
+
// HasExtension returns whether the given extension is present in pb.
func HasExtension(pb extendableProto, extension *ExtensionDesc) bool {
// TODO: Check types, field numbers, etc.?
diff --git a/proto/lib.go b/proto/lib.go
index fa6fe22..5d5e345 100644
--- a/proto/lib.go
+++ b/proto/lib.go
@@ -223,6 +223,7 @@
Decode uint64 // number of decodes
Chit uint64 // number of cache hits
Cmiss uint64 // number of cache misses
+ Size uint64 // number of sizes
}
// Set to true to enable stats collection.
diff --git a/proto/properties.go b/proto/properties.go
index 75b3e8d..7177cfc 100644
--- a/proto/properties.go
+++ b/proto/properties.go
@@ -59,7 +59,7 @@
const startSize = 10 // initial slice/string sizes
-// Encoders are defined in encoder.go
+// Encoders are defined in encode.go
// An encoder outputs the full representation of a field, including its
// tag and encoder type.
type encoder func(p *Buffer, prop *Properties, base structPointer) error
@@ -67,6 +67,15 @@
// A valueEncoder encodes a single integer in a particular encoding.
type valueEncoder func(o *Buffer, x uint64) error
+// Sizers are defined in encode.go
+// A sizer returns the encoded size of a field, including its tag and encoder
+// type.
+type sizer func(prop *Properties, base structPointer) int
+
+// A valueSizer returns the encoded size of a single integer in a particular
+// encoding.
+type valueSizer func(x uint64) int
+
// Decoders are defined in decode.go
// A decoder creates a value from its wire representation.
// Unrecognized subelements are saved in unrec.
@@ -126,7 +135,7 @@
}
// Implement the sorting interface so we can sort the fields in tag order, as recommended by the spec.
-// See encoder.go, (*Buffer).enc_struct.
+// See encode.go, (*Buffer).enc_struct.
func (sp *StructProperties) Len() int { return len(sp.order) }
func (sp *StructProperties) Less(i, j int) bool {
@@ -159,6 +168,9 @@
isMarshaler bool
isUnmarshaler bool
+ size sizer
+ valSize valueSizer // set for bool and numeric types only
+
dec decoder
valDec valueDecoder // set for bool and numeric types only
@@ -210,22 +222,27 @@
p.WireType = WireVarint
p.valEnc = (*Buffer).EncodeVarint
p.valDec = (*Buffer).DecodeVarint
+ p.valSize = sizeVarint
case "fixed32":
p.WireType = WireFixed32
p.valEnc = (*Buffer).EncodeFixed32
p.valDec = (*Buffer).DecodeFixed32
+ p.valSize = sizeFixed32
case "fixed64":
p.WireType = WireFixed64
p.valEnc = (*Buffer).EncodeFixed64
p.valDec = (*Buffer).DecodeFixed64
+ p.valSize = sizeFixed64
case "zigzag32":
p.WireType = WireVarint
p.valEnc = (*Buffer).EncodeZigzag32
p.valDec = (*Buffer).DecodeZigzag32
+ p.valSize = sizeZigzag32
case "zigzag64":
p.WireType = WireVarint
p.valEnc = (*Buffer).EncodeZigzag64
p.valDec = (*Buffer).DecodeZigzag64
+ p.valSize = sizeZigzag64
case "bytes", "group":
p.WireType = WireBytes
// no numeric converter for non-numeric types
@@ -276,6 +293,7 @@
func (p *Properties) setEncAndDec(typ reflect.Type, lockGetProp bool) {
p.enc = nil
p.dec = nil
+ p.size = nil
switch t1 := typ; t1.Kind() {
default:
@@ -289,21 +307,27 @@
case reflect.Bool:
p.enc = (*Buffer).enc_bool
p.dec = (*Buffer).dec_bool
+ p.size = size_bool
case reflect.Int32, reflect.Uint32:
p.enc = (*Buffer).enc_int32
p.dec = (*Buffer).dec_int32
+ p.size = size_int32
case reflect.Int64, reflect.Uint64:
p.enc = (*Buffer).enc_int64
p.dec = (*Buffer).dec_int64
+ p.size = size_int64
case reflect.Float32:
p.enc = (*Buffer).enc_int32 // can just treat them as bits
p.dec = (*Buffer).dec_int32
+ p.size = size_int32
case reflect.Float64:
p.enc = (*Buffer).enc_int64 // can just treat them as bits
p.dec = (*Buffer).dec_int64
+ p.size = size_int64
case reflect.String:
p.enc = (*Buffer).enc_string
p.dec = (*Buffer).dec_string
+ p.size = size_string
case reflect.Struct:
p.stype = t1.Elem()
p.isMarshaler = isMarshaler(t1)
@@ -311,9 +335,11 @@
if p.Wire == "bytes" {
p.enc = (*Buffer).enc_struct_message
p.dec = (*Buffer).dec_struct_message
+ p.size = size_struct_message
} else {
p.enc = (*Buffer).enc_struct_group
p.dec = (*Buffer).dec_struct_group
+ p.size = size_struct_group
}
}
@@ -325,8 +351,10 @@
case reflect.Bool:
if p.Packed {
p.enc = (*Buffer).enc_slice_packed_bool
+ p.size = size_slice_packed_bool
} else {
p.enc = (*Buffer).enc_slice_bool
+ p.size = size_slice_bool
}
p.dec = (*Buffer).dec_slice_bool
p.packedDec = (*Buffer).dec_slice_packed_bool
@@ -335,16 +363,20 @@
case 32:
if p.Packed {
p.enc = (*Buffer).enc_slice_packed_int32
+ p.size = size_slice_packed_int32
} else {
p.enc = (*Buffer).enc_slice_int32
+ p.size = size_slice_int32
}
p.dec = (*Buffer).dec_slice_int32
p.packedDec = (*Buffer).dec_slice_packed_int32
case 64:
if p.Packed {
p.enc = (*Buffer).enc_slice_packed_int64
+ p.size = size_slice_packed_int64
} else {
p.enc = (*Buffer).enc_slice_int64
+ p.size = size_slice_int64
}
p.dec = (*Buffer).dec_slice_int64
p.packedDec = (*Buffer).dec_slice_packed_int64
@@ -352,6 +384,7 @@
if t2.Kind() == reflect.Uint8 {
p.enc = (*Buffer).enc_slice_byte
p.dec = (*Buffer).dec_slice_byte
+ p.size = size_slice_byte
}
default:
logNoSliceEnc(t1, t2)
@@ -363,8 +396,10 @@
// can just treat them as bits
if p.Packed {
p.enc = (*Buffer).enc_slice_packed_int32
+ p.size = size_slice_packed_int32
} else {
p.enc = (*Buffer).enc_slice_int32
+ p.size = size_slice_int32
}
p.dec = (*Buffer).dec_slice_int32
p.packedDec = (*Buffer).dec_slice_packed_int32
@@ -372,8 +407,10 @@
// can just treat them as bits
if p.Packed {
p.enc = (*Buffer).enc_slice_packed_int64
+ p.size = size_slice_packed_int64
} else {
p.enc = (*Buffer).enc_slice_int64
+ p.size = size_slice_int64
}
p.dec = (*Buffer).dec_slice_int64
p.packedDec = (*Buffer).dec_slice_packed_int64
@@ -384,6 +421,7 @@
case reflect.String:
p.enc = (*Buffer).enc_slice_string
p.dec = (*Buffer).dec_slice_string
+ p.size = size_slice_string
case reflect.Ptr:
switch t3 := t2.Elem(); t3.Kind() {
default:
@@ -393,11 +431,14 @@
p.stype = t2.Elem()
p.isMarshaler = isMarshaler(t2)
p.isUnmarshaler = isUnmarshaler(t2)
- p.enc = (*Buffer).enc_slice_struct_group
- p.dec = (*Buffer).dec_slice_struct_group
if p.Wire == "bytes" {
p.enc = (*Buffer).enc_slice_struct_message
p.dec = (*Buffer).dec_slice_struct_message
+ p.size = size_slice_struct_message
+ } else {
+ p.enc = (*Buffer).enc_slice_struct_group
+ p.dec = (*Buffer).dec_slice_struct_group
+ p.size = size_slice_struct_group
}
}
case reflect.Slice:
@@ -408,6 +449,7 @@
case reflect.Uint8:
p.enc = (*Buffer).enc_slice_slice_byte
p.dec = (*Buffer).dec_slice_slice_byte
+ p.size = size_slice_slice_byte
}
}
}
@@ -525,6 +567,7 @@
if f.Name == "XXX_extensions" { // special case
p.enc = (*Buffer).enc_map
p.dec = nil // not needed
+ p.size = size_map
}
if f.Name == "XXX_unrecognized" { // special case
prop.unrecField = toField(&f)
diff --git a/proto/size.go b/proto/size.go
deleted file mode 100644
index ebdc893..0000000
--- a/proto/size.go
+++ /dev/null
@@ -1,193 +0,0 @@
-// Go support for Protocol Buffers - Google's data interchange format
-//
-// Copyright 2012 The Go Authors. All rights reserved.
-// http://code.google.com/p/goprotobuf/
-//
-// Redistribution and use in source and binary forms, with or without
-// modification, are permitted provided that the following conditions are
-// met:
-//
-// * Redistributions of source code must retain the above copyright
-// notice, this list of conditions and the following disclaimer.
-// * Redistributions in binary form must reproduce the above
-// copyright notice, this list of conditions and the following disclaimer
-// in the documentation and/or other materials provided with the
-// distribution.
-// * Neither the name of Google Inc. nor the names of its
-// contributors may be used to endorse or promote products derived from
-// this software without specific prior written permission.
-//
-// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-// Functions to determine the size of an encoded protocol buffer.
-
-package proto
-
-import (
- "log"
- "reflect"
- "strings"
-)
-
-// Size returns the encoded size of a protocol buffer.
-// This function is expensive enough to be avoided unless proven worthwhile with instrumentation.
-func Size(pb Message) int {
- in := reflect.ValueOf(pb)
- if in.IsNil() {
- return 0
- }
- return sizeStruct(in.Elem())
-}
-
-func sizeStruct(x reflect.Value) (n int) {
- sprop := GetProperties(x.Type())
- for _, prop := range sprop.Prop {
- if strings.HasPrefix(prop.Name, "XXX_") { // handled below
- continue
- }
- fi, _ := sprop.decoderTags.get(prop.Tag)
- f := x.Field(fi)
- switch f.Kind() {
- case reflect.Ptr:
- if f.IsNil() {
- continue
- }
- n += len(prop.tagcode)
- f = f.Elem() // avoid a recursion in sizeField
- case reflect.Slice:
- if f.IsNil() {
- continue
- }
- if f.Len() == 0 && f.Type().Elem().Kind() != reflect.Uint8 {
- // short circuit for empty repeated fields.
- // []byte isn't a repeated field.
- continue
- }
- default:
- log.Printf("proto: unknown struct field type %v", f.Type())
- continue
- }
- n += sizeField(f, prop)
- }
-
- if em, ok := x.Addr().Interface().(extendableProto); ok {
- for _, ext := range em.ExtensionMap() {
- ms := len(ext.enc)
- if ext.enc == nil {
- props := new(Properties)
- props.Init(reflect.TypeOf(ext.desc.ExtensionType), "x", ext.desc.Tag, nil)
- ms = len(props.tagcode) + sizeField(reflect.ValueOf(ext.value), props)
- }
- n += ms
- }
- }
-
- if uf := x.FieldByName("XXX_unrecognized"); uf.IsValid() {
- n += uf.Len()
- }
-
- return n
-}
-
-func sizeField(x reflect.Value, prop *Properties) (n int) {
- if x.Type().Kind() == reflect.Slice {
- n := x.Len()
- et := x.Type().Elem()
- if et.Kind() == reflect.Uint8 {
- // []byte is easy.
- return len(prop.tagcode) + sizeVarint(uint64(n)) + n
- }
-
- var nb int
-
- // []bool and repeated fixed integer types are easy.
- switch {
- case et.Kind() == reflect.Bool:
- nb += n
- case prop.WireType == WireFixed64:
- nb += n * 8
- case prop.WireType == WireFixed32:
- nb += n * 4
- default:
- for i := 0; i < n; i++ {
- nb += sizeField(x.Index(i), prop)
- }
- }
- // Non-packed repeated fields have a per-element header of the tagcode.
- // Packed repeated fields only have a single header: the tag code plus a varint of the number of bytes.
- if !prop.Packed {
- nb += len(prop.tagcode) * n
- } else {
- nb += len(prop.tagcode) + sizeVarint(uint64(nb))
- }
- return nb
- }
-
- // easy scalars
- switch prop.WireType {
- case WireFixed64:
- return 8
- case WireFixed32:
- return 4
- }
-
- switch x.Kind() {
- case reflect.Bool:
- return 1
- case reflect.Int32, reflect.Int64:
- if prop.Wire == "varint" {
- return sizeVarint(uint64(x.Int()))
- } else if prop.Wire == "zigzag32" || prop.Wire == "zigzag64" {
- return sizeZigZag(uint64(x.Int()))
- }
- case reflect.Ptr:
- return sizeField(x.Elem(), prop)
- case reflect.String:
- n := x.Len()
- return sizeVarint(uint64(n)) + n
- case reflect.Struct:
- nb := sizeStruct(x)
- if prop.Wire == "group" {
- // Groups have start and end tags instead of a start tag and a length.
- return nb + len(prop.tagcode)
- }
- return sizeVarint(uint64(nb)) + nb
- case reflect.Uint32, reflect.Uint64:
- if prop.Wire == "varint" {
- return sizeVarint(uint64(x.Uint()))
- } else if prop.Wire == "zigzag32" || prop.Wire == "zigzag64" {
- return sizeZigZag(uint64(x.Int()))
- }
- default:
- log.Printf("proto.sizeField: unhandled kind %v", x.Kind())
- }
-
- // unknown type, so not a protocol buffer
- log.Printf("proto: don't know size of %v", x.Type())
- return 0
-}
-
-func sizeVarint(x uint64) (n int) {
- for {
- n++
- x >>= 7
- if x == 0 {
- break
- }
- }
- return n
-}
-
-func sizeZigZag(x uint64) (n int) {
- return sizeVarint(uint64((x << 1) ^ uint64((int64(x) >> 63))))
-}
diff --git a/proto/text.go b/proto/text.go
index ff8a0bb..f8cb9c9 100644
--- a/proto/text.go
+++ b/proto/text.go
@@ -477,7 +477,7 @@
switch wire {
case WireBytes:
buf, e := b.DecodeRawBytes(false)
- if err == nil {
+ if e == nil {
_, err = fmt.Fprintf(w, "%q", buf)
} else {
_, err = fmt.Fprintf(w, "/* %v */", e)