internal/impl: add message validator

This adds a experimental function to the internal/impl package which
validates a wire-format message against a message type. The validator
reports whether the message can be successfully unmarshaled, and whether
the result is initialized (all required fields are set). In some cases,
the validator returns ambiguous results when full validation would be
expensive.

The validator is unused outside of tests. In the future, it may be used
to permit lazy unmarshaling of some data. It is being added now for
testing; in particular, the wire fuzzer now checks the validator output
for consistency with the unmarshaler.

The validator adds a small amount of unused per-MessageType state. If
this becomes a concern, we could conditionalize it with a build tag.

Change-Id: I4216ef81d6a9ed975302eed189b02d08608858b4
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/212302
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/proto/testmessages_test.go b/proto/testmessages_test.go
index 8206ed2..8c62dbb 100644
--- a/proto/testmessages_test.go
+++ b/proto/testmessages_test.go
@@ -6,6 +6,8 @@
 
 import (
 	"google.golang.org/protobuf/internal/encoding/pack"
+	"google.golang.org/protobuf/internal/encoding/wire"
+	"google.golang.org/protobuf/internal/impl"
 	"google.golang.org/protobuf/proto"
 
 	legacypb "google.golang.org/protobuf/internal/testprotos/legacy"
@@ -15,11 +17,12 @@
 )
 
 type testProto struct {
-	desc     string
-	decodeTo []proto.Message
-	wire     []byte
-	partial  bool
-	noEncode bool
+	desc             string
+	decodeTo         []proto.Message
+	wire             []byte
+	partial          bool
+	noEncode         bool
+	validationStatus impl.ValidationStatus
 }
 
 var testValidMessages = []testProto{
@@ -1162,6 +1165,19 @@
 		}.Marshal(),
 	},
 	{
+		desc:    "required field with incompatible wire type",
+		partial: true,
+		decodeTo: []proto.Message{build(
+			&testpb.TestRequired{},
+			unknown(pack.Message{
+				pack.Tag{1, pack.Fixed32Type}, pack.Int32(2),
+			}.Marshal()),
+		)},
+		wire: pack.Message{
+			pack.Tag{1, pack.Fixed32Type}, pack.Int32(2),
+		}.Marshal(),
+	},
+	{
 		desc:    "required field in optional message unset",
 		partial: true,
 		decodeTo: []proto.Message{&testpb.TestRequiredForeign{
@@ -1197,6 +1213,7 @@
 				pack.Tag{1, pack.VarintType}, pack.Varint(1),
 			}),
 		}.Marshal(),
+		validationStatus: impl.ValidationValidMaybeUninitalized,
 	},
 	{
 		desc:    "required field in repeated message unset",
@@ -1483,6 +1500,7 @@
 				}),
 			}),
 		}.Marshal(),
+		validationStatus: impl.ValidationUnknown,
 	},
 	{
 		desc: "first reserved field number",
@@ -1579,31 +1597,43 @@
 		}.Marshal(),
 	},
 	{
-		desc:     "invalid field number zero",
-		decodeTo: []proto.Message{(*testpb.TestAllTypes)(nil)},
+		desc: "invalid field number zero",
+		decodeTo: []proto.Message{
+			(*testpb.TestAllTypes)(nil),
+			(*testpb.TestAllExtensions)(nil),
+		},
 		wire: pack.Message{
 			pack.Tag{pack.MinValidNumber - 1, pack.VarintType}, pack.Varint(1001),
 		}.Marshal(),
 	},
 	{
-		desc:     "invalid field numbers zero and one",
-		decodeTo: []proto.Message{(*testpb.TestAllTypes)(nil)},
+		desc: "invalid field numbers zero and one",
+		decodeTo: []proto.Message{
+			(*testpb.TestAllTypes)(nil),
+			(*testpb.TestAllExtensions)(nil),
+		},
 		wire: pack.Message{
 			pack.Tag{pack.MinValidNumber - 1, pack.VarintType}, pack.Varint(1002),
 			pack.Tag{pack.MinValidNumber, pack.VarintType}, pack.Varint(1003),
 		}.Marshal(),
 	},
 	{
-		desc:     "invalid field numbers max and max+1",
-		decodeTo: []proto.Message{(*testpb.TestAllTypes)(nil)},
+		desc: "invalid field numbers max and max+1",
+		decodeTo: []proto.Message{
+			(*testpb.TestAllTypes)(nil),
+			(*testpb.TestAllExtensions)(nil),
+		},
 		wire: pack.Message{
 			pack.Tag{pack.MaxValidNumber, pack.VarintType}, pack.Varint(1006),
 			pack.Tag{pack.MaxValidNumber + 1, pack.VarintType}, pack.Varint(1007),
 		}.Marshal(),
 	},
 	{
-		desc:     "invalid field number max+1",
-		decodeTo: []proto.Message{(*testpb.TestAllTypes)(nil)},
+		desc: "invalid field number max+1",
+		decodeTo: []proto.Message{
+			(*testpb.TestAllTypes)(nil),
+			(*testpb.TestAllExtensions)(nil),
+		},
 		wire: pack.Message{
 			pack.Tag{pack.MaxValidNumber + 1, pack.VarintType}, pack.Varint(1008),
 		}.Marshal(),
@@ -1619,4 +1649,266 @@
 			}),
 		}.Marshal(),
 	},
+	{
+		desc: "invalid tag varint",
+		decodeTo: []proto.Message{
+			(*testpb.TestAllTypes)(nil),
+			(*testpb.TestAllExtensions)(nil),
+		},
+		wire: []byte{0xff},
+	},
+	{
+		desc: "field number too small",
+		decodeTo: []proto.Message{
+			(*testpb.TestAllTypes)(nil),
+			(*testpb.TestAllExtensions)(nil),
+		},
+		wire: pack.Message{
+			pack.Tag{0, pack.VarintType}, pack.Varint(0),
+		}.Marshal(),
+	},
+	{
+		desc: "field number too large",
+		decodeTo: []proto.Message{
+			(*testpb.TestAllTypes)(nil),
+			(*testpb.TestAllExtensions)(nil),
+		},
+		wire: pack.Message{
+			pack.Tag{wire.MaxValidNumber + 1, pack.VarintType}, pack.Varint(0),
+		}.Marshal(),
+	},
+	{
+		desc: "invalid tag varint in message field",
+		decodeTo: []proto.Message{
+			(*testpb.TestAllTypes)(nil),
+			(*testpb.TestAllExtensions)(nil),
+		},
+		wire: pack.Message{
+			pack.Tag{18, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Raw{0xff},
+			}),
+		}.Marshal(),
+	},
+	{
+		desc: "invalid tag varint in repeated message field",
+		decodeTo: []proto.Message{
+			(*testpb.TestAllTypes)(nil),
+			(*testpb.TestAllExtensions)(nil),
+		},
+		wire: pack.Message{
+			pack.Tag{48, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Raw{0xff},
+			}),
+		}.Marshal(),
+	},
+	{
+		desc: "invalid varint in group field",
+		decodeTo: []proto.Message{
+			(*testpb.TestAllTypes)(nil),
+			(*testpb.TestAllExtensions)(nil),
+		},
+		wire: pack.Message{
+			pack.Tag{16, pack.StartGroupType},
+			pack.Tag{1000, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Raw{0xff},
+			}),
+			pack.Tag{16, pack.EndGroupType},
+		}.Marshal(),
+	},
+	{
+		desc: "invalid varint in repeated group field",
+		decodeTo: []proto.Message{
+			(*testpb.TestAllTypes)(nil),
+			(*testpb.TestAllExtensions)(nil),
+		},
+		wire: pack.Message{
+			pack.Tag{46, pack.StartGroupType},
+			pack.Tag{1001, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Raw{0xff},
+			}),
+			pack.Tag{46, pack.EndGroupType},
+		}.Marshal(),
+	},
+	{
+		desc: "unterminated repeated group field",
+		decodeTo: []proto.Message{
+			(*testpb.TestAllTypes)(nil),
+			(*testpb.TestAllExtensions)(nil),
+		},
+		wire: pack.Message{
+			pack.Tag{46, pack.StartGroupType},
+		}.Marshal(),
+	},
+	{
+		desc: "invalid tag varint in map item",
+		decodeTo: []proto.Message{
+			(*testpb.TestAllTypes)(nil),
+		},
+		wire: pack.Message{
+			pack.Tag{56, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.VarintType}, pack.Varint(0),
+				pack.Tag{2, pack.VarintType}, pack.Varint(0),
+				pack.Raw{0xff},
+			}),
+		}.Marshal(),
+	},
+	{
+		desc: "invalid tag varint in map message value",
+		decodeTo: []proto.Message{
+			(*testpb.TestAllTypes)(nil),
+		},
+		wire: pack.Message{
+			pack.Tag{71, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.VarintType}, pack.Varint(0),
+				pack.Tag{2, pack.BytesType}, pack.LengthPrefix(pack.Message{
+					pack.Raw{0xff},
+				}),
+			}),
+		}.Marshal(),
+	},
+	{
+		desc: "invalid packed int32 field",
+		decodeTo: []proto.Message{
+			(*testpb.TestAllTypes)(nil),
+			(*testpb.TestAllExtensions)(nil),
+		},
+		wire: pack.Message{
+			pack.Tag{31, pack.BytesType}, pack.Bytes{0xff},
+		}.Marshal(),
+	},
+	{
+		desc: "invalid packed int64 field",
+		decodeTo: []proto.Message{
+			(*testpb.TestAllTypes)(nil),
+			(*testpb.TestAllExtensions)(nil),
+		},
+		wire: pack.Message{
+			pack.Tag{32, pack.BytesType}, pack.Bytes{0xff},
+		}.Marshal(),
+	},
+	{
+		desc: "invalid packed uint32 field",
+		decodeTo: []proto.Message{
+			(*testpb.TestAllTypes)(nil),
+			(*testpb.TestAllExtensions)(nil),
+		},
+		wire: pack.Message{
+			pack.Tag{33, pack.BytesType}, pack.Bytes{0xff},
+		}.Marshal(),
+	},
+	{
+		desc: "invalid packed uint64 field",
+		decodeTo: []proto.Message{
+			(*testpb.TestAllTypes)(nil),
+			(*testpb.TestAllExtensions)(nil),
+		},
+		wire: pack.Message{
+			pack.Tag{34, pack.BytesType}, pack.Bytes{0xff},
+		}.Marshal(),
+	},
+	{
+		desc: "invalid packed sint32 field",
+		decodeTo: []proto.Message{
+			(*testpb.TestAllTypes)(nil),
+			(*testpb.TestAllExtensions)(nil),
+		},
+		wire: pack.Message{
+			pack.Tag{35, pack.BytesType}, pack.Bytes{0xff},
+		}.Marshal(),
+	},
+	{
+		desc: "invalid packed sint64 field",
+		decodeTo: []proto.Message{
+			(*testpb.TestAllTypes)(nil),
+			(*testpb.TestAllExtensions)(nil),
+		},
+		wire: pack.Message{
+			pack.Tag{36, pack.BytesType}, pack.Bytes{0xff},
+		}.Marshal(),
+	},
+	{
+		desc: "invalid packed fixed32 field",
+		decodeTo: []proto.Message{
+			(*testpb.TestAllTypes)(nil),
+			(*testpb.TestAllExtensions)(nil),
+		},
+		wire: pack.Message{
+			pack.Tag{37, pack.BytesType}, pack.Bytes{0x00},
+		}.Marshal(),
+	},
+	{
+		desc: "invalid packed fixed64 field",
+		decodeTo: []proto.Message{
+			(*testpb.TestAllTypes)(nil),
+			(*testpb.TestAllExtensions)(nil),
+		},
+		wire: pack.Message{
+			pack.Tag{38, pack.BytesType}, pack.Bytes{0x00},
+		}.Marshal(),
+	},
+	{
+		desc: "invalid packed sfixed32 field",
+		decodeTo: []proto.Message{
+			(*testpb.TestAllTypes)(nil),
+			(*testpb.TestAllExtensions)(nil),
+		},
+		wire: pack.Message{
+			pack.Tag{39, pack.BytesType}, pack.Bytes{0x00},
+		}.Marshal(),
+	},
+	{
+		desc: "invalid packed sfixed64 field",
+		decodeTo: []proto.Message{
+			(*testpb.TestAllTypes)(nil),
+			(*testpb.TestAllExtensions)(nil),
+		},
+		wire: pack.Message{
+			pack.Tag{40, pack.BytesType}, pack.Bytes{0x00},
+		}.Marshal(),
+	},
+	{
+		desc: "invalid packed float field",
+		decodeTo: []proto.Message{
+			(*testpb.TestAllTypes)(nil),
+			(*testpb.TestAllExtensions)(nil),
+		},
+		wire: pack.Message{
+			pack.Tag{41, pack.BytesType}, pack.Bytes{0x00},
+		}.Marshal(),
+	},
+	{
+		desc: "invalid packed double field",
+		decodeTo: []proto.Message{
+			(*testpb.TestAllTypes)(nil),
+			(*testpb.TestAllExtensions)(nil),
+		},
+		wire: pack.Message{
+			pack.Tag{42, pack.BytesType}, pack.Bytes{0x00},
+		}.Marshal(),
+	},
+	{
+		desc: "invalid packed bool field",
+		decodeTo: []proto.Message{
+			(*testpb.TestAllTypes)(nil),
+			(*testpb.TestAllExtensions)(nil),
+		},
+		wire: pack.Message{
+			pack.Tag{43, pack.BytesType}, pack.Bytes{0xff},
+		}.Marshal(),
+	},
+	{
+		desc: "bytes field overruns message",
+		decodeTo: []proto.Message{
+			(*testpb.TestAllTypes)(nil),
+			(*testpb.TestAllExtensions)(nil),
+		},
+		wire: pack.Message{
+			pack.Tag{18, pack.BytesType}, pack.LengthPrefix{pack.Message{
+				pack.Tag{2, pack.BytesType}, pack.LengthPrefix{pack.Message{
+					pack.Tag{15, pack.BytesType}, pack.Varint(2),
+				}},
+				pack.Tag{1, pack.VarintType}, pack.Varint(0),
+			}},
+		}.Marshal(),
+	},
 }