types/dynamicpb: support dynamic extensions

Add a dynamicpb.NewExtensionType function to permit creating extension
types from descriptors.

Also fix a some bugs around extension field handling:
When creating a new value for an extension field, use the
ExtensionType's Zero or New method to create the value.

Ensure that prototest exercises true zero-values of fields. (i.e.,
getting a list, map, or message from an empty message rather than
creating a new empty one with NewField.)

Change-Id: Idb8e87cdc92692610e12a4b8a68c34b129fae617
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/186180
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/testing/prototest/prototest.go b/testing/prototest/prototest.go
index fbfe7c6..20d20c9 100644
--- a/testing/prototest/prototest.go
+++ b/testing/prototest/prototest.go
@@ -29,6 +29,12 @@
 	//
 	// If nil, TestMessage will look for extension types in the global registry.
 	ExtensionTypes []pref.ExtensionType
+
+	// Resolver is used for looking up types when unmarshaling extension fields.
+	// If nil, this defaults to using protoregistry.GlobalTypes.
+	Resolver interface {
+		preg.ExtensionTypeResolver
+	}
 }
 
 // TestMessage runs the provided m through a series of tests
@@ -57,12 +63,20 @@
 	// Test round-trip marshal/unmarshal.
 	m2 := m.ProtoReflect().New().Interface()
 	populateMessage(m2.ProtoReflect(), 1, nil)
-	b, err := (proto.MarshalOptions{AllowPartial: true}).Marshal(m2)
+	for _, xt := range opts.ExtensionTypes {
+		m2.ProtoReflect().Set(xt.TypeDescriptor(), newValue(m2.ProtoReflect(), xt.TypeDescriptor(), 1, nil))
+	}
+	b, err := proto.MarshalOptions{
+		AllowPartial: true,
+	}.Marshal(m2)
 	if err != nil {
 		t.Errorf("Marshal() = %v, want nil\n%v", err, marshalText(m2))
 	}
 	m3 := m.ProtoReflect().New().Interface()
-	if err := (proto.UnmarshalOptions{AllowPartial: true}).Unmarshal(b, m3); err != nil {
+	if err := (proto.UnmarshalOptions{
+		AllowPartial: true,
+		Resolver:     opts.Resolver,
+	}.Unmarshal(b, m3)); err != nil {
 		t.Errorf("Unmarshal() = %v, want nil\n%v", err, marshalText(m2))
 	}
 	if !proto.Equal(m2, m3) {
@@ -150,7 +164,7 @@
 		}
 	case fd.IsMap():
 		if got := m.Get(fd); got.Map().Len() != 0 {
-			t.Errorf("after clearing %q:\nMessage.Get(%v) = %v, want empty list", name, num, formatValue(got))
+			t.Errorf("after clearing %q:\nMessage.Get(%v) = %v, want empty map", name, num, formatValue(got))
 		}
 	case fd.Message() == nil:
 		if got, want := m.Get(fd), fd.Default(); !valueEqual(got, want) {
@@ -158,6 +172,21 @@
 		}
 	}
 
+	// Set to the default value.
+	switch {
+	case fd.IsList() || fd.IsMap():
+		m.Set(fd, m.Get(fd))
+		if got, want := m.Has(fd), fd.IsExtension() || fd.ContainingOneof() != nil; got != want {
+			t.Errorf("after setting %q to default:\nMessage.Has(%v) = %v, want %v", name, num, got, want)
+		}
+	case fd.Message() == nil:
+		m.Set(fd, m.Get(fd))
+		if got, want := m.Get(fd), fd.Default(); !valueEqual(got, want) {
+			t.Errorf("after setting %q to default:\nMessage.Get(%v) = %v, want default %v", name, num, formatValue(got), formatValue(want))
+		}
+	}
+	m.Clear(fd)
+
 	// Set to the wrong type.
 	v := pref.ValueOf("")
 	if fd.Kind() == pref.StringKind {
@@ -508,26 +537,29 @@
 func newValue(m pref.Message, fd pref.FieldDescriptor, n seed, stack []pref.MessageDescriptor) pref.Value {
 	switch {
 	case fd.IsList():
-		list := m.NewField(fd).List()
 		if n == 0 {
-			return pref.ValueOf(list)
+			return m.New().Get(fd)
 		}
+		list := m.NewField(fd).List()
 		list.Append(newListElement(fd, list, 0, stack))
 		list.Append(newListElement(fd, list, minVal, stack))
 		list.Append(newListElement(fd, list, maxVal, stack))
 		list.Append(newListElement(fd, list, n, stack))
 		return pref.ValueOf(list)
 	case fd.IsMap():
-		mapv := m.NewField(fd).Map()
 		if n == 0 {
-			return pref.ValueOf(mapv)
+			return m.New().Get(fd)
 		}
+		mapv := m.NewField(fd).Map()
 		mapv.Set(newMapKey(fd, 0), newMapValue(fd, mapv, 0, stack))
 		mapv.Set(newMapKey(fd, minVal), newMapValue(fd, mapv, minVal, stack))
 		mapv.Set(newMapKey(fd, maxVal), newMapValue(fd, mapv, maxVal, stack))
 		mapv.Set(newMapKey(fd, n), newMapValue(fd, mapv, newSeed(n, 0), stack))
 		return pref.ValueOf(mapv)
 	case fd.Message() != nil:
+		//if n == 0 {
+		//	return m.New().Get(fd)
+		//}
 		return populateMessage(m.NewField(fd).Message(), n, stack)
 	default:
 		return newScalarValue(fd, n)