internal/impl: implement Map fields

Generate functions for wrapping map[K]V to implement protoreflect.Map.
This implementation uses Go reflection instead to provide a single implementation
that can handle all Go map types.

Change-Id: Idcb8069ef836614a88e5df12ef7c5044e8aa3dea
Reviewed-on: https://go-review.googlesource.com/c/142778
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/internal/impl/message_test.go b/internal/impl/message_test.go
index c24fb04..a789247 100644
--- a/internal/impl/message_test.go
+++ b/internal/impl/message_test.go
@@ -40,6 +40,9 @@
 
 	NamedStrings []MyString
 	NamedBytes   []MyBytes
+
+	MapStrings map[MyString]MyString
+	MapBytes   map[MyString]MyBytes
 )
 
 // List of test operations to perform on messages, vectors, or maps.
@@ -50,8 +53,8 @@
 	vectorOp  interface{} // equalVector | lenVector | getVector | setVector | appendVector | truncVector
 	vectorOps []vectorOp
 
-	mapOp  interface{} // TODO
-	mapOps []mapOp     // TODO
+	mapOp  interface{} // equalMap | lenMap | hasMap | getMap | setMap | clearMap | rangeMap
+	mapOps []mapOp
 )
 
 // Test operations performed on a message.
@@ -78,6 +81,18 @@
 	// TODO: Mutable, MutableAppend
 )
 
+// Test operations performed on a map.
+type (
+	equalMap pref.Map
+	lenMap   int
+	hasMap   map[interface{}]bool
+	getMap   map[interface{}]pref.Value
+	setMap   map[interface{}]pref.Value
+	clearMap map[interface{}]bool
+	rangeMap map[interface{}]pref.Value
+	// TODO: List, Mutable
+)
+
 func TestScalarProto2(t *testing.T) {
 	type ScalarProto2 struct {
 		Bool    *bool    `protobuf:"1"`
@@ -396,7 +411,206 @@
 		hasFields{1: true, 2: true, 3: true, 4: true, 5: true, 6: true, 7: true, 8: true, 9: true, 10: true, 11: true, 12: true, 13: true, 14: true, 15: true, 16: true, 17: true, 18: true, 19: true},
 		equalMessage(want),
 		clearFields{1: true, 2: true, 3: true, 4: true, 5: true, 6: true, 7: true, 8: true, 9: true, 10: true, 11: true, 12: true, 13: true, 14: true, 15: true, 16: true, 17: true, 18: true, 19: true},
-		equalMessage(mi.MessageOf(&RepeatedScalars{})),
+		equalMessage(empty),
+	})
+}
+
+func TestMapScalars(t *testing.T) {
+	type MapScalars struct {
+		KeyBools   map[bool]string   `protobuf:"1"`
+		KeyInt32s  map[int32]string  `protobuf:"2"`
+		KeyInt64s  map[int64]string  `protobuf:"3"`
+		KeyUint32s map[uint32]string `protobuf:"4"`
+		KeyUint64s map[uint64]string `protobuf:"5"`
+		KeyStrings map[string]string `protobuf:"6"`
+
+		ValBools    map[string]bool    `protobuf:"7"`
+		ValInt32s   map[string]int32   `protobuf:"8"`
+		ValInt64s   map[string]int64   `protobuf:"9"`
+		ValUint32s  map[string]uint32  `protobuf:"10"`
+		ValUint64s  map[string]uint64  `protobuf:"11"`
+		ValFloat32s map[string]float32 `protobuf:"12"`
+		ValFloat64s map[string]float64 `protobuf:"13"`
+		ValStrings  map[string]string  `protobuf:"14"`
+		ValStringsA map[string][]byte  `protobuf:"15"`
+		ValBytes    map[string][]byte  `protobuf:"16"`
+		ValBytesA   map[string]string  `protobuf:"17"`
+
+		MyStrings1 map[MyString]MyString `protobuf:"18"`
+		MyStrings2 map[MyString]MyBytes  `protobuf:"19"`
+		MyBytes1   map[MyString]MyBytes  `protobuf:"20"`
+		MyBytes2   map[MyString]MyString `protobuf:"21"`
+
+		MyStrings3 MapStrings `protobuf:"22"`
+		MyStrings4 MapBytes   `protobuf:"23"`
+		MyBytes3   MapBytes   `protobuf:"24"`
+		MyBytes4   MapStrings `protobuf:"25"`
+	}
+
+	mustMakeMapEntry := func(n pref.FieldNumber, keyKind, valKind pref.Kind) ptype.Field {
+		return ptype.Field{
+			Name:        pref.Name(fmt.Sprintf("f%d", n)),
+			Number:      n,
+			Cardinality: pref.Repeated,
+			Kind:        pref.MessageKind,
+			MessageType: mustMakeMessageDesc(ptype.StandaloneMessage{
+				Syntax:   pref.Proto2,
+				FullName: pref.FullName(fmt.Sprintf("MapScalars.F%dEntry", n)),
+				Fields: []ptype.Field{
+					{Name: "key", Number: 1, Cardinality: pref.Optional, Kind: keyKind},
+					{Name: "value", Number: 2, Cardinality: pref.Optional, Kind: valKind},
+				},
+				IsMapEntry: true,
+			}),
+		}
+	}
+	mi := MessageType{Desc: mustMakeMessageDesc(ptype.StandaloneMessage{
+		Syntax:   pref.Proto2,
+		FullName: "MapScalars",
+		Fields: []ptype.Field{
+			mustMakeMapEntry(1, pref.BoolKind, pref.StringKind),
+			mustMakeMapEntry(2, pref.Int32Kind, pref.StringKind),
+			mustMakeMapEntry(3, pref.Int64Kind, pref.StringKind),
+			mustMakeMapEntry(4, pref.Uint32Kind, pref.StringKind),
+			mustMakeMapEntry(5, pref.Uint64Kind, pref.StringKind),
+			mustMakeMapEntry(6, pref.StringKind, pref.StringKind),
+
+			mustMakeMapEntry(7, pref.StringKind, pref.BoolKind),
+			mustMakeMapEntry(8, pref.StringKind, pref.Int32Kind),
+			mustMakeMapEntry(9, pref.StringKind, pref.Int64Kind),
+			mustMakeMapEntry(10, pref.StringKind, pref.Uint32Kind),
+			mustMakeMapEntry(11, pref.StringKind, pref.Uint64Kind),
+			mustMakeMapEntry(12, pref.StringKind, pref.FloatKind),
+			mustMakeMapEntry(13, pref.StringKind, pref.DoubleKind),
+			mustMakeMapEntry(14, pref.StringKind, pref.StringKind),
+			mustMakeMapEntry(15, pref.StringKind, pref.StringKind),
+			mustMakeMapEntry(16, pref.StringKind, pref.BytesKind),
+			mustMakeMapEntry(17, pref.StringKind, pref.BytesKind),
+
+			mustMakeMapEntry(18, pref.StringKind, pref.StringKind),
+			mustMakeMapEntry(19, pref.StringKind, pref.StringKind),
+			mustMakeMapEntry(20, pref.StringKind, pref.BytesKind),
+			mustMakeMapEntry(21, pref.StringKind, pref.BytesKind),
+
+			mustMakeMapEntry(22, pref.StringKind, pref.StringKind),
+			mustMakeMapEntry(23, pref.StringKind, pref.StringKind),
+			mustMakeMapEntry(24, pref.StringKind, pref.BytesKind),
+			mustMakeMapEntry(25, pref.StringKind, pref.BytesKind),
+		},
+	})}
+
+	empty := mi.MessageOf(&MapScalars{})
+	emptyFS := empty.KnownFields()
+
+	want := mi.MessageOf(&MapScalars{
+		KeyBools:   map[bool]string{true: "true", false: "false"},
+		KeyInt32s:  map[int32]string{0: "zero", -1: "one", 2: "two"},
+		KeyInt64s:  map[int64]string{0: "zero", -10: "ten", 20: "twenty"},
+		KeyUint32s: map[uint32]string{0: "zero", 1: "one", 2: "two"},
+		KeyUint64s: map[uint64]string{0: "zero", 10: "ten", 20: "twenty"},
+		KeyStrings: map[string]string{"": "", "foo": "bar"},
+
+		ValBools:    map[string]bool{"true": true, "false": false},
+		ValInt32s:   map[string]int32{"one": 1, "two": 2, "three": 3},
+		ValInt64s:   map[string]int64{"ten": 10, "twenty": -20, "thirty": 30},
+		ValUint32s:  map[string]uint32{"0x00": 0x00, "0xff": 0xff, "0xdead": 0xdead},
+		ValUint64s:  map[string]uint64{"0x00": 0x00, "0xff": 0xff, "0xdead": 0xdead},
+		ValFloat32s: map[string]float32{"nan": float32(math.NaN()), "pi": float32(math.Pi)},
+		ValFloat64s: map[string]float64{"nan": float64(math.NaN()), "pi": float64(math.Pi)},
+		ValStrings:  map[string]string{"s1": "s1", "s2": "s2"},
+		ValStringsA: map[string][]byte{"s1": []byte("s1"), "s2": []byte("s2")},
+		ValBytes:    map[string][]byte{"s1": []byte("s1"), "s2": []byte("s2")},
+		ValBytesA:   map[string]string{"s1": "s1", "s2": "s2"},
+
+		MyStrings1: map[MyString]MyString{"s1": "s1", "s2": "s2"},
+		MyStrings2: map[MyString]MyBytes{"s1": []byte("s1"), "s2": []byte("s2")},
+		MyBytes1:   map[MyString]MyBytes{"s1": []byte("s1"), "s2": []byte("s2")},
+		MyBytes2:   map[MyString]MyString{"s1": "s1", "s2": "s2"},
+
+		MyStrings3: MapStrings{"s1": "s1", "s2": "s2"},
+		MyStrings4: MapBytes{"s1": []byte("s1"), "s2": []byte("s2")},
+		MyBytes3:   MapBytes{"s1": []byte("s1"), "s2": []byte("s2")},
+		MyBytes4:   MapStrings{"s1": "s1", "s2": "s2"},
+	})
+	wantFS := want.KnownFields()
+
+	testMessage(t, nil, mi.MessageOf(&MapScalars{}), messageOps{
+		hasFields{1: false, 2: false, 3: false, 4: false, 5: false, 6: false, 7: false, 8: false, 9: false, 10: false, 11: false, 12: false, 13: false, 14: false, 15: false, 16: false, 17: false, 18: false, 19: false, 20: false, 21: false, 22: false, 23: false, 24: false, 25: false},
+		getFields{1: emptyFS.Get(1), 3: emptyFS.Get(3), 5: emptyFS.Get(5), 7: emptyFS.Get(7), 9: emptyFS.Get(9), 11: emptyFS.Get(11), 13: emptyFS.Get(13), 15: emptyFS.Get(15), 17: emptyFS.Get(17), 19: emptyFS.Get(19), 21: emptyFS.Get(21), 23: emptyFS.Get(23), 25: emptyFS.Get(25)},
+		setFields{1: wantFS.Get(1), 3: wantFS.Get(3), 5: wantFS.Get(5), 7: wantFS.Get(7), 9: wantFS.Get(9), 11: wantFS.Get(11), 13: wantFS.Get(13), 15: wantFS.Get(15), 17: wantFS.Get(17), 19: wantFS.Get(19), 21: wantFS.Get(21), 23: wantFS.Get(23), 25: wantFS.Get(25)},
+		mapFields{
+			2: {
+				lenMap(0),
+				hasMap{int32(0): false, int32(-1): false, int32(2): false},
+				setMap{int32(0): V("zero")},
+				lenMap(1),
+				hasMap{int32(0): true, int32(-1): false, int32(2): false},
+				setMap{int32(-1): V("one")},
+				lenMap(2),
+				hasMap{int32(0): true, int32(-1): true, int32(2): false},
+				setMap{int32(2): V("two")},
+				lenMap(3),
+				hasMap{int32(0): true, int32(-1): true, int32(2): true},
+			},
+			4: {
+				setMap{uint32(0): V("zero"), uint32(1): V("one"), uint32(2): V("two")},
+				equalMap(wantFS.Get(4).Map()),
+			},
+			6: {
+				clearMap{"noexist": true},
+				setMap{"foo": V("bar")},
+				setMap{"": V("empty")},
+				getMap{"": V("empty"), "foo": V("bar"), "noexist": V(nil)},
+				setMap{"": V(""), "extra": V("extra")},
+				clearMap{"extra": true, "noexist": true},
+			},
+			8: {
+				equalMap(emptyFS.Get(8).Map()),
+				setMap{"one": V(int32(1)), "two": V(int32(2)), "three": V(int32(3))},
+			},
+			10: {
+				setMap{"0x00": V(uint32(0x00)), "0xff": V(uint32(0xff)), "0xdead": V(uint32(0xdead))},
+				lenMap(3),
+				equalMap(wantFS.Get(10).Map()),
+				getMap{"0x00": V(uint32(0x00)), "0xff": V(uint32(0xff)), "0xdead": V(uint32(0xdead)), "0xdeadbeef": V(nil)},
+			},
+			12: {
+				setMap{"nan": V(float32(math.NaN())), "pi": V(float32(math.Pi)), "e": V(float32(math.E))},
+				clearMap{"e": true, "phi": true},
+				rangeMap{"nan": V(float32(math.NaN())), "pi": V(float32(math.Pi))},
+			},
+			14: {
+				equalMap(emptyFS.Get(14).Map()),
+				setMap{"s1": V("s1"), "s2": V("s2")},
+			},
+			16: {
+				setMap{"s1": V([]byte("s1")), "s2": V([]byte("s2"))},
+				equalMap(wantFS.Get(16).Map()),
+			},
+			18: {
+				hasMap{"s1": false, "s2": false, "s3": false},
+				setMap{"s1": V("s1"), "s2": V("s2")},
+				hasMap{"s1": true, "s2": true, "s3": false},
+			},
+			20: {
+				equalMap(emptyFS.Get(20).Map()),
+				setMap{"s1": V([]byte("s1")), "s2": V([]byte("s2"))},
+			},
+			22: {
+				rangeMap{},
+				setMap{"s1": V("s1"), "s2": V("s2")},
+				rangeMap{"s1": V("s1"), "s2": V("s2")},
+				lenMap(2),
+			},
+			24: {
+				setMap{"s1": V([]byte("s1")), "s2": V([]byte("s2"))},
+				equalMap(wantFS.Get(24).Map()),
+			},
+		},
+		hasFields{1: true, 2: true, 3: true, 4: true, 5: true, 6: true, 7: true, 8: true, 9: true, 10: true, 11: true, 12: true, 13: true, 14: true, 15: true, 16: true, 17: true, 18: true, 19: true, 20: true, 21: true, 22: true, 23: true, 24: true, 25: true},
+		equalMessage(want),
+		clearFields{1: true, 2: true, 3: true, 4: true, 5: true, 6: true, 7: true, 8: true, 9: true, 10: true, 11: true, 12: true, 13: true, 14: true, 15: true, 16: true, 17: true, 18: true, 19: true, 20: true, 21: true, 22: true, 23: true, 24: true, 25: true},
+		equalMessage(empty),
 	})
 }
 
@@ -465,6 +679,12 @@
 				testVectors(t, p, fs.Mutable(n).(pref.Vector), tt)
 				p.Pop()
 			}
+		case mapFields:
+			for n, tt := range op {
+				p.Push(int(n))
+				testMaps(t, p, fs.Mutable(n).(pref.Map), tt)
+				p.Pop()
+			}
 		default:
 			t.Fatalf("operation %v, invalid operation: %T", p, op)
 		}
@@ -510,6 +730,63 @@
 	}
 }
 
+func testMaps(t *testing.T, p path, m pref.Map, tt mapOps) {
+	for i, op := range tt {
+		p.Push(i)
+		switch op := op.(type) {
+		case equalMap:
+			if diff := cmp.Diff(op, m, cmpOpts); diff != "" {
+				t.Errorf("operation %v, map mismatch (-want, +got):\n%s", p, diff)
+			}
+		case lenMap:
+			if got, want := m.Len(), int(op); got != want {
+				t.Errorf("operation %v, Map.Len = %d, want %d", p, got, want)
+			}
+		case hasMap:
+			got := map[interface{}]bool{}
+			want := map[interface{}]bool(op)
+			for k := range want {
+				got[k] = m.Has(V(k).MapKey())
+			}
+			if diff := cmp.Diff(want, got, cmpOpts); diff != "" {
+				t.Errorf("operation %v, Map.Has mismatch (-want, +got):\n%s", p, diff)
+			}
+		case getMap:
+			got := map[interface{}]pref.Value{}
+			want := map[interface{}]pref.Value(op)
+			for k := range want {
+				got[k] = m.Get(V(k).MapKey())
+			}
+			if diff := cmp.Diff(want, got, cmpOpts); diff != "" {
+				t.Errorf("operation %v, Map.Get mismatch (-want, +got):\n%s", p, diff)
+			}
+		case setMap:
+			for k, v := range op {
+				m.Set(V(k).MapKey(), v)
+			}
+		case clearMap:
+			for v, ok := range op {
+				if ok {
+					m.Clear(V(v).MapKey())
+				}
+			}
+		case rangeMap:
+			got := map[interface{}]pref.Value{}
+			want := map[interface{}]pref.Value(op)
+			m.Range(func(k pref.MapKey, v pref.Value) bool {
+				got[k.Interface()] = v
+				return true
+			})
+			if diff := cmp.Diff(want, got, cmpOpts); diff != "" {
+				t.Errorf("operation %v, Map.Range mismatch (-want, +got):\n%s", p, diff)
+			}
+		default:
+			t.Fatalf("operation %v, invalid operation: %T", p, op)
+		}
+		p.Pop()
+	}
+}
+
 type path []int
 
 func (p *path) Push(i int) { *p = append(*p, i) }