blob: b5a662898767b132edc48116b9f81faaa837ea00 [file] [log] [blame]
Guido van Rossum1968ad32006-02-25 22:38:04 +00001"""Unit tests for collections.defaultdict."""
2
3import os
4import copy
5import tempfile
6import unittest
7
8from collections import defaultdict
9
10def foobar():
11 return list
12
13class TestDefaultDict(unittest.TestCase):
14
15 def test_basic(self):
16 d1 = defaultdict()
17 self.assertEqual(d1.default_factory, None)
18 d1.default_factory = list
19 d1[12].append(42)
20 self.assertEqual(d1, {12: [42]})
21 d1[12].append(24)
22 self.assertEqual(d1, {12: [42, 24]})
23 d1[13]
24 d1[14]
25 self.assertEqual(d1, {12: [42, 24], 13: [], 14: []})
26 self.assert_(d1[12] is not d1[13] is not d1[14])
27 d2 = defaultdict(list, foo=1, bar=2)
28 self.assertEqual(d2.default_factory, list)
29 self.assertEqual(d2, {"foo": 1, "bar": 2})
30 self.assertEqual(d2["foo"], 1)
31 self.assertEqual(d2["bar"], 2)
32 self.assertEqual(d2[42], [])
33 self.assert_("foo" in d2)
34 self.assert_("foo" in d2.keys())
35 self.assert_("bar" in d2)
36 self.assert_("bar" in d2.keys())
37 self.assert_(42 in d2)
38 self.assert_(42 in d2.keys())
39 self.assert_(12 not in d2)
40 self.assert_(12 not in d2.keys())
41 d2.default_factory = None
42 self.assertEqual(d2.default_factory, None)
43 try:
44 d2[15]
45 except KeyError, err:
46 self.assertEqual(err.args, (15,))
47 else:
48 self.fail("d2[15] didn't raise KeyError")
49
50 def test_missing(self):
51 d1 = defaultdict()
52 self.assertRaises(KeyError, d1.__missing__, 42)
53 d1.default_factory = list
54 self.assertEqual(d1.__missing__(42), [])
55
56 def test_repr(self):
57 d1 = defaultdict()
58 self.assertEqual(d1.default_factory, None)
59 self.assertEqual(repr(d1), "defaultdict(None, {})")
60 d1[11] = 41
61 self.assertEqual(repr(d1), "defaultdict(None, {11: 41})")
62 d2 = defaultdict(0)
63 self.assertEqual(d2.default_factory, 0)
64 d2[12] = 42
65 self.assertEqual(repr(d2), "defaultdict(0, {12: 42})")
66 def foo(): return 43
67 d3 = defaultdict(foo)
68 self.assert_(d3.default_factory is foo)
69 d3[13]
70 self.assertEqual(repr(d3), "defaultdict(%s, {13: 43})" % repr(foo))
71
72 def test_print(self):
73 d1 = defaultdict()
74 def foo(): return 42
75 d2 = defaultdict(foo, {1: 2})
76 # NOTE: We can't use tempfile.[Named]TemporaryFile since this
77 # code must exercise the tp_print C code, which only gets
78 # invoked for *real* files.
79 tfn = tempfile.mktemp()
80 try:
81 f = open(tfn, "w+")
82 try:
83 print >>f, d1
84 print >>f, d2
85 f.seek(0)
86 self.assertEqual(f.readline(), repr(d1) + "\n")
87 self.assertEqual(f.readline(), repr(d2) + "\n")
88 finally:
89 f.close()
90 finally:
91 os.remove(tfn)
92
93 def test_copy(self):
94 d1 = defaultdict()
95 d2 = d1.copy()
96 self.assertEqual(type(d2), defaultdict)
97 self.assertEqual(d2.default_factory, None)
98 self.assertEqual(d2, {})
99 d1.default_factory = list
100 d3 = d1.copy()
101 self.assertEqual(type(d3), defaultdict)
102 self.assertEqual(d3.default_factory, list)
103 self.assertEqual(d3, {})
104 d1[42]
105 d4 = d1.copy()
106 self.assertEqual(type(d4), defaultdict)
107 self.assertEqual(d4.default_factory, list)
108 self.assertEqual(d4, {42: []})
109 d4[12]
110 self.assertEqual(d4, {42: [], 12: []})
111
112 def test_shallow_copy(self):
113 d1 = defaultdict(foobar, {1: 1})
114 d2 = copy.copy(d1)
115 self.assertEqual(d2.default_factory, foobar)
116 self.assertEqual(d2, d1)
117 d1.default_factory = list
118 d2 = copy.copy(d1)
119 self.assertEqual(d2.default_factory, list)
120 self.assertEqual(d2, d1)
121
122 def test_deep_copy(self):
123 d1 = defaultdict(foobar, {1: [1]})
124 d2 = copy.deepcopy(d1)
125 self.assertEqual(d2.default_factory, foobar)
126 self.assertEqual(d2, d1)
127 self.assert_(d1[1] is not d2[1])
128 d1.default_factory = list
129 d2 = copy.deepcopy(d1)
130 self.assertEqual(d2.default_factory, list)
131 self.assertEqual(d2, d1)
132
133
134if __name__ == "__main__":
135 unittest.main()