all: refactor extensions, add proto.GetExtension etc.
Change protoiface.ExtensionDescV1 to implement protoreflect.ExtensionType.
ExtensionDescV1's Name field conflicts with the Descriptor Name method,
so change the protoreflect.{Message,Enum,Extension}Type types to no
longer implement the corresponding Descriptor interface. This also leads
to a clearer distinction between the two types.
Introduce a protoreflect.ExtensionTypeDescriptor type which bridges
between ExtensionType and ExtensionDescriptor.
Add extension accessor functions to the proto package:
proto.{Has,Clear,Get,Set}Extension. These functions take a
protoreflect.ExtensionType parameter, which allows writing the
same function call using either the old or new API:
proto.GetExtension(message, somepb.E_ExtensionFoo)
Fixes golang/protobuf#908
Change-Id: Ibc65d12a46666297849114fd3aefbc4a597d9f08
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/189199
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/internal/impl/codec_extension.go b/internal/impl/codec_extension.go
index 3a23bb0..7e430ee 100644
--- a/internal/impl/codec_extension.go
+++ b/internal/impl/codec_extension.go
@@ -29,25 +29,26 @@
return e
}
+ xd := xt.Descriptor()
var wiretag uint64
- if !xt.IsPacked() {
- wiretag = wire.EncodeTag(xt.Number(), wireTypes[xt.Kind()])
+ if !xd.IsPacked() {
+ wiretag = wire.EncodeTag(xd.Number(), wireTypes[xd.Kind()])
} else {
- wiretag = wire.EncodeTag(xt.Number(), wire.BytesType)
+ wiretag = wire.EncodeTag(xd.Number(), wire.BytesType)
}
e = &extensionFieldInfo{
wiretag: wiretag,
tagsize: wire.SizeVarint(wiretag),
- funcs: encoderFuncsForValue(xt, xt.GoType()),
+ funcs: encoderFuncsForValue(xd, xt.GoType()),
}
// Does the unmarshal function need a value passed to it?
// This is true for composite types, where we pass in a message, list, or map to fill in,
// and for enums, where we pass in a prototype value to specify the concrete enum type.
- switch xt.Kind() {
+ switch xd.Kind() {
case pref.MessageKind, pref.GroupKind, pref.EnumKind:
e.unmarshalNeedsValue = true
default:
- if xt.Cardinality() == pref.Repeated {
+ if xd.Cardinality() == pref.Repeated {
e.unmarshalNeedsValue = true
}
}
diff --git a/internal/impl/codec_message.go b/internal/impl/codec_message.go
index e43c812..7014861 100644
--- a/internal/impl/codec_message.go
+++ b/internal/impl/codec_message.go
@@ -44,8 +44,9 @@
mi.extensionOffset = si.extensionOffset
mi.coderFields = make(map[wire.Number]*coderFieldInfo)
- for i := 0; i < mi.PBType.Fields().Len(); i++ {
- fd := mi.PBType.Fields().Get(i)
+ fields := mi.PBType.Descriptor().Fields()
+ for i := 0; i < fields.Len(); i++ {
+ fd := fields.Get(i)
fs := si.fieldsByNumber[fd.Number()]
if fd.ContainingOneof() != nil {
@@ -81,7 +82,7 @@
}
if messageset.IsMessageSet(mi.PBType.Descriptor()) {
if !mi.extensionOffset.IsValid() {
- panic(fmt.Sprintf("%v: MessageSet with no extensions field", mi.PBType.FullName()))
+ panic(fmt.Sprintf("%v: MessageSet with no extensions field", mi.PBType.Descriptor().FullName()))
}
cf := &coderFieldInfo{
num: messageset.FieldItem,
@@ -113,7 +114,7 @@
mi.denseCoderFields[cf.num] = cf
}
- mi.needsInitCheck = needsInitCheck(mi.PBType)
+ mi.needsInitCheck = needsInitCheck(mi.PBType.Descriptor())
mi.methods = piface.Methods{
Flags: piface.SupportMarshalDeterministic | piface.SupportUnmarshalDiscardUnknown,
MarshalAppend: mi.marshalAppend,
diff --git a/internal/impl/decode.go b/internal/impl/decode.go
index 4e4c7f3..4821852 100644
--- a/internal/impl/decode.go
+++ b/internal/impl/decode.go
@@ -138,7 +138,7 @@
xt := x.GetType()
if xt == nil {
var err error
- xt, err = opts.Resolver().FindExtensionByNumber(mi.PBType.FullName(), num)
+ xt, err = opts.Resolver().FindExtensionByNumber(mi.PBType.Descriptor().FullName(), num)
if err != nil {
if err == preg.NotFound {
return 0, errUnknown
diff --git a/internal/impl/isinit.go b/internal/impl/isinit.go
index ca00012..079afe0 100644
--- a/internal/impl/isinit.go
+++ b/internal/impl/isinit.go
@@ -29,7 +29,7 @@
if p.IsNil() {
for _, f := range mi.orderedCoderFields {
if f.isRequired {
- return errors.RequiredNotSet(string(mi.PBType.Fields().ByNumber(f.num).FullName()))
+ return errors.RequiredNotSet(string(mi.PBType.Descriptor().Fields().ByNumber(f.num).FullName()))
}
}
return nil
@@ -47,7 +47,7 @@
fptr := p.Apply(f.offset)
if f.isPointer && fptr.Elem().IsNil() {
if f.isRequired {
- return errors.RequiredNotSet(string(mi.PBType.Fields().ByNumber(f.num).FullName()))
+ return errors.RequiredNotSet(string(mi.PBType.Descriptor().Fields().ByNumber(f.num).FullName()))
}
continue
}
diff --git a/internal/impl/legacy_extension.go b/internal/impl/legacy_extension.go
index aaf8fcc..2da4d71 100644
--- a/internal/impl/legacy_extension.go
+++ b/internal/impl/legacy_extension.go
@@ -5,11 +5,9 @@
package impl
import (
- "fmt"
"reflect"
"sync"
- "google.golang.org/protobuf/internal/descfmt"
ptag "google.golang.org/protobuf/internal/encoding/tag"
"google.golang.org/protobuf/internal/filedesc"
pref "google.golang.org/protobuf/reflect/protoreflect"
@@ -62,8 +60,9 @@
}
// Determine the parent type if possible.
+ xd := xt.Descriptor()
var parent piface.MessageV1
- messageName := xt.ContainingMessage().FullName()
+ messageName := xd.ContainingMessage().FullName()
if mt, _ := preg.GlobalTypes.FindMessageByName(messageName); mt != nil {
// Create a new parent message and unwrap it if possible.
mv := mt.New().Interface()
@@ -94,7 +93,7 @@
// Reconstruct the legacy enum full name, which is an odd mixture of the
// proto package name with the Go type name.
var enumName string
- if xt.Kind() == pref.EnumKind {
+ if xd.Kind() == pref.EnumKind {
// Derive Go type name.
t := extType
if t.Kind() == reflect.Ptr || t.Kind() == reflect.Slice {
@@ -105,7 +104,7 @@
// Derive the proto package name.
// For legacy enums, obtain the proto package from the raw descriptor.
var protoPkg string
- if fd := xt.Enum().ParentFile(); fd != nil {
+ if fd := xd.Enum().ParentFile(); fd != nil {
protoPkg = string(fd.Package())
}
if ed, ok := reflect.Zero(t).Interface().(enumV1); ok && protoPkg == "" {
@@ -120,7 +119,7 @@
// Derive the proto file that the extension was declared within.
var filename string
- if fd := xt.ParentFile(); fd != nil {
+ if fd := xd.ParentFile(); fd != nil {
filename = fd.Path()
}
@@ -129,9 +128,9 @@
Type: xt,
ExtendedType: parent,
ExtensionType: reflect.Zero(extType).Interface(),
- Field: int32(xt.Number()),
- Name: string(xt.FullName()),
- Tag: ptag.Marshal(xt, enumName),
+ Field: int32(xd.Number()),
+ Name: string(xd.FullName()),
+ Tag: ptag.Marshal(xd, enumName),
Filename: filename,
}
if d, ok := legacyExtensionDescCache.LoadOrStore(xt, d); ok {
@@ -217,15 +216,16 @@
//
// This is exported for testing purposes.
func LegacyExtensionTypeOf(xd pref.ExtensionDescriptor, t reflect.Type) pref.ExtensionType {
- return &legacyExtensionType{
- ExtensionDescriptor: xd,
- typ: t,
- conv: NewConverter(t, xd),
+ xt := &legacyExtensionType{
+ typ: t,
+ conv: NewConverter(t, xd),
}
+ xt.desc = &extDesc{xd, xt}
+ return xt
}
type legacyExtensionType struct {
- pref.ExtensionDescriptor
+ desc pref.ExtensionTypeDescriptor
typ reflect.Type
conv Converter
}
@@ -239,5 +239,12 @@
func (x *legacyExtensionType) InterfaceOf(v pref.Value) interface{} {
return x.conv.GoValueOf(v).Interface()
}
-func (x *legacyExtensionType) Descriptor() pref.ExtensionDescriptor { return x.ExtensionDescriptor }
-func (x *legacyExtensionType) Format(s fmt.State, r rune) { descfmt.FormatDesc(s, r, x) }
+func (x *legacyExtensionType) Descriptor() pref.ExtensionTypeDescriptor { return x.desc }
+
+type extDesc struct {
+ pref.ExtensionDescriptor
+ xt *legacyExtensionType
+}
+
+func (t *extDesc) Type() pref.ExtensionType { return t.xt }
+func (t *extDesc) Descriptor() pref.ExtensionDescriptor { return t.ExtensionDescriptor }
diff --git a/internal/impl/legacy_test.go b/internal/impl/legacy_test.go
index 70c5603..89cd0bc 100644
--- a/internal/impl/legacy_test.go
+++ b/internal/impl/legacy_test.go
@@ -52,7 +52,7 @@
func init() {
mt := pimpl.Export{}.MessageTypeOf((*LegacyTestMessage)(nil))
- preg.GlobalFiles.Register(mt.ParentFile())
+ preg.GlobalFiles.Register(mt.Descriptor().ParentFile())
preg.GlobalTypes.Register(mt)
}
@@ -357,19 +357,21 @@
}
for i, xt := range extensionTypes {
var got interface{}
- if !(xt.IsList() || xt.IsMap() || xt.Message() != nil) {
- got = xt.InterfaceOf(m.Get(xt))
+ xd := xt.Descriptor()
+ if !(xd.IsList() || xd.IsMap() || xd.Message() != nil) {
+ got = xt.InterfaceOf(m.Get(xd))
}
want := defaultValues[i]
if diff := cmp.Diff(want, got, opts); diff != "" {
- t.Errorf("Message.Get(%d) mismatch (-want +got):\n%v", xt.Number(), diff)
+ t.Errorf("Message.Get(%d) mismatch (-want +got):\n%v", xd.Number(), diff)
}
}
// All fields should be unpopulated.
for _, xt := range extensionTypes {
- if m.Has(xt) {
- t.Errorf("Message.Has(%d) = true, want false", xt.Number())
+ xd := xt.Descriptor()
+ if m.Has(xd) {
+ t.Errorf("Message.Has(%d) = true, want false", xd.Number())
}
}
@@ -401,11 +403,11 @@
19: &[]*EnumMessages{m2b},
}
for i, xt := range extensionTypes {
- m.Set(xt, xt.ValueOf(setValues[i]))
+ m.Set(xt.Descriptor(), xt.ValueOf(setValues[i]))
}
for i, xt := range extensionTypes[len(extensionTypes)/2:] {
v := extensionTypes[i].ValueOf(setValues[i])
- m.Get(xt).List().Append(v)
+ m.Get(xt.Descriptor()).List().Append(v)
}
// Get the values and check for equality.
@@ -432,24 +434,25 @@
19: &[]*EnumMessages{m2b, m2a},
}
for i, xt := range extensionTypes {
- got := xt.InterfaceOf(m.Get(xt))
+ xd := xt.Descriptor()
+ got := xt.InterfaceOf(m.Get(xd))
want := getValues[i]
if diff := cmp.Diff(want, got, opts); diff != "" {
- t.Errorf("Message.Get(%d) mismatch (-want +got):\n%v", xt.Number(), diff)
+ t.Errorf("Message.Get(%d) mismatch (-want +got):\n%v", xd.Number(), diff)
}
}
// Clear all singular fields and truncate all repeated fields.
for _, xt := range extensionTypes[:len(extensionTypes)/2] {
- m.Clear(xt)
+ m.Clear(xt.Descriptor())
}
for _, xt := range extensionTypes[len(extensionTypes)/2:] {
- m.Get(xt).List().Truncate(0)
+ m.Get(xt.Descriptor()).List().Truncate(0)
}
// Clear all repeated fields.
for _, xt := range extensionTypes[len(extensionTypes)/2:] {
- m.Clear(xt)
+ m.Clear(xt.Descriptor())
}
}
@@ -491,8 +494,6 @@
switch name {
case "ParentFile", "Parent":
// Ignore parents to avoid recursive cycle.
- case "New", "Zero":
- // Ignore constructors.
case "Options":
// Ignore descriptor options since protos are not cmperable.
case "ContainingOneof", "ContainingMessage", "Enum", "Message":
@@ -504,6 +505,8 @@
if !v.IsNil() {
out[name] = v.Interface().(pref.Descriptor).FullName()
}
+ case "Type":
+ // Ignore ExtensionTypeDescriptor.Type method to avoid cycle.
default:
out[name] = m.Call(nil)[0].Interface()
}
@@ -511,6 +514,12 @@
}
return out
}),
+ cmp.Transformer("", func(xt pref.ExtensionType) map[string]interface{} {
+ return map[string]interface{}{
+ "Descriptor": xt.Descriptor(),
+ "GoType": xt.GoType(),
+ }
+ }),
cmp.Transformer("", func(v pref.Value) interface{} {
return v.Interface()
}),
@@ -605,23 +614,23 @@
var (
wantMTA = messageATypes[0]
- wantMDA = messageATypes[0].Fields().ByNumber(1).Message()
+ wantMDA = messageATypes[0].Descriptor().Fields().ByNumber(1).Message()
wantMTB = messageBTypes[0]
- wantMDB = messageBTypes[0].Fields().ByNumber(2).Message()
- wantED = messageATypes[0].Fields().ByNumber(3).Enum()
+ wantMDB = messageBTypes[0].Descriptor().Fields().ByNumber(2).Message()
+ wantED = messageATypes[0].Descriptor().Fields().ByNumber(3).Enum()
)
for _, gotMT := range messageATypes[1:] {
if gotMT != wantMTA {
t.Error("MessageType(MessageA) mismatch")
}
- if gotMDA := gotMT.Fields().ByNumber(1).Message(); gotMDA != wantMDA {
+ if gotMDA := gotMT.Descriptor().Fields().ByNumber(1).Message(); gotMDA != wantMDA {
t.Error("MessageDescriptor(MessageA) mismatch")
}
- if gotMDB := gotMT.Fields().ByNumber(2).Message(); gotMDB != wantMDB {
+ if gotMDB := gotMT.Descriptor().Fields().ByNumber(2).Message(); gotMDB != wantMDB {
t.Error("MessageDescriptor(MessageB) mismatch")
}
- if gotED := gotMT.Fields().ByNumber(3).Enum(); gotED != wantED {
+ if gotED := gotMT.Descriptor().Fields().ByNumber(3).Enum(); gotED != wantED {
t.Error("EnumDescriptor(Enum) mismatch")
}
}
@@ -629,13 +638,13 @@
if gotMT != wantMTB {
t.Error("MessageType(MessageB) mismatch")
}
- if gotMDA := gotMT.Fields().ByNumber(1).Message(); gotMDA != wantMDA {
+ if gotMDA := gotMT.Descriptor().Fields().ByNumber(1).Message(); gotMDA != wantMDA {
t.Error("MessageDescriptor(MessageA) mismatch")
}
- if gotMDB := gotMT.Fields().ByNumber(2).Message(); gotMDB != wantMDB {
+ if gotMDB := gotMT.Descriptor().Fields().ByNumber(2).Message(); gotMDB != wantMDB {
t.Error("MessageDescriptor(MessageB) mismatch")
}
- if gotED := gotMT.Fields().ByNumber(3).Enum(); gotED != wantED {
+ if gotED := gotMT.Descriptor().Fields().ByNumber(3).Enum(); gotED != wantED {
t.Error("EnumDescriptor(Enum) mismatch")
}
}
diff --git a/internal/impl/message.go b/internal/impl/message.go
index 305e17d..6100663 100644
--- a/internal/impl/message.go
+++ b/internal/impl/message.go
@@ -222,8 +222,9 @@
// any discrepancies.
func (mi *MessageInfo) makeKnownFieldsFunc(si structInfo) {
mi.fields = map[pref.FieldNumber]*fieldInfo{}
- for i := 0; i < mi.PBType.Fields().Len(); i++ {
- fd := mi.PBType.Fields().Get(i)
+ md := mi.PBType.Descriptor()
+ for i := 0; i < md.Fields().Len(); i++ {
+ fd := md.Fields().Get(i)
fs := si.fieldsByNumber[fd.Number()]
var fi fieldInfo
switch {
@@ -244,8 +245,8 @@
}
mi.oneofs = map[pref.Name]*oneofInfo{}
- for i := 0; i < mi.PBType.Oneofs().Len(); i++ {
- od := mi.PBType.Oneofs().Get(i)
+ for i := 0; i < md.Oneofs().Len(); i++ {
+ od := md.Oneofs().Get(i)
mi.oneofs[od.Name()] = makeOneofInfo(od, si.oneofsByName[od.Name()], mi.Exporter, si.oneofWrappersByType)
}
}
diff --git a/internal/impl/message_reflect.go b/internal/impl/message_reflect.go
index 699ac2c..b0f1778 100644
--- a/internal/impl/message_reflect.go
+++ b/internal/impl/message_reflect.go
@@ -121,7 +121,7 @@
if m != nil {
for _, x := range *m {
xt := x.GetType()
- if !f(xt, xt.ValueOf(x.GetValue())) {
+ if !f(xt.Descriptor(), xt.ValueOf(x.GetValue())) {
return
}
}
@@ -129,16 +129,17 @@
}
func (m *extensionMap) Has(xt pref.ExtensionType) (ok bool) {
if m != nil {
- _, ok = (*m)[int32(xt.Number())]
+ _, ok = (*m)[int32(xt.Descriptor().Number())]
}
return ok
}
func (m *extensionMap) Clear(xt pref.ExtensionType) {
- delete(*m, int32(xt.Number()))
+ delete(*m, int32(xt.Descriptor().Number()))
}
func (m *extensionMap) Get(xt pref.ExtensionType) pref.Value {
+ xd := xt.Descriptor()
if m != nil {
- if x, ok := (*m)[int32(xt.Number())]; ok {
+ if x, ok := (*m)[int32(xd.Number())]; ok {
return xt.ValueOf(x.GetValue())
}
}
@@ -151,13 +152,14 @@
var x ExtensionField
x.SetType(xt)
x.SetEagerValue(xt.InterfaceOf(v))
- (*m)[int32(xt.Number())] = x
+ (*m)[int32(xt.Descriptor().Number())] = x
}
func (m *extensionMap) Mutable(xt pref.ExtensionType) pref.Value {
- if !isComposite(xt) {
+ xd := xt.Descriptor()
+ if !isComposite(xd) {
panic("invalid Mutable on field with non-composite type")
}
- if x, ok := (*m)[int32(xt.Number())]; ok {
+ if x, ok := (*m)[int32(xd.Number())]; ok {
return xt.ValueOf(x.GetValue())
}
v := xt.New()
@@ -179,14 +181,18 @@
return fi, nil
}
if fd.IsExtension() {
- if fd.ContainingMessage().FullName() != mi.PBType.FullName() {
+ if fd.ContainingMessage().FullName() != mi.PBType.Descriptor().FullName() {
// TODO: Should this be exact containing message descriptor match?
panic("mismatching containing message")
}
- if !mi.PBType.ExtensionRanges().Has(fd.Number()) {
+ if !mi.PBType.Descriptor().ExtensionRanges().Has(fd.Number()) {
panic("invalid extension field")
}
- return nil, fd.(pref.ExtensionType)
+ xtd, ok := fd.(pref.ExtensionTypeDescriptor)
+ if !ok {
+ panic("extension descriptor does not implement ExtensionTypeDescriptor")
+ }
+ return nil, xtd.Type()
}
panic("invalid field descriptor")
}