internal/impl: add MessageState to every generated message

We define MessageState, which is essentially an atomically set *MessageInfo.
By nesting this as the first field in every generated message, we can
implement the reflective methods on a *MessageState when obtained by
unsafe casting a concrete message pointer as a *MessageState.
The MessageInfo held by MessageState provides additional Go type information
to interpret the memory that comes after the contents of the MessageState.

Since we are nesting a MessageState in every message,
the memory use of every message instance grows by 8B.

On average, the body of ProtoReflect grows from 133B to 202B (+50%).
However, this is offset by XXX_Methods, which is 108B and
will be removed in a future CL. Taking into account the eventual removal
of XXX_Methods, this is a net reduction of 25%.

name          old time/op    new time/op    delta
Name/Value-4    70.3ns ± 2%    17.5ns ± 6%   -75.08%  (p=0.000 n=10+10)
Name/Nil-4      70.6ns ± 3%    33.4ns ± 2%   -52.66%  (p=0.000 n=10+10)

name          old alloc/op   new alloc/op   delta
Name/Value-4     16.0B ± 0%      0.0B       -100.00%  (p=0.000 n=10+10)
Name/Nil-4       16.0B ± 0%      0.0B       -100.00%  (p=0.000 n=10+10)

name          old allocs/op  new allocs/op  delta
Name/Value-4      1.00 ± 0%      0.00       -100.00%  (p=0.000 n=10+10)
Name/Nil-4        1.00 ± 0%      0.00       -100.00%  (p=0.000 n=10+10)

Change-Id: I92bd58dc681c57c92612fd5ba7fc066aea34e95a
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/185460
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/internal/impl/message.go b/internal/impl/message.go
index fe7cd37..201e1d6 100644
--- a/internal/impl/message.go
+++ b/internal/impl/message.go
@@ -12,7 +12,6 @@
 	"sync"
 	"sync/atomic"
 
-	pvalue "google.golang.org/protobuf/internal/value"
 	pref "google.golang.org/protobuf/reflect/protoreflect"
 	piface "google.golang.org/protobuf/runtime/protoiface"
 )
@@ -39,13 +38,7 @@
 	initMu   sync.Mutex // protects all unexported fields
 	initDone uint32
 
-	fields map[pref.FieldNumber]*fieldInfo
-	oneofs map[pref.Name]*oneofInfo
-
-	getUnknown func(pointer) pref.RawFields
-	setUnknown func(pointer, pref.RawFields)
-
-	extensionMap func(pointer) *extensionMap
+	reflectMessageInfo
 
 	// Information used by the fast-path methods.
 	methods piface.Methods
@@ -55,6 +48,17 @@
 	extensionFieldInfos   map[pref.ExtensionType]*extensionFieldInfo
 }
 
+type reflectMessageInfo struct {
+	fields map[pref.FieldNumber]*fieldInfo
+	oneofs map[pref.Name]*oneofInfo
+
+	getUnknown   func(pointer) pref.RawFields
+	setUnknown   func(pointer, pref.RawFields)
+	extensionMap func(pointer) *extensionMap
+
+	nilMessage atomicNilMessage
+}
+
 // exporter is a function that returns a reference to the ith field of v,
 // where v is a pointer to a struct. It returns nil if it does not support
 // exporting the requested field (e.g., already exported).
@@ -88,10 +92,9 @@
 	// This function is called in the hot path. Inline the sync.Once
 	// logic, since allocating a closure for Once.Do is expensive.
 	// Keep init small to ensure that it can be inlined.
-	if atomic.LoadUint32(&mi.initDone) == 1 {
-		return
+	if atomic.LoadUint32(&mi.initDone) == 0 {
+		mi.initOnce()
 	}
-	mi.initOnce()
 }
 
 func (mi *MessageInfo) initOnce() {
@@ -293,247 +296,8 @@
 	}
 }
 
-func (mi *MessageInfo) MessageOf(p interface{}) pref.Message {
-	return (*messageReflectWrapper)(mi.dataTypeOf(p))
-}
-
+// TODO: Move this to be on the reflect message instance.
 func (mi *MessageInfo) Methods() *piface.Methods {
 	mi.init()
 	return &mi.methods
 }
-
-func (mi *MessageInfo) dataTypeOf(p interface{}) *messageDataType {
-	// TODO: Remove this check? This API is primarily used by generated code,
-	// and should not violate this assumption. Leave this check in for now to
-	// provide some sanity checks during development. This can be removed if
-	// it proves to be detrimental to performance.
-	if reflect.TypeOf(p) != mi.GoType {
-		panic(fmt.Sprintf("type mismatch: got %T, want %v", p, mi.GoType))
-	}
-	return &messageDataType{pointerOfIface(p), mi}
-}
-
-// messageDataType is a tuple of a pointer to the message data and
-// a pointer to the message type.
-//
-// TODO: Unfortunately, we need to close over a pointer and MessageInfo,
-// which incurs an an allocation. This pair is similar to a Go interface,
-// which is essentially a tuple of the same thing. We can make this efficient
-// with reflect.NamedOf (see https://golang.org/issues/16522).
-//
-// With that hypothetical API, we could dynamically create a new named type
-// that has the same underlying type as MessageInfo.GoType, and
-// dynamically create methods that close over MessageInfo.
-// Since the new type would have the same underlying type, we could directly
-// convert between pointers of those types, giving us an efficient way to swap
-// out the method set.
-//
-// Barring the ability to dynamically create named types, the workaround is
-//	1. either to accept the cost of an allocation for this wrapper struct or
-//	2. generate more types and methods, at the expense of binary size increase.
-type messageDataType struct {
-	p  pointer
-	mi *MessageInfo
-}
-
-type messageReflectWrapper messageDataType
-
-func (m *messageReflectWrapper) Descriptor() pref.MessageDescriptor {
-	return m.mi.PBType.Descriptor()
-}
-func (m *messageReflectWrapper) New() pref.Message {
-	return m.mi.PBType.New()
-}
-func (m *messageReflectWrapper) Interface() pref.ProtoMessage {
-	if m, ok := m.ProtoUnwrap().(pref.ProtoMessage); ok {
-		return m
-	}
-	return (*messageIfaceWrapper)(m)
-}
-func (m *messageReflectWrapper) ProtoUnwrap() interface{} {
-	return m.p.AsIfaceOf(m.mi.GoType.Elem())
-}
-
-func (m *messageReflectWrapper) Range(f func(pref.FieldDescriptor, pref.Value) bool) {
-	m.mi.init()
-	for _, fi := range m.mi.fields {
-		if fi.has(m.p) {
-			if !f(fi.fieldDesc, fi.get(m.p)) {
-				return
-			}
-		}
-	}
-	m.mi.extensionMap(m.p).Range(f)
-}
-func (m *messageReflectWrapper) Has(fd pref.FieldDescriptor) bool {
-	if fi, xt := m.checkField(fd); fi != nil {
-		return fi.has(m.p)
-	} else {
-		return m.mi.extensionMap(m.p).Has(xt)
-	}
-}
-func (m *messageReflectWrapper) Clear(fd pref.FieldDescriptor) {
-	if fi, xt := m.checkField(fd); fi != nil {
-		fi.clear(m.p)
-	} else {
-		m.mi.extensionMap(m.p).Clear(xt)
-	}
-}
-func (m *messageReflectWrapper) Get(fd pref.FieldDescriptor) pref.Value {
-	if fi, xt := m.checkField(fd); fi != nil {
-		return fi.get(m.p)
-	} else {
-		return m.mi.extensionMap(m.p).Get(xt)
-	}
-}
-func (m *messageReflectWrapper) Set(fd pref.FieldDescriptor, v pref.Value) {
-	if fi, xt := m.checkField(fd); fi != nil {
-		fi.set(m.p, v)
-	} else {
-		m.mi.extensionMap(m.p).Set(xt, v)
-	}
-}
-func (m *messageReflectWrapper) Mutable(fd pref.FieldDescriptor) pref.Value {
-	if fi, xt := m.checkField(fd); fi != nil {
-		return fi.mutable(m.p)
-	} else {
-		return m.mi.extensionMap(m.p).Mutable(xt)
-	}
-}
-func (m *messageReflectWrapper) NewMessage(fd pref.FieldDescriptor) pref.Message {
-	if fi, xt := m.checkField(fd); fi != nil {
-		return fi.newMessage()
-	} else {
-		return xt.New().Message()
-	}
-}
-func (m *messageReflectWrapper) WhichOneof(od pref.OneofDescriptor) pref.FieldDescriptor {
-	m.mi.init()
-	if oi := m.mi.oneofs[od.Name()]; oi != nil && oi.oneofDesc == od {
-		return od.Fields().ByNumber(oi.which(m.p))
-	}
-	panic("invalid oneof descriptor")
-}
-func (m *messageReflectWrapper) GetUnknown() pref.RawFields {
-	m.mi.init()
-	return m.mi.getUnknown(m.p)
-}
-func (m *messageReflectWrapper) SetUnknown(b pref.RawFields) {
-	m.mi.init()
-	m.mi.setUnknown(m.p, b)
-}
-
-// checkField verifies that the provided field descriptor is valid.
-// Exactly one of the returned values is populated.
-func (m *messageReflectWrapper) checkField(fd pref.FieldDescriptor) (*fieldInfo, pref.ExtensionType) {
-	m.mi.init()
-	if fi := m.mi.fields[fd.Number()]; fi != nil {
-		if fi.fieldDesc != fd {
-			panic("mismatching field descriptor")
-		}
-		return fi, nil
-	}
-	if fd.IsExtension() {
-		if fd.ContainingMessage().FullName() != m.mi.PBType.FullName() {
-			// TODO: Should this be exact containing message descriptor match?
-			panic("mismatching containing message")
-		}
-		if !m.mi.PBType.ExtensionRanges().Has(fd.Number()) {
-			panic("invalid extension field")
-		}
-		return nil, fd.(pref.ExtensionType)
-	}
-	panic("invalid field descriptor")
-}
-
-type extensionMap map[int32]ExtensionField
-
-func (m *extensionMap) Range(f func(pref.FieldDescriptor, pref.Value) bool) {
-	if m != nil {
-		for _, x := range *m {
-			xt := x.GetType()
-			if !f(xt, xt.ValueOf(x.GetValue())) {
-				return
-			}
-		}
-	}
-}
-func (m *extensionMap) Has(xt pref.ExtensionType) (ok bool) {
-	if m != nil {
-		_, ok = (*m)[int32(xt.Number())]
-	}
-	return ok
-}
-func (m *extensionMap) Clear(xt pref.ExtensionType) {
-	delete(*m, int32(xt.Number()))
-}
-func (m *extensionMap) Get(xt pref.ExtensionType) pref.Value {
-	if m != nil {
-		if x, ok := (*m)[int32(xt.Number())]; ok {
-			return xt.ValueOf(x.GetValue())
-		}
-	}
-	if !isComposite(xt) {
-		return defaultValueOf(xt)
-	}
-	return frozenValueOf(xt.New())
-}
-func (m *extensionMap) Set(xt pref.ExtensionType, v pref.Value) {
-	if *m == nil {
-		*m = make(map[int32]ExtensionField)
-	}
-	var x ExtensionField
-	x.SetType(xt)
-	x.SetEagerValue(xt.InterfaceOf(v))
-	(*m)[int32(xt.Number())] = x
-}
-func (m *extensionMap) Mutable(xt pref.ExtensionType) pref.Value {
-	if !isComposite(xt) {
-		panic("invalid Mutable on field with non-composite type")
-	}
-	if x, ok := (*m)[int32(xt.Number())]; ok {
-		return xt.ValueOf(x.GetValue())
-	}
-	v := xt.New()
-	m.Set(xt, v)
-	return v
-}
-
-func isComposite(fd pref.FieldDescriptor) bool {
-	return fd.Kind() == pref.MessageKind || fd.Kind() == pref.GroupKind || fd.IsList() || fd.IsMap()
-}
-
-var _ pvalue.Unwrapper = (*messageReflectWrapper)(nil)
-
-type messageIfaceWrapper messageDataType
-
-func (m *messageIfaceWrapper) ProtoReflect() pref.Message {
-	return (*messageReflectWrapper)(m)
-}
-func (m *messageIfaceWrapper) XXX_Methods() *piface.Methods {
-	// TODO: Consider not recreating this on every call.
-	m.mi.init()
-	return &piface.Methods{
-		Flags:         piface.MethodFlagDeterministicMarshal,
-		MarshalAppend: m.marshalAppend,
-		Unmarshal:     m.unmarshal,
-		Size:          m.size,
-		IsInitialized: m.isInitialized,
-	}
-}
-func (m *messageIfaceWrapper) ProtoUnwrap() interface{} {
-	return m.p.AsIfaceOf(m.mi.GoType.Elem())
-}
-func (m *messageIfaceWrapper) marshalAppend(b []byte, _ pref.ProtoMessage, opts piface.MarshalOptions) ([]byte, error) {
-	return m.mi.marshalAppendPointer(b, m.p, newMarshalOptions(opts))
-}
-func (m *messageIfaceWrapper) unmarshal(b []byte, _ pref.ProtoMessage, opts piface.UnmarshalOptions) error {
-	_, err := m.mi.unmarshalPointer(b, m.p, 0, newUnmarshalOptions(opts))
-	return err
-}
-func (m *messageIfaceWrapper) size(msg pref.ProtoMessage) (size int) {
-	return m.mi.sizePointer(m.p, 0)
-}
-func (m *messageIfaceWrapper) isInitialized(_ pref.ProtoMessage) error {
-	return m.mi.isInitializedPointer(m.p)
-}
diff --git a/internal/impl/message_reflect.go b/internal/impl/message_reflect.go
new file mode 100644
index 0000000..a9eb7a9
--- /dev/null
+++ b/internal/impl/message_reflect.go
@@ -0,0 +1,220 @@
+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package impl
+
+import (
+	"fmt"
+	"reflect"
+
+	"google.golang.org/protobuf/internal/pragma"
+	pvalue "google.golang.org/protobuf/internal/value"
+	pref "google.golang.org/protobuf/reflect/protoreflect"
+	piface "google.golang.org/protobuf/runtime/protoiface"
+)
+
+// MessageState is a data structure that is nested as the first field in a
+// concrete message. It provides a way to implement the ProtoReflect method
+// in an allocation-free way without needing to have a shadow Go type generated
+// for every message type. This technique only works using unsafe.
+//
+//
+// Example generated code:
+//
+//	type M struct {
+//		state protoimpl.MessageState
+//
+//		Field1 int32
+//		Field2 string
+//		Field3 *BarMessage
+//		...
+//	}
+//
+//	func (m *M) ProtoReflect() protoreflect.Message {
+//		mi := &file_fizz_buzz_proto_msgInfos[5]
+//		if protoimpl.UnsafeEnabled && m != nil {
+//			ms := protoimpl.X.MessageStateOf(Pointer(m))
+//			if ms.LoadMessageInfo() == nil {
+//				ms.StoreMessageInfo(mi)
+//			}
+//			return ms
+//		}
+//		return mi.MessageOf(m)
+//	}
+//
+// The MessageState type holds a *MessageInfo, which must be atomically set to
+// the message info associated with a given message instance.
+// By unsafely converting a *M into a *MessageState, the MessageState object
+// has access to all the information needed to implement protobuf reflection.
+// It has access to the message info as its first field, and a pointer to the
+// MessageState is identical to a pointer to the concrete message value.
+//
+//
+// Requirements:
+//	• The type M must implement protoreflect.ProtoMessage.
+//	• The address of m must not be nil.
+//	• The address of m and the address of m.state must be equal,
+//	even though they are different Go types.
+type MessageState struct {
+	pragma.NoUnkeyedLiterals
+	pragma.DoNotCompare
+	pragma.DoNotCopy
+
+	mi *MessageInfo
+}
+
+type messageState MessageState
+
+var (
+	_ pref.Message     = (*messageState)(nil)
+	_ pvalue.Unwrapper = (*messageState)(nil)
+)
+
+// messageDataType is a tuple of a pointer to the message data and
+// a pointer to the message type. It is a generalized way of providing a
+// reflective view over a message instance. The disadvantage of this approach
+// is the need to allocate this tuple of 16B.
+type messageDataType struct {
+	p  pointer
+	mi *MessageInfo
+}
+
+type (
+	messageIfaceWrapper   messageDataType
+	messageReflectWrapper messageDataType
+)
+
+var (
+	_ pref.Message      = (*messageReflectWrapper)(nil)
+	_ pvalue.Unwrapper  = (*messageReflectWrapper)(nil)
+	_ pref.ProtoMessage = (*messageIfaceWrapper)(nil)
+	_ pvalue.Unwrapper  = (*messageIfaceWrapper)(nil)
+)
+
+// MessageOf returns a reflective view over a message. The input must be a
+// pointer to a named Go struct. If the provided type has a ProtoReflect method,
+// it must be implemented by calling this method.
+func (mi *MessageInfo) MessageOf(m interface{}) pref.Message {
+	// TODO: Switch the input to be an opaque Pointer.
+	if reflect.TypeOf(m) != mi.GoType {
+		panic(fmt.Sprintf("type mismatch: got %T, want %v", m, mi.GoType))
+	}
+	p := pointerOfIface(m)
+	if p.IsNil() {
+		return mi.nilMessage.Init(mi)
+	}
+	return &messageReflectWrapper{p, mi}
+}
+
+func (m *messageReflectWrapper) pointer() pointer { return m.p }
+
+func (m *messageIfaceWrapper) ProtoReflect() pref.Message {
+	return (*messageReflectWrapper)(m)
+}
+func (m *messageIfaceWrapper) XXX_Methods() *piface.Methods {
+	// TODO: Consider not recreating this on every call.
+	m.mi.init()
+	return &piface.Methods{
+		Flags:         piface.MethodFlagDeterministicMarshal,
+		MarshalAppend: m.marshalAppend,
+		Unmarshal:     m.unmarshal,
+		Size:          m.size,
+		IsInitialized: m.isInitialized,
+	}
+}
+func (m *messageIfaceWrapper) ProtoUnwrap() interface{} {
+	return m.p.AsIfaceOf(m.mi.GoType.Elem())
+}
+func (m *messageIfaceWrapper) marshalAppend(b []byte, _ pref.ProtoMessage, opts piface.MarshalOptions) ([]byte, error) {
+	return m.mi.marshalAppendPointer(b, m.p, newMarshalOptions(opts))
+}
+func (m *messageIfaceWrapper) unmarshal(b []byte, _ pref.ProtoMessage, opts piface.UnmarshalOptions) error {
+	_, err := m.mi.unmarshalPointer(b, m.p, 0, newUnmarshalOptions(opts))
+	return err
+}
+func (m *messageIfaceWrapper) size(msg pref.ProtoMessage) (size int) {
+	return m.mi.sizePointer(m.p, 0)
+}
+func (m *messageIfaceWrapper) isInitialized(_ pref.ProtoMessage) error {
+	return m.mi.isInitializedPointer(m.p)
+}
+
+type extensionMap map[int32]ExtensionField
+
+func (m *extensionMap) Range(f func(pref.FieldDescriptor, pref.Value) bool) {
+	if m != nil {
+		for _, x := range *m {
+			xt := x.GetType()
+			if !f(xt, xt.ValueOf(x.GetValue())) {
+				return
+			}
+		}
+	}
+}
+func (m *extensionMap) Has(xt pref.ExtensionType) (ok bool) {
+	if m != nil {
+		_, ok = (*m)[int32(xt.Number())]
+	}
+	return ok
+}
+func (m *extensionMap) Clear(xt pref.ExtensionType) {
+	delete(*m, int32(xt.Number()))
+}
+func (m *extensionMap) Get(xt pref.ExtensionType) pref.Value {
+	if m != nil {
+		if x, ok := (*m)[int32(xt.Number())]; ok {
+			return xt.ValueOf(x.GetValue())
+		}
+	}
+	if !isComposite(xt) {
+		return defaultValueOf(xt)
+	}
+	return frozenValueOf(xt.New())
+}
+func (m *extensionMap) Set(xt pref.ExtensionType, v pref.Value) {
+	if *m == nil {
+		*m = make(map[int32]ExtensionField)
+	}
+	var x ExtensionField
+	x.SetType(xt)
+	x.SetEagerValue(xt.InterfaceOf(v))
+	(*m)[int32(xt.Number())] = x
+}
+func (m *extensionMap) Mutable(xt pref.ExtensionType) pref.Value {
+	if !isComposite(xt) {
+		panic("invalid Mutable on field with non-composite type")
+	}
+	if x, ok := (*m)[int32(xt.Number())]; ok {
+		return xt.ValueOf(x.GetValue())
+	}
+	v := xt.New()
+	m.Set(xt, v)
+	return v
+}
+
+func isComposite(fd pref.FieldDescriptor) bool {
+	return fd.Kind() == pref.MessageKind || fd.Kind() == pref.GroupKind || fd.IsList() || fd.IsMap()
+}
+
+// checkField verifies that the provided field descriptor is valid.
+// Exactly one of the returned values is populated.
+func (mi *MessageInfo) checkField(fd pref.FieldDescriptor) (*fieldInfo, pref.ExtensionType) {
+	if fi := mi.fields[fd.Number()]; fi != nil {
+		if fi.fieldDesc != fd {
+			panic("mismatching field descriptor")
+		}
+		return fi, nil
+	}
+	if fd.IsExtension() {
+		if fd.ContainingMessage().FullName() != mi.PBType.FullName() {
+			// TODO: Should this be exact containing message descriptor match?
+			panic("mismatching containing message")
+		}
+		if !mi.PBType.ExtensionRanges().Has(fd.Number()) {
+			panic("invalid extension field")
+		}
+		return nil, fd.(pref.ExtensionType)
+	}
+	panic("invalid field descriptor")
+}
diff --git a/internal/impl/message_reflect_gen.go b/internal/impl/message_reflect_gen.go
new file mode 100644
index 0000000..41ba0b9
--- /dev/null
+++ b/internal/impl/message_reflect_gen.go
@@ -0,0 +1,190 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style.
+// license that can be found in the LICENSE file.
+
+// Code generated by generate-types. DO NOT EDIT.
+
+package impl
+
+import (
+	"google.golang.org/protobuf/reflect/protoreflect"
+)
+
+func (m *messageState) Descriptor() protoreflect.MessageDescriptor {
+	return m.mi.PBType.Descriptor()
+}
+func (m *messageState) New() protoreflect.Message {
+	return m.mi.PBType.New()
+}
+func (m *messageState) Interface() protoreflect.ProtoMessage {
+	return m.ProtoUnwrap().(protoreflect.ProtoMessage)
+}
+func (m *messageState) ProtoUnwrap() interface{} {
+	return m.pointer().AsIfaceOf(m.mi.GoType.Elem())
+}
+
+func (m *messageState) Range(f func(protoreflect.FieldDescriptor, protoreflect.Value) bool) {
+	m.mi.init()
+	for _, fi := range m.mi.fields {
+		if fi.has(m.pointer()) {
+			if !f(fi.fieldDesc, fi.get(m.pointer())) {
+				return
+			}
+		}
+	}
+	m.mi.extensionMap(m.pointer()).Range(f)
+}
+func (m *messageState) Has(fd protoreflect.FieldDescriptor) bool {
+	m.mi.init()
+	if fi, xt := m.mi.checkField(fd); fi != nil {
+		return fi.has(m.pointer())
+	} else {
+		return m.mi.extensionMap(m.pointer()).Has(xt)
+	}
+}
+func (m *messageState) Clear(fd protoreflect.FieldDescriptor) {
+	m.mi.init()
+	if fi, xt := m.mi.checkField(fd); fi != nil {
+		fi.clear(m.pointer())
+	} else {
+		m.mi.extensionMap(m.pointer()).Clear(xt)
+	}
+}
+func (m *messageState) Get(fd protoreflect.FieldDescriptor) protoreflect.Value {
+	m.mi.init()
+	if fi, xt := m.mi.checkField(fd); fi != nil {
+		return fi.get(m.pointer())
+	} else {
+		return m.mi.extensionMap(m.pointer()).Get(xt)
+	}
+}
+func (m *messageState) Set(fd protoreflect.FieldDescriptor, v protoreflect.Value) {
+	m.mi.init()
+	if fi, xt := m.mi.checkField(fd); fi != nil {
+		fi.set(m.pointer(), v)
+	} else {
+		m.mi.extensionMap(m.pointer()).Set(xt, v)
+	}
+}
+func (m *messageState) Mutable(fd protoreflect.FieldDescriptor) protoreflect.Value {
+	m.mi.init()
+	if fi, xt := m.mi.checkField(fd); fi != nil {
+		return fi.mutable(m.pointer())
+	} else {
+		return m.mi.extensionMap(m.pointer()).Mutable(xt)
+	}
+}
+func (m *messageState) NewMessage(fd protoreflect.FieldDescriptor) protoreflect.Message {
+	m.mi.init()
+	if fi, xt := m.mi.checkField(fd); fi != nil {
+		return fi.newMessage()
+	} else {
+		return xt.New().Message()
+	}
+}
+func (m *messageState) WhichOneof(od protoreflect.OneofDescriptor) protoreflect.FieldDescriptor {
+	m.mi.init()
+	if oi := m.mi.oneofs[od.Name()]; oi != nil && oi.oneofDesc == od {
+		return od.Fields().ByNumber(oi.which(m.pointer()))
+	}
+	panic("invalid oneof descriptor")
+}
+func (m *messageState) GetUnknown() protoreflect.RawFields {
+	m.mi.init()
+	return m.mi.getUnknown(m.pointer())
+}
+func (m *messageState) SetUnknown(b protoreflect.RawFields) {
+	m.mi.init()
+	m.mi.setUnknown(m.pointer(), b)
+}
+
+func (m *messageReflectWrapper) Descriptor() protoreflect.MessageDescriptor {
+	return m.mi.PBType.Descriptor()
+}
+func (m *messageReflectWrapper) New() protoreflect.Message {
+	return m.mi.PBType.New()
+}
+func (m *messageReflectWrapper) Interface() protoreflect.ProtoMessage {
+	if m, ok := m.ProtoUnwrap().(protoreflect.ProtoMessage); ok {
+		return m
+	}
+	return (*messageIfaceWrapper)(m)
+}
+func (m *messageReflectWrapper) ProtoUnwrap() interface{} {
+	return m.pointer().AsIfaceOf(m.mi.GoType.Elem())
+}
+
+func (m *messageReflectWrapper) Range(f func(protoreflect.FieldDescriptor, protoreflect.Value) bool) {
+	m.mi.init()
+	for _, fi := range m.mi.fields {
+		if fi.has(m.pointer()) {
+			if !f(fi.fieldDesc, fi.get(m.pointer())) {
+				return
+			}
+		}
+	}
+	m.mi.extensionMap(m.pointer()).Range(f)
+}
+func (m *messageReflectWrapper) Has(fd protoreflect.FieldDescriptor) bool {
+	m.mi.init()
+	if fi, xt := m.mi.checkField(fd); fi != nil {
+		return fi.has(m.pointer())
+	} else {
+		return m.mi.extensionMap(m.pointer()).Has(xt)
+	}
+}
+func (m *messageReflectWrapper) Clear(fd protoreflect.FieldDescriptor) {
+	m.mi.init()
+	if fi, xt := m.mi.checkField(fd); fi != nil {
+		fi.clear(m.pointer())
+	} else {
+		m.mi.extensionMap(m.pointer()).Clear(xt)
+	}
+}
+func (m *messageReflectWrapper) Get(fd protoreflect.FieldDescriptor) protoreflect.Value {
+	m.mi.init()
+	if fi, xt := m.mi.checkField(fd); fi != nil {
+		return fi.get(m.pointer())
+	} else {
+		return m.mi.extensionMap(m.pointer()).Get(xt)
+	}
+}
+func (m *messageReflectWrapper) Set(fd protoreflect.FieldDescriptor, v protoreflect.Value) {
+	m.mi.init()
+	if fi, xt := m.mi.checkField(fd); fi != nil {
+		fi.set(m.pointer(), v)
+	} else {
+		m.mi.extensionMap(m.pointer()).Set(xt, v)
+	}
+}
+func (m *messageReflectWrapper) Mutable(fd protoreflect.FieldDescriptor) protoreflect.Value {
+	m.mi.init()
+	if fi, xt := m.mi.checkField(fd); fi != nil {
+		return fi.mutable(m.pointer())
+	} else {
+		return m.mi.extensionMap(m.pointer()).Mutable(xt)
+	}
+}
+func (m *messageReflectWrapper) NewMessage(fd protoreflect.FieldDescriptor) protoreflect.Message {
+	m.mi.init()
+	if fi, xt := m.mi.checkField(fd); fi != nil {
+		return fi.newMessage()
+	} else {
+		return xt.New().Message()
+	}
+}
+func (m *messageReflectWrapper) WhichOneof(od protoreflect.OneofDescriptor) protoreflect.FieldDescriptor {
+	m.mi.init()
+	if oi := m.mi.oneofs[od.Name()]; oi != nil && oi.oneofDesc == od {
+		return od.Fields().ByNumber(oi.which(m.pointer()))
+	}
+	panic("invalid oneof descriptor")
+}
+func (m *messageReflectWrapper) GetUnknown() protoreflect.RawFields {
+	m.mi.init()
+	return m.mi.getUnknown(m.pointer())
+}
+func (m *messageReflectWrapper) SetUnknown(b protoreflect.RawFields) {
+	m.mi.init()
+	m.mi.setUnknown(m.pointer(), b)
+}
diff --git a/internal/impl/message_test.go b/internal/impl/message_test.go
index af7196d..f2fd290 100644
--- a/internal/impl/message_test.go
+++ b/internal/impl/message_test.go
@@ -8,7 +8,9 @@
 	"fmt"
 	"math"
 	"reflect"
+	"runtime"
 	"strings"
+	"sync"
 	"testing"
 
 	cmp "github.com/google/go-cmp/cmp"
@@ -23,6 +25,7 @@
 	"google.golang.org/protobuf/reflect/prototype"
 
 	proto2_20180125 "google.golang.org/protobuf/internal/testprotos/legacy/proto2.v1.0.0-20180125-92554152"
+	testpb "google.golang.org/protobuf/internal/testprotos/test"
 	"google.golang.org/protobuf/types/descriptorpb"
 )
 
@@ -1435,3 +1438,71 @@
 	}
 	return strings.Join(ss, ".")
 }
+
+// The MessageState implementation makes the assumption that when a
+// concrete message is unsafe casted as a *MessageState, the Go GC does
+// not reclaim the memory for the remainder of the concrete message.
+func TestUnsafeAssumptions(t *testing.T) {
+	if !pimpl.UnsafeEnabled {
+		t.Skip()
+	}
+
+	var wg sync.WaitGroup
+	for i := 0; i < 10; i++ {
+		wg.Add(1)
+		go func() {
+			var ms [10]pref.Message
+
+			// Store the message only in its reflective form.
+			// Trigger the GC after each iteration.
+			for j := 0; j < 10; j++ {
+				ms[j] = (&testpb.TestAllTypes{
+					OptionalInt32: scalar.Int32(int32(j)),
+					OptionalFloat: scalar.Float32(float32(j)),
+					RepeatedInt32: []int32{int32(j)},
+					RepeatedFloat: []float32{float32(j)},
+					DefaultInt32:  scalar.Int32(int32(j)),
+					DefaultFloat:  scalar.Float32(float32(j)),
+				}).ProtoReflect()
+				runtime.GC()
+			}
+
+			// Convert the reflective form back into a concrete form.
+			// Verify that the values written previously are still the same.
+			for j := 0; j < 10; j++ {
+				switch m := ms[j].Interface().(*testpb.TestAllTypes); {
+				case m.GetOptionalInt32() != int32(j):
+				case m.GetOptionalFloat() != float32(j):
+				case m.GetRepeatedInt32()[0] != int32(j):
+				case m.GetRepeatedFloat()[0] != float32(j):
+				case m.GetDefaultInt32() != int32(j):
+				case m.GetDefaultFloat() != float32(j):
+				default:
+					continue
+				}
+				t.Error("memory corrupted detected")
+			}
+			defer wg.Done()
+		}()
+	}
+	wg.Wait()
+}
+
+func BenchmarkName(b *testing.B) {
+	var sink pref.FullName
+	b.Run("Value", func(b *testing.B) {
+		b.ReportAllocs()
+		m := new(descriptorpb.FileDescriptorProto)
+		for i := 0; i < b.N; i++ {
+			sink = m.ProtoReflect().Descriptor().FullName()
+		}
+	})
+	b.Run("Nil", func(b *testing.B) {
+		b.ReportAllocs()
+		m := (*descriptorpb.FileDescriptorProto)(nil)
+		for i := 0; i < b.N; i++ {
+			sink = m.ProtoReflect().Descriptor().FullName()
+		}
+	})
+	runtime.KeepAlive(sink)
+}
diff --git a/internal/impl/pointer_reflect.go b/internal/impl/pointer_reflect.go
index 0a9c53e..020af8e 100644
--- a/internal/impl/pointer_reflect.go
+++ b/internal/impl/pointer_reflect.go
@@ -9,10 +9,14 @@
 import (
 	"fmt"
 	"reflect"
+	"sync"
 )
 
 const UnsafeEnabled = false
 
+// Pointer is an opaque pointer type.
+type Pointer interface{}
+
 // offset represents the offset to a struct field, accessible from a pointer.
 // The offset is the field index into a struct.
 type offset struct {
@@ -46,6 +50,11 @@
 // pointer is an abstract representation of a pointer to a struct or field.
 type pointer struct{ v reflect.Value }
 
+// pointerOf returns p as a pointer.
+func pointerOf(p Pointer) pointer {
+	return pointerOfIface(p)
+}
+
 // pointerOfValue returns v as a pointer.
 func pointerOfValue(v reflect.Value) pointer {
 	return pointer{v: v}
@@ -146,3 +155,21 @@
 func (p pointer) SetPointer(v pointer) {
 	p.v.Elem().Set(v.v)
 }
+
+func (Export) MessageStateOf(p Pointer) *messageState     { panic("not supported") }
+func (ms *messageState) pointer() pointer                 { panic("not supported") }
+func (ms *messageState) LoadMessageInfo() *MessageInfo    { panic("not supported") }
+func (ms *messageState) StoreMessageInfo(mi *MessageInfo) { panic("not supported") }
+
+type atomicNilMessage struct {
+	once sync.Once
+	m    messageReflectWrapper
+}
+
+func (m *atomicNilMessage) Init(mi *MessageInfo) *messageReflectWrapper {
+	m.once.Do(func() {
+		m.m.p = pointerOfIface(reflect.Zero(mi.GoType).Interface())
+		m.m.mi = mi
+	})
+	return &m.m
+}
diff --git a/internal/impl/pointer_unsafe.go b/internal/impl/pointer_unsafe.go
index ceca3b2..ab0d6ee 100644
--- a/internal/impl/pointer_unsafe.go
+++ b/internal/impl/pointer_unsafe.go
@@ -8,11 +8,15 @@
 
 import (
 	"reflect"
+	"sync/atomic"
 	"unsafe"
 )
 
 const UnsafeEnabled = true
 
+// Pointer is an opaque pointer type.
+type Pointer unsafe.Pointer
+
 // offset represents the offset to a struct field, accessible from a pointer.
 // The offset is the byte offset to the field from the start of the struct.
 type offset uintptr
@@ -34,6 +38,11 @@
 // pointer is a pointer to a message struct or field.
 type pointer struct{ p unsafe.Pointer }
 
+// pointerOf returns p as a pointer.
+func pointerOf(p Pointer) pointer {
+	return pointer{p: unsafe.Pointer(p)}
+}
+
 // pointerOfValue returns v as a pointer.
 func pointerOfValue(v reflect.Value) pointer {
 	return pointer{p: unsafe.Pointer(v.Pointer())}
@@ -125,3 +134,30 @@
 func (p pointer) SetPointer(v pointer) {
 	*(*unsafe.Pointer)(p.p) = (unsafe.Pointer)(v.p)
 }
+
+// Static check that MessageState does not exceed the size of a pointer.
+const _ = uint(unsafe.Sizeof(unsafe.Pointer(nil)) - unsafe.Sizeof(MessageState{}))
+
+func (Export) MessageStateOf(p Pointer) *messageState {
+	// Super-tricky - see documentation on MessageState.
+	return (*messageState)(unsafe.Pointer(p))
+}
+func (ms *messageState) pointer() pointer {
+	// Super-tricky - see documentation on MessageState.
+	return pointer{p: unsafe.Pointer(ms)}
+}
+func (ms *messageState) LoadMessageInfo() *MessageInfo {
+	return (*MessageInfo)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&ms.mi))))
+}
+func (ms *messageState) StoreMessageInfo(mi *MessageInfo) {
+	atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&ms.mi)), unsafe.Pointer(mi))
+}
+
+type atomicNilMessage struct{ m messageReflectWrapper }
+
+func (m *atomicNilMessage) Init(mi *MessageInfo) *messageReflectWrapper {
+	if (*messageReflectWrapper)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&m.m.mi)))) == nil {
+		atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&m.m.mi)), unsafe.Pointer(mi))
+	}
+	return &m.m
+}