internal/impl: pass *coderFieldInfo into fast-path functions

Refactor the fast-path size, marshal, unmarshal, and isinit functions to
take the *coderFieldInfo for the field as input.

This replaces a number of closures capturing field-specific information
with functions taking that information as an explicit parameter.

Change-Id: I8cb39701265edb7b673f6f04a0152d5f4dbb4d5d
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/218937
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/internal/impl/codec_tables.go b/internal/impl/codec_tables.go
index 38c4e7e..2c62bda 100644
--- a/internal/impl/codec_tables.go
+++ b/internal/impl/codec_tables.go
@@ -15,10 +15,11 @@
 
 // pointerCoderFuncs is a set of pointer encoding functions.
 type pointerCoderFuncs struct {
-	size      func(p pointer, tagsize int, opts marshalOptions) int
-	marshal   func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error)
-	unmarshal func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (unmarshalOutput, error)
-	isInit    func(p pointer) error
+	mi        *MessageInfo
+	size      func(p pointer, f *coderFieldInfo, opts marshalOptions) int
+	marshal   func(b []byte, p pointer, f *coderFieldInfo, opts marshalOptions) ([]byte, error)
+	unmarshal func(b []byte, p pointer, wtyp wire.Type, f *coderFieldInfo, opts unmarshalOptions) (unmarshalOutput, error)
+	isInit    func(p pointer, f *coderFieldInfo) error
 }
 
 // valueCoderFuncs is a set of protoreflect.Value encoding functions.
@@ -31,7 +32,7 @@
 
 // fieldCoder returns pointer functions for a field, used for operating on
 // struct fields.
-func fieldCoder(fd pref.FieldDescriptor, ft reflect.Type) pointerCoderFuncs {
+func fieldCoder(fd pref.FieldDescriptor, ft reflect.Type) (*MessageInfo, pointerCoderFuncs) {
 	switch {
 	case fd.IsMap():
 		return encoderFuncsForMap(fd, ft)
@@ -44,84 +45,84 @@
 		switch fd.Kind() {
 		case pref.BoolKind:
 			if ft.Kind() == reflect.Bool {
-				return coderBoolSlice
+				return nil, coderBoolSlice
 			}
 		case pref.EnumKind:
 			if ft.Kind() == reflect.Int32 {
-				return coderEnumSlice
+				return nil, coderEnumSlice
 			}
 		case pref.Int32Kind:
 			if ft.Kind() == reflect.Int32 {
-				return coderInt32Slice
+				return nil, coderInt32Slice
 			}
 		case pref.Sint32Kind:
 			if ft.Kind() == reflect.Int32 {
-				return coderSint32Slice
+				return nil, coderSint32Slice
 			}
 		case pref.Uint32Kind:
 			if ft.Kind() == reflect.Uint32 {
-				return coderUint32Slice
+				return nil, coderUint32Slice
 			}
 		case pref.Int64Kind:
 			if ft.Kind() == reflect.Int64 {
-				return coderInt64Slice
+				return nil, coderInt64Slice
 			}
 		case pref.Sint64Kind:
 			if ft.Kind() == reflect.Int64 {
-				return coderSint64Slice
+				return nil, coderSint64Slice
 			}
 		case pref.Uint64Kind:
 			if ft.Kind() == reflect.Uint64 {
-				return coderUint64Slice
+				return nil, coderUint64Slice
 			}
 		case pref.Sfixed32Kind:
 			if ft.Kind() == reflect.Int32 {
-				return coderSfixed32Slice
+				return nil, coderSfixed32Slice
 			}
 		case pref.Fixed32Kind:
 			if ft.Kind() == reflect.Uint32 {
-				return coderFixed32Slice
+				return nil, coderFixed32Slice
 			}
 		case pref.FloatKind:
 			if ft.Kind() == reflect.Float32 {
-				return coderFloatSlice
+				return nil, coderFloatSlice
 			}
 		case pref.Sfixed64Kind:
 			if ft.Kind() == reflect.Int64 {
-				return coderSfixed64Slice
+				return nil, coderSfixed64Slice
 			}
 		case pref.Fixed64Kind:
 			if ft.Kind() == reflect.Uint64 {
-				return coderFixed64Slice
+				return nil, coderFixed64Slice
 			}
 		case pref.DoubleKind:
 			if ft.Kind() == reflect.Float64 {
-				return coderDoubleSlice
+				return nil, coderDoubleSlice
 			}
 		case pref.StringKind:
 			if ft.Kind() == reflect.String && strs.EnforceUTF8(fd) {
-				return coderStringSliceValidateUTF8
+				return nil, coderStringSliceValidateUTF8
 			}
 			if ft.Kind() == reflect.String {
-				return coderStringSlice
+				return nil, coderStringSlice
 			}
 			if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 && strs.EnforceUTF8(fd) {
-				return coderBytesSliceValidateUTF8
+				return nil, coderBytesSliceValidateUTF8
 			}
 			if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 {
-				return coderBytesSlice
+				return nil, coderBytesSlice
 			}
 		case pref.BytesKind:
 			if ft.Kind() == reflect.String {
-				return coderStringSlice
+				return nil, coderStringSlice
 			}
 			if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 {
-				return coderBytesSlice
+				return nil, coderBytesSlice
 			}
 		case pref.MessageKind:
-			return makeMessageSliceFieldCoder(fd, ft)
+			return getMessageInfo(ft), makeMessageSliceFieldCoder(fd, ft)
 		case pref.GroupKind:
-			return makeGroupSliceFieldCoder(fd, ft)
+			return getMessageInfo(ft), makeGroupSliceFieldCoder(fd, ft)
 		}
 	case fd.Cardinality() == pref.Repeated && fd.IsPacked():
 		// Packed repeated fields.
@@ -135,144 +136,144 @@
 		switch fd.Kind() {
 		case pref.BoolKind:
 			if ft.Kind() == reflect.Bool {
-				return coderBoolPackedSlice
+				return nil, coderBoolPackedSlice
 			}
 		case pref.EnumKind:
 			if ft.Kind() == reflect.Int32 {
-				return coderEnumPackedSlice
+				return nil, coderEnumPackedSlice
 			}
 		case pref.Int32Kind:
 			if ft.Kind() == reflect.Int32 {
-				return coderInt32PackedSlice
+				return nil, coderInt32PackedSlice
 			}
 		case pref.Sint32Kind:
 			if ft.Kind() == reflect.Int32 {
-				return coderSint32PackedSlice
+				return nil, coderSint32PackedSlice
 			}
 		case pref.Uint32Kind:
 			if ft.Kind() == reflect.Uint32 {
-				return coderUint32PackedSlice
+				return nil, coderUint32PackedSlice
 			}
 		case pref.Int64Kind:
 			if ft.Kind() == reflect.Int64 {
-				return coderInt64PackedSlice
+				return nil, coderInt64PackedSlice
 			}
 		case pref.Sint64Kind:
 			if ft.Kind() == reflect.Int64 {
-				return coderSint64PackedSlice
+				return nil, coderSint64PackedSlice
 			}
 		case pref.Uint64Kind:
 			if ft.Kind() == reflect.Uint64 {
-				return coderUint64PackedSlice
+				return nil, coderUint64PackedSlice
 			}
 		case pref.Sfixed32Kind:
 			if ft.Kind() == reflect.Int32 {
-				return coderSfixed32PackedSlice
+				return nil, coderSfixed32PackedSlice
 			}
 		case pref.Fixed32Kind:
 			if ft.Kind() == reflect.Uint32 {
-				return coderFixed32PackedSlice
+				return nil, coderFixed32PackedSlice
 			}
 		case pref.FloatKind:
 			if ft.Kind() == reflect.Float32 {
-				return coderFloatPackedSlice
+				return nil, coderFloatPackedSlice
 			}
 		case pref.Sfixed64Kind:
 			if ft.Kind() == reflect.Int64 {
-				return coderSfixed64PackedSlice
+				return nil, coderSfixed64PackedSlice
 			}
 		case pref.Fixed64Kind:
 			if ft.Kind() == reflect.Uint64 {
-				return coderFixed64PackedSlice
+				return nil, coderFixed64PackedSlice
 			}
 		case pref.DoubleKind:
 			if ft.Kind() == reflect.Float64 {
-				return coderDoublePackedSlice
+				return nil, coderDoublePackedSlice
 			}
 		}
 	case fd.Kind() == pref.MessageKind:
-		return makeMessageFieldCoder(fd, ft)
+		return getMessageInfo(ft), makeMessageFieldCoder(fd, ft)
 	case fd.Kind() == pref.GroupKind:
-		return makeGroupFieldCoder(fd, ft)
+		return getMessageInfo(ft), makeGroupFieldCoder(fd, ft)
 	case fd.Syntax() == pref.Proto3 && fd.ContainingOneof() == nil:
 		// Populated oneof fields always encode even if set to the zero value,
 		// which normally are not encoded in proto3.
 		switch fd.Kind() {
 		case pref.BoolKind:
 			if ft.Kind() == reflect.Bool {
-				return coderBoolNoZero
+				return nil, coderBoolNoZero
 			}
 		case pref.EnumKind:
 			if ft.Kind() == reflect.Int32 {
-				return coderEnumNoZero
+				return nil, coderEnumNoZero
 			}
 		case pref.Int32Kind:
 			if ft.Kind() == reflect.Int32 {
-				return coderInt32NoZero
+				return nil, coderInt32NoZero
 			}
 		case pref.Sint32Kind:
 			if ft.Kind() == reflect.Int32 {
-				return coderSint32NoZero
+				return nil, coderSint32NoZero
 			}
 		case pref.Uint32Kind:
 			if ft.Kind() == reflect.Uint32 {
-				return coderUint32NoZero
+				return nil, coderUint32NoZero
 			}
 		case pref.Int64Kind:
 			if ft.Kind() == reflect.Int64 {
-				return coderInt64NoZero
+				return nil, coderInt64NoZero
 			}
 		case pref.Sint64Kind:
 			if ft.Kind() == reflect.Int64 {
-				return coderSint64NoZero
+				return nil, coderSint64NoZero
 			}
 		case pref.Uint64Kind:
 			if ft.Kind() == reflect.Uint64 {
-				return coderUint64NoZero
+				return nil, coderUint64NoZero
 			}
 		case pref.Sfixed32Kind:
 			if ft.Kind() == reflect.Int32 {
-				return coderSfixed32NoZero
+				return nil, coderSfixed32NoZero
 			}
 		case pref.Fixed32Kind:
 			if ft.Kind() == reflect.Uint32 {
-				return coderFixed32NoZero
+				return nil, coderFixed32NoZero
 			}
 		case pref.FloatKind:
 			if ft.Kind() == reflect.Float32 {
-				return coderFloatNoZero
+				return nil, coderFloatNoZero
 			}
 		case pref.Sfixed64Kind:
 			if ft.Kind() == reflect.Int64 {
-				return coderSfixed64NoZero
+				return nil, coderSfixed64NoZero
 			}
 		case pref.Fixed64Kind:
 			if ft.Kind() == reflect.Uint64 {
-				return coderFixed64NoZero
+				return nil, coderFixed64NoZero
 			}
 		case pref.DoubleKind:
 			if ft.Kind() == reflect.Float64 {
-				return coderDoubleNoZero
+				return nil, coderDoubleNoZero
 			}
 		case pref.StringKind:
 			if ft.Kind() == reflect.String && strs.EnforceUTF8(fd) {
-				return coderStringNoZeroValidateUTF8
+				return nil, coderStringNoZeroValidateUTF8
 			}
 			if ft.Kind() == reflect.String {
-				return coderStringNoZero
+				return nil, coderStringNoZero
 			}
 			if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 && strs.EnforceUTF8(fd) {
-				return coderBytesNoZeroValidateUTF8
+				return nil, coderBytesNoZeroValidateUTF8
 			}
 			if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 {
-				return coderBytesNoZero
+				return nil, coderBytesNoZero
 			}
 		case pref.BytesKind:
 			if ft.Kind() == reflect.String {
-				return coderStringNoZero
+				return nil, coderStringNoZero
 			}
 			if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 {
-				return coderBytesNoZero
+				return nil, coderBytesNoZero
 			}
 		}
 	case ft.Kind() == reflect.Ptr:
@@ -280,146 +281,146 @@
 		switch fd.Kind() {
 		case pref.BoolKind:
 			if ft.Kind() == reflect.Bool {
-				return coderBoolPtr
+				return nil, coderBoolPtr
 			}
 		case pref.EnumKind:
 			if ft.Kind() == reflect.Int32 {
-				return coderEnumPtr
+				return nil, coderEnumPtr
 			}
 		case pref.Int32Kind:
 			if ft.Kind() == reflect.Int32 {
-				return coderInt32Ptr
+				return nil, coderInt32Ptr
 			}
 		case pref.Sint32Kind:
 			if ft.Kind() == reflect.Int32 {
-				return coderSint32Ptr
+				return nil, coderSint32Ptr
 			}
 		case pref.Uint32Kind:
 			if ft.Kind() == reflect.Uint32 {
-				return coderUint32Ptr
+				return nil, coderUint32Ptr
 			}
 		case pref.Int64Kind:
 			if ft.Kind() == reflect.Int64 {
-				return coderInt64Ptr
+				return nil, coderInt64Ptr
 			}
 		case pref.Sint64Kind:
 			if ft.Kind() == reflect.Int64 {
-				return coderSint64Ptr
+				return nil, coderSint64Ptr
 			}
 		case pref.Uint64Kind:
 			if ft.Kind() == reflect.Uint64 {
-				return coderUint64Ptr
+				return nil, coderUint64Ptr
 			}
 		case pref.Sfixed32Kind:
 			if ft.Kind() == reflect.Int32 {
-				return coderSfixed32Ptr
+				return nil, coderSfixed32Ptr
 			}
 		case pref.Fixed32Kind:
 			if ft.Kind() == reflect.Uint32 {
-				return coderFixed32Ptr
+				return nil, coderFixed32Ptr
 			}
 		case pref.FloatKind:
 			if ft.Kind() == reflect.Float32 {
-				return coderFloatPtr
+				return nil, coderFloatPtr
 			}
 		case pref.Sfixed64Kind:
 			if ft.Kind() == reflect.Int64 {
-				return coderSfixed64Ptr
+				return nil, coderSfixed64Ptr
 			}
 		case pref.Fixed64Kind:
 			if ft.Kind() == reflect.Uint64 {
-				return coderFixed64Ptr
+				return nil, coderFixed64Ptr
 			}
 		case pref.DoubleKind:
 			if ft.Kind() == reflect.Float64 {
-				return coderDoublePtr
+				return nil, coderDoublePtr
 			}
 		case pref.StringKind:
 			if ft.Kind() == reflect.String {
-				return coderStringPtr
+				return nil, coderStringPtr
 			}
 		case pref.BytesKind:
 			if ft.Kind() == reflect.String {
-				return coderStringPtr
+				return nil, coderStringPtr
 			}
 		}
 	default:
 		switch fd.Kind() {
 		case pref.BoolKind:
 			if ft.Kind() == reflect.Bool {
-				return coderBool
+				return nil, coderBool
 			}
 		case pref.EnumKind:
 			if ft.Kind() == reflect.Int32 {
-				return coderEnum
+				return nil, coderEnum
 			}
 		case pref.Int32Kind:
 			if ft.Kind() == reflect.Int32 {
-				return coderInt32
+				return nil, coderInt32
 			}
 		case pref.Sint32Kind:
 			if ft.Kind() == reflect.Int32 {
-				return coderSint32
+				return nil, coderSint32
 			}
 		case pref.Uint32Kind:
 			if ft.Kind() == reflect.Uint32 {
-				return coderUint32
+				return nil, coderUint32
 			}
 		case pref.Int64Kind:
 			if ft.Kind() == reflect.Int64 {
-				return coderInt64
+				return nil, coderInt64
 			}
 		case pref.Sint64Kind:
 			if ft.Kind() == reflect.Int64 {
-				return coderSint64
+				return nil, coderSint64
 			}
 		case pref.Uint64Kind:
 			if ft.Kind() == reflect.Uint64 {
-				return coderUint64
+				return nil, coderUint64
 			}
 		case pref.Sfixed32Kind:
 			if ft.Kind() == reflect.Int32 {
-				return coderSfixed32
+				return nil, coderSfixed32
 			}
 		case pref.Fixed32Kind:
 			if ft.Kind() == reflect.Uint32 {
-				return coderFixed32
+				return nil, coderFixed32
 			}
 		case pref.FloatKind:
 			if ft.Kind() == reflect.Float32 {
-				return coderFloat
+				return nil, coderFloat
 			}
 		case pref.Sfixed64Kind:
 			if ft.Kind() == reflect.Int64 {
-				return coderSfixed64
+				return nil, coderSfixed64
 			}
 		case pref.Fixed64Kind:
 			if ft.Kind() == reflect.Uint64 {
-				return coderFixed64
+				return nil, coderFixed64
 			}
 		case pref.DoubleKind:
 			if ft.Kind() == reflect.Float64 {
-				return coderDouble
+				return nil, coderDouble
 			}
 		case pref.StringKind:
 			if ft.Kind() == reflect.String && strs.EnforceUTF8(fd) {
-				return coderStringValidateUTF8
+				return nil, coderStringValidateUTF8
 			}
 			if ft.Kind() == reflect.String {
-				return coderString
+				return nil, coderString
 			}
 			if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 && strs.EnforceUTF8(fd) {
-				return coderBytesValidateUTF8
+				return nil, coderBytesValidateUTF8
 			}
 			if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 {
-				return coderBytes
+				return nil, coderBytes
 			}
 		case pref.BytesKind:
 			if ft.Kind() == reflect.String {
-				return coderString
+				return nil, coderString
 			}
 			if ft.Kind() == reflect.Slice && ft.Elem().Kind() == reflect.Uint8 {
-				return coderBytes
+				return nil, coderBytes
 			}
 		}
 	}