proto: validate UTF-8 in proto3 strings

Change-Id: I6a495730c3f438e7b2c4ca86edade7d6f25aa47d
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/171700
Reviewed-by: Herbie Ong <herbie@google.com>
diff --git a/proto/decode_test.go b/proto/decode_test.go
index dda4db1..2c95f6b 100644
--- a/proto/decode_test.go
+++ b/proto/decode_test.go
@@ -12,6 +12,7 @@
 	protoV1 "github.com/golang/protobuf/proto"
 	"github.com/golang/protobuf/v2/encoding/textpb"
 	"github.com/golang/protobuf/v2/internal/encoding/pack"
+	"github.com/golang/protobuf/v2/internal/errors"
 	"github.com/golang/protobuf/v2/internal/scalar"
 	"github.com/golang/protobuf/v2/proto"
 	pref "github.com/golang/protobuf/v2/reflect/protoreflect"
@@ -80,6 +81,23 @@
 	}
 }
 
+func TestDecodeInvalidUTF8(t *testing.T) {
+	for _, test := range invalidUTF8TestProtos {
+		for _, want := range test.decodeTo {
+			t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
+				got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
+				err := proto.Unmarshal(test.wire, got)
+				if !isErrInvalidUTF8(err) {
+					t.Errorf("Unmarshal did not return expected error for invalid UTF8: %v\nMessage:\n%v", err, marshalText(want))
+				}
+				if !protoV1.Equal(got.(protoV1.Message), want.(protoV1.Message)) {
+					t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want))
+				}
+			})
+		}
+	}
+}
+
 var testProtos = []testProto{
 	{
 		desc: "basic scalar types",
@@ -1158,6 +1176,69 @@
 	},
 }
 
+var invalidUTF8TestProtos = []testProto{
+	{
+		desc: "invalid UTF-8 in optional string field",
+		decodeTo: []proto.Message{&test3pb.TestAllTypes{
+			OptionalString: "abc\xff",
+		}},
+		wire: pack.Message{
+			pack.Tag{14, pack.BytesType}, pack.String("abc\xff"),
+		}.Marshal(),
+	},
+	{
+		desc: "invalid UTF-8 in repeated string field",
+		decodeTo: []proto.Message{&test3pb.TestAllTypes{
+			RepeatedString: []string{"foo", "abc\xff"},
+		}},
+		wire: pack.Message{
+			pack.Tag{44, pack.BytesType}, pack.String("foo"),
+			pack.Tag{44, pack.BytesType}, pack.String("abc\xff"),
+		}.Marshal(),
+	},
+	{
+		desc: "invalid UTF-8 in nested message",
+		decodeTo: []proto.Message{&test3pb.TestAllTypes{
+			OptionalNestedMessage: &test3pb.TestAllTypes_NestedMessage{
+				Corecursive: &test3pb.TestAllTypes{
+					OptionalString: "abc\xff",
+				},
+			},
+		}},
+		wire: pack.Message{
+			pack.Tag{18, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{2, pack.BytesType}, pack.LengthPrefix(pack.Message{
+					pack.Tag{14, pack.BytesType}, pack.String("abc\xff"),
+				}),
+			}),
+		}.Marshal(),
+	},
+	{
+		desc: "invalid UTF-8 in map key",
+		decodeTo: []proto.Message{&test3pb.TestAllTypes{
+			MapStringString: map[string]string{"key\xff": "val"},
+		}},
+		wire: pack.Message{
+			pack.Tag{69, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.BytesType}, pack.String("key\xff"),
+				pack.Tag{2, pack.BytesType}, pack.String("val"),
+			}),
+		}.Marshal(),
+	},
+	{
+		desc: "invalid UTF-8 in map value",
+		decodeTo: []proto.Message{&test3pb.TestAllTypes{
+			MapStringString: map[string]string{"key": "val\xff"},
+		}},
+		wire: pack.Message{
+			pack.Tag{69, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.BytesType}, pack.String("key"),
+				pack.Tag{2, pack.BytesType}, pack.String("val\xff"),
+			}),
+		}.Marshal(),
+	},
+}
+
 func build(m proto.Message, opts ...buildOpt) proto.Message {
 	for _, opt := range opts {
 		opt(m)
@@ -1185,3 +1266,17 @@
 	b, _ := textpb.Marshal(m)
 	return string(b)
 }
+
+func isErrInvalidUTF8(err error) bool {
+	nerr, ok := err.(errors.NonFatalErrors)
+	if !ok || len(nerr) == 0 {
+		return false
+	}
+	for _, err := range nerr {
+		if e, ok := err.(interface{ InvalidUTF8() bool }); ok && e.InvalidUTF8() {
+			continue
+		}
+		return false
+	}
+	return true
+}