| Guido van Rossum | 1968ad3 | 2006-02-25 22:38:04 +0000 | [diff] [blame^] | 1 | """Unit tests for collections.defaultdict.""" | 
|  | 2 |  | 
|  | 3 | import os | 
|  | 4 | import copy | 
|  | 5 | import tempfile | 
|  | 6 | import unittest | 
|  | 7 |  | 
|  | 8 | from collections import defaultdict | 
|  | 9 |  | 
|  | 10 | def foobar(): | 
|  | 11 | return list | 
|  | 12 |  | 
|  | 13 | class TestDefaultDict(unittest.TestCase): | 
|  | 14 |  | 
|  | 15 | def test_basic(self): | 
|  | 16 | d1 = defaultdict() | 
|  | 17 | self.assertEqual(d1.default_factory, None) | 
|  | 18 | d1.default_factory = list | 
|  | 19 | d1[12].append(42) | 
|  | 20 | self.assertEqual(d1, {12: [42]}) | 
|  | 21 | d1[12].append(24) | 
|  | 22 | self.assertEqual(d1, {12: [42, 24]}) | 
|  | 23 | d1[13] | 
|  | 24 | d1[14] | 
|  | 25 | self.assertEqual(d1, {12: [42, 24], 13: [], 14: []}) | 
|  | 26 | self.assert_(d1[12] is not d1[13] is not d1[14]) | 
|  | 27 | d2 = defaultdict(list, foo=1, bar=2) | 
|  | 28 | self.assertEqual(d2.default_factory, list) | 
|  | 29 | self.assertEqual(d2, {"foo": 1, "bar": 2}) | 
|  | 30 | self.assertEqual(d2["foo"], 1) | 
|  | 31 | self.assertEqual(d2["bar"], 2) | 
|  | 32 | self.assertEqual(d2[42], []) | 
|  | 33 | self.assert_("foo" in d2) | 
|  | 34 | self.assert_("foo" in d2.keys()) | 
|  | 35 | self.assert_("bar" in d2) | 
|  | 36 | self.assert_("bar" in d2.keys()) | 
|  | 37 | self.assert_(42 in d2) | 
|  | 38 | self.assert_(42 in d2.keys()) | 
|  | 39 | self.assert_(12 not in d2) | 
|  | 40 | self.assert_(12 not in d2.keys()) | 
|  | 41 | d2.default_factory = None | 
|  | 42 | self.assertEqual(d2.default_factory, None) | 
|  | 43 | try: | 
|  | 44 | d2[15] | 
|  | 45 | except KeyError, err: | 
|  | 46 | self.assertEqual(err.args, (15,)) | 
|  | 47 | else: | 
|  | 48 | self.fail("d2[15] didn't raise KeyError") | 
|  | 49 |  | 
|  | 50 | def test_missing(self): | 
|  | 51 | d1 = defaultdict() | 
|  | 52 | self.assertRaises(KeyError, d1.__missing__, 42) | 
|  | 53 | d1.default_factory = list | 
|  | 54 | self.assertEqual(d1.__missing__(42), []) | 
|  | 55 |  | 
|  | 56 | def test_repr(self): | 
|  | 57 | d1 = defaultdict() | 
|  | 58 | self.assertEqual(d1.default_factory, None) | 
|  | 59 | self.assertEqual(repr(d1), "defaultdict(None, {})") | 
|  | 60 | d1[11] = 41 | 
|  | 61 | self.assertEqual(repr(d1), "defaultdict(None, {11: 41})") | 
|  | 62 | d2 = defaultdict(0) | 
|  | 63 | self.assertEqual(d2.default_factory, 0) | 
|  | 64 | d2[12] = 42 | 
|  | 65 | self.assertEqual(repr(d2), "defaultdict(0, {12: 42})") | 
|  | 66 | def foo(): return 43 | 
|  | 67 | d3 = defaultdict(foo) | 
|  | 68 | self.assert_(d3.default_factory is foo) | 
|  | 69 | d3[13] | 
|  | 70 | self.assertEqual(repr(d3), "defaultdict(%s, {13: 43})" % repr(foo)) | 
|  | 71 |  | 
|  | 72 | def test_print(self): | 
|  | 73 | d1 = defaultdict() | 
|  | 74 | def foo(): return 42 | 
|  | 75 | d2 = defaultdict(foo, {1: 2}) | 
|  | 76 | # NOTE: We can't use tempfile.[Named]TemporaryFile since this | 
|  | 77 | # code must exercise the tp_print C code, which only gets | 
|  | 78 | # invoked for *real* files. | 
|  | 79 | tfn = tempfile.mktemp() | 
|  | 80 | try: | 
|  | 81 | f = open(tfn, "w+") | 
|  | 82 | try: | 
|  | 83 | print >>f, d1 | 
|  | 84 | print >>f, d2 | 
|  | 85 | f.seek(0) | 
|  | 86 | self.assertEqual(f.readline(), repr(d1) + "\n") | 
|  | 87 | self.assertEqual(f.readline(), repr(d2) + "\n") | 
|  | 88 | finally: | 
|  | 89 | f.close() | 
|  | 90 | finally: | 
|  | 91 | os.remove(tfn) | 
|  | 92 |  | 
|  | 93 | def test_copy(self): | 
|  | 94 | d1 = defaultdict() | 
|  | 95 | d2 = d1.copy() | 
|  | 96 | self.assertEqual(type(d2), defaultdict) | 
|  | 97 | self.assertEqual(d2.default_factory, None) | 
|  | 98 | self.assertEqual(d2, {}) | 
|  | 99 | d1.default_factory = list | 
|  | 100 | d3 = d1.copy() | 
|  | 101 | self.assertEqual(type(d3), defaultdict) | 
|  | 102 | self.assertEqual(d3.default_factory, list) | 
|  | 103 | self.assertEqual(d3, {}) | 
|  | 104 | d1[42] | 
|  | 105 | d4 = d1.copy() | 
|  | 106 | self.assertEqual(type(d4), defaultdict) | 
|  | 107 | self.assertEqual(d4.default_factory, list) | 
|  | 108 | self.assertEqual(d4, {42: []}) | 
|  | 109 | d4[12] | 
|  | 110 | self.assertEqual(d4, {42: [], 12: []}) | 
|  | 111 |  | 
|  | 112 | def test_shallow_copy(self): | 
|  | 113 | d1 = defaultdict(foobar, {1: 1}) | 
|  | 114 | d2 = copy.copy(d1) | 
|  | 115 | self.assertEqual(d2.default_factory, foobar) | 
|  | 116 | self.assertEqual(d2, d1) | 
|  | 117 | d1.default_factory = list | 
|  | 118 | d2 = copy.copy(d1) | 
|  | 119 | self.assertEqual(d2.default_factory, list) | 
|  | 120 | self.assertEqual(d2, d1) | 
|  | 121 |  | 
|  | 122 | def test_deep_copy(self): | 
|  | 123 | d1 = defaultdict(foobar, {1: [1]}) | 
|  | 124 | d2 = copy.deepcopy(d1) | 
|  | 125 | self.assertEqual(d2.default_factory, foobar) | 
|  | 126 | self.assertEqual(d2, d1) | 
|  | 127 | self.assert_(d1[1] is not d2[1]) | 
|  | 128 | d1.default_factory = list | 
|  | 129 | d2 = copy.deepcopy(d1) | 
|  | 130 | self.assertEqual(d2.default_factory, list) | 
|  | 131 | self.assertEqual(d2, d1) | 
|  | 132 |  | 
|  | 133 |  | 
|  | 134 | if __name__ == "__main__": | 
|  | 135 | unittest.main() |