Fix handling of RequiredNotSetError being returned by fields that implement Marshaler.
This has a chance of breaking some code, but only if they were already
silently broken.
diff --git a/proto/all_test.go b/proto/all_test.go
index 5a9b6a4..b787d58 100644
--- a/proto/all_test.go
+++ b/proto/all_test.go
@@ -401,17 +401,18 @@
err error
}
-func (f fakeMarshaler) Marshal() ([]byte, error) {
- return f.b, f.err
+func (f *fakeMarshaler) Marshal() ([]byte, error) { return f.b, f.err }
+func (f *fakeMarshaler) String() string { return fmt.Sprintf("Bytes: %v Error: %v", f.b, f.err) }
+func (f *fakeMarshaler) ProtoMessage() {}
+func (f *fakeMarshaler) Reset() {}
+
+type msgWithFakeMarshaler struct {
+ M *fakeMarshaler `protobuf:"bytes,1,opt,name=fake"`
}
-func (f fakeMarshaler) String() string {
- return fmt.Sprintf("Bytes: %v Error: %v", f.b, f.err)
-}
-
-func (f fakeMarshaler) ProtoMessage() {}
-
-func (f fakeMarshaler) Reset() {}
+func (m *msgWithFakeMarshaler) String() string { return CompactTextString(m) }
+func (m *msgWithFakeMarshaler) ProtoMessage() {}
+func (m *msgWithFakeMarshaler) Reset() {}
// Simple tests for proto messages that implement the Marshaler interface.
func TestMarshalerEncoding(t *testing.T) {
@@ -423,7 +424,7 @@
}{
{
name: "Marshaler that fails",
- m: fakeMarshaler{
+ m: &fakeMarshaler{
err: errors.New("some marshal err"),
b: []byte{5, 6, 7},
},
@@ -432,8 +433,24 @@
wantErr: errors.New("some marshal err"),
},
{
+ name: "Marshaler that fails with RequiredNotSetError",
+ m: &msgWithFakeMarshaler{
+ M: &fakeMarshaler{
+ err: &RequiredNotSetError{},
+ b: []byte{5, 6, 7},
+ },
+ },
+ // Since there's an error that can be continued after,
+ // the buffer should be written.
+ want: []byte{
+ 10, 3, // for &msgWithFakeMarshaler
+ 5, 6, 7, // for &fakeMarshaler
+ },
+ wantErr: &RequiredNotSetError{},
+ },
+ {
name: "Marshaler that succeeds",
- m: fakeMarshaler{
+ m: &fakeMarshaler{
b: []byte{0, 1, 2, 3, 4, 127, 255},
},
want: []byte{0, 1, 2, 3, 4, 127, 255},
@@ -443,6 +460,10 @@
for _, test := range tests {
b := NewBuffer(nil)
err := b.Marshal(test.m)
+ if _, ok := err.(*RequiredNotSetError); ok {
+ // We're not in package proto, so we can only assert the type in this case.
+ err = &RequiredNotSetError{}
+ }
if !reflect.DeepEqual(test.wantErr, err) {
t.Errorf("%s: got err %v wanted %v", test.name, err, test.wantErr)
}