goprotobuf: Support extensions in proto.Equal.

R=r
CC=golang-dev
http://codereview.appspot.com/4950066
diff --git a/proto/equal.go b/proto/equal.go
index 3a12a7f..405d99a 100644
--- a/proto/equal.go
+++ b/proto/equal.go
@@ -37,6 +37,7 @@
 import (
 	"bytes"
 	"log"
+	"os"
 	"reflect"
 	"strings"
 )
@@ -59,7 +60,7 @@
   - Two unknown field sets are equal if their current
     encoded state is equal. (TODO)
   - Two extension sets are equal iff they have corresponding
-    elements that are pairwise equal. (TODO)
+    elements that are pairwise equal.
   - Every other combination of things are not equal.
 
 The return value is undefined if a and b are not protocol buffers.
@@ -101,7 +102,14 @@
 		}
 	}
 
-	// TODO: Deal with XXX_unrecognized and XXX_extensions.
+	if em1 := v1.FieldByName("XXX_extensions"); em1.IsValid() {
+		em2 := v2.FieldByName("XXX_extensions")
+		if !equalExtensions(v1.Type(), em1.Interface().(map[int32]Extension), em2.Interface().(map[int32]Extension)) {
+			return false
+		}
+	}
+
+	// TODO: Deal with XXX_unrecognized.
 
 	return true
 }
@@ -148,3 +156,56 @@
 	log.Printf("proto: don't know how to compare %v", v1)
 	return false
 }
+
+// base is the struct type that the extensions are based on.
+// em1 and em2 are extension maps.
+func equalExtensions(base reflect.Type, em1, em2 map[int32]Extension) bool {
+	if len(em1) != len(em2) {
+		return false
+	}
+
+	for extNum, e1 := range em1 {
+		e2, ok := em2[extNum]
+		if !ok {
+			return false
+		}
+
+		m1, m2 := e1.value, e2.value
+
+		if m1 != nil && m2 != nil {
+			// Both are unencoded.
+			if !Equal(m1, m2) {
+				return false
+			}
+			continue
+		}
+
+		// At least one is encoded. To do a semantically correct comparison
+		// we need to unmarshal them first.
+		var desc *ExtensionDesc
+		if m := extensionMaps[base]; m != nil {
+			desc = m[extNum]
+		}
+		if desc == nil {
+			log.Printf("proto: don't know how to compare extension %d of %v", extNum, base)
+			continue
+		}
+		var err os.Error
+		if m1 == nil {
+			m1, err = decodeExtension(e1.enc, desc)
+		}
+		if m2 == nil && err == nil {
+			m2, err = decodeExtension(e2.enc, desc)
+		}
+		if err != nil {
+			// The encoded form is invalid.
+			log.Printf("proto: badly encoded extension %d of %v: %v", extNum, base, err)
+			return false
+		}
+		if !Equal(m1, m2) {
+			return false
+		}
+	}
+
+	return true
+}
diff --git a/proto/equal_test.go b/proto/equal_test.go
index e22c13b..a2d7670 100644
--- a/proto/equal_test.go
+++ b/proto/equal_test.go
@@ -32,12 +32,48 @@
 package proto_test
 
 import (
+	"log"
 	"testing"
 
 	. "goprotobuf.googlecode.com/hg/proto"
 	pb "./testdata/_obj/test_proto"
 )
 
+// Four identical base messages.
+// The init function adds extensions to some of them.
+var messageWithoutExtension = &pb.MyMessage{Count: Int32(7)}
+var messageWithExtension1a = &pb.MyMessage{Count: Int32(7)}
+var messageWithExtension1b = &pb.MyMessage{Count: Int32(7)}
+var messageWithExtension2 = &pb.MyMessage{Count: Int32(7)}
+
+func init() {
+	ext1 := &pb.Ext{Data: String("Kirk")}
+	ext2 := &pb.Ext{Data: String("Picard")}
+
+	// messageWithExtension1a has ext1, but never marshals it.
+	if err := SetExtension(messageWithExtension1a, pb.E_Ext_More, ext1); err != nil {
+		log.Panicf("SetExtension on 1a failed: %v", err)
+	}
+
+	// messageWithExtension1b is the unmarshaled form of messageWithExtension1a.
+	if err := SetExtension(messageWithExtension1b, pb.E_Ext_More, ext1); err != nil {
+		log.Panicf("SetExtension on 1b failed: %v", err)
+	}
+	buf, err := Marshal(messageWithExtension1b)
+	if err != nil {
+		log.Panicf("Marshal of 1b failed: %v", err)
+	}
+	messageWithExtension1b.Reset()
+	if err := Unmarshal(buf, messageWithExtension1b); err != nil {
+		log.Panicf("Unmarshal of 1b failed: %v", err)
+	}
+
+	// messageWithExtension2 has ext2.
+	if err := SetExtension(messageWithExtension2, pb.E_Ext_More, ext2); err != nil {
+		log.Panicf("SetExtension on 2 failed: %v", err)
+	}
+}
+
 var EqualTests = []struct {
 	desc string
 	a, b interface{}
@@ -70,6 +106,10 @@
 		&pb.GoTest{RequiredField: &pb.GoTestField{Label: String("wow")}},
 		true,
 	},
+
+	{"extension vs. no extension", messageWithoutExtension, messageWithExtension1a, false},
+	{"extension vs. same extension", messageWithExtension1a, messageWithExtension1b, true},
+	{"extension vs. different extension", messageWithExtension1a, messageWithExtension2, false},
 }
 
 func TestEqual(t *testing.T) {
diff --git a/proto/extensions.go b/proto/extensions.go
index e929c8b..d329b90 100644
--- a/proto/extensions.go
+++ b/proto/extensions.go
@@ -176,9 +176,24 @@
 		return e.value, nil
 	}
 
+	v, err := decodeExtension(e.enc, extension)
+	if err != nil {
+		return nil, err
+	}
+
+	// Remember the decoded version and drop the encoded version.
+	// That way it is safe to mutate what we return.
+	e.value = v
+	e.desc = extension
+	e.enc = nil
+	return e.value, nil
+}
+
+// decodeExtension decodes an extension encoded in b.
+func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, os.Error) {
 	// Discard wire type and field number varint. It isn't needed.
-	_, n := DecodeVarint(e.enc)
-	o := NewBuffer(e.enc[n:])
+	_, n := DecodeVarint(b)
+	o := NewBuffer(b[n:])
 
 	t := reflect.TypeOf(extension.ExtensionType)
 	props := &Properties{}
@@ -195,12 +210,7 @@
 	if err := props.dec(o, props, uintptr(base), sbase); err != nil {
 		return nil, err
 	}
-	// Remember the decoded version and drop the encoded version.
-	// That way it is safe to mutate what we return.
-	e.value = unsafe.Unreflect(t, base)
-	e.desc = extension
-	e.enc = nil
-	return e.value, nil
+	return unsafe.Unreflect(t, base), nil
 }
 
 // GetExtensions returns a slice of the extensions present in pb that are also listed in es.