| """Unit tests for collections.defaultdict.""" |
| |
| import os |
| import copy |
| import pickle |
| import tempfile |
| import unittest |
| from test import support |
| |
| from collections import defaultdict |
| |
| def foobar(): |
| return list |
| |
| class TestDefaultDict(unittest.TestCase): |
| |
| def test_basic(self): |
| d1 = defaultdict() |
| self.assertEqual(d1.default_factory, None) |
| d1.default_factory = list |
| d1[12].append(42) |
| self.assertEqual(d1, {12: [42]}) |
| d1[12].append(24) |
| self.assertEqual(d1, {12: [42, 24]}) |
| d1[13] |
| d1[14] |
| self.assertEqual(d1, {12: [42, 24], 13: [], 14: []}) |
| self.assertTrue(d1[12] is not d1[13] is not d1[14]) |
| d2 = defaultdict(list, foo=1, bar=2) |
| self.assertEqual(d2.default_factory, list) |
| self.assertEqual(d2, {"foo": 1, "bar": 2}) |
| self.assertEqual(d2["foo"], 1) |
| self.assertEqual(d2["bar"], 2) |
| self.assertEqual(d2[42], []) |
| self.assertIn("foo", d2) |
| self.assertIn("foo", d2.keys()) |
| self.assertIn("bar", d2) |
| self.assertIn("bar", d2.keys()) |
| self.assertIn(42, d2) |
| self.assertIn(42, d2.keys()) |
| self.assertNotIn(12, d2) |
| self.assertNotIn(12, d2.keys()) |
| d2.default_factory = None |
| self.assertEqual(d2.default_factory, None) |
| try: |
| d2[15] |
| except KeyError as err: |
| self.assertEqual(err.args, (15,)) |
| else: |
| self.fail("d2[15] didn't raise KeyError") |
| self.assertRaises(TypeError, defaultdict, 1) |
| |
| def test_missing(self): |
| d1 = defaultdict() |
| self.assertRaises(KeyError, d1.__missing__, 42) |
| d1.default_factory = list |
| self.assertEqual(d1.__missing__(42), []) |
| |
| def test_repr(self): |
| d1 = defaultdict() |
| self.assertEqual(d1.default_factory, None) |
| self.assertEqual(repr(d1), "defaultdict(None, {})") |
| self.assertEqual(eval(repr(d1)), d1) |
| d1[11] = 41 |
| self.assertEqual(repr(d1), "defaultdict(None, {11: 41})") |
| d2 = defaultdict(int) |
| self.assertEqual(d2.default_factory, int) |
| d2[12] = 42 |
| self.assertEqual(repr(d2), "defaultdict(<class 'int'>, {12: 42})") |
| def foo(): return 43 |
| d3 = defaultdict(foo) |
| self.assertTrue(d3.default_factory is foo) |
| d3[13] |
| self.assertEqual(repr(d3), "defaultdict(%s, {13: 43})" % repr(foo)) |
| |
| def test_print(self): |
| d1 = defaultdict() |
| def foo(): return 42 |
| d2 = defaultdict(foo, {1: 2}) |
| # NOTE: We can't use tempfile.[Named]TemporaryFile since this |
| # code must exercise the tp_print C code, which only gets |
| # invoked for *real* files. |
| tfn = tempfile.mktemp() |
| try: |
| f = open(tfn, "w+") |
| try: |
| print(d1, file=f) |
| print(d2, file=f) |
| f.seek(0) |
| self.assertEqual(f.readline(), repr(d1) + "\n") |
| self.assertEqual(f.readline(), repr(d2) + "\n") |
| finally: |
| f.close() |
| finally: |
| os.remove(tfn) |
| |
| def test_copy(self): |
| d1 = defaultdict() |
| d2 = d1.copy() |
| self.assertEqual(type(d2), defaultdict) |
| self.assertEqual(d2.default_factory, None) |
| self.assertEqual(d2, {}) |
| d1.default_factory = list |
| d3 = d1.copy() |
| self.assertEqual(type(d3), defaultdict) |
| self.assertEqual(d3.default_factory, list) |
| self.assertEqual(d3, {}) |
| d1[42] |
| d4 = d1.copy() |
| self.assertEqual(type(d4), defaultdict) |
| self.assertEqual(d4.default_factory, list) |
| self.assertEqual(d4, {42: []}) |
| d4[12] |
| self.assertEqual(d4, {42: [], 12: []}) |
| |
| # Issue 6637: Copy fails for empty default dict |
| d = defaultdict() |
| d['a'] = 42 |
| e = d.copy() |
| self.assertEqual(e['a'], 42) |
| |
| def test_shallow_copy(self): |
| d1 = defaultdict(foobar, {1: 1}) |
| d2 = copy.copy(d1) |
| self.assertEqual(d2.default_factory, foobar) |
| self.assertEqual(d2, d1) |
| d1.default_factory = list |
| d2 = copy.copy(d1) |
| self.assertEqual(d2.default_factory, list) |
| self.assertEqual(d2, d1) |
| |
| def test_deep_copy(self): |
| d1 = defaultdict(foobar, {1: [1]}) |
| d2 = copy.deepcopy(d1) |
| self.assertEqual(d2.default_factory, foobar) |
| self.assertEqual(d2, d1) |
| self.assertTrue(d1[1] is not d2[1]) |
| d1.default_factory = list |
| d2 = copy.deepcopy(d1) |
| self.assertEqual(d2.default_factory, list) |
| self.assertEqual(d2, d1) |
| |
| def test_keyerror_without_factory(self): |
| d1 = defaultdict() |
| try: |
| d1[(1,)] |
| except KeyError as err: |
| self.assertEqual(err.args[0], (1,)) |
| else: |
| self.fail("expected KeyError") |
| |
| def test_recursive_repr(self): |
| # Issue2045: stack overflow when default_factory is a bound method |
| class sub(defaultdict): |
| def __init__(self): |
| self.default_factory = self._factory |
| def _factory(self): |
| return [] |
| d = sub() |
| self.assertRegex(repr(d), |
| r"defaultdict\(<bound method .*sub\._factory " |
| r"of defaultdict\(\.\.\., \{\}\)>, \{\}\)") |
| |
| # NOTE: printing a subclass of a builtin type does not call its |
| # tp_print slot. So this part is essentially the same test as above. |
| tfn = tempfile.mktemp() |
| try: |
| f = open(tfn, "w+") |
| try: |
| print(d, file=f) |
| finally: |
| f.close() |
| finally: |
| os.remove(tfn) |
| |
| def test_callable_arg(self): |
| self.assertRaises(TypeError, defaultdict, {}) |
| |
| def test_pickleing(self): |
| d = defaultdict(int) |
| d[1] |
| for proto in range(pickle.HIGHEST_PROTOCOL + 1): |
| s = pickle.dumps(d, proto) |
| o = pickle.loads(s) |
| self.assertEqual(d, o) |
| |
| def test_main(): |
| support.run_unittest(TestDefaultDict) |
| |
| if __name__ == "__main__": |
| test_main() |