blob: 78a8a5fcc0feaa00e41bc1bc666973674bb6b8ad [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08002import builtins
Raymond Hettinger003be522011-05-03 11:01:32 -07003import collections
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03004import collections.abc
Serhiy Storchaka45120f22015-10-24 09:49:56 +03005import copy
Pablo Galindo2f172d82020-06-01 00:41:14 +01006from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00007import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00008from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02009import sys
10from test import support
Antoine Pitroua6a4dc82017-09-07 18:56:24 +020011import threading
Serhiy Storchaka67796522017-01-12 18:34:33 +020012import time
Łukasz Langae5697532017-12-11 13:56:31 -080013import typing
Łukasz Langa6f692512013-06-05 12:20:24 +020014import unittest
Raymond Hettingerd191ef22017-01-07 20:44:48 -080015import unittest.mock
Pablo Galindo99e6c262020-01-23 15:29:52 +000016import os
Dennis Sweeney1253c3e2020-05-05 17:14:32 -040017import weakref
18import gc
Łukasz Langa6f692512013-06-05 12:20:24 +020019from weakref import proxy
Nick Coghlan457fc9a2016-09-10 20:00:02 +100020import contextlib
Raymond Hettinger9c323f82005-02-28 19:39:44 +000021
Hai Shi3ddc6342020-06-30 21:46:06 +080022from test.support import import_helper
Hai Shie80697d2020-05-28 06:10:27 +080023from test.support import threading_helper
Pablo Galindo99e6c262020-01-23 15:29:52 +000024from test.support.script_helper import assert_python_ok
25
Antoine Pitroub5b37142012-11-13 21:35:40 +010026import functools
27
Hai Shi3ddc6342020-06-30 21:46:06 +080028py_functools = import_helper.import_fresh_module('functools',
29 blocked=['_functools'])
Hai Shidd391232020-12-29 20:45:07 +080030c_functools = import_helper.import_fresh_module('functools')
Antoine Pitroub5b37142012-11-13 21:35:40 +010031
Hai Shi3ddc6342020-06-30 21:46:06 +080032decimal = import_helper.import_fresh_module('decimal', fresh=['_decimal'])
Łukasz Langa6f692512013-06-05 12:20:24 +020033
Nick Coghlan457fc9a2016-09-10 20:00:02 +100034@contextlib.contextmanager
35def replaced_module(name, replacement):
36 original_module = sys.modules[name]
37 sys.modules[name] = replacement
38 try:
39 yield
40 finally:
41 sys.modules[name] = original_module
Łukasz Langa6f692512013-06-05 12:20:24 +020042
Raymond Hettinger9c323f82005-02-28 19:39:44 +000043def capture(*args, **kw):
44 """capture all positional and keyword arguments"""
45 return args, kw
46
Łukasz Langa6f692512013-06-05 12:20:24 +020047
Jack Diederiche0cbd692009-04-01 04:27:09 +000048def signature(part):
49 """ return the signature of a partial object """
50 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000051
Serhiy Storchaka38741282016-02-02 18:45:17 +020052class MyTuple(tuple):
53 pass
54
55class BadTuple(tuple):
56 def __add__(self, other):
57 return list(self) + list(other)
58
59class MyDict(dict):
60 pass
61
Łukasz Langa6f692512013-06-05 12:20:24 +020062
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020063class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000064
65 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010066 p = self.partial(capture, 1, 2, a=10, b=20)
67 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000068 self.assertEqual(p(3, 4, b=30, c=40),
69 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010070 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000071 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000072
73 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010074 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000075 # attributes should be readable
76 self.assertEqual(p.func, capture)
77 self.assertEqual(p.args, (1, 2))
78 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000079
80 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010081 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000082 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010083 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000084 except TypeError:
85 pass
86 else:
87 self.fail('First arg not checked for callability')
88
89 def test_protection_of_callers_dict_argument(self):
90 # a caller's dictionary should not be altered by partial
91 def func(a=10, b=20):
92 return a
93 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010094 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000095 self.assertEqual(p(**d), 3)
96 self.assertEqual(d, {'a':3})
97 p(b=7)
98 self.assertEqual(d, {'a':3})
99
Serhiy Storchaka9639e4a2017-02-20 14:04:30 +0200100 def test_kwargs_copy(self):
101 # Issue #29532: Altering a kwarg dictionary passed to a constructor
102 # should not affect a partial object after creation
103 d = {'a': 3}
104 p = self.partial(capture, **d)
105 self.assertEqual(p(), ((), {'a': 3}))
106 d['a'] = 5
107 self.assertEqual(p(), ((), {'a': 3}))
108
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000109 def test_arg_combinations(self):
110 # exercise special code paths for zero args in either partial
111 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100112 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000113 self.assertEqual(p(), ((), {}))
114 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100115 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000116 self.assertEqual(p(), ((1,2), {}))
117 self.assertEqual(p(3,4), ((1,2,3,4), {}))
118
119 def test_kw_combinations(self):
120 # exercise special code paths for no keyword args in
121 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100122 p = self.partial(capture)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400123 self.assertEqual(p.keywords, {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000124 self.assertEqual(p(), ((), {}))
125 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100126 p = self.partial(capture, a=1)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400127 self.assertEqual(p.keywords, {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000128 self.assertEqual(p(), ((), {'a':1}))
129 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
130 # keyword args in the call override those in the partial object
131 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
132
133 def test_positional(self):
134 # make sure positional arguments are captured correctly
135 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100136 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000137 expected = args + ('x',)
138 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000139 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000140
141 def test_keyword(self):
142 # make sure keyword arguments are captured correctly
143 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100144 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000145 expected = {'a':a,'x':None}
146 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000147 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000148
149 def test_no_side_effects(self):
150 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100151 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000152 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000153 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000154 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000155 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000156
157 def test_error_propagation(self):
158 def f(x, y):
159 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100160 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
161 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
162 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
163 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000164
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000165 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100166 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000167 p = proxy(f)
168 self.assertEqual(f.func, p.func)
169 f = None
170 self.assertRaises(ReferenceError, getattr, p, 'func')
171
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000172 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000173 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100174 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000175 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100176 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000177 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000178
Alexander Belopolskye49af342015-03-01 15:08:17 -0500179 def test_nested_optimization(self):
180 partial = self.partial
Alexander Belopolskye49af342015-03-01 15:08:17 -0500181 inner = partial(signature, 'asdf')
182 nested = partial(inner, bar=True)
183 flat = partial(signature, 'asdf', bar=True)
184 self.assertEqual(signature(nested), signature(flat))
185
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300186 def test_nested_partial_with_attribute(self):
187 # see issue 25137
188 partial = self.partial
189
190 def foo(bar):
191 return bar
192
193 p = partial(foo, 'first')
194 p2 = partial(p, 'second')
195 p2.new_attr = 'spam'
196 self.assertEqual(p2.new_attr, 'spam')
197
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000198 def test_repr(self):
199 args = (object(), object())
200 args_repr = ', '.join(repr(a) for a in args)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200201 kwargs = {'a': object(), 'b': object()}
202 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
203 'b={b!r}, a={a!r}'.format_map(kwargs)]
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000204 if self.partial in (c_functools.partial, py_functools.partial):
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000205 name = 'functools.partial'
206 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100207 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000208
Antoine Pitroub5b37142012-11-13 21:35:40 +0100209 f = self.partial(capture)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000210 self.assertEqual(f'{name}({capture!r})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000211
Antoine Pitroub5b37142012-11-13 21:35:40 +0100212 f = self.partial(capture, *args)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000213 self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000214
Antoine Pitroub5b37142012-11-13 21:35:40 +0100215 f = self.partial(capture, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200216 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000217 [f'{name}({capture!r}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200218 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000219
Antoine Pitroub5b37142012-11-13 21:35:40 +0100220 f = self.partial(capture, *args, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200221 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000222 [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200223 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000224
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300225 def test_recursive_repr(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000226 if self.partial in (c_functools.partial, py_functools.partial):
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300227 name = 'functools.partial'
228 else:
229 name = self.partial.__name__
230
231 f = self.partial(capture)
232 f.__setstate__((f, (), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300233 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000234 self.assertEqual(repr(f), '%s(...)' % (name,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300235 finally:
236 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300237
238 f = self.partial(capture)
239 f.__setstate__((capture, (f,), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300240 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000241 self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300242 finally:
243 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300244
245 f = self.partial(capture)
246 f.__setstate__((capture, (), {'a': f}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300247 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000248 self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300249 finally:
250 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300251
Jack Diederiche0cbd692009-04-01 04:27:09 +0000252 def test_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000253 with self.AllowPickle():
254 f = self.partial(signature, ['asdf'], bar=[True])
255 f.attr = []
256 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
257 f_copy = pickle.loads(pickle.dumps(f, proto))
258 self.assertEqual(signature(f_copy), signature(f))
Serhiy Storchaka38741282016-02-02 18:45:17 +0200259
260 def test_copy(self):
261 f = self.partial(signature, ['asdf'], bar=[True])
262 f.attr = []
263 f_copy = copy.copy(f)
264 self.assertEqual(signature(f_copy), signature(f))
265 self.assertIs(f_copy.attr, f.attr)
266 self.assertIs(f_copy.args, f.args)
267 self.assertIs(f_copy.keywords, f.keywords)
268
269 def test_deepcopy(self):
270 f = self.partial(signature, ['asdf'], bar=[True])
271 f.attr = []
272 f_copy = copy.deepcopy(f)
273 self.assertEqual(signature(f_copy), signature(f))
274 self.assertIsNot(f_copy.attr, f.attr)
275 self.assertIsNot(f_copy.args, f.args)
276 self.assertIsNot(f_copy.args[0], f.args[0])
277 self.assertIsNot(f_copy.keywords, f.keywords)
278 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
279
280 def test_setstate(self):
281 f = self.partial(signature)
282 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000283
Serhiy Storchaka38741282016-02-02 18:45:17 +0200284 self.assertEqual(signature(f),
285 (capture, (1,), dict(a=10), dict(attr=[])))
286 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
287
288 f.__setstate__((capture, (1,), dict(a=10), None))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000289
Serhiy Storchaka38741282016-02-02 18:45:17 +0200290 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
291 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
292
293 f.__setstate__((capture, (1,), None, None))
294 #self.assertEqual(signature(f), (capture, (1,), {}, {}))
295 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
296 self.assertEqual(f(2), ((1, 2), {}))
297 self.assertEqual(f(), ((1,), {}))
298
299 f.__setstate__((capture, (), {}, None))
300 self.assertEqual(signature(f), (capture, (), {}, {}))
301 self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
302 self.assertEqual(f(2), ((2,), {}))
303 self.assertEqual(f(), ((), {}))
304
305 def test_setstate_errors(self):
306 f = self.partial(signature)
307 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
308 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
309 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
310 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
311 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
312 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
313 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
314
315 def test_setstate_subclasses(self):
316 f = self.partial(signature)
317 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
318 s = signature(f)
319 self.assertEqual(s, (capture, (1,), dict(a=10), {}))
320 self.assertIs(type(s[1]), tuple)
321 self.assertIs(type(s[2]), dict)
322 r = f()
323 self.assertEqual(r, ((1,), {'a': 10}))
324 self.assertIs(type(r[0]), tuple)
325 self.assertIs(type(r[1]), dict)
326
327 f.__setstate__((capture, BadTuple((1,)), {}, None))
328 s = signature(f)
329 self.assertEqual(s, (capture, (1,), {}, {}))
330 self.assertIs(type(s[1]), tuple)
331 r = f(2)
332 self.assertEqual(r, ((1, 2), {}))
333 self.assertIs(type(r[0]), tuple)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000334
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300335 def test_recursive_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000336 with self.AllowPickle():
337 f = self.partial(capture)
338 f.__setstate__((f, (), {}, {}))
339 try:
340 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
341 with self.assertRaises(RecursionError):
342 pickle.dumps(f, proto)
343 finally:
344 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300345
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000346 f = self.partial(capture)
347 f.__setstate__((capture, (f,), {}, {}))
348 try:
349 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
350 f_copy = pickle.loads(pickle.dumps(f, proto))
351 try:
352 self.assertIs(f_copy.args[0], f_copy)
353 finally:
354 f_copy.__setstate__((capture, (), {}, {}))
355 finally:
356 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300357
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000358 f = self.partial(capture)
359 f.__setstate__((capture, (), {'a': f}, {}))
360 try:
361 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
362 f_copy = pickle.loads(pickle.dumps(f, proto))
363 try:
364 self.assertIs(f_copy.keywords['a'], f_copy)
365 finally:
366 f_copy.__setstate__((capture, (), {}, {}))
367 finally:
368 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300369
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200370 # Issue 6083: Reference counting bug
371 def test_setstate_refcount(self):
372 class BadSequence:
373 def __len__(self):
374 return 4
375 def __getitem__(self, key):
376 if key == 0:
377 return max
378 elif key == 1:
379 return tuple(range(1000000))
380 elif key in (2, 3):
381 return {}
382 raise IndexError
383
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200384 f = self.partial(object)
Serhiy Storchaka38741282016-02-02 18:45:17 +0200385 self.assertRaises(TypeError, f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000386
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000387@unittest.skipUnless(c_functools, 'requires the C _functools module')
388class TestPartialC(TestPartial, unittest.TestCase):
389 if c_functools:
390 partial = c_functools.partial
391
392 class AllowPickle:
393 def __enter__(self):
394 return self
395 def __exit__(self, type, value, tb):
396 return False
397
398 def test_attributes_unwritable(self):
399 # attributes should not be writable
400 p = self.partial(capture, 1, 2, a=10, b=20)
401 self.assertRaises(AttributeError, setattr, p, 'func', map)
402 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
403 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
404
405 p = self.partial(hex)
406 try:
407 del p.__dict__
408 except TypeError:
409 pass
410 else:
411 self.fail('partial object allowed __dict__ to be deleted')
Łukasz Langa6f692512013-06-05 12:20:24 +0200412
Michael Seifert6c3d5272017-03-15 06:26:33 +0100413 def test_manually_adding_non_string_keyword(self):
414 p = self.partial(capture)
415 # Adding a non-string/unicode keyword to partial kwargs
416 p.keywords[1234] = 'value'
417 r = repr(p)
418 self.assertIn('1234', r)
419 self.assertIn("'value'", r)
420 with self.assertRaises(TypeError):
421 p()
422
423 def test_keystr_replaces_value(self):
424 p = self.partial(capture)
425
426 class MutatesYourDict(object):
427 def __str__(self):
428 p.keywords[self] = ['sth2']
429 return 'astr'
430
Mike53f7a7c2017-12-14 14:04:53 +0300431 # Replacing the value during key formatting should keep the original
Michael Seifert6c3d5272017-03-15 06:26:33 +0100432 # value alive (at least long enough).
433 p.keywords[MutatesYourDict()] = ['sth']
434 r = repr(p)
435 self.assertIn('astr', r)
436 self.assertIn("['sth']", r)
437
438
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200439class TestPartialPy(TestPartial, unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000440 partial = py_functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000441
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000442 class AllowPickle:
443 def __init__(self):
444 self._cm = replaced_module("functools", py_functools)
445 def __enter__(self):
446 return self._cm.__enter__()
447 def __exit__(self, type, value, tb):
448 return self._cm.__exit__(type, value, tb)
Łukasz Langa6f692512013-06-05 12:20:24 +0200449
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200450if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000451 class CPartialSubclass(c_functools.partial):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200452 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100453
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000454class PyPartialSubclass(py_functools.partial):
455 pass
Łukasz Langa6f692512013-06-05 12:20:24 +0200456
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200457@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200458class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200459 if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000460 partial = CPartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000461
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300462 # partial subclasses are not optimized for nested calls
463 test_nested_optimization = None
464
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000465class TestPartialPySubclass(TestPartialPy):
466 partial = PyPartialSubclass
Łukasz Langa6f692512013-06-05 12:20:24 +0200467
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000468class TestPartialMethod(unittest.TestCase):
469
470 class A(object):
471 nothing = functools.partialmethod(capture)
472 positional = functools.partialmethod(capture, 1)
473 keywords = functools.partialmethod(capture, a=2)
474 both = functools.partialmethod(capture, 3, b=4)
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300475 spec_keywords = functools.partialmethod(capture, self=1, func=2)
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000476
477 nested = functools.partialmethod(positional, 5)
478
479 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
480
481 static = functools.partialmethod(staticmethod(capture), 8)
482 cls = functools.partialmethod(classmethod(capture), d=9)
483
484 a = A()
485
486 def test_arg_combinations(self):
487 self.assertEqual(self.a.nothing(), ((self.a,), {}))
488 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
489 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
490 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
491
492 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
493 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
494 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
495 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
496
497 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
498 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
499 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
500 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
501
502 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
503 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
504 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
505 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
506
507 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
508
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300509 self.assertEqual(self.a.spec_keywords(), ((self.a,), {'self': 1, 'func': 2}))
510
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000511 def test_nested(self):
512 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
513 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
514 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
515 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
516
517 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
518
519 def test_over_partial(self):
520 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
521 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
522 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
523 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
524
525 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
526
527 def test_bound_method_introspection(self):
528 obj = self.a
529 self.assertIs(obj.both.__self__, obj)
530 self.assertIs(obj.nested.__self__, obj)
531 self.assertIs(obj.over_partial.__self__, obj)
532 self.assertIs(obj.cls.__self__, self.A)
533 self.assertIs(self.A.cls.__self__, self.A)
534
535 def test_unbound_method_retrieval(self):
536 obj = self.A
537 self.assertFalse(hasattr(obj.both, "__self__"))
538 self.assertFalse(hasattr(obj.nested, "__self__"))
539 self.assertFalse(hasattr(obj.over_partial, "__self__"))
540 self.assertFalse(hasattr(obj.static, "__self__"))
541 self.assertFalse(hasattr(self.a.static, "__self__"))
542
543 def test_descriptors(self):
544 for obj in [self.A, self.a]:
545 with self.subTest(obj=obj):
546 self.assertEqual(obj.static(), ((8,), {}))
547 self.assertEqual(obj.static(5), ((8, 5), {}))
548 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
549 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
550
551 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
552 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
553 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
554 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
555
556 def test_overriding_keywords(self):
557 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
558 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
559
560 def test_invalid_args(self):
561 with self.assertRaises(TypeError):
562 class B(object):
563 method = functools.partialmethod(None, 1)
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300564 with self.assertRaises(TypeError):
565 class B:
566 method = functools.partialmethod()
Serhiy Storchaka142566c2019-06-05 18:22:31 +0300567 with self.assertRaises(TypeError):
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300568 class B:
569 method = functools.partialmethod(func=capture, a=1)
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000570
571 def test_repr(self):
572 self.assertEqual(repr(vars(self.A)['both']),
573 'functools.partialmethod({}, 3, b=4)'.format(capture))
574
575 def test_abstract(self):
576 class Abstract(abc.ABCMeta):
577
578 @abc.abstractmethod
579 def add(self, x, y):
580 pass
581
582 add5 = functools.partialmethod(add, 5)
583
584 self.assertTrue(Abstract.add.__isabstractmethod__)
585 self.assertTrue(Abstract.add5.__isabstractmethod__)
586
587 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
588 self.assertFalse(getattr(func, '__isabstractmethod__', False))
589
Pablo Galindo8c77b8c2019-04-29 13:36:57 +0100590 def test_positional_only(self):
591 def f(a, b, /):
592 return a + b
593
594 p = functools.partial(f, 1)
595 self.assertEqual(p(2), f(1, 2))
596
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000597
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000598class TestUpdateWrapper(unittest.TestCase):
599
600 def check_wrapper(self, wrapper, wrapped,
601 assigned=functools.WRAPPER_ASSIGNMENTS,
602 updated=functools.WRAPPER_UPDATES):
603 # Check attributes were assigned
604 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000605 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000606 # Check attributes were updated
607 for name in updated:
608 wrapper_attr = getattr(wrapper, name)
609 wrapped_attr = getattr(wrapped, name)
610 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000611 if name == "__dict__" and key == "__wrapped__":
612 # __wrapped__ is overwritten by the update code
613 continue
614 self.assertIs(wrapped_attr[key], wrapper_attr[key])
615 # Check __wrapped__
616 self.assertIs(wrapper.__wrapped__, wrapped)
617
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000618
R. David Murray378c0cf2010-02-24 01:46:21 +0000619 def _default_update(self):
Pablo Galindob0544ba2021-04-21 12:41:19 +0100620 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000621 """This is a test"""
622 pass
623 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000624 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000625 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000626 pass
627 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000628 return wrapper, f
629
630 def test_default_update(self):
631 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000632 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000633 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000634 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600635 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000636 self.assertEqual(wrapper.attr, 'This is also a test')
Pablo Galindob0544ba2021-04-21 12:41:19 +0100637 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000638 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000639
R. David Murray378c0cf2010-02-24 01:46:21 +0000640 @unittest.skipIf(sys.flags.optimize >= 2,
641 "Docstrings are omitted with -O2 and above")
642 def test_default_update_doc(self):
643 wrapper, f = self._default_update()
644 self.assertEqual(wrapper.__doc__, 'This is a test')
645
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000646 def test_no_update(self):
647 def f():
648 """This is a test"""
649 pass
650 f.attr = 'This is also a test'
651 def wrapper():
652 pass
653 functools.update_wrapper(wrapper, f, (), ())
654 self.check_wrapper(wrapper, f, (), ())
655 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600656 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000657 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000658 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000659 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000660
661 def test_selective_update(self):
662 def f():
663 pass
664 f.attr = 'This is a different test'
665 f.dict_attr = dict(a=1, b=2, c=3)
666 def wrapper():
667 pass
668 wrapper.dict_attr = {}
669 assign = ('attr',)
670 update = ('dict_attr',)
671 functools.update_wrapper(wrapper, f, assign, update)
672 self.check_wrapper(wrapper, f, assign, update)
673 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600674 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000675 self.assertEqual(wrapper.__doc__, None)
676 self.assertEqual(wrapper.attr, 'This is a different test')
677 self.assertEqual(wrapper.dict_attr, f.dict_attr)
678
Nick Coghlan98876832010-08-17 06:17:18 +0000679 def test_missing_attributes(self):
680 def f():
681 pass
682 def wrapper():
683 pass
684 wrapper.dict_attr = {}
685 assign = ('attr',)
686 update = ('dict_attr',)
687 # Missing attributes on wrapped object are ignored
688 functools.update_wrapper(wrapper, f, assign, update)
689 self.assertNotIn('attr', wrapper.__dict__)
690 self.assertEqual(wrapper.dict_attr, {})
691 # Wrapper must have expected attributes for updating
692 del wrapper.dict_attr
693 with self.assertRaises(AttributeError):
694 functools.update_wrapper(wrapper, f, assign, update)
695 wrapper.dict_attr = 1
696 with self.assertRaises(AttributeError):
697 functools.update_wrapper(wrapper, f, assign, update)
698
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200699 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000700 @unittest.skipIf(sys.flags.optimize >= 2,
701 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000702 def test_builtin_update(self):
703 # Test for bug #1576241
704 def wrapper():
705 pass
706 functools.update_wrapper(wrapper, max)
707 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000708 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000709 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000710
Łukasz Langa6f692512013-06-05 12:20:24 +0200711
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000712class TestWraps(TestUpdateWrapper):
713
R. David Murray378c0cf2010-02-24 01:46:21 +0000714 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000715 def f():
716 """This is a test"""
717 pass
718 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000719 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000720 @functools.wraps(f)
721 def wrapper():
722 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600723 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000724
725 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600726 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000727 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000728 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600729 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000730 self.assertEqual(wrapper.attr, 'This is also a test')
731
Antoine Pitroub5b37142012-11-13 21:35:40 +0100732 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000733 "Docstrings are omitted with -O2 and above")
734 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600735 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000736 self.assertEqual(wrapper.__doc__, 'This is a test')
737
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000738 def test_no_update(self):
739 def f():
740 """This is a test"""
741 pass
742 f.attr = 'This is also a test'
743 @functools.wraps(f, (), ())
744 def wrapper():
745 pass
746 self.check_wrapper(wrapper, f, (), ())
747 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600748 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000749 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000750 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000751
752 def test_selective_update(self):
753 def f():
754 pass
755 f.attr = 'This is a different test'
756 f.dict_attr = dict(a=1, b=2, c=3)
757 def add_dict_attr(f):
758 f.dict_attr = {}
759 return f
760 assign = ('attr',)
761 update = ('dict_attr',)
762 @functools.wraps(f, assign, update)
763 @add_dict_attr
764 def wrapper():
765 pass
766 self.check_wrapper(wrapper, f, assign, update)
767 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600768 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000769 self.assertEqual(wrapper.__doc__, None)
770 self.assertEqual(wrapper.attr, 'This is a different test')
771 self.assertEqual(wrapper.dict_attr, f.dict_attr)
772
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000773
madman-bobe25d5fc2018-10-25 15:02:10 +0100774class TestReduce:
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000775 def test_reduce(self):
776 class Squares:
777 def __init__(self, max):
778 self.max = max
779 self.sofar = []
780
781 def __len__(self):
782 return len(self.sofar)
783
784 def __getitem__(self, i):
785 if not 0 <= i < self.max: raise IndexError
786 n = len(self.sofar)
787 while n <= i:
788 self.sofar.append(n*n)
789 n += 1
790 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000791 def add(x, y):
792 return x + y
madman-bobe25d5fc2018-10-25 15:02:10 +0100793 self.assertEqual(self.reduce(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000794 self.assertEqual(
madman-bobe25d5fc2018-10-25 15:02:10 +0100795 self.reduce(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000796 ['a','c','d','w']
797 )
madman-bobe25d5fc2018-10-25 15:02:10 +0100798 self.assertEqual(self.reduce(lambda x, y: x*y, range(2,8), 1), 5040)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000799 self.assertEqual(
madman-bobe25d5fc2018-10-25 15:02:10 +0100800 self.reduce(lambda x, y: x*y, range(2,21), 1),
Guido van Rossume2a383d2007-01-15 16:59:06 +0000801 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000802 )
madman-bobe25d5fc2018-10-25 15:02:10 +0100803 self.assertEqual(self.reduce(add, Squares(10)), 285)
804 self.assertEqual(self.reduce(add, Squares(10), 0), 285)
805 self.assertEqual(self.reduce(add, Squares(0), 0), 0)
806 self.assertRaises(TypeError, self.reduce)
807 self.assertRaises(TypeError, self.reduce, 42, 42)
808 self.assertRaises(TypeError, self.reduce, 42, 42, 42)
809 self.assertEqual(self.reduce(42, "1"), "1") # func is never called with one item
810 self.assertEqual(self.reduce(42, "", "1"), "1") # func is never called with one item
811 self.assertRaises(TypeError, self.reduce, 42, (42, 42))
812 self.assertRaises(TypeError, self.reduce, add, []) # arg 2 must not be empty sequence with no initial value
813 self.assertRaises(TypeError, self.reduce, add, "")
814 self.assertRaises(TypeError, self.reduce, add, ())
815 self.assertRaises(TypeError, self.reduce, add, object())
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000816
817 class TestFailingIter:
818 def __iter__(self):
819 raise RuntimeError
madman-bobe25d5fc2018-10-25 15:02:10 +0100820 self.assertRaises(RuntimeError, self.reduce, add, TestFailingIter())
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000821
madman-bobe25d5fc2018-10-25 15:02:10 +0100822 self.assertEqual(self.reduce(add, [], None), None)
823 self.assertEqual(self.reduce(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000824
825 class BadSeq:
826 def __getitem__(self, index):
827 raise ValueError
madman-bobe25d5fc2018-10-25 15:02:10 +0100828 self.assertRaises(ValueError, self.reduce, 42, BadSeq())
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000829
830 # Test reduce()'s use of iterators.
831 def test_iterator_usage(self):
832 class SequenceClass:
833 def __init__(self, n):
834 self.n = n
835 def __getitem__(self, i):
836 if 0 <= i < self.n:
837 return i
838 else:
839 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000840
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000841 from operator import add
madman-bobe25d5fc2018-10-25 15:02:10 +0100842 self.assertEqual(self.reduce(add, SequenceClass(5)), 10)
843 self.assertEqual(self.reduce(add, SequenceClass(5), 42), 52)
844 self.assertRaises(TypeError, self.reduce, add, SequenceClass(0))
845 self.assertEqual(self.reduce(add, SequenceClass(0), 42), 42)
846 self.assertEqual(self.reduce(add, SequenceClass(1)), 0)
847 self.assertEqual(self.reduce(add, SequenceClass(1), 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000848
849 d = {"one": 1, "two": 2, "three": 3}
madman-bobe25d5fc2018-10-25 15:02:10 +0100850 self.assertEqual(self.reduce(add, d), "".join(d.keys()))
851
852
853@unittest.skipUnless(c_functools, 'requires the C _functools module')
854class TestReduceC(TestReduce, unittest.TestCase):
855 if c_functools:
856 reduce = c_functools.reduce
857
858
859class TestReducePy(TestReduce, unittest.TestCase):
860 reduce = staticmethod(py_functools.reduce)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000861
Łukasz Langa6f692512013-06-05 12:20:24 +0200862
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200863class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700864
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000865 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700866 def cmp1(x, y):
867 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100868 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700869 self.assertEqual(key(3), key(3))
870 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100871 self.assertGreaterEqual(key(3), key(3))
872
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700873 def cmp2(x, y):
874 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100875 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700876 self.assertEqual(key(4.0), key('4'))
877 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100878 self.assertLessEqual(key(2), key('35'))
879 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700880
881 def test_cmp_to_key_arguments(self):
882 def cmp1(x, y):
883 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100884 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700885 self.assertEqual(key(obj=3), key(obj=3))
886 self.assertGreater(key(obj=3), key(obj=1))
887 with self.assertRaises((TypeError, AttributeError)):
888 key(3) > 1 # rhs is not a K object
889 with self.assertRaises((TypeError, AttributeError)):
890 1 < key(3) # lhs is not a K object
891 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100892 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700893 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200894 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100895 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700896 with self.assertRaises(TypeError):
897 key() # too few args
898 with self.assertRaises(TypeError):
899 key(None, None) # too many args
900
901 def test_bad_cmp(self):
902 def cmp1(x, y):
903 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100904 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700905 with self.assertRaises(ZeroDivisionError):
906 key(3) > key(1)
907
908 class BadCmp:
909 def __lt__(self, other):
910 raise ZeroDivisionError
911 def cmp1(x, y):
912 return BadCmp()
913 with self.assertRaises(ZeroDivisionError):
914 key(3) > key(1)
915
916 def test_obj_field(self):
917 def cmp1(x, y):
918 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100919 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700920 self.assertEqual(key(50).obj, 50)
921
922 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000923 def mycmp(x, y):
924 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100925 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000926 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000927
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700928 def test_sort_int_str(self):
929 def mycmp(x, y):
930 x, y = int(x), int(y)
931 return (x > y) - (x < y)
932 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100933 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700934 self.assertEqual([int(value) for value in values],
935 [0, 1, 1, 2, 3, 4, 5, 7, 10])
936
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000937 def test_hash(self):
938 def mycmp(x, y):
939 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100940 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000941 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700942 self.assertRaises(TypeError, hash, k)
Serhiy Storchaka2e576f52017-04-24 09:05:00 +0300943 self.assertNotIsInstance(k, collections.abc.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000944
Łukasz Langa6f692512013-06-05 12:20:24 +0200945
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200946@unittest.skipUnless(c_functools, 'requires the C _functools module')
947class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
948 if c_functools:
949 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100950
Erlend Egeberg Aasland9746cda2021-04-30 16:04:57 +0200951 @support.cpython_only
952 def test_disallow_instantiation(self):
953 # Ensure that the type disallows instantiation (bpo-43916)
Erlend Egeberg Aasland0a3452e2021-06-24 01:46:25 +0200954 support.check_disallow_instantiation(
955 self, type(c_functools.cmp_to_key(None))
956 )
Erlend Egeberg Aasland9746cda2021-04-30 16:04:57 +0200957
Łukasz Langa6f692512013-06-05 12:20:24 +0200958
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200959class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100960 cmp_to_key = staticmethod(py_functools.cmp_to_key)
961
Łukasz Langa6f692512013-06-05 12:20:24 +0200962
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000963class TestTotalOrdering(unittest.TestCase):
964
965 def test_total_ordering_lt(self):
966 @functools.total_ordering
967 class A:
968 def __init__(self, value):
969 self.value = value
970 def __lt__(self, other):
971 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000972 def __eq__(self, other):
973 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000974 self.assertTrue(A(1) < A(2))
975 self.assertTrue(A(2) > A(1))
976 self.assertTrue(A(1) <= A(2))
977 self.assertTrue(A(2) >= A(1))
978 self.assertTrue(A(2) <= A(2))
979 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000980 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000981
982 def test_total_ordering_le(self):
983 @functools.total_ordering
984 class A:
985 def __init__(self, value):
986 self.value = value
987 def __le__(self, other):
988 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000989 def __eq__(self, other):
990 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000991 self.assertTrue(A(1) < A(2))
992 self.assertTrue(A(2) > A(1))
993 self.assertTrue(A(1) <= A(2))
994 self.assertTrue(A(2) >= A(1))
995 self.assertTrue(A(2) <= A(2))
996 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000997 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000998
999 def test_total_ordering_gt(self):
1000 @functools.total_ordering
1001 class A:
1002 def __init__(self, value):
1003 self.value = value
1004 def __gt__(self, other):
1005 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001006 def __eq__(self, other):
1007 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +00001008 self.assertTrue(A(1) < A(2))
1009 self.assertTrue(A(2) > A(1))
1010 self.assertTrue(A(1) <= A(2))
1011 self.assertTrue(A(2) >= A(1))
1012 self.assertTrue(A(2) <= A(2))
1013 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +10001014 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +00001015
1016 def test_total_ordering_ge(self):
1017 @functools.total_ordering
1018 class A:
1019 def __init__(self, value):
1020 self.value = value
1021 def __ge__(self, other):
1022 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001023 def __eq__(self, other):
1024 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +00001025 self.assertTrue(A(1) < A(2))
1026 self.assertTrue(A(2) > A(1))
1027 self.assertTrue(A(1) <= A(2))
1028 self.assertTrue(A(2) >= A(1))
1029 self.assertTrue(A(2) <= A(2))
1030 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +10001031 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +00001032
1033 def test_total_ordering_no_overwrite(self):
1034 # new methods should not overwrite existing
1035 @functools.total_ordering
1036 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +00001037 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +00001038 self.assertTrue(A(1) < A(2))
1039 self.assertTrue(A(2) > A(1))
1040 self.assertTrue(A(1) <= A(2))
1041 self.assertTrue(A(2) >= A(1))
1042 self.assertTrue(A(2) <= A(2))
1043 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001044
Benjamin Peterson42ebee32010-04-11 01:43:16 +00001045 def test_no_operations_defined(self):
1046 with self.assertRaises(ValueError):
1047 @functools.total_ordering
1048 class A:
1049 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001050
Nick Coghlanf05d9812013-10-02 00:02:03 +10001051 def test_type_error_when_not_implemented(self):
1052 # bug 10042; ensure stack overflow does not occur
1053 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001054 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +10001055 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001056 def __init__(self, value):
1057 self.value = value
1058 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001059 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001060 return self.value == other.value
1061 return False
1062 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001063 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001064 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +10001065 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001066
Nick Coghlanf05d9812013-10-02 00:02:03 +10001067 @functools.total_ordering
1068 class ImplementsGreaterThan:
1069 def __init__(self, value):
1070 self.value = value
1071 def __eq__(self, other):
1072 if isinstance(other, ImplementsGreaterThan):
1073 return self.value == other.value
1074 return False
1075 def __gt__(self, other):
1076 if isinstance(other, ImplementsGreaterThan):
1077 return self.value > other.value
1078 return NotImplemented
1079
1080 @functools.total_ordering
1081 class ImplementsLessThanEqualTo:
1082 def __init__(self, value):
1083 self.value = value
1084 def __eq__(self, other):
1085 if isinstance(other, ImplementsLessThanEqualTo):
1086 return self.value == other.value
1087 return False
1088 def __le__(self, other):
1089 if isinstance(other, ImplementsLessThanEqualTo):
1090 return self.value <= other.value
1091 return NotImplemented
1092
1093 @functools.total_ordering
1094 class ImplementsGreaterThanEqualTo:
1095 def __init__(self, value):
1096 self.value = value
1097 def __eq__(self, other):
1098 if isinstance(other, ImplementsGreaterThanEqualTo):
1099 return self.value == other.value
1100 return False
1101 def __ge__(self, other):
1102 if isinstance(other, ImplementsGreaterThanEqualTo):
1103 return self.value >= other.value
1104 return NotImplemented
1105
1106 @functools.total_ordering
1107 class ComparatorNotImplemented:
1108 def __init__(self, value):
1109 self.value = value
1110 def __eq__(self, other):
1111 if isinstance(other, ComparatorNotImplemented):
1112 return self.value == other.value
1113 return False
1114 def __lt__(self, other):
1115 return NotImplemented
1116
1117 with self.subTest("LT < 1"), self.assertRaises(TypeError):
1118 ImplementsLessThan(-1) < 1
1119
1120 with self.subTest("LT < LE"), self.assertRaises(TypeError):
1121 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
1122
1123 with self.subTest("LT < GT"), self.assertRaises(TypeError):
1124 ImplementsLessThan(1) < ImplementsGreaterThan(1)
1125
1126 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
1127 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
1128
1129 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
1130 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
1131
1132 with self.subTest("GT > GE"), self.assertRaises(TypeError):
1133 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
1134
1135 with self.subTest("GT > LT"), self.assertRaises(TypeError):
1136 ImplementsGreaterThan(5) > ImplementsLessThan(5)
1137
1138 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
1139 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
1140
1141 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
1142 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
1143
1144 with self.subTest("GE when equal"):
1145 a = ComparatorNotImplemented(8)
1146 b = ComparatorNotImplemented(8)
1147 self.assertEqual(a, b)
1148 with self.assertRaises(TypeError):
1149 a >= b
1150
1151 with self.subTest("LE when equal"):
1152 a = ComparatorNotImplemented(9)
1153 b = ComparatorNotImplemented(9)
1154 self.assertEqual(a, b)
1155 with self.assertRaises(TypeError):
1156 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +02001157
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001158 def test_pickle(self):
Serhiy Storchaka92bb90a2016-09-22 11:39:25 +03001159 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001160 for name in '__lt__', '__gt__', '__le__', '__ge__':
1161 with self.subTest(method=name, proto=proto):
1162 method = getattr(Orderable_LT, name)
1163 method_copy = pickle.loads(pickle.dumps(method, proto))
1164 self.assertIs(method_copy, method)
1165
1166@functools.total_ordering
1167class Orderable_LT:
1168 def __init__(self, value):
1169 self.value = value
1170 def __lt__(self, other):
1171 return self.value < other.value
1172 def __eq__(self, other):
1173 return self.value == other.value
1174
1175
Raymond Hettinger21cdb712020-05-11 17:00:53 -07001176class TestCache:
1177 # This tests that the pass-through is working as designed.
1178 # The underlying functionality is tested in TestLRU.
1179
1180 def test_cache(self):
1181 @self.module.cache
1182 def fib(n):
1183 if n < 2:
1184 return n
1185 return fib(n-1) + fib(n-2)
1186 self.assertEqual([fib(n) for n in range(16)],
1187 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1188 self.assertEqual(fib.cache_info(),
1189 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1190 fib.cache_clear()
1191 self.assertEqual(fib.cache_info(),
1192 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1193
1194
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001195class TestLRU:
Georg Brandl2e7346a2010-07-31 18:09:23 +00001196
1197 def test_lru(self):
1198 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001199 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001200 f = self.module.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001201 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001202 self.assertEqual(maxsize, 20)
1203 self.assertEqual(currsize, 0)
1204 self.assertEqual(hits, 0)
1205 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001206
1207 domain = range(5)
1208 for i in range(1000):
1209 x, y = choice(domain), choice(domain)
1210 actual = f(x, y)
1211 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +00001212 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001213 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001214 self.assertTrue(hits > misses)
1215 self.assertEqual(hits + misses, 1000)
1216 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001217
Raymond Hettinger02566ec2010-09-04 22:46:06 +00001218 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +00001219 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001220 self.assertEqual(hits, 0)
1221 self.assertEqual(misses, 0)
1222 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001223 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001224 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001225 self.assertEqual(hits, 0)
1226 self.assertEqual(misses, 1)
1227 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001228
Nick Coghlan98876832010-08-17 06:17:18 +00001229 # Test bypassing the cache
1230 self.assertIs(f.__wrapped__, orig)
1231 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001232 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001233 self.assertEqual(hits, 0)
1234 self.assertEqual(misses, 1)
1235 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +00001236
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001237 # test size zero (which means "never-cache")
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001238 @self.module.lru_cache(0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001239 def f():
1240 nonlocal f_cnt
1241 f_cnt += 1
1242 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001243 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001244 f_cnt = 0
1245 for i in range(5):
1246 self.assertEqual(f(), 20)
1247 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001248 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001249 self.assertEqual(hits, 0)
1250 self.assertEqual(misses, 5)
1251 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001252
1253 # test size one
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001254 @self.module.lru_cache(1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001255 def f():
1256 nonlocal f_cnt
1257 f_cnt += 1
1258 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001259 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001260 f_cnt = 0
1261 for i in range(5):
1262 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001263 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001264 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001265 self.assertEqual(hits, 4)
1266 self.assertEqual(misses, 1)
1267 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001268
Raymond Hettingerf3098282010-08-15 03:30:45 +00001269 # test size two
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001270 @self.module.lru_cache(2)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001271 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001272 nonlocal f_cnt
1273 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +00001274 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +00001275 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001276 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +00001277 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1278 # * * * *
1279 self.assertEqual(f(x), x*10)
1280 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001281 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001282 self.assertEqual(hits, 12)
1283 self.assertEqual(misses, 4)
1284 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001285
Raymond Hettingerb8218682019-05-26 11:27:35 -07001286 def test_lru_no_args(self):
1287 @self.module.lru_cache
1288 def square(x):
1289 return x ** 2
1290
1291 self.assertEqual(list(map(square, [10, 20, 10])),
1292 [100, 400, 100])
1293 self.assertEqual(square.cache_info().hits, 1)
1294 self.assertEqual(square.cache_info().misses, 2)
1295 self.assertEqual(square.cache_info().maxsize, 128)
1296 self.assertEqual(square.cache_info().currsize, 2)
1297
Raymond Hettingerd8080c02019-01-26 03:02:00 -05001298 def test_lru_bug_35780(self):
1299 # C version of the lru_cache was not checking to see if
1300 # the user function call has already modified the cache
1301 # (this arises in recursive calls and in multi-threading).
1302 # This cause the cache to have orphan links not referenced
1303 # by the cache dictionary.
1304
1305 once = True # Modified by f(x) below
1306
1307 @self.module.lru_cache(maxsize=10)
1308 def f(x):
1309 nonlocal once
1310 rv = f'.{x}.'
1311 if x == 20 and once:
1312 once = False
1313 rv = f(x)
1314 return rv
1315
1316 # Fill the cache
1317 for x in range(15):
1318 self.assertEqual(f(x), f'.{x}.')
1319 self.assertEqual(f.cache_info().currsize, 10)
1320
1321 # Make a recursive call and make sure the cache remains full
1322 self.assertEqual(f(20), '.20.')
1323 self.assertEqual(f.cache_info().currsize, 10)
1324
Raymond Hettinger14adbd42019-04-20 07:20:44 -10001325 def test_lru_bug_36650(self):
1326 # C version of lru_cache was treating a call with an empty **kwargs
1327 # dictionary as being distinct from a call with no keywords at all.
1328 # This did not result in an incorrect answer, but it did trigger
1329 # an unexpected cache miss.
1330
1331 @self.module.lru_cache()
1332 def f(x):
1333 pass
1334
1335 f(0)
1336 f(0, **{})
1337 self.assertEqual(f.cache_info().hits, 1)
1338
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001339 def test_lru_hash_only_once(self):
1340 # To protect against weird reentrancy bugs and to improve
1341 # efficiency when faced with slow __hash__ methods, the
1342 # LRU cache guarantees that it will only call __hash__
1343 # only once per use as an argument to the cached function.
1344
1345 @self.module.lru_cache(maxsize=1)
1346 def f(x, y):
1347 return x * 3 + y
1348
1349 # Simulate the integer 5
1350 mock_int = unittest.mock.Mock()
1351 mock_int.__mul__ = unittest.mock.Mock(return_value=15)
1352 mock_int.__hash__ = unittest.mock.Mock(return_value=999)
1353
1354 # Add to cache: One use as an argument gives one call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001355 self.assertEqual(f(mock_int, 1), 16)
1356 self.assertEqual(mock_int.__hash__.call_count, 1)
1357 self.assertEqual(f.cache_info(), (0, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001358
1359 # Cache hit: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001360 self.assertEqual(f(mock_int, 1), 16)
1361 self.assertEqual(mock_int.__hash__.call_count, 2)
1362 self.assertEqual(f.cache_info(), (1, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001363
Ville Skyttä49b27342017-08-03 09:00:59 +03001364 # Cache eviction: No use as an argument gives no additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001365 self.assertEqual(f(6, 2), 20)
1366 self.assertEqual(mock_int.__hash__.call_count, 2)
1367 self.assertEqual(f.cache_info(), (1, 2, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001368
1369 # Cache miss: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001370 self.assertEqual(f(mock_int, 1), 16)
1371 self.assertEqual(mock_int.__hash__.call_count, 3)
1372 self.assertEqual(f.cache_info(), (1, 3, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001373
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08001374 def test_lru_reentrancy_with_len(self):
1375 # Test to make sure the LRU cache code isn't thrown-off by
1376 # caching the built-in len() function. Since len() can be
1377 # cached, we shouldn't use it inside the lru code itself.
1378 old_len = builtins.len
1379 try:
1380 builtins.len = self.module.lru_cache(4)(len)
1381 for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
1382 self.assertEqual(len('abcdefghijklmn'[:i]), i)
1383 finally:
1384 builtins.len = old_len
1385
Raymond Hettinger605a4472017-01-09 07:50:19 -08001386 def test_lru_star_arg_handling(self):
1387 # Test regression that arose in ea064ff3c10f
1388 @functools.lru_cache()
1389 def f(*args):
1390 return args
1391
1392 self.assertEqual(f(1, 2), (1, 2))
1393 self.assertEqual(f((1, 2)), ((1, 2),))
1394
Yury Selivanov46a02db2016-11-09 18:55:45 -05001395 def test_lru_type_error(self):
1396 # Regression test for issue #28653.
1397 # lru_cache was leaking when one of the arguments
1398 # wasn't cacheable.
1399
1400 @functools.lru_cache(maxsize=None)
1401 def infinite_cache(o):
1402 pass
1403
1404 @functools.lru_cache(maxsize=10)
1405 def limited_cache(o):
1406 pass
1407
1408 with self.assertRaises(TypeError):
1409 infinite_cache([])
1410
1411 with self.assertRaises(TypeError):
1412 limited_cache([])
1413
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001414 def test_lru_with_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001415 @self.module.lru_cache(maxsize=None)
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001416 def fib(n):
1417 if n < 2:
1418 return n
1419 return fib(n-1) + fib(n-2)
1420 self.assertEqual([fib(n) for n in range(16)],
1421 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1422 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001423 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001424 fib.cache_clear()
1425 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001426 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1427
1428 def test_lru_with_maxsize_negative(self):
1429 @self.module.lru_cache(maxsize=-10)
1430 def eq(n):
1431 return n
1432 for i in (0, 1):
1433 self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1434 self.assertEqual(eq.cache_info(),
Raymond Hettingerd8080c02019-01-26 03:02:00 -05001435 self.module._CacheInfo(hits=0, misses=300, maxsize=0, currsize=0))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001436
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001437 def test_lru_with_exceptions(self):
1438 # Verify that user_function exceptions get passed through without
1439 # creating a hard-to-read chained exception.
1440 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +01001441 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001442 @self.module.lru_cache(maxsize)
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001443 def func(i):
1444 return 'abc'[i]
1445 self.assertEqual(func(0), 'a')
1446 with self.assertRaises(IndexError) as cm:
1447 func(15)
1448 self.assertIsNone(cm.exception.__context__)
1449 # Verify that the previous exception did not result in a cached entry
1450 with self.assertRaises(IndexError):
1451 func(15)
1452
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001453 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001454 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001455 @self.module.lru_cache(maxsize=maxsize, typed=True)
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001456 def square(x):
1457 return x * x
1458 self.assertEqual(square(3), 9)
1459 self.assertEqual(type(square(3)), type(9))
1460 self.assertEqual(square(3.0), 9.0)
1461 self.assertEqual(type(square(3.0)), type(9.0))
1462 self.assertEqual(square(x=3), 9)
1463 self.assertEqual(type(square(x=3)), type(9))
1464 self.assertEqual(square(x=3.0), 9.0)
1465 self.assertEqual(type(square(x=3.0)), type(9.0))
1466 self.assertEqual(square.cache_info().hits, 4)
1467 self.assertEqual(square.cache_info().misses, 4)
1468
Antoine Pitroub5b37142012-11-13 21:35:40 +01001469 def test_lru_with_keyword_args(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001470 @self.module.lru_cache()
Antoine Pitroub5b37142012-11-13 21:35:40 +01001471 def fib(n):
1472 if n < 2:
1473 return n
1474 return fib(n=n-1) + fib(n=n-2)
1475 self.assertEqual(
1476 [fib(n=number) for number in range(16)],
1477 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1478 )
1479 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001480 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001481 fib.cache_clear()
1482 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001483 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001484
1485 def test_lru_with_keyword_args_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001486 @self.module.lru_cache(maxsize=None)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001487 def fib(n):
1488 if n < 2:
1489 return n
1490 return fib(n=n-1) + fib(n=n-2)
1491 self.assertEqual([fib(n=number) for number in range(16)],
1492 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1493 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001494 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001495 fib.cache_clear()
1496 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001497 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1498
Raymond Hettinger4ee39142017-01-08 17:28:20 -08001499 def test_kwargs_order(self):
1500 # PEP 468: Preserving Keyword Argument Order
1501 @self.module.lru_cache(maxsize=10)
1502 def f(**kwargs):
1503 return list(kwargs.items())
1504 self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)])
1505 self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)])
1506 self.assertEqual(f.cache_info(),
1507 self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2))
1508
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001509 def test_lru_cache_decoration(self):
1510 def f(zomg: 'zomg_annotation'):
1511 """f doc string"""
1512 return 42
1513 g = self.module.lru_cache()(f)
1514 for attr in self.module.WRAPPER_ASSIGNMENTS:
1515 self.assertEqual(getattr(g, attr), getattr(f, attr))
1516
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001517 def test_lru_cache_threaded(self):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001518 n, m = 5, 11
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001519 def orig(x, y):
1520 return 3 * x + y
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001521 f = self.module.lru_cache(maxsize=n*m)(orig)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001522 hits, misses, maxsize, currsize = f.cache_info()
1523 self.assertEqual(currsize, 0)
1524
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001525 start = threading.Event()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001526 def full(k):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001527 start.wait(10)
1528 for _ in range(m):
Serhiy Storchaka391af752015-06-08 12:44:18 +03001529 self.assertEqual(f(k, 0), orig(k, 0))
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001530
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001531 def clear():
1532 start.wait(10)
1533 for _ in range(2*m):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001534 f.cache_clear()
1535
1536 orig_si = sys.getswitchinterval()
Xavier de Gaye7522ef42016-12-08 11:06:56 +01001537 support.setswitchinterval(1e-6)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001538 try:
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001539 # create n threads in order to fill cache
Serhiy Storchaka391af752015-06-08 12:44:18 +03001540 threads = [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001541 for k in range(n)]
Hai Shie80697d2020-05-28 06:10:27 +08001542 with threading_helper.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001543 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001544
1545 hits, misses, maxsize, currsize = f.cache_info()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001546 if self.module is py_functools:
1547 # XXX: Why can be not equal?
1548 self.assertLessEqual(misses, n)
1549 self.assertLessEqual(hits, m*n - misses)
1550 else:
1551 self.assertEqual(misses, n)
1552 self.assertEqual(hits, m*n - misses)
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001553 self.assertEqual(currsize, n)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001554
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001555 # create n threads in order to fill cache and 1 to clear it
1556 threads = [threading.Thread(target=clear)]
Serhiy Storchaka391af752015-06-08 12:44:18 +03001557 threads += [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001558 for k in range(n)]
1559 start.clear()
Hai Shie80697d2020-05-28 06:10:27 +08001560 with threading_helper.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001561 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001562 finally:
1563 sys.setswitchinterval(orig_si)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001564
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001565 def test_lru_cache_threaded2(self):
1566 # Simultaneous call with the same arguments
1567 n, m = 5, 7
1568 start = threading.Barrier(n+1)
1569 pause = threading.Barrier(n+1)
1570 stop = threading.Barrier(n+1)
1571 @self.module.lru_cache(maxsize=m*n)
1572 def f(x):
1573 pause.wait(10)
1574 return 3 * x
1575 self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1576 def test():
1577 for i in range(m):
1578 start.wait(10)
1579 self.assertEqual(f(i), 3 * i)
1580 stop.wait(10)
1581 threads = [threading.Thread(target=test) for k in range(n)]
Hai Shie80697d2020-05-28 06:10:27 +08001582 with threading_helper.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001583 for i in range(m):
1584 start.wait(10)
1585 stop.reset()
1586 pause.wait(10)
1587 start.reset()
1588 stop.wait(10)
1589 pause.reset()
1590 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1591
Serhiy Storchaka67796522017-01-12 18:34:33 +02001592 def test_lru_cache_threaded3(self):
1593 @self.module.lru_cache(maxsize=2)
1594 def f(x):
1595 time.sleep(.01)
1596 return 3 * x
1597 def test(i, x):
1598 with self.subTest(thread=i):
1599 self.assertEqual(f(x), 3 * x, i)
1600 threads = [threading.Thread(target=test, args=(i, v))
1601 for i, v in enumerate([1, 2, 2, 3, 2])]
Hai Shie80697d2020-05-28 06:10:27 +08001602 with threading_helper.start_threads(threads):
Serhiy Storchaka67796522017-01-12 18:34:33 +02001603 pass
1604
Raymond Hettinger03923422013-03-04 02:52:50 -05001605 def test_need_for_rlock(self):
1606 # This will deadlock on an LRU cache that uses a regular lock
1607
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001608 @self.module.lru_cache(maxsize=10)
Raymond Hettinger03923422013-03-04 02:52:50 -05001609 def test_func(x):
1610 'Used to demonstrate a reentrant lru_cache call within a single thread'
1611 return x
1612
1613 class DoubleEq:
1614 'Demonstrate a reentrant lru_cache call within a single thread'
1615 def __init__(self, x):
1616 self.x = x
1617 def __hash__(self):
1618 return self.x
1619 def __eq__(self, other):
1620 if self.x == 2:
1621 test_func(DoubleEq(1))
1622 return self.x == other.x
1623
1624 test_func(DoubleEq(1)) # Load the cache
1625 test_func(DoubleEq(2)) # Load the cache
1626 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1627 DoubleEq(2)) # Verify the correct return value
1628
Serhiy Storchakae7070f02015-06-08 11:19:24 +03001629 def test_lru_method(self):
1630 class X(int):
1631 f_cnt = 0
1632 @self.module.lru_cache(2)
1633 def f(self, x):
1634 self.f_cnt += 1
1635 return x*10+self
1636 a = X(5)
1637 b = X(5)
1638 c = X(7)
1639 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1640
1641 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1642 self.assertEqual(a.f(x), x*10 + 5)
1643 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1644 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1645
1646 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1647 self.assertEqual(b.f(x), x*10 + 5)
1648 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1649 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1650
1651 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1652 self.assertEqual(c.f(x), x*10 + 7)
1653 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1654 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1655
1656 self.assertEqual(a.f.cache_info(), X.f.cache_info())
1657 self.assertEqual(b.f.cache_info(), X.f.cache_info())
1658 self.assertEqual(c.f.cache_info(), X.f.cache_info())
1659
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001660 def test_pickle(self):
1661 cls = self.__class__
1662 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1663 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1664 with self.subTest(proto=proto, func=f):
1665 f_copy = pickle.loads(pickle.dumps(f, proto))
1666 self.assertIs(f_copy, f)
1667
1668 def test_copy(self):
1669 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001670 def orig(x, y):
1671 return 3 * x + y
1672 part = self.module.partial(orig, 2)
1673 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1674 self.module.lru_cache(2)(part))
1675 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001676 with self.subTest(func=f):
1677 f_copy = copy.copy(f)
1678 self.assertIs(f_copy, f)
1679
1680 def test_deepcopy(self):
1681 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001682 def orig(x, y):
1683 return 3 * x + y
1684 part = self.module.partial(orig, 2)
1685 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1686 self.module.lru_cache(2)(part))
1687 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001688 with self.subTest(func=f):
1689 f_copy = copy.deepcopy(f)
1690 self.assertIs(f_copy, f)
1691
Manjusaka051ff522019-11-12 15:30:18 +08001692 def test_lru_cache_parameters(self):
1693 @self.module.lru_cache(maxsize=2)
1694 def f():
1695 return 1
1696 self.assertEqual(f.cache_parameters(), {'maxsize': 2, "typed": False})
1697
1698 @self.module.lru_cache(maxsize=1000, typed=True)
1699 def f():
1700 return 1
1701 self.assertEqual(f.cache_parameters(), {'maxsize': 1000, "typed": True})
1702
Dennis Sweeney1253c3e2020-05-05 17:14:32 -04001703 def test_lru_cache_weakrefable(self):
1704 @self.module.lru_cache
1705 def test_function(x):
1706 return x
1707
1708 class A:
1709 @self.module.lru_cache
1710 def test_method(self, x):
1711 return (self, x)
1712
1713 @staticmethod
1714 @self.module.lru_cache
1715 def test_staticmethod(x):
1716 return (self, x)
1717
1718 refs = [weakref.ref(test_function),
1719 weakref.ref(A.test_method),
1720 weakref.ref(A.test_staticmethod)]
1721
1722 for ref in refs:
1723 self.assertIsNotNone(ref())
1724
1725 del A
1726 del test_function
1727 gc.collect()
1728
1729 for ref in refs:
1730 self.assertIsNone(ref())
1731
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001732
1733@py_functools.lru_cache()
1734def py_cached_func(x, y):
1735 return 3 * x + y
1736
1737@c_functools.lru_cache()
1738def c_cached_func(x, y):
1739 return 3 * x + y
1740
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001741
1742class TestLRUPy(TestLRU, unittest.TestCase):
1743 module = py_functools
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001744 cached_func = py_cached_func,
1745
1746 @module.lru_cache()
1747 def cached_meth(self, x, y):
1748 return 3 * x + y
1749
1750 @staticmethod
1751 @module.lru_cache()
1752 def cached_staticmeth(x, y):
1753 return 3 * x + y
1754
1755
1756class TestLRUC(TestLRU, unittest.TestCase):
1757 module = c_functools
1758 cached_func = c_cached_func,
1759
1760 @module.lru_cache()
1761 def cached_meth(self, x, y):
1762 return 3 * x + y
1763
1764 @staticmethod
1765 @module.lru_cache()
1766 def cached_staticmeth(x, y):
1767 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001768
Raymond Hettinger03923422013-03-04 02:52:50 -05001769
Łukasz Langa6f692512013-06-05 12:20:24 +02001770class TestSingleDispatch(unittest.TestCase):
1771 def test_simple_overloads(self):
1772 @functools.singledispatch
1773 def g(obj):
1774 return "base"
1775 def g_int(i):
1776 return "integer"
1777 g.register(int, g_int)
1778 self.assertEqual(g("str"), "base")
1779 self.assertEqual(g(1), "integer")
1780 self.assertEqual(g([1,2,3]), "base")
1781
1782 def test_mro(self):
1783 @functools.singledispatch
1784 def g(obj):
1785 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001786 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001787 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001788 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001789 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001790 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001791 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001792 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001793 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001794 def g_A(a):
1795 return "A"
1796 def g_B(b):
1797 return "B"
1798 g.register(A, g_A)
1799 g.register(B, g_B)
1800 self.assertEqual(g(A()), "A")
1801 self.assertEqual(g(B()), "B")
1802 self.assertEqual(g(C()), "A")
1803 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001804
1805 def test_register_decorator(self):
1806 @functools.singledispatch
1807 def g(obj):
1808 return "base"
1809 @g.register(int)
1810 def g_int(i):
1811 return "int %s" % (i,)
1812 self.assertEqual(g(""), "base")
1813 self.assertEqual(g(12), "int 12")
1814 self.assertIs(g.dispatch(int), g_int)
1815 self.assertIs(g.dispatch(object), g.dispatch(str))
1816 # Note: in the assert above this is not g.
1817 # @singledispatch returns the wrapper.
1818
1819 def test_wrapping_attributes(self):
1820 @functools.singledispatch
1821 def g(obj):
1822 "Simple test"
1823 return "Test"
1824 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02001825 if sys.flags.optimize < 2:
1826 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02001827
1828 @unittest.skipUnless(decimal, 'requires _decimal')
1829 @support.cpython_only
1830 def test_c_classes(self):
1831 @functools.singledispatch
1832 def g(obj):
1833 return "base"
1834 @g.register(decimal.DecimalException)
1835 def _(obj):
1836 return obj.args
1837 subn = decimal.Subnormal("Exponent < Emin")
1838 rnd = decimal.Rounded("Number got rounded")
1839 self.assertEqual(g(subn), ("Exponent < Emin",))
1840 self.assertEqual(g(rnd), ("Number got rounded",))
1841 @g.register(decimal.Subnormal)
1842 def _(obj):
1843 return "Too small to care."
1844 self.assertEqual(g(subn), "Too small to care.")
1845 self.assertEqual(g(rnd), ("Number got rounded",))
1846
1847 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001848 # None of the examples in this test depend on haystack ordering.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001849 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001850 mro = functools._compose_mro
1851 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1852 for haystack in permutations(bases):
1853 m = mro(dict, haystack)
Guido van Rossumf0666942016-08-23 10:47:07 -07001854 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
1855 c.Collection, c.Sized, c.Iterable,
1856 c.Container, object])
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001857 bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict]
Łukasz Langa6f692512013-06-05 12:20:24 +02001858 for haystack in permutations(bases):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001859 m = mro(collections.ChainMap, haystack)
1860 self.assertEqual(m, [collections.ChainMap, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001861 c.Collection, c.Sized, c.Iterable,
1862 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001863
1864 # If there's a generic function with implementations registered for
1865 # both Sized and Container, passing a defaultdict to it results in an
1866 # ambiguous dispatch which will cause a RuntimeError (see
1867 # test_mro_conflicts).
1868 bases = [c.Container, c.Sized, str]
1869 for haystack in permutations(bases):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001870 m = mro(collections.defaultdict, [c.Sized, c.Container, str])
1871 self.assertEqual(m, [collections.defaultdict, dict, c.Sized,
1872 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001873
1874 # MutableSequence below is registered directly on D. In other words, it
Martin Panter46f50722016-05-26 05:35:26 +00001875 # precedes MutableMapping which means single dispatch will always
Łukasz Langa3720c772013-07-01 16:00:38 +02001876 # choose MutableSequence here.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001877 class D(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02001878 pass
1879 c.MutableSequence.register(D)
1880 bases = [c.MutableSequence, c.MutableMapping]
1881 for haystack in permutations(bases):
1882 m = mro(D, bases)
Guido van Rossumf0666942016-08-23 10:47:07 -07001883 self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001884 collections.defaultdict, dict, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001885 c.Collection, c.Sized, c.Iterable, c.Container,
Łukasz Langa3720c772013-07-01 16:00:38 +02001886 object])
1887
1888 # Container and Callable are registered on different base classes and
1889 # a generic function supporting both should always pick the Callable
1890 # implementation if a C instance is passed.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001891 class C(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02001892 def __call__(self):
1893 pass
1894 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1895 for haystack in permutations(bases):
1896 m = mro(C, haystack)
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001897 self.assertEqual(m, [C, c.Callable, collections.defaultdict, dict, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001898 c.Collection, c.Sized, c.Iterable,
1899 c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001900
1901 def test_register_abc(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001902 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001903 d = {"a": "b"}
1904 l = [1, 2, 3]
1905 s = {object(), None}
1906 f = frozenset(s)
1907 t = (1, 2, 3)
1908 @functools.singledispatch
1909 def g(obj):
1910 return "base"
1911 self.assertEqual(g(d), "base")
1912 self.assertEqual(g(l), "base")
1913 self.assertEqual(g(s), "base")
1914 self.assertEqual(g(f), "base")
1915 self.assertEqual(g(t), "base")
1916 g.register(c.Sized, lambda obj: "sized")
1917 self.assertEqual(g(d), "sized")
1918 self.assertEqual(g(l), "sized")
1919 self.assertEqual(g(s), "sized")
1920 self.assertEqual(g(f), "sized")
1921 self.assertEqual(g(t), "sized")
1922 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1923 self.assertEqual(g(d), "mutablemapping")
1924 self.assertEqual(g(l), "sized")
1925 self.assertEqual(g(s), "sized")
1926 self.assertEqual(g(f), "sized")
1927 self.assertEqual(g(t), "sized")
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001928 g.register(collections.ChainMap, lambda obj: "chainmap")
Łukasz Langa6f692512013-06-05 12:20:24 +02001929 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1930 self.assertEqual(g(l), "sized")
1931 self.assertEqual(g(s), "sized")
1932 self.assertEqual(g(f), "sized")
1933 self.assertEqual(g(t), "sized")
1934 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1935 self.assertEqual(g(d), "mutablemapping")
1936 self.assertEqual(g(l), "mutablesequence")
1937 self.assertEqual(g(s), "sized")
1938 self.assertEqual(g(f), "sized")
1939 self.assertEqual(g(t), "sized")
1940 g.register(c.MutableSet, lambda obj: "mutableset")
1941 self.assertEqual(g(d), "mutablemapping")
1942 self.assertEqual(g(l), "mutablesequence")
1943 self.assertEqual(g(s), "mutableset")
1944 self.assertEqual(g(f), "sized")
1945 self.assertEqual(g(t), "sized")
1946 g.register(c.Mapping, lambda obj: "mapping")
1947 self.assertEqual(g(d), "mutablemapping") # not specific enough
1948 self.assertEqual(g(l), "mutablesequence")
1949 self.assertEqual(g(s), "mutableset")
1950 self.assertEqual(g(f), "sized")
1951 self.assertEqual(g(t), "sized")
1952 g.register(c.Sequence, lambda obj: "sequence")
1953 self.assertEqual(g(d), "mutablemapping")
1954 self.assertEqual(g(l), "mutablesequence")
1955 self.assertEqual(g(s), "mutableset")
1956 self.assertEqual(g(f), "sized")
1957 self.assertEqual(g(t), "sequence")
1958 g.register(c.Set, lambda obj: "set")
1959 self.assertEqual(g(d), "mutablemapping")
1960 self.assertEqual(g(l), "mutablesequence")
1961 self.assertEqual(g(s), "mutableset")
1962 self.assertEqual(g(f), "set")
1963 self.assertEqual(g(t), "sequence")
1964 g.register(dict, lambda obj: "dict")
1965 self.assertEqual(g(d), "dict")
1966 self.assertEqual(g(l), "mutablesequence")
1967 self.assertEqual(g(s), "mutableset")
1968 self.assertEqual(g(f), "set")
1969 self.assertEqual(g(t), "sequence")
1970 g.register(list, lambda obj: "list")
1971 self.assertEqual(g(d), "dict")
1972 self.assertEqual(g(l), "list")
1973 self.assertEqual(g(s), "mutableset")
1974 self.assertEqual(g(f), "set")
1975 self.assertEqual(g(t), "sequence")
1976 g.register(set, lambda obj: "concrete-set")
1977 self.assertEqual(g(d), "dict")
1978 self.assertEqual(g(l), "list")
1979 self.assertEqual(g(s), "concrete-set")
1980 self.assertEqual(g(f), "set")
1981 self.assertEqual(g(t), "sequence")
1982 g.register(frozenset, lambda obj: "frozen-set")
1983 self.assertEqual(g(d), "dict")
1984 self.assertEqual(g(l), "list")
1985 self.assertEqual(g(s), "concrete-set")
1986 self.assertEqual(g(f), "frozen-set")
1987 self.assertEqual(g(t), "sequence")
1988 g.register(tuple, lambda obj: "tuple")
1989 self.assertEqual(g(d), "dict")
1990 self.assertEqual(g(l), "list")
1991 self.assertEqual(g(s), "concrete-set")
1992 self.assertEqual(g(f), "frozen-set")
1993 self.assertEqual(g(t), "tuple")
1994
Łukasz Langa3720c772013-07-01 16:00:38 +02001995 def test_c3_abc(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001996 c = collections.abc
Łukasz Langa3720c772013-07-01 16:00:38 +02001997 mro = functools._c3_mro
1998 class A(object):
1999 pass
2000 class B(A):
2001 def __len__(self):
2002 return 0 # implies Sized
2003 @c.Container.register
2004 class C(object):
2005 pass
2006 class D(object):
2007 pass # unrelated
2008 class X(D, C, B):
2009 def __call__(self):
2010 pass # implies Callable
2011 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
2012 for abcs in permutations([c.Sized, c.Callable, c.Container]):
2013 self.assertEqual(mro(X, abcs=abcs), expected)
2014 # unrelated ABCs don't appear in the resulting MRO
2015 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
2016 self.assertEqual(mro(X, abcs=many_abcs), expected)
2017
Yury Selivanov77a8cd62015-08-18 14:20:00 -04002018 def test_false_meta(self):
2019 # see issue23572
2020 class MetaA(type):
2021 def __len__(self):
2022 return 0
2023 class A(metaclass=MetaA):
2024 pass
2025 class AA(A):
2026 pass
2027 @functools.singledispatch
2028 def fun(a):
2029 return 'base A'
2030 @fun.register(A)
2031 def _(a):
2032 return 'fun A'
2033 aa = AA()
2034 self.assertEqual(fun(aa), 'fun A')
2035
Łukasz Langa6f692512013-06-05 12:20:24 +02002036 def test_mro_conflicts(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002037 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02002038 @functools.singledispatch
2039 def g(arg):
2040 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02002041 class O(c.Sized):
2042 def __len__(self):
2043 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02002044 o = O()
2045 self.assertEqual(g(o), "base")
2046 g.register(c.Iterable, lambda arg: "iterable")
2047 g.register(c.Container, lambda arg: "container")
2048 g.register(c.Sized, lambda arg: "sized")
2049 g.register(c.Set, lambda arg: "set")
2050 self.assertEqual(g(o), "sized")
2051 c.Iterable.register(O)
2052 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
2053 c.Container.register(O)
2054 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02002055 c.Set.register(O)
2056 self.assertEqual(g(o), "set") # because c.Set is a subclass of
2057 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02002058 class P:
2059 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02002060 p = P()
2061 self.assertEqual(g(p), "base")
2062 c.Iterable.register(P)
2063 self.assertEqual(g(p), "iterable")
2064 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02002065 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02002066 g(p)
Benjamin Petersonab078e92016-07-13 21:13:29 -07002067 self.assertIn(
2068 str(re_one.exception),
2069 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2070 "or <class 'collections.abc.Iterable'>"),
2071 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
2072 "or <class 'collections.abc.Container'>")),
2073 )
Łukasz Langa6f692512013-06-05 12:20:24 +02002074 class Q(c.Sized):
2075 def __len__(self):
2076 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02002077 q = Q()
2078 self.assertEqual(g(q), "sized")
2079 c.Iterable.register(Q)
2080 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
2081 c.Set.register(Q)
2082 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02002083 # c.Sized and c.Iterable
2084 @functools.singledispatch
2085 def h(arg):
2086 return "base"
2087 @h.register(c.Sized)
2088 def _(arg):
2089 return "sized"
2090 @h.register(c.Container)
2091 def _(arg):
2092 return "container"
2093 # Even though Sized and Container are explicit bases of MutableMapping,
2094 # this ABC is implicitly registered on defaultdict which makes all of
2095 # MutableMapping's bases implicit as well from defaultdict's
2096 # perspective.
2097 with self.assertRaises(RuntimeError) as re_two:
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002098 h(collections.defaultdict(lambda: 0))
Benjamin Petersonab078e92016-07-13 21:13:29 -07002099 self.assertIn(
2100 str(re_two.exception),
2101 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2102 "or <class 'collections.abc.Sized'>"),
2103 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2104 "or <class 'collections.abc.Container'>")),
2105 )
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002106 class R(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02002107 pass
2108 c.MutableSequence.register(R)
2109 @functools.singledispatch
2110 def i(arg):
2111 return "base"
2112 @i.register(c.MutableMapping)
2113 def _(arg):
2114 return "mapping"
2115 @i.register(c.MutableSequence)
2116 def _(arg):
2117 return "sequence"
2118 r = R()
2119 self.assertEqual(i(r), "sequence")
2120 class S:
2121 pass
2122 class T(S, c.Sized):
2123 def __len__(self):
2124 return 0
2125 t = T()
2126 self.assertEqual(h(t), "sized")
2127 c.Container.register(T)
2128 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
2129 class U:
2130 def __len__(self):
2131 return 0
2132 u = U()
2133 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
2134 # from the existence of __len__()
2135 c.Container.register(U)
2136 # There is no preference for registered versus inferred ABCs.
2137 with self.assertRaises(RuntimeError) as re_three:
2138 h(u)
Benjamin Petersonab078e92016-07-13 21:13:29 -07002139 self.assertIn(
2140 str(re_three.exception),
2141 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2142 "or <class 'collections.abc.Sized'>"),
2143 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2144 "or <class 'collections.abc.Container'>")),
2145 )
Łukasz Langa3720c772013-07-01 16:00:38 +02002146 class V(c.Sized, S):
2147 def __len__(self):
2148 return 0
2149 @functools.singledispatch
2150 def j(arg):
2151 return "base"
2152 @j.register(S)
2153 def _(arg):
2154 return "s"
2155 @j.register(c.Container)
2156 def _(arg):
2157 return "container"
2158 v = V()
2159 self.assertEqual(j(v), "s")
2160 c.Container.register(V)
2161 self.assertEqual(j(v), "container") # because it ends up right after
2162 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02002163
2164 def test_cache_invalidation(self):
2165 from collections import UserDict
INADA Naoki9811e802017-09-30 16:13:02 +09002166 import weakref
2167
Łukasz Langa6f692512013-06-05 12:20:24 +02002168 class TracingDict(UserDict):
2169 def __init__(self, *args, **kwargs):
2170 super(TracingDict, self).__init__(*args, **kwargs)
2171 self.set_ops = []
2172 self.get_ops = []
2173 def __getitem__(self, key):
2174 result = self.data[key]
2175 self.get_ops.append(key)
2176 return result
2177 def __setitem__(self, key, value):
2178 self.set_ops.append(key)
2179 self.data[key] = value
2180 def clear(self):
2181 self.data.clear()
INADA Naoki9811e802017-09-30 16:13:02 +09002182
Łukasz Langa6f692512013-06-05 12:20:24 +02002183 td = TracingDict()
INADA Naoki9811e802017-09-30 16:13:02 +09002184 with support.swap_attr(weakref, "WeakKeyDictionary", lambda: td):
2185 c = collections.abc
2186 @functools.singledispatch
2187 def g(arg):
2188 return "base"
2189 d = {}
2190 l = []
2191 self.assertEqual(len(td), 0)
2192 self.assertEqual(g(d), "base")
2193 self.assertEqual(len(td), 1)
2194 self.assertEqual(td.get_ops, [])
2195 self.assertEqual(td.set_ops, [dict])
2196 self.assertEqual(td.data[dict], g.registry[object])
2197 self.assertEqual(g(l), "base")
2198 self.assertEqual(len(td), 2)
2199 self.assertEqual(td.get_ops, [])
2200 self.assertEqual(td.set_ops, [dict, list])
2201 self.assertEqual(td.data[dict], g.registry[object])
2202 self.assertEqual(td.data[list], g.registry[object])
2203 self.assertEqual(td.data[dict], td.data[list])
2204 self.assertEqual(g(l), "base")
2205 self.assertEqual(g(d), "base")
2206 self.assertEqual(td.get_ops, [list, dict])
2207 self.assertEqual(td.set_ops, [dict, list])
2208 g.register(list, lambda arg: "list")
2209 self.assertEqual(td.get_ops, [list, dict])
2210 self.assertEqual(len(td), 0)
2211 self.assertEqual(g(d), "base")
2212 self.assertEqual(len(td), 1)
2213 self.assertEqual(td.get_ops, [list, dict])
2214 self.assertEqual(td.set_ops, [dict, list, dict])
2215 self.assertEqual(td.data[dict],
2216 functools._find_impl(dict, g.registry))
2217 self.assertEqual(g(l), "list")
2218 self.assertEqual(len(td), 2)
2219 self.assertEqual(td.get_ops, [list, dict])
2220 self.assertEqual(td.set_ops, [dict, list, dict, list])
2221 self.assertEqual(td.data[list],
2222 functools._find_impl(list, g.registry))
2223 class X:
2224 pass
2225 c.MutableMapping.register(X) # Will not invalidate the cache,
2226 # not using ABCs yet.
2227 self.assertEqual(g(d), "base")
2228 self.assertEqual(g(l), "list")
2229 self.assertEqual(td.get_ops, [list, dict, dict, list])
2230 self.assertEqual(td.set_ops, [dict, list, dict, list])
2231 g.register(c.Sized, lambda arg: "sized")
2232 self.assertEqual(len(td), 0)
2233 self.assertEqual(g(d), "sized")
2234 self.assertEqual(len(td), 1)
2235 self.assertEqual(td.get_ops, [list, dict, dict, list])
2236 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
2237 self.assertEqual(g(l), "list")
2238 self.assertEqual(len(td), 2)
2239 self.assertEqual(td.get_ops, [list, dict, dict, list])
2240 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2241 self.assertEqual(g(l), "list")
2242 self.assertEqual(g(d), "sized")
2243 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
2244 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2245 g.dispatch(list)
2246 g.dispatch(dict)
2247 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
2248 list, dict])
2249 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2250 c.MutableSet.register(X) # Will invalidate the cache.
2251 self.assertEqual(len(td), 2) # Stale cache.
2252 self.assertEqual(g(l), "list")
2253 self.assertEqual(len(td), 1)
2254 g.register(c.MutableMapping, lambda arg: "mutablemapping")
2255 self.assertEqual(len(td), 0)
2256 self.assertEqual(g(d), "mutablemapping")
2257 self.assertEqual(len(td), 1)
2258 self.assertEqual(g(l), "list")
2259 self.assertEqual(len(td), 2)
2260 g.register(dict, lambda arg: "dict")
2261 self.assertEqual(g(d), "dict")
2262 self.assertEqual(g(l), "list")
2263 g._clear_cache()
2264 self.assertEqual(len(td), 0)
Łukasz Langa6f692512013-06-05 12:20:24 +02002265
Łukasz Langae5697532017-12-11 13:56:31 -08002266 def test_annotations(self):
2267 @functools.singledispatch
2268 def i(arg):
2269 return "base"
2270 @i.register
2271 def _(arg: collections.abc.Mapping):
2272 return "mapping"
2273 @i.register
2274 def _(arg: "collections.abc.Sequence"):
2275 return "sequence"
2276 self.assertEqual(i(None), "base")
2277 self.assertEqual(i({"a": 1}), "mapping")
2278 self.assertEqual(i([1, 2, 3]), "sequence")
2279 self.assertEqual(i((1, 2, 3)), "sequence")
2280 self.assertEqual(i("str"), "sequence")
2281
2282 # Registering classes as callables doesn't work with annotations,
2283 # you need to pass the type explicitly.
2284 @i.register(str)
2285 class _:
2286 def __init__(self, arg):
2287 self.arg = arg
2288
2289 def __eq__(self, other):
2290 return self.arg == other
2291 self.assertEqual(i("str"), "str")
2292
Ethan Smithc6512752018-05-26 16:38:33 -04002293 def test_method_register(self):
2294 class A:
2295 @functools.singledispatchmethod
2296 def t(self, arg):
2297 self.arg = "base"
2298 @t.register(int)
2299 def _(self, arg):
2300 self.arg = "int"
2301 @t.register(str)
2302 def _(self, arg):
2303 self.arg = "str"
2304 a = A()
2305
2306 a.t(0)
2307 self.assertEqual(a.arg, "int")
2308 aa = A()
2309 self.assertFalse(hasattr(aa, 'arg'))
2310 a.t('')
2311 self.assertEqual(a.arg, "str")
2312 aa = A()
2313 self.assertFalse(hasattr(aa, 'arg'))
2314 a.t(0.0)
2315 self.assertEqual(a.arg, "base")
2316 aa = A()
2317 self.assertFalse(hasattr(aa, 'arg'))
2318
2319 def test_staticmethod_register(self):
2320 class A:
2321 @functools.singledispatchmethod
2322 @staticmethod
2323 def t(arg):
2324 return arg
2325 @t.register(int)
2326 @staticmethod
2327 def _(arg):
2328 return isinstance(arg, int)
2329 @t.register(str)
2330 @staticmethod
2331 def _(arg):
2332 return isinstance(arg, str)
2333 a = A()
2334
2335 self.assertTrue(A.t(0))
2336 self.assertTrue(A.t(''))
2337 self.assertEqual(A.t(0.0), 0.0)
2338
2339 def test_classmethod_register(self):
2340 class A:
2341 def __init__(self, arg):
2342 self.arg = arg
2343
2344 @functools.singledispatchmethod
2345 @classmethod
2346 def t(cls, arg):
2347 return cls("base")
2348 @t.register(int)
2349 @classmethod
2350 def _(cls, arg):
2351 return cls("int")
2352 @t.register(str)
2353 @classmethod
2354 def _(cls, arg):
2355 return cls("str")
2356
2357 self.assertEqual(A.t(0).arg, "int")
2358 self.assertEqual(A.t('').arg, "str")
2359 self.assertEqual(A.t(0.0).arg, "base")
2360
2361 def test_callable_register(self):
2362 class A:
2363 def __init__(self, arg):
2364 self.arg = arg
2365
2366 @functools.singledispatchmethod
2367 @classmethod
2368 def t(cls, arg):
2369 return cls("base")
2370
2371 @A.t.register(int)
2372 @classmethod
2373 def _(cls, arg):
2374 return cls("int")
2375 @A.t.register(str)
2376 @classmethod
2377 def _(cls, arg):
2378 return cls("str")
2379
2380 self.assertEqual(A.t(0).arg, "int")
2381 self.assertEqual(A.t('').arg, "str")
2382 self.assertEqual(A.t(0.0).arg, "base")
2383
2384 def test_abstractmethod_register(self):
2385 class Abstract(abc.ABCMeta):
2386
2387 @functools.singledispatchmethod
2388 @abc.abstractmethod
2389 def add(self, x, y):
2390 pass
2391
2392 self.assertTrue(Abstract.add.__isabstractmethod__)
2393
2394 def test_type_ann_register(self):
2395 class A:
2396 @functools.singledispatchmethod
2397 def t(self, arg):
2398 return "base"
2399 @t.register
2400 def _(self, arg: int):
2401 return "int"
2402 @t.register
2403 def _(self, arg: str):
2404 return "str"
2405 a = A()
2406
2407 self.assertEqual(a.t(0), "int")
2408 self.assertEqual(a.t(''), "str")
2409 self.assertEqual(a.t(0.0), "base")
2410
Łukasz Langae5697532017-12-11 13:56:31 -08002411 def test_invalid_registrations(self):
2412 msg_prefix = "Invalid first argument to `register()`: "
2413 msg_suffix = (
2414 ". Use either `@register(some_class)` or plain `@register` on an "
2415 "annotated function."
2416 )
2417 @functools.singledispatch
2418 def i(arg):
2419 return "base"
2420 with self.assertRaises(TypeError) as exc:
2421 @i.register(42)
2422 def _(arg):
2423 return "I annotated with a non-type"
2424 self.assertTrue(str(exc.exception).startswith(msg_prefix + "42"))
2425 self.assertTrue(str(exc.exception).endswith(msg_suffix))
2426 with self.assertRaises(TypeError) as exc:
2427 @i.register
2428 def _(arg):
2429 return "I forgot to annotate"
2430 self.assertTrue(str(exc.exception).startswith(msg_prefix +
2431 "<function TestSingleDispatch.test_invalid_registrations.<locals>._"
2432 ))
2433 self.assertTrue(str(exc.exception).endswith(msg_suffix))
2434
Łukasz Langae5697532017-12-11 13:56:31 -08002435 with self.assertRaises(TypeError) as exc:
2436 @i.register
2437 def _(arg: typing.Iterable[str]):
2438 # At runtime, dispatching on generics is impossible.
2439 # When registering implementations with singledispatch, avoid
2440 # types from `typing`. Instead, annotate with regular types
2441 # or ABCs.
2442 return "I annotated with a generic collection"
Lysandros Nikolaoud6738102019-05-20 00:11:21 +02002443 self.assertTrue(str(exc.exception).startswith(
2444 "Invalid annotation for 'arg'."
Łukasz Langae5697532017-12-11 13:56:31 -08002445 ))
Lysandros Nikolaoud6738102019-05-20 00:11:21 +02002446 self.assertTrue(str(exc.exception).endswith(
2447 'typing.Iterable[str] is not a class.'
2448 ))
Łukasz Langae5697532017-12-11 13:56:31 -08002449
Dong-hee Na445f1b32018-07-10 16:26:36 +09002450 def test_invalid_positional_argument(self):
2451 @functools.singledispatch
2452 def f(*args):
2453 pass
2454 msg = 'f requires at least 1 positional argument'
INADA Naoki56d8f572018-07-17 13:44:47 +09002455 with self.assertRaisesRegex(TypeError, msg):
Dong-hee Na445f1b32018-07-10 16:26:36 +09002456 f()
Łukasz Langa6f692512013-06-05 12:20:24 +02002457
Carl Meyerd658dea2018-08-28 01:11:56 -06002458
2459class CachedCostItem:
2460 _cost = 1
2461
2462 def __init__(self):
2463 self.lock = py_functools.RLock()
2464
2465 @py_functools.cached_property
2466 def cost(self):
2467 """The cost of the item."""
2468 with self.lock:
2469 self._cost += 1
2470 return self._cost
2471
2472
2473class OptionallyCachedCostItem:
2474 _cost = 1
2475
2476 def get_cost(self):
2477 """The cost of the item."""
2478 self._cost += 1
2479 return self._cost
2480
2481 cached_cost = py_functools.cached_property(get_cost)
2482
2483
2484class CachedCostItemWait:
2485
2486 def __init__(self, event):
2487 self._cost = 1
2488 self.lock = py_functools.RLock()
2489 self.event = event
2490
2491 @py_functools.cached_property
2492 def cost(self):
2493 self.event.wait(1)
2494 with self.lock:
2495 self._cost += 1
2496 return self._cost
2497
2498
2499class CachedCostItemWithSlots:
2500 __slots__ = ('_cost')
2501
2502 def __init__(self):
2503 self._cost = 1
2504
2505 @py_functools.cached_property
2506 def cost(self):
2507 raise RuntimeError('never called, slots not supported')
2508
2509
2510class TestCachedProperty(unittest.TestCase):
2511 def test_cached(self):
2512 item = CachedCostItem()
2513 self.assertEqual(item.cost, 2)
2514 self.assertEqual(item.cost, 2) # not 3
2515
2516 def test_cached_attribute_name_differs_from_func_name(self):
2517 item = OptionallyCachedCostItem()
2518 self.assertEqual(item.get_cost(), 2)
2519 self.assertEqual(item.cached_cost, 3)
2520 self.assertEqual(item.get_cost(), 4)
2521 self.assertEqual(item.cached_cost, 3)
2522
2523 def test_threaded(self):
2524 go = threading.Event()
2525 item = CachedCostItemWait(go)
2526
2527 num_threads = 3
2528
2529 orig_si = sys.getswitchinterval()
2530 sys.setswitchinterval(1e-6)
2531 try:
2532 threads = [
2533 threading.Thread(target=lambda: item.cost)
2534 for k in range(num_threads)
2535 ]
Hai Shie80697d2020-05-28 06:10:27 +08002536 with threading_helper.start_threads(threads):
Carl Meyerd658dea2018-08-28 01:11:56 -06002537 go.set()
2538 finally:
2539 sys.setswitchinterval(orig_si)
2540
2541 self.assertEqual(item.cost, 2)
2542
2543 def test_object_with_slots(self):
2544 item = CachedCostItemWithSlots()
2545 with self.assertRaisesRegex(
2546 TypeError,
2547 "No '__dict__' attribute on 'CachedCostItemWithSlots' instance to cache 'cost' property.",
2548 ):
2549 item.cost
2550
2551 def test_immutable_dict(self):
2552 class MyMeta(type):
2553 @py_functools.cached_property
2554 def prop(self):
2555 return True
2556
2557 class MyClass(metaclass=MyMeta):
2558 pass
2559
2560 with self.assertRaisesRegex(
2561 TypeError,
2562 "The '__dict__' attribute on 'MyMeta' instance does not support item assignment for caching 'prop' property.",
2563 ):
2564 MyClass.prop
2565
2566 def test_reuse_different_names(self):
2567 """Disallow this case because decorated function a would not be cached."""
2568 with self.assertRaises(RuntimeError) as ctx:
2569 class ReusedCachedProperty:
2570 @py_functools.cached_property
2571 def a(self):
2572 pass
2573
2574 b = a
2575
2576 self.assertEqual(
2577 str(ctx.exception.__context__),
2578 str(TypeError("Cannot assign the same cached_property to two different names ('a' and 'b')."))
2579 )
2580
2581 def test_reuse_same_name(self):
2582 """Reusing a cached_property on different classes under the same name is OK."""
2583 counter = 0
2584
2585 @py_functools.cached_property
2586 def _cp(_self):
2587 nonlocal counter
2588 counter += 1
2589 return counter
2590
2591 class A:
2592 cp = _cp
2593
2594 class B:
2595 cp = _cp
2596
2597 a = A()
2598 b = B()
2599
2600 self.assertEqual(a.cp, 1)
2601 self.assertEqual(b.cp, 2)
2602 self.assertEqual(a.cp, 1)
2603
2604 def test_set_name_not_called(self):
2605 cp = py_functools.cached_property(lambda s: None)
2606 class Foo:
2607 pass
2608
2609 Foo.cp = cp
2610
2611 with self.assertRaisesRegex(
2612 TypeError,
2613 "Cannot use cached_property instance without calling __set_name__ on it.",
2614 ):
2615 Foo().cp
2616
2617 def test_access_from_class(self):
2618 self.assertIsInstance(CachedCostItem.cost, py_functools.cached_property)
2619
2620 def test_doc(self):
2621 self.assertEqual(CachedCostItem.cost.__doc__, "The cost of the item.")
2622
2623
Raymond Hettinger9c323f82005-02-28 19:39:44 +00002624if __name__ == '__main__':
Zachary Ware38c707e2015-04-13 15:00:43 -05002625 unittest.main()