blob: 3320ab7ec6649d14f1dc8d39199a5469a4b142c8 [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08002import builtins
Raymond Hettinger003be522011-05-03 11:01:32 -07003import collections
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03004import collections.abc
Serhiy Storchaka45120f22015-10-24 09:49:56 +03005import copy
Pablo Galindo2f172d82020-06-01 00:41:14 +01006from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00007import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00008from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02009import sys
10from test import support
Antoine Pitroua6a4dc82017-09-07 18:56:24 +020011import threading
Serhiy Storchaka67796522017-01-12 18:34:33 +020012import time
Łukasz Langae5697532017-12-11 13:56:31 -080013import typing
Łukasz Langa6f692512013-06-05 12:20:24 +020014import unittest
Raymond Hettingerd191ef22017-01-07 20:44:48 -080015import unittest.mock
Pablo Galindo99e6c262020-01-23 15:29:52 +000016import os
Dennis Sweeney1253c3e2020-05-05 17:14:32 -040017import weakref
18import gc
Łukasz Langa6f692512013-06-05 12:20:24 +020019from weakref import proxy
Nick Coghlan457fc9a2016-09-10 20:00:02 +100020import contextlib
Raymond Hettinger9c323f82005-02-28 19:39:44 +000021
Hai Shi3ddc6342020-06-30 21:46:06 +080022from test.support import import_helper
Hai Shie80697d2020-05-28 06:10:27 +080023from test.support import threading_helper
Pablo Galindo99e6c262020-01-23 15:29:52 +000024from test.support.script_helper import assert_python_ok
25
Antoine Pitroub5b37142012-11-13 21:35:40 +010026import functools
27
Hai Shi3ddc6342020-06-30 21:46:06 +080028py_functools = import_helper.import_fresh_module('functools',
29 blocked=['_functools'])
Hai Shidd391232020-12-29 20:45:07 +080030c_functools = import_helper.import_fresh_module('functools')
Antoine Pitroub5b37142012-11-13 21:35:40 +010031
Hai Shi3ddc6342020-06-30 21:46:06 +080032decimal = import_helper.import_fresh_module('decimal', fresh=['_decimal'])
Łukasz Langa6f692512013-06-05 12:20:24 +020033
Nick Coghlan457fc9a2016-09-10 20:00:02 +100034@contextlib.contextmanager
35def replaced_module(name, replacement):
36 original_module = sys.modules[name]
37 sys.modules[name] = replacement
38 try:
39 yield
40 finally:
41 sys.modules[name] = original_module
Łukasz Langa6f692512013-06-05 12:20:24 +020042
Raymond Hettinger9c323f82005-02-28 19:39:44 +000043def capture(*args, **kw):
44 """capture all positional and keyword arguments"""
45 return args, kw
46
Łukasz Langa6f692512013-06-05 12:20:24 +020047
Jack Diederiche0cbd692009-04-01 04:27:09 +000048def signature(part):
49 """ return the signature of a partial object """
50 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000051
Serhiy Storchaka38741282016-02-02 18:45:17 +020052class MyTuple(tuple):
53 pass
54
55class BadTuple(tuple):
56 def __add__(self, other):
57 return list(self) + list(other)
58
59class MyDict(dict):
60 pass
61
Łukasz Langa6f692512013-06-05 12:20:24 +020062
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020063class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000064
65 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010066 p = self.partial(capture, 1, 2, a=10, b=20)
67 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000068 self.assertEqual(p(3, 4, b=30, c=40),
69 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010070 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000071 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000072
73 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010074 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000075 # attributes should be readable
76 self.assertEqual(p.func, capture)
77 self.assertEqual(p.args, (1, 2))
78 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000079
80 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010081 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000082 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010083 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000084 except TypeError:
85 pass
86 else:
87 self.fail('First arg not checked for callability')
88
89 def test_protection_of_callers_dict_argument(self):
90 # a caller's dictionary should not be altered by partial
91 def func(a=10, b=20):
92 return a
93 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010094 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000095 self.assertEqual(p(**d), 3)
96 self.assertEqual(d, {'a':3})
97 p(b=7)
98 self.assertEqual(d, {'a':3})
99
Serhiy Storchaka9639e4a2017-02-20 14:04:30 +0200100 def test_kwargs_copy(self):
101 # Issue #29532: Altering a kwarg dictionary passed to a constructor
102 # should not affect a partial object after creation
103 d = {'a': 3}
104 p = self.partial(capture, **d)
105 self.assertEqual(p(), ((), {'a': 3}))
106 d['a'] = 5
107 self.assertEqual(p(), ((), {'a': 3}))
108
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000109 def test_arg_combinations(self):
110 # exercise special code paths for zero args in either partial
111 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100112 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000113 self.assertEqual(p(), ((), {}))
114 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100115 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000116 self.assertEqual(p(), ((1,2), {}))
117 self.assertEqual(p(3,4), ((1,2,3,4), {}))
118
119 def test_kw_combinations(self):
120 # exercise special code paths for no keyword args in
121 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100122 p = self.partial(capture)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400123 self.assertEqual(p.keywords, {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000124 self.assertEqual(p(), ((), {}))
125 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100126 p = self.partial(capture, a=1)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400127 self.assertEqual(p.keywords, {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000128 self.assertEqual(p(), ((), {'a':1}))
129 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
130 # keyword args in the call override those in the partial object
131 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
132
133 def test_positional(self):
134 # make sure positional arguments are captured correctly
135 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100136 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000137 expected = args + ('x',)
138 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000139 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000140
141 def test_keyword(self):
142 # make sure keyword arguments are captured correctly
143 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100144 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000145 expected = {'a':a,'x':None}
146 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000147 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000148
149 def test_no_side_effects(self):
150 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100151 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000152 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000153 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000154 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000155 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000156
157 def test_error_propagation(self):
158 def f(x, y):
159 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100160 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
161 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
162 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
163 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000164
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000165 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100166 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000167 p = proxy(f)
168 self.assertEqual(f.func, p.func)
169 f = None
170 self.assertRaises(ReferenceError, getattr, p, 'func')
171
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000172 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000173 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100174 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000175 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100176 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000177 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000178
Alexander Belopolskye49af342015-03-01 15:08:17 -0500179 def test_nested_optimization(self):
180 partial = self.partial
Alexander Belopolskye49af342015-03-01 15:08:17 -0500181 inner = partial(signature, 'asdf')
182 nested = partial(inner, bar=True)
183 flat = partial(signature, 'asdf', bar=True)
184 self.assertEqual(signature(nested), signature(flat))
185
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300186 def test_nested_partial_with_attribute(self):
187 # see issue 25137
188 partial = self.partial
189
190 def foo(bar):
191 return bar
192
193 p = partial(foo, 'first')
194 p2 = partial(p, 'second')
195 p2.new_attr = 'spam'
196 self.assertEqual(p2.new_attr, 'spam')
197
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000198 def test_repr(self):
199 args = (object(), object())
200 args_repr = ', '.join(repr(a) for a in args)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200201 kwargs = {'a': object(), 'b': object()}
202 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
203 'b={b!r}, a={a!r}'.format_map(kwargs)]
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000204 if self.partial in (c_functools.partial, py_functools.partial):
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000205 name = 'functools.partial'
206 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100207 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000208
Antoine Pitroub5b37142012-11-13 21:35:40 +0100209 f = self.partial(capture)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000210 self.assertEqual(f'{name}({capture!r})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000211
Antoine Pitroub5b37142012-11-13 21:35:40 +0100212 f = self.partial(capture, *args)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000213 self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000214
Antoine Pitroub5b37142012-11-13 21:35:40 +0100215 f = self.partial(capture, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200216 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000217 [f'{name}({capture!r}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200218 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000219
Antoine Pitroub5b37142012-11-13 21:35:40 +0100220 f = self.partial(capture, *args, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200221 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000222 [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200223 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000224
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300225 def test_recursive_repr(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000226 if self.partial in (c_functools.partial, py_functools.partial):
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300227 name = 'functools.partial'
228 else:
229 name = self.partial.__name__
230
231 f = self.partial(capture)
232 f.__setstate__((f, (), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300233 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000234 self.assertEqual(repr(f), '%s(...)' % (name,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300235 finally:
236 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300237
238 f = self.partial(capture)
239 f.__setstate__((capture, (f,), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300240 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000241 self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300242 finally:
243 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300244
245 f = self.partial(capture)
246 f.__setstate__((capture, (), {'a': f}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300247 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000248 self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300249 finally:
250 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300251
Jack Diederiche0cbd692009-04-01 04:27:09 +0000252 def test_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000253 with self.AllowPickle():
254 f = self.partial(signature, ['asdf'], bar=[True])
255 f.attr = []
256 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
257 f_copy = pickle.loads(pickle.dumps(f, proto))
258 self.assertEqual(signature(f_copy), signature(f))
Serhiy Storchaka38741282016-02-02 18:45:17 +0200259
260 def test_copy(self):
261 f = self.partial(signature, ['asdf'], bar=[True])
262 f.attr = []
263 f_copy = copy.copy(f)
264 self.assertEqual(signature(f_copy), signature(f))
265 self.assertIs(f_copy.attr, f.attr)
266 self.assertIs(f_copy.args, f.args)
267 self.assertIs(f_copy.keywords, f.keywords)
268
269 def test_deepcopy(self):
270 f = self.partial(signature, ['asdf'], bar=[True])
271 f.attr = []
272 f_copy = copy.deepcopy(f)
273 self.assertEqual(signature(f_copy), signature(f))
274 self.assertIsNot(f_copy.attr, f.attr)
275 self.assertIsNot(f_copy.args, f.args)
276 self.assertIsNot(f_copy.args[0], f.args[0])
277 self.assertIsNot(f_copy.keywords, f.keywords)
278 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
279
280 def test_setstate(self):
281 f = self.partial(signature)
282 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000283
Serhiy Storchaka38741282016-02-02 18:45:17 +0200284 self.assertEqual(signature(f),
285 (capture, (1,), dict(a=10), dict(attr=[])))
286 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
287
288 f.__setstate__((capture, (1,), dict(a=10), None))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000289
Serhiy Storchaka38741282016-02-02 18:45:17 +0200290 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
291 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
292
293 f.__setstate__((capture, (1,), None, None))
294 #self.assertEqual(signature(f), (capture, (1,), {}, {}))
295 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
296 self.assertEqual(f(2), ((1, 2), {}))
297 self.assertEqual(f(), ((1,), {}))
298
299 f.__setstate__((capture, (), {}, None))
300 self.assertEqual(signature(f), (capture, (), {}, {}))
301 self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
302 self.assertEqual(f(2), ((2,), {}))
303 self.assertEqual(f(), ((), {}))
304
305 def test_setstate_errors(self):
306 f = self.partial(signature)
307 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
308 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
309 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
310 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
311 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
312 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
313 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
314
315 def test_setstate_subclasses(self):
316 f = self.partial(signature)
317 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
318 s = signature(f)
319 self.assertEqual(s, (capture, (1,), dict(a=10), {}))
320 self.assertIs(type(s[1]), tuple)
321 self.assertIs(type(s[2]), dict)
322 r = f()
323 self.assertEqual(r, ((1,), {'a': 10}))
324 self.assertIs(type(r[0]), tuple)
325 self.assertIs(type(r[1]), dict)
326
327 f.__setstate__((capture, BadTuple((1,)), {}, None))
328 s = signature(f)
329 self.assertEqual(s, (capture, (1,), {}, {}))
330 self.assertIs(type(s[1]), tuple)
331 r = f(2)
332 self.assertEqual(r, ((1, 2), {}))
333 self.assertIs(type(r[0]), tuple)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000334
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300335 def test_recursive_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000336 with self.AllowPickle():
337 f = self.partial(capture)
338 f.__setstate__((f, (), {}, {}))
339 try:
340 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
341 with self.assertRaises(RecursionError):
342 pickle.dumps(f, proto)
343 finally:
344 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300345
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000346 f = self.partial(capture)
347 f.__setstate__((capture, (f,), {}, {}))
348 try:
349 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
350 f_copy = pickle.loads(pickle.dumps(f, proto))
351 try:
352 self.assertIs(f_copy.args[0], f_copy)
353 finally:
354 f_copy.__setstate__((capture, (), {}, {}))
355 finally:
356 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300357
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000358 f = self.partial(capture)
359 f.__setstate__((capture, (), {'a': f}, {}))
360 try:
361 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
362 f_copy = pickle.loads(pickle.dumps(f, proto))
363 try:
364 self.assertIs(f_copy.keywords['a'], f_copy)
365 finally:
366 f_copy.__setstate__((capture, (), {}, {}))
367 finally:
368 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300369
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200370 # Issue 6083: Reference counting bug
371 def test_setstate_refcount(self):
372 class BadSequence:
373 def __len__(self):
374 return 4
375 def __getitem__(self, key):
376 if key == 0:
377 return max
378 elif key == 1:
379 return tuple(range(1000000))
380 elif key in (2, 3):
381 return {}
382 raise IndexError
383
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200384 f = self.partial(object)
Serhiy Storchaka38741282016-02-02 18:45:17 +0200385 self.assertRaises(TypeError, f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000386
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000387@unittest.skipUnless(c_functools, 'requires the C _functools module')
388class TestPartialC(TestPartial, unittest.TestCase):
389 if c_functools:
390 partial = c_functools.partial
391
392 class AllowPickle:
393 def __enter__(self):
394 return self
395 def __exit__(self, type, value, tb):
396 return False
397
398 def test_attributes_unwritable(self):
399 # attributes should not be writable
400 p = self.partial(capture, 1, 2, a=10, b=20)
401 self.assertRaises(AttributeError, setattr, p, 'func', map)
402 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
403 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
404
405 p = self.partial(hex)
406 try:
407 del p.__dict__
408 except TypeError:
409 pass
410 else:
411 self.fail('partial object allowed __dict__ to be deleted')
Łukasz Langa6f692512013-06-05 12:20:24 +0200412
Michael Seifert6c3d5272017-03-15 06:26:33 +0100413 def test_manually_adding_non_string_keyword(self):
414 p = self.partial(capture)
415 # Adding a non-string/unicode keyword to partial kwargs
416 p.keywords[1234] = 'value'
417 r = repr(p)
418 self.assertIn('1234', r)
419 self.assertIn("'value'", r)
420 with self.assertRaises(TypeError):
421 p()
422
423 def test_keystr_replaces_value(self):
424 p = self.partial(capture)
425
426 class MutatesYourDict(object):
427 def __str__(self):
428 p.keywords[self] = ['sth2']
429 return 'astr'
430
Mike53f7a7c2017-12-14 14:04:53 +0300431 # Replacing the value during key formatting should keep the original
Michael Seifert6c3d5272017-03-15 06:26:33 +0100432 # value alive (at least long enough).
433 p.keywords[MutatesYourDict()] = ['sth']
434 r = repr(p)
435 self.assertIn('astr', r)
436 self.assertIn("['sth']", r)
437
438
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200439class TestPartialPy(TestPartial, unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000440 partial = py_functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000441
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000442 class AllowPickle:
443 def __init__(self):
444 self._cm = replaced_module("functools", py_functools)
445 def __enter__(self):
446 return self._cm.__enter__()
447 def __exit__(self, type, value, tb):
448 return self._cm.__exit__(type, value, tb)
Łukasz Langa6f692512013-06-05 12:20:24 +0200449
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200450if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000451 class CPartialSubclass(c_functools.partial):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200452 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100453
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000454class PyPartialSubclass(py_functools.partial):
455 pass
Łukasz Langa6f692512013-06-05 12:20:24 +0200456
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200457@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200458class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200459 if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000460 partial = CPartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000461
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300462 # partial subclasses are not optimized for nested calls
463 test_nested_optimization = None
464
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000465class TestPartialPySubclass(TestPartialPy):
466 partial = PyPartialSubclass
Łukasz Langa6f692512013-06-05 12:20:24 +0200467
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000468class TestPartialMethod(unittest.TestCase):
469
470 class A(object):
471 nothing = functools.partialmethod(capture)
472 positional = functools.partialmethod(capture, 1)
473 keywords = functools.partialmethod(capture, a=2)
474 both = functools.partialmethod(capture, 3, b=4)
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300475 spec_keywords = functools.partialmethod(capture, self=1, func=2)
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000476
477 nested = functools.partialmethod(positional, 5)
478
479 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
480
481 static = functools.partialmethod(staticmethod(capture), 8)
482 cls = functools.partialmethod(classmethod(capture), d=9)
483
484 a = A()
485
486 def test_arg_combinations(self):
487 self.assertEqual(self.a.nothing(), ((self.a,), {}))
488 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
489 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
490 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
491
492 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
493 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
494 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
495 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
496
497 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
498 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
499 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
500 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
501
502 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
503 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
504 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
505 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
506
507 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
508
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300509 self.assertEqual(self.a.spec_keywords(), ((self.a,), {'self': 1, 'func': 2}))
510
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000511 def test_nested(self):
512 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
513 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
514 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
515 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
516
517 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
518
519 def test_over_partial(self):
520 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
521 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
522 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
523 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
524
525 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
526
527 def test_bound_method_introspection(self):
528 obj = self.a
529 self.assertIs(obj.both.__self__, obj)
530 self.assertIs(obj.nested.__self__, obj)
531 self.assertIs(obj.over_partial.__self__, obj)
532 self.assertIs(obj.cls.__self__, self.A)
533 self.assertIs(self.A.cls.__self__, self.A)
534
535 def test_unbound_method_retrieval(self):
536 obj = self.A
537 self.assertFalse(hasattr(obj.both, "__self__"))
538 self.assertFalse(hasattr(obj.nested, "__self__"))
539 self.assertFalse(hasattr(obj.over_partial, "__self__"))
540 self.assertFalse(hasattr(obj.static, "__self__"))
541 self.assertFalse(hasattr(self.a.static, "__self__"))
542
543 def test_descriptors(self):
544 for obj in [self.A, self.a]:
545 with self.subTest(obj=obj):
546 self.assertEqual(obj.static(), ((8,), {}))
547 self.assertEqual(obj.static(5), ((8, 5), {}))
548 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
549 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
550
551 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
552 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
553 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
554 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
555
556 def test_overriding_keywords(self):
557 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
558 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
559
560 def test_invalid_args(self):
561 with self.assertRaises(TypeError):
562 class B(object):
563 method = functools.partialmethod(None, 1)
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300564 with self.assertRaises(TypeError):
565 class B:
566 method = functools.partialmethod()
Serhiy Storchaka142566c2019-06-05 18:22:31 +0300567 with self.assertRaises(TypeError):
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300568 class B:
569 method = functools.partialmethod(func=capture, a=1)
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000570
571 def test_repr(self):
572 self.assertEqual(repr(vars(self.A)['both']),
573 'functools.partialmethod({}, 3, b=4)'.format(capture))
574
575 def test_abstract(self):
576 class Abstract(abc.ABCMeta):
577
578 @abc.abstractmethod
579 def add(self, x, y):
580 pass
581
582 add5 = functools.partialmethod(add, 5)
583
584 self.assertTrue(Abstract.add.__isabstractmethod__)
585 self.assertTrue(Abstract.add5.__isabstractmethod__)
586
587 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
588 self.assertFalse(getattr(func, '__isabstractmethod__', False))
589
Pablo Galindo8c77b8c2019-04-29 13:36:57 +0100590 def test_positional_only(self):
591 def f(a, b, /):
592 return a + b
593
594 p = functools.partial(f, 1)
595 self.assertEqual(p(2), f(1, 2))
596
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000597
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000598class TestUpdateWrapper(unittest.TestCase):
599
600 def check_wrapper(self, wrapper, wrapped,
601 assigned=functools.WRAPPER_ASSIGNMENTS,
602 updated=functools.WRAPPER_UPDATES):
603 # Check attributes were assigned
604 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000605 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000606 # Check attributes were updated
607 for name in updated:
608 wrapper_attr = getattr(wrapper, name)
609 wrapped_attr = getattr(wrapped, name)
610 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000611 if name == "__dict__" and key == "__wrapped__":
612 # __wrapped__ is overwritten by the update code
613 continue
614 self.assertIs(wrapped_attr[key], wrapper_attr[key])
615 # Check __wrapped__
616 self.assertIs(wrapper.__wrapped__, wrapped)
617
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000618
R. David Murray378c0cf2010-02-24 01:46:21 +0000619 def _default_update(self):
Pablo Galindob0544ba2021-04-21 12:41:19 +0100620 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000621 """This is a test"""
622 pass
623 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000624 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000625 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000626 pass
627 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000628 return wrapper, f
629
630 def test_default_update(self):
631 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000632 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000633 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000634 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600635 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000636 self.assertEqual(wrapper.attr, 'This is also a test')
Pablo Galindob0544ba2021-04-21 12:41:19 +0100637 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000638 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000639
R. David Murray378c0cf2010-02-24 01:46:21 +0000640 @unittest.skipIf(sys.flags.optimize >= 2,
641 "Docstrings are omitted with -O2 and above")
642 def test_default_update_doc(self):
643 wrapper, f = self._default_update()
644 self.assertEqual(wrapper.__doc__, 'This is a test')
645
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000646 def test_no_update(self):
647 def f():
648 """This is a test"""
649 pass
650 f.attr = 'This is also a test'
651 def wrapper():
652 pass
653 functools.update_wrapper(wrapper, f, (), ())
654 self.check_wrapper(wrapper, f, (), ())
655 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600656 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000657 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000658 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000659 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000660
661 def test_selective_update(self):
662 def f():
663 pass
664 f.attr = 'This is a different test'
665 f.dict_attr = dict(a=1, b=2, c=3)
666 def wrapper():
667 pass
668 wrapper.dict_attr = {}
669 assign = ('attr',)
670 update = ('dict_attr',)
671 functools.update_wrapper(wrapper, f, assign, update)
672 self.check_wrapper(wrapper, f, assign, update)
673 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600674 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000675 self.assertEqual(wrapper.__doc__, None)
676 self.assertEqual(wrapper.attr, 'This is a different test')
677 self.assertEqual(wrapper.dict_attr, f.dict_attr)
678
Nick Coghlan98876832010-08-17 06:17:18 +0000679 def test_missing_attributes(self):
680 def f():
681 pass
682 def wrapper():
683 pass
684 wrapper.dict_attr = {}
685 assign = ('attr',)
686 update = ('dict_attr',)
687 # Missing attributes on wrapped object are ignored
688 functools.update_wrapper(wrapper, f, assign, update)
689 self.assertNotIn('attr', wrapper.__dict__)
690 self.assertEqual(wrapper.dict_attr, {})
691 # Wrapper must have expected attributes for updating
692 del wrapper.dict_attr
693 with self.assertRaises(AttributeError):
694 functools.update_wrapper(wrapper, f, assign, update)
695 wrapper.dict_attr = 1
696 with self.assertRaises(AttributeError):
697 functools.update_wrapper(wrapper, f, assign, update)
698
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200699 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000700 @unittest.skipIf(sys.flags.optimize >= 2,
701 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000702 def test_builtin_update(self):
703 # Test for bug #1576241
704 def wrapper():
705 pass
706 functools.update_wrapper(wrapper, max)
707 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000708 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000709 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000710
Łukasz Langa6f692512013-06-05 12:20:24 +0200711
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000712class TestWraps(TestUpdateWrapper):
713
R. David Murray378c0cf2010-02-24 01:46:21 +0000714 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000715 def f():
716 """This is a test"""
717 pass
718 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000719 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000720 @functools.wraps(f)
721 def wrapper():
722 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600723 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000724
725 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600726 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000727 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000728 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600729 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000730 self.assertEqual(wrapper.attr, 'This is also a test')
731
Antoine Pitroub5b37142012-11-13 21:35:40 +0100732 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000733 "Docstrings are omitted with -O2 and above")
734 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600735 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000736 self.assertEqual(wrapper.__doc__, 'This is a test')
737
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000738 def test_no_update(self):
739 def f():
740 """This is a test"""
741 pass
742 f.attr = 'This is also a test'
743 @functools.wraps(f, (), ())
744 def wrapper():
745 pass
746 self.check_wrapper(wrapper, f, (), ())
747 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600748 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000749 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000750 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000751
752 def test_selective_update(self):
753 def f():
754 pass
755 f.attr = 'This is a different test'
756 f.dict_attr = dict(a=1, b=2, c=3)
757 def add_dict_attr(f):
758 f.dict_attr = {}
759 return f
760 assign = ('attr',)
761 update = ('dict_attr',)
762 @functools.wraps(f, assign, update)
763 @add_dict_attr
764 def wrapper():
765 pass
766 self.check_wrapper(wrapper, f, assign, update)
767 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600768 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000769 self.assertEqual(wrapper.__doc__, None)
770 self.assertEqual(wrapper.attr, 'This is a different test')
771 self.assertEqual(wrapper.dict_attr, f.dict_attr)
772
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000773
madman-bobe25d5fc2018-10-25 15:02:10 +0100774class TestReduce:
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000775 def test_reduce(self):
776 class Squares:
777 def __init__(self, max):
778 self.max = max
779 self.sofar = []
780
781 def __len__(self):
782 return len(self.sofar)
783
784 def __getitem__(self, i):
785 if not 0 <= i < self.max: raise IndexError
786 n = len(self.sofar)
787 while n <= i:
788 self.sofar.append(n*n)
789 n += 1
790 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000791 def add(x, y):
792 return x + y
madman-bobe25d5fc2018-10-25 15:02:10 +0100793 self.assertEqual(self.reduce(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000794 self.assertEqual(
madman-bobe25d5fc2018-10-25 15:02:10 +0100795 self.reduce(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000796 ['a','c','d','w']
797 )
madman-bobe25d5fc2018-10-25 15:02:10 +0100798 self.assertEqual(self.reduce(lambda x, y: x*y, range(2,8), 1), 5040)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000799 self.assertEqual(
madman-bobe25d5fc2018-10-25 15:02:10 +0100800 self.reduce(lambda x, y: x*y, range(2,21), 1),
Guido van Rossume2a383d2007-01-15 16:59:06 +0000801 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000802 )
madman-bobe25d5fc2018-10-25 15:02:10 +0100803 self.assertEqual(self.reduce(add, Squares(10)), 285)
804 self.assertEqual(self.reduce(add, Squares(10), 0), 285)
805 self.assertEqual(self.reduce(add, Squares(0), 0), 0)
806 self.assertRaises(TypeError, self.reduce)
807 self.assertRaises(TypeError, self.reduce, 42, 42)
808 self.assertRaises(TypeError, self.reduce, 42, 42, 42)
809 self.assertEqual(self.reduce(42, "1"), "1") # func is never called with one item
810 self.assertEqual(self.reduce(42, "", "1"), "1") # func is never called with one item
811 self.assertRaises(TypeError, self.reduce, 42, (42, 42))
812 self.assertRaises(TypeError, self.reduce, add, []) # arg 2 must not be empty sequence with no initial value
813 self.assertRaises(TypeError, self.reduce, add, "")
814 self.assertRaises(TypeError, self.reduce, add, ())
815 self.assertRaises(TypeError, self.reduce, add, object())
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000816
817 class TestFailingIter:
818 def __iter__(self):
819 raise RuntimeError
madman-bobe25d5fc2018-10-25 15:02:10 +0100820 self.assertRaises(RuntimeError, self.reduce, add, TestFailingIter())
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000821
madman-bobe25d5fc2018-10-25 15:02:10 +0100822 self.assertEqual(self.reduce(add, [], None), None)
823 self.assertEqual(self.reduce(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000824
825 class BadSeq:
826 def __getitem__(self, index):
827 raise ValueError
madman-bobe25d5fc2018-10-25 15:02:10 +0100828 self.assertRaises(ValueError, self.reduce, 42, BadSeq())
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000829
830 # Test reduce()'s use of iterators.
831 def test_iterator_usage(self):
832 class SequenceClass:
833 def __init__(self, n):
834 self.n = n
835 def __getitem__(self, i):
836 if 0 <= i < self.n:
837 return i
838 else:
839 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000840
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000841 from operator import add
madman-bobe25d5fc2018-10-25 15:02:10 +0100842 self.assertEqual(self.reduce(add, SequenceClass(5)), 10)
843 self.assertEqual(self.reduce(add, SequenceClass(5), 42), 52)
844 self.assertRaises(TypeError, self.reduce, add, SequenceClass(0))
845 self.assertEqual(self.reduce(add, SequenceClass(0), 42), 42)
846 self.assertEqual(self.reduce(add, SequenceClass(1)), 0)
847 self.assertEqual(self.reduce(add, SequenceClass(1), 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000848
849 d = {"one": 1, "two": 2, "three": 3}
madman-bobe25d5fc2018-10-25 15:02:10 +0100850 self.assertEqual(self.reduce(add, d), "".join(d.keys()))
851
852
853@unittest.skipUnless(c_functools, 'requires the C _functools module')
854class TestReduceC(TestReduce, unittest.TestCase):
855 if c_functools:
856 reduce = c_functools.reduce
857
858
859class TestReducePy(TestReduce, unittest.TestCase):
860 reduce = staticmethod(py_functools.reduce)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000861
Łukasz Langa6f692512013-06-05 12:20:24 +0200862
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200863class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700864
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000865 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700866 def cmp1(x, y):
867 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100868 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700869 self.assertEqual(key(3), key(3))
870 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100871 self.assertGreaterEqual(key(3), key(3))
872
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700873 def cmp2(x, y):
874 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100875 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700876 self.assertEqual(key(4.0), key('4'))
877 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100878 self.assertLessEqual(key(2), key('35'))
879 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700880
881 def test_cmp_to_key_arguments(self):
882 def cmp1(x, y):
883 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100884 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700885 self.assertEqual(key(obj=3), key(obj=3))
886 self.assertGreater(key(obj=3), key(obj=1))
887 with self.assertRaises((TypeError, AttributeError)):
888 key(3) > 1 # rhs is not a K object
889 with self.assertRaises((TypeError, AttributeError)):
890 1 < key(3) # lhs is not a K object
891 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100892 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700893 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200894 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100895 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700896 with self.assertRaises(TypeError):
897 key() # too few args
898 with self.assertRaises(TypeError):
899 key(None, None) # too many args
900
901 def test_bad_cmp(self):
902 def cmp1(x, y):
903 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100904 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700905 with self.assertRaises(ZeroDivisionError):
906 key(3) > key(1)
907
908 class BadCmp:
909 def __lt__(self, other):
910 raise ZeroDivisionError
911 def cmp1(x, y):
912 return BadCmp()
913 with self.assertRaises(ZeroDivisionError):
914 key(3) > key(1)
915
916 def test_obj_field(self):
917 def cmp1(x, y):
918 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100919 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700920 self.assertEqual(key(50).obj, 50)
921
922 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000923 def mycmp(x, y):
924 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100925 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000926 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000927
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700928 def test_sort_int_str(self):
929 def mycmp(x, y):
930 x, y = int(x), int(y)
931 return (x > y) - (x < y)
932 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100933 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700934 self.assertEqual([int(value) for value in values],
935 [0, 1, 1, 2, 3, 4, 5, 7, 10])
936
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000937 def test_hash(self):
938 def mycmp(x, y):
939 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100940 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000941 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700942 self.assertRaises(TypeError, hash, k)
Serhiy Storchaka2e576f52017-04-24 09:05:00 +0300943 self.assertNotIsInstance(k, collections.abc.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000944
Łukasz Langa6f692512013-06-05 12:20:24 +0200945
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200946@unittest.skipUnless(c_functools, 'requires the C _functools module')
947class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
948 if c_functools:
949 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100950
Erlend Egeberg Aasland9746cda2021-04-30 16:04:57 +0200951 @support.cpython_only
952 def test_disallow_instantiation(self):
953 # Ensure that the type disallows instantiation (bpo-43916)
954 tp = type(c_functools.cmp_to_key(None))
955 self.assertRaises(TypeError, tp)
956
Łukasz Langa6f692512013-06-05 12:20:24 +0200957
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200958class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100959 cmp_to_key = staticmethod(py_functools.cmp_to_key)
960
Łukasz Langa6f692512013-06-05 12:20:24 +0200961
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000962class TestTotalOrdering(unittest.TestCase):
963
964 def test_total_ordering_lt(self):
965 @functools.total_ordering
966 class A:
967 def __init__(self, value):
968 self.value = value
969 def __lt__(self, other):
970 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000971 def __eq__(self, other):
972 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000973 self.assertTrue(A(1) < A(2))
974 self.assertTrue(A(2) > A(1))
975 self.assertTrue(A(1) <= A(2))
976 self.assertTrue(A(2) >= A(1))
977 self.assertTrue(A(2) <= A(2))
978 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000979 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000980
981 def test_total_ordering_le(self):
982 @functools.total_ordering
983 class A:
984 def __init__(self, value):
985 self.value = value
986 def __le__(self, other):
987 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000988 def __eq__(self, other):
989 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000990 self.assertTrue(A(1) < A(2))
991 self.assertTrue(A(2) > A(1))
992 self.assertTrue(A(1) <= A(2))
993 self.assertTrue(A(2) >= A(1))
994 self.assertTrue(A(2) <= A(2))
995 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000996 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000997
998 def test_total_ordering_gt(self):
999 @functools.total_ordering
1000 class A:
1001 def __init__(self, value):
1002 self.value = value
1003 def __gt__(self, other):
1004 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001005 def __eq__(self, other):
1006 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +00001007 self.assertTrue(A(1) < A(2))
1008 self.assertTrue(A(2) > A(1))
1009 self.assertTrue(A(1) <= A(2))
1010 self.assertTrue(A(2) >= A(1))
1011 self.assertTrue(A(2) <= A(2))
1012 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +10001013 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +00001014
1015 def test_total_ordering_ge(self):
1016 @functools.total_ordering
1017 class A:
1018 def __init__(self, value):
1019 self.value = value
1020 def __ge__(self, other):
1021 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001022 def __eq__(self, other):
1023 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +00001024 self.assertTrue(A(1) < A(2))
1025 self.assertTrue(A(2) > A(1))
1026 self.assertTrue(A(1) <= A(2))
1027 self.assertTrue(A(2) >= A(1))
1028 self.assertTrue(A(2) <= A(2))
1029 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +10001030 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +00001031
1032 def test_total_ordering_no_overwrite(self):
1033 # new methods should not overwrite existing
1034 @functools.total_ordering
1035 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +00001036 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +00001037 self.assertTrue(A(1) < A(2))
1038 self.assertTrue(A(2) > A(1))
1039 self.assertTrue(A(1) <= A(2))
1040 self.assertTrue(A(2) >= A(1))
1041 self.assertTrue(A(2) <= A(2))
1042 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001043
Benjamin Peterson42ebee32010-04-11 01:43:16 +00001044 def test_no_operations_defined(self):
1045 with self.assertRaises(ValueError):
1046 @functools.total_ordering
1047 class A:
1048 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001049
Nick Coghlanf05d9812013-10-02 00:02:03 +10001050 def test_type_error_when_not_implemented(self):
1051 # bug 10042; ensure stack overflow does not occur
1052 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001053 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +10001054 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001055 def __init__(self, value):
1056 self.value = value
1057 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001058 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001059 return self.value == other.value
1060 return False
1061 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001062 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001063 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +10001064 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001065
Nick Coghlanf05d9812013-10-02 00:02:03 +10001066 @functools.total_ordering
1067 class ImplementsGreaterThan:
1068 def __init__(self, value):
1069 self.value = value
1070 def __eq__(self, other):
1071 if isinstance(other, ImplementsGreaterThan):
1072 return self.value == other.value
1073 return False
1074 def __gt__(self, other):
1075 if isinstance(other, ImplementsGreaterThan):
1076 return self.value > other.value
1077 return NotImplemented
1078
1079 @functools.total_ordering
1080 class ImplementsLessThanEqualTo:
1081 def __init__(self, value):
1082 self.value = value
1083 def __eq__(self, other):
1084 if isinstance(other, ImplementsLessThanEqualTo):
1085 return self.value == other.value
1086 return False
1087 def __le__(self, other):
1088 if isinstance(other, ImplementsLessThanEqualTo):
1089 return self.value <= other.value
1090 return NotImplemented
1091
1092 @functools.total_ordering
1093 class ImplementsGreaterThanEqualTo:
1094 def __init__(self, value):
1095 self.value = value
1096 def __eq__(self, other):
1097 if isinstance(other, ImplementsGreaterThanEqualTo):
1098 return self.value == other.value
1099 return False
1100 def __ge__(self, other):
1101 if isinstance(other, ImplementsGreaterThanEqualTo):
1102 return self.value >= other.value
1103 return NotImplemented
1104
1105 @functools.total_ordering
1106 class ComparatorNotImplemented:
1107 def __init__(self, value):
1108 self.value = value
1109 def __eq__(self, other):
1110 if isinstance(other, ComparatorNotImplemented):
1111 return self.value == other.value
1112 return False
1113 def __lt__(self, other):
1114 return NotImplemented
1115
1116 with self.subTest("LT < 1"), self.assertRaises(TypeError):
1117 ImplementsLessThan(-1) < 1
1118
1119 with self.subTest("LT < LE"), self.assertRaises(TypeError):
1120 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
1121
1122 with self.subTest("LT < GT"), self.assertRaises(TypeError):
1123 ImplementsLessThan(1) < ImplementsGreaterThan(1)
1124
1125 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
1126 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
1127
1128 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
1129 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
1130
1131 with self.subTest("GT > GE"), self.assertRaises(TypeError):
1132 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
1133
1134 with self.subTest("GT > LT"), self.assertRaises(TypeError):
1135 ImplementsGreaterThan(5) > ImplementsLessThan(5)
1136
1137 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
1138 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
1139
1140 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
1141 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
1142
1143 with self.subTest("GE when equal"):
1144 a = ComparatorNotImplemented(8)
1145 b = ComparatorNotImplemented(8)
1146 self.assertEqual(a, b)
1147 with self.assertRaises(TypeError):
1148 a >= b
1149
1150 with self.subTest("LE when equal"):
1151 a = ComparatorNotImplemented(9)
1152 b = ComparatorNotImplemented(9)
1153 self.assertEqual(a, b)
1154 with self.assertRaises(TypeError):
1155 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +02001156
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001157 def test_pickle(self):
Serhiy Storchaka92bb90a2016-09-22 11:39:25 +03001158 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001159 for name in '__lt__', '__gt__', '__le__', '__ge__':
1160 with self.subTest(method=name, proto=proto):
1161 method = getattr(Orderable_LT, name)
1162 method_copy = pickle.loads(pickle.dumps(method, proto))
1163 self.assertIs(method_copy, method)
1164
1165@functools.total_ordering
1166class Orderable_LT:
1167 def __init__(self, value):
1168 self.value = value
1169 def __lt__(self, other):
1170 return self.value < other.value
1171 def __eq__(self, other):
1172 return self.value == other.value
1173
1174
Raymond Hettinger21cdb712020-05-11 17:00:53 -07001175class TestCache:
1176 # This tests that the pass-through is working as designed.
1177 # The underlying functionality is tested in TestLRU.
1178
1179 def test_cache(self):
1180 @self.module.cache
1181 def fib(n):
1182 if n < 2:
1183 return n
1184 return fib(n-1) + fib(n-2)
1185 self.assertEqual([fib(n) for n in range(16)],
1186 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1187 self.assertEqual(fib.cache_info(),
1188 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1189 fib.cache_clear()
1190 self.assertEqual(fib.cache_info(),
1191 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1192
1193
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001194class TestLRU:
Georg Brandl2e7346a2010-07-31 18:09:23 +00001195
1196 def test_lru(self):
1197 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001198 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001199 f = self.module.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001200 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001201 self.assertEqual(maxsize, 20)
1202 self.assertEqual(currsize, 0)
1203 self.assertEqual(hits, 0)
1204 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001205
1206 domain = range(5)
1207 for i in range(1000):
1208 x, y = choice(domain), choice(domain)
1209 actual = f(x, y)
1210 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +00001211 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001212 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001213 self.assertTrue(hits > misses)
1214 self.assertEqual(hits + misses, 1000)
1215 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001216
Raymond Hettinger02566ec2010-09-04 22:46:06 +00001217 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +00001218 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001219 self.assertEqual(hits, 0)
1220 self.assertEqual(misses, 0)
1221 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001222 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001223 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001224 self.assertEqual(hits, 0)
1225 self.assertEqual(misses, 1)
1226 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001227
Nick Coghlan98876832010-08-17 06:17:18 +00001228 # Test bypassing the cache
1229 self.assertIs(f.__wrapped__, orig)
1230 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001231 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001232 self.assertEqual(hits, 0)
1233 self.assertEqual(misses, 1)
1234 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +00001235
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001236 # test size zero (which means "never-cache")
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001237 @self.module.lru_cache(0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001238 def f():
1239 nonlocal f_cnt
1240 f_cnt += 1
1241 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001242 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001243 f_cnt = 0
1244 for i in range(5):
1245 self.assertEqual(f(), 20)
1246 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001247 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001248 self.assertEqual(hits, 0)
1249 self.assertEqual(misses, 5)
1250 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001251
1252 # test size one
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001253 @self.module.lru_cache(1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001254 def f():
1255 nonlocal f_cnt
1256 f_cnt += 1
1257 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001258 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001259 f_cnt = 0
1260 for i in range(5):
1261 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001262 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001263 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001264 self.assertEqual(hits, 4)
1265 self.assertEqual(misses, 1)
1266 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001267
Raymond Hettingerf3098282010-08-15 03:30:45 +00001268 # test size two
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001269 @self.module.lru_cache(2)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001270 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001271 nonlocal f_cnt
1272 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +00001273 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +00001274 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001275 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +00001276 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1277 # * * * *
1278 self.assertEqual(f(x), x*10)
1279 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001280 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001281 self.assertEqual(hits, 12)
1282 self.assertEqual(misses, 4)
1283 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001284
Raymond Hettingerb8218682019-05-26 11:27:35 -07001285 def test_lru_no_args(self):
1286 @self.module.lru_cache
1287 def square(x):
1288 return x ** 2
1289
1290 self.assertEqual(list(map(square, [10, 20, 10])),
1291 [100, 400, 100])
1292 self.assertEqual(square.cache_info().hits, 1)
1293 self.assertEqual(square.cache_info().misses, 2)
1294 self.assertEqual(square.cache_info().maxsize, 128)
1295 self.assertEqual(square.cache_info().currsize, 2)
1296
Raymond Hettingerd8080c02019-01-26 03:02:00 -05001297 def test_lru_bug_35780(self):
1298 # C version of the lru_cache was not checking to see if
1299 # the user function call has already modified the cache
1300 # (this arises in recursive calls and in multi-threading).
1301 # This cause the cache to have orphan links not referenced
1302 # by the cache dictionary.
1303
1304 once = True # Modified by f(x) below
1305
1306 @self.module.lru_cache(maxsize=10)
1307 def f(x):
1308 nonlocal once
1309 rv = f'.{x}.'
1310 if x == 20 and once:
1311 once = False
1312 rv = f(x)
1313 return rv
1314
1315 # Fill the cache
1316 for x in range(15):
1317 self.assertEqual(f(x), f'.{x}.')
1318 self.assertEqual(f.cache_info().currsize, 10)
1319
1320 # Make a recursive call and make sure the cache remains full
1321 self.assertEqual(f(20), '.20.')
1322 self.assertEqual(f.cache_info().currsize, 10)
1323
Raymond Hettinger14adbd42019-04-20 07:20:44 -10001324 def test_lru_bug_36650(self):
1325 # C version of lru_cache was treating a call with an empty **kwargs
1326 # dictionary as being distinct from a call with no keywords at all.
1327 # This did not result in an incorrect answer, but it did trigger
1328 # an unexpected cache miss.
1329
1330 @self.module.lru_cache()
1331 def f(x):
1332 pass
1333
1334 f(0)
1335 f(0, **{})
1336 self.assertEqual(f.cache_info().hits, 1)
1337
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001338 def test_lru_hash_only_once(self):
1339 # To protect against weird reentrancy bugs and to improve
1340 # efficiency when faced with slow __hash__ methods, the
1341 # LRU cache guarantees that it will only call __hash__
1342 # only once per use as an argument to the cached function.
1343
1344 @self.module.lru_cache(maxsize=1)
1345 def f(x, y):
1346 return x * 3 + y
1347
1348 # Simulate the integer 5
1349 mock_int = unittest.mock.Mock()
1350 mock_int.__mul__ = unittest.mock.Mock(return_value=15)
1351 mock_int.__hash__ = unittest.mock.Mock(return_value=999)
1352
1353 # Add to cache: One use as an argument gives one call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001354 self.assertEqual(f(mock_int, 1), 16)
1355 self.assertEqual(mock_int.__hash__.call_count, 1)
1356 self.assertEqual(f.cache_info(), (0, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001357
1358 # Cache hit: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001359 self.assertEqual(f(mock_int, 1), 16)
1360 self.assertEqual(mock_int.__hash__.call_count, 2)
1361 self.assertEqual(f.cache_info(), (1, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001362
Ville Skyttä49b27342017-08-03 09:00:59 +03001363 # Cache eviction: No use as an argument gives no additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001364 self.assertEqual(f(6, 2), 20)
1365 self.assertEqual(mock_int.__hash__.call_count, 2)
1366 self.assertEqual(f.cache_info(), (1, 2, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001367
1368 # Cache miss: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001369 self.assertEqual(f(mock_int, 1), 16)
1370 self.assertEqual(mock_int.__hash__.call_count, 3)
1371 self.assertEqual(f.cache_info(), (1, 3, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001372
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08001373 def test_lru_reentrancy_with_len(self):
1374 # Test to make sure the LRU cache code isn't thrown-off by
1375 # caching the built-in len() function. Since len() can be
1376 # cached, we shouldn't use it inside the lru code itself.
1377 old_len = builtins.len
1378 try:
1379 builtins.len = self.module.lru_cache(4)(len)
1380 for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
1381 self.assertEqual(len('abcdefghijklmn'[:i]), i)
1382 finally:
1383 builtins.len = old_len
1384
Raymond Hettinger605a4472017-01-09 07:50:19 -08001385 def test_lru_star_arg_handling(self):
1386 # Test regression that arose in ea064ff3c10f
1387 @functools.lru_cache()
1388 def f(*args):
1389 return args
1390
1391 self.assertEqual(f(1, 2), (1, 2))
1392 self.assertEqual(f((1, 2)), ((1, 2),))
1393
Yury Selivanov46a02db2016-11-09 18:55:45 -05001394 def test_lru_type_error(self):
1395 # Regression test for issue #28653.
1396 # lru_cache was leaking when one of the arguments
1397 # wasn't cacheable.
1398
1399 @functools.lru_cache(maxsize=None)
1400 def infinite_cache(o):
1401 pass
1402
1403 @functools.lru_cache(maxsize=10)
1404 def limited_cache(o):
1405 pass
1406
1407 with self.assertRaises(TypeError):
1408 infinite_cache([])
1409
1410 with self.assertRaises(TypeError):
1411 limited_cache([])
1412
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001413 def test_lru_with_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001414 @self.module.lru_cache(maxsize=None)
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001415 def fib(n):
1416 if n < 2:
1417 return n
1418 return fib(n-1) + fib(n-2)
1419 self.assertEqual([fib(n) for n in range(16)],
1420 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1421 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001422 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001423 fib.cache_clear()
1424 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001425 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1426
1427 def test_lru_with_maxsize_negative(self):
1428 @self.module.lru_cache(maxsize=-10)
1429 def eq(n):
1430 return n
1431 for i in (0, 1):
1432 self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1433 self.assertEqual(eq.cache_info(),
Raymond Hettingerd8080c02019-01-26 03:02:00 -05001434 self.module._CacheInfo(hits=0, misses=300, maxsize=0, currsize=0))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001435
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001436 def test_lru_with_exceptions(self):
1437 # Verify that user_function exceptions get passed through without
1438 # creating a hard-to-read chained exception.
1439 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +01001440 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001441 @self.module.lru_cache(maxsize)
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001442 def func(i):
1443 return 'abc'[i]
1444 self.assertEqual(func(0), 'a')
1445 with self.assertRaises(IndexError) as cm:
1446 func(15)
1447 self.assertIsNone(cm.exception.__context__)
1448 # Verify that the previous exception did not result in a cached entry
1449 with self.assertRaises(IndexError):
1450 func(15)
1451
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001452 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001453 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001454 @self.module.lru_cache(maxsize=maxsize, typed=True)
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001455 def square(x):
1456 return x * x
1457 self.assertEqual(square(3), 9)
1458 self.assertEqual(type(square(3)), type(9))
1459 self.assertEqual(square(3.0), 9.0)
1460 self.assertEqual(type(square(3.0)), type(9.0))
1461 self.assertEqual(square(x=3), 9)
1462 self.assertEqual(type(square(x=3)), type(9))
1463 self.assertEqual(square(x=3.0), 9.0)
1464 self.assertEqual(type(square(x=3.0)), type(9.0))
1465 self.assertEqual(square.cache_info().hits, 4)
1466 self.assertEqual(square.cache_info().misses, 4)
1467
Antoine Pitroub5b37142012-11-13 21:35:40 +01001468 def test_lru_with_keyword_args(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001469 @self.module.lru_cache()
Antoine Pitroub5b37142012-11-13 21:35:40 +01001470 def fib(n):
1471 if n < 2:
1472 return n
1473 return fib(n=n-1) + fib(n=n-2)
1474 self.assertEqual(
1475 [fib(n=number) for number in range(16)],
1476 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1477 )
1478 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001479 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001480 fib.cache_clear()
1481 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001482 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001483
1484 def test_lru_with_keyword_args_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001485 @self.module.lru_cache(maxsize=None)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001486 def fib(n):
1487 if n < 2:
1488 return n
1489 return fib(n=n-1) + fib(n=n-2)
1490 self.assertEqual([fib(n=number) for number in range(16)],
1491 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1492 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001493 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001494 fib.cache_clear()
1495 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001496 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1497
Raymond Hettinger4ee39142017-01-08 17:28:20 -08001498 def test_kwargs_order(self):
1499 # PEP 468: Preserving Keyword Argument Order
1500 @self.module.lru_cache(maxsize=10)
1501 def f(**kwargs):
1502 return list(kwargs.items())
1503 self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)])
1504 self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)])
1505 self.assertEqual(f.cache_info(),
1506 self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2))
1507
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001508 def test_lru_cache_decoration(self):
1509 def f(zomg: 'zomg_annotation'):
1510 """f doc string"""
1511 return 42
1512 g = self.module.lru_cache()(f)
1513 for attr in self.module.WRAPPER_ASSIGNMENTS:
1514 self.assertEqual(getattr(g, attr), getattr(f, attr))
1515
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001516 def test_lru_cache_threaded(self):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001517 n, m = 5, 11
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001518 def orig(x, y):
1519 return 3 * x + y
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001520 f = self.module.lru_cache(maxsize=n*m)(orig)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001521 hits, misses, maxsize, currsize = f.cache_info()
1522 self.assertEqual(currsize, 0)
1523
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001524 start = threading.Event()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001525 def full(k):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001526 start.wait(10)
1527 for _ in range(m):
Serhiy Storchaka391af752015-06-08 12:44:18 +03001528 self.assertEqual(f(k, 0), orig(k, 0))
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001529
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001530 def clear():
1531 start.wait(10)
1532 for _ in range(2*m):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001533 f.cache_clear()
1534
1535 orig_si = sys.getswitchinterval()
Xavier de Gaye7522ef42016-12-08 11:06:56 +01001536 support.setswitchinterval(1e-6)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001537 try:
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001538 # create n threads in order to fill cache
Serhiy Storchaka391af752015-06-08 12:44:18 +03001539 threads = [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001540 for k in range(n)]
Hai Shie80697d2020-05-28 06:10:27 +08001541 with threading_helper.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001542 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001543
1544 hits, misses, maxsize, currsize = f.cache_info()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001545 if self.module is py_functools:
1546 # XXX: Why can be not equal?
1547 self.assertLessEqual(misses, n)
1548 self.assertLessEqual(hits, m*n - misses)
1549 else:
1550 self.assertEqual(misses, n)
1551 self.assertEqual(hits, m*n - misses)
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001552 self.assertEqual(currsize, n)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001553
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001554 # create n threads in order to fill cache and 1 to clear it
1555 threads = [threading.Thread(target=clear)]
Serhiy Storchaka391af752015-06-08 12:44:18 +03001556 threads += [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001557 for k in range(n)]
1558 start.clear()
Hai Shie80697d2020-05-28 06:10:27 +08001559 with threading_helper.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001560 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001561 finally:
1562 sys.setswitchinterval(orig_si)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001563
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001564 def test_lru_cache_threaded2(self):
1565 # Simultaneous call with the same arguments
1566 n, m = 5, 7
1567 start = threading.Barrier(n+1)
1568 pause = threading.Barrier(n+1)
1569 stop = threading.Barrier(n+1)
1570 @self.module.lru_cache(maxsize=m*n)
1571 def f(x):
1572 pause.wait(10)
1573 return 3 * x
1574 self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1575 def test():
1576 for i in range(m):
1577 start.wait(10)
1578 self.assertEqual(f(i), 3 * i)
1579 stop.wait(10)
1580 threads = [threading.Thread(target=test) for k in range(n)]
Hai Shie80697d2020-05-28 06:10:27 +08001581 with threading_helper.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001582 for i in range(m):
1583 start.wait(10)
1584 stop.reset()
1585 pause.wait(10)
1586 start.reset()
1587 stop.wait(10)
1588 pause.reset()
1589 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1590
Serhiy Storchaka67796522017-01-12 18:34:33 +02001591 def test_lru_cache_threaded3(self):
1592 @self.module.lru_cache(maxsize=2)
1593 def f(x):
1594 time.sleep(.01)
1595 return 3 * x
1596 def test(i, x):
1597 with self.subTest(thread=i):
1598 self.assertEqual(f(x), 3 * x, i)
1599 threads = [threading.Thread(target=test, args=(i, v))
1600 for i, v in enumerate([1, 2, 2, 3, 2])]
Hai Shie80697d2020-05-28 06:10:27 +08001601 with threading_helper.start_threads(threads):
Serhiy Storchaka67796522017-01-12 18:34:33 +02001602 pass
1603
Raymond Hettinger03923422013-03-04 02:52:50 -05001604 def test_need_for_rlock(self):
1605 # This will deadlock on an LRU cache that uses a regular lock
1606
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001607 @self.module.lru_cache(maxsize=10)
Raymond Hettinger03923422013-03-04 02:52:50 -05001608 def test_func(x):
1609 'Used to demonstrate a reentrant lru_cache call within a single thread'
1610 return x
1611
1612 class DoubleEq:
1613 'Demonstrate a reentrant lru_cache call within a single thread'
1614 def __init__(self, x):
1615 self.x = x
1616 def __hash__(self):
1617 return self.x
1618 def __eq__(self, other):
1619 if self.x == 2:
1620 test_func(DoubleEq(1))
1621 return self.x == other.x
1622
1623 test_func(DoubleEq(1)) # Load the cache
1624 test_func(DoubleEq(2)) # Load the cache
1625 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1626 DoubleEq(2)) # Verify the correct return value
1627
Serhiy Storchakae7070f02015-06-08 11:19:24 +03001628 def test_lru_method(self):
1629 class X(int):
1630 f_cnt = 0
1631 @self.module.lru_cache(2)
1632 def f(self, x):
1633 self.f_cnt += 1
1634 return x*10+self
1635 a = X(5)
1636 b = X(5)
1637 c = X(7)
1638 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1639
1640 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1641 self.assertEqual(a.f(x), x*10 + 5)
1642 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1643 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1644
1645 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1646 self.assertEqual(b.f(x), x*10 + 5)
1647 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1648 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1649
1650 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1651 self.assertEqual(c.f(x), x*10 + 7)
1652 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1653 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1654
1655 self.assertEqual(a.f.cache_info(), X.f.cache_info())
1656 self.assertEqual(b.f.cache_info(), X.f.cache_info())
1657 self.assertEqual(c.f.cache_info(), X.f.cache_info())
1658
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001659 def test_pickle(self):
1660 cls = self.__class__
1661 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1662 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1663 with self.subTest(proto=proto, func=f):
1664 f_copy = pickle.loads(pickle.dumps(f, proto))
1665 self.assertIs(f_copy, f)
1666
1667 def test_copy(self):
1668 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001669 def orig(x, y):
1670 return 3 * x + y
1671 part = self.module.partial(orig, 2)
1672 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1673 self.module.lru_cache(2)(part))
1674 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001675 with self.subTest(func=f):
1676 f_copy = copy.copy(f)
1677 self.assertIs(f_copy, f)
1678
1679 def test_deepcopy(self):
1680 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001681 def orig(x, y):
1682 return 3 * x + y
1683 part = self.module.partial(orig, 2)
1684 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1685 self.module.lru_cache(2)(part))
1686 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001687 with self.subTest(func=f):
1688 f_copy = copy.deepcopy(f)
1689 self.assertIs(f_copy, f)
1690
Manjusaka051ff522019-11-12 15:30:18 +08001691 def test_lru_cache_parameters(self):
1692 @self.module.lru_cache(maxsize=2)
1693 def f():
1694 return 1
1695 self.assertEqual(f.cache_parameters(), {'maxsize': 2, "typed": False})
1696
1697 @self.module.lru_cache(maxsize=1000, typed=True)
1698 def f():
1699 return 1
1700 self.assertEqual(f.cache_parameters(), {'maxsize': 1000, "typed": True})
1701
Dennis Sweeney1253c3e2020-05-05 17:14:32 -04001702 def test_lru_cache_weakrefable(self):
1703 @self.module.lru_cache
1704 def test_function(x):
1705 return x
1706
1707 class A:
1708 @self.module.lru_cache
1709 def test_method(self, x):
1710 return (self, x)
1711
1712 @staticmethod
1713 @self.module.lru_cache
1714 def test_staticmethod(x):
1715 return (self, x)
1716
1717 refs = [weakref.ref(test_function),
1718 weakref.ref(A.test_method),
1719 weakref.ref(A.test_staticmethod)]
1720
1721 for ref in refs:
1722 self.assertIsNotNone(ref())
1723
1724 del A
1725 del test_function
1726 gc.collect()
1727
1728 for ref in refs:
1729 self.assertIsNone(ref())
1730
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001731
1732@py_functools.lru_cache()
1733def py_cached_func(x, y):
1734 return 3 * x + y
1735
1736@c_functools.lru_cache()
1737def c_cached_func(x, y):
1738 return 3 * x + y
1739
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001740
1741class TestLRUPy(TestLRU, unittest.TestCase):
1742 module = py_functools
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001743 cached_func = py_cached_func,
1744
1745 @module.lru_cache()
1746 def cached_meth(self, x, y):
1747 return 3 * x + y
1748
1749 @staticmethod
1750 @module.lru_cache()
1751 def cached_staticmeth(x, y):
1752 return 3 * x + y
1753
1754
1755class TestLRUC(TestLRU, unittest.TestCase):
1756 module = c_functools
1757 cached_func = c_cached_func,
1758
1759 @module.lru_cache()
1760 def cached_meth(self, x, y):
1761 return 3 * x + y
1762
1763 @staticmethod
1764 @module.lru_cache()
1765 def cached_staticmeth(x, y):
1766 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001767
Raymond Hettinger03923422013-03-04 02:52:50 -05001768
Łukasz Langa6f692512013-06-05 12:20:24 +02001769class TestSingleDispatch(unittest.TestCase):
1770 def test_simple_overloads(self):
1771 @functools.singledispatch
1772 def g(obj):
1773 return "base"
1774 def g_int(i):
1775 return "integer"
1776 g.register(int, g_int)
1777 self.assertEqual(g("str"), "base")
1778 self.assertEqual(g(1), "integer")
1779 self.assertEqual(g([1,2,3]), "base")
1780
1781 def test_mro(self):
1782 @functools.singledispatch
1783 def g(obj):
1784 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001785 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001786 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001787 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001788 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001789 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001790 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001791 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001792 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001793 def g_A(a):
1794 return "A"
1795 def g_B(b):
1796 return "B"
1797 g.register(A, g_A)
1798 g.register(B, g_B)
1799 self.assertEqual(g(A()), "A")
1800 self.assertEqual(g(B()), "B")
1801 self.assertEqual(g(C()), "A")
1802 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001803
1804 def test_register_decorator(self):
1805 @functools.singledispatch
1806 def g(obj):
1807 return "base"
1808 @g.register(int)
1809 def g_int(i):
1810 return "int %s" % (i,)
1811 self.assertEqual(g(""), "base")
1812 self.assertEqual(g(12), "int 12")
1813 self.assertIs(g.dispatch(int), g_int)
1814 self.assertIs(g.dispatch(object), g.dispatch(str))
1815 # Note: in the assert above this is not g.
1816 # @singledispatch returns the wrapper.
1817
1818 def test_wrapping_attributes(self):
1819 @functools.singledispatch
1820 def g(obj):
1821 "Simple test"
1822 return "Test"
1823 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02001824 if sys.flags.optimize < 2:
1825 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02001826
1827 @unittest.skipUnless(decimal, 'requires _decimal')
1828 @support.cpython_only
1829 def test_c_classes(self):
1830 @functools.singledispatch
1831 def g(obj):
1832 return "base"
1833 @g.register(decimal.DecimalException)
1834 def _(obj):
1835 return obj.args
1836 subn = decimal.Subnormal("Exponent < Emin")
1837 rnd = decimal.Rounded("Number got rounded")
1838 self.assertEqual(g(subn), ("Exponent < Emin",))
1839 self.assertEqual(g(rnd), ("Number got rounded",))
1840 @g.register(decimal.Subnormal)
1841 def _(obj):
1842 return "Too small to care."
1843 self.assertEqual(g(subn), "Too small to care.")
1844 self.assertEqual(g(rnd), ("Number got rounded",))
1845
1846 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001847 # None of the examples in this test depend on haystack ordering.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001848 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001849 mro = functools._compose_mro
1850 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1851 for haystack in permutations(bases):
1852 m = mro(dict, haystack)
Guido van Rossumf0666942016-08-23 10:47:07 -07001853 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
1854 c.Collection, c.Sized, c.Iterable,
1855 c.Container, object])
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001856 bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict]
Łukasz Langa6f692512013-06-05 12:20:24 +02001857 for haystack in permutations(bases):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001858 m = mro(collections.ChainMap, haystack)
1859 self.assertEqual(m, [collections.ChainMap, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001860 c.Collection, c.Sized, c.Iterable,
1861 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001862
1863 # If there's a generic function with implementations registered for
1864 # both Sized and Container, passing a defaultdict to it results in an
1865 # ambiguous dispatch which will cause a RuntimeError (see
1866 # test_mro_conflicts).
1867 bases = [c.Container, c.Sized, str]
1868 for haystack in permutations(bases):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001869 m = mro(collections.defaultdict, [c.Sized, c.Container, str])
1870 self.assertEqual(m, [collections.defaultdict, dict, c.Sized,
1871 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001872
1873 # MutableSequence below is registered directly on D. In other words, it
Martin Panter46f50722016-05-26 05:35:26 +00001874 # precedes MutableMapping which means single dispatch will always
Łukasz Langa3720c772013-07-01 16:00:38 +02001875 # choose MutableSequence here.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001876 class D(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02001877 pass
1878 c.MutableSequence.register(D)
1879 bases = [c.MutableSequence, c.MutableMapping]
1880 for haystack in permutations(bases):
1881 m = mro(D, bases)
Guido van Rossumf0666942016-08-23 10:47:07 -07001882 self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001883 collections.defaultdict, dict, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001884 c.Collection, c.Sized, c.Iterable, c.Container,
Łukasz Langa3720c772013-07-01 16:00:38 +02001885 object])
1886
1887 # Container and Callable are registered on different base classes and
1888 # a generic function supporting both should always pick the Callable
1889 # implementation if a C instance is passed.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001890 class C(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02001891 def __call__(self):
1892 pass
1893 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1894 for haystack in permutations(bases):
1895 m = mro(C, haystack)
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001896 self.assertEqual(m, [C, c.Callable, collections.defaultdict, dict, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001897 c.Collection, c.Sized, c.Iterable,
1898 c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001899
1900 def test_register_abc(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001901 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001902 d = {"a": "b"}
1903 l = [1, 2, 3]
1904 s = {object(), None}
1905 f = frozenset(s)
1906 t = (1, 2, 3)
1907 @functools.singledispatch
1908 def g(obj):
1909 return "base"
1910 self.assertEqual(g(d), "base")
1911 self.assertEqual(g(l), "base")
1912 self.assertEqual(g(s), "base")
1913 self.assertEqual(g(f), "base")
1914 self.assertEqual(g(t), "base")
1915 g.register(c.Sized, lambda obj: "sized")
1916 self.assertEqual(g(d), "sized")
1917 self.assertEqual(g(l), "sized")
1918 self.assertEqual(g(s), "sized")
1919 self.assertEqual(g(f), "sized")
1920 self.assertEqual(g(t), "sized")
1921 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1922 self.assertEqual(g(d), "mutablemapping")
1923 self.assertEqual(g(l), "sized")
1924 self.assertEqual(g(s), "sized")
1925 self.assertEqual(g(f), "sized")
1926 self.assertEqual(g(t), "sized")
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001927 g.register(collections.ChainMap, lambda obj: "chainmap")
Łukasz Langa6f692512013-06-05 12:20:24 +02001928 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1929 self.assertEqual(g(l), "sized")
1930 self.assertEqual(g(s), "sized")
1931 self.assertEqual(g(f), "sized")
1932 self.assertEqual(g(t), "sized")
1933 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1934 self.assertEqual(g(d), "mutablemapping")
1935 self.assertEqual(g(l), "mutablesequence")
1936 self.assertEqual(g(s), "sized")
1937 self.assertEqual(g(f), "sized")
1938 self.assertEqual(g(t), "sized")
1939 g.register(c.MutableSet, lambda obj: "mutableset")
1940 self.assertEqual(g(d), "mutablemapping")
1941 self.assertEqual(g(l), "mutablesequence")
1942 self.assertEqual(g(s), "mutableset")
1943 self.assertEqual(g(f), "sized")
1944 self.assertEqual(g(t), "sized")
1945 g.register(c.Mapping, lambda obj: "mapping")
1946 self.assertEqual(g(d), "mutablemapping") # not specific enough
1947 self.assertEqual(g(l), "mutablesequence")
1948 self.assertEqual(g(s), "mutableset")
1949 self.assertEqual(g(f), "sized")
1950 self.assertEqual(g(t), "sized")
1951 g.register(c.Sequence, lambda obj: "sequence")
1952 self.assertEqual(g(d), "mutablemapping")
1953 self.assertEqual(g(l), "mutablesequence")
1954 self.assertEqual(g(s), "mutableset")
1955 self.assertEqual(g(f), "sized")
1956 self.assertEqual(g(t), "sequence")
1957 g.register(c.Set, lambda obj: "set")
1958 self.assertEqual(g(d), "mutablemapping")
1959 self.assertEqual(g(l), "mutablesequence")
1960 self.assertEqual(g(s), "mutableset")
1961 self.assertEqual(g(f), "set")
1962 self.assertEqual(g(t), "sequence")
1963 g.register(dict, lambda obj: "dict")
1964 self.assertEqual(g(d), "dict")
1965 self.assertEqual(g(l), "mutablesequence")
1966 self.assertEqual(g(s), "mutableset")
1967 self.assertEqual(g(f), "set")
1968 self.assertEqual(g(t), "sequence")
1969 g.register(list, lambda obj: "list")
1970 self.assertEqual(g(d), "dict")
1971 self.assertEqual(g(l), "list")
1972 self.assertEqual(g(s), "mutableset")
1973 self.assertEqual(g(f), "set")
1974 self.assertEqual(g(t), "sequence")
1975 g.register(set, lambda obj: "concrete-set")
1976 self.assertEqual(g(d), "dict")
1977 self.assertEqual(g(l), "list")
1978 self.assertEqual(g(s), "concrete-set")
1979 self.assertEqual(g(f), "set")
1980 self.assertEqual(g(t), "sequence")
1981 g.register(frozenset, lambda obj: "frozen-set")
1982 self.assertEqual(g(d), "dict")
1983 self.assertEqual(g(l), "list")
1984 self.assertEqual(g(s), "concrete-set")
1985 self.assertEqual(g(f), "frozen-set")
1986 self.assertEqual(g(t), "sequence")
1987 g.register(tuple, lambda obj: "tuple")
1988 self.assertEqual(g(d), "dict")
1989 self.assertEqual(g(l), "list")
1990 self.assertEqual(g(s), "concrete-set")
1991 self.assertEqual(g(f), "frozen-set")
1992 self.assertEqual(g(t), "tuple")
1993
Łukasz Langa3720c772013-07-01 16:00:38 +02001994 def test_c3_abc(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001995 c = collections.abc
Łukasz Langa3720c772013-07-01 16:00:38 +02001996 mro = functools._c3_mro
1997 class A(object):
1998 pass
1999 class B(A):
2000 def __len__(self):
2001 return 0 # implies Sized
2002 @c.Container.register
2003 class C(object):
2004 pass
2005 class D(object):
2006 pass # unrelated
2007 class X(D, C, B):
2008 def __call__(self):
2009 pass # implies Callable
2010 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
2011 for abcs in permutations([c.Sized, c.Callable, c.Container]):
2012 self.assertEqual(mro(X, abcs=abcs), expected)
2013 # unrelated ABCs don't appear in the resulting MRO
2014 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
2015 self.assertEqual(mro(X, abcs=many_abcs), expected)
2016
Yury Selivanov77a8cd62015-08-18 14:20:00 -04002017 def test_false_meta(self):
2018 # see issue23572
2019 class MetaA(type):
2020 def __len__(self):
2021 return 0
2022 class A(metaclass=MetaA):
2023 pass
2024 class AA(A):
2025 pass
2026 @functools.singledispatch
2027 def fun(a):
2028 return 'base A'
2029 @fun.register(A)
2030 def _(a):
2031 return 'fun A'
2032 aa = AA()
2033 self.assertEqual(fun(aa), 'fun A')
2034
Łukasz Langa6f692512013-06-05 12:20:24 +02002035 def test_mro_conflicts(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002036 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02002037 @functools.singledispatch
2038 def g(arg):
2039 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02002040 class O(c.Sized):
2041 def __len__(self):
2042 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02002043 o = O()
2044 self.assertEqual(g(o), "base")
2045 g.register(c.Iterable, lambda arg: "iterable")
2046 g.register(c.Container, lambda arg: "container")
2047 g.register(c.Sized, lambda arg: "sized")
2048 g.register(c.Set, lambda arg: "set")
2049 self.assertEqual(g(o), "sized")
2050 c.Iterable.register(O)
2051 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
2052 c.Container.register(O)
2053 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02002054 c.Set.register(O)
2055 self.assertEqual(g(o), "set") # because c.Set is a subclass of
2056 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02002057 class P:
2058 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02002059 p = P()
2060 self.assertEqual(g(p), "base")
2061 c.Iterable.register(P)
2062 self.assertEqual(g(p), "iterable")
2063 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02002064 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02002065 g(p)
Benjamin Petersonab078e92016-07-13 21:13:29 -07002066 self.assertIn(
2067 str(re_one.exception),
2068 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2069 "or <class 'collections.abc.Iterable'>"),
2070 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
2071 "or <class 'collections.abc.Container'>")),
2072 )
Łukasz Langa6f692512013-06-05 12:20:24 +02002073 class Q(c.Sized):
2074 def __len__(self):
2075 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02002076 q = Q()
2077 self.assertEqual(g(q), "sized")
2078 c.Iterable.register(Q)
2079 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
2080 c.Set.register(Q)
2081 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02002082 # c.Sized and c.Iterable
2083 @functools.singledispatch
2084 def h(arg):
2085 return "base"
2086 @h.register(c.Sized)
2087 def _(arg):
2088 return "sized"
2089 @h.register(c.Container)
2090 def _(arg):
2091 return "container"
2092 # Even though Sized and Container are explicit bases of MutableMapping,
2093 # this ABC is implicitly registered on defaultdict which makes all of
2094 # MutableMapping's bases implicit as well from defaultdict's
2095 # perspective.
2096 with self.assertRaises(RuntimeError) as re_two:
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002097 h(collections.defaultdict(lambda: 0))
Benjamin Petersonab078e92016-07-13 21:13:29 -07002098 self.assertIn(
2099 str(re_two.exception),
2100 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2101 "or <class 'collections.abc.Sized'>"),
2102 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2103 "or <class 'collections.abc.Container'>")),
2104 )
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002105 class R(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02002106 pass
2107 c.MutableSequence.register(R)
2108 @functools.singledispatch
2109 def i(arg):
2110 return "base"
2111 @i.register(c.MutableMapping)
2112 def _(arg):
2113 return "mapping"
2114 @i.register(c.MutableSequence)
2115 def _(arg):
2116 return "sequence"
2117 r = R()
2118 self.assertEqual(i(r), "sequence")
2119 class S:
2120 pass
2121 class T(S, c.Sized):
2122 def __len__(self):
2123 return 0
2124 t = T()
2125 self.assertEqual(h(t), "sized")
2126 c.Container.register(T)
2127 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
2128 class U:
2129 def __len__(self):
2130 return 0
2131 u = U()
2132 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
2133 # from the existence of __len__()
2134 c.Container.register(U)
2135 # There is no preference for registered versus inferred ABCs.
2136 with self.assertRaises(RuntimeError) as re_three:
2137 h(u)
Benjamin Petersonab078e92016-07-13 21:13:29 -07002138 self.assertIn(
2139 str(re_three.exception),
2140 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2141 "or <class 'collections.abc.Sized'>"),
2142 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2143 "or <class 'collections.abc.Container'>")),
2144 )
Łukasz Langa3720c772013-07-01 16:00:38 +02002145 class V(c.Sized, S):
2146 def __len__(self):
2147 return 0
2148 @functools.singledispatch
2149 def j(arg):
2150 return "base"
2151 @j.register(S)
2152 def _(arg):
2153 return "s"
2154 @j.register(c.Container)
2155 def _(arg):
2156 return "container"
2157 v = V()
2158 self.assertEqual(j(v), "s")
2159 c.Container.register(V)
2160 self.assertEqual(j(v), "container") # because it ends up right after
2161 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02002162
2163 def test_cache_invalidation(self):
2164 from collections import UserDict
INADA Naoki9811e802017-09-30 16:13:02 +09002165 import weakref
2166
Łukasz Langa6f692512013-06-05 12:20:24 +02002167 class TracingDict(UserDict):
2168 def __init__(self, *args, **kwargs):
2169 super(TracingDict, self).__init__(*args, **kwargs)
2170 self.set_ops = []
2171 self.get_ops = []
2172 def __getitem__(self, key):
2173 result = self.data[key]
2174 self.get_ops.append(key)
2175 return result
2176 def __setitem__(self, key, value):
2177 self.set_ops.append(key)
2178 self.data[key] = value
2179 def clear(self):
2180 self.data.clear()
INADA Naoki9811e802017-09-30 16:13:02 +09002181
Łukasz Langa6f692512013-06-05 12:20:24 +02002182 td = TracingDict()
INADA Naoki9811e802017-09-30 16:13:02 +09002183 with support.swap_attr(weakref, "WeakKeyDictionary", lambda: td):
2184 c = collections.abc
2185 @functools.singledispatch
2186 def g(arg):
2187 return "base"
2188 d = {}
2189 l = []
2190 self.assertEqual(len(td), 0)
2191 self.assertEqual(g(d), "base")
2192 self.assertEqual(len(td), 1)
2193 self.assertEqual(td.get_ops, [])
2194 self.assertEqual(td.set_ops, [dict])
2195 self.assertEqual(td.data[dict], g.registry[object])
2196 self.assertEqual(g(l), "base")
2197 self.assertEqual(len(td), 2)
2198 self.assertEqual(td.get_ops, [])
2199 self.assertEqual(td.set_ops, [dict, list])
2200 self.assertEqual(td.data[dict], g.registry[object])
2201 self.assertEqual(td.data[list], g.registry[object])
2202 self.assertEqual(td.data[dict], td.data[list])
2203 self.assertEqual(g(l), "base")
2204 self.assertEqual(g(d), "base")
2205 self.assertEqual(td.get_ops, [list, dict])
2206 self.assertEqual(td.set_ops, [dict, list])
2207 g.register(list, lambda arg: "list")
2208 self.assertEqual(td.get_ops, [list, dict])
2209 self.assertEqual(len(td), 0)
2210 self.assertEqual(g(d), "base")
2211 self.assertEqual(len(td), 1)
2212 self.assertEqual(td.get_ops, [list, dict])
2213 self.assertEqual(td.set_ops, [dict, list, dict])
2214 self.assertEqual(td.data[dict],
2215 functools._find_impl(dict, g.registry))
2216 self.assertEqual(g(l), "list")
2217 self.assertEqual(len(td), 2)
2218 self.assertEqual(td.get_ops, [list, dict])
2219 self.assertEqual(td.set_ops, [dict, list, dict, list])
2220 self.assertEqual(td.data[list],
2221 functools._find_impl(list, g.registry))
2222 class X:
2223 pass
2224 c.MutableMapping.register(X) # Will not invalidate the cache,
2225 # not using ABCs yet.
2226 self.assertEqual(g(d), "base")
2227 self.assertEqual(g(l), "list")
2228 self.assertEqual(td.get_ops, [list, dict, dict, list])
2229 self.assertEqual(td.set_ops, [dict, list, dict, list])
2230 g.register(c.Sized, lambda arg: "sized")
2231 self.assertEqual(len(td), 0)
2232 self.assertEqual(g(d), "sized")
2233 self.assertEqual(len(td), 1)
2234 self.assertEqual(td.get_ops, [list, dict, dict, list])
2235 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
2236 self.assertEqual(g(l), "list")
2237 self.assertEqual(len(td), 2)
2238 self.assertEqual(td.get_ops, [list, dict, dict, list])
2239 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2240 self.assertEqual(g(l), "list")
2241 self.assertEqual(g(d), "sized")
2242 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
2243 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2244 g.dispatch(list)
2245 g.dispatch(dict)
2246 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
2247 list, dict])
2248 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2249 c.MutableSet.register(X) # Will invalidate the cache.
2250 self.assertEqual(len(td), 2) # Stale cache.
2251 self.assertEqual(g(l), "list")
2252 self.assertEqual(len(td), 1)
2253 g.register(c.MutableMapping, lambda arg: "mutablemapping")
2254 self.assertEqual(len(td), 0)
2255 self.assertEqual(g(d), "mutablemapping")
2256 self.assertEqual(len(td), 1)
2257 self.assertEqual(g(l), "list")
2258 self.assertEqual(len(td), 2)
2259 g.register(dict, lambda arg: "dict")
2260 self.assertEqual(g(d), "dict")
2261 self.assertEqual(g(l), "list")
2262 g._clear_cache()
2263 self.assertEqual(len(td), 0)
Łukasz Langa6f692512013-06-05 12:20:24 +02002264
Łukasz Langae5697532017-12-11 13:56:31 -08002265 def test_annotations(self):
2266 @functools.singledispatch
2267 def i(arg):
2268 return "base"
2269 @i.register
2270 def _(arg: collections.abc.Mapping):
2271 return "mapping"
2272 @i.register
2273 def _(arg: "collections.abc.Sequence"):
2274 return "sequence"
2275 self.assertEqual(i(None), "base")
2276 self.assertEqual(i({"a": 1}), "mapping")
2277 self.assertEqual(i([1, 2, 3]), "sequence")
2278 self.assertEqual(i((1, 2, 3)), "sequence")
2279 self.assertEqual(i("str"), "sequence")
2280
2281 # Registering classes as callables doesn't work with annotations,
2282 # you need to pass the type explicitly.
2283 @i.register(str)
2284 class _:
2285 def __init__(self, arg):
2286 self.arg = arg
2287
2288 def __eq__(self, other):
2289 return self.arg == other
2290 self.assertEqual(i("str"), "str")
2291
Ethan Smithc6512752018-05-26 16:38:33 -04002292 def test_method_register(self):
2293 class A:
2294 @functools.singledispatchmethod
2295 def t(self, arg):
2296 self.arg = "base"
2297 @t.register(int)
2298 def _(self, arg):
2299 self.arg = "int"
2300 @t.register(str)
2301 def _(self, arg):
2302 self.arg = "str"
2303 a = A()
2304
2305 a.t(0)
2306 self.assertEqual(a.arg, "int")
2307 aa = A()
2308 self.assertFalse(hasattr(aa, 'arg'))
2309 a.t('')
2310 self.assertEqual(a.arg, "str")
2311 aa = A()
2312 self.assertFalse(hasattr(aa, 'arg'))
2313 a.t(0.0)
2314 self.assertEqual(a.arg, "base")
2315 aa = A()
2316 self.assertFalse(hasattr(aa, 'arg'))
2317
2318 def test_staticmethod_register(self):
2319 class A:
2320 @functools.singledispatchmethod
2321 @staticmethod
2322 def t(arg):
2323 return arg
2324 @t.register(int)
2325 @staticmethod
2326 def _(arg):
2327 return isinstance(arg, int)
2328 @t.register(str)
2329 @staticmethod
2330 def _(arg):
2331 return isinstance(arg, str)
2332 a = A()
2333
2334 self.assertTrue(A.t(0))
2335 self.assertTrue(A.t(''))
2336 self.assertEqual(A.t(0.0), 0.0)
2337
2338 def test_classmethod_register(self):
2339 class A:
2340 def __init__(self, arg):
2341 self.arg = arg
2342
2343 @functools.singledispatchmethod
2344 @classmethod
2345 def t(cls, arg):
2346 return cls("base")
2347 @t.register(int)
2348 @classmethod
2349 def _(cls, arg):
2350 return cls("int")
2351 @t.register(str)
2352 @classmethod
2353 def _(cls, arg):
2354 return cls("str")
2355
2356 self.assertEqual(A.t(0).arg, "int")
2357 self.assertEqual(A.t('').arg, "str")
2358 self.assertEqual(A.t(0.0).arg, "base")
2359
2360 def test_callable_register(self):
2361 class A:
2362 def __init__(self, arg):
2363 self.arg = arg
2364
2365 @functools.singledispatchmethod
2366 @classmethod
2367 def t(cls, arg):
2368 return cls("base")
2369
2370 @A.t.register(int)
2371 @classmethod
2372 def _(cls, arg):
2373 return cls("int")
2374 @A.t.register(str)
2375 @classmethod
2376 def _(cls, arg):
2377 return cls("str")
2378
2379 self.assertEqual(A.t(0).arg, "int")
2380 self.assertEqual(A.t('').arg, "str")
2381 self.assertEqual(A.t(0.0).arg, "base")
2382
2383 def test_abstractmethod_register(self):
2384 class Abstract(abc.ABCMeta):
2385
2386 @functools.singledispatchmethod
2387 @abc.abstractmethod
2388 def add(self, x, y):
2389 pass
2390
2391 self.assertTrue(Abstract.add.__isabstractmethod__)
2392
2393 def test_type_ann_register(self):
2394 class A:
2395 @functools.singledispatchmethod
2396 def t(self, arg):
2397 return "base"
2398 @t.register
2399 def _(self, arg: int):
2400 return "int"
2401 @t.register
2402 def _(self, arg: str):
2403 return "str"
2404 a = A()
2405
2406 self.assertEqual(a.t(0), "int")
2407 self.assertEqual(a.t(''), "str")
2408 self.assertEqual(a.t(0.0), "base")
2409
Łukasz Langae5697532017-12-11 13:56:31 -08002410 def test_invalid_registrations(self):
2411 msg_prefix = "Invalid first argument to `register()`: "
2412 msg_suffix = (
2413 ". Use either `@register(some_class)` or plain `@register` on an "
2414 "annotated function."
2415 )
2416 @functools.singledispatch
2417 def i(arg):
2418 return "base"
2419 with self.assertRaises(TypeError) as exc:
2420 @i.register(42)
2421 def _(arg):
2422 return "I annotated with a non-type"
2423 self.assertTrue(str(exc.exception).startswith(msg_prefix + "42"))
2424 self.assertTrue(str(exc.exception).endswith(msg_suffix))
2425 with self.assertRaises(TypeError) as exc:
2426 @i.register
2427 def _(arg):
2428 return "I forgot to annotate"
2429 self.assertTrue(str(exc.exception).startswith(msg_prefix +
2430 "<function TestSingleDispatch.test_invalid_registrations.<locals>._"
2431 ))
2432 self.assertTrue(str(exc.exception).endswith(msg_suffix))
2433
Łukasz Langae5697532017-12-11 13:56:31 -08002434 with self.assertRaises(TypeError) as exc:
2435 @i.register
2436 def _(arg: typing.Iterable[str]):
2437 # At runtime, dispatching on generics is impossible.
2438 # When registering implementations with singledispatch, avoid
2439 # types from `typing`. Instead, annotate with regular types
2440 # or ABCs.
2441 return "I annotated with a generic collection"
Lysandros Nikolaoud6738102019-05-20 00:11:21 +02002442 self.assertTrue(str(exc.exception).startswith(
2443 "Invalid annotation for 'arg'."
Łukasz Langae5697532017-12-11 13:56:31 -08002444 ))
Lysandros Nikolaoud6738102019-05-20 00:11:21 +02002445 self.assertTrue(str(exc.exception).endswith(
2446 'typing.Iterable[str] is not a class.'
2447 ))
Łukasz Langae5697532017-12-11 13:56:31 -08002448
Dong-hee Na445f1b32018-07-10 16:26:36 +09002449 def test_invalid_positional_argument(self):
2450 @functools.singledispatch
2451 def f(*args):
2452 pass
2453 msg = 'f requires at least 1 positional argument'
INADA Naoki56d8f572018-07-17 13:44:47 +09002454 with self.assertRaisesRegex(TypeError, msg):
Dong-hee Na445f1b32018-07-10 16:26:36 +09002455 f()
Łukasz Langa6f692512013-06-05 12:20:24 +02002456
Carl Meyerd658dea2018-08-28 01:11:56 -06002457
2458class CachedCostItem:
2459 _cost = 1
2460
2461 def __init__(self):
2462 self.lock = py_functools.RLock()
2463
2464 @py_functools.cached_property
2465 def cost(self):
2466 """The cost of the item."""
2467 with self.lock:
2468 self._cost += 1
2469 return self._cost
2470
2471
2472class OptionallyCachedCostItem:
2473 _cost = 1
2474
2475 def get_cost(self):
2476 """The cost of the item."""
2477 self._cost += 1
2478 return self._cost
2479
2480 cached_cost = py_functools.cached_property(get_cost)
2481
2482
2483class CachedCostItemWait:
2484
2485 def __init__(self, event):
2486 self._cost = 1
2487 self.lock = py_functools.RLock()
2488 self.event = event
2489
2490 @py_functools.cached_property
2491 def cost(self):
2492 self.event.wait(1)
2493 with self.lock:
2494 self._cost += 1
2495 return self._cost
2496
2497
2498class CachedCostItemWithSlots:
2499 __slots__ = ('_cost')
2500
2501 def __init__(self):
2502 self._cost = 1
2503
2504 @py_functools.cached_property
2505 def cost(self):
2506 raise RuntimeError('never called, slots not supported')
2507
2508
2509class TestCachedProperty(unittest.TestCase):
2510 def test_cached(self):
2511 item = CachedCostItem()
2512 self.assertEqual(item.cost, 2)
2513 self.assertEqual(item.cost, 2) # not 3
2514
2515 def test_cached_attribute_name_differs_from_func_name(self):
2516 item = OptionallyCachedCostItem()
2517 self.assertEqual(item.get_cost(), 2)
2518 self.assertEqual(item.cached_cost, 3)
2519 self.assertEqual(item.get_cost(), 4)
2520 self.assertEqual(item.cached_cost, 3)
2521
2522 def test_threaded(self):
2523 go = threading.Event()
2524 item = CachedCostItemWait(go)
2525
2526 num_threads = 3
2527
2528 orig_si = sys.getswitchinterval()
2529 sys.setswitchinterval(1e-6)
2530 try:
2531 threads = [
2532 threading.Thread(target=lambda: item.cost)
2533 for k in range(num_threads)
2534 ]
Hai Shie80697d2020-05-28 06:10:27 +08002535 with threading_helper.start_threads(threads):
Carl Meyerd658dea2018-08-28 01:11:56 -06002536 go.set()
2537 finally:
2538 sys.setswitchinterval(orig_si)
2539
2540 self.assertEqual(item.cost, 2)
2541
2542 def test_object_with_slots(self):
2543 item = CachedCostItemWithSlots()
2544 with self.assertRaisesRegex(
2545 TypeError,
2546 "No '__dict__' attribute on 'CachedCostItemWithSlots' instance to cache 'cost' property.",
2547 ):
2548 item.cost
2549
2550 def test_immutable_dict(self):
2551 class MyMeta(type):
2552 @py_functools.cached_property
2553 def prop(self):
2554 return True
2555
2556 class MyClass(metaclass=MyMeta):
2557 pass
2558
2559 with self.assertRaisesRegex(
2560 TypeError,
2561 "The '__dict__' attribute on 'MyMeta' instance does not support item assignment for caching 'prop' property.",
2562 ):
2563 MyClass.prop
2564
2565 def test_reuse_different_names(self):
2566 """Disallow this case because decorated function a would not be cached."""
2567 with self.assertRaises(RuntimeError) as ctx:
2568 class ReusedCachedProperty:
2569 @py_functools.cached_property
2570 def a(self):
2571 pass
2572
2573 b = a
2574
2575 self.assertEqual(
2576 str(ctx.exception.__context__),
2577 str(TypeError("Cannot assign the same cached_property to two different names ('a' and 'b')."))
2578 )
2579
2580 def test_reuse_same_name(self):
2581 """Reusing a cached_property on different classes under the same name is OK."""
2582 counter = 0
2583
2584 @py_functools.cached_property
2585 def _cp(_self):
2586 nonlocal counter
2587 counter += 1
2588 return counter
2589
2590 class A:
2591 cp = _cp
2592
2593 class B:
2594 cp = _cp
2595
2596 a = A()
2597 b = B()
2598
2599 self.assertEqual(a.cp, 1)
2600 self.assertEqual(b.cp, 2)
2601 self.assertEqual(a.cp, 1)
2602
2603 def test_set_name_not_called(self):
2604 cp = py_functools.cached_property(lambda s: None)
2605 class Foo:
2606 pass
2607
2608 Foo.cp = cp
2609
2610 with self.assertRaisesRegex(
2611 TypeError,
2612 "Cannot use cached_property instance without calling __set_name__ on it.",
2613 ):
2614 Foo().cp
2615
2616 def test_access_from_class(self):
2617 self.assertIsInstance(CachedCostItem.cost, py_functools.cached_property)
2618
2619 def test_doc(self):
2620 self.assertEqual(CachedCostItem.cost.__doc__, "The cost of the item.")
2621
2622
Raymond Hettinger9c323f82005-02-28 19:39:44 +00002623if __name__ == '__main__':
Zachary Ware38c707e2015-04-13 15:00:43 -05002624 unittest.main()