all: refactor extensions, add proto.GetExtension etc.
Change protoiface.ExtensionDescV1 to implement protoreflect.ExtensionType.
ExtensionDescV1's Name field conflicts with the Descriptor Name method,
so change the protoreflect.{Message,Enum,Extension}Type types to no
longer implement the corresponding Descriptor interface. This also leads
to a clearer distinction between the two types.
Introduce a protoreflect.ExtensionTypeDescriptor type which bridges
between ExtensionType and ExtensionDescriptor.
Add extension accessor functions to the proto package:
proto.{Has,Clear,Get,Set}Extension. These functions take a
protoreflect.ExtensionType parameter, which allows writing the
same function call using either the old or new API:
proto.GetExtension(message, somepb.E_ExtensionFoo)
Fixes golang/protobuf#908
Change-Id: Ibc65d12a46666297849114fd3aefbc4a597d9f08
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/189199
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/internal/impl/legacy_test.go b/internal/impl/legacy_test.go
index 70c5603..89cd0bc 100644
--- a/internal/impl/legacy_test.go
+++ b/internal/impl/legacy_test.go
@@ -52,7 +52,7 @@
func init() {
mt := pimpl.Export{}.MessageTypeOf((*LegacyTestMessage)(nil))
- preg.GlobalFiles.Register(mt.ParentFile())
+ preg.GlobalFiles.Register(mt.Descriptor().ParentFile())
preg.GlobalTypes.Register(mt)
}
@@ -357,19 +357,21 @@
}
for i, xt := range extensionTypes {
var got interface{}
- if !(xt.IsList() || xt.IsMap() || xt.Message() != nil) {
- got = xt.InterfaceOf(m.Get(xt))
+ xd := xt.Descriptor()
+ if !(xd.IsList() || xd.IsMap() || xd.Message() != nil) {
+ got = xt.InterfaceOf(m.Get(xd))
}
want := defaultValues[i]
if diff := cmp.Diff(want, got, opts); diff != "" {
- t.Errorf("Message.Get(%d) mismatch (-want +got):\n%v", xt.Number(), diff)
+ t.Errorf("Message.Get(%d) mismatch (-want +got):\n%v", xd.Number(), diff)
}
}
// All fields should be unpopulated.
for _, xt := range extensionTypes {
- if m.Has(xt) {
- t.Errorf("Message.Has(%d) = true, want false", xt.Number())
+ xd := xt.Descriptor()
+ if m.Has(xd) {
+ t.Errorf("Message.Has(%d) = true, want false", xd.Number())
}
}
@@ -401,11 +403,11 @@
19: &[]*EnumMessages{m2b},
}
for i, xt := range extensionTypes {
- m.Set(xt, xt.ValueOf(setValues[i]))
+ m.Set(xt.Descriptor(), xt.ValueOf(setValues[i]))
}
for i, xt := range extensionTypes[len(extensionTypes)/2:] {
v := extensionTypes[i].ValueOf(setValues[i])
- m.Get(xt).List().Append(v)
+ m.Get(xt.Descriptor()).List().Append(v)
}
// Get the values and check for equality.
@@ -432,24 +434,25 @@
19: &[]*EnumMessages{m2b, m2a},
}
for i, xt := range extensionTypes {
- got := xt.InterfaceOf(m.Get(xt))
+ xd := xt.Descriptor()
+ got := xt.InterfaceOf(m.Get(xd))
want := getValues[i]
if diff := cmp.Diff(want, got, opts); diff != "" {
- t.Errorf("Message.Get(%d) mismatch (-want +got):\n%v", xt.Number(), diff)
+ t.Errorf("Message.Get(%d) mismatch (-want +got):\n%v", xd.Number(), diff)
}
}
// Clear all singular fields and truncate all repeated fields.
for _, xt := range extensionTypes[:len(extensionTypes)/2] {
- m.Clear(xt)
+ m.Clear(xt.Descriptor())
}
for _, xt := range extensionTypes[len(extensionTypes)/2:] {
- m.Get(xt).List().Truncate(0)
+ m.Get(xt.Descriptor()).List().Truncate(0)
}
// Clear all repeated fields.
for _, xt := range extensionTypes[len(extensionTypes)/2:] {
- m.Clear(xt)
+ m.Clear(xt.Descriptor())
}
}
@@ -491,8 +494,6 @@
switch name {
case "ParentFile", "Parent":
// Ignore parents to avoid recursive cycle.
- case "New", "Zero":
- // Ignore constructors.
case "Options":
// Ignore descriptor options since protos are not cmperable.
case "ContainingOneof", "ContainingMessage", "Enum", "Message":
@@ -504,6 +505,8 @@
if !v.IsNil() {
out[name] = v.Interface().(pref.Descriptor).FullName()
}
+ case "Type":
+ // Ignore ExtensionTypeDescriptor.Type method to avoid cycle.
default:
out[name] = m.Call(nil)[0].Interface()
}
@@ -511,6 +514,12 @@
}
return out
}),
+ cmp.Transformer("", func(xt pref.ExtensionType) map[string]interface{} {
+ return map[string]interface{}{
+ "Descriptor": xt.Descriptor(),
+ "GoType": xt.GoType(),
+ }
+ }),
cmp.Transformer("", func(v pref.Value) interface{} {
return v.Interface()
}),
@@ -605,23 +614,23 @@
var (
wantMTA = messageATypes[0]
- wantMDA = messageATypes[0].Fields().ByNumber(1).Message()
+ wantMDA = messageATypes[0].Descriptor().Fields().ByNumber(1).Message()
wantMTB = messageBTypes[0]
- wantMDB = messageBTypes[0].Fields().ByNumber(2).Message()
- wantED = messageATypes[0].Fields().ByNumber(3).Enum()
+ wantMDB = messageBTypes[0].Descriptor().Fields().ByNumber(2).Message()
+ wantED = messageATypes[0].Descriptor().Fields().ByNumber(3).Enum()
)
for _, gotMT := range messageATypes[1:] {
if gotMT != wantMTA {
t.Error("MessageType(MessageA) mismatch")
}
- if gotMDA := gotMT.Fields().ByNumber(1).Message(); gotMDA != wantMDA {
+ if gotMDA := gotMT.Descriptor().Fields().ByNumber(1).Message(); gotMDA != wantMDA {
t.Error("MessageDescriptor(MessageA) mismatch")
}
- if gotMDB := gotMT.Fields().ByNumber(2).Message(); gotMDB != wantMDB {
+ if gotMDB := gotMT.Descriptor().Fields().ByNumber(2).Message(); gotMDB != wantMDB {
t.Error("MessageDescriptor(MessageB) mismatch")
}
- if gotED := gotMT.Fields().ByNumber(3).Enum(); gotED != wantED {
+ if gotED := gotMT.Descriptor().Fields().ByNumber(3).Enum(); gotED != wantED {
t.Error("EnumDescriptor(Enum) mismatch")
}
}
@@ -629,13 +638,13 @@
if gotMT != wantMTB {
t.Error("MessageType(MessageB) mismatch")
}
- if gotMDA := gotMT.Fields().ByNumber(1).Message(); gotMDA != wantMDA {
+ if gotMDA := gotMT.Descriptor().Fields().ByNumber(1).Message(); gotMDA != wantMDA {
t.Error("MessageDescriptor(MessageA) mismatch")
}
- if gotMDB := gotMT.Fields().ByNumber(2).Message(); gotMDB != wantMDB {
+ if gotMDB := gotMT.Descriptor().Fields().ByNumber(2).Message(); gotMDB != wantMDB {
t.Error("MessageDescriptor(MessageB) mismatch")
}
- if gotED := gotMT.Fields().ByNumber(3).Enum(); gotED != wantED {
+ if gotED := gotMT.Descriptor().Fields().ByNumber(3).Enum(); gotED != wantED {
t.Error("EnumDescriptor(Enum) mismatch")
}
}