blob: b30c667cb0d7632a54a107677e56e18c8fe32eb9 [file] [log] [blame]
Joe Tsai90fe9962018-10-18 11:06:29 -07001// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8 "fmt"
9 "math"
10 "reflect"
11 "strconv"
12 "strings"
13 "sync"
14 "unicode"
15
16 "github.com/golang/protobuf/v2/internal/encoding/text"
17 pref "github.com/golang/protobuf/v2/reflect/protoreflect"
18 ptype "github.com/golang/protobuf/v2/reflect/prototype"
19)
20
21var messageDescCache sync.Map // map[reflect.Type]protoreflect.MessageDescriptor
22
23// loadMessageDesc returns an MessageDescriptor derived from the Go type,
24// which must be an *struct kind and not implement the v2 API already.
25func loadMessageDesc(t reflect.Type) pref.MessageDescriptor {
26 return messageDescSet{}.Load(t)
27}
28
29type messageDescSet struct {
30 visited map[reflect.Type]*ptype.StandaloneMessage
31 descs []*ptype.StandaloneMessage
32 types []reflect.Type
33}
34
35func (ms messageDescSet) Load(t reflect.Type) pref.MessageDescriptor {
36 // Fast-path: check if a MessageDescriptor is cached for this concrete type.
37 if mi, ok := messageDescCache.Load(t); ok {
38 return mi.(pref.MessageDescriptor)
39 }
40
41 // Slow-path: initialize MessageDescriptor from the Go type.
42
43 // Processing t recursively populates descs and types with all sub-messages.
44 // The descriptor for the first type is guaranteed to be at the front.
45 ms.processMessage(t)
46
47 // Within a proto file it is possible for cyclic dependencies to exist
48 // between multiple message types. When these cases arise, the set of
49 // message descriptors must be created together.
50 mds, err := ptype.NewMessages(ms.descs)
51 if err != nil {
52 panic(err)
53 }
54 for i, md := range mds {
55 // Protobuf semantics represents map entries under-the-hood as
56 // pseudo-messages (has a descriptor, but no generated Go type).
57 // Avoid caching these fake messages.
58 if t := ms.types[i]; t.Kind() != reflect.Map {
59 messageDescCache.Store(t, md)
60 }
61 }
62 return mds[0]
63}
64
65func (ms *messageDescSet) processMessage(t reflect.Type) pref.MessageDescriptor {
66 // Fast-path: Obtain a placeholder if the message is already processed.
67 if m, ok := ms.visited[t]; ok {
68 return ptype.PlaceholderMessage(m.FullName)
69 }
70
71 // Slow-path: Walk over the struct fields to derive the message descriptor.
72 if t.Kind() != reflect.Ptr && t.Elem().Kind() != reflect.Struct {
73 panic(fmt.Sprintf("got %v, want *struct kind", t))
74 }
75
76 // Derive name and syntax from the raw descriptor.
77 m := new(ptype.StandaloneMessage)
78 mv := reflect.New(t.Elem()).Interface()
79 if _, ok := mv.(pref.ProtoMessage); ok {
80 panic(fmt.Sprintf("%v already implements proto.Message", t))
81 }
82 if md, ok := mv.(legacyMessage); ok {
83 b, idxs := md.Descriptor()
84 fd := loadFileDesc(b)
85
86 // Derive syntax.
87 switch fd.GetSyntax() {
88 case "proto2", "":
89 m.Syntax = pref.Proto2
90 case "proto3":
91 m.Syntax = pref.Proto3
92 }
93
94 // Derive full name.
95 md := fd.MessageType[idxs[0]]
96 m.FullName = pref.FullName(fd.GetPackage()).Append(pref.Name(md.GetName()))
97 for _, i := range idxs[1:] {
98 md = md.NestedType[i]
99 m.FullName = m.FullName.Append(pref.Name(md.GetName()))
100 }
101 } else {
102 // If the type does not implement legacyMessage, then the only way to
103 // obtain the full name is through the registry. However, this is
104 // unreliable as some generated messages register with a fork of
105 // golang/protobuf, so the registry may not have this information.
106 m.FullName = deriveFullName(t.Elem())
107 m.Syntax = pref.Proto2
108
109 // Try to determine if the message is using proto3 by checking scalars.
110 for i := 0; i < t.Elem().NumField(); i++ {
111 f := t.Elem().Field(i)
112 if tag := f.Tag.Get("protobuf"); tag != "" {
113 switch f.Type.Kind() {
114 case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String:
115 m.Syntax = pref.Proto3
116 }
117 for _, s := range strings.Split(tag, ",") {
118 if s == "proto3" {
119 m.Syntax = pref.Proto3
120 }
121 }
122 }
123 }
124 }
125 ms.visit(m, t)
126
127 // Obtain a list of oneof wrapper types.
128 var oneofWrappers []reflect.Type
129 if fn, ok := t.MethodByName("XXX_OneofFuncs"); ok {
130 vs := fn.Func.Call([]reflect.Value{reflect.Zero(fn.Type.In(0))})[3]
131 for _, v := range vs.Interface().([]interface{}) {
132 oneofWrappers = append(oneofWrappers, reflect.TypeOf(v))
133 }
134 }
135
136 // Obtain a list of the extension ranges.
137 if fn, ok := t.MethodByName("ExtensionRangeArray"); ok {
138 vs := fn.Func.Call([]reflect.Value{reflect.Zero(fn.Type.In(0))})[0]
139 for i := 0; i < vs.Len(); i++ {
140 v := vs.Index(i)
141 m.ExtensionRanges = append(m.ExtensionRanges, [2]pref.FieldNumber{
142 pref.FieldNumber(v.FieldByName("Start").Int()),
143 pref.FieldNumber(v.FieldByName("End").Int() + 1),
144 })
145 }
146 }
147
148 // Derive the message fields by inspecting the struct fields.
149 for i := 0; i < t.Elem().NumField(); i++ {
150 f := t.Elem().Field(i)
151 if tag := f.Tag.Get("protobuf"); tag != "" {
152 tagKey := f.Tag.Get("protobuf_key")
153 tagVal := f.Tag.Get("protobuf_val")
154 m.Fields = append(m.Fields, ms.parseField(tag, tagKey, tagVal, f.Type, m))
155 }
156 if tag := f.Tag.Get("protobuf_oneof"); tag != "" {
157 name := pref.Name(tag)
158 m.Oneofs = append(m.Oneofs, ptype.Oneof{Name: name})
159 for _, t := range oneofWrappers {
160 if t.Implements(f.Type) {
161 f := t.Elem().Field(0)
162 if tag := f.Tag.Get("protobuf"); tag != "" {
163 ft := ms.parseField(tag, "", "", f.Type, m)
164 ft.OneofName = name
165 m.Fields = append(m.Fields, ft)
166 }
167 }
168 }
169 }
170 }
171
172 return ptype.PlaceholderMessage(m.FullName)
173}
174
175func (ms *messageDescSet) parseField(tag, tagKey, tagVal string, t reflect.Type, parent *ptype.StandaloneMessage) (f ptype.Field) {
176 isOptional := t.Kind() == reflect.Ptr && t.Elem().Kind() != reflect.Struct
177 isRepeated := t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8
178 if isOptional || isRepeated {
179 t = t.Elem()
180 }
181
182 for len(tag) > 0 {
183 i := strings.IndexByte(tag, ',')
184 if i < 0 {
185 i = len(tag)
186 }
187 switch s := tag[:i]; {
188 case strings.HasPrefix(s, "name="):
189 f.Name = pref.Name(s[len("name="):])
190 case strings.Trim(s, "0123456789") == "":
191 n, _ := strconv.ParseUint(s, 10, 32)
192 f.Number = pref.FieldNumber(n)
193 case s == "opt":
194 f.Cardinality = pref.Optional
195 case s == "req":
196 f.Cardinality = pref.Required
197 case s == "rep":
198 f.Cardinality = pref.Repeated
199 case s == "varint":
200 switch t.Kind() {
201 case reflect.Bool:
202 f.Kind = pref.BoolKind
203 case reflect.Int32:
204 f.Kind = pref.Int32Kind
205 case reflect.Int64:
206 f.Kind = pref.Int64Kind
207 case reflect.Uint32:
208 f.Kind = pref.Uint32Kind
209 case reflect.Uint64:
210 f.Kind = pref.Uint64Kind
211 }
212 case s == "zigzag32":
213 if t.Kind() == reflect.Int32 {
214 f.Kind = pref.Sint32Kind
215 }
216 case s == "zigzag64":
217 if t.Kind() == reflect.Int64 {
218 f.Kind = pref.Sint64Kind
219 }
220 case s == "fixed32":
221 switch t.Kind() {
222 case reflect.Int32:
223 f.Kind = pref.Sfixed32Kind
224 case reflect.Uint32:
225 f.Kind = pref.Fixed32Kind
226 case reflect.Float32:
227 f.Kind = pref.FloatKind
228 }
229 case s == "fixed64":
230 switch t.Kind() {
231 case reflect.Int64:
232 f.Kind = pref.Sfixed64Kind
233 case reflect.Uint64:
234 f.Kind = pref.Fixed64Kind
235 case reflect.Float64:
236 f.Kind = pref.DoubleKind
237 }
238 case s == "bytes":
239 switch {
240 case t.Kind() == reflect.String:
241 f.Kind = pref.StringKind
242 case t.Kind() == reflect.Slice && t.Elem() == byteType:
243 f.Kind = pref.BytesKind
244 default:
245 f.Kind = pref.MessageKind
246 }
247 case s == "group":
248 f.Kind = pref.GroupKind
249 case strings.HasPrefix(s, "enum="):
250 f.Kind = pref.EnumKind
251 case strings.HasPrefix(s, "json="):
252 f.JSONName = s[len("json="):]
253 case s == "packed":
254 f.IsPacked = true
255 case strings.HasPrefix(s, "weak="):
256 f.IsWeak = true
257 f.MessageType = ptype.PlaceholderMessage(pref.FullName(s[len("weak="):]))
258 case strings.HasPrefix(s, "def="):
259 // The default tag is special in that everything afterwards is the
260 // default regardless of the presence of commas.
261 s, i = tag[len("def="):], len(tag)
262
263 // Defaults are parsed last, so Kind is populated.
264 switch f.Kind {
265 case pref.BoolKind:
266 switch s {
267 case "1":
268 f.Default = pref.ValueOf(true)
269 case "0":
270 f.Default = pref.ValueOf(false)
271 }
272 case pref.EnumKind:
273 n, _ := strconv.ParseInt(s, 10, 32)
274 f.Default = pref.ValueOf(pref.EnumNumber(n))
275 case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind:
276 n, _ := strconv.ParseInt(s, 10, 32)
277 f.Default = pref.ValueOf(int32(n))
278 case pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
279 n, _ := strconv.ParseInt(s, 10, 64)
280 f.Default = pref.ValueOf(int64(n))
281 case pref.Uint32Kind, pref.Fixed32Kind:
282 n, _ := strconv.ParseUint(s, 10, 32)
283 f.Default = pref.ValueOf(uint32(n))
284 case pref.Uint64Kind, pref.Fixed64Kind:
285 n, _ := strconv.ParseUint(s, 10, 64)
286 f.Default = pref.ValueOf(uint64(n))
287 case pref.FloatKind, pref.DoubleKind:
288 n, _ := strconv.ParseFloat(s, 64)
289 switch s {
290 case "nan":
291 n = math.NaN()
292 case "inf":
293 n = math.Inf(+1)
294 case "-inf":
295 n = math.Inf(-1)
296 }
297 if f.Kind == pref.FloatKind {
298 f.Default = pref.ValueOf(float32(n))
299 } else {
300 f.Default = pref.ValueOf(float64(n))
301 }
302 case pref.StringKind:
303 f.Default = pref.ValueOf(string(s))
304 case pref.BytesKind:
305 // The default value is in escaped form (C-style).
306 // TODO: Export unmarshalString in the text package to avoid this hack.
307 v, err := text.Unmarshal([]byte(`["` + s + `"]:0`))
308 if err == nil && len(v.Message()) == 1 {
309 s := v.Message()[0][0].String()
310 f.Default = pref.ValueOf([]byte(s))
311 }
312 }
313 }
314 tag = strings.TrimPrefix(tag[i:], ",")
315 }
316
317 // The generator uses the group message name instead of the field name.
318 // We obtain the real field name by lowercasing the group name.
319 if f.Kind == pref.GroupKind {
320 f.Name = pref.Name(strings.ToLower(string(f.Name)))
321 }
322
323 // Populate EnumType and MessageType.
324 if f.EnumType == nil && f.Kind == pref.EnumKind {
325 if ev, ok := reflect.Zero(t).Interface().(pref.ProtoEnum); ok {
326 f.EnumType = ev.ProtoReflect().Type()
327 } else {
328 f.EnumType = loadEnumDesc(t)
329 }
330 }
331 if f.MessageType == nil && (f.Kind == pref.MessageKind || f.Kind == pref.GroupKind) {
332 if mv, ok := reflect.Zero(t).Interface().(pref.ProtoMessage); ok {
333 f.MessageType = mv.ProtoReflect().Type()
334 } else if t.Kind() == reflect.Map {
335 m := &ptype.StandaloneMessage{
336 Syntax: parent.Syntax,
337 FullName: parent.FullName.Append(mapEntryName(f.Name)),
338 IsMapEntry: true,
339 Fields: []ptype.Field{
340 ms.parseField(tagKey, "", "", t.Key(), nil),
341 ms.parseField(tagVal, "", "", t.Elem(), nil),
342 },
343 }
344 ms.visit(m, t)
345 f.MessageType = ptype.PlaceholderMessage(m.FullName)
346 } else if mv, ok := messageDescCache.Load(t); ok {
347 f.MessageType = mv.(pref.MessageDescriptor)
348 } else {
349 f.MessageType = ms.processMessage(t)
350 }
351 }
352 return f
353}
354
355func (ms *messageDescSet) visit(m *ptype.StandaloneMessage, t reflect.Type) {
356 if ms.visited == nil {
357 ms.visited = make(map[reflect.Type]*ptype.StandaloneMessage)
358 }
359 if t.Kind() != reflect.Map {
360 ms.visited[t] = m
361 }
362 ms.descs = append(ms.descs, m)
363 ms.types = append(ms.types, t)
364}
365
366// deriveFullName derives a fully qualified protobuf name for the given Go type
367// The provided name is not guaranteed to be stable nor universally unique.
368// It should be sufficiently unique within a program.
369func deriveFullName(t reflect.Type) pref.FullName {
370 sanitize := func(r rune) rune {
371 switch {
372 case r == '/':
373 return '.'
374 case 'a' <= r && r <= 'z', 'A' <= r && r <= 'Z', '0' <= r && r <= '9':
375 return r
376 default:
377 return '_'
378 }
379 }
380 prefix := strings.Map(sanitize, t.PkgPath())
381 suffix := strings.Map(sanitize, t.Name())
382 if suffix == "" {
383 suffix = fmt.Sprintf("UnknownX%X", reflect.ValueOf(t).Pointer())
384 }
385
386 ss := append(strings.Split(prefix, "."), suffix)
387 for i, s := range ss {
388 if s == "" || ('0' <= s[0] && s[0] <= '9') {
389 ss[i] = "x" + s
390 }
391 }
392 return pref.FullName(strings.Join(ss, "."))
393}
394
395// mapEntryName derives the message name for a map field of a given name.
396// This is identical to MapEntryName from parser.cc in the protoc source.
397func mapEntryName(s pref.Name) pref.Name {
398 var b []byte
399 nextUpper := true
400 for i := 0; i < len(s); i++ {
401 if c := s[i]; c == '_' {
402 nextUpper = true
403 } else {
404 if nextUpper {
405 c = byte(unicode.ToUpper(rune(c)))
406 nextUpper = false
407 }
408 b = append(b, c)
409 }
410 }
411 return pref.Name(append(b, "Entry"...))
412}