protogen: automatic handling of imports
The GoIdent type is now a tuple of import path and name. Generated files
have an associated import path. Writing a GoIdent to a generated file
qualifies the name if the identifier is from a different package.
All necessary imports are automatically added to generated Go files.
Change-Id: I839e0b7aa8ec967ce178aea4ffb960b62779cf74
Reviewed-on: https://go-review.googlesource.com/133635
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/protogen/protogen.go b/protogen/protogen.go
index f499a8d..2d65ee0 100644
--- a/protogen/protogen.go
+++ b/protogen/protogen.go
@@ -20,11 +20,14 @@
"io/ioutil"
"os"
"path/filepath"
+ "sort"
+ "strconv"
"strings"
"github.com/golang/protobuf/proto"
descpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
pluginpb "github.com/golang/protobuf/protoc-gen-go/plugin"
+ "golang.org/x/tools/go/ast/astutil"
)
// Run executes a function as a protoc plugin.
@@ -168,7 +171,7 @@
}
}
resp.File = append(resp.File, &pluginpb.CodeGeneratorResponse_File{
- Name: proto.String(gf.path),
+ Name: proto.String(gf.filename),
Content: proto.String(string(content)),
})
}
@@ -185,16 +188,17 @@
type File struct {
Desc *descpb.FileDescriptorProto // TODO: protoreflect.FileDescriptor
- Messages []*Message // top-level message declartions
- Generate bool // true if we should generate code for this file
+ GoImportPath GoImportPath // import path of this file's Go package
+ Messages []*Message // top-level message declarations
+ Generate bool // true if we should generate code for this file
}
func newFile(gen *Plugin, p *descpb.FileDescriptorProto) *File {
f := &File{
Desc: p,
}
- for _, d := range p.MessageType {
- f.Messages = append(f.Messages, newMessage(gen, nil, d))
+ for i, mdesc := range p.MessageType {
+ f.Messages = append(f.Messages, newMessage(gen, f, nil, mdesc, i))
}
return f
}
@@ -207,30 +211,40 @@
Messages []*Message // nested message declarations
}
-func newMessage(gen *Plugin, parent *Message, p *descpb.DescriptorProto) *Message {
+func newMessage(gen *Plugin, f *File, parent *Message, p *descpb.DescriptorProto, index int) *Message {
m := &Message{
- Desc: p,
- GoIdent: camelCase(p.GetName()),
+ Desc: p,
+ GoIdent: GoIdent{
+ GoName: camelCase(p.GetName()),
+ GoImportPath: f.GoImportPath,
+ },
}
if parent != nil {
- m.GoIdent = parent.GoIdent + "_" + m.GoIdent
+ m.GoIdent.GoName = parent.GoIdent.GoName + "_" + m.GoIdent.GoName
}
- for _, nested := range p.GetNestedType() {
- m.Messages = append(m.Messages, newMessage(gen, m, nested))
+ for i, nested := range p.GetNestedType() {
+ m.Messages = append(m.Messages, newMessage(gen, f, m, nested, i))
}
return m
}
// A GeneratedFile is a generated file.
type GeneratedFile struct {
- path string
- buf bytes.Buffer
+ filename string
+ goImportPath GoImportPath
+ buf bytes.Buffer
+ packageNames map[GoImportPath]GoPackageName
+ usedPackageNames map[GoPackageName]bool
}
-// NewGeneratedFile creates a new generated file with the given path.
-func (gen *Plugin) NewGeneratedFile(path string) *GeneratedFile {
+// NewGeneratedFile creates a new generated file with the given filename
+// and import path.
+func (gen *Plugin) NewGeneratedFile(filename string, goImportPath GoImportPath) *GeneratedFile {
g := &GeneratedFile{
- path: path,
+ filename: filename,
+ goImportPath: goImportPath,
+ packageNames: make(map[GoImportPath]GoPackageName),
+ usedPackageNames: make(map[GoPackageName]bool),
}
gen.genFiles = append(gen.genFiles, g)
return g
@@ -243,11 +257,33 @@
// TODO: .meta file annotations.
func (g *GeneratedFile) P(v ...interface{}) {
for _, x := range v {
- fmt.Fprint(&g.buf, x)
+ switch x := x.(type) {
+ case GoIdent:
+ if x.GoImportPath != g.goImportPath {
+ fmt.Fprint(&g.buf, g.goPackageName(x.GoImportPath))
+ fmt.Fprint(&g.buf, ".")
+ }
+ fmt.Fprint(&g.buf, x.GoName)
+ default:
+ fmt.Fprint(&g.buf, x)
+ }
}
fmt.Fprintln(&g.buf)
}
+func (g *GeneratedFile) goPackageName(importPath GoImportPath) GoPackageName {
+ if name, ok := g.packageNames[importPath]; ok {
+ return name
+ }
+ name := cleanPackageName(baseName(string(importPath)))
+ for i, orig := 1, name; g.usedPackageNames[name]; i++ {
+ name = orig + GoPackageName(strconv.Itoa(i))
+ }
+ g.packageNames[importPath] = name
+ g.usedPackageNames[name] = true
+ return name
+}
+
// Write implements io.Writer.
func (g *GeneratedFile) Write(p []byte) (n int, err error) {
return g.buf.Write(p)
@@ -255,7 +291,7 @@
// Content returns the contents of the generated file.
func (g *GeneratedFile) Content() ([]byte, error) {
- if !strings.HasSuffix(g.path, ".go") {
+ if !strings.HasSuffix(g.filename, ".go") {
return g.buf.Bytes(), nil
}
@@ -272,13 +308,24 @@
for line := 1; s.Scan(); line++ {
fmt.Fprintf(&src, "%5d\t%s\n", line, s.Bytes())
}
- return nil, fmt.Errorf("%v: unparsable Go source: %v\n%v", g.path, err, src.String())
+ return nil, fmt.Errorf("%v: unparsable Go source: %v\n%v", g.filename, err, src.String())
}
+
+ // Add imports.
+ var importPaths []string
+ for importPath := range g.packageNames {
+ importPaths = append(importPaths, string(importPath))
+ }
+ sort.Strings(importPaths)
+ for _, importPath := range importPaths {
+ astutil.AddNamedImport(fset, ast, string(g.packageNames[GoImportPath(importPath)]), importPath)
+ }
+
var out bytes.Buffer
if err = (&printer.Config{Mode: printer.TabIndent | printer.UseSpaces, Tabwidth: 8}).Fprint(&out, fset, ast); err != nil {
- return nil, fmt.Errorf("%v: can not reformat Go source: %v", g.path, err)
+ return nil, fmt.Errorf("%v: can not reformat Go source: %v", g.filename, err)
}
- // TODO: Patch annotation locations.
+ // TODO: Annotations.
return out.Bytes(), nil
}