blob: 7ecf877b11de22b9acb3eef29f13961ce8d86ed3 [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettinger003be522011-05-03 11:01:32 -07002import collections
Łukasz Langa6f692512013-06-05 12:20:24 +02003from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00004import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00005from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02006import sys
7from test import support
8import unittest
9from weakref import proxy
Serhiy Storchaka46c56112015-05-24 21:53:49 +030010try:
11 import threading
12except ImportError:
13 threading = None
Raymond Hettinger9c323f82005-02-28 19:39:44 +000014
Antoine Pitroub5b37142012-11-13 21:35:40 +010015import functools
16
Antoine Pitroub5b37142012-11-13 21:35:40 +010017py_functools = support.import_fresh_module('functools', blocked=['_functools'])
18c_functools = support.import_fresh_module('functools', fresh=['_functools'])
19
Łukasz Langa6f692512013-06-05 12:20:24 +020020decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
21
22
Raymond Hettinger9c323f82005-02-28 19:39:44 +000023def capture(*args, **kw):
24 """capture all positional and keyword arguments"""
25 return args, kw
26
Łukasz Langa6f692512013-06-05 12:20:24 +020027
Jack Diederiche0cbd692009-04-01 04:27:09 +000028def signature(part):
29 """ return the signature of a partial object """
30 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000031
Łukasz Langa6f692512013-06-05 12:20:24 +020032
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020033class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000034
35 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010036 p = self.partial(capture, 1, 2, a=10, b=20)
37 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000038 self.assertEqual(p(3, 4, b=30, c=40),
39 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010040 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000041 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000042
43 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010044 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000045 # attributes should be readable
46 self.assertEqual(p.func, capture)
47 self.assertEqual(p.args, (1, 2))
48 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000049
50 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010051 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000052 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010053 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000054 except TypeError:
55 pass
56 else:
57 self.fail('First arg not checked for callability')
58
59 def test_protection_of_callers_dict_argument(self):
60 # a caller's dictionary should not be altered by partial
61 def func(a=10, b=20):
62 return a
63 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010064 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000065 self.assertEqual(p(**d), 3)
66 self.assertEqual(d, {'a':3})
67 p(b=7)
68 self.assertEqual(d, {'a':3})
69
70 def test_arg_combinations(self):
71 # exercise special code paths for zero args in either partial
72 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010073 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000074 self.assertEqual(p(), ((), {}))
75 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010076 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000077 self.assertEqual(p(), ((1,2), {}))
78 self.assertEqual(p(3,4), ((1,2,3,4), {}))
79
80 def test_kw_combinations(self):
81 # exercise special code paths for no keyword args in
82 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010083 p = self.partial(capture)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -040084 self.assertEqual(p.keywords, {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +000085 self.assertEqual(p(), ((), {}))
86 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010087 p = self.partial(capture, a=1)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -040088 self.assertEqual(p.keywords, {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +000089 self.assertEqual(p(), ((), {'a':1}))
90 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
91 # keyword args in the call override those in the partial object
92 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
93
94 def test_positional(self):
95 # make sure positional arguments are captured correctly
96 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +010097 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000098 expected = args + ('x',)
99 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000100 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000101
102 def test_keyword(self):
103 # make sure keyword arguments are captured correctly
104 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100105 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000106 expected = {'a':a,'x':None}
107 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000108 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000109
110 def test_no_side_effects(self):
111 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100112 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000113 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000114 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000115 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000116 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000117
118 def test_error_propagation(self):
119 def f(x, y):
120 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100121 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
122 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
123 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
124 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000125
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000126 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100127 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000128 p = proxy(f)
129 self.assertEqual(f.func, p.func)
130 f = None
131 self.assertRaises(ReferenceError, getattr, p, 'func')
132
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000133 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000134 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100135 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000136 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100137 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000138 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000139
Alexander Belopolskye49af342015-03-01 15:08:17 -0500140 def test_nested_optimization(self):
141 partial = self.partial
Alexander Belopolskye49af342015-03-01 15:08:17 -0500142 inner = partial(signature, 'asdf')
143 nested = partial(inner, bar=True)
144 flat = partial(signature, 'asdf', bar=True)
145 self.assertEqual(signature(nested), signature(flat))
146
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300147 def test_nested_partial_with_attribute(self):
148 # see issue 25137
149 partial = self.partial
150
151 def foo(bar):
152 return bar
153
154 p = partial(foo, 'first')
155 p2 = partial(p, 'second')
156 p2.new_attr = 'spam'
157 self.assertEqual(p2.new_attr, 'spam')
158
Łukasz Langa6f692512013-06-05 12:20:24 +0200159
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200160@unittest.skipUnless(c_functools, 'requires the C _functools module')
161class TestPartialC(TestPartial, unittest.TestCase):
162 if c_functools:
163 partial = c_functools.partial
164
Zachary Ware101d9e72013-12-08 00:44:27 -0600165 def test_attributes_unwritable(self):
166 # attributes should not be writable
167 p = self.partial(capture, 1, 2, a=10, b=20)
168 self.assertRaises(AttributeError, setattr, p, 'func', map)
169 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
170 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
171
172 p = self.partial(hex)
173 try:
174 del p.__dict__
175 except TypeError:
176 pass
177 else:
178 self.fail('partial object allowed __dict__ to be deleted')
179
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000180 def test_repr(self):
181 args = (object(), object())
182 args_repr = ', '.join(repr(a) for a in args)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200183 kwargs = {'a': object(), 'b': object()}
184 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
185 'b={b!r}, a={a!r}'.format_map(kwargs)]
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200186 if self.partial is c_functools.partial:
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000187 name = 'functools.partial'
188 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100189 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000190
Antoine Pitroub5b37142012-11-13 21:35:40 +0100191 f = self.partial(capture)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000192 self.assertEqual('{}({!r})'.format(name, capture),
193 repr(f))
194
Antoine Pitroub5b37142012-11-13 21:35:40 +0100195 f = self.partial(capture, *args)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000196 self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr),
197 repr(f))
198
Antoine Pitroub5b37142012-11-13 21:35:40 +0100199 f = self.partial(capture, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200200 self.assertIn(repr(f),
201 ['{}({!r}, {})'.format(name, capture, kwargs_repr)
202 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000203
Antoine Pitroub5b37142012-11-13 21:35:40 +0100204 f = self.partial(capture, *args, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200205 self.assertIn(repr(f),
206 ['{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr)
207 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000208
Jack Diederiche0cbd692009-04-01 04:27:09 +0000209 def test_pickle(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100210 f = self.partial(signature, 'asdf', bar=True)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000211 f.add_something_to__dict__ = True
Serhiy Storchakabad12572014-12-15 14:03:42 +0200212 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
213 f_copy = pickle.loads(pickle.dumps(f, proto))
214 self.assertEqual(signature(f), signature(f_copy))
Jack Diederiche0cbd692009-04-01 04:27:09 +0000215
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200216 # Issue 6083: Reference counting bug
217 def test_setstate_refcount(self):
218 class BadSequence:
219 def __len__(self):
220 return 4
221 def __getitem__(self, key):
222 if key == 0:
223 return max
224 elif key == 1:
225 return tuple(range(1000000))
226 elif key in (2, 3):
227 return {}
228 raise IndexError
229
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200230 f = self.partial(object)
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200231 self.assertRaisesRegex(SystemError,
232 "new style getargs format but argument is not a tuple",
233 f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000234
Łukasz Langa6f692512013-06-05 12:20:24 +0200235
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200236class TestPartialPy(TestPartial, unittest.TestCase):
237 partial = staticmethod(py_functools.partial)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000238
Łukasz Langa6f692512013-06-05 12:20:24 +0200239
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200240if c_functools:
241 class PartialSubclass(c_functools.partial):
242 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100243
Łukasz Langa6f692512013-06-05 12:20:24 +0200244
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200245@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200246class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200247 if c_functools:
248 partial = PartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000249
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300250 # partial subclasses are not optimized for nested calls
251 test_nested_optimization = None
252
Łukasz Langa6f692512013-06-05 12:20:24 +0200253
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000254class TestPartialMethod(unittest.TestCase):
255
256 class A(object):
257 nothing = functools.partialmethod(capture)
258 positional = functools.partialmethod(capture, 1)
259 keywords = functools.partialmethod(capture, a=2)
260 both = functools.partialmethod(capture, 3, b=4)
261
262 nested = functools.partialmethod(positional, 5)
263
264 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
265
266 static = functools.partialmethod(staticmethod(capture), 8)
267 cls = functools.partialmethod(classmethod(capture), d=9)
268
269 a = A()
270
271 def test_arg_combinations(self):
272 self.assertEqual(self.a.nothing(), ((self.a,), {}))
273 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
274 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
275 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
276
277 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
278 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
279 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
280 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
281
282 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
283 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
284 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
285 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
286
287 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
288 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
289 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
290 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
291
292 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
293
294 def test_nested(self):
295 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
296 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
297 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
298 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
299
300 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
301
302 def test_over_partial(self):
303 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
304 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
305 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
306 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
307
308 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
309
310 def test_bound_method_introspection(self):
311 obj = self.a
312 self.assertIs(obj.both.__self__, obj)
313 self.assertIs(obj.nested.__self__, obj)
314 self.assertIs(obj.over_partial.__self__, obj)
315 self.assertIs(obj.cls.__self__, self.A)
316 self.assertIs(self.A.cls.__self__, self.A)
317
318 def test_unbound_method_retrieval(self):
319 obj = self.A
320 self.assertFalse(hasattr(obj.both, "__self__"))
321 self.assertFalse(hasattr(obj.nested, "__self__"))
322 self.assertFalse(hasattr(obj.over_partial, "__self__"))
323 self.assertFalse(hasattr(obj.static, "__self__"))
324 self.assertFalse(hasattr(self.a.static, "__self__"))
325
326 def test_descriptors(self):
327 for obj in [self.A, self.a]:
328 with self.subTest(obj=obj):
329 self.assertEqual(obj.static(), ((8,), {}))
330 self.assertEqual(obj.static(5), ((8, 5), {}))
331 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
332 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
333
334 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
335 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
336 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
337 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
338
339 def test_overriding_keywords(self):
340 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
341 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
342
343 def test_invalid_args(self):
344 with self.assertRaises(TypeError):
345 class B(object):
346 method = functools.partialmethod(None, 1)
347
348 def test_repr(self):
349 self.assertEqual(repr(vars(self.A)['both']),
350 'functools.partialmethod({}, 3, b=4)'.format(capture))
351
352 def test_abstract(self):
353 class Abstract(abc.ABCMeta):
354
355 @abc.abstractmethod
356 def add(self, x, y):
357 pass
358
359 add5 = functools.partialmethod(add, 5)
360
361 self.assertTrue(Abstract.add.__isabstractmethod__)
362 self.assertTrue(Abstract.add5.__isabstractmethod__)
363
364 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
365 self.assertFalse(getattr(func, '__isabstractmethod__', False))
366
367
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000368class TestUpdateWrapper(unittest.TestCase):
369
370 def check_wrapper(self, wrapper, wrapped,
371 assigned=functools.WRAPPER_ASSIGNMENTS,
372 updated=functools.WRAPPER_UPDATES):
373 # Check attributes were assigned
374 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000375 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000376 # Check attributes were updated
377 for name in updated:
378 wrapper_attr = getattr(wrapper, name)
379 wrapped_attr = getattr(wrapped, name)
380 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000381 if name == "__dict__" and key == "__wrapped__":
382 # __wrapped__ is overwritten by the update code
383 continue
384 self.assertIs(wrapped_attr[key], wrapper_attr[key])
385 # Check __wrapped__
386 self.assertIs(wrapper.__wrapped__, wrapped)
387
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000388
R. David Murray378c0cf2010-02-24 01:46:21 +0000389 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000390 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000391 """This is a test"""
392 pass
393 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000394 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000395 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000396 pass
397 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000398 return wrapper, f
399
400 def test_default_update(self):
401 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000402 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000403 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000404 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600405 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000406 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000407 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
408 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000409
R. David Murray378c0cf2010-02-24 01:46:21 +0000410 @unittest.skipIf(sys.flags.optimize >= 2,
411 "Docstrings are omitted with -O2 and above")
412 def test_default_update_doc(self):
413 wrapper, f = self._default_update()
414 self.assertEqual(wrapper.__doc__, 'This is a test')
415
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000416 def test_no_update(self):
417 def f():
418 """This is a test"""
419 pass
420 f.attr = 'This is also a test'
421 def wrapper():
422 pass
423 functools.update_wrapper(wrapper, f, (), ())
424 self.check_wrapper(wrapper, f, (), ())
425 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600426 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000427 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000428 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000429 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000430
431 def test_selective_update(self):
432 def f():
433 pass
434 f.attr = 'This is a different test'
435 f.dict_attr = dict(a=1, b=2, c=3)
436 def wrapper():
437 pass
438 wrapper.dict_attr = {}
439 assign = ('attr',)
440 update = ('dict_attr',)
441 functools.update_wrapper(wrapper, f, assign, update)
442 self.check_wrapper(wrapper, f, assign, update)
443 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600444 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000445 self.assertEqual(wrapper.__doc__, None)
446 self.assertEqual(wrapper.attr, 'This is a different test')
447 self.assertEqual(wrapper.dict_attr, f.dict_attr)
448
Nick Coghlan98876832010-08-17 06:17:18 +0000449 def test_missing_attributes(self):
450 def f():
451 pass
452 def wrapper():
453 pass
454 wrapper.dict_attr = {}
455 assign = ('attr',)
456 update = ('dict_attr',)
457 # Missing attributes on wrapped object are ignored
458 functools.update_wrapper(wrapper, f, assign, update)
459 self.assertNotIn('attr', wrapper.__dict__)
460 self.assertEqual(wrapper.dict_attr, {})
461 # Wrapper must have expected attributes for updating
462 del wrapper.dict_attr
463 with self.assertRaises(AttributeError):
464 functools.update_wrapper(wrapper, f, assign, update)
465 wrapper.dict_attr = 1
466 with self.assertRaises(AttributeError):
467 functools.update_wrapper(wrapper, f, assign, update)
468
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200469 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000470 @unittest.skipIf(sys.flags.optimize >= 2,
471 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000472 def test_builtin_update(self):
473 # Test for bug #1576241
474 def wrapper():
475 pass
476 functools.update_wrapper(wrapper, max)
477 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000478 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000479 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000480
Łukasz Langa6f692512013-06-05 12:20:24 +0200481
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000482class TestWraps(TestUpdateWrapper):
483
R. David Murray378c0cf2010-02-24 01:46:21 +0000484 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000485 def f():
486 """This is a test"""
487 pass
488 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000489 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000490 @functools.wraps(f)
491 def wrapper():
492 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600493 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000494
495 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600496 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000497 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000498 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600499 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000500 self.assertEqual(wrapper.attr, 'This is also a test')
501
Antoine Pitroub5b37142012-11-13 21:35:40 +0100502 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000503 "Docstrings are omitted with -O2 and above")
504 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600505 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000506 self.assertEqual(wrapper.__doc__, 'This is a test')
507
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000508 def test_no_update(self):
509 def f():
510 """This is a test"""
511 pass
512 f.attr = 'This is also a test'
513 @functools.wraps(f, (), ())
514 def wrapper():
515 pass
516 self.check_wrapper(wrapper, f, (), ())
517 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600518 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000519 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000520 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000521
522 def test_selective_update(self):
523 def f():
524 pass
525 f.attr = 'This is a different test'
526 f.dict_attr = dict(a=1, b=2, c=3)
527 def add_dict_attr(f):
528 f.dict_attr = {}
529 return f
530 assign = ('attr',)
531 update = ('dict_attr',)
532 @functools.wraps(f, assign, update)
533 @add_dict_attr
534 def wrapper():
535 pass
536 self.check_wrapper(wrapper, f, assign, update)
537 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600538 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000539 self.assertEqual(wrapper.__doc__, None)
540 self.assertEqual(wrapper.attr, 'This is a different test')
541 self.assertEqual(wrapper.dict_attr, f.dict_attr)
542
Łukasz Langa6f692512013-06-05 12:20:24 +0200543
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000544class TestReduce(unittest.TestCase):
545 func = functools.reduce
546
547 def test_reduce(self):
548 class Squares:
549 def __init__(self, max):
550 self.max = max
551 self.sofar = []
552
553 def __len__(self):
554 return len(self.sofar)
555
556 def __getitem__(self, i):
557 if not 0 <= i < self.max: raise IndexError
558 n = len(self.sofar)
559 while n <= i:
560 self.sofar.append(n*n)
561 n += 1
562 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000563 def add(x, y):
564 return x + y
565 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000566 self.assertEqual(
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000567 self.func(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000568 ['a','c','d','w']
569 )
570 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
571 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000572 self.func(lambda x, y: x*y, range(2,21), 1),
573 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000574 )
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000575 self.assertEqual(self.func(add, Squares(10)), 285)
576 self.assertEqual(self.func(add, Squares(10), 0), 285)
577 self.assertEqual(self.func(add, Squares(0), 0), 0)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000578 self.assertRaises(TypeError, self.func)
579 self.assertRaises(TypeError, self.func, 42, 42)
580 self.assertRaises(TypeError, self.func, 42, 42, 42)
581 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
582 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
583 self.assertRaises(TypeError, self.func, 42, (42, 42))
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000584 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
585 self.assertRaises(TypeError, self.func, add, "")
586 self.assertRaises(TypeError, self.func, add, ())
587 self.assertRaises(TypeError, self.func, add, object())
588
589 class TestFailingIter:
590 def __iter__(self):
591 raise RuntimeError
592 self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
593
594 self.assertEqual(self.func(add, [], None), None)
595 self.assertEqual(self.func(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000596
597 class BadSeq:
598 def __getitem__(self, index):
599 raise ValueError
600 self.assertRaises(ValueError, self.func, 42, BadSeq())
601
602 # Test reduce()'s use of iterators.
603 def test_iterator_usage(self):
604 class SequenceClass:
605 def __init__(self, n):
606 self.n = n
607 def __getitem__(self, i):
608 if 0 <= i < self.n:
609 return i
610 else:
611 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000612
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000613 from operator import add
614 self.assertEqual(self.func(add, SequenceClass(5)), 10)
615 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
616 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
617 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
618 self.assertEqual(self.func(add, SequenceClass(1)), 0)
619 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
620
621 d = {"one": 1, "two": 2, "three": 3}
622 self.assertEqual(self.func(add, d), "".join(d.keys()))
623
Łukasz Langa6f692512013-06-05 12:20:24 +0200624
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200625class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700626
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000627 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700628 def cmp1(x, y):
629 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100630 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700631 self.assertEqual(key(3), key(3))
632 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100633 self.assertGreaterEqual(key(3), key(3))
634
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700635 def cmp2(x, y):
636 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100637 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700638 self.assertEqual(key(4.0), key('4'))
639 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100640 self.assertLessEqual(key(2), key('35'))
641 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700642
643 def test_cmp_to_key_arguments(self):
644 def cmp1(x, y):
645 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100646 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700647 self.assertEqual(key(obj=3), key(obj=3))
648 self.assertGreater(key(obj=3), key(obj=1))
649 with self.assertRaises((TypeError, AttributeError)):
650 key(3) > 1 # rhs is not a K object
651 with self.assertRaises((TypeError, AttributeError)):
652 1 < key(3) # lhs is not a K object
653 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100654 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700655 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200656 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100657 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700658 with self.assertRaises(TypeError):
659 key() # too few args
660 with self.assertRaises(TypeError):
661 key(None, None) # too many args
662
663 def test_bad_cmp(self):
664 def cmp1(x, y):
665 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100666 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700667 with self.assertRaises(ZeroDivisionError):
668 key(3) > key(1)
669
670 class BadCmp:
671 def __lt__(self, other):
672 raise ZeroDivisionError
673 def cmp1(x, y):
674 return BadCmp()
675 with self.assertRaises(ZeroDivisionError):
676 key(3) > key(1)
677
678 def test_obj_field(self):
679 def cmp1(x, y):
680 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100681 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700682 self.assertEqual(key(50).obj, 50)
683
684 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000685 def mycmp(x, y):
686 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100687 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000688 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000689
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700690 def test_sort_int_str(self):
691 def mycmp(x, y):
692 x, y = int(x), int(y)
693 return (x > y) - (x < y)
694 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100695 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700696 self.assertEqual([int(value) for value in values],
697 [0, 1, 1, 2, 3, 4, 5, 7, 10])
698
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000699 def test_hash(self):
700 def mycmp(x, y):
701 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100702 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000703 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700704 self.assertRaises(TypeError, hash, k)
Raymond Hettingere7a24302011-05-03 11:16:36 -0700705 self.assertNotIsInstance(k, collections.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000706
Łukasz Langa6f692512013-06-05 12:20:24 +0200707
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200708@unittest.skipUnless(c_functools, 'requires the C _functools module')
709class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
710 if c_functools:
711 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100712
Łukasz Langa6f692512013-06-05 12:20:24 +0200713
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200714class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100715 cmp_to_key = staticmethod(py_functools.cmp_to_key)
716
Łukasz Langa6f692512013-06-05 12:20:24 +0200717
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000718class TestTotalOrdering(unittest.TestCase):
719
720 def test_total_ordering_lt(self):
721 @functools.total_ordering
722 class A:
723 def __init__(self, value):
724 self.value = value
725 def __lt__(self, other):
726 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000727 def __eq__(self, other):
728 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000729 self.assertTrue(A(1) < A(2))
730 self.assertTrue(A(2) > A(1))
731 self.assertTrue(A(1) <= A(2))
732 self.assertTrue(A(2) >= A(1))
733 self.assertTrue(A(2) <= A(2))
734 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000735 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000736
737 def test_total_ordering_le(self):
738 @functools.total_ordering
739 class A:
740 def __init__(self, value):
741 self.value = value
742 def __le__(self, other):
743 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000744 def __eq__(self, other):
745 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000746 self.assertTrue(A(1) < A(2))
747 self.assertTrue(A(2) > A(1))
748 self.assertTrue(A(1) <= A(2))
749 self.assertTrue(A(2) >= A(1))
750 self.assertTrue(A(2) <= A(2))
751 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000752 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000753
754 def test_total_ordering_gt(self):
755 @functools.total_ordering
756 class A:
757 def __init__(self, value):
758 self.value = value
759 def __gt__(self, other):
760 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000761 def __eq__(self, other):
762 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000763 self.assertTrue(A(1) < A(2))
764 self.assertTrue(A(2) > A(1))
765 self.assertTrue(A(1) <= A(2))
766 self.assertTrue(A(2) >= A(1))
767 self.assertTrue(A(2) <= A(2))
768 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000769 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000770
771 def test_total_ordering_ge(self):
772 @functools.total_ordering
773 class A:
774 def __init__(self, value):
775 self.value = value
776 def __ge__(self, other):
777 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000778 def __eq__(self, other):
779 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000780 self.assertTrue(A(1) < A(2))
781 self.assertTrue(A(2) > A(1))
782 self.assertTrue(A(1) <= A(2))
783 self.assertTrue(A(2) >= A(1))
784 self.assertTrue(A(2) <= A(2))
785 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000786 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000787
788 def test_total_ordering_no_overwrite(self):
789 # new methods should not overwrite existing
790 @functools.total_ordering
791 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000792 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +0000793 self.assertTrue(A(1) < A(2))
794 self.assertTrue(A(2) > A(1))
795 self.assertTrue(A(1) <= A(2))
796 self.assertTrue(A(2) >= A(1))
797 self.assertTrue(A(2) <= A(2))
798 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000799
Benjamin Peterson42ebee32010-04-11 01:43:16 +0000800 def test_no_operations_defined(self):
801 with self.assertRaises(ValueError):
802 @functools.total_ordering
803 class A:
804 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000805
Nick Coghlanf05d9812013-10-02 00:02:03 +1000806 def test_type_error_when_not_implemented(self):
807 # bug 10042; ensure stack overflow does not occur
808 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000809 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +1000810 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000811 def __init__(self, value):
812 self.value = value
813 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000814 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000815 return self.value == other.value
816 return False
817 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000818 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000819 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +1000820 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000821
Nick Coghlanf05d9812013-10-02 00:02:03 +1000822 @functools.total_ordering
823 class ImplementsGreaterThan:
824 def __init__(self, value):
825 self.value = value
826 def __eq__(self, other):
827 if isinstance(other, ImplementsGreaterThan):
828 return self.value == other.value
829 return False
830 def __gt__(self, other):
831 if isinstance(other, ImplementsGreaterThan):
832 return self.value > other.value
833 return NotImplemented
834
835 @functools.total_ordering
836 class ImplementsLessThanEqualTo:
837 def __init__(self, value):
838 self.value = value
839 def __eq__(self, other):
840 if isinstance(other, ImplementsLessThanEqualTo):
841 return self.value == other.value
842 return False
843 def __le__(self, other):
844 if isinstance(other, ImplementsLessThanEqualTo):
845 return self.value <= other.value
846 return NotImplemented
847
848 @functools.total_ordering
849 class ImplementsGreaterThanEqualTo:
850 def __init__(self, value):
851 self.value = value
852 def __eq__(self, other):
853 if isinstance(other, ImplementsGreaterThanEqualTo):
854 return self.value == other.value
855 return False
856 def __ge__(self, other):
857 if isinstance(other, ImplementsGreaterThanEqualTo):
858 return self.value >= other.value
859 return NotImplemented
860
861 @functools.total_ordering
862 class ComparatorNotImplemented:
863 def __init__(self, value):
864 self.value = value
865 def __eq__(self, other):
866 if isinstance(other, ComparatorNotImplemented):
867 return self.value == other.value
868 return False
869 def __lt__(self, other):
870 return NotImplemented
871
872 with self.subTest("LT < 1"), self.assertRaises(TypeError):
873 ImplementsLessThan(-1) < 1
874
875 with self.subTest("LT < LE"), self.assertRaises(TypeError):
876 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
877
878 with self.subTest("LT < GT"), self.assertRaises(TypeError):
879 ImplementsLessThan(1) < ImplementsGreaterThan(1)
880
881 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
882 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
883
884 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
885 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
886
887 with self.subTest("GT > GE"), self.assertRaises(TypeError):
888 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
889
890 with self.subTest("GT > LT"), self.assertRaises(TypeError):
891 ImplementsGreaterThan(5) > ImplementsLessThan(5)
892
893 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
894 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
895
896 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
897 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
898
899 with self.subTest("GE when equal"):
900 a = ComparatorNotImplemented(8)
901 b = ComparatorNotImplemented(8)
902 self.assertEqual(a, b)
903 with self.assertRaises(TypeError):
904 a >= b
905
906 with self.subTest("LE when equal"):
907 a = ComparatorNotImplemented(9)
908 b = ComparatorNotImplemented(9)
909 self.assertEqual(a, b)
910 with self.assertRaises(TypeError):
911 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +0200912
Serhiy Storchaka697a5262015-01-01 15:23:12 +0200913 def test_pickle(self):
914 for proto in range(4, pickle.HIGHEST_PROTOCOL + 1):
915 for name in '__lt__', '__gt__', '__le__', '__ge__':
916 with self.subTest(method=name, proto=proto):
917 method = getattr(Orderable_LT, name)
918 method_copy = pickle.loads(pickle.dumps(method, proto))
919 self.assertIs(method_copy, method)
920
921@functools.total_ordering
922class Orderable_LT:
923 def __init__(self, value):
924 self.value = value
925 def __lt__(self, other):
926 return self.value < other.value
927 def __eq__(self, other):
928 return self.value == other.value
929
930
Serhiy Storchaka46c56112015-05-24 21:53:49 +0300931class TestLRU:
Georg Brandl2e7346a2010-07-31 18:09:23 +0000932
933 def test_lru(self):
934 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100935 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +0300936 f = self.module.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000937 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000938 self.assertEqual(maxsize, 20)
939 self.assertEqual(currsize, 0)
940 self.assertEqual(hits, 0)
941 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000942
943 domain = range(5)
944 for i in range(1000):
945 x, y = choice(domain), choice(domain)
946 actual = f(x, y)
947 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +0000948 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000949 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000950 self.assertTrue(hits > misses)
951 self.assertEqual(hits + misses, 1000)
952 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000953
Raymond Hettinger02566ec2010-09-04 22:46:06 +0000954 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +0000955 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000956 self.assertEqual(hits, 0)
957 self.assertEqual(misses, 0)
958 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000959 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000960 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000961 self.assertEqual(hits, 0)
962 self.assertEqual(misses, 1)
963 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000964
Nick Coghlan98876832010-08-17 06:17:18 +0000965 # Test bypassing the cache
966 self.assertIs(f.__wrapped__, orig)
967 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000968 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000969 self.assertEqual(hits, 0)
970 self.assertEqual(misses, 1)
971 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +0000972
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000973 # test size zero (which means "never-cache")
Serhiy Storchaka46c56112015-05-24 21:53:49 +0300974 @self.module.lru_cache(0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000975 def f():
976 nonlocal f_cnt
977 f_cnt += 1
978 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000979 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000980 f_cnt = 0
981 for i in range(5):
982 self.assertEqual(f(), 20)
983 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000984 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000985 self.assertEqual(hits, 0)
986 self.assertEqual(misses, 5)
987 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000988
989 # test size one
Serhiy Storchaka46c56112015-05-24 21:53:49 +0300990 @self.module.lru_cache(1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000991 def f():
992 nonlocal f_cnt
993 f_cnt += 1
994 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000995 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000996 f_cnt = 0
997 for i in range(5):
998 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000999 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001000 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001001 self.assertEqual(hits, 4)
1002 self.assertEqual(misses, 1)
1003 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001004
Raymond Hettingerf3098282010-08-15 03:30:45 +00001005 # test size two
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001006 @self.module.lru_cache(2)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001007 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001008 nonlocal f_cnt
1009 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +00001010 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +00001011 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001012 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +00001013 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1014 # * * * *
1015 self.assertEqual(f(x), x*10)
1016 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001017 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001018 self.assertEqual(hits, 12)
1019 self.assertEqual(misses, 4)
1020 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001021
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001022 def test_lru_with_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001023 @self.module.lru_cache(maxsize=None)
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001024 def fib(n):
1025 if n < 2:
1026 return n
1027 return fib(n-1) + fib(n-2)
1028 self.assertEqual([fib(n) for n in range(16)],
1029 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1030 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001031 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001032 fib.cache_clear()
1033 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001034 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1035
1036 def test_lru_with_maxsize_negative(self):
1037 @self.module.lru_cache(maxsize=-10)
1038 def eq(n):
1039 return n
1040 for i in (0, 1):
1041 self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1042 self.assertEqual(eq.cache_info(),
1043 self.module._CacheInfo(hits=0, misses=300, maxsize=-10, currsize=1))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001044
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001045 def test_lru_with_exceptions(self):
1046 # Verify that user_function exceptions get passed through without
1047 # creating a hard-to-read chained exception.
1048 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +01001049 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001050 @self.module.lru_cache(maxsize)
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001051 def func(i):
1052 return 'abc'[i]
1053 self.assertEqual(func(0), 'a')
1054 with self.assertRaises(IndexError) as cm:
1055 func(15)
1056 self.assertIsNone(cm.exception.__context__)
1057 # Verify that the previous exception did not result in a cached entry
1058 with self.assertRaises(IndexError):
1059 func(15)
1060
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001061 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001062 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001063 @self.module.lru_cache(maxsize=maxsize, typed=True)
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001064 def square(x):
1065 return x * x
1066 self.assertEqual(square(3), 9)
1067 self.assertEqual(type(square(3)), type(9))
1068 self.assertEqual(square(3.0), 9.0)
1069 self.assertEqual(type(square(3.0)), type(9.0))
1070 self.assertEqual(square(x=3), 9)
1071 self.assertEqual(type(square(x=3)), type(9))
1072 self.assertEqual(square(x=3.0), 9.0)
1073 self.assertEqual(type(square(x=3.0)), type(9.0))
1074 self.assertEqual(square.cache_info().hits, 4)
1075 self.assertEqual(square.cache_info().misses, 4)
1076
Antoine Pitroub5b37142012-11-13 21:35:40 +01001077 def test_lru_with_keyword_args(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001078 @self.module.lru_cache()
Antoine Pitroub5b37142012-11-13 21:35:40 +01001079 def fib(n):
1080 if n < 2:
1081 return n
1082 return fib(n=n-1) + fib(n=n-2)
1083 self.assertEqual(
1084 [fib(n=number) for number in range(16)],
1085 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1086 )
1087 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001088 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001089 fib.cache_clear()
1090 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001091 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001092
1093 def test_lru_with_keyword_args_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001094 @self.module.lru_cache(maxsize=None)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001095 def fib(n):
1096 if n < 2:
1097 return n
1098 return fib(n=n-1) + fib(n=n-2)
1099 self.assertEqual([fib(n=number) for number in range(16)],
1100 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1101 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001102 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001103 fib.cache_clear()
1104 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001105 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1106
1107 def test_lru_cache_decoration(self):
1108 def f(zomg: 'zomg_annotation'):
1109 """f doc string"""
1110 return 42
1111 g = self.module.lru_cache()(f)
1112 for attr in self.module.WRAPPER_ASSIGNMENTS:
1113 self.assertEqual(getattr(g, attr), getattr(f, attr))
1114
1115 @unittest.skipUnless(threading, 'This test requires threading.')
1116 def test_lru_cache_threaded(self):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001117 n, m = 5, 11
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001118 def orig(x, y):
1119 return 3 * x + y
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001120 f = self.module.lru_cache(maxsize=n*m)(orig)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001121 hits, misses, maxsize, currsize = f.cache_info()
1122 self.assertEqual(currsize, 0)
1123
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001124 start = threading.Event()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001125 def full(k):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001126 start.wait(10)
1127 for _ in range(m):
Serhiy Storchaka391af752015-06-08 12:44:18 +03001128 self.assertEqual(f(k, 0), orig(k, 0))
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001129
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001130 def clear():
1131 start.wait(10)
1132 for _ in range(2*m):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001133 f.cache_clear()
1134
1135 orig_si = sys.getswitchinterval()
1136 sys.setswitchinterval(1e-6)
1137 try:
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001138 # create n threads in order to fill cache
Serhiy Storchaka391af752015-06-08 12:44:18 +03001139 threads = [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001140 for k in range(n)]
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001141 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001142 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001143
1144 hits, misses, maxsize, currsize = f.cache_info()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001145 if self.module is py_functools:
1146 # XXX: Why can be not equal?
1147 self.assertLessEqual(misses, n)
1148 self.assertLessEqual(hits, m*n - misses)
1149 else:
1150 self.assertEqual(misses, n)
1151 self.assertEqual(hits, m*n - misses)
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001152 self.assertEqual(currsize, n)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001153
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001154 # create n threads in order to fill cache and 1 to clear it
1155 threads = [threading.Thread(target=clear)]
Serhiy Storchaka391af752015-06-08 12:44:18 +03001156 threads += [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001157 for k in range(n)]
1158 start.clear()
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001159 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001160 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001161 finally:
1162 sys.setswitchinterval(orig_si)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001163
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001164 @unittest.skipUnless(threading, 'This test requires threading.')
1165 def test_lru_cache_threaded2(self):
1166 # Simultaneous call with the same arguments
1167 n, m = 5, 7
1168 start = threading.Barrier(n+1)
1169 pause = threading.Barrier(n+1)
1170 stop = threading.Barrier(n+1)
1171 @self.module.lru_cache(maxsize=m*n)
1172 def f(x):
1173 pause.wait(10)
1174 return 3 * x
1175 self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1176 def test():
1177 for i in range(m):
1178 start.wait(10)
1179 self.assertEqual(f(i), 3 * i)
1180 stop.wait(10)
1181 threads = [threading.Thread(target=test) for k in range(n)]
1182 with support.start_threads(threads):
1183 for i in range(m):
1184 start.wait(10)
1185 stop.reset()
1186 pause.wait(10)
1187 start.reset()
1188 stop.wait(10)
1189 pause.reset()
1190 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1191
Raymond Hettinger03923422013-03-04 02:52:50 -05001192 def test_need_for_rlock(self):
1193 # This will deadlock on an LRU cache that uses a regular lock
1194
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001195 @self.module.lru_cache(maxsize=10)
Raymond Hettinger03923422013-03-04 02:52:50 -05001196 def test_func(x):
1197 'Used to demonstrate a reentrant lru_cache call within a single thread'
1198 return x
1199
1200 class DoubleEq:
1201 'Demonstrate a reentrant lru_cache call within a single thread'
1202 def __init__(self, x):
1203 self.x = x
1204 def __hash__(self):
1205 return self.x
1206 def __eq__(self, other):
1207 if self.x == 2:
1208 test_func(DoubleEq(1))
1209 return self.x == other.x
1210
1211 test_func(DoubleEq(1)) # Load the cache
1212 test_func(DoubleEq(2)) # Load the cache
1213 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1214 DoubleEq(2)) # Verify the correct return value
1215
Raymond Hettinger4d588972014-08-12 12:44:52 -07001216 def test_early_detection_of_bad_call(self):
1217 # Issue #22184
1218 with self.assertRaises(TypeError):
1219 @functools.lru_cache
1220 def f():
1221 pass
1222
Serhiy Storchakae7070f02015-06-08 11:19:24 +03001223 def test_lru_method(self):
1224 class X(int):
1225 f_cnt = 0
1226 @self.module.lru_cache(2)
1227 def f(self, x):
1228 self.f_cnt += 1
1229 return x*10+self
1230 a = X(5)
1231 b = X(5)
1232 c = X(7)
1233 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1234
1235 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1236 self.assertEqual(a.f(x), x*10 + 5)
1237 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1238 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1239
1240 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1241 self.assertEqual(b.f(x), x*10 + 5)
1242 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1243 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1244
1245 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1246 self.assertEqual(c.f(x), x*10 + 7)
1247 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1248 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1249
1250 self.assertEqual(a.f.cache_info(), X.f.cache_info())
1251 self.assertEqual(b.f.cache_info(), X.f.cache_info())
1252 self.assertEqual(c.f.cache_info(), X.f.cache_info())
1253
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001254class TestLRUC(TestLRU, unittest.TestCase):
1255 module = c_functools
1256
1257class TestLRUPy(TestLRU, unittest.TestCase):
1258 module = py_functools
1259
Raymond Hettinger03923422013-03-04 02:52:50 -05001260
Łukasz Langa6f692512013-06-05 12:20:24 +02001261class TestSingleDispatch(unittest.TestCase):
1262 def test_simple_overloads(self):
1263 @functools.singledispatch
1264 def g(obj):
1265 return "base"
1266 def g_int(i):
1267 return "integer"
1268 g.register(int, g_int)
1269 self.assertEqual(g("str"), "base")
1270 self.assertEqual(g(1), "integer")
1271 self.assertEqual(g([1,2,3]), "base")
1272
1273 def test_mro(self):
1274 @functools.singledispatch
1275 def g(obj):
1276 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001277 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001278 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001279 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001280 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001281 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001282 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001283 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001284 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001285 def g_A(a):
1286 return "A"
1287 def g_B(b):
1288 return "B"
1289 g.register(A, g_A)
1290 g.register(B, g_B)
1291 self.assertEqual(g(A()), "A")
1292 self.assertEqual(g(B()), "B")
1293 self.assertEqual(g(C()), "A")
1294 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001295
1296 def test_register_decorator(self):
1297 @functools.singledispatch
1298 def g(obj):
1299 return "base"
1300 @g.register(int)
1301 def g_int(i):
1302 return "int %s" % (i,)
1303 self.assertEqual(g(""), "base")
1304 self.assertEqual(g(12), "int 12")
1305 self.assertIs(g.dispatch(int), g_int)
1306 self.assertIs(g.dispatch(object), g.dispatch(str))
1307 # Note: in the assert above this is not g.
1308 # @singledispatch returns the wrapper.
1309
1310 def test_wrapping_attributes(self):
1311 @functools.singledispatch
1312 def g(obj):
1313 "Simple test"
1314 return "Test"
1315 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02001316 if sys.flags.optimize < 2:
1317 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02001318
1319 @unittest.skipUnless(decimal, 'requires _decimal')
1320 @support.cpython_only
1321 def test_c_classes(self):
1322 @functools.singledispatch
1323 def g(obj):
1324 return "base"
1325 @g.register(decimal.DecimalException)
1326 def _(obj):
1327 return obj.args
1328 subn = decimal.Subnormal("Exponent < Emin")
1329 rnd = decimal.Rounded("Number got rounded")
1330 self.assertEqual(g(subn), ("Exponent < Emin",))
1331 self.assertEqual(g(rnd), ("Number got rounded",))
1332 @g.register(decimal.Subnormal)
1333 def _(obj):
1334 return "Too small to care."
1335 self.assertEqual(g(subn), "Too small to care.")
1336 self.assertEqual(g(rnd), ("Number got rounded",))
1337
1338 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001339 # None of the examples in this test depend on haystack ordering.
Łukasz Langa6f692512013-06-05 12:20:24 +02001340 c = collections
1341 mro = functools._compose_mro
1342 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1343 for haystack in permutations(bases):
1344 m = mro(dict, haystack)
Łukasz Langa3720c772013-07-01 16:00:38 +02001345 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, c.Sized,
1346 c.Iterable, c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001347 bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict]
1348 for haystack in permutations(bases):
1349 m = mro(c.ChainMap, haystack)
1350 self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping,
1351 c.Sized, c.Iterable, c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001352
1353 # If there's a generic function with implementations registered for
1354 # both Sized and Container, passing a defaultdict to it results in an
1355 # ambiguous dispatch which will cause a RuntimeError (see
1356 # test_mro_conflicts).
1357 bases = [c.Container, c.Sized, str]
1358 for haystack in permutations(bases):
1359 m = mro(c.defaultdict, [c.Sized, c.Container, str])
1360 self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container,
1361 object])
1362
1363 # MutableSequence below is registered directly on D. In other words, it
1364 # preceeds MutableMapping which means single dispatch will always
1365 # choose MutableSequence here.
1366 class D(c.defaultdict):
1367 pass
1368 c.MutableSequence.register(D)
1369 bases = [c.MutableSequence, c.MutableMapping]
1370 for haystack in permutations(bases):
1371 m = mro(D, bases)
1372 self.assertEqual(m, [D, c.MutableSequence, c.Sequence,
1373 c.defaultdict, dict, c.MutableMapping,
1374 c.Mapping, c.Sized, c.Iterable, c.Container,
1375 object])
1376
1377 # Container and Callable are registered on different base classes and
1378 # a generic function supporting both should always pick the Callable
1379 # implementation if a C instance is passed.
1380 class C(c.defaultdict):
1381 def __call__(self):
1382 pass
1383 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1384 for haystack in permutations(bases):
1385 m = mro(C, haystack)
1386 self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping,
1387 c.Sized, c.Iterable, c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001388
1389 def test_register_abc(self):
1390 c = collections
1391 d = {"a": "b"}
1392 l = [1, 2, 3]
1393 s = {object(), None}
1394 f = frozenset(s)
1395 t = (1, 2, 3)
1396 @functools.singledispatch
1397 def g(obj):
1398 return "base"
1399 self.assertEqual(g(d), "base")
1400 self.assertEqual(g(l), "base")
1401 self.assertEqual(g(s), "base")
1402 self.assertEqual(g(f), "base")
1403 self.assertEqual(g(t), "base")
1404 g.register(c.Sized, lambda obj: "sized")
1405 self.assertEqual(g(d), "sized")
1406 self.assertEqual(g(l), "sized")
1407 self.assertEqual(g(s), "sized")
1408 self.assertEqual(g(f), "sized")
1409 self.assertEqual(g(t), "sized")
1410 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1411 self.assertEqual(g(d), "mutablemapping")
1412 self.assertEqual(g(l), "sized")
1413 self.assertEqual(g(s), "sized")
1414 self.assertEqual(g(f), "sized")
1415 self.assertEqual(g(t), "sized")
1416 g.register(c.ChainMap, lambda obj: "chainmap")
1417 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1418 self.assertEqual(g(l), "sized")
1419 self.assertEqual(g(s), "sized")
1420 self.assertEqual(g(f), "sized")
1421 self.assertEqual(g(t), "sized")
1422 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1423 self.assertEqual(g(d), "mutablemapping")
1424 self.assertEqual(g(l), "mutablesequence")
1425 self.assertEqual(g(s), "sized")
1426 self.assertEqual(g(f), "sized")
1427 self.assertEqual(g(t), "sized")
1428 g.register(c.MutableSet, lambda obj: "mutableset")
1429 self.assertEqual(g(d), "mutablemapping")
1430 self.assertEqual(g(l), "mutablesequence")
1431 self.assertEqual(g(s), "mutableset")
1432 self.assertEqual(g(f), "sized")
1433 self.assertEqual(g(t), "sized")
1434 g.register(c.Mapping, lambda obj: "mapping")
1435 self.assertEqual(g(d), "mutablemapping") # not specific enough
1436 self.assertEqual(g(l), "mutablesequence")
1437 self.assertEqual(g(s), "mutableset")
1438 self.assertEqual(g(f), "sized")
1439 self.assertEqual(g(t), "sized")
1440 g.register(c.Sequence, lambda obj: "sequence")
1441 self.assertEqual(g(d), "mutablemapping")
1442 self.assertEqual(g(l), "mutablesequence")
1443 self.assertEqual(g(s), "mutableset")
1444 self.assertEqual(g(f), "sized")
1445 self.assertEqual(g(t), "sequence")
1446 g.register(c.Set, lambda obj: "set")
1447 self.assertEqual(g(d), "mutablemapping")
1448 self.assertEqual(g(l), "mutablesequence")
1449 self.assertEqual(g(s), "mutableset")
1450 self.assertEqual(g(f), "set")
1451 self.assertEqual(g(t), "sequence")
1452 g.register(dict, lambda obj: "dict")
1453 self.assertEqual(g(d), "dict")
1454 self.assertEqual(g(l), "mutablesequence")
1455 self.assertEqual(g(s), "mutableset")
1456 self.assertEqual(g(f), "set")
1457 self.assertEqual(g(t), "sequence")
1458 g.register(list, lambda obj: "list")
1459 self.assertEqual(g(d), "dict")
1460 self.assertEqual(g(l), "list")
1461 self.assertEqual(g(s), "mutableset")
1462 self.assertEqual(g(f), "set")
1463 self.assertEqual(g(t), "sequence")
1464 g.register(set, lambda obj: "concrete-set")
1465 self.assertEqual(g(d), "dict")
1466 self.assertEqual(g(l), "list")
1467 self.assertEqual(g(s), "concrete-set")
1468 self.assertEqual(g(f), "set")
1469 self.assertEqual(g(t), "sequence")
1470 g.register(frozenset, lambda obj: "frozen-set")
1471 self.assertEqual(g(d), "dict")
1472 self.assertEqual(g(l), "list")
1473 self.assertEqual(g(s), "concrete-set")
1474 self.assertEqual(g(f), "frozen-set")
1475 self.assertEqual(g(t), "sequence")
1476 g.register(tuple, lambda obj: "tuple")
1477 self.assertEqual(g(d), "dict")
1478 self.assertEqual(g(l), "list")
1479 self.assertEqual(g(s), "concrete-set")
1480 self.assertEqual(g(f), "frozen-set")
1481 self.assertEqual(g(t), "tuple")
1482
Łukasz Langa3720c772013-07-01 16:00:38 +02001483 def test_c3_abc(self):
1484 c = collections
1485 mro = functools._c3_mro
1486 class A(object):
1487 pass
1488 class B(A):
1489 def __len__(self):
1490 return 0 # implies Sized
1491 @c.Container.register
1492 class C(object):
1493 pass
1494 class D(object):
1495 pass # unrelated
1496 class X(D, C, B):
1497 def __call__(self):
1498 pass # implies Callable
1499 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
1500 for abcs in permutations([c.Sized, c.Callable, c.Container]):
1501 self.assertEqual(mro(X, abcs=abcs), expected)
1502 # unrelated ABCs don't appear in the resulting MRO
1503 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
1504 self.assertEqual(mro(X, abcs=many_abcs), expected)
1505
Yury Selivanov77a8cd62015-08-18 14:20:00 -04001506 def test_false_meta(self):
1507 # see issue23572
1508 class MetaA(type):
1509 def __len__(self):
1510 return 0
1511 class A(metaclass=MetaA):
1512 pass
1513 class AA(A):
1514 pass
1515 @functools.singledispatch
1516 def fun(a):
1517 return 'base A'
1518 @fun.register(A)
1519 def _(a):
1520 return 'fun A'
1521 aa = AA()
1522 self.assertEqual(fun(aa), 'fun A')
1523
Łukasz Langa6f692512013-06-05 12:20:24 +02001524 def test_mro_conflicts(self):
1525 c = collections
Łukasz Langa6f692512013-06-05 12:20:24 +02001526 @functools.singledispatch
1527 def g(arg):
1528 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02001529 class O(c.Sized):
1530 def __len__(self):
1531 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001532 o = O()
1533 self.assertEqual(g(o), "base")
1534 g.register(c.Iterable, lambda arg: "iterable")
1535 g.register(c.Container, lambda arg: "container")
1536 g.register(c.Sized, lambda arg: "sized")
1537 g.register(c.Set, lambda arg: "set")
1538 self.assertEqual(g(o), "sized")
1539 c.Iterable.register(O)
1540 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
1541 c.Container.register(O)
1542 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02001543 c.Set.register(O)
1544 self.assertEqual(g(o), "set") # because c.Set is a subclass of
1545 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02001546 class P:
1547 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02001548 p = P()
1549 self.assertEqual(g(p), "base")
1550 c.Iterable.register(P)
1551 self.assertEqual(g(p), "iterable")
1552 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02001553 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02001554 g(p)
Łukasz Langa3720c772013-07-01 16:00:38 +02001555 self.assertIn(
1556 str(re_one.exception),
1557 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1558 "or <class 'collections.abc.Iterable'>"),
1559 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
1560 "or <class 'collections.abc.Container'>")),
1561 )
Łukasz Langa6f692512013-06-05 12:20:24 +02001562 class Q(c.Sized):
1563 def __len__(self):
1564 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001565 q = Q()
1566 self.assertEqual(g(q), "sized")
1567 c.Iterable.register(Q)
1568 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
1569 c.Set.register(Q)
1570 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02001571 # c.Sized and c.Iterable
1572 @functools.singledispatch
1573 def h(arg):
1574 return "base"
1575 @h.register(c.Sized)
1576 def _(arg):
1577 return "sized"
1578 @h.register(c.Container)
1579 def _(arg):
1580 return "container"
1581 # Even though Sized and Container are explicit bases of MutableMapping,
1582 # this ABC is implicitly registered on defaultdict which makes all of
1583 # MutableMapping's bases implicit as well from defaultdict's
1584 # perspective.
1585 with self.assertRaises(RuntimeError) as re_two:
1586 h(c.defaultdict(lambda: 0))
1587 self.assertIn(
1588 str(re_two.exception),
1589 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1590 "or <class 'collections.abc.Sized'>"),
1591 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1592 "or <class 'collections.abc.Container'>")),
1593 )
1594 class R(c.defaultdict):
1595 pass
1596 c.MutableSequence.register(R)
1597 @functools.singledispatch
1598 def i(arg):
1599 return "base"
1600 @i.register(c.MutableMapping)
1601 def _(arg):
1602 return "mapping"
1603 @i.register(c.MutableSequence)
1604 def _(arg):
1605 return "sequence"
1606 r = R()
1607 self.assertEqual(i(r), "sequence")
1608 class S:
1609 pass
1610 class T(S, c.Sized):
1611 def __len__(self):
1612 return 0
1613 t = T()
1614 self.assertEqual(h(t), "sized")
1615 c.Container.register(T)
1616 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
1617 class U:
1618 def __len__(self):
1619 return 0
1620 u = U()
1621 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
1622 # from the existence of __len__()
1623 c.Container.register(U)
1624 # There is no preference for registered versus inferred ABCs.
1625 with self.assertRaises(RuntimeError) as re_three:
1626 h(u)
1627 self.assertIn(
1628 str(re_three.exception),
1629 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1630 "or <class 'collections.abc.Sized'>"),
1631 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1632 "or <class 'collections.abc.Container'>")),
1633 )
1634 class V(c.Sized, S):
1635 def __len__(self):
1636 return 0
1637 @functools.singledispatch
1638 def j(arg):
1639 return "base"
1640 @j.register(S)
1641 def _(arg):
1642 return "s"
1643 @j.register(c.Container)
1644 def _(arg):
1645 return "container"
1646 v = V()
1647 self.assertEqual(j(v), "s")
1648 c.Container.register(V)
1649 self.assertEqual(j(v), "container") # because it ends up right after
1650 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02001651
1652 def test_cache_invalidation(self):
1653 from collections import UserDict
1654 class TracingDict(UserDict):
1655 def __init__(self, *args, **kwargs):
1656 super(TracingDict, self).__init__(*args, **kwargs)
1657 self.set_ops = []
1658 self.get_ops = []
1659 def __getitem__(self, key):
1660 result = self.data[key]
1661 self.get_ops.append(key)
1662 return result
1663 def __setitem__(self, key, value):
1664 self.set_ops.append(key)
1665 self.data[key] = value
1666 def clear(self):
1667 self.data.clear()
1668 _orig_wkd = functools.WeakKeyDictionary
1669 td = TracingDict()
1670 functools.WeakKeyDictionary = lambda: td
1671 c = collections
1672 @functools.singledispatch
1673 def g(arg):
1674 return "base"
1675 d = {}
1676 l = []
1677 self.assertEqual(len(td), 0)
1678 self.assertEqual(g(d), "base")
1679 self.assertEqual(len(td), 1)
1680 self.assertEqual(td.get_ops, [])
1681 self.assertEqual(td.set_ops, [dict])
1682 self.assertEqual(td.data[dict], g.registry[object])
1683 self.assertEqual(g(l), "base")
1684 self.assertEqual(len(td), 2)
1685 self.assertEqual(td.get_ops, [])
1686 self.assertEqual(td.set_ops, [dict, list])
1687 self.assertEqual(td.data[dict], g.registry[object])
1688 self.assertEqual(td.data[list], g.registry[object])
1689 self.assertEqual(td.data[dict], td.data[list])
1690 self.assertEqual(g(l), "base")
1691 self.assertEqual(g(d), "base")
1692 self.assertEqual(td.get_ops, [list, dict])
1693 self.assertEqual(td.set_ops, [dict, list])
1694 g.register(list, lambda arg: "list")
1695 self.assertEqual(td.get_ops, [list, dict])
1696 self.assertEqual(len(td), 0)
1697 self.assertEqual(g(d), "base")
1698 self.assertEqual(len(td), 1)
1699 self.assertEqual(td.get_ops, [list, dict])
1700 self.assertEqual(td.set_ops, [dict, list, dict])
1701 self.assertEqual(td.data[dict],
1702 functools._find_impl(dict, g.registry))
1703 self.assertEqual(g(l), "list")
1704 self.assertEqual(len(td), 2)
1705 self.assertEqual(td.get_ops, [list, dict])
1706 self.assertEqual(td.set_ops, [dict, list, dict, list])
1707 self.assertEqual(td.data[list],
1708 functools._find_impl(list, g.registry))
1709 class X:
1710 pass
1711 c.MutableMapping.register(X) # Will not invalidate the cache,
1712 # not using ABCs yet.
1713 self.assertEqual(g(d), "base")
1714 self.assertEqual(g(l), "list")
1715 self.assertEqual(td.get_ops, [list, dict, dict, list])
1716 self.assertEqual(td.set_ops, [dict, list, dict, list])
1717 g.register(c.Sized, lambda arg: "sized")
1718 self.assertEqual(len(td), 0)
1719 self.assertEqual(g(d), "sized")
1720 self.assertEqual(len(td), 1)
1721 self.assertEqual(td.get_ops, [list, dict, dict, list])
1722 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
1723 self.assertEqual(g(l), "list")
1724 self.assertEqual(len(td), 2)
1725 self.assertEqual(td.get_ops, [list, dict, dict, list])
1726 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1727 self.assertEqual(g(l), "list")
1728 self.assertEqual(g(d), "sized")
1729 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
1730 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1731 g.dispatch(list)
1732 g.dispatch(dict)
1733 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
1734 list, dict])
1735 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1736 c.MutableSet.register(X) # Will invalidate the cache.
1737 self.assertEqual(len(td), 2) # Stale cache.
1738 self.assertEqual(g(l), "list")
1739 self.assertEqual(len(td), 1)
1740 g.register(c.MutableMapping, lambda arg: "mutablemapping")
1741 self.assertEqual(len(td), 0)
1742 self.assertEqual(g(d), "mutablemapping")
1743 self.assertEqual(len(td), 1)
1744 self.assertEqual(g(l), "list")
1745 self.assertEqual(len(td), 2)
1746 g.register(dict, lambda arg: "dict")
1747 self.assertEqual(g(d), "dict")
1748 self.assertEqual(g(l), "list")
1749 g._clear_cache()
1750 self.assertEqual(len(td), 0)
1751 functools.WeakKeyDictionary = _orig_wkd
1752
1753
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001754if __name__ == '__main__':
Zachary Ware38c707e2015-04-13 15:00:43 -05001755 unittest.main()