If a user passes a nil extension value to SetExtension(), return a useful error message.
The previous behavior was to silently drop all extensions from the encoded message.
diff --git a/proto/extensions.go b/proto/extensions.go
index f7667fa..5f62dff 100644
--- a/proto/extensions.go
+++ b/proto/extensions.go
@@ -37,6 +37,7 @@
import (
"errors"
+ "fmt"
"reflect"
"strconv"
"sync"
@@ -321,6 +322,14 @@
if typ != reflect.TypeOf(value) {
return errors.New("proto: bad extension value type")
}
+ // nil extension values need to be caught early, because the
+ // encoder can't distinguish an ErrNil due to a nil extension
+ // from an ErrNil due to a missing field. Extensions are
+ // always optional, so the encoder would just swallow the error
+ // and drop all the extensions from the encoded message.
+ if reflect.ValueOf(value).IsNil() {
+ return fmt.Errorf("proto: SetExtension called with nil value of type %T", value)
+ }
pb.ExtensionMap()[extension.Field] = Extension{desc: extension, value: value}
return nil
diff --git a/proto/extensions_test.go b/proto/extensions_test.go
index 451ad87..6495f56 100644
--- a/proto/extensions_test.go
+++ b/proto/extensions_test.go
@@ -135,3 +135,19 @@
t.Error("expected some sort of type mismatch error, got nil")
}
}
+
+func TestNilExtension(t *testing.T) {
+ msg := &pb.MyMessage{
+ Count: proto.Int32(1),
+ }
+ if err := proto.SetExtension(msg, pb.E_Ext_Text, proto.String("hello")); err != nil {
+ t.Fatal(err)
+ }
+ if err := proto.SetExtension(msg, pb.E_Ext_More, (*pb.Ext)(nil)); err == nil {
+ t.Error("expected SetExtension to fail due to a nil extension")
+ } else if want := "proto: SetExtension called with nil value of type *testdata.Ext"; err.Error() != want {
+ t.Errorf("expected error %v, got %v", want, err)
+ }
+ // Note: if the behavior of Marshal is ever changed to ignore nil extensions, update
+ // this test to verify that E_Ext_Text is properly propagated through marshal->unmarshal.
+}