proto: check for required fields in encoding/decoding
Change-Id: I0555a92e0399782f075b1dcd248e880dd48c7d6d
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/170579
Reviewed-by: Herbie Ong <herbie@google.com>
diff --git a/proto/decode_test.go b/proto/decode_test.go
index feb4ac6..0a94b8a 100644
--- a/proto/decode_test.go
+++ b/proto/decode_test.go
@@ -21,18 +21,23 @@
)
type testProto struct {
- desc string
- decodeTo []proto.Message
- wire []byte
+ desc string
+ decodeTo []proto.Message
+ wire []byte
+ partial bool
+ invalidExtensions bool
}
func TestDecode(t *testing.T) {
for _, test := range testProtos {
for _, want := range test.decodeTo {
t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
+ opts := proto.UnmarshalOptions{
+ AllowPartial: test.partial,
+ }
wire := append(([]byte)(nil), test.wire...)
got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
- if err := proto.Unmarshal(wire, got); err != nil {
+ if err := opts.Unmarshal(wire, got); err != nil {
t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, marshalText(want))
return
}
@@ -43,6 +48,10 @@
wire[i] = 0
}
+ if test.invalidExtensions {
+ // Equal doesn't work on messages containing invalid extension data.
+ return
+ }
if !protoV1.Equal(got.(protoV1.Message), want.(protoV1.Message)) {
t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want))
}
@@ -51,6 +60,26 @@
}
}
+func TestDecodeRequiredFieldChecks(t *testing.T) {
+ for _, test := range testProtos {
+ if !test.partial {
+ continue
+ }
+ if test.invalidExtensions {
+ // Missing required fields in extensions just end up in the unknown fields.
+ continue
+ }
+ for _, m := range test.decodeTo {
+ t.Run(fmt.Sprintf("%s (%T)", test.desc, m), func(t *testing.T) {
+ got := reflect.New(reflect.TypeOf(m).Elem()).Interface().(proto.Message)
+ if err := proto.Unmarshal(test.wire, got); err == nil {
+ t.Fatalf("Unmarshal succeeded (want error)\nMessage:\n%v", marshalText(got))
+ }
+ })
+ }
+ }
+}
+
var testProtos = []testProto{
{
desc: "basic scalar types",
@@ -878,6 +907,258 @@
}),
}.Marshal(),
},
+ {
+ desc: "required field unset",
+ partial: true,
+ decodeTo: []proto.Message{&testpb.TestRequired{}},
+ },
+ {
+ desc: "required field set",
+ decodeTo: []proto.Message{&testpb.TestRequired{
+ RequiredField: scalar.Int32(1),
+ }},
+ wire: pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(1),
+ }.Marshal(),
+ },
+ {
+ desc: "required field in optional message unset",
+ partial: true,
+ decodeTo: []proto.Message{&testpb.TestRequiredForeign{
+ OptionalMessage: &testpb.TestRequired{},
+ }},
+ wire: pack.Message{
+ pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{}),
+ }.Marshal(),
+ },
+ {
+ desc: "required field in optional message set",
+ decodeTo: []proto.Message{&testpb.TestRequiredForeign{
+ OptionalMessage: &testpb.TestRequired{
+ RequiredField: scalar.Int32(1),
+ },
+ }},
+ wire: pack.Message{
+ pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(1),
+ }),
+ }.Marshal(),
+ },
+ // TODO: Handle this case.
+ /*
+ {
+ desc: "required field in optional message set (split across multiple tags)",
+ decodeTo: []proto.Message{&testpb.TestRequiredForeign{
+ OptionalMessage: &testpb.TestRequired{
+ RequiredField: scalar.Int32(1),
+ },
+ }},
+ wire: pack.Message{
+ pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{}),
+ pack.Tag{1, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(1),
+ }),
+ }.Marshal(),
+ },
+ */
+ {
+ desc: "required field in repeated message unset",
+ partial: true,
+ decodeTo: []proto.Message{&testpb.TestRequiredForeign{
+ RepeatedMessage: []*testpb.TestRequired{
+ {RequiredField: scalar.Int32(1)},
+ {},
+ },
+ }},
+ wire: pack.Message{
+ pack.Tag{2, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(1),
+ }),
+ pack.Tag{2, pack.BytesType}, pack.LengthPrefix(pack.Message{}),
+ }.Marshal(),
+ },
+ {
+ desc: "required field in repeated message set",
+ decodeTo: []proto.Message{&testpb.TestRequiredForeign{
+ RepeatedMessage: []*testpb.TestRequired{
+ {RequiredField: scalar.Int32(1)},
+ {RequiredField: scalar.Int32(2)},
+ },
+ }},
+ wire: pack.Message{
+ pack.Tag{2, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(1),
+ }),
+ pack.Tag{2, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(2),
+ }),
+ }.Marshal(),
+ },
+ {
+ desc: "required field in map message unset",
+ partial: true,
+ decodeTo: []proto.Message{&testpb.TestRequiredForeign{
+ MapMessage: map[int32]*testpb.TestRequired{
+ 1: {RequiredField: scalar.Int32(1)},
+ 2: {},
+ },
+ }},
+ wire: pack.Message{
+ pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(1),
+ pack.Tag{2, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(1),
+ }),
+ }),
+ pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(2),
+ pack.Tag{2, pack.BytesType}, pack.LengthPrefix(pack.Message{}),
+ }),
+ }.Marshal(),
+ },
+ {
+ desc: "required field in map message set",
+ decodeTo: []proto.Message{&testpb.TestRequiredForeign{
+ MapMessage: map[int32]*testpb.TestRequired{
+ 1: {RequiredField: scalar.Int32(1)},
+ 2: {RequiredField: scalar.Int32(2)},
+ },
+ }},
+ wire: pack.Message{
+ pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(1),
+ pack.Tag{2, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(1),
+ }),
+ }),
+ pack.Tag{3, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(2),
+ pack.Tag{2, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(2),
+ }),
+ }),
+ }.Marshal(),
+ },
+ {
+ desc: "required field in optional group unset",
+ partial: true,
+ decodeTo: []proto.Message{&testpb.TestRequiredGroupFields{
+ Optionalgroup: &testpb.TestRequiredGroupFields_OptionalGroup{},
+ }},
+ wire: pack.Message{
+ pack.Tag{1, pack.StartGroupType},
+ pack.Tag{1, pack.EndGroupType},
+ }.Marshal(),
+ },
+ {
+ desc: "required field in optional group set",
+ decodeTo: []proto.Message{&testpb.TestRequiredGroupFields{
+ Optionalgroup: &testpb.TestRequiredGroupFields_OptionalGroup{
+ A: scalar.Int32(1),
+ },
+ }},
+ wire: pack.Message{
+ pack.Tag{1, pack.StartGroupType},
+ pack.Tag{2, pack.VarintType}, pack.Varint(1),
+ pack.Tag{1, pack.EndGroupType},
+ }.Marshal(),
+ },
+ {
+ desc: "required field in repeated group unset",
+ partial: true,
+ decodeTo: []proto.Message{&testpb.TestRequiredGroupFields{
+ Repeatedgroup: []*testpb.TestRequiredGroupFields_RepeatedGroup{
+ {A: scalar.Int32(1)},
+ {},
+ },
+ }},
+ wire: pack.Message{
+ pack.Tag{3, pack.StartGroupType},
+ pack.Tag{4, pack.VarintType}, pack.Varint(1),
+ pack.Tag{3, pack.EndGroupType},
+ pack.Tag{3, pack.StartGroupType},
+ pack.Tag{3, pack.EndGroupType},
+ }.Marshal(),
+ },
+ {
+ desc: "required field in repeated group set",
+ decodeTo: []proto.Message{&testpb.TestRequiredGroupFields{
+ Repeatedgroup: []*testpb.TestRequiredGroupFields_RepeatedGroup{
+ {A: scalar.Int32(1)},
+ {A: scalar.Int32(2)},
+ },
+ }},
+ wire: pack.Message{
+ pack.Tag{3, pack.StartGroupType},
+ pack.Tag{4, pack.VarintType}, pack.Varint(1),
+ pack.Tag{3, pack.EndGroupType},
+ pack.Tag{3, pack.StartGroupType},
+ pack.Tag{4, pack.VarintType}, pack.Varint(2),
+ pack.Tag{3, pack.EndGroupType},
+ }.Marshal(),
+ },
+ {
+ desc: "required field in extension message unset",
+ partial: true,
+ invalidExtensions: true,
+ decodeTo: []proto.Message{build(
+ &testpb.TestAllExtensions{},
+ extend(testpb.E_TestRequired_Single, &testpb.TestRequired{}),
+ )},
+ wire: pack.Message{
+ pack.Tag{1000, pack.BytesType}, pack.LengthPrefix(pack.Message{}),
+ }.Marshal(),
+ },
+ {
+ desc: "required field in extension message set",
+ decodeTo: []proto.Message{build(
+ &testpb.TestAllExtensions{},
+ extend(testpb.E_TestRequired_Single, &testpb.TestRequired{
+ RequiredField: scalar.Int32(1),
+ }),
+ )},
+ wire: pack.Message{
+ pack.Tag{1000, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(1),
+ }),
+ }.Marshal(),
+ },
+ {
+ desc: "required field in repeated extension message unset",
+ partial: true,
+ invalidExtensions: true,
+ decodeTo: []proto.Message{build(
+ &testpb.TestAllExtensions{},
+ extend(testpb.E_TestRequired_Multi, []*testpb.TestRequired{
+ {RequiredField: scalar.Int32(1)},
+ {},
+ }),
+ )},
+ wire: pack.Message{
+ pack.Tag{1001, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(1),
+ }),
+ pack.Tag{1001, pack.BytesType}, pack.LengthPrefix(pack.Message{}),
+ }.Marshal(),
+ },
+ {
+ desc: "required field in repeated extension message set",
+ decodeTo: []proto.Message{build(
+ &testpb.TestAllExtensions{},
+ extend(testpb.E_TestRequired_Multi, []*testpb.TestRequired{
+ {RequiredField: scalar.Int32(1)},
+ {RequiredField: scalar.Int32(2)},
+ }),
+ )},
+ wire: pack.Message{
+ pack.Tag{1001, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(1),
+ }),
+ pack.Tag{1001, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.VarintType}, pack.Varint(2),
+ }),
+ }.Marshal(),
+ },
}
func build(m proto.Message, opts ...buildOpt) proto.Message {