reflect/protoregistry: add (*Types).Register{Message,Enum,Extension}

Add type-safe methods to register message, enum, and extension types.
Deprecate the NewTypes function and the (*Types).Register method.

Add (*File).RegisterFile and deprecate the NewFiles function and
the (*File).Register method.

Updates golang/protobuf#963

Change-Id: Ie89e77526e0874539e9bd929ca0ba8d758e65a6e
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/199898
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/compiler/protogen/protogen.go b/compiler/protogen/protogen.go
index 5dbdaf9..8ffccca 100644
--- a/compiler/protogen/protogen.go
+++ b/compiler/protogen/protogen.go
@@ -157,7 +157,7 @@
 	gen := &Plugin{
 		Request:        req,
 		FilesByPath:    make(map[string]*File),
-		fileReg:        protoregistry.NewFiles(),
+		fileReg:        new(protoregistry.Files),
 		enumsByName:    make(map[protoreflect.FullName]*Enum),
 		messagesByName: make(map[protoreflect.FullName]*Message),
 		opts:           opts,
@@ -440,7 +440,7 @@
 	if err != nil {
 		return nil, fmt.Errorf("invalid FileDescriptorProto %q: %v", p.GetName(), err)
 	}
-	if err := gen.fileReg.Register(desc); err != nil {
+	if err := gen.fileReg.RegisterFile(desc); err != nil {
 		return nil, fmt.Errorf("cannot register descriptor %q: %v", p.GetName(), err)
 	}
 	f := &File{
diff --git a/internal/filedesc/build.go b/internal/filedesc/build.go
index 3719455..f4af7ad 100644
--- a/internal/filedesc/build.go
+++ b/internal/filedesc/build.go
@@ -44,7 +44,7 @@
 	FileRegistry interface {
 		FindFileByPath(string) (protoreflect.FileDescriptor, error)
 		FindDescriptorByName(pref.FullName) (pref.Descriptor, error)
-		Register(...pref.FileDescriptor) error
+		RegisterFile(pref.FileDescriptor) error
 	}
 }
 
@@ -107,7 +107,7 @@
 	out.Extensions = fd.allExtensions
 	out.Services = fd.allServices
 
-	if err := db.FileRegistry.Register(fd); err != nil {
+	if err := db.FileRegistry.RegisterFile(fd); err != nil {
 		panic(err)
 	}
 	return out
diff --git a/internal/filetype/build.go b/internal/filetype/build.go
index f421a62..0a0dd35 100644
--- a/internal/filetype/build.go
+++ b/internal/filetype/build.go
@@ -108,7 +108,9 @@
 	// TypeRegistry is the registry to register each type descriptor.
 	// If nil, it uses protoregistry.GlobalTypes.
 	TypeRegistry interface {
-		Register(...preg.Type) error
+		RegisterMessage(pref.MessageType) error
+		RegisterEnum(pref.EnumType) error
+		RegisterExtension(pref.ExtensionType) error
 	}
 }
 
@@ -149,7 +151,7 @@
 				Desc:          &fbOut.Enums[i],
 			}
 			// Register enum types.
-			if err := tb.TypeRegistry.Register(&tb.EnumInfos[i]); err != nil {
+			if err := tb.TypeRegistry.RegisterEnum(&tb.EnumInfos[i]); err != nil {
 				panic(err)
 			}
 		}
@@ -170,7 +172,7 @@
 			tb.MessageInfos[i].Desc = &fbOut.Messages[i]
 
 			// Register message types.
-			if err := tb.TypeRegistry.Register(&tb.MessageInfos[i]); err != nil {
+			if err := tb.TypeRegistry.RegisterMessage(&tb.MessageInfos[i]); err != nil {
 				panic(err)
 			}
 		}
@@ -232,7 +234,7 @@
 		pimpl.InitExtensionInfo(&tb.ExtensionInfos[i], &fbOut.Extensions[i], goType)
 
 		// Register extension types.
-		if err := tb.TypeRegistry.Register(&tb.ExtensionInfos[i]); err != nil {
+		if err := tb.TypeRegistry.RegisterExtension(&tb.ExtensionInfos[i]); err != nil {
 			panic(err)
 		}
 	}
@@ -274,7 +276,7 @@
 	fileRegistry interface {
 		FindFileByPath(string) (pref.FileDescriptor, error)
 		FindDescriptorByName(pref.FullName) (pref.Descriptor, error)
-		Register(...pref.FileDescriptor) error
+		RegisterFile(pref.FileDescriptor) error
 	}
 )
 
diff --git a/internal/impl/legacy_file.go b/internal/impl/legacy_file.go
index bccaede..b61a135 100644
--- a/internal/impl/legacy_file.go
+++ b/internal/impl/legacy_file.go
@@ -70,4 +70,4 @@
 	*protoregistry.Files
 }
 
-func (resolverOnly) Register(...protoreflect.FileDescriptor) error { return nil }
+func (resolverOnly) Register(protoreflect.FileDescriptor) error { return nil }
diff --git a/internal/impl/legacy_test.go b/internal/impl/legacy_test.go
index f55c0d3..8e63ba6 100644
--- a/internal/impl/legacy_test.go
+++ b/internal/impl/legacy_test.go
@@ -52,8 +52,8 @@
 
 func init() {
 	mt := pimpl.Export{}.MessageTypeOf((*LegacyTestMessage)(nil))
-	preg.GlobalFiles.Register(mt.Descriptor().ParentFile())
-	preg.GlobalTypes.Register(mt)
+	preg.GlobalFiles.RegisterFile(mt.Descriptor().ParentFile())
+	preg.GlobalTypes.RegisterMessage(mt)
 }
 
 func mustMakeExtensionType(fileDesc, extDesc string, t reflect.Type, r pdesc.Resolver) pref.ExtensionType {
@@ -82,7 +82,7 @@
 	testMessageV1Desc = pimpl.Export{}.MessageDescriptorOf((*proto2_20180125.Message_ChildMessage)(nil))
 	testMessageV2Desc = enumMessagesType.Desc
 
-	depReg = preg.NewFiles(
+	depReg = newFileRegistry(
 		testParentDesc.ParentFile(),
 		testEnumV1Desc.ParentFile(),
 		testMessageV1Desc.ParentFile(),
diff --git a/internal/impl/message_reflect_test.go b/internal/impl/message_reflect_test.go
index c5a4a6a..82d5f81 100644
--- a/internal/impl/message_reflect_test.go
+++ b/internal/impl/message_reflect_test.go
@@ -990,7 +990,7 @@
 			{name:"F7Entry" field:[{name:"key" number:1 label:LABEL_OPTIONAL type:TYPE_STRING}, {name:"value" number:2 label:LABEL_OPTIONAL type:TYPE_ENUM    type_name:".EnumProto3"}]   options:{map_entry:true}},
 			{name:"F8Entry" field:[{name:"key" number:1 label:LABEL_OPTIONAL type:TYPE_STRING}, {name:"value" number:2 label:LABEL_OPTIONAL type:TYPE_MESSAGE type_name:".ScalarProto3"}] options:{map_entry:true}}
 		]
-	`, protoregistry.NewFiles(
+	`, newFileRegistry(
 	EnumProto2(0).Descriptor().ParentFile(),
 	EnumProto3(0).Descriptor().ParentFile(),
 	((*ScalarProto2)(nil)).ProtoReflect().Descriptor().ParentFile(),
@@ -999,6 +999,14 @@
 )),
 }
 
+func newFileRegistry(files ...pref.FileDescriptor) *protoregistry.Files {
+	r := new(protoregistry.Files)
+	for _, file := range files {
+		r.RegisterFile(file)
+	}
+	return r
+}
+
 func (m *EnumMessages) ProtoReflect() pref.Message { return enumMessagesType.MessageOf(m) }
 
 func (*EnumMessages) XXX_OneofWrappers() []interface{} {
diff --git a/internal/protolegacy/proto.go b/internal/protolegacy/proto.go
index ffca87a..b9f023a 100644
--- a/internal/protolegacy/proto.go
+++ b/internal/protolegacy/proto.go
@@ -61,7 +61,7 @@
 
 func RegisterType(m Message, s string) {
 	mt := protoimpl.X.LegacyMessageTypeOf(m, protoreflect.FullName(s))
-	if err := protoregistry.GlobalTypes.Register(mt); err != nil {
+	if err := protoregistry.GlobalTypes.RegisterMessage(mt); err != nil {
 		panic(err)
 	}
 }
@@ -75,7 +75,7 @@
 }
 
 func RegisterExtension(d *ExtensionDesc) {
-	if err := protoregistry.GlobalTypes.Register(d); err != nil {
+	if err := protoregistry.GlobalTypes.RegisterExtension(d); err != nil {
 		panic(err)
 	}
 }
diff --git a/reflect/protodesc/file_test.go b/reflect/protodesc/file_test.go
index 38e8b4c..9ee6c8d 100644
--- a/reflect/protodesc/file_test.go
+++ b/reflect/protodesc/file_test.go
@@ -916,7 +916,7 @@
 				if err != nil {
 					t.Fatalf("dependency %d: unexpected NewFile() error: %v", i, err)
 				}
-				if err := r.Register(f); err != nil {
+				if err := r.RegisterFile(f); err != nil {
 					t.Fatalf("dependency %d: unexpected Register() error: %v", i, err)
 				}
 			}
diff --git a/reflect/protoregistry/registry.go b/reflect/protoregistry/registry.go
index 4689f6b..116bd68 100644
--- a/reflect/protoregistry/registry.go
+++ b/reflect/protoregistry/registry.go
@@ -75,20 +75,37 @@
 
 // NewFiles returns a registry initialized with the provided set of files.
 // Files with a namespace conflict with an pre-existing file are not registered.
+//
+// Deprecated: Use Register.
 func NewFiles(files ...protoreflect.FileDescriptor) *Files {
 	r := new(Files)
-	r.Register(files...) // ignore errors; first takes precedence
+	for _, file := range files {
+		r.RegisterFile(file) // ignore errors; first takes precedence
+	}
 	return r
 }
 
 // Register registers the provided list of file descriptors.
 //
-// If any descriptor within a file conflicts with the descriptor of any
+// Deprecated: Use RegisterFile.
+func (r *Files) Register(files ...protoreflect.FileDescriptor) error {
+	var firstErr error
+	for _, file := range files {
+		if err := r.RegisterFile(file); err != nil && firstErr == nil {
+			firstErr = err
+		}
+	}
+	return firstErr
+}
+
+// RegisterFile registers the provided file descriptor.
+//
+// If any descriptor within the file conflicts with the descriptor of any
 // previously registered file (e.g., two enums with the same full name),
-// then that file is not registered and an error is returned.
+// then the file is not registered and an error is returned.
 //
 // It is permitted for multiple files to have the same file path.
-func (r *Files) Register(files ...protoreflect.FileDescriptor) error {
+func (r *Files) RegisterFile(file protoreflect.FileDescriptor) error {
 	if r == GlobalFiles {
 		globalMutex.Lock()
 		defer globalMutex.Unlock()
@@ -99,32 +116,23 @@
 		}
 		r.filesByPath = make(map[string]protoreflect.FileDescriptor)
 	}
-	var firstErr error
-	for _, file := range files {
-		if err := r.registerFile(file); err != nil && firstErr == nil {
-			firstErr = err
-		}
-	}
-	return firstErr
-}
-func (r *Files) registerFile(fd protoreflect.FileDescriptor) error {
-	path := fd.Path()
+	path := file.Path()
 	if prev := r.filesByPath[path]; prev != nil {
-		err := errors.New("file %q is already registered", fd.Path())
-		err = amendErrorWithCaller(err, prev, fd)
-		if r == GlobalFiles && ignoreConflict(fd, err) {
+		err := errors.New("file %q is already registered", file.Path())
+		err = amendErrorWithCaller(err, prev, file)
+		if r == GlobalFiles && ignoreConflict(file, err) {
 			err = nil
 		}
 		return err
 	}
 
-	for name := fd.Package(); name != ""; name = name.Parent() {
+	for name := file.Package(); name != ""; name = name.Parent() {
 		switch prev := r.descsByName[name]; prev.(type) {
 		case nil, *packageDescriptor:
 		default:
-			err := errors.New("file %q has a package name conflict over %v", fd.Path(), name)
-			err = amendErrorWithCaller(err, prev, fd)
-			if r == GlobalFiles && ignoreConflict(fd, err) {
+			err := errors.New("file %q has a package name conflict over %v", file.Path(), name)
+			err = amendErrorWithCaller(err, prev, file)
+			if r == GlobalFiles && ignoreConflict(file, err) {
 				err = nil
 			}
 			return err
@@ -132,11 +140,11 @@
 	}
 	var err error
 	var hasConflict bool
-	rangeTopLevelDescriptors(fd, func(d protoreflect.Descriptor) {
+	rangeTopLevelDescriptors(file, func(d protoreflect.Descriptor) {
 		if prev := r.descsByName[d.FullName()]; prev != nil {
 			hasConflict = true
-			err = errors.New("file %q has a name conflict over %v", fd.Path(), d.FullName())
-			err = amendErrorWithCaller(err, prev, fd)
+			err = errors.New("file %q has a name conflict over %v", file.Path(), d.FullName())
+			err = amendErrorWithCaller(err, prev, file)
 			if r == GlobalFiles && ignoreConflict(d, err) {
 				err = nil
 			}
@@ -146,17 +154,17 @@
 		return err
 	}
 
-	for name := fd.Package(); name != ""; name = name.Parent() {
+	for name := file.Package(); name != ""; name = name.Parent() {
 		if r.descsByName[name] == nil {
 			r.descsByName[name] = &packageDescriptor{}
 		}
 	}
-	p := r.descsByName[fd.Package()].(*packageDescriptor)
-	p.files = append(p.files, fd)
-	rangeTopLevelDescriptors(fd, func(d protoreflect.Descriptor) {
+	p := r.descsByName[file.Package()].(*packageDescriptor)
+	p.files = append(p.files, file)
+	rangeTopLevelDescriptors(file, func(d protoreflect.Descriptor) {
 		r.descsByName[d.FullName()] = d
 	})
-	r.filesByPath[path] = fd
+	r.filesByPath[path] = file
 	return nil
 }
 
@@ -361,6 +369,10 @@
 }
 
 // A Type is a protoreflect.EnumType, protoreflect.MessageType, or protoreflect.ExtensionType.
+//
+// Deprecated: Do not use.
+//
+// TODO: Remove.
 type Type interface{}
 
 // MessageTypeResolver is an interface for looking up messages.
@@ -443,13 +455,15 @@
 }
 
 type (
-	typesByName         map[protoreflect.FullName]Type
+	typesByName         map[protoreflect.FullName]interface{}
 	extensionsByMessage map[protoreflect.FullName]extensionsByNumber
 	extensionsByNumber  map[protoreflect.FieldNumber]protoreflect.ExtensionType
 )
 
 // NewTypes returns a registry initialized with the provided set of types.
 // If there are conflicts, the first one takes precedence.
+//
+// Deprecated: Use RegisterMessage, RegisterEnum, or RegisterExtension.
 func NewTypes(typs ...Type) *Types {
 	r := new(Types)
 	r.Register(typs...) // ignore errors; first takes precedence
@@ -458,88 +472,109 @@
 
 // Register registers the provided list of descriptor types.
 //
-// If a registration conflict occurs for enum, message, or extension types
-// (e.g., two different types have the same full name),
-// then the first type takes precedence and an error is returned.
+// Deprecated: Use RegisterMessage, RegisterEnum, or RegisterExtension.
 func (r *Types) Register(typs ...Type) error {
+	var firstErr error
+	for _, typ := range typs {
+		var err error
+		switch t := typ.(type) {
+		case protoreflect.EnumType:
+			err = r.RegisterEnum(t)
+		case protoreflect.MessageType:
+			err = r.RegisterMessage(t)
+		case protoreflect.ExtensionType:
+			err = r.RegisterExtension(t)
+		default:
+			panic(fmt.Sprintf("invalid type: %T", t))
+		}
+		if firstErr == nil {
+			firstErr = err
+		}
+	}
+	return firstErr
+}
+
+// RegisterMessage registers the provided message type.
+//
+// If a naming conflict occurs, the type is not registered and an error is returned.
+func (r *Types) RegisterMessage(mt protoreflect.MessageType) error {
 	if r == GlobalTypes {
 		globalMutex.Lock()
 		defer globalMutex.Unlock()
 	}
-	var firstErr error
-typeLoop:
-	for _, typ := range typs {
-		switch typ.(type) {
-		case protoreflect.EnumType, protoreflect.MessageType, protoreflect.ExtensionType:
-			// Check for conflicts in typesByName.
-			var desc protoreflect.Descriptor
-			var pcnt *int
-			switch t := typ.(type) {
-			case protoreflect.EnumType:
-				desc = t.Descriptor()
-				pcnt = &r.numEnums
-			case protoreflect.MessageType:
-				desc = t.Descriptor()
-				pcnt = &r.numMessages
-			case protoreflect.ExtensionType:
-				desc = t.TypeDescriptor()
-				pcnt = &r.numExtensions
-			default:
-				panic(fmt.Sprintf("invalid type: %T", t))
-			}
-			name := desc.FullName()
-			if prev := r.typesByName[name]; prev != nil {
-				err := errors.New("%v %v is already registered", typeName(typ), name)
-				err = amendErrorWithCaller(err, prev, typ)
-				if r == GlobalTypes && ignoreConflict(desc, err) {
-					err = nil
-				}
-				if firstErr == nil {
-					firstErr = err
-				}
-				continue typeLoop
-			}
 
-			// Check for conflicts in extensionsByMessage.
-			if xt, _ := typ.(protoreflect.ExtensionType); xt != nil {
-				xd := xt.TypeDescriptor()
-				field := xd.Number()
-				message := xd.ContainingMessage().FullName()
-				if prev := r.extensionsByMessage[message][field]; prev != nil {
-					err := errors.New("extension number %d is already registered on message %v", field, message)
-					err = amendErrorWithCaller(err, prev, typ)
-					if r == GlobalTypes && ignoreConflict(xd, err) {
-						err = nil
-					}
-					if firstErr == nil {
-						firstErr = err
-					}
-					continue typeLoop
-				}
+	if err := r.register("message", mt.Descriptor(), mt); err != nil {
+		return err
+	}
+	r.numMessages++
+	return nil
+}
 
-				// Update extensionsByMessage.
-				if r.extensionsByMessage == nil {
-					r.extensionsByMessage = make(extensionsByMessage)
-				}
-				if r.extensionsByMessage[message] == nil {
-					r.extensionsByMessage[message] = make(extensionsByNumber)
-				}
-				r.extensionsByMessage[message][field] = xt
-			}
+// RegisterEnum registers the provided enum type.
+//
+// If a naming conflict occurs, the type is not registered and an error is returned.
+func (r *Types) RegisterEnum(et protoreflect.EnumType) error {
+	if r == GlobalTypes {
+		globalMutex.Lock()
+		defer globalMutex.Unlock()
+	}
 
-			// Update typesByName and the count.
-			if r.typesByName == nil {
-				r.typesByName = make(typesByName)
-			}
-			r.typesByName[name] = typ
-			(*pcnt)++
-		default:
-			if firstErr == nil {
-				firstErr = errors.New("invalid type: %v", typeName(typ))
-			}
+	if err := r.register("enum", et.Descriptor(), et); err != nil {
+		return err
+	}
+	r.numEnums++
+	return nil
+}
+
+// RegisterExtension registers the provided extension type.
+//
+// If a naming conflict occurs, the type is not registered and an error is returned.
+func (r *Types) RegisterExtension(xt protoreflect.ExtensionType) error {
+	if r == GlobalTypes {
+		globalMutex.Lock()
+		defer globalMutex.Unlock()
+	}
+
+	xd := xt.TypeDescriptor()
+	field := xd.Number()
+	message := xd.ContainingMessage().FullName()
+	if prev := r.extensionsByMessage[message][field]; prev != nil {
+		err := errors.New("extension number %d is already registered on message %v", field, message)
+		err = amendErrorWithCaller(err, prev, xt)
+		if !(r == GlobalTypes && ignoreConflict(xd, err)) {
+			return err
 		}
 	}
-	return firstErr
+
+	if err := r.register("extension", xt.TypeDescriptor(), xt); err != nil {
+		return err
+	}
+	if r.extensionsByMessage == nil {
+		r.extensionsByMessage = make(extensionsByMessage)
+	}
+	if r.extensionsByMessage[message] == nil {
+		r.extensionsByMessage[message] = make(extensionsByNumber)
+	}
+	r.extensionsByMessage[message][field] = xt
+	r.numExtensions++
+	return nil
+}
+
+func (r *Types) register(kind string, desc protoreflect.Descriptor, typ interface{}) error {
+	name := desc.FullName()
+	prev := r.typesByName[name]
+	if prev != nil {
+		err := errors.New("%v %v is already registered", kind, name)
+		err = amendErrorWithCaller(err, prev, typ)
+		if !(r == GlobalTypes && ignoreConflict(desc, err)) {
+			return err
+		}
+	}
+	if r.typesByName == nil {
+		r.typesByName = make(typesByName)
+	}
+	r.typesByName[name] = typ
+	return nil
 }
 
 // FindEnumByName looks up an enum by its full name.
diff --git a/reflect/protoregistry/registry_test.go b/reflect/protoregistry/registry_test.go
index 8d70243..22fbdad 100644
--- a/reflect/protoregistry/registry_test.go
+++ b/reflect/protoregistry/registry_test.go
@@ -282,7 +282,7 @@
 		t.Run("", func(t *testing.T) {
 			var files preg.Files
 			for i, tc := range tt.files {
-				gotErr := files.Register(tc.inFile)
+				gotErr := files.RegisterFile(tc.inFile)
 				if ((gotErr == nil) != (tc.wantErr == "")) || !strings.Contains(fmt.Sprint(gotErr), tc.wantErr) {
 					t.Errorf("file %d, Register() = %v, want %v", i, gotErr, tc.wantErr)
 				}
@@ -332,8 +332,17 @@
 	xt1 := testpb.E_StringField
 	xt2 := testpb.E_Message4_MessageField
 	registry := new(preg.Types)
-	if err := registry.Register(mt1, et1, xt1, xt2); err != nil {
-		t.Fatalf("registry.Register() returns unexpected error: %v", err)
+	if err := registry.RegisterMessage(mt1); err != nil {
+		t.Fatalf("registry.RegisterMessage(%v) returns unexpected error: %v", mt1.Descriptor().FullName(), err)
+	}
+	if err := registry.RegisterEnum(et1); err != nil {
+		t.Fatalf("registry.RegisterEnum(%v) returns unexpected error: %v", et1.Descriptor().FullName(), err)
+	}
+	if err := registry.RegisterExtension(xt1); err != nil {
+		t.Fatalf("registry.RegisterExtension(%v) returns unexpected error: %v", xt1.TypeDescriptor().FullName(), err)
+	}
+	if err := registry.RegisterExtension(xt2); err != nil {
+		t.Fatalf("registry.RegisterExtension(%v) returns unexpected error: %v", xt2.TypeDescriptor().FullName(), err)
 	}
 
 	t.Run("FindMessageByName", func(t *testing.T) {