proto: check for required fields in encoding/decoding
Change-Id: I0555a92e0399782f075b1dcd248e880dd48c7d6d
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/170579
Reviewed-by: Herbie Ong <herbie@google.com>
diff --git a/proto/encode_test.go b/proto/encode_test.go
index d467b74..b9e04b9 100644
--- a/proto/encode_test.go
+++ b/proto/encode_test.go
@@ -15,7 +15,10 @@
for _, test := range testProtos {
for _, want := range test.decodeTo {
t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
- wire, err := proto.Marshal(want)
+ opts := proto.MarshalOptions{
+ AllowPartial: test.partial,
+ }
+ wire, err := opts.Marshal(want)
if err != nil {
t.Fatalf("Marshal error: %v\nMessage:\n%v", err, marshalText(want))
}
@@ -26,11 +29,18 @@
}
got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
- if err := proto.Unmarshal(wire, got); err != nil {
+ uopts := proto.UnmarshalOptions{
+ AllowPartial: test.partial,
+ }
+ if err := uopts.Unmarshal(wire, got); err != nil {
t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, protoV1.MarshalTextString(want.(protoV1.Message)))
return
}
+ if test.invalidExtensions {
+ // Equal doesn't work on messages containing invalid extension data.
+ return
+ }
if !protoV1.Equal(got.(protoV1.Message), want.(protoV1.Message)) {
t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", protoV1.MarshalTextString(got.(protoV1.Message)), protoV1.MarshalTextString(want.(protoV1.Message)))
}
@@ -43,26 +53,35 @@
for _, test := range testProtos {
for _, want := range test.decodeTo {
t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
- wire, err := proto.MarshalOptions{Deterministic: true}.Marshal(want)
+ opts := proto.MarshalOptions{
+ Deterministic: true,
+ AllowPartial: test.partial,
+ }
+ wire, err := opts.Marshal(want)
if err != nil {
t.Fatalf("Marshal error: %v\nMessage:\n%v", err, marshalText(want))
}
-
- wire2, err := proto.MarshalOptions{Deterministic: true}.Marshal(want)
+ wire2, err := opts.Marshal(want)
if err != nil {
t.Fatalf("Marshal error: %v\nMessage:\n%v", err, marshalText(want))
}
-
if !bytes.Equal(wire, wire2) {
t.Fatalf("deterministic marshal returned varying results:\n%v", cmp.Diff(wire, wire2))
}
got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
- if err := proto.Unmarshal(wire, got); err != nil {
+ uopts := proto.UnmarshalOptions{
+ AllowPartial: test.partial,
+ }
+ if err := uopts.Unmarshal(wire, got); err != nil {
t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, marshalText(want))
return
}
+ if test.invalidExtensions {
+ // Equal doesn't work on messages containing invalid extension data.
+ return
+ }
if !protoV1.Equal(got.(protoV1.Message), want.(protoV1.Message)) {
t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want))
}
@@ -70,3 +89,19 @@
}
}
}
+
+func TestEncodeRequiredFieldChecks(t *testing.T) {
+ for _, test := range testProtos {
+ if !test.partial {
+ continue
+ }
+ for _, m := range test.decodeTo {
+ t.Run(fmt.Sprintf("%s (%T)", test.desc, m), func(t *testing.T) {
+ _, err := proto.Marshal(m)
+ if err == nil {
+ t.Fatalf("Marshal succeeded (want error)\nMessage:\n%v", marshalText(m))
+ }
+ })
+ }
+ }
+}