internal/impl: store extension values as Values

Change the storage type of ExtensionField from interface{} to
protoreflect.Value.

Replace the codec functions operating on interface{}s with ones
operating on Values.

Values are potentially more efficient, since they can represent
non-pointer types without allocation. This also reduces the number of
types used to represent field values.

Additionally, this change lays groundwork for changing the
user-visible representation of repeated extension fields from
*[]T to []T. The storage type for extension fields must support mutation
(thus *[]T currently); changing the storage type to a Value permits this
without the need to introduce yet another view on field values.

Change-Id: Ida336be14112bb940f655236eb58df21bf312525
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/192218
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
diff --git a/internal/impl/codec_map.go b/internal/impl/codec_map.go
index 4916704..00d6511 100644
--- a/internal/impl/codec_map.go
+++ b/internal/impl/codec_map.go
@@ -5,11 +5,10 @@
 package impl
 
 import (
-	"fmt"
 	"reflect"
-	"sort"
 
 	"google.golang.org/protobuf/internal/encoding/wire"
+	"google.golang.org/protobuf/internal/mapsort"
 	pref "google.golang.org/protobuf/reflect/protoreflect"
 )
 
@@ -17,11 +16,10 @@
 	goType     reflect.Type
 	keyWiretag uint64
 	valWiretag uint64
-	keyFuncs   ifaceCoderFuncs
-	valFuncs   ifaceCoderFuncs
-	keyZero    interface{}
-	valZero    interface{}
-	newVal     func() interface{}
+	keyFuncs   valueCoderFuncs
+	valFuncs   valueCoderFuncs
+	keyZero    pref.Value
+	keyKind    pref.Kind
 }
 
 func encoderFuncsForMap(fd pref.FieldDescriptor, ft reflect.Type) (funcs pointerCoderFuncs) {
@@ -32,6 +30,7 @@
 	valWiretag := wire.EncodeTag(2, wireTypes[valField.Kind()])
 	keyFuncs := encoderFuncsForValue(keyField, ft.Key())
 	valFuncs := encoderFuncsForValue(valField, ft.Elem())
+	conv := NewConverter(ft, fd)
 
 	mapi := &mapInfo{
 		goType:     ft,
@@ -39,30 +38,32 @@
 		valWiretag: valWiretag,
 		keyFuncs:   keyFuncs,
 		valFuncs:   valFuncs,
-		keyZero:    reflect.Zero(ft.Key()).Interface(),
-		valZero:    reflect.Zero(ft.Elem()).Interface(),
-	}
-	switch valField.Kind() {
-	case pref.GroupKind, pref.MessageKind:
-		mapi.newVal = func() interface{} {
-			return reflect.New(ft.Elem().Elem()).Interface()
-		}
+		keyZero:    keyField.Default(),
+		keyKind:    keyField.Kind(),
 	}
 
 	funcs = pointerCoderFuncs{
 		size: func(p pointer, tagsize int, opts marshalOptions) int {
-			return sizeMap(p, tagsize, ft, keyFuncs, valFuncs, opts)
+			mapv := conv.PBValueOf(p.AsValueOf(ft).Elem()).Map()
+			return sizeMap(mapv, tagsize, mapi, opts)
 		},
 		marshal: func(b []byte, p pointer, wiretag uint64, opts marshalOptions) ([]byte, error) {
-			return appendMap(b, p, wiretag, keyWiretag, valWiretag, ft, keyFuncs, valFuncs, opts)
+			mapv := conv.PBValueOf(p.AsValueOf(ft).Elem()).Map()
+			return appendMap(b, mapv, wiretag, mapi, opts)
 		},
 		unmarshal: func(b []byte, p pointer, wtyp wire.Type, opts unmarshalOptions) (int, error) {
-			return consumeMap(b, p, wtyp, mapi, opts)
+			mp := p.AsValueOf(ft)
+			if mp.Elem().IsNil() {
+				mp.Elem().Set(reflect.MakeMap(mapi.goType))
+			}
+			mapv := conv.PBValueOf(mp.Elem()).Map()
+			return consumeMap(b, mapv, wtyp, mapi, opts)
 		},
 	}
 	if valFuncs.isInit != nil {
 		funcs.isInit = func(p pointer) error {
-			return isInitMap(p, ft, valFuncs.isInit)
+			mapv := conv.PBValueOf(p.AsValueOf(ft).Elem()).Map()
+			return isInitMap(mapv, mapi)
 		}
 	}
 	return funcs
@@ -73,13 +74,21 @@
 	mapValTagSize = 1 // field 2, tag size 2.
 )
 
-func consumeMap(b []byte, p pointer, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (int, error) {
-	mp := p.AsValueOf(mapi.goType)
-	if mp.Elem().IsNil() {
-		mp.Elem().Set(reflect.MakeMap(mapi.goType))
+func sizeMap(mapv pref.Map, tagsize int, mapi *mapInfo, opts marshalOptions) int {
+	if mapv.Len() == 0 {
+		return 0
 	}
-	m := mp.Elem()
+	n := 0
+	mapv.Range(func(key pref.MapKey, value pref.Value) bool {
+		n += tagsize + wire.SizeBytes(
+			mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)+
+				mapi.valFuncs.size(value, mapValTagSize, opts))
+		return true
+	})
+	return n
+}
 
+func consumeMap(b []byte, mapv pref.Map, wtyp wire.Type, mapi *mapInfo, opts unmarshalOptions) (int, error) {
 	if wtyp != wire.BytesType {
 		return 0, errUnknown
 	}
@@ -89,11 +98,8 @@
 	}
 	var (
 		key = mapi.keyZero
-		val = mapi.valZero
+		val = mapv.NewValue()
 	)
-	if mapi.newVal != nil {
-		val = mapi.newVal()
-	}
 	for len(b) > 0 {
 		num, wtyp, n := wire.ConsumeTag(b)
 		if n < 0 {
@@ -103,14 +109,14 @@
 		err := errUnknown
 		switch num {
 		case 1:
-			var v interface{}
+			var v pref.Value
 			v, n, err = mapi.keyFuncs.unmarshal(b, key, num, wtyp, opts)
 			if err != nil {
 				break
 			}
 			key = v
 		case 2:
-			var v interface{}
+			var v pref.Value
 			v, n, err = mapi.valFuncs.unmarshal(b, val, num, wtyp, opts)
 			if err != nil {
 				break
@@ -127,119 +133,44 @@
 		}
 		b = b[n:]
 	}
-	m.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(val))
+	mapv.Set(key.MapKey(), val)
 	return n, nil
 }
 
-func sizeMap(p pointer, tagsize int, goType reflect.Type, keyFuncs, valFuncs ifaceCoderFuncs, opts marshalOptions) int {
-	m := p.AsValueOf(goType).Elem()
-	n := 0
-	if m.Len() == 0 {
-		return 0
+func appendMap(b []byte, mapv pref.Map, wiretag uint64, mapi *mapInfo, opts marshalOptions) ([]byte, error) {
+	if mapv.Len() == 0 {
+		return b, nil
 	}
-	iter := mapRange(m)
-	for iter.Next() {
-		ki := iter.Key().Interface()
-		vi := iter.Value().Interface()
-		size := keyFuncs.size(ki, mapKeyTagSize, opts) + valFuncs.size(vi, mapValTagSize, opts)
-		n += wire.SizeBytes(size) + tagsize
-	}
-	return n
-}
-
-func appendMap(b []byte, p pointer, wiretag, keyWiretag, valWiretag uint64, goType reflect.Type, keyFuncs, valFuncs ifaceCoderFuncs, opts marshalOptions) ([]byte, error) {
-	m := p.AsValueOf(goType).Elem()
 	var err error
-
-	if m.Len() == 0 {
-		return b, nil
-	}
-
-	if opts.Deterministic() {
-		keys := m.MapKeys()
-		sort.Sort(mapKeys(keys))
-		for _, k := range keys {
-			b, err = appendMapElement(b, k, m.MapIndex(k), wiretag, keyWiretag, valWiretag, keyFuncs, valFuncs, opts)
-			if err != nil {
-				return b, err
-			}
-		}
-		return b, nil
-	}
-
-	iter := mapRange(m)
-	for iter.Next() {
-		b, err = appendMapElement(b, iter.Key(), iter.Value(), wiretag, keyWiretag, valWiretag, keyFuncs, valFuncs, opts)
+	fn := func(key pref.MapKey, value pref.Value) bool {
+		b = wire.AppendVarint(b, wiretag)
+		size := 0
+		size += mapi.keyFuncs.size(key.Value(), mapKeyTagSize, opts)
+		size += mapi.valFuncs.size(value, mapValTagSize, opts)
+		b = wire.AppendVarint(b, uint64(size))
+		b, err = mapi.keyFuncs.marshal(b, key.Value(), mapi.keyWiretag, opts)
 		if err != nil {
-			return b, err
+			return false
 		}
-	}
-	return b, nil
-}
-
-func appendMapElement(b []byte, key, value reflect.Value, wiretag, keyWiretag, valWiretag uint64, keyFuncs, valFuncs ifaceCoderFuncs, opts marshalOptions) ([]byte, error) {
-	ki := key.Interface()
-	vi := value.Interface()
-	b = wire.AppendVarint(b, wiretag)
-	size := keyFuncs.size(ki, mapKeyTagSize, opts) + valFuncs.size(vi, mapValTagSize, opts)
-	b = wire.AppendVarint(b, uint64(size))
-	b, err := keyFuncs.marshal(b, ki, keyWiretag, opts)
-	if err != nil {
-		return b, err
-	}
-	b, err = valFuncs.marshal(b, vi, valWiretag, opts)
-	if err != nil {
-		return b, err
-	}
-	return b, nil
-}
-
-func isInitMap(p pointer, goType reflect.Type, isInit func(interface{}) error) error {
-	m := p.AsValueOf(goType).Elem()
-	if m.Len() == 0 {
-		return nil
-	}
-	iter := mapRange(m)
-	for iter.Next() {
-		if err := isInit(iter.Value().Interface()); err != nil {
-			return err
+		b, err = mapi.valFuncs.marshal(b, value, mapi.valWiretag, opts)
+		if err != nil {
+			return false
 		}
+		return true
 	}
-	return nil
+	if opts.Deterministic() {
+		mapsort.Range(mapv, mapi.keyKind, fn)
+	} else {
+		mapv.Range(fn)
+	}
+	return b, err
 }
 
-// mapKeys returns a sort.Interface to be used for sorting the map keys.
-// Map fields may have key types of non-float scalars, strings and enums.
-func mapKeys(vs []reflect.Value) sort.Interface {
-	s := mapKeySorter{vs: vs}
-
-	// Type specialization per https://developers.google.com/protocol-buffers/docs/proto#maps.
-	if len(vs) == 0 {
-		return s
-	}
-	switch vs[0].Kind() {
-	case reflect.Int32, reflect.Int64:
-		s.less = func(a, b reflect.Value) bool { return a.Int() < b.Int() }
-	case reflect.Uint32, reflect.Uint64:
-		s.less = func(a, b reflect.Value) bool { return a.Uint() < b.Uint() }
-	case reflect.Bool:
-		s.less = func(a, b reflect.Value) bool { return !a.Bool() && b.Bool() } // false < true
-	case reflect.String:
-		s.less = func(a, b reflect.Value) bool { return a.String() < b.String() }
-	default:
-		panic(fmt.Sprintf("unsupported map key type: %v", vs[0].Kind()))
-	}
-
-	return s
-}
-
-type mapKeySorter struct {
-	vs   []reflect.Value
-	less func(a, b reflect.Value) bool
-}
-
-func (s mapKeySorter) Len() int      { return len(s.vs) }
-func (s mapKeySorter) Swap(i, j int) { s.vs[i], s.vs[j] = s.vs[j], s.vs[i] }
-func (s mapKeySorter) Less(i, j int) bool {
-	return s.less(s.vs[i], s.vs[j])
+func isInitMap(mapv pref.Map, mapi *mapInfo) error {
+	var err error
+	mapv.Range(func(_ pref.MapKey, value pref.Value) bool {
+		err = mapi.valFuncs.isInit(value)
+		return err == nil
+	})
+	return err
 }