testing/prototest: add extensions, unknown fields, typechecks

Change-Id: Ia9a7e0670f5f4655a564f9b5425fd63991c76960
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/183377
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/testing/prototest/prototest.go b/testing/prototest/prototest.go
index 1ca4e36..7d8c91d 100644
--- a/testing/prototest/prototest.go
+++ b/testing/prototest/prototest.go
@@ -12,37 +12,47 @@
 	"sort"
 	"testing"
 
-	prototext "google.golang.org/protobuf/encoding/prototext"
+	"google.golang.org/protobuf/encoding/prototext"
+	"google.golang.org/protobuf/internal/encoding/wire"
 	"google.golang.org/protobuf/proto"
 	pref "google.golang.org/protobuf/reflect/protoreflect"
+	preg "google.golang.org/protobuf/reflect/protoregistry"
 )
 
 // TODO: Test read-only properties of unpopulated composite values.
 // TODO: Test invalid field descriptors or oneof descriptors.
 // TODO: This should test the functionality that can be provided by fast-paths.
 
+// MessageOptions configure message tests.
+type MessageOptions struct {
+	// ExtensionTypes is a list of types to test with.
+	//
+	// If nil, TestMessage will look for extension types in the global registry.
+	ExtensionTypes []pref.ExtensionType
+}
+
 // TestMessage runs the provided m through a series of tests
 // exercising the protobuf reflection API.
-func TestMessage(t testing.TB, m proto.Message) {
+func TestMessage(t testing.TB, m proto.Message, opts MessageOptions) {
 	md := m.ProtoReflect().Descriptor()
 	m1 := m.ProtoReflect().New()
 	for i := 0; i < md.Fields().Len(); i++ {
 		fd := md.Fields().Get(i)
-		switch {
-		case fd.IsList():
-			testFieldList(t, m1, fd)
-		case fd.IsMap():
-			testFieldMap(t, m1, fd)
-		case fd.Kind() == pref.FloatKind || fd.Kind() == pref.DoubleKind:
-			testFieldFloat(t, m1, fd)
-		}
 		testField(t, m1, fd)
 	}
+	if opts.ExtensionTypes == nil {
+		preg.GlobalTypes.RangeExtensionsByMessage(md.FullName(), func(e pref.ExtensionType) bool {
+			opts.ExtensionTypes = append(opts.ExtensionTypes, e)
+			return true
+		})
+	}
+	for _, xt := range opts.ExtensionTypes {
+		testField(t, m1, xt)
+	}
 	for i := 0; i < md.Oneofs().Len(); i++ {
 		testOneof(t, m1, md.Oneofs().Get(i))
 	}
-
-	// TODO: Extensions, unknown fields.
+	testUnknown(t, m1)
 
 	// Test round-trip marshal/unmarshal.
 	m2 := m.ProtoReflect().New().Interface()
@@ -70,6 +80,15 @@
 	name := fd.FullName()
 	num := fd.Number()
 
+	switch {
+	case fd.IsList():
+		testFieldList(t, m, fd)
+	case fd.IsMap():
+		testFieldMap(t, m, fd)
+	case fd.Kind() == pref.FloatKind || fd.Kind() == pref.DoubleKind:
+		testFieldFloat(t, m, fd)
+	}
+
 	// Set to a non-zero value, the zero value, different non-zero values.
 	for _, n := range []seed{1, 0, minVal, maxVal} {
 		v := newValue(m, fd, n, nil)
@@ -82,6 +101,9 @@
 			if fd.Cardinality() == pref.Repeated {
 				wantHas = false
 			}
+			if fd.IsExtension() {
+				wantHas = true
+			}
 			if fd.ContainingOneof() != nil {
 				wantHas = true
 			}
@@ -92,6 +114,20 @@
 		if got, want := m.Get(fd), v; !valueEqual(got, want) {
 			t.Errorf("after setting %q:\nMessage.Get(%v) = %v, want %v", name, num, formatValue(got), formatValue(want))
 		}
+		found := false
+		m.Range(func(d pref.FieldDescriptor, got pref.Value) bool {
+			if fd != d {
+				return true
+			}
+			found = true
+			if want := v; !valueEqual(got, want) {
+				t.Errorf("after setting %q:\nMessage.Range got value %v, want %v", name, formatValue(got), formatValue(want))
+			}
+			return true
+		})
+		if got, want := wantHas, found; got != want {
+			t.Errorf("after setting %q:\nMessageRange saw field: %v, want %v", name, got, want)
+		}
 	}
 
 	m.Clear(fd)
@@ -112,6 +148,17 @@
 			t.Errorf("after clearing %q:\nMessage.Get(%v) = %v, want default %v", name, num, formatValue(got), formatValue(want))
 		}
 	}
+
+	// Set to the wrong type.
+	v := pref.ValueOf("")
+	if fd.Kind() == pref.StringKind {
+		v = pref.ValueOf(int32(0))
+	}
+	if !panics(func() {
+		m.Set(fd, v)
+	}) {
+		t.Errorf("setting %v to %T succeeds, want panic", name, v.Interface())
+	}
 }
 
 // testFieldMap tests set/get/has/clear of entries in a map field.
@@ -200,7 +247,7 @@
 	// Append values.
 	var want pref.List = &testList{}
 	for i, n := range []seed{1, 0, minVal, maxVal} {
-		if got, want := m.Has(fd), i > 0; got != want {
+		if got, want := m.Has(fd), i > 0 || fd.IsExtension(); got != want {
 			t.Errorf("after appending %d elements to %q:\nMessage.Has(%v) = %v, want %v", i, name, num, got, want)
 		}
 		v := newListElement(fd, list, n, nil)
@@ -227,7 +274,7 @@
 		n := want.Len() - 1
 		want.Truncate(n)
 		list.Truncate(n)
-		if got, want := m.Has(fd), want.Len() > 0; got != want {
+		if got, want := m.Has(fd), want.Len() > 0 || fd.IsExtension(); got != want {
 			t.Errorf("after truncating %q to %d:\nMessage.Has(%v) = %v, want %v", name, n, num, got, want)
 		}
 		if got, want := m.Get(fd), pref.ValueOf(want); !valueEqual(got, want) {
@@ -287,6 +334,17 @@
 	}
 }
 
+// testUnknown tests the behavior of unknown fields.
+func testUnknown(t testing.TB, m pref.Message) {
+	var b []byte
+	b = wire.AppendTag(b, 1000, wire.VarintType)
+	b = wire.AppendVarint(b, 1001)
+	m.SetUnknown(pref.RawFields(b))
+	if got, want := []byte(m.GetUnknown()), b; !bytes.Equal(got, want) {
+		t.Errorf("after setting unknown fields:\nGetUnknown() = %v, want %v", got, want)
+	}
+}
+
 func formatValue(v pref.Value) string {
 	switch v := v.Interface().(type) {
 	case pref.List:
@@ -543,3 +601,13 @@
 	}
 	return pref.ValueOf(m)
 }
+
+func panics(f func()) (didPanic bool) {
+	defer func() {
+		if err := recover(); err != nil {
+			didPanic = true
+		}
+	}()
+	f()
+	return false
+}