blob: cb493bff52d4d37785236366e51b0b118983ccb2 [file] [log] [blame]
Raymond Hettinger003be522011-05-03 11:01:32 -07001import collections
Łukasz Langa6f692512013-06-05 12:20:24 +02002from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00003import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00004from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02005import sys
6from test import support
7import unittest
8from weakref import proxy
Raymond Hettinger9c323f82005-02-28 19:39:44 +00009
Antoine Pitroub5b37142012-11-13 21:35:40 +010010import functools
11
Antoine Pitroub5b37142012-11-13 21:35:40 +010012py_functools = support.import_fresh_module('functools', blocked=['_functools'])
13c_functools = support.import_fresh_module('functools', fresh=['_functools'])
14
Łukasz Langa6f692512013-06-05 12:20:24 +020015decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
16
17
Raymond Hettinger9c323f82005-02-28 19:39:44 +000018def capture(*args, **kw):
19 """capture all positional and keyword arguments"""
20 return args, kw
21
Łukasz Langa6f692512013-06-05 12:20:24 +020022
Jack Diederiche0cbd692009-04-01 04:27:09 +000023def signature(part):
24 """ return the signature of a partial object """
25 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000026
Łukasz Langa6f692512013-06-05 12:20:24 +020027
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020028class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000029
30 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010031 p = self.partial(capture, 1, 2, a=10, b=20)
32 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000033 self.assertEqual(p(3, 4, b=30, c=40),
34 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010035 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000036 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000037
38 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010039 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000040 # attributes should be readable
41 self.assertEqual(p.func, capture)
42 self.assertEqual(p.args, (1, 2))
43 self.assertEqual(p.keywords, dict(a=10, b=20))
44 # attributes should not be writable
Antoine Pitroub5b37142012-11-13 21:35:40 +010045 if not isinstance(self.partial, type):
Raymond Hettinger9c323f82005-02-28 19:39:44 +000046 return
Georg Brandl89fad142010-03-14 10:23:39 +000047 self.assertRaises(AttributeError, setattr, p, 'func', map)
48 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
49 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
50
Antoine Pitroub5b37142012-11-13 21:35:40 +010051 p = self.partial(hex)
Georg Brandl89fad142010-03-14 10:23:39 +000052 try:
53 del p.__dict__
54 except TypeError:
55 pass
56 else:
57 self.fail('partial object allowed __dict__ to be deleted')
Raymond Hettinger9c323f82005-02-28 19:39:44 +000058
59 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010060 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000061 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010062 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000063 except TypeError:
64 pass
65 else:
66 self.fail('First arg not checked for callability')
67
68 def test_protection_of_callers_dict_argument(self):
69 # a caller's dictionary should not be altered by partial
70 def func(a=10, b=20):
71 return a
72 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010073 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000074 self.assertEqual(p(**d), 3)
75 self.assertEqual(d, {'a':3})
76 p(b=7)
77 self.assertEqual(d, {'a':3})
78
79 def test_arg_combinations(self):
80 # exercise special code paths for zero args in either partial
81 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010082 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000083 self.assertEqual(p(), ((), {}))
84 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010085 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000086 self.assertEqual(p(), ((1,2), {}))
87 self.assertEqual(p(3,4), ((1,2,3,4), {}))
88
89 def test_kw_combinations(self):
90 # exercise special code paths for no keyword args in
91 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010092 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000093 self.assertEqual(p(), ((), {}))
94 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010095 p = self.partial(capture, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000096 self.assertEqual(p(), ((), {'a':1}))
97 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
98 # keyword args in the call override those in the partial object
99 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
100
101 def test_positional(self):
102 # make sure positional arguments are captured correctly
103 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100104 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000105 expected = args + ('x',)
106 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000107 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000108
109 def test_keyword(self):
110 # make sure keyword arguments are captured correctly
111 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100112 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000113 expected = {'a':a,'x':None}
114 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000115 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000116
117 def test_no_side_effects(self):
118 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100119 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000120 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000121 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000122 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000123 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000124
125 def test_error_propagation(self):
126 def f(x, y):
127 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100128 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
129 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
130 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
131 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000132
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000133 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100134 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000135 p = proxy(f)
136 self.assertEqual(f.func, p.func)
137 f = None
138 self.assertRaises(ReferenceError, getattr, p, 'func')
139
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000140 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000141 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100142 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000143 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100144 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000145 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000146
Łukasz Langa6f692512013-06-05 12:20:24 +0200147
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200148@unittest.skipUnless(c_functools, 'requires the C _functools module')
149class TestPartialC(TestPartial, unittest.TestCase):
150 if c_functools:
151 partial = c_functools.partial
152
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000153 def test_repr(self):
154 args = (object(), object())
155 args_repr = ', '.join(repr(a) for a in args)
156 kwargs = {'a': object(), 'b': object()}
157 kwargs_repr = ', '.join("%s=%r" % (k, v) for k, v in kwargs.items())
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200158 if self.partial is c_functools.partial:
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000159 name = 'functools.partial'
160 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100161 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000162
Antoine Pitroub5b37142012-11-13 21:35:40 +0100163 f = self.partial(capture)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000164 self.assertEqual('{}({!r})'.format(name, capture),
165 repr(f))
166
Antoine Pitroub5b37142012-11-13 21:35:40 +0100167 f = self.partial(capture, *args)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000168 self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr),
169 repr(f))
170
Antoine Pitroub5b37142012-11-13 21:35:40 +0100171 f = self.partial(capture, **kwargs)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000172 self.assertEqual('{}({!r}, {})'.format(name, capture, kwargs_repr),
173 repr(f))
174
Antoine Pitroub5b37142012-11-13 21:35:40 +0100175 f = self.partial(capture, *args, **kwargs)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000176 self.assertEqual('{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr),
177 repr(f))
178
Jack Diederiche0cbd692009-04-01 04:27:09 +0000179 def test_pickle(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100180 f = self.partial(signature, 'asdf', bar=True)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000181 f.add_something_to__dict__ = True
182 f_copy = pickle.loads(pickle.dumps(f))
183 self.assertEqual(signature(f), signature(f_copy))
184
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200185 # Issue 6083: Reference counting bug
186 def test_setstate_refcount(self):
187 class BadSequence:
188 def __len__(self):
189 return 4
190 def __getitem__(self, key):
191 if key == 0:
192 return max
193 elif key == 1:
194 return tuple(range(1000000))
195 elif key in (2, 3):
196 return {}
197 raise IndexError
198
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200199 f = self.partial(object)
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200200 self.assertRaisesRegex(SystemError,
201 "new style getargs format but argument is not a tuple",
202 f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000203
Łukasz Langa6f692512013-06-05 12:20:24 +0200204
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200205class TestPartialPy(TestPartial, unittest.TestCase):
206 partial = staticmethod(py_functools.partial)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000207
Łukasz Langa6f692512013-06-05 12:20:24 +0200208
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200209if c_functools:
210 class PartialSubclass(c_functools.partial):
211 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100212
Łukasz Langa6f692512013-06-05 12:20:24 +0200213
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200214@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200215class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200216 if c_functools:
217 partial = PartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000218
Łukasz Langa6f692512013-06-05 12:20:24 +0200219
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000220class TestUpdateWrapper(unittest.TestCase):
221
222 def check_wrapper(self, wrapper, wrapped,
223 assigned=functools.WRAPPER_ASSIGNMENTS,
224 updated=functools.WRAPPER_UPDATES):
225 # Check attributes were assigned
226 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000227 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000228 # Check attributes were updated
229 for name in updated:
230 wrapper_attr = getattr(wrapper, name)
231 wrapped_attr = getattr(wrapped, name)
232 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000233 if name == "__dict__" and key == "__wrapped__":
234 # __wrapped__ is overwritten by the update code
235 continue
236 self.assertIs(wrapped_attr[key], wrapper_attr[key])
237 # Check __wrapped__
238 self.assertIs(wrapper.__wrapped__, wrapped)
239
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000240
R. David Murray378c0cf2010-02-24 01:46:21 +0000241 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000242 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000243 """This is a test"""
244 pass
245 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000246 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000247 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000248 pass
249 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000250 return wrapper, f
251
252 def test_default_update(self):
253 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000254 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000255 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000256 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600257 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000258 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000259 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
260 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000261
R. David Murray378c0cf2010-02-24 01:46:21 +0000262 @unittest.skipIf(sys.flags.optimize >= 2,
263 "Docstrings are omitted with -O2 and above")
264 def test_default_update_doc(self):
265 wrapper, f = self._default_update()
266 self.assertEqual(wrapper.__doc__, 'This is a test')
267
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000268 def test_no_update(self):
269 def f():
270 """This is a test"""
271 pass
272 f.attr = 'This is also a test'
273 def wrapper():
274 pass
275 functools.update_wrapper(wrapper, f, (), ())
276 self.check_wrapper(wrapper, f, (), ())
277 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600278 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000279 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000280 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000281 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000282
283 def test_selective_update(self):
284 def f():
285 pass
286 f.attr = 'This is a different test'
287 f.dict_attr = dict(a=1, b=2, c=3)
288 def wrapper():
289 pass
290 wrapper.dict_attr = {}
291 assign = ('attr',)
292 update = ('dict_attr',)
293 functools.update_wrapper(wrapper, f, assign, update)
294 self.check_wrapper(wrapper, f, assign, update)
295 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600296 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000297 self.assertEqual(wrapper.__doc__, None)
298 self.assertEqual(wrapper.attr, 'This is a different test')
299 self.assertEqual(wrapper.dict_attr, f.dict_attr)
300
Nick Coghlan98876832010-08-17 06:17:18 +0000301 def test_missing_attributes(self):
302 def f():
303 pass
304 def wrapper():
305 pass
306 wrapper.dict_attr = {}
307 assign = ('attr',)
308 update = ('dict_attr',)
309 # Missing attributes on wrapped object are ignored
310 functools.update_wrapper(wrapper, f, assign, update)
311 self.assertNotIn('attr', wrapper.__dict__)
312 self.assertEqual(wrapper.dict_attr, {})
313 # Wrapper must have expected attributes for updating
314 del wrapper.dict_attr
315 with self.assertRaises(AttributeError):
316 functools.update_wrapper(wrapper, f, assign, update)
317 wrapper.dict_attr = 1
318 with self.assertRaises(AttributeError):
319 functools.update_wrapper(wrapper, f, assign, update)
320
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200321 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000322 @unittest.skipIf(sys.flags.optimize >= 2,
323 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000324 def test_builtin_update(self):
325 # Test for bug #1576241
326 def wrapper():
327 pass
328 functools.update_wrapper(wrapper, max)
329 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000330 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000331 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000332
Łukasz Langa6f692512013-06-05 12:20:24 +0200333
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000334class TestWraps(TestUpdateWrapper):
335
R. David Murray378c0cf2010-02-24 01:46:21 +0000336 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000337 def f():
338 """This is a test"""
339 pass
340 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000341 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000342 @functools.wraps(f)
343 def wrapper():
344 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600345 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000346
347 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600348 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000349 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000350 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600351 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000352 self.assertEqual(wrapper.attr, 'This is also a test')
353
Antoine Pitroub5b37142012-11-13 21:35:40 +0100354 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000355 "Docstrings are omitted with -O2 and above")
356 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600357 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000358 self.assertEqual(wrapper.__doc__, 'This is a test')
359
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000360 def test_no_update(self):
361 def f():
362 """This is a test"""
363 pass
364 f.attr = 'This is also a test'
365 @functools.wraps(f, (), ())
366 def wrapper():
367 pass
368 self.check_wrapper(wrapper, f, (), ())
369 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600370 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000371 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000372 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000373
374 def test_selective_update(self):
375 def f():
376 pass
377 f.attr = 'This is a different test'
378 f.dict_attr = dict(a=1, b=2, c=3)
379 def add_dict_attr(f):
380 f.dict_attr = {}
381 return f
382 assign = ('attr',)
383 update = ('dict_attr',)
384 @functools.wraps(f, assign, update)
385 @add_dict_attr
386 def wrapper():
387 pass
388 self.check_wrapper(wrapper, f, assign, update)
389 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600390 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000391 self.assertEqual(wrapper.__doc__, None)
392 self.assertEqual(wrapper.attr, 'This is a different test')
393 self.assertEqual(wrapper.dict_attr, f.dict_attr)
394
Łukasz Langa6f692512013-06-05 12:20:24 +0200395
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000396class TestReduce(unittest.TestCase):
397 func = functools.reduce
398
399 def test_reduce(self):
400 class Squares:
401 def __init__(self, max):
402 self.max = max
403 self.sofar = []
404
405 def __len__(self):
406 return len(self.sofar)
407
408 def __getitem__(self, i):
409 if not 0 <= i < self.max: raise IndexError
410 n = len(self.sofar)
411 while n <= i:
412 self.sofar.append(n*n)
413 n += 1
414 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000415 def add(x, y):
416 return x + y
417 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000418 self.assertEqual(
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000419 self.func(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000420 ['a','c','d','w']
421 )
422 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
423 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000424 self.func(lambda x, y: x*y, range(2,21), 1),
425 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000426 )
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000427 self.assertEqual(self.func(add, Squares(10)), 285)
428 self.assertEqual(self.func(add, Squares(10), 0), 285)
429 self.assertEqual(self.func(add, Squares(0), 0), 0)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000430 self.assertRaises(TypeError, self.func)
431 self.assertRaises(TypeError, self.func, 42, 42)
432 self.assertRaises(TypeError, self.func, 42, 42, 42)
433 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
434 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
435 self.assertRaises(TypeError, self.func, 42, (42, 42))
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000436 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
437 self.assertRaises(TypeError, self.func, add, "")
438 self.assertRaises(TypeError, self.func, add, ())
439 self.assertRaises(TypeError, self.func, add, object())
440
441 class TestFailingIter:
442 def __iter__(self):
443 raise RuntimeError
444 self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
445
446 self.assertEqual(self.func(add, [], None), None)
447 self.assertEqual(self.func(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000448
449 class BadSeq:
450 def __getitem__(self, index):
451 raise ValueError
452 self.assertRaises(ValueError, self.func, 42, BadSeq())
453
454 # Test reduce()'s use of iterators.
455 def test_iterator_usage(self):
456 class SequenceClass:
457 def __init__(self, n):
458 self.n = n
459 def __getitem__(self, i):
460 if 0 <= i < self.n:
461 return i
462 else:
463 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000464
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000465 from operator import add
466 self.assertEqual(self.func(add, SequenceClass(5)), 10)
467 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
468 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
469 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
470 self.assertEqual(self.func(add, SequenceClass(1)), 0)
471 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
472
473 d = {"one": 1, "two": 2, "three": 3}
474 self.assertEqual(self.func(add, d), "".join(d.keys()))
475
Łukasz Langa6f692512013-06-05 12:20:24 +0200476
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200477class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700478
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000479 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700480 def cmp1(x, y):
481 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100482 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700483 self.assertEqual(key(3), key(3))
484 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100485 self.assertGreaterEqual(key(3), key(3))
486
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700487 def cmp2(x, y):
488 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100489 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700490 self.assertEqual(key(4.0), key('4'))
491 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100492 self.assertLessEqual(key(2), key('35'))
493 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700494
495 def test_cmp_to_key_arguments(self):
496 def cmp1(x, y):
497 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100498 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700499 self.assertEqual(key(obj=3), key(obj=3))
500 self.assertGreater(key(obj=3), key(obj=1))
501 with self.assertRaises((TypeError, AttributeError)):
502 key(3) > 1 # rhs is not a K object
503 with self.assertRaises((TypeError, AttributeError)):
504 1 < key(3) # lhs is not a K object
505 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100506 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700507 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200508 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100509 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700510 with self.assertRaises(TypeError):
511 key() # too few args
512 with self.assertRaises(TypeError):
513 key(None, None) # too many args
514
515 def test_bad_cmp(self):
516 def cmp1(x, y):
517 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100518 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700519 with self.assertRaises(ZeroDivisionError):
520 key(3) > key(1)
521
522 class BadCmp:
523 def __lt__(self, other):
524 raise ZeroDivisionError
525 def cmp1(x, y):
526 return BadCmp()
527 with self.assertRaises(ZeroDivisionError):
528 key(3) > key(1)
529
530 def test_obj_field(self):
531 def cmp1(x, y):
532 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100533 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700534 self.assertEqual(key(50).obj, 50)
535
536 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000537 def mycmp(x, y):
538 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100539 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000540 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000541
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700542 def test_sort_int_str(self):
543 def mycmp(x, y):
544 x, y = int(x), int(y)
545 return (x > y) - (x < y)
546 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100547 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700548 self.assertEqual([int(value) for value in values],
549 [0, 1, 1, 2, 3, 4, 5, 7, 10])
550
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000551 def test_hash(self):
552 def mycmp(x, y):
553 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100554 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000555 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700556 self.assertRaises(TypeError, hash, k)
Raymond Hettingere7a24302011-05-03 11:16:36 -0700557 self.assertNotIsInstance(k, collections.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000558
Łukasz Langa6f692512013-06-05 12:20:24 +0200559
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200560@unittest.skipUnless(c_functools, 'requires the C _functools module')
561class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
562 if c_functools:
563 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100564
Łukasz Langa6f692512013-06-05 12:20:24 +0200565
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200566class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100567 cmp_to_key = staticmethod(py_functools.cmp_to_key)
568
Łukasz Langa6f692512013-06-05 12:20:24 +0200569
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000570class TestTotalOrdering(unittest.TestCase):
571
572 def test_total_ordering_lt(self):
573 @functools.total_ordering
574 class A:
575 def __init__(self, value):
576 self.value = value
577 def __lt__(self, other):
578 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000579 def __eq__(self, other):
580 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000581 self.assertTrue(A(1) < A(2))
582 self.assertTrue(A(2) > A(1))
583 self.assertTrue(A(1) <= A(2))
584 self.assertTrue(A(2) >= A(1))
585 self.assertTrue(A(2) <= A(2))
586 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000587 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000588
589 def test_total_ordering_le(self):
590 @functools.total_ordering
591 class A:
592 def __init__(self, value):
593 self.value = value
594 def __le__(self, other):
595 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000596 def __eq__(self, other):
597 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000598 self.assertTrue(A(1) < A(2))
599 self.assertTrue(A(2) > A(1))
600 self.assertTrue(A(1) <= A(2))
601 self.assertTrue(A(2) >= A(1))
602 self.assertTrue(A(2) <= A(2))
603 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000604 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000605
606 def test_total_ordering_gt(self):
607 @functools.total_ordering
608 class A:
609 def __init__(self, value):
610 self.value = value
611 def __gt__(self, other):
612 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000613 def __eq__(self, other):
614 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000615 self.assertTrue(A(1) < A(2))
616 self.assertTrue(A(2) > A(1))
617 self.assertTrue(A(1) <= A(2))
618 self.assertTrue(A(2) >= A(1))
619 self.assertTrue(A(2) <= A(2))
620 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000621 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000622
623 def test_total_ordering_ge(self):
624 @functools.total_ordering
625 class A:
626 def __init__(self, value):
627 self.value = value
628 def __ge__(self, other):
629 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000630 def __eq__(self, other):
631 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000632 self.assertTrue(A(1) < A(2))
633 self.assertTrue(A(2) > A(1))
634 self.assertTrue(A(1) <= A(2))
635 self.assertTrue(A(2) >= A(1))
636 self.assertTrue(A(2) <= A(2))
637 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000638 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000639
640 def test_total_ordering_no_overwrite(self):
641 # new methods should not overwrite existing
642 @functools.total_ordering
643 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000644 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +0000645 self.assertTrue(A(1) < A(2))
646 self.assertTrue(A(2) > A(1))
647 self.assertTrue(A(1) <= A(2))
648 self.assertTrue(A(2) >= A(1))
649 self.assertTrue(A(2) <= A(2))
650 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000651
Benjamin Peterson42ebee32010-04-11 01:43:16 +0000652 def test_no_operations_defined(self):
653 with self.assertRaises(ValueError):
654 @functools.total_ordering
655 class A:
656 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000657
Nick Coghlanf05d9812013-10-02 00:02:03 +1000658 def test_type_error_when_not_implemented(self):
659 # bug 10042; ensure stack overflow does not occur
660 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000661 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +1000662 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000663 def __init__(self, value):
664 self.value = value
665 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000666 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000667 return self.value == other.value
668 return False
669 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000670 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000671 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +1000672 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000673
Nick Coghlanf05d9812013-10-02 00:02:03 +1000674 @functools.total_ordering
675 class ImplementsGreaterThan:
676 def __init__(self, value):
677 self.value = value
678 def __eq__(self, other):
679 if isinstance(other, ImplementsGreaterThan):
680 return self.value == other.value
681 return False
682 def __gt__(self, other):
683 if isinstance(other, ImplementsGreaterThan):
684 return self.value > other.value
685 return NotImplemented
686
687 @functools.total_ordering
688 class ImplementsLessThanEqualTo:
689 def __init__(self, value):
690 self.value = value
691 def __eq__(self, other):
692 if isinstance(other, ImplementsLessThanEqualTo):
693 return self.value == other.value
694 return False
695 def __le__(self, other):
696 if isinstance(other, ImplementsLessThanEqualTo):
697 return self.value <= other.value
698 return NotImplemented
699
700 @functools.total_ordering
701 class ImplementsGreaterThanEqualTo:
702 def __init__(self, value):
703 self.value = value
704 def __eq__(self, other):
705 if isinstance(other, ImplementsGreaterThanEqualTo):
706 return self.value == other.value
707 return False
708 def __ge__(self, other):
709 if isinstance(other, ImplementsGreaterThanEqualTo):
710 return self.value >= other.value
711 return NotImplemented
712
713 @functools.total_ordering
714 class ComparatorNotImplemented:
715 def __init__(self, value):
716 self.value = value
717 def __eq__(self, other):
718 if isinstance(other, ComparatorNotImplemented):
719 return self.value == other.value
720 return False
721 def __lt__(self, other):
722 return NotImplemented
723
724 with self.subTest("LT < 1"), self.assertRaises(TypeError):
725 ImplementsLessThan(-1) < 1
726
727 with self.subTest("LT < LE"), self.assertRaises(TypeError):
728 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
729
730 with self.subTest("LT < GT"), self.assertRaises(TypeError):
731 ImplementsLessThan(1) < ImplementsGreaterThan(1)
732
733 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
734 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
735
736 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
737 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
738
739 with self.subTest("GT > GE"), self.assertRaises(TypeError):
740 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
741
742 with self.subTest("GT > LT"), self.assertRaises(TypeError):
743 ImplementsGreaterThan(5) > ImplementsLessThan(5)
744
745 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
746 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
747
748 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
749 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
750
751 with self.subTest("GE when equal"):
752 a = ComparatorNotImplemented(8)
753 b = ComparatorNotImplemented(8)
754 self.assertEqual(a, b)
755 with self.assertRaises(TypeError):
756 a >= b
757
758 with self.subTest("LE when equal"):
759 a = ComparatorNotImplemented(9)
760 b = ComparatorNotImplemented(9)
761 self.assertEqual(a, b)
762 with self.assertRaises(TypeError):
763 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +0200764
Georg Brandl2e7346a2010-07-31 18:09:23 +0000765class TestLRU(unittest.TestCase):
766
767 def test_lru(self):
768 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100769 return 3 * x + y
Georg Brandl2e7346a2010-07-31 18:09:23 +0000770 f = functools.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000771 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000772 self.assertEqual(maxsize, 20)
773 self.assertEqual(currsize, 0)
774 self.assertEqual(hits, 0)
775 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000776
777 domain = range(5)
778 for i in range(1000):
779 x, y = choice(domain), choice(domain)
780 actual = f(x, y)
781 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +0000782 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000783 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000784 self.assertTrue(hits > misses)
785 self.assertEqual(hits + misses, 1000)
786 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000787
Raymond Hettinger02566ec2010-09-04 22:46:06 +0000788 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +0000789 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000790 self.assertEqual(hits, 0)
791 self.assertEqual(misses, 0)
792 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000793 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000794 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000795 self.assertEqual(hits, 0)
796 self.assertEqual(misses, 1)
797 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000798
Nick Coghlan98876832010-08-17 06:17:18 +0000799 # Test bypassing the cache
800 self.assertIs(f.__wrapped__, orig)
801 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000802 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000803 self.assertEqual(hits, 0)
804 self.assertEqual(misses, 1)
805 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +0000806
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000807 # test size zero (which means "never-cache")
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000808 @functools.lru_cache(0)
809 def f():
810 nonlocal f_cnt
811 f_cnt += 1
812 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000813 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000814 f_cnt = 0
815 for i in range(5):
816 self.assertEqual(f(), 20)
817 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000818 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000819 self.assertEqual(hits, 0)
820 self.assertEqual(misses, 5)
821 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000822
823 # test size one
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000824 @functools.lru_cache(1)
825 def f():
826 nonlocal f_cnt
827 f_cnt += 1
828 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000829 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000830 f_cnt = 0
831 for i in range(5):
832 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000833 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000834 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000835 self.assertEqual(hits, 4)
836 self.assertEqual(misses, 1)
837 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000838
Raymond Hettingerf3098282010-08-15 03:30:45 +0000839 # test size two
840 @functools.lru_cache(2)
841 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000842 nonlocal f_cnt
843 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +0000844 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +0000845 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000846 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +0000847 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
848 # * * * *
849 self.assertEqual(f(x), x*10)
850 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000851 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000852 self.assertEqual(hits, 12)
853 self.assertEqual(misses, 4)
854 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000855
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +0000856 def test_lru_with_maxsize_none(self):
857 @functools.lru_cache(maxsize=None)
858 def fib(n):
859 if n < 2:
860 return n
861 return fib(n-1) + fib(n-2)
862 self.assertEqual([fib(n) for n in range(16)],
863 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
864 self.assertEqual(fib.cache_info(),
865 functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
866 fib.cache_clear()
867 self.assertEqual(fib.cache_info(),
868 functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
869
Raymond Hettinger4b779b32011-10-15 23:50:42 -0700870 def test_lru_with_exceptions(self):
871 # Verify that user_function exceptions get passed through without
872 # creating a hard-to-read chained exception.
873 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +0100874 for maxsize in (None, 128):
Raymond Hettinger4b779b32011-10-15 23:50:42 -0700875 @functools.lru_cache(maxsize)
876 def func(i):
877 return 'abc'[i]
878 self.assertEqual(func(0), 'a')
879 with self.assertRaises(IndexError) as cm:
880 func(15)
881 self.assertIsNone(cm.exception.__context__)
882 # Verify that the previous exception did not result in a cached entry
883 with self.assertRaises(IndexError):
884 func(15)
885
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -0700886 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100887 for maxsize in (None, 128):
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -0700888 @functools.lru_cache(maxsize=maxsize, typed=True)
889 def square(x):
890 return x * x
891 self.assertEqual(square(3), 9)
892 self.assertEqual(type(square(3)), type(9))
893 self.assertEqual(square(3.0), 9.0)
894 self.assertEqual(type(square(3.0)), type(9.0))
895 self.assertEqual(square(x=3), 9)
896 self.assertEqual(type(square(x=3)), type(9))
897 self.assertEqual(square(x=3.0), 9.0)
898 self.assertEqual(type(square(x=3.0)), type(9.0))
899 self.assertEqual(square.cache_info().hits, 4)
900 self.assertEqual(square.cache_info().misses, 4)
901
Antoine Pitroub5b37142012-11-13 21:35:40 +0100902 def test_lru_with_keyword_args(self):
903 @functools.lru_cache()
904 def fib(n):
905 if n < 2:
906 return n
907 return fib(n=n-1) + fib(n=n-2)
908 self.assertEqual(
909 [fib(n=number) for number in range(16)],
910 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
911 )
912 self.assertEqual(fib.cache_info(),
913 functools._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
914 fib.cache_clear()
915 self.assertEqual(fib.cache_info(),
916 functools._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
917
918 def test_lru_with_keyword_args_maxsize_none(self):
919 @functools.lru_cache(maxsize=None)
920 def fib(n):
921 if n < 2:
922 return n
923 return fib(n=n-1) + fib(n=n-2)
924 self.assertEqual([fib(n=number) for number in range(16)],
925 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
926 self.assertEqual(fib.cache_info(),
927 functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
928 fib.cache_clear()
929 self.assertEqual(fib.cache_info(),
930 functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
931
Raymond Hettinger03923422013-03-04 02:52:50 -0500932 def test_need_for_rlock(self):
933 # This will deadlock on an LRU cache that uses a regular lock
934
935 @functools.lru_cache(maxsize=10)
936 def test_func(x):
937 'Used to demonstrate a reentrant lru_cache call within a single thread'
938 return x
939
940 class DoubleEq:
941 'Demonstrate a reentrant lru_cache call within a single thread'
942 def __init__(self, x):
943 self.x = x
944 def __hash__(self):
945 return self.x
946 def __eq__(self, other):
947 if self.x == 2:
948 test_func(DoubleEq(1))
949 return self.x == other.x
950
951 test_func(DoubleEq(1)) # Load the cache
952 test_func(DoubleEq(2)) # Load the cache
953 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
954 DoubleEq(2)) # Verify the correct return value
955
956
Łukasz Langa6f692512013-06-05 12:20:24 +0200957class TestSingleDispatch(unittest.TestCase):
958 def test_simple_overloads(self):
959 @functools.singledispatch
960 def g(obj):
961 return "base"
962 def g_int(i):
963 return "integer"
964 g.register(int, g_int)
965 self.assertEqual(g("str"), "base")
966 self.assertEqual(g(1), "integer")
967 self.assertEqual(g([1,2,3]), "base")
968
969 def test_mro(self):
970 @functools.singledispatch
971 def g(obj):
972 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +0200973 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +0200974 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +0200975 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +0200976 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +0200977 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +0200978 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +0200979 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +0200980 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +0200981 def g_A(a):
982 return "A"
983 def g_B(b):
984 return "B"
985 g.register(A, g_A)
986 g.register(B, g_B)
987 self.assertEqual(g(A()), "A")
988 self.assertEqual(g(B()), "B")
989 self.assertEqual(g(C()), "A")
990 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +0200991
992 def test_register_decorator(self):
993 @functools.singledispatch
994 def g(obj):
995 return "base"
996 @g.register(int)
997 def g_int(i):
998 return "int %s" % (i,)
999 self.assertEqual(g(""), "base")
1000 self.assertEqual(g(12), "int 12")
1001 self.assertIs(g.dispatch(int), g_int)
1002 self.assertIs(g.dispatch(object), g.dispatch(str))
1003 # Note: in the assert above this is not g.
1004 # @singledispatch returns the wrapper.
1005
1006 def test_wrapping_attributes(self):
1007 @functools.singledispatch
1008 def g(obj):
1009 "Simple test"
1010 return "Test"
1011 self.assertEqual(g.__name__, "g")
1012 self.assertEqual(g.__doc__, "Simple test")
1013
1014 @unittest.skipUnless(decimal, 'requires _decimal')
1015 @support.cpython_only
1016 def test_c_classes(self):
1017 @functools.singledispatch
1018 def g(obj):
1019 return "base"
1020 @g.register(decimal.DecimalException)
1021 def _(obj):
1022 return obj.args
1023 subn = decimal.Subnormal("Exponent < Emin")
1024 rnd = decimal.Rounded("Number got rounded")
1025 self.assertEqual(g(subn), ("Exponent < Emin",))
1026 self.assertEqual(g(rnd), ("Number got rounded",))
1027 @g.register(decimal.Subnormal)
1028 def _(obj):
1029 return "Too small to care."
1030 self.assertEqual(g(subn), "Too small to care.")
1031 self.assertEqual(g(rnd), ("Number got rounded",))
1032
1033 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001034 # None of the examples in this test depend on haystack ordering.
Łukasz Langa6f692512013-06-05 12:20:24 +02001035 c = collections
1036 mro = functools._compose_mro
1037 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1038 for haystack in permutations(bases):
1039 m = mro(dict, haystack)
Łukasz Langa3720c772013-07-01 16:00:38 +02001040 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, c.Sized,
1041 c.Iterable, c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001042 bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict]
1043 for haystack in permutations(bases):
1044 m = mro(c.ChainMap, haystack)
1045 self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping,
1046 c.Sized, c.Iterable, c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001047
1048 # If there's a generic function with implementations registered for
1049 # both Sized and Container, passing a defaultdict to it results in an
1050 # ambiguous dispatch which will cause a RuntimeError (see
1051 # test_mro_conflicts).
1052 bases = [c.Container, c.Sized, str]
1053 for haystack in permutations(bases):
1054 m = mro(c.defaultdict, [c.Sized, c.Container, str])
1055 self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container,
1056 object])
1057
1058 # MutableSequence below is registered directly on D. In other words, it
1059 # preceeds MutableMapping which means single dispatch will always
1060 # choose MutableSequence here.
1061 class D(c.defaultdict):
1062 pass
1063 c.MutableSequence.register(D)
1064 bases = [c.MutableSequence, c.MutableMapping]
1065 for haystack in permutations(bases):
1066 m = mro(D, bases)
1067 self.assertEqual(m, [D, c.MutableSequence, c.Sequence,
1068 c.defaultdict, dict, c.MutableMapping,
1069 c.Mapping, c.Sized, c.Iterable, c.Container,
1070 object])
1071
1072 # Container and Callable are registered on different base classes and
1073 # a generic function supporting both should always pick the Callable
1074 # implementation if a C instance is passed.
1075 class C(c.defaultdict):
1076 def __call__(self):
1077 pass
1078 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1079 for haystack in permutations(bases):
1080 m = mro(C, haystack)
1081 self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping,
1082 c.Sized, c.Iterable, c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001083
1084 def test_register_abc(self):
1085 c = collections
1086 d = {"a": "b"}
1087 l = [1, 2, 3]
1088 s = {object(), None}
1089 f = frozenset(s)
1090 t = (1, 2, 3)
1091 @functools.singledispatch
1092 def g(obj):
1093 return "base"
1094 self.assertEqual(g(d), "base")
1095 self.assertEqual(g(l), "base")
1096 self.assertEqual(g(s), "base")
1097 self.assertEqual(g(f), "base")
1098 self.assertEqual(g(t), "base")
1099 g.register(c.Sized, lambda obj: "sized")
1100 self.assertEqual(g(d), "sized")
1101 self.assertEqual(g(l), "sized")
1102 self.assertEqual(g(s), "sized")
1103 self.assertEqual(g(f), "sized")
1104 self.assertEqual(g(t), "sized")
1105 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1106 self.assertEqual(g(d), "mutablemapping")
1107 self.assertEqual(g(l), "sized")
1108 self.assertEqual(g(s), "sized")
1109 self.assertEqual(g(f), "sized")
1110 self.assertEqual(g(t), "sized")
1111 g.register(c.ChainMap, lambda obj: "chainmap")
1112 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1113 self.assertEqual(g(l), "sized")
1114 self.assertEqual(g(s), "sized")
1115 self.assertEqual(g(f), "sized")
1116 self.assertEqual(g(t), "sized")
1117 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1118 self.assertEqual(g(d), "mutablemapping")
1119 self.assertEqual(g(l), "mutablesequence")
1120 self.assertEqual(g(s), "sized")
1121 self.assertEqual(g(f), "sized")
1122 self.assertEqual(g(t), "sized")
1123 g.register(c.MutableSet, lambda obj: "mutableset")
1124 self.assertEqual(g(d), "mutablemapping")
1125 self.assertEqual(g(l), "mutablesequence")
1126 self.assertEqual(g(s), "mutableset")
1127 self.assertEqual(g(f), "sized")
1128 self.assertEqual(g(t), "sized")
1129 g.register(c.Mapping, lambda obj: "mapping")
1130 self.assertEqual(g(d), "mutablemapping") # not specific enough
1131 self.assertEqual(g(l), "mutablesequence")
1132 self.assertEqual(g(s), "mutableset")
1133 self.assertEqual(g(f), "sized")
1134 self.assertEqual(g(t), "sized")
1135 g.register(c.Sequence, lambda obj: "sequence")
1136 self.assertEqual(g(d), "mutablemapping")
1137 self.assertEqual(g(l), "mutablesequence")
1138 self.assertEqual(g(s), "mutableset")
1139 self.assertEqual(g(f), "sized")
1140 self.assertEqual(g(t), "sequence")
1141 g.register(c.Set, lambda obj: "set")
1142 self.assertEqual(g(d), "mutablemapping")
1143 self.assertEqual(g(l), "mutablesequence")
1144 self.assertEqual(g(s), "mutableset")
1145 self.assertEqual(g(f), "set")
1146 self.assertEqual(g(t), "sequence")
1147 g.register(dict, lambda obj: "dict")
1148 self.assertEqual(g(d), "dict")
1149 self.assertEqual(g(l), "mutablesequence")
1150 self.assertEqual(g(s), "mutableset")
1151 self.assertEqual(g(f), "set")
1152 self.assertEqual(g(t), "sequence")
1153 g.register(list, lambda obj: "list")
1154 self.assertEqual(g(d), "dict")
1155 self.assertEqual(g(l), "list")
1156 self.assertEqual(g(s), "mutableset")
1157 self.assertEqual(g(f), "set")
1158 self.assertEqual(g(t), "sequence")
1159 g.register(set, lambda obj: "concrete-set")
1160 self.assertEqual(g(d), "dict")
1161 self.assertEqual(g(l), "list")
1162 self.assertEqual(g(s), "concrete-set")
1163 self.assertEqual(g(f), "set")
1164 self.assertEqual(g(t), "sequence")
1165 g.register(frozenset, lambda obj: "frozen-set")
1166 self.assertEqual(g(d), "dict")
1167 self.assertEqual(g(l), "list")
1168 self.assertEqual(g(s), "concrete-set")
1169 self.assertEqual(g(f), "frozen-set")
1170 self.assertEqual(g(t), "sequence")
1171 g.register(tuple, lambda obj: "tuple")
1172 self.assertEqual(g(d), "dict")
1173 self.assertEqual(g(l), "list")
1174 self.assertEqual(g(s), "concrete-set")
1175 self.assertEqual(g(f), "frozen-set")
1176 self.assertEqual(g(t), "tuple")
1177
Łukasz Langa3720c772013-07-01 16:00:38 +02001178 def test_c3_abc(self):
1179 c = collections
1180 mro = functools._c3_mro
1181 class A(object):
1182 pass
1183 class B(A):
1184 def __len__(self):
1185 return 0 # implies Sized
1186 @c.Container.register
1187 class C(object):
1188 pass
1189 class D(object):
1190 pass # unrelated
1191 class X(D, C, B):
1192 def __call__(self):
1193 pass # implies Callable
1194 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
1195 for abcs in permutations([c.Sized, c.Callable, c.Container]):
1196 self.assertEqual(mro(X, abcs=abcs), expected)
1197 # unrelated ABCs don't appear in the resulting MRO
1198 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
1199 self.assertEqual(mro(X, abcs=many_abcs), expected)
1200
Łukasz Langa6f692512013-06-05 12:20:24 +02001201 def test_mro_conflicts(self):
1202 c = collections
Łukasz Langa6f692512013-06-05 12:20:24 +02001203 @functools.singledispatch
1204 def g(arg):
1205 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02001206 class O(c.Sized):
1207 def __len__(self):
1208 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001209 o = O()
1210 self.assertEqual(g(o), "base")
1211 g.register(c.Iterable, lambda arg: "iterable")
1212 g.register(c.Container, lambda arg: "container")
1213 g.register(c.Sized, lambda arg: "sized")
1214 g.register(c.Set, lambda arg: "set")
1215 self.assertEqual(g(o), "sized")
1216 c.Iterable.register(O)
1217 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
1218 c.Container.register(O)
1219 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02001220 c.Set.register(O)
1221 self.assertEqual(g(o), "set") # because c.Set is a subclass of
1222 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02001223 class P:
1224 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02001225 p = P()
1226 self.assertEqual(g(p), "base")
1227 c.Iterable.register(P)
1228 self.assertEqual(g(p), "iterable")
1229 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02001230 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02001231 g(p)
Łukasz Langa3720c772013-07-01 16:00:38 +02001232 self.assertIn(
1233 str(re_one.exception),
1234 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1235 "or <class 'collections.abc.Iterable'>"),
1236 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
1237 "or <class 'collections.abc.Container'>")),
1238 )
Łukasz Langa6f692512013-06-05 12:20:24 +02001239 class Q(c.Sized):
1240 def __len__(self):
1241 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001242 q = Q()
1243 self.assertEqual(g(q), "sized")
1244 c.Iterable.register(Q)
1245 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
1246 c.Set.register(Q)
1247 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02001248 # c.Sized and c.Iterable
1249 @functools.singledispatch
1250 def h(arg):
1251 return "base"
1252 @h.register(c.Sized)
1253 def _(arg):
1254 return "sized"
1255 @h.register(c.Container)
1256 def _(arg):
1257 return "container"
1258 # Even though Sized and Container are explicit bases of MutableMapping,
1259 # this ABC is implicitly registered on defaultdict which makes all of
1260 # MutableMapping's bases implicit as well from defaultdict's
1261 # perspective.
1262 with self.assertRaises(RuntimeError) as re_two:
1263 h(c.defaultdict(lambda: 0))
1264 self.assertIn(
1265 str(re_two.exception),
1266 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1267 "or <class 'collections.abc.Sized'>"),
1268 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1269 "or <class 'collections.abc.Container'>")),
1270 )
1271 class R(c.defaultdict):
1272 pass
1273 c.MutableSequence.register(R)
1274 @functools.singledispatch
1275 def i(arg):
1276 return "base"
1277 @i.register(c.MutableMapping)
1278 def _(arg):
1279 return "mapping"
1280 @i.register(c.MutableSequence)
1281 def _(arg):
1282 return "sequence"
1283 r = R()
1284 self.assertEqual(i(r), "sequence")
1285 class S:
1286 pass
1287 class T(S, c.Sized):
1288 def __len__(self):
1289 return 0
1290 t = T()
1291 self.assertEqual(h(t), "sized")
1292 c.Container.register(T)
1293 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
1294 class U:
1295 def __len__(self):
1296 return 0
1297 u = U()
1298 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
1299 # from the existence of __len__()
1300 c.Container.register(U)
1301 # There is no preference for registered versus inferred ABCs.
1302 with self.assertRaises(RuntimeError) as re_three:
1303 h(u)
1304 self.assertIn(
1305 str(re_three.exception),
1306 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1307 "or <class 'collections.abc.Sized'>"),
1308 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1309 "or <class 'collections.abc.Container'>")),
1310 )
1311 class V(c.Sized, S):
1312 def __len__(self):
1313 return 0
1314 @functools.singledispatch
1315 def j(arg):
1316 return "base"
1317 @j.register(S)
1318 def _(arg):
1319 return "s"
1320 @j.register(c.Container)
1321 def _(arg):
1322 return "container"
1323 v = V()
1324 self.assertEqual(j(v), "s")
1325 c.Container.register(V)
1326 self.assertEqual(j(v), "container") # because it ends up right after
1327 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02001328
1329 def test_cache_invalidation(self):
1330 from collections import UserDict
1331 class TracingDict(UserDict):
1332 def __init__(self, *args, **kwargs):
1333 super(TracingDict, self).__init__(*args, **kwargs)
1334 self.set_ops = []
1335 self.get_ops = []
1336 def __getitem__(self, key):
1337 result = self.data[key]
1338 self.get_ops.append(key)
1339 return result
1340 def __setitem__(self, key, value):
1341 self.set_ops.append(key)
1342 self.data[key] = value
1343 def clear(self):
1344 self.data.clear()
1345 _orig_wkd = functools.WeakKeyDictionary
1346 td = TracingDict()
1347 functools.WeakKeyDictionary = lambda: td
1348 c = collections
1349 @functools.singledispatch
1350 def g(arg):
1351 return "base"
1352 d = {}
1353 l = []
1354 self.assertEqual(len(td), 0)
1355 self.assertEqual(g(d), "base")
1356 self.assertEqual(len(td), 1)
1357 self.assertEqual(td.get_ops, [])
1358 self.assertEqual(td.set_ops, [dict])
1359 self.assertEqual(td.data[dict], g.registry[object])
1360 self.assertEqual(g(l), "base")
1361 self.assertEqual(len(td), 2)
1362 self.assertEqual(td.get_ops, [])
1363 self.assertEqual(td.set_ops, [dict, list])
1364 self.assertEqual(td.data[dict], g.registry[object])
1365 self.assertEqual(td.data[list], g.registry[object])
1366 self.assertEqual(td.data[dict], td.data[list])
1367 self.assertEqual(g(l), "base")
1368 self.assertEqual(g(d), "base")
1369 self.assertEqual(td.get_ops, [list, dict])
1370 self.assertEqual(td.set_ops, [dict, list])
1371 g.register(list, lambda arg: "list")
1372 self.assertEqual(td.get_ops, [list, dict])
1373 self.assertEqual(len(td), 0)
1374 self.assertEqual(g(d), "base")
1375 self.assertEqual(len(td), 1)
1376 self.assertEqual(td.get_ops, [list, dict])
1377 self.assertEqual(td.set_ops, [dict, list, dict])
1378 self.assertEqual(td.data[dict],
1379 functools._find_impl(dict, g.registry))
1380 self.assertEqual(g(l), "list")
1381 self.assertEqual(len(td), 2)
1382 self.assertEqual(td.get_ops, [list, dict])
1383 self.assertEqual(td.set_ops, [dict, list, dict, list])
1384 self.assertEqual(td.data[list],
1385 functools._find_impl(list, g.registry))
1386 class X:
1387 pass
1388 c.MutableMapping.register(X) # Will not invalidate the cache,
1389 # not using ABCs yet.
1390 self.assertEqual(g(d), "base")
1391 self.assertEqual(g(l), "list")
1392 self.assertEqual(td.get_ops, [list, dict, dict, list])
1393 self.assertEqual(td.set_ops, [dict, list, dict, list])
1394 g.register(c.Sized, lambda arg: "sized")
1395 self.assertEqual(len(td), 0)
1396 self.assertEqual(g(d), "sized")
1397 self.assertEqual(len(td), 1)
1398 self.assertEqual(td.get_ops, [list, dict, dict, list])
1399 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
1400 self.assertEqual(g(l), "list")
1401 self.assertEqual(len(td), 2)
1402 self.assertEqual(td.get_ops, [list, dict, dict, list])
1403 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1404 self.assertEqual(g(l), "list")
1405 self.assertEqual(g(d), "sized")
1406 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
1407 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1408 g.dispatch(list)
1409 g.dispatch(dict)
1410 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
1411 list, dict])
1412 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1413 c.MutableSet.register(X) # Will invalidate the cache.
1414 self.assertEqual(len(td), 2) # Stale cache.
1415 self.assertEqual(g(l), "list")
1416 self.assertEqual(len(td), 1)
1417 g.register(c.MutableMapping, lambda arg: "mutablemapping")
1418 self.assertEqual(len(td), 0)
1419 self.assertEqual(g(d), "mutablemapping")
1420 self.assertEqual(len(td), 1)
1421 self.assertEqual(g(l), "list")
1422 self.assertEqual(len(td), 2)
1423 g.register(dict, lambda arg: "dict")
1424 self.assertEqual(g(d), "dict")
1425 self.assertEqual(g(l), "list")
1426 g._clear_cache()
1427 self.assertEqual(len(td), 0)
1428 functools.WeakKeyDictionary = _orig_wkd
1429
1430
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001431def test_main(verbose=None):
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001432 test_classes = (
Antoine Pitroub5b37142012-11-13 21:35:40 +01001433 TestPartialC,
1434 TestPartialPy,
1435 TestPartialCSubclass,
Thomas Wouters73e5a5b2006-06-08 15:35:45 +00001436 TestUpdateWrapper,
Benjamin Peterson9c2930e2010-08-23 17:40:33 +00001437 TestTotalOrdering,
Antoine Pitroub5b37142012-11-13 21:35:40 +01001438 TestCmpToKeyC,
1439 TestCmpToKeyPy,
Guido van Rossum0919a1a2006-08-26 20:49:04 +00001440 TestWraps,
Georg Brandl2e7346a2010-07-31 18:09:23 +00001441 TestReduce,
1442 TestLRU,
Łukasz Langa6f692512013-06-05 12:20:24 +02001443 TestSingleDispatch,
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001444 )
Benjamin Petersonee8712c2008-05-20 21:35:26 +00001445 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001446
1447 # verify reference counting
1448 if verbose and hasattr(sys, "gettotalrefcount"):
1449 import gc
1450 counts = [None] * 5
Guido van Rossum805365e2007-05-07 22:24:25 +00001451 for i in range(len(counts)):
Benjamin Petersonee8712c2008-05-20 21:35:26 +00001452 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001453 gc.collect()
1454 counts[i] = sys.gettotalrefcount()
Guido van Rossumbe19ed72007-02-09 05:37:30 +00001455 print(counts)
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001456
1457if __name__ == '__main__':
1458 test_main(verbose=True)