reflect/protodesc: add validation for NewFile

This covers most of the TODO around validation. I left open the ones
that we didn't have clear consensus on yet.

Change-Id: I336c53173ee8d7447558b1e3a0c1ef945e986cd5
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/175140
Reviewed-by: Joe Tsai <joetsai@google.com>
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/reflect/protodesc/protodesc.go b/reflect/protodesc/protodesc.go
index 406e37f..c1ed334 100644
--- a/reflect/protodesc/protodesc.go
+++ b/reflect/protodesc/protodesc.go
@@ -23,14 +23,7 @@
 // that we don't directly use?
 //
 // For example:
-//	* That field numbers don't overlap with reserved numbers.
-//	* That field names don't overlap with reserved names.
-//	* That enum numbers don't overlap with reserved numbers.
-//	* That enum names don't overlap with reserved names.
-//	* That "extendee" is not set for a message field.
-//	* That "oneof_index" is not set for an extension field.
 //	* That "json_name" is not set for an extension field. Maybe, maybe not.
-//	* That "type_name" is not set on a field for non-enums and non-messages.
 //	* That "weak" is not set for an extension field (double check this).
 
 // TODO: Store the input file descriptor to implement:
@@ -96,8 +89,10 @@
 		}
 	}
 
+	imps := importedFiles(f.Imports)
+
 	var err error
-	f.Messages, err = messagesFromDescriptorProto(fd.GetMessageType(), f.Syntax, r)
+	f.Messages, err = messagesFromDescriptorProto(fd.GetMessageType(), imps, r)
 	if err != nil {
 		return nil, err
 	}
@@ -105,11 +100,11 @@
 	if err != nil {
 		return nil, err
 	}
-	f.Extensions, err = extensionsFromDescriptorProto(fd.GetExtension(), r)
+	f.Extensions, err = extensionsFromDescriptorProto(fd.GetExtension(), imps, r)
 	if err != nil {
 		return nil, err
 	}
-	f.Services, err = servicesFromDescriptorProto(fd.GetService(), r)
+	f.Services, err = servicesFromDescriptorProto(fd.GetService(), imps, r)
 	if err != nil {
 		return nil, err
 	}
@@ -117,16 +112,71 @@
 	return prototype.NewFile(&f)
 }
 
-func messagesFromDescriptorProto(mds []*descriptorpb.DescriptorProto, syntax protoreflect.Syntax, r *protoregistry.Files) (ms []prototype.Message, err error) {
+type importSet map[protoreflect.FileDescriptor]bool
+
+func importedFiles(imps []protoreflect.FileImport) importSet {
+	ret := make(importSet)
+	for _, imp := range imps {
+		ret[imp.FileDescriptor] = true
+		addPublicImports(imp.FileDescriptor, ret)
+	}
+	return ret
+}
+
+func addPublicImports(fd protoreflect.FileDescriptor, out importSet) {
+	imps := fd.Imports()
+	for i := 0; i < imps.Len(); i++ {
+		imp := imps.Get(i)
+		if imp.IsPublic {
+			out[imp.FileDescriptor] = true
+			addPublicImports(imp.FileDescriptor, out)
+		}
+	}
+}
+
+func messagesFromDescriptorProto(mds []*descriptorpb.DescriptorProto, imps importSet, r *protoregistry.Files) (ms []prototype.Message, err error) {
 	for _, md := range mds {
 		var m prototype.Message
 		m.Name = protoreflect.Name(md.GetName())
 		m.Options = md.GetOptions()
 		m.IsMapEntry = md.GetOptions().GetMapEntry()
+
+		for _, s := range md.GetReservedName() {
+			m.ReservedNames = append(m.ReservedNames, protoreflect.Name(s))
+		}
+		for _, rr := range md.GetReservedRange() {
+			m.ReservedRanges = append(m.ReservedRanges, [2]protoreflect.FieldNumber{
+				protoreflect.FieldNumber(rr.GetStart()),
+				protoreflect.FieldNumber(rr.GetEnd()),
+			})
+		}
+		for _, xr := range md.GetExtensionRange() {
+			m.ExtensionRanges = append(m.ExtensionRanges, [2]protoreflect.FieldNumber{
+				protoreflect.FieldNumber(xr.GetStart()),
+				protoreflect.FieldNumber(xr.GetEnd()),
+			})
+			m.ExtensionRangeOptions = append(m.ExtensionRangeOptions, xr.GetOptions())
+		}
+		resNames := prototype.Names(m.ReservedNames)
+		resRanges := prototype.FieldRanges(m.ReservedRanges)
+		extRanges := prototype.FieldRanges(m.ExtensionRanges)
+
 		for _, fd := range md.GetField() {
+			if fd.GetExtendee() != "" {
+				return nil, errors.New("message field may not have extendee")
+			}
 			var f prototype.Field
 			f.Name = protoreflect.Name(fd.GetName())
+			if resNames.Has(f.Name) {
+				return nil, errors.New("%v contains field with reserved name %q", m.Name, f.Name)
+			}
 			f.Number = protoreflect.FieldNumber(fd.GetNumber())
+			if resRanges.Has(f.Number) {
+				return nil, errors.New("%v contains field with reserved number %d", m.Name, f.Number)
+			}
+			if extRanges.Has(f.Number) {
+				return nil, errors.New("%v contains field with number %d in extension range", m.Name, f.Number)
+			}
 			f.Cardinality = protoreflect.Cardinality(fd.GetLabel())
 			f.Kind = protoreflect.Kind(fd.GetType())
 			opts := fd.GetOptions()
@@ -155,7 +205,7 @@
 			}
 			switch f.Kind {
 			case protoreflect.EnumKind:
-				f.EnumType, err = findEnumDescriptor(fd.GetTypeName(), r)
+				f.EnumType, err = findEnumDescriptor(fd.GetTypeName(), imps, r)
 				if err != nil {
 					return nil, err
 				}
@@ -163,13 +213,17 @@
 					f.EnumType = prototype.PlaceholderEnum(f.EnumType.FullName())
 				}
 			case protoreflect.MessageKind, protoreflect.GroupKind:
-				f.MessageType, err = findMessageDescriptor(fd.GetTypeName(), r)
+				f.MessageType, err = findMessageDescriptor(fd.GetTypeName(), imps, r)
 				if err != nil {
 					return nil, err
 				}
 				if opts.GetWeak() && !f.MessageType.IsPlaceholder() {
 					f.MessageType = prototype.PlaceholderMessage(f.MessageType.FullName())
 				}
+			default:
+				if fd.GetTypeName() != "" {
+					return nil, errors.New("field of kind %v has type_name set", f.Kind)
+				}
 			}
 			m.Fields = append(m.Fields, f)
 		}
@@ -179,24 +233,8 @@
 				Options: od.Options,
 			})
 		}
-		for _, s := range md.GetReservedName() {
-			m.ReservedNames = append(m.ReservedNames, protoreflect.Name(s))
-		}
-		for _, rr := range md.GetReservedRange() {
-			m.ReservedRanges = append(m.ReservedRanges, [2]protoreflect.FieldNumber{
-				protoreflect.FieldNumber(rr.GetStart()),
-				protoreflect.FieldNumber(rr.GetEnd()),
-			})
-		}
-		for _, xr := range md.GetExtensionRange() {
-			m.ExtensionRanges = append(m.ExtensionRanges, [2]protoreflect.FieldNumber{
-				protoreflect.FieldNumber(xr.GetStart()),
-				protoreflect.FieldNumber(xr.GetEnd()),
-			})
-			m.ExtensionRangeOptions = append(m.ExtensionRangeOptions, xr.GetOptions())
-		}
 
-		m.Messages, err = messagesFromDescriptorProto(md.GetNestedType(), syntax, r)
+		m.Messages, err = messagesFromDescriptorProto(md.GetNestedType(), imps, r)
 		if err != nil {
 			return nil, err
 		}
@@ -204,7 +242,7 @@
 		if err != nil {
 			return nil, err
 		}
-		m.Extensions, err = extensionsFromDescriptorProto(md.GetExtension(), r)
+		m.Extensions, err = extensionsFromDescriptorProto(md.GetExtension(), imps, r)
 		if err != nil {
 			return nil, err
 		}
@@ -219,13 +257,6 @@
 		var e prototype.Enum
 		e.Name = protoreflect.Name(ed.GetName())
 		e.Options = ed.GetOptions()
-		for _, vd := range ed.GetValue() {
-			e.Values = append(e.Values, prototype.EnumValue{
-				Name:    protoreflect.Name(vd.GetName()),
-				Number:  protoreflect.EnumNumber(vd.GetNumber()),
-				Options: vd.Options,
-			})
-		}
 		for _, s := range ed.GetReservedName() {
 			e.ReservedNames = append(e.ReservedNames, protoreflect.Name(s))
 		}
@@ -235,13 +266,33 @@
 				protoreflect.EnumNumber(rr.GetEnd()),
 			})
 		}
+		resNames := prototype.Names(e.ReservedNames)
+		resRanges := prototype.EnumRanges(e.ReservedRanges)
+
+		for _, vd := range ed.GetValue() {
+			v := prototype.EnumValue{
+				Name:    protoreflect.Name(vd.GetName()),
+				Number:  protoreflect.EnumNumber(vd.GetNumber()),
+				Options: vd.Options,
+			}
+			if resNames.Has(v.Name) {
+				return nil, errors.New("enum %v contains value with reserved name %q", e.Name, v.Name)
+			}
+			if resRanges.Has(v.Number) {
+				return nil, errors.New("enum %v contains value with reserved number %d", e.Name, v.Number)
+			}
+			e.Values = append(e.Values, v)
+		}
 		es = append(es, e)
 	}
 	return es, nil
 }
 
-func extensionsFromDescriptorProto(xds []*descriptorpb.FieldDescriptorProto, r *protoregistry.Files) (xs []prototype.Extension, err error) {
+func extensionsFromDescriptorProto(xds []*descriptorpb.FieldDescriptorProto, imps importSet, r *protoregistry.Files) (xs []prototype.Extension, err error) {
 	for _, xd := range xds {
+		if xd.OneofIndex != nil {
+			return nil, errors.New("extension may not have oneof_index")
+		}
 		var x prototype.Extension
 		x.Name = protoreflect.Name(xd.GetName())
 		x.Number = protoreflect.FieldNumber(xd.GetNumber())
@@ -256,17 +307,21 @@
 		}
 		switch x.Kind {
 		case protoreflect.EnumKind:
-			x.EnumType, err = findEnumDescriptor(xd.GetTypeName(), r)
+			x.EnumType, err = findEnumDescriptor(xd.GetTypeName(), imps, r)
 			if err != nil {
 				return nil, err
 			}
 		case protoreflect.MessageKind, protoreflect.GroupKind:
-			x.MessageType, err = findMessageDescriptor(xd.GetTypeName(), r)
+			x.MessageType, err = findMessageDescriptor(xd.GetTypeName(), imps, r)
 			if err != nil {
 				return nil, err
 			}
+		default:
+			if xd.GetTypeName() != "" {
+				return nil, errors.New("extension of kind %v has type_name set", x.Kind)
+			}
 		}
-		x.ExtendedType, err = findMessageDescriptor(xd.GetExtendee(), r)
+		x.ExtendedType, err = findMessageDescriptor(xd.GetExtendee(), imps, r)
 		if err != nil {
 			return nil, err
 		}
@@ -275,7 +330,7 @@
 	return xs, nil
 }
 
-func servicesFromDescriptorProto(sds []*descriptorpb.ServiceDescriptorProto, r *protoregistry.Files) (ss []prototype.Service, err error) {
+func servicesFromDescriptorProto(sds []*descriptorpb.ServiceDescriptorProto, imps importSet, r *protoregistry.Files) (ss []prototype.Service, err error) {
 	for _, sd := range sds {
 		var s prototype.Service
 		s.Name = protoreflect.Name(sd.GetName())
@@ -284,11 +339,11 @@
 			var m prototype.Method
 			m.Name = protoreflect.Name(md.GetName())
 			m.Options = md.GetOptions()
-			m.InputType, err = findMessageDescriptor(md.GetInputType(), r)
+			m.InputType, err = findMessageDescriptor(md.GetInputType(), imps, r)
 			if err != nil {
 				return nil, err
 			}
-			m.OutputType, err = findMessageDescriptor(md.GetOutputType(), r)
+			m.OutputType, err = findMessageDescriptor(md.GetOutputType(), imps, r)
 			if err != nil {
 				return nil, err
 			}
@@ -306,7 +361,7 @@
 // simplifies our implementation as we won't need to implement C++'s namespace
 // scoping rules.
 
-func findMessageDescriptor(s string, r *protoregistry.Files) (protoreflect.MessageDescriptor, error) {
+func findMessageDescriptor(s string, imps importSet, r *protoregistry.Files) (protoreflect.MessageDescriptor, error) {
 	if !strings.HasPrefix(s, ".") {
 		return nil, errors.New("identifier name must be fully qualified with a leading dot: %v", s)
 	}
@@ -315,10 +370,33 @@
 	if err != nil {
 		return prototype.PlaceholderMessage(name), nil
 	}
+	if err := validateFileInImports(md, imps); err != nil {
+		return nil, err
+	}
 	return md, nil
 }
 
-func findEnumDescriptor(s string, r *protoregistry.Files) (protoreflect.EnumDescriptor, error) {
+func validateFileInImports(d protoreflect.Descriptor, imps importSet) error {
+	fd := fileDescriptor(d)
+	if fd == nil {
+		return errors.New("%v has no parent FileDescriptor", d.FullName())
+	}
+	if !imps[fd] {
+		return errors.New("reference to type %v without import of %v", d.FullName(), fd.Path())
+	}
+	return nil
+}
+
+func fileDescriptor(d protoreflect.Descriptor) protoreflect.FileDescriptor {
+	for ; d != nil; d, _ = d.Parent() {
+		if fd, ok := d.(protoreflect.FileDescriptor); ok {
+			return fd
+		}
+	}
+	return nil
+}
+
+func findEnumDescriptor(s string, imps importSet, r *protoregistry.Files) (protoreflect.EnumDescriptor, error) {
 	if !strings.HasPrefix(s, ".") {
 		return nil, errors.New("identifier name must be fully qualified with a leading dot: %v", s)
 	}
@@ -327,6 +405,9 @@
 	if err != nil {
 		return prototype.PlaceholderEnum(name), nil
 	}
+	if err := validateFileInImports(ed, imps); err != nil {
+		return nil, err
+	}
 	return ed, nil
 }