blob: 49c807d059241a47539e7c3fc3d70c165f9d3bea [file] [log] [blame]
Raymond Hettinger003be522011-05-03 11:01:32 -07001import collections
Łukasz Langa6f692512013-06-05 12:20:24 +02002from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00003import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00004from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02005import sys
6from test import support
7import unittest
8from weakref import proxy
Raymond Hettinger9c323f82005-02-28 19:39:44 +00009
Antoine Pitroub5b37142012-11-13 21:35:40 +010010import functools
11
Antoine Pitroub5b37142012-11-13 21:35:40 +010012py_functools = support.import_fresh_module('functools', blocked=['_functools'])
13c_functools = support.import_fresh_module('functools', fresh=['_functools'])
14
Łukasz Langa6f692512013-06-05 12:20:24 +020015decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
16
17
Raymond Hettinger9c323f82005-02-28 19:39:44 +000018def capture(*args, **kw):
19 """capture all positional and keyword arguments"""
20 return args, kw
21
Łukasz Langa6f692512013-06-05 12:20:24 +020022
Jack Diederiche0cbd692009-04-01 04:27:09 +000023def signature(part):
24 """ return the signature of a partial object """
25 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000026
Łukasz Langa6f692512013-06-05 12:20:24 +020027
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020028class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000029
30 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010031 p = self.partial(capture, 1, 2, a=10, b=20)
32 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000033 self.assertEqual(p(3, 4, b=30, c=40),
34 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010035 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000036 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000037
38 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010039 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000040 # attributes should be readable
41 self.assertEqual(p.func, capture)
42 self.assertEqual(p.args, (1, 2))
43 self.assertEqual(p.keywords, dict(a=10, b=20))
44 # attributes should not be writable
Antoine Pitroub5b37142012-11-13 21:35:40 +010045 if not isinstance(self.partial, type):
Raymond Hettinger9c323f82005-02-28 19:39:44 +000046 return
Georg Brandl89fad142010-03-14 10:23:39 +000047 self.assertRaises(AttributeError, setattr, p, 'func', map)
48 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
49 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
50
Antoine Pitroub5b37142012-11-13 21:35:40 +010051 p = self.partial(hex)
Georg Brandl89fad142010-03-14 10:23:39 +000052 try:
53 del p.__dict__
54 except TypeError:
55 pass
56 else:
57 self.fail('partial object allowed __dict__ to be deleted')
Raymond Hettinger9c323f82005-02-28 19:39:44 +000058
59 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010060 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000061 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010062 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000063 except TypeError:
64 pass
65 else:
66 self.fail('First arg not checked for callability')
67
68 def test_protection_of_callers_dict_argument(self):
69 # a caller's dictionary should not be altered by partial
70 def func(a=10, b=20):
71 return a
72 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010073 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000074 self.assertEqual(p(**d), 3)
75 self.assertEqual(d, {'a':3})
76 p(b=7)
77 self.assertEqual(d, {'a':3})
78
79 def test_arg_combinations(self):
80 # exercise special code paths for zero args in either partial
81 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010082 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000083 self.assertEqual(p(), ((), {}))
84 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010085 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000086 self.assertEqual(p(), ((1,2), {}))
87 self.assertEqual(p(3,4), ((1,2,3,4), {}))
88
89 def test_kw_combinations(self):
90 # exercise special code paths for no keyword args in
91 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010092 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000093 self.assertEqual(p(), ((), {}))
94 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010095 p = self.partial(capture, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000096 self.assertEqual(p(), ((), {'a':1}))
97 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
98 # keyword args in the call override those in the partial object
99 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
100
101 def test_positional(self):
102 # make sure positional arguments are captured correctly
103 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100104 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000105 expected = args + ('x',)
106 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000107 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000108
109 def test_keyword(self):
110 # make sure keyword arguments are captured correctly
111 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100112 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000113 expected = {'a':a,'x':None}
114 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000115 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000116
117 def test_no_side_effects(self):
118 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100119 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000120 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000121 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000122 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000123 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000124
125 def test_error_propagation(self):
126 def f(x, y):
127 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100128 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
129 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
130 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
131 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000132
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000133 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100134 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000135 p = proxy(f)
136 self.assertEqual(f.func, p.func)
137 f = None
138 self.assertRaises(ReferenceError, getattr, p, 'func')
139
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000140 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000141 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100142 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000143 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100144 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000145 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000146
Łukasz Langa6f692512013-06-05 12:20:24 +0200147
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200148@unittest.skipUnless(c_functools, 'requires the C _functools module')
149class TestPartialC(TestPartial, unittest.TestCase):
150 if c_functools:
151 partial = c_functools.partial
152
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000153 def test_repr(self):
154 args = (object(), object())
155 args_repr = ', '.join(repr(a) for a in args)
156 kwargs = {'a': object(), 'b': object()}
157 kwargs_repr = ', '.join("%s=%r" % (k, v) for k, v in kwargs.items())
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200158 if self.partial is c_functools.partial:
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000159 name = 'functools.partial'
160 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100161 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000162
Antoine Pitroub5b37142012-11-13 21:35:40 +0100163 f = self.partial(capture)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000164 self.assertEqual('{}({!r})'.format(name, capture),
165 repr(f))
166
Antoine Pitroub5b37142012-11-13 21:35:40 +0100167 f = self.partial(capture, *args)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000168 self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr),
169 repr(f))
170
Antoine Pitroub5b37142012-11-13 21:35:40 +0100171 f = self.partial(capture, **kwargs)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000172 self.assertEqual('{}({!r}, {})'.format(name, capture, kwargs_repr),
173 repr(f))
174
Antoine Pitroub5b37142012-11-13 21:35:40 +0100175 f = self.partial(capture, *args, **kwargs)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000176 self.assertEqual('{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr),
177 repr(f))
178
Jack Diederiche0cbd692009-04-01 04:27:09 +0000179 def test_pickle(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100180 f = self.partial(signature, 'asdf', bar=True)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000181 f.add_something_to__dict__ = True
182 f_copy = pickle.loads(pickle.dumps(f))
183 self.assertEqual(signature(f), signature(f_copy))
184
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200185 # Issue 6083: Reference counting bug
186 def test_setstate_refcount(self):
187 class BadSequence:
188 def __len__(self):
189 return 4
190 def __getitem__(self, key):
191 if key == 0:
192 return max
193 elif key == 1:
194 return tuple(range(1000000))
195 elif key in (2, 3):
196 return {}
197 raise IndexError
198
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200199 f = self.partial(object)
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200200 self.assertRaisesRegex(SystemError,
201 "new style getargs format but argument is not a tuple",
202 f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000203
Łukasz Langa6f692512013-06-05 12:20:24 +0200204
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200205class TestPartialPy(TestPartial, unittest.TestCase):
206 partial = staticmethod(py_functools.partial)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000207
Łukasz Langa6f692512013-06-05 12:20:24 +0200208
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200209if c_functools:
210 class PartialSubclass(c_functools.partial):
211 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100212
Łukasz Langa6f692512013-06-05 12:20:24 +0200213
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200214@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200215class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200216 if c_functools:
217 partial = PartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000218
Łukasz Langa6f692512013-06-05 12:20:24 +0200219
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000220class TestUpdateWrapper(unittest.TestCase):
221
222 def check_wrapper(self, wrapper, wrapped,
223 assigned=functools.WRAPPER_ASSIGNMENTS,
224 updated=functools.WRAPPER_UPDATES):
225 # Check attributes were assigned
226 for name in assigned:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000227 self.assertTrue(getattr(wrapper, name) is getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000228 # Check attributes were updated
229 for name in updated:
230 wrapper_attr = getattr(wrapper, name)
231 wrapped_attr = getattr(wrapped, name)
232 for key in wrapped_attr:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000233 self.assertTrue(wrapped_attr[key] is wrapper_attr[key])
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000234
R. David Murray378c0cf2010-02-24 01:46:21 +0000235 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000236 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000237 """This is a test"""
238 pass
239 f.attr = 'This is also a test'
Antoine Pitrou560f7642010-08-04 18:28:02 +0000240 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000241 pass
242 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000243 return wrapper, f
244
245 def test_default_update(self):
246 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000247 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000248 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000249 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600250 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000251 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000252 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
253 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000254
R. David Murray378c0cf2010-02-24 01:46:21 +0000255 @unittest.skipIf(sys.flags.optimize >= 2,
256 "Docstrings are omitted with -O2 and above")
257 def test_default_update_doc(self):
258 wrapper, f = self._default_update()
259 self.assertEqual(wrapper.__doc__, 'This is a test')
260
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000261 def test_no_update(self):
262 def f():
263 """This is a test"""
264 pass
265 f.attr = 'This is also a test'
266 def wrapper():
267 pass
268 functools.update_wrapper(wrapper, f, (), ())
269 self.check_wrapper(wrapper, f, (), ())
270 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600271 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000272 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000273 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000274 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000275
276 def test_selective_update(self):
277 def f():
278 pass
279 f.attr = 'This is a different test'
280 f.dict_attr = dict(a=1, b=2, c=3)
281 def wrapper():
282 pass
283 wrapper.dict_attr = {}
284 assign = ('attr',)
285 update = ('dict_attr',)
286 functools.update_wrapper(wrapper, f, assign, update)
287 self.check_wrapper(wrapper, f, assign, update)
288 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600289 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000290 self.assertEqual(wrapper.__doc__, None)
291 self.assertEqual(wrapper.attr, 'This is a different test')
292 self.assertEqual(wrapper.dict_attr, f.dict_attr)
293
Nick Coghlan98876832010-08-17 06:17:18 +0000294 def test_missing_attributes(self):
295 def f():
296 pass
297 def wrapper():
298 pass
299 wrapper.dict_attr = {}
300 assign = ('attr',)
301 update = ('dict_attr',)
302 # Missing attributes on wrapped object are ignored
303 functools.update_wrapper(wrapper, f, assign, update)
304 self.assertNotIn('attr', wrapper.__dict__)
305 self.assertEqual(wrapper.dict_attr, {})
306 # Wrapper must have expected attributes for updating
307 del wrapper.dict_attr
308 with self.assertRaises(AttributeError):
309 functools.update_wrapper(wrapper, f, assign, update)
310 wrapper.dict_attr = 1
311 with self.assertRaises(AttributeError):
312 functools.update_wrapper(wrapper, f, assign, update)
313
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200314 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000315 @unittest.skipIf(sys.flags.optimize >= 2,
316 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000317 def test_builtin_update(self):
318 # Test for bug #1576241
319 def wrapper():
320 pass
321 functools.update_wrapper(wrapper, max)
322 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000323 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000324 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000325
Łukasz Langa6f692512013-06-05 12:20:24 +0200326
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000327class TestWraps(TestUpdateWrapper):
328
R. David Murray378c0cf2010-02-24 01:46:21 +0000329 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000330 def f():
331 """This is a test"""
332 pass
333 f.attr = 'This is also a test'
334 @functools.wraps(f)
335 def wrapper():
336 pass
337 self.check_wrapper(wrapper, f)
Meador Ingeff7f64c2011-12-11 22:37:31 -0600338 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000339
340 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600341 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000342 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600343 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000344 self.assertEqual(wrapper.attr, 'This is also a test')
345
Antoine Pitroub5b37142012-11-13 21:35:40 +0100346 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000347 "Docstrings are omitted with -O2 and above")
348 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600349 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000350 self.assertEqual(wrapper.__doc__, 'This is a test')
351
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000352 def test_no_update(self):
353 def f():
354 """This is a test"""
355 pass
356 f.attr = 'This is also a test'
357 @functools.wraps(f, (), ())
358 def wrapper():
359 pass
360 self.check_wrapper(wrapper, f, (), ())
361 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600362 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000363 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000364 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000365
366 def test_selective_update(self):
367 def f():
368 pass
369 f.attr = 'This is a different test'
370 f.dict_attr = dict(a=1, b=2, c=3)
371 def add_dict_attr(f):
372 f.dict_attr = {}
373 return f
374 assign = ('attr',)
375 update = ('dict_attr',)
376 @functools.wraps(f, assign, update)
377 @add_dict_attr
378 def wrapper():
379 pass
380 self.check_wrapper(wrapper, f, assign, update)
381 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600382 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000383 self.assertEqual(wrapper.__doc__, None)
384 self.assertEqual(wrapper.attr, 'This is a different test')
385 self.assertEqual(wrapper.dict_attr, f.dict_attr)
386
Łukasz Langa6f692512013-06-05 12:20:24 +0200387
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000388class TestReduce(unittest.TestCase):
389 func = functools.reduce
390
391 def test_reduce(self):
392 class Squares:
393 def __init__(self, max):
394 self.max = max
395 self.sofar = []
396
397 def __len__(self):
398 return len(self.sofar)
399
400 def __getitem__(self, i):
401 if not 0 <= i < self.max: raise IndexError
402 n = len(self.sofar)
403 while n <= i:
404 self.sofar.append(n*n)
405 n += 1
406 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000407 def add(x, y):
408 return x + y
409 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000410 self.assertEqual(
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000411 self.func(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000412 ['a','c','d','w']
413 )
414 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
415 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000416 self.func(lambda x, y: x*y, range(2,21), 1),
417 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000418 )
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000419 self.assertEqual(self.func(add, Squares(10)), 285)
420 self.assertEqual(self.func(add, Squares(10), 0), 285)
421 self.assertEqual(self.func(add, Squares(0), 0), 0)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000422 self.assertRaises(TypeError, self.func)
423 self.assertRaises(TypeError, self.func, 42, 42)
424 self.assertRaises(TypeError, self.func, 42, 42, 42)
425 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
426 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
427 self.assertRaises(TypeError, self.func, 42, (42, 42))
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000428 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
429 self.assertRaises(TypeError, self.func, add, "")
430 self.assertRaises(TypeError, self.func, add, ())
431 self.assertRaises(TypeError, self.func, add, object())
432
433 class TestFailingIter:
434 def __iter__(self):
435 raise RuntimeError
436 self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
437
438 self.assertEqual(self.func(add, [], None), None)
439 self.assertEqual(self.func(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000440
441 class BadSeq:
442 def __getitem__(self, index):
443 raise ValueError
444 self.assertRaises(ValueError, self.func, 42, BadSeq())
445
446 # Test reduce()'s use of iterators.
447 def test_iterator_usage(self):
448 class SequenceClass:
449 def __init__(self, n):
450 self.n = n
451 def __getitem__(self, i):
452 if 0 <= i < self.n:
453 return i
454 else:
455 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000456
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000457 from operator import add
458 self.assertEqual(self.func(add, SequenceClass(5)), 10)
459 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
460 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
461 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
462 self.assertEqual(self.func(add, SequenceClass(1)), 0)
463 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
464
465 d = {"one": 1, "two": 2, "three": 3}
466 self.assertEqual(self.func(add, d), "".join(d.keys()))
467
Łukasz Langa6f692512013-06-05 12:20:24 +0200468
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200469class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700470
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000471 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700472 def cmp1(x, y):
473 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100474 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700475 self.assertEqual(key(3), key(3))
476 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100477 self.assertGreaterEqual(key(3), key(3))
478
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700479 def cmp2(x, y):
480 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100481 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700482 self.assertEqual(key(4.0), key('4'))
483 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100484 self.assertLessEqual(key(2), key('35'))
485 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700486
487 def test_cmp_to_key_arguments(self):
488 def cmp1(x, y):
489 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100490 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700491 self.assertEqual(key(obj=3), key(obj=3))
492 self.assertGreater(key(obj=3), key(obj=1))
493 with self.assertRaises((TypeError, AttributeError)):
494 key(3) > 1 # rhs is not a K object
495 with self.assertRaises((TypeError, AttributeError)):
496 1 < key(3) # lhs is not a K object
497 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100498 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700499 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200500 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100501 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700502 with self.assertRaises(TypeError):
503 key() # too few args
504 with self.assertRaises(TypeError):
505 key(None, None) # too many args
506
507 def test_bad_cmp(self):
508 def cmp1(x, y):
509 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100510 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700511 with self.assertRaises(ZeroDivisionError):
512 key(3) > key(1)
513
514 class BadCmp:
515 def __lt__(self, other):
516 raise ZeroDivisionError
517 def cmp1(x, y):
518 return BadCmp()
519 with self.assertRaises(ZeroDivisionError):
520 key(3) > key(1)
521
522 def test_obj_field(self):
523 def cmp1(x, y):
524 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100525 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700526 self.assertEqual(key(50).obj, 50)
527
528 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000529 def mycmp(x, y):
530 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100531 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000532 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000533
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700534 def test_sort_int_str(self):
535 def mycmp(x, y):
536 x, y = int(x), int(y)
537 return (x > y) - (x < y)
538 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100539 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700540 self.assertEqual([int(value) for value in values],
541 [0, 1, 1, 2, 3, 4, 5, 7, 10])
542
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000543 def test_hash(self):
544 def mycmp(x, y):
545 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100546 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000547 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700548 self.assertRaises(TypeError, hash, k)
Raymond Hettingere7a24302011-05-03 11:16:36 -0700549 self.assertNotIsInstance(k, collections.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000550
Łukasz Langa6f692512013-06-05 12:20:24 +0200551
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200552@unittest.skipUnless(c_functools, 'requires the C _functools module')
553class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
554 if c_functools:
555 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100556
Łukasz Langa6f692512013-06-05 12:20:24 +0200557
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200558class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100559 cmp_to_key = staticmethod(py_functools.cmp_to_key)
560
Łukasz Langa6f692512013-06-05 12:20:24 +0200561
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000562class TestTotalOrdering(unittest.TestCase):
563
564 def test_total_ordering_lt(self):
565 @functools.total_ordering
566 class A:
567 def __init__(self, value):
568 self.value = value
569 def __lt__(self, other):
570 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000571 def __eq__(self, other):
572 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000573 self.assertTrue(A(1) < A(2))
574 self.assertTrue(A(2) > A(1))
575 self.assertTrue(A(1) <= A(2))
576 self.assertTrue(A(2) >= A(1))
577 self.assertTrue(A(2) <= A(2))
578 self.assertTrue(A(2) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000579
580 def test_total_ordering_le(self):
581 @functools.total_ordering
582 class A:
583 def __init__(self, value):
584 self.value = value
585 def __le__(self, other):
586 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000587 def __eq__(self, other):
588 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000589 self.assertTrue(A(1) < A(2))
590 self.assertTrue(A(2) > A(1))
591 self.assertTrue(A(1) <= A(2))
592 self.assertTrue(A(2) >= A(1))
593 self.assertTrue(A(2) <= A(2))
594 self.assertTrue(A(2) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000595
596 def test_total_ordering_gt(self):
597 @functools.total_ordering
598 class A:
599 def __init__(self, value):
600 self.value = value
601 def __gt__(self, other):
602 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000603 def __eq__(self, other):
604 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000605 self.assertTrue(A(1) < A(2))
606 self.assertTrue(A(2) > A(1))
607 self.assertTrue(A(1) <= A(2))
608 self.assertTrue(A(2) >= A(1))
609 self.assertTrue(A(2) <= A(2))
610 self.assertTrue(A(2) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000611
612 def test_total_ordering_ge(self):
613 @functools.total_ordering
614 class A:
615 def __init__(self, value):
616 self.value = value
617 def __ge__(self, other):
618 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000619 def __eq__(self, other):
620 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000621 self.assertTrue(A(1) < A(2))
622 self.assertTrue(A(2) > A(1))
623 self.assertTrue(A(1) <= A(2))
624 self.assertTrue(A(2) >= A(1))
625 self.assertTrue(A(2) <= A(2))
626 self.assertTrue(A(2) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000627
628 def test_total_ordering_no_overwrite(self):
629 # new methods should not overwrite existing
630 @functools.total_ordering
631 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000632 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +0000633 self.assertTrue(A(1) < A(2))
634 self.assertTrue(A(2) > A(1))
635 self.assertTrue(A(1) <= A(2))
636 self.assertTrue(A(2) >= A(1))
637 self.assertTrue(A(2) <= A(2))
638 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000639
Benjamin Peterson42ebee32010-04-11 01:43:16 +0000640 def test_no_operations_defined(self):
641 with self.assertRaises(ValueError):
642 @functools.total_ordering
643 class A:
644 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000645
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000646 def test_bug_10042(self):
647 @functools.total_ordering
648 class TestTO:
649 def __init__(self, value):
650 self.value = value
651 def __eq__(self, other):
652 if isinstance(other, TestTO):
653 return self.value == other.value
654 return False
655 def __lt__(self, other):
656 if isinstance(other, TestTO):
657 return self.value < other.value
658 raise TypeError
659 with self.assertRaises(TypeError):
660 TestTO(8) <= ()
661
Łukasz Langa6f692512013-06-05 12:20:24 +0200662
Georg Brandl2e7346a2010-07-31 18:09:23 +0000663class TestLRU(unittest.TestCase):
664
665 def test_lru(self):
666 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100667 return 3 * x + y
Georg Brandl2e7346a2010-07-31 18:09:23 +0000668 f = functools.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000669 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000670 self.assertEqual(maxsize, 20)
671 self.assertEqual(currsize, 0)
672 self.assertEqual(hits, 0)
673 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000674
675 domain = range(5)
676 for i in range(1000):
677 x, y = choice(domain), choice(domain)
678 actual = f(x, y)
679 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +0000680 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000681 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000682 self.assertTrue(hits > misses)
683 self.assertEqual(hits + misses, 1000)
684 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000685
Raymond Hettinger02566ec2010-09-04 22:46:06 +0000686 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +0000687 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000688 self.assertEqual(hits, 0)
689 self.assertEqual(misses, 0)
690 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000691 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000692 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000693 self.assertEqual(hits, 0)
694 self.assertEqual(misses, 1)
695 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000696
Nick Coghlan98876832010-08-17 06:17:18 +0000697 # Test bypassing the cache
698 self.assertIs(f.__wrapped__, orig)
699 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000700 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000701 self.assertEqual(hits, 0)
702 self.assertEqual(misses, 1)
703 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +0000704
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000705 # test size zero (which means "never-cache")
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000706 @functools.lru_cache(0)
707 def f():
708 nonlocal f_cnt
709 f_cnt += 1
710 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000711 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000712 f_cnt = 0
713 for i in range(5):
714 self.assertEqual(f(), 20)
715 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000716 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000717 self.assertEqual(hits, 0)
718 self.assertEqual(misses, 5)
719 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000720
721 # test size one
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000722 @functools.lru_cache(1)
723 def f():
724 nonlocal f_cnt
725 f_cnt += 1
726 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000727 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000728 f_cnt = 0
729 for i in range(5):
730 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000731 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000732 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000733 self.assertEqual(hits, 4)
734 self.assertEqual(misses, 1)
735 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000736
Raymond Hettingerf3098282010-08-15 03:30:45 +0000737 # test size two
738 @functools.lru_cache(2)
739 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000740 nonlocal f_cnt
741 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +0000742 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +0000743 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000744 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +0000745 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
746 # * * * *
747 self.assertEqual(f(x), x*10)
748 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000749 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000750 self.assertEqual(hits, 12)
751 self.assertEqual(misses, 4)
752 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000753
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +0000754 def test_lru_with_maxsize_none(self):
755 @functools.lru_cache(maxsize=None)
756 def fib(n):
757 if n < 2:
758 return n
759 return fib(n-1) + fib(n-2)
760 self.assertEqual([fib(n) for n in range(16)],
761 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
762 self.assertEqual(fib.cache_info(),
763 functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
764 fib.cache_clear()
765 self.assertEqual(fib.cache_info(),
766 functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
767
Raymond Hettinger4b779b32011-10-15 23:50:42 -0700768 def test_lru_with_exceptions(self):
769 # Verify that user_function exceptions get passed through without
770 # creating a hard-to-read chained exception.
771 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +0100772 for maxsize in (None, 128):
Raymond Hettinger4b779b32011-10-15 23:50:42 -0700773 @functools.lru_cache(maxsize)
774 def func(i):
775 return 'abc'[i]
776 self.assertEqual(func(0), 'a')
777 with self.assertRaises(IndexError) as cm:
778 func(15)
779 self.assertIsNone(cm.exception.__context__)
780 # Verify that the previous exception did not result in a cached entry
781 with self.assertRaises(IndexError):
782 func(15)
783
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -0700784 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100785 for maxsize in (None, 128):
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -0700786 @functools.lru_cache(maxsize=maxsize, typed=True)
787 def square(x):
788 return x * x
789 self.assertEqual(square(3), 9)
790 self.assertEqual(type(square(3)), type(9))
791 self.assertEqual(square(3.0), 9.0)
792 self.assertEqual(type(square(3.0)), type(9.0))
793 self.assertEqual(square(x=3), 9)
794 self.assertEqual(type(square(x=3)), type(9))
795 self.assertEqual(square(x=3.0), 9.0)
796 self.assertEqual(type(square(x=3.0)), type(9.0))
797 self.assertEqual(square.cache_info().hits, 4)
798 self.assertEqual(square.cache_info().misses, 4)
799
Antoine Pitroub5b37142012-11-13 21:35:40 +0100800 def test_lru_with_keyword_args(self):
801 @functools.lru_cache()
802 def fib(n):
803 if n < 2:
804 return n
805 return fib(n=n-1) + fib(n=n-2)
806 self.assertEqual(
807 [fib(n=number) for number in range(16)],
808 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
809 )
810 self.assertEqual(fib.cache_info(),
811 functools._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
812 fib.cache_clear()
813 self.assertEqual(fib.cache_info(),
814 functools._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
815
816 def test_lru_with_keyword_args_maxsize_none(self):
817 @functools.lru_cache(maxsize=None)
818 def fib(n):
819 if n < 2:
820 return n
821 return fib(n=n-1) + fib(n=n-2)
822 self.assertEqual([fib(n=number) for number in range(16)],
823 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
824 self.assertEqual(fib.cache_info(),
825 functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
826 fib.cache_clear()
827 self.assertEqual(fib.cache_info(),
828 functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
829
Raymond Hettinger03923422013-03-04 02:52:50 -0500830 def test_need_for_rlock(self):
831 # This will deadlock on an LRU cache that uses a regular lock
832
833 @functools.lru_cache(maxsize=10)
834 def test_func(x):
835 'Used to demonstrate a reentrant lru_cache call within a single thread'
836 return x
837
838 class DoubleEq:
839 'Demonstrate a reentrant lru_cache call within a single thread'
840 def __init__(self, x):
841 self.x = x
842 def __hash__(self):
843 return self.x
844 def __eq__(self, other):
845 if self.x == 2:
846 test_func(DoubleEq(1))
847 return self.x == other.x
848
849 test_func(DoubleEq(1)) # Load the cache
850 test_func(DoubleEq(2)) # Load the cache
851 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
852 DoubleEq(2)) # Verify the correct return value
853
854
Łukasz Langa6f692512013-06-05 12:20:24 +0200855class TestSingleDispatch(unittest.TestCase):
856 def test_simple_overloads(self):
857 @functools.singledispatch
858 def g(obj):
859 return "base"
860 def g_int(i):
861 return "integer"
862 g.register(int, g_int)
863 self.assertEqual(g("str"), "base")
864 self.assertEqual(g(1), "integer")
865 self.assertEqual(g([1,2,3]), "base")
866
867 def test_mro(self):
868 @functools.singledispatch
869 def g(obj):
870 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +0200871 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +0200872 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +0200873 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +0200874 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +0200875 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +0200876 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +0200877 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +0200878 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +0200879 def g_A(a):
880 return "A"
881 def g_B(b):
882 return "B"
883 g.register(A, g_A)
884 g.register(B, g_B)
885 self.assertEqual(g(A()), "A")
886 self.assertEqual(g(B()), "B")
887 self.assertEqual(g(C()), "A")
888 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +0200889
890 def test_register_decorator(self):
891 @functools.singledispatch
892 def g(obj):
893 return "base"
894 @g.register(int)
895 def g_int(i):
896 return "int %s" % (i,)
897 self.assertEqual(g(""), "base")
898 self.assertEqual(g(12), "int 12")
899 self.assertIs(g.dispatch(int), g_int)
900 self.assertIs(g.dispatch(object), g.dispatch(str))
901 # Note: in the assert above this is not g.
902 # @singledispatch returns the wrapper.
903
904 def test_wrapping_attributes(self):
905 @functools.singledispatch
906 def g(obj):
907 "Simple test"
908 return "Test"
909 self.assertEqual(g.__name__, "g")
910 self.assertEqual(g.__doc__, "Simple test")
911
912 @unittest.skipUnless(decimal, 'requires _decimal')
913 @support.cpython_only
914 def test_c_classes(self):
915 @functools.singledispatch
916 def g(obj):
917 return "base"
918 @g.register(decimal.DecimalException)
919 def _(obj):
920 return obj.args
921 subn = decimal.Subnormal("Exponent < Emin")
922 rnd = decimal.Rounded("Number got rounded")
923 self.assertEqual(g(subn), ("Exponent < Emin",))
924 self.assertEqual(g(rnd), ("Number got rounded",))
925 @g.register(decimal.Subnormal)
926 def _(obj):
927 return "Too small to care."
928 self.assertEqual(g(subn), "Too small to care.")
929 self.assertEqual(g(rnd), ("Number got rounded",))
930
931 def test_compose_mro(self):
932 c = collections
933 mro = functools._compose_mro
934 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
935 for haystack in permutations(bases):
936 m = mro(dict, haystack)
937 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, object])
938 bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict]
939 for haystack in permutations(bases):
940 m = mro(c.ChainMap, haystack)
941 self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping,
942 c.Sized, c.Iterable, c.Container, object])
943 # Note: The MRO order below depends on haystack ordering.
944 m = mro(c.defaultdict, [c.Sized, c.Container, str])
945 self.assertEqual(m, [c.defaultdict, dict, c.Container, c.Sized, object])
946 m = mro(c.defaultdict, [c.Container, c.Sized, str])
947 self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container, object])
948
949 def test_register_abc(self):
950 c = collections
951 d = {"a": "b"}
952 l = [1, 2, 3]
953 s = {object(), None}
954 f = frozenset(s)
955 t = (1, 2, 3)
956 @functools.singledispatch
957 def g(obj):
958 return "base"
959 self.assertEqual(g(d), "base")
960 self.assertEqual(g(l), "base")
961 self.assertEqual(g(s), "base")
962 self.assertEqual(g(f), "base")
963 self.assertEqual(g(t), "base")
964 g.register(c.Sized, lambda obj: "sized")
965 self.assertEqual(g(d), "sized")
966 self.assertEqual(g(l), "sized")
967 self.assertEqual(g(s), "sized")
968 self.assertEqual(g(f), "sized")
969 self.assertEqual(g(t), "sized")
970 g.register(c.MutableMapping, lambda obj: "mutablemapping")
971 self.assertEqual(g(d), "mutablemapping")
972 self.assertEqual(g(l), "sized")
973 self.assertEqual(g(s), "sized")
974 self.assertEqual(g(f), "sized")
975 self.assertEqual(g(t), "sized")
976 g.register(c.ChainMap, lambda obj: "chainmap")
977 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
978 self.assertEqual(g(l), "sized")
979 self.assertEqual(g(s), "sized")
980 self.assertEqual(g(f), "sized")
981 self.assertEqual(g(t), "sized")
982 g.register(c.MutableSequence, lambda obj: "mutablesequence")
983 self.assertEqual(g(d), "mutablemapping")
984 self.assertEqual(g(l), "mutablesequence")
985 self.assertEqual(g(s), "sized")
986 self.assertEqual(g(f), "sized")
987 self.assertEqual(g(t), "sized")
988 g.register(c.MutableSet, lambda obj: "mutableset")
989 self.assertEqual(g(d), "mutablemapping")
990 self.assertEqual(g(l), "mutablesequence")
991 self.assertEqual(g(s), "mutableset")
992 self.assertEqual(g(f), "sized")
993 self.assertEqual(g(t), "sized")
994 g.register(c.Mapping, lambda obj: "mapping")
995 self.assertEqual(g(d), "mutablemapping") # not specific enough
996 self.assertEqual(g(l), "mutablesequence")
997 self.assertEqual(g(s), "mutableset")
998 self.assertEqual(g(f), "sized")
999 self.assertEqual(g(t), "sized")
1000 g.register(c.Sequence, lambda obj: "sequence")
1001 self.assertEqual(g(d), "mutablemapping")
1002 self.assertEqual(g(l), "mutablesequence")
1003 self.assertEqual(g(s), "mutableset")
1004 self.assertEqual(g(f), "sized")
1005 self.assertEqual(g(t), "sequence")
1006 g.register(c.Set, lambda obj: "set")
1007 self.assertEqual(g(d), "mutablemapping")
1008 self.assertEqual(g(l), "mutablesequence")
1009 self.assertEqual(g(s), "mutableset")
1010 self.assertEqual(g(f), "set")
1011 self.assertEqual(g(t), "sequence")
1012 g.register(dict, lambda obj: "dict")
1013 self.assertEqual(g(d), "dict")
1014 self.assertEqual(g(l), "mutablesequence")
1015 self.assertEqual(g(s), "mutableset")
1016 self.assertEqual(g(f), "set")
1017 self.assertEqual(g(t), "sequence")
1018 g.register(list, lambda obj: "list")
1019 self.assertEqual(g(d), "dict")
1020 self.assertEqual(g(l), "list")
1021 self.assertEqual(g(s), "mutableset")
1022 self.assertEqual(g(f), "set")
1023 self.assertEqual(g(t), "sequence")
1024 g.register(set, lambda obj: "concrete-set")
1025 self.assertEqual(g(d), "dict")
1026 self.assertEqual(g(l), "list")
1027 self.assertEqual(g(s), "concrete-set")
1028 self.assertEqual(g(f), "set")
1029 self.assertEqual(g(t), "sequence")
1030 g.register(frozenset, lambda obj: "frozen-set")
1031 self.assertEqual(g(d), "dict")
1032 self.assertEqual(g(l), "list")
1033 self.assertEqual(g(s), "concrete-set")
1034 self.assertEqual(g(f), "frozen-set")
1035 self.assertEqual(g(t), "sequence")
1036 g.register(tuple, lambda obj: "tuple")
1037 self.assertEqual(g(d), "dict")
1038 self.assertEqual(g(l), "list")
1039 self.assertEqual(g(s), "concrete-set")
1040 self.assertEqual(g(f), "frozen-set")
1041 self.assertEqual(g(t), "tuple")
1042
1043 def test_mro_conflicts(self):
1044 c = collections
1045
1046 @functools.singledispatch
1047 def g(arg):
1048 return "base"
1049
1050 class O(c.Sized):
1051 def __len__(self):
1052 return 0
1053
1054 o = O()
1055 self.assertEqual(g(o), "base")
1056 g.register(c.Iterable, lambda arg: "iterable")
1057 g.register(c.Container, lambda arg: "container")
1058 g.register(c.Sized, lambda arg: "sized")
1059 g.register(c.Set, lambda arg: "set")
1060 self.assertEqual(g(o), "sized")
1061 c.Iterable.register(O)
1062 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
1063 c.Container.register(O)
1064 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
1065
1066 class P:
1067 pass
1068
1069 p = P()
1070 self.assertEqual(g(p), "base")
1071 c.Iterable.register(P)
1072 self.assertEqual(g(p), "iterable")
1073 c.Container.register(P)
1074 with self.assertRaises(RuntimeError) as re:
1075 g(p)
1076 self.assertEqual(
1077 str(re),
1078 ("Ambiguous dispatch: <class 'collections.abc.Container'> "
1079 "or <class 'collections.abc.Iterable'>"),
1080 )
1081
1082 class Q(c.Sized):
1083 def __len__(self):
1084 return 0
1085
1086 q = Q()
1087 self.assertEqual(g(q), "sized")
1088 c.Iterable.register(Q)
1089 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
1090 c.Set.register(Q)
1091 self.assertEqual(g(q), "set") # because c.Set is a subclass of
1092 # c.Sized which is explicitly in
1093 # __mro__
1094
1095 def test_cache_invalidation(self):
1096 from collections import UserDict
1097 class TracingDict(UserDict):
1098 def __init__(self, *args, **kwargs):
1099 super(TracingDict, self).__init__(*args, **kwargs)
1100 self.set_ops = []
1101 self.get_ops = []
1102 def __getitem__(self, key):
1103 result = self.data[key]
1104 self.get_ops.append(key)
1105 return result
1106 def __setitem__(self, key, value):
1107 self.set_ops.append(key)
1108 self.data[key] = value
1109 def clear(self):
1110 self.data.clear()
1111 _orig_wkd = functools.WeakKeyDictionary
1112 td = TracingDict()
1113 functools.WeakKeyDictionary = lambda: td
1114 c = collections
1115 @functools.singledispatch
1116 def g(arg):
1117 return "base"
1118 d = {}
1119 l = []
1120 self.assertEqual(len(td), 0)
1121 self.assertEqual(g(d), "base")
1122 self.assertEqual(len(td), 1)
1123 self.assertEqual(td.get_ops, [])
1124 self.assertEqual(td.set_ops, [dict])
1125 self.assertEqual(td.data[dict], g.registry[object])
1126 self.assertEqual(g(l), "base")
1127 self.assertEqual(len(td), 2)
1128 self.assertEqual(td.get_ops, [])
1129 self.assertEqual(td.set_ops, [dict, list])
1130 self.assertEqual(td.data[dict], g.registry[object])
1131 self.assertEqual(td.data[list], g.registry[object])
1132 self.assertEqual(td.data[dict], td.data[list])
1133 self.assertEqual(g(l), "base")
1134 self.assertEqual(g(d), "base")
1135 self.assertEqual(td.get_ops, [list, dict])
1136 self.assertEqual(td.set_ops, [dict, list])
1137 g.register(list, lambda arg: "list")
1138 self.assertEqual(td.get_ops, [list, dict])
1139 self.assertEqual(len(td), 0)
1140 self.assertEqual(g(d), "base")
1141 self.assertEqual(len(td), 1)
1142 self.assertEqual(td.get_ops, [list, dict])
1143 self.assertEqual(td.set_ops, [dict, list, dict])
1144 self.assertEqual(td.data[dict],
1145 functools._find_impl(dict, g.registry))
1146 self.assertEqual(g(l), "list")
1147 self.assertEqual(len(td), 2)
1148 self.assertEqual(td.get_ops, [list, dict])
1149 self.assertEqual(td.set_ops, [dict, list, dict, list])
1150 self.assertEqual(td.data[list],
1151 functools._find_impl(list, g.registry))
1152 class X:
1153 pass
1154 c.MutableMapping.register(X) # Will not invalidate the cache,
1155 # not using ABCs yet.
1156 self.assertEqual(g(d), "base")
1157 self.assertEqual(g(l), "list")
1158 self.assertEqual(td.get_ops, [list, dict, dict, list])
1159 self.assertEqual(td.set_ops, [dict, list, dict, list])
1160 g.register(c.Sized, lambda arg: "sized")
1161 self.assertEqual(len(td), 0)
1162 self.assertEqual(g(d), "sized")
1163 self.assertEqual(len(td), 1)
1164 self.assertEqual(td.get_ops, [list, dict, dict, list])
1165 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
1166 self.assertEqual(g(l), "list")
1167 self.assertEqual(len(td), 2)
1168 self.assertEqual(td.get_ops, [list, dict, dict, list])
1169 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1170 self.assertEqual(g(l), "list")
1171 self.assertEqual(g(d), "sized")
1172 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
1173 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1174 g.dispatch(list)
1175 g.dispatch(dict)
1176 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
1177 list, dict])
1178 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1179 c.MutableSet.register(X) # Will invalidate the cache.
1180 self.assertEqual(len(td), 2) # Stale cache.
1181 self.assertEqual(g(l), "list")
1182 self.assertEqual(len(td), 1)
1183 g.register(c.MutableMapping, lambda arg: "mutablemapping")
1184 self.assertEqual(len(td), 0)
1185 self.assertEqual(g(d), "mutablemapping")
1186 self.assertEqual(len(td), 1)
1187 self.assertEqual(g(l), "list")
1188 self.assertEqual(len(td), 2)
1189 g.register(dict, lambda arg: "dict")
1190 self.assertEqual(g(d), "dict")
1191 self.assertEqual(g(l), "list")
1192 g._clear_cache()
1193 self.assertEqual(len(td), 0)
1194 functools.WeakKeyDictionary = _orig_wkd
1195
1196
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001197def test_main(verbose=None):
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001198 test_classes = (
Antoine Pitroub5b37142012-11-13 21:35:40 +01001199 TestPartialC,
1200 TestPartialPy,
1201 TestPartialCSubclass,
Thomas Wouters73e5a5b2006-06-08 15:35:45 +00001202 TestUpdateWrapper,
Benjamin Peterson9c2930e2010-08-23 17:40:33 +00001203 TestTotalOrdering,
Antoine Pitroub5b37142012-11-13 21:35:40 +01001204 TestCmpToKeyC,
1205 TestCmpToKeyPy,
Guido van Rossum0919a1a2006-08-26 20:49:04 +00001206 TestWraps,
Georg Brandl2e7346a2010-07-31 18:09:23 +00001207 TestReduce,
1208 TestLRU,
Łukasz Langa6f692512013-06-05 12:20:24 +02001209 TestSingleDispatch,
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001210 )
Benjamin Petersonee8712c2008-05-20 21:35:26 +00001211 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001212
1213 # verify reference counting
1214 if verbose and hasattr(sys, "gettotalrefcount"):
1215 import gc
1216 counts = [None] * 5
Guido van Rossum805365e2007-05-07 22:24:25 +00001217 for i in range(len(counts)):
Benjamin Petersonee8712c2008-05-20 21:35:26 +00001218 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001219 gc.collect()
1220 counts[i] = sys.gettotalrefcount()
Guido van Rossumbe19ed72007-02-09 05:37:30 +00001221 print(counts)
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001222
1223if __name__ == '__main__':
1224 test_main(verbose=True)