blob: 12c1832ecc0c4e6c6357d0b0399459a822576cde [file] [log] [blame]
Joe Tsai90fe9962018-10-18 11:06:29 -07001// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8 "fmt"
9 "math"
10 "reflect"
11 "strconv"
12 "strings"
13 "sync"
14 "unicode"
15
Damien Neil204f1c02018-10-23 15:03:38 -070016 protoV1 "github.com/golang/protobuf/proto"
17 descriptorV1 "github.com/golang/protobuf/protoc-gen-go/descriptor"
Joe Tsai90fe9962018-10-18 11:06:29 -070018 "github.com/golang/protobuf/v2/internal/encoding/text"
19 pref "github.com/golang/protobuf/v2/reflect/protoreflect"
20 ptype "github.com/golang/protobuf/v2/reflect/prototype"
21)
22
23var messageDescCache sync.Map // map[reflect.Type]protoreflect.MessageDescriptor
24
25// loadMessageDesc returns an MessageDescriptor derived from the Go type,
26// which must be an *struct kind and not implement the v2 API already.
27func loadMessageDesc(t reflect.Type) pref.MessageDescriptor {
28 return messageDescSet{}.Load(t)
29}
30
31type messageDescSet struct {
32 visited map[reflect.Type]*ptype.StandaloneMessage
33 descs []*ptype.StandaloneMessage
34 types []reflect.Type
35}
36
37func (ms messageDescSet) Load(t reflect.Type) pref.MessageDescriptor {
38 // Fast-path: check if a MessageDescriptor is cached for this concrete type.
39 if mi, ok := messageDescCache.Load(t); ok {
40 return mi.(pref.MessageDescriptor)
41 }
42
43 // Slow-path: initialize MessageDescriptor from the Go type.
44
45 // Processing t recursively populates descs and types with all sub-messages.
46 // The descriptor for the first type is guaranteed to be at the front.
47 ms.processMessage(t)
48
49 // Within a proto file it is possible for cyclic dependencies to exist
50 // between multiple message types. When these cases arise, the set of
51 // message descriptors must be created together.
52 mds, err := ptype.NewMessages(ms.descs)
53 if err != nil {
54 panic(err)
55 }
56 for i, md := range mds {
57 // Protobuf semantics represents map entries under-the-hood as
58 // pseudo-messages (has a descriptor, but no generated Go type).
59 // Avoid caching these fake messages.
60 if t := ms.types[i]; t.Kind() != reflect.Map {
61 messageDescCache.Store(t, md)
62 }
63 }
64 return mds[0]
65}
66
67func (ms *messageDescSet) processMessage(t reflect.Type) pref.MessageDescriptor {
68 // Fast-path: Obtain a placeholder if the message is already processed.
69 if m, ok := ms.visited[t]; ok {
70 return ptype.PlaceholderMessage(m.FullName)
71 }
72
73 // Slow-path: Walk over the struct fields to derive the message descriptor.
74 if t.Kind() != reflect.Ptr && t.Elem().Kind() != reflect.Struct {
75 panic(fmt.Sprintf("got %v, want *struct kind", t))
76 }
77
78 // Derive name and syntax from the raw descriptor.
79 m := new(ptype.StandaloneMessage)
80 mv := reflect.New(t.Elem()).Interface()
81 if _, ok := mv.(pref.ProtoMessage); ok {
82 panic(fmt.Sprintf("%v already implements proto.Message", t))
83 }
84 if md, ok := mv.(legacyMessage); ok {
85 b, idxs := md.Descriptor()
86 fd := loadFileDesc(b)
87
88 // Derive syntax.
89 switch fd.GetSyntax() {
90 case "proto2", "":
91 m.Syntax = pref.Proto2
92 case "proto3":
93 m.Syntax = pref.Proto3
94 }
95
96 // Derive full name.
97 md := fd.MessageType[idxs[0]]
98 m.FullName = pref.FullName(fd.GetPackage()).Append(pref.Name(md.GetName()))
99 for _, i := range idxs[1:] {
100 md = md.NestedType[i]
101 m.FullName = m.FullName.Append(pref.Name(md.GetName()))
102 }
103 } else {
104 // If the type does not implement legacyMessage, then the only way to
105 // obtain the full name is through the registry. However, this is
106 // unreliable as some generated messages register with a fork of
107 // golang/protobuf, so the registry may not have this information.
108 m.FullName = deriveFullName(t.Elem())
109 m.Syntax = pref.Proto2
110
111 // Try to determine if the message is using proto3 by checking scalars.
112 for i := 0; i < t.Elem().NumField(); i++ {
113 f := t.Elem().Field(i)
114 if tag := f.Tag.Get("protobuf"); tag != "" {
115 switch f.Type.Kind() {
116 case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String:
117 m.Syntax = pref.Proto3
118 }
119 for _, s := range strings.Split(tag, ",") {
120 if s == "proto3" {
121 m.Syntax = pref.Proto3
122 }
123 }
124 }
125 }
126 }
127 ms.visit(m, t)
128
129 // Obtain a list of oneof wrapper types.
130 var oneofWrappers []reflect.Type
131 if fn, ok := t.MethodByName("XXX_OneofFuncs"); ok {
132 vs := fn.Func.Call([]reflect.Value{reflect.Zero(fn.Type.In(0))})[3]
133 for _, v := range vs.Interface().([]interface{}) {
134 oneofWrappers = append(oneofWrappers, reflect.TypeOf(v))
135 }
136 }
137
138 // Obtain a list of the extension ranges.
139 if fn, ok := t.MethodByName("ExtensionRangeArray"); ok {
140 vs := fn.Func.Call([]reflect.Value{reflect.Zero(fn.Type.In(0))})[0]
141 for i := 0; i < vs.Len(); i++ {
142 v := vs.Index(i)
143 m.ExtensionRanges = append(m.ExtensionRanges, [2]pref.FieldNumber{
144 pref.FieldNumber(v.FieldByName("Start").Int()),
145 pref.FieldNumber(v.FieldByName("End").Int() + 1),
146 })
147 }
148 }
149
150 // Derive the message fields by inspecting the struct fields.
151 for i := 0; i < t.Elem().NumField(); i++ {
152 f := t.Elem().Field(i)
153 if tag := f.Tag.Get("protobuf"); tag != "" {
154 tagKey := f.Tag.Get("protobuf_key")
155 tagVal := f.Tag.Get("protobuf_val")
156 m.Fields = append(m.Fields, ms.parseField(tag, tagKey, tagVal, f.Type, m))
157 }
158 if tag := f.Tag.Get("protobuf_oneof"); tag != "" {
159 name := pref.Name(tag)
160 m.Oneofs = append(m.Oneofs, ptype.Oneof{Name: name})
161 for _, t := range oneofWrappers {
162 if t.Implements(f.Type) {
163 f := t.Elem().Field(0)
164 if tag := f.Tag.Get("protobuf"); tag != "" {
165 ft := ms.parseField(tag, "", "", f.Type, m)
166 ft.OneofName = name
167 m.Fields = append(m.Fields, ft)
168 }
169 }
170 }
171 }
172 }
173
174 return ptype.PlaceholderMessage(m.FullName)
175}
176
177func (ms *messageDescSet) parseField(tag, tagKey, tagVal string, t reflect.Type, parent *ptype.StandaloneMessage) (f ptype.Field) {
178 isOptional := t.Kind() == reflect.Ptr && t.Elem().Kind() != reflect.Struct
179 isRepeated := t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8
180 if isOptional || isRepeated {
181 t = t.Elem()
182 }
183
Damien Neil204f1c02018-10-23 15:03:38 -0700184 f.Options = &descriptorV1.FieldOptions{
185 Packed: protoV1.Bool(false),
186 }
Joe Tsai90fe9962018-10-18 11:06:29 -0700187 for len(tag) > 0 {
188 i := strings.IndexByte(tag, ',')
189 if i < 0 {
190 i = len(tag)
191 }
192 switch s := tag[:i]; {
193 case strings.HasPrefix(s, "name="):
194 f.Name = pref.Name(s[len("name="):])
195 case strings.Trim(s, "0123456789") == "":
196 n, _ := strconv.ParseUint(s, 10, 32)
197 f.Number = pref.FieldNumber(n)
198 case s == "opt":
199 f.Cardinality = pref.Optional
200 case s == "req":
201 f.Cardinality = pref.Required
202 case s == "rep":
203 f.Cardinality = pref.Repeated
204 case s == "varint":
205 switch t.Kind() {
206 case reflect.Bool:
207 f.Kind = pref.BoolKind
208 case reflect.Int32:
209 f.Kind = pref.Int32Kind
210 case reflect.Int64:
211 f.Kind = pref.Int64Kind
212 case reflect.Uint32:
213 f.Kind = pref.Uint32Kind
214 case reflect.Uint64:
215 f.Kind = pref.Uint64Kind
216 }
217 case s == "zigzag32":
218 if t.Kind() == reflect.Int32 {
219 f.Kind = pref.Sint32Kind
220 }
221 case s == "zigzag64":
222 if t.Kind() == reflect.Int64 {
223 f.Kind = pref.Sint64Kind
224 }
225 case s == "fixed32":
226 switch t.Kind() {
227 case reflect.Int32:
228 f.Kind = pref.Sfixed32Kind
229 case reflect.Uint32:
230 f.Kind = pref.Fixed32Kind
231 case reflect.Float32:
232 f.Kind = pref.FloatKind
233 }
234 case s == "fixed64":
235 switch t.Kind() {
236 case reflect.Int64:
237 f.Kind = pref.Sfixed64Kind
238 case reflect.Uint64:
239 f.Kind = pref.Fixed64Kind
240 case reflect.Float64:
241 f.Kind = pref.DoubleKind
242 }
243 case s == "bytes":
244 switch {
245 case t.Kind() == reflect.String:
246 f.Kind = pref.StringKind
247 case t.Kind() == reflect.Slice && t.Elem() == byteType:
248 f.Kind = pref.BytesKind
249 default:
250 f.Kind = pref.MessageKind
251 }
252 case s == "group":
253 f.Kind = pref.GroupKind
254 case strings.HasPrefix(s, "enum="):
255 f.Kind = pref.EnumKind
256 case strings.HasPrefix(s, "json="):
257 f.JSONName = s[len("json="):]
258 case s == "packed":
Damien Neil204f1c02018-10-23 15:03:38 -0700259 *f.Options.Packed = true
Joe Tsai90fe9962018-10-18 11:06:29 -0700260 case strings.HasPrefix(s, "weak="):
Damien Neil204f1c02018-10-23 15:03:38 -0700261 f.Options.Weak = protoV1.Bool(true)
Joe Tsai90fe9962018-10-18 11:06:29 -0700262 f.MessageType = ptype.PlaceholderMessage(pref.FullName(s[len("weak="):]))
263 case strings.HasPrefix(s, "def="):
264 // The default tag is special in that everything afterwards is the
265 // default regardless of the presence of commas.
266 s, i = tag[len("def="):], len(tag)
267
268 // Defaults are parsed last, so Kind is populated.
269 switch f.Kind {
270 case pref.BoolKind:
271 switch s {
272 case "1":
273 f.Default = pref.ValueOf(true)
274 case "0":
275 f.Default = pref.ValueOf(false)
276 }
277 case pref.EnumKind:
278 n, _ := strconv.ParseInt(s, 10, 32)
279 f.Default = pref.ValueOf(pref.EnumNumber(n))
280 case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind:
281 n, _ := strconv.ParseInt(s, 10, 32)
282 f.Default = pref.ValueOf(int32(n))
283 case pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
284 n, _ := strconv.ParseInt(s, 10, 64)
285 f.Default = pref.ValueOf(int64(n))
286 case pref.Uint32Kind, pref.Fixed32Kind:
287 n, _ := strconv.ParseUint(s, 10, 32)
288 f.Default = pref.ValueOf(uint32(n))
289 case pref.Uint64Kind, pref.Fixed64Kind:
290 n, _ := strconv.ParseUint(s, 10, 64)
291 f.Default = pref.ValueOf(uint64(n))
292 case pref.FloatKind, pref.DoubleKind:
293 n, _ := strconv.ParseFloat(s, 64)
294 switch s {
295 case "nan":
296 n = math.NaN()
297 case "inf":
298 n = math.Inf(+1)
299 case "-inf":
300 n = math.Inf(-1)
301 }
302 if f.Kind == pref.FloatKind {
303 f.Default = pref.ValueOf(float32(n))
304 } else {
305 f.Default = pref.ValueOf(float64(n))
306 }
307 case pref.StringKind:
308 f.Default = pref.ValueOf(string(s))
309 case pref.BytesKind:
310 // The default value is in escaped form (C-style).
311 // TODO: Export unmarshalString in the text package to avoid this hack.
312 v, err := text.Unmarshal([]byte(`["` + s + `"]:0`))
313 if err == nil && len(v.Message()) == 1 {
314 s := v.Message()[0][0].String()
315 f.Default = pref.ValueOf([]byte(s))
316 }
317 }
318 }
319 tag = strings.TrimPrefix(tag[i:], ",")
320 }
321
322 // The generator uses the group message name instead of the field name.
323 // We obtain the real field name by lowercasing the group name.
324 if f.Kind == pref.GroupKind {
325 f.Name = pref.Name(strings.ToLower(string(f.Name)))
326 }
327
328 // Populate EnumType and MessageType.
329 if f.EnumType == nil && f.Kind == pref.EnumKind {
330 if ev, ok := reflect.Zero(t).Interface().(pref.ProtoEnum); ok {
331 f.EnumType = ev.ProtoReflect().Type()
332 } else {
333 f.EnumType = loadEnumDesc(t)
334 }
335 }
336 if f.MessageType == nil && (f.Kind == pref.MessageKind || f.Kind == pref.GroupKind) {
337 if mv, ok := reflect.Zero(t).Interface().(pref.ProtoMessage); ok {
338 f.MessageType = mv.ProtoReflect().Type()
339 } else if t.Kind() == reflect.Map {
340 m := &ptype.StandaloneMessage{
Damien Neil204f1c02018-10-23 15:03:38 -0700341 Syntax: parent.Syntax,
342 FullName: parent.FullName.Append(mapEntryName(f.Name)),
343 Options: &descriptorV1.MessageOptions{MapEntry: protoV1.Bool(true)},
Joe Tsai90fe9962018-10-18 11:06:29 -0700344 Fields: []ptype.Field{
345 ms.parseField(tagKey, "", "", t.Key(), nil),
346 ms.parseField(tagVal, "", "", t.Elem(), nil),
347 },
348 }
349 ms.visit(m, t)
350 f.MessageType = ptype.PlaceholderMessage(m.FullName)
351 } else if mv, ok := messageDescCache.Load(t); ok {
352 f.MessageType = mv.(pref.MessageDescriptor)
353 } else {
354 f.MessageType = ms.processMessage(t)
355 }
356 }
357 return f
358}
359
360func (ms *messageDescSet) visit(m *ptype.StandaloneMessage, t reflect.Type) {
361 if ms.visited == nil {
362 ms.visited = make(map[reflect.Type]*ptype.StandaloneMessage)
363 }
364 if t.Kind() != reflect.Map {
365 ms.visited[t] = m
366 }
367 ms.descs = append(ms.descs, m)
368 ms.types = append(ms.types, t)
369}
370
371// deriveFullName derives a fully qualified protobuf name for the given Go type
372// The provided name is not guaranteed to be stable nor universally unique.
373// It should be sufficiently unique within a program.
374func deriveFullName(t reflect.Type) pref.FullName {
375 sanitize := func(r rune) rune {
376 switch {
377 case r == '/':
378 return '.'
379 case 'a' <= r && r <= 'z', 'A' <= r && r <= 'Z', '0' <= r && r <= '9':
380 return r
381 default:
382 return '_'
383 }
384 }
385 prefix := strings.Map(sanitize, t.PkgPath())
386 suffix := strings.Map(sanitize, t.Name())
387 if suffix == "" {
388 suffix = fmt.Sprintf("UnknownX%X", reflect.ValueOf(t).Pointer())
389 }
390
391 ss := append(strings.Split(prefix, "."), suffix)
392 for i, s := range ss {
393 if s == "" || ('0' <= s[0] && s[0] <= '9') {
394 ss[i] = "x" + s
395 }
396 }
397 return pref.FullName(strings.Join(ss, "."))
398}
399
400// mapEntryName derives the message name for a map field of a given name.
401// This is identical to MapEntryName from parser.cc in the protoc source.
402func mapEntryName(s pref.Name) pref.Name {
403 var b []byte
404 nextUpper := true
405 for i := 0; i < len(s); i++ {
406 if c := s[i]; c == '_' {
407 nextUpper = true
408 } else {
409 if nextUpper {
410 c = byte(unicode.ToUpper(rune(c)))
411 nextUpper = false
412 }
413 b = append(b, c)
414 }
415 }
416 return pref.Name(append(b, "Entry"...))
417}