blob: ae929eca99f281b02b0a9e9514bd96bb613732e9 [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettinger003be522011-05-03 11:01:32 -07002import collections
Łukasz Langa6f692512013-06-05 12:20:24 +02003from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00004import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00005from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02006import sys
7from test import support
8import unittest
9from weakref import proxy
Serhiy Storchaka46c56112015-05-24 21:53:49 +030010try:
11 import threading
12except ImportError:
13 threading = None
Raymond Hettinger9c323f82005-02-28 19:39:44 +000014
Antoine Pitroub5b37142012-11-13 21:35:40 +010015import functools
16
Antoine Pitroub5b37142012-11-13 21:35:40 +010017py_functools = support.import_fresh_module('functools', blocked=['_functools'])
18c_functools = support.import_fresh_module('functools', fresh=['_functools'])
19
Łukasz Langa6f692512013-06-05 12:20:24 +020020decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
21
22
Raymond Hettinger9c323f82005-02-28 19:39:44 +000023def capture(*args, **kw):
24 """capture all positional and keyword arguments"""
25 return args, kw
26
Łukasz Langa6f692512013-06-05 12:20:24 +020027
Jack Diederiche0cbd692009-04-01 04:27:09 +000028def signature(part):
29 """ return the signature of a partial object """
30 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000031
Łukasz Langa6f692512013-06-05 12:20:24 +020032
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020033class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000034
35 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010036 p = self.partial(capture, 1, 2, a=10, b=20)
37 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000038 self.assertEqual(p(3, 4, b=30, c=40),
39 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010040 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000041 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000042
43 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010044 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000045 # attributes should be readable
46 self.assertEqual(p.func, capture)
47 self.assertEqual(p.args, (1, 2))
48 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000049
50 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010051 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000052 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010053 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000054 except TypeError:
55 pass
56 else:
57 self.fail('First arg not checked for callability')
58
59 def test_protection_of_callers_dict_argument(self):
60 # a caller's dictionary should not be altered by partial
61 def func(a=10, b=20):
62 return a
63 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010064 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000065 self.assertEqual(p(**d), 3)
66 self.assertEqual(d, {'a':3})
67 p(b=7)
68 self.assertEqual(d, {'a':3})
69
70 def test_arg_combinations(self):
71 # exercise special code paths for zero args in either partial
72 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010073 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000074 self.assertEqual(p(), ((), {}))
75 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010076 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000077 self.assertEqual(p(), ((1,2), {}))
78 self.assertEqual(p(3,4), ((1,2,3,4), {}))
79
80 def test_kw_combinations(self):
81 # exercise special code paths for no keyword args in
82 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010083 p = self.partial(capture)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -040084 self.assertEqual(p.keywords, {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +000085 self.assertEqual(p(), ((), {}))
86 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010087 p = self.partial(capture, a=1)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -040088 self.assertEqual(p.keywords, {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +000089 self.assertEqual(p(), ((), {'a':1}))
90 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
91 # keyword args in the call override those in the partial object
92 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
93
94 def test_positional(self):
95 # make sure positional arguments are captured correctly
96 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +010097 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000098 expected = args + ('x',)
99 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000100 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000101
102 def test_keyword(self):
103 # make sure keyword arguments are captured correctly
104 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100105 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000106 expected = {'a':a,'x':None}
107 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000108 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000109
110 def test_no_side_effects(self):
111 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100112 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000113 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000114 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000115 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000116 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000117
118 def test_error_propagation(self):
119 def f(x, y):
120 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100121 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
122 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
123 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
124 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000125
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000126 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100127 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000128 p = proxy(f)
129 self.assertEqual(f.func, p.func)
130 f = None
131 self.assertRaises(ReferenceError, getattr, p, 'func')
132
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000133 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000134 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100135 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000136 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100137 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000138 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000139
Alexander Belopolskye49af342015-03-01 15:08:17 -0500140 def test_nested_optimization(self):
141 partial = self.partial
142 # Only "true" partial is optimized
143 if partial.__name__ != 'partial':
144 return
145 inner = partial(signature, 'asdf')
146 nested = partial(inner, bar=True)
147 flat = partial(signature, 'asdf', bar=True)
148 self.assertEqual(signature(nested), signature(flat))
149
Łukasz Langa6f692512013-06-05 12:20:24 +0200150
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200151@unittest.skipUnless(c_functools, 'requires the C _functools module')
152class TestPartialC(TestPartial, unittest.TestCase):
153 if c_functools:
154 partial = c_functools.partial
155
Zachary Ware101d9e72013-12-08 00:44:27 -0600156 def test_attributes_unwritable(self):
157 # attributes should not be writable
158 p = self.partial(capture, 1, 2, a=10, b=20)
159 self.assertRaises(AttributeError, setattr, p, 'func', map)
160 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
161 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
162
163 p = self.partial(hex)
164 try:
165 del p.__dict__
166 except TypeError:
167 pass
168 else:
169 self.fail('partial object allowed __dict__ to be deleted')
170
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000171 def test_repr(self):
172 args = (object(), object())
173 args_repr = ', '.join(repr(a) for a in args)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200174 kwargs = {'a': object(), 'b': object()}
175 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
176 'b={b!r}, a={a!r}'.format_map(kwargs)]
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200177 if self.partial is c_functools.partial:
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000178 name = 'functools.partial'
179 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100180 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000181
Antoine Pitroub5b37142012-11-13 21:35:40 +0100182 f = self.partial(capture)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000183 self.assertEqual('{}({!r})'.format(name, capture),
184 repr(f))
185
Antoine Pitroub5b37142012-11-13 21:35:40 +0100186 f = self.partial(capture, *args)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000187 self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr),
188 repr(f))
189
Antoine Pitroub5b37142012-11-13 21:35:40 +0100190 f = self.partial(capture, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200191 self.assertIn(repr(f),
192 ['{}({!r}, {})'.format(name, capture, kwargs_repr)
193 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000194
Antoine Pitroub5b37142012-11-13 21:35:40 +0100195 f = self.partial(capture, *args, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200196 self.assertIn(repr(f),
197 ['{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr)
198 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000199
Jack Diederiche0cbd692009-04-01 04:27:09 +0000200 def test_pickle(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100201 f = self.partial(signature, 'asdf', bar=True)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000202 f.add_something_to__dict__ = True
Serhiy Storchakabad12572014-12-15 14:03:42 +0200203 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
204 f_copy = pickle.loads(pickle.dumps(f, proto))
205 self.assertEqual(signature(f), signature(f_copy))
Jack Diederiche0cbd692009-04-01 04:27:09 +0000206
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200207 # Issue 6083: Reference counting bug
208 def test_setstate_refcount(self):
209 class BadSequence:
210 def __len__(self):
211 return 4
212 def __getitem__(self, key):
213 if key == 0:
214 return max
215 elif key == 1:
216 return tuple(range(1000000))
217 elif key in (2, 3):
218 return {}
219 raise IndexError
220
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200221 f = self.partial(object)
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200222 self.assertRaisesRegex(SystemError,
223 "new style getargs format but argument is not a tuple",
224 f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000225
Łukasz Langa6f692512013-06-05 12:20:24 +0200226
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200227class TestPartialPy(TestPartial, unittest.TestCase):
228 partial = staticmethod(py_functools.partial)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000229
Łukasz Langa6f692512013-06-05 12:20:24 +0200230
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200231if c_functools:
232 class PartialSubclass(c_functools.partial):
233 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100234
Łukasz Langa6f692512013-06-05 12:20:24 +0200235
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200236@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200237class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200238 if c_functools:
239 partial = PartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000240
Łukasz Langa6f692512013-06-05 12:20:24 +0200241
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000242class TestPartialMethod(unittest.TestCase):
243
244 class A(object):
245 nothing = functools.partialmethod(capture)
246 positional = functools.partialmethod(capture, 1)
247 keywords = functools.partialmethod(capture, a=2)
248 both = functools.partialmethod(capture, 3, b=4)
249
250 nested = functools.partialmethod(positional, 5)
251
252 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
253
254 static = functools.partialmethod(staticmethod(capture), 8)
255 cls = functools.partialmethod(classmethod(capture), d=9)
256
257 a = A()
258
259 def test_arg_combinations(self):
260 self.assertEqual(self.a.nothing(), ((self.a,), {}))
261 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
262 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
263 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
264
265 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
266 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
267 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
268 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
269
270 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
271 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
272 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
273 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
274
275 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
276 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
277 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
278 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
279
280 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
281
282 def test_nested(self):
283 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
284 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
285 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
286 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
287
288 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
289
290 def test_over_partial(self):
291 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
292 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
293 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
294 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
295
296 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
297
298 def test_bound_method_introspection(self):
299 obj = self.a
300 self.assertIs(obj.both.__self__, obj)
301 self.assertIs(obj.nested.__self__, obj)
302 self.assertIs(obj.over_partial.__self__, obj)
303 self.assertIs(obj.cls.__self__, self.A)
304 self.assertIs(self.A.cls.__self__, self.A)
305
306 def test_unbound_method_retrieval(self):
307 obj = self.A
308 self.assertFalse(hasattr(obj.both, "__self__"))
309 self.assertFalse(hasattr(obj.nested, "__self__"))
310 self.assertFalse(hasattr(obj.over_partial, "__self__"))
311 self.assertFalse(hasattr(obj.static, "__self__"))
312 self.assertFalse(hasattr(self.a.static, "__self__"))
313
314 def test_descriptors(self):
315 for obj in [self.A, self.a]:
316 with self.subTest(obj=obj):
317 self.assertEqual(obj.static(), ((8,), {}))
318 self.assertEqual(obj.static(5), ((8, 5), {}))
319 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
320 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
321
322 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
323 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
324 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
325 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
326
327 def test_overriding_keywords(self):
328 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
329 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
330
331 def test_invalid_args(self):
332 with self.assertRaises(TypeError):
333 class B(object):
334 method = functools.partialmethod(None, 1)
335
336 def test_repr(self):
337 self.assertEqual(repr(vars(self.A)['both']),
338 'functools.partialmethod({}, 3, b=4)'.format(capture))
339
340 def test_abstract(self):
341 class Abstract(abc.ABCMeta):
342
343 @abc.abstractmethod
344 def add(self, x, y):
345 pass
346
347 add5 = functools.partialmethod(add, 5)
348
349 self.assertTrue(Abstract.add.__isabstractmethod__)
350 self.assertTrue(Abstract.add5.__isabstractmethod__)
351
352 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
353 self.assertFalse(getattr(func, '__isabstractmethod__', False))
354
355
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000356class TestUpdateWrapper(unittest.TestCase):
357
358 def check_wrapper(self, wrapper, wrapped,
359 assigned=functools.WRAPPER_ASSIGNMENTS,
360 updated=functools.WRAPPER_UPDATES):
361 # Check attributes were assigned
362 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000363 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000364 # Check attributes were updated
365 for name in updated:
366 wrapper_attr = getattr(wrapper, name)
367 wrapped_attr = getattr(wrapped, name)
368 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000369 if name == "__dict__" and key == "__wrapped__":
370 # __wrapped__ is overwritten by the update code
371 continue
372 self.assertIs(wrapped_attr[key], wrapper_attr[key])
373 # Check __wrapped__
374 self.assertIs(wrapper.__wrapped__, wrapped)
375
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000376
R. David Murray378c0cf2010-02-24 01:46:21 +0000377 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000378 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000379 """This is a test"""
380 pass
381 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000382 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000383 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000384 pass
385 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000386 return wrapper, f
387
388 def test_default_update(self):
389 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000390 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000391 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000392 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600393 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000394 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000395 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
396 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000397
R. David Murray378c0cf2010-02-24 01:46:21 +0000398 @unittest.skipIf(sys.flags.optimize >= 2,
399 "Docstrings are omitted with -O2 and above")
400 def test_default_update_doc(self):
401 wrapper, f = self._default_update()
402 self.assertEqual(wrapper.__doc__, 'This is a test')
403
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000404 def test_no_update(self):
405 def f():
406 """This is a test"""
407 pass
408 f.attr = 'This is also a test'
409 def wrapper():
410 pass
411 functools.update_wrapper(wrapper, f, (), ())
412 self.check_wrapper(wrapper, f, (), ())
413 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600414 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000415 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000416 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000417 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000418
419 def test_selective_update(self):
420 def f():
421 pass
422 f.attr = 'This is a different test'
423 f.dict_attr = dict(a=1, b=2, c=3)
424 def wrapper():
425 pass
426 wrapper.dict_attr = {}
427 assign = ('attr',)
428 update = ('dict_attr',)
429 functools.update_wrapper(wrapper, f, assign, update)
430 self.check_wrapper(wrapper, f, assign, update)
431 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600432 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000433 self.assertEqual(wrapper.__doc__, None)
434 self.assertEqual(wrapper.attr, 'This is a different test')
435 self.assertEqual(wrapper.dict_attr, f.dict_attr)
436
Nick Coghlan98876832010-08-17 06:17:18 +0000437 def test_missing_attributes(self):
438 def f():
439 pass
440 def wrapper():
441 pass
442 wrapper.dict_attr = {}
443 assign = ('attr',)
444 update = ('dict_attr',)
445 # Missing attributes on wrapped object are ignored
446 functools.update_wrapper(wrapper, f, assign, update)
447 self.assertNotIn('attr', wrapper.__dict__)
448 self.assertEqual(wrapper.dict_attr, {})
449 # Wrapper must have expected attributes for updating
450 del wrapper.dict_attr
451 with self.assertRaises(AttributeError):
452 functools.update_wrapper(wrapper, f, assign, update)
453 wrapper.dict_attr = 1
454 with self.assertRaises(AttributeError):
455 functools.update_wrapper(wrapper, f, assign, update)
456
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200457 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000458 @unittest.skipIf(sys.flags.optimize >= 2,
459 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000460 def test_builtin_update(self):
461 # Test for bug #1576241
462 def wrapper():
463 pass
464 functools.update_wrapper(wrapper, max)
465 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000466 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000467 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000468
Łukasz Langa6f692512013-06-05 12:20:24 +0200469
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000470class TestWraps(TestUpdateWrapper):
471
R. David Murray378c0cf2010-02-24 01:46:21 +0000472 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000473 def f():
474 """This is a test"""
475 pass
476 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000477 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000478 @functools.wraps(f)
479 def wrapper():
480 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600481 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000482
483 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600484 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000485 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000486 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600487 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000488 self.assertEqual(wrapper.attr, 'This is also a test')
489
Antoine Pitroub5b37142012-11-13 21:35:40 +0100490 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000491 "Docstrings are omitted with -O2 and above")
492 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600493 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000494 self.assertEqual(wrapper.__doc__, 'This is a test')
495
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000496 def test_no_update(self):
497 def f():
498 """This is a test"""
499 pass
500 f.attr = 'This is also a test'
501 @functools.wraps(f, (), ())
502 def wrapper():
503 pass
504 self.check_wrapper(wrapper, f, (), ())
505 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600506 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000507 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000508 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000509
510 def test_selective_update(self):
511 def f():
512 pass
513 f.attr = 'This is a different test'
514 f.dict_attr = dict(a=1, b=2, c=3)
515 def add_dict_attr(f):
516 f.dict_attr = {}
517 return f
518 assign = ('attr',)
519 update = ('dict_attr',)
520 @functools.wraps(f, assign, update)
521 @add_dict_attr
522 def wrapper():
523 pass
524 self.check_wrapper(wrapper, f, assign, update)
525 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600526 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000527 self.assertEqual(wrapper.__doc__, None)
528 self.assertEqual(wrapper.attr, 'This is a different test')
529 self.assertEqual(wrapper.dict_attr, f.dict_attr)
530
Łukasz Langa6f692512013-06-05 12:20:24 +0200531
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000532class TestReduce(unittest.TestCase):
533 func = functools.reduce
534
535 def test_reduce(self):
536 class Squares:
537 def __init__(self, max):
538 self.max = max
539 self.sofar = []
540
541 def __len__(self):
542 return len(self.sofar)
543
544 def __getitem__(self, i):
545 if not 0 <= i < self.max: raise IndexError
546 n = len(self.sofar)
547 while n <= i:
548 self.sofar.append(n*n)
549 n += 1
550 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000551 def add(x, y):
552 return x + y
553 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000554 self.assertEqual(
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000555 self.func(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000556 ['a','c','d','w']
557 )
558 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
559 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000560 self.func(lambda x, y: x*y, range(2,21), 1),
561 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000562 )
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000563 self.assertEqual(self.func(add, Squares(10)), 285)
564 self.assertEqual(self.func(add, Squares(10), 0), 285)
565 self.assertEqual(self.func(add, Squares(0), 0), 0)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000566 self.assertRaises(TypeError, self.func)
567 self.assertRaises(TypeError, self.func, 42, 42)
568 self.assertRaises(TypeError, self.func, 42, 42, 42)
569 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
570 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
571 self.assertRaises(TypeError, self.func, 42, (42, 42))
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000572 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
573 self.assertRaises(TypeError, self.func, add, "")
574 self.assertRaises(TypeError, self.func, add, ())
575 self.assertRaises(TypeError, self.func, add, object())
576
577 class TestFailingIter:
578 def __iter__(self):
579 raise RuntimeError
580 self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
581
582 self.assertEqual(self.func(add, [], None), None)
583 self.assertEqual(self.func(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000584
585 class BadSeq:
586 def __getitem__(self, index):
587 raise ValueError
588 self.assertRaises(ValueError, self.func, 42, BadSeq())
589
590 # Test reduce()'s use of iterators.
591 def test_iterator_usage(self):
592 class SequenceClass:
593 def __init__(self, n):
594 self.n = n
595 def __getitem__(self, i):
596 if 0 <= i < self.n:
597 return i
598 else:
599 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000600
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000601 from operator import add
602 self.assertEqual(self.func(add, SequenceClass(5)), 10)
603 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
604 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
605 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
606 self.assertEqual(self.func(add, SequenceClass(1)), 0)
607 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
608
609 d = {"one": 1, "two": 2, "three": 3}
610 self.assertEqual(self.func(add, d), "".join(d.keys()))
611
Łukasz Langa6f692512013-06-05 12:20:24 +0200612
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200613class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700614
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000615 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700616 def cmp1(x, y):
617 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100618 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700619 self.assertEqual(key(3), key(3))
620 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100621 self.assertGreaterEqual(key(3), key(3))
622
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700623 def cmp2(x, y):
624 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100625 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700626 self.assertEqual(key(4.0), key('4'))
627 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100628 self.assertLessEqual(key(2), key('35'))
629 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700630
631 def test_cmp_to_key_arguments(self):
632 def cmp1(x, y):
633 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100634 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700635 self.assertEqual(key(obj=3), key(obj=3))
636 self.assertGreater(key(obj=3), key(obj=1))
637 with self.assertRaises((TypeError, AttributeError)):
638 key(3) > 1 # rhs is not a K object
639 with self.assertRaises((TypeError, AttributeError)):
640 1 < key(3) # lhs is not a K object
641 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100642 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700643 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200644 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100645 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700646 with self.assertRaises(TypeError):
647 key() # too few args
648 with self.assertRaises(TypeError):
649 key(None, None) # too many args
650
651 def test_bad_cmp(self):
652 def cmp1(x, y):
653 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100654 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700655 with self.assertRaises(ZeroDivisionError):
656 key(3) > key(1)
657
658 class BadCmp:
659 def __lt__(self, other):
660 raise ZeroDivisionError
661 def cmp1(x, y):
662 return BadCmp()
663 with self.assertRaises(ZeroDivisionError):
664 key(3) > key(1)
665
666 def test_obj_field(self):
667 def cmp1(x, y):
668 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100669 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700670 self.assertEqual(key(50).obj, 50)
671
672 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000673 def mycmp(x, y):
674 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100675 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000676 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000677
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700678 def test_sort_int_str(self):
679 def mycmp(x, y):
680 x, y = int(x), int(y)
681 return (x > y) - (x < y)
682 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100683 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700684 self.assertEqual([int(value) for value in values],
685 [0, 1, 1, 2, 3, 4, 5, 7, 10])
686
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000687 def test_hash(self):
688 def mycmp(x, y):
689 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100690 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000691 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700692 self.assertRaises(TypeError, hash, k)
Raymond Hettingere7a24302011-05-03 11:16:36 -0700693 self.assertNotIsInstance(k, collections.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000694
Łukasz Langa6f692512013-06-05 12:20:24 +0200695
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200696@unittest.skipUnless(c_functools, 'requires the C _functools module')
697class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
698 if c_functools:
699 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100700
Łukasz Langa6f692512013-06-05 12:20:24 +0200701
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200702class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100703 cmp_to_key = staticmethod(py_functools.cmp_to_key)
704
Łukasz Langa6f692512013-06-05 12:20:24 +0200705
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000706class TestTotalOrdering(unittest.TestCase):
707
708 def test_total_ordering_lt(self):
709 @functools.total_ordering
710 class A:
711 def __init__(self, value):
712 self.value = value
713 def __lt__(self, other):
714 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000715 def __eq__(self, other):
716 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000717 self.assertTrue(A(1) < A(2))
718 self.assertTrue(A(2) > A(1))
719 self.assertTrue(A(1) <= A(2))
720 self.assertTrue(A(2) >= A(1))
721 self.assertTrue(A(2) <= A(2))
722 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000723 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000724
725 def test_total_ordering_le(self):
726 @functools.total_ordering
727 class A:
728 def __init__(self, value):
729 self.value = value
730 def __le__(self, other):
731 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000732 def __eq__(self, other):
733 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000734 self.assertTrue(A(1) < A(2))
735 self.assertTrue(A(2) > A(1))
736 self.assertTrue(A(1) <= A(2))
737 self.assertTrue(A(2) >= A(1))
738 self.assertTrue(A(2) <= A(2))
739 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000740 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000741
742 def test_total_ordering_gt(self):
743 @functools.total_ordering
744 class A:
745 def __init__(self, value):
746 self.value = value
747 def __gt__(self, other):
748 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000749 def __eq__(self, other):
750 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000751 self.assertTrue(A(1) < A(2))
752 self.assertTrue(A(2) > A(1))
753 self.assertTrue(A(1) <= A(2))
754 self.assertTrue(A(2) >= A(1))
755 self.assertTrue(A(2) <= A(2))
756 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000757 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000758
759 def test_total_ordering_ge(self):
760 @functools.total_ordering
761 class A:
762 def __init__(self, value):
763 self.value = value
764 def __ge__(self, other):
765 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000766 def __eq__(self, other):
767 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000768 self.assertTrue(A(1) < A(2))
769 self.assertTrue(A(2) > A(1))
770 self.assertTrue(A(1) <= A(2))
771 self.assertTrue(A(2) >= A(1))
772 self.assertTrue(A(2) <= A(2))
773 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000774 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000775
776 def test_total_ordering_no_overwrite(self):
777 # new methods should not overwrite existing
778 @functools.total_ordering
779 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000780 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +0000781 self.assertTrue(A(1) < A(2))
782 self.assertTrue(A(2) > A(1))
783 self.assertTrue(A(1) <= A(2))
784 self.assertTrue(A(2) >= A(1))
785 self.assertTrue(A(2) <= A(2))
786 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000787
Benjamin Peterson42ebee32010-04-11 01:43:16 +0000788 def test_no_operations_defined(self):
789 with self.assertRaises(ValueError):
790 @functools.total_ordering
791 class A:
792 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000793
Nick Coghlanf05d9812013-10-02 00:02:03 +1000794 def test_type_error_when_not_implemented(self):
795 # bug 10042; ensure stack overflow does not occur
796 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000797 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +1000798 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000799 def __init__(self, value):
800 self.value = value
801 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000802 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000803 return self.value == other.value
804 return False
805 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000806 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000807 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +1000808 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000809
Nick Coghlanf05d9812013-10-02 00:02:03 +1000810 @functools.total_ordering
811 class ImplementsGreaterThan:
812 def __init__(self, value):
813 self.value = value
814 def __eq__(self, other):
815 if isinstance(other, ImplementsGreaterThan):
816 return self.value == other.value
817 return False
818 def __gt__(self, other):
819 if isinstance(other, ImplementsGreaterThan):
820 return self.value > other.value
821 return NotImplemented
822
823 @functools.total_ordering
824 class ImplementsLessThanEqualTo:
825 def __init__(self, value):
826 self.value = value
827 def __eq__(self, other):
828 if isinstance(other, ImplementsLessThanEqualTo):
829 return self.value == other.value
830 return False
831 def __le__(self, other):
832 if isinstance(other, ImplementsLessThanEqualTo):
833 return self.value <= other.value
834 return NotImplemented
835
836 @functools.total_ordering
837 class ImplementsGreaterThanEqualTo:
838 def __init__(self, value):
839 self.value = value
840 def __eq__(self, other):
841 if isinstance(other, ImplementsGreaterThanEqualTo):
842 return self.value == other.value
843 return False
844 def __ge__(self, other):
845 if isinstance(other, ImplementsGreaterThanEqualTo):
846 return self.value >= other.value
847 return NotImplemented
848
849 @functools.total_ordering
850 class ComparatorNotImplemented:
851 def __init__(self, value):
852 self.value = value
853 def __eq__(self, other):
854 if isinstance(other, ComparatorNotImplemented):
855 return self.value == other.value
856 return False
857 def __lt__(self, other):
858 return NotImplemented
859
860 with self.subTest("LT < 1"), self.assertRaises(TypeError):
861 ImplementsLessThan(-1) < 1
862
863 with self.subTest("LT < LE"), self.assertRaises(TypeError):
864 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
865
866 with self.subTest("LT < GT"), self.assertRaises(TypeError):
867 ImplementsLessThan(1) < ImplementsGreaterThan(1)
868
869 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
870 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
871
872 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
873 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
874
875 with self.subTest("GT > GE"), self.assertRaises(TypeError):
876 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
877
878 with self.subTest("GT > LT"), self.assertRaises(TypeError):
879 ImplementsGreaterThan(5) > ImplementsLessThan(5)
880
881 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
882 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
883
884 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
885 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
886
887 with self.subTest("GE when equal"):
888 a = ComparatorNotImplemented(8)
889 b = ComparatorNotImplemented(8)
890 self.assertEqual(a, b)
891 with self.assertRaises(TypeError):
892 a >= b
893
894 with self.subTest("LE when equal"):
895 a = ComparatorNotImplemented(9)
896 b = ComparatorNotImplemented(9)
897 self.assertEqual(a, b)
898 with self.assertRaises(TypeError):
899 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +0200900
Serhiy Storchaka697a5262015-01-01 15:23:12 +0200901 def test_pickle(self):
902 for proto in range(4, pickle.HIGHEST_PROTOCOL + 1):
903 for name in '__lt__', '__gt__', '__le__', '__ge__':
904 with self.subTest(method=name, proto=proto):
905 method = getattr(Orderable_LT, name)
906 method_copy = pickle.loads(pickle.dumps(method, proto))
907 self.assertIs(method_copy, method)
908
909@functools.total_ordering
910class Orderable_LT:
911 def __init__(self, value):
912 self.value = value
913 def __lt__(self, other):
914 return self.value < other.value
915 def __eq__(self, other):
916 return self.value == other.value
917
918
Serhiy Storchaka46c56112015-05-24 21:53:49 +0300919class TestLRU:
Georg Brandl2e7346a2010-07-31 18:09:23 +0000920
921 def test_lru(self):
922 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100923 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +0300924 f = self.module.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000925 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000926 self.assertEqual(maxsize, 20)
927 self.assertEqual(currsize, 0)
928 self.assertEqual(hits, 0)
929 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000930
931 domain = range(5)
932 for i in range(1000):
933 x, y = choice(domain), choice(domain)
934 actual = f(x, y)
935 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +0000936 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000937 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000938 self.assertTrue(hits > misses)
939 self.assertEqual(hits + misses, 1000)
940 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000941
Raymond Hettinger02566ec2010-09-04 22:46:06 +0000942 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +0000943 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000944 self.assertEqual(hits, 0)
945 self.assertEqual(misses, 0)
946 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000947 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000948 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000949 self.assertEqual(hits, 0)
950 self.assertEqual(misses, 1)
951 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000952
Nick Coghlan98876832010-08-17 06:17:18 +0000953 # Test bypassing the cache
954 self.assertIs(f.__wrapped__, orig)
955 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000956 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000957 self.assertEqual(hits, 0)
958 self.assertEqual(misses, 1)
959 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +0000960
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000961 # test size zero (which means "never-cache")
Serhiy Storchaka46c56112015-05-24 21:53:49 +0300962 @self.module.lru_cache(0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000963 def f():
964 nonlocal f_cnt
965 f_cnt += 1
966 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000967 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000968 f_cnt = 0
969 for i in range(5):
970 self.assertEqual(f(), 20)
971 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000972 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000973 self.assertEqual(hits, 0)
974 self.assertEqual(misses, 5)
975 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000976
977 # test size one
Serhiy Storchaka46c56112015-05-24 21:53:49 +0300978 @self.module.lru_cache(1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000979 def f():
980 nonlocal f_cnt
981 f_cnt += 1
982 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000983 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000984 f_cnt = 0
985 for i in range(5):
986 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000987 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000988 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000989 self.assertEqual(hits, 4)
990 self.assertEqual(misses, 1)
991 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000992
Raymond Hettingerf3098282010-08-15 03:30:45 +0000993 # test size two
Serhiy Storchaka46c56112015-05-24 21:53:49 +0300994 @self.module.lru_cache(2)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000995 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000996 nonlocal f_cnt
997 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +0000998 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +0000999 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001000 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +00001001 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1002 # * * * *
1003 self.assertEqual(f(x), x*10)
1004 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001005 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001006 self.assertEqual(hits, 12)
1007 self.assertEqual(misses, 4)
1008 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001009
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001010 def test_lru_with_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001011 @self.module.lru_cache(maxsize=None)
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001012 def fib(n):
1013 if n < 2:
1014 return n
1015 return fib(n-1) + fib(n-2)
1016 self.assertEqual([fib(n) for n in range(16)],
1017 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1018 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001019 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001020 fib.cache_clear()
1021 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001022 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1023
1024 def test_lru_with_maxsize_negative(self):
1025 @self.module.lru_cache(maxsize=-10)
1026 def eq(n):
1027 return n
1028 for i in (0, 1):
1029 self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1030 self.assertEqual(eq.cache_info(),
1031 self.module._CacheInfo(hits=0, misses=300, maxsize=-10, currsize=1))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001032
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001033 def test_lru_with_exceptions(self):
1034 # Verify that user_function exceptions get passed through without
1035 # creating a hard-to-read chained exception.
1036 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +01001037 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001038 @self.module.lru_cache(maxsize)
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001039 def func(i):
1040 return 'abc'[i]
1041 self.assertEqual(func(0), 'a')
1042 with self.assertRaises(IndexError) as cm:
1043 func(15)
1044 self.assertIsNone(cm.exception.__context__)
1045 # Verify that the previous exception did not result in a cached entry
1046 with self.assertRaises(IndexError):
1047 func(15)
1048
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001049 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001050 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001051 @self.module.lru_cache(maxsize=maxsize, typed=True)
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001052 def square(x):
1053 return x * x
1054 self.assertEqual(square(3), 9)
1055 self.assertEqual(type(square(3)), type(9))
1056 self.assertEqual(square(3.0), 9.0)
1057 self.assertEqual(type(square(3.0)), type(9.0))
1058 self.assertEqual(square(x=3), 9)
1059 self.assertEqual(type(square(x=3)), type(9))
1060 self.assertEqual(square(x=3.0), 9.0)
1061 self.assertEqual(type(square(x=3.0)), type(9.0))
1062 self.assertEqual(square.cache_info().hits, 4)
1063 self.assertEqual(square.cache_info().misses, 4)
1064
Antoine Pitroub5b37142012-11-13 21:35:40 +01001065 def test_lru_with_keyword_args(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001066 @self.module.lru_cache()
Antoine Pitroub5b37142012-11-13 21:35:40 +01001067 def fib(n):
1068 if n < 2:
1069 return n
1070 return fib(n=n-1) + fib(n=n-2)
1071 self.assertEqual(
1072 [fib(n=number) for number in range(16)],
1073 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1074 )
1075 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001076 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001077 fib.cache_clear()
1078 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001079 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001080
1081 def test_lru_with_keyword_args_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001082 @self.module.lru_cache(maxsize=None)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001083 def fib(n):
1084 if n < 2:
1085 return n
1086 return fib(n=n-1) + fib(n=n-2)
1087 self.assertEqual([fib(n=number) for number in range(16)],
1088 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1089 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001090 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001091 fib.cache_clear()
1092 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001093 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1094
1095 def test_lru_cache_decoration(self):
1096 def f(zomg: 'zomg_annotation'):
1097 """f doc string"""
1098 return 42
1099 g = self.module.lru_cache()(f)
1100 for attr in self.module.WRAPPER_ASSIGNMENTS:
1101 self.assertEqual(getattr(g, attr), getattr(f, attr))
1102
1103 @unittest.skipUnless(threading, 'This test requires threading.')
1104 def test_lru_cache_threaded(self):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001105 n, m = 5, 11
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001106 def orig(x, y):
1107 return 3 * x + y
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001108 f = self.module.lru_cache(maxsize=n*m)(orig)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001109 hits, misses, maxsize, currsize = f.cache_info()
1110 self.assertEqual(currsize, 0)
1111
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001112 start = threading.Event()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001113 def full(k):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001114 start.wait(10)
1115 for _ in range(m):
Serhiy Storchaka391af752015-06-08 12:44:18 +03001116 self.assertEqual(f(k, 0), orig(k, 0))
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001117
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001118 def clear():
1119 start.wait(10)
1120 for _ in range(2*m):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001121 f.cache_clear()
1122
1123 orig_si = sys.getswitchinterval()
1124 sys.setswitchinterval(1e-6)
1125 try:
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001126 # create n threads in order to fill cache
Serhiy Storchaka391af752015-06-08 12:44:18 +03001127 threads = [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001128 for k in range(n)]
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001129 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001130 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001131
1132 hits, misses, maxsize, currsize = f.cache_info()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001133 if self.module is py_functools:
1134 # XXX: Why can be not equal?
1135 self.assertLessEqual(misses, n)
1136 self.assertLessEqual(hits, m*n - misses)
1137 else:
1138 self.assertEqual(misses, n)
1139 self.assertEqual(hits, m*n - misses)
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001140 self.assertEqual(currsize, n)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001141
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001142 # create n threads in order to fill cache and 1 to clear it
1143 threads = [threading.Thread(target=clear)]
Serhiy Storchaka391af752015-06-08 12:44:18 +03001144 threads += [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001145 for k in range(n)]
1146 start.clear()
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001147 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001148 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001149 finally:
1150 sys.setswitchinterval(orig_si)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001151
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001152 @unittest.skipUnless(threading, 'This test requires threading.')
1153 def test_lru_cache_threaded2(self):
1154 # Simultaneous call with the same arguments
1155 n, m = 5, 7
1156 start = threading.Barrier(n+1)
1157 pause = threading.Barrier(n+1)
1158 stop = threading.Barrier(n+1)
1159 @self.module.lru_cache(maxsize=m*n)
1160 def f(x):
1161 pause.wait(10)
1162 return 3 * x
1163 self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1164 def test():
1165 for i in range(m):
1166 start.wait(10)
1167 self.assertEqual(f(i), 3 * i)
1168 stop.wait(10)
1169 threads = [threading.Thread(target=test) for k in range(n)]
1170 with support.start_threads(threads):
1171 for i in range(m):
1172 start.wait(10)
1173 stop.reset()
1174 pause.wait(10)
1175 start.reset()
1176 stop.wait(10)
1177 pause.reset()
1178 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1179
Raymond Hettinger03923422013-03-04 02:52:50 -05001180 def test_need_for_rlock(self):
1181 # This will deadlock on an LRU cache that uses a regular lock
1182
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001183 @self.module.lru_cache(maxsize=10)
Raymond Hettinger03923422013-03-04 02:52:50 -05001184 def test_func(x):
1185 'Used to demonstrate a reentrant lru_cache call within a single thread'
1186 return x
1187
1188 class DoubleEq:
1189 'Demonstrate a reentrant lru_cache call within a single thread'
1190 def __init__(self, x):
1191 self.x = x
1192 def __hash__(self):
1193 return self.x
1194 def __eq__(self, other):
1195 if self.x == 2:
1196 test_func(DoubleEq(1))
1197 return self.x == other.x
1198
1199 test_func(DoubleEq(1)) # Load the cache
1200 test_func(DoubleEq(2)) # Load the cache
1201 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1202 DoubleEq(2)) # Verify the correct return value
1203
Raymond Hettinger4d588972014-08-12 12:44:52 -07001204 def test_early_detection_of_bad_call(self):
1205 # Issue #22184
1206 with self.assertRaises(TypeError):
1207 @functools.lru_cache
1208 def f():
1209 pass
1210
Serhiy Storchakae7070f02015-06-08 11:19:24 +03001211 def test_lru_method(self):
1212 class X(int):
1213 f_cnt = 0
1214 @self.module.lru_cache(2)
1215 def f(self, x):
1216 self.f_cnt += 1
1217 return x*10+self
1218 a = X(5)
1219 b = X(5)
1220 c = X(7)
1221 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1222
1223 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1224 self.assertEqual(a.f(x), x*10 + 5)
1225 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1226 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1227
1228 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1229 self.assertEqual(b.f(x), x*10 + 5)
1230 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1231 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1232
1233 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1234 self.assertEqual(c.f(x), x*10 + 7)
1235 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1236 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1237
1238 self.assertEqual(a.f.cache_info(), X.f.cache_info())
1239 self.assertEqual(b.f.cache_info(), X.f.cache_info())
1240 self.assertEqual(c.f.cache_info(), X.f.cache_info())
1241
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001242class TestLRUC(TestLRU, unittest.TestCase):
1243 module = c_functools
1244
1245class TestLRUPy(TestLRU, unittest.TestCase):
1246 module = py_functools
1247
Raymond Hettinger03923422013-03-04 02:52:50 -05001248
Łukasz Langa6f692512013-06-05 12:20:24 +02001249class TestSingleDispatch(unittest.TestCase):
1250 def test_simple_overloads(self):
1251 @functools.singledispatch
1252 def g(obj):
1253 return "base"
1254 def g_int(i):
1255 return "integer"
1256 g.register(int, g_int)
1257 self.assertEqual(g("str"), "base")
1258 self.assertEqual(g(1), "integer")
1259 self.assertEqual(g([1,2,3]), "base")
1260
1261 def test_mro(self):
1262 @functools.singledispatch
1263 def g(obj):
1264 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001265 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001266 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001267 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001268 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001269 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001270 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001271 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001272 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001273 def g_A(a):
1274 return "A"
1275 def g_B(b):
1276 return "B"
1277 g.register(A, g_A)
1278 g.register(B, g_B)
1279 self.assertEqual(g(A()), "A")
1280 self.assertEqual(g(B()), "B")
1281 self.assertEqual(g(C()), "A")
1282 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001283
1284 def test_register_decorator(self):
1285 @functools.singledispatch
1286 def g(obj):
1287 return "base"
1288 @g.register(int)
1289 def g_int(i):
1290 return "int %s" % (i,)
1291 self.assertEqual(g(""), "base")
1292 self.assertEqual(g(12), "int 12")
1293 self.assertIs(g.dispatch(int), g_int)
1294 self.assertIs(g.dispatch(object), g.dispatch(str))
1295 # Note: in the assert above this is not g.
1296 # @singledispatch returns the wrapper.
1297
1298 def test_wrapping_attributes(self):
1299 @functools.singledispatch
1300 def g(obj):
1301 "Simple test"
1302 return "Test"
1303 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02001304 if sys.flags.optimize < 2:
1305 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02001306
1307 @unittest.skipUnless(decimal, 'requires _decimal')
1308 @support.cpython_only
1309 def test_c_classes(self):
1310 @functools.singledispatch
1311 def g(obj):
1312 return "base"
1313 @g.register(decimal.DecimalException)
1314 def _(obj):
1315 return obj.args
1316 subn = decimal.Subnormal("Exponent < Emin")
1317 rnd = decimal.Rounded("Number got rounded")
1318 self.assertEqual(g(subn), ("Exponent < Emin",))
1319 self.assertEqual(g(rnd), ("Number got rounded",))
1320 @g.register(decimal.Subnormal)
1321 def _(obj):
1322 return "Too small to care."
1323 self.assertEqual(g(subn), "Too small to care.")
1324 self.assertEqual(g(rnd), ("Number got rounded",))
1325
1326 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001327 # None of the examples in this test depend on haystack ordering.
Łukasz Langa6f692512013-06-05 12:20:24 +02001328 c = collections
1329 mro = functools._compose_mro
1330 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1331 for haystack in permutations(bases):
1332 m = mro(dict, haystack)
Łukasz Langa3720c772013-07-01 16:00:38 +02001333 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, c.Sized,
1334 c.Iterable, c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001335 bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict]
1336 for haystack in permutations(bases):
1337 m = mro(c.ChainMap, haystack)
1338 self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping,
1339 c.Sized, c.Iterable, c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001340
1341 # If there's a generic function with implementations registered for
1342 # both Sized and Container, passing a defaultdict to it results in an
1343 # ambiguous dispatch which will cause a RuntimeError (see
1344 # test_mro_conflicts).
1345 bases = [c.Container, c.Sized, str]
1346 for haystack in permutations(bases):
1347 m = mro(c.defaultdict, [c.Sized, c.Container, str])
1348 self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container,
1349 object])
1350
1351 # MutableSequence below is registered directly on D. In other words, it
1352 # preceeds MutableMapping which means single dispatch will always
1353 # choose MutableSequence here.
1354 class D(c.defaultdict):
1355 pass
1356 c.MutableSequence.register(D)
1357 bases = [c.MutableSequence, c.MutableMapping]
1358 for haystack in permutations(bases):
1359 m = mro(D, bases)
1360 self.assertEqual(m, [D, c.MutableSequence, c.Sequence,
1361 c.defaultdict, dict, c.MutableMapping,
1362 c.Mapping, c.Sized, c.Iterable, c.Container,
1363 object])
1364
1365 # Container and Callable are registered on different base classes and
1366 # a generic function supporting both should always pick the Callable
1367 # implementation if a C instance is passed.
1368 class C(c.defaultdict):
1369 def __call__(self):
1370 pass
1371 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1372 for haystack in permutations(bases):
1373 m = mro(C, haystack)
1374 self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping,
1375 c.Sized, c.Iterable, c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001376
1377 def test_register_abc(self):
1378 c = collections
1379 d = {"a": "b"}
1380 l = [1, 2, 3]
1381 s = {object(), None}
1382 f = frozenset(s)
1383 t = (1, 2, 3)
1384 @functools.singledispatch
1385 def g(obj):
1386 return "base"
1387 self.assertEqual(g(d), "base")
1388 self.assertEqual(g(l), "base")
1389 self.assertEqual(g(s), "base")
1390 self.assertEqual(g(f), "base")
1391 self.assertEqual(g(t), "base")
1392 g.register(c.Sized, lambda obj: "sized")
1393 self.assertEqual(g(d), "sized")
1394 self.assertEqual(g(l), "sized")
1395 self.assertEqual(g(s), "sized")
1396 self.assertEqual(g(f), "sized")
1397 self.assertEqual(g(t), "sized")
1398 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1399 self.assertEqual(g(d), "mutablemapping")
1400 self.assertEqual(g(l), "sized")
1401 self.assertEqual(g(s), "sized")
1402 self.assertEqual(g(f), "sized")
1403 self.assertEqual(g(t), "sized")
1404 g.register(c.ChainMap, lambda obj: "chainmap")
1405 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1406 self.assertEqual(g(l), "sized")
1407 self.assertEqual(g(s), "sized")
1408 self.assertEqual(g(f), "sized")
1409 self.assertEqual(g(t), "sized")
1410 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1411 self.assertEqual(g(d), "mutablemapping")
1412 self.assertEqual(g(l), "mutablesequence")
1413 self.assertEqual(g(s), "sized")
1414 self.assertEqual(g(f), "sized")
1415 self.assertEqual(g(t), "sized")
1416 g.register(c.MutableSet, lambda obj: "mutableset")
1417 self.assertEqual(g(d), "mutablemapping")
1418 self.assertEqual(g(l), "mutablesequence")
1419 self.assertEqual(g(s), "mutableset")
1420 self.assertEqual(g(f), "sized")
1421 self.assertEqual(g(t), "sized")
1422 g.register(c.Mapping, lambda obj: "mapping")
1423 self.assertEqual(g(d), "mutablemapping") # not specific enough
1424 self.assertEqual(g(l), "mutablesequence")
1425 self.assertEqual(g(s), "mutableset")
1426 self.assertEqual(g(f), "sized")
1427 self.assertEqual(g(t), "sized")
1428 g.register(c.Sequence, lambda obj: "sequence")
1429 self.assertEqual(g(d), "mutablemapping")
1430 self.assertEqual(g(l), "mutablesequence")
1431 self.assertEqual(g(s), "mutableset")
1432 self.assertEqual(g(f), "sized")
1433 self.assertEqual(g(t), "sequence")
1434 g.register(c.Set, lambda obj: "set")
1435 self.assertEqual(g(d), "mutablemapping")
1436 self.assertEqual(g(l), "mutablesequence")
1437 self.assertEqual(g(s), "mutableset")
1438 self.assertEqual(g(f), "set")
1439 self.assertEqual(g(t), "sequence")
1440 g.register(dict, lambda obj: "dict")
1441 self.assertEqual(g(d), "dict")
1442 self.assertEqual(g(l), "mutablesequence")
1443 self.assertEqual(g(s), "mutableset")
1444 self.assertEqual(g(f), "set")
1445 self.assertEqual(g(t), "sequence")
1446 g.register(list, lambda obj: "list")
1447 self.assertEqual(g(d), "dict")
1448 self.assertEqual(g(l), "list")
1449 self.assertEqual(g(s), "mutableset")
1450 self.assertEqual(g(f), "set")
1451 self.assertEqual(g(t), "sequence")
1452 g.register(set, lambda obj: "concrete-set")
1453 self.assertEqual(g(d), "dict")
1454 self.assertEqual(g(l), "list")
1455 self.assertEqual(g(s), "concrete-set")
1456 self.assertEqual(g(f), "set")
1457 self.assertEqual(g(t), "sequence")
1458 g.register(frozenset, lambda obj: "frozen-set")
1459 self.assertEqual(g(d), "dict")
1460 self.assertEqual(g(l), "list")
1461 self.assertEqual(g(s), "concrete-set")
1462 self.assertEqual(g(f), "frozen-set")
1463 self.assertEqual(g(t), "sequence")
1464 g.register(tuple, lambda obj: "tuple")
1465 self.assertEqual(g(d), "dict")
1466 self.assertEqual(g(l), "list")
1467 self.assertEqual(g(s), "concrete-set")
1468 self.assertEqual(g(f), "frozen-set")
1469 self.assertEqual(g(t), "tuple")
1470
Łukasz Langa3720c772013-07-01 16:00:38 +02001471 def test_c3_abc(self):
1472 c = collections
1473 mro = functools._c3_mro
1474 class A(object):
1475 pass
1476 class B(A):
1477 def __len__(self):
1478 return 0 # implies Sized
1479 @c.Container.register
1480 class C(object):
1481 pass
1482 class D(object):
1483 pass # unrelated
1484 class X(D, C, B):
1485 def __call__(self):
1486 pass # implies Callable
1487 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
1488 for abcs in permutations([c.Sized, c.Callable, c.Container]):
1489 self.assertEqual(mro(X, abcs=abcs), expected)
1490 # unrelated ABCs don't appear in the resulting MRO
1491 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
1492 self.assertEqual(mro(X, abcs=many_abcs), expected)
1493
Yury Selivanov77a8cd62015-08-18 14:20:00 -04001494 def test_false_meta(self):
1495 # see issue23572
1496 class MetaA(type):
1497 def __len__(self):
1498 return 0
1499 class A(metaclass=MetaA):
1500 pass
1501 class AA(A):
1502 pass
1503 @functools.singledispatch
1504 def fun(a):
1505 return 'base A'
1506 @fun.register(A)
1507 def _(a):
1508 return 'fun A'
1509 aa = AA()
1510 self.assertEqual(fun(aa), 'fun A')
1511
Łukasz Langa6f692512013-06-05 12:20:24 +02001512 def test_mro_conflicts(self):
1513 c = collections
Łukasz Langa6f692512013-06-05 12:20:24 +02001514 @functools.singledispatch
1515 def g(arg):
1516 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02001517 class O(c.Sized):
1518 def __len__(self):
1519 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001520 o = O()
1521 self.assertEqual(g(o), "base")
1522 g.register(c.Iterable, lambda arg: "iterable")
1523 g.register(c.Container, lambda arg: "container")
1524 g.register(c.Sized, lambda arg: "sized")
1525 g.register(c.Set, lambda arg: "set")
1526 self.assertEqual(g(o), "sized")
1527 c.Iterable.register(O)
1528 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
1529 c.Container.register(O)
1530 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02001531 c.Set.register(O)
1532 self.assertEqual(g(o), "set") # because c.Set is a subclass of
1533 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02001534 class P:
1535 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02001536 p = P()
1537 self.assertEqual(g(p), "base")
1538 c.Iterable.register(P)
1539 self.assertEqual(g(p), "iterable")
1540 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02001541 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02001542 g(p)
Łukasz Langa3720c772013-07-01 16:00:38 +02001543 self.assertIn(
1544 str(re_one.exception),
1545 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1546 "or <class 'collections.abc.Iterable'>"),
1547 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
1548 "or <class 'collections.abc.Container'>")),
1549 )
Łukasz Langa6f692512013-06-05 12:20:24 +02001550 class Q(c.Sized):
1551 def __len__(self):
1552 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001553 q = Q()
1554 self.assertEqual(g(q), "sized")
1555 c.Iterable.register(Q)
1556 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
1557 c.Set.register(Q)
1558 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02001559 # c.Sized and c.Iterable
1560 @functools.singledispatch
1561 def h(arg):
1562 return "base"
1563 @h.register(c.Sized)
1564 def _(arg):
1565 return "sized"
1566 @h.register(c.Container)
1567 def _(arg):
1568 return "container"
1569 # Even though Sized and Container are explicit bases of MutableMapping,
1570 # this ABC is implicitly registered on defaultdict which makes all of
1571 # MutableMapping's bases implicit as well from defaultdict's
1572 # perspective.
1573 with self.assertRaises(RuntimeError) as re_two:
1574 h(c.defaultdict(lambda: 0))
1575 self.assertIn(
1576 str(re_two.exception),
1577 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1578 "or <class 'collections.abc.Sized'>"),
1579 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1580 "or <class 'collections.abc.Container'>")),
1581 )
1582 class R(c.defaultdict):
1583 pass
1584 c.MutableSequence.register(R)
1585 @functools.singledispatch
1586 def i(arg):
1587 return "base"
1588 @i.register(c.MutableMapping)
1589 def _(arg):
1590 return "mapping"
1591 @i.register(c.MutableSequence)
1592 def _(arg):
1593 return "sequence"
1594 r = R()
1595 self.assertEqual(i(r), "sequence")
1596 class S:
1597 pass
1598 class T(S, c.Sized):
1599 def __len__(self):
1600 return 0
1601 t = T()
1602 self.assertEqual(h(t), "sized")
1603 c.Container.register(T)
1604 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
1605 class U:
1606 def __len__(self):
1607 return 0
1608 u = U()
1609 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
1610 # from the existence of __len__()
1611 c.Container.register(U)
1612 # There is no preference for registered versus inferred ABCs.
1613 with self.assertRaises(RuntimeError) as re_three:
1614 h(u)
1615 self.assertIn(
1616 str(re_three.exception),
1617 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1618 "or <class 'collections.abc.Sized'>"),
1619 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1620 "or <class 'collections.abc.Container'>")),
1621 )
1622 class V(c.Sized, S):
1623 def __len__(self):
1624 return 0
1625 @functools.singledispatch
1626 def j(arg):
1627 return "base"
1628 @j.register(S)
1629 def _(arg):
1630 return "s"
1631 @j.register(c.Container)
1632 def _(arg):
1633 return "container"
1634 v = V()
1635 self.assertEqual(j(v), "s")
1636 c.Container.register(V)
1637 self.assertEqual(j(v), "container") # because it ends up right after
1638 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02001639
1640 def test_cache_invalidation(self):
1641 from collections import UserDict
1642 class TracingDict(UserDict):
1643 def __init__(self, *args, **kwargs):
1644 super(TracingDict, self).__init__(*args, **kwargs)
1645 self.set_ops = []
1646 self.get_ops = []
1647 def __getitem__(self, key):
1648 result = self.data[key]
1649 self.get_ops.append(key)
1650 return result
1651 def __setitem__(self, key, value):
1652 self.set_ops.append(key)
1653 self.data[key] = value
1654 def clear(self):
1655 self.data.clear()
1656 _orig_wkd = functools.WeakKeyDictionary
1657 td = TracingDict()
1658 functools.WeakKeyDictionary = lambda: td
1659 c = collections
1660 @functools.singledispatch
1661 def g(arg):
1662 return "base"
1663 d = {}
1664 l = []
1665 self.assertEqual(len(td), 0)
1666 self.assertEqual(g(d), "base")
1667 self.assertEqual(len(td), 1)
1668 self.assertEqual(td.get_ops, [])
1669 self.assertEqual(td.set_ops, [dict])
1670 self.assertEqual(td.data[dict], g.registry[object])
1671 self.assertEqual(g(l), "base")
1672 self.assertEqual(len(td), 2)
1673 self.assertEqual(td.get_ops, [])
1674 self.assertEqual(td.set_ops, [dict, list])
1675 self.assertEqual(td.data[dict], g.registry[object])
1676 self.assertEqual(td.data[list], g.registry[object])
1677 self.assertEqual(td.data[dict], td.data[list])
1678 self.assertEqual(g(l), "base")
1679 self.assertEqual(g(d), "base")
1680 self.assertEqual(td.get_ops, [list, dict])
1681 self.assertEqual(td.set_ops, [dict, list])
1682 g.register(list, lambda arg: "list")
1683 self.assertEqual(td.get_ops, [list, dict])
1684 self.assertEqual(len(td), 0)
1685 self.assertEqual(g(d), "base")
1686 self.assertEqual(len(td), 1)
1687 self.assertEqual(td.get_ops, [list, dict])
1688 self.assertEqual(td.set_ops, [dict, list, dict])
1689 self.assertEqual(td.data[dict],
1690 functools._find_impl(dict, g.registry))
1691 self.assertEqual(g(l), "list")
1692 self.assertEqual(len(td), 2)
1693 self.assertEqual(td.get_ops, [list, dict])
1694 self.assertEqual(td.set_ops, [dict, list, dict, list])
1695 self.assertEqual(td.data[list],
1696 functools._find_impl(list, g.registry))
1697 class X:
1698 pass
1699 c.MutableMapping.register(X) # Will not invalidate the cache,
1700 # not using ABCs yet.
1701 self.assertEqual(g(d), "base")
1702 self.assertEqual(g(l), "list")
1703 self.assertEqual(td.get_ops, [list, dict, dict, list])
1704 self.assertEqual(td.set_ops, [dict, list, dict, list])
1705 g.register(c.Sized, lambda arg: "sized")
1706 self.assertEqual(len(td), 0)
1707 self.assertEqual(g(d), "sized")
1708 self.assertEqual(len(td), 1)
1709 self.assertEqual(td.get_ops, [list, dict, dict, list])
1710 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
1711 self.assertEqual(g(l), "list")
1712 self.assertEqual(len(td), 2)
1713 self.assertEqual(td.get_ops, [list, dict, dict, list])
1714 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1715 self.assertEqual(g(l), "list")
1716 self.assertEqual(g(d), "sized")
1717 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
1718 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1719 g.dispatch(list)
1720 g.dispatch(dict)
1721 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
1722 list, dict])
1723 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1724 c.MutableSet.register(X) # Will invalidate the cache.
1725 self.assertEqual(len(td), 2) # Stale cache.
1726 self.assertEqual(g(l), "list")
1727 self.assertEqual(len(td), 1)
1728 g.register(c.MutableMapping, lambda arg: "mutablemapping")
1729 self.assertEqual(len(td), 0)
1730 self.assertEqual(g(d), "mutablemapping")
1731 self.assertEqual(len(td), 1)
1732 self.assertEqual(g(l), "list")
1733 self.assertEqual(len(td), 2)
1734 g.register(dict, lambda arg: "dict")
1735 self.assertEqual(g(d), "dict")
1736 self.assertEqual(g(l), "list")
1737 g._clear_cache()
1738 self.assertEqual(len(td), 0)
1739 functools.WeakKeyDictionary = _orig_wkd
1740
1741
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001742if __name__ == '__main__':
Zachary Ware38c707e2015-04-13 15:00:43 -05001743 unittest.main()