proto: check for required fields in encoding/decoding

Change-Id: I0555a92e0399782f075b1dcd248e880dd48c7d6d
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/170579
Reviewed-by: Herbie Ong <herbie@google.com>
diff --git a/proto/decode_test.go b/proto/decode_test.go
index feb4ac6..0a94b8a 100644
--- a/proto/decode_test.go
+++ b/proto/decode_test.go
@@ -21,18 +21,23 @@
 )
 
 type testProto struct {
-	desc     string
-	decodeTo []proto.Message
-	wire     []byte
+	desc              string
+	decodeTo          []proto.Message
+	wire              []byte
+	partial           bool
+	invalidExtensions bool
 }
 
 func TestDecode(t *testing.T) {
 	for _, test := range testProtos {
 		for _, want := range test.decodeTo {
 			t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
+				opts := proto.UnmarshalOptions{
+					AllowPartial: test.partial,
+				}
 				wire := append(([]byte)(nil), test.wire...)
 				got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
-				if err := proto.Unmarshal(wire, got); err != nil {
+				if err := opts.Unmarshal(wire, got); err != nil {
 					t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, marshalText(want))
 					return
 				}
@@ -43,6 +48,10 @@
 					wire[i] = 0
 				}
 
+				if test.invalidExtensions {
+					// Equal doesn't work on messages containing invalid extension data.
+					return
+				}
 				if !protoV1.Equal(got.(protoV1.Message), want.(protoV1.Message)) {
 					t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want))
 				}
@@ -51,6 +60,26 @@
 	}
 }
 
+func TestDecodeRequiredFieldChecks(t *testing.T) {
+	for _, test := range testProtos {
+		if !test.partial {
+			continue
+		}
+		if test.invalidExtensions {
+			// Missing required fields in extensions just end up in the unknown fields.
+			continue
+		}
+		for _, m := range test.decodeTo {
+			t.Run(fmt.Sprintf("%s (%T)", test.desc, m), func(t *testing.T) {
+				got := reflect.New(reflect.TypeOf(m).Elem()).Interface().(proto.Message)
+				if err := proto.Unmarshal(test.wire, got); err == nil {
+					t.Fatalf("Unmarshal succeeded (want error)\nMessage:\n%v", marshalText(got))
+				}
+			})
+		}
+	}
+}
+
 var testProtos = []testProto{
 	{
 		desc: "basic scalar types",
@@ -878,6 +907,258 @@
 			}),
 		}.Marshal(),
 	},
+	{
+		desc:     "required field unset",
+		partial:  true,
+		decodeTo: []proto.Message{&testpb.TestRequired{}},
+	},
+	{
+		desc: "required field set",
+		decodeTo: []proto.Message{&testpb.TestRequired{
+			RequiredField: scalar.Int32(1),
+		}},
+		wire: pack.Message{
+			pack.Tag{1, pack.VarintType}, pack.Varint(1),
+		}.Marshal(),
+	},
+	{
+		desc:    "required field in optional message unset",
+		partial: true,
+		decodeTo: []proto.Message{&testpb.TestRequiredForeign{
+			OptionalMessage: &testpb.TestRequired{},
+		}},
+		wire: pack.Message{
+			pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{}),
+		}.Marshal(),
+	},
+	{
+		desc: "required field in optional message set",
+		decodeTo: []proto.Message{&testpb.TestRequiredForeign{
+			OptionalMessage: &testpb.TestRequired{
+				RequiredField: scalar.Int32(1),
+			},
+		}},
+		wire: pack.Message{
+			pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.VarintType}, pack.Varint(1),
+			}),
+		}.Marshal(),
+	},
+	// TODO: Handle this case.
+	/*
+		{
+			desc: "required field in optional message set (split across multiple tags)",
+			decodeTo: []proto.Message{&testpb.TestRequiredForeign{
+				OptionalMessage: &testpb.TestRequired{
+					RequiredField: scalar.Int32(1),
+				},
+			}},
+			wire: pack.Message{
+				pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{}),
+				pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{
+					pack.Tag{1, pack.VarintType}, pack.Varint(1),
+				}),
+			}.Marshal(),
+		},
+	*/
+	{
+		desc:    "required field in repeated message unset",
+		partial: true,
+		decodeTo: []proto.Message{&testpb.TestRequiredForeign{
+			RepeatedMessage: []*testpb.TestRequired{
+				{RequiredField: scalar.Int32(1)},
+				{},
+			},
+		}},
+		wire: pack.Message{
+			pack.Tag{2, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.VarintType}, pack.Varint(1),
+			}),
+			pack.Tag{2, pack.BytesType}, pack.LengthPrefix(pack.Message{}),
+		}.Marshal(),
+	},
+	{
+		desc: "required field in repeated message set",
+		decodeTo: []proto.Message{&testpb.TestRequiredForeign{
+			RepeatedMessage: []*testpb.TestRequired{
+				{RequiredField: scalar.Int32(1)},
+				{RequiredField: scalar.Int32(2)},
+			},
+		}},
+		wire: pack.Message{
+			pack.Tag{2, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.VarintType}, pack.Varint(1),
+			}),
+			pack.Tag{2, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.VarintType}, pack.Varint(2),
+			}),
+		}.Marshal(),
+	},
+	{
+		desc:    "required field in map message unset",
+		partial: true,
+		decodeTo: []proto.Message{&testpb.TestRequiredForeign{
+			MapMessage: map[int32]*testpb.TestRequired{
+				1: {RequiredField: scalar.Int32(1)},
+				2: {},
+			},
+		}},
+		wire: pack.Message{
+			pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.VarintType}, pack.Varint(1),
+				pack.Tag{2, pack.BytesType}, pack.LengthPrefix(pack.Message{
+					pack.Tag{1, pack.VarintType}, pack.Varint(1),
+				}),
+			}),
+			pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.VarintType}, pack.Varint(2),
+				pack.Tag{2, pack.BytesType}, pack.LengthPrefix(pack.Message{}),
+			}),
+		}.Marshal(),
+	},
+	{
+		desc: "required field in map message set",
+		decodeTo: []proto.Message{&testpb.TestRequiredForeign{
+			MapMessage: map[int32]*testpb.TestRequired{
+				1: {RequiredField: scalar.Int32(1)},
+				2: {RequiredField: scalar.Int32(2)},
+			},
+		}},
+		wire: pack.Message{
+			pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.VarintType}, pack.Varint(1),
+				pack.Tag{2, pack.BytesType}, pack.LengthPrefix(pack.Message{
+					pack.Tag{1, pack.VarintType}, pack.Varint(1),
+				}),
+			}),
+			pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.VarintType}, pack.Varint(2),
+				pack.Tag{2, pack.BytesType}, pack.LengthPrefix(pack.Message{
+					pack.Tag{1, pack.VarintType}, pack.Varint(2),
+				}),
+			}),
+		}.Marshal(),
+	},
+	{
+		desc:    "required field in optional group unset",
+		partial: true,
+		decodeTo: []proto.Message{&testpb.TestRequiredGroupFields{
+			Optionalgroup: &testpb.TestRequiredGroupFields_OptionalGroup{},
+		}},
+		wire: pack.Message{
+			pack.Tag{1, pack.StartGroupType},
+			pack.Tag{1, pack.EndGroupType},
+		}.Marshal(),
+	},
+	{
+		desc: "required field in optional group set",
+		decodeTo: []proto.Message{&testpb.TestRequiredGroupFields{
+			Optionalgroup: &testpb.TestRequiredGroupFields_OptionalGroup{
+				A: scalar.Int32(1),
+			},
+		}},
+		wire: pack.Message{
+			pack.Tag{1, pack.StartGroupType},
+			pack.Tag{2, pack.VarintType}, pack.Varint(1),
+			pack.Tag{1, pack.EndGroupType},
+		}.Marshal(),
+	},
+	{
+		desc:    "required field in repeated group unset",
+		partial: true,
+		decodeTo: []proto.Message{&testpb.TestRequiredGroupFields{
+			Repeatedgroup: []*testpb.TestRequiredGroupFields_RepeatedGroup{
+				{A: scalar.Int32(1)},
+				{},
+			},
+		}},
+		wire: pack.Message{
+			pack.Tag{3, pack.StartGroupType},
+			pack.Tag{4, pack.VarintType}, pack.Varint(1),
+			pack.Tag{3, pack.EndGroupType},
+			pack.Tag{3, pack.StartGroupType},
+			pack.Tag{3, pack.EndGroupType},
+		}.Marshal(),
+	},
+	{
+		desc: "required field in repeated group set",
+		decodeTo: []proto.Message{&testpb.TestRequiredGroupFields{
+			Repeatedgroup: []*testpb.TestRequiredGroupFields_RepeatedGroup{
+				{A: scalar.Int32(1)},
+				{A: scalar.Int32(2)},
+			},
+		}},
+		wire: pack.Message{
+			pack.Tag{3, pack.StartGroupType},
+			pack.Tag{4, pack.VarintType}, pack.Varint(1),
+			pack.Tag{3, pack.EndGroupType},
+			pack.Tag{3, pack.StartGroupType},
+			pack.Tag{4, pack.VarintType}, pack.Varint(2),
+			pack.Tag{3, pack.EndGroupType},
+		}.Marshal(),
+	},
+	{
+		desc:              "required field in extension message unset",
+		partial:           true,
+		invalidExtensions: true,
+		decodeTo: []proto.Message{build(
+			&testpb.TestAllExtensions{},
+			extend(testpb.E_TestRequired_Single, &testpb.TestRequired{}),
+		)},
+		wire: pack.Message{
+			pack.Tag{1000, pack.BytesType}, pack.LengthPrefix(pack.Message{}),
+		}.Marshal(),
+	},
+	{
+		desc: "required field in extension message set",
+		decodeTo: []proto.Message{build(
+			&testpb.TestAllExtensions{},
+			extend(testpb.E_TestRequired_Single, &testpb.TestRequired{
+				RequiredField: scalar.Int32(1),
+			}),
+		)},
+		wire: pack.Message{
+			pack.Tag{1000, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.VarintType}, pack.Varint(1),
+			}),
+		}.Marshal(),
+	},
+	{
+		desc:              "required field in repeated extension message unset",
+		partial:           true,
+		invalidExtensions: true,
+		decodeTo: []proto.Message{build(
+			&testpb.TestAllExtensions{},
+			extend(testpb.E_TestRequired_Multi, []*testpb.TestRequired{
+				{RequiredField: scalar.Int32(1)},
+				{},
+			}),
+		)},
+		wire: pack.Message{
+			pack.Tag{1001, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.VarintType}, pack.Varint(1),
+			}),
+			pack.Tag{1001, pack.BytesType}, pack.LengthPrefix(pack.Message{}),
+		}.Marshal(),
+	},
+	{
+		desc: "required field in repeated extension message set",
+		decodeTo: []proto.Message{build(
+			&testpb.TestAllExtensions{},
+			extend(testpb.E_TestRequired_Multi, []*testpb.TestRequired{
+				{RequiredField: scalar.Int32(1)},
+				{RequiredField: scalar.Int32(2)},
+			}),
+		)},
+		wire: pack.Message{
+			pack.Tag{1001, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.VarintType}, pack.Varint(1),
+			}),
+			pack.Tag{1001, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.VarintType}, pack.Varint(2),
+			}),
+		}.Marshal(),
+	},
 }
 
 func build(m proto.Message, opts ...buildOpt) proto.Message {