cmd/protoc-gen-go: register messages and map field types
Move generation of the init function that registers all the types in a
file into a single function.
Take some care to generate the registrations in the same order as the
previous protoc-gen-go, to make it easier to catch unintended
differences in output.
For the same reason, adjust the order of generation to generate all
enums before all messages (matches previous behavior).
Change-Id: Ie0d574004d01a16f8d7b10be3882719a3c41676e
Reviewed-on: https://go-review.googlesource.com/135359
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/cmd/protoc-gen-go/main.go b/cmd/protoc-gen-go/main.go
index 815f080..c7d9f8d 100644
--- a/cmd/protoc-gen-go/main.go
+++ b/cmd/protoc-gen-go/main.go
@@ -14,6 +14,7 @@
"flag"
"fmt"
"math"
+ "sort"
"strconv"
"strings"
@@ -53,7 +54,8 @@
*protogen.File
locationMap map[string][]*descpb.SourceCodeInfo_Location
descriptorVar string // var containing the gzipped FileDescriptorProto
- init []string
+ allEnums []*protogen.Enum
+ allMessages []*protogen.Message
}
func genFile(gen *protogen.Plugin, file *protogen.File) {
@@ -66,6 +68,12 @@
f.locationMap[key] = append(f.locationMap[key], loc)
}
+ f.allEnums = append(f.allEnums, f.File.Enums...)
+ f.allMessages = append(f.allMessages, f.File.Messages...)
+ for _, message := range f.Messages {
+ f.initMessage(message)
+ }
+
// Determine the name of the var holding the file descriptor:
//
// fileDescriptor_<hash of filename>
@@ -91,25 +99,26 @@
}, "// please upgrade the proto package")
g.P()
- for _, enum := range f.Enums {
+ for _, enum := range f.allEnums {
genEnum(gen, g, f, enum)
}
- for _, message := range f.Messages {
+ for _, message := range f.allMessages {
genMessage(gen, g, f, message)
}
- if len(f.init) != 0 {
- g.P("func init() {")
- for _, s := range f.init {
- g.P(s)
- }
- g.P("}")
- g.P()
- }
+ genInitFunction(gen, g, f)
genFileDescriptor(gen, g, f)
}
+func (f *File) initMessage(message *protogen.Message) {
+ f.allEnums = append(f.allEnums, message.Enums...)
+ f.allMessages = append(f.allMessages, message.Messages...)
+ for _, m := range message.Messages {
+ f.initMessage(m)
+ }
+}
+
func genFileDescriptor(gen *protogen.Plugin, g *protogen.GeneratedFile, f *File) {
// Trim the source_code_info from the descriptor.
// Marshal and gzip it.
@@ -215,14 +224,6 @@
g.P()
genWellKnownType(g, enum.GoIdent, enum.Desc)
-
- f.init = append(f.init, fmt.Sprintf("%s(%q, %s, %s)",
- g.QualifiedGoIdent(protogen.GoIdent{
- GoImportPath: protoPackage,
- GoName: "RegisterEnum",
- }),
- enumRegistryName(enum), nameMap, valueMap,
- ))
}
// enumRegistryName returns the name used to register an enum with the proto
@@ -250,10 +251,6 @@
return
}
- for _, e := range message.Enums {
- genEnum(gen, g, f, e)
- }
-
genComment(g, f, message.Path)
// TODO: deprecation
g.P("type ", message.GoIdent, " struct {")
@@ -415,9 +412,7 @@
g.P()
}
- for _, nested := range message.Messages {
- genMessage(gen, g, f, nested)
- }
+ genWellKnownType(g, message.GoIdent, message.Desc)
}
// fieldGoType returns the Go type used for a field.
@@ -590,6 +585,57 @@
return string(field.Desc.Name()) + ",omitempty"
}
+// genInitFunction generates an init function that registers the types in the
+// generated file with the proto package.
+func genInitFunction(gen *protogen.Plugin, g *protogen.GeneratedFile, f *File) {
+ if len(f.allMessages) == 0 && len(f.allEnums) == 0 {
+ return
+ }
+
+ g.P("func init() {")
+ for _, message := range f.allMessages {
+ if message.Desc.IsMapEntry() {
+ continue
+ }
+
+ name := message.GoIdent.GoName
+ g.P(protogen.GoIdent{
+ GoImportPath: protoPackage,
+ GoName: "RegisterType",
+ }, fmt.Sprintf("((*%s)(nil), %q)", name, message.Desc.FullName()))
+
+ // Types of map fields, sorted by the name of the field message type.
+ var mapFields []*protogen.Field
+ for _, field := range message.Fields {
+ if field.Desc.IsMap() {
+ mapFields = append(mapFields, field)
+ }
+ }
+ sort.Slice(mapFields, func(i, j int) bool {
+ ni := mapFields[i].MessageType.Desc.FullName()
+ nj := mapFields[j].MessageType.Desc.FullName()
+ return ni < nj
+ })
+ for _, field := range mapFields {
+ typeName := string(field.MessageType.Desc.FullName())
+ goType, _ := fieldGoType(g, field)
+ g.P(protogen.GoIdent{
+ GoImportPath: protoPackage,
+ GoName: "RegisterMapType",
+ }, fmt.Sprintf("((%v)(nil), %q)", goType, typeName))
+ }
+ }
+ for _, enum := range f.allEnums {
+ name := enum.GoIdent.GoName
+ g.P(protogen.GoIdent{
+ GoImportPath: protoPackage,
+ GoName: "RegisterEnum",
+ }, fmt.Sprintf("(%q, %s_name, %s_value)", enumRegistryName(enum), name, name))
+ }
+ g.P("}")
+ g.P()
+}
+
func genComment(g *protogen.GeneratedFile, f *File, path []int32) {
for _, loc := range f.locationMap[pathKey(path)] {
if loc.LeadingComments == nil {