goprotobuf: Support extensions in proto.Equal.
R=r
CC=golang-dev
http://codereview.appspot.com/4950066
diff --git a/proto/equal.go b/proto/equal.go
index 3a12a7f..405d99a 100644
--- a/proto/equal.go
+++ b/proto/equal.go
@@ -37,6 +37,7 @@
import (
"bytes"
"log"
+ "os"
"reflect"
"strings"
)
@@ -59,7 +60,7 @@
- Two unknown field sets are equal if their current
encoded state is equal. (TODO)
- Two extension sets are equal iff they have corresponding
- elements that are pairwise equal. (TODO)
+ elements that are pairwise equal.
- Every other combination of things are not equal.
The return value is undefined if a and b are not protocol buffers.
@@ -101,7 +102,14 @@
}
}
- // TODO: Deal with XXX_unrecognized and XXX_extensions.
+ if em1 := v1.FieldByName("XXX_extensions"); em1.IsValid() {
+ em2 := v2.FieldByName("XXX_extensions")
+ if !equalExtensions(v1.Type(), em1.Interface().(map[int32]Extension), em2.Interface().(map[int32]Extension)) {
+ return false
+ }
+ }
+
+ // TODO: Deal with XXX_unrecognized.
return true
}
@@ -148,3 +156,56 @@
log.Printf("proto: don't know how to compare %v", v1)
return false
}
+
+// base is the struct type that the extensions are based on.
+// em1 and em2 are extension maps.
+func equalExtensions(base reflect.Type, em1, em2 map[int32]Extension) bool {
+ if len(em1) != len(em2) {
+ return false
+ }
+
+ for extNum, e1 := range em1 {
+ e2, ok := em2[extNum]
+ if !ok {
+ return false
+ }
+
+ m1, m2 := e1.value, e2.value
+
+ if m1 != nil && m2 != nil {
+ // Both are unencoded.
+ if !Equal(m1, m2) {
+ return false
+ }
+ continue
+ }
+
+ // At least one is encoded. To do a semantically correct comparison
+ // we need to unmarshal them first.
+ var desc *ExtensionDesc
+ if m := extensionMaps[base]; m != nil {
+ desc = m[extNum]
+ }
+ if desc == nil {
+ log.Printf("proto: don't know how to compare extension %d of %v", extNum, base)
+ continue
+ }
+ var err os.Error
+ if m1 == nil {
+ m1, err = decodeExtension(e1.enc, desc)
+ }
+ if m2 == nil && err == nil {
+ m2, err = decodeExtension(e2.enc, desc)
+ }
+ if err != nil {
+ // The encoded form is invalid.
+ log.Printf("proto: badly encoded extension %d of %v: %v", extNum, base, err)
+ return false
+ }
+ if !Equal(m1, m2) {
+ return false
+ }
+ }
+
+ return true
+}