proto: validate UTF-8 in proto3 strings
Change-Id: I6a495730c3f438e7b2c4ca86edade7d6f25aa47d
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/171700
Reviewed-by: Herbie Ong <herbie@google.com>
diff --git a/proto/decode_test.go b/proto/decode_test.go
index dda4db1..2c95f6b 100644
--- a/proto/decode_test.go
+++ b/proto/decode_test.go
@@ -12,6 +12,7 @@
protoV1 "github.com/golang/protobuf/proto"
"github.com/golang/protobuf/v2/encoding/textpb"
"github.com/golang/protobuf/v2/internal/encoding/pack"
+ "github.com/golang/protobuf/v2/internal/errors"
"github.com/golang/protobuf/v2/internal/scalar"
"github.com/golang/protobuf/v2/proto"
pref "github.com/golang/protobuf/v2/reflect/protoreflect"
@@ -80,6 +81,23 @@
}
}
+func TestDecodeInvalidUTF8(t *testing.T) {
+ for _, test := range invalidUTF8TestProtos {
+ for _, want := range test.decodeTo {
+ t.Run(fmt.Sprintf("%s (%T)", test.desc, want), func(t *testing.T) {
+ got := reflect.New(reflect.TypeOf(want).Elem()).Interface().(proto.Message)
+ err := proto.Unmarshal(test.wire, got)
+ if !isErrInvalidUTF8(err) {
+ t.Errorf("Unmarshal did not return expected error for invalid UTF8: %v\nMessage:\n%v", err, marshalText(want))
+ }
+ if !protoV1.Equal(got.(protoV1.Message), want.(protoV1.Message)) {
+ t.Errorf("Unmarshal returned unexpected result; got:\n%v\nwant:\n%v", marshalText(got), marshalText(want))
+ }
+ })
+ }
+ }
+}
+
var testProtos = []testProto{
{
desc: "basic scalar types",
@@ -1158,6 +1176,69 @@
},
}
+var invalidUTF8TestProtos = []testProto{
+ {
+ desc: "invalid UTF-8 in optional string field",
+ decodeTo: []proto.Message{&test3pb.TestAllTypes{
+ OptionalString: "abc\xff",
+ }},
+ wire: pack.Message{
+ pack.Tag{14, pack.BytesType}, pack.String("abc\xff"),
+ }.Marshal(),
+ },
+ {
+ desc: "invalid UTF-8 in repeated string field",
+ decodeTo: []proto.Message{&test3pb.TestAllTypes{
+ RepeatedString: []string{"foo", "abc\xff"},
+ }},
+ wire: pack.Message{
+ pack.Tag{44, pack.BytesType}, pack.String("foo"),
+ pack.Tag{44, pack.BytesType}, pack.String("abc\xff"),
+ }.Marshal(),
+ },
+ {
+ desc: "invalid UTF-8 in nested message",
+ decodeTo: []proto.Message{&test3pb.TestAllTypes{
+ OptionalNestedMessage: &test3pb.TestAllTypes_NestedMessage{
+ Corecursive: &test3pb.TestAllTypes{
+ OptionalString: "abc\xff",
+ },
+ },
+ }},
+ wire: pack.Message{
+ pack.Tag{18, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{2, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{14, pack.BytesType}, pack.String("abc\xff"),
+ }),
+ }),
+ }.Marshal(),
+ },
+ {
+ desc: "invalid UTF-8 in map key",
+ decodeTo: []proto.Message{&test3pb.TestAllTypes{
+ MapStringString: map[string]string{"key\xff": "val"},
+ }},
+ wire: pack.Message{
+ pack.Tag{69, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.BytesType}, pack.String("key\xff"),
+ pack.Tag{2, pack.BytesType}, pack.String("val"),
+ }),
+ }.Marshal(),
+ },
+ {
+ desc: "invalid UTF-8 in map value",
+ decodeTo: []proto.Message{&test3pb.TestAllTypes{
+ MapStringString: map[string]string{"key": "val\xff"},
+ }},
+ wire: pack.Message{
+ pack.Tag{69, pack.BytesType}, pack.LengthPrefix(pack.Message{
+ pack.Tag{1, pack.BytesType}, pack.String("key"),
+ pack.Tag{2, pack.BytesType}, pack.String("val\xff"),
+ }),
+ }.Marshal(),
+ },
+}
+
func build(m proto.Message, opts ...buildOpt) proto.Message {
for _, opt := range opts {
opt(m)
@@ -1185,3 +1266,17 @@
b, _ := textpb.Marshal(m)
return string(b)
}
+
+func isErrInvalidUTF8(err error) bool {
+ nerr, ok := err.(errors.NonFatalErrors)
+ if !ok || len(nerr) == 0 {
+ return false
+ }
+ for _, err := range nerr {
+ if e, ok := err.(interface{ InvalidUTF8() bool }); ok && e.InvalidUTF8() {
+ continue
+ }
+ return false
+ }
+ return true
+}