internal/impl: improve extension fast path performance

Stash fast-path information for extensions on the ExtensionInfo. In
the usual case where an ExtensionType's underlying implementation is
an *ExtensionInfo, fetching the fast-path information becomes a type
assertion rather than a mutex-guarded map access.

Maintain a global sync.Map for the case where an ExtensionType isn't an
*ExtensionInfo.

Substantially improves performance for fast-path operations on
extensions:

Encode/MessageSet_type_id_before_message_content-12      267ns ± 1%   185ns ± 1%  -30.44%  (p=0.001 n=7+7)
Encode/basic_scalar_types_(*test.TestAllExtensions)-12  1.94µs ± 1%  0.40µs ± 1%  -79.32%  (p=0.000 n=8+7)

Change-Id: If048b521deb3665a090ea3d0a178c61691d4201e
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/210540
Reviewed-by: Joe Tsai <joetsai@google.com>
diff --git a/internal/impl/codec_extension.go b/internal/impl/codec_extension.go
index f2dca63..22a25dc 100644
--- a/internal/impl/codec_extension.go
+++ b/internal/impl/codec_extension.go
@@ -19,24 +19,36 @@
 	funcs               valueCoderFuncs
 }
 
-func (mi *MessageInfo) extensionFieldInfo(xt pref.ExtensionType) *extensionFieldInfo {
-	// As of this time (Go 1.12, linux/amd64), an RWMutex benchmarks as faster
-	// than a sync.Map.
-	mi.extensionFieldInfosMu.RLock()
-	e, ok := mi.extensionFieldInfos[xt]
-	mi.extensionFieldInfosMu.RUnlock()
-	if ok {
-		return e
-	}
+var legacyExtensionFieldInfoCache sync.Map // map[protoreflect.ExtensionType]*extensionFieldInfo
 
-	xd := xt.TypeDescriptor()
+func getExtensionFieldInfo(xt pref.ExtensionType) *extensionFieldInfo {
+	if xi, ok := xt.(*ExtensionInfo); ok {
+		xi.lazyInit()
+		return xi.info
+	}
+	return legacyLoadExtensionFieldInfo(xt)
+}
+
+// legacyLoadExtensionFieldInfo dynamically loads a *ExtensionInfo for xt.
+func legacyLoadExtensionFieldInfo(xt pref.ExtensionType) *extensionFieldInfo {
+	if xi, ok := legacyExtensionFieldInfoCache.Load(xt); ok {
+		return xi.(*extensionFieldInfo)
+	}
+	e := makeExtensionFieldInfo(xt.TypeDescriptor())
+	if e, ok := legacyMessageTypeCache.LoadOrStore(xt, e); ok {
+		return e.(*extensionFieldInfo)
+	}
+	return e
+}
+
+func makeExtensionFieldInfo(xd pref.ExtensionDescriptor) *extensionFieldInfo {
 	var wiretag uint64
 	if !xd.IsPacked() {
 		wiretag = wire.EncodeTag(xd.Number(), wireTypes[xd.Kind()])
 	} else {
 		wiretag = wire.EncodeTag(xd.Number(), wire.BytesType)
 	}
-	e = &extensionFieldInfo{
+	e := &extensionFieldInfo{
 		wiretag: wiretag,
 		tagsize: wire.SizeVarint(wiretag),
 		funcs:   encoderFuncsForValue(xd),
@@ -52,12 +64,6 @@
 			e.unmarshalNeedsValue = true
 		}
 	}
-	mi.extensionFieldInfosMu.Lock()
-	if mi.extensionFieldInfos == nil {
-		mi.extensionFieldInfos = make(map[pref.ExtensionType]*extensionFieldInfo)
-	}
-	mi.extensionFieldInfos[xt] = e
-	mi.extensionFieldInfosMu.Unlock()
 	return e
 }