reflect/protoreflect: add alternative message reflection API
Added API:
Message.Len
Message.Range
Message.Has
Message.Clear
Message.Get
Message.Set
Message.Mutable
Message.NewMessage
Message.WhichOneof
Message.GetUnknown
Message.SetUnknown
Deprecated API (to be removed in subsequent CL):
Message.KnownFields
Message.UnknownFields
The primary difference with the new API is that the top-level
Message methods are keyed by FieldDescriptor rather than FieldNumber
with the following semantics:
* For known fields, the FieldDescriptor must exactly match the
field descriptor known by the message.
* For extension fields, the FieldDescriptor must implement ExtensionType,
where ContainingMessage.FullName matches the message name, and
the field number is within the message's extension range.
When setting an extension field, it automatically stores
the extension type information.
* Extension fields are always considered nullable,
implying that repeated extension fields are nullable.
That is, you can distinguish between a unpopulated list and an empty list.
* Message.Get always returns a valid Value even if unpopulated.
The behavior is already well-defined for scalars, but for unpopulated
composite types, it now returns an empty read-only version of it.
Change-Id: Ia120630b4db221aeaaf743d0f64160e1a61a0f61
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/175458
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/testing/prototest/prototest.go b/testing/prototest/prototest.go
index b5053f4..1ca4e36 100644
--- a/testing/prototest/prototest.go
+++ b/testing/prototest/prototest.go
@@ -17,68 +17,46 @@
pref "google.golang.org/protobuf/reflect/protoreflect"
)
-// TestMessage runs the provided message through a series of tests
-// exercising the protobuf reflection API.
-func TestMessage(t testing.TB, message proto.Message) {
- md := message.ProtoReflect().Descriptor()
+// TODO: Test read-only properties of unpopulated composite values.
+// TODO: Test invalid field descriptors or oneof descriptors.
+// TODO: This should test the functionality that can be provided by fast-paths.
- m := message.ProtoReflect().New()
+// TestMessage runs the provided m through a series of tests
+// exercising the protobuf reflection API.
+func TestMessage(t testing.TB, m proto.Message) {
+ md := m.ProtoReflect().Descriptor()
+ m1 := m.ProtoReflect().New()
for i := 0; i < md.Fields().Len(); i++ {
fd := md.Fields().Get(i)
switch {
case fd.IsList():
- testFieldList(t, m, fd)
+ testFieldList(t, m1, fd)
case fd.IsMap():
- testFieldMap(t, m, fd)
+ testFieldMap(t, m1, fd)
case fd.Kind() == pref.FloatKind || fd.Kind() == pref.DoubleKind:
- testFieldFloat(t, m, fd)
+ testFieldFloat(t, m1, fd)
}
- testField(t, m, fd)
+ testField(t, m1, fd)
}
for i := 0; i < md.Oneofs().Len(); i++ {
- testOneof(t, m, md.Oneofs().Get(i))
- }
-
- // Test has/get/clear on a non-existent field.
- for num := pref.FieldNumber(1); ; num++ {
- if md.Fields().ByNumber(num) != nil {
- continue
- }
- if md.ExtensionRanges().Has(num) {
- continue
- }
- // Field num does not exist.
- if m.KnownFields().Has(num) {
- t.Errorf("non-existent field: Has(%v) = true, want false", num)
- }
- if v := m.KnownFields().Get(num); v.IsValid() {
- t.Errorf("non-existent field: Get(%v) = %v, want invalid", num, formatValue(v))
- }
- m.KnownFields().Clear(num) // noop
- break
- }
-
- // Test WhichOneof on a non-existent oneof.
- const invalidName = "invalid-name"
- if got, want := m.KnownFields().WhichOneof(invalidName), pref.FieldNumber(0); got != want {
- t.Errorf("non-existent oneof: WhichOneof(%q) = %v, want %v", invalidName, got, want)
+ testOneof(t, m1, md.Oneofs().Get(i))
}
// TODO: Extensions, unknown fields.
// Test round-trip marshal/unmarshal.
- m1 := message.ProtoReflect().New().Interface()
- populateMessage(m1.ProtoReflect(), 1, nil)
- b, err := proto.Marshal(m1)
+ m2 := m.ProtoReflect().New().Interface()
+ populateMessage(m2.ProtoReflect(), 1, nil)
+ b, err := proto.Marshal(m2)
if err != nil {
- t.Errorf("Marshal() = %v, want nil\n%v", err, marshalText(m1))
+ t.Errorf("Marshal() = %v, want nil\n%v", err, marshalText(m2))
}
- m2 := message.ProtoReflect().New().Interface()
- if err := proto.Unmarshal(b, m2); err != nil {
- t.Errorf("Unmarshal() = %v, want nil\n%v", err, marshalText(m1))
+ m3 := m.ProtoReflect().New().Interface()
+ if err := proto.Unmarshal(b, m3); err != nil {
+ t.Errorf("Unmarshal() = %v, want nil\n%v", err, marshalText(m2))
}
- if !proto.Equal(m1, m2) {
- t.Errorf("round-trip marshal/unmarshal did not preserve message.\nOriginal:\n%v\nNew:\n%v", marshalText(m1), marshalText(m2))
+ if !proto.Equal(m2, m3) {
+ t.Errorf("round-trip marshal/unmarshal did not preserve message\nOriginal:\n%v\nNew:\n%v", marshalText(m2), marshalText(m3))
}
}
@@ -87,16 +65,15 @@
return string(b)
}
-// testField exericises set/get/has/clear of a field.
+// testField exercises set/get/has/clear of a field.
func testField(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
- num := fd.Number()
name := fd.FullName()
- known := m.KnownFields()
+ num := fd.Number()
// Set to a non-zero value, the zero value, different non-zero values.
for _, n := range []seed{1, 0, minVal, maxVal} {
v := newValue(m, fd, n, nil)
- known.Set(num, v)
+ m.Set(fd, v)
wantHas := true
if n == 0 {
if fd.Syntax() == pref.Proto3 && fd.Message() == nil {
@@ -109,55 +86,55 @@
wantHas = true
}
}
- if got, want := known.Has(num), wantHas; got != want {
- t.Errorf("after setting %q to %v:\nHas(%v) = %v, want %v", name, formatValue(v), num, got, want)
+ if got, want := m.Has(fd), wantHas; got != want {
+ t.Errorf("after setting %q to %v:\nMessage.Has(%v) = %v, want %v", name, formatValue(v), num, got, want)
}
- if got, want := known.Get(num), v; !valueEqual(got, want) {
- t.Errorf("after setting %q:\nGet(%v) = %v, want %v", name, num, formatValue(got), formatValue(want))
+ if got, want := m.Get(fd), v; !valueEqual(got, want) {
+ t.Errorf("after setting %q:\nMessage.Get(%v) = %v, want %v", name, num, formatValue(got), formatValue(want))
}
}
- known.Clear(num)
- if got, want := known.Has(num), false; got != want {
- t.Errorf("after clearing %q:\nHas(%v) = %v, want %v", name, num, got, want)
+ m.Clear(fd)
+ if got, want := m.Has(fd), false; got != want {
+ t.Errorf("after clearing %q:\nMessage.Has(%v) = %v, want %v", name, num, got, want)
}
switch {
case fd.IsList():
- if got := known.Get(num); got.List().Len() != 0 {
- t.Errorf("after clearing %q:\nGet(%v) = %v, want empty list", name, num, formatValue(got))
+ if got := m.Get(fd); got.List().Len() != 0 {
+ t.Errorf("after clearing %q:\nMessage.Get(%v) = %v, want empty list", name, num, formatValue(got))
}
case fd.IsMap():
- if got := known.Get(num); got.Map().Len() != 0 {
- t.Errorf("after clearing %q:\nGet(%v) = %v, want empty list", name, num, formatValue(got))
+ if got := m.Get(fd); got.Map().Len() != 0 {
+ t.Errorf("after clearing %q:\nMessage.Get(%v) = %v, want empty list", name, num, formatValue(got))
}
- default:
- if got, want := known.Get(num), fd.Default(); !valueEqual(got, want) {
- t.Errorf("after clearing %q:\nGet(%v) = %v, want default %v", name, num, formatValue(got), formatValue(want))
+ case fd.Message() == nil:
+ if got, want := m.Get(fd), fd.Default(); !valueEqual(got, want) {
+ t.Errorf("after clearing %q:\nMessage.Get(%v) = %v, want default %v", name, num, formatValue(got), formatValue(want))
}
}
}
// testFieldMap tests set/get/has/clear of entries in a map field.
func testFieldMap(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
- num := fd.Number()
name := fd.FullName()
- known := m.KnownFields()
- known.Clear(num) // start with an empty map
- mapv := known.Get(num).Map()
+ num := fd.Number()
+
+ m.Clear(fd) // start with an empty map
+ mapv := m.Mutable(fd).Map()
// Add values.
want := make(testMap)
for i, n := range []seed{1, 0, minVal, maxVal} {
- if got, want := known.Has(num), i > 0; got != want {
- t.Errorf("after inserting %d elements to %q:\nHas(%v) = %v, want %v", i, name, num, got, want)
+ if got, want := m.Has(fd), i > 0; got != want {
+ t.Errorf("after inserting %d elements to %q:\nMessage.Has(%v) = %v, want %v", i, name, num, got, want)
}
k := newMapKey(fd, n)
v := newMapValue(fd, mapv, n, nil)
mapv.Set(k, v)
want.Set(k, v)
- if got, want := known.Get(num), pref.ValueOf(want); !valueEqual(got, want) {
- t.Errorf("after inserting %d elements to %q:\nGet(%v) = %v, want %v", i, name, num, formatValue(got), formatValue(want))
+ if got, want := m.Get(fd), pref.ValueOf(want); !valueEqual(got, want) {
+ t.Errorf("after inserting %d elements to %q:\nMessage.Get(%v) = %v, want %v", i, name, num, formatValue(got), formatValue(want))
}
}
@@ -166,8 +143,8 @@
nv := newMapValue(fd, mapv, 10, nil)
mapv.Set(k, nv)
want.Set(k, nv)
- if got, want := m.KnownFields().Get(num), pref.ValueOf(want); !valueEqual(got, want) {
- t.Errorf("after setting element %v of %q:\nGet(%v) = %v, want %v", formatValue(k.Value()), name, num, formatValue(got), formatValue(want))
+ if got, want := m.Get(fd), pref.ValueOf(want); !valueEqual(got, want) {
+ t.Errorf("after setting element %v of %q:\nMessage.Get(%v) = %v, want %v", formatValue(k.Value()), name, num, formatValue(got), formatValue(want))
}
return true
})
@@ -176,11 +153,11 @@
want.Range(func(k pref.MapKey, v pref.Value) bool {
mapv.Clear(k)
want.Clear(k)
- if got, want := known.Has(num), want.Len() > 0; got != want {
- t.Errorf("after clearing elements of %q:\nHas(%v) = %v, want %v", name, num, got, want)
+ if got, want := m.Has(fd), want.Len() > 0; got != want {
+ t.Errorf("after clearing elements of %q:\nMessage.Has(%v) = %v, want %v", name, num, got, want)
}
- if got, want := m.KnownFields().Get(num), pref.ValueOf(want); !valueEqual(got, want) {
- t.Errorf("after clearing elements of %q:\nGet(%v) = %v, want %v", name, num, formatValue(got), formatValue(want))
+ if got, want := m.Get(fd), pref.ValueOf(want); !valueEqual(got, want) {
+ t.Errorf("after clearing elements of %q:\nMessage.Get(%v) = %v, want %v", name, num, formatValue(got), formatValue(want))
}
return true
})
@@ -188,10 +165,10 @@
// Non-existent map keys.
missingKey := newMapKey(fd, 1)
if got, want := mapv.Has(missingKey), false; got != want {
- t.Errorf("non-existent map key in %q: Has(%v) = %v, want %v", name, formatValue(missingKey.Value()), got, want)
+ t.Errorf("non-existent map key in %q: Map.Has(%v) = %v, want %v", name, formatValue(missingKey.Value()), got, want)
}
if got, want := mapv.Get(missingKey).IsValid(), false; got != want {
- t.Errorf("non-existent map key in %q: Get(%v).IsValid() = %v, want %v", name, formatValue(missingKey.Value()), got, want)
+ t.Errorf("non-existent map key in %q: Map.Get(%v).IsValid() = %v, want %v", name, formatValue(missingKey.Value()), got, want)
}
mapv.Clear(missingKey) // noop
}
@@ -214,24 +191,24 @@
// testFieldList exercises set/get/append/truncate of values in a list.
func testFieldList(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
- num := fd.Number()
name := fd.FullName()
- known := m.KnownFields()
- known.Clear(num) // start with an empty list
- list := known.Get(num).List()
+ num := fd.Number()
+
+ m.Clear(fd) // start with an empty list
+ list := m.Mutable(fd).List()
// Append values.
var want pref.List = &testList{}
for i, n := range []seed{1, 0, minVal, maxVal} {
- if got, want := known.Has(num), i > 0; got != want {
- t.Errorf("after appending %d elements to %q:\nHas(%v) = %v, want %v", i, name, num, got, want)
+ if got, want := m.Has(fd), i > 0; got != want {
+ t.Errorf("after appending %d elements to %q:\nMessage.Has(%v) = %v, want %v", i, name, num, got, want)
}
v := newListElement(fd, list, n, nil)
want.Append(v)
list.Append(v)
- if got, want := m.KnownFields().Get(num), pref.ValueOf(want); !valueEqual(got, want) {
- t.Errorf("after appending %d elements to %q:\nGet(%v) = %v, want %v", i+1, name, num, formatValue(got), formatValue(want))
+ if got, want := m.Get(fd), pref.ValueOf(want); !valueEqual(got, want) {
+ t.Errorf("after appending %d elements to %q:\nMessage.Get(%v) = %v, want %v", i+1, name, num, formatValue(got), formatValue(want))
}
}
@@ -240,8 +217,8 @@
v := newListElement(fd, list, seed(i+10), nil)
want.Set(i, v)
list.Set(i, v)
- if got, want := m.KnownFields().Get(num), pref.ValueOf(want); !valueEqual(got, want) {
- t.Errorf("after setting element %d of %q:\nGet(%v) = %v, want %v", i, name, num, formatValue(got), formatValue(want))
+ if got, want := m.Get(fd), pref.ValueOf(want); !valueEqual(got, want) {
+ t.Errorf("after setting element %d of %q:\nMessage.Get(%v) = %v, want %v", i, name, num, formatValue(got), formatValue(want))
}
}
@@ -250,11 +227,11 @@
n := want.Len() - 1
want.Truncate(n)
list.Truncate(n)
- if got, want := known.Has(num), want.Len() > 0; got != want {
- t.Errorf("after truncating %q to %d:\nHas(%v) = %v, want %v", name, n, num, got, want)
+ if got, want := m.Has(fd), want.Len() > 0; got != want {
+ t.Errorf("after truncating %q to %d:\nMessage.Has(%v) = %v, want %v", name, n, num, got, want)
}
- if got, want := m.KnownFields().Get(num), pref.ValueOf(want); !valueEqual(got, want) {
- t.Errorf("after truncating %q to %d:\nGet(%v) = %v, want %v", name, n, num, formatValue(got), formatValue(want))
+ if got, want := m.Get(fd), pref.ValueOf(want); !valueEqual(got, want) {
+ t.Errorf("after truncating %q to %d:\nMessage.Get(%v) = %v, want %v", name, n, num, formatValue(got), formatValue(want))
}
}
}
@@ -272,9 +249,9 @@
// testFieldFloat exercises some interesting floating-point scalar field values.
func testFieldFloat(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
- num := fd.Number()
name := fd.FullName()
- known := m.KnownFields()
+ num := fd.Number()
+
for _, v := range []float64{math.Inf(-1), math.Inf(1), math.NaN(), math.Copysign(0, -1)} {
var val pref.Value
if fd.Kind() == pref.FloatKind {
@@ -282,29 +259,28 @@
} else {
val = pref.ValueOf(v)
}
- known.Set(num, val)
+ m.Set(fd, val)
// Note that Has is true for -0.
- if got, want := known.Has(num), true; got != want {
- t.Errorf("after setting %v to %v: Get(%v) = %v, want %v", name, v, num, got, want)
+ if got, want := m.Has(fd), true; got != want {
+ t.Errorf("after setting %v to %v: Message.Has(%v) = %v, want %v", name, v, num, got, want)
}
- if got, want := known.Get(num), val; !valueEqual(got, want) {
- t.Errorf("after setting %v: Get(%v) = %v, want %v", name, num, formatValue(got), formatValue(want))
+ if got, want := m.Get(fd), val; !valueEqual(got, want) {
+ t.Errorf("after setting %v: Message.Get(%v) = %v, want %v", name, num, formatValue(got), formatValue(want))
}
}
}
// testOneof tests the behavior of fields in a oneof.
func testOneof(t testing.TB, m pref.Message, od pref.OneofDescriptor) {
- known := m.KnownFields()
for i := 0; i < od.Fields().Len(); i++ {
fda := od.Fields().Get(i)
- known.Set(fda.Number(), newValue(m, fda, 1, nil))
- if got, want := known.WhichOneof(od.Name()), fda.Number(); got != want {
+ m.Set(fda, newValue(m, fda, 1, nil))
+ if got, want := m.WhichOneof(od), fda; got != want {
t.Errorf("after setting oneof field %q:\nWhichOneof(%q) = %v, want %v", fda.FullName(), fda.Name(), got, want)
}
for j := 0; j < od.Fields().Len(); j++ {
fdb := od.Fields().Get(j)
- if got, want := known.Has(fdb.Number()), i == j; got != want {
+ if got, want := m.Has(fdb), i == j; got != want {
t.Errorf("after setting oneof field %q:\nGet(%q) = %v, want %v", fda.FullName(), fdb.FullName(), got, want)
}
}
@@ -422,10 +398,9 @@
// The stack parameter is used to avoid infinite recursion when populating circular
// data structures.
func newValue(m pref.Message, fd pref.FieldDescriptor, n seed, stack []pref.MessageDescriptor) pref.Value {
- num := fd.Number()
switch {
case fd.IsList():
- list := m.New().KnownFields().Get(num).List()
+ list := m.New().Mutable(fd).List()
if n == 0 {
return pref.ValueOf(list)
}
@@ -435,7 +410,7 @@
list.Append(newListElement(fd, list, n, stack))
return pref.ValueOf(list)
case fd.IsMap():
- mapv := m.New().KnownFields().Get(num).Map()
+ mapv := m.New().Mutable(fd).Map()
if n == 0 {
return pref.ValueOf(mapv)
}
@@ -445,7 +420,7 @@
mapv.Set(newMapKey(fd, n), newMapValue(fd, mapv, 10*n, stack))
return pref.ValueOf(mapv)
case fd.Message() != nil:
- return populateMessage(m.KnownFields().NewMessage(num), n, stack)
+ return populateMessage(m.Mutable(fd).Message(), n, stack)
default:
return newScalarValue(fd, n)
}
@@ -476,7 +451,7 @@
case pref.BoolKind:
return pref.ValueOf(n != 0)
case pref.EnumKind:
- // TODO use actual value
+ // TODO: use actual value
return pref.ValueOf(pref.EnumNumber(n))
case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind:
switch n {
@@ -559,13 +534,12 @@
}
}
stack = append(stack, md)
- known := m.KnownFields()
for i := 0; i < md.Fields().Len(); i++ {
fd := md.Fields().Get(i)
if fd.IsWeak() {
continue
}
- known.Set(fd.Number(), newValue(m, fd, 10*n+seed(i), stack))
+ m.Set(fd, newValue(m, fd, 10*n+seed(i), stack))
}
return pref.ValueOf(m)
}