reflect/protodesc: add validation for NewFile
This covers most of the TODO around validation. I left open the ones
that we didn't have clear consensus on yet.
Change-Id: I336c53173ee8d7447558b1e3a0c1ef945e986cd5
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/175140
Reviewed-by: Joe Tsai <joetsai@google.com>
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/reflect/protodesc/protodesc.go b/reflect/protodesc/protodesc.go
index 406e37f..c1ed334 100644
--- a/reflect/protodesc/protodesc.go
+++ b/reflect/protodesc/protodesc.go
@@ -23,14 +23,7 @@
// that we don't directly use?
//
// For example:
-// * That field numbers don't overlap with reserved numbers.
-// * That field names don't overlap with reserved names.
-// * That enum numbers don't overlap with reserved numbers.
-// * That enum names don't overlap with reserved names.
-// * That "extendee" is not set for a message field.
-// * That "oneof_index" is not set for an extension field.
// * That "json_name" is not set for an extension field. Maybe, maybe not.
-// * That "type_name" is not set on a field for non-enums and non-messages.
// * That "weak" is not set for an extension field (double check this).
// TODO: Store the input file descriptor to implement:
@@ -96,8 +89,10 @@
}
}
+ imps := importedFiles(f.Imports)
+
var err error
- f.Messages, err = messagesFromDescriptorProto(fd.GetMessageType(), f.Syntax, r)
+ f.Messages, err = messagesFromDescriptorProto(fd.GetMessageType(), imps, r)
if err != nil {
return nil, err
}
@@ -105,11 +100,11 @@
if err != nil {
return nil, err
}
- f.Extensions, err = extensionsFromDescriptorProto(fd.GetExtension(), r)
+ f.Extensions, err = extensionsFromDescriptorProto(fd.GetExtension(), imps, r)
if err != nil {
return nil, err
}
- f.Services, err = servicesFromDescriptorProto(fd.GetService(), r)
+ f.Services, err = servicesFromDescriptorProto(fd.GetService(), imps, r)
if err != nil {
return nil, err
}
@@ -117,16 +112,71 @@
return prototype.NewFile(&f)
}
-func messagesFromDescriptorProto(mds []*descriptorpb.DescriptorProto, syntax protoreflect.Syntax, r *protoregistry.Files) (ms []prototype.Message, err error) {
+type importSet map[protoreflect.FileDescriptor]bool
+
+func importedFiles(imps []protoreflect.FileImport) importSet {
+ ret := make(importSet)
+ for _, imp := range imps {
+ ret[imp.FileDescriptor] = true
+ addPublicImports(imp.FileDescriptor, ret)
+ }
+ return ret
+}
+
+func addPublicImports(fd protoreflect.FileDescriptor, out importSet) {
+ imps := fd.Imports()
+ for i := 0; i < imps.Len(); i++ {
+ imp := imps.Get(i)
+ if imp.IsPublic {
+ out[imp.FileDescriptor] = true
+ addPublicImports(imp.FileDescriptor, out)
+ }
+ }
+}
+
+func messagesFromDescriptorProto(mds []*descriptorpb.DescriptorProto, imps importSet, r *protoregistry.Files) (ms []prototype.Message, err error) {
for _, md := range mds {
var m prototype.Message
m.Name = protoreflect.Name(md.GetName())
m.Options = md.GetOptions()
m.IsMapEntry = md.GetOptions().GetMapEntry()
+
+ for _, s := range md.GetReservedName() {
+ m.ReservedNames = append(m.ReservedNames, protoreflect.Name(s))
+ }
+ for _, rr := range md.GetReservedRange() {
+ m.ReservedRanges = append(m.ReservedRanges, [2]protoreflect.FieldNumber{
+ protoreflect.FieldNumber(rr.GetStart()),
+ protoreflect.FieldNumber(rr.GetEnd()),
+ })
+ }
+ for _, xr := range md.GetExtensionRange() {
+ m.ExtensionRanges = append(m.ExtensionRanges, [2]protoreflect.FieldNumber{
+ protoreflect.FieldNumber(xr.GetStart()),
+ protoreflect.FieldNumber(xr.GetEnd()),
+ })
+ m.ExtensionRangeOptions = append(m.ExtensionRangeOptions, xr.GetOptions())
+ }
+ resNames := prototype.Names(m.ReservedNames)
+ resRanges := prototype.FieldRanges(m.ReservedRanges)
+ extRanges := prototype.FieldRanges(m.ExtensionRanges)
+
for _, fd := range md.GetField() {
+ if fd.GetExtendee() != "" {
+ return nil, errors.New("message field may not have extendee")
+ }
var f prototype.Field
f.Name = protoreflect.Name(fd.GetName())
+ if resNames.Has(f.Name) {
+ return nil, errors.New("%v contains field with reserved name %q", m.Name, f.Name)
+ }
f.Number = protoreflect.FieldNumber(fd.GetNumber())
+ if resRanges.Has(f.Number) {
+ return nil, errors.New("%v contains field with reserved number %d", m.Name, f.Number)
+ }
+ if extRanges.Has(f.Number) {
+ return nil, errors.New("%v contains field with number %d in extension range", m.Name, f.Number)
+ }
f.Cardinality = protoreflect.Cardinality(fd.GetLabel())
f.Kind = protoreflect.Kind(fd.GetType())
opts := fd.GetOptions()
@@ -155,7 +205,7 @@
}
switch f.Kind {
case protoreflect.EnumKind:
- f.EnumType, err = findEnumDescriptor(fd.GetTypeName(), r)
+ f.EnumType, err = findEnumDescriptor(fd.GetTypeName(), imps, r)
if err != nil {
return nil, err
}
@@ -163,13 +213,17 @@
f.EnumType = prototype.PlaceholderEnum(f.EnumType.FullName())
}
case protoreflect.MessageKind, protoreflect.GroupKind:
- f.MessageType, err = findMessageDescriptor(fd.GetTypeName(), r)
+ f.MessageType, err = findMessageDescriptor(fd.GetTypeName(), imps, r)
if err != nil {
return nil, err
}
if opts.GetWeak() && !f.MessageType.IsPlaceholder() {
f.MessageType = prototype.PlaceholderMessage(f.MessageType.FullName())
}
+ default:
+ if fd.GetTypeName() != "" {
+ return nil, errors.New("field of kind %v has type_name set", f.Kind)
+ }
}
m.Fields = append(m.Fields, f)
}
@@ -179,24 +233,8 @@
Options: od.Options,
})
}
- for _, s := range md.GetReservedName() {
- m.ReservedNames = append(m.ReservedNames, protoreflect.Name(s))
- }
- for _, rr := range md.GetReservedRange() {
- m.ReservedRanges = append(m.ReservedRanges, [2]protoreflect.FieldNumber{
- protoreflect.FieldNumber(rr.GetStart()),
- protoreflect.FieldNumber(rr.GetEnd()),
- })
- }
- for _, xr := range md.GetExtensionRange() {
- m.ExtensionRanges = append(m.ExtensionRanges, [2]protoreflect.FieldNumber{
- protoreflect.FieldNumber(xr.GetStart()),
- protoreflect.FieldNumber(xr.GetEnd()),
- })
- m.ExtensionRangeOptions = append(m.ExtensionRangeOptions, xr.GetOptions())
- }
- m.Messages, err = messagesFromDescriptorProto(md.GetNestedType(), syntax, r)
+ m.Messages, err = messagesFromDescriptorProto(md.GetNestedType(), imps, r)
if err != nil {
return nil, err
}
@@ -204,7 +242,7 @@
if err != nil {
return nil, err
}
- m.Extensions, err = extensionsFromDescriptorProto(md.GetExtension(), r)
+ m.Extensions, err = extensionsFromDescriptorProto(md.GetExtension(), imps, r)
if err != nil {
return nil, err
}
@@ -219,13 +257,6 @@
var e prototype.Enum
e.Name = protoreflect.Name(ed.GetName())
e.Options = ed.GetOptions()
- for _, vd := range ed.GetValue() {
- e.Values = append(e.Values, prototype.EnumValue{
- Name: protoreflect.Name(vd.GetName()),
- Number: protoreflect.EnumNumber(vd.GetNumber()),
- Options: vd.Options,
- })
- }
for _, s := range ed.GetReservedName() {
e.ReservedNames = append(e.ReservedNames, protoreflect.Name(s))
}
@@ -235,13 +266,33 @@
protoreflect.EnumNumber(rr.GetEnd()),
})
}
+ resNames := prototype.Names(e.ReservedNames)
+ resRanges := prototype.EnumRanges(e.ReservedRanges)
+
+ for _, vd := range ed.GetValue() {
+ v := prototype.EnumValue{
+ Name: protoreflect.Name(vd.GetName()),
+ Number: protoreflect.EnumNumber(vd.GetNumber()),
+ Options: vd.Options,
+ }
+ if resNames.Has(v.Name) {
+ return nil, errors.New("enum %v contains value with reserved name %q", e.Name, v.Name)
+ }
+ if resRanges.Has(v.Number) {
+ return nil, errors.New("enum %v contains value with reserved number %d", e.Name, v.Number)
+ }
+ e.Values = append(e.Values, v)
+ }
es = append(es, e)
}
return es, nil
}
-func extensionsFromDescriptorProto(xds []*descriptorpb.FieldDescriptorProto, r *protoregistry.Files) (xs []prototype.Extension, err error) {
+func extensionsFromDescriptorProto(xds []*descriptorpb.FieldDescriptorProto, imps importSet, r *protoregistry.Files) (xs []prototype.Extension, err error) {
for _, xd := range xds {
+ if xd.OneofIndex != nil {
+ return nil, errors.New("extension may not have oneof_index")
+ }
var x prototype.Extension
x.Name = protoreflect.Name(xd.GetName())
x.Number = protoreflect.FieldNumber(xd.GetNumber())
@@ -256,17 +307,21 @@
}
switch x.Kind {
case protoreflect.EnumKind:
- x.EnumType, err = findEnumDescriptor(xd.GetTypeName(), r)
+ x.EnumType, err = findEnumDescriptor(xd.GetTypeName(), imps, r)
if err != nil {
return nil, err
}
case protoreflect.MessageKind, protoreflect.GroupKind:
- x.MessageType, err = findMessageDescriptor(xd.GetTypeName(), r)
+ x.MessageType, err = findMessageDescriptor(xd.GetTypeName(), imps, r)
if err != nil {
return nil, err
}
+ default:
+ if xd.GetTypeName() != "" {
+ return nil, errors.New("extension of kind %v has type_name set", x.Kind)
+ }
}
- x.ExtendedType, err = findMessageDescriptor(xd.GetExtendee(), r)
+ x.ExtendedType, err = findMessageDescriptor(xd.GetExtendee(), imps, r)
if err != nil {
return nil, err
}
@@ -275,7 +330,7 @@
return xs, nil
}
-func servicesFromDescriptorProto(sds []*descriptorpb.ServiceDescriptorProto, r *protoregistry.Files) (ss []prototype.Service, err error) {
+func servicesFromDescriptorProto(sds []*descriptorpb.ServiceDescriptorProto, imps importSet, r *protoregistry.Files) (ss []prototype.Service, err error) {
for _, sd := range sds {
var s prototype.Service
s.Name = protoreflect.Name(sd.GetName())
@@ -284,11 +339,11 @@
var m prototype.Method
m.Name = protoreflect.Name(md.GetName())
m.Options = md.GetOptions()
- m.InputType, err = findMessageDescriptor(md.GetInputType(), r)
+ m.InputType, err = findMessageDescriptor(md.GetInputType(), imps, r)
if err != nil {
return nil, err
}
- m.OutputType, err = findMessageDescriptor(md.GetOutputType(), r)
+ m.OutputType, err = findMessageDescriptor(md.GetOutputType(), imps, r)
if err != nil {
return nil, err
}
@@ -306,7 +361,7 @@
// simplifies our implementation as we won't need to implement C++'s namespace
// scoping rules.
-func findMessageDescriptor(s string, r *protoregistry.Files) (protoreflect.MessageDescriptor, error) {
+func findMessageDescriptor(s string, imps importSet, r *protoregistry.Files) (protoreflect.MessageDescriptor, error) {
if !strings.HasPrefix(s, ".") {
return nil, errors.New("identifier name must be fully qualified with a leading dot: %v", s)
}
@@ -315,10 +370,33 @@
if err != nil {
return prototype.PlaceholderMessage(name), nil
}
+ if err := validateFileInImports(md, imps); err != nil {
+ return nil, err
+ }
return md, nil
}
-func findEnumDescriptor(s string, r *protoregistry.Files) (protoreflect.EnumDescriptor, error) {
+func validateFileInImports(d protoreflect.Descriptor, imps importSet) error {
+ fd := fileDescriptor(d)
+ if fd == nil {
+ return errors.New("%v has no parent FileDescriptor", d.FullName())
+ }
+ if !imps[fd] {
+ return errors.New("reference to type %v without import of %v", d.FullName(), fd.Path())
+ }
+ return nil
+}
+
+func fileDescriptor(d protoreflect.Descriptor) protoreflect.FileDescriptor {
+ for ; d != nil; d, _ = d.Parent() {
+ if fd, ok := d.(protoreflect.FileDescriptor); ok {
+ return fd
+ }
+ }
+ return nil
+}
+
+func findEnumDescriptor(s string, imps importSet, r *protoregistry.Files) (protoreflect.EnumDescriptor, error) {
if !strings.HasPrefix(s, ".") {
return nil, errors.New("identifier name must be fully qualified with a leading dot: %v", s)
}
@@ -327,6 +405,9 @@
if err != nil {
return prototype.PlaceholderEnum(name), nil
}
+ if err := validateFileInImports(ed, imps); err != nil {
+ return nil, err
+ }
return ed, nil
}