internal/impl: support wrapping Go structs to implement proto.Message

Given a pointer to a Go struct (that is well-formed according to the v1
struct field layout), wrap the type such that it implements the v2
protoreflect.Message interface.

Change-Id: I5987cad0d22e53970c613cdbbb1cfd4210897f69
Reviewed-on: https://go-review.googlesource.com/c/138897
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/internal/impl/message.go b/internal/impl/message.go
index 2ffa207..f51526b 100644
--- a/internal/impl/message.go
+++ b/internal/impl/message.go
@@ -5,19 +5,86 @@
 package impl
 
 import (
+	"fmt"
 	"reflect"
 	"strconv"
 	"strings"
+	"sync"
 
 	pref "github.com/golang/protobuf/v2/reflect/protoreflect"
+	ptype "github.com/golang/protobuf/v2/reflect/prototype"
 )
 
-type MessageInfo struct {
+// MessageType provides protobuf related functionality for a given Go type
+// that represents a message. A given instance of MessageType is tied to
+// exactly one Go type, which must be a pointer to a struct type.
+type MessageType struct {
+	// Desc is an optionally provided message descriptor. If nil, the descriptor
+	// is lazily derived from the Go type information of generated messages
+	// for the v1 API.
+	//
+	// Once set, this field must never be mutated.
+	Desc pref.MessageDescriptor
+
+	once sync.Once // protects all unexported fields
+
+	goType reflect.Type     // pointer to struct
+	pbType pref.MessageType // only valid if goType does not implement proto.Message
+
 	// TODO: Split fields into dense and sparse maps similar to the current
 	// table-driven implementation in v1?
 	fields map[pref.FieldNumber]*fieldInfo
 }
 
+// init lazily initializes the MessageType upon first use and
+// also checks that the provided pointer p is of the correct Go type.
+//
+// It must be called at the start of every exported method.
+func (mi *MessageType) init(p interface{}) {
+	mi.once.Do(func() {
+		v := reflect.ValueOf(p)
+		t := v.Type()
+		if t.Kind() != reflect.Ptr && t.Elem().Kind() != reflect.Struct {
+			panic(fmt.Sprintf("got %v, want *struct kind", t))
+		}
+		mi.goType = t
+
+		// Derive the message descriptor if unspecified.
+		md := mi.Desc
+		if md == nil {
+			// TODO: derive the message type from the Go struct type
+		}
+
+		// Initialize the Go message type wrapper if the Go type does not
+		// implement the proto.Message interface.
+		//
+		// Otherwise, we assume that the Go type manually implements the
+		// interface and is internally consistent such that:
+		//	goType == reflect.New(goType.Elem()).Interface().(proto.Message).ProtoReflect().Type().GoType()
+		//
+		// Generated code ensures that this property holds.
+		if _, ok := p.(pref.ProtoMessage); !ok {
+			mi.pbType = ptype.NewGoMessage(&ptype.GoMessage{
+				MessageDescriptor: md,
+				New: func(pref.MessageType) pref.ProtoMessage {
+					p := reflect.New(t.Elem()).Interface()
+					return (*message)(mi.dataTypeOf(p))
+				},
+			})
+		}
+
+		mi.generateFieldFuncs(t.Elem(), md)
+	})
+
+	// TODO: Remove this check? This API is primarily used by generated code,
+	// and should not violate this assumption. Leave this check in for now to
+	// provide some sanity checks during development. This can be removed if
+	// it proves to be detrimental to performance.
+	if reflect.TypeOf(p) != mi.goType {
+		panic(fmt.Sprintf("type mismatch: got %T, want %v", p, mi.goType))
+	}
+}
+
 // generateFieldFuncs generates per-field functions for all common operations
 // to be performed on each field. It takes in a reflect.Type representing the
 // Go struct, and a protoreflect.MessageDescriptor to match with the fields
@@ -25,7 +92,7 @@
 //
 // This code assumes that the struct is well-formed and panics if there are
 // any discrepancies.
-func (mi *MessageInfo) generateFieldFuncs(t reflect.Type, md pref.MessageDescriptor) {
+func (mi *MessageType) generateFieldFuncs(t reflect.Type, md pref.MessageDescriptor) {
 	// Generate a mapping of field numbers and names to Go struct field or type.
 	fields := map[pref.FieldNumber]reflect.StructField{}
 	oneofs := map[pref.Name]reflect.StructField{}
@@ -81,11 +148,170 @@
 			fi = fieldInfoForMap(fd, fs)
 		case fd.Cardinality() == pref.Repeated:
 			fi = fieldInfoForVector(fd, fs)
-		case fd.Kind() != pref.MessageKind && fd.Kind() != pref.GroupKind:
-			fi = fieldInfoForScalar(fd, fs)
-		default:
+		case fd.Kind() == pref.MessageKind || fd.Kind() == pref.GroupKind:
 			fi = fieldInfoForMessage(fd, fs)
+		default:
+			fi = fieldInfoForScalar(fd, fs)
 		}
 		mi.fields[fd.Number()] = &fi
 	}
 }
+
+func (mi *MessageType) MessageOf(p interface{}) pref.Message {
+	mi.init(p)
+	if m, ok := p.(pref.ProtoMessage); ok {
+		// We assume p properly implements protoreflect.Message.
+		// See the comment in MessageType.init regarding pbType.
+		return m.ProtoReflect()
+	}
+	return (*message)(mi.dataTypeOf(p))
+}
+
+func (mi *MessageType) KnownFieldsOf(p interface{}) pref.KnownFields {
+	mi.init(p)
+	return (*knownFields)(mi.dataTypeOf(p))
+}
+
+func (mi *MessageType) UnknownFieldsOf(p interface{}) pref.UnknownFields {
+	mi.init(p)
+	return (*unknownFields)(mi.dataTypeOf(p))
+}
+
+func (mi *MessageType) dataTypeOf(p interface{}) *messageDataType {
+	return &messageDataType{pointerOfIface(&p), mi}
+}
+
+// messageDataType is a tuple of a pointer to the message data and
+// a pointer to the message type.
+//
+// TODO: Unfortunately, we need to close over a pointer and MessageType,
+// which incurs an an allocation. This pair is similar to a Go interface,
+// which is essentially a tuple of the same thing. We can make this efficient
+// with reflect.NamedOf (see https://golang.org/issues/16522).
+//
+// With that hypothetical API, we could dynamically create a new named type
+// that has the same underlying type as MessageType.goType, and
+// dynamically create methods that close over MessageType.
+// Since the new type would have the same underlying type, we could directly
+// convert between pointers of those types, giving us an efficient way to swap
+// out the method set.
+//
+// Barring the ability to dynamically create named types, the workaround is
+//	1. either to accept the cost of an allocation for this wrapper struct or
+//	2. generate more types and methods, at the expense of binary size increase.
+type messageDataType struct {
+	p  pointer
+	mi *MessageType
+}
+
+type message messageDataType
+
+func (m *message) Type() pref.MessageType {
+	return m.mi.pbType
+}
+func (m *message) KnownFields() pref.KnownFields {
+	return (*knownFields)(m)
+}
+func (m *message) UnknownFields() pref.UnknownFields {
+	return (*unknownFields)(m)
+}
+func (m *message) Unwrap() interface{} {
+	return m.p.asType(m.mi.goType.Elem()).Interface()
+}
+func (m *message) Interface() pref.ProtoMessage {
+	return m
+}
+func (m *message) ProtoReflect() pref.Message {
+	return m
+}
+func (m *message) ProtoMutable() {}
+
+type knownFields messageDataType
+
+func (fs *knownFields) List() (nums []pref.FieldNumber) {
+	for n, fi := range fs.mi.fields {
+		if fi.has(fs.p) {
+			nums = append(nums, n)
+		}
+	}
+	// TODO: Handle extension fields.
+	return nums
+}
+func (fs *knownFields) Len() (cnt int) {
+	for _, fi := range fs.mi.fields {
+		if fi.has(fs.p) {
+			cnt++
+		}
+	}
+	// TODO: Handle extension fields.
+	return cnt
+}
+func (fs *knownFields) Has(n pref.FieldNumber) bool {
+	if fi := fs.mi.fields[n]; fi != nil {
+		return fi.has(fs.p)
+	}
+	// TODO: Handle extension fields.
+	return false
+}
+func (fs *knownFields) Get(n pref.FieldNumber) pref.Value {
+	if fi := fs.mi.fields[n]; fi != nil {
+		return fi.get(fs.p)
+	}
+	// TODO: Handle extension fields.
+	return pref.Value{}
+}
+func (fs *knownFields) Set(n pref.FieldNumber, v pref.Value) {
+	if fi := fs.mi.fields[n]; fi != nil {
+		fi.set(fs.p, v)
+		return
+	}
+	// TODO: Handle extension fields.
+	panic("invalid field")
+}
+func (fs *knownFields) Clear(n pref.FieldNumber) {
+	if fi := fs.mi.fields[n]; fi != nil {
+		fi.clear(fs.p)
+		return
+	}
+	// TODO: Handle extension fields.
+	panic("invalid field")
+}
+func (fs *knownFields) Mutable(n pref.FieldNumber) pref.Mutable {
+	if fi := fs.mi.fields[n]; fi != nil {
+		return fi.mutable(fs.p)
+	}
+	// TODO: Handle extension fields.
+	panic("invalid field")
+}
+func (fs *knownFields) Range(f func(pref.FieldNumber, pref.Value) bool) {
+	for n, fi := range fs.mi.fields {
+		if fi.has(fs.p) {
+			if !f(n, fi.get(fs.p)) {
+				return
+			}
+		}
+	}
+	// TODO: Handle extension fields.
+}
+func (fs *knownFields) ExtensionTypes() pref.ExtensionFieldTypes {
+	return (*extensionFieldTypes)(fs)
+}
+
+type extensionFieldTypes messageDataType // TODO
+
+func (fs *extensionFieldTypes) List() []pref.ExtensionType                   { return nil }
+func (fs *extensionFieldTypes) Len() int                                     { return 0 }
+func (fs *extensionFieldTypes) Register(pref.ExtensionType)                  { return }
+func (fs *extensionFieldTypes) Remove(pref.ExtensionType)                    { return }
+func (fs *extensionFieldTypes) ByNumber(pref.FieldNumber) pref.ExtensionType { return nil }
+func (fs *extensionFieldTypes) ByName(pref.FullName) pref.ExtensionType      { return nil }
+func (fs *extensionFieldTypes) Range(f func(pref.ExtensionType) bool)        { return }
+
+type unknownFields messageDataType // TODO
+
+func (fs *unknownFields) List() []pref.FieldNumber                            { return nil }
+func (fs *unknownFields) Len() int                                            { return 0 }
+func (fs *unknownFields) Get(n pref.FieldNumber) pref.RawFields               { return nil }
+func (fs *unknownFields) Set(n pref.FieldNumber, b pref.RawFields)            { return }
+func (fs *unknownFields) Range(f func(pref.FieldNumber, pref.RawFields) bool) { return }
+func (fs *unknownFields) IsSupported() bool                                   { return false }