testing/prototest: refactor prototest API

For consistency with other options types in the protobuf module, make
the test function a method of the options.

Drop the ExtensionTypes option and just look up the extension types to
test with in the provided resolver.

Change-Id: I7918bd10b7c003e4af56d27521d30218653d5b4d
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/219142
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/testing/prototest/prototest.go b/testing/prototest/prototest.go
index 0c6b700..4fccb70 100644
--- a/testing/prototest/prototest.go
+++ b/testing/prototest/prototest.go
@@ -17,47 +17,42 @@
 	"google.golang.org/protobuf/internal/encoding/wire"
 	"google.golang.org/protobuf/proto"
 	pref "google.golang.org/protobuf/reflect/protoreflect"
-	preg "google.golang.org/protobuf/reflect/protoregistry"
+	"google.golang.org/protobuf/reflect/protoregistry"
 )
 
 // TODO: Test invalid field descriptors or oneof descriptors.
 // TODO: This should test the functionality that can be provided by fast-paths.
 
-// MessageOptions configure message tests.
-type MessageOptions struct {
-	// ExtensionTypes is a list of types to test with.
-	//
-	// If nil, TestMessage will look for extension types in the global registry.
-	ExtensionTypes []pref.ExtensionType
-
-	// Resolver is used for looking up types when unmarshaling extension fields.
+// Message tests a message implemention.
+type Message struct {
+	// Resolver is used to determine the list of extension fields to test with.
 	// If nil, this defaults to using protoregistry.GlobalTypes.
 	Resolver interface {
-		preg.ExtensionTypeResolver
+		FindExtensionByName(field pref.FullName) (pref.ExtensionType, error)
+		FindExtensionByNumber(message pref.FullName, field pref.FieldNumber) (pref.ExtensionType, error)
+		RangeExtensionsByMessage(message pref.FullName, f func(pref.ExtensionType) bool)
 	}
 }
 
-// TODO(blocks): TestMessage should not take in MessageOptions,
-// but have a MessageOptions.Test method instead.
+// Test performs tests on a MessageType implementation.
+func (test Message) Test(t testing.TB, mt pref.MessageType) {
+	testType(t, mt)
 
-// TestMessage runs the provided m through a series of tests
-// exercising the protobuf reflection API.
-func TestMessage(t testing.TB, m proto.Message, opts MessageOptions) {
-	testType(t, m)
-
-	md := m.ProtoReflect().Descriptor()
-	m1 := m.ProtoReflect().New()
+	md := mt.Descriptor()
+	m1 := mt.New()
 	for i := 0; i < md.Fields().Len(); i++ {
 		fd := md.Fields().Get(i)
 		testField(t, m1, fd)
 	}
-	if opts.ExtensionTypes == nil {
-		preg.GlobalTypes.RangeExtensionsByMessage(md.FullName(), func(e pref.ExtensionType) bool {
-			opts.ExtensionTypes = append(opts.ExtensionTypes, e)
-			return true
-		})
+	if test.Resolver == nil {
+		test.Resolver = protoregistry.GlobalTypes
 	}
-	for _, xt := range opts.ExtensionTypes {
+	var extTypes []pref.ExtensionType
+	test.Resolver.RangeExtensionsByMessage(md.FullName(), func(e pref.ExtensionType) bool {
+		extTypes = append(extTypes, e)
+		return true
+	})
+	for _, xt := range extTypes {
 		testField(t, m1, xt.TypeDescriptor())
 	}
 	for i := 0; i < md.Oneofs().Len(); i++ {
@@ -66,9 +61,9 @@
 	testUnknown(t, m1)
 
 	// Test round-trip marshal/unmarshal.
-	m2 := m.ProtoReflect().New().Interface()
+	m2 := mt.New().Interface()
 	populateMessage(m2.ProtoReflect(), 1, nil)
-	for _, xt := range opts.ExtensionTypes {
+	for _, xt := range extTypes {
 		m2.ProtoReflect().Set(xt.TypeDescriptor(), newValue(m2.ProtoReflect(), xt.TypeDescriptor(), 1, nil))
 	}
 	b, err := proto.MarshalOptions{
@@ -77,10 +72,10 @@
 	if err != nil {
 		t.Errorf("Marshal() = %v, want nil\n%v", err, prototext.Format(m2))
 	}
-	m3 := m.ProtoReflect().New().Interface()
+	m3 := mt.New().Interface()
 	if err := (proto.UnmarshalOptions{
 		AllowPartial: true,
-		Resolver:     opts.Resolver,
+		Resolver:     test.Resolver,
 	}.Unmarshal(b, m3)); err != nil {
 		t.Errorf("Unmarshal() = %v, want nil\n%v", err, prototext.Format(m2))
 	}
@@ -89,7 +84,8 @@
 	}
 }
 
-func testType(t testing.TB, m proto.Message) {
+func testType(t testing.TB, mt pref.MessageType) {
+	m := mt.New().Interface()
 	want := reflect.TypeOf(m)
 	if got := reflect.TypeOf(m.ProtoReflect().Interface()); got != want {
 		t.Errorf("type mismatch: reflect.TypeOf(m) != reflect.TypeOf(m.ProtoReflect().Interface()): %v != %v", got, want)