internal/impl: fix race over messageState.mi

The messageState.mi field is atomically checked and set
in generated code to the *MessageInfo associated with that message.
However, the messageState type accesses the mi field without
any atomic loads, thus being a potential race.
We fix this by always calling a messageInfo method that performs
a atomic.LoadPointer on the *MessageInfo.

There is no performance effect from this change on x86 since
an atomic.LoadPointer is identical to a MOV instruction.
From an assembly perspective, there was no memory race previously.
However, the lack of an atomic.LoadPointer meant that the compiler
could in theory reorder the "normal" load to produce truly racy code.

Change-Id: I8afefaf35c1916872781abc0239cbb63d62edf16
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/189017
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/internal/impl/message_reflect_gen.go b/internal/impl/message_reflect_gen.go
index 40447a6..e2f6d17 100644
--- a/internal/impl/message_reflect_gen.go
+++ b/internal/impl/message_reflect_gen.go
@@ -12,23 +12,23 @@
 )
 
 func (m *messageState) Descriptor() protoreflect.MessageDescriptor {
-	return m.mi.PBType.Descriptor()
+	return m.messageInfo().PBType.Descriptor()
 }
 func (m *messageState) Type() protoreflect.MessageType {
-	return m.mi.PBType
+	return m.messageInfo().PBType
 }
 func (m *messageState) New() protoreflect.Message {
-	return m.mi.PBType.New()
+	return m.messageInfo().PBType.New()
 }
 func (m *messageState) Interface() protoreflect.ProtoMessage {
 	return m.ProtoUnwrap().(protoreflect.ProtoMessage)
 }
 func (m *messageState) ProtoUnwrap() interface{} {
-	return m.pointer().AsIfaceOf(m.mi.GoType.Elem())
+	return m.pointer().AsIfaceOf(m.messageInfo().GoType.Elem())
 }
 func (m *messageState) ProtoMethods() *protoiface.Methods {
-	m.mi.init()
-	return &m.mi.methods
+	m.messageInfo().init()
+	return &m.messageInfo().methods
 }
 
 // ProtoMessageInfo is a pseudo-internal API for allowing the v1 code
@@ -37,92 +37,92 @@
 // WARNING: This method is exempt from the compatibility promise and
 // may be removed in the future without warning.
 func (m *messageState) ProtoMessageInfo() *MessageInfo {
-	return m.mi
+	return m.messageInfo()
 }
 
 func (m *messageState) Range(f func(protoreflect.FieldDescriptor, protoreflect.Value) bool) {
-	m.mi.init()
-	for _, fi := range m.mi.fields {
+	m.messageInfo().init()
+	for _, fi := range m.messageInfo().fields {
 		if fi.has(m.pointer()) {
 			if !f(fi.fieldDesc, fi.get(m.pointer())) {
 				return
 			}
 		}
 	}
-	m.mi.extensionMap(m.pointer()).Range(f)
+	m.messageInfo().extensionMap(m.pointer()).Range(f)
 }
 func (m *messageState) Has(fd protoreflect.FieldDescriptor) bool {
-	m.mi.init()
-	if fi, xt := m.mi.checkField(fd); fi != nil {
+	m.messageInfo().init()
+	if fi, xt := m.messageInfo().checkField(fd); fi != nil {
 		return fi.has(m.pointer())
 	} else {
-		return m.mi.extensionMap(m.pointer()).Has(xt)
+		return m.messageInfo().extensionMap(m.pointer()).Has(xt)
 	}
 }
 func (m *messageState) Clear(fd protoreflect.FieldDescriptor) {
-	m.mi.init()
-	if fi, xt := m.mi.checkField(fd); fi != nil {
+	m.messageInfo().init()
+	if fi, xt := m.messageInfo().checkField(fd); fi != nil {
 		fi.clear(m.pointer())
 	} else {
-		m.mi.extensionMap(m.pointer()).Clear(xt)
+		m.messageInfo().extensionMap(m.pointer()).Clear(xt)
 	}
 }
 func (m *messageState) Get(fd protoreflect.FieldDescriptor) protoreflect.Value {
-	m.mi.init()
-	if fi, xt := m.mi.checkField(fd); fi != nil {
+	m.messageInfo().init()
+	if fi, xt := m.messageInfo().checkField(fd); fi != nil {
 		return fi.get(m.pointer())
 	} else {
-		return m.mi.extensionMap(m.pointer()).Get(xt)
+		return m.messageInfo().extensionMap(m.pointer()).Get(xt)
 	}
 }
 func (m *messageState) Set(fd protoreflect.FieldDescriptor, v protoreflect.Value) {
-	m.mi.init()
-	if fi, xt := m.mi.checkField(fd); fi != nil {
+	m.messageInfo().init()
+	if fi, xt := m.messageInfo().checkField(fd); fi != nil {
 		fi.set(m.pointer(), v)
 	} else {
-		m.mi.extensionMap(m.pointer()).Set(xt, v)
+		m.messageInfo().extensionMap(m.pointer()).Set(xt, v)
 	}
 }
 func (m *messageState) Mutable(fd protoreflect.FieldDescriptor) protoreflect.Value {
-	m.mi.init()
-	if fi, xt := m.mi.checkField(fd); fi != nil {
+	m.messageInfo().init()
+	if fi, xt := m.messageInfo().checkField(fd); fi != nil {
 		return fi.mutable(m.pointer())
 	} else {
-		return m.mi.extensionMap(m.pointer()).Mutable(xt)
+		return m.messageInfo().extensionMap(m.pointer()).Mutable(xt)
 	}
 }
 func (m *messageState) NewMessage(fd protoreflect.FieldDescriptor) protoreflect.Message {
-	m.mi.init()
-	if fi, xt := m.mi.checkField(fd); fi != nil {
+	m.messageInfo().init()
+	if fi, xt := m.messageInfo().checkField(fd); fi != nil {
 		return fi.newMessage()
 	} else {
 		return xt.New().Message()
 	}
 }
 func (m *messageState) WhichOneof(od protoreflect.OneofDescriptor) protoreflect.FieldDescriptor {
-	m.mi.init()
-	if oi := m.mi.oneofs[od.Name()]; oi != nil && oi.oneofDesc == od {
+	m.messageInfo().init()
+	if oi := m.messageInfo().oneofs[od.Name()]; oi != nil && oi.oneofDesc == od {
 		return od.Fields().ByNumber(oi.which(m.pointer()))
 	}
 	panic("invalid oneof descriptor")
 }
 func (m *messageState) GetUnknown() protoreflect.RawFields {
-	m.mi.init()
-	return m.mi.getUnknown(m.pointer())
+	m.messageInfo().init()
+	return m.messageInfo().getUnknown(m.pointer())
 }
 func (m *messageState) SetUnknown(b protoreflect.RawFields) {
-	m.mi.init()
-	m.mi.setUnknown(m.pointer(), b)
+	m.messageInfo().init()
+	m.messageInfo().setUnknown(m.pointer(), b)
 }
 
 func (m *messageReflectWrapper) Descriptor() protoreflect.MessageDescriptor {
-	return m.mi.PBType.Descriptor()
+	return m.messageInfo().PBType.Descriptor()
 }
 func (m *messageReflectWrapper) Type() protoreflect.MessageType {
-	return m.mi.PBType
+	return m.messageInfo().PBType
 }
 func (m *messageReflectWrapper) New() protoreflect.Message {
-	return m.mi.PBType.New()
+	return m.messageInfo().PBType.New()
 }
 func (m *messageReflectWrapper) Interface() protoreflect.ProtoMessage {
 	if m, ok := m.ProtoUnwrap().(protoreflect.ProtoMessage); ok {
@@ -131,11 +131,11 @@
 	return (*messageIfaceWrapper)(m)
 }
 func (m *messageReflectWrapper) ProtoUnwrap() interface{} {
-	return m.pointer().AsIfaceOf(m.mi.GoType.Elem())
+	return m.pointer().AsIfaceOf(m.messageInfo().GoType.Elem())
 }
 func (m *messageReflectWrapper) ProtoMethods() *protoiface.Methods {
-	m.mi.init()
-	return &m.mi.methods
+	m.messageInfo().init()
+	return &m.messageInfo().methods
 }
 
 // ProtoMessageInfo is a pseudo-internal API for allowing the v1 code
@@ -144,80 +144,80 @@
 // WARNING: This method is exempt from the compatibility promise and
 // may be removed in the future without warning.
 func (m *messageReflectWrapper) ProtoMessageInfo() *MessageInfo {
-	return m.mi
+	return m.messageInfo()
 }
 
 func (m *messageReflectWrapper) Range(f func(protoreflect.FieldDescriptor, protoreflect.Value) bool) {
-	m.mi.init()
-	for _, fi := range m.mi.fields {
+	m.messageInfo().init()
+	for _, fi := range m.messageInfo().fields {
 		if fi.has(m.pointer()) {
 			if !f(fi.fieldDesc, fi.get(m.pointer())) {
 				return
 			}
 		}
 	}
-	m.mi.extensionMap(m.pointer()).Range(f)
+	m.messageInfo().extensionMap(m.pointer()).Range(f)
 }
 func (m *messageReflectWrapper) Has(fd protoreflect.FieldDescriptor) bool {
-	m.mi.init()
-	if fi, xt := m.mi.checkField(fd); fi != nil {
+	m.messageInfo().init()
+	if fi, xt := m.messageInfo().checkField(fd); fi != nil {
 		return fi.has(m.pointer())
 	} else {
-		return m.mi.extensionMap(m.pointer()).Has(xt)
+		return m.messageInfo().extensionMap(m.pointer()).Has(xt)
 	}
 }
 func (m *messageReflectWrapper) Clear(fd protoreflect.FieldDescriptor) {
-	m.mi.init()
-	if fi, xt := m.mi.checkField(fd); fi != nil {
+	m.messageInfo().init()
+	if fi, xt := m.messageInfo().checkField(fd); fi != nil {
 		fi.clear(m.pointer())
 	} else {
-		m.mi.extensionMap(m.pointer()).Clear(xt)
+		m.messageInfo().extensionMap(m.pointer()).Clear(xt)
 	}
 }
 func (m *messageReflectWrapper) Get(fd protoreflect.FieldDescriptor) protoreflect.Value {
-	m.mi.init()
-	if fi, xt := m.mi.checkField(fd); fi != nil {
+	m.messageInfo().init()
+	if fi, xt := m.messageInfo().checkField(fd); fi != nil {
 		return fi.get(m.pointer())
 	} else {
-		return m.mi.extensionMap(m.pointer()).Get(xt)
+		return m.messageInfo().extensionMap(m.pointer()).Get(xt)
 	}
 }
 func (m *messageReflectWrapper) Set(fd protoreflect.FieldDescriptor, v protoreflect.Value) {
-	m.mi.init()
-	if fi, xt := m.mi.checkField(fd); fi != nil {
+	m.messageInfo().init()
+	if fi, xt := m.messageInfo().checkField(fd); fi != nil {
 		fi.set(m.pointer(), v)
 	} else {
-		m.mi.extensionMap(m.pointer()).Set(xt, v)
+		m.messageInfo().extensionMap(m.pointer()).Set(xt, v)
 	}
 }
 func (m *messageReflectWrapper) Mutable(fd protoreflect.FieldDescriptor) protoreflect.Value {
-	m.mi.init()
-	if fi, xt := m.mi.checkField(fd); fi != nil {
+	m.messageInfo().init()
+	if fi, xt := m.messageInfo().checkField(fd); fi != nil {
 		return fi.mutable(m.pointer())
 	} else {
-		return m.mi.extensionMap(m.pointer()).Mutable(xt)
+		return m.messageInfo().extensionMap(m.pointer()).Mutable(xt)
 	}
 }
 func (m *messageReflectWrapper) NewMessage(fd protoreflect.FieldDescriptor) protoreflect.Message {
-	m.mi.init()
-	if fi, xt := m.mi.checkField(fd); fi != nil {
+	m.messageInfo().init()
+	if fi, xt := m.messageInfo().checkField(fd); fi != nil {
 		return fi.newMessage()
 	} else {
 		return xt.New().Message()
 	}
 }
 func (m *messageReflectWrapper) WhichOneof(od protoreflect.OneofDescriptor) protoreflect.FieldDescriptor {
-	m.mi.init()
-	if oi := m.mi.oneofs[od.Name()]; oi != nil && oi.oneofDesc == od {
+	m.messageInfo().init()
+	if oi := m.messageInfo().oneofs[od.Name()]; oi != nil && oi.oneofDesc == od {
 		return od.Fields().ByNumber(oi.which(m.pointer()))
 	}
 	panic("invalid oneof descriptor")
 }
 func (m *messageReflectWrapper) GetUnknown() protoreflect.RawFields {
-	m.mi.init()
-	return m.mi.getUnknown(m.pointer())
+	m.messageInfo().init()
+	return m.messageInfo().getUnknown(m.pointer())
 }
 func (m *messageReflectWrapper) SetUnknown(b protoreflect.RawFields) {
-	m.mi.init()
-	m.mi.setUnknown(m.pointer(), b)
+	m.messageInfo().init()
+	m.messageInfo().setUnknown(m.pointer(), b)
 }