cmd/protoc-gen-go: special cases for MessageSet extensions
Add special-case handling for extension fields named
"message_set_extension" that extend a message_set_wire_format message.
Support special cases for a proto1 feature that was superseded by proto2
extensions.
Change-Id: Icbdb711111c66be547bf8d6f37ab3079c320e2a1
Reviewed-on: https://go-review.googlesource.com/136536
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/cmd/protoc-gen-go/main.go b/cmd/protoc-gen-go/main.go
index 28745ae..893f3ed 100644
--- a/cmd/protoc-gen-go/main.go
+++ b/cmd/protoc-gen-go/main.go
@@ -744,6 +744,19 @@
}
func genExtension(gen *protogen.Plugin, g *protogen.GeneratedFile, f *File, extension *protogen.Extension) {
+ // Special case for proto2 message sets: If this extension is extending
+ // proto2.bridge.MessageSet, and its final name component is "message_set_extension",
+ // then drop that last component.
+ //
+ // TODO: This should be implemented in the text formatter rather than the generator.
+ // In addition, the situation for when to apply this special case is implemented
+ // differently in other languages:
+ // https://github.com/google/protobuf/blob/aff10976/src/google/protobuf/text_format.cc#L1560
+ name := extension.Desc.FullName()
+ if isExtensionMessageSetElement(gen, extension) {
+ name = name.Parent()
+ }
+
g.P("var ", extensionVar(f, extension), " = &", protogen.GoIdent{
GoImportPath: protoPackage,
GoName: "ExtensionDesc",
@@ -755,13 +768,19 @@
}
g.P("ExtensionType: (", goType, ")(nil),")
g.P("Field: ", extension.Desc.Number(), ",")
- g.P("Name: ", strconv.Quote(string(extension.Desc.FullName())), ",")
+ g.P("Name: ", strconv.Quote(string(name)), ",")
g.P("Tag: ", strconv.Quote(fieldProtobufTag(extension)), ",")
g.P("Filename: ", strconv.Quote(f.Desc.Path()), ",")
g.P("}")
g.P()
}
+func isExtensionMessageSetElement(gen *protogen.Plugin, extension *protogen.Extension) bool {
+ return extension.ParentMessage != nil &&
+ messageOptions(gen, extension.ExtendedType).GetMessageSetWireFormat() &&
+ extension.Desc.Name() == "message_set_extension"
+}
+
// extensionVar returns the var holding the ExtensionDesc for an extension.
func extensionVar(f *File, extension *protogen.Extension) protogen.GoIdent {
name := "E_"
@@ -783,11 +802,22 @@
}
g.P("func init() {")
+ for _, enum := range f.allEnums {
+ name := enum.GoIdent.GoName
+ g.P(protogen.GoIdent{
+ GoImportPath: protoPackage,
+ GoName: "RegisterEnum",
+ }, fmt.Sprintf("(%q, %s_name, %s_value)", enumRegistryName(enum), name, name))
+ }
for _, message := range f.allMessages {
if message.Desc.IsMapEntry() {
continue
}
+ for _, extension := range message.Extensions {
+ genRegisterExtension(gen, g, f, extension)
+ }
+
name := message.GoIdent.GoName
g.P(protogen.GoIdent{
GoImportPath: protoPackage,
@@ -815,23 +845,30 @@
}, fmt.Sprintf("((%v)(nil), %q)", goType, typeName))
}
}
- for _, enum := range f.allEnums {
- name := enum.GoIdent.GoName
- g.P(protogen.GoIdent{
- GoImportPath: protoPackage,
- GoName: "RegisterEnum",
- }, fmt.Sprintf("(%q, %s_name, %s_value)", enumRegistryName(enum), name, name))
- }
- for _, extension := range f.allExtensions {
- g.P(protogen.GoIdent{
- GoImportPath: protoPackage,
- GoName: "RegisterExtension",
- }, "(", extensionVar(f, extension), ")")
+ for _, extension := range f.Extensions {
+ genRegisterExtension(gen, g, f, extension)
}
g.P("}")
g.P()
}
+func genRegisterExtension(gen *protogen.Plugin, g *protogen.GeneratedFile, f *File, extension *protogen.Extension) {
+ g.P(protogen.GoIdent{
+ GoImportPath: protoPackage,
+ GoName: "RegisterExtension",
+ }, "(", extensionVar(f, extension), ")")
+ if isExtensionMessageSetElement(gen, extension) {
+ goType, pointer := fieldGoType(g, extension)
+ if pointer {
+ goType = "*" + goType
+ }
+ g.P(protogen.GoIdent{
+ GoImportPath: protoPackage,
+ GoName: "RegisterMessageSetType",
+ }, "((", goType, ")(nil), ", extension.Desc.Number(), ",", strconv.Quote(string(extension.Desc.FullName().Parent())), ")")
+ }
+}
+
func genComment(g *protogen.GeneratedFile, f *File, path []int32) (hasComment bool) {
for _, loc := range f.locationMap[pathKey(path)] {
if loc.LeadingComments == nil {