blob: e5a9bd5df743a7e802b9db7ef0d49bb4ab51b726 [file] [log] [blame]
Guido van Rossum1968ad32006-02-25 22:38:04 +00001"""Unit tests for collections.defaultdict."""
2
3import os
4import copy
Amaury Forgeot d'Arcf43ee812008-10-30 20:58:42 +00005import pickle
Guido van Rossum1968ad32006-02-25 22:38:04 +00006import tempfile
7import unittest
Benjamin Petersonee8712c2008-05-20 21:35:26 +00008from test import support
Guido van Rossum1968ad32006-02-25 22:38:04 +00009
10from collections import defaultdict
11
12def foobar():
13 return list
14
15class TestDefaultDict(unittest.TestCase):
16
17 def test_basic(self):
18 d1 = defaultdict()
19 self.assertEqual(d1.default_factory, None)
20 d1.default_factory = list
21 d1[12].append(42)
22 self.assertEqual(d1, {12: [42]})
23 d1[12].append(24)
24 self.assertEqual(d1, {12: [42, 24]})
25 d1[13]
26 d1[14]
27 self.assertEqual(d1, {12: [42, 24], 13: [], 14: []})
Georg Brandlab91fde2009-08-13 08:51:18 +000028 self.assertTrue(d1[12] is not d1[13] is not d1[14])
Guido van Rossum1968ad32006-02-25 22:38:04 +000029 d2 = defaultdict(list, foo=1, bar=2)
30 self.assertEqual(d2.default_factory, list)
31 self.assertEqual(d2, {"foo": 1, "bar": 2})
32 self.assertEqual(d2["foo"], 1)
33 self.assertEqual(d2["bar"], 2)
34 self.assertEqual(d2[42], [])
Georg Brandlab91fde2009-08-13 08:51:18 +000035 self.assertTrue("foo" in d2)
36 self.assertTrue("foo" in d2.keys())
37 self.assertTrue("bar" in d2)
38 self.assertTrue("bar" in d2.keys())
39 self.assertTrue(42 in d2)
40 self.assertTrue(42 in d2.keys())
41 self.assertTrue(12 not in d2)
42 self.assertTrue(12 not in d2.keys())
Guido van Rossum1968ad32006-02-25 22:38:04 +000043 d2.default_factory = None
44 self.assertEqual(d2.default_factory, None)
45 try:
46 d2[15]
Guido van Rossumb940e112007-01-10 16:19:56 +000047 except KeyError as err:
Guido van Rossum1968ad32006-02-25 22:38:04 +000048 self.assertEqual(err.args, (15,))
49 else:
50 self.fail("d2[15] didn't raise KeyError")
Thomas Wouterscf297e42007-02-23 15:07:44 +000051 self.assertRaises(TypeError, defaultdict, 1)
Guido van Rossum1968ad32006-02-25 22:38:04 +000052
53 def test_missing(self):
54 d1 = defaultdict()
55 self.assertRaises(KeyError, d1.__missing__, 42)
56 d1.default_factory = list
57 self.assertEqual(d1.__missing__(42), [])
58
59 def test_repr(self):
60 d1 = defaultdict()
61 self.assertEqual(d1.default_factory, None)
62 self.assertEqual(repr(d1), "defaultdict(None, {})")
Raymond Hettinger99a13ee2009-08-04 19:13:37 +000063 self.assertEqual(eval(repr(d1)), d1)
Guido van Rossum1968ad32006-02-25 22:38:04 +000064 d1[11] = 41
65 self.assertEqual(repr(d1), "defaultdict(None, {11: 41})")
Thomas Wouterscf297e42007-02-23 15:07:44 +000066 d2 = defaultdict(int)
67 self.assertEqual(d2.default_factory, int)
Guido van Rossum1968ad32006-02-25 22:38:04 +000068 d2[12] = 42
Martin v. Löwis250ad612008-04-07 05:43:42 +000069 self.assertEqual(repr(d2), "defaultdict(<class 'int'>, {12: 42})")
Guido van Rossum1968ad32006-02-25 22:38:04 +000070 def foo(): return 43
71 d3 = defaultdict(foo)
Georg Brandlab91fde2009-08-13 08:51:18 +000072 self.assertTrue(d3.default_factory is foo)
Guido van Rossum1968ad32006-02-25 22:38:04 +000073 d3[13]
74 self.assertEqual(repr(d3), "defaultdict(%s, {13: 43})" % repr(foo))
75
76 def test_print(self):
77 d1 = defaultdict()
78 def foo(): return 42
79 d2 = defaultdict(foo, {1: 2})
80 # NOTE: We can't use tempfile.[Named]TemporaryFile since this
81 # code must exercise the tp_print C code, which only gets
82 # invoked for *real* files.
83 tfn = tempfile.mktemp()
84 try:
85 f = open(tfn, "w+")
86 try:
Guido van Rossumbe19ed72007-02-09 05:37:30 +000087 print(d1, file=f)
88 print(d2, file=f)
Guido van Rossum1968ad32006-02-25 22:38:04 +000089 f.seek(0)
90 self.assertEqual(f.readline(), repr(d1) + "\n")
91 self.assertEqual(f.readline(), repr(d2) + "\n")
92 finally:
93 f.close()
94 finally:
95 os.remove(tfn)
96
97 def test_copy(self):
98 d1 = defaultdict()
99 d2 = d1.copy()
100 self.assertEqual(type(d2), defaultdict)
101 self.assertEqual(d2.default_factory, None)
102 self.assertEqual(d2, {})
103 d1.default_factory = list
104 d3 = d1.copy()
105 self.assertEqual(type(d3), defaultdict)
106 self.assertEqual(d3.default_factory, list)
107 self.assertEqual(d3, {})
108 d1[42]
109 d4 = d1.copy()
110 self.assertEqual(type(d4), defaultdict)
111 self.assertEqual(d4.default_factory, list)
112 self.assertEqual(d4, {42: []})
113 d4[12]
114 self.assertEqual(d4, {42: [], 12: []})
115
Raymond Hettinger99a13ee2009-08-04 19:13:37 +0000116 # Issue 6637: Copy fails for empty default dict
117 d = defaultdict()
118 d['a'] = 42
119 e = d.copy()
120 self.assertEqual(e['a'], 42)
121
Guido van Rossum1968ad32006-02-25 22:38:04 +0000122 def test_shallow_copy(self):
123 d1 = defaultdict(foobar, {1: 1})
124 d2 = copy.copy(d1)
125 self.assertEqual(d2.default_factory, foobar)
126 self.assertEqual(d2, d1)
127 d1.default_factory = list
128 d2 = copy.copy(d1)
129 self.assertEqual(d2.default_factory, list)
130 self.assertEqual(d2, d1)
131
132 def test_deep_copy(self):
133 d1 = defaultdict(foobar, {1: [1]})
134 d2 = copy.deepcopy(d1)
135 self.assertEqual(d2.default_factory, foobar)
136 self.assertEqual(d2, d1)
Georg Brandlab91fde2009-08-13 08:51:18 +0000137 self.assertTrue(d1[1] is not d2[1])
Guido van Rossum1968ad32006-02-25 22:38:04 +0000138 d1.default_factory = list
139 d2 = copy.deepcopy(d1)
140 self.assertEqual(d2.default_factory, list)
141 self.assertEqual(d2, d1)
142
Guido van Rossumd8faa362007-04-27 19:54:29 +0000143 def test_keyerror_without_factory(self):
144 d1 = defaultdict()
145 try:
146 d1[(1,)]
147 except KeyError as err:
Guido van Rossum360e4b82007-05-14 22:51:27 +0000148 self.assertEqual(err.args[0], (1,))
Guido van Rossumd8faa362007-04-27 19:54:29 +0000149 else:
150 self.fail("expected KeyError")
151
Christian Heimes77c02eb2008-02-09 02:18:51 +0000152 def test_recursive_repr(self):
153 # Issue2045: stack overflow when default_factory is a bound method
154 class sub(defaultdict):
155 def __init__(self):
156 self.default_factory = self._factory
157 def _factory(self):
158 return []
159 d = sub()
Georg Brandlab91fde2009-08-13 08:51:18 +0000160 self.assertTrue(repr(d).startswith(
Christian Heimes77c02eb2008-02-09 02:18:51 +0000161 "defaultdict(<bound method sub._factory of defaultdict(..."))
162
163 # NOTE: printing a subclass of a builtin type does not call its
164 # tp_print slot. So this part is essentially the same test as above.
165 tfn = tempfile.mktemp()
166 try:
167 f = open(tfn, "w+")
168 try:
169 print(d, file=f)
170 finally:
171 f.close()
172 finally:
173 os.remove(tfn)
174
Amaury Forgeot d'Arcf43ee812008-10-30 20:58:42 +0000175 def test_pickleing(self):
176 d = defaultdict(int)
177 d[1]
178 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
179 s = pickle.dumps(d, proto)
180 o = pickle.loads(s)
181 self.assertEqual(d, o)
Guido van Rossum1968ad32006-02-25 22:38:04 +0000182
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000183def test_main():
Benjamin Petersonee8712c2008-05-20 21:35:26 +0000184 support.run_unittest(TestDefaultDict)
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000185
Guido van Rossum1968ad32006-02-25 22:38:04 +0000186if __name__ == "__main__":
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000187 test_main()