blob: f9a6a17ad18baf7b40811ee0481133e658527bd8 [file] [log] [blame]
Guido van Rossum1968ad32006-02-25 22:38:04 +00001"""Unit tests for collections.defaultdict."""
2
3import os
4import copy
5import tempfile
6import unittest
Georg Brandlf102fc52006-07-27 15:05:36 +00007from test import test_support
Guido van Rossum1968ad32006-02-25 22:38:04 +00008
9from collections import defaultdict
10
11def foobar():
12 return list
13
14class TestDefaultDict(unittest.TestCase):
15
16 def test_basic(self):
17 d1 = defaultdict()
18 self.assertEqual(d1.default_factory, None)
19 d1.default_factory = list
20 d1[12].append(42)
21 self.assertEqual(d1, {12: [42]})
22 d1[12].append(24)
23 self.assertEqual(d1, {12: [42, 24]})
24 d1[13]
25 d1[14]
26 self.assertEqual(d1, {12: [42, 24], 13: [], 14: []})
Benjamin Peterson5c8da862009-06-30 22:57:08 +000027 self.assertTrue(d1[12] is not d1[13] is not d1[14])
Guido van Rossum1968ad32006-02-25 22:38:04 +000028 d2 = defaultdict(list, foo=1, bar=2)
29 self.assertEqual(d2.default_factory, list)
30 self.assertEqual(d2, {"foo": 1, "bar": 2})
31 self.assertEqual(d2["foo"], 1)
32 self.assertEqual(d2["bar"], 2)
33 self.assertEqual(d2[42], [])
Benjamin Peterson5c8da862009-06-30 22:57:08 +000034 self.assertTrue("foo" in d2)
35 self.assertTrue("foo" in d2.keys())
36 self.assertTrue("bar" in d2)
37 self.assertTrue("bar" in d2.keys())
38 self.assertTrue(42 in d2)
39 self.assertTrue(42 in d2.keys())
40 self.assertTrue(12 not in d2)
41 self.assertTrue(12 not in d2.keys())
Guido van Rossum1968ad32006-02-25 22:38:04 +000042 d2.default_factory = None
43 self.assertEqual(d2.default_factory, None)
44 try:
45 d2[15]
46 except KeyError, err:
47 self.assertEqual(err.args, (15,))
48 else:
49 self.fail("d2[15] didn't raise KeyError")
Raymond Hettinger5a0217e2007-02-07 21:42:17 +000050 self.assertRaises(TypeError, defaultdict, 1)
Guido van Rossum1968ad32006-02-25 22:38:04 +000051
52 def test_missing(self):
53 d1 = defaultdict()
54 self.assertRaises(KeyError, d1.__missing__, 42)
55 d1.default_factory = list
56 self.assertEqual(d1.__missing__(42), [])
57
58 def test_repr(self):
59 d1 = defaultdict()
60 self.assertEqual(d1.default_factory, None)
61 self.assertEqual(repr(d1), "defaultdict(None, {})")
Raymond Hettinger8fdab952009-08-04 19:08:05 +000062 self.assertEqual(eval(repr(d1)), d1)
Guido van Rossum1968ad32006-02-25 22:38:04 +000063 d1[11] = 41
64 self.assertEqual(repr(d1), "defaultdict(None, {11: 41})")
Raymond Hettinger5a0217e2007-02-07 21:42:17 +000065 d2 = defaultdict(int)
66 self.assertEqual(d2.default_factory, int)
Guido van Rossum1968ad32006-02-25 22:38:04 +000067 d2[12] = 42
Raymond Hettinger5a0217e2007-02-07 21:42:17 +000068 self.assertEqual(repr(d2), "defaultdict(<type 'int'>, {12: 42})")
Guido van Rossum1968ad32006-02-25 22:38:04 +000069 def foo(): return 43
70 d3 = defaultdict(foo)
Benjamin Peterson5c8da862009-06-30 22:57:08 +000071 self.assertTrue(d3.default_factory is foo)
Guido van Rossum1968ad32006-02-25 22:38:04 +000072 d3[13]
73 self.assertEqual(repr(d3), "defaultdict(%s, {13: 43})" % repr(foo))
74
75 def test_print(self):
76 d1 = defaultdict()
77 def foo(): return 42
78 d2 = defaultdict(foo, {1: 2})
79 # NOTE: We can't use tempfile.[Named]TemporaryFile since this
80 # code must exercise the tp_print C code, which only gets
81 # invoked for *real* files.
82 tfn = tempfile.mktemp()
83 try:
84 f = open(tfn, "w+")
85 try:
86 print >>f, d1
87 print >>f, d2
88 f.seek(0)
89 self.assertEqual(f.readline(), repr(d1) + "\n")
90 self.assertEqual(f.readline(), repr(d2) + "\n")
91 finally:
92 f.close()
93 finally:
94 os.remove(tfn)
95
96 def test_copy(self):
97 d1 = defaultdict()
98 d2 = d1.copy()
99 self.assertEqual(type(d2), defaultdict)
100 self.assertEqual(d2.default_factory, None)
101 self.assertEqual(d2, {})
102 d1.default_factory = list
103 d3 = d1.copy()
104 self.assertEqual(type(d3), defaultdict)
105 self.assertEqual(d3.default_factory, list)
106 self.assertEqual(d3, {})
107 d1[42]
108 d4 = d1.copy()
109 self.assertEqual(type(d4), defaultdict)
110 self.assertEqual(d4.default_factory, list)
111 self.assertEqual(d4, {42: []})
112 d4[12]
113 self.assertEqual(d4, {42: [], 12: []})
114
Raymond Hettinger8fdab952009-08-04 19:08:05 +0000115 # Issue 6637: Copy fails for empty default dict
116 d = defaultdict()
117 d['a'] = 42
118 e = d.copy()
119 self.assertEqual(e['a'], 42)
120
Guido van Rossum1968ad32006-02-25 22:38:04 +0000121 def test_shallow_copy(self):
122 d1 = defaultdict(foobar, {1: 1})
123 d2 = copy.copy(d1)
124 self.assertEqual(d2.default_factory, foobar)
125 self.assertEqual(d2, d1)
126 d1.default_factory = list
127 d2 = copy.copy(d1)
128 self.assertEqual(d2.default_factory, list)
129 self.assertEqual(d2, d1)
130
131 def test_deep_copy(self):
132 d1 = defaultdict(foobar, {1: [1]})
133 d2 = copy.deepcopy(d1)
134 self.assertEqual(d2.default_factory, foobar)
135 self.assertEqual(d2, d1)
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000136 self.assertTrue(d1[1] is not d2[1])
Guido van Rossum1968ad32006-02-25 22:38:04 +0000137 d1.default_factory = list
138 d2 = copy.deepcopy(d1)
139 self.assertEqual(d2.default_factory, list)
140 self.assertEqual(d2, d1)
141
Georg Brandl72363032007-03-06 13:35:00 +0000142 def test_keyerror_without_factory(self):
143 d1 = defaultdict()
144 try:
145 d1[(1,)]
146 except KeyError, err:
Brett Cannon229cee22007-05-05 01:34:02 +0000147 self.assertEqual(err.args[0], (1,))
Georg Brandl72363032007-03-06 13:35:00 +0000148 else:
149 self.fail("expected KeyError")
150
Amaury Forgeot d'Arcb01aa432008-02-08 00:56:02 +0000151 def test_recursive_repr(self):
152 # Issue2045: stack overflow when default_factory is a bound method
153 class sub(defaultdict):
154 def __init__(self):
155 self.default_factory = self._factory
156 def _factory(self):
157 return []
158 d = sub()
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000159 self.assertTrue(repr(d).startswith(
Amaury Forgeot d'Arcb01aa432008-02-08 00:56:02 +0000160 "defaultdict(<bound method sub._factory of defaultdict(..."))
161
162 # NOTE: printing a subclass of a builtin type does not call its
163 # tp_print slot. So this part is essentially the same test as above.
164 tfn = tempfile.mktemp()
165 try:
166 f = open(tfn, "w+")
167 try:
168 print >>f, d
169 finally:
170 f.close()
171 finally:
172 os.remove(tfn)
173
Guido van Rossum1968ad32006-02-25 22:38:04 +0000174
Georg Brandlf102fc52006-07-27 15:05:36 +0000175def test_main():
176 test_support.run_unittest(TestDefaultDict)
177
Guido van Rossum1968ad32006-02-25 22:38:04 +0000178if __name__ == "__main__":
Georg Brandlf102fc52006-07-27 15:05:36 +0000179 test_main()