reflect/protoregistry: add (*Types).Register{Message,Enum,Extension}
Add type-safe methods to register message, enum, and extension types.
Deprecate the NewTypes function and the (*Types).Register method.
Add (*File).RegisterFile and deprecate the NewFiles function and
the (*File).Register method.
Updates golang/protobuf#963
Change-Id: Ie89e77526e0874539e9bd929ca0ba8d758e65a6e
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/199898
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/reflect/protodesc/file_test.go b/reflect/protodesc/file_test.go
index 38e8b4c..9ee6c8d 100644
--- a/reflect/protodesc/file_test.go
+++ b/reflect/protodesc/file_test.go
@@ -916,7 +916,7 @@
if err != nil {
t.Fatalf("dependency %d: unexpected NewFile() error: %v", i, err)
}
- if err := r.Register(f); err != nil {
+ if err := r.RegisterFile(f); err != nil {
t.Fatalf("dependency %d: unexpected Register() error: %v", i, err)
}
}
diff --git a/reflect/protoregistry/registry.go b/reflect/protoregistry/registry.go
index 4689f6b..116bd68 100644
--- a/reflect/protoregistry/registry.go
+++ b/reflect/protoregistry/registry.go
@@ -75,20 +75,37 @@
// NewFiles returns a registry initialized with the provided set of files.
// Files with a namespace conflict with an pre-existing file are not registered.
+//
+// Deprecated: Use Register.
func NewFiles(files ...protoreflect.FileDescriptor) *Files {
r := new(Files)
- r.Register(files...) // ignore errors; first takes precedence
+ for _, file := range files {
+ r.RegisterFile(file) // ignore errors; first takes precedence
+ }
return r
}
// Register registers the provided list of file descriptors.
//
-// If any descriptor within a file conflicts with the descriptor of any
+// Deprecated: Use RegisterFile.
+func (r *Files) Register(files ...protoreflect.FileDescriptor) error {
+ var firstErr error
+ for _, file := range files {
+ if err := r.RegisterFile(file); err != nil && firstErr == nil {
+ firstErr = err
+ }
+ }
+ return firstErr
+}
+
+// RegisterFile registers the provided file descriptor.
+//
+// If any descriptor within the file conflicts with the descriptor of any
// previously registered file (e.g., two enums with the same full name),
-// then that file is not registered and an error is returned.
+// then the file is not registered and an error is returned.
//
// It is permitted for multiple files to have the same file path.
-func (r *Files) Register(files ...protoreflect.FileDescriptor) error {
+func (r *Files) RegisterFile(file protoreflect.FileDescriptor) error {
if r == GlobalFiles {
globalMutex.Lock()
defer globalMutex.Unlock()
@@ -99,32 +116,23 @@
}
r.filesByPath = make(map[string]protoreflect.FileDescriptor)
}
- var firstErr error
- for _, file := range files {
- if err := r.registerFile(file); err != nil && firstErr == nil {
- firstErr = err
- }
- }
- return firstErr
-}
-func (r *Files) registerFile(fd protoreflect.FileDescriptor) error {
- path := fd.Path()
+ path := file.Path()
if prev := r.filesByPath[path]; prev != nil {
- err := errors.New("file %q is already registered", fd.Path())
- err = amendErrorWithCaller(err, prev, fd)
- if r == GlobalFiles && ignoreConflict(fd, err) {
+ err := errors.New("file %q is already registered", file.Path())
+ err = amendErrorWithCaller(err, prev, file)
+ if r == GlobalFiles && ignoreConflict(file, err) {
err = nil
}
return err
}
- for name := fd.Package(); name != ""; name = name.Parent() {
+ for name := file.Package(); name != ""; name = name.Parent() {
switch prev := r.descsByName[name]; prev.(type) {
case nil, *packageDescriptor:
default:
- err := errors.New("file %q has a package name conflict over %v", fd.Path(), name)
- err = amendErrorWithCaller(err, prev, fd)
- if r == GlobalFiles && ignoreConflict(fd, err) {
+ err := errors.New("file %q has a package name conflict over %v", file.Path(), name)
+ err = amendErrorWithCaller(err, prev, file)
+ if r == GlobalFiles && ignoreConflict(file, err) {
err = nil
}
return err
@@ -132,11 +140,11 @@
}
var err error
var hasConflict bool
- rangeTopLevelDescriptors(fd, func(d protoreflect.Descriptor) {
+ rangeTopLevelDescriptors(file, func(d protoreflect.Descriptor) {
if prev := r.descsByName[d.FullName()]; prev != nil {
hasConflict = true
- err = errors.New("file %q has a name conflict over %v", fd.Path(), d.FullName())
- err = amendErrorWithCaller(err, prev, fd)
+ err = errors.New("file %q has a name conflict over %v", file.Path(), d.FullName())
+ err = amendErrorWithCaller(err, prev, file)
if r == GlobalFiles && ignoreConflict(d, err) {
err = nil
}
@@ -146,17 +154,17 @@
return err
}
- for name := fd.Package(); name != ""; name = name.Parent() {
+ for name := file.Package(); name != ""; name = name.Parent() {
if r.descsByName[name] == nil {
r.descsByName[name] = &packageDescriptor{}
}
}
- p := r.descsByName[fd.Package()].(*packageDescriptor)
- p.files = append(p.files, fd)
- rangeTopLevelDescriptors(fd, func(d protoreflect.Descriptor) {
+ p := r.descsByName[file.Package()].(*packageDescriptor)
+ p.files = append(p.files, file)
+ rangeTopLevelDescriptors(file, func(d protoreflect.Descriptor) {
r.descsByName[d.FullName()] = d
})
- r.filesByPath[path] = fd
+ r.filesByPath[path] = file
return nil
}
@@ -361,6 +369,10 @@
}
// A Type is a protoreflect.EnumType, protoreflect.MessageType, or protoreflect.ExtensionType.
+//
+// Deprecated: Do not use.
+//
+// TODO: Remove.
type Type interface{}
// MessageTypeResolver is an interface for looking up messages.
@@ -443,13 +455,15 @@
}
type (
- typesByName map[protoreflect.FullName]Type
+ typesByName map[protoreflect.FullName]interface{}
extensionsByMessage map[protoreflect.FullName]extensionsByNumber
extensionsByNumber map[protoreflect.FieldNumber]protoreflect.ExtensionType
)
// NewTypes returns a registry initialized with the provided set of types.
// If there are conflicts, the first one takes precedence.
+//
+// Deprecated: Use RegisterMessage, RegisterEnum, or RegisterExtension.
func NewTypes(typs ...Type) *Types {
r := new(Types)
r.Register(typs...) // ignore errors; first takes precedence
@@ -458,88 +472,109 @@
// Register registers the provided list of descriptor types.
//
-// If a registration conflict occurs for enum, message, or extension types
-// (e.g., two different types have the same full name),
-// then the first type takes precedence and an error is returned.
+// Deprecated: Use RegisterMessage, RegisterEnum, or RegisterExtension.
func (r *Types) Register(typs ...Type) error {
+ var firstErr error
+ for _, typ := range typs {
+ var err error
+ switch t := typ.(type) {
+ case protoreflect.EnumType:
+ err = r.RegisterEnum(t)
+ case protoreflect.MessageType:
+ err = r.RegisterMessage(t)
+ case protoreflect.ExtensionType:
+ err = r.RegisterExtension(t)
+ default:
+ panic(fmt.Sprintf("invalid type: %T", t))
+ }
+ if firstErr == nil {
+ firstErr = err
+ }
+ }
+ return firstErr
+}
+
+// RegisterMessage registers the provided message type.
+//
+// If a naming conflict occurs, the type is not registered and an error is returned.
+func (r *Types) RegisterMessage(mt protoreflect.MessageType) error {
if r == GlobalTypes {
globalMutex.Lock()
defer globalMutex.Unlock()
}
- var firstErr error
-typeLoop:
- for _, typ := range typs {
- switch typ.(type) {
- case protoreflect.EnumType, protoreflect.MessageType, protoreflect.ExtensionType:
- // Check for conflicts in typesByName.
- var desc protoreflect.Descriptor
- var pcnt *int
- switch t := typ.(type) {
- case protoreflect.EnumType:
- desc = t.Descriptor()
- pcnt = &r.numEnums
- case protoreflect.MessageType:
- desc = t.Descriptor()
- pcnt = &r.numMessages
- case protoreflect.ExtensionType:
- desc = t.TypeDescriptor()
- pcnt = &r.numExtensions
- default:
- panic(fmt.Sprintf("invalid type: %T", t))
- }
- name := desc.FullName()
- if prev := r.typesByName[name]; prev != nil {
- err := errors.New("%v %v is already registered", typeName(typ), name)
- err = amendErrorWithCaller(err, prev, typ)
- if r == GlobalTypes && ignoreConflict(desc, err) {
- err = nil
- }
- if firstErr == nil {
- firstErr = err
- }
- continue typeLoop
- }
- // Check for conflicts in extensionsByMessage.
- if xt, _ := typ.(protoreflect.ExtensionType); xt != nil {
- xd := xt.TypeDescriptor()
- field := xd.Number()
- message := xd.ContainingMessage().FullName()
- if prev := r.extensionsByMessage[message][field]; prev != nil {
- err := errors.New("extension number %d is already registered on message %v", field, message)
- err = amendErrorWithCaller(err, prev, typ)
- if r == GlobalTypes && ignoreConflict(xd, err) {
- err = nil
- }
- if firstErr == nil {
- firstErr = err
- }
- continue typeLoop
- }
+ if err := r.register("message", mt.Descriptor(), mt); err != nil {
+ return err
+ }
+ r.numMessages++
+ return nil
+}
- // Update extensionsByMessage.
- if r.extensionsByMessage == nil {
- r.extensionsByMessage = make(extensionsByMessage)
- }
- if r.extensionsByMessage[message] == nil {
- r.extensionsByMessage[message] = make(extensionsByNumber)
- }
- r.extensionsByMessage[message][field] = xt
- }
+// RegisterEnum registers the provided enum type.
+//
+// If a naming conflict occurs, the type is not registered and an error is returned.
+func (r *Types) RegisterEnum(et protoreflect.EnumType) error {
+ if r == GlobalTypes {
+ globalMutex.Lock()
+ defer globalMutex.Unlock()
+ }
- // Update typesByName and the count.
- if r.typesByName == nil {
- r.typesByName = make(typesByName)
- }
- r.typesByName[name] = typ
- (*pcnt)++
- default:
- if firstErr == nil {
- firstErr = errors.New("invalid type: %v", typeName(typ))
- }
+ if err := r.register("enum", et.Descriptor(), et); err != nil {
+ return err
+ }
+ r.numEnums++
+ return nil
+}
+
+// RegisterExtension registers the provided extension type.
+//
+// If a naming conflict occurs, the type is not registered and an error is returned.
+func (r *Types) RegisterExtension(xt protoreflect.ExtensionType) error {
+ if r == GlobalTypes {
+ globalMutex.Lock()
+ defer globalMutex.Unlock()
+ }
+
+ xd := xt.TypeDescriptor()
+ field := xd.Number()
+ message := xd.ContainingMessage().FullName()
+ if prev := r.extensionsByMessage[message][field]; prev != nil {
+ err := errors.New("extension number %d is already registered on message %v", field, message)
+ err = amendErrorWithCaller(err, prev, xt)
+ if !(r == GlobalTypes && ignoreConflict(xd, err)) {
+ return err
}
}
- return firstErr
+
+ if err := r.register("extension", xt.TypeDescriptor(), xt); err != nil {
+ return err
+ }
+ if r.extensionsByMessage == nil {
+ r.extensionsByMessage = make(extensionsByMessage)
+ }
+ if r.extensionsByMessage[message] == nil {
+ r.extensionsByMessage[message] = make(extensionsByNumber)
+ }
+ r.extensionsByMessage[message][field] = xt
+ r.numExtensions++
+ return nil
+}
+
+func (r *Types) register(kind string, desc protoreflect.Descriptor, typ interface{}) error {
+ name := desc.FullName()
+ prev := r.typesByName[name]
+ if prev != nil {
+ err := errors.New("%v %v is already registered", kind, name)
+ err = amendErrorWithCaller(err, prev, typ)
+ if !(r == GlobalTypes && ignoreConflict(desc, err)) {
+ return err
+ }
+ }
+ if r.typesByName == nil {
+ r.typesByName = make(typesByName)
+ }
+ r.typesByName[name] = typ
+ return nil
}
// FindEnumByName looks up an enum by its full name.
diff --git a/reflect/protoregistry/registry_test.go b/reflect/protoregistry/registry_test.go
index 8d70243..22fbdad 100644
--- a/reflect/protoregistry/registry_test.go
+++ b/reflect/protoregistry/registry_test.go
@@ -282,7 +282,7 @@
t.Run("", func(t *testing.T) {
var files preg.Files
for i, tc := range tt.files {
- gotErr := files.Register(tc.inFile)
+ gotErr := files.RegisterFile(tc.inFile)
if ((gotErr == nil) != (tc.wantErr == "")) || !strings.Contains(fmt.Sprint(gotErr), tc.wantErr) {
t.Errorf("file %d, Register() = %v, want %v", i, gotErr, tc.wantErr)
}
@@ -332,8 +332,17 @@
xt1 := testpb.E_StringField
xt2 := testpb.E_Message4_MessageField
registry := new(preg.Types)
- if err := registry.Register(mt1, et1, xt1, xt2); err != nil {
- t.Fatalf("registry.Register() returns unexpected error: %v", err)
+ if err := registry.RegisterMessage(mt1); err != nil {
+ t.Fatalf("registry.RegisterMessage(%v) returns unexpected error: %v", mt1.Descriptor().FullName(), err)
+ }
+ if err := registry.RegisterEnum(et1); err != nil {
+ t.Fatalf("registry.RegisterEnum(%v) returns unexpected error: %v", et1.Descriptor().FullName(), err)
+ }
+ if err := registry.RegisterExtension(xt1); err != nil {
+ t.Fatalf("registry.RegisterExtension(%v) returns unexpected error: %v", xt1.TypeDescriptor().FullName(), err)
+ }
+ if err := registry.RegisterExtension(xt2); err != nil {
+ t.Fatalf("registry.RegisterExtension(%v) returns unexpected error: %v", xt2.TypeDescriptor().FullName(), err)
}
t.Run("FindMessageByName", func(t *testing.T) {