cmd/protoc-gen-go: generate public imports by parsing the imported .pb.go

Rather than explicitly enumerating the set of symbols to import,
just parse the imported file and extract every exported symbol.

This is possibly a bit more code, but adapts much better to future
expansion.

Change-Id: I4429664f4c068a2a55949d46aefc19865b008a77
Reviewed-on: https://go-review.googlesource.com/c/155677
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/cmd/protoc-gen-go/internal_gengo/main.go b/cmd/protoc-gen-go/internal_gengo/main.go
index 08a7660..5e6b7b7 100644
--- a/cmd/protoc-gen-go/internal_gengo/main.go
+++ b/cmd/protoc-gen-go/internal_gengo/main.go
@@ -11,10 +11,15 @@
 	"crypto/sha256"
 	"encoding/hex"
 	"fmt"
+	"go/ast"
+	"go/parser"
+	"go/token"
 	"math"
 	"sort"
 	"strconv"
 	"strings"
+	"unicode"
+	"unicode/utf8"
 
 	"github.com/golang/protobuf/proto"
 	"github.com/golang/protobuf/v2/internal/encoding/tag"
@@ -178,65 +183,54 @@
 	if !imp.IsPublic {
 		return
 	}
-	// TODO: An alternate approach to generating public imports might be
-	// to generate the imported file contents, parse it, and extract all
-	// exported identifiers from the AST to build a list of forwarding
-	// declarations.
-	//
-	// TODO: Consider whether this should generate recursive aliases. e.g.,
-	// if a.proto publicly imports b.proto publicly imports c.proto, should
-	// a.pb.go contain aliases for symbols defined in c.proto?
-	var enums []*protogen.Enum
-	enums = append(enums, impFile.Enums...)
-	walkMessages(impFile.Messages, func(message *protogen.Message) {
-		if message.Desc.IsMapEntry() {
+
+	// Generate public imports by generating the imported file, parsing it,
+	// and extracting every symbol that should receive a forwarding declaration.
+	impGen := gen.NewGeneratedFile("temp.go", impFile.GoImportPath)
+	impGen.Skip()
+	GenerateFile(gen, impFile, impGen)
+	b, err := impGen.Content()
+	if err != nil {
+		gen.Error(err)
+		return
+	}
+	fset := token.NewFileSet()
+	astFile, err := parser.ParseFile(fset, "", b, parser.ParseComments)
+	if err != nil {
+		gen.Error(err)
+		return
+	}
+	genForward := func(tok token.Token, name string) {
+		// Don't import unexported symbols.
+		r, _ := utf8.DecodeRuneInString(name)
+		if !unicode.IsUpper(r) {
 			return
 		}
-		enums = append(enums, message.Enums...)
-		for _, field := range message.Fields {
-			if !field.Desc.HasDefault() {
-				continue
-			}
-			defVar := protogen.GoIdent{
-				GoImportPath: message.GoIdent.GoImportPath,
-				GoName:       "Default_" + message.GoIdent.GoName + "_" + field.GoName,
-			}
-			decl := "const"
-			switch field.Desc.Kind() {
-			case protoreflect.BytesKind:
-				decl = "var"
-			case protoreflect.FloatKind, protoreflect.DoubleKind:
-				f := field.Desc.Default().Float()
-				if math.IsInf(f, -1) || math.IsInf(f, 1) || math.IsNaN(f) {
-					decl = "var"
+		// Don't import the FileDescriptor.
+		if name == impFile.GoDescriptorIdent.GoName {
+			return
+		}
+		g.P(tok, " ", name, " = ", impFile.GoImportPath.Ident(name))
+	}
+	g.P("// Symbols defined in public import of ", imp.Path())
+	g.P()
+	for _, decl := range astFile.Decls {
+		switch decl := decl.(type) {
+		case *ast.GenDecl:
+			for _, spec := range decl.Specs {
+				switch spec := spec.(type) {
+				case *ast.TypeSpec:
+					genForward(decl.Tok, spec.Name.Name)
+				case *ast.ValueSpec:
+					for _, name := range spec.Names {
+						genForward(decl.Tok, name.Name)
+					}
+				case *ast.ImportSpec:
+				default:
+					panic(fmt.Sprintf("can't generate forward for spec type %T", spec))
 				}
 			}
-			g.P(decl, " ", defVar.GoName, " = ", defVar)
 		}
-		g.P("// ", message.GoIdent.GoName, " from public import ", imp.Path())
-		g.P("type ", message.GoIdent.GoName, " = ", message.GoIdent)
-		for _, oneof := range message.Oneofs {
-			for _, field := range oneof.Fields {
-				typ := fieldOneofType(field)
-				g.P("type ", typ.GoName, " = ", typ)
-			}
-		}
-		g.P()
-	})
-	for _, enum := range enums {
-		g.P("// ", enum.GoIdent.GoName, " from public import ", imp.Path())
-		g.P("type ", enum.GoIdent.GoName, " = ", enum.GoIdent)
-		g.P("var ", enum.GoIdent.GoName, "_name = ", enum.GoIdent, "_name")
-		g.P("var ", enum.GoIdent.GoName, "_value = ", enum.GoIdent, "_value")
-		g.P()
-		for _, value := range enum.Values {
-			g.P("const ", value.GoIdent.GoName, " = ", enum.GoIdent.GoName, "(", value.GoIdent, ")")
-		}
-	}
-	for _, ext := range impFile.Extensions {
-		ident := extensionVar(impFile, ext)
-		g.P("var ", ident.GoName, " = ", ident)
-		g.P()
 	}
 	g.P()
 }
diff --git a/cmd/protoc-gen-go/testdata/import_public/a.pb.go b/cmd/protoc-gen-go/testdata/import_public/a.pb.go
index 5cb190b..738dc5e 100644
--- a/cmd/protoc-gen-go/testdata/import_public/a.pb.go
+++ b/cmd/protoc-gen-go/testdata/import_public/a.pb.go
@@ -17,45 +17,42 @@
 // proto package needs to be updated.
 const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
 
+// Symbols defined in public import of import_public/sub/a.proto
+
+type E = sub.E
+
+const E_ZERO = sub.E_ZERO
+
+var E_name = sub.E_name
+var E_value = sub.E_value
+
+type M_Subenum = sub.M_Subenum
+
+const M_M_ZERO = sub.M_M_ZERO
+
+var M_Subenum_name = sub.M_Subenum_name
+var M_Subenum_value = sub.M_Subenum_value
+
+type M_Submessage_Submessage_Subenum = sub.M_Submessage_Submessage_Subenum
+
+const M_Submessage_M_SUBMESSAGE_ZERO = sub.M_Submessage_M_SUBMESSAGE_ZERO
+
+var M_Submessage_Submessage_Subenum_name = sub.M_Submessage_Submessage_Subenum_name
+var M_Submessage_Submessage_Subenum_value = sub.M_Submessage_Submessage_Subenum_value
+
+type M = sub.M
+
 const Default_M_S = sub.Default_M_S
 
 var Default_M_B = sub.Default_M_B
 var Default_M_F = sub.Default_M_F
 
-// M from public import import_public/sub/a.proto
-type M = sub.M
 type M_OneofInt32 = sub.M_OneofInt32
 type M_OneofInt64 = sub.M_OneofInt64
-
-// M_Submessage from public import import_public/sub/a.proto
 type M_Submessage = sub.M_Submessage
 type M_Submessage_SubmessageOneofInt32 = sub.M_Submessage_SubmessageOneofInt32
 type M_Submessage_SubmessageOneofInt64 = sub.M_Submessage_SubmessageOneofInt64
 
-// E from public import import_public/sub/a.proto
-type E = sub.E
-
-var E_name = sub.E_name
-var E_value = sub.E_value
-
-const E_ZERO = E(sub.E_ZERO)
-
-// M_Subenum from public import import_public/sub/a.proto
-type M_Subenum = sub.M_Subenum
-
-var M_Subenum_name = sub.M_Subenum_name
-var M_Subenum_value = sub.M_Subenum_value
-
-const M_M_ZERO = M_Subenum(sub.M_M_ZERO)
-
-// M_Submessage_Submessage_Subenum from public import import_public/sub/a.proto
-type M_Submessage_Submessage_Subenum = sub.M_Submessage_Submessage_Subenum
-
-var M_Submessage_Submessage_Subenum_name = sub.M_Submessage_Submessage_Subenum_name
-var M_Submessage_Submessage_Subenum_value = sub.M_Submessage_Submessage_Subenum_value
-
-const M_Submessage_M_SUBMESSAGE_ZERO = M_Submessage_Submessage_Subenum(sub.M_Submessage_M_SUBMESSAGE_ZERO)
-
 var E_ExtensionField = sub.E_ExtensionField
 
 type Public struct {
diff --git a/protogen/protogen.go b/protogen/protogen.go
index a2de5ac..76964eb 100644
--- a/protogen/protogen.go
+++ b/protogen/protogen.go
@@ -342,7 +342,10 @@
 		return resp
 	}
 	for _, g := range gen.genFiles {
-		content, err := g.content()
+		if g.skip {
+			continue
+		}
+		content, err := g.Content()
 		if err != nil {
 			return &pluginpb.CodeGeneratorResponse{
 				Error: scalar.String(err.Error()),
@@ -761,34 +764,6 @@
 	}
 }
 
-// A GeneratedFile is a generated file.
-type GeneratedFile struct {
-	gen              *Plugin
-	filename         string
-	goImportPath     GoImportPath
-	buf              bytes.Buffer
-	packageNames     map[GoImportPath]GoPackageName
-	usedPackageNames map[GoPackageName]bool
-	manualImports    map[GoImportPath]bool
-	annotations      map[string][]Location
-}
-
-// NewGeneratedFile creates a new generated file with the given filename
-// and import path.
-func (gen *Plugin) NewGeneratedFile(filename string, goImportPath GoImportPath) *GeneratedFile {
-	g := &GeneratedFile{
-		gen:              gen,
-		filename:         filename,
-		goImportPath:     goImportPath,
-		packageNames:     make(map[GoImportPath]GoPackageName),
-		usedPackageNames: make(map[GoPackageName]bool),
-		manualImports:    make(map[GoImportPath]bool),
-		annotations:      make(map[string][]Location),
-	}
-	gen.genFiles = append(gen.genFiles, g)
-	return g
-}
-
 // A Service describes a service.
 type Service struct {
 	Desc protoreflect.ServiceDescriptor
@@ -851,6 +826,35 @@
 	return nil
 }
 
+// A GeneratedFile is a generated file.
+type GeneratedFile struct {
+	gen              *Plugin
+	skip             bool
+	filename         string
+	goImportPath     GoImportPath
+	buf              bytes.Buffer
+	packageNames     map[GoImportPath]GoPackageName
+	usedPackageNames map[GoPackageName]bool
+	manualImports    map[GoImportPath]bool
+	annotations      map[string][]Location
+}
+
+// NewGeneratedFile creates a new generated file with the given filename
+// and import path.
+func (gen *Plugin) NewGeneratedFile(filename string, goImportPath GoImportPath) *GeneratedFile {
+	g := &GeneratedFile{
+		gen:              gen,
+		filename:         filename,
+		goImportPath:     goImportPath,
+		packageNames:     make(map[GoImportPath]GoPackageName),
+		usedPackageNames: make(map[GoPackageName]bool),
+		manualImports:    make(map[GoImportPath]bool),
+		annotations:      make(map[string][]Location),
+	}
+	gen.genFiles = append(gen.genFiles, g)
+	return g
+}
+
 // P prints a line to the generated output. It converts each parameter to a
 // string following the same rules as fmt.Print. It never inserts spaces
 // between parameters.
@@ -924,6 +928,11 @@
 	return g.buf.Write(p)
 }
 
+// Skip removes the generated file from the plugin output.
+func (g *GeneratedFile) Skip() {
+	g.skip = true
+}
+
 // Annotate associates a symbol in a generated Go file with a location in a
 // source .proto file.
 //
@@ -934,8 +943,8 @@
 	g.annotations[symbol] = append(g.annotations[symbol], loc)
 }
 
-// content returns the contents of the generated file.
-func (g *GeneratedFile) content() ([]byte, error) {
+// Content returns the contents of the generated file.
+func (g *GeneratedFile) Content() ([]byte, error) {
 	if !strings.HasSuffix(g.filename, ".go") {
 		return g.buf.Bytes(), nil
 	}
diff --git a/protogen/protogen_test.go b/protogen/protogen_test.go
index 7882605..9218acd 100644
--- a/protogen/protogen_test.go
+++ b/protogen/protogen_test.go
@@ -297,9 +297,9 @@
 var _ = baz.X     // "golang.org/x/baz"
 var _ = string1.X // "golang.org/z/string"
 `
-	got, err := g.content()
+	got, err := g.Content()
 	if err != nil {
-		t.Fatalf("g.content() = %v", err)
+		t.Fatalf("g.Content() = %v", err)
 	}
 	if want != string(got) {
 		t.Fatalf(`want:
@@ -333,9 +333,9 @@
 
 var _ = bar.X
 `
-	got, err := g.content()
+	got, err := g.Content()
 	if err != nil {
-		t.Fatalf("g.content() = %v", err)
+		t.Fatalf("g.Content() = %v", err)
 	}
 	if want != string(got) {
 		t.Fatalf(`want: