proto: add Equal
Add support for basic equality comparison of messages.
Messages are equal if they have the same type and marshal to the
same bytes with deterministic serialization, with some exceptions:
- Messages with different registered extensions are unequal.
- NaN is not equal to itself.
Unlike the v1 Equal, a nil message is equal to an empty message of
the same type.
Change-Id: Ibabdadd8c767b801051b8241aeae1ba077e58121
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/174277
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/proto/encode_test.go b/proto/encode_test.go
index d670edf..6ed5250 100644
--- a/proto/encode_test.go
+++ b/proto/encode_test.go
@@ -8,6 +8,7 @@
protoV1 "github.com/golang/protobuf/proto"
"github.com/golang/protobuf/v2/proto"
+ pref "github.com/golang/protobuf/v2/reflect/protoreflect"
"github.com/google/go-cmp/cmp"
test3pb "github.com/golang/protobuf/v2/internal/testprotos/test3"
@@ -30,7 +31,7 @@
t.Errorf("Size and marshal disagree: Size(m)=%v; len(Marshal(m))=%v\nMessage:\n%v", size, len(wire), marshalText(want))
}
- got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
+ got := newMessage(want)
uopts := proto.UnmarshalOptions{
AllowPartial: test.partial,
}
@@ -43,7 +44,7 @@
// Equal doesn't work on messages containing invalid extension data.
return
}
- if !protoV1.Equal(got.(protoV1.Message), want.(protoV1.Message)) {
+ if !proto.Equal(got, want) {
t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", protoV1.MarshalTextString(got.(protoV1.Message)), protoV1.MarshalTextString(want.(protoV1.Message)))
}
})
@@ -71,7 +72,7 @@
t.Fatalf("deterministic marshal returned varying results:\n%v", cmp.Diff(wire, wire2))
}
- got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
+ got := newMessage(want)
uopts := proto.UnmarshalOptions{
AllowPartial: test.partial,
}
@@ -84,7 +85,7 @@
// Equal doesn't work on messages containing invalid extension data.
return
}
- if !protoV1.Equal(got.(protoV1.Message), want.(protoV1.Message)) {
+ if !proto.Equal(got, want) {
t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want))
}
})
@@ -100,12 +101,12 @@
if !isErrInvalidUTF8(err) {
t.Errorf("Marshal did not return expected error for invalid UTF8: %v\nMessage:\n%v", err, marshalText(want))
}
- got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
+ got := newMessage(want)
if err := proto.Unmarshal(wire, got); !isErrInvalidUTF8(err) {
t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, marshalText(want))
return
}
- if !protoV1.Equal(got.(protoV1.Message), want.(protoV1.Message)) {
+ if !proto.Equal(got, want) {
t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want))
}
})
@@ -142,3 +143,13 @@
t.Fatalf("MarshalAppend modified prefix: got %v, want prefix %v", got, want)
}
}
+
+// newMessage returns a new message with the same type and extension fields as m.
+func newMessage(m proto.Message) proto.Message {
+ n := reflect.New(reflect.TypeOf(m).Elem()).Interface().(proto.Message)
+ m.ProtoReflect().KnownFields().ExtensionTypes().Range(func(xt pref.ExtensionType) bool {
+ n.ProtoReflect().KnownFields().ExtensionTypes().Register(xt)
+ return true
+ })
+ return n
+}