goprotobuf: Detect and return a useful error if someone passes T (instead of *T) to proto.Marshal.
R=r
CC=golang-dev
http://codereview.appspot.com/4805047
diff --git a/proto/all_test.go b/proto/all_test.go
index eb32155..dd7f66e 100644
--- a/proto/all_test.go
+++ b/proto/all_test.go
@@ -1212,6 +1212,15 @@
}
}
+// Check that passing a struct to Marshal returns a good error,
+// rather than panicking.
+func TestStructMarshaling(t *testing.T) {
+ _, err := Marshal(OtherMessage{})
+ if err != ErrNotPtr {
+ t.Errorf("got %v, expected %v", err, ErrNotPtr)
+ }
+}
+
func BenchmarkMarshal(b *testing.B) {
b.StopTimer()
diff --git a/proto/encode.go b/proto/encode.go
index 7f242e1..3952dc4 100644
--- a/proto/encode.go
+++ b/proto/encode.go
@@ -52,15 +52,20 @@
}
func (e *ErrRequiredNotSet) String() string {
- return "required fields not set in " + e.t.String()
+ return "proto: required fields not set in " + e.t.String()
}
-// ErrRepeatedHasNil is the error returned if Marshal is called with
-// a protocol buffer struct with a repeated field containing a nil element.
-var ErrRepeatedHasNil = os.NewError("repeated field has nil")
+var (
+ // ErrRepeatedHasNil is the error returned if Marshal is called with
+ // a struct with a repeated field containing a nil element.
+ ErrRepeatedHasNil = os.NewError("proto: repeated field has nil")
-// ErrNil is the error returned if Marshal is called with nil.
-var ErrNil = os.NewError("marshal called with nil")
+ // ErrNil is the error returned if Marshal is called with nil.
+ ErrNil = os.NewError("proto: Marshal called with nil")
+
+ // ErrNotPtr is the error returned if Marshal is called with a non-pointer.
+ ErrNotPtr = os.NewError("proto: Marshal called with a non-pointer")
+)
// The fundamental encoders that put bytes on the wire.
// Those that take integer types all accept uint64 and are
@@ -213,6 +218,9 @@
mstat := runtime.MemStats.Mallocs
t, b, err := getbase(pb)
+ if t.Kind() != reflect.Ptr {
+ return ErrNotPtr
+ }
if err == nil {
err = p.enc_struct(t.Elem(), b)
}
diff --git a/proto/properties.go b/proto/properties.go
index 61d691e..44f4bc8 100644
--- a/proto/properties.go
+++ b/proto/properties.go
@@ -523,7 +523,7 @@
return prop.Prop[x[0]]
}
-// Get the address and type of a pointer to the structure from an interface.
+// Get the address and type of a pointer to a struct from an interface.
// unsafe.Reflect can do this, but does multiple mallocs.
func getbase(pb interface{}) (t reflect.Type, b uintptr, err os.Error) {
// get pointer