goprotobuf: Avoid reflect+interface conversion for Marshaler/Unmarshaler check.

R=r
CC=golang-dev, rsc
http://codereview.appspot.com/5824045
diff --git a/proto/decode.go b/proto/decode.go
index 17f95f2..f259b7d 100644
--- a/proto/decode.go
+++ b/proto/decode.go
@@ -354,8 +354,7 @@
 		fieldnum, ok := prop.tags[tag]
 		if !ok {
 			// Maybe it's an extension?
-			o.ptr = base // copy the address here to avoid a heap allocation.
-			iv := reflect.NewAt(t, unsafe.Pointer(&o.ptr)).Elem().Interface()
+			iv := reflect.NewAt(st, unsafe.Pointer(base)).Interface()
 			if e, ok := iv.(extendableProto); ok && isExtensionField(e, int32(tag)) {
 				if err = o.skip(st, tag, wire); err == nil {
 					e.ExtensionMap()[int32(tag)] = Extension{enc: append([]byte(nil), o.buf[oi:o.index]...)}
@@ -654,13 +653,13 @@
 	ptr := (**struct{})(unsafe.Pointer(base + p.offset))
 	typ := p.stype.Elem()
 	bas := reflect.New(typ).Pointer()
-	structv := unsafe.Pointer(bas)
-	*ptr = (*struct{})(structv)
+	structp := unsafe.Pointer(bas)
+	*ptr = (*struct{})(structp)
 
 	// If the object can unmarshal itself, let it.
-	iv := reflect.NewAt(p.stype, unsafe.Pointer(ptr)).Elem().Interface()
-	if u, ok := iv.(Unmarshaler); ok {
-		return u.Unmarshal(raw)
+	if p.isMarshaler {
+		iv := reflect.NewAt(p.stype.Elem(), structp).Interface()
+		return iv.(Unmarshaler).Unmarshal(raw)
 	}
 
 	obuf := o.buf
@@ -693,8 +692,8 @@
 
 	typ := p.stype.Elem()
 	bas := reflect.New(typ).Pointer()
-	structv := unsafe.Pointer(bas)
-	y = append(y, (*struct{})(structv))
+	structp := unsafe.Pointer(bas)
+	y = append(y, (*struct{})(structp))
 	*v = y
 
 	if is_group {
@@ -708,9 +707,9 @@
 	}
 
 	// If the object can unmarshal itself, let it.
-	iv := reflect.NewAt(p.stype, unsafe.Pointer(&y[len(y)-1])).Elem().Interface()
-	if u, ok := iv.(Unmarshaler); ok {
-		return u.Unmarshal(raw)
+	if p.isUnmarshaler {
+		iv := reflect.NewAt(typ, structp).Interface()
+		return iv.(Unmarshaler).Unmarshal(raw)
 	}
 
 	obuf := o.buf
diff --git a/proto/encode.go b/proto/encode.go
index a1a9f1d..a64de1f 100644
--- a/proto/encode.go
+++ b/proto/encode.go
@@ -272,12 +272,16 @@
 
 // Encode a message struct.
 func (o *Buffer) enc_struct_message(p *Properties, base uintptr) error {
+	structp := *(*unsafe.Pointer)(unsafe.Pointer(base + p.offset))
+	if structp == nil {
+		return ErrNil
+	}
+
+	typ := p.stype.Elem()
+
 	// Can the object marshal itself?
-	iv := reflect.NewAt(p.stype, unsafe.Pointer(base+p.offset)).Elem().Interface()
-	if m, ok := iv.(Marshaler); ok {
-		if isNil(reflect.ValueOf(iv)) {
-			return ErrNil
-		}
+	if p.isMarshaler {
+		m := reflect.NewAt(typ, structp).Interface().(Marshaler)
 		data, err := m.Marshal()
 		if err != nil {
 			return err
@@ -286,19 +290,13 @@
 		o.EncodeRawBytes(data)
 		return nil
 	}
-	v := *(**struct{})(unsafe.Pointer(base + p.offset))
-	if v == nil {
-		return ErrNil
-	}
 
 	// need the length before we can write out the message itself,
 	// so marshal into a separate byte buffer first.
 	obuf := o.buf
 	o.buf = o.bufalloc()
 
-	b := uintptr(unsafe.Pointer(v))
-	typ := p.stype.Elem()
-	err := o.enc_struct(typ, b)
+	err := o.enc_struct(typ, uintptr(structp))
 
 	nbuf := o.buf
 	o.buf = obuf
@@ -473,22 +471,19 @@
 
 // Encode a slice of message structs ([]*struct).
 func (o *Buffer) enc_slice_struct_message(p *Properties, base uintptr) error {
-	s := *(*[]*struct{})(unsafe.Pointer(base + p.offset))
+	s := *(*[]unsafe.Pointer)(unsafe.Pointer(base + p.offset))
 	l := len(s)
 	typ := p.stype.Elem()
 
 	for i := 0; i < l; i++ {
-		v := s[i]
-		if v == nil {
+		structp := s[i]
+		if structp == nil {
 			return ErrRepeatedHasNil
 		}
 
 		// Can the object marshal itself?
-		iv := reflect.NewAt(p.stype, unsafe.Pointer(&s[i])).Elem().Interface()
-		if m, ok := iv.(Marshaler); ok {
-			if isNil(reflect.ValueOf(iv)) {
-				return ErrNil
-			}
+		if p.isMarshaler {
+			m := reflect.NewAt(typ, structp).Interface().(Marshaler)
 			data, err := m.Marshal()
 			if err != nil {
 				return err
@@ -501,8 +496,7 @@
 		obuf := o.buf
 		o.buf = o.bufalloc()
 
-		b := uintptr(unsafe.Pointer(v))
-		err := o.enc_struct(typ, b)
+		err := o.enc_struct(typ, uintptr(structp))
 
 		nbuf := o.buf
 		o.buf = obuf
diff --git a/proto/lib.go b/proto/lib.go
index a6949df..86187e0 100644
--- a/proto/lib.go
+++ b/proto/lib.go
@@ -198,7 +198,6 @@
 	index     int        // write point
 	freelist  [10][]byte // list of available buffers
 	nfreelist int        // number of free buffers
-	ptr       uintptr    // used to avoid a heap allocation.
 	// pools of basic types to amortize allocation.
 	bools  []bool
 	int32s []int32
diff --git a/proto/properties.go b/proto/properties.go
index 47bc425..c68936d 100644
--- a/proto/properties.go
+++ b/proto/properties.go
@@ -108,12 +108,14 @@
 	Default    string // default value
 	def_uint64 uint64
 
-	enc     encoder
-	valEnc  valueEncoder // set for bool and numeric types only
-	offset  uintptr
-	tagcode []byte // encoding of EncodeVarint((Tag<<3)|WireType)
-	tagbuf  [8]byte
-	stype   reflect.Type
+	enc           encoder
+	valEnc        valueEncoder // set for bool and numeric types only
+	offset        uintptr
+	tagcode       []byte // encoding of EncodeVarint((Tag<<3)|WireType)
+	tagbuf        [8]byte
+	stype         reflect.Type
+	isMarshaler   bool
+	isUnmarshaler bool
 
 	dec    decoder
 	valDec valueDecoder // set for bool and numeric types only
@@ -261,6 +263,8 @@
 			p.dec = (*Buffer).dec_string
 		case reflect.Struct:
 			p.stype = t1
+			p.isMarshaler = isMarshaler(t1)
+			p.isUnmarshaler = isUnmarshaler(t1)
 			if p.Wire == "bytes" {
 				p.enc = (*Buffer).enc_struct_message
 				p.dec = (*Buffer).dec_struct_message
@@ -344,6 +348,8 @@
 				break
 			case reflect.Struct:
 				p.stype = t2
+				p.isMarshaler = isMarshaler(t2)
+				p.isUnmarshaler = isUnmarshaler(t2)
 				p.enc = (*Buffer).enc_slice_struct_group
 				p.dec = (*Buffer).dec_slice_struct_group
 				if p.Wire == "bytes" {
@@ -378,6 +384,33 @@
 	p.tagcode = p.tagbuf[0 : i+1]
 }
 
+var (
+	marshalerType   = reflect.TypeOf((*Marshaler)(nil)).Elem()
+	unmarshalerType = reflect.TypeOf((*Unmarshaler)(nil)).Elem()
+)
+
+// isMarshaler reports whether type t implements Marshaler.
+func isMarshaler(t reflect.Type) bool {
+	// We're checking for (likely) pointer-receiver methods
+	// so if t is not a pointer, something is very wrong.
+	// The calls above only invoke isMarshaler on pointer types.
+	if t.Kind() != reflect.Ptr {
+		panic("proto: misuse of isMarshaler")
+	}
+	return t.Implements(marshalerType)
+}
+
+// isUnmarshaler reports whether type t implements Unmarshaler.
+func isUnmarshaler(t reflect.Type) bool {
+	// We're checking for (likely) pointer-receiver methods
+	// so if t is not a pointer, something is very wrong.
+	// The calls above only invoke isUnmarshaler on pointer types.
+	if t.Kind() != reflect.Ptr {
+		panic("proto: misuse of isUnmarshaler")
+	}
+	return t.Implements(unmarshalerType)
+}
+
 // Init populates the properties from a protocol buffer struct tag.
 func (p *Properties) Init(typ reflect.Type, name, tag string, offset uintptr) {
 	// "bytes,49,opt,def=hello!"