blob: 35468b4b4f8f65f24b8483191966ed2b0d0b1bb6 [file] [log] [blame]
Guido van Rossum1968ad32006-02-25 22:38:04 +00001"""Unit tests for collections.defaultdict."""
2
3import os
4import copy
Amaury Forgeot d'Arcf43ee812008-10-30 20:58:42 +00005import pickle
Guido van Rossum1968ad32006-02-25 22:38:04 +00006import tempfile
7import unittest
Benjamin Petersonee8712c2008-05-20 21:35:26 +00008from test import support
Guido van Rossum1968ad32006-02-25 22:38:04 +00009
10from collections import defaultdict
11
12def foobar():
13 return list
14
15class TestDefaultDict(unittest.TestCase):
16
17 def test_basic(self):
18 d1 = defaultdict()
19 self.assertEqual(d1.default_factory, None)
20 d1.default_factory = list
21 d1[12].append(42)
22 self.assertEqual(d1, {12: [42]})
23 d1[12].append(24)
24 self.assertEqual(d1, {12: [42, 24]})
25 d1[13]
26 d1[14]
27 self.assertEqual(d1, {12: [42, 24], 13: [], 14: []})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +000028 self.assertTrue(d1[12] is not d1[13] is not d1[14])
Guido van Rossum1968ad32006-02-25 22:38:04 +000029 d2 = defaultdict(list, foo=1, bar=2)
30 self.assertEqual(d2.default_factory, list)
31 self.assertEqual(d2, {"foo": 1, "bar": 2})
32 self.assertEqual(d2["foo"], 1)
33 self.assertEqual(d2["bar"], 2)
34 self.assertEqual(d2[42], [])
Benjamin Petersonc9c0f202009-06-30 23:06:06 +000035 self.assertTrue("foo" in d2)
36 self.assertTrue("foo" in d2.keys())
37 self.assertTrue("bar" in d2)
38 self.assertTrue("bar" in d2.keys())
39 self.assertTrue(42 in d2)
40 self.assertTrue(42 in d2.keys())
41 self.assertTrue(12 not in d2)
42 self.assertTrue(12 not in d2.keys())
Guido van Rossum1968ad32006-02-25 22:38:04 +000043 d2.default_factory = None
44 self.assertEqual(d2.default_factory, None)
45 try:
46 d2[15]
Guido van Rossumb940e112007-01-10 16:19:56 +000047 except KeyError as err:
Guido van Rossum1968ad32006-02-25 22:38:04 +000048 self.assertEqual(err.args, (15,))
49 else:
50 self.fail("d2[15] didn't raise KeyError")
Thomas Wouterscf297e42007-02-23 15:07:44 +000051 self.assertRaises(TypeError, defaultdict, 1)
Guido van Rossum1968ad32006-02-25 22:38:04 +000052
53 def test_missing(self):
54 d1 = defaultdict()
55 self.assertRaises(KeyError, d1.__missing__, 42)
56 d1.default_factory = list
57 self.assertEqual(d1.__missing__(42), [])
58
59 def test_repr(self):
60 d1 = defaultdict()
61 self.assertEqual(d1.default_factory, None)
62 self.assertEqual(repr(d1), "defaultdict(None, {})")
63 d1[11] = 41
64 self.assertEqual(repr(d1), "defaultdict(None, {11: 41})")
Thomas Wouterscf297e42007-02-23 15:07:44 +000065 d2 = defaultdict(int)
66 self.assertEqual(d2.default_factory, int)
Guido van Rossum1968ad32006-02-25 22:38:04 +000067 d2[12] = 42
Martin v. Löwis250ad612008-04-07 05:43:42 +000068 self.assertEqual(repr(d2), "defaultdict(<class 'int'>, {12: 42})")
Guido van Rossum1968ad32006-02-25 22:38:04 +000069 def foo(): return 43
70 d3 = defaultdict(foo)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +000071 self.assertTrue(d3.default_factory is foo)
Guido van Rossum1968ad32006-02-25 22:38:04 +000072 d3[13]
73 self.assertEqual(repr(d3), "defaultdict(%s, {13: 43})" % repr(foo))
74
75 def test_print(self):
76 d1 = defaultdict()
77 def foo(): return 42
78 d2 = defaultdict(foo, {1: 2})
79 # NOTE: We can't use tempfile.[Named]TemporaryFile since this
80 # code must exercise the tp_print C code, which only gets
81 # invoked for *real* files.
82 tfn = tempfile.mktemp()
83 try:
84 f = open(tfn, "w+")
85 try:
Guido van Rossumbe19ed72007-02-09 05:37:30 +000086 print(d1, file=f)
87 print(d2, file=f)
Guido van Rossum1968ad32006-02-25 22:38:04 +000088 f.seek(0)
89 self.assertEqual(f.readline(), repr(d1) + "\n")
90 self.assertEqual(f.readline(), repr(d2) + "\n")
91 finally:
92 f.close()
93 finally:
94 os.remove(tfn)
95
96 def test_copy(self):
97 d1 = defaultdict()
98 d2 = d1.copy()
99 self.assertEqual(type(d2), defaultdict)
100 self.assertEqual(d2.default_factory, None)
101 self.assertEqual(d2, {})
102 d1.default_factory = list
103 d3 = d1.copy()
104 self.assertEqual(type(d3), defaultdict)
105 self.assertEqual(d3.default_factory, list)
106 self.assertEqual(d3, {})
107 d1[42]
108 d4 = d1.copy()
109 self.assertEqual(type(d4), defaultdict)
110 self.assertEqual(d4.default_factory, list)
111 self.assertEqual(d4, {42: []})
112 d4[12]
113 self.assertEqual(d4, {42: [], 12: []})
114
115 def test_shallow_copy(self):
116 d1 = defaultdict(foobar, {1: 1})
117 d2 = copy.copy(d1)
118 self.assertEqual(d2.default_factory, foobar)
119 self.assertEqual(d2, d1)
120 d1.default_factory = list
121 d2 = copy.copy(d1)
122 self.assertEqual(d2.default_factory, list)
123 self.assertEqual(d2, d1)
124
125 def test_deep_copy(self):
126 d1 = defaultdict(foobar, {1: [1]})
127 d2 = copy.deepcopy(d1)
128 self.assertEqual(d2.default_factory, foobar)
129 self.assertEqual(d2, d1)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000130 self.assertTrue(d1[1] is not d2[1])
Guido van Rossum1968ad32006-02-25 22:38:04 +0000131 d1.default_factory = list
132 d2 = copy.deepcopy(d1)
133 self.assertEqual(d2.default_factory, list)
134 self.assertEqual(d2, d1)
135
Guido van Rossumd8faa362007-04-27 19:54:29 +0000136 def test_keyerror_without_factory(self):
137 d1 = defaultdict()
138 try:
139 d1[(1,)]
140 except KeyError as err:
Guido van Rossum360e4b82007-05-14 22:51:27 +0000141 self.assertEqual(err.args[0], (1,))
Guido van Rossumd8faa362007-04-27 19:54:29 +0000142 else:
143 self.fail("expected KeyError")
144
Christian Heimes77c02eb2008-02-09 02:18:51 +0000145 def test_recursive_repr(self):
146 # Issue2045: stack overflow when default_factory is a bound method
147 class sub(defaultdict):
148 def __init__(self):
149 self.default_factory = self._factory
150 def _factory(self):
151 return []
152 d = sub()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000153 self.assertTrue(repr(d).startswith(
Christian Heimes77c02eb2008-02-09 02:18:51 +0000154 "defaultdict(<bound method sub._factory of defaultdict(..."))
155
156 # NOTE: printing a subclass of a builtin type does not call its
157 # tp_print slot. So this part is essentially the same test as above.
158 tfn = tempfile.mktemp()
159 try:
160 f = open(tfn, "w+")
161 try:
162 print(d, file=f)
163 finally:
164 f.close()
165 finally:
166 os.remove(tfn)
167
Amaury Forgeot d'Arcf43ee812008-10-30 20:58:42 +0000168 def test_pickleing(self):
169 d = defaultdict(int)
170 d[1]
171 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
172 s = pickle.dumps(d, proto)
173 o = pickle.loads(s)
174 self.assertEqual(d, o)
Guido van Rossum1968ad32006-02-25 22:38:04 +0000175
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000176def test_main():
Benjamin Petersonee8712c2008-05-20 21:35:26 +0000177 support.run_unittest(TestDefaultDict)
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000178
Guido van Rossum1968ad32006-02-25 22:38:04 +0000179if __name__ == "__main__":
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000180 test_main()