blob: caeeb2712a1a48982d410f4bee738485506eaccc [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08002import builtins
Raymond Hettinger003be522011-05-03 11:01:32 -07003import collections
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03004import collections.abc
Serhiy Storchaka45120f22015-10-24 09:49:56 +03005import copy
Pablo Galindo2f172d82020-06-01 00:41:14 +01006from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00007import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00008from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02009import sys
10from test import support
Antoine Pitroua6a4dc82017-09-07 18:56:24 +020011import threading
Serhiy Storchaka67796522017-01-12 18:34:33 +020012import time
Łukasz Langae5697532017-12-11 13:56:31 -080013import typing
Łukasz Langa6f692512013-06-05 12:20:24 +020014import unittest
Raymond Hettingerd191ef22017-01-07 20:44:48 -080015import unittest.mock
Pablo Galindo99e6c262020-01-23 15:29:52 +000016import os
Dennis Sweeney1253c3e2020-05-05 17:14:32 -040017import weakref
18import gc
Łukasz Langa6f692512013-06-05 12:20:24 +020019from weakref import proxy
Nick Coghlan457fc9a2016-09-10 20:00:02 +100020import contextlib
Raymond Hettinger9c323f82005-02-28 19:39:44 +000021
Hai Shi3ddc6342020-06-30 21:46:06 +080022from test.support import import_helper
Hai Shie80697d2020-05-28 06:10:27 +080023from test.support import threading_helper
Pablo Galindo99e6c262020-01-23 15:29:52 +000024from test.support.script_helper import assert_python_ok
25
Antoine Pitroub5b37142012-11-13 21:35:40 +010026import functools
27
Hai Shi3ddc6342020-06-30 21:46:06 +080028py_functools = import_helper.import_fresh_module('functools',
29 blocked=['_functools'])
Hai Shidd391232020-12-29 20:45:07 +080030c_functools = import_helper.import_fresh_module('functools')
Antoine Pitroub5b37142012-11-13 21:35:40 +010031
Hai Shi3ddc6342020-06-30 21:46:06 +080032decimal = import_helper.import_fresh_module('decimal', fresh=['_decimal'])
Łukasz Langa6f692512013-06-05 12:20:24 +020033
Nick Coghlan457fc9a2016-09-10 20:00:02 +100034@contextlib.contextmanager
35def replaced_module(name, replacement):
36 original_module = sys.modules[name]
37 sys.modules[name] = replacement
38 try:
39 yield
40 finally:
41 sys.modules[name] = original_module
Łukasz Langa6f692512013-06-05 12:20:24 +020042
Raymond Hettinger9c323f82005-02-28 19:39:44 +000043def capture(*args, **kw):
44 """capture all positional and keyword arguments"""
45 return args, kw
46
Łukasz Langa6f692512013-06-05 12:20:24 +020047
Jack Diederiche0cbd692009-04-01 04:27:09 +000048def signature(part):
49 """ return the signature of a partial object """
50 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000051
Serhiy Storchaka38741282016-02-02 18:45:17 +020052class MyTuple(tuple):
53 pass
54
55class BadTuple(tuple):
56 def __add__(self, other):
57 return list(self) + list(other)
58
59class MyDict(dict):
60 pass
61
Łukasz Langa6f692512013-06-05 12:20:24 +020062
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020063class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000064
65 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010066 p = self.partial(capture, 1, 2, a=10, b=20)
67 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000068 self.assertEqual(p(3, 4, b=30, c=40),
69 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010070 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000071 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000072
73 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010074 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000075 # attributes should be readable
76 self.assertEqual(p.func, capture)
77 self.assertEqual(p.args, (1, 2))
78 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000079
80 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010081 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000082 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010083 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000084 except TypeError:
85 pass
86 else:
87 self.fail('First arg not checked for callability')
88
89 def test_protection_of_callers_dict_argument(self):
90 # a caller's dictionary should not be altered by partial
91 def func(a=10, b=20):
92 return a
93 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010094 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000095 self.assertEqual(p(**d), 3)
96 self.assertEqual(d, {'a':3})
97 p(b=7)
98 self.assertEqual(d, {'a':3})
99
Serhiy Storchaka9639e4a2017-02-20 14:04:30 +0200100 def test_kwargs_copy(self):
101 # Issue #29532: Altering a kwarg dictionary passed to a constructor
102 # should not affect a partial object after creation
103 d = {'a': 3}
104 p = self.partial(capture, **d)
105 self.assertEqual(p(), ((), {'a': 3}))
106 d['a'] = 5
107 self.assertEqual(p(), ((), {'a': 3}))
108
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000109 def test_arg_combinations(self):
110 # exercise special code paths for zero args in either partial
111 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100112 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000113 self.assertEqual(p(), ((), {}))
114 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100115 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000116 self.assertEqual(p(), ((1,2), {}))
117 self.assertEqual(p(3,4), ((1,2,3,4), {}))
118
119 def test_kw_combinations(self):
120 # exercise special code paths for no keyword args in
121 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100122 p = self.partial(capture)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400123 self.assertEqual(p.keywords, {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000124 self.assertEqual(p(), ((), {}))
125 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100126 p = self.partial(capture, a=1)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400127 self.assertEqual(p.keywords, {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000128 self.assertEqual(p(), ((), {'a':1}))
129 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
130 # keyword args in the call override those in the partial object
131 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
132
133 def test_positional(self):
134 # make sure positional arguments are captured correctly
135 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100136 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000137 expected = args + ('x',)
138 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000139 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000140
141 def test_keyword(self):
142 # make sure keyword arguments are captured correctly
143 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100144 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000145 expected = {'a':a,'x':None}
146 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000147 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000148
149 def test_no_side_effects(self):
150 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100151 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000152 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000153 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000154 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000155 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000156
157 def test_error_propagation(self):
158 def f(x, y):
159 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100160 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
161 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
162 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
163 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000164
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000165 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100166 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000167 p = proxy(f)
168 self.assertEqual(f.func, p.func)
169 f = None
170 self.assertRaises(ReferenceError, getattr, p, 'func')
171
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000172 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000173 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100174 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000175 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100176 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000177 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000178
Alexander Belopolskye49af342015-03-01 15:08:17 -0500179 def test_nested_optimization(self):
180 partial = self.partial
Alexander Belopolskye49af342015-03-01 15:08:17 -0500181 inner = partial(signature, 'asdf')
182 nested = partial(inner, bar=True)
183 flat = partial(signature, 'asdf', bar=True)
184 self.assertEqual(signature(nested), signature(flat))
185
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300186 def test_nested_partial_with_attribute(self):
187 # see issue 25137
188 partial = self.partial
189
190 def foo(bar):
191 return bar
192
193 p = partial(foo, 'first')
194 p2 = partial(p, 'second')
195 p2.new_attr = 'spam'
196 self.assertEqual(p2.new_attr, 'spam')
197
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000198 def test_repr(self):
199 args = (object(), object())
200 args_repr = ', '.join(repr(a) for a in args)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200201 kwargs = {'a': object(), 'b': object()}
202 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
203 'b={b!r}, a={a!r}'.format_map(kwargs)]
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000204 if self.partial in (c_functools.partial, py_functools.partial):
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000205 name = 'functools.partial'
206 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100207 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000208
Antoine Pitroub5b37142012-11-13 21:35:40 +0100209 f = self.partial(capture)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000210 self.assertEqual(f'{name}({capture!r})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000211
Antoine Pitroub5b37142012-11-13 21:35:40 +0100212 f = self.partial(capture, *args)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000213 self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000214
Antoine Pitroub5b37142012-11-13 21:35:40 +0100215 f = self.partial(capture, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200216 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000217 [f'{name}({capture!r}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200218 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000219
Antoine Pitroub5b37142012-11-13 21:35:40 +0100220 f = self.partial(capture, *args, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200221 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000222 [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200223 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000224
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300225 def test_recursive_repr(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000226 if self.partial in (c_functools.partial, py_functools.partial):
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300227 name = 'functools.partial'
228 else:
229 name = self.partial.__name__
230
231 f = self.partial(capture)
232 f.__setstate__((f, (), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300233 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000234 self.assertEqual(repr(f), '%s(...)' % (name,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300235 finally:
236 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300237
238 f = self.partial(capture)
239 f.__setstate__((capture, (f,), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300240 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000241 self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300242 finally:
243 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300244
245 f = self.partial(capture)
246 f.__setstate__((capture, (), {'a': f}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300247 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000248 self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300249 finally:
250 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300251
Jack Diederiche0cbd692009-04-01 04:27:09 +0000252 def test_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000253 with self.AllowPickle():
254 f = self.partial(signature, ['asdf'], bar=[True])
255 f.attr = []
256 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
257 f_copy = pickle.loads(pickle.dumps(f, proto))
258 self.assertEqual(signature(f_copy), signature(f))
Serhiy Storchaka38741282016-02-02 18:45:17 +0200259
260 def test_copy(self):
261 f = self.partial(signature, ['asdf'], bar=[True])
262 f.attr = []
263 f_copy = copy.copy(f)
264 self.assertEqual(signature(f_copy), signature(f))
265 self.assertIs(f_copy.attr, f.attr)
266 self.assertIs(f_copy.args, f.args)
267 self.assertIs(f_copy.keywords, f.keywords)
268
269 def test_deepcopy(self):
270 f = self.partial(signature, ['asdf'], bar=[True])
271 f.attr = []
272 f_copy = copy.deepcopy(f)
273 self.assertEqual(signature(f_copy), signature(f))
274 self.assertIsNot(f_copy.attr, f.attr)
275 self.assertIsNot(f_copy.args, f.args)
276 self.assertIsNot(f_copy.args[0], f.args[0])
277 self.assertIsNot(f_copy.keywords, f.keywords)
278 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
279
280 def test_setstate(self):
281 f = self.partial(signature)
282 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000283
Serhiy Storchaka38741282016-02-02 18:45:17 +0200284 self.assertEqual(signature(f),
285 (capture, (1,), dict(a=10), dict(attr=[])))
286 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
287
288 f.__setstate__((capture, (1,), dict(a=10), None))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000289
Serhiy Storchaka38741282016-02-02 18:45:17 +0200290 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
291 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
292
293 f.__setstate__((capture, (1,), None, None))
294 #self.assertEqual(signature(f), (capture, (1,), {}, {}))
295 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
296 self.assertEqual(f(2), ((1, 2), {}))
297 self.assertEqual(f(), ((1,), {}))
298
299 f.__setstate__((capture, (), {}, None))
300 self.assertEqual(signature(f), (capture, (), {}, {}))
301 self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
302 self.assertEqual(f(2), ((2,), {}))
303 self.assertEqual(f(), ((), {}))
304
305 def test_setstate_errors(self):
306 f = self.partial(signature)
307 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
308 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
309 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
310 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
311 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
312 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
313 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
314
315 def test_setstate_subclasses(self):
316 f = self.partial(signature)
317 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
318 s = signature(f)
319 self.assertEqual(s, (capture, (1,), dict(a=10), {}))
320 self.assertIs(type(s[1]), tuple)
321 self.assertIs(type(s[2]), dict)
322 r = f()
323 self.assertEqual(r, ((1,), {'a': 10}))
324 self.assertIs(type(r[0]), tuple)
325 self.assertIs(type(r[1]), dict)
326
327 f.__setstate__((capture, BadTuple((1,)), {}, None))
328 s = signature(f)
329 self.assertEqual(s, (capture, (1,), {}, {}))
330 self.assertIs(type(s[1]), tuple)
331 r = f(2)
332 self.assertEqual(r, ((1, 2), {}))
333 self.assertIs(type(r[0]), tuple)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000334
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300335 def test_recursive_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000336 with self.AllowPickle():
337 f = self.partial(capture)
338 f.__setstate__((f, (), {}, {}))
339 try:
340 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
341 with self.assertRaises(RecursionError):
342 pickle.dumps(f, proto)
343 finally:
344 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300345
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000346 f = self.partial(capture)
347 f.__setstate__((capture, (f,), {}, {}))
348 try:
349 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
350 f_copy = pickle.loads(pickle.dumps(f, proto))
351 try:
352 self.assertIs(f_copy.args[0], f_copy)
353 finally:
354 f_copy.__setstate__((capture, (), {}, {}))
355 finally:
356 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300357
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000358 f = self.partial(capture)
359 f.__setstate__((capture, (), {'a': f}, {}))
360 try:
361 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
362 f_copy = pickle.loads(pickle.dumps(f, proto))
363 try:
364 self.assertIs(f_copy.keywords['a'], f_copy)
365 finally:
366 f_copy.__setstate__((capture, (), {}, {}))
367 finally:
368 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300369
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200370 # Issue 6083: Reference counting bug
371 def test_setstate_refcount(self):
372 class BadSequence:
373 def __len__(self):
374 return 4
375 def __getitem__(self, key):
376 if key == 0:
377 return max
378 elif key == 1:
379 return tuple(range(1000000))
380 elif key in (2, 3):
381 return {}
382 raise IndexError
383
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200384 f = self.partial(object)
Serhiy Storchaka38741282016-02-02 18:45:17 +0200385 self.assertRaises(TypeError, f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000386
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000387@unittest.skipUnless(c_functools, 'requires the C _functools module')
388class TestPartialC(TestPartial, unittest.TestCase):
389 if c_functools:
390 partial = c_functools.partial
391
392 class AllowPickle:
393 def __enter__(self):
394 return self
395 def __exit__(self, type, value, tb):
396 return False
397
398 def test_attributes_unwritable(self):
399 # attributes should not be writable
400 p = self.partial(capture, 1, 2, a=10, b=20)
401 self.assertRaises(AttributeError, setattr, p, 'func', map)
402 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
403 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
404
405 p = self.partial(hex)
406 try:
407 del p.__dict__
408 except TypeError:
409 pass
410 else:
411 self.fail('partial object allowed __dict__ to be deleted')
Łukasz Langa6f692512013-06-05 12:20:24 +0200412
Michael Seifert6c3d5272017-03-15 06:26:33 +0100413 def test_manually_adding_non_string_keyword(self):
414 p = self.partial(capture)
415 # Adding a non-string/unicode keyword to partial kwargs
416 p.keywords[1234] = 'value'
417 r = repr(p)
418 self.assertIn('1234', r)
419 self.assertIn("'value'", r)
420 with self.assertRaises(TypeError):
421 p()
422
423 def test_keystr_replaces_value(self):
424 p = self.partial(capture)
425
426 class MutatesYourDict(object):
427 def __str__(self):
428 p.keywords[self] = ['sth2']
429 return 'astr'
430
Mike53f7a7c2017-12-14 14:04:53 +0300431 # Replacing the value during key formatting should keep the original
Michael Seifert6c3d5272017-03-15 06:26:33 +0100432 # value alive (at least long enough).
433 p.keywords[MutatesYourDict()] = ['sth']
434 r = repr(p)
435 self.assertIn('astr', r)
436 self.assertIn("['sth']", r)
437
438
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200439class TestPartialPy(TestPartial, unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000440 partial = py_functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000441
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000442 class AllowPickle:
443 def __init__(self):
444 self._cm = replaced_module("functools", py_functools)
445 def __enter__(self):
446 return self._cm.__enter__()
447 def __exit__(self, type, value, tb):
448 return self._cm.__exit__(type, value, tb)
Łukasz Langa6f692512013-06-05 12:20:24 +0200449
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200450if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000451 class CPartialSubclass(c_functools.partial):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200452 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100453
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000454class PyPartialSubclass(py_functools.partial):
455 pass
Łukasz Langa6f692512013-06-05 12:20:24 +0200456
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200457@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200458class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200459 if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000460 partial = CPartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000461
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300462 # partial subclasses are not optimized for nested calls
463 test_nested_optimization = None
464
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000465class TestPartialPySubclass(TestPartialPy):
466 partial = PyPartialSubclass
Łukasz Langa6f692512013-06-05 12:20:24 +0200467
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000468class TestPartialMethod(unittest.TestCase):
469
470 class A(object):
471 nothing = functools.partialmethod(capture)
472 positional = functools.partialmethod(capture, 1)
473 keywords = functools.partialmethod(capture, a=2)
474 both = functools.partialmethod(capture, 3, b=4)
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300475 spec_keywords = functools.partialmethod(capture, self=1, func=2)
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000476
477 nested = functools.partialmethod(positional, 5)
478
479 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
480
481 static = functools.partialmethod(staticmethod(capture), 8)
482 cls = functools.partialmethod(classmethod(capture), d=9)
483
484 a = A()
485
486 def test_arg_combinations(self):
487 self.assertEqual(self.a.nothing(), ((self.a,), {}))
488 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
489 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
490 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
491
492 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
493 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
494 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
495 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
496
497 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
498 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
499 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
500 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
501
502 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
503 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
504 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
505 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
506
507 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
508
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300509 self.assertEqual(self.a.spec_keywords(), ((self.a,), {'self': 1, 'func': 2}))
510
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000511 def test_nested(self):
512 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
513 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
514 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
515 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
516
517 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
518
519 def test_over_partial(self):
520 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
521 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
522 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
523 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
524
525 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
526
527 def test_bound_method_introspection(self):
528 obj = self.a
529 self.assertIs(obj.both.__self__, obj)
530 self.assertIs(obj.nested.__self__, obj)
531 self.assertIs(obj.over_partial.__self__, obj)
532 self.assertIs(obj.cls.__self__, self.A)
533 self.assertIs(self.A.cls.__self__, self.A)
534
535 def test_unbound_method_retrieval(self):
536 obj = self.A
537 self.assertFalse(hasattr(obj.both, "__self__"))
538 self.assertFalse(hasattr(obj.nested, "__self__"))
539 self.assertFalse(hasattr(obj.over_partial, "__self__"))
540 self.assertFalse(hasattr(obj.static, "__self__"))
541 self.assertFalse(hasattr(self.a.static, "__self__"))
542
543 def test_descriptors(self):
544 for obj in [self.A, self.a]:
545 with self.subTest(obj=obj):
546 self.assertEqual(obj.static(), ((8,), {}))
547 self.assertEqual(obj.static(5), ((8, 5), {}))
548 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
549 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
550
551 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
552 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
553 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
554 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
555
556 def test_overriding_keywords(self):
557 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
558 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
559
560 def test_invalid_args(self):
561 with self.assertRaises(TypeError):
562 class B(object):
563 method = functools.partialmethod(None, 1)
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300564 with self.assertRaises(TypeError):
565 class B:
566 method = functools.partialmethod()
Serhiy Storchaka142566c2019-06-05 18:22:31 +0300567 with self.assertRaises(TypeError):
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300568 class B:
569 method = functools.partialmethod(func=capture, a=1)
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000570
571 def test_repr(self):
572 self.assertEqual(repr(vars(self.A)['both']),
573 'functools.partialmethod({}, 3, b=4)'.format(capture))
574
575 def test_abstract(self):
576 class Abstract(abc.ABCMeta):
577
578 @abc.abstractmethod
579 def add(self, x, y):
580 pass
581
582 add5 = functools.partialmethod(add, 5)
583
584 self.assertTrue(Abstract.add.__isabstractmethod__)
585 self.assertTrue(Abstract.add5.__isabstractmethod__)
586
587 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
588 self.assertFalse(getattr(func, '__isabstractmethod__', False))
589
Pablo Galindo8c77b8c2019-04-29 13:36:57 +0100590 def test_positional_only(self):
591 def f(a, b, /):
592 return a + b
593
594 p = functools.partial(f, 1)
595 self.assertEqual(p(2), f(1, 2))
596
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000597
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000598class TestUpdateWrapper(unittest.TestCase):
599
600 def check_wrapper(self, wrapper, wrapped,
601 assigned=functools.WRAPPER_ASSIGNMENTS,
602 updated=functools.WRAPPER_UPDATES):
603 # Check attributes were assigned
604 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000605 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000606 # Check attributes were updated
607 for name in updated:
608 wrapper_attr = getattr(wrapper, name)
609 wrapped_attr = getattr(wrapped, name)
610 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000611 if name == "__dict__" and key == "__wrapped__":
612 # __wrapped__ is overwritten by the update code
613 continue
614 self.assertIs(wrapped_attr[key], wrapper_attr[key])
615 # Check __wrapped__
616 self.assertIs(wrapper.__wrapped__, wrapped)
617
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000618
R. David Murray378c0cf2010-02-24 01:46:21 +0000619 def _default_update(self):
Batuhan Taskaya044a1042020-10-06 23:03:02 +0300620 def f(a: int):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000621 """This is a test"""
622 pass
623 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000624 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000625 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000626 pass
627 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000628 return wrapper, f
629
630 def test_default_update(self):
631 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000632 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000633 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000634 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600635 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000636 self.assertEqual(wrapper.attr, 'This is also a test')
Batuhan Taskaya044a1042020-10-06 23:03:02 +0300637 self.assertEqual(wrapper.__annotations__['a'], 'int')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000638 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000639
R. David Murray378c0cf2010-02-24 01:46:21 +0000640 @unittest.skipIf(sys.flags.optimize >= 2,
641 "Docstrings are omitted with -O2 and above")
642 def test_default_update_doc(self):
643 wrapper, f = self._default_update()
644 self.assertEqual(wrapper.__doc__, 'This is a test')
645
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000646 def test_no_update(self):
647 def f():
648 """This is a test"""
649 pass
650 f.attr = 'This is also a test'
651 def wrapper():
652 pass
653 functools.update_wrapper(wrapper, f, (), ())
654 self.check_wrapper(wrapper, f, (), ())
655 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600656 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000657 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000658 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000659 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000660
661 def test_selective_update(self):
662 def f():
663 pass
664 f.attr = 'This is a different test'
665 f.dict_attr = dict(a=1, b=2, c=3)
666 def wrapper():
667 pass
668 wrapper.dict_attr = {}
669 assign = ('attr',)
670 update = ('dict_attr',)
671 functools.update_wrapper(wrapper, f, assign, update)
672 self.check_wrapper(wrapper, f, assign, update)
673 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600674 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000675 self.assertEqual(wrapper.__doc__, None)
676 self.assertEqual(wrapper.attr, 'This is a different test')
677 self.assertEqual(wrapper.dict_attr, f.dict_attr)
678
Nick Coghlan98876832010-08-17 06:17:18 +0000679 def test_missing_attributes(self):
680 def f():
681 pass
682 def wrapper():
683 pass
684 wrapper.dict_attr = {}
685 assign = ('attr',)
686 update = ('dict_attr',)
687 # Missing attributes on wrapped object are ignored
688 functools.update_wrapper(wrapper, f, assign, update)
689 self.assertNotIn('attr', wrapper.__dict__)
690 self.assertEqual(wrapper.dict_attr, {})
691 # Wrapper must have expected attributes for updating
692 del wrapper.dict_attr
693 with self.assertRaises(AttributeError):
694 functools.update_wrapper(wrapper, f, assign, update)
695 wrapper.dict_attr = 1
696 with self.assertRaises(AttributeError):
697 functools.update_wrapper(wrapper, f, assign, update)
698
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200699 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000700 @unittest.skipIf(sys.flags.optimize >= 2,
701 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000702 def test_builtin_update(self):
703 # Test for bug #1576241
704 def wrapper():
705 pass
706 functools.update_wrapper(wrapper, max)
707 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000708 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000709 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000710
Łukasz Langa6f692512013-06-05 12:20:24 +0200711
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000712class TestWraps(TestUpdateWrapper):
713
R. David Murray378c0cf2010-02-24 01:46:21 +0000714 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000715 def f():
716 """This is a test"""
717 pass
718 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000719 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000720 @functools.wraps(f)
721 def wrapper():
722 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600723 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000724
725 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600726 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000727 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000728 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600729 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000730 self.assertEqual(wrapper.attr, 'This is also a test')
731
Antoine Pitroub5b37142012-11-13 21:35:40 +0100732 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000733 "Docstrings are omitted with -O2 and above")
734 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600735 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000736 self.assertEqual(wrapper.__doc__, 'This is a test')
737
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000738 def test_no_update(self):
739 def f():
740 """This is a test"""
741 pass
742 f.attr = 'This is also a test'
743 @functools.wraps(f, (), ())
744 def wrapper():
745 pass
746 self.check_wrapper(wrapper, f, (), ())
747 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600748 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000749 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000750 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000751
752 def test_selective_update(self):
753 def f():
754 pass
755 f.attr = 'This is a different test'
756 f.dict_attr = dict(a=1, b=2, c=3)
757 def add_dict_attr(f):
758 f.dict_attr = {}
759 return f
760 assign = ('attr',)
761 update = ('dict_attr',)
762 @functools.wraps(f, assign, update)
763 @add_dict_attr
764 def wrapper():
765 pass
766 self.check_wrapper(wrapper, f, assign, update)
767 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600768 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000769 self.assertEqual(wrapper.__doc__, None)
770 self.assertEqual(wrapper.attr, 'This is a different test')
771 self.assertEqual(wrapper.dict_attr, f.dict_attr)
772
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000773
madman-bobe25d5fc2018-10-25 15:02:10 +0100774class TestReduce:
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000775 def test_reduce(self):
776 class Squares:
777 def __init__(self, max):
778 self.max = max
779 self.sofar = []
780
781 def __len__(self):
782 return len(self.sofar)
783
784 def __getitem__(self, i):
785 if not 0 <= i < self.max: raise IndexError
786 n = len(self.sofar)
787 while n <= i:
788 self.sofar.append(n*n)
789 n += 1
790 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000791 def add(x, y):
792 return x + y
madman-bobe25d5fc2018-10-25 15:02:10 +0100793 self.assertEqual(self.reduce(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000794 self.assertEqual(
madman-bobe25d5fc2018-10-25 15:02:10 +0100795 self.reduce(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000796 ['a','c','d','w']
797 )
madman-bobe25d5fc2018-10-25 15:02:10 +0100798 self.assertEqual(self.reduce(lambda x, y: x*y, range(2,8), 1), 5040)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000799 self.assertEqual(
madman-bobe25d5fc2018-10-25 15:02:10 +0100800 self.reduce(lambda x, y: x*y, range(2,21), 1),
Guido van Rossume2a383d2007-01-15 16:59:06 +0000801 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000802 )
madman-bobe25d5fc2018-10-25 15:02:10 +0100803 self.assertEqual(self.reduce(add, Squares(10)), 285)
804 self.assertEqual(self.reduce(add, Squares(10), 0), 285)
805 self.assertEqual(self.reduce(add, Squares(0), 0), 0)
806 self.assertRaises(TypeError, self.reduce)
807 self.assertRaises(TypeError, self.reduce, 42, 42)
808 self.assertRaises(TypeError, self.reduce, 42, 42, 42)
809 self.assertEqual(self.reduce(42, "1"), "1") # func is never called with one item
810 self.assertEqual(self.reduce(42, "", "1"), "1") # func is never called with one item
811 self.assertRaises(TypeError, self.reduce, 42, (42, 42))
812 self.assertRaises(TypeError, self.reduce, add, []) # arg 2 must not be empty sequence with no initial value
813 self.assertRaises(TypeError, self.reduce, add, "")
814 self.assertRaises(TypeError, self.reduce, add, ())
815 self.assertRaises(TypeError, self.reduce, add, object())
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000816
817 class TestFailingIter:
818 def __iter__(self):
819 raise RuntimeError
madman-bobe25d5fc2018-10-25 15:02:10 +0100820 self.assertRaises(RuntimeError, self.reduce, add, TestFailingIter())
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000821
madman-bobe25d5fc2018-10-25 15:02:10 +0100822 self.assertEqual(self.reduce(add, [], None), None)
823 self.assertEqual(self.reduce(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000824
825 class BadSeq:
826 def __getitem__(self, index):
827 raise ValueError
madman-bobe25d5fc2018-10-25 15:02:10 +0100828 self.assertRaises(ValueError, self.reduce, 42, BadSeq())
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000829
830 # Test reduce()'s use of iterators.
831 def test_iterator_usage(self):
832 class SequenceClass:
833 def __init__(self, n):
834 self.n = n
835 def __getitem__(self, i):
836 if 0 <= i < self.n:
837 return i
838 else:
839 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000840
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000841 from operator import add
madman-bobe25d5fc2018-10-25 15:02:10 +0100842 self.assertEqual(self.reduce(add, SequenceClass(5)), 10)
843 self.assertEqual(self.reduce(add, SequenceClass(5), 42), 52)
844 self.assertRaises(TypeError, self.reduce, add, SequenceClass(0))
845 self.assertEqual(self.reduce(add, SequenceClass(0), 42), 42)
846 self.assertEqual(self.reduce(add, SequenceClass(1)), 0)
847 self.assertEqual(self.reduce(add, SequenceClass(1), 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000848
849 d = {"one": 1, "two": 2, "three": 3}
madman-bobe25d5fc2018-10-25 15:02:10 +0100850 self.assertEqual(self.reduce(add, d), "".join(d.keys()))
851
852
853@unittest.skipUnless(c_functools, 'requires the C _functools module')
854class TestReduceC(TestReduce, unittest.TestCase):
855 if c_functools:
856 reduce = c_functools.reduce
857
858
859class TestReducePy(TestReduce, unittest.TestCase):
860 reduce = staticmethod(py_functools.reduce)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000861
Łukasz Langa6f692512013-06-05 12:20:24 +0200862
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200863class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700864
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000865 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700866 def cmp1(x, y):
867 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100868 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700869 self.assertEqual(key(3), key(3))
870 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100871 self.assertGreaterEqual(key(3), key(3))
872
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700873 def cmp2(x, y):
874 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100875 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700876 self.assertEqual(key(4.0), key('4'))
877 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100878 self.assertLessEqual(key(2), key('35'))
879 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700880
881 def test_cmp_to_key_arguments(self):
882 def cmp1(x, y):
883 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100884 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700885 self.assertEqual(key(obj=3), key(obj=3))
886 self.assertGreater(key(obj=3), key(obj=1))
887 with self.assertRaises((TypeError, AttributeError)):
888 key(3) > 1 # rhs is not a K object
889 with self.assertRaises((TypeError, AttributeError)):
890 1 < key(3) # lhs is not a K object
891 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100892 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700893 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200894 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100895 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700896 with self.assertRaises(TypeError):
897 key() # too few args
898 with self.assertRaises(TypeError):
899 key(None, None) # too many args
900
901 def test_bad_cmp(self):
902 def cmp1(x, y):
903 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100904 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700905 with self.assertRaises(ZeroDivisionError):
906 key(3) > key(1)
907
908 class BadCmp:
909 def __lt__(self, other):
910 raise ZeroDivisionError
911 def cmp1(x, y):
912 return BadCmp()
913 with self.assertRaises(ZeroDivisionError):
914 key(3) > key(1)
915
916 def test_obj_field(self):
917 def cmp1(x, y):
918 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100919 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700920 self.assertEqual(key(50).obj, 50)
921
922 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000923 def mycmp(x, y):
924 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100925 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000926 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000927
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700928 def test_sort_int_str(self):
929 def mycmp(x, y):
930 x, y = int(x), int(y)
931 return (x > y) - (x < y)
932 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100933 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700934 self.assertEqual([int(value) for value in values],
935 [0, 1, 1, 2, 3, 4, 5, 7, 10])
936
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000937 def test_hash(self):
938 def mycmp(x, y):
939 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100940 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000941 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700942 self.assertRaises(TypeError, hash, k)
Serhiy Storchaka2e576f52017-04-24 09:05:00 +0300943 self.assertNotIsInstance(k, collections.abc.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000944
Łukasz Langa6f692512013-06-05 12:20:24 +0200945
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200946@unittest.skipUnless(c_functools, 'requires the C _functools module')
947class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
948 if c_functools:
949 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100950
Łukasz Langa6f692512013-06-05 12:20:24 +0200951
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200952class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100953 cmp_to_key = staticmethod(py_functools.cmp_to_key)
954
Łukasz Langa6f692512013-06-05 12:20:24 +0200955
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000956class TestTotalOrdering(unittest.TestCase):
957
958 def test_total_ordering_lt(self):
959 @functools.total_ordering
960 class A:
961 def __init__(self, value):
962 self.value = value
963 def __lt__(self, other):
964 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000965 def __eq__(self, other):
966 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000967 self.assertTrue(A(1) < A(2))
968 self.assertTrue(A(2) > A(1))
969 self.assertTrue(A(1) <= A(2))
970 self.assertTrue(A(2) >= A(1))
971 self.assertTrue(A(2) <= A(2))
972 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000973 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000974
975 def test_total_ordering_le(self):
976 @functools.total_ordering
977 class A:
978 def __init__(self, value):
979 self.value = value
980 def __le__(self, other):
981 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000982 def __eq__(self, other):
983 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000984 self.assertTrue(A(1) < A(2))
985 self.assertTrue(A(2) > A(1))
986 self.assertTrue(A(1) <= A(2))
987 self.assertTrue(A(2) >= A(1))
988 self.assertTrue(A(2) <= A(2))
989 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000990 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000991
992 def test_total_ordering_gt(self):
993 @functools.total_ordering
994 class A:
995 def __init__(self, value):
996 self.value = value
997 def __gt__(self, other):
998 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000999 def __eq__(self, other):
1000 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +00001001 self.assertTrue(A(1) < A(2))
1002 self.assertTrue(A(2) > A(1))
1003 self.assertTrue(A(1) <= A(2))
1004 self.assertTrue(A(2) >= A(1))
1005 self.assertTrue(A(2) <= A(2))
1006 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +10001007 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +00001008
1009 def test_total_ordering_ge(self):
1010 @functools.total_ordering
1011 class A:
1012 def __init__(self, value):
1013 self.value = value
1014 def __ge__(self, other):
1015 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001016 def __eq__(self, other):
1017 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +00001018 self.assertTrue(A(1) < A(2))
1019 self.assertTrue(A(2) > A(1))
1020 self.assertTrue(A(1) <= A(2))
1021 self.assertTrue(A(2) >= A(1))
1022 self.assertTrue(A(2) <= A(2))
1023 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +10001024 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +00001025
1026 def test_total_ordering_no_overwrite(self):
1027 # new methods should not overwrite existing
1028 @functools.total_ordering
1029 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +00001030 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +00001031 self.assertTrue(A(1) < A(2))
1032 self.assertTrue(A(2) > A(1))
1033 self.assertTrue(A(1) <= A(2))
1034 self.assertTrue(A(2) >= A(1))
1035 self.assertTrue(A(2) <= A(2))
1036 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001037
Benjamin Peterson42ebee32010-04-11 01:43:16 +00001038 def test_no_operations_defined(self):
1039 with self.assertRaises(ValueError):
1040 @functools.total_ordering
1041 class A:
1042 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001043
Nick Coghlanf05d9812013-10-02 00:02:03 +10001044 def test_type_error_when_not_implemented(self):
1045 # bug 10042; ensure stack overflow does not occur
1046 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001047 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +10001048 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001049 def __init__(self, value):
1050 self.value = value
1051 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001052 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001053 return self.value == other.value
1054 return False
1055 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001056 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001057 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +10001058 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001059
Nick Coghlanf05d9812013-10-02 00:02:03 +10001060 @functools.total_ordering
1061 class ImplementsGreaterThan:
1062 def __init__(self, value):
1063 self.value = value
1064 def __eq__(self, other):
1065 if isinstance(other, ImplementsGreaterThan):
1066 return self.value == other.value
1067 return False
1068 def __gt__(self, other):
1069 if isinstance(other, ImplementsGreaterThan):
1070 return self.value > other.value
1071 return NotImplemented
1072
1073 @functools.total_ordering
1074 class ImplementsLessThanEqualTo:
1075 def __init__(self, value):
1076 self.value = value
1077 def __eq__(self, other):
1078 if isinstance(other, ImplementsLessThanEqualTo):
1079 return self.value == other.value
1080 return False
1081 def __le__(self, other):
1082 if isinstance(other, ImplementsLessThanEqualTo):
1083 return self.value <= other.value
1084 return NotImplemented
1085
1086 @functools.total_ordering
1087 class ImplementsGreaterThanEqualTo:
1088 def __init__(self, value):
1089 self.value = value
1090 def __eq__(self, other):
1091 if isinstance(other, ImplementsGreaterThanEqualTo):
1092 return self.value == other.value
1093 return False
1094 def __ge__(self, other):
1095 if isinstance(other, ImplementsGreaterThanEqualTo):
1096 return self.value >= other.value
1097 return NotImplemented
1098
1099 @functools.total_ordering
1100 class ComparatorNotImplemented:
1101 def __init__(self, value):
1102 self.value = value
1103 def __eq__(self, other):
1104 if isinstance(other, ComparatorNotImplemented):
1105 return self.value == other.value
1106 return False
1107 def __lt__(self, other):
1108 return NotImplemented
1109
1110 with self.subTest("LT < 1"), self.assertRaises(TypeError):
1111 ImplementsLessThan(-1) < 1
1112
1113 with self.subTest("LT < LE"), self.assertRaises(TypeError):
1114 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
1115
1116 with self.subTest("LT < GT"), self.assertRaises(TypeError):
1117 ImplementsLessThan(1) < ImplementsGreaterThan(1)
1118
1119 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
1120 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
1121
1122 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
1123 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
1124
1125 with self.subTest("GT > GE"), self.assertRaises(TypeError):
1126 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
1127
1128 with self.subTest("GT > LT"), self.assertRaises(TypeError):
1129 ImplementsGreaterThan(5) > ImplementsLessThan(5)
1130
1131 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
1132 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
1133
1134 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
1135 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
1136
1137 with self.subTest("GE when equal"):
1138 a = ComparatorNotImplemented(8)
1139 b = ComparatorNotImplemented(8)
1140 self.assertEqual(a, b)
1141 with self.assertRaises(TypeError):
1142 a >= b
1143
1144 with self.subTest("LE when equal"):
1145 a = ComparatorNotImplemented(9)
1146 b = ComparatorNotImplemented(9)
1147 self.assertEqual(a, b)
1148 with self.assertRaises(TypeError):
1149 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +02001150
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001151 def test_pickle(self):
Serhiy Storchaka92bb90a2016-09-22 11:39:25 +03001152 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001153 for name in '__lt__', '__gt__', '__le__', '__ge__':
1154 with self.subTest(method=name, proto=proto):
1155 method = getattr(Orderable_LT, name)
1156 method_copy = pickle.loads(pickle.dumps(method, proto))
1157 self.assertIs(method_copy, method)
1158
1159@functools.total_ordering
1160class Orderable_LT:
1161 def __init__(self, value):
1162 self.value = value
1163 def __lt__(self, other):
1164 return self.value < other.value
1165 def __eq__(self, other):
1166 return self.value == other.value
1167
1168
Raymond Hettinger21cdb712020-05-11 17:00:53 -07001169class TestCache:
1170 # This tests that the pass-through is working as designed.
1171 # The underlying functionality is tested in TestLRU.
1172
1173 def test_cache(self):
1174 @self.module.cache
1175 def fib(n):
1176 if n < 2:
1177 return n
1178 return fib(n-1) + fib(n-2)
1179 self.assertEqual([fib(n) for n in range(16)],
1180 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1181 self.assertEqual(fib.cache_info(),
1182 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1183 fib.cache_clear()
1184 self.assertEqual(fib.cache_info(),
1185 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1186
1187
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001188class TestLRU:
Georg Brandl2e7346a2010-07-31 18:09:23 +00001189
1190 def test_lru(self):
1191 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001192 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001193 f = self.module.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001194 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001195 self.assertEqual(maxsize, 20)
1196 self.assertEqual(currsize, 0)
1197 self.assertEqual(hits, 0)
1198 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001199
1200 domain = range(5)
1201 for i in range(1000):
1202 x, y = choice(domain), choice(domain)
1203 actual = f(x, y)
1204 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +00001205 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001206 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001207 self.assertTrue(hits > misses)
1208 self.assertEqual(hits + misses, 1000)
1209 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001210
Raymond Hettinger02566ec2010-09-04 22:46:06 +00001211 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +00001212 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001213 self.assertEqual(hits, 0)
1214 self.assertEqual(misses, 0)
1215 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001216 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001217 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001218 self.assertEqual(hits, 0)
1219 self.assertEqual(misses, 1)
1220 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001221
Nick Coghlan98876832010-08-17 06:17:18 +00001222 # Test bypassing the cache
1223 self.assertIs(f.__wrapped__, orig)
1224 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001225 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001226 self.assertEqual(hits, 0)
1227 self.assertEqual(misses, 1)
1228 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +00001229
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001230 # test size zero (which means "never-cache")
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001231 @self.module.lru_cache(0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001232 def f():
1233 nonlocal f_cnt
1234 f_cnt += 1
1235 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001236 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001237 f_cnt = 0
1238 for i in range(5):
1239 self.assertEqual(f(), 20)
1240 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001241 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001242 self.assertEqual(hits, 0)
1243 self.assertEqual(misses, 5)
1244 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001245
1246 # test size one
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001247 @self.module.lru_cache(1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001248 def f():
1249 nonlocal f_cnt
1250 f_cnt += 1
1251 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001252 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001253 f_cnt = 0
1254 for i in range(5):
1255 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001256 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001257 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001258 self.assertEqual(hits, 4)
1259 self.assertEqual(misses, 1)
1260 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001261
Raymond Hettingerf3098282010-08-15 03:30:45 +00001262 # test size two
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001263 @self.module.lru_cache(2)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001264 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001265 nonlocal f_cnt
1266 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +00001267 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +00001268 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001269 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +00001270 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1271 # * * * *
1272 self.assertEqual(f(x), x*10)
1273 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001274 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001275 self.assertEqual(hits, 12)
1276 self.assertEqual(misses, 4)
1277 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001278
Raymond Hettingerb8218682019-05-26 11:27:35 -07001279 def test_lru_no_args(self):
1280 @self.module.lru_cache
1281 def square(x):
1282 return x ** 2
1283
1284 self.assertEqual(list(map(square, [10, 20, 10])),
1285 [100, 400, 100])
1286 self.assertEqual(square.cache_info().hits, 1)
1287 self.assertEqual(square.cache_info().misses, 2)
1288 self.assertEqual(square.cache_info().maxsize, 128)
1289 self.assertEqual(square.cache_info().currsize, 2)
1290
Raymond Hettingerd8080c02019-01-26 03:02:00 -05001291 def test_lru_bug_35780(self):
1292 # C version of the lru_cache was not checking to see if
1293 # the user function call has already modified the cache
1294 # (this arises in recursive calls and in multi-threading).
1295 # This cause the cache to have orphan links not referenced
1296 # by the cache dictionary.
1297
1298 once = True # Modified by f(x) below
1299
1300 @self.module.lru_cache(maxsize=10)
1301 def f(x):
1302 nonlocal once
1303 rv = f'.{x}.'
1304 if x == 20 and once:
1305 once = False
1306 rv = f(x)
1307 return rv
1308
1309 # Fill the cache
1310 for x in range(15):
1311 self.assertEqual(f(x), f'.{x}.')
1312 self.assertEqual(f.cache_info().currsize, 10)
1313
1314 # Make a recursive call and make sure the cache remains full
1315 self.assertEqual(f(20), '.20.')
1316 self.assertEqual(f.cache_info().currsize, 10)
1317
Raymond Hettinger14adbd42019-04-20 07:20:44 -10001318 def test_lru_bug_36650(self):
1319 # C version of lru_cache was treating a call with an empty **kwargs
1320 # dictionary as being distinct from a call with no keywords at all.
1321 # This did not result in an incorrect answer, but it did trigger
1322 # an unexpected cache miss.
1323
1324 @self.module.lru_cache()
1325 def f(x):
1326 pass
1327
1328 f(0)
1329 f(0, **{})
1330 self.assertEqual(f.cache_info().hits, 1)
1331
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001332 def test_lru_hash_only_once(self):
1333 # To protect against weird reentrancy bugs and to improve
1334 # efficiency when faced with slow __hash__ methods, the
1335 # LRU cache guarantees that it will only call __hash__
1336 # only once per use as an argument to the cached function.
1337
1338 @self.module.lru_cache(maxsize=1)
1339 def f(x, y):
1340 return x * 3 + y
1341
1342 # Simulate the integer 5
1343 mock_int = unittest.mock.Mock()
1344 mock_int.__mul__ = unittest.mock.Mock(return_value=15)
1345 mock_int.__hash__ = unittest.mock.Mock(return_value=999)
1346
1347 # Add to cache: One use as an argument gives one call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001348 self.assertEqual(f(mock_int, 1), 16)
1349 self.assertEqual(mock_int.__hash__.call_count, 1)
1350 self.assertEqual(f.cache_info(), (0, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001351
1352 # Cache hit: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001353 self.assertEqual(f(mock_int, 1), 16)
1354 self.assertEqual(mock_int.__hash__.call_count, 2)
1355 self.assertEqual(f.cache_info(), (1, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001356
Ville Skyttä49b27342017-08-03 09:00:59 +03001357 # Cache eviction: No use as an argument gives no additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001358 self.assertEqual(f(6, 2), 20)
1359 self.assertEqual(mock_int.__hash__.call_count, 2)
1360 self.assertEqual(f.cache_info(), (1, 2, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001361
1362 # Cache miss: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001363 self.assertEqual(f(mock_int, 1), 16)
1364 self.assertEqual(mock_int.__hash__.call_count, 3)
1365 self.assertEqual(f.cache_info(), (1, 3, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001366
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08001367 def test_lru_reentrancy_with_len(self):
1368 # Test to make sure the LRU cache code isn't thrown-off by
1369 # caching the built-in len() function. Since len() can be
1370 # cached, we shouldn't use it inside the lru code itself.
1371 old_len = builtins.len
1372 try:
1373 builtins.len = self.module.lru_cache(4)(len)
1374 for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
1375 self.assertEqual(len('abcdefghijklmn'[:i]), i)
1376 finally:
1377 builtins.len = old_len
1378
Raymond Hettinger605a4472017-01-09 07:50:19 -08001379 def test_lru_star_arg_handling(self):
1380 # Test regression that arose in ea064ff3c10f
1381 @functools.lru_cache()
1382 def f(*args):
1383 return args
1384
1385 self.assertEqual(f(1, 2), (1, 2))
1386 self.assertEqual(f((1, 2)), ((1, 2),))
1387
Yury Selivanov46a02db2016-11-09 18:55:45 -05001388 def test_lru_type_error(self):
1389 # Regression test for issue #28653.
1390 # lru_cache was leaking when one of the arguments
1391 # wasn't cacheable.
1392
1393 @functools.lru_cache(maxsize=None)
1394 def infinite_cache(o):
1395 pass
1396
1397 @functools.lru_cache(maxsize=10)
1398 def limited_cache(o):
1399 pass
1400
1401 with self.assertRaises(TypeError):
1402 infinite_cache([])
1403
1404 with self.assertRaises(TypeError):
1405 limited_cache([])
1406
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001407 def test_lru_with_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001408 @self.module.lru_cache(maxsize=None)
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001409 def fib(n):
1410 if n < 2:
1411 return n
1412 return fib(n-1) + fib(n-2)
1413 self.assertEqual([fib(n) for n in range(16)],
1414 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1415 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001416 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001417 fib.cache_clear()
1418 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001419 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1420
1421 def test_lru_with_maxsize_negative(self):
1422 @self.module.lru_cache(maxsize=-10)
1423 def eq(n):
1424 return n
1425 for i in (0, 1):
1426 self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1427 self.assertEqual(eq.cache_info(),
Raymond Hettingerd8080c02019-01-26 03:02:00 -05001428 self.module._CacheInfo(hits=0, misses=300, maxsize=0, currsize=0))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001429
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001430 def test_lru_with_exceptions(self):
1431 # Verify that user_function exceptions get passed through without
1432 # creating a hard-to-read chained exception.
1433 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +01001434 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001435 @self.module.lru_cache(maxsize)
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001436 def func(i):
1437 return 'abc'[i]
1438 self.assertEqual(func(0), 'a')
1439 with self.assertRaises(IndexError) as cm:
1440 func(15)
1441 self.assertIsNone(cm.exception.__context__)
1442 # Verify that the previous exception did not result in a cached entry
1443 with self.assertRaises(IndexError):
1444 func(15)
1445
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001446 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001447 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001448 @self.module.lru_cache(maxsize=maxsize, typed=True)
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001449 def square(x):
1450 return x * x
1451 self.assertEqual(square(3), 9)
1452 self.assertEqual(type(square(3)), type(9))
1453 self.assertEqual(square(3.0), 9.0)
1454 self.assertEqual(type(square(3.0)), type(9.0))
1455 self.assertEqual(square(x=3), 9)
1456 self.assertEqual(type(square(x=3)), type(9))
1457 self.assertEqual(square(x=3.0), 9.0)
1458 self.assertEqual(type(square(x=3.0)), type(9.0))
1459 self.assertEqual(square.cache_info().hits, 4)
1460 self.assertEqual(square.cache_info().misses, 4)
1461
Antoine Pitroub5b37142012-11-13 21:35:40 +01001462 def test_lru_with_keyword_args(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001463 @self.module.lru_cache()
Antoine Pitroub5b37142012-11-13 21:35:40 +01001464 def fib(n):
1465 if n < 2:
1466 return n
1467 return fib(n=n-1) + fib(n=n-2)
1468 self.assertEqual(
1469 [fib(n=number) for number in range(16)],
1470 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1471 )
1472 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001473 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001474 fib.cache_clear()
1475 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001476 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001477
1478 def test_lru_with_keyword_args_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001479 @self.module.lru_cache(maxsize=None)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001480 def fib(n):
1481 if n < 2:
1482 return n
1483 return fib(n=n-1) + fib(n=n-2)
1484 self.assertEqual([fib(n=number) for number in range(16)],
1485 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1486 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001487 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001488 fib.cache_clear()
1489 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001490 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1491
Raymond Hettinger4ee39142017-01-08 17:28:20 -08001492 def test_kwargs_order(self):
1493 # PEP 468: Preserving Keyword Argument Order
1494 @self.module.lru_cache(maxsize=10)
1495 def f(**kwargs):
1496 return list(kwargs.items())
1497 self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)])
1498 self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)])
1499 self.assertEqual(f.cache_info(),
1500 self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2))
1501
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001502 def test_lru_cache_decoration(self):
1503 def f(zomg: 'zomg_annotation'):
1504 """f doc string"""
1505 return 42
1506 g = self.module.lru_cache()(f)
1507 for attr in self.module.WRAPPER_ASSIGNMENTS:
1508 self.assertEqual(getattr(g, attr), getattr(f, attr))
1509
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001510 def test_lru_cache_threaded(self):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001511 n, m = 5, 11
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001512 def orig(x, y):
1513 return 3 * x + y
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001514 f = self.module.lru_cache(maxsize=n*m)(orig)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001515 hits, misses, maxsize, currsize = f.cache_info()
1516 self.assertEqual(currsize, 0)
1517
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001518 start = threading.Event()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001519 def full(k):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001520 start.wait(10)
1521 for _ in range(m):
Serhiy Storchaka391af752015-06-08 12:44:18 +03001522 self.assertEqual(f(k, 0), orig(k, 0))
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001523
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001524 def clear():
1525 start.wait(10)
1526 for _ in range(2*m):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001527 f.cache_clear()
1528
1529 orig_si = sys.getswitchinterval()
Xavier de Gaye7522ef42016-12-08 11:06:56 +01001530 support.setswitchinterval(1e-6)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001531 try:
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001532 # create n threads in order to fill cache
Serhiy Storchaka391af752015-06-08 12:44:18 +03001533 threads = [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001534 for k in range(n)]
Hai Shie80697d2020-05-28 06:10:27 +08001535 with threading_helper.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001536 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001537
1538 hits, misses, maxsize, currsize = f.cache_info()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001539 if self.module is py_functools:
1540 # XXX: Why can be not equal?
1541 self.assertLessEqual(misses, n)
1542 self.assertLessEqual(hits, m*n - misses)
1543 else:
1544 self.assertEqual(misses, n)
1545 self.assertEqual(hits, m*n - misses)
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001546 self.assertEqual(currsize, n)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001547
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001548 # create n threads in order to fill cache and 1 to clear it
1549 threads = [threading.Thread(target=clear)]
Serhiy Storchaka391af752015-06-08 12:44:18 +03001550 threads += [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001551 for k in range(n)]
1552 start.clear()
Hai Shie80697d2020-05-28 06:10:27 +08001553 with threading_helper.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001554 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001555 finally:
1556 sys.setswitchinterval(orig_si)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001557
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001558 def test_lru_cache_threaded2(self):
1559 # Simultaneous call with the same arguments
1560 n, m = 5, 7
1561 start = threading.Barrier(n+1)
1562 pause = threading.Barrier(n+1)
1563 stop = threading.Barrier(n+1)
1564 @self.module.lru_cache(maxsize=m*n)
1565 def f(x):
1566 pause.wait(10)
1567 return 3 * x
1568 self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1569 def test():
1570 for i in range(m):
1571 start.wait(10)
1572 self.assertEqual(f(i), 3 * i)
1573 stop.wait(10)
1574 threads = [threading.Thread(target=test) for k in range(n)]
Hai Shie80697d2020-05-28 06:10:27 +08001575 with threading_helper.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001576 for i in range(m):
1577 start.wait(10)
1578 stop.reset()
1579 pause.wait(10)
1580 start.reset()
1581 stop.wait(10)
1582 pause.reset()
1583 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1584
Serhiy Storchaka67796522017-01-12 18:34:33 +02001585 def test_lru_cache_threaded3(self):
1586 @self.module.lru_cache(maxsize=2)
1587 def f(x):
1588 time.sleep(.01)
1589 return 3 * x
1590 def test(i, x):
1591 with self.subTest(thread=i):
1592 self.assertEqual(f(x), 3 * x, i)
1593 threads = [threading.Thread(target=test, args=(i, v))
1594 for i, v in enumerate([1, 2, 2, 3, 2])]
Hai Shie80697d2020-05-28 06:10:27 +08001595 with threading_helper.start_threads(threads):
Serhiy Storchaka67796522017-01-12 18:34:33 +02001596 pass
1597
Raymond Hettinger03923422013-03-04 02:52:50 -05001598 def test_need_for_rlock(self):
1599 # This will deadlock on an LRU cache that uses a regular lock
1600
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001601 @self.module.lru_cache(maxsize=10)
Raymond Hettinger03923422013-03-04 02:52:50 -05001602 def test_func(x):
1603 'Used to demonstrate a reentrant lru_cache call within a single thread'
1604 return x
1605
1606 class DoubleEq:
1607 'Demonstrate a reentrant lru_cache call within a single thread'
1608 def __init__(self, x):
1609 self.x = x
1610 def __hash__(self):
1611 return self.x
1612 def __eq__(self, other):
1613 if self.x == 2:
1614 test_func(DoubleEq(1))
1615 return self.x == other.x
1616
1617 test_func(DoubleEq(1)) # Load the cache
1618 test_func(DoubleEq(2)) # Load the cache
1619 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1620 DoubleEq(2)) # Verify the correct return value
1621
Serhiy Storchakae7070f02015-06-08 11:19:24 +03001622 def test_lru_method(self):
1623 class X(int):
1624 f_cnt = 0
1625 @self.module.lru_cache(2)
1626 def f(self, x):
1627 self.f_cnt += 1
1628 return x*10+self
1629 a = X(5)
1630 b = X(5)
1631 c = X(7)
1632 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1633
1634 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1635 self.assertEqual(a.f(x), x*10 + 5)
1636 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1637 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1638
1639 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1640 self.assertEqual(b.f(x), x*10 + 5)
1641 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1642 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1643
1644 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1645 self.assertEqual(c.f(x), x*10 + 7)
1646 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1647 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1648
1649 self.assertEqual(a.f.cache_info(), X.f.cache_info())
1650 self.assertEqual(b.f.cache_info(), X.f.cache_info())
1651 self.assertEqual(c.f.cache_info(), X.f.cache_info())
1652
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001653 def test_pickle(self):
1654 cls = self.__class__
1655 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1656 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1657 with self.subTest(proto=proto, func=f):
1658 f_copy = pickle.loads(pickle.dumps(f, proto))
1659 self.assertIs(f_copy, f)
1660
1661 def test_copy(self):
1662 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001663 def orig(x, y):
1664 return 3 * x + y
1665 part = self.module.partial(orig, 2)
1666 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1667 self.module.lru_cache(2)(part))
1668 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001669 with self.subTest(func=f):
1670 f_copy = copy.copy(f)
1671 self.assertIs(f_copy, f)
1672
1673 def test_deepcopy(self):
1674 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001675 def orig(x, y):
1676 return 3 * x + y
1677 part = self.module.partial(orig, 2)
1678 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1679 self.module.lru_cache(2)(part))
1680 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001681 with self.subTest(func=f):
1682 f_copy = copy.deepcopy(f)
1683 self.assertIs(f_copy, f)
1684
Manjusaka051ff522019-11-12 15:30:18 +08001685 def test_lru_cache_parameters(self):
1686 @self.module.lru_cache(maxsize=2)
1687 def f():
1688 return 1
1689 self.assertEqual(f.cache_parameters(), {'maxsize': 2, "typed": False})
1690
1691 @self.module.lru_cache(maxsize=1000, typed=True)
1692 def f():
1693 return 1
1694 self.assertEqual(f.cache_parameters(), {'maxsize': 1000, "typed": True})
1695
Dennis Sweeney1253c3e2020-05-05 17:14:32 -04001696 def test_lru_cache_weakrefable(self):
1697 @self.module.lru_cache
1698 def test_function(x):
1699 return x
1700
1701 class A:
1702 @self.module.lru_cache
1703 def test_method(self, x):
1704 return (self, x)
1705
1706 @staticmethod
1707 @self.module.lru_cache
1708 def test_staticmethod(x):
1709 return (self, x)
1710
1711 refs = [weakref.ref(test_function),
1712 weakref.ref(A.test_method),
1713 weakref.ref(A.test_staticmethod)]
1714
1715 for ref in refs:
1716 self.assertIsNotNone(ref())
1717
1718 del A
1719 del test_function
1720 gc.collect()
1721
1722 for ref in refs:
1723 self.assertIsNone(ref())
1724
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001725
1726@py_functools.lru_cache()
1727def py_cached_func(x, y):
1728 return 3 * x + y
1729
1730@c_functools.lru_cache()
1731def c_cached_func(x, y):
1732 return 3 * x + y
1733
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001734
1735class TestLRUPy(TestLRU, unittest.TestCase):
1736 module = py_functools
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001737 cached_func = py_cached_func,
1738
1739 @module.lru_cache()
1740 def cached_meth(self, x, y):
1741 return 3 * x + y
1742
1743 @staticmethod
1744 @module.lru_cache()
1745 def cached_staticmeth(x, y):
1746 return 3 * x + y
1747
1748
1749class TestLRUC(TestLRU, unittest.TestCase):
1750 module = c_functools
1751 cached_func = c_cached_func,
1752
1753 @module.lru_cache()
1754 def cached_meth(self, x, y):
1755 return 3 * x + y
1756
1757 @staticmethod
1758 @module.lru_cache()
1759 def cached_staticmeth(x, y):
1760 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001761
Raymond Hettinger03923422013-03-04 02:52:50 -05001762
Łukasz Langa6f692512013-06-05 12:20:24 +02001763class TestSingleDispatch(unittest.TestCase):
1764 def test_simple_overloads(self):
1765 @functools.singledispatch
1766 def g(obj):
1767 return "base"
1768 def g_int(i):
1769 return "integer"
1770 g.register(int, g_int)
1771 self.assertEqual(g("str"), "base")
1772 self.assertEqual(g(1), "integer")
1773 self.assertEqual(g([1,2,3]), "base")
1774
1775 def test_mro(self):
1776 @functools.singledispatch
1777 def g(obj):
1778 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001779 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001780 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001781 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001782 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001783 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001784 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001785 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001786 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001787 def g_A(a):
1788 return "A"
1789 def g_B(b):
1790 return "B"
1791 g.register(A, g_A)
1792 g.register(B, g_B)
1793 self.assertEqual(g(A()), "A")
1794 self.assertEqual(g(B()), "B")
1795 self.assertEqual(g(C()), "A")
1796 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001797
1798 def test_register_decorator(self):
1799 @functools.singledispatch
1800 def g(obj):
1801 return "base"
1802 @g.register(int)
1803 def g_int(i):
1804 return "int %s" % (i,)
1805 self.assertEqual(g(""), "base")
1806 self.assertEqual(g(12), "int 12")
1807 self.assertIs(g.dispatch(int), g_int)
1808 self.assertIs(g.dispatch(object), g.dispatch(str))
1809 # Note: in the assert above this is not g.
1810 # @singledispatch returns the wrapper.
1811
1812 def test_wrapping_attributes(self):
1813 @functools.singledispatch
1814 def g(obj):
1815 "Simple test"
1816 return "Test"
1817 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02001818 if sys.flags.optimize < 2:
1819 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02001820
1821 @unittest.skipUnless(decimal, 'requires _decimal')
1822 @support.cpython_only
1823 def test_c_classes(self):
1824 @functools.singledispatch
1825 def g(obj):
1826 return "base"
1827 @g.register(decimal.DecimalException)
1828 def _(obj):
1829 return obj.args
1830 subn = decimal.Subnormal("Exponent < Emin")
1831 rnd = decimal.Rounded("Number got rounded")
1832 self.assertEqual(g(subn), ("Exponent < Emin",))
1833 self.assertEqual(g(rnd), ("Number got rounded",))
1834 @g.register(decimal.Subnormal)
1835 def _(obj):
1836 return "Too small to care."
1837 self.assertEqual(g(subn), "Too small to care.")
1838 self.assertEqual(g(rnd), ("Number got rounded",))
1839
1840 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001841 # None of the examples in this test depend on haystack ordering.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001842 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001843 mro = functools._compose_mro
1844 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1845 for haystack in permutations(bases):
1846 m = mro(dict, haystack)
Guido van Rossumf0666942016-08-23 10:47:07 -07001847 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
1848 c.Collection, c.Sized, c.Iterable,
1849 c.Container, object])
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001850 bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict]
Łukasz Langa6f692512013-06-05 12:20:24 +02001851 for haystack in permutations(bases):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001852 m = mro(collections.ChainMap, haystack)
1853 self.assertEqual(m, [collections.ChainMap, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001854 c.Collection, c.Sized, c.Iterable,
1855 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001856
1857 # If there's a generic function with implementations registered for
1858 # both Sized and Container, passing a defaultdict to it results in an
1859 # ambiguous dispatch which will cause a RuntimeError (see
1860 # test_mro_conflicts).
1861 bases = [c.Container, c.Sized, str]
1862 for haystack in permutations(bases):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001863 m = mro(collections.defaultdict, [c.Sized, c.Container, str])
1864 self.assertEqual(m, [collections.defaultdict, dict, c.Sized,
1865 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001866
1867 # MutableSequence below is registered directly on D. In other words, it
Martin Panter46f50722016-05-26 05:35:26 +00001868 # precedes MutableMapping which means single dispatch will always
Łukasz Langa3720c772013-07-01 16:00:38 +02001869 # choose MutableSequence here.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001870 class D(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02001871 pass
1872 c.MutableSequence.register(D)
1873 bases = [c.MutableSequence, c.MutableMapping]
1874 for haystack in permutations(bases):
1875 m = mro(D, bases)
Guido van Rossumf0666942016-08-23 10:47:07 -07001876 self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001877 collections.defaultdict, dict, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001878 c.Collection, c.Sized, c.Iterable, c.Container,
Łukasz Langa3720c772013-07-01 16:00:38 +02001879 object])
1880
1881 # Container and Callable are registered on different base classes and
1882 # a generic function supporting both should always pick the Callable
1883 # implementation if a C instance is passed.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001884 class C(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02001885 def __call__(self):
1886 pass
1887 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1888 for haystack in permutations(bases):
1889 m = mro(C, haystack)
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001890 self.assertEqual(m, [C, c.Callable, collections.defaultdict, dict, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001891 c.Collection, c.Sized, c.Iterable,
1892 c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001893
1894 def test_register_abc(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001895 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001896 d = {"a": "b"}
1897 l = [1, 2, 3]
1898 s = {object(), None}
1899 f = frozenset(s)
1900 t = (1, 2, 3)
1901 @functools.singledispatch
1902 def g(obj):
1903 return "base"
1904 self.assertEqual(g(d), "base")
1905 self.assertEqual(g(l), "base")
1906 self.assertEqual(g(s), "base")
1907 self.assertEqual(g(f), "base")
1908 self.assertEqual(g(t), "base")
1909 g.register(c.Sized, lambda obj: "sized")
1910 self.assertEqual(g(d), "sized")
1911 self.assertEqual(g(l), "sized")
1912 self.assertEqual(g(s), "sized")
1913 self.assertEqual(g(f), "sized")
1914 self.assertEqual(g(t), "sized")
1915 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1916 self.assertEqual(g(d), "mutablemapping")
1917 self.assertEqual(g(l), "sized")
1918 self.assertEqual(g(s), "sized")
1919 self.assertEqual(g(f), "sized")
1920 self.assertEqual(g(t), "sized")
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001921 g.register(collections.ChainMap, lambda obj: "chainmap")
Łukasz Langa6f692512013-06-05 12:20:24 +02001922 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1923 self.assertEqual(g(l), "sized")
1924 self.assertEqual(g(s), "sized")
1925 self.assertEqual(g(f), "sized")
1926 self.assertEqual(g(t), "sized")
1927 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1928 self.assertEqual(g(d), "mutablemapping")
1929 self.assertEqual(g(l), "mutablesequence")
1930 self.assertEqual(g(s), "sized")
1931 self.assertEqual(g(f), "sized")
1932 self.assertEqual(g(t), "sized")
1933 g.register(c.MutableSet, lambda obj: "mutableset")
1934 self.assertEqual(g(d), "mutablemapping")
1935 self.assertEqual(g(l), "mutablesequence")
1936 self.assertEqual(g(s), "mutableset")
1937 self.assertEqual(g(f), "sized")
1938 self.assertEqual(g(t), "sized")
1939 g.register(c.Mapping, lambda obj: "mapping")
1940 self.assertEqual(g(d), "mutablemapping") # not specific enough
1941 self.assertEqual(g(l), "mutablesequence")
1942 self.assertEqual(g(s), "mutableset")
1943 self.assertEqual(g(f), "sized")
1944 self.assertEqual(g(t), "sized")
1945 g.register(c.Sequence, lambda obj: "sequence")
1946 self.assertEqual(g(d), "mutablemapping")
1947 self.assertEqual(g(l), "mutablesequence")
1948 self.assertEqual(g(s), "mutableset")
1949 self.assertEqual(g(f), "sized")
1950 self.assertEqual(g(t), "sequence")
1951 g.register(c.Set, lambda obj: "set")
1952 self.assertEqual(g(d), "mutablemapping")
1953 self.assertEqual(g(l), "mutablesequence")
1954 self.assertEqual(g(s), "mutableset")
1955 self.assertEqual(g(f), "set")
1956 self.assertEqual(g(t), "sequence")
1957 g.register(dict, lambda obj: "dict")
1958 self.assertEqual(g(d), "dict")
1959 self.assertEqual(g(l), "mutablesequence")
1960 self.assertEqual(g(s), "mutableset")
1961 self.assertEqual(g(f), "set")
1962 self.assertEqual(g(t), "sequence")
1963 g.register(list, lambda obj: "list")
1964 self.assertEqual(g(d), "dict")
1965 self.assertEqual(g(l), "list")
1966 self.assertEqual(g(s), "mutableset")
1967 self.assertEqual(g(f), "set")
1968 self.assertEqual(g(t), "sequence")
1969 g.register(set, lambda obj: "concrete-set")
1970 self.assertEqual(g(d), "dict")
1971 self.assertEqual(g(l), "list")
1972 self.assertEqual(g(s), "concrete-set")
1973 self.assertEqual(g(f), "set")
1974 self.assertEqual(g(t), "sequence")
1975 g.register(frozenset, lambda obj: "frozen-set")
1976 self.assertEqual(g(d), "dict")
1977 self.assertEqual(g(l), "list")
1978 self.assertEqual(g(s), "concrete-set")
1979 self.assertEqual(g(f), "frozen-set")
1980 self.assertEqual(g(t), "sequence")
1981 g.register(tuple, lambda obj: "tuple")
1982 self.assertEqual(g(d), "dict")
1983 self.assertEqual(g(l), "list")
1984 self.assertEqual(g(s), "concrete-set")
1985 self.assertEqual(g(f), "frozen-set")
1986 self.assertEqual(g(t), "tuple")
1987
Łukasz Langa3720c772013-07-01 16:00:38 +02001988 def test_c3_abc(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001989 c = collections.abc
Łukasz Langa3720c772013-07-01 16:00:38 +02001990 mro = functools._c3_mro
1991 class A(object):
1992 pass
1993 class B(A):
1994 def __len__(self):
1995 return 0 # implies Sized
1996 @c.Container.register
1997 class C(object):
1998 pass
1999 class D(object):
2000 pass # unrelated
2001 class X(D, C, B):
2002 def __call__(self):
2003 pass # implies Callable
2004 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
2005 for abcs in permutations([c.Sized, c.Callable, c.Container]):
2006 self.assertEqual(mro(X, abcs=abcs), expected)
2007 # unrelated ABCs don't appear in the resulting MRO
2008 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
2009 self.assertEqual(mro(X, abcs=many_abcs), expected)
2010
Yury Selivanov77a8cd62015-08-18 14:20:00 -04002011 def test_false_meta(self):
2012 # see issue23572
2013 class MetaA(type):
2014 def __len__(self):
2015 return 0
2016 class A(metaclass=MetaA):
2017 pass
2018 class AA(A):
2019 pass
2020 @functools.singledispatch
2021 def fun(a):
2022 return 'base A'
2023 @fun.register(A)
2024 def _(a):
2025 return 'fun A'
2026 aa = AA()
2027 self.assertEqual(fun(aa), 'fun A')
2028
Łukasz Langa6f692512013-06-05 12:20:24 +02002029 def test_mro_conflicts(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002030 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02002031 @functools.singledispatch
2032 def g(arg):
2033 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02002034 class O(c.Sized):
2035 def __len__(self):
2036 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02002037 o = O()
2038 self.assertEqual(g(o), "base")
2039 g.register(c.Iterable, lambda arg: "iterable")
2040 g.register(c.Container, lambda arg: "container")
2041 g.register(c.Sized, lambda arg: "sized")
2042 g.register(c.Set, lambda arg: "set")
2043 self.assertEqual(g(o), "sized")
2044 c.Iterable.register(O)
2045 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
2046 c.Container.register(O)
2047 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02002048 c.Set.register(O)
2049 self.assertEqual(g(o), "set") # because c.Set is a subclass of
2050 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02002051 class P:
2052 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02002053 p = P()
2054 self.assertEqual(g(p), "base")
2055 c.Iterable.register(P)
2056 self.assertEqual(g(p), "iterable")
2057 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02002058 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02002059 g(p)
Benjamin Petersonab078e92016-07-13 21:13:29 -07002060 self.assertIn(
2061 str(re_one.exception),
2062 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2063 "or <class 'collections.abc.Iterable'>"),
2064 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
2065 "or <class 'collections.abc.Container'>")),
2066 )
Łukasz Langa6f692512013-06-05 12:20:24 +02002067 class Q(c.Sized):
2068 def __len__(self):
2069 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02002070 q = Q()
2071 self.assertEqual(g(q), "sized")
2072 c.Iterable.register(Q)
2073 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
2074 c.Set.register(Q)
2075 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02002076 # c.Sized and c.Iterable
2077 @functools.singledispatch
2078 def h(arg):
2079 return "base"
2080 @h.register(c.Sized)
2081 def _(arg):
2082 return "sized"
2083 @h.register(c.Container)
2084 def _(arg):
2085 return "container"
2086 # Even though Sized and Container are explicit bases of MutableMapping,
2087 # this ABC is implicitly registered on defaultdict which makes all of
2088 # MutableMapping's bases implicit as well from defaultdict's
2089 # perspective.
2090 with self.assertRaises(RuntimeError) as re_two:
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002091 h(collections.defaultdict(lambda: 0))
Benjamin Petersonab078e92016-07-13 21:13:29 -07002092 self.assertIn(
2093 str(re_two.exception),
2094 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2095 "or <class 'collections.abc.Sized'>"),
2096 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2097 "or <class 'collections.abc.Container'>")),
2098 )
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002099 class R(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02002100 pass
2101 c.MutableSequence.register(R)
2102 @functools.singledispatch
2103 def i(arg):
2104 return "base"
2105 @i.register(c.MutableMapping)
2106 def _(arg):
2107 return "mapping"
2108 @i.register(c.MutableSequence)
2109 def _(arg):
2110 return "sequence"
2111 r = R()
2112 self.assertEqual(i(r), "sequence")
2113 class S:
2114 pass
2115 class T(S, c.Sized):
2116 def __len__(self):
2117 return 0
2118 t = T()
2119 self.assertEqual(h(t), "sized")
2120 c.Container.register(T)
2121 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
2122 class U:
2123 def __len__(self):
2124 return 0
2125 u = U()
2126 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
2127 # from the existence of __len__()
2128 c.Container.register(U)
2129 # There is no preference for registered versus inferred ABCs.
2130 with self.assertRaises(RuntimeError) as re_three:
2131 h(u)
Benjamin Petersonab078e92016-07-13 21:13:29 -07002132 self.assertIn(
2133 str(re_three.exception),
2134 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2135 "or <class 'collections.abc.Sized'>"),
2136 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2137 "or <class 'collections.abc.Container'>")),
2138 )
Łukasz Langa3720c772013-07-01 16:00:38 +02002139 class V(c.Sized, S):
2140 def __len__(self):
2141 return 0
2142 @functools.singledispatch
2143 def j(arg):
2144 return "base"
2145 @j.register(S)
2146 def _(arg):
2147 return "s"
2148 @j.register(c.Container)
2149 def _(arg):
2150 return "container"
2151 v = V()
2152 self.assertEqual(j(v), "s")
2153 c.Container.register(V)
2154 self.assertEqual(j(v), "container") # because it ends up right after
2155 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02002156
2157 def test_cache_invalidation(self):
2158 from collections import UserDict
INADA Naoki9811e802017-09-30 16:13:02 +09002159 import weakref
2160
Łukasz Langa6f692512013-06-05 12:20:24 +02002161 class TracingDict(UserDict):
2162 def __init__(self, *args, **kwargs):
2163 super(TracingDict, self).__init__(*args, **kwargs)
2164 self.set_ops = []
2165 self.get_ops = []
2166 def __getitem__(self, key):
2167 result = self.data[key]
2168 self.get_ops.append(key)
2169 return result
2170 def __setitem__(self, key, value):
2171 self.set_ops.append(key)
2172 self.data[key] = value
2173 def clear(self):
2174 self.data.clear()
INADA Naoki9811e802017-09-30 16:13:02 +09002175
Łukasz Langa6f692512013-06-05 12:20:24 +02002176 td = TracingDict()
INADA Naoki9811e802017-09-30 16:13:02 +09002177 with support.swap_attr(weakref, "WeakKeyDictionary", lambda: td):
2178 c = collections.abc
2179 @functools.singledispatch
2180 def g(arg):
2181 return "base"
2182 d = {}
2183 l = []
2184 self.assertEqual(len(td), 0)
2185 self.assertEqual(g(d), "base")
2186 self.assertEqual(len(td), 1)
2187 self.assertEqual(td.get_ops, [])
2188 self.assertEqual(td.set_ops, [dict])
2189 self.assertEqual(td.data[dict], g.registry[object])
2190 self.assertEqual(g(l), "base")
2191 self.assertEqual(len(td), 2)
2192 self.assertEqual(td.get_ops, [])
2193 self.assertEqual(td.set_ops, [dict, list])
2194 self.assertEqual(td.data[dict], g.registry[object])
2195 self.assertEqual(td.data[list], g.registry[object])
2196 self.assertEqual(td.data[dict], td.data[list])
2197 self.assertEqual(g(l), "base")
2198 self.assertEqual(g(d), "base")
2199 self.assertEqual(td.get_ops, [list, dict])
2200 self.assertEqual(td.set_ops, [dict, list])
2201 g.register(list, lambda arg: "list")
2202 self.assertEqual(td.get_ops, [list, dict])
2203 self.assertEqual(len(td), 0)
2204 self.assertEqual(g(d), "base")
2205 self.assertEqual(len(td), 1)
2206 self.assertEqual(td.get_ops, [list, dict])
2207 self.assertEqual(td.set_ops, [dict, list, dict])
2208 self.assertEqual(td.data[dict],
2209 functools._find_impl(dict, g.registry))
2210 self.assertEqual(g(l), "list")
2211 self.assertEqual(len(td), 2)
2212 self.assertEqual(td.get_ops, [list, dict])
2213 self.assertEqual(td.set_ops, [dict, list, dict, list])
2214 self.assertEqual(td.data[list],
2215 functools._find_impl(list, g.registry))
2216 class X:
2217 pass
2218 c.MutableMapping.register(X) # Will not invalidate the cache,
2219 # not using ABCs yet.
2220 self.assertEqual(g(d), "base")
2221 self.assertEqual(g(l), "list")
2222 self.assertEqual(td.get_ops, [list, dict, dict, list])
2223 self.assertEqual(td.set_ops, [dict, list, dict, list])
2224 g.register(c.Sized, lambda arg: "sized")
2225 self.assertEqual(len(td), 0)
2226 self.assertEqual(g(d), "sized")
2227 self.assertEqual(len(td), 1)
2228 self.assertEqual(td.get_ops, [list, dict, dict, list])
2229 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
2230 self.assertEqual(g(l), "list")
2231 self.assertEqual(len(td), 2)
2232 self.assertEqual(td.get_ops, [list, dict, dict, list])
2233 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2234 self.assertEqual(g(l), "list")
2235 self.assertEqual(g(d), "sized")
2236 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
2237 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2238 g.dispatch(list)
2239 g.dispatch(dict)
2240 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
2241 list, dict])
2242 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2243 c.MutableSet.register(X) # Will invalidate the cache.
2244 self.assertEqual(len(td), 2) # Stale cache.
2245 self.assertEqual(g(l), "list")
2246 self.assertEqual(len(td), 1)
2247 g.register(c.MutableMapping, lambda arg: "mutablemapping")
2248 self.assertEqual(len(td), 0)
2249 self.assertEqual(g(d), "mutablemapping")
2250 self.assertEqual(len(td), 1)
2251 self.assertEqual(g(l), "list")
2252 self.assertEqual(len(td), 2)
2253 g.register(dict, lambda arg: "dict")
2254 self.assertEqual(g(d), "dict")
2255 self.assertEqual(g(l), "list")
2256 g._clear_cache()
2257 self.assertEqual(len(td), 0)
Łukasz Langa6f692512013-06-05 12:20:24 +02002258
Łukasz Langae5697532017-12-11 13:56:31 -08002259 def test_annotations(self):
2260 @functools.singledispatch
2261 def i(arg):
2262 return "base"
2263 @i.register
2264 def _(arg: collections.abc.Mapping):
2265 return "mapping"
2266 @i.register
2267 def _(arg: "collections.abc.Sequence"):
2268 return "sequence"
2269 self.assertEqual(i(None), "base")
2270 self.assertEqual(i({"a": 1}), "mapping")
2271 self.assertEqual(i([1, 2, 3]), "sequence")
2272 self.assertEqual(i((1, 2, 3)), "sequence")
2273 self.assertEqual(i("str"), "sequence")
2274
2275 # Registering classes as callables doesn't work with annotations,
2276 # you need to pass the type explicitly.
2277 @i.register(str)
2278 class _:
2279 def __init__(self, arg):
2280 self.arg = arg
2281
2282 def __eq__(self, other):
2283 return self.arg == other
2284 self.assertEqual(i("str"), "str")
2285
Ethan Smithc6512752018-05-26 16:38:33 -04002286 def test_method_register(self):
2287 class A:
2288 @functools.singledispatchmethod
2289 def t(self, arg):
2290 self.arg = "base"
2291 @t.register(int)
2292 def _(self, arg):
2293 self.arg = "int"
2294 @t.register(str)
2295 def _(self, arg):
2296 self.arg = "str"
2297 a = A()
2298
2299 a.t(0)
2300 self.assertEqual(a.arg, "int")
2301 aa = A()
2302 self.assertFalse(hasattr(aa, 'arg'))
2303 a.t('')
2304 self.assertEqual(a.arg, "str")
2305 aa = A()
2306 self.assertFalse(hasattr(aa, 'arg'))
2307 a.t(0.0)
2308 self.assertEqual(a.arg, "base")
2309 aa = A()
2310 self.assertFalse(hasattr(aa, 'arg'))
2311
2312 def test_staticmethod_register(self):
2313 class A:
2314 @functools.singledispatchmethod
2315 @staticmethod
2316 def t(arg):
2317 return arg
2318 @t.register(int)
2319 @staticmethod
2320 def _(arg):
2321 return isinstance(arg, int)
2322 @t.register(str)
2323 @staticmethod
2324 def _(arg):
2325 return isinstance(arg, str)
2326 a = A()
2327
2328 self.assertTrue(A.t(0))
2329 self.assertTrue(A.t(''))
2330 self.assertEqual(A.t(0.0), 0.0)
2331
2332 def test_classmethod_register(self):
2333 class A:
2334 def __init__(self, arg):
2335 self.arg = arg
2336
2337 @functools.singledispatchmethod
2338 @classmethod
2339 def t(cls, arg):
2340 return cls("base")
2341 @t.register(int)
2342 @classmethod
2343 def _(cls, arg):
2344 return cls("int")
2345 @t.register(str)
2346 @classmethod
2347 def _(cls, arg):
2348 return cls("str")
2349
2350 self.assertEqual(A.t(0).arg, "int")
2351 self.assertEqual(A.t('').arg, "str")
2352 self.assertEqual(A.t(0.0).arg, "base")
2353
2354 def test_callable_register(self):
2355 class A:
2356 def __init__(self, arg):
2357 self.arg = arg
2358
2359 @functools.singledispatchmethod
2360 @classmethod
2361 def t(cls, arg):
2362 return cls("base")
2363
2364 @A.t.register(int)
2365 @classmethod
2366 def _(cls, arg):
2367 return cls("int")
2368 @A.t.register(str)
2369 @classmethod
2370 def _(cls, arg):
2371 return cls("str")
2372
2373 self.assertEqual(A.t(0).arg, "int")
2374 self.assertEqual(A.t('').arg, "str")
2375 self.assertEqual(A.t(0.0).arg, "base")
2376
2377 def test_abstractmethod_register(self):
2378 class Abstract(abc.ABCMeta):
2379
2380 @functools.singledispatchmethod
2381 @abc.abstractmethod
2382 def add(self, x, y):
2383 pass
2384
2385 self.assertTrue(Abstract.add.__isabstractmethod__)
2386
2387 def test_type_ann_register(self):
2388 class A:
2389 @functools.singledispatchmethod
2390 def t(self, arg):
2391 return "base"
2392 @t.register
2393 def _(self, arg: int):
2394 return "int"
2395 @t.register
2396 def _(self, arg: str):
2397 return "str"
2398 a = A()
2399
2400 self.assertEqual(a.t(0), "int")
2401 self.assertEqual(a.t(''), "str")
2402 self.assertEqual(a.t(0.0), "base")
2403
Łukasz Langae5697532017-12-11 13:56:31 -08002404 def test_invalid_registrations(self):
2405 msg_prefix = "Invalid first argument to `register()`: "
2406 msg_suffix = (
2407 ". Use either `@register(some_class)` or plain `@register` on an "
2408 "annotated function."
2409 )
2410 @functools.singledispatch
2411 def i(arg):
2412 return "base"
2413 with self.assertRaises(TypeError) as exc:
2414 @i.register(42)
2415 def _(arg):
2416 return "I annotated with a non-type"
2417 self.assertTrue(str(exc.exception).startswith(msg_prefix + "42"))
2418 self.assertTrue(str(exc.exception).endswith(msg_suffix))
2419 with self.assertRaises(TypeError) as exc:
2420 @i.register
2421 def _(arg):
2422 return "I forgot to annotate"
2423 self.assertTrue(str(exc.exception).startswith(msg_prefix +
2424 "<function TestSingleDispatch.test_invalid_registrations.<locals>._"
2425 ))
2426 self.assertTrue(str(exc.exception).endswith(msg_suffix))
2427
Łukasz Langae5697532017-12-11 13:56:31 -08002428 with self.assertRaises(TypeError) as exc:
2429 @i.register
2430 def _(arg: typing.Iterable[str]):
2431 # At runtime, dispatching on generics is impossible.
2432 # When registering implementations with singledispatch, avoid
2433 # types from `typing`. Instead, annotate with regular types
2434 # or ABCs.
2435 return "I annotated with a generic collection"
Lysandros Nikolaoud6738102019-05-20 00:11:21 +02002436 self.assertTrue(str(exc.exception).startswith(
2437 "Invalid annotation for 'arg'."
Łukasz Langae5697532017-12-11 13:56:31 -08002438 ))
Lysandros Nikolaoud6738102019-05-20 00:11:21 +02002439 self.assertTrue(str(exc.exception).endswith(
2440 'typing.Iterable[str] is not a class.'
2441 ))
Łukasz Langae5697532017-12-11 13:56:31 -08002442
Dong-hee Na445f1b32018-07-10 16:26:36 +09002443 def test_invalid_positional_argument(self):
2444 @functools.singledispatch
2445 def f(*args):
2446 pass
2447 msg = 'f requires at least 1 positional argument'
INADA Naoki56d8f572018-07-17 13:44:47 +09002448 with self.assertRaisesRegex(TypeError, msg):
Dong-hee Na445f1b32018-07-10 16:26:36 +09002449 f()
Łukasz Langa6f692512013-06-05 12:20:24 +02002450
Carl Meyerd658dea2018-08-28 01:11:56 -06002451
2452class CachedCostItem:
2453 _cost = 1
2454
2455 def __init__(self):
2456 self.lock = py_functools.RLock()
2457
2458 @py_functools.cached_property
2459 def cost(self):
2460 """The cost of the item."""
2461 with self.lock:
2462 self._cost += 1
2463 return self._cost
2464
2465
2466class OptionallyCachedCostItem:
2467 _cost = 1
2468
2469 def get_cost(self):
2470 """The cost of the item."""
2471 self._cost += 1
2472 return self._cost
2473
2474 cached_cost = py_functools.cached_property(get_cost)
2475
2476
2477class CachedCostItemWait:
2478
2479 def __init__(self, event):
2480 self._cost = 1
2481 self.lock = py_functools.RLock()
2482 self.event = event
2483
2484 @py_functools.cached_property
2485 def cost(self):
2486 self.event.wait(1)
2487 with self.lock:
2488 self._cost += 1
2489 return self._cost
2490
2491
2492class CachedCostItemWithSlots:
2493 __slots__ = ('_cost')
2494
2495 def __init__(self):
2496 self._cost = 1
2497
2498 @py_functools.cached_property
2499 def cost(self):
2500 raise RuntimeError('never called, slots not supported')
2501
2502
2503class TestCachedProperty(unittest.TestCase):
2504 def test_cached(self):
2505 item = CachedCostItem()
2506 self.assertEqual(item.cost, 2)
2507 self.assertEqual(item.cost, 2) # not 3
2508
2509 def test_cached_attribute_name_differs_from_func_name(self):
2510 item = OptionallyCachedCostItem()
2511 self.assertEqual(item.get_cost(), 2)
2512 self.assertEqual(item.cached_cost, 3)
2513 self.assertEqual(item.get_cost(), 4)
2514 self.assertEqual(item.cached_cost, 3)
2515
2516 def test_threaded(self):
2517 go = threading.Event()
2518 item = CachedCostItemWait(go)
2519
2520 num_threads = 3
2521
2522 orig_si = sys.getswitchinterval()
2523 sys.setswitchinterval(1e-6)
2524 try:
2525 threads = [
2526 threading.Thread(target=lambda: item.cost)
2527 for k in range(num_threads)
2528 ]
Hai Shie80697d2020-05-28 06:10:27 +08002529 with threading_helper.start_threads(threads):
Carl Meyerd658dea2018-08-28 01:11:56 -06002530 go.set()
2531 finally:
2532 sys.setswitchinterval(orig_si)
2533
2534 self.assertEqual(item.cost, 2)
2535
2536 def test_object_with_slots(self):
2537 item = CachedCostItemWithSlots()
2538 with self.assertRaisesRegex(
2539 TypeError,
2540 "No '__dict__' attribute on 'CachedCostItemWithSlots' instance to cache 'cost' property.",
2541 ):
2542 item.cost
2543
2544 def test_immutable_dict(self):
2545 class MyMeta(type):
2546 @py_functools.cached_property
2547 def prop(self):
2548 return True
2549
2550 class MyClass(metaclass=MyMeta):
2551 pass
2552
2553 with self.assertRaisesRegex(
2554 TypeError,
2555 "The '__dict__' attribute on 'MyMeta' instance does not support item assignment for caching 'prop' property.",
2556 ):
2557 MyClass.prop
2558
2559 def test_reuse_different_names(self):
2560 """Disallow this case because decorated function a would not be cached."""
2561 with self.assertRaises(RuntimeError) as ctx:
2562 class ReusedCachedProperty:
2563 @py_functools.cached_property
2564 def a(self):
2565 pass
2566
2567 b = a
2568
2569 self.assertEqual(
2570 str(ctx.exception.__context__),
2571 str(TypeError("Cannot assign the same cached_property to two different names ('a' and 'b')."))
2572 )
2573
2574 def test_reuse_same_name(self):
2575 """Reusing a cached_property on different classes under the same name is OK."""
2576 counter = 0
2577
2578 @py_functools.cached_property
2579 def _cp(_self):
2580 nonlocal counter
2581 counter += 1
2582 return counter
2583
2584 class A:
2585 cp = _cp
2586
2587 class B:
2588 cp = _cp
2589
2590 a = A()
2591 b = B()
2592
2593 self.assertEqual(a.cp, 1)
2594 self.assertEqual(b.cp, 2)
2595 self.assertEqual(a.cp, 1)
2596
2597 def test_set_name_not_called(self):
2598 cp = py_functools.cached_property(lambda s: None)
2599 class Foo:
2600 pass
2601
2602 Foo.cp = cp
2603
2604 with self.assertRaisesRegex(
2605 TypeError,
2606 "Cannot use cached_property instance without calling __set_name__ on it.",
2607 ):
2608 Foo().cp
2609
2610 def test_access_from_class(self):
2611 self.assertIsInstance(CachedCostItem.cost, py_functools.cached_property)
2612
2613 def test_doc(self):
2614 self.assertEqual(CachedCostItem.cost.__doc__, "The cost of the item.")
2615
2616
Raymond Hettinger9c323f82005-02-28 19:39:44 +00002617if __name__ == '__main__':
Zachary Ware38c707e2015-04-13 15:00:43 -05002618 unittest.main()