blob: 29b97ac62a143f37b43c4de5b67d78890c798806 [file] [log] [blame]
Damien Neil220c2022018-08-15 11:24:18 -07001package protogen
2
3import (
Damien Neild9016772018-08-23 14:39:30 -07004 "fmt"
Damien Neil220c2022018-08-15 11:24:18 -07005 "go/token"
6 "strconv"
7 "strings"
8 "unicode"
9 "unicode/utf8"
Damien Neilabc6fc12018-08-23 14:39:30 -070010
Joe Tsai01ab2962018-09-21 17:44:00 -070011 "github.com/golang/protobuf/v2/reflect/protoreflect"
Damien Neil220c2022018-08-15 11:24:18 -070012)
13
Damien Neild9016772018-08-23 14:39:30 -070014// A GoIdent is a Go identifier, consisting of a name and import path.
15type GoIdent struct {
16 GoName string
17 GoImportPath GoImportPath
18}
19
20func (id GoIdent) String() string { return fmt.Sprintf("%q.%v", id.GoImportPath, id.GoName) }
Damien Neilc7d07d92018-08-22 13:46:02 -070021
Damien Neilabc6fc12018-08-23 14:39:30 -070022// newGoIdent returns the Go identifier for a descriptor.
23func newGoIdent(f *File, d protoreflect.Descriptor) GoIdent {
24 name := strings.TrimPrefix(string(d.FullName()), string(f.Desc.Package())+".")
25 return GoIdent{
26 GoName: camelCase(name),
27 GoImportPath: f.GoImportPath,
28 }
29}
30
Damien Neil220c2022018-08-15 11:24:18 -070031// A GoImportPath is the import path of a Go package. e.g., "google.golang.org/genproto/protobuf".
32type GoImportPath string
33
34func (p GoImportPath) String() string { return strconv.Quote(string(p)) }
35
Joe Tsaic1c17aa2018-11-16 11:14:14 -080036// Ident returns a GoIdent with s as the GoName and p as the GoImportPath.
37func (p GoImportPath) Ident(s string) GoIdent {
38 return GoIdent{GoName: s, GoImportPath: p}
39}
40
Damien Neil220c2022018-08-15 11:24:18 -070041// A GoPackageName is the name of a Go package. e.g., "protobuf".
42type GoPackageName string
43
44// cleanPacakgeName converts a string to a valid Go package name.
45func cleanPackageName(name string) GoPackageName {
46 name = strings.Map(badToUnderscore, name)
47 // Identifier must not be keyword: insert _.
48 if token.Lookup(name).IsKeyword() {
49 name = "_" + name
50 }
51 // Identifier must not begin with digit: insert _.
52 if r, _ := utf8.DecodeRuneInString(name); unicode.IsDigit(r) {
53 name = "_" + name
54 }
55 return GoPackageName(name)
56}
57
Damien Neil87214662018-10-05 11:23:35 -070058var isGoPredeclaredIdentifier = map[string]bool{
59 "append": true,
60 "bool": true,
61 "byte": true,
62 "cap": true,
63 "close": true,
64 "complex": true,
65 "complex128": true,
66 "complex64": true,
67 "copy": true,
68 "delete": true,
69 "error": true,
70 "false": true,
71 "float32": true,
72 "float64": true,
73 "imag": true,
74 "int": true,
75 "int16": true,
76 "int32": true,
77 "int64": true,
78 "int8": true,
79 "iota": true,
80 "len": true,
81 "make": true,
82 "new": true,
83 "nil": true,
84 "panic": true,
85 "print": true,
86 "println": true,
87 "real": true,
88 "recover": true,
89 "rune": true,
90 "string": true,
91 "true": true,
92 "uint": true,
93 "uint16": true,
94 "uint32": true,
95 "uint64": true,
96 "uint8": true,
97 "uintptr": true,
98}
99
Damien Neil220c2022018-08-15 11:24:18 -0700100// badToUnderscore is the mapping function used to generate Go names from package names,
101// which can be dotted in the input .proto file. It replaces non-identifier characters such as
102// dot or dash with underscore.
103func badToUnderscore(r rune) rune {
104 if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' {
105 return r
106 }
107 return '_'
108}
109
110// baseName returns the last path element of the name, with the last dotted suffix removed.
111func baseName(name string) string {
112 // First, find the last element
113 if i := strings.LastIndex(name, "/"); i >= 0 {
114 name = name[i+1:]
115 }
116 // Now drop the suffix
117 if i := strings.LastIndex(name, "."); i >= 0 {
118 name = name[:i]
119 }
120 return name
121}
Damien Neilc7d07d92018-08-22 13:46:02 -0700122
123// camelCase converts a name to CamelCase.
124//
125// If there is an interior underscore followed by a lower case letter,
126// drop the underscore and convert the letter to upper case.
127// There is a remote possibility of this rewrite causing a name collision,
128// but it's so remote we're prepared to pretend it's nonexistent - since the
129// C++ generator lowercases names, it's extremely unlikely to have two fields
130// with different capitalizations.
Damien Neild9016772018-08-23 14:39:30 -0700131func camelCase(s string) string {
Damien Neilc7d07d92018-08-22 13:46:02 -0700132 if s == "" {
133 return ""
134 }
135 var t []byte
136 i := 0
137 // Invariant: if the next letter is lower case, it must be converted
138 // to upper case.
139 // That is, we process a word at a time, where words are marked by _ or
140 // upper case letter. Digits are treated as words.
141 for ; i < len(s); i++ {
142 c := s[i]
143 switch {
Damien Neil3863ee52018-10-09 13:24:04 -0700144 case c == '.' && i+1 < len(s) && isASCIILower(s[i+1]):
145 // Skip over .<lowercase>, to match historic behavior.
Damien Neilc7d07d92018-08-22 13:46:02 -0700146 case c == '.':
147 t = append(t, '_') // Convert . to _.
148 case c == '_' && (i == 0 || s[i-1] == '.'):
149 // Convert initial _ to X so we start with a capital letter.
150 // Do the same for _ after .; not strictly necessary, but matches
151 // historic behavior.
152 t = append(t, 'X')
153 case c == '_' && i+1 < len(s) && isASCIILower(s[i+1]):
154 // Skip the underscore in s.
155 case isASCIIDigit(c):
156 t = append(t, c)
157 default:
158 // Assume we have a letter now - if not, it's a bogus identifier.
159 // The next word is a sequence of characters that must start upper case.
160 if isASCIILower(c) {
161 c ^= ' ' // Make it a capital letter.
162 }
163 t = append(t, c) // Guaranteed not lower case.
164 // Accept lower case sequence that follows.
165 for i+1 < len(s) && isASCIILower(s[i+1]) {
166 i++
167 t = append(t, s[i])
168 }
169 }
170 }
Damien Neild9016772018-08-23 14:39:30 -0700171 return string(t)
Damien Neilc7d07d92018-08-22 13:46:02 -0700172}
173
174// Is c an ASCII lower-case letter?
175func isASCIILower(c byte) bool {
176 return 'a' <= c && c <= 'z'
177}
178
179// Is c an ASCII digit?
180func isASCIIDigit(c byte) bool {
181 return '0' <= c && c <= '9'
182}