blob: 11e8aa356d333179ca195eedc3157d2f77aefabd [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08002import builtins
Raymond Hettinger003be522011-05-03 11:01:32 -07003import collections
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03004import collections.abc
Serhiy Storchaka45120f22015-10-24 09:49:56 +03005import copy
Miss Islington (bot)0a674632020-05-31 17:01:37 -07006from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00007import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00008from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02009import sys
10from test import support
Antoine Pitroua6a4dc82017-09-07 18:56:24 +020011import threading
Serhiy Storchaka67796522017-01-12 18:34:33 +020012import time
Łukasz Langae5697532017-12-11 13:56:31 -080013import typing
Łukasz Langa6f692512013-06-05 12:20:24 +020014import unittest
Raymond Hettingerd191ef22017-01-07 20:44:48 -080015import unittest.mock
Pablo Galindo99e6c262020-01-23 15:29:52 +000016import os
Dennis Sweeney1253c3e2020-05-05 17:14:32 -040017import weakref
18import gc
Łukasz Langa6f692512013-06-05 12:20:24 +020019from weakref import proxy
Nick Coghlan457fc9a2016-09-10 20:00:02 +100020import contextlib
Raymond Hettinger9c323f82005-02-28 19:39:44 +000021
Pablo Galindo99e6c262020-01-23 15:29:52 +000022from test.support.script_helper import assert_python_ok
23
Antoine Pitroub5b37142012-11-13 21:35:40 +010024import functools
25
Antoine Pitroub5b37142012-11-13 21:35:40 +010026py_functools = support.import_fresh_module('functools', blocked=['_functools'])
27c_functools = support.import_fresh_module('functools', fresh=['_functools'])
28
Łukasz Langa6f692512013-06-05 12:20:24 +020029decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
30
Nick Coghlan457fc9a2016-09-10 20:00:02 +100031@contextlib.contextmanager
32def replaced_module(name, replacement):
33 original_module = sys.modules[name]
34 sys.modules[name] = replacement
35 try:
36 yield
37 finally:
38 sys.modules[name] = original_module
Łukasz Langa6f692512013-06-05 12:20:24 +020039
Raymond Hettinger9c323f82005-02-28 19:39:44 +000040def capture(*args, **kw):
41 """capture all positional and keyword arguments"""
42 return args, kw
43
Łukasz Langa6f692512013-06-05 12:20:24 +020044
Jack Diederiche0cbd692009-04-01 04:27:09 +000045def signature(part):
46 """ return the signature of a partial object """
47 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000048
Serhiy Storchaka38741282016-02-02 18:45:17 +020049class MyTuple(tuple):
50 pass
51
52class BadTuple(tuple):
53 def __add__(self, other):
54 return list(self) + list(other)
55
56class MyDict(dict):
57 pass
58
Łukasz Langa6f692512013-06-05 12:20:24 +020059
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020060class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000061
62 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010063 p = self.partial(capture, 1, 2, a=10, b=20)
64 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000065 self.assertEqual(p(3, 4, b=30, c=40),
66 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010067 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000068 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000069
70 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010071 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000072 # attributes should be readable
73 self.assertEqual(p.func, capture)
74 self.assertEqual(p.args, (1, 2))
75 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000076
77 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010078 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000079 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010080 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000081 except TypeError:
82 pass
83 else:
84 self.fail('First arg not checked for callability')
85
86 def test_protection_of_callers_dict_argument(self):
87 # a caller's dictionary should not be altered by partial
88 def func(a=10, b=20):
89 return a
90 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010091 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000092 self.assertEqual(p(**d), 3)
93 self.assertEqual(d, {'a':3})
94 p(b=7)
95 self.assertEqual(d, {'a':3})
96
Serhiy Storchaka9639e4a2017-02-20 14:04:30 +020097 def test_kwargs_copy(self):
98 # Issue #29532: Altering a kwarg dictionary passed to a constructor
99 # should not affect a partial object after creation
100 d = {'a': 3}
101 p = self.partial(capture, **d)
102 self.assertEqual(p(), ((), {'a': 3}))
103 d['a'] = 5
104 self.assertEqual(p(), ((), {'a': 3}))
105
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000106 def test_arg_combinations(self):
107 # exercise special code paths for zero args in either partial
108 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100109 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000110 self.assertEqual(p(), ((), {}))
111 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100112 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000113 self.assertEqual(p(), ((1,2), {}))
114 self.assertEqual(p(3,4), ((1,2,3,4), {}))
115
116 def test_kw_combinations(self):
117 # exercise special code paths for no keyword args in
118 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100119 p = self.partial(capture)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400120 self.assertEqual(p.keywords, {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000121 self.assertEqual(p(), ((), {}))
122 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100123 p = self.partial(capture, a=1)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400124 self.assertEqual(p.keywords, {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000125 self.assertEqual(p(), ((), {'a':1}))
126 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
127 # keyword args in the call override those in the partial object
128 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
129
130 def test_positional(self):
131 # make sure positional arguments are captured correctly
132 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100133 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000134 expected = args + ('x',)
135 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000136 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000137
138 def test_keyword(self):
139 # make sure keyword arguments are captured correctly
140 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100141 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000142 expected = {'a':a,'x':None}
143 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000144 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000145
146 def test_no_side_effects(self):
147 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100148 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000149 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000150 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000151 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000152 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000153
154 def test_error_propagation(self):
155 def f(x, y):
156 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100157 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
158 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
159 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
160 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000161
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000162 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100163 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000164 p = proxy(f)
165 self.assertEqual(f.func, p.func)
166 f = None
167 self.assertRaises(ReferenceError, getattr, p, 'func')
168
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000169 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000170 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100171 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000172 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100173 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000174 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000175
Alexander Belopolskye49af342015-03-01 15:08:17 -0500176 def test_nested_optimization(self):
177 partial = self.partial
Alexander Belopolskye49af342015-03-01 15:08:17 -0500178 inner = partial(signature, 'asdf')
179 nested = partial(inner, bar=True)
180 flat = partial(signature, 'asdf', bar=True)
181 self.assertEqual(signature(nested), signature(flat))
182
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300183 def test_nested_partial_with_attribute(self):
184 # see issue 25137
185 partial = self.partial
186
187 def foo(bar):
188 return bar
189
190 p = partial(foo, 'first')
191 p2 = partial(p, 'second')
192 p2.new_attr = 'spam'
193 self.assertEqual(p2.new_attr, 'spam')
194
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000195 def test_repr(self):
196 args = (object(), object())
197 args_repr = ', '.join(repr(a) for a in args)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200198 kwargs = {'a': object(), 'b': object()}
199 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
200 'b={b!r}, a={a!r}'.format_map(kwargs)]
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000201 if self.partial in (c_functools.partial, py_functools.partial):
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000202 name = 'functools.partial'
203 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100204 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000205
Antoine Pitroub5b37142012-11-13 21:35:40 +0100206 f = self.partial(capture)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000207 self.assertEqual(f'{name}({capture!r})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000208
Antoine Pitroub5b37142012-11-13 21:35:40 +0100209 f = self.partial(capture, *args)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000210 self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000211
Antoine Pitroub5b37142012-11-13 21:35:40 +0100212 f = self.partial(capture, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200213 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000214 [f'{name}({capture!r}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200215 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000216
Antoine Pitroub5b37142012-11-13 21:35:40 +0100217 f = self.partial(capture, *args, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200218 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000219 [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200220 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000221
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300222 def test_recursive_repr(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000223 if self.partial in (c_functools.partial, py_functools.partial):
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300224 name = 'functools.partial'
225 else:
226 name = self.partial.__name__
227
228 f = self.partial(capture)
229 f.__setstate__((f, (), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300230 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000231 self.assertEqual(repr(f), '%s(...)' % (name,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300232 finally:
233 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300234
235 f = self.partial(capture)
236 f.__setstate__((capture, (f,), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300237 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000238 self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300239 finally:
240 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300241
242 f = self.partial(capture)
243 f.__setstate__((capture, (), {'a': f}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300244 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000245 self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300246 finally:
247 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300248
Jack Diederiche0cbd692009-04-01 04:27:09 +0000249 def test_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000250 with self.AllowPickle():
251 f = self.partial(signature, ['asdf'], bar=[True])
252 f.attr = []
253 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
254 f_copy = pickle.loads(pickle.dumps(f, proto))
255 self.assertEqual(signature(f_copy), signature(f))
Serhiy Storchaka38741282016-02-02 18:45:17 +0200256
257 def test_copy(self):
258 f = self.partial(signature, ['asdf'], bar=[True])
259 f.attr = []
260 f_copy = copy.copy(f)
261 self.assertEqual(signature(f_copy), signature(f))
262 self.assertIs(f_copy.attr, f.attr)
263 self.assertIs(f_copy.args, f.args)
264 self.assertIs(f_copy.keywords, f.keywords)
265
266 def test_deepcopy(self):
267 f = self.partial(signature, ['asdf'], bar=[True])
268 f.attr = []
269 f_copy = copy.deepcopy(f)
270 self.assertEqual(signature(f_copy), signature(f))
271 self.assertIsNot(f_copy.attr, f.attr)
272 self.assertIsNot(f_copy.args, f.args)
273 self.assertIsNot(f_copy.args[0], f.args[0])
274 self.assertIsNot(f_copy.keywords, f.keywords)
275 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
276
277 def test_setstate(self):
278 f = self.partial(signature)
279 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000280
Serhiy Storchaka38741282016-02-02 18:45:17 +0200281 self.assertEqual(signature(f),
282 (capture, (1,), dict(a=10), dict(attr=[])))
283 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
284
285 f.__setstate__((capture, (1,), dict(a=10), None))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000286
Serhiy Storchaka38741282016-02-02 18:45:17 +0200287 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
288 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
289
290 f.__setstate__((capture, (1,), None, None))
291 #self.assertEqual(signature(f), (capture, (1,), {}, {}))
292 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
293 self.assertEqual(f(2), ((1, 2), {}))
294 self.assertEqual(f(), ((1,), {}))
295
296 f.__setstate__((capture, (), {}, None))
297 self.assertEqual(signature(f), (capture, (), {}, {}))
298 self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
299 self.assertEqual(f(2), ((2,), {}))
300 self.assertEqual(f(), ((), {}))
301
302 def test_setstate_errors(self):
303 f = self.partial(signature)
304 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
305 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
306 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
307 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
308 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
309 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
310 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
311
312 def test_setstate_subclasses(self):
313 f = self.partial(signature)
314 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
315 s = signature(f)
316 self.assertEqual(s, (capture, (1,), dict(a=10), {}))
317 self.assertIs(type(s[1]), tuple)
318 self.assertIs(type(s[2]), dict)
319 r = f()
320 self.assertEqual(r, ((1,), {'a': 10}))
321 self.assertIs(type(r[0]), tuple)
322 self.assertIs(type(r[1]), dict)
323
324 f.__setstate__((capture, BadTuple((1,)), {}, None))
325 s = signature(f)
326 self.assertEqual(s, (capture, (1,), {}, {}))
327 self.assertIs(type(s[1]), tuple)
328 r = f(2)
329 self.assertEqual(r, ((1, 2), {}))
330 self.assertIs(type(r[0]), tuple)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000331
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300332 def test_recursive_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000333 with self.AllowPickle():
334 f = self.partial(capture)
335 f.__setstate__((f, (), {}, {}))
336 try:
337 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
338 with self.assertRaises(RecursionError):
339 pickle.dumps(f, proto)
340 finally:
341 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300342
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000343 f = self.partial(capture)
344 f.__setstate__((capture, (f,), {}, {}))
345 try:
346 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
347 f_copy = pickle.loads(pickle.dumps(f, proto))
348 try:
349 self.assertIs(f_copy.args[0], f_copy)
350 finally:
351 f_copy.__setstate__((capture, (), {}, {}))
352 finally:
353 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300354
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000355 f = self.partial(capture)
356 f.__setstate__((capture, (), {'a': f}, {}))
357 try:
358 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
359 f_copy = pickle.loads(pickle.dumps(f, proto))
360 try:
361 self.assertIs(f_copy.keywords['a'], f_copy)
362 finally:
363 f_copy.__setstate__((capture, (), {}, {}))
364 finally:
365 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300366
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200367 # Issue 6083: Reference counting bug
368 def test_setstate_refcount(self):
369 class BadSequence:
370 def __len__(self):
371 return 4
372 def __getitem__(self, key):
373 if key == 0:
374 return max
375 elif key == 1:
376 return tuple(range(1000000))
377 elif key in (2, 3):
378 return {}
379 raise IndexError
380
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200381 f = self.partial(object)
Serhiy Storchaka38741282016-02-02 18:45:17 +0200382 self.assertRaises(TypeError, f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000383
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000384@unittest.skipUnless(c_functools, 'requires the C _functools module')
385class TestPartialC(TestPartial, unittest.TestCase):
386 if c_functools:
387 partial = c_functools.partial
388
389 class AllowPickle:
390 def __enter__(self):
391 return self
392 def __exit__(self, type, value, tb):
393 return False
394
395 def test_attributes_unwritable(self):
396 # attributes should not be writable
397 p = self.partial(capture, 1, 2, a=10, b=20)
398 self.assertRaises(AttributeError, setattr, p, 'func', map)
399 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
400 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
401
402 p = self.partial(hex)
403 try:
404 del p.__dict__
405 except TypeError:
406 pass
407 else:
408 self.fail('partial object allowed __dict__ to be deleted')
Łukasz Langa6f692512013-06-05 12:20:24 +0200409
Michael Seifert6c3d5272017-03-15 06:26:33 +0100410 def test_manually_adding_non_string_keyword(self):
411 p = self.partial(capture)
412 # Adding a non-string/unicode keyword to partial kwargs
413 p.keywords[1234] = 'value'
414 r = repr(p)
415 self.assertIn('1234', r)
416 self.assertIn("'value'", r)
417 with self.assertRaises(TypeError):
418 p()
419
420 def test_keystr_replaces_value(self):
421 p = self.partial(capture)
422
423 class MutatesYourDict(object):
424 def __str__(self):
425 p.keywords[self] = ['sth2']
426 return 'astr'
427
Mike53f7a7c2017-12-14 14:04:53 +0300428 # Replacing the value during key formatting should keep the original
Michael Seifert6c3d5272017-03-15 06:26:33 +0100429 # value alive (at least long enough).
430 p.keywords[MutatesYourDict()] = ['sth']
431 r = repr(p)
432 self.assertIn('astr', r)
433 self.assertIn("['sth']", r)
434
435
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200436class TestPartialPy(TestPartial, unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000437 partial = py_functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000438
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000439 class AllowPickle:
440 def __init__(self):
441 self._cm = replaced_module("functools", py_functools)
442 def __enter__(self):
443 return self._cm.__enter__()
444 def __exit__(self, type, value, tb):
445 return self._cm.__exit__(type, value, tb)
Łukasz Langa6f692512013-06-05 12:20:24 +0200446
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200447if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000448 class CPartialSubclass(c_functools.partial):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200449 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100450
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000451class PyPartialSubclass(py_functools.partial):
452 pass
Łukasz Langa6f692512013-06-05 12:20:24 +0200453
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200454@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200455class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200456 if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000457 partial = CPartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000458
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300459 # partial subclasses are not optimized for nested calls
460 test_nested_optimization = None
461
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000462class TestPartialPySubclass(TestPartialPy):
463 partial = PyPartialSubclass
Łukasz Langa6f692512013-06-05 12:20:24 +0200464
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000465class TestPartialMethod(unittest.TestCase):
466
467 class A(object):
468 nothing = functools.partialmethod(capture)
469 positional = functools.partialmethod(capture, 1)
470 keywords = functools.partialmethod(capture, a=2)
471 both = functools.partialmethod(capture, 3, b=4)
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300472 spec_keywords = functools.partialmethod(capture, self=1, func=2)
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000473
474 nested = functools.partialmethod(positional, 5)
475
476 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
477
478 static = functools.partialmethod(staticmethod(capture), 8)
479 cls = functools.partialmethod(classmethod(capture), d=9)
480
481 a = A()
482
483 def test_arg_combinations(self):
484 self.assertEqual(self.a.nothing(), ((self.a,), {}))
485 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
486 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
487 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
488
489 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
490 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
491 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
492 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
493
494 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
495 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
496 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
497 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
498
499 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
500 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
501 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
502 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
503
504 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
505
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300506 self.assertEqual(self.a.spec_keywords(), ((self.a,), {'self': 1, 'func': 2}))
507
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000508 def test_nested(self):
509 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
510 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
511 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
512 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
513
514 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
515
516 def test_over_partial(self):
517 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
518 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
519 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
520 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
521
522 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
523
524 def test_bound_method_introspection(self):
525 obj = self.a
526 self.assertIs(obj.both.__self__, obj)
527 self.assertIs(obj.nested.__self__, obj)
528 self.assertIs(obj.over_partial.__self__, obj)
529 self.assertIs(obj.cls.__self__, self.A)
530 self.assertIs(self.A.cls.__self__, self.A)
531
532 def test_unbound_method_retrieval(self):
533 obj = self.A
534 self.assertFalse(hasattr(obj.both, "__self__"))
535 self.assertFalse(hasattr(obj.nested, "__self__"))
536 self.assertFalse(hasattr(obj.over_partial, "__self__"))
537 self.assertFalse(hasattr(obj.static, "__self__"))
538 self.assertFalse(hasattr(self.a.static, "__self__"))
539
540 def test_descriptors(self):
541 for obj in [self.A, self.a]:
542 with self.subTest(obj=obj):
543 self.assertEqual(obj.static(), ((8,), {}))
544 self.assertEqual(obj.static(5), ((8, 5), {}))
545 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
546 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
547
548 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
549 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
550 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
551 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
552
553 def test_overriding_keywords(self):
554 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
555 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
556
557 def test_invalid_args(self):
558 with self.assertRaises(TypeError):
559 class B(object):
560 method = functools.partialmethod(None, 1)
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300561 with self.assertRaises(TypeError):
562 class B:
563 method = functools.partialmethod()
Serhiy Storchaka142566c2019-06-05 18:22:31 +0300564 with self.assertRaises(TypeError):
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300565 class B:
566 method = functools.partialmethod(func=capture, a=1)
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000567
568 def test_repr(self):
569 self.assertEqual(repr(vars(self.A)['both']),
570 'functools.partialmethod({}, 3, b=4)'.format(capture))
571
572 def test_abstract(self):
573 class Abstract(abc.ABCMeta):
574
575 @abc.abstractmethod
576 def add(self, x, y):
577 pass
578
579 add5 = functools.partialmethod(add, 5)
580
581 self.assertTrue(Abstract.add.__isabstractmethod__)
582 self.assertTrue(Abstract.add5.__isabstractmethod__)
583
584 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
585 self.assertFalse(getattr(func, '__isabstractmethod__', False))
586
Pablo Galindo8c77b8c2019-04-29 13:36:57 +0100587 def test_positional_only(self):
588 def f(a, b, /):
589 return a + b
590
591 p = functools.partial(f, 1)
592 self.assertEqual(p(2), f(1, 2))
593
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000594
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000595class TestUpdateWrapper(unittest.TestCase):
596
597 def check_wrapper(self, wrapper, wrapped,
598 assigned=functools.WRAPPER_ASSIGNMENTS,
599 updated=functools.WRAPPER_UPDATES):
600 # Check attributes were assigned
601 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000602 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000603 # Check attributes were updated
604 for name in updated:
605 wrapper_attr = getattr(wrapper, name)
606 wrapped_attr = getattr(wrapped, name)
607 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000608 if name == "__dict__" and key == "__wrapped__":
609 # __wrapped__ is overwritten by the update code
610 continue
611 self.assertIs(wrapped_attr[key], wrapper_attr[key])
612 # Check __wrapped__
613 self.assertIs(wrapper.__wrapped__, wrapped)
614
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000615
R. David Murray378c0cf2010-02-24 01:46:21 +0000616 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000617 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000618 """This is a test"""
619 pass
620 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000621 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000622 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000623 pass
624 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000625 return wrapper, f
626
627 def test_default_update(self):
628 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000629 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000630 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000631 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600632 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000633 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000634 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
635 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000636
R. David Murray378c0cf2010-02-24 01:46:21 +0000637 @unittest.skipIf(sys.flags.optimize >= 2,
638 "Docstrings are omitted with -O2 and above")
639 def test_default_update_doc(self):
640 wrapper, f = self._default_update()
641 self.assertEqual(wrapper.__doc__, 'This is a test')
642
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000643 def test_no_update(self):
644 def f():
645 """This is a test"""
646 pass
647 f.attr = 'This is also a test'
648 def wrapper():
649 pass
650 functools.update_wrapper(wrapper, f, (), ())
651 self.check_wrapper(wrapper, f, (), ())
652 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600653 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000654 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000655 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000656 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000657
658 def test_selective_update(self):
659 def f():
660 pass
661 f.attr = 'This is a different test'
662 f.dict_attr = dict(a=1, b=2, c=3)
663 def wrapper():
664 pass
665 wrapper.dict_attr = {}
666 assign = ('attr',)
667 update = ('dict_attr',)
668 functools.update_wrapper(wrapper, f, assign, update)
669 self.check_wrapper(wrapper, f, assign, update)
670 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600671 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000672 self.assertEqual(wrapper.__doc__, None)
673 self.assertEqual(wrapper.attr, 'This is a different test')
674 self.assertEqual(wrapper.dict_attr, f.dict_attr)
675
Nick Coghlan98876832010-08-17 06:17:18 +0000676 def test_missing_attributes(self):
677 def f():
678 pass
679 def wrapper():
680 pass
681 wrapper.dict_attr = {}
682 assign = ('attr',)
683 update = ('dict_attr',)
684 # Missing attributes on wrapped object are ignored
685 functools.update_wrapper(wrapper, f, assign, update)
686 self.assertNotIn('attr', wrapper.__dict__)
687 self.assertEqual(wrapper.dict_attr, {})
688 # Wrapper must have expected attributes for updating
689 del wrapper.dict_attr
690 with self.assertRaises(AttributeError):
691 functools.update_wrapper(wrapper, f, assign, update)
692 wrapper.dict_attr = 1
693 with self.assertRaises(AttributeError):
694 functools.update_wrapper(wrapper, f, assign, update)
695
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200696 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000697 @unittest.skipIf(sys.flags.optimize >= 2,
698 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000699 def test_builtin_update(self):
700 # Test for bug #1576241
701 def wrapper():
702 pass
703 functools.update_wrapper(wrapper, max)
704 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000705 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000706 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000707
Łukasz Langa6f692512013-06-05 12:20:24 +0200708
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000709class TestWraps(TestUpdateWrapper):
710
R. David Murray378c0cf2010-02-24 01:46:21 +0000711 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000712 def f():
713 """This is a test"""
714 pass
715 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000716 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000717 @functools.wraps(f)
718 def wrapper():
719 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600720 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000721
722 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600723 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000724 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000725 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600726 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000727 self.assertEqual(wrapper.attr, 'This is also a test')
728
Antoine Pitroub5b37142012-11-13 21:35:40 +0100729 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000730 "Docstrings are omitted with -O2 and above")
731 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600732 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000733 self.assertEqual(wrapper.__doc__, 'This is a test')
734
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000735 def test_no_update(self):
736 def f():
737 """This is a test"""
738 pass
739 f.attr = 'This is also a test'
740 @functools.wraps(f, (), ())
741 def wrapper():
742 pass
743 self.check_wrapper(wrapper, f, (), ())
744 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600745 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000746 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000747 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000748
749 def test_selective_update(self):
750 def f():
751 pass
752 f.attr = 'This is a different test'
753 f.dict_attr = dict(a=1, b=2, c=3)
754 def add_dict_attr(f):
755 f.dict_attr = {}
756 return f
757 assign = ('attr',)
758 update = ('dict_attr',)
759 @functools.wraps(f, assign, update)
760 @add_dict_attr
761 def wrapper():
762 pass
763 self.check_wrapper(wrapper, f, assign, update)
764 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600765 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000766 self.assertEqual(wrapper.__doc__, None)
767 self.assertEqual(wrapper.attr, 'This is a different test')
768 self.assertEqual(wrapper.dict_attr, f.dict_attr)
769
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000770
madman-bobe25d5fc2018-10-25 15:02:10 +0100771class TestReduce:
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000772 def test_reduce(self):
773 class Squares:
774 def __init__(self, max):
775 self.max = max
776 self.sofar = []
777
778 def __len__(self):
779 return len(self.sofar)
780
781 def __getitem__(self, i):
782 if not 0 <= i < self.max: raise IndexError
783 n = len(self.sofar)
784 while n <= i:
785 self.sofar.append(n*n)
786 n += 1
787 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000788 def add(x, y):
789 return x + y
madman-bobe25d5fc2018-10-25 15:02:10 +0100790 self.assertEqual(self.reduce(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000791 self.assertEqual(
madman-bobe25d5fc2018-10-25 15:02:10 +0100792 self.reduce(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000793 ['a','c','d','w']
794 )
madman-bobe25d5fc2018-10-25 15:02:10 +0100795 self.assertEqual(self.reduce(lambda x, y: x*y, range(2,8), 1), 5040)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000796 self.assertEqual(
madman-bobe25d5fc2018-10-25 15:02:10 +0100797 self.reduce(lambda x, y: x*y, range(2,21), 1),
Guido van Rossume2a383d2007-01-15 16:59:06 +0000798 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000799 )
madman-bobe25d5fc2018-10-25 15:02:10 +0100800 self.assertEqual(self.reduce(add, Squares(10)), 285)
801 self.assertEqual(self.reduce(add, Squares(10), 0), 285)
802 self.assertEqual(self.reduce(add, Squares(0), 0), 0)
803 self.assertRaises(TypeError, self.reduce)
804 self.assertRaises(TypeError, self.reduce, 42, 42)
805 self.assertRaises(TypeError, self.reduce, 42, 42, 42)
806 self.assertEqual(self.reduce(42, "1"), "1") # func is never called with one item
807 self.assertEqual(self.reduce(42, "", "1"), "1") # func is never called with one item
808 self.assertRaises(TypeError, self.reduce, 42, (42, 42))
809 self.assertRaises(TypeError, self.reduce, add, []) # arg 2 must not be empty sequence with no initial value
810 self.assertRaises(TypeError, self.reduce, add, "")
811 self.assertRaises(TypeError, self.reduce, add, ())
812 self.assertRaises(TypeError, self.reduce, add, object())
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000813
814 class TestFailingIter:
815 def __iter__(self):
816 raise RuntimeError
madman-bobe25d5fc2018-10-25 15:02:10 +0100817 self.assertRaises(RuntimeError, self.reduce, add, TestFailingIter())
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000818
madman-bobe25d5fc2018-10-25 15:02:10 +0100819 self.assertEqual(self.reduce(add, [], None), None)
820 self.assertEqual(self.reduce(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000821
822 class BadSeq:
823 def __getitem__(self, index):
824 raise ValueError
madman-bobe25d5fc2018-10-25 15:02:10 +0100825 self.assertRaises(ValueError, self.reduce, 42, BadSeq())
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000826
827 # Test reduce()'s use of iterators.
828 def test_iterator_usage(self):
829 class SequenceClass:
830 def __init__(self, n):
831 self.n = n
832 def __getitem__(self, i):
833 if 0 <= i < self.n:
834 return i
835 else:
836 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000837
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000838 from operator import add
madman-bobe25d5fc2018-10-25 15:02:10 +0100839 self.assertEqual(self.reduce(add, SequenceClass(5)), 10)
840 self.assertEqual(self.reduce(add, SequenceClass(5), 42), 52)
841 self.assertRaises(TypeError, self.reduce, add, SequenceClass(0))
842 self.assertEqual(self.reduce(add, SequenceClass(0), 42), 42)
843 self.assertEqual(self.reduce(add, SequenceClass(1)), 0)
844 self.assertEqual(self.reduce(add, SequenceClass(1), 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000845
846 d = {"one": 1, "two": 2, "three": 3}
madman-bobe25d5fc2018-10-25 15:02:10 +0100847 self.assertEqual(self.reduce(add, d), "".join(d.keys()))
848
849
850@unittest.skipUnless(c_functools, 'requires the C _functools module')
851class TestReduceC(TestReduce, unittest.TestCase):
852 if c_functools:
853 reduce = c_functools.reduce
854
855
856class TestReducePy(TestReduce, unittest.TestCase):
857 reduce = staticmethod(py_functools.reduce)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000858
Łukasz Langa6f692512013-06-05 12:20:24 +0200859
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200860class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700861
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000862 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700863 def cmp1(x, y):
864 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100865 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700866 self.assertEqual(key(3), key(3))
867 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100868 self.assertGreaterEqual(key(3), key(3))
869
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700870 def cmp2(x, y):
871 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100872 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700873 self.assertEqual(key(4.0), key('4'))
874 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100875 self.assertLessEqual(key(2), key('35'))
876 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700877
878 def test_cmp_to_key_arguments(self):
879 def cmp1(x, y):
880 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100881 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700882 self.assertEqual(key(obj=3), key(obj=3))
883 self.assertGreater(key(obj=3), key(obj=1))
884 with self.assertRaises((TypeError, AttributeError)):
885 key(3) > 1 # rhs is not a K object
886 with self.assertRaises((TypeError, AttributeError)):
887 1 < key(3) # lhs is not a K object
888 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100889 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700890 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200891 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100892 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700893 with self.assertRaises(TypeError):
894 key() # too few args
895 with self.assertRaises(TypeError):
896 key(None, None) # too many args
897
898 def test_bad_cmp(self):
899 def cmp1(x, y):
900 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100901 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700902 with self.assertRaises(ZeroDivisionError):
903 key(3) > key(1)
904
905 class BadCmp:
906 def __lt__(self, other):
907 raise ZeroDivisionError
908 def cmp1(x, y):
909 return BadCmp()
910 with self.assertRaises(ZeroDivisionError):
911 key(3) > key(1)
912
913 def test_obj_field(self):
914 def cmp1(x, y):
915 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100916 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700917 self.assertEqual(key(50).obj, 50)
918
919 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000920 def mycmp(x, y):
921 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100922 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000923 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000924
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700925 def test_sort_int_str(self):
926 def mycmp(x, y):
927 x, y = int(x), int(y)
928 return (x > y) - (x < y)
929 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100930 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700931 self.assertEqual([int(value) for value in values],
932 [0, 1, 1, 2, 3, 4, 5, 7, 10])
933
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000934 def test_hash(self):
935 def mycmp(x, y):
936 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100937 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000938 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700939 self.assertRaises(TypeError, hash, k)
Serhiy Storchaka2e576f52017-04-24 09:05:00 +0300940 self.assertNotIsInstance(k, collections.abc.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000941
Łukasz Langa6f692512013-06-05 12:20:24 +0200942
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200943@unittest.skipUnless(c_functools, 'requires the C _functools module')
944class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
945 if c_functools:
946 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100947
Łukasz Langa6f692512013-06-05 12:20:24 +0200948
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200949class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100950 cmp_to_key = staticmethod(py_functools.cmp_to_key)
951
Łukasz Langa6f692512013-06-05 12:20:24 +0200952
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000953class TestTotalOrdering(unittest.TestCase):
954
955 def test_total_ordering_lt(self):
956 @functools.total_ordering
957 class A:
958 def __init__(self, value):
959 self.value = value
960 def __lt__(self, other):
961 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000962 def __eq__(self, other):
963 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000964 self.assertTrue(A(1) < A(2))
965 self.assertTrue(A(2) > A(1))
966 self.assertTrue(A(1) <= A(2))
967 self.assertTrue(A(2) >= A(1))
968 self.assertTrue(A(2) <= A(2))
969 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000970 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000971
972 def test_total_ordering_le(self):
973 @functools.total_ordering
974 class A:
975 def __init__(self, value):
976 self.value = value
977 def __le__(self, other):
978 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000979 def __eq__(self, other):
980 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000981 self.assertTrue(A(1) < A(2))
982 self.assertTrue(A(2) > A(1))
983 self.assertTrue(A(1) <= A(2))
984 self.assertTrue(A(2) >= A(1))
985 self.assertTrue(A(2) <= A(2))
986 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000987 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000988
989 def test_total_ordering_gt(self):
990 @functools.total_ordering
991 class A:
992 def __init__(self, value):
993 self.value = value
994 def __gt__(self, other):
995 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000996 def __eq__(self, other):
997 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000998 self.assertTrue(A(1) < A(2))
999 self.assertTrue(A(2) > A(1))
1000 self.assertTrue(A(1) <= A(2))
1001 self.assertTrue(A(2) >= A(1))
1002 self.assertTrue(A(2) <= A(2))
1003 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +10001004 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +00001005
1006 def test_total_ordering_ge(self):
1007 @functools.total_ordering
1008 class A:
1009 def __init__(self, value):
1010 self.value = value
1011 def __ge__(self, other):
1012 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001013 def __eq__(self, other):
1014 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +00001015 self.assertTrue(A(1) < A(2))
1016 self.assertTrue(A(2) > A(1))
1017 self.assertTrue(A(1) <= A(2))
1018 self.assertTrue(A(2) >= A(1))
1019 self.assertTrue(A(2) <= A(2))
1020 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +10001021 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +00001022
1023 def test_total_ordering_no_overwrite(self):
1024 # new methods should not overwrite existing
1025 @functools.total_ordering
1026 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +00001027 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +00001028 self.assertTrue(A(1) < A(2))
1029 self.assertTrue(A(2) > A(1))
1030 self.assertTrue(A(1) <= A(2))
1031 self.assertTrue(A(2) >= A(1))
1032 self.assertTrue(A(2) <= A(2))
1033 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001034
Benjamin Peterson42ebee32010-04-11 01:43:16 +00001035 def test_no_operations_defined(self):
1036 with self.assertRaises(ValueError):
1037 @functools.total_ordering
1038 class A:
1039 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001040
Nick Coghlanf05d9812013-10-02 00:02:03 +10001041 def test_type_error_when_not_implemented(self):
1042 # bug 10042; ensure stack overflow does not occur
1043 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001044 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +10001045 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001046 def __init__(self, value):
1047 self.value = value
1048 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001049 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001050 return self.value == other.value
1051 return False
1052 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001053 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001054 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +10001055 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001056
Nick Coghlanf05d9812013-10-02 00:02:03 +10001057 @functools.total_ordering
1058 class ImplementsGreaterThan:
1059 def __init__(self, value):
1060 self.value = value
1061 def __eq__(self, other):
1062 if isinstance(other, ImplementsGreaterThan):
1063 return self.value == other.value
1064 return False
1065 def __gt__(self, other):
1066 if isinstance(other, ImplementsGreaterThan):
1067 return self.value > other.value
1068 return NotImplemented
1069
1070 @functools.total_ordering
1071 class ImplementsLessThanEqualTo:
1072 def __init__(self, value):
1073 self.value = value
1074 def __eq__(self, other):
1075 if isinstance(other, ImplementsLessThanEqualTo):
1076 return self.value == other.value
1077 return False
1078 def __le__(self, other):
1079 if isinstance(other, ImplementsLessThanEqualTo):
1080 return self.value <= other.value
1081 return NotImplemented
1082
1083 @functools.total_ordering
1084 class ImplementsGreaterThanEqualTo:
1085 def __init__(self, value):
1086 self.value = value
1087 def __eq__(self, other):
1088 if isinstance(other, ImplementsGreaterThanEqualTo):
1089 return self.value == other.value
1090 return False
1091 def __ge__(self, other):
1092 if isinstance(other, ImplementsGreaterThanEqualTo):
1093 return self.value >= other.value
1094 return NotImplemented
1095
1096 @functools.total_ordering
1097 class ComparatorNotImplemented:
1098 def __init__(self, value):
1099 self.value = value
1100 def __eq__(self, other):
1101 if isinstance(other, ComparatorNotImplemented):
1102 return self.value == other.value
1103 return False
1104 def __lt__(self, other):
1105 return NotImplemented
1106
1107 with self.subTest("LT < 1"), self.assertRaises(TypeError):
1108 ImplementsLessThan(-1) < 1
1109
1110 with self.subTest("LT < LE"), self.assertRaises(TypeError):
1111 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
1112
1113 with self.subTest("LT < GT"), self.assertRaises(TypeError):
1114 ImplementsLessThan(1) < ImplementsGreaterThan(1)
1115
1116 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
1117 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
1118
1119 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
1120 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
1121
1122 with self.subTest("GT > GE"), self.assertRaises(TypeError):
1123 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
1124
1125 with self.subTest("GT > LT"), self.assertRaises(TypeError):
1126 ImplementsGreaterThan(5) > ImplementsLessThan(5)
1127
1128 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
1129 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
1130
1131 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
1132 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
1133
1134 with self.subTest("GE when equal"):
1135 a = ComparatorNotImplemented(8)
1136 b = ComparatorNotImplemented(8)
1137 self.assertEqual(a, b)
1138 with self.assertRaises(TypeError):
1139 a >= b
1140
1141 with self.subTest("LE when equal"):
1142 a = ComparatorNotImplemented(9)
1143 b = ComparatorNotImplemented(9)
1144 self.assertEqual(a, b)
1145 with self.assertRaises(TypeError):
1146 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +02001147
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001148 def test_pickle(self):
Serhiy Storchaka92bb90a2016-09-22 11:39:25 +03001149 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001150 for name in '__lt__', '__gt__', '__le__', '__ge__':
1151 with self.subTest(method=name, proto=proto):
1152 method = getattr(Orderable_LT, name)
1153 method_copy = pickle.loads(pickle.dumps(method, proto))
1154 self.assertIs(method_copy, method)
1155
1156@functools.total_ordering
1157class Orderable_LT:
1158 def __init__(self, value):
1159 self.value = value
1160 def __lt__(self, other):
1161 return self.value < other.value
1162 def __eq__(self, other):
1163 return self.value == other.value
1164
1165
Raymond Hettinger21cdb712020-05-11 17:00:53 -07001166class TestCache:
1167 # This tests that the pass-through is working as designed.
1168 # The underlying functionality is tested in TestLRU.
1169
1170 def test_cache(self):
1171 @self.module.cache
1172 def fib(n):
1173 if n < 2:
1174 return n
1175 return fib(n-1) + fib(n-2)
1176 self.assertEqual([fib(n) for n in range(16)],
1177 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1178 self.assertEqual(fib.cache_info(),
1179 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1180 fib.cache_clear()
1181 self.assertEqual(fib.cache_info(),
1182 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1183
1184
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001185class TestLRU:
Georg Brandl2e7346a2010-07-31 18:09:23 +00001186
1187 def test_lru(self):
1188 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001189 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001190 f = self.module.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001191 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001192 self.assertEqual(maxsize, 20)
1193 self.assertEqual(currsize, 0)
1194 self.assertEqual(hits, 0)
1195 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001196
1197 domain = range(5)
1198 for i in range(1000):
1199 x, y = choice(domain), choice(domain)
1200 actual = f(x, y)
1201 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +00001202 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001203 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001204 self.assertTrue(hits > misses)
1205 self.assertEqual(hits + misses, 1000)
1206 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001207
Raymond Hettinger02566ec2010-09-04 22:46:06 +00001208 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +00001209 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001210 self.assertEqual(hits, 0)
1211 self.assertEqual(misses, 0)
1212 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001213 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001214 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001215 self.assertEqual(hits, 0)
1216 self.assertEqual(misses, 1)
1217 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001218
Nick Coghlan98876832010-08-17 06:17:18 +00001219 # Test bypassing the cache
1220 self.assertIs(f.__wrapped__, orig)
1221 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001222 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001223 self.assertEqual(hits, 0)
1224 self.assertEqual(misses, 1)
1225 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +00001226
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001227 # test size zero (which means "never-cache")
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001228 @self.module.lru_cache(0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001229 def f():
1230 nonlocal f_cnt
1231 f_cnt += 1
1232 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001233 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001234 f_cnt = 0
1235 for i in range(5):
1236 self.assertEqual(f(), 20)
1237 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001238 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001239 self.assertEqual(hits, 0)
1240 self.assertEqual(misses, 5)
1241 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001242
1243 # test size one
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001244 @self.module.lru_cache(1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001245 def f():
1246 nonlocal f_cnt
1247 f_cnt += 1
1248 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001249 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001250 f_cnt = 0
1251 for i in range(5):
1252 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001253 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001254 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001255 self.assertEqual(hits, 4)
1256 self.assertEqual(misses, 1)
1257 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001258
Raymond Hettingerf3098282010-08-15 03:30:45 +00001259 # test size two
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001260 @self.module.lru_cache(2)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001261 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001262 nonlocal f_cnt
1263 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +00001264 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +00001265 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001266 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +00001267 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1268 # * * * *
1269 self.assertEqual(f(x), x*10)
1270 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001271 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001272 self.assertEqual(hits, 12)
1273 self.assertEqual(misses, 4)
1274 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001275
Raymond Hettingerb8218682019-05-26 11:27:35 -07001276 def test_lru_no_args(self):
1277 @self.module.lru_cache
1278 def square(x):
1279 return x ** 2
1280
1281 self.assertEqual(list(map(square, [10, 20, 10])),
1282 [100, 400, 100])
1283 self.assertEqual(square.cache_info().hits, 1)
1284 self.assertEqual(square.cache_info().misses, 2)
1285 self.assertEqual(square.cache_info().maxsize, 128)
1286 self.assertEqual(square.cache_info().currsize, 2)
1287
Raymond Hettingerd8080c02019-01-26 03:02:00 -05001288 def test_lru_bug_35780(self):
1289 # C version of the lru_cache was not checking to see if
1290 # the user function call has already modified the cache
1291 # (this arises in recursive calls and in multi-threading).
1292 # This cause the cache to have orphan links not referenced
1293 # by the cache dictionary.
1294
1295 once = True # Modified by f(x) below
1296
1297 @self.module.lru_cache(maxsize=10)
1298 def f(x):
1299 nonlocal once
1300 rv = f'.{x}.'
1301 if x == 20 and once:
1302 once = False
1303 rv = f(x)
1304 return rv
1305
1306 # Fill the cache
1307 for x in range(15):
1308 self.assertEqual(f(x), f'.{x}.')
1309 self.assertEqual(f.cache_info().currsize, 10)
1310
1311 # Make a recursive call and make sure the cache remains full
1312 self.assertEqual(f(20), '.20.')
1313 self.assertEqual(f.cache_info().currsize, 10)
1314
Raymond Hettinger14adbd42019-04-20 07:20:44 -10001315 def test_lru_bug_36650(self):
1316 # C version of lru_cache was treating a call with an empty **kwargs
1317 # dictionary as being distinct from a call with no keywords at all.
1318 # This did not result in an incorrect answer, but it did trigger
1319 # an unexpected cache miss.
1320
1321 @self.module.lru_cache()
1322 def f(x):
1323 pass
1324
1325 f(0)
1326 f(0, **{})
1327 self.assertEqual(f.cache_info().hits, 1)
1328
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001329 def test_lru_hash_only_once(self):
1330 # To protect against weird reentrancy bugs and to improve
1331 # efficiency when faced with slow __hash__ methods, the
1332 # LRU cache guarantees that it will only call __hash__
1333 # only once per use as an argument to the cached function.
1334
1335 @self.module.lru_cache(maxsize=1)
1336 def f(x, y):
1337 return x * 3 + y
1338
1339 # Simulate the integer 5
1340 mock_int = unittest.mock.Mock()
1341 mock_int.__mul__ = unittest.mock.Mock(return_value=15)
1342 mock_int.__hash__ = unittest.mock.Mock(return_value=999)
1343
1344 # Add to cache: One use as an argument gives one call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001345 self.assertEqual(f(mock_int, 1), 16)
1346 self.assertEqual(mock_int.__hash__.call_count, 1)
1347 self.assertEqual(f.cache_info(), (0, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001348
1349 # Cache hit: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001350 self.assertEqual(f(mock_int, 1), 16)
1351 self.assertEqual(mock_int.__hash__.call_count, 2)
1352 self.assertEqual(f.cache_info(), (1, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001353
Ville Skyttä49b27342017-08-03 09:00:59 +03001354 # Cache eviction: No use as an argument gives no additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001355 self.assertEqual(f(6, 2), 20)
1356 self.assertEqual(mock_int.__hash__.call_count, 2)
1357 self.assertEqual(f.cache_info(), (1, 2, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001358
1359 # Cache miss: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001360 self.assertEqual(f(mock_int, 1), 16)
1361 self.assertEqual(mock_int.__hash__.call_count, 3)
1362 self.assertEqual(f.cache_info(), (1, 3, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001363
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08001364 def test_lru_reentrancy_with_len(self):
1365 # Test to make sure the LRU cache code isn't thrown-off by
1366 # caching the built-in len() function. Since len() can be
1367 # cached, we shouldn't use it inside the lru code itself.
1368 old_len = builtins.len
1369 try:
1370 builtins.len = self.module.lru_cache(4)(len)
1371 for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
1372 self.assertEqual(len('abcdefghijklmn'[:i]), i)
1373 finally:
1374 builtins.len = old_len
1375
Raymond Hettinger605a4472017-01-09 07:50:19 -08001376 def test_lru_star_arg_handling(self):
1377 # Test regression that arose in ea064ff3c10f
1378 @functools.lru_cache()
1379 def f(*args):
1380 return args
1381
1382 self.assertEqual(f(1, 2), (1, 2))
1383 self.assertEqual(f((1, 2)), ((1, 2),))
1384
Yury Selivanov46a02db2016-11-09 18:55:45 -05001385 def test_lru_type_error(self):
1386 # Regression test for issue #28653.
1387 # lru_cache was leaking when one of the arguments
1388 # wasn't cacheable.
1389
1390 @functools.lru_cache(maxsize=None)
1391 def infinite_cache(o):
1392 pass
1393
1394 @functools.lru_cache(maxsize=10)
1395 def limited_cache(o):
1396 pass
1397
1398 with self.assertRaises(TypeError):
1399 infinite_cache([])
1400
1401 with self.assertRaises(TypeError):
1402 limited_cache([])
1403
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001404 def test_lru_with_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001405 @self.module.lru_cache(maxsize=None)
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001406 def fib(n):
1407 if n < 2:
1408 return n
1409 return fib(n-1) + fib(n-2)
1410 self.assertEqual([fib(n) for n in range(16)],
1411 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1412 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001413 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001414 fib.cache_clear()
1415 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001416 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1417
1418 def test_lru_with_maxsize_negative(self):
1419 @self.module.lru_cache(maxsize=-10)
1420 def eq(n):
1421 return n
1422 for i in (0, 1):
1423 self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1424 self.assertEqual(eq.cache_info(),
Raymond Hettingerd8080c02019-01-26 03:02:00 -05001425 self.module._CacheInfo(hits=0, misses=300, maxsize=0, currsize=0))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001426
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001427 def test_lru_with_exceptions(self):
1428 # Verify that user_function exceptions get passed through without
1429 # creating a hard-to-read chained exception.
1430 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +01001431 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001432 @self.module.lru_cache(maxsize)
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001433 def func(i):
1434 return 'abc'[i]
1435 self.assertEqual(func(0), 'a')
1436 with self.assertRaises(IndexError) as cm:
1437 func(15)
1438 self.assertIsNone(cm.exception.__context__)
1439 # Verify that the previous exception did not result in a cached entry
1440 with self.assertRaises(IndexError):
1441 func(15)
1442
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001443 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001444 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001445 @self.module.lru_cache(maxsize=maxsize, typed=True)
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001446 def square(x):
1447 return x * x
1448 self.assertEqual(square(3), 9)
1449 self.assertEqual(type(square(3)), type(9))
1450 self.assertEqual(square(3.0), 9.0)
1451 self.assertEqual(type(square(3.0)), type(9.0))
1452 self.assertEqual(square(x=3), 9)
1453 self.assertEqual(type(square(x=3)), type(9))
1454 self.assertEqual(square(x=3.0), 9.0)
1455 self.assertEqual(type(square(x=3.0)), type(9.0))
1456 self.assertEqual(square.cache_info().hits, 4)
1457 self.assertEqual(square.cache_info().misses, 4)
1458
Antoine Pitroub5b37142012-11-13 21:35:40 +01001459 def test_lru_with_keyword_args(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001460 @self.module.lru_cache()
Antoine Pitroub5b37142012-11-13 21:35:40 +01001461 def fib(n):
1462 if n < 2:
1463 return n
1464 return fib(n=n-1) + fib(n=n-2)
1465 self.assertEqual(
1466 [fib(n=number) for number in range(16)],
1467 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1468 )
1469 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001470 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001471 fib.cache_clear()
1472 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001473 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001474
1475 def test_lru_with_keyword_args_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001476 @self.module.lru_cache(maxsize=None)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001477 def fib(n):
1478 if n < 2:
1479 return n
1480 return fib(n=n-1) + fib(n=n-2)
1481 self.assertEqual([fib(n=number) for number in range(16)],
1482 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1483 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001484 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001485 fib.cache_clear()
1486 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001487 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1488
Raymond Hettinger4ee39142017-01-08 17:28:20 -08001489 def test_kwargs_order(self):
1490 # PEP 468: Preserving Keyword Argument Order
1491 @self.module.lru_cache(maxsize=10)
1492 def f(**kwargs):
1493 return list(kwargs.items())
1494 self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)])
1495 self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)])
1496 self.assertEqual(f.cache_info(),
1497 self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2))
1498
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001499 def test_lru_cache_decoration(self):
1500 def f(zomg: 'zomg_annotation'):
1501 """f doc string"""
1502 return 42
1503 g = self.module.lru_cache()(f)
1504 for attr in self.module.WRAPPER_ASSIGNMENTS:
1505 self.assertEqual(getattr(g, attr), getattr(f, attr))
1506
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001507 def test_lru_cache_threaded(self):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001508 n, m = 5, 11
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001509 def orig(x, y):
1510 return 3 * x + y
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001511 f = self.module.lru_cache(maxsize=n*m)(orig)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001512 hits, misses, maxsize, currsize = f.cache_info()
1513 self.assertEqual(currsize, 0)
1514
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001515 start = threading.Event()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001516 def full(k):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001517 start.wait(10)
1518 for _ in range(m):
Serhiy Storchaka391af752015-06-08 12:44:18 +03001519 self.assertEqual(f(k, 0), orig(k, 0))
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001520
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001521 def clear():
1522 start.wait(10)
1523 for _ in range(2*m):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001524 f.cache_clear()
1525
1526 orig_si = sys.getswitchinterval()
Xavier de Gaye7522ef42016-12-08 11:06:56 +01001527 support.setswitchinterval(1e-6)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001528 try:
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001529 # create n threads in order to fill cache
Serhiy Storchaka391af752015-06-08 12:44:18 +03001530 threads = [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001531 for k in range(n)]
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001532 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001533 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001534
1535 hits, misses, maxsize, currsize = f.cache_info()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001536 if self.module is py_functools:
1537 # XXX: Why can be not equal?
1538 self.assertLessEqual(misses, n)
1539 self.assertLessEqual(hits, m*n - misses)
1540 else:
1541 self.assertEqual(misses, n)
1542 self.assertEqual(hits, m*n - misses)
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001543 self.assertEqual(currsize, n)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001544
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001545 # create n threads in order to fill cache and 1 to clear it
1546 threads = [threading.Thread(target=clear)]
Serhiy Storchaka391af752015-06-08 12:44:18 +03001547 threads += [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001548 for k in range(n)]
1549 start.clear()
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001550 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001551 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001552 finally:
1553 sys.setswitchinterval(orig_si)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001554
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001555 def test_lru_cache_threaded2(self):
1556 # Simultaneous call with the same arguments
1557 n, m = 5, 7
1558 start = threading.Barrier(n+1)
1559 pause = threading.Barrier(n+1)
1560 stop = threading.Barrier(n+1)
1561 @self.module.lru_cache(maxsize=m*n)
1562 def f(x):
1563 pause.wait(10)
1564 return 3 * x
1565 self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1566 def test():
1567 for i in range(m):
1568 start.wait(10)
1569 self.assertEqual(f(i), 3 * i)
1570 stop.wait(10)
1571 threads = [threading.Thread(target=test) for k in range(n)]
1572 with support.start_threads(threads):
1573 for i in range(m):
1574 start.wait(10)
1575 stop.reset()
1576 pause.wait(10)
1577 start.reset()
1578 stop.wait(10)
1579 pause.reset()
1580 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1581
Serhiy Storchaka67796522017-01-12 18:34:33 +02001582 def test_lru_cache_threaded3(self):
1583 @self.module.lru_cache(maxsize=2)
1584 def f(x):
1585 time.sleep(.01)
1586 return 3 * x
1587 def test(i, x):
1588 with self.subTest(thread=i):
1589 self.assertEqual(f(x), 3 * x, i)
1590 threads = [threading.Thread(target=test, args=(i, v))
1591 for i, v in enumerate([1, 2, 2, 3, 2])]
1592 with support.start_threads(threads):
1593 pass
1594
Raymond Hettinger03923422013-03-04 02:52:50 -05001595 def test_need_for_rlock(self):
1596 # This will deadlock on an LRU cache that uses a regular lock
1597
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001598 @self.module.lru_cache(maxsize=10)
Raymond Hettinger03923422013-03-04 02:52:50 -05001599 def test_func(x):
1600 'Used to demonstrate a reentrant lru_cache call within a single thread'
1601 return x
1602
1603 class DoubleEq:
1604 'Demonstrate a reentrant lru_cache call within a single thread'
1605 def __init__(self, x):
1606 self.x = x
1607 def __hash__(self):
1608 return self.x
1609 def __eq__(self, other):
1610 if self.x == 2:
1611 test_func(DoubleEq(1))
1612 return self.x == other.x
1613
1614 test_func(DoubleEq(1)) # Load the cache
1615 test_func(DoubleEq(2)) # Load the cache
1616 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1617 DoubleEq(2)) # Verify the correct return value
1618
Serhiy Storchakae7070f02015-06-08 11:19:24 +03001619 def test_lru_method(self):
1620 class X(int):
1621 f_cnt = 0
1622 @self.module.lru_cache(2)
1623 def f(self, x):
1624 self.f_cnt += 1
1625 return x*10+self
1626 a = X(5)
1627 b = X(5)
1628 c = X(7)
1629 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1630
1631 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1632 self.assertEqual(a.f(x), x*10 + 5)
1633 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1634 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1635
1636 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1637 self.assertEqual(b.f(x), x*10 + 5)
1638 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1639 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1640
1641 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1642 self.assertEqual(c.f(x), x*10 + 7)
1643 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1644 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1645
1646 self.assertEqual(a.f.cache_info(), X.f.cache_info())
1647 self.assertEqual(b.f.cache_info(), X.f.cache_info())
1648 self.assertEqual(c.f.cache_info(), X.f.cache_info())
1649
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001650 def test_pickle(self):
1651 cls = self.__class__
1652 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1653 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1654 with self.subTest(proto=proto, func=f):
1655 f_copy = pickle.loads(pickle.dumps(f, proto))
1656 self.assertIs(f_copy, f)
1657
1658 def test_copy(self):
1659 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001660 def orig(x, y):
1661 return 3 * x + y
1662 part = self.module.partial(orig, 2)
1663 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1664 self.module.lru_cache(2)(part))
1665 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001666 with self.subTest(func=f):
1667 f_copy = copy.copy(f)
1668 self.assertIs(f_copy, f)
1669
1670 def test_deepcopy(self):
1671 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001672 def orig(x, y):
1673 return 3 * x + y
1674 part = self.module.partial(orig, 2)
1675 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1676 self.module.lru_cache(2)(part))
1677 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001678 with self.subTest(func=f):
1679 f_copy = copy.deepcopy(f)
1680 self.assertIs(f_copy, f)
1681
Manjusaka051ff522019-11-12 15:30:18 +08001682 def test_lru_cache_parameters(self):
1683 @self.module.lru_cache(maxsize=2)
1684 def f():
1685 return 1
1686 self.assertEqual(f.cache_parameters(), {'maxsize': 2, "typed": False})
1687
1688 @self.module.lru_cache(maxsize=1000, typed=True)
1689 def f():
1690 return 1
1691 self.assertEqual(f.cache_parameters(), {'maxsize': 1000, "typed": True})
1692
Dennis Sweeney1253c3e2020-05-05 17:14:32 -04001693 def test_lru_cache_weakrefable(self):
1694 @self.module.lru_cache
1695 def test_function(x):
1696 return x
1697
1698 class A:
1699 @self.module.lru_cache
1700 def test_method(self, x):
1701 return (self, x)
1702
1703 @staticmethod
1704 @self.module.lru_cache
1705 def test_staticmethod(x):
1706 return (self, x)
1707
1708 refs = [weakref.ref(test_function),
1709 weakref.ref(A.test_method),
1710 weakref.ref(A.test_staticmethod)]
1711
1712 for ref in refs:
1713 self.assertIsNotNone(ref())
1714
1715 del A
1716 del test_function
1717 gc.collect()
1718
1719 for ref in refs:
1720 self.assertIsNone(ref())
1721
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001722
1723@py_functools.lru_cache()
1724def py_cached_func(x, y):
1725 return 3 * x + y
1726
1727@c_functools.lru_cache()
1728def c_cached_func(x, y):
1729 return 3 * x + y
1730
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001731
1732class TestLRUPy(TestLRU, unittest.TestCase):
1733 module = py_functools
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001734 cached_func = py_cached_func,
1735
1736 @module.lru_cache()
1737 def cached_meth(self, x, y):
1738 return 3 * x + y
1739
1740 @staticmethod
1741 @module.lru_cache()
1742 def cached_staticmeth(x, y):
1743 return 3 * x + y
1744
1745
1746class TestLRUC(TestLRU, unittest.TestCase):
1747 module = c_functools
1748 cached_func = c_cached_func,
1749
1750 @module.lru_cache()
1751 def cached_meth(self, x, y):
1752 return 3 * x + y
1753
1754 @staticmethod
1755 @module.lru_cache()
1756 def cached_staticmeth(x, y):
1757 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001758
Raymond Hettinger03923422013-03-04 02:52:50 -05001759
Łukasz Langa6f692512013-06-05 12:20:24 +02001760class TestSingleDispatch(unittest.TestCase):
1761 def test_simple_overloads(self):
1762 @functools.singledispatch
1763 def g(obj):
1764 return "base"
1765 def g_int(i):
1766 return "integer"
1767 g.register(int, g_int)
1768 self.assertEqual(g("str"), "base")
1769 self.assertEqual(g(1), "integer")
1770 self.assertEqual(g([1,2,3]), "base")
1771
1772 def test_mro(self):
1773 @functools.singledispatch
1774 def g(obj):
1775 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001776 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001777 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001778 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001779 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001780 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001781 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001782 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001783 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001784 def g_A(a):
1785 return "A"
1786 def g_B(b):
1787 return "B"
1788 g.register(A, g_A)
1789 g.register(B, g_B)
1790 self.assertEqual(g(A()), "A")
1791 self.assertEqual(g(B()), "B")
1792 self.assertEqual(g(C()), "A")
1793 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001794
1795 def test_register_decorator(self):
1796 @functools.singledispatch
1797 def g(obj):
1798 return "base"
1799 @g.register(int)
1800 def g_int(i):
1801 return "int %s" % (i,)
1802 self.assertEqual(g(""), "base")
1803 self.assertEqual(g(12), "int 12")
1804 self.assertIs(g.dispatch(int), g_int)
1805 self.assertIs(g.dispatch(object), g.dispatch(str))
1806 # Note: in the assert above this is not g.
1807 # @singledispatch returns the wrapper.
1808
1809 def test_wrapping_attributes(self):
1810 @functools.singledispatch
1811 def g(obj):
1812 "Simple test"
1813 return "Test"
1814 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02001815 if sys.flags.optimize < 2:
1816 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02001817
1818 @unittest.skipUnless(decimal, 'requires _decimal')
1819 @support.cpython_only
1820 def test_c_classes(self):
1821 @functools.singledispatch
1822 def g(obj):
1823 return "base"
1824 @g.register(decimal.DecimalException)
1825 def _(obj):
1826 return obj.args
1827 subn = decimal.Subnormal("Exponent < Emin")
1828 rnd = decimal.Rounded("Number got rounded")
1829 self.assertEqual(g(subn), ("Exponent < Emin",))
1830 self.assertEqual(g(rnd), ("Number got rounded",))
1831 @g.register(decimal.Subnormal)
1832 def _(obj):
1833 return "Too small to care."
1834 self.assertEqual(g(subn), "Too small to care.")
1835 self.assertEqual(g(rnd), ("Number got rounded",))
1836
1837 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001838 # None of the examples in this test depend on haystack ordering.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001839 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001840 mro = functools._compose_mro
1841 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1842 for haystack in permutations(bases):
1843 m = mro(dict, haystack)
Guido van Rossumf0666942016-08-23 10:47:07 -07001844 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
1845 c.Collection, c.Sized, c.Iterable,
1846 c.Container, object])
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001847 bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict]
Łukasz Langa6f692512013-06-05 12:20:24 +02001848 for haystack in permutations(bases):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001849 m = mro(collections.ChainMap, haystack)
1850 self.assertEqual(m, [collections.ChainMap, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001851 c.Collection, c.Sized, c.Iterable,
1852 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001853
1854 # If there's a generic function with implementations registered for
1855 # both Sized and Container, passing a defaultdict to it results in an
1856 # ambiguous dispatch which will cause a RuntimeError (see
1857 # test_mro_conflicts).
1858 bases = [c.Container, c.Sized, str]
1859 for haystack in permutations(bases):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001860 m = mro(collections.defaultdict, [c.Sized, c.Container, str])
1861 self.assertEqual(m, [collections.defaultdict, dict, c.Sized,
1862 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001863
1864 # MutableSequence below is registered directly on D. In other words, it
Martin Panter46f50722016-05-26 05:35:26 +00001865 # precedes MutableMapping which means single dispatch will always
Łukasz Langa3720c772013-07-01 16:00:38 +02001866 # choose MutableSequence here.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001867 class D(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02001868 pass
1869 c.MutableSequence.register(D)
1870 bases = [c.MutableSequence, c.MutableMapping]
1871 for haystack in permutations(bases):
1872 m = mro(D, bases)
Guido van Rossumf0666942016-08-23 10:47:07 -07001873 self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001874 collections.defaultdict, dict, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001875 c.Collection, c.Sized, c.Iterable, c.Container,
Łukasz Langa3720c772013-07-01 16:00:38 +02001876 object])
1877
1878 # Container and Callable are registered on different base classes and
1879 # a generic function supporting both should always pick the Callable
1880 # implementation if a C instance is passed.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001881 class C(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02001882 def __call__(self):
1883 pass
1884 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1885 for haystack in permutations(bases):
1886 m = mro(C, haystack)
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001887 self.assertEqual(m, [C, c.Callable, collections.defaultdict, dict, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001888 c.Collection, c.Sized, c.Iterable,
1889 c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001890
1891 def test_register_abc(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001892 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001893 d = {"a": "b"}
1894 l = [1, 2, 3]
1895 s = {object(), None}
1896 f = frozenset(s)
1897 t = (1, 2, 3)
1898 @functools.singledispatch
1899 def g(obj):
1900 return "base"
1901 self.assertEqual(g(d), "base")
1902 self.assertEqual(g(l), "base")
1903 self.assertEqual(g(s), "base")
1904 self.assertEqual(g(f), "base")
1905 self.assertEqual(g(t), "base")
1906 g.register(c.Sized, lambda obj: "sized")
1907 self.assertEqual(g(d), "sized")
1908 self.assertEqual(g(l), "sized")
1909 self.assertEqual(g(s), "sized")
1910 self.assertEqual(g(f), "sized")
1911 self.assertEqual(g(t), "sized")
1912 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1913 self.assertEqual(g(d), "mutablemapping")
1914 self.assertEqual(g(l), "sized")
1915 self.assertEqual(g(s), "sized")
1916 self.assertEqual(g(f), "sized")
1917 self.assertEqual(g(t), "sized")
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001918 g.register(collections.ChainMap, lambda obj: "chainmap")
Łukasz Langa6f692512013-06-05 12:20:24 +02001919 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1920 self.assertEqual(g(l), "sized")
1921 self.assertEqual(g(s), "sized")
1922 self.assertEqual(g(f), "sized")
1923 self.assertEqual(g(t), "sized")
1924 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1925 self.assertEqual(g(d), "mutablemapping")
1926 self.assertEqual(g(l), "mutablesequence")
1927 self.assertEqual(g(s), "sized")
1928 self.assertEqual(g(f), "sized")
1929 self.assertEqual(g(t), "sized")
1930 g.register(c.MutableSet, lambda obj: "mutableset")
1931 self.assertEqual(g(d), "mutablemapping")
1932 self.assertEqual(g(l), "mutablesequence")
1933 self.assertEqual(g(s), "mutableset")
1934 self.assertEqual(g(f), "sized")
1935 self.assertEqual(g(t), "sized")
1936 g.register(c.Mapping, lambda obj: "mapping")
1937 self.assertEqual(g(d), "mutablemapping") # not specific enough
1938 self.assertEqual(g(l), "mutablesequence")
1939 self.assertEqual(g(s), "mutableset")
1940 self.assertEqual(g(f), "sized")
1941 self.assertEqual(g(t), "sized")
1942 g.register(c.Sequence, lambda obj: "sequence")
1943 self.assertEqual(g(d), "mutablemapping")
1944 self.assertEqual(g(l), "mutablesequence")
1945 self.assertEqual(g(s), "mutableset")
1946 self.assertEqual(g(f), "sized")
1947 self.assertEqual(g(t), "sequence")
1948 g.register(c.Set, lambda obj: "set")
1949 self.assertEqual(g(d), "mutablemapping")
1950 self.assertEqual(g(l), "mutablesequence")
1951 self.assertEqual(g(s), "mutableset")
1952 self.assertEqual(g(f), "set")
1953 self.assertEqual(g(t), "sequence")
1954 g.register(dict, lambda obj: "dict")
1955 self.assertEqual(g(d), "dict")
1956 self.assertEqual(g(l), "mutablesequence")
1957 self.assertEqual(g(s), "mutableset")
1958 self.assertEqual(g(f), "set")
1959 self.assertEqual(g(t), "sequence")
1960 g.register(list, lambda obj: "list")
1961 self.assertEqual(g(d), "dict")
1962 self.assertEqual(g(l), "list")
1963 self.assertEqual(g(s), "mutableset")
1964 self.assertEqual(g(f), "set")
1965 self.assertEqual(g(t), "sequence")
1966 g.register(set, lambda obj: "concrete-set")
1967 self.assertEqual(g(d), "dict")
1968 self.assertEqual(g(l), "list")
1969 self.assertEqual(g(s), "concrete-set")
1970 self.assertEqual(g(f), "set")
1971 self.assertEqual(g(t), "sequence")
1972 g.register(frozenset, lambda obj: "frozen-set")
1973 self.assertEqual(g(d), "dict")
1974 self.assertEqual(g(l), "list")
1975 self.assertEqual(g(s), "concrete-set")
1976 self.assertEqual(g(f), "frozen-set")
1977 self.assertEqual(g(t), "sequence")
1978 g.register(tuple, lambda obj: "tuple")
1979 self.assertEqual(g(d), "dict")
1980 self.assertEqual(g(l), "list")
1981 self.assertEqual(g(s), "concrete-set")
1982 self.assertEqual(g(f), "frozen-set")
1983 self.assertEqual(g(t), "tuple")
1984
Łukasz Langa3720c772013-07-01 16:00:38 +02001985 def test_c3_abc(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001986 c = collections.abc
Łukasz Langa3720c772013-07-01 16:00:38 +02001987 mro = functools._c3_mro
1988 class A(object):
1989 pass
1990 class B(A):
1991 def __len__(self):
1992 return 0 # implies Sized
1993 @c.Container.register
1994 class C(object):
1995 pass
1996 class D(object):
1997 pass # unrelated
1998 class X(D, C, B):
1999 def __call__(self):
2000 pass # implies Callable
2001 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
2002 for abcs in permutations([c.Sized, c.Callable, c.Container]):
2003 self.assertEqual(mro(X, abcs=abcs), expected)
2004 # unrelated ABCs don't appear in the resulting MRO
2005 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
2006 self.assertEqual(mro(X, abcs=many_abcs), expected)
2007
Yury Selivanov77a8cd62015-08-18 14:20:00 -04002008 def test_false_meta(self):
2009 # see issue23572
2010 class MetaA(type):
2011 def __len__(self):
2012 return 0
2013 class A(metaclass=MetaA):
2014 pass
2015 class AA(A):
2016 pass
2017 @functools.singledispatch
2018 def fun(a):
2019 return 'base A'
2020 @fun.register(A)
2021 def _(a):
2022 return 'fun A'
2023 aa = AA()
2024 self.assertEqual(fun(aa), 'fun A')
2025
Łukasz Langa6f692512013-06-05 12:20:24 +02002026 def test_mro_conflicts(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002027 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02002028 @functools.singledispatch
2029 def g(arg):
2030 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02002031 class O(c.Sized):
2032 def __len__(self):
2033 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02002034 o = O()
2035 self.assertEqual(g(o), "base")
2036 g.register(c.Iterable, lambda arg: "iterable")
2037 g.register(c.Container, lambda arg: "container")
2038 g.register(c.Sized, lambda arg: "sized")
2039 g.register(c.Set, lambda arg: "set")
2040 self.assertEqual(g(o), "sized")
2041 c.Iterable.register(O)
2042 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
2043 c.Container.register(O)
2044 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02002045 c.Set.register(O)
2046 self.assertEqual(g(o), "set") # because c.Set is a subclass of
2047 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02002048 class P:
2049 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02002050 p = P()
2051 self.assertEqual(g(p), "base")
2052 c.Iterable.register(P)
2053 self.assertEqual(g(p), "iterable")
2054 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02002055 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02002056 g(p)
Benjamin Petersonab078e92016-07-13 21:13:29 -07002057 self.assertIn(
2058 str(re_one.exception),
2059 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2060 "or <class 'collections.abc.Iterable'>"),
2061 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
2062 "or <class 'collections.abc.Container'>")),
2063 )
Łukasz Langa6f692512013-06-05 12:20:24 +02002064 class Q(c.Sized):
2065 def __len__(self):
2066 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02002067 q = Q()
2068 self.assertEqual(g(q), "sized")
2069 c.Iterable.register(Q)
2070 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
2071 c.Set.register(Q)
2072 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02002073 # c.Sized and c.Iterable
2074 @functools.singledispatch
2075 def h(arg):
2076 return "base"
2077 @h.register(c.Sized)
2078 def _(arg):
2079 return "sized"
2080 @h.register(c.Container)
2081 def _(arg):
2082 return "container"
2083 # Even though Sized and Container are explicit bases of MutableMapping,
2084 # this ABC is implicitly registered on defaultdict which makes all of
2085 # MutableMapping's bases implicit as well from defaultdict's
2086 # perspective.
2087 with self.assertRaises(RuntimeError) as re_two:
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002088 h(collections.defaultdict(lambda: 0))
Benjamin Petersonab078e92016-07-13 21:13:29 -07002089 self.assertIn(
2090 str(re_two.exception),
2091 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2092 "or <class 'collections.abc.Sized'>"),
2093 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2094 "or <class 'collections.abc.Container'>")),
2095 )
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002096 class R(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02002097 pass
2098 c.MutableSequence.register(R)
2099 @functools.singledispatch
2100 def i(arg):
2101 return "base"
2102 @i.register(c.MutableMapping)
2103 def _(arg):
2104 return "mapping"
2105 @i.register(c.MutableSequence)
2106 def _(arg):
2107 return "sequence"
2108 r = R()
2109 self.assertEqual(i(r), "sequence")
2110 class S:
2111 pass
2112 class T(S, c.Sized):
2113 def __len__(self):
2114 return 0
2115 t = T()
2116 self.assertEqual(h(t), "sized")
2117 c.Container.register(T)
2118 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
2119 class U:
2120 def __len__(self):
2121 return 0
2122 u = U()
2123 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
2124 # from the existence of __len__()
2125 c.Container.register(U)
2126 # There is no preference for registered versus inferred ABCs.
2127 with self.assertRaises(RuntimeError) as re_three:
2128 h(u)
Benjamin Petersonab078e92016-07-13 21:13:29 -07002129 self.assertIn(
2130 str(re_three.exception),
2131 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2132 "or <class 'collections.abc.Sized'>"),
2133 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2134 "or <class 'collections.abc.Container'>")),
2135 )
Łukasz Langa3720c772013-07-01 16:00:38 +02002136 class V(c.Sized, S):
2137 def __len__(self):
2138 return 0
2139 @functools.singledispatch
2140 def j(arg):
2141 return "base"
2142 @j.register(S)
2143 def _(arg):
2144 return "s"
2145 @j.register(c.Container)
2146 def _(arg):
2147 return "container"
2148 v = V()
2149 self.assertEqual(j(v), "s")
2150 c.Container.register(V)
2151 self.assertEqual(j(v), "container") # because it ends up right after
2152 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02002153
2154 def test_cache_invalidation(self):
2155 from collections import UserDict
INADA Naoki9811e802017-09-30 16:13:02 +09002156 import weakref
2157
Łukasz Langa6f692512013-06-05 12:20:24 +02002158 class TracingDict(UserDict):
2159 def __init__(self, *args, **kwargs):
2160 super(TracingDict, self).__init__(*args, **kwargs)
2161 self.set_ops = []
2162 self.get_ops = []
2163 def __getitem__(self, key):
2164 result = self.data[key]
2165 self.get_ops.append(key)
2166 return result
2167 def __setitem__(self, key, value):
2168 self.set_ops.append(key)
2169 self.data[key] = value
2170 def clear(self):
2171 self.data.clear()
INADA Naoki9811e802017-09-30 16:13:02 +09002172
Łukasz Langa6f692512013-06-05 12:20:24 +02002173 td = TracingDict()
INADA Naoki9811e802017-09-30 16:13:02 +09002174 with support.swap_attr(weakref, "WeakKeyDictionary", lambda: td):
2175 c = collections.abc
2176 @functools.singledispatch
2177 def g(arg):
2178 return "base"
2179 d = {}
2180 l = []
2181 self.assertEqual(len(td), 0)
2182 self.assertEqual(g(d), "base")
2183 self.assertEqual(len(td), 1)
2184 self.assertEqual(td.get_ops, [])
2185 self.assertEqual(td.set_ops, [dict])
2186 self.assertEqual(td.data[dict], g.registry[object])
2187 self.assertEqual(g(l), "base")
2188 self.assertEqual(len(td), 2)
2189 self.assertEqual(td.get_ops, [])
2190 self.assertEqual(td.set_ops, [dict, list])
2191 self.assertEqual(td.data[dict], g.registry[object])
2192 self.assertEqual(td.data[list], g.registry[object])
2193 self.assertEqual(td.data[dict], td.data[list])
2194 self.assertEqual(g(l), "base")
2195 self.assertEqual(g(d), "base")
2196 self.assertEqual(td.get_ops, [list, dict])
2197 self.assertEqual(td.set_ops, [dict, list])
2198 g.register(list, lambda arg: "list")
2199 self.assertEqual(td.get_ops, [list, dict])
2200 self.assertEqual(len(td), 0)
2201 self.assertEqual(g(d), "base")
2202 self.assertEqual(len(td), 1)
2203 self.assertEqual(td.get_ops, [list, dict])
2204 self.assertEqual(td.set_ops, [dict, list, dict])
2205 self.assertEqual(td.data[dict],
2206 functools._find_impl(dict, g.registry))
2207 self.assertEqual(g(l), "list")
2208 self.assertEqual(len(td), 2)
2209 self.assertEqual(td.get_ops, [list, dict])
2210 self.assertEqual(td.set_ops, [dict, list, dict, list])
2211 self.assertEqual(td.data[list],
2212 functools._find_impl(list, g.registry))
2213 class X:
2214 pass
2215 c.MutableMapping.register(X) # Will not invalidate the cache,
2216 # not using ABCs yet.
2217 self.assertEqual(g(d), "base")
2218 self.assertEqual(g(l), "list")
2219 self.assertEqual(td.get_ops, [list, dict, dict, list])
2220 self.assertEqual(td.set_ops, [dict, list, dict, list])
2221 g.register(c.Sized, lambda arg: "sized")
2222 self.assertEqual(len(td), 0)
2223 self.assertEqual(g(d), "sized")
2224 self.assertEqual(len(td), 1)
2225 self.assertEqual(td.get_ops, [list, dict, dict, list])
2226 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
2227 self.assertEqual(g(l), "list")
2228 self.assertEqual(len(td), 2)
2229 self.assertEqual(td.get_ops, [list, dict, dict, list])
2230 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2231 self.assertEqual(g(l), "list")
2232 self.assertEqual(g(d), "sized")
2233 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
2234 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2235 g.dispatch(list)
2236 g.dispatch(dict)
2237 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
2238 list, dict])
2239 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2240 c.MutableSet.register(X) # Will invalidate the cache.
2241 self.assertEqual(len(td), 2) # Stale cache.
2242 self.assertEqual(g(l), "list")
2243 self.assertEqual(len(td), 1)
2244 g.register(c.MutableMapping, lambda arg: "mutablemapping")
2245 self.assertEqual(len(td), 0)
2246 self.assertEqual(g(d), "mutablemapping")
2247 self.assertEqual(len(td), 1)
2248 self.assertEqual(g(l), "list")
2249 self.assertEqual(len(td), 2)
2250 g.register(dict, lambda arg: "dict")
2251 self.assertEqual(g(d), "dict")
2252 self.assertEqual(g(l), "list")
2253 g._clear_cache()
2254 self.assertEqual(len(td), 0)
Łukasz Langa6f692512013-06-05 12:20:24 +02002255
Łukasz Langae5697532017-12-11 13:56:31 -08002256 def test_annotations(self):
2257 @functools.singledispatch
2258 def i(arg):
2259 return "base"
2260 @i.register
2261 def _(arg: collections.abc.Mapping):
2262 return "mapping"
2263 @i.register
2264 def _(arg: "collections.abc.Sequence"):
2265 return "sequence"
2266 self.assertEqual(i(None), "base")
2267 self.assertEqual(i({"a": 1}), "mapping")
2268 self.assertEqual(i([1, 2, 3]), "sequence")
2269 self.assertEqual(i((1, 2, 3)), "sequence")
2270 self.assertEqual(i("str"), "sequence")
2271
2272 # Registering classes as callables doesn't work with annotations,
2273 # you need to pass the type explicitly.
2274 @i.register(str)
2275 class _:
2276 def __init__(self, arg):
2277 self.arg = arg
2278
2279 def __eq__(self, other):
2280 return self.arg == other
2281 self.assertEqual(i("str"), "str")
2282
Ethan Smithc6512752018-05-26 16:38:33 -04002283 def test_method_register(self):
2284 class A:
2285 @functools.singledispatchmethod
2286 def t(self, arg):
2287 self.arg = "base"
2288 @t.register(int)
2289 def _(self, arg):
2290 self.arg = "int"
2291 @t.register(str)
2292 def _(self, arg):
2293 self.arg = "str"
2294 a = A()
2295
2296 a.t(0)
2297 self.assertEqual(a.arg, "int")
2298 aa = A()
2299 self.assertFalse(hasattr(aa, 'arg'))
2300 a.t('')
2301 self.assertEqual(a.arg, "str")
2302 aa = A()
2303 self.assertFalse(hasattr(aa, 'arg'))
2304 a.t(0.0)
2305 self.assertEqual(a.arg, "base")
2306 aa = A()
2307 self.assertFalse(hasattr(aa, 'arg'))
2308
2309 def test_staticmethod_register(self):
2310 class A:
2311 @functools.singledispatchmethod
2312 @staticmethod
2313 def t(arg):
2314 return arg
2315 @t.register(int)
2316 @staticmethod
2317 def _(arg):
2318 return isinstance(arg, int)
2319 @t.register(str)
2320 @staticmethod
2321 def _(arg):
2322 return isinstance(arg, str)
2323 a = A()
2324
2325 self.assertTrue(A.t(0))
2326 self.assertTrue(A.t(''))
2327 self.assertEqual(A.t(0.0), 0.0)
2328
2329 def test_classmethod_register(self):
2330 class A:
2331 def __init__(self, arg):
2332 self.arg = arg
2333
2334 @functools.singledispatchmethod
2335 @classmethod
2336 def t(cls, arg):
2337 return cls("base")
2338 @t.register(int)
2339 @classmethod
2340 def _(cls, arg):
2341 return cls("int")
2342 @t.register(str)
2343 @classmethod
2344 def _(cls, arg):
2345 return cls("str")
2346
2347 self.assertEqual(A.t(0).arg, "int")
2348 self.assertEqual(A.t('').arg, "str")
2349 self.assertEqual(A.t(0.0).arg, "base")
2350
2351 def test_callable_register(self):
2352 class A:
2353 def __init__(self, arg):
2354 self.arg = arg
2355
2356 @functools.singledispatchmethod
2357 @classmethod
2358 def t(cls, arg):
2359 return cls("base")
2360
2361 @A.t.register(int)
2362 @classmethod
2363 def _(cls, arg):
2364 return cls("int")
2365 @A.t.register(str)
2366 @classmethod
2367 def _(cls, arg):
2368 return cls("str")
2369
2370 self.assertEqual(A.t(0).arg, "int")
2371 self.assertEqual(A.t('').arg, "str")
2372 self.assertEqual(A.t(0.0).arg, "base")
2373
2374 def test_abstractmethod_register(self):
2375 class Abstract(abc.ABCMeta):
2376
2377 @functools.singledispatchmethod
2378 @abc.abstractmethod
2379 def add(self, x, y):
2380 pass
2381
2382 self.assertTrue(Abstract.add.__isabstractmethod__)
2383
2384 def test_type_ann_register(self):
2385 class A:
2386 @functools.singledispatchmethod
2387 def t(self, arg):
2388 return "base"
2389 @t.register
2390 def _(self, arg: int):
2391 return "int"
2392 @t.register
2393 def _(self, arg: str):
2394 return "str"
2395 a = A()
2396
2397 self.assertEqual(a.t(0), "int")
2398 self.assertEqual(a.t(''), "str")
2399 self.assertEqual(a.t(0.0), "base")
2400
Łukasz Langae5697532017-12-11 13:56:31 -08002401 def test_invalid_registrations(self):
2402 msg_prefix = "Invalid first argument to `register()`: "
2403 msg_suffix = (
2404 ". Use either `@register(some_class)` or plain `@register` on an "
2405 "annotated function."
2406 )
2407 @functools.singledispatch
2408 def i(arg):
2409 return "base"
2410 with self.assertRaises(TypeError) as exc:
2411 @i.register(42)
2412 def _(arg):
2413 return "I annotated with a non-type"
2414 self.assertTrue(str(exc.exception).startswith(msg_prefix + "42"))
2415 self.assertTrue(str(exc.exception).endswith(msg_suffix))
2416 with self.assertRaises(TypeError) as exc:
2417 @i.register
2418 def _(arg):
2419 return "I forgot to annotate"
2420 self.assertTrue(str(exc.exception).startswith(msg_prefix +
2421 "<function TestSingleDispatch.test_invalid_registrations.<locals>._"
2422 ))
2423 self.assertTrue(str(exc.exception).endswith(msg_suffix))
2424
Łukasz Langae5697532017-12-11 13:56:31 -08002425 with self.assertRaises(TypeError) as exc:
2426 @i.register
2427 def _(arg: typing.Iterable[str]):
2428 # At runtime, dispatching on generics is impossible.
2429 # When registering implementations with singledispatch, avoid
2430 # types from `typing`. Instead, annotate with regular types
2431 # or ABCs.
2432 return "I annotated with a generic collection"
Lysandros Nikolaoud6738102019-05-20 00:11:21 +02002433 self.assertTrue(str(exc.exception).startswith(
2434 "Invalid annotation for 'arg'."
Łukasz Langae5697532017-12-11 13:56:31 -08002435 ))
Lysandros Nikolaoud6738102019-05-20 00:11:21 +02002436 self.assertTrue(str(exc.exception).endswith(
2437 'typing.Iterable[str] is not a class.'
2438 ))
Łukasz Langae5697532017-12-11 13:56:31 -08002439
Dong-hee Na445f1b32018-07-10 16:26:36 +09002440 def test_invalid_positional_argument(self):
2441 @functools.singledispatch
2442 def f(*args):
2443 pass
2444 msg = 'f requires at least 1 positional argument'
INADA Naoki56d8f572018-07-17 13:44:47 +09002445 with self.assertRaisesRegex(TypeError, msg):
Dong-hee Na445f1b32018-07-10 16:26:36 +09002446 f()
Łukasz Langa6f692512013-06-05 12:20:24 +02002447
Carl Meyerd658dea2018-08-28 01:11:56 -06002448
2449class CachedCostItem:
2450 _cost = 1
2451
2452 def __init__(self):
2453 self.lock = py_functools.RLock()
2454
2455 @py_functools.cached_property
2456 def cost(self):
2457 """The cost of the item."""
2458 with self.lock:
2459 self._cost += 1
2460 return self._cost
2461
2462
2463class OptionallyCachedCostItem:
2464 _cost = 1
2465
2466 def get_cost(self):
2467 """The cost of the item."""
2468 self._cost += 1
2469 return self._cost
2470
2471 cached_cost = py_functools.cached_property(get_cost)
2472
2473
2474class CachedCostItemWait:
2475
2476 def __init__(self, event):
2477 self._cost = 1
2478 self.lock = py_functools.RLock()
2479 self.event = event
2480
2481 @py_functools.cached_property
2482 def cost(self):
2483 self.event.wait(1)
2484 with self.lock:
2485 self._cost += 1
2486 return self._cost
2487
2488
2489class CachedCostItemWithSlots:
2490 __slots__ = ('_cost')
2491
2492 def __init__(self):
2493 self._cost = 1
2494
2495 @py_functools.cached_property
2496 def cost(self):
2497 raise RuntimeError('never called, slots not supported')
2498
2499
2500class TestCachedProperty(unittest.TestCase):
2501 def test_cached(self):
2502 item = CachedCostItem()
2503 self.assertEqual(item.cost, 2)
2504 self.assertEqual(item.cost, 2) # not 3
2505
2506 def test_cached_attribute_name_differs_from_func_name(self):
2507 item = OptionallyCachedCostItem()
2508 self.assertEqual(item.get_cost(), 2)
2509 self.assertEqual(item.cached_cost, 3)
2510 self.assertEqual(item.get_cost(), 4)
2511 self.assertEqual(item.cached_cost, 3)
2512
2513 def test_threaded(self):
2514 go = threading.Event()
2515 item = CachedCostItemWait(go)
2516
2517 num_threads = 3
2518
2519 orig_si = sys.getswitchinterval()
2520 sys.setswitchinterval(1e-6)
2521 try:
2522 threads = [
2523 threading.Thread(target=lambda: item.cost)
2524 for k in range(num_threads)
2525 ]
2526 with support.start_threads(threads):
2527 go.set()
2528 finally:
2529 sys.setswitchinterval(orig_si)
2530
2531 self.assertEqual(item.cost, 2)
2532
2533 def test_object_with_slots(self):
2534 item = CachedCostItemWithSlots()
2535 with self.assertRaisesRegex(
2536 TypeError,
2537 "No '__dict__' attribute on 'CachedCostItemWithSlots' instance to cache 'cost' property.",
2538 ):
2539 item.cost
2540
2541 def test_immutable_dict(self):
2542 class MyMeta(type):
2543 @py_functools.cached_property
2544 def prop(self):
2545 return True
2546
2547 class MyClass(metaclass=MyMeta):
2548 pass
2549
2550 with self.assertRaisesRegex(
2551 TypeError,
2552 "The '__dict__' attribute on 'MyMeta' instance does not support item assignment for caching 'prop' property.",
2553 ):
2554 MyClass.prop
2555
2556 def test_reuse_different_names(self):
2557 """Disallow this case because decorated function a would not be cached."""
2558 with self.assertRaises(RuntimeError) as ctx:
2559 class ReusedCachedProperty:
2560 @py_functools.cached_property
2561 def a(self):
2562 pass
2563
2564 b = a
2565
2566 self.assertEqual(
2567 str(ctx.exception.__context__),
2568 str(TypeError("Cannot assign the same cached_property to two different names ('a' and 'b')."))
2569 )
2570
2571 def test_reuse_same_name(self):
2572 """Reusing a cached_property on different classes under the same name is OK."""
2573 counter = 0
2574
2575 @py_functools.cached_property
2576 def _cp(_self):
2577 nonlocal counter
2578 counter += 1
2579 return counter
2580
2581 class A:
2582 cp = _cp
2583
2584 class B:
2585 cp = _cp
2586
2587 a = A()
2588 b = B()
2589
2590 self.assertEqual(a.cp, 1)
2591 self.assertEqual(b.cp, 2)
2592 self.assertEqual(a.cp, 1)
2593
2594 def test_set_name_not_called(self):
2595 cp = py_functools.cached_property(lambda s: None)
2596 class Foo:
2597 pass
2598
2599 Foo.cp = cp
2600
2601 with self.assertRaisesRegex(
2602 TypeError,
2603 "Cannot use cached_property instance without calling __set_name__ on it.",
2604 ):
2605 Foo().cp
2606
2607 def test_access_from_class(self):
2608 self.assertIsInstance(CachedCostItem.cost, py_functools.cached_property)
2609
2610 def test_doc(self):
2611 self.assertEqual(CachedCostItem.cost.__doc__, "The cost of the item.")
2612
2613
Raymond Hettinger9c323f82005-02-28 19:39:44 +00002614if __name__ == '__main__':
Zachary Ware38c707e2015-04-13 15:00:43 -05002615 unittest.main()