internal/impl: pass *coderFieldInfo into fast-path functions

Refactor the fast-path size, marshal, unmarshal, and isinit functions to
take the *coderFieldInfo for the field as input.

This replaces a number of closures capturing field-specific information
with functions taking that information as an explicit parameter.

Change-Id: I8cb39701265edb7b673f6f04a0152d5f4dbb4d5d
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/218937
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/internal/impl/codec_map.go b/internal/impl/codec_map.go
index c8c0925..b319583 100644
--- a/internal/impl/codec_map.go
+++ b/internal/impl/codec_map.go
@@ -14,18 +14,17 @@
 )
 
 type mapInfo struct {
-	goType         reflect.Type
-	keyWiretag     uint64
-	valWiretag     uint64
-	keyFuncs       valueCoderFuncs
-	valFuncs       valueCoderFuncs
-	keyZero        pref.Value
-	keyKind        pref.Kind
-	valMessageInfo *MessageInfo
-	conv           *mapConverter
+	goType     reflect.Type
+	keyWiretag uint64
+	valWiretag uint64
+	keyFuncs   valueCoderFuncs
+	valFuncs   valueCoderFuncs
+	keyZero    pref.Value
+	keyKind    pref.Kind
+	conv       *mapConverter
 }
 
-func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (funcs pointerCoderFuncs) {
+func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (valueMessage *MessageInfo, funcs pointerCoderFuncs) {
 	// TODO: Consider generating specialized map coders.
 	keyField := fd.MapKey()
 	valField := fd.MapValue()
@@ -46,34 +45,34 @@
 		conv:       conv,
 	}
 	if valField.Kind() == pref.MessageKind {
-		mapi.valMessageInfo = getMessageInfo(ft.Elem())
+		valueMessage = getMessageInfo(ft.Elem())
 	}
 
 	funcs = pointerCoderFuncs{
-		size: func(p pointer, tagsize int, opts marshalOptions) int {
-			return sizeMap(p.AsValueOf(ft).Elem(), tagsize, mapi, opts)
+		size: func(p pointer, f *coderFieldInfo, opts marshalOptions) int {
+			return sizeMap(p.AsValueOf(ft).Elem(), mapi, f, opts)
 		},
-		marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
-			return appendMap(b, p.AsValueOf(ft).Elem(), wiretag, mapi, opts)
+		marshal: func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
+			return appendMap(b, p.AsValueOf(ft).Elem(), mapi, f, opts)
 		},
-		unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (unmarshalOutput, error) {
+		unmarshal: func(b []byte, p pointer, wtyp wire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error) {
 			mp := p.AsValueOf(ft)
 			if mp.Elem().IsNil() {
 				mp.Elem().Set(reflect.MakeMap(mapi.goType))
 			}
-			if mapi.valMessageInfo == nil {
-				return consumeMap(b, mp.Elem(), wtyp, mapi, opts)
+			if f.mi == nil {
+				return consumeMap(b, mp.Elem(), wtyp, mapi, f, opts)
 			} else {
-				return consumeMapOfMessage(b, mp.Elem(), wtyp, mapi, opts)
+				return consumeMapOfMessage(b, mp.Elem(), wtyp, mapi, f, opts)
 			}
 		},
 	}
 	if valFuncs.isInit != nil {
-		funcs.isInit = func(p pointer) error {
-			return isInitMap(p.AsValueOf(ft).Elem(), mapi)
+		funcs.isInit = func(p pointer, f *coderFieldInfo) error {
+			return isInitMap(p.AsValueOf(ft).Elem(), mapi, f)
 		}
 	}
-	return funcs
+	return valueMessage, funcs
 }
 
 const (
@@ -81,7 +80,7 @@
 	mapValTagSize = 1 // field 2, tag size 2.
 )
 
-func sizeMap(mapv reflect.Value, tagsize int, mapi *mapInfo, opts marshalOptions) int {
+func sizeMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) int {
 	if mapv.Len() == 0 {
 		return 0
 	}
@@ -92,19 +91,19 @@
 		keySize := mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
 		var valSize int
 		value := mapi.conv.valConv.PBValueOf(iter.Value())
-		if mapi.valMessageInfo == nil {
+		if f.mi == nil {
 			valSize = mapi.valFuncs.size(value, mapValTagSize, opts)
 		} else {
 			p := pointerOfValue(iter.Value())
 			valSize += mapValTagSize
-			valSize += wire.SizeBytes(mapi.valMessageInfo.sizePointer(p, opts))
+			valSize += wire.SizeBytes(f.mi.sizePointer(p, opts))
 		}
-		n += tagsize + wire.SizeBytes(keySize+valSize)
+		n += f.tagsize + wire.SizeBytes(keySize+valSize)
 	}
 	return n
 }
 
-func consumeMap(b []byte, mapv reflect.Value, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
+func consumeMap(b []byte, mapv reflect.Value, wtyp wire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
 	if wtyp != wire.BytesType {
 		return out, errUnknown
 	}
@@ -161,7 +160,7 @@
 	return out, nil
 }
 
-func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
+func consumeMapOfMessage(b []byte, mapv reflect.Value, wtyp wire.Type, mapi *mapInfo, f *coderFieldInfo, opts unmarshalOptions) (out unmarshalOutput, err error) {
 	if wtyp != wire.BytesType {
 		return out, errUnknown
 	}
@@ -171,7 +170,7 @@
 	}
 	var (
 		key = mapi.keyZero
-		val = reflect.New(mapi.valMessageInfo.GoReflectType.Elem())
+		val = reflect.New(f.mi.GoReflectType.Elem())
 	)
 	for len(b) > 0 {
 		num, wtyp, n := wire.ConsumeTag(b)
@@ -203,7 +202,7 @@
 				return out, wire.ParseError(n)
 			}
 			var o unmarshalOutput
-			o, err = mapi.valMessageInfo.unmarshalPointer(v, pointerOfValue(val), 0, opts)
+			o, err = f.mi.unmarshalPointer(v, pointerOfValue(val), 0, opts)
 			if o.initialized {
 				// Consider this map item initialized so long as we see
 				// an initialized value.
@@ -225,8 +224,8 @@
 	return out, nil
 }
 
-func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, opts marshalOptions) ([]byte, error) {
-	if mapi.valMessageInfo == nil {
+func appendMapItem(b []byte, keyrv, valrv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
+	if f.mi == nil {
 		key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
 		val := mapi.conv.valConv.PBValueOf(valrv)
 		size := 0
@@ -241,7 +240,7 @@
 	} else {
 		key := mapi.conv.keyConv.PBValueOf(keyrv).MapKey()
 		val := pointerOfValue(valrv)
-		valSize := mapi.valMessageInfo.sizePointer(val, opts)
+		valSize := f.mi.sizePointer(val, opts)
 		size := 0
 		size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
 		size += mapValTagSize + wire.SizeBytes(valSize)
@@ -252,22 +251,22 @@
 		}
 		b = wire.AppendVarint(b, mapi.valWiretag)
 		b = wire.AppendVarint(b, uint64(valSize))
-		return mapi.valMessageInfo.marshalAppendPointer(b, val, opts)
+		return f.mi.marshalAppendPointer(b, val, opts)
 	}
 }
 
-func appendMap(b []byte, mapv reflect.Value, wiretag uint64, mapi *mapInfo, opts marshalOptions) ([]byte, error) {
+func appendMap(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
 	if mapv.Len() == 0 {
 		return b, nil
 	}
 	if opts.Deterministic() {
-		return appendMapDeterministic(b, mapv, wiretag, mapi, opts)
+		return appendMapDeterministic(b, mapv, mapi, f, opts)
 	}
 	iter := mapRange(mapv)
 	for iter.Next() {
 		var err error
-		b = wire.AppendVarint(b, wiretag)
-		b, err = appendMapItem(b, iter.Key(), iter.Value(), mapi, opts)
+		b = wire.AppendVarint(b, f.wiretag)
+		b, err = appendMapItem(b, iter.Key(), iter.Value(), mapi, f, opts)
 		if err != nil {
 			return b, err
 		}
@@ -275,7 +274,7 @@
 	return b, nil
 }
 
-func appendMapDeterministic(b []byte, mapv reflect.Value, wiretag uint64, mapi *mapInfo, opts marshalOptions) ([]byte, error) {
+func appendMapDeterministic(b []byte, mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo, opts marshalOptions) ([]byte, error) {
 	keys := mapv.MapKeys()
 	sort.Slice(keys, func(i, j int) bool {
 		switch keys[i].Kind() {
@@ -295,8 +294,8 @@
 	})
 	for _, key := range keys {
 		var err error
-		b = wire.AppendVarint(b, wiretag)
-		b, err = appendMapItem(b, key, mapv.MapIndex(key), mapi, opts)
+		b = wire.AppendVarint(b, f.wiretag)
+		b, err = appendMapItem(b, key, mapv.MapIndex(key), mapi, f, opts)
 		if err != nil {
 			return b, err
 		}
@@ -304,8 +303,8 @@
 	return b, nil
 }
 
-func isInitMap(mapv reflect.Value, mapi *mapInfo) error {
-	if mi := mapi.valMessageInfo; mi != nil {
+func isInitMap(mapv reflect.Value, mapi *mapInfo, f *coderFieldInfo) error {
+	if mi := f.mi; mi != nil {
 		mi.init()
 		if !mi.needsInitCheck {
 			return nil