More extensive testing in extensions_test.go.
diff --git a/proto/extensions_test.go b/proto/extensions_test.go
index 9c8a6bc..451ad87 100644
--- a/proto/extensions_test.go
+++ b/proto/extensions_test.go
@@ -92,3 +92,46 @@
t.Errorf("GetExtension() not stable after unmarshaling")
}
}
+
+func TestExtensionsRoundTrip(t *testing.T) {
+ msg := &pb.MyMessage{}
+ ext1 := &pb.Ext{
+ Data: proto.String("hi"),
+ }
+ ext2 := &pb.Ext{
+ Data: proto.String("there"),
+ }
+ exists := proto.HasExtension(msg, pb.E_Ext_More)
+ if exists {
+ t.Error("Extension More present unexpectedly")
+ }
+ if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil {
+ t.Error(err)
+ }
+ if err := proto.SetExtension(msg, pb.E_Ext_More, ext2); err != nil {
+ t.Error(err)
+ }
+ e, err := proto.GetExtension(msg, pb.E_Ext_More)
+ if err != nil {
+ t.Error(err)
+ }
+ x, ok := e.(*pb.Ext)
+ if !ok {
+ t.Errorf("e has type %T, expected testdata.Ext", e)
+ } else if *x.Data != "there" {
+ t.Errorf("SetExtension failed to overwrite, got %+v, not 'there'", x)
+ }
+ proto.ClearExtension(msg, pb.E_Ext_More)
+ if _, err = proto.GetExtension(msg, pb.E_Ext_More); err != proto.ErrMissingExtension {
+ t.Errorf("got %v, expected ErrMissingExtension", e)
+ }
+ if _, err := proto.GetExtension(msg, pb.E_X215); err == nil {
+ t.Error("expected bad extension error, got nil")
+ }
+ if err := proto.SetExtension(msg, pb.E_X215, 12); err == nil {
+ t.Error("expected extension err")
+ }
+ if err := proto.SetExtension(msg, pb.E_Ext_More, 12); err == nil {
+ t.Error("expected some sort of type mismatch error, got nil")
+ }
+}