blob: a7a28cf53943b96f7da0e19c1ba6b2e5373d2638 [file] [log] [blame]
Joe Tsaif0c01e42018-11-06 13:05:20 -08001// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package impl
6
7import (
8 "fmt"
9 "reflect"
10
11 protoV1 "github.com/golang/protobuf/proto"
12 ptag "github.com/golang/protobuf/v2/internal/encoding/tag"
13 pvalue "github.com/golang/protobuf/v2/internal/value"
14 pref "github.com/golang/protobuf/v2/reflect/protoreflect"
15 ptype "github.com/golang/protobuf/v2/reflect/prototype"
16)
17
18func makeLegacyExtensionFieldsFunc(t reflect.Type) func(p *messageDataType) pref.KnownFields {
19 f := makeLegacyExtensionMapFunc(t)
20 if f == nil {
21 return nil
22 }
23 return func(p *messageDataType) pref.KnownFields {
24 return legacyExtensionFields{p.mi, f(p)}
25 }
26}
27
28type legacyExtensionFields struct {
29 mi *MessageType
30 x legacyExtensionIface
31}
32
33func (p legacyExtensionFields) Len() (n int) {
34 p.x.Range(func(num pref.FieldNumber, _ legacyExtensionEntry) bool {
35 if p.Has(num) {
36 n++
37 }
38 return true
39 })
40 return n
41}
42
43func (p legacyExtensionFields) Has(n pref.FieldNumber) bool {
44 x := p.x.Get(n)
45 if x.val == nil {
46 return false
47 }
48 t := legacyExtensionTypeOf(x.desc)
49 if t.Cardinality() == pref.Repeated {
Joe Tsaia31649d2018-11-14 17:07:59 -080050 return t.ValueOf(x.val).List().Len() > 0
Joe Tsaif0c01e42018-11-06 13:05:20 -080051 }
52 return true
53}
54
55func (p legacyExtensionFields) Get(n pref.FieldNumber) pref.Value {
56 x := p.x.Get(n)
57 if x.desc == nil {
58 return pref.Value{}
59 }
60 t := legacyExtensionTypeOf(x.desc)
61 if x.val == nil {
Joe Tsaif6d4a422018-11-19 14:26:06 -080062 // NOTE: x.val is never nil for Lists since they are always populated
63 // during ExtensionFieldTypes.Register.
Joe Tsaif0c01e42018-11-06 13:05:20 -080064 if t.Kind() == pref.MessageKind || t.Kind() == pref.GroupKind {
Joe Tsaif0c01e42018-11-06 13:05:20 -080065 return pref.Value{}
66 }
67 return t.Default()
68 }
Joe Tsaia31649d2018-11-14 17:07:59 -080069 return t.ValueOf(x.val)
Joe Tsaif0c01e42018-11-06 13:05:20 -080070}
71
72func (p legacyExtensionFields) Set(n pref.FieldNumber, v pref.Value) {
73 x := p.x.Get(n)
74 if x.desc == nil {
75 panic("no extension descriptor registered")
76 }
77 t := legacyExtensionTypeOf(x.desc)
Joe Tsaia31649d2018-11-14 17:07:59 -080078 x.val = t.InterfaceOf(v)
Joe Tsaif0c01e42018-11-06 13:05:20 -080079 p.x.Set(n, x)
80}
81
82func (p legacyExtensionFields) Clear(n pref.FieldNumber) {
83 x := p.x.Get(n)
84 if x.desc == nil {
85 return
86 }
Joe Tsaif6d4a422018-11-19 14:26:06 -080087 t := legacyExtensionTypeOf(x.desc)
88 if t.Cardinality() == pref.Repeated {
89 t.ValueOf(x.val).List().Truncate(0)
90 return
91 }
Joe Tsaif0c01e42018-11-06 13:05:20 -080092 x.val = nil
93 p.x.Set(n, x)
94}
95
96func (p legacyExtensionFields) Mutable(n pref.FieldNumber) pref.Mutable {
97 x := p.x.Get(n)
98 if x.desc == nil {
99 panic("no extension descriptor registered")
100 }
101 t := legacyExtensionTypeOf(x.desc)
102 if x.val == nil {
103 v := t.ValueOf(t.New())
Joe Tsaia31649d2018-11-14 17:07:59 -0800104 x.val = t.InterfaceOf(v)
Joe Tsaif0c01e42018-11-06 13:05:20 -0800105 p.x.Set(n, x)
106 }
Joe Tsaia31649d2018-11-14 17:07:59 -0800107 return t.ValueOf(x.val).Interface().(pref.Mutable)
Joe Tsaif0c01e42018-11-06 13:05:20 -0800108}
109
110func (p legacyExtensionFields) Range(f func(pref.FieldNumber, pref.Value) bool) {
111 p.x.Range(func(n pref.FieldNumber, x legacyExtensionEntry) bool {
112 if p.Has(n) {
113 return f(n, p.Get(n))
114 }
115 return true
116 })
117}
118
119func (p legacyExtensionFields) ExtensionTypes() pref.ExtensionFieldTypes {
120 return legacyExtensionTypes(p)
121}
122
123type legacyExtensionTypes legacyExtensionFields
124
125func (p legacyExtensionTypes) Len() (n int) {
126 p.x.Range(func(_ pref.FieldNumber, x legacyExtensionEntry) bool {
127 if x.desc != nil {
128 n++
129 }
130 return true
131 })
132 return n
133}
134
135func (p legacyExtensionTypes) Register(t pref.ExtensionType) {
136 if p.mi.Type.FullName() != t.ExtendedType().FullName() {
137 panic("extended type mismatch")
138 }
139 if !p.mi.Type.ExtensionRanges().Has(t.Number()) {
140 panic("invalid extension field number")
141 }
142 x := p.x.Get(t.Number())
143 if x.desc != nil {
144 panic("extension descriptor already registered")
145 }
146 x.desc = legacyExtensionDescOf(t, p.mi.goType)
Joe Tsaif6d4a422018-11-19 14:26:06 -0800147 if t.Cardinality() == pref.Repeated {
148 // If the field is repeated, initialize the entry with an empty list
149 // so that future Get operations can return a mutable and concrete list.
150 x.val = t.InterfaceOf(t.ValueOf(t.New()))
151 }
Joe Tsaif0c01e42018-11-06 13:05:20 -0800152 p.x.Set(t.Number(), x)
153}
154
155func (p legacyExtensionTypes) Remove(t pref.ExtensionType) {
156 if !p.mi.Type.ExtensionRanges().Has(t.Number()) {
157 return
158 }
159 x := p.x.Get(t.Number())
Joe Tsaif6d4a422018-11-19 14:26:06 -0800160 if t.Cardinality() == pref.Repeated {
161 // Treat an empty repeated field as unpopulated.
162 v := reflect.ValueOf(x.val)
163 if x.val == nil || v.IsNil() || v.Elem().Len() == 0 {
164 x.val = nil
165 }
166 }
Joe Tsaif0c01e42018-11-06 13:05:20 -0800167 if x.val != nil {
168 panic("value for extension descriptor still populated")
169 }
170 x.desc = nil
171 if len(x.raw) == 0 {
172 p.x.Clear(t.Number())
173 } else {
174 p.x.Set(t.Number(), x)
175 }
176}
177
178func (p legacyExtensionTypes) ByNumber(n pref.FieldNumber) pref.ExtensionType {
179 x := p.x.Get(n)
180 if x.desc != nil {
181 return legacyExtensionTypeOf(x.desc)
182 }
183 return nil
184}
185
186func (p legacyExtensionTypes) ByName(s pref.FullName) (t pref.ExtensionType) {
187 p.x.Range(func(_ pref.FieldNumber, x legacyExtensionEntry) bool {
188 if x.desc != nil && x.desc.Name == string(s) {
189 t = legacyExtensionTypeOf(x.desc)
190 return false
191 }
192 return true
193 })
194 return t
195}
196
197func (p legacyExtensionTypes) Range(f func(pref.ExtensionType) bool) {
198 p.x.Range(func(_ pref.FieldNumber, x legacyExtensionEntry) bool {
199 if x.desc != nil {
200 if !f(legacyExtensionTypeOf(x.desc)) {
201 return false
202 }
203 }
204 return true
205 })
206}
207
208func legacyExtensionDescOf(t pref.ExtensionType, parent reflect.Type) *protoV1.ExtensionDesc {
209 if t, ok := t.(*legacyExtensionType); ok {
210 return t.desc
211 }
212
213 // Determine the v1 extension type, which is unfortunately not the same as
214 // the v2 ExtensionType.GoType.
215 extType := t.GoType()
216 switch extType.Kind() {
217 case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String:
218 extType = reflect.PtrTo(extType) // T -> *T for singular scalar fields
219 case reflect.Ptr:
220 if extType.Elem().Kind() == reflect.Slice {
221 extType = extType.Elem() // *[]T -> []T for repeated fields
222 }
223 }
224
225 // Reconstruct the legacy enum full name, which is an odd mixture of the
226 // proto package name with the Go type name.
227 var enumName string
228 if t.Kind() == pref.EnumKind {
229 enumName = t.GoType().Name()
230 for d, ok := pref.Descriptor(t.EnumType()), true; ok; d, ok = d.Parent() {
231 if fd, _ := d.(pref.FileDescriptor); fd != nil && fd.Package() != "" {
232 enumName = string(fd.Package()) + "." + enumName
233 }
234 }
235 }
236
237 // Construct and return a v1 ExtensionDesc.
238 return &protoV1.ExtensionDesc{
239 ExtendedType: reflect.Zero(parent).Interface().(protoV1.Message),
240 ExtensionType: reflect.Zero(extType).Interface(),
241 Field: int32(t.Number()),
242 Name: string(t.FullName()),
243 Tag: ptag.Marshal(t, enumName),
244 }
245}
246
247func legacyExtensionTypeOf(d *protoV1.ExtensionDesc) pref.ExtensionType {
248 // TODO: Add a field to protoV1.ExtensionDesc to contain a v2 descriptor.
249
250 // Derive basic field information from the struct tag.
251 t := reflect.TypeOf(d.ExtensionType)
252 isOptional := t.Kind() == reflect.Ptr && t.Elem().Kind() != reflect.Struct
253 isRepeated := t.Kind() == reflect.Slice && t.Elem().Kind() != reflect.Uint8
254 if isOptional || isRepeated {
255 t = t.Elem()
256 }
257 f := ptag.Unmarshal(d.Tag, t)
258
259 // Construct a v2 ExtensionType.
260 conv := newConverter(t, f.Kind)
261 xd, err := ptype.NewExtension(&ptype.StandaloneExtension{
262 FullName: pref.FullName(d.Name),
263 Number: pref.FieldNumber(d.Field),
264 Cardinality: f.Cardinality,
265 Kind: f.Kind,
266 Default: f.Default,
267 Options: f.Options,
268 EnumType: conv.EnumType,
269 MessageType: conv.MessageType,
270 ExtendedType: legacyLoadMessageDesc(reflect.TypeOf(d.ExtendedType)),
271 })
272 if err != nil {
273 panic(err)
274 }
275 xt := ptype.GoExtension(xd, conv.EnumType, conv.MessageType)
276
277 // Return the extension type as is if the dependencies already support v2.
278 xt2 := &legacyExtensionType{ExtensionType: xt, desc: d}
279 if !conv.IsLegacy {
280 return xt2
281 }
282
283 // If the dependency is a v1 enum or message, we need to create a custom
284 // extension type where ExtensionType.GoType continues to use the legacy
285 // v1 Go type, instead of the wrapped versions that satisfy the v2 API.
286 if xd.Cardinality() != pref.Repeated {
287 // Custom extension type for singular enums and messages.
288 // The legacy wrappers use legacyEnumWrapper and legacyMessageWrapper
289 // to implement the v2 interfaces for enums and messages.
290 // Both of those type satisfy the value.Unwrapper interface.
291 xt2.typ = t
292 xt2.new = func() interface{} {
293 return xt.New().(pvalue.Unwrapper).Unwrap()
294 }
295 xt2.valueOf = func(v interface{}) pref.Value {
296 if reflect.TypeOf(v) != xt2.typ {
297 panic(fmt.Sprintf("invalid type: got %T, want %v", v, xt2.typ))
298 }
299 if xd.Kind() == pref.EnumKind {
300 return xt.ValueOf(legacyWrapEnum(reflect.ValueOf(v)))
301 } else {
302 return xt.ValueOf(legacyWrapMessage(reflect.ValueOf(v)))
303 }
304 }
305 xt2.interfaceOf = func(v pref.Value) interface{} {
306 return xt.InterfaceOf(v).(pvalue.Unwrapper).Unwrap()
307 }
308 } else {
309 // Custom extension type for repeated enums and messages.
310 xt2.typ = reflect.PtrTo(reflect.SliceOf(t))
311 xt2.new = func() interface{} {
312 return reflect.New(xt2.typ.Elem()).Interface()
313 }
314 xt2.valueOf = func(v interface{}) pref.Value {
315 if reflect.TypeOf(v) != xt2.typ {
316 panic(fmt.Sprintf("invalid type: got %T, want %v", v, xt2.typ))
317 }
Joe Tsai4b7aff62018-11-14 14:05:19 -0800318 return pref.ValueOf(pvalue.ListOf(v, conv))
Joe Tsaif0c01e42018-11-06 13:05:20 -0800319 }
320 xt2.interfaceOf = func(pv pref.Value) interface{} {
Joe Tsai4b7aff62018-11-14 14:05:19 -0800321 v := pv.List().(pvalue.Unwrapper).Unwrap()
Joe Tsaif0c01e42018-11-06 13:05:20 -0800322 if reflect.TypeOf(v) != xt2.typ {
323 panic(fmt.Sprintf("invalid type: got %T, want %v", v, xt2.typ))
324 }
325 return v
326 }
327 }
328 return xt2
329}
330
331type legacyExtensionType struct {
332 pref.ExtensionType
333 desc *protoV1.ExtensionDesc
334 typ reflect.Type
335 new func() interface{}
336 valueOf func(interface{}) pref.Value
337 interfaceOf func(pref.Value) interface{}
338}
339
340func (x *legacyExtensionType) GoType() reflect.Type {
341 if x.typ != nil {
342 return x.typ
343 }
344 return x.ExtensionType.GoType()
345}
346func (x *legacyExtensionType) New() interface{} {
347 if x.new != nil {
348 return x.new()
349 }
350 return x.ExtensionType.New()
351}
352func (x *legacyExtensionType) ValueOf(v interface{}) pref.Value {
353 if x.valueOf != nil {
354 return x.valueOf(v)
355 }
356 return x.ExtensionType.ValueOf(v)
357}
358func (x *legacyExtensionType) InterfaceOf(v pref.Value) interface{} {
359 if x.interfaceOf != nil {
360 return x.interfaceOf(v)
361 }
362 return x.ExtensionType.InterfaceOf(v)
363}