blob: edd5773e13d549314bb5fc551e2b77e3388d3363 [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08002import builtins
Raymond Hettinger003be522011-05-03 11:01:32 -07003import collections
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03004import collections.abc
Serhiy Storchaka45120f22015-10-24 09:49:56 +03005import copy
Pablo Galindo2f172d82020-06-01 00:41:14 +01006from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00007import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00008from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02009import sys
10from test import support
Antoine Pitroua6a4dc82017-09-07 18:56:24 +020011import threading
Serhiy Storchaka67796522017-01-12 18:34:33 +020012import time
Łukasz Langae5697532017-12-11 13:56:31 -080013import typing
Łukasz Langa6f692512013-06-05 12:20:24 +020014import unittest
Raymond Hettingerd191ef22017-01-07 20:44:48 -080015import unittest.mock
Pablo Galindo99e6c262020-01-23 15:29:52 +000016import os
Dennis Sweeney1253c3e2020-05-05 17:14:32 -040017import weakref
18import gc
Łukasz Langa6f692512013-06-05 12:20:24 +020019from weakref import proxy
Nick Coghlan457fc9a2016-09-10 20:00:02 +100020import contextlib
Raymond Hettinger9c323f82005-02-28 19:39:44 +000021
Hai Shi3ddc6342020-06-30 21:46:06 +080022from test.support import import_helper
Hai Shie80697d2020-05-28 06:10:27 +080023from test.support import threading_helper
Pablo Galindo99e6c262020-01-23 15:29:52 +000024from test.support.script_helper import assert_python_ok
25
Antoine Pitroub5b37142012-11-13 21:35:40 +010026import functools
27
Hai Shi3ddc6342020-06-30 21:46:06 +080028py_functools = import_helper.import_fresh_module('functools',
29 blocked=['_functools'])
30c_functools = import_helper.import_fresh_module('functools',
31 fresh=['_functools'])
Antoine Pitroub5b37142012-11-13 21:35:40 +010032
Hai Shi3ddc6342020-06-30 21:46:06 +080033decimal = import_helper.import_fresh_module('decimal', fresh=['_decimal'])
Łukasz Langa6f692512013-06-05 12:20:24 +020034
Nick Coghlan457fc9a2016-09-10 20:00:02 +100035@contextlib.contextmanager
36def replaced_module(name, replacement):
37 original_module = sys.modules[name]
38 sys.modules[name] = replacement
39 try:
40 yield
41 finally:
42 sys.modules[name] = original_module
Łukasz Langa6f692512013-06-05 12:20:24 +020043
Raymond Hettinger9c323f82005-02-28 19:39:44 +000044def capture(*args, **kw):
45 """capture all positional and keyword arguments"""
46 return args, kw
47
Łukasz Langa6f692512013-06-05 12:20:24 +020048
Jack Diederiche0cbd692009-04-01 04:27:09 +000049def signature(part):
50 """ return the signature of a partial object """
51 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000052
Serhiy Storchaka38741282016-02-02 18:45:17 +020053class MyTuple(tuple):
54 pass
55
56class BadTuple(tuple):
57 def __add__(self, other):
58 return list(self) + list(other)
59
60class MyDict(dict):
61 pass
62
Łukasz Langa6f692512013-06-05 12:20:24 +020063
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020064class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000065
66 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010067 p = self.partial(capture, 1, 2, a=10, b=20)
68 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000069 self.assertEqual(p(3, 4, b=30, c=40),
70 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010071 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000072 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000073
74 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010075 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000076 # attributes should be readable
77 self.assertEqual(p.func, capture)
78 self.assertEqual(p.args, (1, 2))
79 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000080
81 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010082 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000083 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010084 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000085 except TypeError:
86 pass
87 else:
88 self.fail('First arg not checked for callability')
89
90 def test_protection_of_callers_dict_argument(self):
91 # a caller's dictionary should not be altered by partial
92 def func(a=10, b=20):
93 return a
94 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010095 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000096 self.assertEqual(p(**d), 3)
97 self.assertEqual(d, {'a':3})
98 p(b=7)
99 self.assertEqual(d, {'a':3})
100
Serhiy Storchaka9639e4a2017-02-20 14:04:30 +0200101 def test_kwargs_copy(self):
102 # Issue #29532: Altering a kwarg dictionary passed to a constructor
103 # should not affect a partial object after creation
104 d = {'a': 3}
105 p = self.partial(capture, **d)
106 self.assertEqual(p(), ((), {'a': 3}))
107 d['a'] = 5
108 self.assertEqual(p(), ((), {'a': 3}))
109
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000110 def test_arg_combinations(self):
111 # exercise special code paths for zero args in either partial
112 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100113 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000114 self.assertEqual(p(), ((), {}))
115 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100116 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000117 self.assertEqual(p(), ((1,2), {}))
118 self.assertEqual(p(3,4), ((1,2,3,4), {}))
119
120 def test_kw_combinations(self):
121 # exercise special code paths for no keyword args in
122 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100123 p = self.partial(capture)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400124 self.assertEqual(p.keywords, {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000125 self.assertEqual(p(), ((), {}))
126 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100127 p = self.partial(capture, a=1)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400128 self.assertEqual(p.keywords, {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000129 self.assertEqual(p(), ((), {'a':1}))
130 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
131 # keyword args in the call override those in the partial object
132 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
133
134 def test_positional(self):
135 # make sure positional arguments are captured correctly
136 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100137 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000138 expected = args + ('x',)
139 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000140 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000141
142 def test_keyword(self):
143 # make sure keyword arguments are captured correctly
144 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100145 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000146 expected = {'a':a,'x':None}
147 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000148 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000149
150 def test_no_side_effects(self):
151 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100152 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000153 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000154 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000155 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000156 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000157
158 def test_error_propagation(self):
159 def f(x, y):
160 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100161 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
162 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
163 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
164 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000165
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000166 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100167 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000168 p = proxy(f)
169 self.assertEqual(f.func, p.func)
170 f = None
171 self.assertRaises(ReferenceError, getattr, p, 'func')
172
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000173 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000174 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100175 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000176 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100177 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000178 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000179
Alexander Belopolskye49af342015-03-01 15:08:17 -0500180 def test_nested_optimization(self):
181 partial = self.partial
Alexander Belopolskye49af342015-03-01 15:08:17 -0500182 inner = partial(signature, 'asdf')
183 nested = partial(inner, bar=True)
184 flat = partial(signature, 'asdf', bar=True)
185 self.assertEqual(signature(nested), signature(flat))
186
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300187 def test_nested_partial_with_attribute(self):
188 # see issue 25137
189 partial = self.partial
190
191 def foo(bar):
192 return bar
193
194 p = partial(foo, 'first')
195 p2 = partial(p, 'second')
196 p2.new_attr = 'spam'
197 self.assertEqual(p2.new_attr, 'spam')
198
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000199 def test_repr(self):
200 args = (object(), object())
201 args_repr = ', '.join(repr(a) for a in args)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200202 kwargs = {'a': object(), 'b': object()}
203 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
204 'b={b!r}, a={a!r}'.format_map(kwargs)]
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000205 if self.partial in (c_functools.partial, py_functools.partial):
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000206 name = 'functools.partial'
207 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100208 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000209
Antoine Pitroub5b37142012-11-13 21:35:40 +0100210 f = self.partial(capture)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000211 self.assertEqual(f'{name}({capture!r})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000212
Antoine Pitroub5b37142012-11-13 21:35:40 +0100213 f = self.partial(capture, *args)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000214 self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000215
Antoine Pitroub5b37142012-11-13 21:35:40 +0100216 f = self.partial(capture, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200217 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000218 [f'{name}({capture!r}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200219 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000220
Antoine Pitroub5b37142012-11-13 21:35:40 +0100221 f = self.partial(capture, *args, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200222 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000223 [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200224 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000225
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300226 def test_recursive_repr(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000227 if self.partial in (c_functools.partial, py_functools.partial):
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300228 name = 'functools.partial'
229 else:
230 name = self.partial.__name__
231
232 f = self.partial(capture)
233 f.__setstate__((f, (), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300234 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000235 self.assertEqual(repr(f), '%s(...)' % (name,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300236 finally:
237 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300238
239 f = self.partial(capture)
240 f.__setstate__((capture, (f,), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300241 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000242 self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300243 finally:
244 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300245
246 f = self.partial(capture)
247 f.__setstate__((capture, (), {'a': f}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300248 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000249 self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300250 finally:
251 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300252
Jack Diederiche0cbd692009-04-01 04:27:09 +0000253 def test_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000254 with self.AllowPickle():
255 f = self.partial(signature, ['asdf'], bar=[True])
256 f.attr = []
257 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
258 f_copy = pickle.loads(pickle.dumps(f, proto))
259 self.assertEqual(signature(f_copy), signature(f))
Serhiy Storchaka38741282016-02-02 18:45:17 +0200260
261 def test_copy(self):
262 f = self.partial(signature, ['asdf'], bar=[True])
263 f.attr = []
264 f_copy = copy.copy(f)
265 self.assertEqual(signature(f_copy), signature(f))
266 self.assertIs(f_copy.attr, f.attr)
267 self.assertIs(f_copy.args, f.args)
268 self.assertIs(f_copy.keywords, f.keywords)
269
270 def test_deepcopy(self):
271 f = self.partial(signature, ['asdf'], bar=[True])
272 f.attr = []
273 f_copy = copy.deepcopy(f)
274 self.assertEqual(signature(f_copy), signature(f))
275 self.assertIsNot(f_copy.attr, f.attr)
276 self.assertIsNot(f_copy.args, f.args)
277 self.assertIsNot(f_copy.args[0], f.args[0])
278 self.assertIsNot(f_copy.keywords, f.keywords)
279 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
280
281 def test_setstate(self):
282 f = self.partial(signature)
283 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000284
Serhiy Storchaka38741282016-02-02 18:45:17 +0200285 self.assertEqual(signature(f),
286 (capture, (1,), dict(a=10), dict(attr=[])))
287 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
288
289 f.__setstate__((capture, (1,), dict(a=10), None))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000290
Serhiy Storchaka38741282016-02-02 18:45:17 +0200291 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
292 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
293
294 f.__setstate__((capture, (1,), None, None))
295 #self.assertEqual(signature(f), (capture, (1,), {}, {}))
296 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
297 self.assertEqual(f(2), ((1, 2), {}))
298 self.assertEqual(f(), ((1,), {}))
299
300 f.__setstate__((capture, (), {}, None))
301 self.assertEqual(signature(f), (capture, (), {}, {}))
302 self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
303 self.assertEqual(f(2), ((2,), {}))
304 self.assertEqual(f(), ((), {}))
305
306 def test_setstate_errors(self):
307 f = self.partial(signature)
308 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
309 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
310 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
311 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
312 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
313 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
314 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
315
316 def test_setstate_subclasses(self):
317 f = self.partial(signature)
318 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
319 s = signature(f)
320 self.assertEqual(s, (capture, (1,), dict(a=10), {}))
321 self.assertIs(type(s[1]), tuple)
322 self.assertIs(type(s[2]), dict)
323 r = f()
324 self.assertEqual(r, ((1,), {'a': 10}))
325 self.assertIs(type(r[0]), tuple)
326 self.assertIs(type(r[1]), dict)
327
328 f.__setstate__((capture, BadTuple((1,)), {}, None))
329 s = signature(f)
330 self.assertEqual(s, (capture, (1,), {}, {}))
331 self.assertIs(type(s[1]), tuple)
332 r = f(2)
333 self.assertEqual(r, ((1, 2), {}))
334 self.assertIs(type(r[0]), tuple)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000335
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300336 def test_recursive_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000337 with self.AllowPickle():
338 f = self.partial(capture)
339 f.__setstate__((f, (), {}, {}))
340 try:
341 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
342 with self.assertRaises(RecursionError):
343 pickle.dumps(f, proto)
344 finally:
345 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300346
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000347 f = self.partial(capture)
348 f.__setstate__((capture, (f,), {}, {}))
349 try:
350 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
351 f_copy = pickle.loads(pickle.dumps(f, proto))
352 try:
353 self.assertIs(f_copy.args[0], f_copy)
354 finally:
355 f_copy.__setstate__((capture, (), {}, {}))
356 finally:
357 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300358
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000359 f = self.partial(capture)
360 f.__setstate__((capture, (), {'a': f}, {}))
361 try:
362 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
363 f_copy = pickle.loads(pickle.dumps(f, proto))
364 try:
365 self.assertIs(f_copy.keywords['a'], f_copy)
366 finally:
367 f_copy.__setstate__((capture, (), {}, {}))
368 finally:
369 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300370
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200371 # Issue 6083: Reference counting bug
372 def test_setstate_refcount(self):
373 class BadSequence:
374 def __len__(self):
375 return 4
376 def __getitem__(self, key):
377 if key == 0:
378 return max
379 elif key == 1:
380 return tuple(range(1000000))
381 elif key in (2, 3):
382 return {}
383 raise IndexError
384
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200385 f = self.partial(object)
Serhiy Storchaka38741282016-02-02 18:45:17 +0200386 self.assertRaises(TypeError, f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000387
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000388@unittest.skipUnless(c_functools, 'requires the C _functools module')
389class TestPartialC(TestPartial, unittest.TestCase):
390 if c_functools:
391 partial = c_functools.partial
392
393 class AllowPickle:
394 def __enter__(self):
395 return self
396 def __exit__(self, type, value, tb):
397 return False
398
399 def test_attributes_unwritable(self):
400 # attributes should not be writable
401 p = self.partial(capture, 1, 2, a=10, b=20)
402 self.assertRaises(AttributeError, setattr, p, 'func', map)
403 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
404 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
405
406 p = self.partial(hex)
407 try:
408 del p.__dict__
409 except TypeError:
410 pass
411 else:
412 self.fail('partial object allowed __dict__ to be deleted')
Łukasz Langa6f692512013-06-05 12:20:24 +0200413
Michael Seifert6c3d5272017-03-15 06:26:33 +0100414 def test_manually_adding_non_string_keyword(self):
415 p = self.partial(capture)
416 # Adding a non-string/unicode keyword to partial kwargs
417 p.keywords[1234] = 'value'
418 r = repr(p)
419 self.assertIn('1234', r)
420 self.assertIn("'value'", r)
421 with self.assertRaises(TypeError):
422 p()
423
424 def test_keystr_replaces_value(self):
425 p = self.partial(capture)
426
427 class MutatesYourDict(object):
428 def __str__(self):
429 p.keywords[self] = ['sth2']
430 return 'astr'
431
Mike53f7a7c2017-12-14 14:04:53 +0300432 # Replacing the value during key formatting should keep the original
Michael Seifert6c3d5272017-03-15 06:26:33 +0100433 # value alive (at least long enough).
434 p.keywords[MutatesYourDict()] = ['sth']
435 r = repr(p)
436 self.assertIn('astr', r)
437 self.assertIn("['sth']", r)
438
439
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200440class TestPartialPy(TestPartial, unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000441 partial = py_functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000442
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000443 class AllowPickle:
444 def __init__(self):
445 self._cm = replaced_module("functools", py_functools)
446 def __enter__(self):
447 return self._cm.__enter__()
448 def __exit__(self, type, value, tb):
449 return self._cm.__exit__(type, value, tb)
Łukasz Langa6f692512013-06-05 12:20:24 +0200450
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200451if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000452 class CPartialSubclass(c_functools.partial):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200453 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100454
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000455class PyPartialSubclass(py_functools.partial):
456 pass
Łukasz Langa6f692512013-06-05 12:20:24 +0200457
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200458@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200459class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200460 if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000461 partial = CPartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000462
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300463 # partial subclasses are not optimized for nested calls
464 test_nested_optimization = None
465
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000466class TestPartialPySubclass(TestPartialPy):
467 partial = PyPartialSubclass
Łukasz Langa6f692512013-06-05 12:20:24 +0200468
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000469class TestPartialMethod(unittest.TestCase):
470
471 class A(object):
472 nothing = functools.partialmethod(capture)
473 positional = functools.partialmethod(capture, 1)
474 keywords = functools.partialmethod(capture, a=2)
475 both = functools.partialmethod(capture, 3, b=4)
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300476 spec_keywords = functools.partialmethod(capture, self=1, func=2)
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000477
478 nested = functools.partialmethod(positional, 5)
479
480 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
481
482 static = functools.partialmethod(staticmethod(capture), 8)
483 cls = functools.partialmethod(classmethod(capture), d=9)
484
485 a = A()
486
487 def test_arg_combinations(self):
488 self.assertEqual(self.a.nothing(), ((self.a,), {}))
489 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
490 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
491 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
492
493 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
494 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
495 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
496 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
497
498 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
499 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
500 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
501 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
502
503 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
504 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
505 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
506 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
507
508 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
509
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300510 self.assertEqual(self.a.spec_keywords(), ((self.a,), {'self': 1, 'func': 2}))
511
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000512 def test_nested(self):
513 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
514 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
515 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
516 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
517
518 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
519
520 def test_over_partial(self):
521 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
522 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
523 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
524 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
525
526 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
527
528 def test_bound_method_introspection(self):
529 obj = self.a
530 self.assertIs(obj.both.__self__, obj)
531 self.assertIs(obj.nested.__self__, obj)
532 self.assertIs(obj.over_partial.__self__, obj)
533 self.assertIs(obj.cls.__self__, self.A)
534 self.assertIs(self.A.cls.__self__, self.A)
535
536 def test_unbound_method_retrieval(self):
537 obj = self.A
538 self.assertFalse(hasattr(obj.both, "__self__"))
539 self.assertFalse(hasattr(obj.nested, "__self__"))
540 self.assertFalse(hasattr(obj.over_partial, "__self__"))
541 self.assertFalse(hasattr(obj.static, "__self__"))
542 self.assertFalse(hasattr(self.a.static, "__self__"))
543
544 def test_descriptors(self):
545 for obj in [self.A, self.a]:
546 with self.subTest(obj=obj):
547 self.assertEqual(obj.static(), ((8,), {}))
548 self.assertEqual(obj.static(5), ((8, 5), {}))
549 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
550 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
551
552 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
553 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
554 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
555 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
556
557 def test_overriding_keywords(self):
558 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
559 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
560
561 def test_invalid_args(self):
562 with self.assertRaises(TypeError):
563 class B(object):
564 method = functools.partialmethod(None, 1)
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300565 with self.assertRaises(TypeError):
566 class B:
567 method = functools.partialmethod()
Serhiy Storchaka142566c2019-06-05 18:22:31 +0300568 with self.assertRaises(TypeError):
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300569 class B:
570 method = functools.partialmethod(func=capture, a=1)
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000571
572 def test_repr(self):
573 self.assertEqual(repr(vars(self.A)['both']),
574 'functools.partialmethod({}, 3, b=4)'.format(capture))
575
576 def test_abstract(self):
577 class Abstract(abc.ABCMeta):
578
579 @abc.abstractmethod
580 def add(self, x, y):
581 pass
582
583 add5 = functools.partialmethod(add, 5)
584
585 self.assertTrue(Abstract.add.__isabstractmethod__)
586 self.assertTrue(Abstract.add5.__isabstractmethod__)
587
588 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
589 self.assertFalse(getattr(func, '__isabstractmethod__', False))
590
Pablo Galindo8c77b8c2019-04-29 13:36:57 +0100591 def test_positional_only(self):
592 def f(a, b, /):
593 return a + b
594
595 p = functools.partial(f, 1)
596 self.assertEqual(p(2), f(1, 2))
597
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000598
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000599class TestUpdateWrapper(unittest.TestCase):
600
601 def check_wrapper(self, wrapper, wrapped,
602 assigned=functools.WRAPPER_ASSIGNMENTS,
603 updated=functools.WRAPPER_UPDATES):
604 # Check attributes were assigned
605 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000606 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000607 # Check attributes were updated
608 for name in updated:
609 wrapper_attr = getattr(wrapper, name)
610 wrapped_attr = getattr(wrapped, name)
611 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000612 if name == "__dict__" and key == "__wrapped__":
613 # __wrapped__ is overwritten by the update code
614 continue
615 self.assertIs(wrapped_attr[key], wrapper_attr[key])
616 # Check __wrapped__
617 self.assertIs(wrapper.__wrapped__, wrapped)
618
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000619
R. David Murray378c0cf2010-02-24 01:46:21 +0000620 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000621 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000622 """This is a test"""
623 pass
624 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000625 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000626 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000627 pass
628 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000629 return wrapper, f
630
631 def test_default_update(self):
632 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000633 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000634 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000635 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600636 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000637 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000638 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
639 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000640
R. David Murray378c0cf2010-02-24 01:46:21 +0000641 @unittest.skipIf(sys.flags.optimize >= 2,
642 "Docstrings are omitted with -O2 and above")
643 def test_default_update_doc(self):
644 wrapper, f = self._default_update()
645 self.assertEqual(wrapper.__doc__, 'This is a test')
646
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000647 def test_no_update(self):
648 def f():
649 """This is a test"""
650 pass
651 f.attr = 'This is also a test'
652 def wrapper():
653 pass
654 functools.update_wrapper(wrapper, f, (), ())
655 self.check_wrapper(wrapper, f, (), ())
656 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600657 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000658 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000659 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000660 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000661
662 def test_selective_update(self):
663 def f():
664 pass
665 f.attr = 'This is a different test'
666 f.dict_attr = dict(a=1, b=2, c=3)
667 def wrapper():
668 pass
669 wrapper.dict_attr = {}
670 assign = ('attr',)
671 update = ('dict_attr',)
672 functools.update_wrapper(wrapper, f, assign, update)
673 self.check_wrapper(wrapper, f, assign, update)
674 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600675 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000676 self.assertEqual(wrapper.__doc__, None)
677 self.assertEqual(wrapper.attr, 'This is a different test')
678 self.assertEqual(wrapper.dict_attr, f.dict_attr)
679
Nick Coghlan98876832010-08-17 06:17:18 +0000680 def test_missing_attributes(self):
681 def f():
682 pass
683 def wrapper():
684 pass
685 wrapper.dict_attr = {}
686 assign = ('attr',)
687 update = ('dict_attr',)
688 # Missing attributes on wrapped object are ignored
689 functools.update_wrapper(wrapper, f, assign, update)
690 self.assertNotIn('attr', wrapper.__dict__)
691 self.assertEqual(wrapper.dict_attr, {})
692 # Wrapper must have expected attributes for updating
693 del wrapper.dict_attr
694 with self.assertRaises(AttributeError):
695 functools.update_wrapper(wrapper, f, assign, update)
696 wrapper.dict_attr = 1
697 with self.assertRaises(AttributeError):
698 functools.update_wrapper(wrapper, f, assign, update)
699
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200700 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000701 @unittest.skipIf(sys.flags.optimize >= 2,
702 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000703 def test_builtin_update(self):
704 # Test for bug #1576241
705 def wrapper():
706 pass
707 functools.update_wrapper(wrapper, max)
708 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000709 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000710 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000711
Łukasz Langa6f692512013-06-05 12:20:24 +0200712
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000713class TestWraps(TestUpdateWrapper):
714
R. David Murray378c0cf2010-02-24 01:46:21 +0000715 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000716 def f():
717 """This is a test"""
718 pass
719 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000720 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000721 @functools.wraps(f)
722 def wrapper():
723 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600724 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000725
726 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600727 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000728 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000729 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600730 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000731 self.assertEqual(wrapper.attr, 'This is also a test')
732
Antoine Pitroub5b37142012-11-13 21:35:40 +0100733 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000734 "Docstrings are omitted with -O2 and above")
735 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600736 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000737 self.assertEqual(wrapper.__doc__, 'This is a test')
738
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000739 def test_no_update(self):
740 def f():
741 """This is a test"""
742 pass
743 f.attr = 'This is also a test'
744 @functools.wraps(f, (), ())
745 def wrapper():
746 pass
747 self.check_wrapper(wrapper, f, (), ())
748 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600749 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000750 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000751 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000752
753 def test_selective_update(self):
754 def f():
755 pass
756 f.attr = 'This is a different test'
757 f.dict_attr = dict(a=1, b=2, c=3)
758 def add_dict_attr(f):
759 f.dict_attr = {}
760 return f
761 assign = ('attr',)
762 update = ('dict_attr',)
763 @functools.wraps(f, assign, update)
764 @add_dict_attr
765 def wrapper():
766 pass
767 self.check_wrapper(wrapper, f, assign, update)
768 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600769 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000770 self.assertEqual(wrapper.__doc__, None)
771 self.assertEqual(wrapper.attr, 'This is a different test')
772 self.assertEqual(wrapper.dict_attr, f.dict_attr)
773
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000774
madman-bobe25d5fc2018-10-25 15:02:10 +0100775class TestReduce:
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000776 def test_reduce(self):
777 class Squares:
778 def __init__(self, max):
779 self.max = max
780 self.sofar = []
781
782 def __len__(self):
783 return len(self.sofar)
784
785 def __getitem__(self, i):
786 if not 0 <= i < self.max: raise IndexError
787 n = len(self.sofar)
788 while n <= i:
789 self.sofar.append(n*n)
790 n += 1
791 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000792 def add(x, y):
793 return x + y
madman-bobe25d5fc2018-10-25 15:02:10 +0100794 self.assertEqual(self.reduce(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000795 self.assertEqual(
madman-bobe25d5fc2018-10-25 15:02:10 +0100796 self.reduce(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000797 ['a','c','d','w']
798 )
madman-bobe25d5fc2018-10-25 15:02:10 +0100799 self.assertEqual(self.reduce(lambda x, y: x*y, range(2,8), 1), 5040)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000800 self.assertEqual(
madman-bobe25d5fc2018-10-25 15:02:10 +0100801 self.reduce(lambda x, y: x*y, range(2,21), 1),
Guido van Rossume2a383d2007-01-15 16:59:06 +0000802 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000803 )
madman-bobe25d5fc2018-10-25 15:02:10 +0100804 self.assertEqual(self.reduce(add, Squares(10)), 285)
805 self.assertEqual(self.reduce(add, Squares(10), 0), 285)
806 self.assertEqual(self.reduce(add, Squares(0), 0), 0)
807 self.assertRaises(TypeError, self.reduce)
808 self.assertRaises(TypeError, self.reduce, 42, 42)
809 self.assertRaises(TypeError, self.reduce, 42, 42, 42)
810 self.assertEqual(self.reduce(42, "1"), "1") # func is never called with one item
811 self.assertEqual(self.reduce(42, "", "1"), "1") # func is never called with one item
812 self.assertRaises(TypeError, self.reduce, 42, (42, 42))
813 self.assertRaises(TypeError, self.reduce, add, []) # arg 2 must not be empty sequence with no initial value
814 self.assertRaises(TypeError, self.reduce, add, "")
815 self.assertRaises(TypeError, self.reduce, add, ())
816 self.assertRaises(TypeError, self.reduce, add, object())
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000817
818 class TestFailingIter:
819 def __iter__(self):
820 raise RuntimeError
madman-bobe25d5fc2018-10-25 15:02:10 +0100821 self.assertRaises(RuntimeError, self.reduce, add, TestFailingIter())
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000822
madman-bobe25d5fc2018-10-25 15:02:10 +0100823 self.assertEqual(self.reduce(add, [], None), None)
824 self.assertEqual(self.reduce(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000825
826 class BadSeq:
827 def __getitem__(self, index):
828 raise ValueError
madman-bobe25d5fc2018-10-25 15:02:10 +0100829 self.assertRaises(ValueError, self.reduce, 42, BadSeq())
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000830
831 # Test reduce()'s use of iterators.
832 def test_iterator_usage(self):
833 class SequenceClass:
834 def __init__(self, n):
835 self.n = n
836 def __getitem__(self, i):
837 if 0 <= i < self.n:
838 return i
839 else:
840 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000841
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000842 from operator import add
madman-bobe25d5fc2018-10-25 15:02:10 +0100843 self.assertEqual(self.reduce(add, SequenceClass(5)), 10)
844 self.assertEqual(self.reduce(add, SequenceClass(5), 42), 52)
845 self.assertRaises(TypeError, self.reduce, add, SequenceClass(0))
846 self.assertEqual(self.reduce(add, SequenceClass(0), 42), 42)
847 self.assertEqual(self.reduce(add, SequenceClass(1)), 0)
848 self.assertEqual(self.reduce(add, SequenceClass(1), 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000849
850 d = {"one": 1, "two": 2, "three": 3}
madman-bobe25d5fc2018-10-25 15:02:10 +0100851 self.assertEqual(self.reduce(add, d), "".join(d.keys()))
852
853
854@unittest.skipUnless(c_functools, 'requires the C _functools module')
855class TestReduceC(TestReduce, unittest.TestCase):
856 if c_functools:
857 reduce = c_functools.reduce
858
859
860class TestReducePy(TestReduce, unittest.TestCase):
861 reduce = staticmethod(py_functools.reduce)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000862
Łukasz Langa6f692512013-06-05 12:20:24 +0200863
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200864class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700865
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000866 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700867 def cmp1(x, y):
868 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100869 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700870 self.assertEqual(key(3), key(3))
871 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100872 self.assertGreaterEqual(key(3), key(3))
873
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700874 def cmp2(x, y):
875 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100876 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700877 self.assertEqual(key(4.0), key('4'))
878 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100879 self.assertLessEqual(key(2), key('35'))
880 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700881
882 def test_cmp_to_key_arguments(self):
883 def cmp1(x, y):
884 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100885 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700886 self.assertEqual(key(obj=3), key(obj=3))
887 self.assertGreater(key(obj=3), key(obj=1))
888 with self.assertRaises((TypeError, AttributeError)):
889 key(3) > 1 # rhs is not a K object
890 with self.assertRaises((TypeError, AttributeError)):
891 1 < key(3) # lhs is not a K object
892 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100893 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700894 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200895 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100896 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700897 with self.assertRaises(TypeError):
898 key() # too few args
899 with self.assertRaises(TypeError):
900 key(None, None) # too many args
901
902 def test_bad_cmp(self):
903 def cmp1(x, y):
904 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100905 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700906 with self.assertRaises(ZeroDivisionError):
907 key(3) > key(1)
908
909 class BadCmp:
910 def __lt__(self, other):
911 raise ZeroDivisionError
912 def cmp1(x, y):
913 return BadCmp()
914 with self.assertRaises(ZeroDivisionError):
915 key(3) > key(1)
916
917 def test_obj_field(self):
918 def cmp1(x, y):
919 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100920 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700921 self.assertEqual(key(50).obj, 50)
922
923 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000924 def mycmp(x, y):
925 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100926 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000927 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000928
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700929 def test_sort_int_str(self):
930 def mycmp(x, y):
931 x, y = int(x), int(y)
932 return (x > y) - (x < y)
933 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100934 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700935 self.assertEqual([int(value) for value in values],
936 [0, 1, 1, 2, 3, 4, 5, 7, 10])
937
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000938 def test_hash(self):
939 def mycmp(x, y):
940 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100941 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000942 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700943 self.assertRaises(TypeError, hash, k)
Serhiy Storchaka2e576f52017-04-24 09:05:00 +0300944 self.assertNotIsInstance(k, collections.abc.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000945
Łukasz Langa6f692512013-06-05 12:20:24 +0200946
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200947@unittest.skipUnless(c_functools, 'requires the C _functools module')
948class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
949 if c_functools:
950 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100951
Łukasz Langa6f692512013-06-05 12:20:24 +0200952
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200953class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100954 cmp_to_key = staticmethod(py_functools.cmp_to_key)
955
Łukasz Langa6f692512013-06-05 12:20:24 +0200956
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000957class TestTotalOrdering(unittest.TestCase):
958
959 def test_total_ordering_lt(self):
960 @functools.total_ordering
961 class A:
962 def __init__(self, value):
963 self.value = value
964 def __lt__(self, other):
965 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000966 def __eq__(self, other):
967 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000968 self.assertTrue(A(1) < A(2))
969 self.assertTrue(A(2) > A(1))
970 self.assertTrue(A(1) <= A(2))
971 self.assertTrue(A(2) >= A(1))
972 self.assertTrue(A(2) <= A(2))
973 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000974 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000975
976 def test_total_ordering_le(self):
977 @functools.total_ordering
978 class A:
979 def __init__(self, value):
980 self.value = value
981 def __le__(self, other):
982 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000983 def __eq__(self, other):
984 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000985 self.assertTrue(A(1) < A(2))
986 self.assertTrue(A(2) > A(1))
987 self.assertTrue(A(1) <= A(2))
988 self.assertTrue(A(2) >= A(1))
989 self.assertTrue(A(2) <= A(2))
990 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000991 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000992
993 def test_total_ordering_gt(self):
994 @functools.total_ordering
995 class A:
996 def __init__(self, value):
997 self.value = value
998 def __gt__(self, other):
999 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001000 def __eq__(self, other):
1001 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +00001002 self.assertTrue(A(1) < A(2))
1003 self.assertTrue(A(2) > A(1))
1004 self.assertTrue(A(1) <= A(2))
1005 self.assertTrue(A(2) >= A(1))
1006 self.assertTrue(A(2) <= A(2))
1007 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +10001008 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +00001009
1010 def test_total_ordering_ge(self):
1011 @functools.total_ordering
1012 class A:
1013 def __init__(self, value):
1014 self.value = value
1015 def __ge__(self, other):
1016 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001017 def __eq__(self, other):
1018 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +00001019 self.assertTrue(A(1) < A(2))
1020 self.assertTrue(A(2) > A(1))
1021 self.assertTrue(A(1) <= A(2))
1022 self.assertTrue(A(2) >= A(1))
1023 self.assertTrue(A(2) <= A(2))
1024 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +10001025 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +00001026
1027 def test_total_ordering_no_overwrite(self):
1028 # new methods should not overwrite existing
1029 @functools.total_ordering
1030 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +00001031 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +00001032 self.assertTrue(A(1) < A(2))
1033 self.assertTrue(A(2) > A(1))
1034 self.assertTrue(A(1) <= A(2))
1035 self.assertTrue(A(2) >= A(1))
1036 self.assertTrue(A(2) <= A(2))
1037 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001038
Benjamin Peterson42ebee32010-04-11 01:43:16 +00001039 def test_no_operations_defined(self):
1040 with self.assertRaises(ValueError):
1041 @functools.total_ordering
1042 class A:
1043 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001044
Nick Coghlanf05d9812013-10-02 00:02:03 +10001045 def test_type_error_when_not_implemented(self):
1046 # bug 10042; ensure stack overflow does not occur
1047 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001048 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +10001049 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001050 def __init__(self, value):
1051 self.value = value
1052 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001053 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001054 return self.value == other.value
1055 return False
1056 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001057 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001058 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +10001059 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001060
Nick Coghlanf05d9812013-10-02 00:02:03 +10001061 @functools.total_ordering
1062 class ImplementsGreaterThan:
1063 def __init__(self, value):
1064 self.value = value
1065 def __eq__(self, other):
1066 if isinstance(other, ImplementsGreaterThan):
1067 return self.value == other.value
1068 return False
1069 def __gt__(self, other):
1070 if isinstance(other, ImplementsGreaterThan):
1071 return self.value > other.value
1072 return NotImplemented
1073
1074 @functools.total_ordering
1075 class ImplementsLessThanEqualTo:
1076 def __init__(self, value):
1077 self.value = value
1078 def __eq__(self, other):
1079 if isinstance(other, ImplementsLessThanEqualTo):
1080 return self.value == other.value
1081 return False
1082 def __le__(self, other):
1083 if isinstance(other, ImplementsLessThanEqualTo):
1084 return self.value <= other.value
1085 return NotImplemented
1086
1087 @functools.total_ordering
1088 class ImplementsGreaterThanEqualTo:
1089 def __init__(self, value):
1090 self.value = value
1091 def __eq__(self, other):
1092 if isinstance(other, ImplementsGreaterThanEqualTo):
1093 return self.value == other.value
1094 return False
1095 def __ge__(self, other):
1096 if isinstance(other, ImplementsGreaterThanEqualTo):
1097 return self.value >= other.value
1098 return NotImplemented
1099
1100 @functools.total_ordering
1101 class ComparatorNotImplemented:
1102 def __init__(self, value):
1103 self.value = value
1104 def __eq__(self, other):
1105 if isinstance(other, ComparatorNotImplemented):
1106 return self.value == other.value
1107 return False
1108 def __lt__(self, other):
1109 return NotImplemented
1110
1111 with self.subTest("LT < 1"), self.assertRaises(TypeError):
1112 ImplementsLessThan(-1) < 1
1113
1114 with self.subTest("LT < LE"), self.assertRaises(TypeError):
1115 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
1116
1117 with self.subTest("LT < GT"), self.assertRaises(TypeError):
1118 ImplementsLessThan(1) < ImplementsGreaterThan(1)
1119
1120 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
1121 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
1122
1123 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
1124 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
1125
1126 with self.subTest("GT > GE"), self.assertRaises(TypeError):
1127 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
1128
1129 with self.subTest("GT > LT"), self.assertRaises(TypeError):
1130 ImplementsGreaterThan(5) > ImplementsLessThan(5)
1131
1132 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
1133 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
1134
1135 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
1136 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
1137
1138 with self.subTest("GE when equal"):
1139 a = ComparatorNotImplemented(8)
1140 b = ComparatorNotImplemented(8)
1141 self.assertEqual(a, b)
1142 with self.assertRaises(TypeError):
1143 a >= b
1144
1145 with self.subTest("LE when equal"):
1146 a = ComparatorNotImplemented(9)
1147 b = ComparatorNotImplemented(9)
1148 self.assertEqual(a, b)
1149 with self.assertRaises(TypeError):
1150 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +02001151
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001152 def test_pickle(self):
Serhiy Storchaka92bb90a2016-09-22 11:39:25 +03001153 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001154 for name in '__lt__', '__gt__', '__le__', '__ge__':
1155 with self.subTest(method=name, proto=proto):
1156 method = getattr(Orderable_LT, name)
1157 method_copy = pickle.loads(pickle.dumps(method, proto))
1158 self.assertIs(method_copy, method)
1159
1160@functools.total_ordering
1161class Orderable_LT:
1162 def __init__(self, value):
1163 self.value = value
1164 def __lt__(self, other):
1165 return self.value < other.value
1166 def __eq__(self, other):
1167 return self.value == other.value
1168
1169
Raymond Hettinger21cdb712020-05-11 17:00:53 -07001170class TestCache:
1171 # This tests that the pass-through is working as designed.
1172 # The underlying functionality is tested in TestLRU.
1173
1174 def test_cache(self):
1175 @self.module.cache
1176 def fib(n):
1177 if n < 2:
1178 return n
1179 return fib(n-1) + fib(n-2)
1180 self.assertEqual([fib(n) for n in range(16)],
1181 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1182 self.assertEqual(fib.cache_info(),
1183 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1184 fib.cache_clear()
1185 self.assertEqual(fib.cache_info(),
1186 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1187
1188
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001189class TestLRU:
Georg Brandl2e7346a2010-07-31 18:09:23 +00001190
1191 def test_lru(self):
1192 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001193 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001194 f = self.module.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001195 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001196 self.assertEqual(maxsize, 20)
1197 self.assertEqual(currsize, 0)
1198 self.assertEqual(hits, 0)
1199 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001200
1201 domain = range(5)
1202 for i in range(1000):
1203 x, y = choice(domain), choice(domain)
1204 actual = f(x, y)
1205 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +00001206 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001207 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001208 self.assertTrue(hits > misses)
1209 self.assertEqual(hits + misses, 1000)
1210 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001211
Raymond Hettinger02566ec2010-09-04 22:46:06 +00001212 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +00001213 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001214 self.assertEqual(hits, 0)
1215 self.assertEqual(misses, 0)
1216 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001217 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001218 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001219 self.assertEqual(hits, 0)
1220 self.assertEqual(misses, 1)
1221 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001222
Nick Coghlan98876832010-08-17 06:17:18 +00001223 # Test bypassing the cache
1224 self.assertIs(f.__wrapped__, orig)
1225 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001226 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001227 self.assertEqual(hits, 0)
1228 self.assertEqual(misses, 1)
1229 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +00001230
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001231 # test size zero (which means "never-cache")
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001232 @self.module.lru_cache(0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001233 def f():
1234 nonlocal f_cnt
1235 f_cnt += 1
1236 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001237 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001238 f_cnt = 0
1239 for i in range(5):
1240 self.assertEqual(f(), 20)
1241 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001242 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001243 self.assertEqual(hits, 0)
1244 self.assertEqual(misses, 5)
1245 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001246
1247 # test size one
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001248 @self.module.lru_cache(1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001249 def f():
1250 nonlocal f_cnt
1251 f_cnt += 1
1252 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001253 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001254 f_cnt = 0
1255 for i in range(5):
1256 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001257 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001258 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001259 self.assertEqual(hits, 4)
1260 self.assertEqual(misses, 1)
1261 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001262
Raymond Hettingerf3098282010-08-15 03:30:45 +00001263 # test size two
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001264 @self.module.lru_cache(2)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001265 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001266 nonlocal f_cnt
1267 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +00001268 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +00001269 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001270 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +00001271 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1272 # * * * *
1273 self.assertEqual(f(x), x*10)
1274 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001275 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001276 self.assertEqual(hits, 12)
1277 self.assertEqual(misses, 4)
1278 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001279
Raymond Hettingerb8218682019-05-26 11:27:35 -07001280 def test_lru_no_args(self):
1281 @self.module.lru_cache
1282 def square(x):
1283 return x ** 2
1284
1285 self.assertEqual(list(map(square, [10, 20, 10])),
1286 [100, 400, 100])
1287 self.assertEqual(square.cache_info().hits, 1)
1288 self.assertEqual(square.cache_info().misses, 2)
1289 self.assertEqual(square.cache_info().maxsize, 128)
1290 self.assertEqual(square.cache_info().currsize, 2)
1291
Raymond Hettingerd8080c02019-01-26 03:02:00 -05001292 def test_lru_bug_35780(self):
1293 # C version of the lru_cache was not checking to see if
1294 # the user function call has already modified the cache
1295 # (this arises in recursive calls and in multi-threading).
1296 # This cause the cache to have orphan links not referenced
1297 # by the cache dictionary.
1298
1299 once = True # Modified by f(x) below
1300
1301 @self.module.lru_cache(maxsize=10)
1302 def f(x):
1303 nonlocal once
1304 rv = f'.{x}.'
1305 if x == 20 and once:
1306 once = False
1307 rv = f(x)
1308 return rv
1309
1310 # Fill the cache
1311 for x in range(15):
1312 self.assertEqual(f(x), f'.{x}.')
1313 self.assertEqual(f.cache_info().currsize, 10)
1314
1315 # Make a recursive call and make sure the cache remains full
1316 self.assertEqual(f(20), '.20.')
1317 self.assertEqual(f.cache_info().currsize, 10)
1318
Raymond Hettinger14adbd42019-04-20 07:20:44 -10001319 def test_lru_bug_36650(self):
1320 # C version of lru_cache was treating a call with an empty **kwargs
1321 # dictionary as being distinct from a call with no keywords at all.
1322 # This did not result in an incorrect answer, but it did trigger
1323 # an unexpected cache miss.
1324
1325 @self.module.lru_cache()
1326 def f(x):
1327 pass
1328
1329 f(0)
1330 f(0, **{})
1331 self.assertEqual(f.cache_info().hits, 1)
1332
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001333 def test_lru_hash_only_once(self):
1334 # To protect against weird reentrancy bugs and to improve
1335 # efficiency when faced with slow __hash__ methods, the
1336 # LRU cache guarantees that it will only call __hash__
1337 # only once per use as an argument to the cached function.
1338
1339 @self.module.lru_cache(maxsize=1)
1340 def f(x, y):
1341 return x * 3 + y
1342
1343 # Simulate the integer 5
1344 mock_int = unittest.mock.Mock()
1345 mock_int.__mul__ = unittest.mock.Mock(return_value=15)
1346 mock_int.__hash__ = unittest.mock.Mock(return_value=999)
1347
1348 # Add to cache: One use as an argument gives one call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001349 self.assertEqual(f(mock_int, 1), 16)
1350 self.assertEqual(mock_int.__hash__.call_count, 1)
1351 self.assertEqual(f.cache_info(), (0, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001352
1353 # Cache hit: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001354 self.assertEqual(f(mock_int, 1), 16)
1355 self.assertEqual(mock_int.__hash__.call_count, 2)
1356 self.assertEqual(f.cache_info(), (1, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001357
Ville Skyttä49b27342017-08-03 09:00:59 +03001358 # Cache eviction: No use as an argument gives no additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001359 self.assertEqual(f(6, 2), 20)
1360 self.assertEqual(mock_int.__hash__.call_count, 2)
1361 self.assertEqual(f.cache_info(), (1, 2, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001362
1363 # Cache miss: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001364 self.assertEqual(f(mock_int, 1), 16)
1365 self.assertEqual(mock_int.__hash__.call_count, 3)
1366 self.assertEqual(f.cache_info(), (1, 3, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001367
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08001368 def test_lru_reentrancy_with_len(self):
1369 # Test to make sure the LRU cache code isn't thrown-off by
1370 # caching the built-in len() function. Since len() can be
1371 # cached, we shouldn't use it inside the lru code itself.
1372 old_len = builtins.len
1373 try:
1374 builtins.len = self.module.lru_cache(4)(len)
1375 for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
1376 self.assertEqual(len('abcdefghijklmn'[:i]), i)
1377 finally:
1378 builtins.len = old_len
1379
Raymond Hettinger605a4472017-01-09 07:50:19 -08001380 def test_lru_star_arg_handling(self):
1381 # Test regression that arose in ea064ff3c10f
1382 @functools.lru_cache()
1383 def f(*args):
1384 return args
1385
1386 self.assertEqual(f(1, 2), (1, 2))
1387 self.assertEqual(f((1, 2)), ((1, 2),))
1388
Yury Selivanov46a02db2016-11-09 18:55:45 -05001389 def test_lru_type_error(self):
1390 # Regression test for issue #28653.
1391 # lru_cache was leaking when one of the arguments
1392 # wasn't cacheable.
1393
1394 @functools.lru_cache(maxsize=None)
1395 def infinite_cache(o):
1396 pass
1397
1398 @functools.lru_cache(maxsize=10)
1399 def limited_cache(o):
1400 pass
1401
1402 with self.assertRaises(TypeError):
1403 infinite_cache([])
1404
1405 with self.assertRaises(TypeError):
1406 limited_cache([])
1407
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001408 def test_lru_with_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001409 @self.module.lru_cache(maxsize=None)
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001410 def fib(n):
1411 if n < 2:
1412 return n
1413 return fib(n-1) + fib(n-2)
1414 self.assertEqual([fib(n) for n in range(16)],
1415 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1416 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001417 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001418 fib.cache_clear()
1419 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001420 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1421
1422 def test_lru_with_maxsize_negative(self):
1423 @self.module.lru_cache(maxsize=-10)
1424 def eq(n):
1425 return n
1426 for i in (0, 1):
1427 self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1428 self.assertEqual(eq.cache_info(),
Raymond Hettingerd8080c02019-01-26 03:02:00 -05001429 self.module._CacheInfo(hits=0, misses=300, maxsize=0, currsize=0))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001430
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001431 def test_lru_with_exceptions(self):
1432 # Verify that user_function exceptions get passed through without
1433 # creating a hard-to-read chained exception.
1434 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +01001435 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001436 @self.module.lru_cache(maxsize)
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001437 def func(i):
1438 return 'abc'[i]
1439 self.assertEqual(func(0), 'a')
1440 with self.assertRaises(IndexError) as cm:
1441 func(15)
1442 self.assertIsNone(cm.exception.__context__)
1443 # Verify that the previous exception did not result in a cached entry
1444 with self.assertRaises(IndexError):
1445 func(15)
1446
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001447 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001448 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001449 @self.module.lru_cache(maxsize=maxsize, typed=True)
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001450 def square(x):
1451 return x * x
1452 self.assertEqual(square(3), 9)
1453 self.assertEqual(type(square(3)), type(9))
1454 self.assertEqual(square(3.0), 9.0)
1455 self.assertEqual(type(square(3.0)), type(9.0))
1456 self.assertEqual(square(x=3), 9)
1457 self.assertEqual(type(square(x=3)), type(9))
1458 self.assertEqual(square(x=3.0), 9.0)
1459 self.assertEqual(type(square(x=3.0)), type(9.0))
1460 self.assertEqual(square.cache_info().hits, 4)
1461 self.assertEqual(square.cache_info().misses, 4)
1462
Antoine Pitroub5b37142012-11-13 21:35:40 +01001463 def test_lru_with_keyword_args(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001464 @self.module.lru_cache()
Antoine Pitroub5b37142012-11-13 21:35:40 +01001465 def fib(n):
1466 if n < 2:
1467 return n
1468 return fib(n=n-1) + fib(n=n-2)
1469 self.assertEqual(
1470 [fib(n=number) for number in range(16)],
1471 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1472 )
1473 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001474 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001475 fib.cache_clear()
1476 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001477 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001478
1479 def test_lru_with_keyword_args_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001480 @self.module.lru_cache(maxsize=None)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001481 def fib(n):
1482 if n < 2:
1483 return n
1484 return fib(n=n-1) + fib(n=n-2)
1485 self.assertEqual([fib(n=number) for number in range(16)],
1486 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1487 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001488 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001489 fib.cache_clear()
1490 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001491 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1492
Raymond Hettinger4ee39142017-01-08 17:28:20 -08001493 def test_kwargs_order(self):
1494 # PEP 468: Preserving Keyword Argument Order
1495 @self.module.lru_cache(maxsize=10)
1496 def f(**kwargs):
1497 return list(kwargs.items())
1498 self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)])
1499 self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)])
1500 self.assertEqual(f.cache_info(),
1501 self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2))
1502
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001503 def test_lru_cache_decoration(self):
1504 def f(zomg: 'zomg_annotation'):
1505 """f doc string"""
1506 return 42
1507 g = self.module.lru_cache()(f)
1508 for attr in self.module.WRAPPER_ASSIGNMENTS:
1509 self.assertEqual(getattr(g, attr), getattr(f, attr))
1510
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001511 def test_lru_cache_threaded(self):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001512 n, m = 5, 11
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001513 def orig(x, y):
1514 return 3 * x + y
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001515 f = self.module.lru_cache(maxsize=n*m)(orig)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001516 hits, misses, maxsize, currsize = f.cache_info()
1517 self.assertEqual(currsize, 0)
1518
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001519 start = threading.Event()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001520 def full(k):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001521 start.wait(10)
1522 for _ in range(m):
Serhiy Storchaka391af752015-06-08 12:44:18 +03001523 self.assertEqual(f(k, 0), orig(k, 0))
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001524
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001525 def clear():
1526 start.wait(10)
1527 for _ in range(2*m):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001528 f.cache_clear()
1529
1530 orig_si = sys.getswitchinterval()
Xavier de Gaye7522ef42016-12-08 11:06:56 +01001531 support.setswitchinterval(1e-6)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001532 try:
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001533 # create n threads in order to fill cache
Serhiy Storchaka391af752015-06-08 12:44:18 +03001534 threads = [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001535 for k in range(n)]
Hai Shie80697d2020-05-28 06:10:27 +08001536 with threading_helper.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001537 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001538
1539 hits, misses, maxsize, currsize = f.cache_info()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001540 if self.module is py_functools:
1541 # XXX: Why can be not equal?
1542 self.assertLessEqual(misses, n)
1543 self.assertLessEqual(hits, m*n - misses)
1544 else:
1545 self.assertEqual(misses, n)
1546 self.assertEqual(hits, m*n - misses)
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001547 self.assertEqual(currsize, n)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001548
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001549 # create n threads in order to fill cache and 1 to clear it
1550 threads = [threading.Thread(target=clear)]
Serhiy Storchaka391af752015-06-08 12:44:18 +03001551 threads += [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001552 for k in range(n)]
1553 start.clear()
Hai Shie80697d2020-05-28 06:10:27 +08001554 with threading_helper.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001555 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001556 finally:
1557 sys.setswitchinterval(orig_si)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001558
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001559 def test_lru_cache_threaded2(self):
1560 # Simultaneous call with the same arguments
1561 n, m = 5, 7
1562 start = threading.Barrier(n+1)
1563 pause = threading.Barrier(n+1)
1564 stop = threading.Barrier(n+1)
1565 @self.module.lru_cache(maxsize=m*n)
1566 def f(x):
1567 pause.wait(10)
1568 return 3 * x
1569 self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1570 def test():
1571 for i in range(m):
1572 start.wait(10)
1573 self.assertEqual(f(i), 3 * i)
1574 stop.wait(10)
1575 threads = [threading.Thread(target=test) for k in range(n)]
Hai Shie80697d2020-05-28 06:10:27 +08001576 with threading_helper.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001577 for i in range(m):
1578 start.wait(10)
1579 stop.reset()
1580 pause.wait(10)
1581 start.reset()
1582 stop.wait(10)
1583 pause.reset()
1584 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1585
Serhiy Storchaka67796522017-01-12 18:34:33 +02001586 def test_lru_cache_threaded3(self):
1587 @self.module.lru_cache(maxsize=2)
1588 def f(x):
1589 time.sleep(.01)
1590 return 3 * x
1591 def test(i, x):
1592 with self.subTest(thread=i):
1593 self.assertEqual(f(x), 3 * x, i)
1594 threads = [threading.Thread(target=test, args=(i, v))
1595 for i, v in enumerate([1, 2, 2, 3, 2])]
Hai Shie80697d2020-05-28 06:10:27 +08001596 with threading_helper.start_threads(threads):
Serhiy Storchaka67796522017-01-12 18:34:33 +02001597 pass
1598
Raymond Hettinger03923422013-03-04 02:52:50 -05001599 def test_need_for_rlock(self):
1600 # This will deadlock on an LRU cache that uses a regular lock
1601
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001602 @self.module.lru_cache(maxsize=10)
Raymond Hettinger03923422013-03-04 02:52:50 -05001603 def test_func(x):
1604 'Used to demonstrate a reentrant lru_cache call within a single thread'
1605 return x
1606
1607 class DoubleEq:
1608 'Demonstrate a reentrant lru_cache call within a single thread'
1609 def __init__(self, x):
1610 self.x = x
1611 def __hash__(self):
1612 return self.x
1613 def __eq__(self, other):
1614 if self.x == 2:
1615 test_func(DoubleEq(1))
1616 return self.x == other.x
1617
1618 test_func(DoubleEq(1)) # Load the cache
1619 test_func(DoubleEq(2)) # Load the cache
1620 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1621 DoubleEq(2)) # Verify the correct return value
1622
Serhiy Storchakae7070f02015-06-08 11:19:24 +03001623 def test_lru_method(self):
1624 class X(int):
1625 f_cnt = 0
1626 @self.module.lru_cache(2)
1627 def f(self, x):
1628 self.f_cnt += 1
1629 return x*10+self
1630 a = X(5)
1631 b = X(5)
1632 c = X(7)
1633 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1634
1635 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1636 self.assertEqual(a.f(x), x*10 + 5)
1637 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1638 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1639
1640 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1641 self.assertEqual(b.f(x), x*10 + 5)
1642 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1643 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1644
1645 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1646 self.assertEqual(c.f(x), x*10 + 7)
1647 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1648 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1649
1650 self.assertEqual(a.f.cache_info(), X.f.cache_info())
1651 self.assertEqual(b.f.cache_info(), X.f.cache_info())
1652 self.assertEqual(c.f.cache_info(), X.f.cache_info())
1653
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001654 def test_pickle(self):
1655 cls = self.__class__
1656 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1657 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1658 with self.subTest(proto=proto, func=f):
1659 f_copy = pickle.loads(pickle.dumps(f, proto))
1660 self.assertIs(f_copy, f)
1661
1662 def test_copy(self):
1663 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001664 def orig(x, y):
1665 return 3 * x + y
1666 part = self.module.partial(orig, 2)
1667 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1668 self.module.lru_cache(2)(part))
1669 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001670 with self.subTest(func=f):
1671 f_copy = copy.copy(f)
1672 self.assertIs(f_copy, f)
1673
1674 def test_deepcopy(self):
1675 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001676 def orig(x, y):
1677 return 3 * x + y
1678 part = self.module.partial(orig, 2)
1679 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1680 self.module.lru_cache(2)(part))
1681 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001682 with self.subTest(func=f):
1683 f_copy = copy.deepcopy(f)
1684 self.assertIs(f_copy, f)
1685
Manjusaka051ff522019-11-12 15:30:18 +08001686 def test_lru_cache_parameters(self):
1687 @self.module.lru_cache(maxsize=2)
1688 def f():
1689 return 1
1690 self.assertEqual(f.cache_parameters(), {'maxsize': 2, "typed": False})
1691
1692 @self.module.lru_cache(maxsize=1000, typed=True)
1693 def f():
1694 return 1
1695 self.assertEqual(f.cache_parameters(), {'maxsize': 1000, "typed": True})
1696
Dennis Sweeney1253c3e2020-05-05 17:14:32 -04001697 def test_lru_cache_weakrefable(self):
1698 @self.module.lru_cache
1699 def test_function(x):
1700 return x
1701
1702 class A:
1703 @self.module.lru_cache
1704 def test_method(self, x):
1705 return (self, x)
1706
1707 @staticmethod
1708 @self.module.lru_cache
1709 def test_staticmethod(x):
1710 return (self, x)
1711
1712 refs = [weakref.ref(test_function),
1713 weakref.ref(A.test_method),
1714 weakref.ref(A.test_staticmethod)]
1715
1716 for ref in refs:
1717 self.assertIsNotNone(ref())
1718
1719 del A
1720 del test_function
1721 gc.collect()
1722
1723 for ref in refs:
1724 self.assertIsNone(ref())
1725
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001726
1727@py_functools.lru_cache()
1728def py_cached_func(x, y):
1729 return 3 * x + y
1730
1731@c_functools.lru_cache()
1732def c_cached_func(x, y):
1733 return 3 * x + y
1734
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001735
1736class TestLRUPy(TestLRU, unittest.TestCase):
1737 module = py_functools
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001738 cached_func = py_cached_func,
1739
1740 @module.lru_cache()
1741 def cached_meth(self, x, y):
1742 return 3 * x + y
1743
1744 @staticmethod
1745 @module.lru_cache()
1746 def cached_staticmeth(x, y):
1747 return 3 * x + y
1748
1749
1750class TestLRUC(TestLRU, unittest.TestCase):
1751 module = c_functools
1752 cached_func = c_cached_func,
1753
1754 @module.lru_cache()
1755 def cached_meth(self, x, y):
1756 return 3 * x + y
1757
1758 @staticmethod
1759 @module.lru_cache()
1760 def cached_staticmeth(x, y):
1761 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001762
Raymond Hettinger03923422013-03-04 02:52:50 -05001763
Łukasz Langa6f692512013-06-05 12:20:24 +02001764class TestSingleDispatch(unittest.TestCase):
1765 def test_simple_overloads(self):
1766 @functools.singledispatch
1767 def g(obj):
1768 return "base"
1769 def g_int(i):
1770 return "integer"
1771 g.register(int, g_int)
1772 self.assertEqual(g("str"), "base")
1773 self.assertEqual(g(1), "integer")
1774 self.assertEqual(g([1,2,3]), "base")
1775
1776 def test_mro(self):
1777 @functools.singledispatch
1778 def g(obj):
1779 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001780 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001781 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001782 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001783 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001784 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001785 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001786 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001787 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001788 def g_A(a):
1789 return "A"
1790 def g_B(b):
1791 return "B"
1792 g.register(A, g_A)
1793 g.register(B, g_B)
1794 self.assertEqual(g(A()), "A")
1795 self.assertEqual(g(B()), "B")
1796 self.assertEqual(g(C()), "A")
1797 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001798
1799 def test_register_decorator(self):
1800 @functools.singledispatch
1801 def g(obj):
1802 return "base"
1803 @g.register(int)
1804 def g_int(i):
1805 return "int %s" % (i,)
1806 self.assertEqual(g(""), "base")
1807 self.assertEqual(g(12), "int 12")
1808 self.assertIs(g.dispatch(int), g_int)
1809 self.assertIs(g.dispatch(object), g.dispatch(str))
1810 # Note: in the assert above this is not g.
1811 # @singledispatch returns the wrapper.
1812
1813 def test_wrapping_attributes(self):
1814 @functools.singledispatch
1815 def g(obj):
1816 "Simple test"
1817 return "Test"
1818 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02001819 if sys.flags.optimize < 2:
1820 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02001821
1822 @unittest.skipUnless(decimal, 'requires _decimal')
1823 @support.cpython_only
1824 def test_c_classes(self):
1825 @functools.singledispatch
1826 def g(obj):
1827 return "base"
1828 @g.register(decimal.DecimalException)
1829 def _(obj):
1830 return obj.args
1831 subn = decimal.Subnormal("Exponent < Emin")
1832 rnd = decimal.Rounded("Number got rounded")
1833 self.assertEqual(g(subn), ("Exponent < Emin",))
1834 self.assertEqual(g(rnd), ("Number got rounded",))
1835 @g.register(decimal.Subnormal)
1836 def _(obj):
1837 return "Too small to care."
1838 self.assertEqual(g(subn), "Too small to care.")
1839 self.assertEqual(g(rnd), ("Number got rounded",))
1840
1841 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001842 # None of the examples in this test depend on haystack ordering.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001843 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001844 mro = functools._compose_mro
1845 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1846 for haystack in permutations(bases):
1847 m = mro(dict, haystack)
Guido van Rossumf0666942016-08-23 10:47:07 -07001848 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
1849 c.Collection, c.Sized, c.Iterable,
1850 c.Container, object])
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001851 bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict]
Łukasz Langa6f692512013-06-05 12:20:24 +02001852 for haystack in permutations(bases):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001853 m = mro(collections.ChainMap, haystack)
1854 self.assertEqual(m, [collections.ChainMap, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001855 c.Collection, c.Sized, c.Iterable,
1856 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001857
1858 # If there's a generic function with implementations registered for
1859 # both Sized and Container, passing a defaultdict to it results in an
1860 # ambiguous dispatch which will cause a RuntimeError (see
1861 # test_mro_conflicts).
1862 bases = [c.Container, c.Sized, str]
1863 for haystack in permutations(bases):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001864 m = mro(collections.defaultdict, [c.Sized, c.Container, str])
1865 self.assertEqual(m, [collections.defaultdict, dict, c.Sized,
1866 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001867
1868 # MutableSequence below is registered directly on D. In other words, it
Martin Panter46f50722016-05-26 05:35:26 +00001869 # precedes MutableMapping which means single dispatch will always
Łukasz Langa3720c772013-07-01 16:00:38 +02001870 # choose MutableSequence here.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001871 class D(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02001872 pass
1873 c.MutableSequence.register(D)
1874 bases = [c.MutableSequence, c.MutableMapping]
1875 for haystack in permutations(bases):
1876 m = mro(D, bases)
Guido van Rossumf0666942016-08-23 10:47:07 -07001877 self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001878 collections.defaultdict, dict, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001879 c.Collection, c.Sized, c.Iterable, c.Container,
Łukasz Langa3720c772013-07-01 16:00:38 +02001880 object])
1881
1882 # Container and Callable are registered on different base classes and
1883 # a generic function supporting both should always pick the Callable
1884 # implementation if a C instance is passed.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001885 class C(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02001886 def __call__(self):
1887 pass
1888 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1889 for haystack in permutations(bases):
1890 m = mro(C, haystack)
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001891 self.assertEqual(m, [C, c.Callable, collections.defaultdict, dict, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001892 c.Collection, c.Sized, c.Iterable,
1893 c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001894
1895 def test_register_abc(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001896 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001897 d = {"a": "b"}
1898 l = [1, 2, 3]
1899 s = {object(), None}
1900 f = frozenset(s)
1901 t = (1, 2, 3)
1902 @functools.singledispatch
1903 def g(obj):
1904 return "base"
1905 self.assertEqual(g(d), "base")
1906 self.assertEqual(g(l), "base")
1907 self.assertEqual(g(s), "base")
1908 self.assertEqual(g(f), "base")
1909 self.assertEqual(g(t), "base")
1910 g.register(c.Sized, lambda obj: "sized")
1911 self.assertEqual(g(d), "sized")
1912 self.assertEqual(g(l), "sized")
1913 self.assertEqual(g(s), "sized")
1914 self.assertEqual(g(f), "sized")
1915 self.assertEqual(g(t), "sized")
1916 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1917 self.assertEqual(g(d), "mutablemapping")
1918 self.assertEqual(g(l), "sized")
1919 self.assertEqual(g(s), "sized")
1920 self.assertEqual(g(f), "sized")
1921 self.assertEqual(g(t), "sized")
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001922 g.register(collections.ChainMap, lambda obj: "chainmap")
Łukasz Langa6f692512013-06-05 12:20:24 +02001923 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1924 self.assertEqual(g(l), "sized")
1925 self.assertEqual(g(s), "sized")
1926 self.assertEqual(g(f), "sized")
1927 self.assertEqual(g(t), "sized")
1928 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1929 self.assertEqual(g(d), "mutablemapping")
1930 self.assertEqual(g(l), "mutablesequence")
1931 self.assertEqual(g(s), "sized")
1932 self.assertEqual(g(f), "sized")
1933 self.assertEqual(g(t), "sized")
1934 g.register(c.MutableSet, lambda obj: "mutableset")
1935 self.assertEqual(g(d), "mutablemapping")
1936 self.assertEqual(g(l), "mutablesequence")
1937 self.assertEqual(g(s), "mutableset")
1938 self.assertEqual(g(f), "sized")
1939 self.assertEqual(g(t), "sized")
1940 g.register(c.Mapping, lambda obj: "mapping")
1941 self.assertEqual(g(d), "mutablemapping") # not specific enough
1942 self.assertEqual(g(l), "mutablesequence")
1943 self.assertEqual(g(s), "mutableset")
1944 self.assertEqual(g(f), "sized")
1945 self.assertEqual(g(t), "sized")
1946 g.register(c.Sequence, lambda obj: "sequence")
1947 self.assertEqual(g(d), "mutablemapping")
1948 self.assertEqual(g(l), "mutablesequence")
1949 self.assertEqual(g(s), "mutableset")
1950 self.assertEqual(g(f), "sized")
1951 self.assertEqual(g(t), "sequence")
1952 g.register(c.Set, lambda obj: "set")
1953 self.assertEqual(g(d), "mutablemapping")
1954 self.assertEqual(g(l), "mutablesequence")
1955 self.assertEqual(g(s), "mutableset")
1956 self.assertEqual(g(f), "set")
1957 self.assertEqual(g(t), "sequence")
1958 g.register(dict, lambda obj: "dict")
1959 self.assertEqual(g(d), "dict")
1960 self.assertEqual(g(l), "mutablesequence")
1961 self.assertEqual(g(s), "mutableset")
1962 self.assertEqual(g(f), "set")
1963 self.assertEqual(g(t), "sequence")
1964 g.register(list, lambda obj: "list")
1965 self.assertEqual(g(d), "dict")
1966 self.assertEqual(g(l), "list")
1967 self.assertEqual(g(s), "mutableset")
1968 self.assertEqual(g(f), "set")
1969 self.assertEqual(g(t), "sequence")
1970 g.register(set, lambda obj: "concrete-set")
1971 self.assertEqual(g(d), "dict")
1972 self.assertEqual(g(l), "list")
1973 self.assertEqual(g(s), "concrete-set")
1974 self.assertEqual(g(f), "set")
1975 self.assertEqual(g(t), "sequence")
1976 g.register(frozenset, lambda obj: "frozen-set")
1977 self.assertEqual(g(d), "dict")
1978 self.assertEqual(g(l), "list")
1979 self.assertEqual(g(s), "concrete-set")
1980 self.assertEqual(g(f), "frozen-set")
1981 self.assertEqual(g(t), "sequence")
1982 g.register(tuple, lambda obj: "tuple")
1983 self.assertEqual(g(d), "dict")
1984 self.assertEqual(g(l), "list")
1985 self.assertEqual(g(s), "concrete-set")
1986 self.assertEqual(g(f), "frozen-set")
1987 self.assertEqual(g(t), "tuple")
1988
Łukasz Langa3720c772013-07-01 16:00:38 +02001989 def test_c3_abc(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001990 c = collections.abc
Łukasz Langa3720c772013-07-01 16:00:38 +02001991 mro = functools._c3_mro
1992 class A(object):
1993 pass
1994 class B(A):
1995 def __len__(self):
1996 return 0 # implies Sized
1997 @c.Container.register
1998 class C(object):
1999 pass
2000 class D(object):
2001 pass # unrelated
2002 class X(D, C, B):
2003 def __call__(self):
2004 pass # implies Callable
2005 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
2006 for abcs in permutations([c.Sized, c.Callable, c.Container]):
2007 self.assertEqual(mro(X, abcs=abcs), expected)
2008 # unrelated ABCs don't appear in the resulting MRO
2009 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
2010 self.assertEqual(mro(X, abcs=many_abcs), expected)
2011
Yury Selivanov77a8cd62015-08-18 14:20:00 -04002012 def test_false_meta(self):
2013 # see issue23572
2014 class MetaA(type):
2015 def __len__(self):
2016 return 0
2017 class A(metaclass=MetaA):
2018 pass
2019 class AA(A):
2020 pass
2021 @functools.singledispatch
2022 def fun(a):
2023 return 'base A'
2024 @fun.register(A)
2025 def _(a):
2026 return 'fun A'
2027 aa = AA()
2028 self.assertEqual(fun(aa), 'fun A')
2029
Łukasz Langa6f692512013-06-05 12:20:24 +02002030 def test_mro_conflicts(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002031 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02002032 @functools.singledispatch
2033 def g(arg):
2034 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02002035 class O(c.Sized):
2036 def __len__(self):
2037 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02002038 o = O()
2039 self.assertEqual(g(o), "base")
2040 g.register(c.Iterable, lambda arg: "iterable")
2041 g.register(c.Container, lambda arg: "container")
2042 g.register(c.Sized, lambda arg: "sized")
2043 g.register(c.Set, lambda arg: "set")
2044 self.assertEqual(g(o), "sized")
2045 c.Iterable.register(O)
2046 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
2047 c.Container.register(O)
2048 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02002049 c.Set.register(O)
2050 self.assertEqual(g(o), "set") # because c.Set is a subclass of
2051 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02002052 class P:
2053 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02002054 p = P()
2055 self.assertEqual(g(p), "base")
2056 c.Iterable.register(P)
2057 self.assertEqual(g(p), "iterable")
2058 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02002059 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02002060 g(p)
Benjamin Petersonab078e92016-07-13 21:13:29 -07002061 self.assertIn(
2062 str(re_one.exception),
2063 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2064 "or <class 'collections.abc.Iterable'>"),
2065 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
2066 "or <class 'collections.abc.Container'>")),
2067 )
Łukasz Langa6f692512013-06-05 12:20:24 +02002068 class Q(c.Sized):
2069 def __len__(self):
2070 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02002071 q = Q()
2072 self.assertEqual(g(q), "sized")
2073 c.Iterable.register(Q)
2074 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
2075 c.Set.register(Q)
2076 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02002077 # c.Sized and c.Iterable
2078 @functools.singledispatch
2079 def h(arg):
2080 return "base"
2081 @h.register(c.Sized)
2082 def _(arg):
2083 return "sized"
2084 @h.register(c.Container)
2085 def _(arg):
2086 return "container"
2087 # Even though Sized and Container are explicit bases of MutableMapping,
2088 # this ABC is implicitly registered on defaultdict which makes all of
2089 # MutableMapping's bases implicit as well from defaultdict's
2090 # perspective.
2091 with self.assertRaises(RuntimeError) as re_two:
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002092 h(collections.defaultdict(lambda: 0))
Benjamin Petersonab078e92016-07-13 21:13:29 -07002093 self.assertIn(
2094 str(re_two.exception),
2095 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2096 "or <class 'collections.abc.Sized'>"),
2097 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2098 "or <class 'collections.abc.Container'>")),
2099 )
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002100 class R(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02002101 pass
2102 c.MutableSequence.register(R)
2103 @functools.singledispatch
2104 def i(arg):
2105 return "base"
2106 @i.register(c.MutableMapping)
2107 def _(arg):
2108 return "mapping"
2109 @i.register(c.MutableSequence)
2110 def _(arg):
2111 return "sequence"
2112 r = R()
2113 self.assertEqual(i(r), "sequence")
2114 class S:
2115 pass
2116 class T(S, c.Sized):
2117 def __len__(self):
2118 return 0
2119 t = T()
2120 self.assertEqual(h(t), "sized")
2121 c.Container.register(T)
2122 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
2123 class U:
2124 def __len__(self):
2125 return 0
2126 u = U()
2127 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
2128 # from the existence of __len__()
2129 c.Container.register(U)
2130 # There is no preference for registered versus inferred ABCs.
2131 with self.assertRaises(RuntimeError) as re_three:
2132 h(u)
Benjamin Petersonab078e92016-07-13 21:13:29 -07002133 self.assertIn(
2134 str(re_three.exception),
2135 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2136 "or <class 'collections.abc.Sized'>"),
2137 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2138 "or <class 'collections.abc.Container'>")),
2139 )
Łukasz Langa3720c772013-07-01 16:00:38 +02002140 class V(c.Sized, S):
2141 def __len__(self):
2142 return 0
2143 @functools.singledispatch
2144 def j(arg):
2145 return "base"
2146 @j.register(S)
2147 def _(arg):
2148 return "s"
2149 @j.register(c.Container)
2150 def _(arg):
2151 return "container"
2152 v = V()
2153 self.assertEqual(j(v), "s")
2154 c.Container.register(V)
2155 self.assertEqual(j(v), "container") # because it ends up right after
2156 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02002157
2158 def test_cache_invalidation(self):
2159 from collections import UserDict
INADA Naoki9811e802017-09-30 16:13:02 +09002160 import weakref
2161
Łukasz Langa6f692512013-06-05 12:20:24 +02002162 class TracingDict(UserDict):
2163 def __init__(self, *args, **kwargs):
2164 super(TracingDict, self).__init__(*args, **kwargs)
2165 self.set_ops = []
2166 self.get_ops = []
2167 def __getitem__(self, key):
2168 result = self.data[key]
2169 self.get_ops.append(key)
2170 return result
2171 def __setitem__(self, key, value):
2172 self.set_ops.append(key)
2173 self.data[key] = value
2174 def clear(self):
2175 self.data.clear()
INADA Naoki9811e802017-09-30 16:13:02 +09002176
Łukasz Langa6f692512013-06-05 12:20:24 +02002177 td = TracingDict()
INADA Naoki9811e802017-09-30 16:13:02 +09002178 with support.swap_attr(weakref, "WeakKeyDictionary", lambda: td):
2179 c = collections.abc
2180 @functools.singledispatch
2181 def g(arg):
2182 return "base"
2183 d = {}
2184 l = []
2185 self.assertEqual(len(td), 0)
2186 self.assertEqual(g(d), "base")
2187 self.assertEqual(len(td), 1)
2188 self.assertEqual(td.get_ops, [])
2189 self.assertEqual(td.set_ops, [dict])
2190 self.assertEqual(td.data[dict], g.registry[object])
2191 self.assertEqual(g(l), "base")
2192 self.assertEqual(len(td), 2)
2193 self.assertEqual(td.get_ops, [])
2194 self.assertEqual(td.set_ops, [dict, list])
2195 self.assertEqual(td.data[dict], g.registry[object])
2196 self.assertEqual(td.data[list], g.registry[object])
2197 self.assertEqual(td.data[dict], td.data[list])
2198 self.assertEqual(g(l), "base")
2199 self.assertEqual(g(d), "base")
2200 self.assertEqual(td.get_ops, [list, dict])
2201 self.assertEqual(td.set_ops, [dict, list])
2202 g.register(list, lambda arg: "list")
2203 self.assertEqual(td.get_ops, [list, dict])
2204 self.assertEqual(len(td), 0)
2205 self.assertEqual(g(d), "base")
2206 self.assertEqual(len(td), 1)
2207 self.assertEqual(td.get_ops, [list, dict])
2208 self.assertEqual(td.set_ops, [dict, list, dict])
2209 self.assertEqual(td.data[dict],
2210 functools._find_impl(dict, g.registry))
2211 self.assertEqual(g(l), "list")
2212 self.assertEqual(len(td), 2)
2213 self.assertEqual(td.get_ops, [list, dict])
2214 self.assertEqual(td.set_ops, [dict, list, dict, list])
2215 self.assertEqual(td.data[list],
2216 functools._find_impl(list, g.registry))
2217 class X:
2218 pass
2219 c.MutableMapping.register(X) # Will not invalidate the cache,
2220 # not using ABCs yet.
2221 self.assertEqual(g(d), "base")
2222 self.assertEqual(g(l), "list")
2223 self.assertEqual(td.get_ops, [list, dict, dict, list])
2224 self.assertEqual(td.set_ops, [dict, list, dict, list])
2225 g.register(c.Sized, lambda arg: "sized")
2226 self.assertEqual(len(td), 0)
2227 self.assertEqual(g(d), "sized")
2228 self.assertEqual(len(td), 1)
2229 self.assertEqual(td.get_ops, [list, dict, dict, list])
2230 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
2231 self.assertEqual(g(l), "list")
2232 self.assertEqual(len(td), 2)
2233 self.assertEqual(td.get_ops, [list, dict, dict, list])
2234 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2235 self.assertEqual(g(l), "list")
2236 self.assertEqual(g(d), "sized")
2237 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
2238 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2239 g.dispatch(list)
2240 g.dispatch(dict)
2241 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
2242 list, dict])
2243 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2244 c.MutableSet.register(X) # Will invalidate the cache.
2245 self.assertEqual(len(td), 2) # Stale cache.
2246 self.assertEqual(g(l), "list")
2247 self.assertEqual(len(td), 1)
2248 g.register(c.MutableMapping, lambda arg: "mutablemapping")
2249 self.assertEqual(len(td), 0)
2250 self.assertEqual(g(d), "mutablemapping")
2251 self.assertEqual(len(td), 1)
2252 self.assertEqual(g(l), "list")
2253 self.assertEqual(len(td), 2)
2254 g.register(dict, lambda arg: "dict")
2255 self.assertEqual(g(d), "dict")
2256 self.assertEqual(g(l), "list")
2257 g._clear_cache()
2258 self.assertEqual(len(td), 0)
Łukasz Langa6f692512013-06-05 12:20:24 +02002259
Łukasz Langae5697532017-12-11 13:56:31 -08002260 def test_annotations(self):
2261 @functools.singledispatch
2262 def i(arg):
2263 return "base"
2264 @i.register
2265 def _(arg: collections.abc.Mapping):
2266 return "mapping"
2267 @i.register
2268 def _(arg: "collections.abc.Sequence"):
2269 return "sequence"
2270 self.assertEqual(i(None), "base")
2271 self.assertEqual(i({"a": 1}), "mapping")
2272 self.assertEqual(i([1, 2, 3]), "sequence")
2273 self.assertEqual(i((1, 2, 3)), "sequence")
2274 self.assertEqual(i("str"), "sequence")
2275
2276 # Registering classes as callables doesn't work with annotations,
2277 # you need to pass the type explicitly.
2278 @i.register(str)
2279 class _:
2280 def __init__(self, arg):
2281 self.arg = arg
2282
2283 def __eq__(self, other):
2284 return self.arg == other
2285 self.assertEqual(i("str"), "str")
2286
Ethan Smithc6512752018-05-26 16:38:33 -04002287 def test_method_register(self):
2288 class A:
2289 @functools.singledispatchmethod
2290 def t(self, arg):
2291 self.arg = "base"
2292 @t.register(int)
2293 def _(self, arg):
2294 self.arg = "int"
2295 @t.register(str)
2296 def _(self, arg):
2297 self.arg = "str"
2298 a = A()
2299
2300 a.t(0)
2301 self.assertEqual(a.arg, "int")
2302 aa = A()
2303 self.assertFalse(hasattr(aa, 'arg'))
2304 a.t('')
2305 self.assertEqual(a.arg, "str")
2306 aa = A()
2307 self.assertFalse(hasattr(aa, 'arg'))
2308 a.t(0.0)
2309 self.assertEqual(a.arg, "base")
2310 aa = A()
2311 self.assertFalse(hasattr(aa, 'arg'))
2312
2313 def test_staticmethod_register(self):
2314 class A:
2315 @functools.singledispatchmethod
2316 @staticmethod
2317 def t(arg):
2318 return arg
2319 @t.register(int)
2320 @staticmethod
2321 def _(arg):
2322 return isinstance(arg, int)
2323 @t.register(str)
2324 @staticmethod
2325 def _(arg):
2326 return isinstance(arg, str)
2327 a = A()
2328
2329 self.assertTrue(A.t(0))
2330 self.assertTrue(A.t(''))
2331 self.assertEqual(A.t(0.0), 0.0)
2332
2333 def test_classmethod_register(self):
2334 class A:
2335 def __init__(self, arg):
2336 self.arg = arg
2337
2338 @functools.singledispatchmethod
2339 @classmethod
2340 def t(cls, arg):
2341 return cls("base")
2342 @t.register(int)
2343 @classmethod
2344 def _(cls, arg):
2345 return cls("int")
2346 @t.register(str)
2347 @classmethod
2348 def _(cls, arg):
2349 return cls("str")
2350
2351 self.assertEqual(A.t(0).arg, "int")
2352 self.assertEqual(A.t('').arg, "str")
2353 self.assertEqual(A.t(0.0).arg, "base")
2354
2355 def test_callable_register(self):
2356 class A:
2357 def __init__(self, arg):
2358 self.arg = arg
2359
2360 @functools.singledispatchmethod
2361 @classmethod
2362 def t(cls, arg):
2363 return cls("base")
2364
2365 @A.t.register(int)
2366 @classmethod
2367 def _(cls, arg):
2368 return cls("int")
2369 @A.t.register(str)
2370 @classmethod
2371 def _(cls, arg):
2372 return cls("str")
2373
2374 self.assertEqual(A.t(0).arg, "int")
2375 self.assertEqual(A.t('').arg, "str")
2376 self.assertEqual(A.t(0.0).arg, "base")
2377
2378 def test_abstractmethod_register(self):
2379 class Abstract(abc.ABCMeta):
2380
2381 @functools.singledispatchmethod
2382 @abc.abstractmethod
2383 def add(self, x, y):
2384 pass
2385
2386 self.assertTrue(Abstract.add.__isabstractmethod__)
2387
2388 def test_type_ann_register(self):
2389 class A:
2390 @functools.singledispatchmethod
2391 def t(self, arg):
2392 return "base"
2393 @t.register
2394 def _(self, arg: int):
2395 return "int"
2396 @t.register
2397 def _(self, arg: str):
2398 return "str"
2399 a = A()
2400
2401 self.assertEqual(a.t(0), "int")
2402 self.assertEqual(a.t(''), "str")
2403 self.assertEqual(a.t(0.0), "base")
2404
Łukasz Langae5697532017-12-11 13:56:31 -08002405 def test_invalid_registrations(self):
2406 msg_prefix = "Invalid first argument to `register()`: "
2407 msg_suffix = (
2408 ". Use either `@register(some_class)` or plain `@register` on an "
2409 "annotated function."
2410 )
2411 @functools.singledispatch
2412 def i(arg):
2413 return "base"
2414 with self.assertRaises(TypeError) as exc:
2415 @i.register(42)
2416 def _(arg):
2417 return "I annotated with a non-type"
2418 self.assertTrue(str(exc.exception).startswith(msg_prefix + "42"))
2419 self.assertTrue(str(exc.exception).endswith(msg_suffix))
2420 with self.assertRaises(TypeError) as exc:
2421 @i.register
2422 def _(arg):
2423 return "I forgot to annotate"
2424 self.assertTrue(str(exc.exception).startswith(msg_prefix +
2425 "<function TestSingleDispatch.test_invalid_registrations.<locals>._"
2426 ))
2427 self.assertTrue(str(exc.exception).endswith(msg_suffix))
2428
Łukasz Langae5697532017-12-11 13:56:31 -08002429 with self.assertRaises(TypeError) as exc:
2430 @i.register
2431 def _(arg: typing.Iterable[str]):
2432 # At runtime, dispatching on generics is impossible.
2433 # When registering implementations with singledispatch, avoid
2434 # types from `typing`. Instead, annotate with regular types
2435 # or ABCs.
2436 return "I annotated with a generic collection"
Lysandros Nikolaoud6738102019-05-20 00:11:21 +02002437 self.assertTrue(str(exc.exception).startswith(
2438 "Invalid annotation for 'arg'."
Łukasz Langae5697532017-12-11 13:56:31 -08002439 ))
Lysandros Nikolaoud6738102019-05-20 00:11:21 +02002440 self.assertTrue(str(exc.exception).endswith(
2441 'typing.Iterable[str] is not a class.'
2442 ))
Łukasz Langae5697532017-12-11 13:56:31 -08002443
Dong-hee Na445f1b32018-07-10 16:26:36 +09002444 def test_invalid_positional_argument(self):
2445 @functools.singledispatch
2446 def f(*args):
2447 pass
2448 msg = 'f requires at least 1 positional argument'
INADA Naoki56d8f572018-07-17 13:44:47 +09002449 with self.assertRaisesRegex(TypeError, msg):
Dong-hee Na445f1b32018-07-10 16:26:36 +09002450 f()
Łukasz Langa6f692512013-06-05 12:20:24 +02002451
Carl Meyerd658dea2018-08-28 01:11:56 -06002452
2453class CachedCostItem:
2454 _cost = 1
2455
2456 def __init__(self):
2457 self.lock = py_functools.RLock()
2458
2459 @py_functools.cached_property
2460 def cost(self):
2461 """The cost of the item."""
2462 with self.lock:
2463 self._cost += 1
2464 return self._cost
2465
2466
2467class OptionallyCachedCostItem:
2468 _cost = 1
2469
2470 def get_cost(self):
2471 """The cost of the item."""
2472 self._cost += 1
2473 return self._cost
2474
2475 cached_cost = py_functools.cached_property(get_cost)
2476
2477
2478class CachedCostItemWait:
2479
2480 def __init__(self, event):
2481 self._cost = 1
2482 self.lock = py_functools.RLock()
2483 self.event = event
2484
2485 @py_functools.cached_property
2486 def cost(self):
2487 self.event.wait(1)
2488 with self.lock:
2489 self._cost += 1
2490 return self._cost
2491
2492
2493class CachedCostItemWithSlots:
2494 __slots__ = ('_cost')
2495
2496 def __init__(self):
2497 self._cost = 1
2498
2499 @py_functools.cached_property
2500 def cost(self):
2501 raise RuntimeError('never called, slots not supported')
2502
2503
2504class TestCachedProperty(unittest.TestCase):
2505 def test_cached(self):
2506 item = CachedCostItem()
2507 self.assertEqual(item.cost, 2)
2508 self.assertEqual(item.cost, 2) # not 3
2509
2510 def test_cached_attribute_name_differs_from_func_name(self):
2511 item = OptionallyCachedCostItem()
2512 self.assertEqual(item.get_cost(), 2)
2513 self.assertEqual(item.cached_cost, 3)
2514 self.assertEqual(item.get_cost(), 4)
2515 self.assertEqual(item.cached_cost, 3)
2516
2517 def test_threaded(self):
2518 go = threading.Event()
2519 item = CachedCostItemWait(go)
2520
2521 num_threads = 3
2522
2523 orig_si = sys.getswitchinterval()
2524 sys.setswitchinterval(1e-6)
2525 try:
2526 threads = [
2527 threading.Thread(target=lambda: item.cost)
2528 for k in range(num_threads)
2529 ]
Hai Shie80697d2020-05-28 06:10:27 +08002530 with threading_helper.start_threads(threads):
Carl Meyerd658dea2018-08-28 01:11:56 -06002531 go.set()
2532 finally:
2533 sys.setswitchinterval(orig_si)
2534
2535 self.assertEqual(item.cost, 2)
2536
2537 def test_object_with_slots(self):
2538 item = CachedCostItemWithSlots()
2539 with self.assertRaisesRegex(
2540 TypeError,
2541 "No '__dict__' attribute on 'CachedCostItemWithSlots' instance to cache 'cost' property.",
2542 ):
2543 item.cost
2544
2545 def test_immutable_dict(self):
2546 class MyMeta(type):
2547 @py_functools.cached_property
2548 def prop(self):
2549 return True
2550
2551 class MyClass(metaclass=MyMeta):
2552 pass
2553
2554 with self.assertRaisesRegex(
2555 TypeError,
2556 "The '__dict__' attribute on 'MyMeta' instance does not support item assignment for caching 'prop' property.",
2557 ):
2558 MyClass.prop
2559
2560 def test_reuse_different_names(self):
2561 """Disallow this case because decorated function a would not be cached."""
2562 with self.assertRaises(RuntimeError) as ctx:
2563 class ReusedCachedProperty:
2564 @py_functools.cached_property
2565 def a(self):
2566 pass
2567
2568 b = a
2569
2570 self.assertEqual(
2571 str(ctx.exception.__context__),
2572 str(TypeError("Cannot assign the same cached_property to two different names ('a' and 'b')."))
2573 )
2574
2575 def test_reuse_same_name(self):
2576 """Reusing a cached_property on different classes under the same name is OK."""
2577 counter = 0
2578
2579 @py_functools.cached_property
2580 def _cp(_self):
2581 nonlocal counter
2582 counter += 1
2583 return counter
2584
2585 class A:
2586 cp = _cp
2587
2588 class B:
2589 cp = _cp
2590
2591 a = A()
2592 b = B()
2593
2594 self.assertEqual(a.cp, 1)
2595 self.assertEqual(b.cp, 2)
2596 self.assertEqual(a.cp, 1)
2597
2598 def test_set_name_not_called(self):
2599 cp = py_functools.cached_property(lambda s: None)
2600 class Foo:
2601 pass
2602
2603 Foo.cp = cp
2604
2605 with self.assertRaisesRegex(
2606 TypeError,
2607 "Cannot use cached_property instance without calling __set_name__ on it.",
2608 ):
2609 Foo().cp
2610
2611 def test_access_from_class(self):
2612 self.assertIsInstance(CachedCostItem.cost, py_functools.cached_property)
2613
2614 def test_doc(self):
2615 self.assertEqual(CachedCostItem.cost.__doc__, "The cost of the item.")
2616
2617
Raymond Hettinger9c323f82005-02-28 19:39:44 +00002618if __name__ == '__main__':
Zachary Ware38c707e2015-04-13 15:00:43 -05002619 unittest.main()