types/dynamicpb: support dynamic extensions
Add a dynamicpb.NewExtensionType function to permit creating extension
types from descriptors.
Also fix a some bugs around extension field handling:
When creating a new value for an extension field, use the
ExtensionType's Zero or New method to create the value.
Ensure that prototest exercises true zero-values of fields. (i.e.,
getting a list, map, or message from an empty message rather than
creating a new empty one with NewField.)
Change-Id: Idb8e87cdc92692610e12a4b8a68c34b129fae617
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/186180
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/testing/prototest/prototest.go b/testing/prototest/prototest.go
index fbfe7c6..20d20c9 100644
--- a/testing/prototest/prototest.go
+++ b/testing/prototest/prototest.go
@@ -29,6 +29,12 @@
//
// If nil, TestMessage will look for extension types in the global registry.
ExtensionTypes []pref.ExtensionType
+
+ // Resolver is used for looking up types when unmarshaling extension fields.
+ // If nil, this defaults to using protoregistry.GlobalTypes.
+ Resolver interface {
+ preg.ExtensionTypeResolver
+ }
}
// TestMessage runs the provided m through a series of tests
@@ -57,12 +63,20 @@
// Test round-trip marshal/unmarshal.
m2 := m.ProtoReflect().New().Interface()
populateMessage(m2.ProtoReflect(), 1, nil)
- b, err := (proto.MarshalOptions{AllowPartial: true}).Marshal(m2)
+ for _, xt := range opts.ExtensionTypes {
+ m2.ProtoReflect().Set(xt.TypeDescriptor(), newValue(m2.ProtoReflect(), xt.TypeDescriptor(), 1, nil))
+ }
+ b, err := proto.MarshalOptions{
+ AllowPartial: true,
+ }.Marshal(m2)
if err != nil {
t.Errorf("Marshal() = %v, want nil\n%v", err, marshalText(m2))
}
m3 := m.ProtoReflect().New().Interface()
- if err := (proto.UnmarshalOptions{AllowPartial: true}).Unmarshal(b, m3); err != nil {
+ if err := (proto.UnmarshalOptions{
+ AllowPartial: true,
+ Resolver: opts.Resolver,
+ }.Unmarshal(b, m3)); err != nil {
t.Errorf("Unmarshal() = %v, want nil\n%v", err, marshalText(m2))
}
if !proto.Equal(m2, m3) {
@@ -150,7 +164,7 @@
}
case fd.IsMap():
if got := m.Get(fd); got.Map().Len() != 0 {
- t.Errorf("after clearing %q:\nMessage.Get(%v) = %v, want empty list", name, num, formatValue(got))
+ t.Errorf("after clearing %q:\nMessage.Get(%v) = %v, want empty map", name, num, formatValue(got))
}
case fd.Message() == nil:
if got, want := m.Get(fd), fd.Default(); !valueEqual(got, want) {
@@ -158,6 +172,21 @@
}
}
+ // Set to the default value.
+ switch {
+ case fd.IsList() || fd.IsMap():
+ m.Set(fd, m.Get(fd))
+ if got, want := m.Has(fd), fd.IsExtension() || fd.ContainingOneof() != nil; got != want {
+ t.Errorf("after setting %q to default:\nMessage.Has(%v) = %v, want %v", name, num, got, want)
+ }
+ case fd.Message() == nil:
+ m.Set(fd, m.Get(fd))
+ if got, want := m.Get(fd), fd.Default(); !valueEqual(got, want) {
+ t.Errorf("after setting %q to default:\nMessage.Get(%v) = %v, want default %v", name, num, formatValue(got), formatValue(want))
+ }
+ }
+ m.Clear(fd)
+
// Set to the wrong type.
v := pref.ValueOf("")
if fd.Kind() == pref.StringKind {
@@ -508,26 +537,29 @@
func newValue(m pref.Message, fd pref.FieldDescriptor, n seed, stack []pref.MessageDescriptor) pref.Value {
switch {
case fd.IsList():
- list := m.NewField(fd).List()
if n == 0 {
- return pref.ValueOf(list)
+ return m.New().Get(fd)
}
+ list := m.NewField(fd).List()
list.Append(newListElement(fd, list, 0, stack))
list.Append(newListElement(fd, list, minVal, stack))
list.Append(newListElement(fd, list, maxVal, stack))
list.Append(newListElement(fd, list, n, stack))
return pref.ValueOf(list)
case fd.IsMap():
- mapv := m.NewField(fd).Map()
if n == 0 {
- return pref.ValueOf(mapv)
+ return m.New().Get(fd)
}
+ mapv := m.NewField(fd).Map()
mapv.Set(newMapKey(fd, 0), newMapValue(fd, mapv, 0, stack))
mapv.Set(newMapKey(fd, minVal), newMapValue(fd, mapv, minVal, stack))
mapv.Set(newMapKey(fd, maxVal), newMapValue(fd, mapv, maxVal, stack))
mapv.Set(newMapKey(fd, n), newMapValue(fd, mapv, newSeed(n, 0), stack))
return pref.ValueOf(mapv)
case fd.Message() != nil:
+ //if n == 0 {
+ // return m.New().Get(fd)
+ //}
return populateMessage(m.NewField(fd).Message(), n, stack)
default:
return newScalarValue(fd, n)