cmd/protoc-gen-go: generate public imports by parsing the imported .pb.go
Rather than explicitly enumerating the set of symbols to import,
just parse the imported file and extract every exported symbol.
This is possibly a bit more code, but adapts much better to future
expansion.
Change-Id: I4429664f4c068a2a55949d46aefc19865b008a77
Reviewed-on: https://go-review.googlesource.com/c/155677
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/cmd/protoc-gen-go/internal_gengo/main.go b/cmd/protoc-gen-go/internal_gengo/main.go
index 08a7660..5e6b7b7 100644
--- a/cmd/protoc-gen-go/internal_gengo/main.go
+++ b/cmd/protoc-gen-go/internal_gengo/main.go
@@ -11,10 +11,15 @@
"crypto/sha256"
"encoding/hex"
"fmt"
+ "go/ast"
+ "go/parser"
+ "go/token"
"math"
"sort"
"strconv"
"strings"
+ "unicode"
+ "unicode/utf8"
"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/v2/internal/encoding/tag"
@@ -178,65 +183,54 @@
if !imp.IsPublic {
return
}
- // TODO: An alternate approach to generating public imports might be
- // to generate the imported file contents, parse it, and extract all
- // exported identifiers from the AST to build a list of forwarding
- // declarations.
- //
- // TODO: Consider whether this should generate recursive aliases. e.g.,
- // if a.proto publicly imports b.proto publicly imports c.proto, should
- // a.pb.go contain aliases for symbols defined in c.proto?
- var enums []*protogen.Enum
- enums = append(enums, impFile.Enums...)
- walkMessages(impFile.Messages, func(message *protogen.Message) {
- if message.Desc.IsMapEntry() {
+
+ // Generate public imports by generating the imported file, parsing it,
+ // and extracting every symbol that should receive a forwarding declaration.
+ impGen := gen.NewGeneratedFile("temp.go", impFile.GoImportPath)
+ impGen.Skip()
+ GenerateFile(gen, impFile, impGen)
+ b, err := impGen.Content()
+ if err != nil {
+ gen.Error(err)
+ return
+ }
+ fset := token.NewFileSet()
+ astFile, err := parser.ParseFile(fset, "", b, parser.ParseComments)
+ if err != nil {
+ gen.Error(err)
+ return
+ }
+ genForward := func(tok token.Token, name string) {
+ // Don't import unexported symbols.
+ r, _ := utf8.DecodeRuneInString(name)
+ if !unicode.IsUpper(r) {
return
}
- enums = append(enums, message.Enums...)
- for _, field := range message.Fields {
- if !field.Desc.HasDefault() {
- continue
- }
- defVar := protogen.GoIdent{
- GoImportPath: message.GoIdent.GoImportPath,
- GoName: "Default_" + message.GoIdent.GoName + "_" + field.GoName,
- }
- decl := "const"
- switch field.Desc.Kind() {
- case protoreflect.BytesKind:
- decl = "var"
- case protoreflect.FloatKind, protoreflect.DoubleKind:
- f := field.Desc.Default().Float()
- if math.IsInf(f, -1) || math.IsInf(f, 1) || math.IsNaN(f) {
- decl = "var"
+ // Don't import the FileDescriptor.
+ if name == impFile.GoDescriptorIdent.GoName {
+ return
+ }
+ g.P(tok, " ", name, " = ", impFile.GoImportPath.Ident(name))
+ }
+ g.P("// Symbols defined in public import of ", imp.Path())
+ g.P()
+ for _, decl := range astFile.Decls {
+ switch decl := decl.(type) {
+ case *ast.GenDecl:
+ for _, spec := range decl.Specs {
+ switch spec := spec.(type) {
+ case *ast.TypeSpec:
+ genForward(decl.Tok, spec.Name.Name)
+ case *ast.ValueSpec:
+ for _, name := range spec.Names {
+ genForward(decl.Tok, name.Name)
+ }
+ case *ast.ImportSpec:
+ default:
+ panic(fmt.Sprintf("can't generate forward for spec type %T", spec))
}
}
- g.P(decl, " ", defVar.GoName, " = ", defVar)
}
- g.P("// ", message.GoIdent.GoName, " from public import ", imp.Path())
- g.P("type ", message.GoIdent.GoName, " = ", message.GoIdent)
- for _, oneof := range message.Oneofs {
- for _, field := range oneof.Fields {
- typ := fieldOneofType(field)
- g.P("type ", typ.GoName, " = ", typ)
- }
- }
- g.P()
- })
- for _, enum := range enums {
- g.P("// ", enum.GoIdent.GoName, " from public import ", imp.Path())
- g.P("type ", enum.GoIdent.GoName, " = ", enum.GoIdent)
- g.P("var ", enum.GoIdent.GoName, "_name = ", enum.GoIdent, "_name")
- g.P("var ", enum.GoIdent.GoName, "_value = ", enum.GoIdent, "_value")
- g.P()
- for _, value := range enum.Values {
- g.P("const ", value.GoIdent.GoName, " = ", enum.GoIdent.GoName, "(", value.GoIdent, ")")
- }
- }
- for _, ext := range impFile.Extensions {
- ident := extensionVar(impFile, ext)
- g.P("var ", ident.GoName, " = ", ident)
- g.P()
}
g.P()
}