all: refactor extensions, add proto.GetExtension etc.

Change protoiface.ExtensionDescV1 to implement protoreflect.ExtensionType.

ExtensionDescV1's Name field conflicts with the Descriptor Name method,
so change the protoreflect.{Message,Enum,Extension}Type types to no
longer implement the corresponding Descriptor interface. This also leads
to a clearer distinction between the two types.

Introduce a protoreflect.ExtensionTypeDescriptor type which bridges
between ExtensionType and ExtensionDescriptor.

Add extension accessor functions to the proto package:
proto.{Has,Clear,Get,Set}Extension. These functions take a
protoreflect.ExtensionType parameter, which allows writing the
same function call using either the old or new API:

  proto.GetExtension(message, somepb.E_ExtensionFoo)

Fixes golang/protobuf#908

Change-Id: Ibc65d12a46666297849114fd3aefbc4a597d9f08
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/189199
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/internal/impl/codec_extension.go b/internal/impl/codec_extension.go
index 3a23bb0..7e430ee 100644
--- a/internal/impl/codec_extension.go
+++ b/internal/impl/codec_extension.go
@@ -29,25 +29,26 @@
 		return e
 	}
 
+	xd := xt.Descriptor()
 	var wiretag uint64
-	if !xt.IsPacked() {
-		wiretag = wire.EncodeTag(xt.Number(), wireTypes[xt.Kind()])
+	if !xd.IsPacked() {
+		wiretag = wire.EncodeTag(xd.Number(), wireTypes[xd.Kind()])
 	} else {
-		wiretag = wire.EncodeTag(xt.Number(), wire.BytesType)
+		wiretag = wire.EncodeTag(xd.Number(), wire.BytesType)
 	}
 	e = &extensionFieldInfo{
 		wiretag: wiretag,
 		tagsize: wire.SizeVarint(wiretag),
-		funcs:   encoderFuncsForValue(xt, xt.GoType()),
+		funcs:   encoderFuncsForValue(xd, xt.GoType()),
 	}
 	// Does the unmarshal function need a value passed to it?
 	// This is true for composite types, where we pass in a message, list, or map to fill in,
 	// and for enums, where we pass in a prototype value to specify the concrete enum type.
-	switch xt.Kind() {
+	switch xd.Kind() {
 	case pref.MessageKind, pref.GroupKind, pref.EnumKind:
 		e.unmarshalNeedsValue = true
 	default:
-		if xt.Cardinality() == pref.Repeated {
+		if xd.Cardinality() == pref.Repeated {
 			e.unmarshalNeedsValue = true
 		}
 	}
diff --git a/internal/impl/codec_message.go b/internal/impl/codec_message.go
index e43c812..7014861 100644
--- a/internal/impl/codec_message.go
+++ b/internal/impl/codec_message.go
@@ -44,8 +44,9 @@
 	mi.extensionOffset = si.extensionOffset
 
 	mi.coderFields = make(map[wire.Number]*coderFieldInfo)
-	for i := 0; i < mi.PBType.Fields().Len(); i++ {
-		fd := mi.PBType.Fields().Get(i)
+	fields := mi.PBType.Descriptor().Fields()
+	for i := 0; i < fields.Len(); i++ {
+		fd := fields.Get(i)
 
 		fs := si.fieldsByNumber[fd.Number()]
 		if fd.ContainingOneof() != nil {
@@ -81,7 +82,7 @@
 	}
 	if messageset.IsMessageSet(mi.PBType.Descriptor()) {
 		if !mi.extensionOffset.IsValid() {
-			panic(fmt.Sprintf("%v: MessageSet with no extensions field", mi.PBType.FullName()))
+			panic(fmt.Sprintf("%v: MessageSet with no extensions field", mi.PBType.Descriptor().FullName()))
 		}
 		cf := &coderFieldInfo{
 			num:       messageset.FieldItem,
@@ -113,7 +114,7 @@
 		mi.denseCoderFields[cf.num] = cf
 	}
 
-	mi.needsInitCheck = needsInitCheck(mi.PBType)
+	mi.needsInitCheck = needsInitCheck(mi.PBType.Descriptor())
 	mi.methods = piface.Methods{
 		Flags:         piface.SupportMarshalDeterministic | piface.SupportUnmarshalDiscardUnknown,
 		MarshalAppend: mi.marshalAppend,
diff --git a/internal/impl/decode.go b/internal/impl/decode.go
index 4e4c7f3..4821852 100644
--- a/internal/impl/decode.go
+++ b/internal/impl/decode.go
@@ -138,7 +138,7 @@
 	xt := x.GetType()
 	if xt == nil {
 		var err error
-		xt, err = opts.Resolver().FindExtensionByNumber(mi.PBType.FullName(), num)
+		xt, err = opts.Resolver().FindExtensionByNumber(mi.PBType.Descriptor().FullName(), num)
 		if err != nil {
 			if err == preg.NotFound {
 				return 0, errUnknown
diff --git a/internal/impl/isinit.go b/internal/impl/isinit.go
index ca00012..079afe0 100644
--- a/internal/impl/isinit.go
+++ b/internal/impl/isinit.go
@@ -29,7 +29,7 @@
 	if p.IsNil() {
 		for _, f := range mi.orderedCoderFields {
 			if f.isRequired {
-				return errors.RequiredNotSet(string(mi.PBType.Fields().ByNumber(f.num).FullName()))
+				return errors.RequiredNotSet(string(mi.PBType.Descriptor().Fields().ByNumber(f.num).FullName()))
 			}
 		}
 		return nil
@@ -47,7 +47,7 @@
 		fptr := p.Apply(f.offset)
 		if f.isPointer && fptr.Elem().IsNil() {
 			if f.isRequired {
-				return errors.RequiredNotSet(string(mi.PBType.Fields().ByNumber(f.num).FullName()))
+				return errors.RequiredNotSet(string(mi.PBType.Descriptor().Fields().ByNumber(f.num).FullName()))
 			}
 			continue
 		}
diff --git a/internal/impl/legacy_extension.go b/internal/impl/legacy_extension.go
index aaf8fcc..2da4d71 100644
--- a/internal/impl/legacy_extension.go
+++ b/internal/impl/legacy_extension.go
@@ -5,11 +5,9 @@
 package impl
 
 import (
-	"fmt"
 	"reflect"
 	"sync"
 
-	"google.golang.org/protobuf/internal/descfmt"
 	ptag "google.golang.org/protobuf/internal/encoding/tag"
 	"google.golang.org/protobuf/internal/filedesc"
 	pref "google.golang.org/protobuf/reflect/protoreflect"
@@ -62,8 +60,9 @@
 	}
 
 	// Determine the parent type if possible.
+	xd := xt.Descriptor()
 	var parent piface.MessageV1
-	messageName := xt.ContainingMessage().FullName()
+	messageName := xd.ContainingMessage().FullName()
 	if mt, _ := preg.GlobalTypes.FindMessageByName(messageName); mt != nil {
 		// Create a new parent message and unwrap it if possible.
 		mv := mt.New().Interface()
@@ -94,7 +93,7 @@
 	// Reconstruct the legacy enum full name, which is an odd mixture of the
 	// proto package name with the Go type name.
 	var enumName string
-	if xt.Kind() == pref.EnumKind {
+	if xd.Kind() == pref.EnumKind {
 		// Derive Go type name.
 		t := extType
 		if t.Kind() == reflect.Ptr || t.Kind() == reflect.Slice {
@@ -105,7 +104,7 @@
 		// Derive the proto package name.
 		// For legacy enums, obtain the proto package from the raw descriptor.
 		var protoPkg string
-		if fd := xt.Enum().ParentFile(); fd != nil {
+		if fd := xd.Enum().ParentFile(); fd != nil {
 			protoPkg = string(fd.Package())
 		}
 		if ed, ok := reflect.Zero(t).Interface().(enumV1); ok && protoPkg == "" {
@@ -120,7 +119,7 @@
 
 	// Derive the proto file that the extension was declared within.
 	var filename string
-	if fd := xt.ParentFile(); fd != nil {
+	if fd := xd.ParentFile(); fd != nil {
 		filename = fd.Path()
 	}
 
@@ -129,9 +128,9 @@
 		Type:          xt,
 		ExtendedType:  parent,
 		ExtensionType: reflect.Zero(extType).Interface(),
-		Field:         int32(xt.Number()),
-		Name:          string(xt.FullName()),
-		Tag:           ptag.Marshal(xt, enumName),
+		Field:         int32(xd.Number()),
+		Name:          string(xd.FullName()),
+		Tag:           ptag.Marshal(xd, enumName),
 		Filename:      filename,
 	}
 	if d, ok := legacyExtensionDescCache.LoadOrStore(xt, d); ok {
@@ -217,15 +216,16 @@
 //
 // This is exported for testing purposes.
 func LegacyExtensionTypeOf(xd pref.ExtensionDescriptor, t reflect.Type) pref.ExtensionType {
-	return &legacyExtensionType{
-		ExtensionDescriptor: xd,
-		typ:                 t,
-		conv:                NewConverter(t, xd),
+	xt := &legacyExtensionType{
+		typ:  t,
+		conv: NewConverter(t, xd),
 	}
+	xt.desc = &extDesc{xd, xt}
+	return xt
 }
 
 type legacyExtensionType struct {
-	pref.ExtensionDescriptor
+	desc pref.ExtensionTypeDescriptor
 	typ  reflect.Type
 	conv Converter
 }
@@ -239,5 +239,12 @@
 func (x *legacyExtensionType) InterfaceOf(v pref.Value) interface{} {
 	return x.conv.GoValueOf(v).Interface()
 }
-func (x *legacyExtensionType) Descriptor() pref.ExtensionDescriptor { return x.ExtensionDescriptor }
-func (x *legacyExtensionType) Format(s fmt.State, r rune)           { descfmt.FormatDesc(s, r, x) }
+func (x *legacyExtensionType) Descriptor() pref.ExtensionTypeDescriptor { return x.desc }
+
+type extDesc struct {
+	pref.ExtensionDescriptor
+	xt *legacyExtensionType
+}
+
+func (t *extDesc) Type() pref.ExtensionType             { return t.xt }
+func (t *extDesc) Descriptor() pref.ExtensionDescriptor { return t.ExtensionDescriptor }
diff --git a/internal/impl/legacy_test.go b/internal/impl/legacy_test.go
index 70c5603..89cd0bc 100644
--- a/internal/impl/legacy_test.go
+++ b/internal/impl/legacy_test.go
@@ -52,7 +52,7 @@
 
 func init() {
 	mt := pimpl.Export{}.MessageTypeOf((*LegacyTestMessage)(nil))
-	preg.GlobalFiles.Register(mt.ParentFile())
+	preg.GlobalFiles.Register(mt.Descriptor().ParentFile())
 	preg.GlobalTypes.Register(mt)
 }
 
@@ -357,19 +357,21 @@
 	}
 	for i, xt := range extensionTypes {
 		var got interface{}
-		if !(xt.IsList() || xt.IsMap() || xt.Message() != nil) {
-			got = xt.InterfaceOf(m.Get(xt))
+		xd := xt.Descriptor()
+		if !(xd.IsList() || xd.IsMap() || xd.Message() != nil) {
+			got = xt.InterfaceOf(m.Get(xd))
 		}
 		want := defaultValues[i]
 		if diff := cmp.Diff(want, got, opts); diff != "" {
-			t.Errorf("Message.Get(%d) mismatch (-want +got):\n%v", xt.Number(), diff)
+			t.Errorf("Message.Get(%d) mismatch (-want +got):\n%v", xd.Number(), diff)
 		}
 	}
 
 	// All fields should be unpopulated.
 	for _, xt := range extensionTypes {
-		if m.Has(xt) {
-			t.Errorf("Message.Has(%d) = true, want false", xt.Number())
+		xd := xt.Descriptor()
+		if m.Has(xd) {
+			t.Errorf("Message.Has(%d) = true, want false", xd.Number())
 		}
 	}
 
@@ -401,11 +403,11 @@
 		19: &[]*EnumMessages{m2b},
 	}
 	for i, xt := range extensionTypes {
-		m.Set(xt, xt.ValueOf(setValues[i]))
+		m.Set(xt.Descriptor(), xt.ValueOf(setValues[i]))
 	}
 	for i, xt := range extensionTypes[len(extensionTypes)/2:] {
 		v := extensionTypes[i].ValueOf(setValues[i])
-		m.Get(xt).List().Append(v)
+		m.Get(xt.Descriptor()).List().Append(v)
 	}
 
 	// Get the values and check for equality.
@@ -432,24 +434,25 @@
 		19: &[]*EnumMessages{m2b, m2a},
 	}
 	for i, xt := range extensionTypes {
-		got := xt.InterfaceOf(m.Get(xt))
+		xd := xt.Descriptor()
+		got := xt.InterfaceOf(m.Get(xd))
 		want := getValues[i]
 		if diff := cmp.Diff(want, got, opts); diff != "" {
-			t.Errorf("Message.Get(%d) mismatch (-want +got):\n%v", xt.Number(), diff)
+			t.Errorf("Message.Get(%d) mismatch (-want +got):\n%v", xd.Number(), diff)
 		}
 	}
 
 	// Clear all singular fields and truncate all repeated fields.
 	for _, xt := range extensionTypes[:len(extensionTypes)/2] {
-		m.Clear(xt)
+		m.Clear(xt.Descriptor())
 	}
 	for _, xt := range extensionTypes[len(extensionTypes)/2:] {
-		m.Get(xt).List().Truncate(0)
+		m.Get(xt.Descriptor()).List().Truncate(0)
 	}
 
 	// Clear all repeated fields.
 	for _, xt := range extensionTypes[len(extensionTypes)/2:] {
-		m.Clear(xt)
+		m.Clear(xt.Descriptor())
 	}
 }
 
@@ -491,8 +494,6 @@
 							switch name {
 							case "ParentFile", "Parent":
 							// Ignore parents to avoid recursive cycle.
-							case "New", "Zero":
-								// Ignore constructors.
 							case "Options":
 								// Ignore descriptor options since protos are not cmperable.
 							case "ContainingOneof", "ContainingMessage", "Enum", "Message":
@@ -504,6 +505,8 @@
 								if !v.IsNil() {
 									out[name] = v.Interface().(pref.Descriptor).FullName()
 								}
+							case "Type":
+								// Ignore ExtensionTypeDescriptor.Type method to avoid cycle.
 							default:
 								out[name] = m.Call(nil)[0].Interface()
 							}
@@ -511,6 +514,12 @@
 					}
 					return out
 				}),
+				cmp.Transformer("", func(xt pref.ExtensionType) map[string]interface{} {
+					return map[string]interface{}{
+						"Descriptor": xt.Descriptor(),
+						"GoType":     xt.GoType(),
+					}
+				}),
 				cmp.Transformer("", func(v pref.Value) interface{} {
 					return v.Interface()
 				}),
@@ -605,23 +614,23 @@
 
 	var (
 		wantMTA = messageATypes[0]
-		wantMDA = messageATypes[0].Fields().ByNumber(1).Message()
+		wantMDA = messageATypes[0].Descriptor().Fields().ByNumber(1).Message()
 		wantMTB = messageBTypes[0]
-		wantMDB = messageBTypes[0].Fields().ByNumber(2).Message()
-		wantED  = messageATypes[0].Fields().ByNumber(3).Enum()
+		wantMDB = messageBTypes[0].Descriptor().Fields().ByNumber(2).Message()
+		wantED  = messageATypes[0].Descriptor().Fields().ByNumber(3).Enum()
 	)
 
 	for _, gotMT := range messageATypes[1:] {
 		if gotMT != wantMTA {
 			t.Error("MessageType(MessageA) mismatch")
 		}
-		if gotMDA := gotMT.Fields().ByNumber(1).Message(); gotMDA != wantMDA {
+		if gotMDA := gotMT.Descriptor().Fields().ByNumber(1).Message(); gotMDA != wantMDA {
 			t.Error("MessageDescriptor(MessageA) mismatch")
 		}
-		if gotMDB := gotMT.Fields().ByNumber(2).Message(); gotMDB != wantMDB {
+		if gotMDB := gotMT.Descriptor().Fields().ByNumber(2).Message(); gotMDB != wantMDB {
 			t.Error("MessageDescriptor(MessageB) mismatch")
 		}
-		if gotED := gotMT.Fields().ByNumber(3).Enum(); gotED != wantED {
+		if gotED := gotMT.Descriptor().Fields().ByNumber(3).Enum(); gotED != wantED {
 			t.Error("EnumDescriptor(Enum) mismatch")
 		}
 	}
@@ -629,13 +638,13 @@
 		if gotMT != wantMTB {
 			t.Error("MessageType(MessageB) mismatch")
 		}
-		if gotMDA := gotMT.Fields().ByNumber(1).Message(); gotMDA != wantMDA {
+		if gotMDA := gotMT.Descriptor().Fields().ByNumber(1).Message(); gotMDA != wantMDA {
 			t.Error("MessageDescriptor(MessageA) mismatch")
 		}
-		if gotMDB := gotMT.Fields().ByNumber(2).Message(); gotMDB != wantMDB {
+		if gotMDB := gotMT.Descriptor().Fields().ByNumber(2).Message(); gotMDB != wantMDB {
 			t.Error("MessageDescriptor(MessageB) mismatch")
 		}
-		if gotED := gotMT.Fields().ByNumber(3).Enum(); gotED != wantED {
+		if gotED := gotMT.Descriptor().Fields().ByNumber(3).Enum(); gotED != wantED {
 			t.Error("EnumDescriptor(Enum) mismatch")
 		}
 	}
diff --git a/internal/impl/message.go b/internal/impl/message.go
index 305e17d..6100663 100644
--- a/internal/impl/message.go
+++ b/internal/impl/message.go
@@ -222,8 +222,9 @@
 // any discrepancies.
 func (mi *MessageInfo) makeKnownFieldsFunc(si structInfo) {
 	mi.fields = map[pref.FieldNumber]*fieldInfo{}
-	for i := 0; i < mi.PBType.Fields().Len(); i++ {
-		fd := mi.PBType.Fields().Get(i)
+	md := mi.PBType.Descriptor()
+	for i := 0; i < md.Fields().Len(); i++ {
+		fd := md.Fields().Get(i)
 		fs := si.fieldsByNumber[fd.Number()]
 		var fi fieldInfo
 		switch {
@@ -244,8 +245,8 @@
 	}
 
 	mi.oneofs = map[pref.Name]*oneofInfo{}
-	for i := 0; i < mi.PBType.Oneofs().Len(); i++ {
-		od := mi.PBType.Oneofs().Get(i)
+	for i := 0; i < md.Oneofs().Len(); i++ {
+		od := md.Oneofs().Get(i)
 		mi.oneofs[od.Name()] = makeOneofInfo(od, si.oneofsByName[od.Name()], mi.Exporter, si.oneofWrappersByType)
 	}
 }
diff --git a/internal/impl/message_reflect.go b/internal/impl/message_reflect.go
index 699ac2c..b0f1778 100644
--- a/internal/impl/message_reflect.go
+++ b/internal/impl/message_reflect.go
@@ -121,7 +121,7 @@
 	if m != nil {
 		for _, x := range *m {
 			xt := x.GetType()
-			if !f(xt, xt.ValueOf(x.GetValue())) {
+			if !f(xt.Descriptor(), xt.ValueOf(x.GetValue())) {
 				return
 			}
 		}
@@ -129,16 +129,17 @@
 }
 func (m *extensionMap) Has(xt pref.ExtensionType) (ok bool) {
 	if m != nil {
-		_, ok = (*m)[int32(xt.Number())]
+		_, ok = (*m)[int32(xt.Descriptor().Number())]
 	}
 	return ok
 }
 func (m *extensionMap) Clear(xt pref.ExtensionType) {
-	delete(*m, int32(xt.Number()))
+	delete(*m, int32(xt.Descriptor().Number()))
 }
 func (m *extensionMap) Get(xt pref.ExtensionType) pref.Value {
+	xd := xt.Descriptor()
 	if m != nil {
-		if x, ok := (*m)[int32(xt.Number())]; ok {
+		if x, ok := (*m)[int32(xd.Number())]; ok {
 			return xt.ValueOf(x.GetValue())
 		}
 	}
@@ -151,13 +152,14 @@
 	var x ExtensionField
 	x.SetType(xt)
 	x.SetEagerValue(xt.InterfaceOf(v))
-	(*m)[int32(xt.Number())] = x
+	(*m)[int32(xt.Descriptor().Number())] = x
 }
 func (m *extensionMap) Mutable(xt pref.ExtensionType) pref.Value {
-	if !isComposite(xt) {
+	xd := xt.Descriptor()
+	if !isComposite(xd) {
 		panic("invalid Mutable on field with non-composite type")
 	}
-	if x, ok := (*m)[int32(xt.Number())]; ok {
+	if x, ok := (*m)[int32(xd.Number())]; ok {
 		return xt.ValueOf(x.GetValue())
 	}
 	v := xt.New()
@@ -179,14 +181,18 @@
 		return fi, nil
 	}
 	if fd.IsExtension() {
-		if fd.ContainingMessage().FullName() != mi.PBType.FullName() {
+		if fd.ContainingMessage().FullName() != mi.PBType.Descriptor().FullName() {
 			// TODO: Should this be exact containing message descriptor match?
 			panic("mismatching containing message")
 		}
-		if !mi.PBType.ExtensionRanges().Has(fd.Number()) {
+		if !mi.PBType.Descriptor().ExtensionRanges().Has(fd.Number()) {
 			panic("invalid extension field")
 		}
-		return nil, fd.(pref.ExtensionType)
+		xtd, ok := fd.(pref.ExtensionTypeDescriptor)
+		if !ok {
+			panic("extension descriptor does not implement ExtensionTypeDescriptor")
+		}
+		return nil, xtd.Type()
 	}
 	panic("invalid field descriptor")
 }