cmd/protoc-gen-go: special cases for MessageSet extensions

Add special-case handling for extension fields named
"message_set_extension" that extend a message_set_wire_format message.

Support special cases for a proto1 feature that was superseded by proto2
extensions.

Change-Id: Icbdb711111c66be547bf8d6f37ab3079c320e2a1
Reviewed-on: https://go-review.googlesource.com/136536
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/cmd/protoc-gen-go/main.go b/cmd/protoc-gen-go/main.go
index 28745ae..893f3ed 100644
--- a/cmd/protoc-gen-go/main.go
+++ b/cmd/protoc-gen-go/main.go
@@ -744,6 +744,19 @@
 }
 
 func genExtension(gen *protogen.Plugin, g *protogen.GeneratedFile, f *File, extension *protogen.Extension) {
+	// Special case for proto2 message sets: If this extension is extending
+	// proto2.bridge.MessageSet, and its final name component is "message_set_extension",
+	// then drop that last component.
+	//
+	// TODO: This should be implemented in the text formatter rather than the generator.
+	// In addition, the situation for when to apply this special case is implemented
+	// differently in other languages:
+	// https://github.com/google/protobuf/blob/aff10976/src/google/protobuf/text_format.cc#L1560
+	name := extension.Desc.FullName()
+	if isExtensionMessageSetElement(gen, extension) {
+		name = name.Parent()
+	}
+
 	g.P("var ", extensionVar(f, extension), " = &", protogen.GoIdent{
 		GoImportPath: protoPackage,
 		GoName:       "ExtensionDesc",
@@ -755,13 +768,19 @@
 	}
 	g.P("ExtensionType: (", goType, ")(nil),")
 	g.P("Field: ", extension.Desc.Number(), ",")
-	g.P("Name: ", strconv.Quote(string(extension.Desc.FullName())), ",")
+	g.P("Name: ", strconv.Quote(string(name)), ",")
 	g.P("Tag: ", strconv.Quote(fieldProtobufTag(extension)), ",")
 	g.P("Filename: ", strconv.Quote(f.Desc.Path()), ",")
 	g.P("}")
 	g.P()
 }
 
+func isExtensionMessageSetElement(gen *protogen.Plugin, extension *protogen.Extension) bool {
+	return extension.ParentMessage != nil &&
+		messageOptions(gen, extension.ExtendedType).GetMessageSetWireFormat() &&
+		extension.Desc.Name() == "message_set_extension"
+}
+
 // extensionVar returns the var holding the ExtensionDesc for an extension.
 func extensionVar(f *File, extension *protogen.Extension) protogen.GoIdent {
 	name := "E_"
@@ -783,11 +802,22 @@
 	}
 
 	g.P("func init() {")
+	for _, enum := range f.allEnums {
+		name := enum.GoIdent.GoName
+		g.P(protogen.GoIdent{
+			GoImportPath: protoPackage,
+			GoName:       "RegisterEnum",
+		}, fmt.Sprintf("(%q, %s_name, %s_value)", enumRegistryName(enum), name, name))
+	}
 	for _, message := range f.allMessages {
 		if message.Desc.IsMapEntry() {
 			continue
 		}
 
+		for _, extension := range message.Extensions {
+			genRegisterExtension(gen, g, f, extension)
+		}
+
 		name := message.GoIdent.GoName
 		g.P(protogen.GoIdent{
 			GoImportPath: protoPackage,
@@ -815,23 +845,30 @@
 			}, fmt.Sprintf("((%v)(nil), %q)", goType, typeName))
 		}
 	}
-	for _, enum := range f.allEnums {
-		name := enum.GoIdent.GoName
-		g.P(protogen.GoIdent{
-			GoImportPath: protoPackage,
-			GoName:       "RegisterEnum",
-		}, fmt.Sprintf("(%q, %s_name, %s_value)", enumRegistryName(enum), name, name))
-	}
-	for _, extension := range f.allExtensions {
-		g.P(protogen.GoIdent{
-			GoImportPath: protoPackage,
-			GoName:       "RegisterExtension",
-		}, "(", extensionVar(f, extension), ")")
+	for _, extension := range f.Extensions {
+		genRegisterExtension(gen, g, f, extension)
 	}
 	g.P("}")
 	g.P()
 }
 
+func genRegisterExtension(gen *protogen.Plugin, g *protogen.GeneratedFile, f *File, extension *protogen.Extension) {
+	g.P(protogen.GoIdent{
+		GoImportPath: protoPackage,
+		GoName:       "RegisterExtension",
+	}, "(", extensionVar(f, extension), ")")
+	if isExtensionMessageSetElement(gen, extension) {
+		goType, pointer := fieldGoType(g, extension)
+		if pointer {
+			goType = "*" + goType
+		}
+		g.P(protogen.GoIdent{
+			GoImportPath: protoPackage,
+			GoName:       "RegisterMessageSetType",
+		}, "((", goType, ")(nil), ", extension.Desc.Number(), ",", strconv.Quote(string(extension.Desc.FullName().Parent())), ")")
+	}
+}
+
 func genComment(g *protogen.GeneratedFile, f *File, path []int32) (hasComment bool) {
 	for _, loc := range f.locationMap[pathKey(path)] {
 		if loc.LeadingComments == nil {