blob: 00bd9dc516e19518047a3f5117d40fd76ec64f0d [file] [log] [blame]
Guido van Rossum1968ad32006-02-25 22:38:04 +00001"""Unit tests for collections.defaultdict."""
2
3import os
4import copy
Amaury Forgeot d'Arcf43ee812008-10-30 20:58:42 +00005import pickle
Guido van Rossum1968ad32006-02-25 22:38:04 +00006import tempfile
7import unittest
Benjamin Petersonee8712c2008-05-20 21:35:26 +00008from test import support
Guido van Rossum1968ad32006-02-25 22:38:04 +00009
10from collections import defaultdict
11
12def foobar():
13 return list
14
15class TestDefaultDict(unittest.TestCase):
16
17 def test_basic(self):
18 d1 = defaultdict()
19 self.assertEqual(d1.default_factory, None)
20 d1.default_factory = list
21 d1[12].append(42)
22 self.assertEqual(d1, {12: [42]})
23 d1[12].append(24)
24 self.assertEqual(d1, {12: [42, 24]})
25 d1[13]
26 d1[14]
27 self.assertEqual(d1, {12: [42, 24], 13: [], 14: []})
28 self.assert_(d1[12] is not d1[13] is not d1[14])
29 d2 = defaultdict(list, foo=1, bar=2)
30 self.assertEqual(d2.default_factory, list)
31 self.assertEqual(d2, {"foo": 1, "bar": 2})
32 self.assertEqual(d2["foo"], 1)
33 self.assertEqual(d2["bar"], 2)
34 self.assertEqual(d2[42], [])
35 self.assert_("foo" in d2)
36 self.assert_("foo" in d2.keys())
37 self.assert_("bar" in d2)
38 self.assert_("bar" in d2.keys())
39 self.assert_(42 in d2)
40 self.assert_(42 in d2.keys())
41 self.assert_(12 not in d2)
42 self.assert_(12 not in d2.keys())
43 d2.default_factory = None
44 self.assertEqual(d2.default_factory, None)
45 try:
46 d2[15]
Guido van Rossumb940e112007-01-10 16:19:56 +000047 except KeyError as err:
Guido van Rossum1968ad32006-02-25 22:38:04 +000048 self.assertEqual(err.args, (15,))
49 else:
50 self.fail("d2[15] didn't raise KeyError")
Thomas Wouterscf297e42007-02-23 15:07:44 +000051 self.assertRaises(TypeError, defaultdict, 1)
Guido van Rossum1968ad32006-02-25 22:38:04 +000052
53 def test_missing(self):
54 d1 = defaultdict()
55 self.assertRaises(KeyError, d1.__missing__, 42)
56 d1.default_factory = list
57 self.assertEqual(d1.__missing__(42), [])
58
59 def test_repr(self):
60 d1 = defaultdict()
61 self.assertEqual(d1.default_factory, None)
62 self.assertEqual(repr(d1), "defaultdict(None, {})")
63 d1[11] = 41
64 self.assertEqual(repr(d1), "defaultdict(None, {11: 41})")
Thomas Wouterscf297e42007-02-23 15:07:44 +000065 d2 = defaultdict(int)
66 self.assertEqual(d2.default_factory, int)
Guido van Rossum1968ad32006-02-25 22:38:04 +000067 d2[12] = 42
Martin v. Löwis250ad612008-04-07 05:43:42 +000068 self.assertEqual(repr(d2), "defaultdict(<class 'int'>, {12: 42})")
Guido van Rossum1968ad32006-02-25 22:38:04 +000069 def foo(): return 43
70 d3 = defaultdict(foo)
71 self.assert_(d3.default_factory is foo)
72 d3[13]
73 self.assertEqual(repr(d3), "defaultdict(%s, {13: 43})" % repr(foo))
74
75 def test_print(self):
76 d1 = defaultdict()
77 def foo(): return 42
78 d2 = defaultdict(foo, {1: 2})
79 # NOTE: We can't use tempfile.[Named]TemporaryFile since this
80 # code must exercise the tp_print C code, which only gets
81 # invoked for *real* files.
82 tfn = tempfile.mktemp()
83 try:
84 f = open(tfn, "w+")
85 try:
Guido van Rossumbe19ed72007-02-09 05:37:30 +000086 print(d1, file=f)
87 print(d2, file=f)
Guido van Rossum1968ad32006-02-25 22:38:04 +000088 f.seek(0)
89 self.assertEqual(f.readline(), repr(d1) + "\n")
90 self.assertEqual(f.readline(), repr(d2) + "\n")
91 finally:
92 f.close()
93 finally:
94 os.remove(tfn)
95
96 def test_copy(self):
97 d1 = defaultdict()
98 d2 = d1.copy()
99 self.assertEqual(type(d2), defaultdict)
100 self.assertEqual(d2.default_factory, None)
101 self.assertEqual(d2, {})
102 d1.default_factory = list
103 d3 = d1.copy()
104 self.assertEqual(type(d3), defaultdict)
105 self.assertEqual(d3.default_factory, list)
106 self.assertEqual(d3, {})
107 d1[42]
108 d4 = d1.copy()
109 self.assertEqual(type(d4), defaultdict)
110 self.assertEqual(d4.default_factory, list)
111 self.assertEqual(d4, {42: []})
112 d4[12]
113 self.assertEqual(d4, {42: [], 12: []})
114
115 def test_shallow_copy(self):
116 d1 = defaultdict(foobar, {1: 1})
117 d2 = copy.copy(d1)
118 self.assertEqual(d2.default_factory, foobar)
119 self.assertEqual(d2, d1)
120 d1.default_factory = list
121 d2 = copy.copy(d1)
122 self.assertEqual(d2.default_factory, list)
123 self.assertEqual(d2, d1)
124
125 def test_deep_copy(self):
126 d1 = defaultdict(foobar, {1: [1]})
127 d2 = copy.deepcopy(d1)
128 self.assertEqual(d2.default_factory, foobar)
129 self.assertEqual(d2, d1)
130 self.assert_(d1[1] is not d2[1])
131 d1.default_factory = list
132 d2 = copy.deepcopy(d1)
133 self.assertEqual(d2.default_factory, list)
134 self.assertEqual(d2, d1)
135
Guido van Rossumd8faa362007-04-27 19:54:29 +0000136 def test_keyerror_without_factory(self):
137 d1 = defaultdict()
138 try:
139 d1[(1,)]
140 except KeyError as err:
Guido van Rossum360e4b82007-05-14 22:51:27 +0000141 self.assertEqual(err.args[0], (1,))
Guido van Rossumd8faa362007-04-27 19:54:29 +0000142 else:
143 self.fail("expected KeyError")
144
Christian Heimes77c02eb2008-02-09 02:18:51 +0000145 def test_recursive_repr(self):
146 # Issue2045: stack overflow when default_factory is a bound method
147 class sub(defaultdict):
148 def __init__(self):
149 self.default_factory = self._factory
150 def _factory(self):
151 return []
152 d = sub()
153 self.assert_(repr(d).startswith(
154 "defaultdict(<bound method sub._factory of defaultdict(..."))
155
156 # NOTE: printing a subclass of a builtin type does not call its
157 # tp_print slot. So this part is essentially the same test as above.
158 tfn = tempfile.mktemp()
159 try:
160 f = open(tfn, "w+")
161 try:
162 print(d, file=f)
163 finally:
164 f.close()
165 finally:
166 os.remove(tfn)
167
Amaury Forgeot d'Arcf43ee812008-10-30 20:58:42 +0000168 def test_pickleing(self):
169 d = defaultdict(int)
170 d[1]
171 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
172 s = pickle.dumps(d, proto)
173 o = pickle.loads(s)
174 self.assertEqual(d, o)
Guido van Rossum1968ad32006-02-25 22:38:04 +0000175
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000176def test_main():
Benjamin Petersonee8712c2008-05-20 21:35:26 +0000177 support.run_unittest(TestDefaultDict)
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000178
Guido van Rossum1968ad32006-02-25 22:38:04 +0000179if __name__ == "__main__":
Thomas Wouters0e3f5912006-08-11 14:57:12 +0000180 test_main()