cmd/protoc-gen-go: register messages and map field types

Move generation of the init function that registers all the types in a
file into a single function.

Take some care to generate the registrations in the same order as the
previous protoc-gen-go, to make it easier to catch unintended
differences in output.

For the same reason, adjust the order of generation to generate all
enums before all messages (matches previous behavior).

Change-Id: Ie0d574004d01a16f8d7b10be3882719a3c41676e
Reviewed-on: https://go-review.googlesource.com/135359
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/cmd/protoc-gen-go/main.go b/cmd/protoc-gen-go/main.go
index 815f080..c7d9f8d 100644
--- a/cmd/protoc-gen-go/main.go
+++ b/cmd/protoc-gen-go/main.go
@@ -14,6 +14,7 @@
 	"flag"
 	"fmt"
 	"math"
+	"sort"
 	"strconv"
 	"strings"
 
@@ -53,7 +54,8 @@
 	*protogen.File
 	locationMap   map[string][]*descpb.SourceCodeInfo_Location
 	descriptorVar string // var containing the gzipped FileDescriptorProto
-	init          []string
+	allEnums      []*protogen.Enum
+	allMessages   []*protogen.Message
 }
 
 func genFile(gen *protogen.Plugin, file *protogen.File) {
@@ -66,6 +68,12 @@
 		f.locationMap[key] = append(f.locationMap[key], loc)
 	}
 
+	f.allEnums = append(f.allEnums, f.File.Enums...)
+	f.allMessages = append(f.allMessages, f.File.Messages...)
+	for _, message := range f.Messages {
+		f.initMessage(message)
+	}
+
 	// Determine the name of the var holding the file descriptor:
 	//
 	//     fileDescriptor_<hash of filename>
@@ -91,25 +99,26 @@
 	}, "// please upgrade the proto package")
 	g.P()
 
-	for _, enum := range f.Enums {
+	for _, enum := range f.allEnums {
 		genEnum(gen, g, f, enum)
 	}
-	for _, message := range f.Messages {
+	for _, message := range f.allMessages {
 		genMessage(gen, g, f, message)
 	}
 
-	if len(f.init) != 0 {
-		g.P("func init() {")
-		for _, s := range f.init {
-			g.P(s)
-		}
-		g.P("}")
-		g.P()
-	}
+	genInitFunction(gen, g, f)
 
 	genFileDescriptor(gen, g, f)
 }
 
+func (f *File) initMessage(message *protogen.Message) {
+	f.allEnums = append(f.allEnums, message.Enums...)
+	f.allMessages = append(f.allMessages, message.Messages...)
+	for _, m := range message.Messages {
+		f.initMessage(m)
+	}
+}
+
 func genFileDescriptor(gen *protogen.Plugin, g *protogen.GeneratedFile, f *File) {
 	// Trim the source_code_info from the descriptor.
 	// Marshal and gzip it.
@@ -215,14 +224,6 @@
 	g.P()
 
 	genWellKnownType(g, enum.GoIdent, enum.Desc)
-
-	f.init = append(f.init, fmt.Sprintf("%s(%q, %s, %s)",
-		g.QualifiedGoIdent(protogen.GoIdent{
-			GoImportPath: protoPackage,
-			GoName:       "RegisterEnum",
-		}),
-		enumRegistryName(enum), nameMap, valueMap,
-	))
 }
 
 // enumRegistryName returns the name used to register an enum with the proto
@@ -250,10 +251,6 @@
 		return
 	}
 
-	for _, e := range message.Enums {
-		genEnum(gen, g, f, e)
-	}
-
 	genComment(g, f, message.Path)
 	// TODO: deprecation
 	g.P("type ", message.GoIdent, " struct {")
@@ -415,9 +412,7 @@
 		g.P()
 	}
 
-	for _, nested := range message.Messages {
-		genMessage(gen, g, f, nested)
-	}
+	genWellKnownType(g, message.GoIdent, message.Desc)
 }
 
 // fieldGoType returns the Go type used for a field.
@@ -590,6 +585,57 @@
 	return string(field.Desc.Name()) + ",omitempty"
 }
 
+// genInitFunction generates an init function that registers the types in the
+// generated file with the proto package.
+func genInitFunction(gen *protogen.Plugin, g *protogen.GeneratedFile, f *File) {
+	if len(f.allMessages) == 0 && len(f.allEnums) == 0 {
+		return
+	}
+
+	g.P("func init() {")
+	for _, message := range f.allMessages {
+		if message.Desc.IsMapEntry() {
+			continue
+		}
+
+		name := message.GoIdent.GoName
+		g.P(protogen.GoIdent{
+			GoImportPath: protoPackage,
+			GoName:       "RegisterType",
+		}, fmt.Sprintf("((*%s)(nil), %q)", name, message.Desc.FullName()))
+
+		// Types of map fields, sorted by the name of the field message type.
+		var mapFields []*protogen.Field
+		for _, field := range message.Fields {
+			if field.Desc.IsMap() {
+				mapFields = append(mapFields, field)
+			}
+		}
+		sort.Slice(mapFields, func(i, j int) bool {
+			ni := mapFields[i].MessageType.Desc.FullName()
+			nj := mapFields[j].MessageType.Desc.FullName()
+			return ni < nj
+		})
+		for _, field := range mapFields {
+			typeName := string(field.MessageType.Desc.FullName())
+			goType, _ := fieldGoType(g, field)
+			g.P(protogen.GoIdent{
+				GoImportPath: protoPackage,
+				GoName:       "RegisterMapType",
+			}, fmt.Sprintf("((%v)(nil), %q)", goType, typeName))
+		}
+	}
+	for _, enum := range f.allEnums {
+		name := enum.GoIdent.GoName
+		g.P(protogen.GoIdent{
+			GoImportPath: protoPackage,
+			GoName:       "RegisterEnum",
+		}, fmt.Sprintf("(%q, %s_name, %s_value)", enumRegistryName(enum), name, name))
+	}
+	g.P("}")
+	g.P()
+}
+
 func genComment(g *protogen.GeneratedFile, f *File, path []int32) {
 	for _, loc := range f.locationMap[pathKey(path)] {
 		if loc.LeadingComments == nil {