goprotobuf: Avoid reflect+interface conversion for Marshaler/Unmarshaler check.
R=r
CC=golang-dev, rsc
http://codereview.appspot.com/5824045
diff --git a/proto/decode.go b/proto/decode.go
index 17f95f2..f259b7d 100644
--- a/proto/decode.go
+++ b/proto/decode.go
@@ -354,8 +354,7 @@
fieldnum, ok := prop.tags[tag]
if !ok {
// Maybe it's an extension?
- o.ptr = base // copy the address here to avoid a heap allocation.
- iv := reflect.NewAt(t, unsafe.Pointer(&o.ptr)).Elem().Interface()
+ iv := reflect.NewAt(st, unsafe.Pointer(base)).Interface()
if e, ok := iv.(extendableProto); ok && isExtensionField(e, int32(tag)) {
if err = o.skip(st, tag, wire); err == nil {
e.ExtensionMap()[int32(tag)] = Extension{enc: append([]byte(nil), o.buf[oi:o.index]...)}
@@ -654,13 +653,13 @@
ptr := (**struct{})(unsafe.Pointer(base + p.offset))
typ := p.stype.Elem()
bas := reflect.New(typ).Pointer()
- structv := unsafe.Pointer(bas)
- *ptr = (*struct{})(structv)
+ structp := unsafe.Pointer(bas)
+ *ptr = (*struct{})(structp)
// If the object can unmarshal itself, let it.
- iv := reflect.NewAt(p.stype, unsafe.Pointer(ptr)).Elem().Interface()
- if u, ok := iv.(Unmarshaler); ok {
- return u.Unmarshal(raw)
+ if p.isMarshaler {
+ iv := reflect.NewAt(p.stype.Elem(), structp).Interface()
+ return iv.(Unmarshaler).Unmarshal(raw)
}
obuf := o.buf
@@ -693,8 +692,8 @@
typ := p.stype.Elem()
bas := reflect.New(typ).Pointer()
- structv := unsafe.Pointer(bas)
- y = append(y, (*struct{})(structv))
+ structp := unsafe.Pointer(bas)
+ y = append(y, (*struct{})(structp))
*v = y
if is_group {
@@ -708,9 +707,9 @@
}
// If the object can unmarshal itself, let it.
- iv := reflect.NewAt(p.stype, unsafe.Pointer(&y[len(y)-1])).Elem().Interface()
- if u, ok := iv.(Unmarshaler); ok {
- return u.Unmarshal(raw)
+ if p.isUnmarshaler {
+ iv := reflect.NewAt(typ, structp).Interface()
+ return iv.(Unmarshaler).Unmarshal(raw)
}
obuf := o.buf
diff --git a/proto/encode.go b/proto/encode.go
index a1a9f1d..a64de1f 100644
--- a/proto/encode.go
+++ b/proto/encode.go
@@ -272,12 +272,16 @@
// Encode a message struct.
func (o *Buffer) enc_struct_message(p *Properties, base uintptr) error {
+ structp := *(*unsafe.Pointer)(unsafe.Pointer(base + p.offset))
+ if structp == nil {
+ return ErrNil
+ }
+
+ typ := p.stype.Elem()
+
// Can the object marshal itself?
- iv := reflect.NewAt(p.stype, unsafe.Pointer(base+p.offset)).Elem().Interface()
- if m, ok := iv.(Marshaler); ok {
- if isNil(reflect.ValueOf(iv)) {
- return ErrNil
- }
+ if p.isMarshaler {
+ m := reflect.NewAt(typ, structp).Interface().(Marshaler)
data, err := m.Marshal()
if err != nil {
return err
@@ -286,19 +290,13 @@
o.EncodeRawBytes(data)
return nil
}
- v := *(**struct{})(unsafe.Pointer(base + p.offset))
- if v == nil {
- return ErrNil
- }
// need the length before we can write out the message itself,
// so marshal into a separate byte buffer first.
obuf := o.buf
o.buf = o.bufalloc()
- b := uintptr(unsafe.Pointer(v))
- typ := p.stype.Elem()
- err := o.enc_struct(typ, b)
+ err := o.enc_struct(typ, uintptr(structp))
nbuf := o.buf
o.buf = obuf
@@ -473,22 +471,19 @@
// Encode a slice of message structs ([]*struct).
func (o *Buffer) enc_slice_struct_message(p *Properties, base uintptr) error {
- s := *(*[]*struct{})(unsafe.Pointer(base + p.offset))
+ s := *(*[]unsafe.Pointer)(unsafe.Pointer(base + p.offset))
l := len(s)
typ := p.stype.Elem()
for i := 0; i < l; i++ {
- v := s[i]
- if v == nil {
+ structp := s[i]
+ if structp == nil {
return ErrRepeatedHasNil
}
// Can the object marshal itself?
- iv := reflect.NewAt(p.stype, unsafe.Pointer(&s[i])).Elem().Interface()
- if m, ok := iv.(Marshaler); ok {
- if isNil(reflect.ValueOf(iv)) {
- return ErrNil
- }
+ if p.isMarshaler {
+ m := reflect.NewAt(typ, structp).Interface().(Marshaler)
data, err := m.Marshal()
if err != nil {
return err
@@ -501,8 +496,7 @@
obuf := o.buf
o.buf = o.bufalloc()
- b := uintptr(unsafe.Pointer(v))
- err := o.enc_struct(typ, b)
+ err := o.enc_struct(typ, uintptr(structp))
nbuf := o.buf
o.buf = obuf
diff --git a/proto/lib.go b/proto/lib.go
index a6949df..86187e0 100644
--- a/proto/lib.go
+++ b/proto/lib.go
@@ -198,7 +198,6 @@
index int // write point
freelist [10][]byte // list of available buffers
nfreelist int // number of free buffers
- ptr uintptr // used to avoid a heap allocation.
// pools of basic types to amortize allocation.
bools []bool
int32s []int32
diff --git a/proto/properties.go b/proto/properties.go
index 47bc425..c68936d 100644
--- a/proto/properties.go
+++ b/proto/properties.go
@@ -108,12 +108,14 @@
Default string // default value
def_uint64 uint64
- enc encoder
- valEnc valueEncoder // set for bool and numeric types only
- offset uintptr
- tagcode []byte // encoding of EncodeVarint((Tag<<3)|WireType)
- tagbuf [8]byte
- stype reflect.Type
+ enc encoder
+ valEnc valueEncoder // set for bool and numeric types only
+ offset uintptr
+ tagcode []byte // encoding of EncodeVarint((Tag<<3)|WireType)
+ tagbuf [8]byte
+ stype reflect.Type
+ isMarshaler bool
+ isUnmarshaler bool
dec decoder
valDec valueDecoder // set for bool and numeric types only
@@ -261,6 +263,8 @@
p.dec = (*Buffer).dec_string
case reflect.Struct:
p.stype = t1
+ p.isMarshaler = isMarshaler(t1)
+ p.isUnmarshaler = isUnmarshaler(t1)
if p.Wire == "bytes" {
p.enc = (*Buffer).enc_struct_message
p.dec = (*Buffer).dec_struct_message
@@ -344,6 +348,8 @@
break
case reflect.Struct:
p.stype = t2
+ p.isMarshaler = isMarshaler(t2)
+ p.isUnmarshaler = isUnmarshaler(t2)
p.enc = (*Buffer).enc_slice_struct_group
p.dec = (*Buffer).dec_slice_struct_group
if p.Wire == "bytes" {
@@ -378,6 +384,33 @@
p.tagcode = p.tagbuf[0 : i+1]
}
+var (
+ marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem()
+ unmarshalerType = reflect.TypeOf((*Unmarshaler)(nil)).Elem()
+)
+
+// isMarshaler reports whether type t implements Marshaler.
+func isMarshaler(t reflect.Type) bool {
+ // We're checking for (likely) pointer-receiver methods
+ // so if t is not a pointer, something is very wrong.
+ // The calls above only invoke isMarshaler on pointer types.
+ if t.Kind() != reflect.Ptr {
+ panic("proto: misuse of isMarshaler")
+ }
+ return t.Implements(marshalerType)
+}
+
+// isUnmarshaler reports whether type t implements Unmarshaler.
+func isUnmarshaler(t reflect.Type) bool {
+ // We're checking for (likely) pointer-receiver methods
+ // so if t is not a pointer, something is very wrong.
+ // The calls above only invoke isUnmarshaler on pointer types.
+ if t.Kind() != reflect.Ptr {
+ panic("proto: misuse of isUnmarshaler")
+ }
+ return t.Implements(unmarshalerType)
+}
+
// Init populates the properties from a protocol buffer struct tag.
func (p *Properties) Init(typ reflect.Type, name, tag string, offset uintptr) {
// "bytes,49,opt,def=hello!"