proto: check for required fields in encoding/decoding

Change-Id: I0555a92e0399782f075b1dcd248e880dd48c7d6d
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/170579
Reviewed-by: Herbie Ong <herbie@google.com>
diff --git a/proto/encode_test.go b/proto/encode_test.go
index d467b74..b9e04b9 100644
--- a/proto/encode_test.go
+++ b/proto/encode_test.go
@@ -15,7 +15,10 @@
 	for _, test := range testProtos {
 		for _, want := range test.decodeTo {
 			t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
-				wire, err := proto.Marshal(want)
+				opts := proto.MarshalOptions{
+					AllowPartial: test.partial,
+				}
+				wire, err := opts.Marshal(want)
 				if err != nil {
 					t.Fatalf("Marshal error: %v\nMessage:\n%v", err, marshalText(want))
 				}
@@ -26,11 +29,18 @@
 				}
 
 				got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
-				if err := proto.Unmarshal(wire, got); err != nil {
+				uopts := proto.UnmarshalOptions{
+					AllowPartial: test.partial,
+				}
+				if err := uopts.Unmarshal(wire, got); err != nil {
 					t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, protoV1.MarshalTextString(want.(protoV1.Message)))
 					return
 				}
 
+				if test.invalidExtensions {
+					// Equal doesn't work on messages containing invalid extension data.
+					return
+				}
 				if !protoV1.Equal(got.(protoV1.Message), want.(protoV1.Message)) {
 					t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", protoV1.MarshalTextString(got.(protoV1.Message)), protoV1.MarshalTextString(want.(protoV1.Message)))
 				}
@@ -43,26 +53,35 @@
 	for _, test := range testProtos {
 		for _, want := range test.decodeTo {
 			t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
-				wire, err := proto.MarshalOptions{Deterministic: true}.Marshal(want)
+				opts := proto.MarshalOptions{
+					Deterministic: true,
+					AllowPartial:  test.partial,
+				}
+				wire, err := opts.Marshal(want)
 				if err != nil {
 					t.Fatalf("Marshal error: %v\nMessage:\n%v", err, marshalText(want))
 				}
-
-				wire2, err := proto.MarshalOptions{Deterministic: true}.Marshal(want)
+				wire2, err := opts.Marshal(want)
 				if err != nil {
 					t.Fatalf("Marshal error: %v\nMessage:\n%v", err, marshalText(want))
 				}
-
 				if !bytes.Equal(wire, wire2) {
 					t.Fatalf("deterministic marshal returned varying results:\n%v", cmp.Diff(wire, wire2))
 				}
 
 				got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
-				if err := proto.Unmarshal(wire, got); err != nil {
+				uopts := proto.UnmarshalOptions{
+					AllowPartial: test.partial,
+				}
+				if err := uopts.Unmarshal(wire, got); err != nil {
 					t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, marshalText(want))
 					return
 				}
 
+				if test.invalidExtensions {
+					// Equal doesn't work on messages containing invalid extension data.
+					return
+				}
 				if !protoV1.Equal(got.(protoV1.Message), want.(protoV1.Message)) {
 					t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want))
 				}
@@ -70,3 +89,19 @@
 		}
 	}
 }
+
+func TestEncodeRequiredFieldChecks(t *testing.T) {
+	for _, test := range testProtos {
+		if !test.partial {
+			continue
+		}
+		for _, m := range test.decodeTo {
+			t.Run(fmt.Sprintf("%s (%T)", test.desc, m), func(t *testing.T) {
+				_, err := proto.Marshal(m)
+				if err == nil {
+					t.Fatalf("Marshal succeeded (want error)\nMessage:\n%v", marshalText(m))
+				}
+			})
+		}
+	}
+}