internal/fileinit: prevent map entry descriptors from implementing MessageType

The protobuf type system hacks the representation of map entries into that
of a pseudo-message descriptor.

Previously, we made all message descriptors implement MessageType
where type descriptors had a GoType method that simply returned nil.
Unfortunately, this violates a nice property in the Go type system
where being able to assert to a MessageType guarantees that Go type
information is truly associated with that descriptor.

This CL makes it such that message descriptors for map entries
do not implement MessageType.

Change-Id: I23873cb71fe0ab3c0befd8052830ea6e53c97ca9
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/168399
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/internal/fileinit/desc_init.go b/internal/fileinit/desc_init.go
index 7d14998..d1310a4 100644
--- a/internal/fileinit/desc_init.go
+++ b/internal/fileinit/desc_init.go
@@ -19,6 +19,13 @@
 	file.initDecls(len(fb.EnumOutputTypes), len(fb.MessageOutputTypes), len(fb.ExtensionOutputTypes))
 	file.unmarshalSeed(fb.RawDescriptor)
 
+	// Determine which message descriptors represent map entries based on the
+	// lack of an associated Go type.
+	messageDecls := file.GoTypes[len(file.allEnums):]
+	for i := range file.allMessages {
+		file.allMessages[i].isMapEntry = messageDecls[i] == nil
+	}
+
 	// Extended message dependencies are eagerly handled since registration
 	// needs this information at program init time.
 	for i := range file.allExtensions {
@@ -31,7 +38,7 @@
 }
 
 // initDecls pre-allocates slices for the exact number of enums, messages
-// (excluding map entries), and extensions declared in the proto file.
+// (including map entries), and extensions declared in the proto file.
 // This is done to avoid regrowing the slice, which would change the address
 // for any previously seen declaration.
 //
@@ -279,7 +286,7 @@
 		for i := range md.enums.list {
 			_, n := wire.ConsumeVarint(b)
 			v, m := wire.ConsumeBytes(b[n:])
-			md.enums.list[i].unmarshalSeed(v, nb, pf, md, i)
+			md.enums.list[i].unmarshalSeed(v, nb, pf, md.asDesc(), i)
 			b = b[n+m:]
 		}
 	}
@@ -288,7 +295,7 @@
 		for i := range md.messages.list {
 			_, n := wire.ConsumeVarint(b)
 			v, m := wire.ConsumeBytes(b[n:])
-			md.messages.list[i].unmarshalSeed(v, nb, pf, md, i)
+			md.messages.list[i].unmarshalSeed(v, nb, pf, md.asDesc(), i)
 			b = b[n+m:]
 		}
 	}
@@ -297,7 +304,7 @@
 		for i := range md.extensions.list {
 			_, n := wire.ConsumeVarint(b)
 			v, m := wire.ConsumeBytes(b[n:])
-			md.extensions.list[i].unmarshalSeed(v, nb, pf, md, i)
+			md.extensions.list[i].unmarshalSeed(v, nb, pf, md.asDesc(), i)
 			b = b[n+m:]
 		}
 	}