blob: 271d655dd09098776af82d3a190da1f0b5459638 [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettinger003be522011-05-03 11:01:32 -07002import collections
Łukasz Langa6f692512013-06-05 12:20:24 +02003from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00004import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00005from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02006import sys
7from test import support
8import unittest
9from weakref import proxy
Serhiy Storchaka1c858c32015-05-23 22:42:49 +030010try:
11 import threading
12except ImportError:
13 threading = None
Raymond Hettinger9c323f82005-02-28 19:39:44 +000014
Antoine Pitroub5b37142012-11-13 21:35:40 +010015import functools
16
Antoine Pitroub5b37142012-11-13 21:35:40 +010017py_functools = support.import_fresh_module('functools', blocked=['_functools'])
18c_functools = support.import_fresh_module('functools', fresh=['_functools'])
19
Łukasz Langa6f692512013-06-05 12:20:24 +020020decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
21
22
Raymond Hettinger9c323f82005-02-28 19:39:44 +000023def capture(*args, **kw):
24 """capture all positional and keyword arguments"""
25 return args, kw
26
Łukasz Langa6f692512013-06-05 12:20:24 +020027
Jack Diederiche0cbd692009-04-01 04:27:09 +000028def signature(part):
29 """ return the signature of a partial object """
30 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000031
Łukasz Langa6f692512013-06-05 12:20:24 +020032
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020033class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000034
35 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010036 p = self.partial(capture, 1, 2, a=10, b=20)
37 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000038 self.assertEqual(p(3, 4, b=30, c=40),
39 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010040 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000041 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000042
43 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010044 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000045 # attributes should be readable
46 self.assertEqual(p.func, capture)
47 self.assertEqual(p.args, (1, 2))
48 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000049
50 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010051 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000052 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010053 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000054 except TypeError:
55 pass
56 else:
57 self.fail('First arg not checked for callability')
58
59 def test_protection_of_callers_dict_argument(self):
60 # a caller's dictionary should not be altered by partial
61 def func(a=10, b=20):
62 return a
63 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010064 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000065 self.assertEqual(p(**d), 3)
66 self.assertEqual(d, {'a':3})
67 p(b=7)
68 self.assertEqual(d, {'a':3})
69
70 def test_arg_combinations(self):
71 # exercise special code paths for zero args in either partial
72 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010073 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000074 self.assertEqual(p(), ((), {}))
75 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010076 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000077 self.assertEqual(p(), ((1,2), {}))
78 self.assertEqual(p(3,4), ((1,2,3,4), {}))
79
80 def test_kw_combinations(self):
81 # exercise special code paths for no keyword args in
82 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010083 p = self.partial(capture)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -040084 self.assertEqual(p.keywords, {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +000085 self.assertEqual(p(), ((), {}))
86 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010087 p = self.partial(capture, a=1)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -040088 self.assertEqual(p.keywords, {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +000089 self.assertEqual(p(), ((), {'a':1}))
90 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
91 # keyword args in the call override those in the partial object
92 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
93
94 def test_positional(self):
95 # make sure positional arguments are captured correctly
96 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +010097 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000098 expected = args + ('x',)
99 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000100 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000101
102 def test_keyword(self):
103 # make sure keyword arguments are captured correctly
104 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100105 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000106 expected = {'a':a,'x':None}
107 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000108 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000109
110 def test_no_side_effects(self):
111 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100112 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000113 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000114 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000115 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000116 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000117
118 def test_error_propagation(self):
119 def f(x, y):
120 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100121 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
122 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
123 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
124 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000125
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000126 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100127 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000128 p = proxy(f)
129 self.assertEqual(f.func, p.func)
130 f = None
131 self.assertRaises(ReferenceError, getattr, p, 'func')
132
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000133 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000134 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100135 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000136 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100137 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000138 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000139
Alexander Belopolskye49af342015-03-01 15:08:17 -0500140 def test_nested_optimization(self):
141 partial = self.partial
142 # Only "true" partial is optimized
143 if partial.__name__ != 'partial':
144 return
145 inner = partial(signature, 'asdf')
146 nested = partial(inner, bar=True)
147 flat = partial(signature, 'asdf', bar=True)
148 self.assertEqual(signature(nested), signature(flat))
149
Łukasz Langa6f692512013-06-05 12:20:24 +0200150
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200151@unittest.skipUnless(c_functools, 'requires the C _functools module')
152class TestPartialC(TestPartial, unittest.TestCase):
153 if c_functools:
154 partial = c_functools.partial
155
Zachary Ware101d9e72013-12-08 00:44:27 -0600156 def test_attributes_unwritable(self):
157 # attributes should not be writable
158 p = self.partial(capture, 1, 2, a=10, b=20)
159 self.assertRaises(AttributeError, setattr, p, 'func', map)
160 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
161 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
162
163 p = self.partial(hex)
164 try:
165 del p.__dict__
166 except TypeError:
167 pass
168 else:
169 self.fail('partial object allowed __dict__ to be deleted')
170
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000171 def test_repr(self):
172 args = (object(), object())
173 args_repr = ', '.join(repr(a) for a in args)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200174 kwargs = {'a': object(), 'b': object()}
175 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
176 'b={b!r}, a={a!r}'.format_map(kwargs)]
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200177 if self.partial is c_functools.partial:
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000178 name = 'functools.partial'
179 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100180 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000181
Antoine Pitroub5b37142012-11-13 21:35:40 +0100182 f = self.partial(capture)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000183 self.assertEqual('{}({!r})'.format(name, capture),
184 repr(f))
185
Antoine Pitroub5b37142012-11-13 21:35:40 +0100186 f = self.partial(capture, *args)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000187 self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr),
188 repr(f))
189
Antoine Pitroub5b37142012-11-13 21:35:40 +0100190 f = self.partial(capture, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200191 self.assertIn(repr(f),
192 ['{}({!r}, {})'.format(name, capture, kwargs_repr)
193 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000194
Antoine Pitroub5b37142012-11-13 21:35:40 +0100195 f = self.partial(capture, *args, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200196 self.assertIn(repr(f),
197 ['{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr)
198 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000199
Jack Diederiche0cbd692009-04-01 04:27:09 +0000200 def test_pickle(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100201 f = self.partial(signature, 'asdf', bar=True)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000202 f.add_something_to__dict__ = True
Serhiy Storchakabad12572014-12-15 14:03:42 +0200203 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
204 f_copy = pickle.loads(pickle.dumps(f, proto))
205 self.assertEqual(signature(f), signature(f_copy))
Jack Diederiche0cbd692009-04-01 04:27:09 +0000206
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200207 # Issue 6083: Reference counting bug
208 def test_setstate_refcount(self):
209 class BadSequence:
210 def __len__(self):
211 return 4
212 def __getitem__(self, key):
213 if key == 0:
214 return max
215 elif key == 1:
216 return tuple(range(1000000))
217 elif key in (2, 3):
218 return {}
219 raise IndexError
220
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200221 f = self.partial(object)
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200222 self.assertRaisesRegex(SystemError,
223 "new style getargs format but argument is not a tuple",
224 f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000225
Łukasz Langa6f692512013-06-05 12:20:24 +0200226
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200227class TestPartialPy(TestPartial, unittest.TestCase):
228 partial = staticmethod(py_functools.partial)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000229
Łukasz Langa6f692512013-06-05 12:20:24 +0200230
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200231if c_functools:
232 class PartialSubclass(c_functools.partial):
233 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100234
Łukasz Langa6f692512013-06-05 12:20:24 +0200235
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200236@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200237class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200238 if c_functools:
239 partial = PartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000240
Łukasz Langa6f692512013-06-05 12:20:24 +0200241
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000242class TestPartialMethod(unittest.TestCase):
243
244 class A(object):
245 nothing = functools.partialmethod(capture)
246 positional = functools.partialmethod(capture, 1)
247 keywords = functools.partialmethod(capture, a=2)
248 both = functools.partialmethod(capture, 3, b=4)
249
250 nested = functools.partialmethod(positional, 5)
251
252 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
253
254 static = functools.partialmethod(staticmethod(capture), 8)
255 cls = functools.partialmethod(classmethod(capture), d=9)
256
257 a = A()
258
259 def test_arg_combinations(self):
260 self.assertEqual(self.a.nothing(), ((self.a,), {}))
261 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
262 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
263 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
264
265 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
266 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
267 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
268 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
269
270 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
271 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
272 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
273 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
274
275 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
276 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
277 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
278 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
279
280 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
281
282 def test_nested(self):
283 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
284 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
285 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
286 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
287
288 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
289
290 def test_over_partial(self):
291 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
292 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
293 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
294 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
295
296 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
297
298 def test_bound_method_introspection(self):
299 obj = self.a
300 self.assertIs(obj.both.__self__, obj)
301 self.assertIs(obj.nested.__self__, obj)
302 self.assertIs(obj.over_partial.__self__, obj)
303 self.assertIs(obj.cls.__self__, self.A)
304 self.assertIs(self.A.cls.__self__, self.A)
305
306 def test_unbound_method_retrieval(self):
307 obj = self.A
308 self.assertFalse(hasattr(obj.both, "__self__"))
309 self.assertFalse(hasattr(obj.nested, "__self__"))
310 self.assertFalse(hasattr(obj.over_partial, "__self__"))
311 self.assertFalse(hasattr(obj.static, "__self__"))
312 self.assertFalse(hasattr(self.a.static, "__self__"))
313
314 def test_descriptors(self):
315 for obj in [self.A, self.a]:
316 with self.subTest(obj=obj):
317 self.assertEqual(obj.static(), ((8,), {}))
318 self.assertEqual(obj.static(5), ((8, 5), {}))
319 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
320 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
321
322 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
323 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
324 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
325 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
326
327 def test_overriding_keywords(self):
328 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
329 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
330
331 def test_invalid_args(self):
332 with self.assertRaises(TypeError):
333 class B(object):
334 method = functools.partialmethod(None, 1)
335
336 def test_repr(self):
337 self.assertEqual(repr(vars(self.A)['both']),
338 'functools.partialmethod({}, 3, b=4)'.format(capture))
339
340 def test_abstract(self):
341 class Abstract(abc.ABCMeta):
342
343 @abc.abstractmethod
344 def add(self, x, y):
345 pass
346
347 add5 = functools.partialmethod(add, 5)
348
349 self.assertTrue(Abstract.add.__isabstractmethod__)
350 self.assertTrue(Abstract.add5.__isabstractmethod__)
351
352 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
353 self.assertFalse(getattr(func, '__isabstractmethod__', False))
354
355
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000356class TestUpdateWrapper(unittest.TestCase):
357
358 def check_wrapper(self, wrapper, wrapped,
359 assigned=functools.WRAPPER_ASSIGNMENTS,
360 updated=functools.WRAPPER_UPDATES):
361 # Check attributes were assigned
362 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000363 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000364 # Check attributes were updated
365 for name in updated:
366 wrapper_attr = getattr(wrapper, name)
367 wrapped_attr = getattr(wrapped, name)
368 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000369 if name == "__dict__" and key == "__wrapped__":
370 # __wrapped__ is overwritten by the update code
371 continue
372 self.assertIs(wrapped_attr[key], wrapper_attr[key])
373 # Check __wrapped__
374 self.assertIs(wrapper.__wrapped__, wrapped)
375
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000376
R. David Murray378c0cf2010-02-24 01:46:21 +0000377 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000378 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000379 """This is a test"""
380 pass
381 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000382 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000383 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000384 pass
385 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000386 return wrapper, f
387
388 def test_default_update(self):
389 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000390 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000391 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000392 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600393 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000394 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000395 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
396 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000397
R. David Murray378c0cf2010-02-24 01:46:21 +0000398 @unittest.skipIf(sys.flags.optimize >= 2,
399 "Docstrings are omitted with -O2 and above")
400 def test_default_update_doc(self):
401 wrapper, f = self._default_update()
402 self.assertEqual(wrapper.__doc__, 'This is a test')
403
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000404 def test_no_update(self):
405 def f():
406 """This is a test"""
407 pass
408 f.attr = 'This is also a test'
409 def wrapper():
410 pass
411 functools.update_wrapper(wrapper, f, (), ())
412 self.check_wrapper(wrapper, f, (), ())
413 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600414 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000415 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000416 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000417 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000418
419 def test_selective_update(self):
420 def f():
421 pass
422 f.attr = 'This is a different test'
423 f.dict_attr = dict(a=1, b=2, c=3)
424 def wrapper():
425 pass
426 wrapper.dict_attr = {}
427 assign = ('attr',)
428 update = ('dict_attr',)
429 functools.update_wrapper(wrapper, f, assign, update)
430 self.check_wrapper(wrapper, f, assign, update)
431 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600432 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000433 self.assertEqual(wrapper.__doc__, None)
434 self.assertEqual(wrapper.attr, 'This is a different test')
435 self.assertEqual(wrapper.dict_attr, f.dict_attr)
436
Nick Coghlan98876832010-08-17 06:17:18 +0000437 def test_missing_attributes(self):
438 def f():
439 pass
440 def wrapper():
441 pass
442 wrapper.dict_attr = {}
443 assign = ('attr',)
444 update = ('dict_attr',)
445 # Missing attributes on wrapped object are ignored
446 functools.update_wrapper(wrapper, f, assign, update)
447 self.assertNotIn('attr', wrapper.__dict__)
448 self.assertEqual(wrapper.dict_attr, {})
449 # Wrapper must have expected attributes for updating
450 del wrapper.dict_attr
451 with self.assertRaises(AttributeError):
452 functools.update_wrapper(wrapper, f, assign, update)
453 wrapper.dict_attr = 1
454 with self.assertRaises(AttributeError):
455 functools.update_wrapper(wrapper, f, assign, update)
456
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200457 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000458 @unittest.skipIf(sys.flags.optimize >= 2,
459 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000460 def test_builtin_update(self):
461 # Test for bug #1576241
462 def wrapper():
463 pass
464 functools.update_wrapper(wrapper, max)
465 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000466 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000467 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000468
Łukasz Langa6f692512013-06-05 12:20:24 +0200469
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000470class TestWraps(TestUpdateWrapper):
471
R. David Murray378c0cf2010-02-24 01:46:21 +0000472 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000473 def f():
474 """This is a test"""
475 pass
476 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000477 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000478 @functools.wraps(f)
479 def wrapper():
480 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600481 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000482
483 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600484 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000485 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000486 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600487 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000488 self.assertEqual(wrapper.attr, 'This is also a test')
489
Antoine Pitroub5b37142012-11-13 21:35:40 +0100490 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000491 "Docstrings are omitted with -O2 and above")
492 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600493 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000494 self.assertEqual(wrapper.__doc__, 'This is a test')
495
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000496 def test_no_update(self):
497 def f():
498 """This is a test"""
499 pass
500 f.attr = 'This is also a test'
501 @functools.wraps(f, (), ())
502 def wrapper():
503 pass
504 self.check_wrapper(wrapper, f, (), ())
505 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600506 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000507 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000508 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000509
510 def test_selective_update(self):
511 def f():
512 pass
513 f.attr = 'This is a different test'
514 f.dict_attr = dict(a=1, b=2, c=3)
515 def add_dict_attr(f):
516 f.dict_attr = {}
517 return f
518 assign = ('attr',)
519 update = ('dict_attr',)
520 @functools.wraps(f, assign, update)
521 @add_dict_attr
522 def wrapper():
523 pass
524 self.check_wrapper(wrapper, f, assign, update)
525 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600526 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000527 self.assertEqual(wrapper.__doc__, None)
528 self.assertEqual(wrapper.attr, 'This is a different test')
529 self.assertEqual(wrapper.dict_attr, f.dict_attr)
530
Łukasz Langa6f692512013-06-05 12:20:24 +0200531
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000532class TestReduce(unittest.TestCase):
533 func = functools.reduce
534
535 def test_reduce(self):
536 class Squares:
537 def __init__(self, max):
538 self.max = max
539 self.sofar = []
540
541 def __len__(self):
542 return len(self.sofar)
543
544 def __getitem__(self, i):
545 if not 0 <= i < self.max: raise IndexError
546 n = len(self.sofar)
547 while n <= i:
548 self.sofar.append(n*n)
549 n += 1
550 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000551 def add(x, y):
552 return x + y
553 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000554 self.assertEqual(
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000555 self.func(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000556 ['a','c','d','w']
557 )
558 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
559 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000560 self.func(lambda x, y: x*y, range(2,21), 1),
561 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000562 )
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000563 self.assertEqual(self.func(add, Squares(10)), 285)
564 self.assertEqual(self.func(add, Squares(10), 0), 285)
565 self.assertEqual(self.func(add, Squares(0), 0), 0)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000566 self.assertRaises(TypeError, self.func)
567 self.assertRaises(TypeError, self.func, 42, 42)
568 self.assertRaises(TypeError, self.func, 42, 42, 42)
569 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
570 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
571 self.assertRaises(TypeError, self.func, 42, (42, 42))
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000572 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
573 self.assertRaises(TypeError, self.func, add, "")
574 self.assertRaises(TypeError, self.func, add, ())
575 self.assertRaises(TypeError, self.func, add, object())
576
577 class TestFailingIter:
578 def __iter__(self):
579 raise RuntimeError
580 self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
581
582 self.assertEqual(self.func(add, [], None), None)
583 self.assertEqual(self.func(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000584
585 class BadSeq:
586 def __getitem__(self, index):
587 raise ValueError
588 self.assertRaises(ValueError, self.func, 42, BadSeq())
589
590 # Test reduce()'s use of iterators.
591 def test_iterator_usage(self):
592 class SequenceClass:
593 def __init__(self, n):
594 self.n = n
595 def __getitem__(self, i):
596 if 0 <= i < self.n:
597 return i
598 else:
599 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000600
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000601 from operator import add
602 self.assertEqual(self.func(add, SequenceClass(5)), 10)
603 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
604 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
605 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
606 self.assertEqual(self.func(add, SequenceClass(1)), 0)
607 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
608
609 d = {"one": 1, "two": 2, "three": 3}
610 self.assertEqual(self.func(add, d), "".join(d.keys()))
611
Łukasz Langa6f692512013-06-05 12:20:24 +0200612
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200613class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700614
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000615 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700616 def cmp1(x, y):
617 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100618 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700619 self.assertEqual(key(3), key(3))
620 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100621 self.assertGreaterEqual(key(3), key(3))
622
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700623 def cmp2(x, y):
624 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100625 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700626 self.assertEqual(key(4.0), key('4'))
627 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100628 self.assertLessEqual(key(2), key('35'))
629 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700630
631 def test_cmp_to_key_arguments(self):
632 def cmp1(x, y):
633 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100634 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700635 self.assertEqual(key(obj=3), key(obj=3))
636 self.assertGreater(key(obj=3), key(obj=1))
637 with self.assertRaises((TypeError, AttributeError)):
638 key(3) > 1 # rhs is not a K object
639 with self.assertRaises((TypeError, AttributeError)):
640 1 < key(3) # lhs is not a K object
641 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100642 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700643 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200644 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100645 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700646 with self.assertRaises(TypeError):
647 key() # too few args
648 with self.assertRaises(TypeError):
649 key(None, None) # too many args
650
651 def test_bad_cmp(self):
652 def cmp1(x, y):
653 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100654 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700655 with self.assertRaises(ZeroDivisionError):
656 key(3) > key(1)
657
658 class BadCmp:
659 def __lt__(self, other):
660 raise ZeroDivisionError
661 def cmp1(x, y):
662 return BadCmp()
663 with self.assertRaises(ZeroDivisionError):
664 key(3) > key(1)
665
666 def test_obj_field(self):
667 def cmp1(x, y):
668 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100669 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700670 self.assertEqual(key(50).obj, 50)
671
672 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000673 def mycmp(x, y):
674 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100675 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000676 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000677
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700678 def test_sort_int_str(self):
679 def mycmp(x, y):
680 x, y = int(x), int(y)
681 return (x > y) - (x < y)
682 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100683 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700684 self.assertEqual([int(value) for value in values],
685 [0, 1, 1, 2, 3, 4, 5, 7, 10])
686
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000687 def test_hash(self):
688 def mycmp(x, y):
689 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100690 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000691 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700692 self.assertRaises(TypeError, hash, k)
Raymond Hettingere7a24302011-05-03 11:16:36 -0700693 self.assertNotIsInstance(k, collections.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000694
Łukasz Langa6f692512013-06-05 12:20:24 +0200695
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200696@unittest.skipUnless(c_functools, 'requires the C _functools module')
697class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
698 if c_functools:
699 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100700
Łukasz Langa6f692512013-06-05 12:20:24 +0200701
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200702class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100703 cmp_to_key = staticmethod(py_functools.cmp_to_key)
704
Łukasz Langa6f692512013-06-05 12:20:24 +0200705
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000706class TestTotalOrdering(unittest.TestCase):
707
708 def test_total_ordering_lt(self):
709 @functools.total_ordering
710 class A:
711 def __init__(self, value):
712 self.value = value
713 def __lt__(self, other):
714 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000715 def __eq__(self, other):
716 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000717 self.assertTrue(A(1) < A(2))
718 self.assertTrue(A(2) > A(1))
719 self.assertTrue(A(1) <= A(2))
720 self.assertTrue(A(2) >= A(1))
721 self.assertTrue(A(2) <= A(2))
722 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000723 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000724
725 def test_total_ordering_le(self):
726 @functools.total_ordering
727 class A:
728 def __init__(self, value):
729 self.value = value
730 def __le__(self, other):
731 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000732 def __eq__(self, other):
733 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000734 self.assertTrue(A(1) < A(2))
735 self.assertTrue(A(2) > A(1))
736 self.assertTrue(A(1) <= A(2))
737 self.assertTrue(A(2) >= A(1))
738 self.assertTrue(A(2) <= A(2))
739 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000740 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000741
742 def test_total_ordering_gt(self):
743 @functools.total_ordering
744 class A:
745 def __init__(self, value):
746 self.value = value
747 def __gt__(self, other):
748 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000749 def __eq__(self, other):
750 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000751 self.assertTrue(A(1) < A(2))
752 self.assertTrue(A(2) > A(1))
753 self.assertTrue(A(1) <= A(2))
754 self.assertTrue(A(2) >= A(1))
755 self.assertTrue(A(2) <= A(2))
756 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000757 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000758
759 def test_total_ordering_ge(self):
760 @functools.total_ordering
761 class A:
762 def __init__(self, value):
763 self.value = value
764 def __ge__(self, other):
765 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000766 def __eq__(self, other):
767 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000768 self.assertTrue(A(1) < A(2))
769 self.assertTrue(A(2) > A(1))
770 self.assertTrue(A(1) <= A(2))
771 self.assertTrue(A(2) >= A(1))
772 self.assertTrue(A(2) <= A(2))
773 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000774 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000775
776 def test_total_ordering_no_overwrite(self):
777 # new methods should not overwrite existing
778 @functools.total_ordering
779 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000780 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +0000781 self.assertTrue(A(1) < A(2))
782 self.assertTrue(A(2) > A(1))
783 self.assertTrue(A(1) <= A(2))
784 self.assertTrue(A(2) >= A(1))
785 self.assertTrue(A(2) <= A(2))
786 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000787
Benjamin Peterson42ebee32010-04-11 01:43:16 +0000788 def test_no_operations_defined(self):
789 with self.assertRaises(ValueError):
790 @functools.total_ordering
791 class A:
792 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000793
Nick Coghlanf05d9812013-10-02 00:02:03 +1000794 def test_type_error_when_not_implemented(self):
795 # bug 10042; ensure stack overflow does not occur
796 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000797 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +1000798 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000799 def __init__(self, value):
800 self.value = value
801 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000802 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000803 return self.value == other.value
804 return False
805 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000806 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000807 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +1000808 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000809
Nick Coghlanf05d9812013-10-02 00:02:03 +1000810 @functools.total_ordering
811 class ImplementsGreaterThan:
812 def __init__(self, value):
813 self.value = value
814 def __eq__(self, other):
815 if isinstance(other, ImplementsGreaterThan):
816 return self.value == other.value
817 return False
818 def __gt__(self, other):
819 if isinstance(other, ImplementsGreaterThan):
820 return self.value > other.value
821 return NotImplemented
822
823 @functools.total_ordering
824 class ImplementsLessThanEqualTo:
825 def __init__(self, value):
826 self.value = value
827 def __eq__(self, other):
828 if isinstance(other, ImplementsLessThanEqualTo):
829 return self.value == other.value
830 return False
831 def __le__(self, other):
832 if isinstance(other, ImplementsLessThanEqualTo):
833 return self.value <= other.value
834 return NotImplemented
835
836 @functools.total_ordering
837 class ImplementsGreaterThanEqualTo:
838 def __init__(self, value):
839 self.value = value
840 def __eq__(self, other):
841 if isinstance(other, ImplementsGreaterThanEqualTo):
842 return self.value == other.value
843 return False
844 def __ge__(self, other):
845 if isinstance(other, ImplementsGreaterThanEqualTo):
846 return self.value >= other.value
847 return NotImplemented
848
849 @functools.total_ordering
850 class ComparatorNotImplemented:
851 def __init__(self, value):
852 self.value = value
853 def __eq__(self, other):
854 if isinstance(other, ComparatorNotImplemented):
855 return self.value == other.value
856 return False
857 def __lt__(self, other):
858 return NotImplemented
859
860 with self.subTest("LT < 1"), self.assertRaises(TypeError):
861 ImplementsLessThan(-1) < 1
862
863 with self.subTest("LT < LE"), self.assertRaises(TypeError):
864 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
865
866 with self.subTest("LT < GT"), self.assertRaises(TypeError):
867 ImplementsLessThan(1) < ImplementsGreaterThan(1)
868
869 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
870 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
871
872 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
873 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
874
875 with self.subTest("GT > GE"), self.assertRaises(TypeError):
876 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
877
878 with self.subTest("GT > LT"), self.assertRaises(TypeError):
879 ImplementsGreaterThan(5) > ImplementsLessThan(5)
880
881 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
882 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
883
884 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
885 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
886
887 with self.subTest("GE when equal"):
888 a = ComparatorNotImplemented(8)
889 b = ComparatorNotImplemented(8)
890 self.assertEqual(a, b)
891 with self.assertRaises(TypeError):
892 a >= b
893
894 with self.subTest("LE when equal"):
895 a = ComparatorNotImplemented(9)
896 b = ComparatorNotImplemented(9)
897 self.assertEqual(a, b)
898 with self.assertRaises(TypeError):
899 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +0200900
Serhiy Storchaka697a5262015-01-01 15:23:12 +0200901 def test_pickle(self):
902 for proto in range(4, pickle.HIGHEST_PROTOCOL + 1):
903 for name in '__lt__', '__gt__', '__le__', '__ge__':
904 with self.subTest(method=name, proto=proto):
905 method = getattr(Orderable_LT, name)
906 method_copy = pickle.loads(pickle.dumps(method, proto))
907 self.assertIs(method_copy, method)
908
909@functools.total_ordering
910class Orderable_LT:
911 def __init__(self, value):
912 self.value = value
913 def __lt__(self, other):
914 return self.value < other.value
915 def __eq__(self, other):
916 return self.value == other.value
917
918
Serhiy Storchaka1c858c32015-05-23 22:42:49 +0300919class TestLRU:
Georg Brandl2e7346a2010-07-31 18:09:23 +0000920
921 def test_lru(self):
922 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100923 return 3 * x + y
Serhiy Storchaka1c858c32015-05-23 22:42:49 +0300924 f = self.module.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000925 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000926 self.assertEqual(maxsize, 20)
927 self.assertEqual(currsize, 0)
928 self.assertEqual(hits, 0)
929 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000930
931 domain = range(5)
932 for i in range(1000):
933 x, y = choice(domain), choice(domain)
934 actual = f(x, y)
935 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +0000936 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000937 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000938 self.assertTrue(hits > misses)
939 self.assertEqual(hits + misses, 1000)
940 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000941
Raymond Hettinger02566ec2010-09-04 22:46:06 +0000942 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +0000943 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000944 self.assertEqual(hits, 0)
945 self.assertEqual(misses, 0)
946 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000947 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000948 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000949 self.assertEqual(hits, 0)
950 self.assertEqual(misses, 1)
951 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000952
Nick Coghlan98876832010-08-17 06:17:18 +0000953 # Test bypassing the cache
954 self.assertIs(f.__wrapped__, orig)
955 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000956 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000957 self.assertEqual(hits, 0)
958 self.assertEqual(misses, 1)
959 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +0000960
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000961 # test size zero (which means "never-cache")
Serhiy Storchaka1c858c32015-05-23 22:42:49 +0300962 @self.module.lru_cache(0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000963 def f():
964 nonlocal f_cnt
965 f_cnt += 1
966 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000967 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000968 f_cnt = 0
969 for i in range(5):
970 self.assertEqual(f(), 20)
971 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000972 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000973 self.assertEqual(hits, 0)
974 self.assertEqual(misses, 5)
975 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000976
977 # test size one
Serhiy Storchaka1c858c32015-05-23 22:42:49 +0300978 @self.module.lru_cache(1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000979 def f():
980 nonlocal f_cnt
981 f_cnt += 1
982 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000983 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000984 f_cnt = 0
985 for i in range(5):
986 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000987 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000988 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000989 self.assertEqual(hits, 4)
990 self.assertEqual(misses, 1)
991 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000992
Raymond Hettingerf3098282010-08-15 03:30:45 +0000993 # test size two
Serhiy Storchaka1c858c32015-05-23 22:42:49 +0300994 @self.module.lru_cache(2)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000995 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000996 nonlocal f_cnt
997 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +0000998 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +0000999 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001000 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +00001001 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1002 # * * * *
1003 self.assertEqual(f(x), x*10)
1004 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001005 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001006 self.assertEqual(hits, 12)
1007 self.assertEqual(misses, 4)
1008 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001009
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001010 def test_lru_with_maxsize_none(self):
Serhiy Storchaka1c858c32015-05-23 22:42:49 +03001011 @self.module.lru_cache(maxsize=None)
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001012 def fib(n):
1013 if n < 2:
1014 return n
1015 return fib(n-1) + fib(n-2)
1016 self.assertEqual([fib(n) for n in range(16)],
1017 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1018 self.assertEqual(fib.cache_info(),
Serhiy Storchaka1c858c32015-05-23 22:42:49 +03001019 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001020 fib.cache_clear()
1021 self.assertEqual(fib.cache_info(),
Serhiy Storchaka1c858c32015-05-23 22:42:49 +03001022 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1023
1024 def test_lru_with_maxsize_negative(self):
1025 @self.module.lru_cache(maxsize=-10)
1026 def eq(n):
1027 return n
1028 for i in (0, 1):
1029 self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1030 self.assertEqual(eq.cache_info(),
1031 self.module._CacheInfo(hits=0, misses=300, maxsize=-10, currsize=1))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001032
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001033 def test_lru_with_exceptions(self):
1034 # Verify that user_function exceptions get passed through without
1035 # creating a hard-to-read chained exception.
1036 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +01001037 for maxsize in (None, 128):
Serhiy Storchaka1c858c32015-05-23 22:42:49 +03001038 @self.module.lru_cache(maxsize)
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001039 def func(i):
1040 return 'abc'[i]
1041 self.assertEqual(func(0), 'a')
1042 with self.assertRaises(IndexError) as cm:
1043 func(15)
1044 self.assertIsNone(cm.exception.__context__)
1045 # Verify that the previous exception did not result in a cached entry
1046 with self.assertRaises(IndexError):
1047 func(15)
1048
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001049 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001050 for maxsize in (None, 128):
Serhiy Storchaka1c858c32015-05-23 22:42:49 +03001051 @self.module.lru_cache(maxsize=maxsize, typed=True)
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001052 def square(x):
1053 return x * x
1054 self.assertEqual(square(3), 9)
1055 self.assertEqual(type(square(3)), type(9))
1056 self.assertEqual(square(3.0), 9.0)
1057 self.assertEqual(type(square(3.0)), type(9.0))
1058 self.assertEqual(square(x=3), 9)
1059 self.assertEqual(type(square(x=3)), type(9))
1060 self.assertEqual(square(x=3.0), 9.0)
1061 self.assertEqual(type(square(x=3.0)), type(9.0))
1062 self.assertEqual(square.cache_info().hits, 4)
1063 self.assertEqual(square.cache_info().misses, 4)
1064
Antoine Pitroub5b37142012-11-13 21:35:40 +01001065 def test_lru_with_keyword_args(self):
Serhiy Storchaka1c858c32015-05-23 22:42:49 +03001066 @self.module.lru_cache()
Antoine Pitroub5b37142012-11-13 21:35:40 +01001067 def fib(n):
1068 if n < 2:
1069 return n
1070 return fib(n=n-1) + fib(n=n-2)
1071 self.assertEqual(
1072 [fib(n=number) for number in range(16)],
1073 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1074 )
1075 self.assertEqual(fib.cache_info(),
Serhiy Storchaka1c858c32015-05-23 22:42:49 +03001076 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001077 fib.cache_clear()
1078 self.assertEqual(fib.cache_info(),
Serhiy Storchaka1c858c32015-05-23 22:42:49 +03001079 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001080
1081 def test_lru_with_keyword_args_maxsize_none(self):
Serhiy Storchaka1c858c32015-05-23 22:42:49 +03001082 @self.module.lru_cache(maxsize=None)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001083 def fib(n):
1084 if n < 2:
1085 return n
1086 return fib(n=n-1) + fib(n=n-2)
1087 self.assertEqual([fib(n=number) for number in range(16)],
1088 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1089 self.assertEqual(fib.cache_info(),
Serhiy Storchaka1c858c32015-05-23 22:42:49 +03001090 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001091 fib.cache_clear()
1092 self.assertEqual(fib.cache_info(),
Serhiy Storchaka1c858c32015-05-23 22:42:49 +03001093 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1094
1095 def test_lru_cache_decoration(self):
1096 def f(zomg: 'zomg_annotation'):
1097 """f doc string"""
1098 return 42
1099 g = self.module.lru_cache()(f)
1100 for attr in self.module.WRAPPER_ASSIGNMENTS:
1101 self.assertEqual(getattr(g, attr), getattr(f, attr))
1102
1103 @unittest.skipUnless(threading, 'This test requires threading.')
1104 def test_lru_cache_threaded(self):
1105 def orig(x, y):
1106 return 3 * x + y
1107 f = self.module.lru_cache(maxsize=20)(orig)
1108 hits, misses, maxsize, currsize = f.cache_info()
1109 self.assertEqual(currsize, 0)
1110
1111 def full(f, *args):
1112 for _ in range(10):
1113 f(*args)
1114
1115 def clear(f):
1116 for _ in range(10):
1117 f.cache_clear()
1118
1119 orig_si = sys.getswitchinterval()
1120 sys.setswitchinterval(1e-6)
1121 try:
1122 # create 5 threads in order to fill cache
1123 threads = []
1124 for k in range(5):
1125 t = threading.Thread(target=full, args=[f, k, k])
1126 t.start()
1127 threads.append(t)
1128
1129 for t in threads:
1130 t.join()
1131
1132 hits, misses, maxsize, currsize = f.cache_info()
1133 self.assertEqual(hits, 45)
1134 self.assertEqual(misses, 5)
1135 self.assertEqual(currsize, 5)
1136
1137 # create 5 threads in order to fill cache and 1 to clear it
1138 cleaner = threading.Thread(target=clear, args=[f])
1139 cleaner.start()
1140 threads = [cleaner]
1141 for k in range(5):
1142 t = threading.Thread(target=full, args=[f, k, k])
1143 t.start()
1144 threads.append(t)
1145
1146 for t in threads:
1147 t.join()
1148 finally:
1149 sys.setswitchinterval(orig_si)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001150
Raymond Hettinger03923422013-03-04 02:52:50 -05001151 def test_need_for_rlock(self):
1152 # This will deadlock on an LRU cache that uses a regular lock
1153
Serhiy Storchaka1c858c32015-05-23 22:42:49 +03001154 @self.module.lru_cache(maxsize=10)
Raymond Hettinger03923422013-03-04 02:52:50 -05001155 def test_func(x):
1156 'Used to demonstrate a reentrant lru_cache call within a single thread'
1157 return x
1158
1159 class DoubleEq:
1160 'Demonstrate a reentrant lru_cache call within a single thread'
1161 def __init__(self, x):
1162 self.x = x
1163 def __hash__(self):
1164 return self.x
1165 def __eq__(self, other):
1166 if self.x == 2:
1167 test_func(DoubleEq(1))
1168 return self.x == other.x
1169
1170 test_func(DoubleEq(1)) # Load the cache
1171 test_func(DoubleEq(2)) # Load the cache
1172 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1173 DoubleEq(2)) # Verify the correct return value
1174
Raymond Hettinger4d588972014-08-12 12:44:52 -07001175 def test_early_detection_of_bad_call(self):
1176 # Issue #22184
1177 with self.assertRaises(TypeError):
1178 @functools.lru_cache
1179 def f():
1180 pass
1181
Serhiy Storchaka1c858c32015-05-23 22:42:49 +03001182class TestLRUC(TestLRU, unittest.TestCase):
1183 module = c_functools
1184
1185class TestLRUPy(TestLRU, unittest.TestCase):
1186 module = py_functools
1187
Raymond Hettinger03923422013-03-04 02:52:50 -05001188
Łukasz Langa6f692512013-06-05 12:20:24 +02001189class TestSingleDispatch(unittest.TestCase):
1190 def test_simple_overloads(self):
1191 @functools.singledispatch
1192 def g(obj):
1193 return "base"
1194 def g_int(i):
1195 return "integer"
1196 g.register(int, g_int)
1197 self.assertEqual(g("str"), "base")
1198 self.assertEqual(g(1), "integer")
1199 self.assertEqual(g([1,2,3]), "base")
1200
1201 def test_mro(self):
1202 @functools.singledispatch
1203 def g(obj):
1204 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001205 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001206 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001207 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001208 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001209 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001210 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001211 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001212 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001213 def g_A(a):
1214 return "A"
1215 def g_B(b):
1216 return "B"
1217 g.register(A, g_A)
1218 g.register(B, g_B)
1219 self.assertEqual(g(A()), "A")
1220 self.assertEqual(g(B()), "B")
1221 self.assertEqual(g(C()), "A")
1222 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001223
1224 def test_register_decorator(self):
1225 @functools.singledispatch
1226 def g(obj):
1227 return "base"
1228 @g.register(int)
1229 def g_int(i):
1230 return "int %s" % (i,)
1231 self.assertEqual(g(""), "base")
1232 self.assertEqual(g(12), "int 12")
1233 self.assertIs(g.dispatch(int), g_int)
1234 self.assertIs(g.dispatch(object), g.dispatch(str))
1235 # Note: in the assert above this is not g.
1236 # @singledispatch returns the wrapper.
1237
1238 def test_wrapping_attributes(self):
1239 @functools.singledispatch
1240 def g(obj):
1241 "Simple test"
1242 return "Test"
1243 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02001244 if sys.flags.optimize < 2:
1245 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02001246
1247 @unittest.skipUnless(decimal, 'requires _decimal')
1248 @support.cpython_only
1249 def test_c_classes(self):
1250 @functools.singledispatch
1251 def g(obj):
1252 return "base"
1253 @g.register(decimal.DecimalException)
1254 def _(obj):
1255 return obj.args
1256 subn = decimal.Subnormal("Exponent < Emin")
1257 rnd = decimal.Rounded("Number got rounded")
1258 self.assertEqual(g(subn), ("Exponent < Emin",))
1259 self.assertEqual(g(rnd), ("Number got rounded",))
1260 @g.register(decimal.Subnormal)
1261 def _(obj):
1262 return "Too small to care."
1263 self.assertEqual(g(subn), "Too small to care.")
1264 self.assertEqual(g(rnd), ("Number got rounded",))
1265
1266 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001267 # None of the examples in this test depend on haystack ordering.
Łukasz Langa6f692512013-06-05 12:20:24 +02001268 c = collections
1269 mro = functools._compose_mro
1270 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1271 for haystack in permutations(bases):
1272 m = mro(dict, haystack)
Łukasz Langa3720c772013-07-01 16:00:38 +02001273 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, c.Sized,
1274 c.Iterable, c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001275 bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict]
1276 for haystack in permutations(bases):
1277 m = mro(c.ChainMap, haystack)
1278 self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping,
1279 c.Sized, c.Iterable, c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001280
1281 # If there's a generic function with implementations registered for
1282 # both Sized and Container, passing a defaultdict to it results in an
1283 # ambiguous dispatch which will cause a RuntimeError (see
1284 # test_mro_conflicts).
1285 bases = [c.Container, c.Sized, str]
1286 for haystack in permutations(bases):
1287 m = mro(c.defaultdict, [c.Sized, c.Container, str])
1288 self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container,
1289 object])
1290
1291 # MutableSequence below is registered directly on D. In other words, it
1292 # preceeds MutableMapping which means single dispatch will always
1293 # choose MutableSequence here.
1294 class D(c.defaultdict):
1295 pass
1296 c.MutableSequence.register(D)
1297 bases = [c.MutableSequence, c.MutableMapping]
1298 for haystack in permutations(bases):
1299 m = mro(D, bases)
1300 self.assertEqual(m, [D, c.MutableSequence, c.Sequence,
1301 c.defaultdict, dict, c.MutableMapping,
1302 c.Mapping, c.Sized, c.Iterable, c.Container,
1303 object])
1304
1305 # Container and Callable are registered on different base classes and
1306 # a generic function supporting both should always pick the Callable
1307 # implementation if a C instance is passed.
1308 class C(c.defaultdict):
1309 def __call__(self):
1310 pass
1311 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1312 for haystack in permutations(bases):
1313 m = mro(C, haystack)
1314 self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping,
1315 c.Sized, c.Iterable, c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001316
1317 def test_register_abc(self):
1318 c = collections
1319 d = {"a": "b"}
1320 l = [1, 2, 3]
1321 s = {object(), None}
1322 f = frozenset(s)
1323 t = (1, 2, 3)
1324 @functools.singledispatch
1325 def g(obj):
1326 return "base"
1327 self.assertEqual(g(d), "base")
1328 self.assertEqual(g(l), "base")
1329 self.assertEqual(g(s), "base")
1330 self.assertEqual(g(f), "base")
1331 self.assertEqual(g(t), "base")
1332 g.register(c.Sized, lambda obj: "sized")
1333 self.assertEqual(g(d), "sized")
1334 self.assertEqual(g(l), "sized")
1335 self.assertEqual(g(s), "sized")
1336 self.assertEqual(g(f), "sized")
1337 self.assertEqual(g(t), "sized")
1338 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1339 self.assertEqual(g(d), "mutablemapping")
1340 self.assertEqual(g(l), "sized")
1341 self.assertEqual(g(s), "sized")
1342 self.assertEqual(g(f), "sized")
1343 self.assertEqual(g(t), "sized")
1344 g.register(c.ChainMap, lambda obj: "chainmap")
1345 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1346 self.assertEqual(g(l), "sized")
1347 self.assertEqual(g(s), "sized")
1348 self.assertEqual(g(f), "sized")
1349 self.assertEqual(g(t), "sized")
1350 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1351 self.assertEqual(g(d), "mutablemapping")
1352 self.assertEqual(g(l), "mutablesequence")
1353 self.assertEqual(g(s), "sized")
1354 self.assertEqual(g(f), "sized")
1355 self.assertEqual(g(t), "sized")
1356 g.register(c.MutableSet, lambda obj: "mutableset")
1357 self.assertEqual(g(d), "mutablemapping")
1358 self.assertEqual(g(l), "mutablesequence")
1359 self.assertEqual(g(s), "mutableset")
1360 self.assertEqual(g(f), "sized")
1361 self.assertEqual(g(t), "sized")
1362 g.register(c.Mapping, lambda obj: "mapping")
1363 self.assertEqual(g(d), "mutablemapping") # not specific enough
1364 self.assertEqual(g(l), "mutablesequence")
1365 self.assertEqual(g(s), "mutableset")
1366 self.assertEqual(g(f), "sized")
1367 self.assertEqual(g(t), "sized")
1368 g.register(c.Sequence, lambda obj: "sequence")
1369 self.assertEqual(g(d), "mutablemapping")
1370 self.assertEqual(g(l), "mutablesequence")
1371 self.assertEqual(g(s), "mutableset")
1372 self.assertEqual(g(f), "sized")
1373 self.assertEqual(g(t), "sequence")
1374 g.register(c.Set, lambda obj: "set")
1375 self.assertEqual(g(d), "mutablemapping")
1376 self.assertEqual(g(l), "mutablesequence")
1377 self.assertEqual(g(s), "mutableset")
1378 self.assertEqual(g(f), "set")
1379 self.assertEqual(g(t), "sequence")
1380 g.register(dict, lambda obj: "dict")
1381 self.assertEqual(g(d), "dict")
1382 self.assertEqual(g(l), "mutablesequence")
1383 self.assertEqual(g(s), "mutableset")
1384 self.assertEqual(g(f), "set")
1385 self.assertEqual(g(t), "sequence")
1386 g.register(list, lambda obj: "list")
1387 self.assertEqual(g(d), "dict")
1388 self.assertEqual(g(l), "list")
1389 self.assertEqual(g(s), "mutableset")
1390 self.assertEqual(g(f), "set")
1391 self.assertEqual(g(t), "sequence")
1392 g.register(set, lambda obj: "concrete-set")
1393 self.assertEqual(g(d), "dict")
1394 self.assertEqual(g(l), "list")
1395 self.assertEqual(g(s), "concrete-set")
1396 self.assertEqual(g(f), "set")
1397 self.assertEqual(g(t), "sequence")
1398 g.register(frozenset, lambda obj: "frozen-set")
1399 self.assertEqual(g(d), "dict")
1400 self.assertEqual(g(l), "list")
1401 self.assertEqual(g(s), "concrete-set")
1402 self.assertEqual(g(f), "frozen-set")
1403 self.assertEqual(g(t), "sequence")
1404 g.register(tuple, lambda obj: "tuple")
1405 self.assertEqual(g(d), "dict")
1406 self.assertEqual(g(l), "list")
1407 self.assertEqual(g(s), "concrete-set")
1408 self.assertEqual(g(f), "frozen-set")
1409 self.assertEqual(g(t), "tuple")
1410
Łukasz Langa3720c772013-07-01 16:00:38 +02001411 def test_c3_abc(self):
1412 c = collections
1413 mro = functools._c3_mro
1414 class A(object):
1415 pass
1416 class B(A):
1417 def __len__(self):
1418 return 0 # implies Sized
1419 @c.Container.register
1420 class C(object):
1421 pass
1422 class D(object):
1423 pass # unrelated
1424 class X(D, C, B):
1425 def __call__(self):
1426 pass # implies Callable
1427 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
1428 for abcs in permutations([c.Sized, c.Callable, c.Container]):
1429 self.assertEqual(mro(X, abcs=abcs), expected)
1430 # unrelated ABCs don't appear in the resulting MRO
1431 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
1432 self.assertEqual(mro(X, abcs=many_abcs), expected)
1433
Łukasz Langa6f692512013-06-05 12:20:24 +02001434 def test_mro_conflicts(self):
1435 c = collections
Łukasz Langa6f692512013-06-05 12:20:24 +02001436 @functools.singledispatch
1437 def g(arg):
1438 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02001439 class O(c.Sized):
1440 def __len__(self):
1441 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001442 o = O()
1443 self.assertEqual(g(o), "base")
1444 g.register(c.Iterable, lambda arg: "iterable")
1445 g.register(c.Container, lambda arg: "container")
1446 g.register(c.Sized, lambda arg: "sized")
1447 g.register(c.Set, lambda arg: "set")
1448 self.assertEqual(g(o), "sized")
1449 c.Iterable.register(O)
1450 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
1451 c.Container.register(O)
1452 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02001453 c.Set.register(O)
1454 self.assertEqual(g(o), "set") # because c.Set is a subclass of
1455 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02001456 class P:
1457 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02001458 p = P()
1459 self.assertEqual(g(p), "base")
1460 c.Iterable.register(P)
1461 self.assertEqual(g(p), "iterable")
1462 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02001463 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02001464 g(p)
Łukasz Langa3720c772013-07-01 16:00:38 +02001465 self.assertIn(
1466 str(re_one.exception),
1467 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1468 "or <class 'collections.abc.Iterable'>"),
1469 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
1470 "or <class 'collections.abc.Container'>")),
1471 )
Łukasz Langa6f692512013-06-05 12:20:24 +02001472 class Q(c.Sized):
1473 def __len__(self):
1474 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001475 q = Q()
1476 self.assertEqual(g(q), "sized")
1477 c.Iterable.register(Q)
1478 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
1479 c.Set.register(Q)
1480 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02001481 # c.Sized and c.Iterable
1482 @functools.singledispatch
1483 def h(arg):
1484 return "base"
1485 @h.register(c.Sized)
1486 def _(arg):
1487 return "sized"
1488 @h.register(c.Container)
1489 def _(arg):
1490 return "container"
1491 # Even though Sized and Container are explicit bases of MutableMapping,
1492 # this ABC is implicitly registered on defaultdict which makes all of
1493 # MutableMapping's bases implicit as well from defaultdict's
1494 # perspective.
1495 with self.assertRaises(RuntimeError) as re_two:
1496 h(c.defaultdict(lambda: 0))
1497 self.assertIn(
1498 str(re_two.exception),
1499 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1500 "or <class 'collections.abc.Sized'>"),
1501 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1502 "or <class 'collections.abc.Container'>")),
1503 )
1504 class R(c.defaultdict):
1505 pass
1506 c.MutableSequence.register(R)
1507 @functools.singledispatch
1508 def i(arg):
1509 return "base"
1510 @i.register(c.MutableMapping)
1511 def _(arg):
1512 return "mapping"
1513 @i.register(c.MutableSequence)
1514 def _(arg):
1515 return "sequence"
1516 r = R()
1517 self.assertEqual(i(r), "sequence")
1518 class S:
1519 pass
1520 class T(S, c.Sized):
1521 def __len__(self):
1522 return 0
1523 t = T()
1524 self.assertEqual(h(t), "sized")
1525 c.Container.register(T)
1526 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
1527 class U:
1528 def __len__(self):
1529 return 0
1530 u = U()
1531 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
1532 # from the existence of __len__()
1533 c.Container.register(U)
1534 # There is no preference for registered versus inferred ABCs.
1535 with self.assertRaises(RuntimeError) as re_three:
1536 h(u)
1537 self.assertIn(
1538 str(re_three.exception),
1539 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1540 "or <class 'collections.abc.Sized'>"),
1541 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1542 "or <class 'collections.abc.Container'>")),
1543 )
1544 class V(c.Sized, S):
1545 def __len__(self):
1546 return 0
1547 @functools.singledispatch
1548 def j(arg):
1549 return "base"
1550 @j.register(S)
1551 def _(arg):
1552 return "s"
1553 @j.register(c.Container)
1554 def _(arg):
1555 return "container"
1556 v = V()
1557 self.assertEqual(j(v), "s")
1558 c.Container.register(V)
1559 self.assertEqual(j(v), "container") # because it ends up right after
1560 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02001561
1562 def test_cache_invalidation(self):
1563 from collections import UserDict
1564 class TracingDict(UserDict):
1565 def __init__(self, *args, **kwargs):
1566 super(TracingDict, self).__init__(*args, **kwargs)
1567 self.set_ops = []
1568 self.get_ops = []
1569 def __getitem__(self, key):
1570 result = self.data[key]
1571 self.get_ops.append(key)
1572 return result
1573 def __setitem__(self, key, value):
1574 self.set_ops.append(key)
1575 self.data[key] = value
1576 def clear(self):
1577 self.data.clear()
1578 _orig_wkd = functools.WeakKeyDictionary
1579 td = TracingDict()
1580 functools.WeakKeyDictionary = lambda: td
1581 c = collections
1582 @functools.singledispatch
1583 def g(arg):
1584 return "base"
1585 d = {}
1586 l = []
1587 self.assertEqual(len(td), 0)
1588 self.assertEqual(g(d), "base")
1589 self.assertEqual(len(td), 1)
1590 self.assertEqual(td.get_ops, [])
1591 self.assertEqual(td.set_ops, [dict])
1592 self.assertEqual(td.data[dict], g.registry[object])
1593 self.assertEqual(g(l), "base")
1594 self.assertEqual(len(td), 2)
1595 self.assertEqual(td.get_ops, [])
1596 self.assertEqual(td.set_ops, [dict, list])
1597 self.assertEqual(td.data[dict], g.registry[object])
1598 self.assertEqual(td.data[list], g.registry[object])
1599 self.assertEqual(td.data[dict], td.data[list])
1600 self.assertEqual(g(l), "base")
1601 self.assertEqual(g(d), "base")
1602 self.assertEqual(td.get_ops, [list, dict])
1603 self.assertEqual(td.set_ops, [dict, list])
1604 g.register(list, lambda arg: "list")
1605 self.assertEqual(td.get_ops, [list, dict])
1606 self.assertEqual(len(td), 0)
1607 self.assertEqual(g(d), "base")
1608 self.assertEqual(len(td), 1)
1609 self.assertEqual(td.get_ops, [list, dict])
1610 self.assertEqual(td.set_ops, [dict, list, dict])
1611 self.assertEqual(td.data[dict],
1612 functools._find_impl(dict, g.registry))
1613 self.assertEqual(g(l), "list")
1614 self.assertEqual(len(td), 2)
1615 self.assertEqual(td.get_ops, [list, dict])
1616 self.assertEqual(td.set_ops, [dict, list, dict, list])
1617 self.assertEqual(td.data[list],
1618 functools._find_impl(list, g.registry))
1619 class X:
1620 pass
1621 c.MutableMapping.register(X) # Will not invalidate the cache,
1622 # not using ABCs yet.
1623 self.assertEqual(g(d), "base")
1624 self.assertEqual(g(l), "list")
1625 self.assertEqual(td.get_ops, [list, dict, dict, list])
1626 self.assertEqual(td.set_ops, [dict, list, dict, list])
1627 g.register(c.Sized, lambda arg: "sized")
1628 self.assertEqual(len(td), 0)
1629 self.assertEqual(g(d), "sized")
1630 self.assertEqual(len(td), 1)
1631 self.assertEqual(td.get_ops, [list, dict, dict, list])
1632 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
1633 self.assertEqual(g(l), "list")
1634 self.assertEqual(len(td), 2)
1635 self.assertEqual(td.get_ops, [list, dict, dict, list])
1636 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1637 self.assertEqual(g(l), "list")
1638 self.assertEqual(g(d), "sized")
1639 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
1640 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1641 g.dispatch(list)
1642 g.dispatch(dict)
1643 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
1644 list, dict])
1645 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1646 c.MutableSet.register(X) # Will invalidate the cache.
1647 self.assertEqual(len(td), 2) # Stale cache.
1648 self.assertEqual(g(l), "list")
1649 self.assertEqual(len(td), 1)
1650 g.register(c.MutableMapping, lambda arg: "mutablemapping")
1651 self.assertEqual(len(td), 0)
1652 self.assertEqual(g(d), "mutablemapping")
1653 self.assertEqual(len(td), 1)
1654 self.assertEqual(g(l), "list")
1655 self.assertEqual(len(td), 2)
1656 g.register(dict, lambda arg: "dict")
1657 self.assertEqual(g(d), "dict")
1658 self.assertEqual(g(l), "list")
1659 g._clear_cache()
1660 self.assertEqual(len(td), 0)
1661 functools.WeakKeyDictionary = _orig_wkd
1662
1663
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001664if __name__ == '__main__':
Zachary Ware38c707e2015-04-13 15:00:43 -05001665 unittest.main()