internal/impl: fix race in aberrant message logic

Previously, when aberrantLoadMessageDesc returned it was guaranteed
to have initialized the current message through the use of the done signal.
However, this does not guarantee that the descriptor for a cylic reference
has also finished initialization.

Rather than add more complicated logic to wait until all cyclic references
have finished initializing, just add a global lock for the entire
aberrantLoadMessageDesc function.

This slows down performance, but is easier to reason about.

Change-Id: I4cdae8b955f71ee40fa6979f5a8d548d9749042c
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/184657
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/internal/impl/legacy_message.go b/internal/impl/legacy_message.go
index 99c8e0c..f745a81 100644
--- a/internal/impl/legacy_message.go
+++ b/internal/impl/legacy_message.go
@@ -59,9 +59,6 @@
 //
 // This is exported for testing purposes.
 func LegacyLoadMessageDesc(t reflect.Type) pref.MessageDescriptor {
-	return legacyLoadMessageDesc(t, true)
-}
-func legacyLoadMessageDesc(t reflect.Type, finalized bool) pref.MessageDescriptor {
 	// Fast-path: check if a MessageDescriptor is cached for this concrete type.
 	if mi, ok := legacyMessageDescCache.Load(t); ok {
 		return mi.(pref.MessageDescriptor)
@@ -74,7 +71,7 @@
 	}
 	mdV1, ok := mv.(messageV1)
 	if !ok {
-		return aberrantLoadMessageDesc(t, finalized)
+		return aberrantLoadMessageDesc(t)
 	}
 	b, idxs := mdV1.Descriptor()
 
@@ -88,16 +85,10 @@
 	return md
 }
 
-var aberrantMessageDescCache sync.Map // map[reflect.Type]aberrantMessageDesc
-
-// aberrantMessageDesc is a tuple containing a MessageDescriptor and a channel
-// to signal whether the descriptor is initialized. For external lookups,
-// we must ensure that the descriptor is fully initialized. For internal lookups
-// to resolve cycles, we only need to obtain the descriptor reference.
-type aberrantMessageDesc struct {
-	desc protoreflect.MessageDescriptor
-	done chan struct{} // closed when desc is fully initialized
-}
+var (
+	aberrantMessageDescLock  sync.Mutex
+	aberrantMessageDescCache map[reflect.Type]protoreflect.MessageDescriptor
+)
 
 // aberrantLoadEnumDesc returns an EnumDescriptor derived from the Go type,
 // which must not implement protoreflect.ProtoMessage or messageV1.
@@ -107,31 +98,27 @@
 //
 // The finalized flag determines whether the returned message descriptor must
 // be fully initialized.
-func aberrantLoadMessageDesc(t reflect.Type, finalized bool) pref.MessageDescriptor {
+func aberrantLoadMessageDesc(t reflect.Type) pref.MessageDescriptor {
+	aberrantMessageDescLock.Lock()
+	defer aberrantMessageDescLock.Unlock()
+	if aberrantMessageDescCache == nil {
+		aberrantMessageDescCache = make(map[reflect.Type]protoreflect.MessageDescriptor)
+	}
+	return aberrantLoadMessageDescReentrant(t)
+}
+func aberrantLoadMessageDescReentrant(t reflect.Type) pref.MessageDescriptor {
 	// Fast-path: check if an MessageDescriptor is cached for this concrete type.
-	if mdi, ok := aberrantMessageDescCache.Load(t); ok {
-		if finalized {
-			<-mdi.(aberrantMessageDesc).done
-		}
-		return mdi.(aberrantMessageDesc).desc
+	if md, ok := aberrantMessageDescCache[t]; ok {
+		return md
 	}
 
-	// Medium-path: create an initial descriptor and cache it immediately,
-	// so that cyclic references can be resolved. Each descriptor is paired
-	// with a channel to signal when the descriptor is fully initialized.
-	md := &filedesc.Message{L2: new(filedesc.MessageL2)}
-	mdi := aberrantMessageDesc{desc: md, done: make(chan struct{})}
-	if mdi, ok := aberrantMessageDescCache.LoadOrStore(t, mdi); ok {
-		if finalized {
-			<-mdi.(aberrantMessageDesc).done
-		}
-		return mdi.(aberrantMessageDesc).desc
-	}
-	defer func() { close(mdi.done) }()
-
 	// Slow-path: construct a descriptor from the Go struct type (best-effort).
+	// Cache the MessageDescriptor early on so that we can resolve internal
+	// cyclic references.
+	md := &filedesc.Message{L2: new(filedesc.MessageL2)}
 	md.L0.FullName = aberrantDeriveFullName(t.Elem())
 	md.L0.ParentFile = filedesc.SurrogateProto2
+	aberrantMessageDescCache[t] = md
 
 	// Try to determine if the message is using proto3 by checking scalars.
 	for i := 0; i < t.Elem().NumField(); i++ {
@@ -257,6 +244,8 @@
 		switch v := reflect.Zero(t).Interface().(type) {
 		case pref.ProtoMessage:
 			fd.L1.Message = v.ProtoReflect().Descriptor()
+		case messageV1:
+			fd.L1.Message = LegacyLoadMessageDesc(t)
 		default:
 			if t.Kind() == reflect.Map {
 				n := len(md.L1.Messages.List)
@@ -280,7 +269,7 @@
 				fd.L1.Message = md2
 				break
 			}
-			fd.L1.Message = aberrantLoadMessageDesc(t, false)
+			fd.L1.Message = aberrantLoadMessageDescReentrant(t)
 		}
 	}
 }