blob: 63fe83e5dbd6028ecb4d68bbf74a4f43deac14e9 [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08002import builtins
Raymond Hettinger003be522011-05-03 11:01:32 -07003import collections
Serhiy Storchaka45120f22015-10-24 09:49:56 +03004import copy
Łukasz Langa6f692512013-06-05 12:20:24 +02005from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00006import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00007from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02008import sys
9from test import support
Serhiy Storchaka67796522017-01-12 18:34:33 +020010import time
Łukasz Langa6f692512013-06-05 12:20:24 +020011import unittest
Raymond Hettingerd191ef22017-01-07 20:44:48 -080012import unittest.mock
Łukasz Langa6f692512013-06-05 12:20:24 +020013from weakref import proxy
Nick Coghlan457fc9a2016-09-10 20:00:02 +100014import contextlib
Serhiy Storchaka46c56112015-05-24 21:53:49 +030015try:
16 import threading
17except ImportError:
18 threading = None
Raymond Hettinger9c323f82005-02-28 19:39:44 +000019
Antoine Pitroub5b37142012-11-13 21:35:40 +010020import functools
21
Antoine Pitroub5b37142012-11-13 21:35:40 +010022py_functools = support.import_fresh_module('functools', blocked=['_functools'])
23c_functools = support.import_fresh_module('functools', fresh=['_functools'])
24
Łukasz Langa6f692512013-06-05 12:20:24 +020025decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
26
Nick Coghlan457fc9a2016-09-10 20:00:02 +100027@contextlib.contextmanager
28def replaced_module(name, replacement):
29 original_module = sys.modules[name]
30 sys.modules[name] = replacement
31 try:
32 yield
33 finally:
34 sys.modules[name] = original_module
Łukasz Langa6f692512013-06-05 12:20:24 +020035
Raymond Hettinger9c323f82005-02-28 19:39:44 +000036def capture(*args, **kw):
37 """capture all positional and keyword arguments"""
38 return args, kw
39
Łukasz Langa6f692512013-06-05 12:20:24 +020040
Jack Diederiche0cbd692009-04-01 04:27:09 +000041def signature(part):
42 """ return the signature of a partial object """
43 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000044
Serhiy Storchaka38741282016-02-02 18:45:17 +020045class MyTuple(tuple):
46 pass
47
48class BadTuple(tuple):
49 def __add__(self, other):
50 return list(self) + list(other)
51
52class MyDict(dict):
53 pass
54
Łukasz Langa6f692512013-06-05 12:20:24 +020055
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020056class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000057
58 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010059 p = self.partial(capture, 1, 2, a=10, b=20)
60 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000061 self.assertEqual(p(3, 4, b=30, c=40),
62 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010063 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000064 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000065
66 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010067 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000068 # attributes should be readable
69 self.assertEqual(p.func, capture)
70 self.assertEqual(p.args, (1, 2))
71 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000072
73 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010074 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000075 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010076 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000077 except TypeError:
78 pass
79 else:
80 self.fail('First arg not checked for callability')
81
82 def test_protection_of_callers_dict_argument(self):
83 # a caller's dictionary should not be altered by partial
84 def func(a=10, b=20):
85 return a
86 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010087 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000088 self.assertEqual(p(**d), 3)
89 self.assertEqual(d, {'a':3})
90 p(b=7)
91 self.assertEqual(d, {'a':3})
92
93 def test_arg_combinations(self):
94 # exercise special code paths for zero args in either partial
95 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010096 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000097 self.assertEqual(p(), ((), {}))
98 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010099 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000100 self.assertEqual(p(), ((1,2), {}))
101 self.assertEqual(p(3,4), ((1,2,3,4), {}))
102
103 def test_kw_combinations(self):
104 # exercise special code paths for no keyword args in
105 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100106 p = self.partial(capture)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400107 self.assertEqual(p.keywords, {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000108 self.assertEqual(p(), ((), {}))
109 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100110 p = self.partial(capture, a=1)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400111 self.assertEqual(p.keywords, {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000112 self.assertEqual(p(), ((), {'a':1}))
113 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
114 # keyword args in the call override those in the partial object
115 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
116
117 def test_positional(self):
118 # make sure positional arguments are captured correctly
119 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100120 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000121 expected = args + ('x',)
122 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000123 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000124
125 def test_keyword(self):
126 # make sure keyword arguments are captured correctly
127 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100128 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000129 expected = {'a':a,'x':None}
130 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000131 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000132
133 def test_no_side_effects(self):
134 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100135 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000136 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000137 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000138 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000139 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000140
141 def test_error_propagation(self):
142 def f(x, y):
143 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100144 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
145 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
146 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
147 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000148
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000149 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100150 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000151 p = proxy(f)
152 self.assertEqual(f.func, p.func)
153 f = None
154 self.assertRaises(ReferenceError, getattr, p, 'func')
155
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000156 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000157 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100158 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000159 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100160 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000161 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000162
Alexander Belopolskye49af342015-03-01 15:08:17 -0500163 def test_nested_optimization(self):
164 partial = self.partial
Alexander Belopolskye49af342015-03-01 15:08:17 -0500165 inner = partial(signature, 'asdf')
166 nested = partial(inner, bar=True)
167 flat = partial(signature, 'asdf', bar=True)
168 self.assertEqual(signature(nested), signature(flat))
169
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300170 def test_nested_partial_with_attribute(self):
171 # see issue 25137
172 partial = self.partial
173
174 def foo(bar):
175 return bar
176
177 p = partial(foo, 'first')
178 p2 = partial(p, 'second')
179 p2.new_attr = 'spam'
180 self.assertEqual(p2.new_attr, 'spam')
181
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000182 def test_repr(self):
183 args = (object(), object())
184 args_repr = ', '.join(repr(a) for a in args)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200185 kwargs = {'a': object(), 'b': object()}
186 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
187 'b={b!r}, a={a!r}'.format_map(kwargs)]
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000188 if self.partial in (c_functools.partial, py_functools.partial):
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000189 name = 'functools.partial'
190 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100191 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000192
Antoine Pitroub5b37142012-11-13 21:35:40 +0100193 f = self.partial(capture)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000194 self.assertEqual(f'{name}({capture!r})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000195
Antoine Pitroub5b37142012-11-13 21:35:40 +0100196 f = self.partial(capture, *args)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000197 self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000198
Antoine Pitroub5b37142012-11-13 21:35:40 +0100199 f = self.partial(capture, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200200 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000201 [f'{name}({capture!r}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200202 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000203
Antoine Pitroub5b37142012-11-13 21:35:40 +0100204 f = self.partial(capture, *args, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200205 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000206 [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200207 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000208
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300209 def test_recursive_repr(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000210 if self.partial in (c_functools.partial, py_functools.partial):
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300211 name = 'functools.partial'
212 else:
213 name = self.partial.__name__
214
215 f = self.partial(capture)
216 f.__setstate__((f, (), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300217 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000218 self.assertEqual(repr(f), '%s(...)' % (name,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300219 finally:
220 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300221
222 f = self.partial(capture)
223 f.__setstate__((capture, (f,), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300224 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000225 self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300226 finally:
227 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300228
229 f = self.partial(capture)
230 f.__setstate__((capture, (), {'a': f}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300231 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000232 self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300233 finally:
234 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300235
Jack Diederiche0cbd692009-04-01 04:27:09 +0000236 def test_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000237 with self.AllowPickle():
238 f = self.partial(signature, ['asdf'], bar=[True])
239 f.attr = []
240 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
241 f_copy = pickle.loads(pickle.dumps(f, proto))
242 self.assertEqual(signature(f_copy), signature(f))
Serhiy Storchaka38741282016-02-02 18:45:17 +0200243
244 def test_copy(self):
245 f = self.partial(signature, ['asdf'], bar=[True])
246 f.attr = []
247 f_copy = copy.copy(f)
248 self.assertEqual(signature(f_copy), signature(f))
249 self.assertIs(f_copy.attr, f.attr)
250 self.assertIs(f_copy.args, f.args)
251 self.assertIs(f_copy.keywords, f.keywords)
252
253 def test_deepcopy(self):
254 f = self.partial(signature, ['asdf'], bar=[True])
255 f.attr = []
256 f_copy = copy.deepcopy(f)
257 self.assertEqual(signature(f_copy), signature(f))
258 self.assertIsNot(f_copy.attr, f.attr)
259 self.assertIsNot(f_copy.args, f.args)
260 self.assertIsNot(f_copy.args[0], f.args[0])
261 self.assertIsNot(f_copy.keywords, f.keywords)
262 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
263
264 def test_setstate(self):
265 f = self.partial(signature)
266 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000267
Serhiy Storchaka38741282016-02-02 18:45:17 +0200268 self.assertEqual(signature(f),
269 (capture, (1,), dict(a=10), dict(attr=[])))
270 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
271
272 f.__setstate__((capture, (1,), dict(a=10), None))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000273
Serhiy Storchaka38741282016-02-02 18:45:17 +0200274 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
275 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
276
277 f.__setstate__((capture, (1,), None, None))
278 #self.assertEqual(signature(f), (capture, (1,), {}, {}))
279 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
280 self.assertEqual(f(2), ((1, 2), {}))
281 self.assertEqual(f(), ((1,), {}))
282
283 f.__setstate__((capture, (), {}, None))
284 self.assertEqual(signature(f), (capture, (), {}, {}))
285 self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
286 self.assertEqual(f(2), ((2,), {}))
287 self.assertEqual(f(), ((), {}))
288
289 def test_setstate_errors(self):
290 f = self.partial(signature)
291 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
292 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
293 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
294 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
295 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
296 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
297 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
298
299 def test_setstate_subclasses(self):
300 f = self.partial(signature)
301 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
302 s = signature(f)
303 self.assertEqual(s, (capture, (1,), dict(a=10), {}))
304 self.assertIs(type(s[1]), tuple)
305 self.assertIs(type(s[2]), dict)
306 r = f()
307 self.assertEqual(r, ((1,), {'a': 10}))
308 self.assertIs(type(r[0]), tuple)
309 self.assertIs(type(r[1]), dict)
310
311 f.__setstate__((capture, BadTuple((1,)), {}, None))
312 s = signature(f)
313 self.assertEqual(s, (capture, (1,), {}, {}))
314 self.assertIs(type(s[1]), tuple)
315 r = f(2)
316 self.assertEqual(r, ((1, 2), {}))
317 self.assertIs(type(r[0]), tuple)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000318
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300319 def test_recursive_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000320 with self.AllowPickle():
321 f = self.partial(capture)
322 f.__setstate__((f, (), {}, {}))
323 try:
324 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
325 with self.assertRaises(RecursionError):
326 pickle.dumps(f, proto)
327 finally:
328 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300329
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000330 f = self.partial(capture)
331 f.__setstate__((capture, (f,), {}, {}))
332 try:
333 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
334 f_copy = pickle.loads(pickle.dumps(f, proto))
335 try:
336 self.assertIs(f_copy.args[0], f_copy)
337 finally:
338 f_copy.__setstate__((capture, (), {}, {}))
339 finally:
340 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300341
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000342 f = self.partial(capture)
343 f.__setstate__((capture, (), {'a': f}, {}))
344 try:
345 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
346 f_copy = pickle.loads(pickle.dumps(f, proto))
347 try:
348 self.assertIs(f_copy.keywords['a'], f_copy)
349 finally:
350 f_copy.__setstate__((capture, (), {}, {}))
351 finally:
352 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300353
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200354 # Issue 6083: Reference counting bug
355 def test_setstate_refcount(self):
356 class BadSequence:
357 def __len__(self):
358 return 4
359 def __getitem__(self, key):
360 if key == 0:
361 return max
362 elif key == 1:
363 return tuple(range(1000000))
364 elif key in (2, 3):
365 return {}
366 raise IndexError
367
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200368 f = self.partial(object)
Serhiy Storchaka38741282016-02-02 18:45:17 +0200369 self.assertRaises(TypeError, f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000370
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000371@unittest.skipUnless(c_functools, 'requires the C _functools module')
372class TestPartialC(TestPartial, unittest.TestCase):
373 if c_functools:
374 partial = c_functools.partial
375
376 class AllowPickle:
377 def __enter__(self):
378 return self
379 def __exit__(self, type, value, tb):
380 return False
381
382 def test_attributes_unwritable(self):
383 # attributes should not be writable
384 p = self.partial(capture, 1, 2, a=10, b=20)
385 self.assertRaises(AttributeError, setattr, p, 'func', map)
386 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
387 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
388
389 p = self.partial(hex)
390 try:
391 del p.__dict__
392 except TypeError:
393 pass
394 else:
395 self.fail('partial object allowed __dict__ to be deleted')
Łukasz Langa6f692512013-06-05 12:20:24 +0200396
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200397class TestPartialPy(TestPartial, unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000398 partial = py_functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000399
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000400 class AllowPickle:
401 def __init__(self):
402 self._cm = replaced_module("functools", py_functools)
403 def __enter__(self):
404 return self._cm.__enter__()
405 def __exit__(self, type, value, tb):
406 return self._cm.__exit__(type, value, tb)
Łukasz Langa6f692512013-06-05 12:20:24 +0200407
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200408if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000409 class CPartialSubclass(c_functools.partial):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200410 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100411
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000412class PyPartialSubclass(py_functools.partial):
413 pass
Łukasz Langa6f692512013-06-05 12:20:24 +0200414
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200415@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200416class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200417 if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000418 partial = CPartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000419
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300420 # partial subclasses are not optimized for nested calls
421 test_nested_optimization = None
422
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000423class TestPartialPySubclass(TestPartialPy):
424 partial = PyPartialSubclass
Łukasz Langa6f692512013-06-05 12:20:24 +0200425
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000426class TestPartialMethod(unittest.TestCase):
427
428 class A(object):
429 nothing = functools.partialmethod(capture)
430 positional = functools.partialmethod(capture, 1)
431 keywords = functools.partialmethod(capture, a=2)
432 both = functools.partialmethod(capture, 3, b=4)
433
434 nested = functools.partialmethod(positional, 5)
435
436 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
437
438 static = functools.partialmethod(staticmethod(capture), 8)
439 cls = functools.partialmethod(classmethod(capture), d=9)
440
441 a = A()
442
443 def test_arg_combinations(self):
444 self.assertEqual(self.a.nothing(), ((self.a,), {}))
445 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
446 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
447 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
448
449 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
450 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
451 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
452 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
453
454 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
455 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
456 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
457 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
458
459 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
460 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
461 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
462 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
463
464 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
465
466 def test_nested(self):
467 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
468 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
469 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
470 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
471
472 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
473
474 def test_over_partial(self):
475 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
476 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
477 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
478 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
479
480 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
481
482 def test_bound_method_introspection(self):
483 obj = self.a
484 self.assertIs(obj.both.__self__, obj)
485 self.assertIs(obj.nested.__self__, obj)
486 self.assertIs(obj.over_partial.__self__, obj)
487 self.assertIs(obj.cls.__self__, self.A)
488 self.assertIs(self.A.cls.__self__, self.A)
489
490 def test_unbound_method_retrieval(self):
491 obj = self.A
492 self.assertFalse(hasattr(obj.both, "__self__"))
493 self.assertFalse(hasattr(obj.nested, "__self__"))
494 self.assertFalse(hasattr(obj.over_partial, "__self__"))
495 self.assertFalse(hasattr(obj.static, "__self__"))
496 self.assertFalse(hasattr(self.a.static, "__self__"))
497
498 def test_descriptors(self):
499 for obj in [self.A, self.a]:
500 with self.subTest(obj=obj):
501 self.assertEqual(obj.static(), ((8,), {}))
502 self.assertEqual(obj.static(5), ((8, 5), {}))
503 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
504 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
505
506 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
507 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
508 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
509 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
510
511 def test_overriding_keywords(self):
512 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
513 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
514
515 def test_invalid_args(self):
516 with self.assertRaises(TypeError):
517 class B(object):
518 method = functools.partialmethod(None, 1)
519
520 def test_repr(self):
521 self.assertEqual(repr(vars(self.A)['both']),
522 'functools.partialmethod({}, 3, b=4)'.format(capture))
523
524 def test_abstract(self):
525 class Abstract(abc.ABCMeta):
526
527 @abc.abstractmethod
528 def add(self, x, y):
529 pass
530
531 add5 = functools.partialmethod(add, 5)
532
533 self.assertTrue(Abstract.add.__isabstractmethod__)
534 self.assertTrue(Abstract.add5.__isabstractmethod__)
535
536 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
537 self.assertFalse(getattr(func, '__isabstractmethod__', False))
538
539
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000540class TestUpdateWrapper(unittest.TestCase):
541
542 def check_wrapper(self, wrapper, wrapped,
543 assigned=functools.WRAPPER_ASSIGNMENTS,
544 updated=functools.WRAPPER_UPDATES):
545 # Check attributes were assigned
546 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000547 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000548 # Check attributes were updated
549 for name in updated:
550 wrapper_attr = getattr(wrapper, name)
551 wrapped_attr = getattr(wrapped, name)
552 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000553 if name == "__dict__" and key == "__wrapped__":
554 # __wrapped__ is overwritten by the update code
555 continue
556 self.assertIs(wrapped_attr[key], wrapper_attr[key])
557 # Check __wrapped__
558 self.assertIs(wrapper.__wrapped__, wrapped)
559
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000560
R. David Murray378c0cf2010-02-24 01:46:21 +0000561 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000562 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000563 """This is a test"""
564 pass
565 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000566 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000567 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000568 pass
569 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000570 return wrapper, f
571
572 def test_default_update(self):
573 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000574 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000575 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000576 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600577 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000578 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000579 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
580 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000581
R. David Murray378c0cf2010-02-24 01:46:21 +0000582 @unittest.skipIf(sys.flags.optimize >= 2,
583 "Docstrings are omitted with -O2 and above")
584 def test_default_update_doc(self):
585 wrapper, f = self._default_update()
586 self.assertEqual(wrapper.__doc__, 'This is a test')
587
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000588 def test_no_update(self):
589 def f():
590 """This is a test"""
591 pass
592 f.attr = 'This is also a test'
593 def wrapper():
594 pass
595 functools.update_wrapper(wrapper, f, (), ())
596 self.check_wrapper(wrapper, f, (), ())
597 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600598 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000599 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000600 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000601 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000602
603 def test_selective_update(self):
604 def f():
605 pass
606 f.attr = 'This is a different test'
607 f.dict_attr = dict(a=1, b=2, c=3)
608 def wrapper():
609 pass
610 wrapper.dict_attr = {}
611 assign = ('attr',)
612 update = ('dict_attr',)
613 functools.update_wrapper(wrapper, f, assign, update)
614 self.check_wrapper(wrapper, f, assign, update)
615 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600616 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000617 self.assertEqual(wrapper.__doc__, None)
618 self.assertEqual(wrapper.attr, 'This is a different test')
619 self.assertEqual(wrapper.dict_attr, f.dict_attr)
620
Nick Coghlan98876832010-08-17 06:17:18 +0000621 def test_missing_attributes(self):
622 def f():
623 pass
624 def wrapper():
625 pass
626 wrapper.dict_attr = {}
627 assign = ('attr',)
628 update = ('dict_attr',)
629 # Missing attributes on wrapped object are ignored
630 functools.update_wrapper(wrapper, f, assign, update)
631 self.assertNotIn('attr', wrapper.__dict__)
632 self.assertEqual(wrapper.dict_attr, {})
633 # Wrapper must have expected attributes for updating
634 del wrapper.dict_attr
635 with self.assertRaises(AttributeError):
636 functools.update_wrapper(wrapper, f, assign, update)
637 wrapper.dict_attr = 1
638 with self.assertRaises(AttributeError):
639 functools.update_wrapper(wrapper, f, assign, update)
640
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200641 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000642 @unittest.skipIf(sys.flags.optimize >= 2,
643 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000644 def test_builtin_update(self):
645 # Test for bug #1576241
646 def wrapper():
647 pass
648 functools.update_wrapper(wrapper, max)
649 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000650 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000651 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000652
Łukasz Langa6f692512013-06-05 12:20:24 +0200653
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000654class TestWraps(TestUpdateWrapper):
655
R. David Murray378c0cf2010-02-24 01:46:21 +0000656 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000657 def f():
658 """This is a test"""
659 pass
660 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000661 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000662 @functools.wraps(f)
663 def wrapper():
664 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600665 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000666
667 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600668 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000669 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000670 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600671 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000672 self.assertEqual(wrapper.attr, 'This is also a test')
673
Antoine Pitroub5b37142012-11-13 21:35:40 +0100674 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000675 "Docstrings are omitted with -O2 and above")
676 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600677 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000678 self.assertEqual(wrapper.__doc__, 'This is a test')
679
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000680 def test_no_update(self):
681 def f():
682 """This is a test"""
683 pass
684 f.attr = 'This is also a test'
685 @functools.wraps(f, (), ())
686 def wrapper():
687 pass
688 self.check_wrapper(wrapper, f, (), ())
689 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600690 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000691 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000692 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000693
694 def test_selective_update(self):
695 def f():
696 pass
697 f.attr = 'This is a different test'
698 f.dict_attr = dict(a=1, b=2, c=3)
699 def add_dict_attr(f):
700 f.dict_attr = {}
701 return f
702 assign = ('attr',)
703 update = ('dict_attr',)
704 @functools.wraps(f, assign, update)
705 @add_dict_attr
706 def wrapper():
707 pass
708 self.check_wrapper(wrapper, f, assign, update)
709 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600710 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000711 self.assertEqual(wrapper.__doc__, None)
712 self.assertEqual(wrapper.attr, 'This is a different test')
713 self.assertEqual(wrapper.dict_attr, f.dict_attr)
714
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000715@unittest.skipUnless(c_functools, 'requires the C _functools module')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000716class TestReduce(unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000717 if c_functools:
718 func = c_functools.reduce
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000719
720 def test_reduce(self):
721 class Squares:
722 def __init__(self, max):
723 self.max = max
724 self.sofar = []
725
726 def __len__(self):
727 return len(self.sofar)
728
729 def __getitem__(self, i):
730 if not 0 <= i < self.max: raise IndexError
731 n = len(self.sofar)
732 while n <= i:
733 self.sofar.append(n*n)
734 n += 1
735 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000736 def add(x, y):
737 return x + y
738 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000739 self.assertEqual(
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000740 self.func(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000741 ['a','c','d','w']
742 )
743 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
744 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000745 self.func(lambda x, y: x*y, range(2,21), 1),
746 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000747 )
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000748 self.assertEqual(self.func(add, Squares(10)), 285)
749 self.assertEqual(self.func(add, Squares(10), 0), 285)
750 self.assertEqual(self.func(add, Squares(0), 0), 0)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000751 self.assertRaises(TypeError, self.func)
752 self.assertRaises(TypeError, self.func, 42, 42)
753 self.assertRaises(TypeError, self.func, 42, 42, 42)
754 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
755 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
756 self.assertRaises(TypeError, self.func, 42, (42, 42))
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000757 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
758 self.assertRaises(TypeError, self.func, add, "")
759 self.assertRaises(TypeError, self.func, add, ())
760 self.assertRaises(TypeError, self.func, add, object())
761
762 class TestFailingIter:
763 def __iter__(self):
764 raise RuntimeError
765 self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
766
767 self.assertEqual(self.func(add, [], None), None)
768 self.assertEqual(self.func(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000769
770 class BadSeq:
771 def __getitem__(self, index):
772 raise ValueError
773 self.assertRaises(ValueError, self.func, 42, BadSeq())
774
775 # Test reduce()'s use of iterators.
776 def test_iterator_usage(self):
777 class SequenceClass:
778 def __init__(self, n):
779 self.n = n
780 def __getitem__(self, i):
781 if 0 <= i < self.n:
782 return i
783 else:
784 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000785
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000786 from operator import add
787 self.assertEqual(self.func(add, SequenceClass(5)), 10)
788 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
789 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
790 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
791 self.assertEqual(self.func(add, SequenceClass(1)), 0)
792 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
793
794 d = {"one": 1, "two": 2, "three": 3}
795 self.assertEqual(self.func(add, d), "".join(d.keys()))
796
Łukasz Langa6f692512013-06-05 12:20:24 +0200797
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200798class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700799
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000800 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700801 def cmp1(x, y):
802 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100803 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700804 self.assertEqual(key(3), key(3))
805 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100806 self.assertGreaterEqual(key(3), key(3))
807
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700808 def cmp2(x, y):
809 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100810 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700811 self.assertEqual(key(4.0), key('4'))
812 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100813 self.assertLessEqual(key(2), key('35'))
814 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700815
816 def test_cmp_to_key_arguments(self):
817 def cmp1(x, y):
818 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100819 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700820 self.assertEqual(key(obj=3), key(obj=3))
821 self.assertGreater(key(obj=3), key(obj=1))
822 with self.assertRaises((TypeError, AttributeError)):
823 key(3) > 1 # rhs is not a K object
824 with self.assertRaises((TypeError, AttributeError)):
825 1 < key(3) # lhs is not a K object
826 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100827 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700828 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200829 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100830 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700831 with self.assertRaises(TypeError):
832 key() # too few args
833 with self.assertRaises(TypeError):
834 key(None, None) # too many args
835
836 def test_bad_cmp(self):
837 def cmp1(x, y):
838 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100839 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700840 with self.assertRaises(ZeroDivisionError):
841 key(3) > key(1)
842
843 class BadCmp:
844 def __lt__(self, other):
845 raise ZeroDivisionError
846 def cmp1(x, y):
847 return BadCmp()
848 with self.assertRaises(ZeroDivisionError):
849 key(3) > key(1)
850
851 def test_obj_field(self):
852 def cmp1(x, y):
853 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100854 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700855 self.assertEqual(key(50).obj, 50)
856
857 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000858 def mycmp(x, y):
859 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100860 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000861 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000862
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700863 def test_sort_int_str(self):
864 def mycmp(x, y):
865 x, y = int(x), int(y)
866 return (x > y) - (x < y)
867 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100868 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700869 self.assertEqual([int(value) for value in values],
870 [0, 1, 1, 2, 3, 4, 5, 7, 10])
871
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000872 def test_hash(self):
873 def mycmp(x, y):
874 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100875 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000876 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700877 self.assertRaises(TypeError, hash, k)
Raymond Hettingere7a24302011-05-03 11:16:36 -0700878 self.assertNotIsInstance(k, collections.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000879
Łukasz Langa6f692512013-06-05 12:20:24 +0200880
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200881@unittest.skipUnless(c_functools, 'requires the C _functools module')
882class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
883 if c_functools:
884 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100885
Łukasz Langa6f692512013-06-05 12:20:24 +0200886
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200887class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100888 cmp_to_key = staticmethod(py_functools.cmp_to_key)
889
Łukasz Langa6f692512013-06-05 12:20:24 +0200890
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000891class TestTotalOrdering(unittest.TestCase):
892
893 def test_total_ordering_lt(self):
894 @functools.total_ordering
895 class A:
896 def __init__(self, value):
897 self.value = value
898 def __lt__(self, other):
899 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000900 def __eq__(self, other):
901 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000902 self.assertTrue(A(1) < A(2))
903 self.assertTrue(A(2) > A(1))
904 self.assertTrue(A(1) <= A(2))
905 self.assertTrue(A(2) >= A(1))
906 self.assertTrue(A(2) <= A(2))
907 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000908 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000909
910 def test_total_ordering_le(self):
911 @functools.total_ordering
912 class A:
913 def __init__(self, value):
914 self.value = value
915 def __le__(self, other):
916 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000917 def __eq__(self, other):
918 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000919 self.assertTrue(A(1) < A(2))
920 self.assertTrue(A(2) > A(1))
921 self.assertTrue(A(1) <= A(2))
922 self.assertTrue(A(2) >= A(1))
923 self.assertTrue(A(2) <= A(2))
924 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000925 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000926
927 def test_total_ordering_gt(self):
928 @functools.total_ordering
929 class A:
930 def __init__(self, value):
931 self.value = value
932 def __gt__(self, other):
933 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000934 def __eq__(self, other):
935 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000936 self.assertTrue(A(1) < A(2))
937 self.assertTrue(A(2) > A(1))
938 self.assertTrue(A(1) <= A(2))
939 self.assertTrue(A(2) >= A(1))
940 self.assertTrue(A(2) <= A(2))
941 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000942 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000943
944 def test_total_ordering_ge(self):
945 @functools.total_ordering
946 class A:
947 def __init__(self, value):
948 self.value = value
949 def __ge__(self, other):
950 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000951 def __eq__(self, other):
952 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000953 self.assertTrue(A(1) < A(2))
954 self.assertTrue(A(2) > A(1))
955 self.assertTrue(A(1) <= A(2))
956 self.assertTrue(A(2) >= A(1))
957 self.assertTrue(A(2) <= A(2))
958 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000959 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000960
961 def test_total_ordering_no_overwrite(self):
962 # new methods should not overwrite existing
963 @functools.total_ordering
964 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000965 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +0000966 self.assertTrue(A(1) < A(2))
967 self.assertTrue(A(2) > A(1))
968 self.assertTrue(A(1) <= A(2))
969 self.assertTrue(A(2) >= A(1))
970 self.assertTrue(A(2) <= A(2))
971 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000972
Benjamin Peterson42ebee32010-04-11 01:43:16 +0000973 def test_no_operations_defined(self):
974 with self.assertRaises(ValueError):
975 @functools.total_ordering
976 class A:
977 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000978
Nick Coghlanf05d9812013-10-02 00:02:03 +1000979 def test_type_error_when_not_implemented(self):
980 # bug 10042; ensure stack overflow does not occur
981 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000982 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +1000983 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000984 def __init__(self, value):
985 self.value = value
986 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000987 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000988 return self.value == other.value
989 return False
990 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000991 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000992 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +1000993 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000994
Nick Coghlanf05d9812013-10-02 00:02:03 +1000995 @functools.total_ordering
996 class ImplementsGreaterThan:
997 def __init__(self, value):
998 self.value = value
999 def __eq__(self, other):
1000 if isinstance(other, ImplementsGreaterThan):
1001 return self.value == other.value
1002 return False
1003 def __gt__(self, other):
1004 if isinstance(other, ImplementsGreaterThan):
1005 return self.value > other.value
1006 return NotImplemented
1007
1008 @functools.total_ordering
1009 class ImplementsLessThanEqualTo:
1010 def __init__(self, value):
1011 self.value = value
1012 def __eq__(self, other):
1013 if isinstance(other, ImplementsLessThanEqualTo):
1014 return self.value == other.value
1015 return False
1016 def __le__(self, other):
1017 if isinstance(other, ImplementsLessThanEqualTo):
1018 return self.value <= other.value
1019 return NotImplemented
1020
1021 @functools.total_ordering
1022 class ImplementsGreaterThanEqualTo:
1023 def __init__(self, value):
1024 self.value = value
1025 def __eq__(self, other):
1026 if isinstance(other, ImplementsGreaterThanEqualTo):
1027 return self.value == other.value
1028 return False
1029 def __ge__(self, other):
1030 if isinstance(other, ImplementsGreaterThanEqualTo):
1031 return self.value >= other.value
1032 return NotImplemented
1033
1034 @functools.total_ordering
1035 class ComparatorNotImplemented:
1036 def __init__(self, value):
1037 self.value = value
1038 def __eq__(self, other):
1039 if isinstance(other, ComparatorNotImplemented):
1040 return self.value == other.value
1041 return False
1042 def __lt__(self, other):
1043 return NotImplemented
1044
1045 with self.subTest("LT < 1"), self.assertRaises(TypeError):
1046 ImplementsLessThan(-1) < 1
1047
1048 with self.subTest("LT < LE"), self.assertRaises(TypeError):
1049 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
1050
1051 with self.subTest("LT < GT"), self.assertRaises(TypeError):
1052 ImplementsLessThan(1) < ImplementsGreaterThan(1)
1053
1054 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
1055 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
1056
1057 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
1058 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
1059
1060 with self.subTest("GT > GE"), self.assertRaises(TypeError):
1061 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
1062
1063 with self.subTest("GT > LT"), self.assertRaises(TypeError):
1064 ImplementsGreaterThan(5) > ImplementsLessThan(5)
1065
1066 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
1067 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
1068
1069 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
1070 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
1071
1072 with self.subTest("GE when equal"):
1073 a = ComparatorNotImplemented(8)
1074 b = ComparatorNotImplemented(8)
1075 self.assertEqual(a, b)
1076 with self.assertRaises(TypeError):
1077 a >= b
1078
1079 with self.subTest("LE when equal"):
1080 a = ComparatorNotImplemented(9)
1081 b = ComparatorNotImplemented(9)
1082 self.assertEqual(a, b)
1083 with self.assertRaises(TypeError):
1084 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +02001085
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001086 def test_pickle(self):
Serhiy Storchaka92bb90a2016-09-22 11:39:25 +03001087 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001088 for name in '__lt__', '__gt__', '__le__', '__ge__':
1089 with self.subTest(method=name, proto=proto):
1090 method = getattr(Orderable_LT, name)
1091 method_copy = pickle.loads(pickle.dumps(method, proto))
1092 self.assertIs(method_copy, method)
1093
1094@functools.total_ordering
1095class Orderable_LT:
1096 def __init__(self, value):
1097 self.value = value
1098 def __lt__(self, other):
1099 return self.value < other.value
1100 def __eq__(self, other):
1101 return self.value == other.value
1102
1103
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001104class TestLRU:
Georg Brandl2e7346a2010-07-31 18:09:23 +00001105
1106 def test_lru(self):
1107 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001108 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001109 f = self.module.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001110 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001111 self.assertEqual(maxsize, 20)
1112 self.assertEqual(currsize, 0)
1113 self.assertEqual(hits, 0)
1114 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001115
1116 domain = range(5)
1117 for i in range(1000):
1118 x, y = choice(domain), choice(domain)
1119 actual = f(x, y)
1120 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +00001121 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001122 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001123 self.assertTrue(hits > misses)
1124 self.assertEqual(hits + misses, 1000)
1125 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001126
Raymond Hettinger02566ec2010-09-04 22:46:06 +00001127 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +00001128 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001129 self.assertEqual(hits, 0)
1130 self.assertEqual(misses, 0)
1131 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001132 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001133 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001134 self.assertEqual(hits, 0)
1135 self.assertEqual(misses, 1)
1136 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001137
Nick Coghlan98876832010-08-17 06:17:18 +00001138 # Test bypassing the cache
1139 self.assertIs(f.__wrapped__, orig)
1140 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001141 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001142 self.assertEqual(hits, 0)
1143 self.assertEqual(misses, 1)
1144 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +00001145
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001146 # test size zero (which means "never-cache")
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001147 @self.module.lru_cache(0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001148 def f():
1149 nonlocal f_cnt
1150 f_cnt += 1
1151 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001152 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001153 f_cnt = 0
1154 for i in range(5):
1155 self.assertEqual(f(), 20)
1156 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001157 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001158 self.assertEqual(hits, 0)
1159 self.assertEqual(misses, 5)
1160 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001161
1162 # test size one
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001163 @self.module.lru_cache(1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001164 def f():
1165 nonlocal f_cnt
1166 f_cnt += 1
1167 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001168 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001169 f_cnt = 0
1170 for i in range(5):
1171 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001172 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001173 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001174 self.assertEqual(hits, 4)
1175 self.assertEqual(misses, 1)
1176 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001177
Raymond Hettingerf3098282010-08-15 03:30:45 +00001178 # test size two
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001179 @self.module.lru_cache(2)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001180 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001181 nonlocal f_cnt
1182 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +00001183 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +00001184 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001185 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +00001186 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1187 # * * * *
1188 self.assertEqual(f(x), x*10)
1189 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001190 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001191 self.assertEqual(hits, 12)
1192 self.assertEqual(misses, 4)
1193 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001194
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001195 def test_lru_hash_only_once(self):
1196 # To protect against weird reentrancy bugs and to improve
1197 # efficiency when faced with slow __hash__ methods, the
1198 # LRU cache guarantees that it will only call __hash__
1199 # only once per use as an argument to the cached function.
1200
1201 @self.module.lru_cache(maxsize=1)
1202 def f(x, y):
1203 return x * 3 + y
1204
1205 # Simulate the integer 5
1206 mock_int = unittest.mock.Mock()
1207 mock_int.__mul__ = unittest.mock.Mock(return_value=15)
1208 mock_int.__hash__ = unittest.mock.Mock(return_value=999)
1209
1210 # Add to cache: One use as an argument gives one call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001211 self.assertEqual(f(mock_int, 1), 16)
1212 self.assertEqual(mock_int.__hash__.call_count, 1)
1213 self.assertEqual(f.cache_info(), (0, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001214
1215 # Cache hit: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001216 self.assertEqual(f(mock_int, 1), 16)
1217 self.assertEqual(mock_int.__hash__.call_count, 2)
1218 self.assertEqual(f.cache_info(), (1, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001219
1220 # Cache eviction: No use as an argument gives no additonal call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001221 self.assertEqual(f(6, 2), 20)
1222 self.assertEqual(mock_int.__hash__.call_count, 2)
1223 self.assertEqual(f.cache_info(), (1, 2, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001224
1225 # Cache miss: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001226 self.assertEqual(f(mock_int, 1), 16)
1227 self.assertEqual(mock_int.__hash__.call_count, 3)
1228 self.assertEqual(f.cache_info(), (1, 3, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001229
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08001230 def test_lru_reentrancy_with_len(self):
1231 # Test to make sure the LRU cache code isn't thrown-off by
1232 # caching the built-in len() function. Since len() can be
1233 # cached, we shouldn't use it inside the lru code itself.
1234 old_len = builtins.len
1235 try:
1236 builtins.len = self.module.lru_cache(4)(len)
1237 for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
1238 self.assertEqual(len('abcdefghijklmn'[:i]), i)
1239 finally:
1240 builtins.len = old_len
1241
Raymond Hettinger605a4472017-01-09 07:50:19 -08001242 def test_lru_star_arg_handling(self):
1243 # Test regression that arose in ea064ff3c10f
1244 @functools.lru_cache()
1245 def f(*args):
1246 return args
1247
1248 self.assertEqual(f(1, 2), (1, 2))
1249 self.assertEqual(f((1, 2)), ((1, 2),))
1250
Yury Selivanov46a02db2016-11-09 18:55:45 -05001251 def test_lru_type_error(self):
1252 # Regression test for issue #28653.
1253 # lru_cache was leaking when one of the arguments
1254 # wasn't cacheable.
1255
1256 @functools.lru_cache(maxsize=None)
1257 def infinite_cache(o):
1258 pass
1259
1260 @functools.lru_cache(maxsize=10)
1261 def limited_cache(o):
1262 pass
1263
1264 with self.assertRaises(TypeError):
1265 infinite_cache([])
1266
1267 with self.assertRaises(TypeError):
1268 limited_cache([])
1269
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001270 def test_lru_with_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001271 @self.module.lru_cache(maxsize=None)
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001272 def fib(n):
1273 if n < 2:
1274 return n
1275 return fib(n-1) + fib(n-2)
1276 self.assertEqual([fib(n) for n in range(16)],
1277 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1278 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001279 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001280 fib.cache_clear()
1281 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001282 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1283
1284 def test_lru_with_maxsize_negative(self):
1285 @self.module.lru_cache(maxsize=-10)
1286 def eq(n):
1287 return n
1288 for i in (0, 1):
1289 self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1290 self.assertEqual(eq.cache_info(),
1291 self.module._CacheInfo(hits=0, misses=300, maxsize=-10, currsize=1))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001292
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001293 def test_lru_with_exceptions(self):
1294 # Verify that user_function exceptions get passed through without
1295 # creating a hard-to-read chained exception.
1296 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +01001297 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001298 @self.module.lru_cache(maxsize)
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001299 def func(i):
1300 return 'abc'[i]
1301 self.assertEqual(func(0), 'a')
1302 with self.assertRaises(IndexError) as cm:
1303 func(15)
1304 self.assertIsNone(cm.exception.__context__)
1305 # Verify that the previous exception did not result in a cached entry
1306 with self.assertRaises(IndexError):
1307 func(15)
1308
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001309 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001310 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001311 @self.module.lru_cache(maxsize=maxsize, typed=True)
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001312 def square(x):
1313 return x * x
1314 self.assertEqual(square(3), 9)
1315 self.assertEqual(type(square(3)), type(9))
1316 self.assertEqual(square(3.0), 9.0)
1317 self.assertEqual(type(square(3.0)), type(9.0))
1318 self.assertEqual(square(x=3), 9)
1319 self.assertEqual(type(square(x=3)), type(9))
1320 self.assertEqual(square(x=3.0), 9.0)
1321 self.assertEqual(type(square(x=3.0)), type(9.0))
1322 self.assertEqual(square.cache_info().hits, 4)
1323 self.assertEqual(square.cache_info().misses, 4)
1324
Antoine Pitroub5b37142012-11-13 21:35:40 +01001325 def test_lru_with_keyword_args(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001326 @self.module.lru_cache()
Antoine Pitroub5b37142012-11-13 21:35:40 +01001327 def fib(n):
1328 if n < 2:
1329 return n
1330 return fib(n=n-1) + fib(n=n-2)
1331 self.assertEqual(
1332 [fib(n=number) for number in range(16)],
1333 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1334 )
1335 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001336 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001337 fib.cache_clear()
1338 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001339 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001340
1341 def test_lru_with_keyword_args_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001342 @self.module.lru_cache(maxsize=None)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001343 def fib(n):
1344 if n < 2:
1345 return n
1346 return fib(n=n-1) + fib(n=n-2)
1347 self.assertEqual([fib(n=number) for number in range(16)],
1348 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1349 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001350 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001351 fib.cache_clear()
1352 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001353 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1354
Raymond Hettinger4ee39142017-01-08 17:28:20 -08001355 def test_kwargs_order(self):
1356 # PEP 468: Preserving Keyword Argument Order
1357 @self.module.lru_cache(maxsize=10)
1358 def f(**kwargs):
1359 return list(kwargs.items())
1360 self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)])
1361 self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)])
1362 self.assertEqual(f.cache_info(),
1363 self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2))
1364
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001365 def test_lru_cache_decoration(self):
1366 def f(zomg: 'zomg_annotation'):
1367 """f doc string"""
1368 return 42
1369 g = self.module.lru_cache()(f)
1370 for attr in self.module.WRAPPER_ASSIGNMENTS:
1371 self.assertEqual(getattr(g, attr), getattr(f, attr))
1372
1373 @unittest.skipUnless(threading, 'This test requires threading.')
1374 def test_lru_cache_threaded(self):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001375 n, m = 5, 11
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001376 def orig(x, y):
1377 return 3 * x + y
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001378 f = self.module.lru_cache(maxsize=n*m)(orig)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001379 hits, misses, maxsize, currsize = f.cache_info()
1380 self.assertEqual(currsize, 0)
1381
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001382 start = threading.Event()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001383 def full(k):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001384 start.wait(10)
1385 for _ in range(m):
Serhiy Storchaka391af752015-06-08 12:44:18 +03001386 self.assertEqual(f(k, 0), orig(k, 0))
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001387
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001388 def clear():
1389 start.wait(10)
1390 for _ in range(2*m):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001391 f.cache_clear()
1392
1393 orig_si = sys.getswitchinterval()
Xavier de Gaye7522ef42016-12-08 11:06:56 +01001394 support.setswitchinterval(1e-6)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001395 try:
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001396 # create n threads in order to fill cache
Serhiy Storchaka391af752015-06-08 12:44:18 +03001397 threads = [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001398 for k in range(n)]
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001399 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001400 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001401
1402 hits, misses, maxsize, currsize = f.cache_info()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001403 if self.module is py_functools:
1404 # XXX: Why can be not equal?
1405 self.assertLessEqual(misses, n)
1406 self.assertLessEqual(hits, m*n - misses)
1407 else:
1408 self.assertEqual(misses, n)
1409 self.assertEqual(hits, m*n - misses)
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001410 self.assertEqual(currsize, n)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001411
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001412 # create n threads in order to fill cache and 1 to clear it
1413 threads = [threading.Thread(target=clear)]
Serhiy Storchaka391af752015-06-08 12:44:18 +03001414 threads += [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001415 for k in range(n)]
1416 start.clear()
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001417 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001418 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001419 finally:
1420 sys.setswitchinterval(orig_si)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001421
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001422 @unittest.skipUnless(threading, 'This test requires threading.')
1423 def test_lru_cache_threaded2(self):
1424 # Simultaneous call with the same arguments
1425 n, m = 5, 7
1426 start = threading.Barrier(n+1)
1427 pause = threading.Barrier(n+1)
1428 stop = threading.Barrier(n+1)
1429 @self.module.lru_cache(maxsize=m*n)
1430 def f(x):
1431 pause.wait(10)
1432 return 3 * x
1433 self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1434 def test():
1435 for i in range(m):
1436 start.wait(10)
1437 self.assertEqual(f(i), 3 * i)
1438 stop.wait(10)
1439 threads = [threading.Thread(target=test) for k in range(n)]
1440 with support.start_threads(threads):
1441 for i in range(m):
1442 start.wait(10)
1443 stop.reset()
1444 pause.wait(10)
1445 start.reset()
1446 stop.wait(10)
1447 pause.reset()
1448 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1449
Serhiy Storchaka67796522017-01-12 18:34:33 +02001450 @unittest.skipUnless(threading, 'This test requires threading.')
1451 def test_lru_cache_threaded3(self):
1452 @self.module.lru_cache(maxsize=2)
1453 def f(x):
1454 time.sleep(.01)
1455 return 3 * x
1456 def test(i, x):
1457 with self.subTest(thread=i):
1458 self.assertEqual(f(x), 3 * x, i)
1459 threads = [threading.Thread(target=test, args=(i, v))
1460 for i, v in enumerate([1, 2, 2, 3, 2])]
1461 with support.start_threads(threads):
1462 pass
1463
Raymond Hettinger03923422013-03-04 02:52:50 -05001464 def test_need_for_rlock(self):
1465 # This will deadlock on an LRU cache that uses a regular lock
1466
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001467 @self.module.lru_cache(maxsize=10)
Raymond Hettinger03923422013-03-04 02:52:50 -05001468 def test_func(x):
1469 'Used to demonstrate a reentrant lru_cache call within a single thread'
1470 return x
1471
1472 class DoubleEq:
1473 'Demonstrate a reentrant lru_cache call within a single thread'
1474 def __init__(self, x):
1475 self.x = x
1476 def __hash__(self):
1477 return self.x
1478 def __eq__(self, other):
1479 if self.x == 2:
1480 test_func(DoubleEq(1))
1481 return self.x == other.x
1482
1483 test_func(DoubleEq(1)) # Load the cache
1484 test_func(DoubleEq(2)) # Load the cache
1485 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1486 DoubleEq(2)) # Verify the correct return value
1487
Raymond Hettinger4d588972014-08-12 12:44:52 -07001488 def test_early_detection_of_bad_call(self):
1489 # Issue #22184
1490 with self.assertRaises(TypeError):
1491 @functools.lru_cache
1492 def f():
1493 pass
1494
Serhiy Storchakae7070f02015-06-08 11:19:24 +03001495 def test_lru_method(self):
1496 class X(int):
1497 f_cnt = 0
1498 @self.module.lru_cache(2)
1499 def f(self, x):
1500 self.f_cnt += 1
1501 return x*10+self
1502 a = X(5)
1503 b = X(5)
1504 c = X(7)
1505 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1506
1507 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1508 self.assertEqual(a.f(x), x*10 + 5)
1509 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1510 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1511
1512 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1513 self.assertEqual(b.f(x), x*10 + 5)
1514 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1515 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1516
1517 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1518 self.assertEqual(c.f(x), x*10 + 7)
1519 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1520 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1521
1522 self.assertEqual(a.f.cache_info(), X.f.cache_info())
1523 self.assertEqual(b.f.cache_info(), X.f.cache_info())
1524 self.assertEqual(c.f.cache_info(), X.f.cache_info())
1525
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001526 def test_pickle(self):
1527 cls = self.__class__
1528 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1529 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1530 with self.subTest(proto=proto, func=f):
1531 f_copy = pickle.loads(pickle.dumps(f, proto))
1532 self.assertIs(f_copy, f)
1533
1534 def test_copy(self):
1535 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001536 def orig(x, y):
1537 return 3 * x + y
1538 part = self.module.partial(orig, 2)
1539 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1540 self.module.lru_cache(2)(part))
1541 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001542 with self.subTest(func=f):
1543 f_copy = copy.copy(f)
1544 self.assertIs(f_copy, f)
1545
1546 def test_deepcopy(self):
1547 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001548 def orig(x, y):
1549 return 3 * x + y
1550 part = self.module.partial(orig, 2)
1551 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1552 self.module.lru_cache(2)(part))
1553 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001554 with self.subTest(func=f):
1555 f_copy = copy.deepcopy(f)
1556 self.assertIs(f_copy, f)
1557
1558
1559@py_functools.lru_cache()
1560def py_cached_func(x, y):
1561 return 3 * x + y
1562
1563@c_functools.lru_cache()
1564def c_cached_func(x, y):
1565 return 3 * x + y
1566
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001567
1568class TestLRUPy(TestLRU, unittest.TestCase):
1569 module = py_functools
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001570 cached_func = py_cached_func,
1571
1572 @module.lru_cache()
1573 def cached_meth(self, x, y):
1574 return 3 * x + y
1575
1576 @staticmethod
1577 @module.lru_cache()
1578 def cached_staticmeth(x, y):
1579 return 3 * x + y
1580
1581
1582class TestLRUC(TestLRU, unittest.TestCase):
1583 module = c_functools
1584 cached_func = c_cached_func,
1585
1586 @module.lru_cache()
1587 def cached_meth(self, x, y):
1588 return 3 * x + y
1589
1590 @staticmethod
1591 @module.lru_cache()
1592 def cached_staticmeth(x, y):
1593 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001594
Raymond Hettinger03923422013-03-04 02:52:50 -05001595
Łukasz Langa6f692512013-06-05 12:20:24 +02001596class TestSingleDispatch(unittest.TestCase):
1597 def test_simple_overloads(self):
1598 @functools.singledispatch
1599 def g(obj):
1600 return "base"
1601 def g_int(i):
1602 return "integer"
1603 g.register(int, g_int)
1604 self.assertEqual(g("str"), "base")
1605 self.assertEqual(g(1), "integer")
1606 self.assertEqual(g([1,2,3]), "base")
1607
1608 def test_mro(self):
1609 @functools.singledispatch
1610 def g(obj):
1611 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001612 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001613 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001614 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001615 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001616 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001617 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001618 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001619 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001620 def g_A(a):
1621 return "A"
1622 def g_B(b):
1623 return "B"
1624 g.register(A, g_A)
1625 g.register(B, g_B)
1626 self.assertEqual(g(A()), "A")
1627 self.assertEqual(g(B()), "B")
1628 self.assertEqual(g(C()), "A")
1629 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001630
1631 def test_register_decorator(self):
1632 @functools.singledispatch
1633 def g(obj):
1634 return "base"
1635 @g.register(int)
1636 def g_int(i):
1637 return "int %s" % (i,)
1638 self.assertEqual(g(""), "base")
1639 self.assertEqual(g(12), "int 12")
1640 self.assertIs(g.dispatch(int), g_int)
1641 self.assertIs(g.dispatch(object), g.dispatch(str))
1642 # Note: in the assert above this is not g.
1643 # @singledispatch returns the wrapper.
1644
1645 def test_wrapping_attributes(self):
1646 @functools.singledispatch
1647 def g(obj):
1648 "Simple test"
1649 return "Test"
1650 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02001651 if sys.flags.optimize < 2:
1652 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02001653
1654 @unittest.skipUnless(decimal, 'requires _decimal')
1655 @support.cpython_only
1656 def test_c_classes(self):
1657 @functools.singledispatch
1658 def g(obj):
1659 return "base"
1660 @g.register(decimal.DecimalException)
1661 def _(obj):
1662 return obj.args
1663 subn = decimal.Subnormal("Exponent < Emin")
1664 rnd = decimal.Rounded("Number got rounded")
1665 self.assertEqual(g(subn), ("Exponent < Emin",))
1666 self.assertEqual(g(rnd), ("Number got rounded",))
1667 @g.register(decimal.Subnormal)
1668 def _(obj):
1669 return "Too small to care."
1670 self.assertEqual(g(subn), "Too small to care.")
1671 self.assertEqual(g(rnd), ("Number got rounded",))
1672
1673 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001674 # None of the examples in this test depend on haystack ordering.
Łukasz Langa6f692512013-06-05 12:20:24 +02001675 c = collections
1676 mro = functools._compose_mro
1677 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1678 for haystack in permutations(bases):
1679 m = mro(dict, haystack)
Guido van Rossumf0666942016-08-23 10:47:07 -07001680 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
1681 c.Collection, c.Sized, c.Iterable,
1682 c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001683 bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict]
1684 for haystack in permutations(bases):
1685 m = mro(c.ChainMap, haystack)
1686 self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001687 c.Collection, c.Sized, c.Iterable,
1688 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001689
1690 # If there's a generic function with implementations registered for
1691 # both Sized and Container, passing a defaultdict to it results in an
1692 # ambiguous dispatch which will cause a RuntimeError (see
1693 # test_mro_conflicts).
1694 bases = [c.Container, c.Sized, str]
1695 for haystack in permutations(bases):
1696 m = mro(c.defaultdict, [c.Sized, c.Container, str])
1697 self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container,
1698 object])
1699
1700 # MutableSequence below is registered directly on D. In other words, it
Martin Panter46f50722016-05-26 05:35:26 +00001701 # precedes MutableMapping which means single dispatch will always
Łukasz Langa3720c772013-07-01 16:00:38 +02001702 # choose MutableSequence here.
1703 class D(c.defaultdict):
1704 pass
1705 c.MutableSequence.register(D)
1706 bases = [c.MutableSequence, c.MutableMapping]
1707 for haystack in permutations(bases):
1708 m = mro(D, bases)
Guido van Rossumf0666942016-08-23 10:47:07 -07001709 self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
1710 c.defaultdict, dict, c.MutableMapping, c.Mapping,
1711 c.Collection, c.Sized, c.Iterable, c.Container,
Łukasz Langa3720c772013-07-01 16:00:38 +02001712 object])
1713
1714 # Container and Callable are registered on different base classes and
1715 # a generic function supporting both should always pick the Callable
1716 # implementation if a C instance is passed.
1717 class C(c.defaultdict):
1718 def __call__(self):
1719 pass
1720 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1721 for haystack in permutations(bases):
1722 m = mro(C, haystack)
1723 self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001724 c.Collection, c.Sized, c.Iterable,
1725 c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001726
1727 def test_register_abc(self):
1728 c = collections
1729 d = {"a": "b"}
1730 l = [1, 2, 3]
1731 s = {object(), None}
1732 f = frozenset(s)
1733 t = (1, 2, 3)
1734 @functools.singledispatch
1735 def g(obj):
1736 return "base"
1737 self.assertEqual(g(d), "base")
1738 self.assertEqual(g(l), "base")
1739 self.assertEqual(g(s), "base")
1740 self.assertEqual(g(f), "base")
1741 self.assertEqual(g(t), "base")
1742 g.register(c.Sized, lambda obj: "sized")
1743 self.assertEqual(g(d), "sized")
1744 self.assertEqual(g(l), "sized")
1745 self.assertEqual(g(s), "sized")
1746 self.assertEqual(g(f), "sized")
1747 self.assertEqual(g(t), "sized")
1748 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1749 self.assertEqual(g(d), "mutablemapping")
1750 self.assertEqual(g(l), "sized")
1751 self.assertEqual(g(s), "sized")
1752 self.assertEqual(g(f), "sized")
1753 self.assertEqual(g(t), "sized")
1754 g.register(c.ChainMap, lambda obj: "chainmap")
1755 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1756 self.assertEqual(g(l), "sized")
1757 self.assertEqual(g(s), "sized")
1758 self.assertEqual(g(f), "sized")
1759 self.assertEqual(g(t), "sized")
1760 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1761 self.assertEqual(g(d), "mutablemapping")
1762 self.assertEqual(g(l), "mutablesequence")
1763 self.assertEqual(g(s), "sized")
1764 self.assertEqual(g(f), "sized")
1765 self.assertEqual(g(t), "sized")
1766 g.register(c.MutableSet, lambda obj: "mutableset")
1767 self.assertEqual(g(d), "mutablemapping")
1768 self.assertEqual(g(l), "mutablesequence")
1769 self.assertEqual(g(s), "mutableset")
1770 self.assertEqual(g(f), "sized")
1771 self.assertEqual(g(t), "sized")
1772 g.register(c.Mapping, lambda obj: "mapping")
1773 self.assertEqual(g(d), "mutablemapping") # not specific enough
1774 self.assertEqual(g(l), "mutablesequence")
1775 self.assertEqual(g(s), "mutableset")
1776 self.assertEqual(g(f), "sized")
1777 self.assertEqual(g(t), "sized")
1778 g.register(c.Sequence, lambda obj: "sequence")
1779 self.assertEqual(g(d), "mutablemapping")
1780 self.assertEqual(g(l), "mutablesequence")
1781 self.assertEqual(g(s), "mutableset")
1782 self.assertEqual(g(f), "sized")
1783 self.assertEqual(g(t), "sequence")
1784 g.register(c.Set, lambda obj: "set")
1785 self.assertEqual(g(d), "mutablemapping")
1786 self.assertEqual(g(l), "mutablesequence")
1787 self.assertEqual(g(s), "mutableset")
1788 self.assertEqual(g(f), "set")
1789 self.assertEqual(g(t), "sequence")
1790 g.register(dict, lambda obj: "dict")
1791 self.assertEqual(g(d), "dict")
1792 self.assertEqual(g(l), "mutablesequence")
1793 self.assertEqual(g(s), "mutableset")
1794 self.assertEqual(g(f), "set")
1795 self.assertEqual(g(t), "sequence")
1796 g.register(list, lambda obj: "list")
1797 self.assertEqual(g(d), "dict")
1798 self.assertEqual(g(l), "list")
1799 self.assertEqual(g(s), "mutableset")
1800 self.assertEqual(g(f), "set")
1801 self.assertEqual(g(t), "sequence")
1802 g.register(set, lambda obj: "concrete-set")
1803 self.assertEqual(g(d), "dict")
1804 self.assertEqual(g(l), "list")
1805 self.assertEqual(g(s), "concrete-set")
1806 self.assertEqual(g(f), "set")
1807 self.assertEqual(g(t), "sequence")
1808 g.register(frozenset, lambda obj: "frozen-set")
1809 self.assertEqual(g(d), "dict")
1810 self.assertEqual(g(l), "list")
1811 self.assertEqual(g(s), "concrete-set")
1812 self.assertEqual(g(f), "frozen-set")
1813 self.assertEqual(g(t), "sequence")
1814 g.register(tuple, lambda obj: "tuple")
1815 self.assertEqual(g(d), "dict")
1816 self.assertEqual(g(l), "list")
1817 self.assertEqual(g(s), "concrete-set")
1818 self.assertEqual(g(f), "frozen-set")
1819 self.assertEqual(g(t), "tuple")
1820
Łukasz Langa3720c772013-07-01 16:00:38 +02001821 def test_c3_abc(self):
1822 c = collections
1823 mro = functools._c3_mro
1824 class A(object):
1825 pass
1826 class B(A):
1827 def __len__(self):
1828 return 0 # implies Sized
1829 @c.Container.register
1830 class C(object):
1831 pass
1832 class D(object):
1833 pass # unrelated
1834 class X(D, C, B):
1835 def __call__(self):
1836 pass # implies Callable
1837 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
1838 for abcs in permutations([c.Sized, c.Callable, c.Container]):
1839 self.assertEqual(mro(X, abcs=abcs), expected)
1840 # unrelated ABCs don't appear in the resulting MRO
1841 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
1842 self.assertEqual(mro(X, abcs=many_abcs), expected)
1843
Yury Selivanov77a8cd62015-08-18 14:20:00 -04001844 def test_false_meta(self):
1845 # see issue23572
1846 class MetaA(type):
1847 def __len__(self):
1848 return 0
1849 class A(metaclass=MetaA):
1850 pass
1851 class AA(A):
1852 pass
1853 @functools.singledispatch
1854 def fun(a):
1855 return 'base A'
1856 @fun.register(A)
1857 def _(a):
1858 return 'fun A'
1859 aa = AA()
1860 self.assertEqual(fun(aa), 'fun A')
1861
Łukasz Langa6f692512013-06-05 12:20:24 +02001862 def test_mro_conflicts(self):
1863 c = collections
Łukasz Langa6f692512013-06-05 12:20:24 +02001864 @functools.singledispatch
1865 def g(arg):
1866 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02001867 class O(c.Sized):
1868 def __len__(self):
1869 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001870 o = O()
1871 self.assertEqual(g(o), "base")
1872 g.register(c.Iterable, lambda arg: "iterable")
1873 g.register(c.Container, lambda arg: "container")
1874 g.register(c.Sized, lambda arg: "sized")
1875 g.register(c.Set, lambda arg: "set")
1876 self.assertEqual(g(o), "sized")
1877 c.Iterable.register(O)
1878 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
1879 c.Container.register(O)
1880 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02001881 c.Set.register(O)
1882 self.assertEqual(g(o), "set") # because c.Set is a subclass of
1883 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02001884 class P:
1885 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02001886 p = P()
1887 self.assertEqual(g(p), "base")
1888 c.Iterable.register(P)
1889 self.assertEqual(g(p), "iterable")
1890 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02001891 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02001892 g(p)
Benjamin Petersonab078e92016-07-13 21:13:29 -07001893 self.assertIn(
1894 str(re_one.exception),
1895 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1896 "or <class 'collections.abc.Iterable'>"),
1897 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
1898 "or <class 'collections.abc.Container'>")),
1899 )
Łukasz Langa6f692512013-06-05 12:20:24 +02001900 class Q(c.Sized):
1901 def __len__(self):
1902 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001903 q = Q()
1904 self.assertEqual(g(q), "sized")
1905 c.Iterable.register(Q)
1906 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
1907 c.Set.register(Q)
1908 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02001909 # c.Sized and c.Iterable
1910 @functools.singledispatch
1911 def h(arg):
1912 return "base"
1913 @h.register(c.Sized)
1914 def _(arg):
1915 return "sized"
1916 @h.register(c.Container)
1917 def _(arg):
1918 return "container"
1919 # Even though Sized and Container are explicit bases of MutableMapping,
1920 # this ABC is implicitly registered on defaultdict which makes all of
1921 # MutableMapping's bases implicit as well from defaultdict's
1922 # perspective.
1923 with self.assertRaises(RuntimeError) as re_two:
1924 h(c.defaultdict(lambda: 0))
Benjamin Petersonab078e92016-07-13 21:13:29 -07001925 self.assertIn(
1926 str(re_two.exception),
1927 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1928 "or <class 'collections.abc.Sized'>"),
1929 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1930 "or <class 'collections.abc.Container'>")),
1931 )
Łukasz Langa3720c772013-07-01 16:00:38 +02001932 class R(c.defaultdict):
1933 pass
1934 c.MutableSequence.register(R)
1935 @functools.singledispatch
1936 def i(arg):
1937 return "base"
1938 @i.register(c.MutableMapping)
1939 def _(arg):
1940 return "mapping"
1941 @i.register(c.MutableSequence)
1942 def _(arg):
1943 return "sequence"
1944 r = R()
1945 self.assertEqual(i(r), "sequence")
1946 class S:
1947 pass
1948 class T(S, c.Sized):
1949 def __len__(self):
1950 return 0
1951 t = T()
1952 self.assertEqual(h(t), "sized")
1953 c.Container.register(T)
1954 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
1955 class U:
1956 def __len__(self):
1957 return 0
1958 u = U()
1959 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
1960 # from the existence of __len__()
1961 c.Container.register(U)
1962 # There is no preference for registered versus inferred ABCs.
1963 with self.assertRaises(RuntimeError) as re_three:
1964 h(u)
Benjamin Petersonab078e92016-07-13 21:13:29 -07001965 self.assertIn(
1966 str(re_three.exception),
1967 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1968 "or <class 'collections.abc.Sized'>"),
1969 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1970 "or <class 'collections.abc.Container'>")),
1971 )
Łukasz Langa3720c772013-07-01 16:00:38 +02001972 class V(c.Sized, S):
1973 def __len__(self):
1974 return 0
1975 @functools.singledispatch
1976 def j(arg):
1977 return "base"
1978 @j.register(S)
1979 def _(arg):
1980 return "s"
1981 @j.register(c.Container)
1982 def _(arg):
1983 return "container"
1984 v = V()
1985 self.assertEqual(j(v), "s")
1986 c.Container.register(V)
1987 self.assertEqual(j(v), "container") # because it ends up right after
1988 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02001989
1990 def test_cache_invalidation(self):
1991 from collections import UserDict
1992 class TracingDict(UserDict):
1993 def __init__(self, *args, **kwargs):
1994 super(TracingDict, self).__init__(*args, **kwargs)
1995 self.set_ops = []
1996 self.get_ops = []
1997 def __getitem__(self, key):
1998 result = self.data[key]
1999 self.get_ops.append(key)
2000 return result
2001 def __setitem__(self, key, value):
2002 self.set_ops.append(key)
2003 self.data[key] = value
2004 def clear(self):
2005 self.data.clear()
2006 _orig_wkd = functools.WeakKeyDictionary
2007 td = TracingDict()
2008 functools.WeakKeyDictionary = lambda: td
2009 c = collections
2010 @functools.singledispatch
2011 def g(arg):
2012 return "base"
2013 d = {}
2014 l = []
2015 self.assertEqual(len(td), 0)
2016 self.assertEqual(g(d), "base")
2017 self.assertEqual(len(td), 1)
2018 self.assertEqual(td.get_ops, [])
2019 self.assertEqual(td.set_ops, [dict])
2020 self.assertEqual(td.data[dict], g.registry[object])
2021 self.assertEqual(g(l), "base")
2022 self.assertEqual(len(td), 2)
2023 self.assertEqual(td.get_ops, [])
2024 self.assertEqual(td.set_ops, [dict, list])
2025 self.assertEqual(td.data[dict], g.registry[object])
2026 self.assertEqual(td.data[list], g.registry[object])
2027 self.assertEqual(td.data[dict], td.data[list])
2028 self.assertEqual(g(l), "base")
2029 self.assertEqual(g(d), "base")
2030 self.assertEqual(td.get_ops, [list, dict])
2031 self.assertEqual(td.set_ops, [dict, list])
2032 g.register(list, lambda arg: "list")
2033 self.assertEqual(td.get_ops, [list, dict])
2034 self.assertEqual(len(td), 0)
2035 self.assertEqual(g(d), "base")
2036 self.assertEqual(len(td), 1)
2037 self.assertEqual(td.get_ops, [list, dict])
2038 self.assertEqual(td.set_ops, [dict, list, dict])
2039 self.assertEqual(td.data[dict],
2040 functools._find_impl(dict, g.registry))
2041 self.assertEqual(g(l), "list")
2042 self.assertEqual(len(td), 2)
2043 self.assertEqual(td.get_ops, [list, dict])
2044 self.assertEqual(td.set_ops, [dict, list, dict, list])
2045 self.assertEqual(td.data[list],
2046 functools._find_impl(list, g.registry))
2047 class X:
2048 pass
2049 c.MutableMapping.register(X) # Will not invalidate the cache,
2050 # not using ABCs yet.
2051 self.assertEqual(g(d), "base")
2052 self.assertEqual(g(l), "list")
2053 self.assertEqual(td.get_ops, [list, dict, dict, list])
2054 self.assertEqual(td.set_ops, [dict, list, dict, list])
2055 g.register(c.Sized, lambda arg: "sized")
2056 self.assertEqual(len(td), 0)
2057 self.assertEqual(g(d), "sized")
2058 self.assertEqual(len(td), 1)
2059 self.assertEqual(td.get_ops, [list, dict, dict, list])
2060 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
2061 self.assertEqual(g(l), "list")
2062 self.assertEqual(len(td), 2)
2063 self.assertEqual(td.get_ops, [list, dict, dict, list])
2064 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2065 self.assertEqual(g(l), "list")
2066 self.assertEqual(g(d), "sized")
2067 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
2068 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2069 g.dispatch(list)
2070 g.dispatch(dict)
2071 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
2072 list, dict])
2073 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2074 c.MutableSet.register(X) # Will invalidate the cache.
2075 self.assertEqual(len(td), 2) # Stale cache.
2076 self.assertEqual(g(l), "list")
2077 self.assertEqual(len(td), 1)
2078 g.register(c.MutableMapping, lambda arg: "mutablemapping")
2079 self.assertEqual(len(td), 0)
2080 self.assertEqual(g(d), "mutablemapping")
2081 self.assertEqual(len(td), 1)
2082 self.assertEqual(g(l), "list")
2083 self.assertEqual(len(td), 2)
2084 g.register(dict, lambda arg: "dict")
2085 self.assertEqual(g(d), "dict")
2086 self.assertEqual(g(l), "list")
2087 g._clear_cache()
2088 self.assertEqual(len(td), 0)
2089 functools.WeakKeyDictionary = _orig_wkd
2090
2091
Raymond Hettinger9c323f82005-02-28 19:39:44 +00002092if __name__ == '__main__':
Zachary Ware38c707e2015-04-13 15:00:43 -05002093 unittest.main()