internal/impl: implement Map fields
Generate functions for wrapping map[K]V to implement protoreflect.Map.
This implementation uses Go reflection instead to provide a single implementation
that can handle all Go map types.
Change-Id: Idcb8069ef836614a88e5df12ef7c5044e8aa3dea
Reviewed-on: https://go-review.googlesource.com/c/142778
Reviewed-by: Damien Neil <dneil@google.com>
diff --git a/internal/impl/message_test.go b/internal/impl/message_test.go
index c24fb04..a789247 100644
--- a/internal/impl/message_test.go
+++ b/internal/impl/message_test.go
@@ -40,6 +40,9 @@
NamedStrings []MyString
NamedBytes []MyBytes
+
+ MapStrings map[MyString]MyString
+ MapBytes map[MyString]MyBytes
)
// List of test operations to perform on messages, vectors, or maps.
@@ -50,8 +53,8 @@
vectorOp interface{} // equalVector | lenVector | getVector | setVector | appendVector | truncVector
vectorOps []vectorOp
- mapOp interface{} // TODO
- mapOps []mapOp // TODO
+ mapOp interface{} // equalMap | lenMap | hasMap | getMap | setMap | clearMap | rangeMap
+ mapOps []mapOp
)
// Test operations performed on a message.
@@ -78,6 +81,18 @@
// TODO: Mutable, MutableAppend
)
+// Test operations performed on a map.
+type (
+ equalMap pref.Map
+ lenMap int
+ hasMap map[interface{}]bool
+ getMap map[interface{}]pref.Value
+ setMap map[interface{}]pref.Value
+ clearMap map[interface{}]bool
+ rangeMap map[interface{}]pref.Value
+ // TODO: List, Mutable
+)
+
func TestScalarProto2(t *testing.T) {
type ScalarProto2 struct {
Bool *bool `protobuf:"1"`
@@ -396,7 +411,206 @@
hasFields{1: true, 2: true, 3: true, 4: true, 5: true, 6: true, 7: true, 8: true, 9: true, 10: true, 11: true, 12: true, 13: true, 14: true, 15: true, 16: true, 17: true, 18: true, 19: true},
equalMessage(want),
clearFields{1: true, 2: true, 3: true, 4: true, 5: true, 6: true, 7: true, 8: true, 9: true, 10: true, 11: true, 12: true, 13: true, 14: true, 15: true, 16: true, 17: true, 18: true, 19: true},
- equalMessage(mi.MessageOf(&RepeatedScalars{})),
+ equalMessage(empty),
+ })
+}
+
+func TestMapScalars(t *testing.T) {
+ type MapScalars struct {
+ KeyBools map[bool]string `protobuf:"1"`
+ KeyInt32s map[int32]string `protobuf:"2"`
+ KeyInt64s map[int64]string `protobuf:"3"`
+ KeyUint32s map[uint32]string `protobuf:"4"`
+ KeyUint64s map[uint64]string `protobuf:"5"`
+ KeyStrings map[string]string `protobuf:"6"`
+
+ ValBools map[string]bool `protobuf:"7"`
+ ValInt32s map[string]int32 `protobuf:"8"`
+ ValInt64s map[string]int64 `protobuf:"9"`
+ ValUint32s map[string]uint32 `protobuf:"10"`
+ ValUint64s map[string]uint64 `protobuf:"11"`
+ ValFloat32s map[string]float32 `protobuf:"12"`
+ ValFloat64s map[string]float64 `protobuf:"13"`
+ ValStrings map[string]string `protobuf:"14"`
+ ValStringsA map[string][]byte `protobuf:"15"`
+ ValBytes map[string][]byte `protobuf:"16"`
+ ValBytesA map[string]string `protobuf:"17"`
+
+ MyStrings1 map[MyString]MyString `protobuf:"18"`
+ MyStrings2 map[MyString]MyBytes `protobuf:"19"`
+ MyBytes1 map[MyString]MyBytes `protobuf:"20"`
+ MyBytes2 map[MyString]MyString `protobuf:"21"`
+
+ MyStrings3 MapStrings `protobuf:"22"`
+ MyStrings4 MapBytes `protobuf:"23"`
+ MyBytes3 MapBytes `protobuf:"24"`
+ MyBytes4 MapStrings `protobuf:"25"`
+ }
+
+ mustMakeMapEntry := func(n pref.FieldNumber, keyKind, valKind pref.Kind) ptype.Field {
+ return ptype.Field{
+ Name: pref.Name(fmt.Sprintf("f%d", n)),
+ Number: n,
+ Cardinality: pref.Repeated,
+ Kind: pref.MessageKind,
+ MessageType: mustMakeMessageDesc(ptype.StandaloneMessage{
+ Syntax: pref.Proto2,
+ FullName: pref.FullName(fmt.Sprintf("MapScalars.F%dEntry", n)),
+ Fields: []ptype.Field{
+ {Name: "key", Number: 1, Cardinality: pref.Optional, Kind: keyKind},
+ {Name: "value", Number: 2, Cardinality: pref.Optional, Kind: valKind},
+ },
+ IsMapEntry: true,
+ }),
+ }
+ }
+ mi := MessageType{Desc: mustMakeMessageDesc(ptype.StandaloneMessage{
+ Syntax: pref.Proto2,
+ FullName: "MapScalars",
+ Fields: []ptype.Field{
+ mustMakeMapEntry(1, pref.BoolKind, pref.StringKind),
+ mustMakeMapEntry(2, pref.Int32Kind, pref.StringKind),
+ mustMakeMapEntry(3, pref.Int64Kind, pref.StringKind),
+ mustMakeMapEntry(4, pref.Uint32Kind, pref.StringKind),
+ mustMakeMapEntry(5, pref.Uint64Kind, pref.StringKind),
+ mustMakeMapEntry(6, pref.StringKind, pref.StringKind),
+
+ mustMakeMapEntry(7, pref.StringKind, pref.BoolKind),
+ mustMakeMapEntry(8, pref.StringKind, pref.Int32Kind),
+ mustMakeMapEntry(9, pref.StringKind, pref.Int64Kind),
+ mustMakeMapEntry(10, pref.StringKind, pref.Uint32Kind),
+ mustMakeMapEntry(11, pref.StringKind, pref.Uint64Kind),
+ mustMakeMapEntry(12, pref.StringKind, pref.FloatKind),
+ mustMakeMapEntry(13, pref.StringKind, pref.DoubleKind),
+ mustMakeMapEntry(14, pref.StringKind, pref.StringKind),
+ mustMakeMapEntry(15, pref.StringKind, pref.StringKind),
+ mustMakeMapEntry(16, pref.StringKind, pref.BytesKind),
+ mustMakeMapEntry(17, pref.StringKind, pref.BytesKind),
+
+ mustMakeMapEntry(18, pref.StringKind, pref.StringKind),
+ mustMakeMapEntry(19, pref.StringKind, pref.StringKind),
+ mustMakeMapEntry(20, pref.StringKind, pref.BytesKind),
+ mustMakeMapEntry(21, pref.StringKind, pref.BytesKind),
+
+ mustMakeMapEntry(22, pref.StringKind, pref.StringKind),
+ mustMakeMapEntry(23, pref.StringKind, pref.StringKind),
+ mustMakeMapEntry(24, pref.StringKind, pref.BytesKind),
+ mustMakeMapEntry(25, pref.StringKind, pref.BytesKind),
+ },
+ })}
+
+ empty := mi.MessageOf(&MapScalars{})
+ emptyFS := empty.KnownFields()
+
+ want := mi.MessageOf(&MapScalars{
+ KeyBools: map[bool]string{true: "true", false: "false"},
+ KeyInt32s: map[int32]string{0: "zero", -1: "one", 2: "two"},
+ KeyInt64s: map[int64]string{0: "zero", -10: "ten", 20: "twenty"},
+ KeyUint32s: map[uint32]string{0: "zero", 1: "one", 2: "two"},
+ KeyUint64s: map[uint64]string{0: "zero", 10: "ten", 20: "twenty"},
+ KeyStrings: map[string]string{"": "", "foo": "bar"},
+
+ ValBools: map[string]bool{"true": true, "false": false},
+ ValInt32s: map[string]int32{"one": 1, "two": 2, "three": 3},
+ ValInt64s: map[string]int64{"ten": 10, "twenty": -20, "thirty": 30},
+ ValUint32s: map[string]uint32{"0x00": 0x00, "0xff": 0xff, "0xdead": 0xdead},
+ ValUint64s: map[string]uint64{"0x00": 0x00, "0xff": 0xff, "0xdead": 0xdead},
+ ValFloat32s: map[string]float32{"nan": float32(math.NaN()), "pi": float32(math.Pi)},
+ ValFloat64s: map[string]float64{"nan": float64(math.NaN()), "pi": float64(math.Pi)},
+ ValStrings: map[string]string{"s1": "s1", "s2": "s2"},
+ ValStringsA: map[string][]byte{"s1": []byte("s1"), "s2": []byte("s2")},
+ ValBytes: map[string][]byte{"s1": []byte("s1"), "s2": []byte("s2")},
+ ValBytesA: map[string]string{"s1": "s1", "s2": "s2"},
+
+ MyStrings1: map[MyString]MyString{"s1": "s1", "s2": "s2"},
+ MyStrings2: map[MyString]MyBytes{"s1": []byte("s1"), "s2": []byte("s2")},
+ MyBytes1: map[MyString]MyBytes{"s1": []byte("s1"), "s2": []byte("s2")},
+ MyBytes2: map[MyString]MyString{"s1": "s1", "s2": "s2"},
+
+ MyStrings3: MapStrings{"s1": "s1", "s2": "s2"},
+ MyStrings4: MapBytes{"s1": []byte("s1"), "s2": []byte("s2")},
+ MyBytes3: MapBytes{"s1": []byte("s1"), "s2": []byte("s2")},
+ MyBytes4: MapStrings{"s1": "s1", "s2": "s2"},
+ })
+ wantFS := want.KnownFields()
+
+ testMessage(t, nil, mi.MessageOf(&MapScalars{}), messageOps{
+ hasFields{1: false, 2: false, 3: false, 4: false, 5: false, 6: false, 7: false, 8: false, 9: false, 10: false, 11: false, 12: false, 13: false, 14: false, 15: false, 16: false, 17: false, 18: false, 19: false, 20: false, 21: false, 22: false, 23: false, 24: false, 25: false},
+ getFields{1: emptyFS.Get(1), 3: emptyFS.Get(3), 5: emptyFS.Get(5), 7: emptyFS.Get(7), 9: emptyFS.Get(9), 11: emptyFS.Get(11), 13: emptyFS.Get(13), 15: emptyFS.Get(15), 17: emptyFS.Get(17), 19: emptyFS.Get(19), 21: emptyFS.Get(21), 23: emptyFS.Get(23), 25: emptyFS.Get(25)},
+ setFields{1: wantFS.Get(1), 3: wantFS.Get(3), 5: wantFS.Get(5), 7: wantFS.Get(7), 9: wantFS.Get(9), 11: wantFS.Get(11), 13: wantFS.Get(13), 15: wantFS.Get(15), 17: wantFS.Get(17), 19: wantFS.Get(19), 21: wantFS.Get(21), 23: wantFS.Get(23), 25: wantFS.Get(25)},
+ mapFields{
+ 2: {
+ lenMap(0),
+ hasMap{int32(0): false, int32(-1): false, int32(2): false},
+ setMap{int32(0): V("zero")},
+ lenMap(1),
+ hasMap{int32(0): true, int32(-1): false, int32(2): false},
+ setMap{int32(-1): V("one")},
+ lenMap(2),
+ hasMap{int32(0): true, int32(-1): true, int32(2): false},
+ setMap{int32(2): V("two")},
+ lenMap(3),
+ hasMap{int32(0): true, int32(-1): true, int32(2): true},
+ },
+ 4: {
+ setMap{uint32(0): V("zero"), uint32(1): V("one"), uint32(2): V("two")},
+ equalMap(wantFS.Get(4).Map()),
+ },
+ 6: {
+ clearMap{"noexist": true},
+ setMap{"foo": V("bar")},
+ setMap{"": V("empty")},
+ getMap{"": V("empty"), "foo": V("bar"), "noexist": V(nil)},
+ setMap{"": V(""), "extra": V("extra")},
+ clearMap{"extra": true, "noexist": true},
+ },
+ 8: {
+ equalMap(emptyFS.Get(8).Map()),
+ setMap{"one": V(int32(1)), "two": V(int32(2)), "three": V(int32(3))},
+ },
+ 10: {
+ setMap{"0x00": V(uint32(0x00)), "0xff": V(uint32(0xff)), "0xdead": V(uint32(0xdead))},
+ lenMap(3),
+ equalMap(wantFS.Get(10).Map()),
+ getMap{"0x00": V(uint32(0x00)), "0xff": V(uint32(0xff)), "0xdead": V(uint32(0xdead)), "0xdeadbeef": V(nil)},
+ },
+ 12: {
+ setMap{"nan": V(float32(math.NaN())), "pi": V(float32(math.Pi)), "e": V(float32(math.E))},
+ clearMap{"e": true, "phi": true},
+ rangeMap{"nan": V(float32(math.NaN())), "pi": V(float32(math.Pi))},
+ },
+ 14: {
+ equalMap(emptyFS.Get(14).Map()),
+ setMap{"s1": V("s1"), "s2": V("s2")},
+ },
+ 16: {
+ setMap{"s1": V([]byte("s1")), "s2": V([]byte("s2"))},
+ equalMap(wantFS.Get(16).Map()),
+ },
+ 18: {
+ hasMap{"s1": false, "s2": false, "s3": false},
+ setMap{"s1": V("s1"), "s2": V("s2")},
+ hasMap{"s1": true, "s2": true, "s3": false},
+ },
+ 20: {
+ equalMap(emptyFS.Get(20).Map()),
+ setMap{"s1": V([]byte("s1")), "s2": V([]byte("s2"))},
+ },
+ 22: {
+ rangeMap{},
+ setMap{"s1": V("s1"), "s2": V("s2")},
+ rangeMap{"s1": V("s1"), "s2": V("s2")},
+ lenMap(2),
+ },
+ 24: {
+ setMap{"s1": V([]byte("s1")), "s2": V([]byte("s2"))},
+ equalMap(wantFS.Get(24).Map()),
+ },
+ },
+ hasFields{1: true, 2: true, 3: true, 4: true, 5: true, 6: true, 7: true, 8: true, 9: true, 10: true, 11: true, 12: true, 13: true, 14: true, 15: true, 16: true, 17: true, 18: true, 19: true, 20: true, 21: true, 22: true, 23: true, 24: true, 25: true},
+ equalMessage(want),
+ clearFields{1: true, 2: true, 3: true, 4: true, 5: true, 6: true, 7: true, 8: true, 9: true, 10: true, 11: true, 12: true, 13: true, 14: true, 15: true, 16: true, 17: true, 18: true, 19: true, 20: true, 21: true, 22: true, 23: true, 24: true, 25: true},
+ equalMessage(empty),
})
}
@@ -465,6 +679,12 @@
testVectors(t, p, fs.Mutable(n).(pref.Vector), tt)
p.Pop()
}
+ case mapFields:
+ for n, tt := range op {
+ p.Push(int(n))
+ testMaps(t, p, fs.Mutable(n).(pref.Map), tt)
+ p.Pop()
+ }
default:
t.Fatalf("operation %v, invalid operation: %T", p, op)
}
@@ -510,6 +730,63 @@
}
}
+func testMaps(t *testing.T, p path, m pref.Map, tt mapOps) {
+ for i, op := range tt {
+ p.Push(i)
+ switch op := op.(type) {
+ case equalMap:
+ if diff := cmp.Diff(op, m, cmpOpts); diff != "" {
+ t.Errorf("operation %v, map mismatch (-want, +got):\n%s", p, diff)
+ }
+ case lenMap:
+ if got, want := m.Len(), int(op); got != want {
+ t.Errorf("operation %v, Map.Len = %d, want %d", p, got, want)
+ }
+ case hasMap:
+ got := map[interface{}]bool{}
+ want := map[interface{}]bool(op)
+ for k := range want {
+ got[k] = m.Has(V(k).MapKey())
+ }
+ if diff := cmp.Diff(want, got, cmpOpts); diff != "" {
+ t.Errorf("operation %v, Map.Has mismatch (-want, +got):\n%s", p, diff)
+ }
+ case getMap:
+ got := map[interface{}]pref.Value{}
+ want := map[interface{}]pref.Value(op)
+ for k := range want {
+ got[k] = m.Get(V(k).MapKey())
+ }
+ if diff := cmp.Diff(want, got, cmpOpts); diff != "" {
+ t.Errorf("operation %v, Map.Get mismatch (-want, +got):\n%s", p, diff)
+ }
+ case setMap:
+ for k, v := range op {
+ m.Set(V(k).MapKey(), v)
+ }
+ case clearMap:
+ for v, ok := range op {
+ if ok {
+ m.Clear(V(v).MapKey())
+ }
+ }
+ case rangeMap:
+ got := map[interface{}]pref.Value{}
+ want := map[interface{}]pref.Value(op)
+ m.Range(func(k pref.MapKey, v pref.Value) bool {
+ got[k.Interface()] = v
+ return true
+ })
+ if diff := cmp.Diff(want, got, cmpOpts); diff != "" {
+ t.Errorf("operation %v, Map.Range mismatch (-want, +got):\n%s", p, diff)
+ }
+ default:
+ t.Fatalf("operation %v, invalid operation: %T", p, op)
+ }
+ p.Pop()
+ }
+}
+
type path []int
func (p *path) Push(i int) { *p = append(*p, i) }