blob: e726188982bc4a40647ba069658a7b55d36c58e2 [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08002import builtins
Raymond Hettinger003be522011-05-03 11:01:32 -07003import collections
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03004import collections.abc
Serhiy Storchaka45120f22015-10-24 09:49:56 +03005import copy
Pablo Galindo2f172d82020-06-01 00:41:14 +01006from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00007import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00008from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02009import sys
10from test import support
Antoine Pitroua6a4dc82017-09-07 18:56:24 +020011import threading
Serhiy Storchaka67796522017-01-12 18:34:33 +020012import time
Łukasz Langae5697532017-12-11 13:56:31 -080013import typing
Łukasz Langa6f692512013-06-05 12:20:24 +020014import unittest
Raymond Hettingerd191ef22017-01-07 20:44:48 -080015import unittest.mock
Pablo Galindo99e6c262020-01-23 15:29:52 +000016import os
Dennis Sweeney1253c3e2020-05-05 17:14:32 -040017import weakref
18import gc
Łukasz Langa6f692512013-06-05 12:20:24 +020019from weakref import proxy
Nick Coghlan457fc9a2016-09-10 20:00:02 +100020import contextlib
Raymond Hettinger9c323f82005-02-28 19:39:44 +000021
Hai Shie80697d2020-05-28 06:10:27 +080022from test.support import threading_helper
Pablo Galindo99e6c262020-01-23 15:29:52 +000023from test.support.script_helper import assert_python_ok
24
Antoine Pitroub5b37142012-11-13 21:35:40 +010025import functools
26
Antoine Pitroub5b37142012-11-13 21:35:40 +010027py_functools = support.import_fresh_module('functools', blocked=['_functools'])
28c_functools = support.import_fresh_module('functools', fresh=['_functools'])
29
Łukasz Langa6f692512013-06-05 12:20:24 +020030decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
31
Nick Coghlan457fc9a2016-09-10 20:00:02 +100032@contextlib.contextmanager
33def replaced_module(name, replacement):
34 original_module = sys.modules[name]
35 sys.modules[name] = replacement
36 try:
37 yield
38 finally:
39 sys.modules[name] = original_module
Łukasz Langa6f692512013-06-05 12:20:24 +020040
Raymond Hettinger9c323f82005-02-28 19:39:44 +000041def capture(*args, **kw):
42 """capture all positional and keyword arguments"""
43 return args, kw
44
Łukasz Langa6f692512013-06-05 12:20:24 +020045
Jack Diederiche0cbd692009-04-01 04:27:09 +000046def signature(part):
47 """ return the signature of a partial object """
48 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000049
Serhiy Storchaka38741282016-02-02 18:45:17 +020050class MyTuple(tuple):
51 pass
52
53class BadTuple(tuple):
54 def __add__(self, other):
55 return list(self) + list(other)
56
57class MyDict(dict):
58 pass
59
Łukasz Langa6f692512013-06-05 12:20:24 +020060
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020061class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000062
63 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010064 p = self.partial(capture, 1, 2, a=10, b=20)
65 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000066 self.assertEqual(p(3, 4, b=30, c=40),
67 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010068 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000069 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000070
71 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010072 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000073 # attributes should be readable
74 self.assertEqual(p.func, capture)
75 self.assertEqual(p.args, (1, 2))
76 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000077
78 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010079 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000080 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010081 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000082 except TypeError:
83 pass
84 else:
85 self.fail('First arg not checked for callability')
86
87 def test_protection_of_callers_dict_argument(self):
88 # a caller's dictionary should not be altered by partial
89 def func(a=10, b=20):
90 return a
91 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010092 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000093 self.assertEqual(p(**d), 3)
94 self.assertEqual(d, {'a':3})
95 p(b=7)
96 self.assertEqual(d, {'a':3})
97
Serhiy Storchaka9639e4a2017-02-20 14:04:30 +020098 def test_kwargs_copy(self):
99 # Issue #29532: Altering a kwarg dictionary passed to a constructor
100 # should not affect a partial object after creation
101 d = {'a': 3}
102 p = self.partial(capture, **d)
103 self.assertEqual(p(), ((), {'a': 3}))
104 d['a'] = 5
105 self.assertEqual(p(), ((), {'a': 3}))
106
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000107 def test_arg_combinations(self):
108 # exercise special code paths for zero args in either partial
109 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100110 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000111 self.assertEqual(p(), ((), {}))
112 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100113 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000114 self.assertEqual(p(), ((1,2), {}))
115 self.assertEqual(p(3,4), ((1,2,3,4), {}))
116
117 def test_kw_combinations(self):
118 # exercise special code paths for no keyword args in
119 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100120 p = self.partial(capture)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400121 self.assertEqual(p.keywords, {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000122 self.assertEqual(p(), ((), {}))
123 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100124 p = self.partial(capture, a=1)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400125 self.assertEqual(p.keywords, {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000126 self.assertEqual(p(), ((), {'a':1}))
127 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
128 # keyword args in the call override those in the partial object
129 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
130
131 def test_positional(self):
132 # make sure positional arguments are captured correctly
133 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100134 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000135 expected = args + ('x',)
136 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000137 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000138
139 def test_keyword(self):
140 # make sure keyword arguments are captured correctly
141 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100142 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000143 expected = {'a':a,'x':None}
144 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000145 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000146
147 def test_no_side_effects(self):
148 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100149 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000150 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000151 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000152 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000153 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000154
155 def test_error_propagation(self):
156 def f(x, y):
157 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100158 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
159 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
160 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
161 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000162
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000163 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100164 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000165 p = proxy(f)
166 self.assertEqual(f.func, p.func)
167 f = None
168 self.assertRaises(ReferenceError, getattr, p, 'func')
169
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000170 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000171 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100172 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000173 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100174 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000175 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000176
Alexander Belopolskye49af342015-03-01 15:08:17 -0500177 def test_nested_optimization(self):
178 partial = self.partial
Alexander Belopolskye49af342015-03-01 15:08:17 -0500179 inner = partial(signature, 'asdf')
180 nested = partial(inner, bar=True)
181 flat = partial(signature, 'asdf', bar=True)
182 self.assertEqual(signature(nested), signature(flat))
183
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300184 def test_nested_partial_with_attribute(self):
185 # see issue 25137
186 partial = self.partial
187
188 def foo(bar):
189 return bar
190
191 p = partial(foo, 'first')
192 p2 = partial(p, 'second')
193 p2.new_attr = 'spam'
194 self.assertEqual(p2.new_attr, 'spam')
195
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000196 def test_repr(self):
197 args = (object(), object())
198 args_repr = ', '.join(repr(a) for a in args)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200199 kwargs = {'a': object(), 'b': object()}
200 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
201 'b={b!r}, a={a!r}'.format_map(kwargs)]
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000202 if self.partial in (c_functools.partial, py_functools.partial):
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000203 name = 'functools.partial'
204 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100205 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000206
Antoine Pitroub5b37142012-11-13 21:35:40 +0100207 f = self.partial(capture)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000208 self.assertEqual(f'{name}({capture!r})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000209
Antoine Pitroub5b37142012-11-13 21:35:40 +0100210 f = self.partial(capture, *args)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000211 self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000212
Antoine Pitroub5b37142012-11-13 21:35:40 +0100213 f = self.partial(capture, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200214 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000215 [f'{name}({capture!r}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200216 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000217
Antoine Pitroub5b37142012-11-13 21:35:40 +0100218 f = self.partial(capture, *args, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200219 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000220 [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200221 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000222
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300223 def test_recursive_repr(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000224 if self.partial in (c_functools.partial, py_functools.partial):
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300225 name = 'functools.partial'
226 else:
227 name = self.partial.__name__
228
229 f = self.partial(capture)
230 f.__setstate__((f, (), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300231 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000232 self.assertEqual(repr(f), '%s(...)' % (name,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300233 finally:
234 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300235
236 f = self.partial(capture)
237 f.__setstate__((capture, (f,), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300238 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000239 self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300240 finally:
241 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300242
243 f = self.partial(capture)
244 f.__setstate__((capture, (), {'a': f}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300245 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000246 self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300247 finally:
248 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300249
Jack Diederiche0cbd692009-04-01 04:27:09 +0000250 def test_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000251 with self.AllowPickle():
252 f = self.partial(signature, ['asdf'], bar=[True])
253 f.attr = []
254 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
255 f_copy = pickle.loads(pickle.dumps(f, proto))
256 self.assertEqual(signature(f_copy), signature(f))
Serhiy Storchaka38741282016-02-02 18:45:17 +0200257
258 def test_copy(self):
259 f = self.partial(signature, ['asdf'], bar=[True])
260 f.attr = []
261 f_copy = copy.copy(f)
262 self.assertEqual(signature(f_copy), signature(f))
263 self.assertIs(f_copy.attr, f.attr)
264 self.assertIs(f_copy.args, f.args)
265 self.assertIs(f_copy.keywords, f.keywords)
266
267 def test_deepcopy(self):
268 f = self.partial(signature, ['asdf'], bar=[True])
269 f.attr = []
270 f_copy = copy.deepcopy(f)
271 self.assertEqual(signature(f_copy), signature(f))
272 self.assertIsNot(f_copy.attr, f.attr)
273 self.assertIsNot(f_copy.args, f.args)
274 self.assertIsNot(f_copy.args[0], f.args[0])
275 self.assertIsNot(f_copy.keywords, f.keywords)
276 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
277
278 def test_setstate(self):
279 f = self.partial(signature)
280 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000281
Serhiy Storchaka38741282016-02-02 18:45:17 +0200282 self.assertEqual(signature(f),
283 (capture, (1,), dict(a=10), dict(attr=[])))
284 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
285
286 f.__setstate__((capture, (1,), dict(a=10), None))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000287
Serhiy Storchaka38741282016-02-02 18:45:17 +0200288 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
289 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
290
291 f.__setstate__((capture, (1,), None, None))
292 #self.assertEqual(signature(f), (capture, (1,), {}, {}))
293 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
294 self.assertEqual(f(2), ((1, 2), {}))
295 self.assertEqual(f(), ((1,), {}))
296
297 f.__setstate__((capture, (), {}, None))
298 self.assertEqual(signature(f), (capture, (), {}, {}))
299 self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
300 self.assertEqual(f(2), ((2,), {}))
301 self.assertEqual(f(), ((), {}))
302
303 def test_setstate_errors(self):
304 f = self.partial(signature)
305 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
306 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
307 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
308 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
309 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
310 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
311 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
312
313 def test_setstate_subclasses(self):
314 f = self.partial(signature)
315 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
316 s = signature(f)
317 self.assertEqual(s, (capture, (1,), dict(a=10), {}))
318 self.assertIs(type(s[1]), tuple)
319 self.assertIs(type(s[2]), dict)
320 r = f()
321 self.assertEqual(r, ((1,), {'a': 10}))
322 self.assertIs(type(r[0]), tuple)
323 self.assertIs(type(r[1]), dict)
324
325 f.__setstate__((capture, BadTuple((1,)), {}, None))
326 s = signature(f)
327 self.assertEqual(s, (capture, (1,), {}, {}))
328 self.assertIs(type(s[1]), tuple)
329 r = f(2)
330 self.assertEqual(r, ((1, 2), {}))
331 self.assertIs(type(r[0]), tuple)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000332
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300333 def test_recursive_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000334 with self.AllowPickle():
335 f = self.partial(capture)
336 f.__setstate__((f, (), {}, {}))
337 try:
338 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
339 with self.assertRaises(RecursionError):
340 pickle.dumps(f, proto)
341 finally:
342 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300343
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000344 f = self.partial(capture)
345 f.__setstate__((capture, (f,), {}, {}))
346 try:
347 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
348 f_copy = pickle.loads(pickle.dumps(f, proto))
349 try:
350 self.assertIs(f_copy.args[0], f_copy)
351 finally:
352 f_copy.__setstate__((capture, (), {}, {}))
353 finally:
354 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300355
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000356 f = self.partial(capture)
357 f.__setstate__((capture, (), {'a': f}, {}))
358 try:
359 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
360 f_copy = pickle.loads(pickle.dumps(f, proto))
361 try:
362 self.assertIs(f_copy.keywords['a'], f_copy)
363 finally:
364 f_copy.__setstate__((capture, (), {}, {}))
365 finally:
366 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300367
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200368 # Issue 6083: Reference counting bug
369 def test_setstate_refcount(self):
370 class BadSequence:
371 def __len__(self):
372 return 4
373 def __getitem__(self, key):
374 if key == 0:
375 return max
376 elif key == 1:
377 return tuple(range(1000000))
378 elif key in (2, 3):
379 return {}
380 raise IndexError
381
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200382 f = self.partial(object)
Serhiy Storchaka38741282016-02-02 18:45:17 +0200383 self.assertRaises(TypeError, f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000384
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000385@unittest.skipUnless(c_functools, 'requires the C _functools module')
386class TestPartialC(TestPartial, unittest.TestCase):
387 if c_functools:
388 partial = c_functools.partial
389
390 class AllowPickle:
391 def __enter__(self):
392 return self
393 def __exit__(self, type, value, tb):
394 return False
395
396 def test_attributes_unwritable(self):
397 # attributes should not be writable
398 p = self.partial(capture, 1, 2, a=10, b=20)
399 self.assertRaises(AttributeError, setattr, p, 'func', map)
400 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
401 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
402
403 p = self.partial(hex)
404 try:
405 del p.__dict__
406 except TypeError:
407 pass
408 else:
409 self.fail('partial object allowed __dict__ to be deleted')
Łukasz Langa6f692512013-06-05 12:20:24 +0200410
Michael Seifert6c3d5272017-03-15 06:26:33 +0100411 def test_manually_adding_non_string_keyword(self):
412 p = self.partial(capture)
413 # Adding a non-string/unicode keyword to partial kwargs
414 p.keywords[1234] = 'value'
415 r = repr(p)
416 self.assertIn('1234', r)
417 self.assertIn("'value'", r)
418 with self.assertRaises(TypeError):
419 p()
420
421 def test_keystr_replaces_value(self):
422 p = self.partial(capture)
423
424 class MutatesYourDict(object):
425 def __str__(self):
426 p.keywords[self] = ['sth2']
427 return 'astr'
428
Mike53f7a7c2017-12-14 14:04:53 +0300429 # Replacing the value during key formatting should keep the original
Michael Seifert6c3d5272017-03-15 06:26:33 +0100430 # value alive (at least long enough).
431 p.keywords[MutatesYourDict()] = ['sth']
432 r = repr(p)
433 self.assertIn('astr', r)
434 self.assertIn("['sth']", r)
435
436
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200437class TestPartialPy(TestPartial, unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000438 partial = py_functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000439
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000440 class AllowPickle:
441 def __init__(self):
442 self._cm = replaced_module("functools", py_functools)
443 def __enter__(self):
444 return self._cm.__enter__()
445 def __exit__(self, type, value, tb):
446 return self._cm.__exit__(type, value, tb)
Łukasz Langa6f692512013-06-05 12:20:24 +0200447
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200448if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000449 class CPartialSubclass(c_functools.partial):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200450 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100451
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000452class PyPartialSubclass(py_functools.partial):
453 pass
Łukasz Langa6f692512013-06-05 12:20:24 +0200454
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200455@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200456class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200457 if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000458 partial = CPartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000459
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300460 # partial subclasses are not optimized for nested calls
461 test_nested_optimization = None
462
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000463class TestPartialPySubclass(TestPartialPy):
464 partial = PyPartialSubclass
Łukasz Langa6f692512013-06-05 12:20:24 +0200465
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000466class TestPartialMethod(unittest.TestCase):
467
468 class A(object):
469 nothing = functools.partialmethod(capture)
470 positional = functools.partialmethod(capture, 1)
471 keywords = functools.partialmethod(capture, a=2)
472 both = functools.partialmethod(capture, 3, b=4)
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300473 spec_keywords = functools.partialmethod(capture, self=1, func=2)
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000474
475 nested = functools.partialmethod(positional, 5)
476
477 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
478
479 static = functools.partialmethod(staticmethod(capture), 8)
480 cls = functools.partialmethod(classmethod(capture), d=9)
481
482 a = A()
483
484 def test_arg_combinations(self):
485 self.assertEqual(self.a.nothing(), ((self.a,), {}))
486 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
487 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
488 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
489
490 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
491 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
492 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
493 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
494
495 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
496 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
497 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
498 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
499
500 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
501 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
502 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
503 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
504
505 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
506
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300507 self.assertEqual(self.a.spec_keywords(), ((self.a,), {'self': 1, 'func': 2}))
508
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000509 def test_nested(self):
510 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
511 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
512 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
513 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
514
515 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
516
517 def test_over_partial(self):
518 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
519 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
520 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
521 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
522
523 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
524
525 def test_bound_method_introspection(self):
526 obj = self.a
527 self.assertIs(obj.both.__self__, obj)
528 self.assertIs(obj.nested.__self__, obj)
529 self.assertIs(obj.over_partial.__self__, obj)
530 self.assertIs(obj.cls.__self__, self.A)
531 self.assertIs(self.A.cls.__self__, self.A)
532
533 def test_unbound_method_retrieval(self):
534 obj = self.A
535 self.assertFalse(hasattr(obj.both, "__self__"))
536 self.assertFalse(hasattr(obj.nested, "__self__"))
537 self.assertFalse(hasattr(obj.over_partial, "__self__"))
538 self.assertFalse(hasattr(obj.static, "__self__"))
539 self.assertFalse(hasattr(self.a.static, "__self__"))
540
541 def test_descriptors(self):
542 for obj in [self.A, self.a]:
543 with self.subTest(obj=obj):
544 self.assertEqual(obj.static(), ((8,), {}))
545 self.assertEqual(obj.static(5), ((8, 5), {}))
546 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
547 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
548
549 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
550 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
551 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
552 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
553
554 def test_overriding_keywords(self):
555 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
556 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
557
558 def test_invalid_args(self):
559 with self.assertRaises(TypeError):
560 class B(object):
561 method = functools.partialmethod(None, 1)
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300562 with self.assertRaises(TypeError):
563 class B:
564 method = functools.partialmethod()
Serhiy Storchaka142566c2019-06-05 18:22:31 +0300565 with self.assertRaises(TypeError):
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300566 class B:
567 method = functools.partialmethod(func=capture, a=1)
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000568
569 def test_repr(self):
570 self.assertEqual(repr(vars(self.A)['both']),
571 'functools.partialmethod({}, 3, b=4)'.format(capture))
572
573 def test_abstract(self):
574 class Abstract(abc.ABCMeta):
575
576 @abc.abstractmethod
577 def add(self, x, y):
578 pass
579
580 add5 = functools.partialmethod(add, 5)
581
582 self.assertTrue(Abstract.add.__isabstractmethod__)
583 self.assertTrue(Abstract.add5.__isabstractmethod__)
584
585 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
586 self.assertFalse(getattr(func, '__isabstractmethod__', False))
587
Pablo Galindo8c77b8c2019-04-29 13:36:57 +0100588 def test_positional_only(self):
589 def f(a, b, /):
590 return a + b
591
592 p = functools.partial(f, 1)
593 self.assertEqual(p(2), f(1, 2))
594
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000595
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000596class TestUpdateWrapper(unittest.TestCase):
597
598 def check_wrapper(self, wrapper, wrapped,
599 assigned=functools.WRAPPER_ASSIGNMENTS,
600 updated=functools.WRAPPER_UPDATES):
601 # Check attributes were assigned
602 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000603 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000604 # Check attributes were updated
605 for name in updated:
606 wrapper_attr = getattr(wrapper, name)
607 wrapped_attr = getattr(wrapped, name)
608 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000609 if name == "__dict__" and key == "__wrapped__":
610 # __wrapped__ is overwritten by the update code
611 continue
612 self.assertIs(wrapped_attr[key], wrapper_attr[key])
613 # Check __wrapped__
614 self.assertIs(wrapper.__wrapped__, wrapped)
615
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000616
R. David Murray378c0cf2010-02-24 01:46:21 +0000617 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000618 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000619 """This is a test"""
620 pass
621 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000622 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000623 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000624 pass
625 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000626 return wrapper, f
627
628 def test_default_update(self):
629 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000630 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000631 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000632 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600633 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000634 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000635 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
636 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000637
R. David Murray378c0cf2010-02-24 01:46:21 +0000638 @unittest.skipIf(sys.flags.optimize >= 2,
639 "Docstrings are omitted with -O2 and above")
640 def test_default_update_doc(self):
641 wrapper, f = self._default_update()
642 self.assertEqual(wrapper.__doc__, 'This is a test')
643
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000644 def test_no_update(self):
645 def f():
646 """This is a test"""
647 pass
648 f.attr = 'This is also a test'
649 def wrapper():
650 pass
651 functools.update_wrapper(wrapper, f, (), ())
652 self.check_wrapper(wrapper, f, (), ())
653 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600654 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000655 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000656 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000657 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000658
659 def test_selective_update(self):
660 def f():
661 pass
662 f.attr = 'This is a different test'
663 f.dict_attr = dict(a=1, b=2, c=3)
664 def wrapper():
665 pass
666 wrapper.dict_attr = {}
667 assign = ('attr',)
668 update = ('dict_attr',)
669 functools.update_wrapper(wrapper, f, assign, update)
670 self.check_wrapper(wrapper, f, assign, update)
671 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600672 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000673 self.assertEqual(wrapper.__doc__, None)
674 self.assertEqual(wrapper.attr, 'This is a different test')
675 self.assertEqual(wrapper.dict_attr, f.dict_attr)
676
Nick Coghlan98876832010-08-17 06:17:18 +0000677 def test_missing_attributes(self):
678 def f():
679 pass
680 def wrapper():
681 pass
682 wrapper.dict_attr = {}
683 assign = ('attr',)
684 update = ('dict_attr',)
685 # Missing attributes on wrapped object are ignored
686 functools.update_wrapper(wrapper, f, assign, update)
687 self.assertNotIn('attr', wrapper.__dict__)
688 self.assertEqual(wrapper.dict_attr, {})
689 # Wrapper must have expected attributes for updating
690 del wrapper.dict_attr
691 with self.assertRaises(AttributeError):
692 functools.update_wrapper(wrapper, f, assign, update)
693 wrapper.dict_attr = 1
694 with self.assertRaises(AttributeError):
695 functools.update_wrapper(wrapper, f, assign, update)
696
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200697 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000698 @unittest.skipIf(sys.flags.optimize >= 2,
699 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000700 def test_builtin_update(self):
701 # Test for bug #1576241
702 def wrapper():
703 pass
704 functools.update_wrapper(wrapper, max)
705 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000706 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000707 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000708
Łukasz Langa6f692512013-06-05 12:20:24 +0200709
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000710class TestWraps(TestUpdateWrapper):
711
R. David Murray378c0cf2010-02-24 01:46:21 +0000712 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000713 def f():
714 """This is a test"""
715 pass
716 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000717 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000718 @functools.wraps(f)
719 def wrapper():
720 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600721 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000722
723 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600724 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000725 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000726 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600727 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000728 self.assertEqual(wrapper.attr, 'This is also a test')
729
Antoine Pitroub5b37142012-11-13 21:35:40 +0100730 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000731 "Docstrings are omitted with -O2 and above")
732 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600733 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000734 self.assertEqual(wrapper.__doc__, 'This is a test')
735
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000736 def test_no_update(self):
737 def f():
738 """This is a test"""
739 pass
740 f.attr = 'This is also a test'
741 @functools.wraps(f, (), ())
742 def wrapper():
743 pass
744 self.check_wrapper(wrapper, f, (), ())
745 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600746 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000747 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000748 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000749
750 def test_selective_update(self):
751 def f():
752 pass
753 f.attr = 'This is a different test'
754 f.dict_attr = dict(a=1, b=2, c=3)
755 def add_dict_attr(f):
756 f.dict_attr = {}
757 return f
758 assign = ('attr',)
759 update = ('dict_attr',)
760 @functools.wraps(f, assign, update)
761 @add_dict_attr
762 def wrapper():
763 pass
764 self.check_wrapper(wrapper, f, assign, update)
765 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600766 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000767 self.assertEqual(wrapper.__doc__, None)
768 self.assertEqual(wrapper.attr, 'This is a different test')
769 self.assertEqual(wrapper.dict_attr, f.dict_attr)
770
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000771
madman-bobe25d5fc2018-10-25 15:02:10 +0100772class TestReduce:
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000773 def test_reduce(self):
774 class Squares:
775 def __init__(self, max):
776 self.max = max
777 self.sofar = []
778
779 def __len__(self):
780 return len(self.sofar)
781
782 def __getitem__(self, i):
783 if not 0 <= i < self.max: raise IndexError
784 n = len(self.sofar)
785 while n <= i:
786 self.sofar.append(n*n)
787 n += 1
788 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000789 def add(x, y):
790 return x + y
madman-bobe25d5fc2018-10-25 15:02:10 +0100791 self.assertEqual(self.reduce(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000792 self.assertEqual(
madman-bobe25d5fc2018-10-25 15:02:10 +0100793 self.reduce(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000794 ['a','c','d','w']
795 )
madman-bobe25d5fc2018-10-25 15:02:10 +0100796 self.assertEqual(self.reduce(lambda x, y: x*y, range(2,8), 1), 5040)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000797 self.assertEqual(
madman-bobe25d5fc2018-10-25 15:02:10 +0100798 self.reduce(lambda x, y: x*y, range(2,21), 1),
Guido van Rossume2a383d2007-01-15 16:59:06 +0000799 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000800 )
madman-bobe25d5fc2018-10-25 15:02:10 +0100801 self.assertEqual(self.reduce(add, Squares(10)), 285)
802 self.assertEqual(self.reduce(add, Squares(10), 0), 285)
803 self.assertEqual(self.reduce(add, Squares(0), 0), 0)
804 self.assertRaises(TypeError, self.reduce)
805 self.assertRaises(TypeError, self.reduce, 42, 42)
806 self.assertRaises(TypeError, self.reduce, 42, 42, 42)
807 self.assertEqual(self.reduce(42, "1"), "1") # func is never called with one item
808 self.assertEqual(self.reduce(42, "", "1"), "1") # func is never called with one item
809 self.assertRaises(TypeError, self.reduce, 42, (42, 42))
810 self.assertRaises(TypeError, self.reduce, add, []) # arg 2 must not be empty sequence with no initial value
811 self.assertRaises(TypeError, self.reduce, add, "")
812 self.assertRaises(TypeError, self.reduce, add, ())
813 self.assertRaises(TypeError, self.reduce, add, object())
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000814
815 class TestFailingIter:
816 def __iter__(self):
817 raise RuntimeError
madman-bobe25d5fc2018-10-25 15:02:10 +0100818 self.assertRaises(RuntimeError, self.reduce, add, TestFailingIter())
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000819
madman-bobe25d5fc2018-10-25 15:02:10 +0100820 self.assertEqual(self.reduce(add, [], None), None)
821 self.assertEqual(self.reduce(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000822
823 class BadSeq:
824 def __getitem__(self, index):
825 raise ValueError
madman-bobe25d5fc2018-10-25 15:02:10 +0100826 self.assertRaises(ValueError, self.reduce, 42, BadSeq())
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000827
828 # Test reduce()'s use of iterators.
829 def test_iterator_usage(self):
830 class SequenceClass:
831 def __init__(self, n):
832 self.n = n
833 def __getitem__(self, i):
834 if 0 <= i < self.n:
835 return i
836 else:
837 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000838
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000839 from operator import add
madman-bobe25d5fc2018-10-25 15:02:10 +0100840 self.assertEqual(self.reduce(add, SequenceClass(5)), 10)
841 self.assertEqual(self.reduce(add, SequenceClass(5), 42), 52)
842 self.assertRaises(TypeError, self.reduce, add, SequenceClass(0))
843 self.assertEqual(self.reduce(add, SequenceClass(0), 42), 42)
844 self.assertEqual(self.reduce(add, SequenceClass(1)), 0)
845 self.assertEqual(self.reduce(add, SequenceClass(1), 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000846
847 d = {"one": 1, "two": 2, "three": 3}
madman-bobe25d5fc2018-10-25 15:02:10 +0100848 self.assertEqual(self.reduce(add, d), "".join(d.keys()))
849
850
851@unittest.skipUnless(c_functools, 'requires the C _functools module')
852class TestReduceC(TestReduce, unittest.TestCase):
853 if c_functools:
854 reduce = c_functools.reduce
855
856
857class TestReducePy(TestReduce, unittest.TestCase):
858 reduce = staticmethod(py_functools.reduce)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000859
Łukasz Langa6f692512013-06-05 12:20:24 +0200860
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200861class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700862
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000863 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700864 def cmp1(x, y):
865 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100866 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700867 self.assertEqual(key(3), key(3))
868 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100869 self.assertGreaterEqual(key(3), key(3))
870
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700871 def cmp2(x, y):
872 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100873 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700874 self.assertEqual(key(4.0), key('4'))
875 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100876 self.assertLessEqual(key(2), key('35'))
877 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700878
879 def test_cmp_to_key_arguments(self):
880 def cmp1(x, y):
881 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100882 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700883 self.assertEqual(key(obj=3), key(obj=3))
884 self.assertGreater(key(obj=3), key(obj=1))
885 with self.assertRaises((TypeError, AttributeError)):
886 key(3) > 1 # rhs is not a K object
887 with self.assertRaises((TypeError, AttributeError)):
888 1 < key(3) # lhs is not a K object
889 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100890 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700891 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200892 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100893 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700894 with self.assertRaises(TypeError):
895 key() # too few args
896 with self.assertRaises(TypeError):
897 key(None, None) # too many args
898
899 def test_bad_cmp(self):
900 def cmp1(x, y):
901 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100902 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700903 with self.assertRaises(ZeroDivisionError):
904 key(3) > key(1)
905
906 class BadCmp:
907 def __lt__(self, other):
908 raise ZeroDivisionError
909 def cmp1(x, y):
910 return BadCmp()
911 with self.assertRaises(ZeroDivisionError):
912 key(3) > key(1)
913
914 def test_obj_field(self):
915 def cmp1(x, y):
916 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100917 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700918 self.assertEqual(key(50).obj, 50)
919
920 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000921 def mycmp(x, y):
922 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100923 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000924 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000925
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700926 def test_sort_int_str(self):
927 def mycmp(x, y):
928 x, y = int(x), int(y)
929 return (x > y) - (x < y)
930 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100931 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700932 self.assertEqual([int(value) for value in values],
933 [0, 1, 1, 2, 3, 4, 5, 7, 10])
934
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000935 def test_hash(self):
936 def mycmp(x, y):
937 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100938 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000939 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700940 self.assertRaises(TypeError, hash, k)
Serhiy Storchaka2e576f52017-04-24 09:05:00 +0300941 self.assertNotIsInstance(k, collections.abc.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000942
Łukasz Langa6f692512013-06-05 12:20:24 +0200943
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200944@unittest.skipUnless(c_functools, 'requires the C _functools module')
945class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
946 if c_functools:
947 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100948
Łukasz Langa6f692512013-06-05 12:20:24 +0200949
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200950class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100951 cmp_to_key = staticmethod(py_functools.cmp_to_key)
952
Łukasz Langa6f692512013-06-05 12:20:24 +0200953
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000954class TestTotalOrdering(unittest.TestCase):
955
956 def test_total_ordering_lt(self):
957 @functools.total_ordering
958 class A:
959 def __init__(self, value):
960 self.value = value
961 def __lt__(self, other):
962 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000963 def __eq__(self, other):
964 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000965 self.assertTrue(A(1) < A(2))
966 self.assertTrue(A(2) > A(1))
967 self.assertTrue(A(1) <= A(2))
968 self.assertTrue(A(2) >= A(1))
969 self.assertTrue(A(2) <= A(2))
970 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000971 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000972
973 def test_total_ordering_le(self):
974 @functools.total_ordering
975 class A:
976 def __init__(self, value):
977 self.value = value
978 def __le__(self, other):
979 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000980 def __eq__(self, other):
981 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000982 self.assertTrue(A(1) < A(2))
983 self.assertTrue(A(2) > A(1))
984 self.assertTrue(A(1) <= A(2))
985 self.assertTrue(A(2) >= A(1))
986 self.assertTrue(A(2) <= A(2))
987 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000988 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000989
990 def test_total_ordering_gt(self):
991 @functools.total_ordering
992 class A:
993 def __init__(self, value):
994 self.value = value
995 def __gt__(self, other):
996 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000997 def __eq__(self, other):
998 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000999 self.assertTrue(A(1) < A(2))
1000 self.assertTrue(A(2) > A(1))
1001 self.assertTrue(A(1) <= A(2))
1002 self.assertTrue(A(2) >= A(1))
1003 self.assertTrue(A(2) <= A(2))
1004 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +10001005 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +00001006
1007 def test_total_ordering_ge(self):
1008 @functools.total_ordering
1009 class A:
1010 def __init__(self, value):
1011 self.value = value
1012 def __ge__(self, other):
1013 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001014 def __eq__(self, other):
1015 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +00001016 self.assertTrue(A(1) < A(2))
1017 self.assertTrue(A(2) > A(1))
1018 self.assertTrue(A(1) <= A(2))
1019 self.assertTrue(A(2) >= A(1))
1020 self.assertTrue(A(2) <= A(2))
1021 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +10001022 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +00001023
1024 def test_total_ordering_no_overwrite(self):
1025 # new methods should not overwrite existing
1026 @functools.total_ordering
1027 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +00001028 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +00001029 self.assertTrue(A(1) < A(2))
1030 self.assertTrue(A(2) > A(1))
1031 self.assertTrue(A(1) <= A(2))
1032 self.assertTrue(A(2) >= A(1))
1033 self.assertTrue(A(2) <= A(2))
1034 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001035
Benjamin Peterson42ebee32010-04-11 01:43:16 +00001036 def test_no_operations_defined(self):
1037 with self.assertRaises(ValueError):
1038 @functools.total_ordering
1039 class A:
1040 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001041
Nick Coghlanf05d9812013-10-02 00:02:03 +10001042 def test_type_error_when_not_implemented(self):
1043 # bug 10042; ensure stack overflow does not occur
1044 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001045 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +10001046 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001047 def __init__(self, value):
1048 self.value = value
1049 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001050 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001051 return self.value == other.value
1052 return False
1053 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001054 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001055 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +10001056 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001057
Nick Coghlanf05d9812013-10-02 00:02:03 +10001058 @functools.total_ordering
1059 class ImplementsGreaterThan:
1060 def __init__(self, value):
1061 self.value = value
1062 def __eq__(self, other):
1063 if isinstance(other, ImplementsGreaterThan):
1064 return self.value == other.value
1065 return False
1066 def __gt__(self, other):
1067 if isinstance(other, ImplementsGreaterThan):
1068 return self.value > other.value
1069 return NotImplemented
1070
1071 @functools.total_ordering
1072 class ImplementsLessThanEqualTo:
1073 def __init__(self, value):
1074 self.value = value
1075 def __eq__(self, other):
1076 if isinstance(other, ImplementsLessThanEqualTo):
1077 return self.value == other.value
1078 return False
1079 def __le__(self, other):
1080 if isinstance(other, ImplementsLessThanEqualTo):
1081 return self.value <= other.value
1082 return NotImplemented
1083
1084 @functools.total_ordering
1085 class ImplementsGreaterThanEqualTo:
1086 def __init__(self, value):
1087 self.value = value
1088 def __eq__(self, other):
1089 if isinstance(other, ImplementsGreaterThanEqualTo):
1090 return self.value == other.value
1091 return False
1092 def __ge__(self, other):
1093 if isinstance(other, ImplementsGreaterThanEqualTo):
1094 return self.value >= other.value
1095 return NotImplemented
1096
1097 @functools.total_ordering
1098 class ComparatorNotImplemented:
1099 def __init__(self, value):
1100 self.value = value
1101 def __eq__(self, other):
1102 if isinstance(other, ComparatorNotImplemented):
1103 return self.value == other.value
1104 return False
1105 def __lt__(self, other):
1106 return NotImplemented
1107
1108 with self.subTest("LT < 1"), self.assertRaises(TypeError):
1109 ImplementsLessThan(-1) < 1
1110
1111 with self.subTest("LT < LE"), self.assertRaises(TypeError):
1112 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
1113
1114 with self.subTest("LT < GT"), self.assertRaises(TypeError):
1115 ImplementsLessThan(1) < ImplementsGreaterThan(1)
1116
1117 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
1118 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
1119
1120 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
1121 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
1122
1123 with self.subTest("GT > GE"), self.assertRaises(TypeError):
1124 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
1125
1126 with self.subTest("GT > LT"), self.assertRaises(TypeError):
1127 ImplementsGreaterThan(5) > ImplementsLessThan(5)
1128
1129 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
1130 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
1131
1132 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
1133 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
1134
1135 with self.subTest("GE when equal"):
1136 a = ComparatorNotImplemented(8)
1137 b = ComparatorNotImplemented(8)
1138 self.assertEqual(a, b)
1139 with self.assertRaises(TypeError):
1140 a >= b
1141
1142 with self.subTest("LE when equal"):
1143 a = ComparatorNotImplemented(9)
1144 b = ComparatorNotImplemented(9)
1145 self.assertEqual(a, b)
1146 with self.assertRaises(TypeError):
1147 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +02001148
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001149 def test_pickle(self):
Serhiy Storchaka92bb90a2016-09-22 11:39:25 +03001150 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001151 for name in '__lt__', '__gt__', '__le__', '__ge__':
1152 with self.subTest(method=name, proto=proto):
1153 method = getattr(Orderable_LT, name)
1154 method_copy = pickle.loads(pickle.dumps(method, proto))
1155 self.assertIs(method_copy, method)
1156
1157@functools.total_ordering
1158class Orderable_LT:
1159 def __init__(self, value):
1160 self.value = value
1161 def __lt__(self, other):
1162 return self.value < other.value
1163 def __eq__(self, other):
1164 return self.value == other.value
1165
1166
Raymond Hettinger21cdb712020-05-11 17:00:53 -07001167class TestCache:
1168 # This tests that the pass-through is working as designed.
1169 # The underlying functionality is tested in TestLRU.
1170
1171 def test_cache(self):
1172 @self.module.cache
1173 def fib(n):
1174 if n < 2:
1175 return n
1176 return fib(n-1) + fib(n-2)
1177 self.assertEqual([fib(n) for n in range(16)],
1178 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1179 self.assertEqual(fib.cache_info(),
1180 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1181 fib.cache_clear()
1182 self.assertEqual(fib.cache_info(),
1183 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1184
1185
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001186class TestLRU:
Georg Brandl2e7346a2010-07-31 18:09:23 +00001187
1188 def test_lru(self):
1189 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001190 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001191 f = self.module.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001192 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001193 self.assertEqual(maxsize, 20)
1194 self.assertEqual(currsize, 0)
1195 self.assertEqual(hits, 0)
1196 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001197
1198 domain = range(5)
1199 for i in range(1000):
1200 x, y = choice(domain), choice(domain)
1201 actual = f(x, y)
1202 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +00001203 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001204 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001205 self.assertTrue(hits > misses)
1206 self.assertEqual(hits + misses, 1000)
1207 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001208
Raymond Hettinger02566ec2010-09-04 22:46:06 +00001209 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +00001210 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001211 self.assertEqual(hits, 0)
1212 self.assertEqual(misses, 0)
1213 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001214 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001215 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001216 self.assertEqual(hits, 0)
1217 self.assertEqual(misses, 1)
1218 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001219
Nick Coghlan98876832010-08-17 06:17:18 +00001220 # Test bypassing the cache
1221 self.assertIs(f.__wrapped__, orig)
1222 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001223 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001224 self.assertEqual(hits, 0)
1225 self.assertEqual(misses, 1)
1226 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +00001227
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001228 # test size zero (which means "never-cache")
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001229 @self.module.lru_cache(0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001230 def f():
1231 nonlocal f_cnt
1232 f_cnt += 1
1233 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001234 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001235 f_cnt = 0
1236 for i in range(5):
1237 self.assertEqual(f(), 20)
1238 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001239 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001240 self.assertEqual(hits, 0)
1241 self.assertEqual(misses, 5)
1242 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001243
1244 # test size one
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001245 @self.module.lru_cache(1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001246 def f():
1247 nonlocal f_cnt
1248 f_cnt += 1
1249 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001250 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001251 f_cnt = 0
1252 for i in range(5):
1253 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001254 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001255 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001256 self.assertEqual(hits, 4)
1257 self.assertEqual(misses, 1)
1258 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001259
Raymond Hettingerf3098282010-08-15 03:30:45 +00001260 # test size two
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001261 @self.module.lru_cache(2)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001262 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001263 nonlocal f_cnt
1264 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +00001265 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +00001266 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001267 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +00001268 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1269 # * * * *
1270 self.assertEqual(f(x), x*10)
1271 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001272 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001273 self.assertEqual(hits, 12)
1274 self.assertEqual(misses, 4)
1275 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001276
Raymond Hettingerb8218682019-05-26 11:27:35 -07001277 def test_lru_no_args(self):
1278 @self.module.lru_cache
1279 def square(x):
1280 return x ** 2
1281
1282 self.assertEqual(list(map(square, [10, 20, 10])),
1283 [100, 400, 100])
1284 self.assertEqual(square.cache_info().hits, 1)
1285 self.assertEqual(square.cache_info().misses, 2)
1286 self.assertEqual(square.cache_info().maxsize, 128)
1287 self.assertEqual(square.cache_info().currsize, 2)
1288
Raymond Hettingerd8080c02019-01-26 03:02:00 -05001289 def test_lru_bug_35780(self):
1290 # C version of the lru_cache was not checking to see if
1291 # the user function call has already modified the cache
1292 # (this arises in recursive calls and in multi-threading).
1293 # This cause the cache to have orphan links not referenced
1294 # by the cache dictionary.
1295
1296 once = True # Modified by f(x) below
1297
1298 @self.module.lru_cache(maxsize=10)
1299 def f(x):
1300 nonlocal once
1301 rv = f'.{x}.'
1302 if x == 20 and once:
1303 once = False
1304 rv = f(x)
1305 return rv
1306
1307 # Fill the cache
1308 for x in range(15):
1309 self.assertEqual(f(x), f'.{x}.')
1310 self.assertEqual(f.cache_info().currsize, 10)
1311
1312 # Make a recursive call and make sure the cache remains full
1313 self.assertEqual(f(20), '.20.')
1314 self.assertEqual(f.cache_info().currsize, 10)
1315
Raymond Hettinger14adbd42019-04-20 07:20:44 -10001316 def test_lru_bug_36650(self):
1317 # C version of lru_cache was treating a call with an empty **kwargs
1318 # dictionary as being distinct from a call with no keywords at all.
1319 # This did not result in an incorrect answer, but it did trigger
1320 # an unexpected cache miss.
1321
1322 @self.module.lru_cache()
1323 def f(x):
1324 pass
1325
1326 f(0)
1327 f(0, **{})
1328 self.assertEqual(f.cache_info().hits, 1)
1329
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001330 def test_lru_hash_only_once(self):
1331 # To protect against weird reentrancy bugs and to improve
1332 # efficiency when faced with slow __hash__ methods, the
1333 # LRU cache guarantees that it will only call __hash__
1334 # only once per use as an argument to the cached function.
1335
1336 @self.module.lru_cache(maxsize=1)
1337 def f(x, y):
1338 return x * 3 + y
1339
1340 # Simulate the integer 5
1341 mock_int = unittest.mock.Mock()
1342 mock_int.__mul__ = unittest.mock.Mock(return_value=15)
1343 mock_int.__hash__ = unittest.mock.Mock(return_value=999)
1344
1345 # Add to cache: One use as an argument gives one call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001346 self.assertEqual(f(mock_int, 1), 16)
1347 self.assertEqual(mock_int.__hash__.call_count, 1)
1348 self.assertEqual(f.cache_info(), (0, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001349
1350 # Cache hit: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001351 self.assertEqual(f(mock_int, 1), 16)
1352 self.assertEqual(mock_int.__hash__.call_count, 2)
1353 self.assertEqual(f.cache_info(), (1, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001354
Ville Skyttä49b27342017-08-03 09:00:59 +03001355 # Cache eviction: No use as an argument gives no additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001356 self.assertEqual(f(6, 2), 20)
1357 self.assertEqual(mock_int.__hash__.call_count, 2)
1358 self.assertEqual(f.cache_info(), (1, 2, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001359
1360 # Cache miss: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001361 self.assertEqual(f(mock_int, 1), 16)
1362 self.assertEqual(mock_int.__hash__.call_count, 3)
1363 self.assertEqual(f.cache_info(), (1, 3, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001364
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08001365 def test_lru_reentrancy_with_len(self):
1366 # Test to make sure the LRU cache code isn't thrown-off by
1367 # caching the built-in len() function. Since len() can be
1368 # cached, we shouldn't use it inside the lru code itself.
1369 old_len = builtins.len
1370 try:
1371 builtins.len = self.module.lru_cache(4)(len)
1372 for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
1373 self.assertEqual(len('abcdefghijklmn'[:i]), i)
1374 finally:
1375 builtins.len = old_len
1376
Raymond Hettinger605a4472017-01-09 07:50:19 -08001377 def test_lru_star_arg_handling(self):
1378 # Test regression that arose in ea064ff3c10f
1379 @functools.lru_cache()
1380 def f(*args):
1381 return args
1382
1383 self.assertEqual(f(1, 2), (1, 2))
1384 self.assertEqual(f((1, 2)), ((1, 2),))
1385
Yury Selivanov46a02db2016-11-09 18:55:45 -05001386 def test_lru_type_error(self):
1387 # Regression test for issue #28653.
1388 # lru_cache was leaking when one of the arguments
1389 # wasn't cacheable.
1390
1391 @functools.lru_cache(maxsize=None)
1392 def infinite_cache(o):
1393 pass
1394
1395 @functools.lru_cache(maxsize=10)
1396 def limited_cache(o):
1397 pass
1398
1399 with self.assertRaises(TypeError):
1400 infinite_cache([])
1401
1402 with self.assertRaises(TypeError):
1403 limited_cache([])
1404
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001405 def test_lru_with_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001406 @self.module.lru_cache(maxsize=None)
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001407 def fib(n):
1408 if n < 2:
1409 return n
1410 return fib(n-1) + fib(n-2)
1411 self.assertEqual([fib(n) for n in range(16)],
1412 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1413 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001414 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001415 fib.cache_clear()
1416 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001417 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1418
1419 def test_lru_with_maxsize_negative(self):
1420 @self.module.lru_cache(maxsize=-10)
1421 def eq(n):
1422 return n
1423 for i in (0, 1):
1424 self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1425 self.assertEqual(eq.cache_info(),
Raymond Hettingerd8080c02019-01-26 03:02:00 -05001426 self.module._CacheInfo(hits=0, misses=300, maxsize=0, currsize=0))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001427
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001428 def test_lru_with_exceptions(self):
1429 # Verify that user_function exceptions get passed through without
1430 # creating a hard-to-read chained exception.
1431 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +01001432 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001433 @self.module.lru_cache(maxsize)
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001434 def func(i):
1435 return 'abc'[i]
1436 self.assertEqual(func(0), 'a')
1437 with self.assertRaises(IndexError) as cm:
1438 func(15)
1439 self.assertIsNone(cm.exception.__context__)
1440 # Verify that the previous exception did not result in a cached entry
1441 with self.assertRaises(IndexError):
1442 func(15)
1443
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001444 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001445 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001446 @self.module.lru_cache(maxsize=maxsize, typed=True)
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001447 def square(x):
1448 return x * x
1449 self.assertEqual(square(3), 9)
1450 self.assertEqual(type(square(3)), type(9))
1451 self.assertEqual(square(3.0), 9.0)
1452 self.assertEqual(type(square(3.0)), type(9.0))
1453 self.assertEqual(square(x=3), 9)
1454 self.assertEqual(type(square(x=3)), type(9))
1455 self.assertEqual(square(x=3.0), 9.0)
1456 self.assertEqual(type(square(x=3.0)), type(9.0))
1457 self.assertEqual(square.cache_info().hits, 4)
1458 self.assertEqual(square.cache_info().misses, 4)
1459
Antoine Pitroub5b37142012-11-13 21:35:40 +01001460 def test_lru_with_keyword_args(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001461 @self.module.lru_cache()
Antoine Pitroub5b37142012-11-13 21:35:40 +01001462 def fib(n):
1463 if n < 2:
1464 return n
1465 return fib(n=n-1) + fib(n=n-2)
1466 self.assertEqual(
1467 [fib(n=number) for number in range(16)],
1468 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1469 )
1470 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001471 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001472 fib.cache_clear()
1473 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001474 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001475
1476 def test_lru_with_keyword_args_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001477 @self.module.lru_cache(maxsize=None)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001478 def fib(n):
1479 if n < 2:
1480 return n
1481 return fib(n=n-1) + fib(n=n-2)
1482 self.assertEqual([fib(n=number) for number in range(16)],
1483 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1484 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001485 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001486 fib.cache_clear()
1487 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001488 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1489
Raymond Hettinger4ee39142017-01-08 17:28:20 -08001490 def test_kwargs_order(self):
1491 # PEP 468: Preserving Keyword Argument Order
1492 @self.module.lru_cache(maxsize=10)
1493 def f(**kwargs):
1494 return list(kwargs.items())
1495 self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)])
1496 self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)])
1497 self.assertEqual(f.cache_info(),
1498 self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2))
1499
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001500 def test_lru_cache_decoration(self):
1501 def f(zomg: 'zomg_annotation'):
1502 """f doc string"""
1503 return 42
1504 g = self.module.lru_cache()(f)
1505 for attr in self.module.WRAPPER_ASSIGNMENTS:
1506 self.assertEqual(getattr(g, attr), getattr(f, attr))
1507
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001508 def test_lru_cache_threaded(self):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001509 n, m = 5, 11
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001510 def orig(x, y):
1511 return 3 * x + y
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001512 f = self.module.lru_cache(maxsize=n*m)(orig)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001513 hits, misses, maxsize, currsize = f.cache_info()
1514 self.assertEqual(currsize, 0)
1515
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001516 start = threading.Event()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001517 def full(k):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001518 start.wait(10)
1519 for _ in range(m):
Serhiy Storchaka391af752015-06-08 12:44:18 +03001520 self.assertEqual(f(k, 0), orig(k, 0))
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001521
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001522 def clear():
1523 start.wait(10)
1524 for _ in range(2*m):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001525 f.cache_clear()
1526
1527 orig_si = sys.getswitchinterval()
Xavier de Gaye7522ef42016-12-08 11:06:56 +01001528 support.setswitchinterval(1e-6)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001529 try:
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001530 # create n threads in order to fill cache
Serhiy Storchaka391af752015-06-08 12:44:18 +03001531 threads = [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001532 for k in range(n)]
Hai Shie80697d2020-05-28 06:10:27 +08001533 with threading_helper.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001534 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001535
1536 hits, misses, maxsize, currsize = f.cache_info()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001537 if self.module is py_functools:
1538 # XXX: Why can be not equal?
1539 self.assertLessEqual(misses, n)
1540 self.assertLessEqual(hits, m*n - misses)
1541 else:
1542 self.assertEqual(misses, n)
1543 self.assertEqual(hits, m*n - misses)
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001544 self.assertEqual(currsize, n)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001545
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001546 # create n threads in order to fill cache and 1 to clear it
1547 threads = [threading.Thread(target=clear)]
Serhiy Storchaka391af752015-06-08 12:44:18 +03001548 threads += [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001549 for k in range(n)]
1550 start.clear()
Hai Shie80697d2020-05-28 06:10:27 +08001551 with threading_helper.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001552 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001553 finally:
1554 sys.setswitchinterval(orig_si)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001555
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001556 def test_lru_cache_threaded2(self):
1557 # Simultaneous call with the same arguments
1558 n, m = 5, 7
1559 start = threading.Barrier(n+1)
1560 pause = threading.Barrier(n+1)
1561 stop = threading.Barrier(n+1)
1562 @self.module.lru_cache(maxsize=m*n)
1563 def f(x):
1564 pause.wait(10)
1565 return 3 * x
1566 self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1567 def test():
1568 for i in range(m):
1569 start.wait(10)
1570 self.assertEqual(f(i), 3 * i)
1571 stop.wait(10)
1572 threads = [threading.Thread(target=test) for k in range(n)]
Hai Shie80697d2020-05-28 06:10:27 +08001573 with threading_helper.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001574 for i in range(m):
1575 start.wait(10)
1576 stop.reset()
1577 pause.wait(10)
1578 start.reset()
1579 stop.wait(10)
1580 pause.reset()
1581 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1582
Serhiy Storchaka67796522017-01-12 18:34:33 +02001583 def test_lru_cache_threaded3(self):
1584 @self.module.lru_cache(maxsize=2)
1585 def f(x):
1586 time.sleep(.01)
1587 return 3 * x
1588 def test(i, x):
1589 with self.subTest(thread=i):
1590 self.assertEqual(f(x), 3 * x, i)
1591 threads = [threading.Thread(target=test, args=(i, v))
1592 for i, v in enumerate([1, 2, 2, 3, 2])]
Hai Shie80697d2020-05-28 06:10:27 +08001593 with threading_helper.start_threads(threads):
Serhiy Storchaka67796522017-01-12 18:34:33 +02001594 pass
1595
Raymond Hettinger03923422013-03-04 02:52:50 -05001596 def test_need_for_rlock(self):
1597 # This will deadlock on an LRU cache that uses a regular lock
1598
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001599 @self.module.lru_cache(maxsize=10)
Raymond Hettinger03923422013-03-04 02:52:50 -05001600 def test_func(x):
1601 'Used to demonstrate a reentrant lru_cache call within a single thread'
1602 return x
1603
1604 class DoubleEq:
1605 'Demonstrate a reentrant lru_cache call within a single thread'
1606 def __init__(self, x):
1607 self.x = x
1608 def __hash__(self):
1609 return self.x
1610 def __eq__(self, other):
1611 if self.x == 2:
1612 test_func(DoubleEq(1))
1613 return self.x == other.x
1614
1615 test_func(DoubleEq(1)) # Load the cache
1616 test_func(DoubleEq(2)) # Load the cache
1617 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1618 DoubleEq(2)) # Verify the correct return value
1619
Serhiy Storchakae7070f02015-06-08 11:19:24 +03001620 def test_lru_method(self):
1621 class X(int):
1622 f_cnt = 0
1623 @self.module.lru_cache(2)
1624 def f(self, x):
1625 self.f_cnt += 1
1626 return x*10+self
1627 a = X(5)
1628 b = X(5)
1629 c = X(7)
1630 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1631
1632 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1633 self.assertEqual(a.f(x), x*10 + 5)
1634 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1635 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1636
1637 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1638 self.assertEqual(b.f(x), x*10 + 5)
1639 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1640 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1641
1642 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1643 self.assertEqual(c.f(x), x*10 + 7)
1644 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1645 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1646
1647 self.assertEqual(a.f.cache_info(), X.f.cache_info())
1648 self.assertEqual(b.f.cache_info(), X.f.cache_info())
1649 self.assertEqual(c.f.cache_info(), X.f.cache_info())
1650
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001651 def test_pickle(self):
1652 cls = self.__class__
1653 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1654 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1655 with self.subTest(proto=proto, func=f):
1656 f_copy = pickle.loads(pickle.dumps(f, proto))
1657 self.assertIs(f_copy, f)
1658
1659 def test_copy(self):
1660 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001661 def orig(x, y):
1662 return 3 * x + y
1663 part = self.module.partial(orig, 2)
1664 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1665 self.module.lru_cache(2)(part))
1666 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001667 with self.subTest(func=f):
1668 f_copy = copy.copy(f)
1669 self.assertIs(f_copy, f)
1670
1671 def test_deepcopy(self):
1672 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001673 def orig(x, y):
1674 return 3 * x + y
1675 part = self.module.partial(orig, 2)
1676 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1677 self.module.lru_cache(2)(part))
1678 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001679 with self.subTest(func=f):
1680 f_copy = copy.deepcopy(f)
1681 self.assertIs(f_copy, f)
1682
Manjusaka051ff522019-11-12 15:30:18 +08001683 def test_lru_cache_parameters(self):
1684 @self.module.lru_cache(maxsize=2)
1685 def f():
1686 return 1
1687 self.assertEqual(f.cache_parameters(), {'maxsize': 2, "typed": False})
1688
1689 @self.module.lru_cache(maxsize=1000, typed=True)
1690 def f():
1691 return 1
1692 self.assertEqual(f.cache_parameters(), {'maxsize': 1000, "typed": True})
1693
Dennis Sweeney1253c3e2020-05-05 17:14:32 -04001694 def test_lru_cache_weakrefable(self):
1695 @self.module.lru_cache
1696 def test_function(x):
1697 return x
1698
1699 class A:
1700 @self.module.lru_cache
1701 def test_method(self, x):
1702 return (self, x)
1703
1704 @staticmethod
1705 @self.module.lru_cache
1706 def test_staticmethod(x):
1707 return (self, x)
1708
1709 refs = [weakref.ref(test_function),
1710 weakref.ref(A.test_method),
1711 weakref.ref(A.test_staticmethod)]
1712
1713 for ref in refs:
1714 self.assertIsNotNone(ref())
1715
1716 del A
1717 del test_function
1718 gc.collect()
1719
1720 for ref in refs:
1721 self.assertIsNone(ref())
1722
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001723
1724@py_functools.lru_cache()
1725def py_cached_func(x, y):
1726 return 3 * x + y
1727
1728@c_functools.lru_cache()
1729def c_cached_func(x, y):
1730 return 3 * x + y
1731
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001732
1733class TestLRUPy(TestLRU, unittest.TestCase):
1734 module = py_functools
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001735 cached_func = py_cached_func,
1736
1737 @module.lru_cache()
1738 def cached_meth(self, x, y):
1739 return 3 * x + y
1740
1741 @staticmethod
1742 @module.lru_cache()
1743 def cached_staticmeth(x, y):
1744 return 3 * x + y
1745
1746
1747class TestLRUC(TestLRU, unittest.TestCase):
1748 module = c_functools
1749 cached_func = c_cached_func,
1750
1751 @module.lru_cache()
1752 def cached_meth(self, x, y):
1753 return 3 * x + y
1754
1755 @staticmethod
1756 @module.lru_cache()
1757 def cached_staticmeth(x, y):
1758 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001759
Raymond Hettinger03923422013-03-04 02:52:50 -05001760
Łukasz Langa6f692512013-06-05 12:20:24 +02001761class TestSingleDispatch(unittest.TestCase):
1762 def test_simple_overloads(self):
1763 @functools.singledispatch
1764 def g(obj):
1765 return "base"
1766 def g_int(i):
1767 return "integer"
1768 g.register(int, g_int)
1769 self.assertEqual(g("str"), "base")
1770 self.assertEqual(g(1), "integer")
1771 self.assertEqual(g([1,2,3]), "base")
1772
1773 def test_mro(self):
1774 @functools.singledispatch
1775 def g(obj):
1776 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001777 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001778 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001779 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001780 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001781 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001782 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001783 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001784 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001785 def g_A(a):
1786 return "A"
1787 def g_B(b):
1788 return "B"
1789 g.register(A, g_A)
1790 g.register(B, g_B)
1791 self.assertEqual(g(A()), "A")
1792 self.assertEqual(g(B()), "B")
1793 self.assertEqual(g(C()), "A")
1794 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001795
1796 def test_register_decorator(self):
1797 @functools.singledispatch
1798 def g(obj):
1799 return "base"
1800 @g.register(int)
1801 def g_int(i):
1802 return "int %s" % (i,)
1803 self.assertEqual(g(""), "base")
1804 self.assertEqual(g(12), "int 12")
1805 self.assertIs(g.dispatch(int), g_int)
1806 self.assertIs(g.dispatch(object), g.dispatch(str))
1807 # Note: in the assert above this is not g.
1808 # @singledispatch returns the wrapper.
1809
1810 def test_wrapping_attributes(self):
1811 @functools.singledispatch
1812 def g(obj):
1813 "Simple test"
1814 return "Test"
1815 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02001816 if sys.flags.optimize < 2:
1817 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02001818
1819 @unittest.skipUnless(decimal, 'requires _decimal')
1820 @support.cpython_only
1821 def test_c_classes(self):
1822 @functools.singledispatch
1823 def g(obj):
1824 return "base"
1825 @g.register(decimal.DecimalException)
1826 def _(obj):
1827 return obj.args
1828 subn = decimal.Subnormal("Exponent < Emin")
1829 rnd = decimal.Rounded("Number got rounded")
1830 self.assertEqual(g(subn), ("Exponent < Emin",))
1831 self.assertEqual(g(rnd), ("Number got rounded",))
1832 @g.register(decimal.Subnormal)
1833 def _(obj):
1834 return "Too small to care."
1835 self.assertEqual(g(subn), "Too small to care.")
1836 self.assertEqual(g(rnd), ("Number got rounded",))
1837
1838 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001839 # None of the examples in this test depend on haystack ordering.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001840 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001841 mro = functools._compose_mro
1842 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1843 for haystack in permutations(bases):
1844 m = mro(dict, haystack)
Guido van Rossumf0666942016-08-23 10:47:07 -07001845 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
1846 c.Collection, c.Sized, c.Iterable,
1847 c.Container, object])
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001848 bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict]
Łukasz Langa6f692512013-06-05 12:20:24 +02001849 for haystack in permutations(bases):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001850 m = mro(collections.ChainMap, haystack)
1851 self.assertEqual(m, [collections.ChainMap, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001852 c.Collection, c.Sized, c.Iterable,
1853 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001854
1855 # If there's a generic function with implementations registered for
1856 # both Sized and Container, passing a defaultdict to it results in an
1857 # ambiguous dispatch which will cause a RuntimeError (see
1858 # test_mro_conflicts).
1859 bases = [c.Container, c.Sized, str]
1860 for haystack in permutations(bases):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001861 m = mro(collections.defaultdict, [c.Sized, c.Container, str])
1862 self.assertEqual(m, [collections.defaultdict, dict, c.Sized,
1863 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001864
1865 # MutableSequence below is registered directly on D. In other words, it
Martin Panter46f50722016-05-26 05:35:26 +00001866 # precedes MutableMapping which means single dispatch will always
Łukasz Langa3720c772013-07-01 16:00:38 +02001867 # choose MutableSequence here.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001868 class D(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02001869 pass
1870 c.MutableSequence.register(D)
1871 bases = [c.MutableSequence, c.MutableMapping]
1872 for haystack in permutations(bases):
1873 m = mro(D, bases)
Guido van Rossumf0666942016-08-23 10:47:07 -07001874 self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001875 collections.defaultdict, dict, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001876 c.Collection, c.Sized, c.Iterable, c.Container,
Łukasz Langa3720c772013-07-01 16:00:38 +02001877 object])
1878
1879 # Container and Callable are registered on different base classes and
1880 # a generic function supporting both should always pick the Callable
1881 # implementation if a C instance is passed.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001882 class C(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02001883 def __call__(self):
1884 pass
1885 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1886 for haystack in permutations(bases):
1887 m = mro(C, haystack)
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001888 self.assertEqual(m, [C, c.Callable, collections.defaultdict, dict, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001889 c.Collection, c.Sized, c.Iterable,
1890 c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001891
1892 def test_register_abc(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001893 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001894 d = {"a": "b"}
1895 l = [1, 2, 3]
1896 s = {object(), None}
1897 f = frozenset(s)
1898 t = (1, 2, 3)
1899 @functools.singledispatch
1900 def g(obj):
1901 return "base"
1902 self.assertEqual(g(d), "base")
1903 self.assertEqual(g(l), "base")
1904 self.assertEqual(g(s), "base")
1905 self.assertEqual(g(f), "base")
1906 self.assertEqual(g(t), "base")
1907 g.register(c.Sized, lambda obj: "sized")
1908 self.assertEqual(g(d), "sized")
1909 self.assertEqual(g(l), "sized")
1910 self.assertEqual(g(s), "sized")
1911 self.assertEqual(g(f), "sized")
1912 self.assertEqual(g(t), "sized")
1913 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1914 self.assertEqual(g(d), "mutablemapping")
1915 self.assertEqual(g(l), "sized")
1916 self.assertEqual(g(s), "sized")
1917 self.assertEqual(g(f), "sized")
1918 self.assertEqual(g(t), "sized")
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001919 g.register(collections.ChainMap, lambda obj: "chainmap")
Łukasz Langa6f692512013-06-05 12:20:24 +02001920 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1921 self.assertEqual(g(l), "sized")
1922 self.assertEqual(g(s), "sized")
1923 self.assertEqual(g(f), "sized")
1924 self.assertEqual(g(t), "sized")
1925 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1926 self.assertEqual(g(d), "mutablemapping")
1927 self.assertEqual(g(l), "mutablesequence")
1928 self.assertEqual(g(s), "sized")
1929 self.assertEqual(g(f), "sized")
1930 self.assertEqual(g(t), "sized")
1931 g.register(c.MutableSet, lambda obj: "mutableset")
1932 self.assertEqual(g(d), "mutablemapping")
1933 self.assertEqual(g(l), "mutablesequence")
1934 self.assertEqual(g(s), "mutableset")
1935 self.assertEqual(g(f), "sized")
1936 self.assertEqual(g(t), "sized")
1937 g.register(c.Mapping, lambda obj: "mapping")
1938 self.assertEqual(g(d), "mutablemapping") # not specific enough
1939 self.assertEqual(g(l), "mutablesequence")
1940 self.assertEqual(g(s), "mutableset")
1941 self.assertEqual(g(f), "sized")
1942 self.assertEqual(g(t), "sized")
1943 g.register(c.Sequence, lambda obj: "sequence")
1944 self.assertEqual(g(d), "mutablemapping")
1945 self.assertEqual(g(l), "mutablesequence")
1946 self.assertEqual(g(s), "mutableset")
1947 self.assertEqual(g(f), "sized")
1948 self.assertEqual(g(t), "sequence")
1949 g.register(c.Set, lambda obj: "set")
1950 self.assertEqual(g(d), "mutablemapping")
1951 self.assertEqual(g(l), "mutablesequence")
1952 self.assertEqual(g(s), "mutableset")
1953 self.assertEqual(g(f), "set")
1954 self.assertEqual(g(t), "sequence")
1955 g.register(dict, lambda obj: "dict")
1956 self.assertEqual(g(d), "dict")
1957 self.assertEqual(g(l), "mutablesequence")
1958 self.assertEqual(g(s), "mutableset")
1959 self.assertEqual(g(f), "set")
1960 self.assertEqual(g(t), "sequence")
1961 g.register(list, lambda obj: "list")
1962 self.assertEqual(g(d), "dict")
1963 self.assertEqual(g(l), "list")
1964 self.assertEqual(g(s), "mutableset")
1965 self.assertEqual(g(f), "set")
1966 self.assertEqual(g(t), "sequence")
1967 g.register(set, lambda obj: "concrete-set")
1968 self.assertEqual(g(d), "dict")
1969 self.assertEqual(g(l), "list")
1970 self.assertEqual(g(s), "concrete-set")
1971 self.assertEqual(g(f), "set")
1972 self.assertEqual(g(t), "sequence")
1973 g.register(frozenset, lambda obj: "frozen-set")
1974 self.assertEqual(g(d), "dict")
1975 self.assertEqual(g(l), "list")
1976 self.assertEqual(g(s), "concrete-set")
1977 self.assertEqual(g(f), "frozen-set")
1978 self.assertEqual(g(t), "sequence")
1979 g.register(tuple, lambda obj: "tuple")
1980 self.assertEqual(g(d), "dict")
1981 self.assertEqual(g(l), "list")
1982 self.assertEqual(g(s), "concrete-set")
1983 self.assertEqual(g(f), "frozen-set")
1984 self.assertEqual(g(t), "tuple")
1985
Łukasz Langa3720c772013-07-01 16:00:38 +02001986 def test_c3_abc(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001987 c = collections.abc
Łukasz Langa3720c772013-07-01 16:00:38 +02001988 mro = functools._c3_mro
1989 class A(object):
1990 pass
1991 class B(A):
1992 def __len__(self):
1993 return 0 # implies Sized
1994 @c.Container.register
1995 class C(object):
1996 pass
1997 class D(object):
1998 pass # unrelated
1999 class X(D, C, B):
2000 def __call__(self):
2001 pass # implies Callable
2002 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
2003 for abcs in permutations([c.Sized, c.Callable, c.Container]):
2004 self.assertEqual(mro(X, abcs=abcs), expected)
2005 # unrelated ABCs don't appear in the resulting MRO
2006 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
2007 self.assertEqual(mro(X, abcs=many_abcs), expected)
2008
Yury Selivanov77a8cd62015-08-18 14:20:00 -04002009 def test_false_meta(self):
2010 # see issue23572
2011 class MetaA(type):
2012 def __len__(self):
2013 return 0
2014 class A(metaclass=MetaA):
2015 pass
2016 class AA(A):
2017 pass
2018 @functools.singledispatch
2019 def fun(a):
2020 return 'base A'
2021 @fun.register(A)
2022 def _(a):
2023 return 'fun A'
2024 aa = AA()
2025 self.assertEqual(fun(aa), 'fun A')
2026
Łukasz Langa6f692512013-06-05 12:20:24 +02002027 def test_mro_conflicts(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002028 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02002029 @functools.singledispatch
2030 def g(arg):
2031 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02002032 class O(c.Sized):
2033 def __len__(self):
2034 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02002035 o = O()
2036 self.assertEqual(g(o), "base")
2037 g.register(c.Iterable, lambda arg: "iterable")
2038 g.register(c.Container, lambda arg: "container")
2039 g.register(c.Sized, lambda arg: "sized")
2040 g.register(c.Set, lambda arg: "set")
2041 self.assertEqual(g(o), "sized")
2042 c.Iterable.register(O)
2043 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
2044 c.Container.register(O)
2045 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02002046 c.Set.register(O)
2047 self.assertEqual(g(o), "set") # because c.Set is a subclass of
2048 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02002049 class P:
2050 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02002051 p = P()
2052 self.assertEqual(g(p), "base")
2053 c.Iterable.register(P)
2054 self.assertEqual(g(p), "iterable")
2055 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02002056 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02002057 g(p)
Benjamin Petersonab078e92016-07-13 21:13:29 -07002058 self.assertIn(
2059 str(re_one.exception),
2060 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2061 "or <class 'collections.abc.Iterable'>"),
2062 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
2063 "or <class 'collections.abc.Container'>")),
2064 )
Łukasz Langa6f692512013-06-05 12:20:24 +02002065 class Q(c.Sized):
2066 def __len__(self):
2067 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02002068 q = Q()
2069 self.assertEqual(g(q), "sized")
2070 c.Iterable.register(Q)
2071 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
2072 c.Set.register(Q)
2073 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02002074 # c.Sized and c.Iterable
2075 @functools.singledispatch
2076 def h(arg):
2077 return "base"
2078 @h.register(c.Sized)
2079 def _(arg):
2080 return "sized"
2081 @h.register(c.Container)
2082 def _(arg):
2083 return "container"
2084 # Even though Sized and Container are explicit bases of MutableMapping,
2085 # this ABC is implicitly registered on defaultdict which makes all of
2086 # MutableMapping's bases implicit as well from defaultdict's
2087 # perspective.
2088 with self.assertRaises(RuntimeError) as re_two:
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002089 h(collections.defaultdict(lambda: 0))
Benjamin Petersonab078e92016-07-13 21:13:29 -07002090 self.assertIn(
2091 str(re_two.exception),
2092 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2093 "or <class 'collections.abc.Sized'>"),
2094 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2095 "or <class 'collections.abc.Container'>")),
2096 )
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002097 class R(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02002098 pass
2099 c.MutableSequence.register(R)
2100 @functools.singledispatch
2101 def i(arg):
2102 return "base"
2103 @i.register(c.MutableMapping)
2104 def _(arg):
2105 return "mapping"
2106 @i.register(c.MutableSequence)
2107 def _(arg):
2108 return "sequence"
2109 r = R()
2110 self.assertEqual(i(r), "sequence")
2111 class S:
2112 pass
2113 class T(S, c.Sized):
2114 def __len__(self):
2115 return 0
2116 t = T()
2117 self.assertEqual(h(t), "sized")
2118 c.Container.register(T)
2119 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
2120 class U:
2121 def __len__(self):
2122 return 0
2123 u = U()
2124 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
2125 # from the existence of __len__()
2126 c.Container.register(U)
2127 # There is no preference for registered versus inferred ABCs.
2128 with self.assertRaises(RuntimeError) as re_three:
2129 h(u)
Benjamin Petersonab078e92016-07-13 21:13:29 -07002130 self.assertIn(
2131 str(re_three.exception),
2132 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2133 "or <class 'collections.abc.Sized'>"),
2134 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2135 "or <class 'collections.abc.Container'>")),
2136 )
Łukasz Langa3720c772013-07-01 16:00:38 +02002137 class V(c.Sized, S):
2138 def __len__(self):
2139 return 0
2140 @functools.singledispatch
2141 def j(arg):
2142 return "base"
2143 @j.register(S)
2144 def _(arg):
2145 return "s"
2146 @j.register(c.Container)
2147 def _(arg):
2148 return "container"
2149 v = V()
2150 self.assertEqual(j(v), "s")
2151 c.Container.register(V)
2152 self.assertEqual(j(v), "container") # because it ends up right after
2153 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02002154
2155 def test_cache_invalidation(self):
2156 from collections import UserDict
INADA Naoki9811e802017-09-30 16:13:02 +09002157 import weakref
2158
Łukasz Langa6f692512013-06-05 12:20:24 +02002159 class TracingDict(UserDict):
2160 def __init__(self, *args, **kwargs):
2161 super(TracingDict, self).__init__(*args, **kwargs)
2162 self.set_ops = []
2163 self.get_ops = []
2164 def __getitem__(self, key):
2165 result = self.data[key]
2166 self.get_ops.append(key)
2167 return result
2168 def __setitem__(self, key, value):
2169 self.set_ops.append(key)
2170 self.data[key] = value
2171 def clear(self):
2172 self.data.clear()
INADA Naoki9811e802017-09-30 16:13:02 +09002173
Łukasz Langa6f692512013-06-05 12:20:24 +02002174 td = TracingDict()
INADA Naoki9811e802017-09-30 16:13:02 +09002175 with support.swap_attr(weakref, "WeakKeyDictionary", lambda: td):
2176 c = collections.abc
2177 @functools.singledispatch
2178 def g(arg):
2179 return "base"
2180 d = {}
2181 l = []
2182 self.assertEqual(len(td), 0)
2183 self.assertEqual(g(d), "base")
2184 self.assertEqual(len(td), 1)
2185 self.assertEqual(td.get_ops, [])
2186 self.assertEqual(td.set_ops, [dict])
2187 self.assertEqual(td.data[dict], g.registry[object])
2188 self.assertEqual(g(l), "base")
2189 self.assertEqual(len(td), 2)
2190 self.assertEqual(td.get_ops, [])
2191 self.assertEqual(td.set_ops, [dict, list])
2192 self.assertEqual(td.data[dict], g.registry[object])
2193 self.assertEqual(td.data[list], g.registry[object])
2194 self.assertEqual(td.data[dict], td.data[list])
2195 self.assertEqual(g(l), "base")
2196 self.assertEqual(g(d), "base")
2197 self.assertEqual(td.get_ops, [list, dict])
2198 self.assertEqual(td.set_ops, [dict, list])
2199 g.register(list, lambda arg: "list")
2200 self.assertEqual(td.get_ops, [list, dict])
2201 self.assertEqual(len(td), 0)
2202 self.assertEqual(g(d), "base")
2203 self.assertEqual(len(td), 1)
2204 self.assertEqual(td.get_ops, [list, dict])
2205 self.assertEqual(td.set_ops, [dict, list, dict])
2206 self.assertEqual(td.data[dict],
2207 functools._find_impl(dict, g.registry))
2208 self.assertEqual(g(l), "list")
2209 self.assertEqual(len(td), 2)
2210 self.assertEqual(td.get_ops, [list, dict])
2211 self.assertEqual(td.set_ops, [dict, list, dict, list])
2212 self.assertEqual(td.data[list],
2213 functools._find_impl(list, g.registry))
2214 class X:
2215 pass
2216 c.MutableMapping.register(X) # Will not invalidate the cache,
2217 # not using ABCs yet.
2218 self.assertEqual(g(d), "base")
2219 self.assertEqual(g(l), "list")
2220 self.assertEqual(td.get_ops, [list, dict, dict, list])
2221 self.assertEqual(td.set_ops, [dict, list, dict, list])
2222 g.register(c.Sized, lambda arg: "sized")
2223 self.assertEqual(len(td), 0)
2224 self.assertEqual(g(d), "sized")
2225 self.assertEqual(len(td), 1)
2226 self.assertEqual(td.get_ops, [list, dict, dict, list])
2227 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
2228 self.assertEqual(g(l), "list")
2229 self.assertEqual(len(td), 2)
2230 self.assertEqual(td.get_ops, [list, dict, dict, list])
2231 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2232 self.assertEqual(g(l), "list")
2233 self.assertEqual(g(d), "sized")
2234 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
2235 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2236 g.dispatch(list)
2237 g.dispatch(dict)
2238 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
2239 list, dict])
2240 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2241 c.MutableSet.register(X) # Will invalidate the cache.
2242 self.assertEqual(len(td), 2) # Stale cache.
2243 self.assertEqual(g(l), "list")
2244 self.assertEqual(len(td), 1)
2245 g.register(c.MutableMapping, lambda arg: "mutablemapping")
2246 self.assertEqual(len(td), 0)
2247 self.assertEqual(g(d), "mutablemapping")
2248 self.assertEqual(len(td), 1)
2249 self.assertEqual(g(l), "list")
2250 self.assertEqual(len(td), 2)
2251 g.register(dict, lambda arg: "dict")
2252 self.assertEqual(g(d), "dict")
2253 self.assertEqual(g(l), "list")
2254 g._clear_cache()
2255 self.assertEqual(len(td), 0)
Łukasz Langa6f692512013-06-05 12:20:24 +02002256
Łukasz Langae5697532017-12-11 13:56:31 -08002257 def test_annotations(self):
2258 @functools.singledispatch
2259 def i(arg):
2260 return "base"
2261 @i.register
2262 def _(arg: collections.abc.Mapping):
2263 return "mapping"
2264 @i.register
2265 def _(arg: "collections.abc.Sequence"):
2266 return "sequence"
2267 self.assertEqual(i(None), "base")
2268 self.assertEqual(i({"a": 1}), "mapping")
2269 self.assertEqual(i([1, 2, 3]), "sequence")
2270 self.assertEqual(i((1, 2, 3)), "sequence")
2271 self.assertEqual(i("str"), "sequence")
2272
2273 # Registering classes as callables doesn't work with annotations,
2274 # you need to pass the type explicitly.
2275 @i.register(str)
2276 class _:
2277 def __init__(self, arg):
2278 self.arg = arg
2279
2280 def __eq__(self, other):
2281 return self.arg == other
2282 self.assertEqual(i("str"), "str")
2283
Ethan Smithc6512752018-05-26 16:38:33 -04002284 def test_method_register(self):
2285 class A:
2286 @functools.singledispatchmethod
2287 def t(self, arg):
2288 self.arg = "base"
2289 @t.register(int)
2290 def _(self, arg):
2291 self.arg = "int"
2292 @t.register(str)
2293 def _(self, arg):
2294 self.arg = "str"
2295 a = A()
2296
2297 a.t(0)
2298 self.assertEqual(a.arg, "int")
2299 aa = A()
2300 self.assertFalse(hasattr(aa, 'arg'))
2301 a.t('')
2302 self.assertEqual(a.arg, "str")
2303 aa = A()
2304 self.assertFalse(hasattr(aa, 'arg'))
2305 a.t(0.0)
2306 self.assertEqual(a.arg, "base")
2307 aa = A()
2308 self.assertFalse(hasattr(aa, 'arg'))
2309
2310 def test_staticmethod_register(self):
2311 class A:
2312 @functools.singledispatchmethod
2313 @staticmethod
2314 def t(arg):
2315 return arg
2316 @t.register(int)
2317 @staticmethod
2318 def _(arg):
2319 return isinstance(arg, int)
2320 @t.register(str)
2321 @staticmethod
2322 def _(arg):
2323 return isinstance(arg, str)
2324 a = A()
2325
2326 self.assertTrue(A.t(0))
2327 self.assertTrue(A.t(''))
2328 self.assertEqual(A.t(0.0), 0.0)
2329
2330 def test_classmethod_register(self):
2331 class A:
2332 def __init__(self, arg):
2333 self.arg = arg
2334
2335 @functools.singledispatchmethod
2336 @classmethod
2337 def t(cls, arg):
2338 return cls("base")
2339 @t.register(int)
2340 @classmethod
2341 def _(cls, arg):
2342 return cls("int")
2343 @t.register(str)
2344 @classmethod
2345 def _(cls, arg):
2346 return cls("str")
2347
2348 self.assertEqual(A.t(0).arg, "int")
2349 self.assertEqual(A.t('').arg, "str")
2350 self.assertEqual(A.t(0.0).arg, "base")
2351
2352 def test_callable_register(self):
2353 class A:
2354 def __init__(self, arg):
2355 self.arg = arg
2356
2357 @functools.singledispatchmethod
2358 @classmethod
2359 def t(cls, arg):
2360 return cls("base")
2361
2362 @A.t.register(int)
2363 @classmethod
2364 def _(cls, arg):
2365 return cls("int")
2366 @A.t.register(str)
2367 @classmethod
2368 def _(cls, arg):
2369 return cls("str")
2370
2371 self.assertEqual(A.t(0).arg, "int")
2372 self.assertEqual(A.t('').arg, "str")
2373 self.assertEqual(A.t(0.0).arg, "base")
2374
2375 def test_abstractmethod_register(self):
2376 class Abstract(abc.ABCMeta):
2377
2378 @functools.singledispatchmethod
2379 @abc.abstractmethod
2380 def add(self, x, y):
2381 pass
2382
2383 self.assertTrue(Abstract.add.__isabstractmethod__)
2384
2385 def test_type_ann_register(self):
2386 class A:
2387 @functools.singledispatchmethod
2388 def t(self, arg):
2389 return "base"
2390 @t.register
2391 def _(self, arg: int):
2392 return "int"
2393 @t.register
2394 def _(self, arg: str):
2395 return "str"
2396 a = A()
2397
2398 self.assertEqual(a.t(0), "int")
2399 self.assertEqual(a.t(''), "str")
2400 self.assertEqual(a.t(0.0), "base")
2401
Łukasz Langae5697532017-12-11 13:56:31 -08002402 def test_invalid_registrations(self):
2403 msg_prefix = "Invalid first argument to `register()`: "
2404 msg_suffix = (
2405 ". Use either `@register(some_class)` or plain `@register` on an "
2406 "annotated function."
2407 )
2408 @functools.singledispatch
2409 def i(arg):
2410 return "base"
2411 with self.assertRaises(TypeError) as exc:
2412 @i.register(42)
2413 def _(arg):
2414 return "I annotated with a non-type"
2415 self.assertTrue(str(exc.exception).startswith(msg_prefix + "42"))
2416 self.assertTrue(str(exc.exception).endswith(msg_suffix))
2417 with self.assertRaises(TypeError) as exc:
2418 @i.register
2419 def _(arg):
2420 return "I forgot to annotate"
2421 self.assertTrue(str(exc.exception).startswith(msg_prefix +
2422 "<function TestSingleDispatch.test_invalid_registrations.<locals>._"
2423 ))
2424 self.assertTrue(str(exc.exception).endswith(msg_suffix))
2425
Łukasz Langae5697532017-12-11 13:56:31 -08002426 with self.assertRaises(TypeError) as exc:
2427 @i.register
2428 def _(arg: typing.Iterable[str]):
2429 # At runtime, dispatching on generics is impossible.
2430 # When registering implementations with singledispatch, avoid
2431 # types from `typing`. Instead, annotate with regular types
2432 # or ABCs.
2433 return "I annotated with a generic collection"
Lysandros Nikolaoud6738102019-05-20 00:11:21 +02002434 self.assertTrue(str(exc.exception).startswith(
2435 "Invalid annotation for 'arg'."
Łukasz Langae5697532017-12-11 13:56:31 -08002436 ))
Lysandros Nikolaoud6738102019-05-20 00:11:21 +02002437 self.assertTrue(str(exc.exception).endswith(
2438 'typing.Iterable[str] is not a class.'
2439 ))
Łukasz Langae5697532017-12-11 13:56:31 -08002440
Dong-hee Na445f1b32018-07-10 16:26:36 +09002441 def test_invalid_positional_argument(self):
2442 @functools.singledispatch
2443 def f(*args):
2444 pass
2445 msg = 'f requires at least 1 positional argument'
INADA Naoki56d8f572018-07-17 13:44:47 +09002446 with self.assertRaisesRegex(TypeError, msg):
Dong-hee Na445f1b32018-07-10 16:26:36 +09002447 f()
Łukasz Langa6f692512013-06-05 12:20:24 +02002448
Carl Meyerd658dea2018-08-28 01:11:56 -06002449
2450class CachedCostItem:
2451 _cost = 1
2452
2453 def __init__(self):
2454 self.lock = py_functools.RLock()
2455
2456 @py_functools.cached_property
2457 def cost(self):
2458 """The cost of the item."""
2459 with self.lock:
2460 self._cost += 1
2461 return self._cost
2462
2463
2464class OptionallyCachedCostItem:
2465 _cost = 1
2466
2467 def get_cost(self):
2468 """The cost of the item."""
2469 self._cost += 1
2470 return self._cost
2471
2472 cached_cost = py_functools.cached_property(get_cost)
2473
2474
2475class CachedCostItemWait:
2476
2477 def __init__(self, event):
2478 self._cost = 1
2479 self.lock = py_functools.RLock()
2480 self.event = event
2481
2482 @py_functools.cached_property
2483 def cost(self):
2484 self.event.wait(1)
2485 with self.lock:
2486 self._cost += 1
2487 return self._cost
2488
2489
2490class CachedCostItemWithSlots:
2491 __slots__ = ('_cost')
2492
2493 def __init__(self):
2494 self._cost = 1
2495
2496 @py_functools.cached_property
2497 def cost(self):
2498 raise RuntimeError('never called, slots not supported')
2499
2500
2501class TestCachedProperty(unittest.TestCase):
2502 def test_cached(self):
2503 item = CachedCostItem()
2504 self.assertEqual(item.cost, 2)
2505 self.assertEqual(item.cost, 2) # not 3
2506
2507 def test_cached_attribute_name_differs_from_func_name(self):
2508 item = OptionallyCachedCostItem()
2509 self.assertEqual(item.get_cost(), 2)
2510 self.assertEqual(item.cached_cost, 3)
2511 self.assertEqual(item.get_cost(), 4)
2512 self.assertEqual(item.cached_cost, 3)
2513
2514 def test_threaded(self):
2515 go = threading.Event()
2516 item = CachedCostItemWait(go)
2517
2518 num_threads = 3
2519
2520 orig_si = sys.getswitchinterval()
2521 sys.setswitchinterval(1e-6)
2522 try:
2523 threads = [
2524 threading.Thread(target=lambda: item.cost)
2525 for k in range(num_threads)
2526 ]
Hai Shie80697d2020-05-28 06:10:27 +08002527 with threading_helper.start_threads(threads):
Carl Meyerd658dea2018-08-28 01:11:56 -06002528 go.set()
2529 finally:
2530 sys.setswitchinterval(orig_si)
2531
2532 self.assertEqual(item.cost, 2)
2533
2534 def test_object_with_slots(self):
2535 item = CachedCostItemWithSlots()
2536 with self.assertRaisesRegex(
2537 TypeError,
2538 "No '__dict__' attribute on 'CachedCostItemWithSlots' instance to cache 'cost' property.",
2539 ):
2540 item.cost
2541
2542 def test_immutable_dict(self):
2543 class MyMeta(type):
2544 @py_functools.cached_property
2545 def prop(self):
2546 return True
2547
2548 class MyClass(metaclass=MyMeta):
2549 pass
2550
2551 with self.assertRaisesRegex(
2552 TypeError,
2553 "The '__dict__' attribute on 'MyMeta' instance does not support item assignment for caching 'prop' property.",
2554 ):
2555 MyClass.prop
2556
2557 def test_reuse_different_names(self):
2558 """Disallow this case because decorated function a would not be cached."""
2559 with self.assertRaises(RuntimeError) as ctx:
2560 class ReusedCachedProperty:
2561 @py_functools.cached_property
2562 def a(self):
2563 pass
2564
2565 b = a
2566
2567 self.assertEqual(
2568 str(ctx.exception.__context__),
2569 str(TypeError("Cannot assign the same cached_property to two different names ('a' and 'b')."))
2570 )
2571
2572 def test_reuse_same_name(self):
2573 """Reusing a cached_property on different classes under the same name is OK."""
2574 counter = 0
2575
2576 @py_functools.cached_property
2577 def _cp(_self):
2578 nonlocal counter
2579 counter += 1
2580 return counter
2581
2582 class A:
2583 cp = _cp
2584
2585 class B:
2586 cp = _cp
2587
2588 a = A()
2589 b = B()
2590
2591 self.assertEqual(a.cp, 1)
2592 self.assertEqual(b.cp, 2)
2593 self.assertEqual(a.cp, 1)
2594
2595 def test_set_name_not_called(self):
2596 cp = py_functools.cached_property(lambda s: None)
2597 class Foo:
2598 pass
2599
2600 Foo.cp = cp
2601
2602 with self.assertRaisesRegex(
2603 TypeError,
2604 "Cannot use cached_property instance without calling __set_name__ on it.",
2605 ):
2606 Foo().cp
2607
2608 def test_access_from_class(self):
2609 self.assertIsInstance(CachedCostItem.cost, py_functools.cached_property)
2610
2611 def test_doc(self):
2612 self.assertEqual(CachedCostItem.cost.__doc__, "The cost of the item.")
2613
2614
Raymond Hettinger9c323f82005-02-28 19:39:44 +00002615if __name__ == '__main__':
Zachary Ware38c707e2015-04-13 15:00:43 -05002616 unittest.main()