goprotobuf: Support extensions in proto.Equal.

R=r
CC=golang-dev
http://codereview.appspot.com/4950066
diff --git a/proto/equal.go b/proto/equal.go
index 3a12a7f..405d99a 100644
--- a/proto/equal.go
+++ b/proto/equal.go
@@ -37,6 +37,7 @@
 import (
 	"bytes"
 	"log"
+	"os"
 	"reflect"
 	"strings"
 )
@@ -59,7 +60,7 @@
   - Two unknown field sets are equal if their current
     encoded state is equal. (TODO)
   - Two extension sets are equal iff they have corresponding
-    elements that are pairwise equal. (TODO)
+    elements that are pairwise equal.
   - Every other combination of things are not equal.
 
 The return value is undefined if a and b are not protocol buffers.
@@ -101,7 +102,14 @@
 		}
 	}
 
-	// TODO: Deal with XXX_unrecognized and XXX_extensions.
+	if em1 := v1.FieldByName("XXX_extensions"); em1.IsValid() {
+		em2 := v2.FieldByName("XXX_extensions")
+		if !equalExtensions(v1.Type(), em1.Interface().(map[int32]Extension), em2.Interface().(map[int32]Extension)) {
+			return false
+		}
+	}
+
+	// TODO: Deal with XXX_unrecognized.
 
 	return true
 }
@@ -148,3 +156,56 @@
 	log.Printf("proto: don't know how to compare %v", v1)
 	return false
 }
+
+// base is the struct type that the extensions are based on.
+// em1 and em2 are extension maps.
+func equalExtensions(base reflect.Type, em1, em2 map[int32]Extension) bool {
+	if len(em1) != len(em2) {
+		return false
+	}
+
+	for extNum, e1 := range em1 {
+		e2, ok := em2[extNum]
+		if !ok {
+			return false
+		}
+
+		m1, m2 := e1.value, e2.value
+
+		if m1 != nil && m2 != nil {
+			// Both are unencoded.
+			if !Equal(m1, m2) {
+				return false
+			}
+			continue
+		}
+
+		// At least one is encoded. To do a semantically correct comparison
+		// we need to unmarshal them first.
+		var desc *ExtensionDesc
+		if m := extensionMaps[base]; m != nil {
+			desc = m[extNum]
+		}
+		if desc == nil {
+			log.Printf("proto: don't know how to compare extension %d of %v", extNum, base)
+			continue
+		}
+		var err os.Error
+		if m1 == nil {
+			m1, err = decodeExtension(e1.enc, desc)
+		}
+		if m2 == nil && err == nil {
+			m2, err = decodeExtension(e2.enc, desc)
+		}
+		if err != nil {
+			// The encoded form is invalid.
+			log.Printf("proto: badly encoded extension %d of %v: %v", extNum, base, err)
+			return false
+		}
+		if !Equal(m1, m2) {
+			return false
+		}
+	}
+
+	return true
+}