blob: fbf5578872e6b0519f2be5aa7646ea89d6e39fef [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08002import builtins
Raymond Hettinger003be522011-05-03 11:01:32 -07003import collections
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03004import collections.abc
Serhiy Storchaka45120f22015-10-24 09:49:56 +03005import copy
Pablo Galindo2f172d82020-06-01 00:41:14 +01006from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00007import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00008from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02009import sys
10from test import support
Antoine Pitroua6a4dc82017-09-07 18:56:24 +020011import threading
Serhiy Storchaka67796522017-01-12 18:34:33 +020012import time
Łukasz Langae5697532017-12-11 13:56:31 -080013import typing
Łukasz Langa6f692512013-06-05 12:20:24 +020014import unittest
Raymond Hettingerd191ef22017-01-07 20:44:48 -080015import unittest.mock
Pablo Galindo99e6c262020-01-23 15:29:52 +000016import os
Dennis Sweeney1253c3e2020-05-05 17:14:32 -040017import weakref
18import gc
Łukasz Langa6f692512013-06-05 12:20:24 +020019from weakref import proxy
Nick Coghlan457fc9a2016-09-10 20:00:02 +100020import contextlib
Raymond Hettinger9c323f82005-02-28 19:39:44 +000021
Hai Shi3ddc6342020-06-30 21:46:06 +080022from test.support import import_helper
Hai Shie80697d2020-05-28 06:10:27 +080023from test.support import threading_helper
Pablo Galindo99e6c262020-01-23 15:29:52 +000024from test.support.script_helper import assert_python_ok
25
Antoine Pitroub5b37142012-11-13 21:35:40 +010026import functools
27
Hai Shi3ddc6342020-06-30 21:46:06 +080028py_functools = import_helper.import_fresh_module('functools',
29 blocked=['_functools'])
Hai Shidd391232020-12-29 20:45:07 +080030c_functools = import_helper.import_fresh_module('functools')
Antoine Pitroub5b37142012-11-13 21:35:40 +010031
Hai Shi3ddc6342020-06-30 21:46:06 +080032decimal = import_helper.import_fresh_module('decimal', fresh=['_decimal'])
Łukasz Langa6f692512013-06-05 12:20:24 +020033
Nick Coghlan457fc9a2016-09-10 20:00:02 +100034@contextlib.contextmanager
35def replaced_module(name, replacement):
36 original_module = sys.modules[name]
37 sys.modules[name] = replacement
38 try:
39 yield
40 finally:
41 sys.modules[name] = original_module
Łukasz Langa6f692512013-06-05 12:20:24 +020042
Raymond Hettinger9c323f82005-02-28 19:39:44 +000043def capture(*args, **kw):
44 """capture all positional and keyword arguments"""
45 return args, kw
46
Łukasz Langa6f692512013-06-05 12:20:24 +020047
Jack Diederiche0cbd692009-04-01 04:27:09 +000048def signature(part):
49 """ return the signature of a partial object """
50 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000051
Serhiy Storchaka38741282016-02-02 18:45:17 +020052class MyTuple(tuple):
53 pass
54
55class BadTuple(tuple):
56 def __add__(self, other):
57 return list(self) + list(other)
58
59class MyDict(dict):
60 pass
61
Łukasz Langa6f692512013-06-05 12:20:24 +020062
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020063class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000064
65 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010066 p = self.partial(capture, 1, 2, a=10, b=20)
67 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000068 self.assertEqual(p(3, 4, b=30, c=40),
69 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010070 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000071 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000072
73 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010074 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000075 # attributes should be readable
76 self.assertEqual(p.func, capture)
77 self.assertEqual(p.args, (1, 2))
78 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000079
80 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010081 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000082 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010083 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000084 except TypeError:
85 pass
86 else:
87 self.fail('First arg not checked for callability')
88
89 def test_protection_of_callers_dict_argument(self):
90 # a caller's dictionary should not be altered by partial
91 def func(a=10, b=20):
92 return a
93 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010094 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000095 self.assertEqual(p(**d), 3)
96 self.assertEqual(d, {'a':3})
97 p(b=7)
98 self.assertEqual(d, {'a':3})
99
Serhiy Storchaka9639e4a2017-02-20 14:04:30 +0200100 def test_kwargs_copy(self):
101 # Issue #29532: Altering a kwarg dictionary passed to a constructor
102 # should not affect a partial object after creation
103 d = {'a': 3}
104 p = self.partial(capture, **d)
105 self.assertEqual(p(), ((), {'a': 3}))
106 d['a'] = 5
107 self.assertEqual(p(), ((), {'a': 3}))
108
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000109 def test_arg_combinations(self):
110 # exercise special code paths for zero args in either partial
111 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100112 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000113 self.assertEqual(p(), ((), {}))
114 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100115 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000116 self.assertEqual(p(), ((1,2), {}))
117 self.assertEqual(p(3,4), ((1,2,3,4), {}))
118
119 def test_kw_combinations(self):
120 # exercise special code paths for no keyword args in
121 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100122 p = self.partial(capture)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400123 self.assertEqual(p.keywords, {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000124 self.assertEqual(p(), ((), {}))
125 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100126 p = self.partial(capture, a=1)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400127 self.assertEqual(p.keywords, {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000128 self.assertEqual(p(), ((), {'a':1}))
129 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
130 # keyword args in the call override those in the partial object
131 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
132
133 def test_positional(self):
134 # make sure positional arguments are captured correctly
135 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100136 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000137 expected = args + ('x',)
138 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000139 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000140
141 def test_keyword(self):
142 # make sure keyword arguments are captured correctly
143 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100144 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000145 expected = {'a':a,'x':None}
146 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000147 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000148
149 def test_no_side_effects(self):
150 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100151 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000152 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000153 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000154 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000155 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000156
157 def test_error_propagation(self):
158 def f(x, y):
159 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100160 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
161 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
162 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
163 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000164
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000165 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100166 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000167 p = proxy(f)
168 self.assertEqual(f.func, p.func)
169 f = None
170 self.assertRaises(ReferenceError, getattr, p, 'func')
171
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000172 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000173 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100174 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000175 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100176 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000177 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000178
Alexander Belopolskye49af342015-03-01 15:08:17 -0500179 def test_nested_optimization(self):
180 partial = self.partial
Alexander Belopolskye49af342015-03-01 15:08:17 -0500181 inner = partial(signature, 'asdf')
182 nested = partial(inner, bar=True)
183 flat = partial(signature, 'asdf', bar=True)
184 self.assertEqual(signature(nested), signature(flat))
185
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300186 def test_nested_partial_with_attribute(self):
187 # see issue 25137
188 partial = self.partial
189
190 def foo(bar):
191 return bar
192
193 p = partial(foo, 'first')
194 p2 = partial(p, 'second')
195 p2.new_attr = 'spam'
196 self.assertEqual(p2.new_attr, 'spam')
197
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000198 def test_repr(self):
199 args = (object(), object())
200 args_repr = ', '.join(repr(a) for a in args)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200201 kwargs = {'a': object(), 'b': object()}
202 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
203 'b={b!r}, a={a!r}'.format_map(kwargs)]
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000204 if self.partial in (c_functools.partial, py_functools.partial):
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000205 name = 'functools.partial'
206 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100207 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000208
Antoine Pitroub5b37142012-11-13 21:35:40 +0100209 f = self.partial(capture)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000210 self.assertEqual(f'{name}({capture!r})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000211
Antoine Pitroub5b37142012-11-13 21:35:40 +0100212 f = self.partial(capture, *args)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000213 self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000214
Antoine Pitroub5b37142012-11-13 21:35:40 +0100215 f = self.partial(capture, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200216 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000217 [f'{name}({capture!r}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200218 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000219
Antoine Pitroub5b37142012-11-13 21:35:40 +0100220 f = self.partial(capture, *args, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200221 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000222 [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200223 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000224
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300225 def test_recursive_repr(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000226 if self.partial in (c_functools.partial, py_functools.partial):
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300227 name = 'functools.partial'
228 else:
229 name = self.partial.__name__
230
231 f = self.partial(capture)
232 f.__setstate__((f, (), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300233 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000234 self.assertEqual(repr(f), '%s(...)' % (name,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300235 finally:
236 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300237
238 f = self.partial(capture)
239 f.__setstate__((capture, (f,), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300240 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000241 self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300242 finally:
243 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300244
245 f = self.partial(capture)
246 f.__setstate__((capture, (), {'a': f}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300247 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000248 self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300249 finally:
250 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300251
Jack Diederiche0cbd692009-04-01 04:27:09 +0000252 def test_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000253 with self.AllowPickle():
254 f = self.partial(signature, ['asdf'], bar=[True])
255 f.attr = []
256 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
257 f_copy = pickle.loads(pickle.dumps(f, proto))
258 self.assertEqual(signature(f_copy), signature(f))
Serhiy Storchaka38741282016-02-02 18:45:17 +0200259
260 def test_copy(self):
261 f = self.partial(signature, ['asdf'], bar=[True])
262 f.attr = []
263 f_copy = copy.copy(f)
264 self.assertEqual(signature(f_copy), signature(f))
265 self.assertIs(f_copy.attr, f.attr)
266 self.assertIs(f_copy.args, f.args)
267 self.assertIs(f_copy.keywords, f.keywords)
268
269 def test_deepcopy(self):
270 f = self.partial(signature, ['asdf'], bar=[True])
271 f.attr = []
272 f_copy = copy.deepcopy(f)
273 self.assertEqual(signature(f_copy), signature(f))
274 self.assertIsNot(f_copy.attr, f.attr)
275 self.assertIsNot(f_copy.args, f.args)
276 self.assertIsNot(f_copy.args[0], f.args[0])
277 self.assertIsNot(f_copy.keywords, f.keywords)
278 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
279
280 def test_setstate(self):
281 f = self.partial(signature)
282 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000283
Serhiy Storchaka38741282016-02-02 18:45:17 +0200284 self.assertEqual(signature(f),
285 (capture, (1,), dict(a=10), dict(attr=[])))
286 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
287
288 f.__setstate__((capture, (1,), dict(a=10), None))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000289
Serhiy Storchaka38741282016-02-02 18:45:17 +0200290 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
291 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
292
293 f.__setstate__((capture, (1,), None, None))
294 #self.assertEqual(signature(f), (capture, (1,), {}, {}))
295 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
296 self.assertEqual(f(2), ((1, 2), {}))
297 self.assertEqual(f(), ((1,), {}))
298
299 f.__setstate__((capture, (), {}, None))
300 self.assertEqual(signature(f), (capture, (), {}, {}))
301 self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
302 self.assertEqual(f(2), ((2,), {}))
303 self.assertEqual(f(), ((), {}))
304
305 def test_setstate_errors(self):
306 f = self.partial(signature)
307 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
308 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
309 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
310 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
311 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
312 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
313 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
314
315 def test_setstate_subclasses(self):
316 f = self.partial(signature)
317 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
318 s = signature(f)
319 self.assertEqual(s, (capture, (1,), dict(a=10), {}))
320 self.assertIs(type(s[1]), tuple)
321 self.assertIs(type(s[2]), dict)
322 r = f()
323 self.assertEqual(r, ((1,), {'a': 10}))
324 self.assertIs(type(r[0]), tuple)
325 self.assertIs(type(r[1]), dict)
326
327 f.__setstate__((capture, BadTuple((1,)), {}, None))
328 s = signature(f)
329 self.assertEqual(s, (capture, (1,), {}, {}))
330 self.assertIs(type(s[1]), tuple)
331 r = f(2)
332 self.assertEqual(r, ((1, 2), {}))
333 self.assertIs(type(r[0]), tuple)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000334
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300335 def test_recursive_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000336 with self.AllowPickle():
337 f = self.partial(capture)
338 f.__setstate__((f, (), {}, {}))
339 try:
340 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
341 with self.assertRaises(RecursionError):
342 pickle.dumps(f, proto)
343 finally:
344 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300345
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000346 f = self.partial(capture)
347 f.__setstate__((capture, (f,), {}, {}))
348 try:
349 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
350 f_copy = pickle.loads(pickle.dumps(f, proto))
351 try:
352 self.assertIs(f_copy.args[0], f_copy)
353 finally:
354 f_copy.__setstate__((capture, (), {}, {}))
355 finally:
356 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300357
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000358 f = self.partial(capture)
359 f.__setstate__((capture, (), {'a': f}, {}))
360 try:
361 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
362 f_copy = pickle.loads(pickle.dumps(f, proto))
363 try:
364 self.assertIs(f_copy.keywords['a'], f_copy)
365 finally:
366 f_copy.__setstate__((capture, (), {}, {}))
367 finally:
368 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300369
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200370 # Issue 6083: Reference counting bug
371 def test_setstate_refcount(self):
372 class BadSequence:
373 def __len__(self):
374 return 4
375 def __getitem__(self, key):
376 if key == 0:
377 return max
378 elif key == 1:
379 return tuple(range(1000000))
380 elif key in (2, 3):
381 return {}
382 raise IndexError
383
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200384 f = self.partial(object)
Serhiy Storchaka38741282016-02-02 18:45:17 +0200385 self.assertRaises(TypeError, f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000386
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000387@unittest.skipUnless(c_functools, 'requires the C _functools module')
388class TestPartialC(TestPartial, unittest.TestCase):
389 if c_functools:
390 partial = c_functools.partial
391
392 class AllowPickle:
393 def __enter__(self):
394 return self
395 def __exit__(self, type, value, tb):
396 return False
397
398 def test_attributes_unwritable(self):
399 # attributes should not be writable
400 p = self.partial(capture, 1, 2, a=10, b=20)
401 self.assertRaises(AttributeError, setattr, p, 'func', map)
402 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
403 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
404
405 p = self.partial(hex)
406 try:
407 del p.__dict__
408 except TypeError:
409 pass
410 else:
411 self.fail('partial object allowed __dict__ to be deleted')
Łukasz Langa6f692512013-06-05 12:20:24 +0200412
Michael Seifert6c3d5272017-03-15 06:26:33 +0100413 def test_manually_adding_non_string_keyword(self):
414 p = self.partial(capture)
415 # Adding a non-string/unicode keyword to partial kwargs
416 p.keywords[1234] = 'value'
417 r = repr(p)
418 self.assertIn('1234', r)
419 self.assertIn("'value'", r)
420 with self.assertRaises(TypeError):
421 p()
422
423 def test_keystr_replaces_value(self):
424 p = self.partial(capture)
425
426 class MutatesYourDict(object):
427 def __str__(self):
428 p.keywords[self] = ['sth2']
429 return 'astr'
430
Mike53f7a7c2017-12-14 14:04:53 +0300431 # Replacing the value during key formatting should keep the original
Michael Seifert6c3d5272017-03-15 06:26:33 +0100432 # value alive (at least long enough).
433 p.keywords[MutatesYourDict()] = ['sth']
434 r = repr(p)
435 self.assertIn('astr', r)
436 self.assertIn("['sth']", r)
437
438
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200439class TestPartialPy(TestPartial, unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000440 partial = py_functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000441
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000442 class AllowPickle:
443 def __init__(self):
444 self._cm = replaced_module("functools", py_functools)
445 def __enter__(self):
446 return self._cm.__enter__()
447 def __exit__(self, type, value, tb):
448 return self._cm.__exit__(type, value, tb)
Łukasz Langa6f692512013-06-05 12:20:24 +0200449
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200450if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000451 class CPartialSubclass(c_functools.partial):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200452 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100453
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000454class PyPartialSubclass(py_functools.partial):
455 pass
Łukasz Langa6f692512013-06-05 12:20:24 +0200456
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200457@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200458class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200459 if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000460 partial = CPartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000461
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300462 # partial subclasses are not optimized for nested calls
463 test_nested_optimization = None
464
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000465class TestPartialPySubclass(TestPartialPy):
466 partial = PyPartialSubclass
Łukasz Langa6f692512013-06-05 12:20:24 +0200467
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000468class TestPartialMethod(unittest.TestCase):
469
470 class A(object):
471 nothing = functools.partialmethod(capture)
472 positional = functools.partialmethod(capture, 1)
473 keywords = functools.partialmethod(capture, a=2)
474 both = functools.partialmethod(capture, 3, b=4)
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300475 spec_keywords = functools.partialmethod(capture, self=1, func=2)
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000476
477 nested = functools.partialmethod(positional, 5)
478
479 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
480
481 static = functools.partialmethod(staticmethod(capture), 8)
482 cls = functools.partialmethod(classmethod(capture), d=9)
483
484 a = A()
485
486 def test_arg_combinations(self):
487 self.assertEqual(self.a.nothing(), ((self.a,), {}))
488 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
489 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
490 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
491
492 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
493 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
494 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
495 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
496
497 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
498 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
499 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
500 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
501
502 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
503 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
504 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
505 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
506
507 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
508
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300509 self.assertEqual(self.a.spec_keywords(), ((self.a,), {'self': 1, 'func': 2}))
510
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000511 def test_nested(self):
512 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
513 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
514 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
515 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
516
517 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
518
519 def test_over_partial(self):
520 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
521 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
522 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
523 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
524
525 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
526
527 def test_bound_method_introspection(self):
528 obj = self.a
529 self.assertIs(obj.both.__self__, obj)
530 self.assertIs(obj.nested.__self__, obj)
531 self.assertIs(obj.over_partial.__self__, obj)
532 self.assertIs(obj.cls.__self__, self.A)
533 self.assertIs(self.A.cls.__self__, self.A)
534
535 def test_unbound_method_retrieval(self):
536 obj = self.A
537 self.assertFalse(hasattr(obj.both, "__self__"))
538 self.assertFalse(hasattr(obj.nested, "__self__"))
539 self.assertFalse(hasattr(obj.over_partial, "__self__"))
540 self.assertFalse(hasattr(obj.static, "__self__"))
541 self.assertFalse(hasattr(self.a.static, "__self__"))
542
543 def test_descriptors(self):
544 for obj in [self.A, self.a]:
545 with self.subTest(obj=obj):
546 self.assertEqual(obj.static(), ((8,), {}))
547 self.assertEqual(obj.static(5), ((8, 5), {}))
548 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
549 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
550
551 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
552 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
553 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
554 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
555
556 def test_overriding_keywords(self):
557 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
558 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
559
560 def test_invalid_args(self):
561 with self.assertRaises(TypeError):
562 class B(object):
563 method = functools.partialmethod(None, 1)
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300564 with self.assertRaises(TypeError):
565 class B:
566 method = functools.partialmethod()
Serhiy Storchaka142566c2019-06-05 18:22:31 +0300567 with self.assertRaises(TypeError):
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300568 class B:
569 method = functools.partialmethod(func=capture, a=1)
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000570
571 def test_repr(self):
572 self.assertEqual(repr(vars(self.A)['both']),
573 'functools.partialmethod({}, 3, b=4)'.format(capture))
574
575 def test_abstract(self):
576 class Abstract(abc.ABCMeta):
577
578 @abc.abstractmethod
579 def add(self, x, y):
580 pass
581
582 add5 = functools.partialmethod(add, 5)
583
584 self.assertTrue(Abstract.add.__isabstractmethod__)
585 self.assertTrue(Abstract.add5.__isabstractmethod__)
586
587 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
588 self.assertFalse(getattr(func, '__isabstractmethod__', False))
589
Pablo Galindo8c77b8c2019-04-29 13:36:57 +0100590 def test_positional_only(self):
591 def f(a, b, /):
592 return a + b
593
594 p = functools.partial(f, 1)
595 self.assertEqual(p(2), f(1, 2))
596
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000597
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000598class TestUpdateWrapper(unittest.TestCase):
599
600 def check_wrapper(self, wrapper, wrapped,
601 assigned=functools.WRAPPER_ASSIGNMENTS,
602 updated=functools.WRAPPER_UPDATES):
603 # Check attributes were assigned
604 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000605 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000606 # Check attributes were updated
607 for name in updated:
608 wrapper_attr = getattr(wrapper, name)
609 wrapped_attr = getattr(wrapped, name)
610 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000611 if name == "__dict__" and key == "__wrapped__":
612 # __wrapped__ is overwritten by the update code
613 continue
614 self.assertIs(wrapped_attr[key], wrapper_attr[key])
615 # Check __wrapped__
616 self.assertIs(wrapper.__wrapped__, wrapped)
617
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000618
R. David Murray378c0cf2010-02-24 01:46:21 +0000619 def _default_update(self):
Pablo Galindob0544ba2021-04-21 12:41:19 +0100620 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000621 """This is a test"""
622 pass
623 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000624 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000625 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000626 pass
627 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000628 return wrapper, f
629
630 def test_default_update(self):
631 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000632 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000633 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000634 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600635 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000636 self.assertEqual(wrapper.attr, 'This is also a test')
Pablo Galindob0544ba2021-04-21 12:41:19 +0100637 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000638 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000639
R. David Murray378c0cf2010-02-24 01:46:21 +0000640 @unittest.skipIf(sys.flags.optimize >= 2,
641 "Docstrings are omitted with -O2 and above")
642 def test_default_update_doc(self):
643 wrapper, f = self._default_update()
644 self.assertEqual(wrapper.__doc__, 'This is a test')
645
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000646 def test_no_update(self):
647 def f():
648 """This is a test"""
649 pass
650 f.attr = 'This is also a test'
651 def wrapper():
652 pass
653 functools.update_wrapper(wrapper, f, (), ())
654 self.check_wrapper(wrapper, f, (), ())
655 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600656 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000657 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000658 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000659 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000660
661 def test_selective_update(self):
662 def f():
663 pass
664 f.attr = 'This is a different test'
665 f.dict_attr = dict(a=1, b=2, c=3)
666 def wrapper():
667 pass
668 wrapper.dict_attr = {}
669 assign = ('attr',)
670 update = ('dict_attr',)
671 functools.update_wrapper(wrapper, f, assign, update)
672 self.check_wrapper(wrapper, f, assign, update)
673 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600674 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000675 self.assertEqual(wrapper.__doc__, None)
676 self.assertEqual(wrapper.attr, 'This is a different test')
677 self.assertEqual(wrapper.dict_attr, f.dict_attr)
678
Nick Coghlan98876832010-08-17 06:17:18 +0000679 def test_missing_attributes(self):
680 def f():
681 pass
682 def wrapper():
683 pass
684 wrapper.dict_attr = {}
685 assign = ('attr',)
686 update = ('dict_attr',)
687 # Missing attributes on wrapped object are ignored
688 functools.update_wrapper(wrapper, f, assign, update)
689 self.assertNotIn('attr', wrapper.__dict__)
690 self.assertEqual(wrapper.dict_attr, {})
691 # Wrapper must have expected attributes for updating
692 del wrapper.dict_attr
693 with self.assertRaises(AttributeError):
694 functools.update_wrapper(wrapper, f, assign, update)
695 wrapper.dict_attr = 1
696 with self.assertRaises(AttributeError):
697 functools.update_wrapper(wrapper, f, assign, update)
698
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200699 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000700 @unittest.skipIf(sys.flags.optimize >= 2,
701 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000702 def test_builtin_update(self):
703 # Test for bug #1576241
704 def wrapper():
705 pass
706 functools.update_wrapper(wrapper, max)
707 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000708 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000709 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000710
Łukasz Langa6f692512013-06-05 12:20:24 +0200711
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000712class TestWraps(TestUpdateWrapper):
713
R. David Murray378c0cf2010-02-24 01:46:21 +0000714 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000715 def f():
716 """This is a test"""
717 pass
718 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000719 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000720 @functools.wraps(f)
721 def wrapper():
722 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600723 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000724
725 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600726 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000727 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000728 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600729 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000730 self.assertEqual(wrapper.attr, 'This is also a test')
731
Antoine Pitroub5b37142012-11-13 21:35:40 +0100732 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000733 "Docstrings are omitted with -O2 and above")
734 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600735 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000736 self.assertEqual(wrapper.__doc__, 'This is a test')
737
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000738 def test_no_update(self):
739 def f():
740 """This is a test"""
741 pass
742 f.attr = 'This is also a test'
743 @functools.wraps(f, (), ())
744 def wrapper():
745 pass
746 self.check_wrapper(wrapper, f, (), ())
747 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600748 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000749 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000750 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000751
752 def test_selective_update(self):
753 def f():
754 pass
755 f.attr = 'This is a different test'
756 f.dict_attr = dict(a=1, b=2, c=3)
757 def add_dict_attr(f):
758 f.dict_attr = {}
759 return f
760 assign = ('attr',)
761 update = ('dict_attr',)
762 @functools.wraps(f, assign, update)
763 @add_dict_attr
764 def wrapper():
765 pass
766 self.check_wrapper(wrapper, f, assign, update)
767 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600768 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000769 self.assertEqual(wrapper.__doc__, None)
770 self.assertEqual(wrapper.attr, 'This is a different test')
771 self.assertEqual(wrapper.dict_attr, f.dict_attr)
772
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000773
madman-bobe25d5fc2018-10-25 15:02:10 +0100774class TestReduce:
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000775 def test_reduce(self):
776 class Squares:
777 def __init__(self, max):
778 self.max = max
779 self.sofar = []
780
781 def __len__(self):
782 return len(self.sofar)
783
784 def __getitem__(self, i):
785 if not 0 <= i < self.max: raise IndexError
786 n = len(self.sofar)
787 while n <= i:
788 self.sofar.append(n*n)
789 n += 1
790 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000791 def add(x, y):
792 return x + y
madman-bobe25d5fc2018-10-25 15:02:10 +0100793 self.assertEqual(self.reduce(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000794 self.assertEqual(
madman-bobe25d5fc2018-10-25 15:02:10 +0100795 self.reduce(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000796 ['a','c','d','w']
797 )
madman-bobe25d5fc2018-10-25 15:02:10 +0100798 self.assertEqual(self.reduce(lambda x, y: x*y, range(2,8), 1), 5040)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000799 self.assertEqual(
madman-bobe25d5fc2018-10-25 15:02:10 +0100800 self.reduce(lambda x, y: x*y, range(2,21), 1),
Guido van Rossume2a383d2007-01-15 16:59:06 +0000801 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000802 )
madman-bobe25d5fc2018-10-25 15:02:10 +0100803 self.assertEqual(self.reduce(add, Squares(10)), 285)
804 self.assertEqual(self.reduce(add, Squares(10), 0), 285)
805 self.assertEqual(self.reduce(add, Squares(0), 0), 0)
806 self.assertRaises(TypeError, self.reduce)
807 self.assertRaises(TypeError, self.reduce, 42, 42)
808 self.assertRaises(TypeError, self.reduce, 42, 42, 42)
809 self.assertEqual(self.reduce(42, "1"), "1") # func is never called with one item
810 self.assertEqual(self.reduce(42, "", "1"), "1") # func is never called with one item
811 self.assertRaises(TypeError, self.reduce, 42, (42, 42))
812 self.assertRaises(TypeError, self.reduce, add, []) # arg 2 must not be empty sequence with no initial value
813 self.assertRaises(TypeError, self.reduce, add, "")
814 self.assertRaises(TypeError, self.reduce, add, ())
815 self.assertRaises(TypeError, self.reduce, add, object())
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000816
817 class TestFailingIter:
818 def __iter__(self):
819 raise RuntimeError
madman-bobe25d5fc2018-10-25 15:02:10 +0100820 self.assertRaises(RuntimeError, self.reduce, add, TestFailingIter())
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000821
madman-bobe25d5fc2018-10-25 15:02:10 +0100822 self.assertEqual(self.reduce(add, [], None), None)
823 self.assertEqual(self.reduce(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000824
825 class BadSeq:
826 def __getitem__(self, index):
827 raise ValueError
madman-bobe25d5fc2018-10-25 15:02:10 +0100828 self.assertRaises(ValueError, self.reduce, 42, BadSeq())
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000829
830 # Test reduce()'s use of iterators.
831 def test_iterator_usage(self):
832 class SequenceClass:
833 def __init__(self, n):
834 self.n = n
835 def __getitem__(self, i):
836 if 0 <= i < self.n:
837 return i
838 else:
839 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000840
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000841 from operator import add
madman-bobe25d5fc2018-10-25 15:02:10 +0100842 self.assertEqual(self.reduce(add, SequenceClass(5)), 10)
843 self.assertEqual(self.reduce(add, SequenceClass(5), 42), 52)
844 self.assertRaises(TypeError, self.reduce, add, SequenceClass(0))
845 self.assertEqual(self.reduce(add, SequenceClass(0), 42), 42)
846 self.assertEqual(self.reduce(add, SequenceClass(1)), 0)
847 self.assertEqual(self.reduce(add, SequenceClass(1), 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000848
849 d = {"one": 1, "two": 2, "three": 3}
madman-bobe25d5fc2018-10-25 15:02:10 +0100850 self.assertEqual(self.reduce(add, d), "".join(d.keys()))
851
852
853@unittest.skipUnless(c_functools, 'requires the C _functools module')
854class TestReduceC(TestReduce, unittest.TestCase):
855 if c_functools:
856 reduce = c_functools.reduce
857
858
859class TestReducePy(TestReduce, unittest.TestCase):
860 reduce = staticmethod(py_functools.reduce)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000861
Łukasz Langa6f692512013-06-05 12:20:24 +0200862
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200863class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700864
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000865 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700866 def cmp1(x, y):
867 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100868 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700869 self.assertEqual(key(3), key(3))
870 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100871 self.assertGreaterEqual(key(3), key(3))
872
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700873 def cmp2(x, y):
874 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100875 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700876 self.assertEqual(key(4.0), key('4'))
877 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100878 self.assertLessEqual(key(2), key('35'))
879 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700880
881 def test_cmp_to_key_arguments(self):
882 def cmp1(x, y):
883 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100884 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700885 self.assertEqual(key(obj=3), key(obj=3))
886 self.assertGreater(key(obj=3), key(obj=1))
887 with self.assertRaises((TypeError, AttributeError)):
888 key(3) > 1 # rhs is not a K object
889 with self.assertRaises((TypeError, AttributeError)):
890 1 < key(3) # lhs is not a K object
891 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100892 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700893 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200894 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100895 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700896 with self.assertRaises(TypeError):
897 key() # too few args
898 with self.assertRaises(TypeError):
899 key(None, None) # too many args
900
901 def test_bad_cmp(self):
902 def cmp1(x, y):
903 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100904 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700905 with self.assertRaises(ZeroDivisionError):
906 key(3) > key(1)
907
908 class BadCmp:
909 def __lt__(self, other):
910 raise ZeroDivisionError
911 def cmp1(x, y):
912 return BadCmp()
913 with self.assertRaises(ZeroDivisionError):
914 key(3) > key(1)
915
916 def test_obj_field(self):
917 def cmp1(x, y):
918 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100919 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700920 self.assertEqual(key(50).obj, 50)
921
922 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000923 def mycmp(x, y):
924 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100925 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000926 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000927
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700928 def test_sort_int_str(self):
929 def mycmp(x, y):
930 x, y = int(x), int(y)
931 return (x > y) - (x < y)
932 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100933 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700934 self.assertEqual([int(value) for value in values],
935 [0, 1, 1, 2, 3, 4, 5, 7, 10])
936
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000937 def test_hash(self):
938 def mycmp(x, y):
939 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100940 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000941 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700942 self.assertRaises(TypeError, hash, k)
Serhiy Storchaka2e576f52017-04-24 09:05:00 +0300943 self.assertNotIsInstance(k, collections.abc.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000944
Łukasz Langa6f692512013-06-05 12:20:24 +0200945
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200946@unittest.skipUnless(c_functools, 'requires the C _functools module')
947class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
948 if c_functools:
949 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100950
Erlend Egeberg Aasland9746cda2021-04-30 16:04:57 +0200951 @support.cpython_only
952 def test_disallow_instantiation(self):
953 # Ensure that the type disallows instantiation (bpo-43916)
Erlend Egeberg Aasland0a3452e2021-06-24 01:46:25 +0200954 support.check_disallow_instantiation(
955 self, type(c_functools.cmp_to_key(None))
956 )
Erlend Egeberg Aasland9746cda2021-04-30 16:04:57 +0200957
Łukasz Langa6f692512013-06-05 12:20:24 +0200958
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200959class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100960 cmp_to_key = staticmethod(py_functools.cmp_to_key)
961
Łukasz Langa6f692512013-06-05 12:20:24 +0200962
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000963class TestTotalOrdering(unittest.TestCase):
964
965 def test_total_ordering_lt(self):
966 @functools.total_ordering
967 class A:
968 def __init__(self, value):
969 self.value = value
970 def __lt__(self, other):
971 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000972 def __eq__(self, other):
973 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000974 self.assertTrue(A(1) < A(2))
975 self.assertTrue(A(2) > A(1))
976 self.assertTrue(A(1) <= A(2))
977 self.assertTrue(A(2) >= A(1))
978 self.assertTrue(A(2) <= A(2))
979 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000980 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000981
982 def test_total_ordering_le(self):
983 @functools.total_ordering
984 class A:
985 def __init__(self, value):
986 self.value = value
987 def __le__(self, other):
988 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000989 def __eq__(self, other):
990 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000991 self.assertTrue(A(1) < A(2))
992 self.assertTrue(A(2) > A(1))
993 self.assertTrue(A(1) <= A(2))
994 self.assertTrue(A(2) >= A(1))
995 self.assertTrue(A(2) <= A(2))
996 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000997 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000998
999 def test_total_ordering_gt(self):
1000 @functools.total_ordering
1001 class A:
1002 def __init__(self, value):
1003 self.value = value
1004 def __gt__(self, other):
1005 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001006 def __eq__(self, other):
1007 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +00001008 self.assertTrue(A(1) < A(2))
1009 self.assertTrue(A(2) > A(1))
1010 self.assertTrue(A(1) <= A(2))
1011 self.assertTrue(A(2) >= A(1))
1012 self.assertTrue(A(2) <= A(2))
1013 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +10001014 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +00001015
1016 def test_total_ordering_ge(self):
1017 @functools.total_ordering
1018 class A:
1019 def __init__(self, value):
1020 self.value = value
1021 def __ge__(self, other):
1022 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001023 def __eq__(self, other):
1024 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +00001025 self.assertTrue(A(1) < A(2))
1026 self.assertTrue(A(2) > A(1))
1027 self.assertTrue(A(1) <= A(2))
1028 self.assertTrue(A(2) >= A(1))
1029 self.assertTrue(A(2) <= A(2))
1030 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +10001031 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +00001032
1033 def test_total_ordering_no_overwrite(self):
1034 # new methods should not overwrite existing
1035 @functools.total_ordering
1036 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +00001037 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +00001038 self.assertTrue(A(1) < A(2))
1039 self.assertTrue(A(2) > A(1))
1040 self.assertTrue(A(1) <= A(2))
1041 self.assertTrue(A(2) >= A(1))
1042 self.assertTrue(A(2) <= A(2))
1043 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001044
Benjamin Peterson42ebee32010-04-11 01:43:16 +00001045 def test_no_operations_defined(self):
1046 with self.assertRaises(ValueError):
1047 @functools.total_ordering
1048 class A:
1049 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001050
Nick Coghlanf05d9812013-10-02 00:02:03 +10001051 def test_type_error_when_not_implemented(self):
1052 # bug 10042; ensure stack overflow does not occur
1053 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001054 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +10001055 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001056 def __init__(self, value):
1057 self.value = value
1058 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001059 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001060 return self.value == other.value
1061 return False
1062 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001063 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001064 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +10001065 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001066
Nick Coghlanf05d9812013-10-02 00:02:03 +10001067 @functools.total_ordering
1068 class ImplementsGreaterThan:
1069 def __init__(self, value):
1070 self.value = value
1071 def __eq__(self, other):
1072 if isinstance(other, ImplementsGreaterThan):
1073 return self.value == other.value
1074 return False
1075 def __gt__(self, other):
1076 if isinstance(other, ImplementsGreaterThan):
1077 return self.value > other.value
1078 return NotImplemented
1079
1080 @functools.total_ordering
1081 class ImplementsLessThanEqualTo:
1082 def __init__(self, value):
1083 self.value = value
1084 def __eq__(self, other):
1085 if isinstance(other, ImplementsLessThanEqualTo):
1086 return self.value == other.value
1087 return False
1088 def __le__(self, other):
1089 if isinstance(other, ImplementsLessThanEqualTo):
1090 return self.value <= other.value
1091 return NotImplemented
1092
1093 @functools.total_ordering
1094 class ImplementsGreaterThanEqualTo:
1095 def __init__(self, value):
1096 self.value = value
1097 def __eq__(self, other):
1098 if isinstance(other, ImplementsGreaterThanEqualTo):
1099 return self.value == other.value
1100 return False
1101 def __ge__(self, other):
1102 if isinstance(other, ImplementsGreaterThanEqualTo):
1103 return self.value >= other.value
1104 return NotImplemented
1105
1106 @functools.total_ordering
1107 class ComparatorNotImplemented:
1108 def __init__(self, value):
1109 self.value = value
1110 def __eq__(self, other):
1111 if isinstance(other, ComparatorNotImplemented):
1112 return self.value == other.value
1113 return False
1114 def __lt__(self, other):
1115 return NotImplemented
1116
1117 with self.subTest("LT < 1"), self.assertRaises(TypeError):
1118 ImplementsLessThan(-1) < 1
1119
1120 with self.subTest("LT < LE"), self.assertRaises(TypeError):
1121 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
1122
1123 with self.subTest("LT < GT"), self.assertRaises(TypeError):
1124 ImplementsLessThan(1) < ImplementsGreaterThan(1)
1125
1126 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
1127 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
1128
1129 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
1130 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
1131
1132 with self.subTest("GT > GE"), self.assertRaises(TypeError):
1133 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
1134
1135 with self.subTest("GT > LT"), self.assertRaises(TypeError):
1136 ImplementsGreaterThan(5) > ImplementsLessThan(5)
1137
1138 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
1139 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
1140
1141 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
1142 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
1143
1144 with self.subTest("GE when equal"):
1145 a = ComparatorNotImplemented(8)
1146 b = ComparatorNotImplemented(8)
1147 self.assertEqual(a, b)
1148 with self.assertRaises(TypeError):
1149 a >= b
1150
1151 with self.subTest("LE when equal"):
1152 a = ComparatorNotImplemented(9)
1153 b = ComparatorNotImplemented(9)
1154 self.assertEqual(a, b)
1155 with self.assertRaises(TypeError):
1156 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +02001157
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001158 def test_pickle(self):
Serhiy Storchaka92bb90a2016-09-22 11:39:25 +03001159 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001160 for name in '__lt__', '__gt__', '__le__', '__ge__':
1161 with self.subTest(method=name, proto=proto):
1162 method = getattr(Orderable_LT, name)
1163 method_copy = pickle.loads(pickle.dumps(method, proto))
1164 self.assertIs(method_copy, method)
1165
Miss Islington (bot)66dd1a02021-08-06 13:11:44 -07001166
1167 def test_total_ordering_for_metaclasses_issue_44605(self):
1168
1169 @functools.total_ordering
1170 class SortableMeta(type):
1171 def __new__(cls, name, bases, ns):
1172 return super().__new__(cls, name, bases, ns)
1173
1174 def __lt__(self, other):
1175 if not isinstance(other, SortableMeta):
1176 pass
1177 return self.__name__ < other.__name__
1178
1179 def __eq__(self, other):
1180 if not isinstance(other, SortableMeta):
1181 pass
1182 return self.__name__ == other.__name__
1183
1184 class B(metaclass=SortableMeta):
1185 pass
1186
1187 class A(metaclass=SortableMeta):
1188 pass
1189
1190 self.assertTrue(A < B)
1191 self.assertFalse(A > B)
1192
1193
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001194@functools.total_ordering
1195class Orderable_LT:
1196 def __init__(self, value):
1197 self.value = value
1198 def __lt__(self, other):
1199 return self.value < other.value
1200 def __eq__(self, other):
1201 return self.value == other.value
1202
1203
Raymond Hettinger21cdb712020-05-11 17:00:53 -07001204class TestCache:
1205 # This tests that the pass-through is working as designed.
1206 # The underlying functionality is tested in TestLRU.
1207
1208 def test_cache(self):
1209 @self.module.cache
1210 def fib(n):
1211 if n < 2:
1212 return n
1213 return fib(n-1) + fib(n-2)
1214 self.assertEqual([fib(n) for n in range(16)],
1215 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1216 self.assertEqual(fib.cache_info(),
1217 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1218 fib.cache_clear()
1219 self.assertEqual(fib.cache_info(),
1220 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1221
1222
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001223class TestLRU:
Georg Brandl2e7346a2010-07-31 18:09:23 +00001224
1225 def test_lru(self):
1226 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001227 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001228 f = self.module.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001229 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001230 self.assertEqual(maxsize, 20)
1231 self.assertEqual(currsize, 0)
1232 self.assertEqual(hits, 0)
1233 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001234
1235 domain = range(5)
1236 for i in range(1000):
1237 x, y = choice(domain), choice(domain)
1238 actual = f(x, y)
1239 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +00001240 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001241 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001242 self.assertTrue(hits > misses)
1243 self.assertEqual(hits + misses, 1000)
1244 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001245
Raymond Hettinger02566ec2010-09-04 22:46:06 +00001246 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +00001247 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001248 self.assertEqual(hits, 0)
1249 self.assertEqual(misses, 0)
1250 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001251 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001252 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001253 self.assertEqual(hits, 0)
1254 self.assertEqual(misses, 1)
1255 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001256
Nick Coghlan98876832010-08-17 06:17:18 +00001257 # Test bypassing the cache
1258 self.assertIs(f.__wrapped__, orig)
1259 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001260 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001261 self.assertEqual(hits, 0)
1262 self.assertEqual(misses, 1)
1263 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +00001264
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001265 # test size zero (which means "never-cache")
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001266 @self.module.lru_cache(0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001267 def f():
1268 nonlocal f_cnt
1269 f_cnt += 1
1270 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001271 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001272 f_cnt = 0
1273 for i in range(5):
1274 self.assertEqual(f(), 20)
1275 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001276 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001277 self.assertEqual(hits, 0)
1278 self.assertEqual(misses, 5)
1279 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001280
1281 # test size one
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001282 @self.module.lru_cache(1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001283 def f():
1284 nonlocal f_cnt
1285 f_cnt += 1
1286 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001287 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001288 f_cnt = 0
1289 for i in range(5):
1290 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001291 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001292 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001293 self.assertEqual(hits, 4)
1294 self.assertEqual(misses, 1)
1295 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001296
Raymond Hettingerf3098282010-08-15 03:30:45 +00001297 # test size two
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001298 @self.module.lru_cache(2)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001299 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001300 nonlocal f_cnt
1301 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +00001302 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +00001303 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001304 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +00001305 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1306 # * * * *
1307 self.assertEqual(f(x), x*10)
1308 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001309 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001310 self.assertEqual(hits, 12)
1311 self.assertEqual(misses, 4)
1312 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001313
Raymond Hettingerb8218682019-05-26 11:27:35 -07001314 def test_lru_no_args(self):
1315 @self.module.lru_cache
1316 def square(x):
1317 return x ** 2
1318
1319 self.assertEqual(list(map(square, [10, 20, 10])),
1320 [100, 400, 100])
1321 self.assertEqual(square.cache_info().hits, 1)
1322 self.assertEqual(square.cache_info().misses, 2)
1323 self.assertEqual(square.cache_info().maxsize, 128)
1324 self.assertEqual(square.cache_info().currsize, 2)
1325
Raymond Hettingerd8080c02019-01-26 03:02:00 -05001326 def test_lru_bug_35780(self):
1327 # C version of the lru_cache was not checking to see if
1328 # the user function call has already modified the cache
1329 # (this arises in recursive calls and in multi-threading).
1330 # This cause the cache to have orphan links not referenced
1331 # by the cache dictionary.
1332
1333 once = True # Modified by f(x) below
1334
1335 @self.module.lru_cache(maxsize=10)
1336 def f(x):
1337 nonlocal once
1338 rv = f'.{x}.'
1339 if x == 20 and once:
1340 once = False
1341 rv = f(x)
1342 return rv
1343
1344 # Fill the cache
1345 for x in range(15):
1346 self.assertEqual(f(x), f'.{x}.')
1347 self.assertEqual(f.cache_info().currsize, 10)
1348
1349 # Make a recursive call and make sure the cache remains full
1350 self.assertEqual(f(20), '.20.')
1351 self.assertEqual(f.cache_info().currsize, 10)
1352
Raymond Hettinger14adbd42019-04-20 07:20:44 -10001353 def test_lru_bug_36650(self):
1354 # C version of lru_cache was treating a call with an empty **kwargs
1355 # dictionary as being distinct from a call with no keywords at all.
1356 # This did not result in an incorrect answer, but it did trigger
1357 # an unexpected cache miss.
1358
1359 @self.module.lru_cache()
1360 def f(x):
1361 pass
1362
1363 f(0)
1364 f(0, **{})
1365 self.assertEqual(f.cache_info().hits, 1)
1366
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001367 def test_lru_hash_only_once(self):
1368 # To protect against weird reentrancy bugs and to improve
1369 # efficiency when faced with slow __hash__ methods, the
1370 # LRU cache guarantees that it will only call __hash__
1371 # only once per use as an argument to the cached function.
1372
1373 @self.module.lru_cache(maxsize=1)
1374 def f(x, y):
1375 return x * 3 + y
1376
1377 # Simulate the integer 5
1378 mock_int = unittest.mock.Mock()
1379 mock_int.__mul__ = unittest.mock.Mock(return_value=15)
1380 mock_int.__hash__ = unittest.mock.Mock(return_value=999)
1381
1382 # Add to cache: One use as an argument gives one call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001383 self.assertEqual(f(mock_int, 1), 16)
1384 self.assertEqual(mock_int.__hash__.call_count, 1)
1385 self.assertEqual(f.cache_info(), (0, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001386
1387 # Cache hit: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001388 self.assertEqual(f(mock_int, 1), 16)
1389 self.assertEqual(mock_int.__hash__.call_count, 2)
1390 self.assertEqual(f.cache_info(), (1, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001391
Ville Skyttä49b27342017-08-03 09:00:59 +03001392 # Cache eviction: No use as an argument gives no additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001393 self.assertEqual(f(6, 2), 20)
1394 self.assertEqual(mock_int.__hash__.call_count, 2)
1395 self.assertEqual(f.cache_info(), (1, 2, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001396
1397 # Cache miss: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001398 self.assertEqual(f(mock_int, 1), 16)
1399 self.assertEqual(mock_int.__hash__.call_count, 3)
1400 self.assertEqual(f.cache_info(), (1, 3, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001401
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08001402 def test_lru_reentrancy_with_len(self):
1403 # Test to make sure the LRU cache code isn't thrown-off by
1404 # caching the built-in len() function. Since len() can be
1405 # cached, we shouldn't use it inside the lru code itself.
1406 old_len = builtins.len
1407 try:
1408 builtins.len = self.module.lru_cache(4)(len)
1409 for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
1410 self.assertEqual(len('abcdefghijklmn'[:i]), i)
1411 finally:
1412 builtins.len = old_len
1413
Raymond Hettinger605a4472017-01-09 07:50:19 -08001414 def test_lru_star_arg_handling(self):
1415 # Test regression that arose in ea064ff3c10f
1416 @functools.lru_cache()
1417 def f(*args):
1418 return args
1419
1420 self.assertEqual(f(1, 2), (1, 2))
1421 self.assertEqual(f((1, 2)), ((1, 2),))
1422
Yury Selivanov46a02db2016-11-09 18:55:45 -05001423 def test_lru_type_error(self):
1424 # Regression test for issue #28653.
1425 # lru_cache was leaking when one of the arguments
1426 # wasn't cacheable.
1427
1428 @functools.lru_cache(maxsize=None)
1429 def infinite_cache(o):
1430 pass
1431
1432 @functools.lru_cache(maxsize=10)
1433 def limited_cache(o):
1434 pass
1435
1436 with self.assertRaises(TypeError):
1437 infinite_cache([])
1438
1439 with self.assertRaises(TypeError):
1440 limited_cache([])
1441
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001442 def test_lru_with_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001443 @self.module.lru_cache(maxsize=None)
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001444 def fib(n):
1445 if n < 2:
1446 return n
1447 return fib(n-1) + fib(n-2)
1448 self.assertEqual([fib(n) for n in range(16)],
1449 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1450 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001451 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001452 fib.cache_clear()
1453 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001454 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1455
1456 def test_lru_with_maxsize_negative(self):
1457 @self.module.lru_cache(maxsize=-10)
1458 def eq(n):
1459 return n
1460 for i in (0, 1):
1461 self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1462 self.assertEqual(eq.cache_info(),
Raymond Hettingerd8080c02019-01-26 03:02:00 -05001463 self.module._CacheInfo(hits=0, misses=300, maxsize=0, currsize=0))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001464
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001465 def test_lru_with_exceptions(self):
1466 # Verify that user_function exceptions get passed through without
1467 # creating a hard-to-read chained exception.
1468 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +01001469 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001470 @self.module.lru_cache(maxsize)
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001471 def func(i):
1472 return 'abc'[i]
1473 self.assertEqual(func(0), 'a')
1474 with self.assertRaises(IndexError) as cm:
1475 func(15)
1476 self.assertIsNone(cm.exception.__context__)
1477 # Verify that the previous exception did not result in a cached entry
1478 with self.assertRaises(IndexError):
1479 func(15)
1480
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001481 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001482 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001483 @self.module.lru_cache(maxsize=maxsize, typed=True)
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001484 def square(x):
1485 return x * x
1486 self.assertEqual(square(3), 9)
1487 self.assertEqual(type(square(3)), type(9))
1488 self.assertEqual(square(3.0), 9.0)
1489 self.assertEqual(type(square(3.0)), type(9.0))
1490 self.assertEqual(square(x=3), 9)
1491 self.assertEqual(type(square(x=3)), type(9))
1492 self.assertEqual(square(x=3.0), 9.0)
1493 self.assertEqual(type(square(x=3.0)), type(9.0))
1494 self.assertEqual(square.cache_info().hits, 4)
1495 self.assertEqual(square.cache_info().misses, 4)
1496
Antoine Pitroub5b37142012-11-13 21:35:40 +01001497 def test_lru_with_keyword_args(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001498 @self.module.lru_cache()
Antoine Pitroub5b37142012-11-13 21:35:40 +01001499 def fib(n):
1500 if n < 2:
1501 return n
1502 return fib(n=n-1) + fib(n=n-2)
1503 self.assertEqual(
1504 [fib(n=number) for number in range(16)],
1505 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1506 )
1507 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001508 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001509 fib.cache_clear()
1510 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001511 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001512
1513 def test_lru_with_keyword_args_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001514 @self.module.lru_cache(maxsize=None)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001515 def fib(n):
1516 if n < 2:
1517 return n
1518 return fib(n=n-1) + fib(n=n-2)
1519 self.assertEqual([fib(n=number) for number in range(16)],
1520 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1521 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001522 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001523 fib.cache_clear()
1524 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001525 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1526
Raymond Hettinger4ee39142017-01-08 17:28:20 -08001527 def test_kwargs_order(self):
1528 # PEP 468: Preserving Keyword Argument Order
1529 @self.module.lru_cache(maxsize=10)
1530 def f(**kwargs):
1531 return list(kwargs.items())
1532 self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)])
1533 self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)])
1534 self.assertEqual(f.cache_info(),
1535 self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2))
1536
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001537 def test_lru_cache_decoration(self):
1538 def f(zomg: 'zomg_annotation'):
1539 """f doc string"""
1540 return 42
1541 g = self.module.lru_cache()(f)
1542 for attr in self.module.WRAPPER_ASSIGNMENTS:
1543 self.assertEqual(getattr(g, attr), getattr(f, attr))
1544
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001545 def test_lru_cache_threaded(self):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001546 n, m = 5, 11
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001547 def orig(x, y):
1548 return 3 * x + y
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001549 f = self.module.lru_cache(maxsize=n*m)(orig)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001550 hits, misses, maxsize, currsize = f.cache_info()
1551 self.assertEqual(currsize, 0)
1552
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001553 start = threading.Event()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001554 def full(k):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001555 start.wait(10)
1556 for _ in range(m):
Serhiy Storchaka391af752015-06-08 12:44:18 +03001557 self.assertEqual(f(k, 0), orig(k, 0))
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001558
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001559 def clear():
1560 start.wait(10)
1561 for _ in range(2*m):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001562 f.cache_clear()
1563
1564 orig_si = sys.getswitchinterval()
Xavier de Gaye7522ef42016-12-08 11:06:56 +01001565 support.setswitchinterval(1e-6)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001566 try:
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001567 # create n threads in order to fill cache
Serhiy Storchaka391af752015-06-08 12:44:18 +03001568 threads = [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001569 for k in range(n)]
Hai Shie80697d2020-05-28 06:10:27 +08001570 with threading_helper.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001571 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001572
1573 hits, misses, maxsize, currsize = f.cache_info()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001574 if self.module is py_functools:
1575 # XXX: Why can be not equal?
1576 self.assertLessEqual(misses, n)
1577 self.assertLessEqual(hits, m*n - misses)
1578 else:
1579 self.assertEqual(misses, n)
1580 self.assertEqual(hits, m*n - misses)
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001581 self.assertEqual(currsize, n)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001582
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001583 # create n threads in order to fill cache and 1 to clear it
1584 threads = [threading.Thread(target=clear)]
Serhiy Storchaka391af752015-06-08 12:44:18 +03001585 threads += [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001586 for k in range(n)]
1587 start.clear()
Hai Shie80697d2020-05-28 06:10:27 +08001588 with threading_helper.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001589 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001590 finally:
1591 sys.setswitchinterval(orig_si)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001592
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001593 def test_lru_cache_threaded2(self):
1594 # Simultaneous call with the same arguments
1595 n, m = 5, 7
1596 start = threading.Barrier(n+1)
1597 pause = threading.Barrier(n+1)
1598 stop = threading.Barrier(n+1)
1599 @self.module.lru_cache(maxsize=m*n)
1600 def f(x):
1601 pause.wait(10)
1602 return 3 * x
1603 self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1604 def test():
1605 for i in range(m):
1606 start.wait(10)
1607 self.assertEqual(f(i), 3 * i)
1608 stop.wait(10)
1609 threads = [threading.Thread(target=test) for k in range(n)]
Hai Shie80697d2020-05-28 06:10:27 +08001610 with threading_helper.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001611 for i in range(m):
1612 start.wait(10)
1613 stop.reset()
1614 pause.wait(10)
1615 start.reset()
1616 stop.wait(10)
1617 pause.reset()
1618 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1619
Serhiy Storchaka67796522017-01-12 18:34:33 +02001620 def test_lru_cache_threaded3(self):
1621 @self.module.lru_cache(maxsize=2)
1622 def f(x):
1623 time.sleep(.01)
1624 return 3 * x
1625 def test(i, x):
1626 with self.subTest(thread=i):
1627 self.assertEqual(f(x), 3 * x, i)
1628 threads = [threading.Thread(target=test, args=(i, v))
1629 for i, v in enumerate([1, 2, 2, 3, 2])]
Hai Shie80697d2020-05-28 06:10:27 +08001630 with threading_helper.start_threads(threads):
Serhiy Storchaka67796522017-01-12 18:34:33 +02001631 pass
1632
Raymond Hettinger03923422013-03-04 02:52:50 -05001633 def test_need_for_rlock(self):
1634 # This will deadlock on an LRU cache that uses a regular lock
1635
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001636 @self.module.lru_cache(maxsize=10)
Raymond Hettinger03923422013-03-04 02:52:50 -05001637 def test_func(x):
1638 'Used to demonstrate a reentrant lru_cache call within a single thread'
1639 return x
1640
1641 class DoubleEq:
1642 'Demonstrate a reentrant lru_cache call within a single thread'
1643 def __init__(self, x):
1644 self.x = x
1645 def __hash__(self):
1646 return self.x
1647 def __eq__(self, other):
1648 if self.x == 2:
1649 test_func(DoubleEq(1))
1650 return self.x == other.x
1651
1652 test_func(DoubleEq(1)) # Load the cache
1653 test_func(DoubleEq(2)) # Load the cache
1654 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1655 DoubleEq(2)) # Verify the correct return value
1656
Serhiy Storchakae7070f02015-06-08 11:19:24 +03001657 def test_lru_method(self):
1658 class X(int):
1659 f_cnt = 0
1660 @self.module.lru_cache(2)
1661 def f(self, x):
1662 self.f_cnt += 1
1663 return x*10+self
1664 a = X(5)
1665 b = X(5)
1666 c = X(7)
1667 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1668
1669 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1670 self.assertEqual(a.f(x), x*10 + 5)
1671 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1672 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1673
1674 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1675 self.assertEqual(b.f(x), x*10 + 5)
1676 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1677 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1678
1679 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1680 self.assertEqual(c.f(x), x*10 + 7)
1681 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1682 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1683
1684 self.assertEqual(a.f.cache_info(), X.f.cache_info())
1685 self.assertEqual(b.f.cache_info(), X.f.cache_info())
1686 self.assertEqual(c.f.cache_info(), X.f.cache_info())
1687
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001688 def test_pickle(self):
1689 cls = self.__class__
1690 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1691 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1692 with self.subTest(proto=proto, func=f):
1693 f_copy = pickle.loads(pickle.dumps(f, proto))
1694 self.assertIs(f_copy, f)
1695
1696 def test_copy(self):
1697 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001698 def orig(x, y):
1699 return 3 * x + y
1700 part = self.module.partial(orig, 2)
1701 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1702 self.module.lru_cache(2)(part))
1703 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001704 with self.subTest(func=f):
1705 f_copy = copy.copy(f)
1706 self.assertIs(f_copy, f)
1707
1708 def test_deepcopy(self):
1709 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001710 def orig(x, y):
1711 return 3 * x + y
1712 part = self.module.partial(orig, 2)
1713 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1714 self.module.lru_cache(2)(part))
1715 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001716 with self.subTest(func=f):
1717 f_copy = copy.deepcopy(f)
1718 self.assertIs(f_copy, f)
1719
Manjusaka051ff522019-11-12 15:30:18 +08001720 def test_lru_cache_parameters(self):
1721 @self.module.lru_cache(maxsize=2)
1722 def f():
1723 return 1
1724 self.assertEqual(f.cache_parameters(), {'maxsize': 2, "typed": False})
1725
1726 @self.module.lru_cache(maxsize=1000, typed=True)
1727 def f():
1728 return 1
1729 self.assertEqual(f.cache_parameters(), {'maxsize': 1000, "typed": True})
1730
Dennis Sweeney1253c3e2020-05-05 17:14:32 -04001731 def test_lru_cache_weakrefable(self):
1732 @self.module.lru_cache
1733 def test_function(x):
1734 return x
1735
1736 class A:
1737 @self.module.lru_cache
1738 def test_method(self, x):
1739 return (self, x)
1740
1741 @staticmethod
1742 @self.module.lru_cache
1743 def test_staticmethod(x):
1744 return (self, x)
1745
1746 refs = [weakref.ref(test_function),
1747 weakref.ref(A.test_method),
1748 weakref.ref(A.test_staticmethod)]
1749
1750 for ref in refs:
1751 self.assertIsNotNone(ref())
1752
1753 del A
1754 del test_function
1755 gc.collect()
1756
1757 for ref in refs:
1758 self.assertIsNone(ref())
1759
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001760
1761@py_functools.lru_cache()
1762def py_cached_func(x, y):
1763 return 3 * x + y
1764
1765@c_functools.lru_cache()
1766def c_cached_func(x, y):
1767 return 3 * x + y
1768
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001769
1770class TestLRUPy(TestLRU, unittest.TestCase):
1771 module = py_functools
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001772 cached_func = py_cached_func,
1773
1774 @module.lru_cache()
1775 def cached_meth(self, x, y):
1776 return 3 * x + y
1777
1778 @staticmethod
1779 @module.lru_cache()
1780 def cached_staticmeth(x, y):
1781 return 3 * x + y
1782
1783
1784class TestLRUC(TestLRU, unittest.TestCase):
1785 module = c_functools
1786 cached_func = c_cached_func,
1787
1788 @module.lru_cache()
1789 def cached_meth(self, x, y):
1790 return 3 * x + y
1791
1792 @staticmethod
1793 @module.lru_cache()
1794 def cached_staticmeth(x, y):
1795 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001796
Raymond Hettinger03923422013-03-04 02:52:50 -05001797
Łukasz Langa6f692512013-06-05 12:20:24 +02001798class TestSingleDispatch(unittest.TestCase):
1799 def test_simple_overloads(self):
1800 @functools.singledispatch
1801 def g(obj):
1802 return "base"
1803 def g_int(i):
1804 return "integer"
1805 g.register(int, g_int)
1806 self.assertEqual(g("str"), "base")
1807 self.assertEqual(g(1), "integer")
1808 self.assertEqual(g([1,2,3]), "base")
1809
1810 def test_mro(self):
1811 @functools.singledispatch
1812 def g(obj):
1813 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001814 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001815 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001816 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001817 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001818 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001819 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001820 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001821 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001822 def g_A(a):
1823 return "A"
1824 def g_B(b):
1825 return "B"
1826 g.register(A, g_A)
1827 g.register(B, g_B)
1828 self.assertEqual(g(A()), "A")
1829 self.assertEqual(g(B()), "B")
1830 self.assertEqual(g(C()), "A")
1831 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001832
1833 def test_register_decorator(self):
1834 @functools.singledispatch
1835 def g(obj):
1836 return "base"
1837 @g.register(int)
1838 def g_int(i):
1839 return "int %s" % (i,)
1840 self.assertEqual(g(""), "base")
1841 self.assertEqual(g(12), "int 12")
1842 self.assertIs(g.dispatch(int), g_int)
1843 self.assertIs(g.dispatch(object), g.dispatch(str))
1844 # Note: in the assert above this is not g.
1845 # @singledispatch returns the wrapper.
1846
1847 def test_wrapping_attributes(self):
1848 @functools.singledispatch
1849 def g(obj):
1850 "Simple test"
1851 return "Test"
1852 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02001853 if sys.flags.optimize < 2:
1854 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02001855
1856 @unittest.skipUnless(decimal, 'requires _decimal')
1857 @support.cpython_only
1858 def test_c_classes(self):
1859 @functools.singledispatch
1860 def g(obj):
1861 return "base"
1862 @g.register(decimal.DecimalException)
1863 def _(obj):
1864 return obj.args
1865 subn = decimal.Subnormal("Exponent < Emin")
1866 rnd = decimal.Rounded("Number got rounded")
1867 self.assertEqual(g(subn), ("Exponent < Emin",))
1868 self.assertEqual(g(rnd), ("Number got rounded",))
1869 @g.register(decimal.Subnormal)
1870 def _(obj):
1871 return "Too small to care."
1872 self.assertEqual(g(subn), "Too small to care.")
1873 self.assertEqual(g(rnd), ("Number got rounded",))
1874
1875 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001876 # None of the examples in this test depend on haystack ordering.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001877 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001878 mro = functools._compose_mro
1879 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1880 for haystack in permutations(bases):
1881 m = mro(dict, haystack)
Guido van Rossumf0666942016-08-23 10:47:07 -07001882 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
1883 c.Collection, c.Sized, c.Iterable,
1884 c.Container, object])
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001885 bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict]
Łukasz Langa6f692512013-06-05 12:20:24 +02001886 for haystack in permutations(bases):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001887 m = mro(collections.ChainMap, haystack)
1888 self.assertEqual(m, [collections.ChainMap, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001889 c.Collection, c.Sized, c.Iterable,
1890 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001891
1892 # If there's a generic function with implementations registered for
1893 # both Sized and Container, passing a defaultdict to it results in an
1894 # ambiguous dispatch which will cause a RuntimeError (see
1895 # test_mro_conflicts).
1896 bases = [c.Container, c.Sized, str]
1897 for haystack in permutations(bases):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001898 m = mro(collections.defaultdict, [c.Sized, c.Container, str])
1899 self.assertEqual(m, [collections.defaultdict, dict, c.Sized,
1900 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001901
1902 # MutableSequence below is registered directly on D. In other words, it
Martin Panter46f50722016-05-26 05:35:26 +00001903 # precedes MutableMapping which means single dispatch will always
Łukasz Langa3720c772013-07-01 16:00:38 +02001904 # choose MutableSequence here.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001905 class D(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02001906 pass
1907 c.MutableSequence.register(D)
1908 bases = [c.MutableSequence, c.MutableMapping]
1909 for haystack in permutations(bases):
1910 m = mro(D, bases)
Guido van Rossumf0666942016-08-23 10:47:07 -07001911 self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001912 collections.defaultdict, dict, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001913 c.Collection, c.Sized, c.Iterable, c.Container,
Łukasz Langa3720c772013-07-01 16:00:38 +02001914 object])
1915
1916 # Container and Callable are registered on different base classes and
1917 # a generic function supporting both should always pick the Callable
1918 # implementation if a C instance is passed.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001919 class C(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02001920 def __call__(self):
1921 pass
1922 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1923 for haystack in permutations(bases):
1924 m = mro(C, haystack)
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001925 self.assertEqual(m, [C, c.Callable, collections.defaultdict, dict, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001926 c.Collection, c.Sized, c.Iterable,
1927 c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001928
1929 def test_register_abc(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001930 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001931 d = {"a": "b"}
1932 l = [1, 2, 3]
1933 s = {object(), None}
1934 f = frozenset(s)
1935 t = (1, 2, 3)
1936 @functools.singledispatch
1937 def g(obj):
1938 return "base"
1939 self.assertEqual(g(d), "base")
1940 self.assertEqual(g(l), "base")
1941 self.assertEqual(g(s), "base")
1942 self.assertEqual(g(f), "base")
1943 self.assertEqual(g(t), "base")
1944 g.register(c.Sized, lambda obj: "sized")
1945 self.assertEqual(g(d), "sized")
1946 self.assertEqual(g(l), "sized")
1947 self.assertEqual(g(s), "sized")
1948 self.assertEqual(g(f), "sized")
1949 self.assertEqual(g(t), "sized")
1950 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1951 self.assertEqual(g(d), "mutablemapping")
1952 self.assertEqual(g(l), "sized")
1953 self.assertEqual(g(s), "sized")
1954 self.assertEqual(g(f), "sized")
1955 self.assertEqual(g(t), "sized")
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001956 g.register(collections.ChainMap, lambda obj: "chainmap")
Łukasz Langa6f692512013-06-05 12:20:24 +02001957 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1958 self.assertEqual(g(l), "sized")
1959 self.assertEqual(g(s), "sized")
1960 self.assertEqual(g(f), "sized")
1961 self.assertEqual(g(t), "sized")
1962 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1963 self.assertEqual(g(d), "mutablemapping")
1964 self.assertEqual(g(l), "mutablesequence")
1965 self.assertEqual(g(s), "sized")
1966 self.assertEqual(g(f), "sized")
1967 self.assertEqual(g(t), "sized")
1968 g.register(c.MutableSet, lambda obj: "mutableset")
1969 self.assertEqual(g(d), "mutablemapping")
1970 self.assertEqual(g(l), "mutablesequence")
1971 self.assertEqual(g(s), "mutableset")
1972 self.assertEqual(g(f), "sized")
1973 self.assertEqual(g(t), "sized")
1974 g.register(c.Mapping, lambda obj: "mapping")
1975 self.assertEqual(g(d), "mutablemapping") # not specific enough
1976 self.assertEqual(g(l), "mutablesequence")
1977 self.assertEqual(g(s), "mutableset")
1978 self.assertEqual(g(f), "sized")
1979 self.assertEqual(g(t), "sized")
1980 g.register(c.Sequence, lambda obj: "sequence")
1981 self.assertEqual(g(d), "mutablemapping")
1982 self.assertEqual(g(l), "mutablesequence")
1983 self.assertEqual(g(s), "mutableset")
1984 self.assertEqual(g(f), "sized")
1985 self.assertEqual(g(t), "sequence")
1986 g.register(c.Set, lambda obj: "set")
1987 self.assertEqual(g(d), "mutablemapping")
1988 self.assertEqual(g(l), "mutablesequence")
1989 self.assertEqual(g(s), "mutableset")
1990 self.assertEqual(g(f), "set")
1991 self.assertEqual(g(t), "sequence")
1992 g.register(dict, lambda obj: "dict")
1993 self.assertEqual(g(d), "dict")
1994 self.assertEqual(g(l), "mutablesequence")
1995 self.assertEqual(g(s), "mutableset")
1996 self.assertEqual(g(f), "set")
1997 self.assertEqual(g(t), "sequence")
1998 g.register(list, lambda obj: "list")
1999 self.assertEqual(g(d), "dict")
2000 self.assertEqual(g(l), "list")
2001 self.assertEqual(g(s), "mutableset")
2002 self.assertEqual(g(f), "set")
2003 self.assertEqual(g(t), "sequence")
2004 g.register(set, lambda obj: "concrete-set")
2005 self.assertEqual(g(d), "dict")
2006 self.assertEqual(g(l), "list")
2007 self.assertEqual(g(s), "concrete-set")
2008 self.assertEqual(g(f), "set")
2009 self.assertEqual(g(t), "sequence")
2010 g.register(frozenset, lambda obj: "frozen-set")
2011 self.assertEqual(g(d), "dict")
2012 self.assertEqual(g(l), "list")
2013 self.assertEqual(g(s), "concrete-set")
2014 self.assertEqual(g(f), "frozen-set")
2015 self.assertEqual(g(t), "sequence")
2016 g.register(tuple, lambda obj: "tuple")
2017 self.assertEqual(g(d), "dict")
2018 self.assertEqual(g(l), "list")
2019 self.assertEqual(g(s), "concrete-set")
2020 self.assertEqual(g(f), "frozen-set")
2021 self.assertEqual(g(t), "tuple")
2022
Łukasz Langa3720c772013-07-01 16:00:38 +02002023 def test_c3_abc(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002024 c = collections.abc
Łukasz Langa3720c772013-07-01 16:00:38 +02002025 mro = functools._c3_mro
2026 class A(object):
2027 pass
2028 class B(A):
2029 def __len__(self):
2030 return 0 # implies Sized
2031 @c.Container.register
2032 class C(object):
2033 pass
2034 class D(object):
2035 pass # unrelated
2036 class X(D, C, B):
2037 def __call__(self):
2038 pass # implies Callable
2039 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
2040 for abcs in permutations([c.Sized, c.Callable, c.Container]):
2041 self.assertEqual(mro(X, abcs=abcs), expected)
2042 # unrelated ABCs don't appear in the resulting MRO
2043 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
2044 self.assertEqual(mro(X, abcs=many_abcs), expected)
2045
Yury Selivanov77a8cd62015-08-18 14:20:00 -04002046 def test_false_meta(self):
2047 # see issue23572
2048 class MetaA(type):
2049 def __len__(self):
2050 return 0
2051 class A(metaclass=MetaA):
2052 pass
2053 class AA(A):
2054 pass
2055 @functools.singledispatch
2056 def fun(a):
2057 return 'base A'
2058 @fun.register(A)
2059 def _(a):
2060 return 'fun A'
2061 aa = AA()
2062 self.assertEqual(fun(aa), 'fun A')
2063
Łukasz Langa6f692512013-06-05 12:20:24 +02002064 def test_mro_conflicts(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002065 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02002066 @functools.singledispatch
2067 def g(arg):
2068 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02002069 class O(c.Sized):
2070 def __len__(self):
2071 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02002072 o = O()
2073 self.assertEqual(g(o), "base")
2074 g.register(c.Iterable, lambda arg: "iterable")
2075 g.register(c.Container, lambda arg: "container")
2076 g.register(c.Sized, lambda arg: "sized")
2077 g.register(c.Set, lambda arg: "set")
2078 self.assertEqual(g(o), "sized")
2079 c.Iterable.register(O)
2080 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
2081 c.Container.register(O)
2082 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02002083 c.Set.register(O)
2084 self.assertEqual(g(o), "set") # because c.Set is a subclass of
2085 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02002086 class P:
2087 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02002088 p = P()
2089 self.assertEqual(g(p), "base")
2090 c.Iterable.register(P)
2091 self.assertEqual(g(p), "iterable")
2092 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02002093 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02002094 g(p)
Benjamin Petersonab078e92016-07-13 21:13:29 -07002095 self.assertIn(
2096 str(re_one.exception),
2097 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2098 "or <class 'collections.abc.Iterable'>"),
2099 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
2100 "or <class 'collections.abc.Container'>")),
2101 )
Łukasz Langa6f692512013-06-05 12:20:24 +02002102 class Q(c.Sized):
2103 def __len__(self):
2104 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02002105 q = Q()
2106 self.assertEqual(g(q), "sized")
2107 c.Iterable.register(Q)
2108 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
2109 c.Set.register(Q)
2110 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02002111 # c.Sized and c.Iterable
2112 @functools.singledispatch
2113 def h(arg):
2114 return "base"
2115 @h.register(c.Sized)
2116 def _(arg):
2117 return "sized"
2118 @h.register(c.Container)
2119 def _(arg):
2120 return "container"
2121 # Even though Sized and Container are explicit bases of MutableMapping,
2122 # this ABC is implicitly registered on defaultdict which makes all of
2123 # MutableMapping's bases implicit as well from defaultdict's
2124 # perspective.
2125 with self.assertRaises(RuntimeError) as re_two:
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002126 h(collections.defaultdict(lambda: 0))
Benjamin Petersonab078e92016-07-13 21:13:29 -07002127 self.assertIn(
2128 str(re_two.exception),
2129 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2130 "or <class 'collections.abc.Sized'>"),
2131 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2132 "or <class 'collections.abc.Container'>")),
2133 )
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002134 class R(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02002135 pass
2136 c.MutableSequence.register(R)
2137 @functools.singledispatch
2138 def i(arg):
2139 return "base"
2140 @i.register(c.MutableMapping)
2141 def _(arg):
2142 return "mapping"
2143 @i.register(c.MutableSequence)
2144 def _(arg):
2145 return "sequence"
2146 r = R()
2147 self.assertEqual(i(r), "sequence")
2148 class S:
2149 pass
2150 class T(S, c.Sized):
2151 def __len__(self):
2152 return 0
2153 t = T()
2154 self.assertEqual(h(t), "sized")
2155 c.Container.register(T)
2156 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
2157 class U:
2158 def __len__(self):
2159 return 0
2160 u = U()
2161 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
2162 # from the existence of __len__()
2163 c.Container.register(U)
2164 # There is no preference for registered versus inferred ABCs.
2165 with self.assertRaises(RuntimeError) as re_three:
2166 h(u)
Benjamin Petersonab078e92016-07-13 21:13:29 -07002167 self.assertIn(
2168 str(re_three.exception),
2169 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2170 "or <class 'collections.abc.Sized'>"),
2171 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2172 "or <class 'collections.abc.Container'>")),
2173 )
Łukasz Langa3720c772013-07-01 16:00:38 +02002174 class V(c.Sized, S):
2175 def __len__(self):
2176 return 0
2177 @functools.singledispatch
2178 def j(arg):
2179 return "base"
2180 @j.register(S)
2181 def _(arg):
2182 return "s"
2183 @j.register(c.Container)
2184 def _(arg):
2185 return "container"
2186 v = V()
2187 self.assertEqual(j(v), "s")
2188 c.Container.register(V)
2189 self.assertEqual(j(v), "container") # because it ends up right after
2190 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02002191
2192 def test_cache_invalidation(self):
2193 from collections import UserDict
INADA Naoki9811e802017-09-30 16:13:02 +09002194 import weakref
2195
Łukasz Langa6f692512013-06-05 12:20:24 +02002196 class TracingDict(UserDict):
2197 def __init__(self, *args, **kwargs):
2198 super(TracingDict, self).__init__(*args, **kwargs)
2199 self.set_ops = []
2200 self.get_ops = []
2201 def __getitem__(self, key):
2202 result = self.data[key]
2203 self.get_ops.append(key)
2204 return result
2205 def __setitem__(self, key, value):
2206 self.set_ops.append(key)
2207 self.data[key] = value
2208 def clear(self):
2209 self.data.clear()
INADA Naoki9811e802017-09-30 16:13:02 +09002210
Łukasz Langa6f692512013-06-05 12:20:24 +02002211 td = TracingDict()
INADA Naoki9811e802017-09-30 16:13:02 +09002212 with support.swap_attr(weakref, "WeakKeyDictionary", lambda: td):
2213 c = collections.abc
2214 @functools.singledispatch
2215 def g(arg):
2216 return "base"
2217 d = {}
2218 l = []
2219 self.assertEqual(len(td), 0)
2220 self.assertEqual(g(d), "base")
2221 self.assertEqual(len(td), 1)
2222 self.assertEqual(td.get_ops, [])
2223 self.assertEqual(td.set_ops, [dict])
2224 self.assertEqual(td.data[dict], g.registry[object])
2225 self.assertEqual(g(l), "base")
2226 self.assertEqual(len(td), 2)
2227 self.assertEqual(td.get_ops, [])
2228 self.assertEqual(td.set_ops, [dict, list])
2229 self.assertEqual(td.data[dict], g.registry[object])
2230 self.assertEqual(td.data[list], g.registry[object])
2231 self.assertEqual(td.data[dict], td.data[list])
2232 self.assertEqual(g(l), "base")
2233 self.assertEqual(g(d), "base")
2234 self.assertEqual(td.get_ops, [list, dict])
2235 self.assertEqual(td.set_ops, [dict, list])
2236 g.register(list, lambda arg: "list")
2237 self.assertEqual(td.get_ops, [list, dict])
2238 self.assertEqual(len(td), 0)
2239 self.assertEqual(g(d), "base")
2240 self.assertEqual(len(td), 1)
2241 self.assertEqual(td.get_ops, [list, dict])
2242 self.assertEqual(td.set_ops, [dict, list, dict])
2243 self.assertEqual(td.data[dict],
2244 functools._find_impl(dict, g.registry))
2245 self.assertEqual(g(l), "list")
2246 self.assertEqual(len(td), 2)
2247 self.assertEqual(td.get_ops, [list, dict])
2248 self.assertEqual(td.set_ops, [dict, list, dict, list])
2249 self.assertEqual(td.data[list],
2250 functools._find_impl(list, g.registry))
2251 class X:
2252 pass
2253 c.MutableMapping.register(X) # Will not invalidate the cache,
2254 # not using ABCs yet.
2255 self.assertEqual(g(d), "base")
2256 self.assertEqual(g(l), "list")
2257 self.assertEqual(td.get_ops, [list, dict, dict, list])
2258 self.assertEqual(td.set_ops, [dict, list, dict, list])
2259 g.register(c.Sized, lambda arg: "sized")
2260 self.assertEqual(len(td), 0)
2261 self.assertEqual(g(d), "sized")
2262 self.assertEqual(len(td), 1)
2263 self.assertEqual(td.get_ops, [list, dict, dict, list])
2264 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
2265 self.assertEqual(g(l), "list")
2266 self.assertEqual(len(td), 2)
2267 self.assertEqual(td.get_ops, [list, dict, dict, list])
2268 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2269 self.assertEqual(g(l), "list")
2270 self.assertEqual(g(d), "sized")
2271 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
2272 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2273 g.dispatch(list)
2274 g.dispatch(dict)
2275 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
2276 list, dict])
2277 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2278 c.MutableSet.register(X) # Will invalidate the cache.
2279 self.assertEqual(len(td), 2) # Stale cache.
2280 self.assertEqual(g(l), "list")
2281 self.assertEqual(len(td), 1)
2282 g.register(c.MutableMapping, lambda arg: "mutablemapping")
2283 self.assertEqual(len(td), 0)
2284 self.assertEqual(g(d), "mutablemapping")
2285 self.assertEqual(len(td), 1)
2286 self.assertEqual(g(l), "list")
2287 self.assertEqual(len(td), 2)
2288 g.register(dict, lambda arg: "dict")
2289 self.assertEqual(g(d), "dict")
2290 self.assertEqual(g(l), "list")
2291 g._clear_cache()
2292 self.assertEqual(len(td), 0)
Łukasz Langa6f692512013-06-05 12:20:24 +02002293
Łukasz Langae5697532017-12-11 13:56:31 -08002294 def test_annotations(self):
2295 @functools.singledispatch
2296 def i(arg):
2297 return "base"
2298 @i.register
2299 def _(arg: collections.abc.Mapping):
2300 return "mapping"
2301 @i.register
2302 def _(arg: "collections.abc.Sequence"):
2303 return "sequence"
2304 self.assertEqual(i(None), "base")
2305 self.assertEqual(i({"a": 1}), "mapping")
2306 self.assertEqual(i([1, 2, 3]), "sequence")
2307 self.assertEqual(i((1, 2, 3)), "sequence")
2308 self.assertEqual(i("str"), "sequence")
2309
2310 # Registering classes as callables doesn't work with annotations,
2311 # you need to pass the type explicitly.
2312 @i.register(str)
2313 class _:
2314 def __init__(self, arg):
2315 self.arg = arg
2316
2317 def __eq__(self, other):
2318 return self.arg == other
2319 self.assertEqual(i("str"), "str")
2320
Ethan Smithc6512752018-05-26 16:38:33 -04002321 def test_method_register(self):
2322 class A:
2323 @functools.singledispatchmethod
2324 def t(self, arg):
2325 self.arg = "base"
2326 @t.register(int)
2327 def _(self, arg):
2328 self.arg = "int"
2329 @t.register(str)
2330 def _(self, arg):
2331 self.arg = "str"
2332 a = A()
2333
2334 a.t(0)
2335 self.assertEqual(a.arg, "int")
2336 aa = A()
2337 self.assertFalse(hasattr(aa, 'arg'))
2338 a.t('')
2339 self.assertEqual(a.arg, "str")
2340 aa = A()
2341 self.assertFalse(hasattr(aa, 'arg'))
2342 a.t(0.0)
2343 self.assertEqual(a.arg, "base")
2344 aa = A()
2345 self.assertFalse(hasattr(aa, 'arg'))
2346
2347 def test_staticmethod_register(self):
2348 class A:
2349 @functools.singledispatchmethod
2350 @staticmethod
2351 def t(arg):
2352 return arg
2353 @t.register(int)
2354 @staticmethod
2355 def _(arg):
2356 return isinstance(arg, int)
2357 @t.register(str)
2358 @staticmethod
2359 def _(arg):
2360 return isinstance(arg, str)
2361 a = A()
2362
2363 self.assertTrue(A.t(0))
2364 self.assertTrue(A.t(''))
2365 self.assertEqual(A.t(0.0), 0.0)
2366
2367 def test_classmethod_register(self):
2368 class A:
2369 def __init__(self, arg):
2370 self.arg = arg
2371
2372 @functools.singledispatchmethod
2373 @classmethod
2374 def t(cls, arg):
2375 return cls("base")
2376 @t.register(int)
2377 @classmethod
2378 def _(cls, arg):
2379 return cls("int")
2380 @t.register(str)
2381 @classmethod
2382 def _(cls, arg):
2383 return cls("str")
2384
2385 self.assertEqual(A.t(0).arg, "int")
2386 self.assertEqual(A.t('').arg, "str")
2387 self.assertEqual(A.t(0.0).arg, "base")
2388
2389 def test_callable_register(self):
2390 class A:
2391 def __init__(self, arg):
2392 self.arg = arg
2393
2394 @functools.singledispatchmethod
2395 @classmethod
2396 def t(cls, arg):
2397 return cls("base")
2398
2399 @A.t.register(int)
2400 @classmethod
2401 def _(cls, arg):
2402 return cls("int")
2403 @A.t.register(str)
2404 @classmethod
2405 def _(cls, arg):
2406 return cls("str")
2407
2408 self.assertEqual(A.t(0).arg, "int")
2409 self.assertEqual(A.t('').arg, "str")
2410 self.assertEqual(A.t(0.0).arg, "base")
2411
2412 def test_abstractmethod_register(self):
2413 class Abstract(abc.ABCMeta):
2414
2415 @functools.singledispatchmethod
2416 @abc.abstractmethod
2417 def add(self, x, y):
2418 pass
2419
2420 self.assertTrue(Abstract.add.__isabstractmethod__)
2421
2422 def test_type_ann_register(self):
2423 class A:
2424 @functools.singledispatchmethod
2425 def t(self, arg):
2426 return "base"
2427 @t.register
2428 def _(self, arg: int):
2429 return "int"
2430 @t.register
2431 def _(self, arg: str):
2432 return "str"
2433 a = A()
2434
2435 self.assertEqual(a.t(0), "int")
2436 self.assertEqual(a.t(''), "str")
2437 self.assertEqual(a.t(0.0), "base")
2438
Łukasz Langae5697532017-12-11 13:56:31 -08002439 def test_invalid_registrations(self):
2440 msg_prefix = "Invalid first argument to `register()`: "
2441 msg_suffix = (
2442 ". Use either `@register(some_class)` or plain `@register` on an "
2443 "annotated function."
2444 )
2445 @functools.singledispatch
2446 def i(arg):
2447 return "base"
2448 with self.assertRaises(TypeError) as exc:
2449 @i.register(42)
2450 def _(arg):
2451 return "I annotated with a non-type"
2452 self.assertTrue(str(exc.exception).startswith(msg_prefix + "42"))
2453 self.assertTrue(str(exc.exception).endswith(msg_suffix))
2454 with self.assertRaises(TypeError) as exc:
2455 @i.register
2456 def _(arg):
2457 return "I forgot to annotate"
2458 self.assertTrue(str(exc.exception).startswith(msg_prefix +
2459 "<function TestSingleDispatch.test_invalid_registrations.<locals>._"
2460 ))
2461 self.assertTrue(str(exc.exception).endswith(msg_suffix))
2462
Łukasz Langae5697532017-12-11 13:56:31 -08002463 with self.assertRaises(TypeError) as exc:
2464 @i.register
2465 def _(arg: typing.Iterable[str]):
2466 # At runtime, dispatching on generics is impossible.
2467 # When registering implementations with singledispatch, avoid
2468 # types from `typing`. Instead, annotate with regular types
2469 # or ABCs.
2470 return "I annotated with a generic collection"
Lysandros Nikolaoud6738102019-05-20 00:11:21 +02002471 self.assertTrue(str(exc.exception).startswith(
2472 "Invalid annotation for 'arg'."
Łukasz Langae5697532017-12-11 13:56:31 -08002473 ))
Lysandros Nikolaoud6738102019-05-20 00:11:21 +02002474 self.assertTrue(str(exc.exception).endswith(
2475 'typing.Iterable[str] is not a class.'
2476 ))
Łukasz Langae5697532017-12-11 13:56:31 -08002477
Dong-hee Na445f1b32018-07-10 16:26:36 +09002478 def test_invalid_positional_argument(self):
2479 @functools.singledispatch
2480 def f(*args):
2481 pass
2482 msg = 'f requires at least 1 positional argument'
INADA Naoki56d8f572018-07-17 13:44:47 +09002483 with self.assertRaisesRegex(TypeError, msg):
Dong-hee Na445f1b32018-07-10 16:26:36 +09002484 f()
Łukasz Langa6f692512013-06-05 12:20:24 +02002485
Carl Meyerd658dea2018-08-28 01:11:56 -06002486
2487class CachedCostItem:
2488 _cost = 1
2489
2490 def __init__(self):
2491 self.lock = py_functools.RLock()
2492
2493 @py_functools.cached_property
2494 def cost(self):
2495 """The cost of the item."""
2496 with self.lock:
2497 self._cost += 1
2498 return self._cost
2499
2500
2501class OptionallyCachedCostItem:
2502 _cost = 1
2503
2504 def get_cost(self):
2505 """The cost of the item."""
2506 self._cost += 1
2507 return self._cost
2508
2509 cached_cost = py_functools.cached_property(get_cost)
2510
2511
2512class CachedCostItemWait:
2513
2514 def __init__(self, event):
2515 self._cost = 1
2516 self.lock = py_functools.RLock()
2517 self.event = event
2518
2519 @py_functools.cached_property
2520 def cost(self):
2521 self.event.wait(1)
2522 with self.lock:
2523 self._cost += 1
2524 return self._cost
2525
2526
2527class CachedCostItemWithSlots:
2528 __slots__ = ('_cost')
2529
2530 def __init__(self):
2531 self._cost = 1
2532
2533 @py_functools.cached_property
2534 def cost(self):
2535 raise RuntimeError('never called, slots not supported')
2536
2537
2538class TestCachedProperty(unittest.TestCase):
2539 def test_cached(self):
2540 item = CachedCostItem()
2541 self.assertEqual(item.cost, 2)
2542 self.assertEqual(item.cost, 2) # not 3
2543
2544 def test_cached_attribute_name_differs_from_func_name(self):
2545 item = OptionallyCachedCostItem()
2546 self.assertEqual(item.get_cost(), 2)
2547 self.assertEqual(item.cached_cost, 3)
2548 self.assertEqual(item.get_cost(), 4)
2549 self.assertEqual(item.cached_cost, 3)
2550
2551 def test_threaded(self):
2552 go = threading.Event()
2553 item = CachedCostItemWait(go)
2554
2555 num_threads = 3
2556
2557 orig_si = sys.getswitchinterval()
2558 sys.setswitchinterval(1e-6)
2559 try:
2560 threads = [
2561 threading.Thread(target=lambda: item.cost)
2562 for k in range(num_threads)
2563 ]
Hai Shie80697d2020-05-28 06:10:27 +08002564 with threading_helper.start_threads(threads):
Carl Meyerd658dea2018-08-28 01:11:56 -06002565 go.set()
2566 finally:
2567 sys.setswitchinterval(orig_si)
2568
2569 self.assertEqual(item.cost, 2)
2570
2571 def test_object_with_slots(self):
2572 item = CachedCostItemWithSlots()
2573 with self.assertRaisesRegex(
2574 TypeError,
2575 "No '__dict__' attribute on 'CachedCostItemWithSlots' instance to cache 'cost' property.",
2576 ):
2577 item.cost
2578
2579 def test_immutable_dict(self):
2580 class MyMeta(type):
2581 @py_functools.cached_property
2582 def prop(self):
2583 return True
2584
2585 class MyClass(metaclass=MyMeta):
2586 pass
2587
2588 with self.assertRaisesRegex(
2589 TypeError,
2590 "The '__dict__' attribute on 'MyMeta' instance does not support item assignment for caching 'prop' property.",
2591 ):
2592 MyClass.prop
2593
2594 def test_reuse_different_names(self):
2595 """Disallow this case because decorated function a would not be cached."""
2596 with self.assertRaises(RuntimeError) as ctx:
2597 class ReusedCachedProperty:
2598 @py_functools.cached_property
2599 def a(self):
2600 pass
2601
2602 b = a
2603
2604 self.assertEqual(
2605 str(ctx.exception.__context__),
2606 str(TypeError("Cannot assign the same cached_property to two different names ('a' and 'b')."))
2607 )
2608
2609 def test_reuse_same_name(self):
2610 """Reusing a cached_property on different classes under the same name is OK."""
2611 counter = 0
2612
2613 @py_functools.cached_property
2614 def _cp(_self):
2615 nonlocal counter
2616 counter += 1
2617 return counter
2618
2619 class A:
2620 cp = _cp
2621
2622 class B:
2623 cp = _cp
2624
2625 a = A()
2626 b = B()
2627
2628 self.assertEqual(a.cp, 1)
2629 self.assertEqual(b.cp, 2)
2630 self.assertEqual(a.cp, 1)
2631
2632 def test_set_name_not_called(self):
2633 cp = py_functools.cached_property(lambda s: None)
2634 class Foo:
2635 pass
2636
2637 Foo.cp = cp
2638
2639 with self.assertRaisesRegex(
2640 TypeError,
2641 "Cannot use cached_property instance without calling __set_name__ on it.",
2642 ):
2643 Foo().cp
2644
2645 def test_access_from_class(self):
2646 self.assertIsInstance(CachedCostItem.cost, py_functools.cached_property)
2647
2648 def test_doc(self):
2649 self.assertEqual(CachedCostItem.cost.__doc__, "The cost of the item.")
2650
2651
Raymond Hettinger9c323f82005-02-28 19:39:44 +00002652if __name__ == '__main__':
Zachary Ware38c707e2015-04-13 15:00:43 -05002653 unittest.main()