blob: 71d69c35d24d86f1f8249e42b924bab18d8f2d78 [file] [log] [blame]
Guido van Rossum1968ad32006-02-25 22:38:04 +00001"""Unit tests for collections.defaultdict."""
2
3import os
4import copy
5import tempfile
6import unittest
Georg Brandlf102fc52006-07-27 15:05:36 +00007from test import test_support
Guido van Rossum1968ad32006-02-25 22:38:04 +00008
9from collections import defaultdict
10
11def foobar():
12 return list
13
14class TestDefaultDict(unittest.TestCase):
15
16 def test_basic(self):
17 d1 = defaultdict()
18 self.assertEqual(d1.default_factory, None)
19 d1.default_factory = list
20 d1[12].append(42)
21 self.assertEqual(d1, {12: [42]})
22 d1[12].append(24)
23 self.assertEqual(d1, {12: [42, 24]})
24 d1[13]
25 d1[14]
26 self.assertEqual(d1, {12: [42, 24], 13: [], 14: []})
Benjamin Peterson5c8da862009-06-30 22:57:08 +000027 self.assertTrue(d1[12] is not d1[13] is not d1[14])
Guido van Rossum1968ad32006-02-25 22:38:04 +000028 d2 = defaultdict(list, foo=1, bar=2)
29 self.assertEqual(d2.default_factory, list)
30 self.assertEqual(d2, {"foo": 1, "bar": 2})
31 self.assertEqual(d2["foo"], 1)
32 self.assertEqual(d2["bar"], 2)
33 self.assertEqual(d2[42], [])
Benjamin Peterson5c8da862009-06-30 22:57:08 +000034 self.assertTrue("foo" in d2)
35 self.assertTrue("foo" in d2.keys())
36 self.assertTrue("bar" in d2)
37 self.assertTrue("bar" in d2.keys())
38 self.assertTrue(42 in d2)
39 self.assertTrue(42 in d2.keys())
40 self.assertTrue(12 not in d2)
41 self.assertTrue(12 not in d2.keys())
Guido van Rossum1968ad32006-02-25 22:38:04 +000042 d2.default_factory = None
43 self.assertEqual(d2.default_factory, None)
44 try:
45 d2[15]
46 except KeyError, err:
47 self.assertEqual(err.args, (15,))
48 else:
49 self.fail("d2[15] didn't raise KeyError")
Raymond Hettinger5a0217e2007-02-07 21:42:17 +000050 self.assertRaises(TypeError, defaultdict, 1)
Guido van Rossum1968ad32006-02-25 22:38:04 +000051
52 def test_missing(self):
53 d1 = defaultdict()
54 self.assertRaises(KeyError, d1.__missing__, 42)
55 d1.default_factory = list
56 self.assertEqual(d1.__missing__(42), [])
57
58 def test_repr(self):
59 d1 = defaultdict()
60 self.assertEqual(d1.default_factory, None)
61 self.assertEqual(repr(d1), "defaultdict(None, {})")
62 d1[11] = 41
63 self.assertEqual(repr(d1), "defaultdict(None, {11: 41})")
Raymond Hettinger5a0217e2007-02-07 21:42:17 +000064 d2 = defaultdict(int)
65 self.assertEqual(d2.default_factory, int)
Guido van Rossum1968ad32006-02-25 22:38:04 +000066 d2[12] = 42
Raymond Hettinger5a0217e2007-02-07 21:42:17 +000067 self.assertEqual(repr(d2), "defaultdict(<type 'int'>, {12: 42})")
Guido van Rossum1968ad32006-02-25 22:38:04 +000068 def foo(): return 43
69 d3 = defaultdict(foo)
Benjamin Peterson5c8da862009-06-30 22:57:08 +000070 self.assertTrue(d3.default_factory is foo)
Guido van Rossum1968ad32006-02-25 22:38:04 +000071 d3[13]
72 self.assertEqual(repr(d3), "defaultdict(%s, {13: 43})" % repr(foo))
73
74 def test_print(self):
75 d1 = defaultdict()
76 def foo(): return 42
77 d2 = defaultdict(foo, {1: 2})
78 # NOTE: We can't use tempfile.[Named]TemporaryFile since this
79 # code must exercise the tp_print C code, which only gets
80 # invoked for *real* files.
81 tfn = tempfile.mktemp()
82 try:
83 f = open(tfn, "w+")
84 try:
85 print >>f, d1
86 print >>f, d2
87 f.seek(0)
88 self.assertEqual(f.readline(), repr(d1) + "\n")
89 self.assertEqual(f.readline(), repr(d2) + "\n")
90 finally:
91 f.close()
92 finally:
93 os.remove(tfn)
94
95 def test_copy(self):
96 d1 = defaultdict()
97 d2 = d1.copy()
98 self.assertEqual(type(d2), defaultdict)
99 self.assertEqual(d2.default_factory, None)
100 self.assertEqual(d2, {})
101 d1.default_factory = list
102 d3 = d1.copy()
103 self.assertEqual(type(d3), defaultdict)
104 self.assertEqual(d3.default_factory, list)
105 self.assertEqual(d3, {})
106 d1[42]
107 d4 = d1.copy()
108 self.assertEqual(type(d4), defaultdict)
109 self.assertEqual(d4.default_factory, list)
110 self.assertEqual(d4, {42: []})
111 d4[12]
112 self.assertEqual(d4, {42: [], 12: []})
113
114 def test_shallow_copy(self):
115 d1 = defaultdict(foobar, {1: 1})
116 d2 = copy.copy(d1)
117 self.assertEqual(d2.default_factory, foobar)
118 self.assertEqual(d2, d1)
119 d1.default_factory = list
120 d2 = copy.copy(d1)
121 self.assertEqual(d2.default_factory, list)
122 self.assertEqual(d2, d1)
123
124 def test_deep_copy(self):
125 d1 = defaultdict(foobar, {1: [1]})
126 d2 = copy.deepcopy(d1)
127 self.assertEqual(d2.default_factory, foobar)
128 self.assertEqual(d2, d1)
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000129 self.assertTrue(d1[1] is not d2[1])
Guido van Rossum1968ad32006-02-25 22:38:04 +0000130 d1.default_factory = list
131 d2 = copy.deepcopy(d1)
132 self.assertEqual(d2.default_factory, list)
133 self.assertEqual(d2, d1)
134
Georg Brandl72363032007-03-06 13:35:00 +0000135 def test_keyerror_without_factory(self):
136 d1 = defaultdict()
137 try:
138 d1[(1,)]
139 except KeyError, err:
Brett Cannon229cee22007-05-05 01:34:02 +0000140 self.assertEqual(err.args[0], (1,))
Georg Brandl72363032007-03-06 13:35:00 +0000141 else:
142 self.fail("expected KeyError")
143
Amaury Forgeot d'Arcb01aa432008-02-08 00:56:02 +0000144 def test_recursive_repr(self):
145 # Issue2045: stack overflow when default_factory is a bound method
146 class sub(defaultdict):
147 def __init__(self):
148 self.default_factory = self._factory
149 def _factory(self):
150 return []
151 d = sub()
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000152 self.assertTrue(repr(d).startswith(
Amaury Forgeot d'Arcb01aa432008-02-08 00:56:02 +0000153 "defaultdict(<bound method sub._factory of defaultdict(..."))
154
155 # NOTE: printing a subclass of a builtin type does not call its
156 # tp_print slot. So this part is essentially the same test as above.
157 tfn = tempfile.mktemp()
158 try:
159 f = open(tfn, "w+")
160 try:
161 print >>f, d
162 finally:
163 f.close()
164 finally:
165 os.remove(tfn)
166
Guido van Rossum1968ad32006-02-25 22:38:04 +0000167
Georg Brandlf102fc52006-07-27 15:05:36 +0000168def test_main():
169 test_support.run_unittest(TestDefaultDict)
170
Guido van Rossum1968ad32006-02-25 22:38:04 +0000171if __name__ == "__main__":
Georg Brandlf102fc52006-07-27 15:05:36 +0000172 test_main()