reflect/protoreflect: add MessageFieldTypes
The MessageFieldTypes interface (if implemented by a MessageType)
provides Go type information about the fields if they are
an enum or message type.
Change-Id: I68b20f5726377f6b0f2c20a8b6e45f9802b43f67
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/236777
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/testing/prototest/message.go b/testing/prototest/message.go
new file mode 100644
index 0000000..e495628
--- /dev/null
+++ b/testing/prototest/message.go
@@ -0,0 +1,894 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package prototest exercises protobuf reflection.
+package prototest
+
+import (
+ "bytes"
+ "fmt"
+ "math"
+ "reflect"
+ "sort"
+ "strings"
+ "testing"
+
+ "google.golang.org/protobuf/encoding/prototext"
+ "google.golang.org/protobuf/encoding/protowire"
+ "google.golang.org/protobuf/proto"
+ "google.golang.org/protobuf/reflect/protoreflect"
+ pref "google.golang.org/protobuf/reflect/protoreflect"
+ "google.golang.org/protobuf/reflect/protoregistry"
+)
+
+// TODO: Test invalid field descriptors or oneof descriptors.
+// TODO: This should test the functionality that can be provided by fast-paths.
+
+// Message tests a message implemention.
+type Message struct {
+ // Resolver is used to determine the list of extension fields to test with.
+ // If nil, this defaults to using protoregistry.GlobalTypes.
+ Resolver interface {
+ FindExtensionByName(field pref.FullName) (pref.ExtensionType, error)
+ FindExtensionByNumber(message pref.FullName, field pref.FieldNumber) (pref.ExtensionType, error)
+ RangeExtensionsByMessage(message pref.FullName, f func(pref.ExtensionType) bool)
+ }
+}
+
+// Test performs tests on a MessageType implementation.
+func (test Message) Test(t testing.TB, mt pref.MessageType) {
+ testType(t, mt)
+
+ md := mt.Descriptor()
+ m1 := mt.New()
+ for i := 0; i < md.Fields().Len(); i++ {
+ fd := md.Fields().Get(i)
+ testField(t, m1, fd)
+ }
+ if test.Resolver == nil {
+ test.Resolver = protoregistry.GlobalTypes
+ }
+ var extTypes []pref.ExtensionType
+ test.Resolver.RangeExtensionsByMessage(md.FullName(), func(e pref.ExtensionType) bool {
+ extTypes = append(extTypes, e)
+ return true
+ })
+ for _, xt := range extTypes {
+ testField(t, m1, xt.TypeDescriptor())
+ }
+ for i := 0; i < md.Oneofs().Len(); i++ {
+ testOneof(t, m1, md.Oneofs().Get(i))
+ }
+ testUnknown(t, m1)
+
+ // Test round-trip marshal/unmarshal.
+ m2 := mt.New().Interface()
+ populateMessage(m2.ProtoReflect(), 1, nil)
+ for _, xt := range extTypes {
+ m2.ProtoReflect().Set(xt.TypeDescriptor(), newValue(m2.ProtoReflect(), xt.TypeDescriptor(), 1, nil))
+ }
+ b, err := proto.MarshalOptions{
+ AllowPartial: true,
+ }.Marshal(m2)
+ if err != nil {
+ t.Errorf("Marshal() = %v, want nil\n%v", err, prototext.Format(m2))
+ }
+ m3 := mt.New().Interface()
+ if err := (proto.UnmarshalOptions{
+ AllowPartial: true,
+ Resolver: test.Resolver,
+ }.Unmarshal(b, m3)); err != nil {
+ t.Errorf("Unmarshal() = %v, want nil\n%v", err, prototext.Format(m2))
+ }
+ if !proto.Equal(m2, m3) {
+ t.Errorf("round-trip marshal/unmarshal did not preserve message\nOriginal:\n%v\nNew:\n%v", prototext.Format(m2), prototext.Format(m3))
+ }
+}
+
+func testType(t testing.TB, mt pref.MessageType) {
+ m := mt.New().Interface()
+ want := reflect.TypeOf(m)
+ if got := reflect.TypeOf(m.ProtoReflect().Interface()); got != want {
+ t.Errorf("type mismatch: reflect.TypeOf(m) != reflect.TypeOf(m.ProtoReflect().Interface()): %v != %v", got, want)
+ }
+ if got := reflect.TypeOf(m.ProtoReflect().New().Interface()); got != want {
+ t.Errorf("type mismatch: reflect.TypeOf(m) != reflect.TypeOf(m.ProtoReflect().New().Interface()): %v != %v", got, want)
+ }
+ if got := reflect.TypeOf(m.ProtoReflect().Type().Zero().Interface()); got != want {
+ t.Errorf("type mismatch: reflect.TypeOf(m) != reflect.TypeOf(m.ProtoReflect().Type().Zero().Interface()): %v != %v", got, want)
+ }
+ if mt, ok := mt.(pref.MessageFieldTypes); ok {
+ testFieldTypes(t, mt)
+ }
+}
+
+func testFieldTypes(t testing.TB, mt pref.MessageFieldTypes) {
+ descName := func(d pref.Descriptor) pref.FullName {
+ if d == nil {
+ return "<nil>"
+ }
+ return d.FullName()
+ }
+ typeName := func(mt pref.MessageType) pref.FullName {
+ if mt == nil {
+ return "<nil>"
+ }
+ return mt.Descriptor().FullName()
+ }
+ adjustExpr := func(idx int, expr string) string {
+ expr = strings.Replace(expr, "fd.", "md.Fields().Get(i).", -1)
+ expr = strings.Replace(expr, "(fd)", "(md.Fields().Get(i))", -1)
+ expr = strings.Replace(expr, "mti.", "mt.Message(i).", -1)
+ expr = strings.Replace(expr, "(i)", fmt.Sprintf("(%d)", idx), -1)
+ return expr
+ }
+ checkEnumDesc := func(idx int, gotExpr, wantExpr string, got, want protoreflect.EnumDescriptor) {
+ if got != want {
+ t.Errorf("descriptor mismatch: %v != %v: %v != %v", adjustExpr(idx, gotExpr), adjustExpr(idx, wantExpr), descName(got), descName(want))
+ }
+ }
+ checkMessageDesc := func(idx int, gotExpr, wantExpr string, got, want protoreflect.MessageDescriptor) {
+ if got != want {
+ t.Errorf("descriptor mismatch: %v != %v: %v != %v", adjustExpr(idx, gotExpr), adjustExpr(idx, wantExpr), descName(got), descName(want))
+ }
+ }
+ checkMessageType := func(idx int, gotExpr, wantExpr string, got, want protoreflect.MessageType) {
+ if got != want {
+ t.Errorf("type mismatch: %v != %v: %v != %v", adjustExpr(idx, gotExpr), adjustExpr(idx, wantExpr), typeName(got), typeName(want))
+ }
+ }
+
+ fds := mt.Descriptor().Fields()
+ m := mt.New()
+ for i := 0; i < fds.Len(); i++ {
+ fd := fds.Get(i)
+ switch {
+ case fd.IsList():
+ if fd.Enum() != nil {
+ checkEnumDesc(i,
+ "mt.Enum(i).Descriptor()", "fd.Enum()",
+ mt.Enum(i).Descriptor(), fd.Enum())
+ }
+ if fd.Message() != nil {
+ checkMessageDesc(i,
+ "mt.Message(i).Descriptor()", "fd.Message()",
+ mt.Message(i).Descriptor(), fd.Message())
+ checkMessageType(i,
+ "mt.Message(i)", "m.NewField(fd).List().NewElement().Message().Type()",
+ mt.Message(i), m.NewField(fd).List().NewElement().Message().Type())
+ }
+ case fd.IsMap():
+ mti := mt.Message(i)
+ if m := mti.New(); m != nil {
+ checkMessageDesc(i,
+ "m.Descriptor()", "fd.Message()",
+ m.Descriptor(), fd.Message())
+ }
+ if m := mti.Zero(); m != nil {
+ checkMessageDesc(i,
+ "m.Descriptor()", "fd.Message()",
+ m.Descriptor(), fd.Message())
+ }
+ checkMessageDesc(i,
+ "mti.Descriptor()", "fd.Message()",
+ mti.Descriptor(), fd.Message())
+ if mti := mti.(pref.MessageFieldTypes); mti != nil {
+ if fd.MapValue().Enum() != nil {
+ checkEnumDesc(i,
+ "mti.Enum(fd.MapValue().Index()).Descriptor()", "fd.MapValue().Enum()",
+ mti.Enum(fd.MapValue().Index()).Descriptor(), fd.MapValue().Enum())
+ }
+ if fd.MapValue().Message() != nil {
+ checkMessageDesc(i,
+ "mti.Message(fd.MapValue().Index()).Descriptor()", "fd.MapValue().Message()",
+ mti.Message(fd.MapValue().Index()).Descriptor(), fd.MapValue().Message())
+ checkMessageType(i,
+ "mti.Message(fd.MapValue().Index())", "m.NewField(fd).Map().NewValue().Message().Type()",
+ mti.Message(fd.MapValue().Index()), m.NewField(fd).Map().NewValue().Message().Type())
+ }
+ }
+ default:
+ if fd.Enum() != nil {
+ checkEnumDesc(i,
+ "mt.Enum(i).Descriptor()", "fd.Enum()",
+ mt.Enum(i).Descriptor(), fd.Enum())
+ }
+ if fd.Message() != nil {
+ checkMessageDesc(i,
+ "mt.Message(i).Descriptor()", "fd.Message()",
+ mt.Message(i).Descriptor(), fd.Message())
+ checkMessageType(i,
+ "mt.Message(i)", "m.NewField(fd).Message().Type()",
+ mt.Message(i), m.NewField(fd).Message().Type())
+ }
+ }
+ }
+}
+
+// testField exercises set/get/has/clear of a field.
+func testField(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
+ name := fd.FullName()
+ num := fd.Number()
+
+ switch {
+ case fd.IsList():
+ testFieldList(t, m, fd)
+ case fd.IsMap():
+ testFieldMap(t, m, fd)
+ case fd.Message() != nil:
+ default:
+ if got, want := m.NewField(fd), fd.Default(); !valueEqual(got, want) {
+ t.Errorf("Message.NewField(%v) = %v, want default value %v", name, formatValue(got), formatValue(want))
+ }
+ if fd.Kind() == pref.FloatKind || fd.Kind() == pref.DoubleKind {
+ testFieldFloat(t, m, fd)
+ }
+ }
+
+ // Set to a non-zero value, the zero value, different non-zero values.
+ for _, n := range []seed{1, 0, minVal, maxVal} {
+ v := newValue(m, fd, n, nil)
+ m.Set(fd, v)
+ wantHas := true
+ if n == 0 {
+ if fd.Syntax() == pref.Proto3 && fd.Message() == nil {
+ wantHas = false
+ }
+ if fd.IsExtension() {
+ wantHas = true
+ }
+ if fd.Cardinality() == pref.Repeated {
+ wantHas = false
+ }
+ if fd.ContainingOneof() != nil {
+ wantHas = true
+ }
+ }
+ if fd.Syntax() == pref.Proto3 && fd.Cardinality() != pref.Repeated && fd.ContainingOneof() == nil && fd.Kind() == pref.EnumKind && v.Enum() == 0 {
+ wantHas = false
+ }
+ if got, want := m.Has(fd), wantHas; got != want {
+ t.Errorf("after setting %q to %v:\nMessage.Has(%v) = %v, want %v", name, formatValue(v), num, got, want)
+ }
+ if got, want := m.Get(fd), v; !valueEqual(got, want) {
+ t.Errorf("after setting %q:\nMessage.Get(%v) = %v, want %v", name, num, formatValue(got), formatValue(want))
+ }
+ found := false
+ m.Range(func(d pref.FieldDescriptor, got pref.Value) bool {
+ if fd != d {
+ return true
+ }
+ found = true
+ if want := v; !valueEqual(got, want) {
+ t.Errorf("after setting %q:\nMessage.Range got value %v, want %v", name, formatValue(got), formatValue(want))
+ }
+ return true
+ })
+ if got, want := wantHas, found; got != want {
+ t.Errorf("after setting %q:\nMessageRange saw field: %v, want %v", name, got, want)
+ }
+ }
+
+ m.Clear(fd)
+ if got, want := m.Has(fd), false; got != want {
+ t.Errorf("after clearing %q:\nMessage.Has(%v) = %v, want %v", name, num, got, want)
+ }
+ switch {
+ case fd.IsList():
+ if got := m.Get(fd); got.List().Len() != 0 {
+ t.Errorf("after clearing %q:\nMessage.Get(%v) = %v, want empty list", name, num, formatValue(got))
+ }
+ case fd.IsMap():
+ if got := m.Get(fd); got.Map().Len() != 0 {
+ t.Errorf("after clearing %q:\nMessage.Get(%v) = %v, want empty map", name, num, formatValue(got))
+ }
+ case fd.Message() == nil:
+ if got, want := m.Get(fd), fd.Default(); !valueEqual(got, want) {
+ t.Errorf("after clearing %q:\nMessage.Get(%v) = %v, want default %v", name, num, formatValue(got), formatValue(want))
+ }
+ }
+
+ // Set to the default value.
+ switch {
+ case fd.IsList() || fd.IsMap():
+ m.Set(fd, m.Mutable(fd))
+ if got, want := m.Has(fd), (fd.IsExtension() && fd.Cardinality() != pref.Repeated) || fd.ContainingOneof() != nil; got != want {
+ t.Errorf("after setting %q to default:\nMessage.Has(%v) = %v, want %v", name, num, got, want)
+ }
+ case fd.Message() == nil:
+ m.Set(fd, m.Get(fd))
+ if got, want := m.Get(fd), fd.Default(); !valueEqual(got, want) {
+ t.Errorf("after setting %q to default:\nMessage.Get(%v) = %v, want default %v", name, num, formatValue(got), formatValue(want))
+ }
+ }
+ m.Clear(fd)
+
+ // Set to the wrong type.
+ v := pref.ValueOfString("")
+ if fd.Kind() == pref.StringKind {
+ v = pref.ValueOfInt32(0)
+ }
+ if !panics(func() {
+ m.Set(fd, v)
+ }) {
+ t.Errorf("setting %v to %T succeeds, want panic", name, v.Interface())
+ }
+}
+
+// testFieldMap tests set/get/has/clear of entries in a map field.
+func testFieldMap(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
+ name := fd.FullName()
+ num := fd.Number()
+
+ // New values.
+ m.Clear(fd) // start with an empty map
+ mapv := m.Get(fd).Map()
+ if mapv.IsValid() {
+ t.Errorf("after clearing field: message.Get(%v).IsValid() = true, want false", name)
+ }
+ if got, want := mapv.NewValue(), newMapValue(fd, mapv, 0, nil); !valueEqual(got, want) {
+ t.Errorf("message.Get(%v).NewValue() = %v, want %v", name, formatValue(got), formatValue(want))
+ }
+ if !panics(func() {
+ m.Set(fd, pref.ValueOfMap(mapv))
+ }) {
+ t.Errorf("message.Set(%v, <invalid>) does not panic", name)
+ }
+ if !panics(func() {
+ mapv.Set(newMapKey(fd, 0), newMapValue(fd, mapv, 0, nil))
+ }) {
+ t.Errorf("message.Get(%v).Set(...) of invalid map does not panic", name)
+ }
+ mapv = m.Mutable(fd).Map() // mutable map
+ if !mapv.IsValid() {
+ t.Errorf("message.Mutable(%v).IsValid() = false, want true", name)
+ }
+ if got, want := mapv.NewValue(), newMapValue(fd, mapv, 0, nil); !valueEqual(got, want) {
+ t.Errorf("message.Mutable(%v).NewValue() = %v, want %v", name, formatValue(got), formatValue(want))
+ }
+
+ // Add values.
+ want := make(testMap)
+ for i, n := range []seed{1, 0, minVal, maxVal} {
+ if got, want := m.Has(fd), i > 0; got != want {
+ t.Errorf("after inserting %d elements to %q:\nMessage.Has(%v) = %v, want %v", i, name, num, got, want)
+ }
+
+ k := newMapKey(fd, n)
+ v := newMapValue(fd, mapv, n, nil)
+ mapv.Set(k, v)
+ want.Set(k, v)
+ if got, want := m.Get(fd), pref.ValueOfMap(want); !valueEqual(got, want) {
+ t.Errorf("after inserting %d elements to %q:\nMessage.Get(%v) = %v, want %v", i, name, num, formatValue(got), formatValue(want))
+ }
+ }
+
+ // Set values.
+ want.Range(func(k pref.MapKey, v pref.Value) bool {
+ nv := newMapValue(fd, mapv, 10, nil)
+ mapv.Set(k, nv)
+ want.Set(k, nv)
+ if got, want := m.Get(fd), pref.ValueOfMap(want); !valueEqual(got, want) {
+ t.Errorf("after setting element %v of %q:\nMessage.Get(%v) = %v, want %v", formatValue(k.Value()), name, num, formatValue(got), formatValue(want))
+ }
+ return true
+ })
+
+ // Clear values.
+ want.Range(func(k pref.MapKey, v pref.Value) bool {
+ mapv.Clear(k)
+ want.Clear(k)
+ if got, want := m.Has(fd), want.Len() > 0; got != want {
+ t.Errorf("after clearing elements of %q:\nMessage.Has(%v) = %v, want %v", name, num, got, want)
+ }
+ if got, want := m.Get(fd), pref.ValueOfMap(want); !valueEqual(got, want) {
+ t.Errorf("after clearing elements of %q:\nMessage.Get(%v) = %v, want %v", name, num, formatValue(got), formatValue(want))
+ }
+ return true
+ })
+ if mapv := m.Get(fd).Map(); mapv.IsValid() {
+ t.Errorf("after clearing all elements: message.Get(%v).IsValid() = true, want false %v", name, formatValue(pref.ValueOfMap(mapv)))
+ }
+
+ // Non-existent map keys.
+ missingKey := newMapKey(fd, 1)
+ if got, want := mapv.Has(missingKey), false; got != want {
+ t.Errorf("non-existent map key in %q: Map.Has(%v) = %v, want %v", name, formatValue(missingKey.Value()), got, want)
+ }
+ if got, want := mapv.Get(missingKey).IsValid(), false; got != want {
+ t.Errorf("non-existent map key in %q: Map.Get(%v).IsValid() = %v, want %v", name, formatValue(missingKey.Value()), got, want)
+ }
+ mapv.Clear(missingKey) // noop
+
+ // Mutable.
+ if fd.MapValue().Message() == nil {
+ if !panics(func() {
+ mapv.Mutable(newMapKey(fd, 1))
+ }) {
+ t.Errorf("Mutable on %q succeeds, want panic", name)
+ }
+ } else {
+ k := newMapKey(fd, 1)
+ v := mapv.Mutable(k)
+ if got, want := mapv.Len(), 1; got != want {
+ t.Errorf("after Mutable on %q, Map.Len() = %v, want %v", name, got, want)
+ }
+ populateMessage(v.Message(), 1, nil)
+ if !valueEqual(mapv.Get(k), v) {
+ t.Errorf("after Mutable on %q, changing new mutable value does not change map entry", name)
+ }
+ mapv.Clear(k)
+ }
+}
+
+type testMap map[interface{}]pref.Value
+
+func (m testMap) Get(k pref.MapKey) pref.Value { return m[k.Interface()] }
+func (m testMap) Set(k pref.MapKey, v pref.Value) { m[k.Interface()] = v }
+func (m testMap) Has(k pref.MapKey) bool { return m.Get(k).IsValid() }
+func (m testMap) Clear(k pref.MapKey) { delete(m, k.Interface()) }
+func (m testMap) Mutable(k pref.MapKey) pref.Value { panic("unimplemented") }
+func (m testMap) Len() int { return len(m) }
+func (m testMap) NewValue() pref.Value { panic("unimplemented") }
+func (m testMap) Range(f func(pref.MapKey, pref.Value) bool) {
+ for k, v := range m {
+ if !f(pref.ValueOf(k).MapKey(), v) {
+ return
+ }
+ }
+}
+func (m testMap) IsValid() bool { return true }
+
+// testFieldList exercises set/get/append/truncate of values in a list.
+func testFieldList(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
+ name := fd.FullName()
+ num := fd.Number()
+
+ m.Clear(fd) // start with an empty list
+ list := m.Get(fd).List()
+ if list.IsValid() {
+ t.Errorf("message.Get(%v).IsValid() = true, want false", name)
+ }
+ if !panics(func() {
+ m.Set(fd, pref.ValueOfList(list))
+ }) {
+ t.Errorf("message.Set(%v, <invalid>) does not panic", name)
+ }
+ if !panics(func() {
+ list.Append(newListElement(fd, list, 0, nil))
+ }) {
+ t.Errorf("message.Get(%v).Append(...) of invalid list does not panic", name)
+ }
+ if got, want := list.NewElement(), newListElement(fd, list, 0, nil); !valueEqual(got, want) {
+ t.Errorf("message.Get(%v).NewElement() = %v, want %v", name, formatValue(got), formatValue(want))
+ }
+ list = m.Mutable(fd).List() // mutable list
+ if !list.IsValid() {
+ t.Errorf("message.Get(%v).IsValid() = false, want true", name)
+ }
+ if got, want := list.NewElement(), newListElement(fd, list, 0, nil); !valueEqual(got, want) {
+ t.Errorf("message.Mutable(%v).NewElement() = %v, want %v", name, formatValue(got), formatValue(want))
+ }
+
+ // Append values.
+ var want pref.List = &testList{}
+ for i, n := range []seed{1, 0, minVal, maxVal} {
+ if got, want := m.Has(fd), i > 0; got != want {
+ t.Errorf("after appending %d elements to %q:\nMessage.Has(%v) = %v, want %v", i, name, num, got, want)
+ }
+ v := newListElement(fd, list, n, nil)
+ want.Append(v)
+ list.Append(v)
+
+ if got, want := m.Get(fd), pref.ValueOfList(want); !valueEqual(got, want) {
+ t.Errorf("after appending %d elements to %q:\nMessage.Get(%v) = %v, want %v", i+1, name, num, formatValue(got), formatValue(want))
+ }
+ }
+
+ // Set values.
+ for i := 0; i < want.Len(); i++ {
+ v := newListElement(fd, list, seed(i+10), nil)
+ want.Set(i, v)
+ list.Set(i, v)
+ if got, want := m.Get(fd), pref.ValueOfList(want); !valueEqual(got, want) {
+ t.Errorf("after setting element %d of %q:\nMessage.Get(%v) = %v, want %v", i, name, num, formatValue(got), formatValue(want))
+ }
+ }
+
+ // Truncate.
+ for want.Len() > 0 {
+ n := want.Len() - 1
+ want.Truncate(n)
+ list.Truncate(n)
+ if got, want := m.Has(fd), want.Len() > 0; got != want {
+ t.Errorf("after truncating %q to %d:\nMessage.Has(%v) = %v, want %v", name, n, num, got, want)
+ }
+ if got, want := m.Get(fd), pref.ValueOfList(want); !valueEqual(got, want) {
+ t.Errorf("after truncating %q to %d:\nMessage.Get(%v) = %v, want %v", name, n, num, formatValue(got), formatValue(want))
+ }
+ }
+
+ // AppendMutable.
+ if fd.Message() == nil {
+ if !panics(func() {
+ list.AppendMutable()
+ }) {
+ t.Errorf("AppendMutable on %q succeeds, want panic", name)
+ }
+ } else {
+ v := list.AppendMutable()
+ if got, want := list.Len(), 1; got != want {
+ t.Errorf("after AppendMutable on %q, list.Len() = %v, want %v", name, got, want)
+ }
+ populateMessage(v.Message(), 1, nil)
+ if !valueEqual(list.Get(0), v) {
+ t.Errorf("after AppendMutable on %q, changing new mutable value does not change list item 0", name)
+ }
+ want.Truncate(0)
+ }
+}
+
+type testList struct {
+ a []pref.Value
+}
+
+func (l *testList) Append(v pref.Value) { l.a = append(l.a, v) }
+func (l *testList) AppendMutable() pref.Value { panic("unimplemented") }
+func (l *testList) Get(n int) pref.Value { return l.a[n] }
+func (l *testList) Len() int { return len(l.a) }
+func (l *testList) Set(n int, v pref.Value) { l.a[n] = v }
+func (l *testList) Truncate(n int) { l.a = l.a[:n] }
+func (l *testList) NewElement() pref.Value { panic("unimplemented") }
+func (l *testList) IsValid() bool { return true }
+
+// testFieldFloat exercises some interesting floating-point scalar field values.
+func testFieldFloat(t testing.TB, m pref.Message, fd pref.FieldDescriptor) {
+ name := fd.FullName()
+ num := fd.Number()
+
+ for _, v := range []float64{math.Inf(-1), math.Inf(1), math.NaN(), math.Copysign(0, -1)} {
+ var val pref.Value
+ if fd.Kind() == pref.FloatKind {
+ val = pref.ValueOfFloat32(float32(v))
+ } else {
+ val = pref.ValueOfFloat64(float64(v))
+ }
+ m.Set(fd, val)
+ // Note that Has is true for -0.
+ if got, want := m.Has(fd), true; got != want {
+ t.Errorf("after setting %v to %v: Message.Has(%v) = %v, want %v", name, v, num, got, want)
+ }
+ if got, want := m.Get(fd), val; !valueEqual(got, want) {
+ t.Errorf("after setting %v: Message.Get(%v) = %v, want %v", name, num, formatValue(got), formatValue(want))
+ }
+ }
+}
+
+// testOneof tests the behavior of fields in a oneof.
+func testOneof(t testing.TB, m pref.Message, od pref.OneofDescriptor) {
+ for _, mutable := range []bool{false, true} {
+ for i := 0; i < od.Fields().Len(); i++ {
+ fda := od.Fields().Get(i)
+ if mutable {
+ // Set fields by requesting a mutable reference.
+ if !fda.IsMap() && !fda.IsList() && fda.Message() == nil {
+ continue
+ }
+ _ = m.Mutable(fda)
+ } else {
+ // Set fields explicitly.
+ m.Set(fda, newValue(m, fda, 1, nil))
+ }
+ if got, want := m.WhichOneof(od), fda; got != want {
+ t.Errorf("after setting oneof field %q:\nWhichOneof(%q) = %v, want %v", fda.FullName(), fda.Name(), got, want)
+ }
+ for j := 0; j < od.Fields().Len(); j++ {
+ fdb := od.Fields().Get(j)
+ if got, want := m.Has(fdb), i == j; got != want {
+ t.Errorf("after setting oneof field %q:\nGet(%q) = %v, want %v", fda.FullName(), fdb.FullName(), got, want)
+ }
+ }
+ }
+ }
+}
+
+// testUnknown tests the behavior of unknown fields.
+func testUnknown(t testing.TB, m pref.Message) {
+ var b []byte
+ b = protowire.AppendTag(b, 1000, protowire.VarintType)
+ b = protowire.AppendVarint(b, 1001)
+ m.SetUnknown(pref.RawFields(b))
+ if got, want := []byte(m.GetUnknown()), b; !bytes.Equal(got, want) {
+ t.Errorf("after setting unknown fields:\nGetUnknown() = %v, want %v", got, want)
+ }
+}
+
+func formatValue(v pref.Value) string {
+ switch v := v.Interface().(type) {
+ case pref.List:
+ var buf bytes.Buffer
+ buf.WriteString("list[")
+ for i := 0; i < v.Len(); i++ {
+ if i > 0 {
+ buf.WriteString(" ")
+ }
+ buf.WriteString(formatValue(v.Get(i)))
+ }
+ buf.WriteString("]")
+ return buf.String()
+ case pref.Map:
+ var buf bytes.Buffer
+ buf.WriteString("map[")
+ var keys []pref.MapKey
+ v.Range(func(k pref.MapKey, v pref.Value) bool {
+ keys = append(keys, k)
+ return true
+ })
+ sort.Slice(keys, func(i, j int) bool {
+ return keys[i].String() < keys[j].String()
+ })
+ for i, k := range keys {
+ if i > 0 {
+ buf.WriteString(" ")
+ }
+ buf.WriteString(formatValue(k.Value()))
+ buf.WriteString(":")
+ buf.WriteString(formatValue(v.Get(k)))
+ }
+ buf.WriteString("]")
+ return buf.String()
+ case pref.Message:
+ b, err := prototext.Marshal(v.Interface())
+ if err != nil {
+ return fmt.Sprintf("<%v>", err)
+ }
+ return fmt.Sprintf("%v{%v}", v.Descriptor().FullName(), string(b))
+ case string:
+ return fmt.Sprintf("%q", v)
+ default:
+ return fmt.Sprint(v)
+ }
+}
+
+func valueEqual(a, b pref.Value) bool {
+ ai, bi := a.Interface(), b.Interface()
+ switch ai.(type) {
+ case pref.Message:
+ return proto.Equal(
+ a.Message().Interface(),
+ b.Message().Interface(),
+ )
+ case pref.List:
+ lista, listb := a.List(), b.List()
+ if lista.Len() != listb.Len() {
+ return false
+ }
+ for i := 0; i < lista.Len(); i++ {
+ if !valueEqual(lista.Get(i), listb.Get(i)) {
+ return false
+ }
+ }
+ return true
+ case pref.Map:
+ mapa, mapb := a.Map(), b.Map()
+ if mapa.Len() != mapb.Len() {
+ return false
+ }
+ equal := true
+ mapa.Range(func(k pref.MapKey, v pref.Value) bool {
+ if !valueEqual(v, mapb.Get(k)) {
+ equal = false
+ return false
+ }
+ return true
+ })
+ return equal
+ case []byte:
+ return bytes.Equal(a.Bytes(), b.Bytes())
+ case float32:
+ // NaNs are equal, but must be the same NaN.
+ return math.Float32bits(ai.(float32)) == math.Float32bits(bi.(float32))
+ case float64:
+ // NaNs are equal, but must be the same NaN.
+ return math.Float64bits(ai.(float64)) == math.Float64bits(bi.(float64))
+ default:
+ return ai == bi
+ }
+}
+
+// A seed is used to vary the content of a value.
+//
+// A seed of 0 is the zero value. Messages do not have a zero-value; a 0-seeded messages
+// is unpopulated.
+//
+// A seed of minVal or maxVal is the least or greatest value of the value type.
+type seed int
+
+const (
+ minVal seed = -1
+ maxVal seed = -2
+)
+
+// newSeed creates new seed values from a base, for example to create seeds for the
+// elements in a list. If the input seed is minVal or maxVal, so is the output.
+func newSeed(n seed, adjust ...int) seed {
+ switch n {
+ case minVal, maxVal:
+ return n
+ }
+ for _, a := range adjust {
+ n = 10*n + seed(a)
+ }
+ return n
+}
+
+// newValue returns a new value assignable to a field.
+//
+// The stack parameter is used to avoid infinite recursion when populating circular
+// data structures.
+func newValue(m pref.Message, fd pref.FieldDescriptor, n seed, stack []pref.MessageDescriptor) pref.Value {
+ switch {
+ case fd.IsList():
+ if n == 0 {
+ return m.New().Mutable(fd)
+ }
+ list := m.NewField(fd).List()
+ list.Append(newListElement(fd, list, 0, stack))
+ list.Append(newListElement(fd, list, minVal, stack))
+ list.Append(newListElement(fd, list, maxVal, stack))
+ list.Append(newListElement(fd, list, n, stack))
+ return pref.ValueOfList(list)
+ case fd.IsMap():
+ if n == 0 {
+ return m.New().Mutable(fd)
+ }
+ mapv := m.NewField(fd).Map()
+ mapv.Set(newMapKey(fd, 0), newMapValue(fd, mapv, 0, stack))
+ mapv.Set(newMapKey(fd, minVal), newMapValue(fd, mapv, minVal, stack))
+ mapv.Set(newMapKey(fd, maxVal), newMapValue(fd, mapv, maxVal, stack))
+ mapv.Set(newMapKey(fd, n), newMapValue(fd, mapv, newSeed(n, 0), stack))
+ return pref.ValueOfMap(mapv)
+ case fd.Message() != nil:
+ return populateMessage(m.NewField(fd).Message(), n, stack)
+ default:
+ return newScalarValue(fd, n)
+ }
+}
+
+func newListElement(fd pref.FieldDescriptor, list pref.List, n seed, stack []pref.MessageDescriptor) pref.Value {
+ if fd.Message() == nil {
+ return newScalarValue(fd, n)
+ }
+ return populateMessage(list.NewElement().Message(), n, stack)
+}
+
+func newMapKey(fd pref.FieldDescriptor, n seed) pref.MapKey {
+ kd := fd.MapKey()
+ return newScalarValue(kd, n).MapKey()
+}
+
+func newMapValue(fd pref.FieldDescriptor, mapv pref.Map, n seed, stack []pref.MessageDescriptor) pref.Value {
+ vd := fd.MapValue()
+ if vd.Message() == nil {
+ return newScalarValue(vd, n)
+ }
+ return populateMessage(mapv.NewValue().Message(), n, stack)
+}
+
+func newScalarValue(fd pref.FieldDescriptor, n seed) pref.Value {
+ switch fd.Kind() {
+ case pref.BoolKind:
+ return pref.ValueOfBool(n != 0)
+ case pref.EnumKind:
+ vals := fd.Enum().Values()
+ var i int
+ switch n {
+ case minVal:
+ i = 0
+ case maxVal:
+ i = vals.Len() - 1
+ default:
+ i = int(n) % vals.Len()
+ }
+ return pref.ValueOfEnum(vals.Get(i).Number())
+ case pref.Int32Kind, pref.Sint32Kind, pref.Sfixed32Kind:
+ switch n {
+ case minVal:
+ return pref.ValueOfInt32(math.MinInt32)
+ case maxVal:
+ return pref.ValueOfInt32(math.MaxInt32)
+ default:
+ return pref.ValueOfInt32(int32(n))
+ }
+ case pref.Uint32Kind, pref.Fixed32Kind:
+ switch n {
+ case minVal:
+ // Only use 0 for the zero value.
+ return pref.ValueOfUint32(1)
+ case maxVal:
+ return pref.ValueOfUint32(math.MaxInt32)
+ default:
+ return pref.ValueOfUint32(uint32(n))
+ }
+ case pref.Int64Kind, pref.Sint64Kind, pref.Sfixed64Kind:
+ switch n {
+ case minVal:
+ return pref.ValueOfInt64(math.MinInt64)
+ case maxVal:
+ return pref.ValueOfInt64(math.MaxInt64)
+ default:
+ return pref.ValueOfInt64(int64(n))
+ }
+ case pref.Uint64Kind, pref.Fixed64Kind:
+ switch n {
+ case minVal:
+ // Only use 0 for the zero value.
+ return pref.ValueOfUint64(1)
+ case maxVal:
+ return pref.ValueOfUint64(math.MaxInt64)
+ default:
+ return pref.ValueOfUint64(uint64(n))
+ }
+ case pref.FloatKind:
+ switch n {
+ case minVal:
+ return pref.ValueOfFloat32(math.SmallestNonzeroFloat32)
+ case maxVal:
+ return pref.ValueOfFloat32(math.MaxFloat32)
+ default:
+ return pref.ValueOfFloat32(1.5 * float32(n))
+ }
+ case pref.DoubleKind:
+ switch n {
+ case minVal:
+ return pref.ValueOfFloat64(math.SmallestNonzeroFloat64)
+ case maxVal:
+ return pref.ValueOfFloat64(math.MaxFloat64)
+ default:
+ return pref.ValueOfFloat64(1.5 * float64(n))
+ }
+ case pref.StringKind:
+ if n == 0 {
+ return pref.ValueOfString("")
+ }
+ return pref.ValueOfString(fmt.Sprintf("%d", n))
+ case pref.BytesKind:
+ if n == 0 {
+ return pref.ValueOfBytes(nil)
+ }
+ return pref.ValueOfBytes([]byte{byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n)})
+ }
+ panic("unhandled kind")
+}
+
+func populateMessage(m pref.Message, n seed, stack []pref.MessageDescriptor) pref.Value {
+ if n == 0 {
+ return pref.ValueOfMessage(m)
+ }
+ md := m.Descriptor()
+ for _, x := range stack {
+ if md == x {
+ return pref.ValueOfMessage(m)
+ }
+ }
+ stack = append(stack, md)
+ for i := 0; i < md.Fields().Len(); i++ {
+ fd := md.Fields().Get(i)
+ if fd.IsWeak() {
+ continue
+ }
+ m.Set(fd, newValue(m, fd, newSeed(n, i), stack))
+ }
+ return pref.ValueOfMessage(m)
+}
+
+func panics(f func()) (didPanic bool) {
+ defer func() {
+ if err := recover(); err != nil {
+ didPanic = true
+ }
+ }()
+ f()
+ return false
+}