blob: 532d535981b36c1efbbea0c2dd4cd6ede7caffef [file] [log] [blame]
Guido van Rossum1968ad32006-02-25 22:38:04 +00001"""Unit tests for collections.defaultdict."""
2
3import os
4import copy
Amaury Forgeot d'Arcf43ee812008-10-30 20:58:42 +00005import pickle
Guido van Rossum1968ad32006-02-25 22:38:04 +00006import tempfile
7import unittest
Benjamin Petersonee8712c2008-05-20 21:35:26 +00008from test import support
Guido van Rossum1968ad32006-02-25 22:38:04 +00009
10from collections import defaultdict
11
12def foobar():
13 return list
14
15class TestDefaultDict(unittest.TestCase):
16
17 def test_basic(self):
18 d1 = defaultdict()
19 self.assertEqual(d1.default_factory, None)
20 d1.default_factory = list
21 d1[12].append(42)
22 self.assertEqual(d1, {12: [42]})
23 d1[12].append(24)
24 self.assertEqual(d1, {12: [42, 24]})
25 d1[13]
26 d1[14]
27 self.assertEqual(d1, {12: [42, 24], 13: [], 14: []})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +000028 self.assertTrue(d1[12] is not d1[13] is not d1[14])
Guido van Rossum1968ad32006-02-25 22:38:04 +000029 d2 = defaultdict(list, foo=1, bar=2)
30 self.assertEqual(d2.default_factory, list)
31 self.assertEqual(d2, {"foo": 1, "bar": 2})
32 self.assertEqual(d2["foo"], 1)
33 self.assertEqual(d2["bar"], 2)
34 self.assertEqual(d2[42], [])
Benjamin Peterson577473f2010-01-19 00:09:57 +000035 self.assertIn("foo", d2)
36 self.assertIn("foo", d2.keys())
37 self.assertIn("bar", d2)
38 self.assertIn("bar", d2.keys())
39 self.assertIn(42, d2)
40 self.assertIn(42, d2.keys())
41 self.assertNotIn(12, d2)
42 self.assertNotIn(12, d2.keys())
Guido van Rossum1968ad32006-02-25 22:38:04 +000043 d2.default_factory = None
44 self.assertEqual(d2.default_factory, None)
45 try:
46 d2[15]
Guido van Rossumb940e112007-01-10 16:19:56 +000047 except KeyError as err:
Guido van Rossum1968ad32006-02-25 22:38:04 +000048 self.assertEqual(err.args, (15,))
49 else:
50 self.fail("d2[15] didn't raise KeyError")
Thomas Wouterscf297e42007-02-23 15:07:44 +000051 self.assertRaises(TypeError, defaultdict, 1)
Guido van Rossum1968ad32006-02-25 22:38:04 +000052
53 def test_missing(self):
54 d1 = defaultdict()
55 self.assertRaises(KeyError, d1.__missing__, 42)
56 d1.default_factory = list
57 self.assertEqual(d1.__missing__(42), [])
58
59 def test_repr(self):
60 d1 = defaultdict()
61 self.assertEqual(d1.default_factory, None)
62 self.assertEqual(repr(d1), "defaultdict(None, {})")
Raymond Hettinger54628fa2009-08-04 19:16:39 +000063 self.assertEqual(eval(repr(d1)), d1)
Guido van Rossum1968ad32006-02-25 22:38:04 +000064 d1[11] = 41
65 self.assertEqual(repr(d1), "defaultdict(None, {11: 41})")
Thomas Wouterscf297e42007-02-23 15:07:44 +000066 d2 = defaultdict(int)
67 self.assertEqual(d2.default_factory, int)
Guido van Rossum1968ad32006-02-25 22:38:04 +000068 d2[12] = 42
Martin v. Löwis250ad612008-04-07 05:43:42 +000069 self.assertEqual(repr(d2), "defaultdict(<class 'int'>, {12: 42})")
Guido van Rossum1968ad32006-02-25 22:38:04 +000070 def foo(): return 43
71 d3 = defaultdict(foo)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +000072 self.assertTrue(d3.default_factory is foo)
Guido van Rossum1968ad32006-02-25 22:38:04 +000073 d3[13]
74 self.assertEqual(repr(d3), "defaultdict(%s, {13: 43})" % repr(foo))
75
76 def test_print(self):
77 d1 = defaultdict()
78 def foo(): return 42
79 d2 = defaultdict(foo, {1: 2})
80 # NOTE: We can't use tempfile.[Named]TemporaryFile since this
81 # code must exercise the tp_print C code, which only gets
82 # invoked for *real* files.
83 tfn = tempfile.mktemp()
84 try:
85 f = open(tfn, "w+")
86 try:
Guido van Rossumbe19ed72007-02-09 05:37:30 +000087 print(d1, file=f)
88 print(d2, file=f)
Guido van Rossum1968ad32006-02-25 22:38:04 +000089 f.seek(0)
90 self.assertEqual(f.readline(), repr(d1) + "\n")
91 self.assertEqual(f.readline(), repr(d2) + "\n")
92 finally:
93 f.close()
94 finally:
95 os.remove(tfn)
96
97 def test_copy(self):
98 d1 = defaultdict()
99 d2 = d1.copy()
100 self.assertEqual(type(d2), defaultdict)
101 self.assertEqual(d2.default_factory, None)
102 self.assertEqual(d2, {})
103 d1.default_factory = list
104 d3 = d1.copy()
105 self.assertEqual(type(d3), defaultdict)
106 self.assertEqual(d3.default_factory, list)
107 self.assertEqual(d3, {})
108 d1[42]
109 d4 = d1.copy()
110 self.assertEqual(type(d4), defaultdict)
111 self.assertEqual(d4.default_factory, list)
112 self.assertEqual(d4, {42: []})
113 d4[12]
114 self.assertEqual(d4, {42: [], 12: []})
115
Raymond Hettinger54628fa2009-08-04 19:16:39 +0000116 # Issue 6637: Copy fails for empty default dict
117 d = defaultdict()
118 d['a'] = 42
119 e = d.copy()
120 self.assertEqual(e['a'], 42)
121
Guido van Rossum1968ad32006-02-25 22:38:04 +0000122 def test_shallow_copy(self):
123 d1 = defaultdict(foobar, {1: 1})
124 d2 = copy.copy(d1)
125 self.assertEqual(d2.default_factory, foobar)
126 self.assertEqual(d2, d1)
127 d1.default_factory = list
128 d2 = copy.copy(d1)
129 self.assertEqual(d2.default_factory, list)
130 self.assertEqual(d2, d1)
131
132 def test_deep_copy(self):
133 d1 = defaultdict(foobar, {1: [1]})
134 d2 = copy.deepcopy(d1)
135 self.assertEqual(d2.default_factory, foobar)
136 self.assertEqual(d2, d1)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000137 self.assertTrue(d1[1] is not d2[1])
Guido van Rossum1968ad32006-02-25 22:38:04 +0000138 d1.default_factory = list
139 d2 = copy.deepcopy(d1)
140 self.assertEqual(d2.default_factory, list)
141 self.assertEqual(d2, d1)
142
Guido van Rossumd8faa362007-04-27 19:54:29 +0000143 def test_keyerror_without_factory(self):
144 d1 = defaultdict()
145 try:
146 d1[(1,)]
147 except KeyError as err:
Guido van Rossum360e4b82007-05-14 22:51:27 +0000148 self.assertEqual(err.args[0], (1,))
Guido van Rossumd8faa362007-04-27 19:54:29 +0000149 else:
150 self.fail("expected KeyError")
151
Christian Heimes77c02eb2008-02-09 02:18:51 +0000152 def test_recursive_repr(self):
153 # Issue2045: stack overflow when default_factory is a bound method
154 class sub(defaultdict):
155 def __init__(self):
156 self.default_factory = self._factory
157 def _factory(self):
158 return []
159 d = sub()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000160 self.assertTrue(repr(d).startswith(
Christian Heimes77c02eb2008-02-09 02:18:51 +0000161 "defaultdict(<bound method sub._factory of defaultdict(..."))
162
163 # NOTE: printing a subclass of a builtin type does not call its
164 # tp_print slot. So this part is essentially the same test as above.
165 tfn = tempfile.mktemp()
166 try:
167 f = open(tfn, "w+")
168 try:
169 print(d, file=f)
170 finally:
171 f.close()
172 finally:
173 os.remove(tfn)
174
Ezio Melottieb587942011-12-08 00:02:00 +0200175 def test_callable_arg(self):
176 self.assertRaises(TypeError, defaultdict, {})
177
Amaury Forgeot d'Arcf43ee812008-10-30 20:58:42 +0000178 def test_pickleing(self):
179 d = defaultdict(int)
180 d[1]
181 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
182 s = pickle.dumps(d, proto)
183 o = pickle.loads(s)
184 self.assertEqual(d, o)
Guido van Rossum1968ad32006-02-25 22:38:04 +0000185
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000186def test_main():
Benjamin Petersonee8712c2008-05-20 21:35:26 +0000187 support.run_unittest(TestDefaultDict)
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000188
Guido van Rossum1968ad32006-02-25 22:38:04 +0000189if __name__ == "__main__":
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000190 test_main()