proto: validate UTF-8 in proto3 strings

Change-Id: I6a495730c3f438e7b2c4ca86edade7d6f25aa47d
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/171700
Reviewed-by: Herbie Ong <herbie@google.com>
diff --git a/proto/decode.go b/proto/decode.go
index 3e00074..0b1aa3f 100644
--- a/proto/decode.go
+++ b/proto/decode.go
@@ -86,7 +86,7 @@
 		case fieldType.Cardinality() != protoreflect.Repeated:
 			valLen, err = o.unmarshalScalarField(b[tagLen:], wtyp, num, knownFields, fieldType)
 		case !fieldType.IsMap():
-			valLen, err = o.unmarshalList(b[tagLen:], wtyp, num, knownFields.Get(num).List(), fieldType.Kind())
+			valLen, err = o.unmarshalList(b[tagLen:], wtyp, num, knownFields.Get(num).List(), fieldType)
 		default:
 			valLen, err = o.unmarshalMap(b[tagLen:], wtyp, num, knownFields.Get(num).Map(), fieldType)
 		}
@@ -105,8 +105,9 @@
 }
 
 func (o UnmarshalOptions) unmarshalScalarField(b []byte, wtyp wire.Type, num wire.Number, knownFields protoreflect.KnownFields, field protoreflect.FieldDescriptor) (n int, err error) {
-	v, n, err := o.unmarshalScalar(b, wtyp, num, field.Kind())
-	if err != nil {
+	var nerr errors.NonFatal
+	v, n, err := o.unmarshalScalar(b, wtyp, num, field)
+	if !nerr.Merge(err) {
 		return 0, err
 	}
 	switch field.Kind() {
@@ -124,12 +125,14 @@
 			knownFields.Set(num, protoreflect.ValueOf(m))
 		}
 		// Pass up errors (fatal and otherwise).
-		err = o.unmarshalMessage(v.Bytes(), m)
+		if err := o.unmarshalMessage(v.Bytes(), m); !nerr.Merge(err) {
+			return n, err
+		}
 	default:
 		// Non-message scalars replace the previous value.
 		knownFields.Set(num, v)
 	}
-	return n, err
+	return n, nerr.E
 }
 
 func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp wire.Type, num wire.Number, mapv protoreflect.Map, field protoreflect.FieldDescriptor) (n int, err error) {
@@ -164,17 +167,19 @@
 		err = errUnknown
 		switch num {
 		case 1:
-			key, n, err = o.unmarshalScalar(b, wtyp, num, keyField.Kind())
-			if err != nil {
+			key, n, err = o.unmarshalScalar(b, wtyp, num, keyField)
+			if !nerr.Merge(err) {
 				break
 			}
+			err = nil
 			haveKey = true
 		case 2:
 			var v protoreflect.Value
-			v, n, err = o.unmarshalScalar(b, wtyp, num, valField.Kind())
-			if err != nil {
+			v, n, err = o.unmarshalScalar(b, wtyp, num, valField)
+			if !nerr.Merge(err) {
 				break
 			}
+			err = nil
 			switch valField.Kind() {
 			case protoreflect.GroupKind, protoreflect.MessageKind:
 				if err := o.unmarshalMessage(v.Bytes(), val.Message()); !nerr.Merge(err) {
@@ -190,7 +195,7 @@
 			if n < 0 {
 				return 0, wire.ParseError(n)
 			}
-		} else if !nerr.Merge(err) {
+		} else if err != nil {
 			return 0, err
 		}
 		b = b[n:]
diff --git a/proto/decode_gen.go b/proto/decode_gen.go
index 51b85d7..1a3ef15 100644
--- a/proto/decode_gen.go
+++ b/proto/decode_gen.go
@@ -8,6 +8,7 @@
 
 import (
 	"math"
+	"unicode/utf8"
 
 	"github.com/golang/protobuf/v2/internal/encoding/wire"
 	"github.com/golang/protobuf/v2/internal/errors"
@@ -17,8 +18,8 @@
 // unmarshalScalar decodes a value of the given kind.
 //
 // Message values are decoded into a []byte which aliases the input data.
-func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, num wire.Number, kind protoreflect.Kind) (val protoreflect.Value, n int, err error) {
-	switch kind {
+func (o UnmarshalOptions) unmarshalScalar(b []byte, wtyp wire.Type, num wire.Number, field protoreflect.FieldDescriptor) (val protoreflect.Value, n int, err error) {
+	switch field.Kind() {
 	case protoreflect.BoolKind:
 		if wtyp != wire.VarintType {
 			return val, 0, errUnknown
@@ -153,6 +154,11 @@
 		if n < 0 {
 			return val, 0, wire.ParseError(n)
 		}
+		if field.Syntax() == protoreflect.Proto3 && !utf8.Valid(v) {
+			var nerr errors.NonFatal
+			nerr.AppendInvalidUTF8(string(field.FullName()))
+			return protoreflect.ValueOf(string(v)), n, nerr.E
+		}
 		return protoreflect.ValueOf(string(v)), n, nil
 	case protoreflect.BytesKind:
 		if wtyp != wire.BytesType {
@@ -186,9 +192,9 @@
 	}
 }
 
-func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, num wire.Number, list protoreflect.List, kind protoreflect.Kind) (n int, err error) {
+func (o UnmarshalOptions) unmarshalList(b []byte, wtyp wire.Type, num wire.Number, list protoreflect.List, field protoreflect.FieldDescriptor) (n int, err error) {
 	var nerr errors.NonFatal
-	switch kind {
+	switch field.Kind() {
 	case protoreflect.BoolKind:
 		if wtyp == wire.BytesType {
 			buf, n := wire.ConsumeBytes(b)
@@ -547,6 +553,9 @@
 		if n < 0 {
 			return 0, wire.ParseError(n)
 		}
+		if field.Syntax() == protoreflect.Proto3 && !utf8.Valid(v) {
+			nerr.AppendInvalidUTF8(string(field.FullName()))
+		}
 		list.Append(protoreflect.ValueOf(string(v)))
 		return n, nerr.E
 	case protoreflect.BytesKind:
diff --git a/proto/decode_test.go b/proto/decode_test.go
index dda4db1..2c95f6b 100644
--- a/proto/decode_test.go
+++ b/proto/decode_test.go
@@ -12,6 +12,7 @@
 	protoV1 "github.com/golang/protobuf/proto"
 	"github.com/golang/protobuf/v2/encoding/textpb"
 	"github.com/golang/protobuf/v2/internal/encoding/pack"
+	"github.com/golang/protobuf/v2/internal/errors"
 	"github.com/golang/protobuf/v2/internal/scalar"
 	"github.com/golang/protobuf/v2/proto"
 	pref "github.com/golang/protobuf/v2/reflect/protoreflect"
@@ -80,6 +81,23 @@
 	}
 }
 
+func TestDecodeInvalidUTF8(t *testing.T) {
+	for _, test := range invalidUTF8TestProtos {
+		for _, want := range test.decodeTo {
+			t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
+				got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
+				err := proto.Unmarshal(test.wire, got)
+				if !isErrInvalidUTF8(err) {
+					t.Errorf("Unmarshal did not return expected error for invalid UTF8: %v\nMessage:\n%v", err, marshalText(want))
+				}
+				if !protoV1.Equal(got.(protoV1.Message), want.(protoV1.Message)) {
+					t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want))
+				}
+			})
+		}
+	}
+}
+
 var testProtos = []testProto{
 	{
 		desc: "basic scalar types",
@@ -1158,6 +1176,69 @@
 	},
 }
 
+var invalidUTF8TestProtos = []testProto{
+	{
+		desc: "invalid UTF-8 in optional string field",
+		decodeTo: []proto.Message{&test3pb.TestAllTypes{
+			OptionalString: "abc\xff",
+		}},
+		wire: pack.Message{
+			pack.Tag{14, pack.BytesType}, pack.String("abc\xff"),
+		}.Marshal(),
+	},
+	{
+		desc: "invalid UTF-8 in repeated string field",
+		decodeTo: []proto.Message{&test3pb.TestAllTypes{
+			RepeatedString: []string{"foo", "abc\xff"},
+		}},
+		wire: pack.Message{
+			pack.Tag{44, pack.BytesType}, pack.String("foo"),
+			pack.Tag{44, pack.BytesType}, pack.String("abc\xff"),
+		}.Marshal(),
+	},
+	{
+		desc: "invalid UTF-8 in nested message",
+		decodeTo: []proto.Message{&test3pb.TestAllTypes{
+			OptionalNestedMessage: &test3pb.TestAllTypes_NestedMessage{
+				Corecursive: &test3pb.TestAllTypes{
+					OptionalString: "abc\xff",
+				},
+			},
+		}},
+		wire: pack.Message{
+			pack.Tag{18, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{2, pack.BytesType}, pack.LengthPrefix(pack.Message{
+					pack.Tag{14, pack.BytesType}, pack.String("abc\xff"),
+				}),
+			}),
+		}.Marshal(),
+	},
+	{
+		desc: "invalid UTF-8 in map key",
+		decodeTo: []proto.Message{&test3pb.TestAllTypes{
+			MapStringString: map[string]string{"key\xff": "val"},
+		}},
+		wire: pack.Message{
+			pack.Tag{69, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.BytesType}, pack.String("key\xff"),
+				pack.Tag{2, pack.BytesType}, pack.String("val"),
+			}),
+		}.Marshal(),
+	},
+	{
+		desc: "invalid UTF-8 in map value",
+		decodeTo: []proto.Message{&test3pb.TestAllTypes{
+			MapStringString: map[string]string{"key": "val\xff"},
+		}},
+		wire: pack.Message{
+			pack.Tag{69, pack.BytesType}, pack.LengthPrefix(pack.Message{
+				pack.Tag{1, pack.BytesType}, pack.String("key"),
+				pack.Tag{2, pack.BytesType}, pack.String("val\xff"),
+			}),
+		}.Marshal(),
+	},
+}
+
 func build(m proto.Message, opts ...buildOpt) proto.Message {
 	for _, opt := range opts {
 		opt(m)
@@ -1185,3 +1266,17 @@
 	b, _ := textpb.Marshal(m)
 	return string(b)
 }
+
+func isErrInvalidUTF8(err error) bool {
+	nerr, ok := err.(errors.NonFatalErrors)
+	if !ok || len(nerr) == 0 {
+		return false
+	}
+	for _, err := range nerr {
+		if e, ok := err.(interface{ InvalidUTF8() bool }); ok && e.InvalidUTF8() {
+			continue
+		}
+		return false
+	}
+	return true
+}
diff --git a/proto/encode.go b/proto/encode.go
index b294392..8635790 100644
--- a/proto/encode.go
+++ b/proto/encode.go
@@ -182,13 +182,13 @@
 	switch {
 	case field.Cardinality() != protoreflect.Repeated:
 		b = wire.AppendTag(b, num, wireTypes[kind])
-		return o.marshalSingular(b, num, kind, value)
+		return o.marshalSingular(b, num, field, value)
 	case field.IsMap():
 		return o.marshalMap(b, num, kind, field.MessageType(), value.Map())
 	case field.IsPacked():
-		return o.marshalPacked(b, num, kind, value.List())
+		return o.marshalPacked(b, num, field, value.List())
 	default:
-		return o.marshalList(b, num, kind, value.List())
+		return o.marshalList(b, num, field, value.List())
 	}
 }
 
@@ -229,13 +229,13 @@
 	mapsort.Range(mapv, kind, f)
 }
 
-func (o MarshalOptions) marshalPacked(b []byte, num wire.Number, kind protoreflect.Kind, list protoreflect.List) ([]byte, error) {
+func (o MarshalOptions) marshalPacked(b []byte, num wire.Number, field protoreflect.FieldDescriptor, list protoreflect.List) ([]byte, error) {
 	b = wire.AppendTag(b, num, wire.BytesType)
 	b, pos := appendSpeculativeLength(b)
 	var nerr errors.NonFatal
 	for i, llen := 0, list.Len(); i < llen; i++ {
 		var err error
-		b, err = o.marshalSingular(b, num, kind, list.Get(i))
+		b, err = o.marshalSingular(b, num, field, list.Get(i))
 		if !nerr.Merge(err) {
 			return b, err
 		}
@@ -244,12 +244,13 @@
 	return b, nerr.E
 }
 
-func (o MarshalOptions) marshalList(b []byte, num wire.Number, kind protoreflect.Kind, list protoreflect.List) ([]byte, error) {
+func (o MarshalOptions) marshalList(b []byte, num wire.Number, field protoreflect.FieldDescriptor, list protoreflect.List) ([]byte, error) {
+	kind := field.Kind()
 	var nerr errors.NonFatal
 	for i, llen := 0, list.Len(); i < llen; i++ {
 		var err error
 		b = wire.AppendTag(b, num, wireTypes[kind])
-		b, err = o.marshalSingular(b, num, kind, list.Get(i))
+		b, err = o.marshalSingular(b, num, field, list.Get(i))
 		if !nerr.Merge(err) {
 			return b, err
 		}
diff --git a/proto/encode_gen.go b/proto/encode_gen.go
index 46621c8..4919b96 100644
--- a/proto/encode_gen.go
+++ b/proto/encode_gen.go
@@ -8,6 +8,7 @@
 
 import (
 	"math"
+	"unicode/utf8"
 
 	"github.com/golang/protobuf/v2/internal/encoding/wire"
 	"github.com/golang/protobuf/v2/internal/errors"
@@ -35,9 +36,9 @@
 	protoreflect.GroupKind:    wire.StartGroupType,
 }
 
-func (o MarshalOptions) marshalSingular(b []byte, num wire.Number, kind protoreflect.Kind, v protoreflect.Value) ([]byte, error) {
+func (o MarshalOptions) marshalSingular(b []byte, num wire.Number, field protoreflect.FieldDescriptor, v protoreflect.Value) ([]byte, error) {
 	var nerr errors.NonFatal
-	switch kind {
+	switch field.Kind() {
 	case protoreflect.BoolKind:
 		b = wire.AppendVarint(b, wire.EncodeBool(v.Bool()))
 	case protoreflect.EnumKind:
@@ -67,6 +68,9 @@
 	case protoreflect.DoubleKind:
 		b = wire.AppendFixed64(b, math.Float64bits(v.Float()))
 	case protoreflect.StringKind:
+		if field.Syntax() == protoreflect.Proto3 && !utf8.ValidString(v.String()) {
+			nerr.AppendInvalidUTF8(string(field.FullName()))
+		}
 		b = wire.AppendBytes(b, []byte(v.String()))
 	case protoreflect.BytesKind:
 		b = wire.AppendBytes(b, v.Bytes())
@@ -87,7 +91,7 @@
 		}
 		b = wire.AppendVarint(b, wire.EncodeTag(num, wire.EndGroupType))
 	default:
-		return b, errors.New("invalid kind %v", kind)
+		return b, errors.New("invalid kind %v", field.Kind())
 	}
 	return b, nerr.E
 }
diff --git a/proto/encode_test.go b/proto/encode_test.go
index 30722e0..d670edf 100644
--- a/proto/encode_test.go
+++ b/proto/encode_test.go
@@ -92,6 +92,27 @@
 	}
 }
 
+func TestEncodeInvalidUTF8(t *testing.T) {
+	for _, test := range invalidUTF8TestProtos {
+		for _, want := range test.decodeTo {
+			t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
+				wire, err := proto.Marshal(want)
+				if !isErrInvalidUTF8(err) {
+					t.Errorf("Marshal did not return expected error for invalid UTF8: %v\nMessage:\n%v", err, marshalText(want))
+				}
+				got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
+				if err := proto.Unmarshal(wire, got); !isErrInvalidUTF8(err) {
+					t.Errorf("Unmarshal error: %v\nMessage:\n%v", err, marshalText(want))
+					return
+				}
+				if !protoV1.Equal(got.(protoV1.Message), want.(protoV1.Message)) {
+					t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want))
+				}
+			})
+		}
+	}
+}
+
 func TestEncodeRequiredFieldChecks(t *testing.T) {
 	for _, test := range testProtos {
 		if !test.partial {