internal/fileinit: prevent map entry descriptors from implementing MessageType

The protobuf type system hacks the representation of map entries into that
of a pseudo-message descriptor.

Previously, we made all message descriptors implement MessageType
where type descriptors had a GoType method that simply returned nil.
Unfortunately, this violates a nice property in the Go type system
where being able to assert to a MessageType guarantees that Go type
information is truly associated with that descriptor.

This CL makes it such that message descriptors for map entries
do not implement MessageType.

Change-Id: I23873cb71fe0ab3c0befd8052830ea6e53c97ca9
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/168399
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/internal/fileinit/desc_lazy.go b/internal/fileinit/desc_lazy.go
index 4eef546..7356b32 100644
--- a/internal/fileinit/desc_lazy.go
+++ b/internal/fileinit/desc_lazy.go
@@ -64,15 +64,12 @@
 		md := &file.allMessages[i]
 
 		// Associate the MessageType with a concrete Go type.
-		//
-		// Note that descriptors for map entries, which have no associated
-		// Go type, also implement the protoreflect.MessageType interface,
-		// but have a GoType accessor that reports nil. Calling New results
-		// in a panic, which is sensible behavior.
-		md.lazy.typ = reflect.TypeOf(messageDecls[i])
-		md.lazy.new = func() pref.Message {
-			t := md.lazy.typ.Elem()
-			return reflect.New(t).Interface().(pref.ProtoMessage).ProtoReflect()
+		if !md.isMapEntry {
+			md.lazy.typ = reflect.TypeOf(messageDecls[i])
+			md.lazy.new = func() pref.Message {
+				t := md.lazy.typ.Elem()
+				return reflect.New(t).Interface().(pref.ProtoMessage).ProtoReflect()
+			}
 		}
 
 		// Resolve message field dependencies.
@@ -173,9 +170,9 @@
 		// Resolve extension field dependency.
 		switch xd.lazy.kind {
 		case pref.EnumKind:
-			xd.lazy.enumType = file.popEnumDependency()
+			xd.lazy.enumType = file.popEnumDependency().(pref.EnumType)
 		case pref.MessageKind, pref.GroupKind:
-			xd.lazy.messageType = file.popMessageDependency()
+			xd.lazy.messageType = file.popMessageDependency().(pref.MessageType)
 		}
 		xd.lazy.defVal.lazyInit(xd.lazy.kind, file.enumValuesOf(xd.lazy.enumType))
 	}
@@ -219,8 +216,8 @@
 	if md == nil {
 		return false
 	}
-	if md, ok := md.(*messageDesc); ok && md.parentFile == fd {
-		return md.lazy.isMapEntry
+	if md, ok := md.(*messageDescriptor); ok && md.parentFile == fd {
+		return md.isMapEntry
 	}
 	return md.IsMapEntry()
 }
@@ -238,7 +235,7 @@
 	return ed.Values()
 }
 
-func (fd *fileDesc) popEnumDependency() pref.EnumType {
+func (fd *fileDesc) popEnumDependency() pref.EnumDescriptor {
 	depIdx := fd.popDependencyIndex()
 	if depIdx < len(fd.allEnums)+len(fd.allMessages) {
 		return &fd.allEnums[depIdx]
@@ -247,10 +244,10 @@
 	}
 }
 
-func (fd *fileDesc) popMessageDependency() pref.MessageType {
+func (fd *fileDesc) popMessageDependency() pref.MessageDescriptor {
 	depIdx := fd.popDependencyIndex()
 	if depIdx < len(fd.allEnums)+len(fd.allMessages) {
-		return &fd.allMessages[depIdx-len(fd.allEnums)]
+		return fd.allMessages[depIdx-len(fd.allEnums)].asDesc()
 	} else {
 		return pimpl.Export{}.MessageTypeOf(fd.GoTypes[depIdx])
 	}
@@ -490,6 +487,7 @@
 func (md *messageDesc) unmarshalFull(b []byte, nb *nameBuilder) {
 	var rawFields, rawOneofs [][]byte
 	var enumIdx, messageIdx, extensionIdx int
+	var isMapEntry bool
 	md.lazy = new(messageLazy)
 	for len(b) > 0 {
 		num, typ, n := wire.ConsumeTag(b)
@@ -521,7 +519,7 @@
 				md.extensions.list[extensionIdx].unmarshalFull(v, nb)
 				extensionIdx++
 			case fieldnum.DescriptorProto_Options:
-				md.unmarshalOptions(v)
+				md.unmarshalOptions(v, &isMapEntry)
 			}
 		default:
 			m := wire.ConsumeFieldValue(num, typ, b)
@@ -534,21 +532,25 @@
 		md.lazy.oneofs.list = make([]oneofDesc, len(rawOneofs))
 		for i, b := range rawFields {
 			fd := &md.lazy.fields.list[i]
-			fd.unmarshalFull(b, nb, md.parentFile, md, i)
+			fd.unmarshalFull(b, nb, md.parentFile, md.asDesc(), i)
 			if fd.cardinality == pref.Required {
 				md.lazy.reqNumbers.list = append(md.lazy.reqNumbers.list, fd.number)
 			}
 		}
 		for i, b := range rawOneofs {
 			od := &md.lazy.oneofs.list[i]
-			od.unmarshalFull(b, nb, md.parentFile, md, i)
+			od.unmarshalFull(b, nb, md.parentFile, md.asDesc(), i)
 		}
 	}
 
-	md.parentFile.lazy.byName[md.FullName()] = md
+	if isMapEntry != md.isMapEntry {
+		panic("mismatching map entry property")
+	}
+
+	md.parentFile.lazy.byName[md.FullName()] = md.asDesc()
 }
 
-func (md *messageDesc) unmarshalOptions(b []byte) {
+func (md *messageDesc) unmarshalOptions(b []byte, isMapEntry *bool) {
 	md.lazy.options = append(md.lazy.options, b...)
 	for len(b) > 0 {
 		num, typ, n := wire.ConsumeTag(b)
@@ -559,7 +561,7 @@
 			b = b[m:]
 			switch num {
 			case fieldnum.MessageOptions_MapEntry:
-				md.lazy.isMapEntry = wire.DecodeBool(v)
+				*isMapEntry = wire.DecodeBool(v)
 			case fieldnum.MessageOptions_MessageSetWireFormat:
 				md.lazy.isMessageSet = wire.DecodeBool(v)
 			}
@@ -646,7 +648,7 @@
 				// In messageDesc.UnmarshalFull, we allocate slices for both
 				// the field and oneof descriptors before unmarshaling either
 				// of them. This ensures pointers to slice elements are stable.
-				od := &pd.(*messageDesc).lazy.oneofs.list[v]
+				od := &pd.(messageType).lazy.oneofs.list[v]
 				od.fields.list = append(od.fields.list, fd)
 				if fd.oneofType != nil {
 					panic("oneof type already set")