blob: d54da3460dfb01208c7242724f56024eda9ad35b [file] [log] [blame]
Joe Tsaif0c01e42018-11-06 13:05:20 -08001// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8 "fmt"
9 "reflect"
10
11 protoV1 "github.com/golang/protobuf/proto"
12 ptag "github.com/golang/protobuf/v2/internal/encoding/tag"
13 pvalue "github.com/golang/protobuf/v2/internal/value"
14 pref "github.com/golang/protobuf/v2/reflect/protoreflect"
15 ptype "github.com/golang/protobuf/v2/reflect/prototype"
16)
17
18func makeLegacyExtensionFieldsFunc(t reflect.Type) func(p *messageDataType) pref.KnownFields {
19 f := makeLegacyExtensionMapFunc(t)
20 if f == nil {
21 return nil
22 }
23 return func(p *messageDataType) pref.KnownFields {
24 return legacyExtensionFields{p.mi, f(p)}
25 }
26}
27
28type legacyExtensionFields struct {
29 mi *MessageType
30 x legacyExtensionIface
31}
32
33func (p legacyExtensionFields) Len() (n int) {
34 p.x.Range(func(num pref.FieldNumber, _ legacyExtensionEntry) bool {
35 if p.Has(num) {
36 n++
37 }
38 return true
39 })
40 return n
41}
42
43func (p legacyExtensionFields) Has(n pref.FieldNumber) bool {
44 x := p.x.Get(n)
45 if x.val == nil {
46 return false
47 }
48 t := legacyExtensionTypeOf(x.desc)
49 if t.Cardinality() == pref.Repeated {
Joe Tsaia31649d2018-11-14 17:07:59 -080050 return t.ValueOf(x.val).List().Len() > 0
Joe Tsaif0c01e42018-11-06 13:05:20 -080051 }
52 return true
53}
54
55func (p legacyExtensionFields) Get(n pref.FieldNumber) pref.Value {
56 x := p.x.Get(n)
57 if x.desc == nil {
58 return pref.Value{}
59 }
60 t := legacyExtensionTypeOf(x.desc)
61 if x.val == nil {
62 if t.Cardinality() == pref.Repeated {
Joe Tsai4b7aff62018-11-14 14:05:19 -080063 // TODO: What is the zero value for Lists?
Joe Tsaif0c01e42018-11-06 13:05:20 -080064 // TODO: This logic is racy.
65 v := t.ValueOf(t.New())
Joe Tsaia31649d2018-11-14 17:07:59 -080066 x.val = t.InterfaceOf(v)
Joe Tsaif0c01e42018-11-06 13:05:20 -080067 p.x.Set(n, x)
68 return v
69 }
70 if t.Kind() == pref.MessageKind || t.Kind() == pref.GroupKind {
71 // TODO: What is the zero value for Messages?
72 return pref.Value{}
73 }
74 return t.Default()
75 }
Joe Tsaia31649d2018-11-14 17:07:59 -080076 return t.ValueOf(x.val)
Joe Tsaif0c01e42018-11-06 13:05:20 -080077}
78
79func (p legacyExtensionFields) Set(n pref.FieldNumber, v pref.Value) {
80 x := p.x.Get(n)
81 if x.desc == nil {
82 panic("no extension descriptor registered")
83 }
84 t := legacyExtensionTypeOf(x.desc)
Joe Tsaia31649d2018-11-14 17:07:59 -080085 x.val = t.InterfaceOf(v)
Joe Tsaif0c01e42018-11-06 13:05:20 -080086 p.x.Set(n, x)
87}
88
89func (p legacyExtensionFields) Clear(n pref.FieldNumber) {
90 x := p.x.Get(n)
91 if x.desc == nil {
92 return
93 }
94 x.val = nil
95 p.x.Set(n, x)
96}
97
98func (p legacyExtensionFields) Mutable(n pref.FieldNumber) pref.Mutable {
99 x := p.x.Get(n)
100 if x.desc == nil {
101 panic("no extension descriptor registered")
102 }
103 t := legacyExtensionTypeOf(x.desc)
104 if x.val == nil {
105 v := t.ValueOf(t.New())
Joe Tsaia31649d2018-11-14 17:07:59 -0800106 x.val = t.InterfaceOf(v)
Joe Tsaif0c01e42018-11-06 13:05:20 -0800107 p.x.Set(n, x)
108 }
Joe Tsaia31649d2018-11-14 17:07:59 -0800109 return t.ValueOf(x.val).Interface().(pref.Mutable)
Joe Tsaif0c01e42018-11-06 13:05:20 -0800110}
111
112func (p legacyExtensionFields) Range(f func(pref.FieldNumber, pref.Value) bool) {
113 p.x.Range(func(n pref.FieldNumber, x legacyExtensionEntry) bool {
114 if p.Has(n) {
115 return f(n, p.Get(n))
116 }
117 return true
118 })
119}
120
121func (p legacyExtensionFields) ExtensionTypes() pref.ExtensionFieldTypes {
122 return legacyExtensionTypes(p)
123}
124
125type legacyExtensionTypes legacyExtensionFields
126
127func (p legacyExtensionTypes) Len() (n int) {
128 p.x.Range(func(_ pref.FieldNumber, x legacyExtensionEntry) bool {
129 if x.desc != nil {
130 n++
131 }
132 return true
133 })
134 return n
135}
136
137func (p legacyExtensionTypes) Register(t pref.ExtensionType) {
138 if p.mi.Type.FullName() != t.ExtendedType().FullName() {
139 panic("extended type mismatch")
140 }
141 if !p.mi.Type.ExtensionRanges().Has(t.Number()) {
142 panic("invalid extension field number")
143 }
144 x := p.x.Get(t.Number())
145 if x.desc != nil {
146 panic("extension descriptor already registered")
147 }
148 x.desc = legacyExtensionDescOf(t, p.mi.goType)
149 p.x.Set(t.Number(), x)
150}
151
152func (p legacyExtensionTypes) Remove(t pref.ExtensionType) {
153 if !p.mi.Type.ExtensionRanges().Has(t.Number()) {
154 return
155 }
156 x := p.x.Get(t.Number())
157 if x.val != nil {
158 panic("value for extension descriptor still populated")
159 }
160 x.desc = nil
161 if len(x.raw) == 0 {
162 p.x.Clear(t.Number())
163 } else {
164 p.x.Set(t.Number(), x)
165 }
166}
167
168func (p legacyExtensionTypes) ByNumber(n pref.FieldNumber) pref.ExtensionType {
169 x := p.x.Get(n)
170 if x.desc != nil {
171 return legacyExtensionTypeOf(x.desc)
172 }
173 return nil
174}
175
176func (p legacyExtensionTypes) ByName(s pref.FullName) (t pref.ExtensionType) {
177 p.x.Range(func(_ pref.FieldNumber, x legacyExtensionEntry) bool {
178 if x.desc != nil && x.desc.Name == string(s) {
179 t = legacyExtensionTypeOf(x.desc)
180 return false
181 }
182 return true
183 })
184 return t
185}
186
187func (p legacyExtensionTypes) Range(f func(pref.ExtensionType) bool) {
188 p.x.Range(func(_ pref.FieldNumber, x legacyExtensionEntry) bool {
189 if x.desc != nil {
190 if !f(legacyExtensionTypeOf(x.desc)) {
191 return false
192 }
193 }
194 return true
195 })
196}
197
198func legacyExtensionDescOf(t pref.ExtensionType, parent reflect.Type) *protoV1.ExtensionDesc {
199 if t, ok := t.(*legacyExtensionType); ok {
200 return t.desc
201 }
202
203 // Determine the v1 extension type, which is unfortunately not the same as
204 // the v2 ExtensionType.GoType.
205 extType := t.GoType()
206 switch extType.Kind() {
207 case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String:
208 extType = reflect.PtrTo(extType) // T -> *T for singular scalar fields
209 case reflect.Ptr:
210 if extType.Elem().Kind() == reflect.Slice {
211 extType = extType.Elem() // *[]T -> []T for repeated fields
212 }
213 }
214
215 // Reconstruct the legacy enum full name, which is an odd mixture of the
216 // proto package name with the Go type name.
217 var enumName string
218 if t.Kind() == pref.EnumKind {
219 enumName = t.GoType().Name()
220 for d, ok := pref.Descriptor(t.EnumType()), true; ok; d, ok = d.Parent() {
221 if fd, _ := d.(pref.FileDescriptor); fd != nil && fd.Package() != "" {
222 enumName = string(fd.Package()) + "." + enumName
223 }
224 }
225 }
226
227 // Construct and return a v1 ExtensionDesc.
228 return &protoV1.ExtensionDesc{
229 ExtendedType: reflect.Zero(parent).Interface().(protoV1.Message),
230 ExtensionType: reflect.Zero(extType).Interface(),
231 Field: int32(t.Number()),
232 Name: string(t.FullName()),
233 Tag: ptag.Marshal(t, enumName),
234 }
235}
236
237func legacyExtensionTypeOf(d *protoV1.ExtensionDesc) pref.ExtensionType {
238 // TODO: Add a field to protoV1.ExtensionDesc to contain a v2 descriptor.
239
240 // Derive basic field information from the struct tag.
241 t := reflect.TypeOf(d.ExtensionType)
242 isOptional := t.Kind() == reflect.Ptr && t.Elem().Kind() != reflect.Struct
243 isRepeated := t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8
244 if isOptional || isRepeated {
245 t = t.Elem()
246 }
247 f := ptag.Unmarshal(d.Tag, t)
248
249 // Construct a v2 ExtensionType.
250 conv := newConverter(t, f.Kind)
251 xd, err := ptype.NewExtension(&ptype.StandaloneExtension{
252 FullName: pref.FullName(d.Name),
253 Number: pref.FieldNumber(d.Field),
254 Cardinality: f.Cardinality,
255 Kind: f.Kind,
256 Default: f.Default,
257 Options: f.Options,
258 EnumType: conv.EnumType,
259 MessageType: conv.MessageType,
260 ExtendedType: legacyLoadMessageDesc(reflect.TypeOf(d.ExtendedType)),
261 })
262 if err != nil {
263 panic(err)
264 }
265 xt := ptype.GoExtension(xd, conv.EnumType, conv.MessageType)
266
267 // Return the extension type as is if the dependencies already support v2.
268 xt2 := &legacyExtensionType{ExtensionType: xt, desc: d}
269 if !conv.IsLegacy {
270 return xt2
271 }
272
273 // If the dependency is a v1 enum or message, we need to create a custom
274 // extension type where ExtensionType.GoType continues to use the legacy
275 // v1 Go type, instead of the wrapped versions that satisfy the v2 API.
276 if xd.Cardinality() != pref.Repeated {
277 // Custom extension type for singular enums and messages.
278 // The legacy wrappers use legacyEnumWrapper and legacyMessageWrapper
279 // to implement the v2 interfaces for enums and messages.
280 // Both of those type satisfy the value.Unwrapper interface.
281 xt2.typ = t
282 xt2.new = func() interface{} {
283 return xt.New().(pvalue.Unwrapper).Unwrap()
284 }
285 xt2.valueOf = func(v interface{}) pref.Value {
286 if reflect.TypeOf(v) != xt2.typ {
287 panic(fmt.Sprintf("invalid type: got %T, want %v", v, xt2.typ))
288 }
289 if xd.Kind() == pref.EnumKind {
290 return xt.ValueOf(legacyWrapEnum(reflect.ValueOf(v)))
291 } else {
292 return xt.ValueOf(legacyWrapMessage(reflect.ValueOf(v)))
293 }
294 }
295 xt2.interfaceOf = func(v pref.Value) interface{} {
296 return xt.InterfaceOf(v).(pvalue.Unwrapper).Unwrap()
297 }
298 } else {
299 // Custom extension type for repeated enums and messages.
300 xt2.typ = reflect.PtrTo(reflect.SliceOf(t))
301 xt2.new = func() interface{} {
302 return reflect.New(xt2.typ.Elem()).Interface()
303 }
304 xt2.valueOf = func(v interface{}) pref.Value {
305 if reflect.TypeOf(v) != xt2.typ {
306 panic(fmt.Sprintf("invalid type: got %T, want %v", v, xt2.typ))
307 }
Joe Tsai4b7aff62018-11-14 14:05:19 -0800308 return pref.ValueOf(pvalue.ListOf(v, conv))
Joe Tsaif0c01e42018-11-06 13:05:20 -0800309 }
310 xt2.interfaceOf = func(pv pref.Value) interface{} {
Joe Tsai4b7aff62018-11-14 14:05:19 -0800311 v := pv.List().(pvalue.Unwrapper).Unwrap()
Joe Tsaif0c01e42018-11-06 13:05:20 -0800312 if reflect.TypeOf(v) != xt2.typ {
313 panic(fmt.Sprintf("invalid type: got %T, want %v", v, xt2.typ))
314 }
315 return v
316 }
317 }
318 return xt2
319}
320
321type legacyExtensionType struct {
322 pref.ExtensionType
323 desc *protoV1.ExtensionDesc
324 typ reflect.Type
325 new func() interface{}
326 valueOf func(interface{}) pref.Value
327 interfaceOf func(pref.Value) interface{}
328}
329
330func (x *legacyExtensionType) GoType() reflect.Type {
331 if x.typ != nil {
332 return x.typ
333 }
334 return x.ExtensionType.GoType()
335}
336func (x *legacyExtensionType) New() interface{} {
337 if x.new != nil {
338 return x.new()
339 }
340 return x.ExtensionType.New()
341}
342func (x *legacyExtensionType) ValueOf(v interface{}) pref.Value {
343 if x.valueOf != nil {
344 return x.valueOf(v)
345 }
346 return x.ExtensionType.ValueOf(v)
347}
348func (x *legacyExtensionType) InterfaceOf(v pref.Value) interface{} {
349 if x.interfaceOf != nil {
350 return x.interfaceOf(v)
351 }
352 return x.ExtensionType.InterfaceOf(v)
353}