blob: ffba17224ea1cd946afbd8eda07f16bfeff0ab73 [file] [log] [blame]
Damien Neil220c2022018-08-15 11:24:18 -07001// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package protogen provides support for writing protoc plugins.
6//
7// Plugins for protoc, the Protocol Buffers Compiler, are programs which read
8// a CodeGeneratorRequest protocol buffer from standard input and write a
9// CodeGeneratorResponse protocol buffer to standard output. This package
10// provides support for writing plugins which generate Go code.
11package protogen
12
13import (
Damien Neilc7d07d92018-08-22 13:46:02 -070014 "bufio"
Damien Neil220c2022018-08-15 11:24:18 -070015 "bytes"
16 "fmt"
Damien Neilc7d07d92018-08-22 13:46:02 -070017 "go/parser"
18 "go/printer"
19 "go/token"
Damien Neil220c2022018-08-15 11:24:18 -070020 "io/ioutil"
21 "os"
22 "path/filepath"
Damien Neild9016772018-08-23 14:39:30 -070023 "sort"
24 "strconv"
Damien Neil220c2022018-08-15 11:24:18 -070025 "strings"
26
27 "github.com/golang/protobuf/proto"
28 descpb "github.com/golang/protobuf/protoc-gen-go/descriptor"
29 pluginpb "github.com/golang/protobuf/protoc-gen-go/plugin"
Damien Neild9016772018-08-23 14:39:30 -070030 "golang.org/x/tools/go/ast/astutil"
Damien Neilabc6fc12018-08-23 14:39:30 -070031 "google.golang.org/proto/reflect/protoreflect"
32 "google.golang.org/proto/reflect/protoregistry"
33 "google.golang.org/proto/reflect/prototype"
Damien Neil220c2022018-08-15 11:24:18 -070034)
35
36// Run executes a function as a protoc plugin.
37//
38// It reads a CodeGeneratorRequest message from os.Stdin, invokes the plugin
39// function, and writes a CodeGeneratorResponse message to os.Stdout.
40//
41// If a failure occurs while reading or writing, Run prints an error to
42// os.Stderr and calls os.Exit(1).
43func Run(f func(*Plugin) error) {
44 if err := run(f); err != nil {
45 fmt.Fprintf(os.Stderr, "%s: %v\n", filepath.Base(os.Args[0]), err)
46 os.Exit(1)
47 }
48}
49
50func run(f func(*Plugin) error) error {
51 in, err := ioutil.ReadAll(os.Stdin)
52 if err != nil {
53 return err
54 }
55 req := &pluginpb.CodeGeneratorRequest{}
56 if err := proto.Unmarshal(in, req); err != nil {
57 return err
58 }
59 gen, err := New(req)
60 if err != nil {
61 return err
62 }
63 if err := f(gen); err != nil {
64 // Errors from the plugin function are reported by setting the
65 // error field in the CodeGeneratorResponse.
66 //
67 // In contrast, errors that indicate a problem in protoc
68 // itself (unparsable input, I/O errors, etc.) are reported
69 // to stderr.
70 gen.Error(err)
71 }
72 resp := gen.Response()
73 out, err := proto.Marshal(resp)
74 if err != nil {
75 return err
76 }
77 if _, err := os.Stdout.Write(out); err != nil {
78 return err
79 }
80 return nil
81}
82
83// A Plugin is a protoc plugin invocation.
84type Plugin struct {
85 // Request is the CodeGeneratorRequest provided by protoc.
86 Request *pluginpb.CodeGeneratorRequest
87
88 // Files is the set of files to generate and everything they import.
89 // Files appear in topological order, so each file appears before any
90 // file that imports it.
91 Files []*File
92 filesByName map[string]*File
93
Damien Neilabc6fc12018-08-23 14:39:30 -070094 fileReg *protoregistry.Files
95
Damien Neil220c2022018-08-15 11:24:18 -070096 packageImportPath string // Go import path of the package we're generating code for.
97
98 genFiles []*GeneratedFile
99 err error
100}
101
102// New returns a new Plugin.
103func New(req *pluginpb.CodeGeneratorRequest) (*Plugin, error) {
104 gen := &Plugin{
105 Request: req,
106 filesByName: make(map[string]*File),
Damien Neilabc6fc12018-08-23 14:39:30 -0700107 fileReg: protoregistry.NewFiles(),
Damien Neil220c2022018-08-15 11:24:18 -0700108 }
109
110 // TODO: Figure out how to pass parameters to the generator.
111 for _, param := range strings.Split(req.GetParameter(), ",") {
112 var value string
113 if i := strings.Index(param, "="); i >= 0 {
114 value = param[i+1:]
115 param = param[0:i]
116 }
117 switch param {
118 case "":
119 // Ignore.
120 case "import_prefix":
121 // TODO
122 case "import_path":
123 gen.packageImportPath = value
124 case "paths":
125 // TODO
126 case "plugins":
127 // TODO
128 case "annotate_code":
129 // TODO
130 default:
131 if param[0] != 'M' {
132 return nil, fmt.Errorf("unknown parameter %q", param)
133 }
134 // TODO
135 }
136 }
137
138 for _, fdesc := range gen.Request.ProtoFile {
Damien Neilabc6fc12018-08-23 14:39:30 -0700139 f, err := newFile(gen, fdesc)
140 if err != nil {
141 return nil, err
142 }
143 name := f.Desc.Path()
Damien Neil220c2022018-08-15 11:24:18 -0700144 if gen.filesByName[name] != nil {
145 return nil, fmt.Errorf("duplicate file name: %q", name)
146 }
147 gen.Files = append(gen.Files, f)
148 gen.filesByName[name] = f
149 }
150 for _, name := range gen.Request.FileToGenerate {
151 f, ok := gen.FileByName(name)
152 if !ok {
153 return nil, fmt.Errorf("no descriptor for generated file: %v", name)
154 }
155 f.Generate = true
156 }
157 return gen, nil
158}
159
160// Error records an error in code generation. The generator will report the
161// error back to protoc and will not produce output.
162func (gen *Plugin) Error(err error) {
163 if gen.err == nil {
164 gen.err = err
165 }
166}
167
168// Response returns the generator output.
169func (gen *Plugin) Response() *pluginpb.CodeGeneratorResponse {
170 resp := &pluginpb.CodeGeneratorResponse{}
171 if gen.err != nil {
172 resp.Error = proto.String(gen.err.Error())
173 return resp
174 }
175 for _, gf := range gen.genFiles {
Damien Neilc7d07d92018-08-22 13:46:02 -0700176 content, err := gf.Content()
177 if err != nil {
178 return &pluginpb.CodeGeneratorResponse{
179 Error: proto.String(err.Error()),
180 }
181 }
Damien Neil220c2022018-08-15 11:24:18 -0700182 resp.File = append(resp.File, &pluginpb.CodeGeneratorResponse_File{
Damien Neild9016772018-08-23 14:39:30 -0700183 Name: proto.String(gf.filename),
Damien Neilc7d07d92018-08-22 13:46:02 -0700184 Content: proto.String(string(content)),
Damien Neil220c2022018-08-15 11:24:18 -0700185 })
186 }
187 return resp
188}
189
190// FileByName returns the file with the given name.
191func (gen *Plugin) FileByName(name string) (f *File, ok bool) {
192 f, ok = gen.filesByName[name]
193 return f, ok
194}
195
Damien Neilc7d07d92018-08-22 13:46:02 -0700196// A File describes a .proto source file.
Damien Neil220c2022018-08-15 11:24:18 -0700197type File struct {
Damien Neilabc6fc12018-08-23 14:39:30 -0700198 Desc protoreflect.FileDescriptor
Damien Neil220c2022018-08-15 11:24:18 -0700199
Damien Neild9016772018-08-23 14:39:30 -0700200 GoImportPath GoImportPath // import path of this file's Go package
201 Messages []*Message // top-level message declarations
202 Generate bool // true if we should generate code for this file
Damien Neil220c2022018-08-15 11:24:18 -0700203}
204
Damien Neilabc6fc12018-08-23 14:39:30 -0700205func newFile(gen *Plugin, p *descpb.FileDescriptorProto) (*File, error) {
206 desc, err := prototype.NewFileFromDescriptorProto(p, gen.fileReg)
207 if err != nil {
208 return nil, fmt.Errorf("invalid FileDescriptorProto %q: %v", p.GetName(), err)
209 }
210 if err := gen.fileReg.Register(desc); err != nil {
211 return nil, fmt.Errorf("cannot register descriptor %q: %v", p.GetName(), err)
212 }
Damien Neilc7d07d92018-08-22 13:46:02 -0700213 f := &File{
Damien Neilabc6fc12018-08-23 14:39:30 -0700214 Desc: desc,
Damien Neil220c2022018-08-15 11:24:18 -0700215 }
Damien Neilabc6fc12018-08-23 14:39:30 -0700216 for i, mdescs := 0, desc.Messages(); i < mdescs.Len(); i++ {
217 f.Messages = append(f.Messages, newMessage(gen, f, nil, mdescs.Get(i), i))
Damien Neilc7d07d92018-08-22 13:46:02 -0700218 }
Damien Neilabc6fc12018-08-23 14:39:30 -0700219 return f, nil
Damien Neilc7d07d92018-08-22 13:46:02 -0700220}
221
222// A Message describes a message.
223type Message struct {
Damien Neilabc6fc12018-08-23 14:39:30 -0700224 Desc protoreflect.MessageDescriptor
Damien Neilc7d07d92018-08-22 13:46:02 -0700225
226 GoIdent GoIdent // name of the generated Go type
227 Messages []*Message // nested message declarations
228}
229
Damien Neilabc6fc12018-08-23 14:39:30 -0700230func newMessage(gen *Plugin, f *File, parent *Message, desc protoreflect.MessageDescriptor, index int) *Message {
Damien Neilc7d07d92018-08-22 13:46:02 -0700231 m := &Message{
Damien Neilabc6fc12018-08-23 14:39:30 -0700232 Desc: desc,
233 GoIdent: newGoIdent(f, desc),
Damien Neilc7d07d92018-08-22 13:46:02 -0700234 }
Damien Neilabc6fc12018-08-23 14:39:30 -0700235 for i, mdescs := 0, desc.Messages(); i < mdescs.Len(); i++ {
236 m.Messages = append(m.Messages, newMessage(gen, f, m, mdescs.Get(i), i))
Damien Neilc7d07d92018-08-22 13:46:02 -0700237 }
238 return m
Damien Neil220c2022018-08-15 11:24:18 -0700239}
240
241// A GeneratedFile is a generated file.
242type GeneratedFile struct {
Damien Neild9016772018-08-23 14:39:30 -0700243 filename string
244 goImportPath GoImportPath
245 buf bytes.Buffer
246 packageNames map[GoImportPath]GoPackageName
247 usedPackageNames map[GoPackageName]bool
Damien Neil220c2022018-08-15 11:24:18 -0700248}
249
Damien Neild9016772018-08-23 14:39:30 -0700250// NewGeneratedFile creates a new generated file with the given filename
251// and import path.
252func (gen *Plugin) NewGeneratedFile(filename string, goImportPath GoImportPath) *GeneratedFile {
Damien Neil220c2022018-08-15 11:24:18 -0700253 g := &GeneratedFile{
Damien Neild9016772018-08-23 14:39:30 -0700254 filename: filename,
255 goImportPath: goImportPath,
256 packageNames: make(map[GoImportPath]GoPackageName),
257 usedPackageNames: make(map[GoPackageName]bool),
Damien Neil220c2022018-08-15 11:24:18 -0700258 }
259 gen.genFiles = append(gen.genFiles, g)
260 return g
261}
262
263// P prints a line to the generated output. It converts each parameter to a
264// string following the same rules as fmt.Print. It never inserts spaces
265// between parameters.
266//
267// TODO: .meta file annotations.
268func (g *GeneratedFile) P(v ...interface{}) {
269 for _, x := range v {
Damien Neild9016772018-08-23 14:39:30 -0700270 switch x := x.(type) {
271 case GoIdent:
272 if x.GoImportPath != g.goImportPath {
273 fmt.Fprint(&g.buf, g.goPackageName(x.GoImportPath))
274 fmt.Fprint(&g.buf, ".")
275 }
276 fmt.Fprint(&g.buf, x.GoName)
277 default:
278 fmt.Fprint(&g.buf, x)
279 }
Damien Neil220c2022018-08-15 11:24:18 -0700280 }
281 fmt.Fprintln(&g.buf)
282}
283
Damien Neild9016772018-08-23 14:39:30 -0700284func (g *GeneratedFile) goPackageName(importPath GoImportPath) GoPackageName {
285 if name, ok := g.packageNames[importPath]; ok {
286 return name
287 }
288 name := cleanPackageName(baseName(string(importPath)))
289 for i, orig := 1, name; g.usedPackageNames[name]; i++ {
290 name = orig + GoPackageName(strconv.Itoa(i))
291 }
292 g.packageNames[importPath] = name
293 g.usedPackageNames[name] = true
294 return name
295}
296
Damien Neil220c2022018-08-15 11:24:18 -0700297// Write implements io.Writer.
298func (g *GeneratedFile) Write(p []byte) (n int, err error) {
299 return g.buf.Write(p)
300}
301
302// Content returns the contents of the generated file.
Damien Neilc7d07d92018-08-22 13:46:02 -0700303func (g *GeneratedFile) Content() ([]byte, error) {
Damien Neild9016772018-08-23 14:39:30 -0700304 if !strings.HasSuffix(g.filename, ".go") {
Damien Neilc7d07d92018-08-22 13:46:02 -0700305 return g.buf.Bytes(), nil
306 }
307
308 // Reformat generated code.
309 original := g.buf.Bytes()
310 fset := token.NewFileSet()
311 ast, err := parser.ParseFile(fset, "", original, parser.ParseComments)
312 if err != nil {
313 // Print out the bad code with line numbers.
314 // This should never happen in practice, but it can while changing generated code
315 // so consider this a debugging aid.
316 var src bytes.Buffer
317 s := bufio.NewScanner(bytes.NewReader(original))
318 for line := 1; s.Scan(); line++ {
319 fmt.Fprintf(&src, "%5d\t%s\n", line, s.Bytes())
320 }
Damien Neild9016772018-08-23 14:39:30 -0700321 return nil, fmt.Errorf("%v: unparsable Go source: %v\n%v", g.filename, err, src.String())
Damien Neilc7d07d92018-08-22 13:46:02 -0700322 }
Damien Neild9016772018-08-23 14:39:30 -0700323
324 // Add imports.
325 var importPaths []string
326 for importPath := range g.packageNames {
327 importPaths = append(importPaths, string(importPath))
328 }
329 sort.Strings(importPaths)
330 for _, importPath := range importPaths {
331 astutil.AddNamedImport(fset, ast, string(g.packageNames[GoImportPath(importPath)]), importPath)
332 }
333
Damien Neilc7d07d92018-08-22 13:46:02 -0700334 var out bytes.Buffer
335 if err = (&printer.Config{Mode: printer.TabIndent | printer.UseSpaces, Tabwidth: 8}).Fprint(&out, fset, ast); err != nil {
Damien Neild9016772018-08-23 14:39:30 -0700336 return nil, fmt.Errorf("%v: can not reformat Go source: %v", g.filename, err)
Damien Neilc7d07d92018-08-22 13:46:02 -0700337 }
Damien Neild9016772018-08-23 14:39:30 -0700338 // TODO: Annotations.
Damien Neilc7d07d92018-08-22 13:46:02 -0700339 return out.Bytes(), nil
340
Damien Neil220c2022018-08-15 11:24:18 -0700341}