blob: 8fee1c6afdd450e48c613f5cafbc4e2853695359 [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08002import builtins
Raymond Hettinger003be522011-05-03 11:01:32 -07003import collections
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03004import collections.abc
Serhiy Storchaka45120f22015-10-24 09:49:56 +03005import copy
Łukasz Langa6f692512013-06-05 12:20:24 +02006from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00007import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00008from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02009import sys
10from test import support
Antoine Pitroua6a4dc82017-09-07 18:56:24 +020011import threading
Serhiy Storchaka67796522017-01-12 18:34:33 +020012import time
Łukasz Langae5697532017-12-11 13:56:31 -080013import typing
Łukasz Langa6f692512013-06-05 12:20:24 +020014import unittest
Raymond Hettingerd191ef22017-01-07 20:44:48 -080015import unittest.mock
Łukasz Langa6f692512013-06-05 12:20:24 +020016from weakref import proxy
Nick Coghlan457fc9a2016-09-10 20:00:02 +100017import contextlib
Raymond Hettinger9c323f82005-02-28 19:39:44 +000018
Antoine Pitroub5b37142012-11-13 21:35:40 +010019import functools
20
Antoine Pitroub5b37142012-11-13 21:35:40 +010021py_functools = support.import_fresh_module('functools', blocked=['_functools'])
22c_functools = support.import_fresh_module('functools', fresh=['_functools'])
23
Łukasz Langa6f692512013-06-05 12:20:24 +020024decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
25
Nick Coghlan457fc9a2016-09-10 20:00:02 +100026@contextlib.contextmanager
27def replaced_module(name, replacement):
28 original_module = sys.modules[name]
29 sys.modules[name] = replacement
30 try:
31 yield
32 finally:
33 sys.modules[name] = original_module
Łukasz Langa6f692512013-06-05 12:20:24 +020034
Raymond Hettinger9c323f82005-02-28 19:39:44 +000035def capture(*args, **kw):
36 """capture all positional and keyword arguments"""
37 return args, kw
38
Łukasz Langa6f692512013-06-05 12:20:24 +020039
Jack Diederiche0cbd692009-04-01 04:27:09 +000040def signature(part):
41 """ return the signature of a partial object """
42 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000043
Serhiy Storchaka38741282016-02-02 18:45:17 +020044class MyTuple(tuple):
45 pass
46
47class BadTuple(tuple):
48 def __add__(self, other):
49 return list(self) + list(other)
50
51class MyDict(dict):
52 pass
53
Łukasz Langa6f692512013-06-05 12:20:24 +020054
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020055class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000056
57 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010058 p = self.partial(capture, 1, 2, a=10, b=20)
59 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000060 self.assertEqual(p(3, 4, b=30, c=40),
61 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010062 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000063 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000064
65 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010066 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000067 # attributes should be readable
68 self.assertEqual(p.func, capture)
69 self.assertEqual(p.args, (1, 2))
70 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000071
72 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010073 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000074 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010075 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000076 except TypeError:
77 pass
78 else:
79 self.fail('First arg not checked for callability')
80
81 def test_protection_of_callers_dict_argument(self):
82 # a caller's dictionary should not be altered by partial
83 def func(a=10, b=20):
84 return a
85 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010086 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000087 self.assertEqual(p(**d), 3)
88 self.assertEqual(d, {'a':3})
89 p(b=7)
90 self.assertEqual(d, {'a':3})
91
Serhiy Storchaka9639e4a2017-02-20 14:04:30 +020092 def test_kwargs_copy(self):
93 # Issue #29532: Altering a kwarg dictionary passed to a constructor
94 # should not affect a partial object after creation
95 d = {'a': 3}
96 p = self.partial(capture, **d)
97 self.assertEqual(p(), ((), {'a': 3}))
98 d['a'] = 5
99 self.assertEqual(p(), ((), {'a': 3}))
100
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000101 def test_arg_combinations(self):
102 # exercise special code paths for zero args in either partial
103 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100104 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000105 self.assertEqual(p(), ((), {}))
106 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100107 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000108 self.assertEqual(p(), ((1,2), {}))
109 self.assertEqual(p(3,4), ((1,2,3,4), {}))
110
111 def test_kw_combinations(self):
112 # exercise special code paths for no keyword args in
113 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100114 p = self.partial(capture)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400115 self.assertEqual(p.keywords, {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000116 self.assertEqual(p(), ((), {}))
117 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100118 p = self.partial(capture, a=1)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400119 self.assertEqual(p.keywords, {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000120 self.assertEqual(p(), ((), {'a':1}))
121 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
122 # keyword args in the call override those in the partial object
123 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
124
125 def test_positional(self):
126 # make sure positional arguments are captured correctly
127 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100128 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000129 expected = args + ('x',)
130 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000131 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000132
133 def test_keyword(self):
134 # make sure keyword arguments are captured correctly
135 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100136 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000137 expected = {'a':a,'x':None}
138 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000139 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000140
141 def test_no_side_effects(self):
142 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100143 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000144 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000145 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000146 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000147 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000148
149 def test_error_propagation(self):
150 def f(x, y):
151 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100152 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
153 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
154 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
155 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000156
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000157 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100158 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000159 p = proxy(f)
160 self.assertEqual(f.func, p.func)
161 f = None
162 self.assertRaises(ReferenceError, getattr, p, 'func')
163
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000164 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000165 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100166 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000167 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100168 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000169 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000170
Alexander Belopolskye49af342015-03-01 15:08:17 -0500171 def test_nested_optimization(self):
172 partial = self.partial
Alexander Belopolskye49af342015-03-01 15:08:17 -0500173 inner = partial(signature, 'asdf')
174 nested = partial(inner, bar=True)
175 flat = partial(signature, 'asdf', bar=True)
176 self.assertEqual(signature(nested), signature(flat))
177
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300178 def test_nested_partial_with_attribute(self):
179 # see issue 25137
180 partial = self.partial
181
182 def foo(bar):
183 return bar
184
185 p = partial(foo, 'first')
186 p2 = partial(p, 'second')
187 p2.new_attr = 'spam'
188 self.assertEqual(p2.new_attr, 'spam')
189
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000190 def test_repr(self):
191 args = (object(), object())
192 args_repr = ', '.join(repr(a) for a in args)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200193 kwargs = {'a': object(), 'b': object()}
194 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
195 'b={b!r}, a={a!r}'.format_map(kwargs)]
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000196 if self.partial in (c_functools.partial, py_functools.partial):
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000197 name = 'functools.partial'
198 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100199 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000200
Antoine Pitroub5b37142012-11-13 21:35:40 +0100201 f = self.partial(capture)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000202 self.assertEqual(f'{name}({capture!r})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000203
Antoine Pitroub5b37142012-11-13 21:35:40 +0100204 f = self.partial(capture, *args)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000205 self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000206
Antoine Pitroub5b37142012-11-13 21:35:40 +0100207 f = self.partial(capture, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200208 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000209 [f'{name}({capture!r}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200210 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000211
Antoine Pitroub5b37142012-11-13 21:35:40 +0100212 f = self.partial(capture, *args, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200213 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000214 [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200215 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000216
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300217 def test_recursive_repr(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000218 if self.partial in (c_functools.partial, py_functools.partial):
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300219 name = 'functools.partial'
220 else:
221 name = self.partial.__name__
222
223 f = self.partial(capture)
224 f.__setstate__((f, (), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300225 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000226 self.assertEqual(repr(f), '%s(...)' % (name,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300227 finally:
228 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300229
230 f = self.partial(capture)
231 f.__setstate__((capture, (f,), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300232 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000233 self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300234 finally:
235 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300236
237 f = self.partial(capture)
238 f.__setstate__((capture, (), {'a': f}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300239 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000240 self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300241 finally:
242 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300243
Jack Diederiche0cbd692009-04-01 04:27:09 +0000244 def test_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000245 with self.AllowPickle():
246 f = self.partial(signature, ['asdf'], bar=[True])
247 f.attr = []
248 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
249 f_copy = pickle.loads(pickle.dumps(f, proto))
250 self.assertEqual(signature(f_copy), signature(f))
Serhiy Storchaka38741282016-02-02 18:45:17 +0200251
252 def test_copy(self):
253 f = self.partial(signature, ['asdf'], bar=[True])
254 f.attr = []
255 f_copy = copy.copy(f)
256 self.assertEqual(signature(f_copy), signature(f))
257 self.assertIs(f_copy.attr, f.attr)
258 self.assertIs(f_copy.args, f.args)
259 self.assertIs(f_copy.keywords, f.keywords)
260
261 def test_deepcopy(self):
262 f = self.partial(signature, ['asdf'], bar=[True])
263 f.attr = []
264 f_copy = copy.deepcopy(f)
265 self.assertEqual(signature(f_copy), signature(f))
266 self.assertIsNot(f_copy.attr, f.attr)
267 self.assertIsNot(f_copy.args, f.args)
268 self.assertIsNot(f_copy.args[0], f.args[0])
269 self.assertIsNot(f_copy.keywords, f.keywords)
270 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
271
272 def test_setstate(self):
273 f = self.partial(signature)
274 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000275
Serhiy Storchaka38741282016-02-02 18:45:17 +0200276 self.assertEqual(signature(f),
277 (capture, (1,), dict(a=10), dict(attr=[])))
278 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
279
280 f.__setstate__((capture, (1,), dict(a=10), None))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000281
Serhiy Storchaka38741282016-02-02 18:45:17 +0200282 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
283 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
284
285 f.__setstate__((capture, (1,), None, None))
286 #self.assertEqual(signature(f), (capture, (1,), {}, {}))
287 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
288 self.assertEqual(f(2), ((1, 2), {}))
289 self.assertEqual(f(), ((1,), {}))
290
291 f.__setstate__((capture, (), {}, None))
292 self.assertEqual(signature(f), (capture, (), {}, {}))
293 self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
294 self.assertEqual(f(2), ((2,), {}))
295 self.assertEqual(f(), ((), {}))
296
297 def test_setstate_errors(self):
298 f = self.partial(signature)
299 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
300 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
301 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
302 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
303 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
304 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
305 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
306
307 def test_setstate_subclasses(self):
308 f = self.partial(signature)
309 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
310 s = signature(f)
311 self.assertEqual(s, (capture, (1,), dict(a=10), {}))
312 self.assertIs(type(s[1]), tuple)
313 self.assertIs(type(s[2]), dict)
314 r = f()
315 self.assertEqual(r, ((1,), {'a': 10}))
316 self.assertIs(type(r[0]), tuple)
317 self.assertIs(type(r[1]), dict)
318
319 f.__setstate__((capture, BadTuple((1,)), {}, None))
320 s = signature(f)
321 self.assertEqual(s, (capture, (1,), {}, {}))
322 self.assertIs(type(s[1]), tuple)
323 r = f(2)
324 self.assertEqual(r, ((1, 2), {}))
325 self.assertIs(type(r[0]), tuple)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000326
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300327 def test_recursive_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000328 with self.AllowPickle():
329 f = self.partial(capture)
330 f.__setstate__((f, (), {}, {}))
331 try:
332 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
333 with self.assertRaises(RecursionError):
334 pickle.dumps(f, proto)
335 finally:
336 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300337
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000338 f = self.partial(capture)
339 f.__setstate__((capture, (f,), {}, {}))
340 try:
341 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
342 f_copy = pickle.loads(pickle.dumps(f, proto))
343 try:
344 self.assertIs(f_copy.args[0], f_copy)
345 finally:
346 f_copy.__setstate__((capture, (), {}, {}))
347 finally:
348 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300349
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000350 f = self.partial(capture)
351 f.__setstate__((capture, (), {'a': f}, {}))
352 try:
353 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
354 f_copy = pickle.loads(pickle.dumps(f, proto))
355 try:
356 self.assertIs(f_copy.keywords['a'], f_copy)
357 finally:
358 f_copy.__setstate__((capture, (), {}, {}))
359 finally:
360 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300361
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200362 # Issue 6083: Reference counting bug
363 def test_setstate_refcount(self):
364 class BadSequence:
365 def __len__(self):
366 return 4
367 def __getitem__(self, key):
368 if key == 0:
369 return max
370 elif key == 1:
371 return tuple(range(1000000))
372 elif key in (2, 3):
373 return {}
374 raise IndexError
375
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200376 f = self.partial(object)
Serhiy Storchaka38741282016-02-02 18:45:17 +0200377 self.assertRaises(TypeError, f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000378
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000379@unittest.skipUnless(c_functools, 'requires the C _functools module')
380class TestPartialC(TestPartial, unittest.TestCase):
381 if c_functools:
382 partial = c_functools.partial
383
384 class AllowPickle:
385 def __enter__(self):
386 return self
387 def __exit__(self, type, value, tb):
388 return False
389
390 def test_attributes_unwritable(self):
391 # attributes should not be writable
392 p = self.partial(capture, 1, 2, a=10, b=20)
393 self.assertRaises(AttributeError, setattr, p, 'func', map)
394 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
395 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
396
397 p = self.partial(hex)
398 try:
399 del p.__dict__
400 except TypeError:
401 pass
402 else:
403 self.fail('partial object allowed __dict__ to be deleted')
Łukasz Langa6f692512013-06-05 12:20:24 +0200404
Michael Seifert6c3d5272017-03-15 06:26:33 +0100405 def test_manually_adding_non_string_keyword(self):
406 p = self.partial(capture)
407 # Adding a non-string/unicode keyword to partial kwargs
408 p.keywords[1234] = 'value'
409 r = repr(p)
410 self.assertIn('1234', r)
411 self.assertIn("'value'", r)
412 with self.assertRaises(TypeError):
413 p()
414
415 def test_keystr_replaces_value(self):
416 p = self.partial(capture)
417
418 class MutatesYourDict(object):
419 def __str__(self):
420 p.keywords[self] = ['sth2']
421 return 'astr'
422
Mike53f7a7c2017-12-14 14:04:53 +0300423 # Replacing the value during key formatting should keep the original
Michael Seifert6c3d5272017-03-15 06:26:33 +0100424 # value alive (at least long enough).
425 p.keywords[MutatesYourDict()] = ['sth']
426 r = repr(p)
427 self.assertIn('astr', r)
428 self.assertIn("['sth']", r)
429
430
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200431class TestPartialPy(TestPartial, unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000432 partial = py_functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000433
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000434 class AllowPickle:
435 def __init__(self):
436 self._cm = replaced_module("functools", py_functools)
437 def __enter__(self):
438 return self._cm.__enter__()
439 def __exit__(self, type, value, tb):
440 return self._cm.__exit__(type, value, tb)
Łukasz Langa6f692512013-06-05 12:20:24 +0200441
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200442if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000443 class CPartialSubclass(c_functools.partial):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200444 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100445
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000446class PyPartialSubclass(py_functools.partial):
447 pass
Łukasz Langa6f692512013-06-05 12:20:24 +0200448
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200449@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200450class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200451 if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000452 partial = CPartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000453
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300454 # partial subclasses are not optimized for nested calls
455 test_nested_optimization = None
456
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000457class TestPartialPySubclass(TestPartialPy):
458 partial = PyPartialSubclass
Łukasz Langa6f692512013-06-05 12:20:24 +0200459
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000460class TestPartialMethod(unittest.TestCase):
461
462 class A(object):
463 nothing = functools.partialmethod(capture)
464 positional = functools.partialmethod(capture, 1)
465 keywords = functools.partialmethod(capture, a=2)
466 both = functools.partialmethod(capture, 3, b=4)
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300467 spec_keywords = functools.partialmethod(capture, self=1, func=2)
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000468
469 nested = functools.partialmethod(positional, 5)
470
471 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
472
473 static = functools.partialmethod(staticmethod(capture), 8)
474 cls = functools.partialmethod(classmethod(capture), d=9)
475
476 a = A()
477
478 def test_arg_combinations(self):
479 self.assertEqual(self.a.nothing(), ((self.a,), {}))
480 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
481 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
482 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
483
484 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
485 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
486 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
487 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
488
489 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
490 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
491 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
492 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
493
494 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
495 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
496 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
497 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
498
499 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
500
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300501 self.assertEqual(self.a.spec_keywords(), ((self.a,), {'self': 1, 'func': 2}))
502
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000503 def test_nested(self):
504 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
505 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
506 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
507 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
508
509 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
510
511 def test_over_partial(self):
512 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
513 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
514 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
515 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
516
517 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
518
519 def test_bound_method_introspection(self):
520 obj = self.a
521 self.assertIs(obj.both.__self__, obj)
522 self.assertIs(obj.nested.__self__, obj)
523 self.assertIs(obj.over_partial.__self__, obj)
524 self.assertIs(obj.cls.__self__, self.A)
525 self.assertIs(self.A.cls.__self__, self.A)
526
527 def test_unbound_method_retrieval(self):
528 obj = self.A
529 self.assertFalse(hasattr(obj.both, "__self__"))
530 self.assertFalse(hasattr(obj.nested, "__self__"))
531 self.assertFalse(hasattr(obj.over_partial, "__self__"))
532 self.assertFalse(hasattr(obj.static, "__self__"))
533 self.assertFalse(hasattr(self.a.static, "__self__"))
534
535 def test_descriptors(self):
536 for obj in [self.A, self.a]:
537 with self.subTest(obj=obj):
538 self.assertEqual(obj.static(), ((8,), {}))
539 self.assertEqual(obj.static(5), ((8, 5), {}))
540 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
541 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
542
543 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
544 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
545 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
546 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
547
548 def test_overriding_keywords(self):
549 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
550 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
551
552 def test_invalid_args(self):
553 with self.assertRaises(TypeError):
554 class B(object):
555 method = functools.partialmethod(None, 1)
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300556 with self.assertRaises(TypeError):
557 class B:
558 method = functools.partialmethod()
559 with self.assertWarns(DeprecationWarning):
560 class B:
561 method = functools.partialmethod(func=capture, a=1)
562 b = B()
563 self.assertEqual(b.method(2, x=3), ((b, 2), {'a': 1, 'x': 3}))
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000564
565 def test_repr(self):
566 self.assertEqual(repr(vars(self.A)['both']),
567 'functools.partialmethod({}, 3, b=4)'.format(capture))
568
569 def test_abstract(self):
570 class Abstract(abc.ABCMeta):
571
572 @abc.abstractmethod
573 def add(self, x, y):
574 pass
575
576 add5 = functools.partialmethod(add, 5)
577
578 self.assertTrue(Abstract.add.__isabstractmethod__)
579 self.assertTrue(Abstract.add5.__isabstractmethod__)
580
581 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
582 self.assertFalse(getattr(func, '__isabstractmethod__', False))
583
Pablo Galindo8c77b8c2019-04-29 13:36:57 +0100584 def test_positional_only(self):
585 def f(a, b, /):
586 return a + b
587
588 p = functools.partial(f, 1)
589 self.assertEqual(p(2), f(1, 2))
590
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000591
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000592class TestUpdateWrapper(unittest.TestCase):
593
594 def check_wrapper(self, wrapper, wrapped,
595 assigned=functools.WRAPPER_ASSIGNMENTS,
596 updated=functools.WRAPPER_UPDATES):
597 # Check attributes were assigned
598 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000599 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000600 # Check attributes were updated
601 for name in updated:
602 wrapper_attr = getattr(wrapper, name)
603 wrapped_attr = getattr(wrapped, name)
604 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000605 if name == "__dict__" and key == "__wrapped__":
606 # __wrapped__ is overwritten by the update code
607 continue
608 self.assertIs(wrapped_attr[key], wrapper_attr[key])
609 # Check __wrapped__
610 self.assertIs(wrapper.__wrapped__, wrapped)
611
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000612
R. David Murray378c0cf2010-02-24 01:46:21 +0000613 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000614 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000615 """This is a test"""
616 pass
617 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000618 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000619 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000620 pass
621 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000622 return wrapper, f
623
624 def test_default_update(self):
625 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000626 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000627 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000628 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600629 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000630 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000631 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
632 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000633
R. David Murray378c0cf2010-02-24 01:46:21 +0000634 @unittest.skipIf(sys.flags.optimize >= 2,
635 "Docstrings are omitted with -O2 and above")
636 def test_default_update_doc(self):
637 wrapper, f = self._default_update()
638 self.assertEqual(wrapper.__doc__, 'This is a test')
639
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000640 def test_no_update(self):
641 def f():
642 """This is a test"""
643 pass
644 f.attr = 'This is also a test'
645 def wrapper():
646 pass
647 functools.update_wrapper(wrapper, f, (), ())
648 self.check_wrapper(wrapper, f, (), ())
649 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600650 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000651 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000652 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000653 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000654
655 def test_selective_update(self):
656 def f():
657 pass
658 f.attr = 'This is a different test'
659 f.dict_attr = dict(a=1, b=2, c=3)
660 def wrapper():
661 pass
662 wrapper.dict_attr = {}
663 assign = ('attr',)
664 update = ('dict_attr',)
665 functools.update_wrapper(wrapper, f, assign, update)
666 self.check_wrapper(wrapper, f, assign, update)
667 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600668 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000669 self.assertEqual(wrapper.__doc__, None)
670 self.assertEqual(wrapper.attr, 'This is a different test')
671 self.assertEqual(wrapper.dict_attr, f.dict_attr)
672
Nick Coghlan98876832010-08-17 06:17:18 +0000673 def test_missing_attributes(self):
674 def f():
675 pass
676 def wrapper():
677 pass
678 wrapper.dict_attr = {}
679 assign = ('attr',)
680 update = ('dict_attr',)
681 # Missing attributes on wrapped object are ignored
682 functools.update_wrapper(wrapper, f, assign, update)
683 self.assertNotIn('attr', wrapper.__dict__)
684 self.assertEqual(wrapper.dict_attr, {})
685 # Wrapper must have expected attributes for updating
686 del wrapper.dict_attr
687 with self.assertRaises(AttributeError):
688 functools.update_wrapper(wrapper, f, assign, update)
689 wrapper.dict_attr = 1
690 with self.assertRaises(AttributeError):
691 functools.update_wrapper(wrapper, f, assign, update)
692
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200693 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000694 @unittest.skipIf(sys.flags.optimize >= 2,
695 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000696 def test_builtin_update(self):
697 # Test for bug #1576241
698 def wrapper():
699 pass
700 functools.update_wrapper(wrapper, max)
701 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000702 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000703 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000704
Łukasz Langa6f692512013-06-05 12:20:24 +0200705
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000706class TestWraps(TestUpdateWrapper):
707
R. David Murray378c0cf2010-02-24 01:46:21 +0000708 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000709 def f():
710 """This is a test"""
711 pass
712 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000713 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000714 @functools.wraps(f)
715 def wrapper():
716 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600717 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000718
719 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600720 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000721 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000722 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600723 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000724 self.assertEqual(wrapper.attr, 'This is also a test')
725
Antoine Pitroub5b37142012-11-13 21:35:40 +0100726 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000727 "Docstrings are omitted with -O2 and above")
728 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600729 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000730 self.assertEqual(wrapper.__doc__, 'This is a test')
731
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000732 def test_no_update(self):
733 def f():
734 """This is a test"""
735 pass
736 f.attr = 'This is also a test'
737 @functools.wraps(f, (), ())
738 def wrapper():
739 pass
740 self.check_wrapper(wrapper, f, (), ())
741 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600742 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000743 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000744 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000745
746 def test_selective_update(self):
747 def f():
748 pass
749 f.attr = 'This is a different test'
750 f.dict_attr = dict(a=1, b=2, c=3)
751 def add_dict_attr(f):
752 f.dict_attr = {}
753 return f
754 assign = ('attr',)
755 update = ('dict_attr',)
756 @functools.wraps(f, assign, update)
757 @add_dict_attr
758 def wrapper():
759 pass
760 self.check_wrapper(wrapper, f, assign, update)
761 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600762 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000763 self.assertEqual(wrapper.__doc__, None)
764 self.assertEqual(wrapper.attr, 'This is a different test')
765 self.assertEqual(wrapper.dict_attr, f.dict_attr)
766
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000767
madman-bobe25d5fc2018-10-25 15:02:10 +0100768class TestReduce:
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000769 def test_reduce(self):
770 class Squares:
771 def __init__(self, max):
772 self.max = max
773 self.sofar = []
774
775 def __len__(self):
776 return len(self.sofar)
777
778 def __getitem__(self, i):
779 if not 0 <= i < self.max: raise IndexError
780 n = len(self.sofar)
781 while n <= i:
782 self.sofar.append(n*n)
783 n += 1
784 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000785 def add(x, y):
786 return x + y
madman-bobe25d5fc2018-10-25 15:02:10 +0100787 self.assertEqual(self.reduce(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000788 self.assertEqual(
madman-bobe25d5fc2018-10-25 15:02:10 +0100789 self.reduce(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000790 ['a','c','d','w']
791 )
madman-bobe25d5fc2018-10-25 15:02:10 +0100792 self.assertEqual(self.reduce(lambda x, y: x*y, range(2,8), 1), 5040)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000793 self.assertEqual(
madman-bobe25d5fc2018-10-25 15:02:10 +0100794 self.reduce(lambda x, y: x*y, range(2,21), 1),
Guido van Rossume2a383d2007-01-15 16:59:06 +0000795 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000796 )
madman-bobe25d5fc2018-10-25 15:02:10 +0100797 self.assertEqual(self.reduce(add, Squares(10)), 285)
798 self.assertEqual(self.reduce(add, Squares(10), 0), 285)
799 self.assertEqual(self.reduce(add, Squares(0), 0), 0)
800 self.assertRaises(TypeError, self.reduce)
801 self.assertRaises(TypeError, self.reduce, 42, 42)
802 self.assertRaises(TypeError, self.reduce, 42, 42, 42)
803 self.assertEqual(self.reduce(42, "1"), "1") # func is never called with one item
804 self.assertEqual(self.reduce(42, "", "1"), "1") # func is never called with one item
805 self.assertRaises(TypeError, self.reduce, 42, (42, 42))
806 self.assertRaises(TypeError, self.reduce, add, []) # arg 2 must not be empty sequence with no initial value
807 self.assertRaises(TypeError, self.reduce, add, "")
808 self.assertRaises(TypeError, self.reduce, add, ())
809 self.assertRaises(TypeError, self.reduce, add, object())
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000810
811 class TestFailingIter:
812 def __iter__(self):
813 raise RuntimeError
madman-bobe25d5fc2018-10-25 15:02:10 +0100814 self.assertRaises(RuntimeError, self.reduce, add, TestFailingIter())
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000815
madman-bobe25d5fc2018-10-25 15:02:10 +0100816 self.assertEqual(self.reduce(add, [], None), None)
817 self.assertEqual(self.reduce(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000818
819 class BadSeq:
820 def __getitem__(self, index):
821 raise ValueError
madman-bobe25d5fc2018-10-25 15:02:10 +0100822 self.assertRaises(ValueError, self.reduce, 42, BadSeq())
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000823
824 # Test reduce()'s use of iterators.
825 def test_iterator_usage(self):
826 class SequenceClass:
827 def __init__(self, n):
828 self.n = n
829 def __getitem__(self, i):
830 if 0 <= i < self.n:
831 return i
832 else:
833 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000834
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000835 from operator import add
madman-bobe25d5fc2018-10-25 15:02:10 +0100836 self.assertEqual(self.reduce(add, SequenceClass(5)), 10)
837 self.assertEqual(self.reduce(add, SequenceClass(5), 42), 52)
838 self.assertRaises(TypeError, self.reduce, add, SequenceClass(0))
839 self.assertEqual(self.reduce(add, SequenceClass(0), 42), 42)
840 self.assertEqual(self.reduce(add, SequenceClass(1)), 0)
841 self.assertEqual(self.reduce(add, SequenceClass(1), 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000842
843 d = {"one": 1, "two": 2, "three": 3}
madman-bobe25d5fc2018-10-25 15:02:10 +0100844 self.assertEqual(self.reduce(add, d), "".join(d.keys()))
845
846
847@unittest.skipUnless(c_functools, 'requires the C _functools module')
848class TestReduceC(TestReduce, unittest.TestCase):
849 if c_functools:
850 reduce = c_functools.reduce
851
852
853class TestReducePy(TestReduce, unittest.TestCase):
854 reduce = staticmethod(py_functools.reduce)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000855
Łukasz Langa6f692512013-06-05 12:20:24 +0200856
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200857class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700858
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000859 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700860 def cmp1(x, y):
861 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100862 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700863 self.assertEqual(key(3), key(3))
864 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100865 self.assertGreaterEqual(key(3), key(3))
866
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700867 def cmp2(x, y):
868 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100869 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700870 self.assertEqual(key(4.0), key('4'))
871 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100872 self.assertLessEqual(key(2), key('35'))
873 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700874
875 def test_cmp_to_key_arguments(self):
876 def cmp1(x, y):
877 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100878 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700879 self.assertEqual(key(obj=3), key(obj=3))
880 self.assertGreater(key(obj=3), key(obj=1))
881 with self.assertRaises((TypeError, AttributeError)):
882 key(3) > 1 # rhs is not a K object
883 with self.assertRaises((TypeError, AttributeError)):
884 1 < key(3) # lhs is not a K object
885 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100886 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700887 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200888 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100889 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700890 with self.assertRaises(TypeError):
891 key() # too few args
892 with self.assertRaises(TypeError):
893 key(None, None) # too many args
894
895 def test_bad_cmp(self):
896 def cmp1(x, y):
897 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100898 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700899 with self.assertRaises(ZeroDivisionError):
900 key(3) > key(1)
901
902 class BadCmp:
903 def __lt__(self, other):
904 raise ZeroDivisionError
905 def cmp1(x, y):
906 return BadCmp()
907 with self.assertRaises(ZeroDivisionError):
908 key(3) > key(1)
909
910 def test_obj_field(self):
911 def cmp1(x, y):
912 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100913 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700914 self.assertEqual(key(50).obj, 50)
915
916 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000917 def mycmp(x, y):
918 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100919 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000920 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000921
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700922 def test_sort_int_str(self):
923 def mycmp(x, y):
924 x, y = int(x), int(y)
925 return (x > y) - (x < y)
926 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100927 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700928 self.assertEqual([int(value) for value in values],
929 [0, 1, 1, 2, 3, 4, 5, 7, 10])
930
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000931 def test_hash(self):
932 def mycmp(x, y):
933 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100934 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000935 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700936 self.assertRaises(TypeError, hash, k)
Serhiy Storchaka2e576f52017-04-24 09:05:00 +0300937 self.assertNotIsInstance(k, collections.abc.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000938
Łukasz Langa6f692512013-06-05 12:20:24 +0200939
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200940@unittest.skipUnless(c_functools, 'requires the C _functools module')
941class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
942 if c_functools:
943 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100944
Łukasz Langa6f692512013-06-05 12:20:24 +0200945
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200946class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100947 cmp_to_key = staticmethod(py_functools.cmp_to_key)
948
Łukasz Langa6f692512013-06-05 12:20:24 +0200949
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000950class TestTotalOrdering(unittest.TestCase):
951
952 def test_total_ordering_lt(self):
953 @functools.total_ordering
954 class A:
955 def __init__(self, value):
956 self.value = value
957 def __lt__(self, other):
958 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000959 def __eq__(self, other):
960 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000961 self.assertTrue(A(1) < A(2))
962 self.assertTrue(A(2) > A(1))
963 self.assertTrue(A(1) <= A(2))
964 self.assertTrue(A(2) >= A(1))
965 self.assertTrue(A(2) <= A(2))
966 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000967 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000968
969 def test_total_ordering_le(self):
970 @functools.total_ordering
971 class A:
972 def __init__(self, value):
973 self.value = value
974 def __le__(self, other):
975 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000976 def __eq__(self, other):
977 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000978 self.assertTrue(A(1) < A(2))
979 self.assertTrue(A(2) > A(1))
980 self.assertTrue(A(1) <= A(2))
981 self.assertTrue(A(2) >= A(1))
982 self.assertTrue(A(2) <= A(2))
983 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000984 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000985
986 def test_total_ordering_gt(self):
987 @functools.total_ordering
988 class A:
989 def __init__(self, value):
990 self.value = value
991 def __gt__(self, other):
992 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000993 def __eq__(self, other):
994 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000995 self.assertTrue(A(1) < A(2))
996 self.assertTrue(A(2) > A(1))
997 self.assertTrue(A(1) <= A(2))
998 self.assertTrue(A(2) >= A(1))
999 self.assertTrue(A(2) <= A(2))
1000 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +10001001 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +00001002
1003 def test_total_ordering_ge(self):
1004 @functools.total_ordering
1005 class A:
1006 def __init__(self, value):
1007 self.value = value
1008 def __ge__(self, other):
1009 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001010 def __eq__(self, other):
1011 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +00001012 self.assertTrue(A(1) < A(2))
1013 self.assertTrue(A(2) > A(1))
1014 self.assertTrue(A(1) <= A(2))
1015 self.assertTrue(A(2) >= A(1))
1016 self.assertTrue(A(2) <= A(2))
1017 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +10001018 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +00001019
1020 def test_total_ordering_no_overwrite(self):
1021 # new methods should not overwrite existing
1022 @functools.total_ordering
1023 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +00001024 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +00001025 self.assertTrue(A(1) < A(2))
1026 self.assertTrue(A(2) > A(1))
1027 self.assertTrue(A(1) <= A(2))
1028 self.assertTrue(A(2) >= A(1))
1029 self.assertTrue(A(2) <= A(2))
1030 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001031
Benjamin Peterson42ebee32010-04-11 01:43:16 +00001032 def test_no_operations_defined(self):
1033 with self.assertRaises(ValueError):
1034 @functools.total_ordering
1035 class A:
1036 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001037
Nick Coghlanf05d9812013-10-02 00:02:03 +10001038 def test_type_error_when_not_implemented(self):
1039 # bug 10042; ensure stack overflow does not occur
1040 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001041 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +10001042 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001043 def __init__(self, value):
1044 self.value = value
1045 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001046 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001047 return self.value == other.value
1048 return False
1049 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001050 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001051 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +10001052 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001053
Nick Coghlanf05d9812013-10-02 00:02:03 +10001054 @functools.total_ordering
1055 class ImplementsGreaterThan:
1056 def __init__(self, value):
1057 self.value = value
1058 def __eq__(self, other):
1059 if isinstance(other, ImplementsGreaterThan):
1060 return self.value == other.value
1061 return False
1062 def __gt__(self, other):
1063 if isinstance(other, ImplementsGreaterThan):
1064 return self.value > other.value
1065 return NotImplemented
1066
1067 @functools.total_ordering
1068 class ImplementsLessThanEqualTo:
1069 def __init__(self, value):
1070 self.value = value
1071 def __eq__(self, other):
1072 if isinstance(other, ImplementsLessThanEqualTo):
1073 return self.value == other.value
1074 return False
1075 def __le__(self, other):
1076 if isinstance(other, ImplementsLessThanEqualTo):
1077 return self.value <= other.value
1078 return NotImplemented
1079
1080 @functools.total_ordering
1081 class ImplementsGreaterThanEqualTo:
1082 def __init__(self, value):
1083 self.value = value
1084 def __eq__(self, other):
1085 if isinstance(other, ImplementsGreaterThanEqualTo):
1086 return self.value == other.value
1087 return False
1088 def __ge__(self, other):
1089 if isinstance(other, ImplementsGreaterThanEqualTo):
1090 return self.value >= other.value
1091 return NotImplemented
1092
1093 @functools.total_ordering
1094 class ComparatorNotImplemented:
1095 def __init__(self, value):
1096 self.value = value
1097 def __eq__(self, other):
1098 if isinstance(other, ComparatorNotImplemented):
1099 return self.value == other.value
1100 return False
1101 def __lt__(self, other):
1102 return NotImplemented
1103
1104 with self.subTest("LT < 1"), self.assertRaises(TypeError):
1105 ImplementsLessThan(-1) < 1
1106
1107 with self.subTest("LT < LE"), self.assertRaises(TypeError):
1108 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
1109
1110 with self.subTest("LT < GT"), self.assertRaises(TypeError):
1111 ImplementsLessThan(1) < ImplementsGreaterThan(1)
1112
1113 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
1114 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
1115
1116 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
1117 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
1118
1119 with self.subTest("GT > GE"), self.assertRaises(TypeError):
1120 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
1121
1122 with self.subTest("GT > LT"), self.assertRaises(TypeError):
1123 ImplementsGreaterThan(5) > ImplementsLessThan(5)
1124
1125 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
1126 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
1127
1128 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
1129 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
1130
1131 with self.subTest("GE when equal"):
1132 a = ComparatorNotImplemented(8)
1133 b = ComparatorNotImplemented(8)
1134 self.assertEqual(a, b)
1135 with self.assertRaises(TypeError):
1136 a >= b
1137
1138 with self.subTest("LE when equal"):
1139 a = ComparatorNotImplemented(9)
1140 b = ComparatorNotImplemented(9)
1141 self.assertEqual(a, b)
1142 with self.assertRaises(TypeError):
1143 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +02001144
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001145 def test_pickle(self):
Serhiy Storchaka92bb90a2016-09-22 11:39:25 +03001146 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001147 for name in '__lt__', '__gt__', '__le__', '__ge__':
1148 with self.subTest(method=name, proto=proto):
1149 method = getattr(Orderable_LT, name)
1150 method_copy = pickle.loads(pickle.dumps(method, proto))
1151 self.assertIs(method_copy, method)
1152
1153@functools.total_ordering
1154class Orderable_LT:
1155 def __init__(self, value):
1156 self.value = value
1157 def __lt__(self, other):
1158 return self.value < other.value
1159 def __eq__(self, other):
1160 return self.value == other.value
1161
1162
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001163class TestLRU:
Georg Brandl2e7346a2010-07-31 18:09:23 +00001164
1165 def test_lru(self):
1166 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001167 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001168 f = self.module.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001169 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001170 self.assertEqual(maxsize, 20)
1171 self.assertEqual(currsize, 0)
1172 self.assertEqual(hits, 0)
1173 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001174
1175 domain = range(5)
1176 for i in range(1000):
1177 x, y = choice(domain), choice(domain)
1178 actual = f(x, y)
1179 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +00001180 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001181 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001182 self.assertTrue(hits > misses)
1183 self.assertEqual(hits + misses, 1000)
1184 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001185
Raymond Hettinger02566ec2010-09-04 22:46:06 +00001186 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +00001187 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001188 self.assertEqual(hits, 0)
1189 self.assertEqual(misses, 0)
1190 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001191 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001192 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001193 self.assertEqual(hits, 0)
1194 self.assertEqual(misses, 1)
1195 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001196
Nick Coghlan98876832010-08-17 06:17:18 +00001197 # Test bypassing the cache
1198 self.assertIs(f.__wrapped__, orig)
1199 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001200 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001201 self.assertEqual(hits, 0)
1202 self.assertEqual(misses, 1)
1203 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +00001204
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001205 # test size zero (which means "never-cache")
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001206 @self.module.lru_cache(0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001207 def f():
1208 nonlocal f_cnt
1209 f_cnt += 1
1210 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001211 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001212 f_cnt = 0
1213 for i in range(5):
1214 self.assertEqual(f(), 20)
1215 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001216 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001217 self.assertEqual(hits, 0)
1218 self.assertEqual(misses, 5)
1219 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001220
1221 # test size one
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001222 @self.module.lru_cache(1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001223 def f():
1224 nonlocal f_cnt
1225 f_cnt += 1
1226 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001227 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001228 f_cnt = 0
1229 for i in range(5):
1230 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001231 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001232 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001233 self.assertEqual(hits, 4)
1234 self.assertEqual(misses, 1)
1235 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001236
Raymond Hettingerf3098282010-08-15 03:30:45 +00001237 # test size two
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001238 @self.module.lru_cache(2)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001239 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001240 nonlocal f_cnt
1241 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +00001242 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +00001243 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001244 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +00001245 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1246 # * * * *
1247 self.assertEqual(f(x), x*10)
1248 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001249 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001250 self.assertEqual(hits, 12)
1251 self.assertEqual(misses, 4)
1252 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001253
Raymond Hettingerb8218682019-05-26 11:27:35 -07001254 def test_lru_no_args(self):
1255 @self.module.lru_cache
1256 def square(x):
1257 return x ** 2
1258
1259 self.assertEqual(list(map(square, [10, 20, 10])),
1260 [100, 400, 100])
1261 self.assertEqual(square.cache_info().hits, 1)
1262 self.assertEqual(square.cache_info().misses, 2)
1263 self.assertEqual(square.cache_info().maxsize, 128)
1264 self.assertEqual(square.cache_info().currsize, 2)
1265
Raymond Hettingerd8080c02019-01-26 03:02:00 -05001266 def test_lru_bug_35780(self):
1267 # C version of the lru_cache was not checking to see if
1268 # the user function call has already modified the cache
1269 # (this arises in recursive calls and in multi-threading).
1270 # This cause the cache to have orphan links not referenced
1271 # by the cache dictionary.
1272
1273 once = True # Modified by f(x) below
1274
1275 @self.module.lru_cache(maxsize=10)
1276 def f(x):
1277 nonlocal once
1278 rv = f'.{x}.'
1279 if x == 20 and once:
1280 once = False
1281 rv = f(x)
1282 return rv
1283
1284 # Fill the cache
1285 for x in range(15):
1286 self.assertEqual(f(x), f'.{x}.')
1287 self.assertEqual(f.cache_info().currsize, 10)
1288
1289 # Make a recursive call and make sure the cache remains full
1290 self.assertEqual(f(20), '.20.')
1291 self.assertEqual(f.cache_info().currsize, 10)
1292
Raymond Hettinger14adbd42019-04-20 07:20:44 -10001293 def test_lru_bug_36650(self):
1294 # C version of lru_cache was treating a call with an empty **kwargs
1295 # dictionary as being distinct from a call with no keywords at all.
1296 # This did not result in an incorrect answer, but it did trigger
1297 # an unexpected cache miss.
1298
1299 @self.module.lru_cache()
1300 def f(x):
1301 pass
1302
1303 f(0)
1304 f(0, **{})
1305 self.assertEqual(f.cache_info().hits, 1)
1306
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001307 def test_lru_hash_only_once(self):
1308 # To protect against weird reentrancy bugs and to improve
1309 # efficiency when faced with slow __hash__ methods, the
1310 # LRU cache guarantees that it will only call __hash__
1311 # only once per use as an argument to the cached function.
1312
1313 @self.module.lru_cache(maxsize=1)
1314 def f(x, y):
1315 return x * 3 + y
1316
1317 # Simulate the integer 5
1318 mock_int = unittest.mock.Mock()
1319 mock_int.__mul__ = unittest.mock.Mock(return_value=15)
1320 mock_int.__hash__ = unittest.mock.Mock(return_value=999)
1321
1322 # Add to cache: One use as an argument gives one call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001323 self.assertEqual(f(mock_int, 1), 16)
1324 self.assertEqual(mock_int.__hash__.call_count, 1)
1325 self.assertEqual(f.cache_info(), (0, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001326
1327 # Cache hit: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001328 self.assertEqual(f(mock_int, 1), 16)
1329 self.assertEqual(mock_int.__hash__.call_count, 2)
1330 self.assertEqual(f.cache_info(), (1, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001331
Ville Skyttä49b27342017-08-03 09:00:59 +03001332 # Cache eviction: No use as an argument gives no additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001333 self.assertEqual(f(6, 2), 20)
1334 self.assertEqual(mock_int.__hash__.call_count, 2)
1335 self.assertEqual(f.cache_info(), (1, 2, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001336
1337 # Cache miss: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001338 self.assertEqual(f(mock_int, 1), 16)
1339 self.assertEqual(mock_int.__hash__.call_count, 3)
1340 self.assertEqual(f.cache_info(), (1, 3, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001341
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08001342 def test_lru_reentrancy_with_len(self):
1343 # Test to make sure the LRU cache code isn't thrown-off by
1344 # caching the built-in len() function. Since len() can be
1345 # cached, we shouldn't use it inside the lru code itself.
1346 old_len = builtins.len
1347 try:
1348 builtins.len = self.module.lru_cache(4)(len)
1349 for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
1350 self.assertEqual(len('abcdefghijklmn'[:i]), i)
1351 finally:
1352 builtins.len = old_len
1353
Raymond Hettinger605a4472017-01-09 07:50:19 -08001354 def test_lru_star_arg_handling(self):
1355 # Test regression that arose in ea064ff3c10f
1356 @functools.lru_cache()
1357 def f(*args):
1358 return args
1359
1360 self.assertEqual(f(1, 2), (1, 2))
1361 self.assertEqual(f((1, 2)), ((1, 2),))
1362
Yury Selivanov46a02db2016-11-09 18:55:45 -05001363 def test_lru_type_error(self):
1364 # Regression test for issue #28653.
1365 # lru_cache was leaking when one of the arguments
1366 # wasn't cacheable.
1367
1368 @functools.lru_cache(maxsize=None)
1369 def infinite_cache(o):
1370 pass
1371
1372 @functools.lru_cache(maxsize=10)
1373 def limited_cache(o):
1374 pass
1375
1376 with self.assertRaises(TypeError):
1377 infinite_cache([])
1378
1379 with self.assertRaises(TypeError):
1380 limited_cache([])
1381
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001382 def test_lru_with_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001383 @self.module.lru_cache(maxsize=None)
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001384 def fib(n):
1385 if n < 2:
1386 return n
1387 return fib(n-1) + fib(n-2)
1388 self.assertEqual([fib(n) for n in range(16)],
1389 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1390 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001391 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001392 fib.cache_clear()
1393 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001394 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1395
1396 def test_lru_with_maxsize_negative(self):
1397 @self.module.lru_cache(maxsize=-10)
1398 def eq(n):
1399 return n
1400 for i in (0, 1):
1401 self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1402 self.assertEqual(eq.cache_info(),
Raymond Hettingerd8080c02019-01-26 03:02:00 -05001403 self.module._CacheInfo(hits=0, misses=300, maxsize=0, currsize=0))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001404
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001405 def test_lru_with_exceptions(self):
1406 # Verify that user_function exceptions get passed through without
1407 # creating a hard-to-read chained exception.
1408 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +01001409 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001410 @self.module.lru_cache(maxsize)
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001411 def func(i):
1412 return 'abc'[i]
1413 self.assertEqual(func(0), 'a')
1414 with self.assertRaises(IndexError) as cm:
1415 func(15)
1416 self.assertIsNone(cm.exception.__context__)
1417 # Verify that the previous exception did not result in a cached entry
1418 with self.assertRaises(IndexError):
1419 func(15)
1420
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001421 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001422 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001423 @self.module.lru_cache(maxsize=maxsize, typed=True)
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001424 def square(x):
1425 return x * x
1426 self.assertEqual(square(3), 9)
1427 self.assertEqual(type(square(3)), type(9))
1428 self.assertEqual(square(3.0), 9.0)
1429 self.assertEqual(type(square(3.0)), type(9.0))
1430 self.assertEqual(square(x=3), 9)
1431 self.assertEqual(type(square(x=3)), type(9))
1432 self.assertEqual(square(x=3.0), 9.0)
1433 self.assertEqual(type(square(x=3.0)), type(9.0))
1434 self.assertEqual(square.cache_info().hits, 4)
1435 self.assertEqual(square.cache_info().misses, 4)
1436
Antoine Pitroub5b37142012-11-13 21:35:40 +01001437 def test_lru_with_keyword_args(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001438 @self.module.lru_cache()
Antoine Pitroub5b37142012-11-13 21:35:40 +01001439 def fib(n):
1440 if n < 2:
1441 return n
1442 return fib(n=n-1) + fib(n=n-2)
1443 self.assertEqual(
1444 [fib(n=number) for number in range(16)],
1445 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1446 )
1447 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001448 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001449 fib.cache_clear()
1450 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001451 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001452
1453 def test_lru_with_keyword_args_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001454 @self.module.lru_cache(maxsize=None)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001455 def fib(n):
1456 if n < 2:
1457 return n
1458 return fib(n=n-1) + fib(n=n-2)
1459 self.assertEqual([fib(n=number) for number in range(16)],
1460 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1461 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001462 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001463 fib.cache_clear()
1464 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001465 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1466
Raymond Hettinger4ee39142017-01-08 17:28:20 -08001467 def test_kwargs_order(self):
1468 # PEP 468: Preserving Keyword Argument Order
1469 @self.module.lru_cache(maxsize=10)
1470 def f(**kwargs):
1471 return list(kwargs.items())
1472 self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)])
1473 self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)])
1474 self.assertEqual(f.cache_info(),
1475 self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2))
1476
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001477 def test_lru_cache_decoration(self):
1478 def f(zomg: 'zomg_annotation'):
1479 """f doc string"""
1480 return 42
1481 g = self.module.lru_cache()(f)
1482 for attr in self.module.WRAPPER_ASSIGNMENTS:
1483 self.assertEqual(getattr(g, attr), getattr(f, attr))
1484
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001485 def test_lru_cache_threaded(self):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001486 n, m = 5, 11
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001487 def orig(x, y):
1488 return 3 * x + y
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001489 f = self.module.lru_cache(maxsize=n*m)(orig)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001490 hits, misses, maxsize, currsize = f.cache_info()
1491 self.assertEqual(currsize, 0)
1492
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001493 start = threading.Event()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001494 def full(k):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001495 start.wait(10)
1496 for _ in range(m):
Serhiy Storchaka391af752015-06-08 12:44:18 +03001497 self.assertEqual(f(k, 0), orig(k, 0))
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001498
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001499 def clear():
1500 start.wait(10)
1501 for _ in range(2*m):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001502 f.cache_clear()
1503
1504 orig_si = sys.getswitchinterval()
Xavier de Gaye7522ef42016-12-08 11:06:56 +01001505 support.setswitchinterval(1e-6)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001506 try:
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001507 # create n threads in order to fill cache
Serhiy Storchaka391af752015-06-08 12:44:18 +03001508 threads = [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001509 for k in range(n)]
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001510 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001511 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001512
1513 hits, misses, maxsize, currsize = f.cache_info()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001514 if self.module is py_functools:
1515 # XXX: Why can be not equal?
1516 self.assertLessEqual(misses, n)
1517 self.assertLessEqual(hits, m*n - misses)
1518 else:
1519 self.assertEqual(misses, n)
1520 self.assertEqual(hits, m*n - misses)
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001521 self.assertEqual(currsize, n)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001522
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001523 # create n threads in order to fill cache and 1 to clear it
1524 threads = [threading.Thread(target=clear)]
Serhiy Storchaka391af752015-06-08 12:44:18 +03001525 threads += [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001526 for k in range(n)]
1527 start.clear()
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001528 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001529 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001530 finally:
1531 sys.setswitchinterval(orig_si)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001532
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001533 def test_lru_cache_threaded2(self):
1534 # Simultaneous call with the same arguments
1535 n, m = 5, 7
1536 start = threading.Barrier(n+1)
1537 pause = threading.Barrier(n+1)
1538 stop = threading.Barrier(n+1)
1539 @self.module.lru_cache(maxsize=m*n)
1540 def f(x):
1541 pause.wait(10)
1542 return 3 * x
1543 self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1544 def test():
1545 for i in range(m):
1546 start.wait(10)
1547 self.assertEqual(f(i), 3 * i)
1548 stop.wait(10)
1549 threads = [threading.Thread(target=test) for k in range(n)]
1550 with support.start_threads(threads):
1551 for i in range(m):
1552 start.wait(10)
1553 stop.reset()
1554 pause.wait(10)
1555 start.reset()
1556 stop.wait(10)
1557 pause.reset()
1558 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1559
Serhiy Storchaka67796522017-01-12 18:34:33 +02001560 def test_lru_cache_threaded3(self):
1561 @self.module.lru_cache(maxsize=2)
1562 def f(x):
1563 time.sleep(.01)
1564 return 3 * x
1565 def test(i, x):
1566 with self.subTest(thread=i):
1567 self.assertEqual(f(x), 3 * x, i)
1568 threads = [threading.Thread(target=test, args=(i, v))
1569 for i, v in enumerate([1, 2, 2, 3, 2])]
1570 with support.start_threads(threads):
1571 pass
1572
Raymond Hettinger03923422013-03-04 02:52:50 -05001573 def test_need_for_rlock(self):
1574 # This will deadlock on an LRU cache that uses a regular lock
1575
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001576 @self.module.lru_cache(maxsize=10)
Raymond Hettinger03923422013-03-04 02:52:50 -05001577 def test_func(x):
1578 'Used to demonstrate a reentrant lru_cache call within a single thread'
1579 return x
1580
1581 class DoubleEq:
1582 'Demonstrate a reentrant lru_cache call within a single thread'
1583 def __init__(self, x):
1584 self.x = x
1585 def __hash__(self):
1586 return self.x
1587 def __eq__(self, other):
1588 if self.x == 2:
1589 test_func(DoubleEq(1))
1590 return self.x == other.x
1591
1592 test_func(DoubleEq(1)) # Load the cache
1593 test_func(DoubleEq(2)) # Load the cache
1594 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1595 DoubleEq(2)) # Verify the correct return value
1596
Serhiy Storchakae7070f02015-06-08 11:19:24 +03001597 def test_lru_method(self):
1598 class X(int):
1599 f_cnt = 0
1600 @self.module.lru_cache(2)
1601 def f(self, x):
1602 self.f_cnt += 1
1603 return x*10+self
1604 a = X(5)
1605 b = X(5)
1606 c = X(7)
1607 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1608
1609 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1610 self.assertEqual(a.f(x), x*10 + 5)
1611 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1612 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1613
1614 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1615 self.assertEqual(b.f(x), x*10 + 5)
1616 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1617 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1618
1619 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1620 self.assertEqual(c.f(x), x*10 + 7)
1621 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1622 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1623
1624 self.assertEqual(a.f.cache_info(), X.f.cache_info())
1625 self.assertEqual(b.f.cache_info(), X.f.cache_info())
1626 self.assertEqual(c.f.cache_info(), X.f.cache_info())
1627
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001628 def test_pickle(self):
1629 cls = self.__class__
1630 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1631 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1632 with self.subTest(proto=proto, func=f):
1633 f_copy = pickle.loads(pickle.dumps(f, proto))
1634 self.assertIs(f_copy, f)
1635
1636 def test_copy(self):
1637 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001638 def orig(x, y):
1639 return 3 * x + y
1640 part = self.module.partial(orig, 2)
1641 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1642 self.module.lru_cache(2)(part))
1643 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001644 with self.subTest(func=f):
1645 f_copy = copy.copy(f)
1646 self.assertIs(f_copy, f)
1647
1648 def test_deepcopy(self):
1649 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001650 def orig(x, y):
1651 return 3 * x + y
1652 part = self.module.partial(orig, 2)
1653 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1654 self.module.lru_cache(2)(part))
1655 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001656 with self.subTest(func=f):
1657 f_copy = copy.deepcopy(f)
1658 self.assertIs(f_copy, f)
1659
1660
1661@py_functools.lru_cache()
1662def py_cached_func(x, y):
1663 return 3 * x + y
1664
1665@c_functools.lru_cache()
1666def c_cached_func(x, y):
1667 return 3 * x + y
1668
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001669
1670class TestLRUPy(TestLRU, unittest.TestCase):
1671 module = py_functools
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001672 cached_func = py_cached_func,
1673
1674 @module.lru_cache()
1675 def cached_meth(self, x, y):
1676 return 3 * x + y
1677
1678 @staticmethod
1679 @module.lru_cache()
1680 def cached_staticmeth(x, y):
1681 return 3 * x + y
1682
1683
1684class TestLRUC(TestLRU, unittest.TestCase):
1685 module = c_functools
1686 cached_func = c_cached_func,
1687
1688 @module.lru_cache()
1689 def cached_meth(self, x, y):
1690 return 3 * x + y
1691
1692 @staticmethod
1693 @module.lru_cache()
1694 def cached_staticmeth(x, y):
1695 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001696
Raymond Hettinger03923422013-03-04 02:52:50 -05001697
Łukasz Langa6f692512013-06-05 12:20:24 +02001698class TestSingleDispatch(unittest.TestCase):
1699 def test_simple_overloads(self):
1700 @functools.singledispatch
1701 def g(obj):
1702 return "base"
1703 def g_int(i):
1704 return "integer"
1705 g.register(int, g_int)
1706 self.assertEqual(g("str"), "base")
1707 self.assertEqual(g(1), "integer")
1708 self.assertEqual(g([1,2,3]), "base")
1709
1710 def test_mro(self):
1711 @functools.singledispatch
1712 def g(obj):
1713 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001714 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001715 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001716 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001717 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001718 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001719 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001720 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001721 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001722 def g_A(a):
1723 return "A"
1724 def g_B(b):
1725 return "B"
1726 g.register(A, g_A)
1727 g.register(B, g_B)
1728 self.assertEqual(g(A()), "A")
1729 self.assertEqual(g(B()), "B")
1730 self.assertEqual(g(C()), "A")
1731 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001732
1733 def test_register_decorator(self):
1734 @functools.singledispatch
1735 def g(obj):
1736 return "base"
1737 @g.register(int)
1738 def g_int(i):
1739 return "int %s" % (i,)
1740 self.assertEqual(g(""), "base")
1741 self.assertEqual(g(12), "int 12")
1742 self.assertIs(g.dispatch(int), g_int)
1743 self.assertIs(g.dispatch(object), g.dispatch(str))
1744 # Note: in the assert above this is not g.
1745 # @singledispatch returns the wrapper.
1746
1747 def test_wrapping_attributes(self):
1748 @functools.singledispatch
1749 def g(obj):
1750 "Simple test"
1751 return "Test"
1752 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02001753 if sys.flags.optimize < 2:
1754 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02001755
1756 @unittest.skipUnless(decimal, 'requires _decimal')
1757 @support.cpython_only
1758 def test_c_classes(self):
1759 @functools.singledispatch
1760 def g(obj):
1761 return "base"
1762 @g.register(decimal.DecimalException)
1763 def _(obj):
1764 return obj.args
1765 subn = decimal.Subnormal("Exponent < Emin")
1766 rnd = decimal.Rounded("Number got rounded")
1767 self.assertEqual(g(subn), ("Exponent < Emin",))
1768 self.assertEqual(g(rnd), ("Number got rounded",))
1769 @g.register(decimal.Subnormal)
1770 def _(obj):
1771 return "Too small to care."
1772 self.assertEqual(g(subn), "Too small to care.")
1773 self.assertEqual(g(rnd), ("Number got rounded",))
1774
1775 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001776 # None of the examples in this test depend on haystack ordering.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001777 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001778 mro = functools._compose_mro
1779 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1780 for haystack in permutations(bases):
1781 m = mro(dict, haystack)
Guido van Rossumf0666942016-08-23 10:47:07 -07001782 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
1783 c.Collection, c.Sized, c.Iterable,
1784 c.Container, object])
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001785 bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict]
Łukasz Langa6f692512013-06-05 12:20:24 +02001786 for haystack in permutations(bases):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001787 m = mro(collections.ChainMap, haystack)
1788 self.assertEqual(m, [collections.ChainMap, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001789 c.Collection, c.Sized, c.Iterable,
1790 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001791
1792 # If there's a generic function with implementations registered for
1793 # both Sized and Container, passing a defaultdict to it results in an
1794 # ambiguous dispatch which will cause a RuntimeError (see
1795 # test_mro_conflicts).
1796 bases = [c.Container, c.Sized, str]
1797 for haystack in permutations(bases):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001798 m = mro(collections.defaultdict, [c.Sized, c.Container, str])
1799 self.assertEqual(m, [collections.defaultdict, dict, c.Sized,
1800 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001801
1802 # MutableSequence below is registered directly on D. In other words, it
Martin Panter46f50722016-05-26 05:35:26 +00001803 # precedes MutableMapping which means single dispatch will always
Łukasz Langa3720c772013-07-01 16:00:38 +02001804 # choose MutableSequence here.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001805 class D(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02001806 pass
1807 c.MutableSequence.register(D)
1808 bases = [c.MutableSequence, c.MutableMapping]
1809 for haystack in permutations(bases):
1810 m = mro(D, bases)
Guido van Rossumf0666942016-08-23 10:47:07 -07001811 self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001812 collections.defaultdict, dict, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001813 c.Collection, c.Sized, c.Iterable, c.Container,
Łukasz Langa3720c772013-07-01 16:00:38 +02001814 object])
1815
1816 # Container and Callable are registered on different base classes and
1817 # a generic function supporting both should always pick the Callable
1818 # implementation if a C instance is passed.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001819 class C(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02001820 def __call__(self):
1821 pass
1822 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1823 for haystack in permutations(bases):
1824 m = mro(C, haystack)
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001825 self.assertEqual(m, [C, c.Callable, collections.defaultdict, dict, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001826 c.Collection, c.Sized, c.Iterable,
1827 c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001828
1829 def test_register_abc(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001830 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001831 d = {"a": "b"}
1832 l = [1, 2, 3]
1833 s = {object(), None}
1834 f = frozenset(s)
1835 t = (1, 2, 3)
1836 @functools.singledispatch
1837 def g(obj):
1838 return "base"
1839 self.assertEqual(g(d), "base")
1840 self.assertEqual(g(l), "base")
1841 self.assertEqual(g(s), "base")
1842 self.assertEqual(g(f), "base")
1843 self.assertEqual(g(t), "base")
1844 g.register(c.Sized, lambda obj: "sized")
1845 self.assertEqual(g(d), "sized")
1846 self.assertEqual(g(l), "sized")
1847 self.assertEqual(g(s), "sized")
1848 self.assertEqual(g(f), "sized")
1849 self.assertEqual(g(t), "sized")
1850 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1851 self.assertEqual(g(d), "mutablemapping")
1852 self.assertEqual(g(l), "sized")
1853 self.assertEqual(g(s), "sized")
1854 self.assertEqual(g(f), "sized")
1855 self.assertEqual(g(t), "sized")
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001856 g.register(collections.ChainMap, lambda obj: "chainmap")
Łukasz Langa6f692512013-06-05 12:20:24 +02001857 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1858 self.assertEqual(g(l), "sized")
1859 self.assertEqual(g(s), "sized")
1860 self.assertEqual(g(f), "sized")
1861 self.assertEqual(g(t), "sized")
1862 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1863 self.assertEqual(g(d), "mutablemapping")
1864 self.assertEqual(g(l), "mutablesequence")
1865 self.assertEqual(g(s), "sized")
1866 self.assertEqual(g(f), "sized")
1867 self.assertEqual(g(t), "sized")
1868 g.register(c.MutableSet, lambda obj: "mutableset")
1869 self.assertEqual(g(d), "mutablemapping")
1870 self.assertEqual(g(l), "mutablesequence")
1871 self.assertEqual(g(s), "mutableset")
1872 self.assertEqual(g(f), "sized")
1873 self.assertEqual(g(t), "sized")
1874 g.register(c.Mapping, lambda obj: "mapping")
1875 self.assertEqual(g(d), "mutablemapping") # not specific enough
1876 self.assertEqual(g(l), "mutablesequence")
1877 self.assertEqual(g(s), "mutableset")
1878 self.assertEqual(g(f), "sized")
1879 self.assertEqual(g(t), "sized")
1880 g.register(c.Sequence, lambda obj: "sequence")
1881 self.assertEqual(g(d), "mutablemapping")
1882 self.assertEqual(g(l), "mutablesequence")
1883 self.assertEqual(g(s), "mutableset")
1884 self.assertEqual(g(f), "sized")
1885 self.assertEqual(g(t), "sequence")
1886 g.register(c.Set, lambda obj: "set")
1887 self.assertEqual(g(d), "mutablemapping")
1888 self.assertEqual(g(l), "mutablesequence")
1889 self.assertEqual(g(s), "mutableset")
1890 self.assertEqual(g(f), "set")
1891 self.assertEqual(g(t), "sequence")
1892 g.register(dict, lambda obj: "dict")
1893 self.assertEqual(g(d), "dict")
1894 self.assertEqual(g(l), "mutablesequence")
1895 self.assertEqual(g(s), "mutableset")
1896 self.assertEqual(g(f), "set")
1897 self.assertEqual(g(t), "sequence")
1898 g.register(list, lambda obj: "list")
1899 self.assertEqual(g(d), "dict")
1900 self.assertEqual(g(l), "list")
1901 self.assertEqual(g(s), "mutableset")
1902 self.assertEqual(g(f), "set")
1903 self.assertEqual(g(t), "sequence")
1904 g.register(set, lambda obj: "concrete-set")
1905 self.assertEqual(g(d), "dict")
1906 self.assertEqual(g(l), "list")
1907 self.assertEqual(g(s), "concrete-set")
1908 self.assertEqual(g(f), "set")
1909 self.assertEqual(g(t), "sequence")
1910 g.register(frozenset, lambda obj: "frozen-set")
1911 self.assertEqual(g(d), "dict")
1912 self.assertEqual(g(l), "list")
1913 self.assertEqual(g(s), "concrete-set")
1914 self.assertEqual(g(f), "frozen-set")
1915 self.assertEqual(g(t), "sequence")
1916 g.register(tuple, lambda obj: "tuple")
1917 self.assertEqual(g(d), "dict")
1918 self.assertEqual(g(l), "list")
1919 self.assertEqual(g(s), "concrete-set")
1920 self.assertEqual(g(f), "frozen-set")
1921 self.assertEqual(g(t), "tuple")
1922
Łukasz Langa3720c772013-07-01 16:00:38 +02001923 def test_c3_abc(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001924 c = collections.abc
Łukasz Langa3720c772013-07-01 16:00:38 +02001925 mro = functools._c3_mro
1926 class A(object):
1927 pass
1928 class B(A):
1929 def __len__(self):
1930 return 0 # implies Sized
1931 @c.Container.register
1932 class C(object):
1933 pass
1934 class D(object):
1935 pass # unrelated
1936 class X(D, C, B):
1937 def __call__(self):
1938 pass # implies Callable
1939 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
1940 for abcs in permutations([c.Sized, c.Callable, c.Container]):
1941 self.assertEqual(mro(X, abcs=abcs), expected)
1942 # unrelated ABCs don't appear in the resulting MRO
1943 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
1944 self.assertEqual(mro(X, abcs=many_abcs), expected)
1945
Yury Selivanov77a8cd62015-08-18 14:20:00 -04001946 def test_false_meta(self):
1947 # see issue23572
1948 class MetaA(type):
1949 def __len__(self):
1950 return 0
1951 class A(metaclass=MetaA):
1952 pass
1953 class AA(A):
1954 pass
1955 @functools.singledispatch
1956 def fun(a):
1957 return 'base A'
1958 @fun.register(A)
1959 def _(a):
1960 return 'fun A'
1961 aa = AA()
1962 self.assertEqual(fun(aa), 'fun A')
1963
Łukasz Langa6f692512013-06-05 12:20:24 +02001964 def test_mro_conflicts(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001965 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001966 @functools.singledispatch
1967 def g(arg):
1968 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02001969 class O(c.Sized):
1970 def __len__(self):
1971 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001972 o = O()
1973 self.assertEqual(g(o), "base")
1974 g.register(c.Iterable, lambda arg: "iterable")
1975 g.register(c.Container, lambda arg: "container")
1976 g.register(c.Sized, lambda arg: "sized")
1977 g.register(c.Set, lambda arg: "set")
1978 self.assertEqual(g(o), "sized")
1979 c.Iterable.register(O)
1980 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
1981 c.Container.register(O)
1982 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02001983 c.Set.register(O)
1984 self.assertEqual(g(o), "set") # because c.Set is a subclass of
1985 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02001986 class P:
1987 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02001988 p = P()
1989 self.assertEqual(g(p), "base")
1990 c.Iterable.register(P)
1991 self.assertEqual(g(p), "iterable")
1992 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02001993 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02001994 g(p)
Benjamin Petersonab078e92016-07-13 21:13:29 -07001995 self.assertIn(
1996 str(re_one.exception),
1997 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1998 "or <class 'collections.abc.Iterable'>"),
1999 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
2000 "or <class 'collections.abc.Container'>")),
2001 )
Łukasz Langa6f692512013-06-05 12:20:24 +02002002 class Q(c.Sized):
2003 def __len__(self):
2004 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02002005 q = Q()
2006 self.assertEqual(g(q), "sized")
2007 c.Iterable.register(Q)
2008 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
2009 c.Set.register(Q)
2010 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02002011 # c.Sized and c.Iterable
2012 @functools.singledispatch
2013 def h(arg):
2014 return "base"
2015 @h.register(c.Sized)
2016 def _(arg):
2017 return "sized"
2018 @h.register(c.Container)
2019 def _(arg):
2020 return "container"
2021 # Even though Sized and Container are explicit bases of MutableMapping,
2022 # this ABC is implicitly registered on defaultdict which makes all of
2023 # MutableMapping's bases implicit as well from defaultdict's
2024 # perspective.
2025 with self.assertRaises(RuntimeError) as re_two:
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002026 h(collections.defaultdict(lambda: 0))
Benjamin Petersonab078e92016-07-13 21:13:29 -07002027 self.assertIn(
2028 str(re_two.exception),
2029 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2030 "or <class 'collections.abc.Sized'>"),
2031 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2032 "or <class 'collections.abc.Container'>")),
2033 )
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002034 class R(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02002035 pass
2036 c.MutableSequence.register(R)
2037 @functools.singledispatch
2038 def i(arg):
2039 return "base"
2040 @i.register(c.MutableMapping)
2041 def _(arg):
2042 return "mapping"
2043 @i.register(c.MutableSequence)
2044 def _(arg):
2045 return "sequence"
2046 r = R()
2047 self.assertEqual(i(r), "sequence")
2048 class S:
2049 pass
2050 class T(S, c.Sized):
2051 def __len__(self):
2052 return 0
2053 t = T()
2054 self.assertEqual(h(t), "sized")
2055 c.Container.register(T)
2056 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
2057 class U:
2058 def __len__(self):
2059 return 0
2060 u = U()
2061 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
2062 # from the existence of __len__()
2063 c.Container.register(U)
2064 # There is no preference for registered versus inferred ABCs.
2065 with self.assertRaises(RuntimeError) as re_three:
2066 h(u)
Benjamin Petersonab078e92016-07-13 21:13:29 -07002067 self.assertIn(
2068 str(re_three.exception),
2069 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2070 "or <class 'collections.abc.Sized'>"),
2071 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2072 "or <class 'collections.abc.Container'>")),
2073 )
Łukasz Langa3720c772013-07-01 16:00:38 +02002074 class V(c.Sized, S):
2075 def __len__(self):
2076 return 0
2077 @functools.singledispatch
2078 def j(arg):
2079 return "base"
2080 @j.register(S)
2081 def _(arg):
2082 return "s"
2083 @j.register(c.Container)
2084 def _(arg):
2085 return "container"
2086 v = V()
2087 self.assertEqual(j(v), "s")
2088 c.Container.register(V)
2089 self.assertEqual(j(v), "container") # because it ends up right after
2090 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02002091
2092 def test_cache_invalidation(self):
2093 from collections import UserDict
INADA Naoki9811e802017-09-30 16:13:02 +09002094 import weakref
2095
Łukasz Langa6f692512013-06-05 12:20:24 +02002096 class TracingDict(UserDict):
2097 def __init__(self, *args, **kwargs):
2098 super(TracingDict, self).__init__(*args, **kwargs)
2099 self.set_ops = []
2100 self.get_ops = []
2101 def __getitem__(self, key):
2102 result = self.data[key]
2103 self.get_ops.append(key)
2104 return result
2105 def __setitem__(self, key, value):
2106 self.set_ops.append(key)
2107 self.data[key] = value
2108 def clear(self):
2109 self.data.clear()
INADA Naoki9811e802017-09-30 16:13:02 +09002110
Łukasz Langa6f692512013-06-05 12:20:24 +02002111 td = TracingDict()
INADA Naoki9811e802017-09-30 16:13:02 +09002112 with support.swap_attr(weakref, "WeakKeyDictionary", lambda: td):
2113 c = collections.abc
2114 @functools.singledispatch
2115 def g(arg):
2116 return "base"
2117 d = {}
2118 l = []
2119 self.assertEqual(len(td), 0)
2120 self.assertEqual(g(d), "base")
2121 self.assertEqual(len(td), 1)
2122 self.assertEqual(td.get_ops, [])
2123 self.assertEqual(td.set_ops, [dict])
2124 self.assertEqual(td.data[dict], g.registry[object])
2125 self.assertEqual(g(l), "base")
2126 self.assertEqual(len(td), 2)
2127 self.assertEqual(td.get_ops, [])
2128 self.assertEqual(td.set_ops, [dict, list])
2129 self.assertEqual(td.data[dict], g.registry[object])
2130 self.assertEqual(td.data[list], g.registry[object])
2131 self.assertEqual(td.data[dict], td.data[list])
2132 self.assertEqual(g(l), "base")
2133 self.assertEqual(g(d), "base")
2134 self.assertEqual(td.get_ops, [list, dict])
2135 self.assertEqual(td.set_ops, [dict, list])
2136 g.register(list, lambda arg: "list")
2137 self.assertEqual(td.get_ops, [list, dict])
2138 self.assertEqual(len(td), 0)
2139 self.assertEqual(g(d), "base")
2140 self.assertEqual(len(td), 1)
2141 self.assertEqual(td.get_ops, [list, dict])
2142 self.assertEqual(td.set_ops, [dict, list, dict])
2143 self.assertEqual(td.data[dict],
2144 functools._find_impl(dict, g.registry))
2145 self.assertEqual(g(l), "list")
2146 self.assertEqual(len(td), 2)
2147 self.assertEqual(td.get_ops, [list, dict])
2148 self.assertEqual(td.set_ops, [dict, list, dict, list])
2149 self.assertEqual(td.data[list],
2150 functools._find_impl(list, g.registry))
2151 class X:
2152 pass
2153 c.MutableMapping.register(X) # Will not invalidate the cache,
2154 # not using ABCs yet.
2155 self.assertEqual(g(d), "base")
2156 self.assertEqual(g(l), "list")
2157 self.assertEqual(td.get_ops, [list, dict, dict, list])
2158 self.assertEqual(td.set_ops, [dict, list, dict, list])
2159 g.register(c.Sized, lambda arg: "sized")
2160 self.assertEqual(len(td), 0)
2161 self.assertEqual(g(d), "sized")
2162 self.assertEqual(len(td), 1)
2163 self.assertEqual(td.get_ops, [list, dict, dict, list])
2164 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
2165 self.assertEqual(g(l), "list")
2166 self.assertEqual(len(td), 2)
2167 self.assertEqual(td.get_ops, [list, dict, dict, list])
2168 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2169 self.assertEqual(g(l), "list")
2170 self.assertEqual(g(d), "sized")
2171 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
2172 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2173 g.dispatch(list)
2174 g.dispatch(dict)
2175 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
2176 list, dict])
2177 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2178 c.MutableSet.register(X) # Will invalidate the cache.
2179 self.assertEqual(len(td), 2) # Stale cache.
2180 self.assertEqual(g(l), "list")
2181 self.assertEqual(len(td), 1)
2182 g.register(c.MutableMapping, lambda arg: "mutablemapping")
2183 self.assertEqual(len(td), 0)
2184 self.assertEqual(g(d), "mutablemapping")
2185 self.assertEqual(len(td), 1)
2186 self.assertEqual(g(l), "list")
2187 self.assertEqual(len(td), 2)
2188 g.register(dict, lambda arg: "dict")
2189 self.assertEqual(g(d), "dict")
2190 self.assertEqual(g(l), "list")
2191 g._clear_cache()
2192 self.assertEqual(len(td), 0)
Łukasz Langa6f692512013-06-05 12:20:24 +02002193
Łukasz Langae5697532017-12-11 13:56:31 -08002194 def test_annotations(self):
2195 @functools.singledispatch
2196 def i(arg):
2197 return "base"
2198 @i.register
2199 def _(arg: collections.abc.Mapping):
2200 return "mapping"
2201 @i.register
2202 def _(arg: "collections.abc.Sequence"):
2203 return "sequence"
2204 self.assertEqual(i(None), "base")
2205 self.assertEqual(i({"a": 1}), "mapping")
2206 self.assertEqual(i([1, 2, 3]), "sequence")
2207 self.assertEqual(i((1, 2, 3)), "sequence")
2208 self.assertEqual(i("str"), "sequence")
2209
2210 # Registering classes as callables doesn't work with annotations,
2211 # you need to pass the type explicitly.
2212 @i.register(str)
2213 class _:
2214 def __init__(self, arg):
2215 self.arg = arg
2216
2217 def __eq__(self, other):
2218 return self.arg == other
2219 self.assertEqual(i("str"), "str")
2220
Ethan Smithc6512752018-05-26 16:38:33 -04002221 def test_method_register(self):
2222 class A:
2223 @functools.singledispatchmethod
2224 def t(self, arg):
2225 self.arg = "base"
2226 @t.register(int)
2227 def _(self, arg):
2228 self.arg = "int"
2229 @t.register(str)
2230 def _(self, arg):
2231 self.arg = "str"
2232 a = A()
2233
2234 a.t(0)
2235 self.assertEqual(a.arg, "int")
2236 aa = A()
2237 self.assertFalse(hasattr(aa, 'arg'))
2238 a.t('')
2239 self.assertEqual(a.arg, "str")
2240 aa = A()
2241 self.assertFalse(hasattr(aa, 'arg'))
2242 a.t(0.0)
2243 self.assertEqual(a.arg, "base")
2244 aa = A()
2245 self.assertFalse(hasattr(aa, 'arg'))
2246
2247 def test_staticmethod_register(self):
2248 class A:
2249 @functools.singledispatchmethod
2250 @staticmethod
2251 def t(arg):
2252 return arg
2253 @t.register(int)
2254 @staticmethod
2255 def _(arg):
2256 return isinstance(arg, int)
2257 @t.register(str)
2258 @staticmethod
2259 def _(arg):
2260 return isinstance(arg, str)
2261 a = A()
2262
2263 self.assertTrue(A.t(0))
2264 self.assertTrue(A.t(''))
2265 self.assertEqual(A.t(0.0), 0.0)
2266
2267 def test_classmethod_register(self):
2268 class A:
2269 def __init__(self, arg):
2270 self.arg = arg
2271
2272 @functools.singledispatchmethod
2273 @classmethod
2274 def t(cls, arg):
2275 return cls("base")
2276 @t.register(int)
2277 @classmethod
2278 def _(cls, arg):
2279 return cls("int")
2280 @t.register(str)
2281 @classmethod
2282 def _(cls, arg):
2283 return cls("str")
2284
2285 self.assertEqual(A.t(0).arg, "int")
2286 self.assertEqual(A.t('').arg, "str")
2287 self.assertEqual(A.t(0.0).arg, "base")
2288
2289 def test_callable_register(self):
2290 class A:
2291 def __init__(self, arg):
2292 self.arg = arg
2293
2294 @functools.singledispatchmethod
2295 @classmethod
2296 def t(cls, arg):
2297 return cls("base")
2298
2299 @A.t.register(int)
2300 @classmethod
2301 def _(cls, arg):
2302 return cls("int")
2303 @A.t.register(str)
2304 @classmethod
2305 def _(cls, arg):
2306 return cls("str")
2307
2308 self.assertEqual(A.t(0).arg, "int")
2309 self.assertEqual(A.t('').arg, "str")
2310 self.assertEqual(A.t(0.0).arg, "base")
2311
2312 def test_abstractmethod_register(self):
2313 class Abstract(abc.ABCMeta):
2314
2315 @functools.singledispatchmethod
2316 @abc.abstractmethod
2317 def add(self, x, y):
2318 pass
2319
2320 self.assertTrue(Abstract.add.__isabstractmethod__)
2321
2322 def test_type_ann_register(self):
2323 class A:
2324 @functools.singledispatchmethod
2325 def t(self, arg):
2326 return "base"
2327 @t.register
2328 def _(self, arg: int):
2329 return "int"
2330 @t.register
2331 def _(self, arg: str):
2332 return "str"
2333 a = A()
2334
2335 self.assertEqual(a.t(0), "int")
2336 self.assertEqual(a.t(''), "str")
2337 self.assertEqual(a.t(0.0), "base")
2338
Łukasz Langae5697532017-12-11 13:56:31 -08002339 def test_invalid_registrations(self):
2340 msg_prefix = "Invalid first argument to `register()`: "
2341 msg_suffix = (
2342 ". Use either `@register(some_class)` or plain `@register` on an "
2343 "annotated function."
2344 )
2345 @functools.singledispatch
2346 def i(arg):
2347 return "base"
2348 with self.assertRaises(TypeError) as exc:
2349 @i.register(42)
2350 def _(arg):
2351 return "I annotated with a non-type"
2352 self.assertTrue(str(exc.exception).startswith(msg_prefix + "42"))
2353 self.assertTrue(str(exc.exception).endswith(msg_suffix))
2354 with self.assertRaises(TypeError) as exc:
2355 @i.register
2356 def _(arg):
2357 return "I forgot to annotate"
2358 self.assertTrue(str(exc.exception).startswith(msg_prefix +
2359 "<function TestSingleDispatch.test_invalid_registrations.<locals>._"
2360 ))
2361 self.assertTrue(str(exc.exception).endswith(msg_suffix))
2362
Łukasz Langae5697532017-12-11 13:56:31 -08002363 with self.assertRaises(TypeError) as exc:
2364 @i.register
2365 def _(arg: typing.Iterable[str]):
2366 # At runtime, dispatching on generics is impossible.
2367 # When registering implementations with singledispatch, avoid
2368 # types from `typing`. Instead, annotate with regular types
2369 # or ABCs.
2370 return "I annotated with a generic collection"
Lysandros Nikolaoud6738102019-05-20 00:11:21 +02002371 self.assertTrue(str(exc.exception).startswith(
2372 "Invalid annotation for 'arg'."
Łukasz Langae5697532017-12-11 13:56:31 -08002373 ))
Lysandros Nikolaoud6738102019-05-20 00:11:21 +02002374 self.assertTrue(str(exc.exception).endswith(
2375 'typing.Iterable[str] is not a class.'
2376 ))
Łukasz Langae5697532017-12-11 13:56:31 -08002377
Dong-hee Na445f1b32018-07-10 16:26:36 +09002378 def test_invalid_positional_argument(self):
2379 @functools.singledispatch
2380 def f(*args):
2381 pass
2382 msg = 'f requires at least 1 positional argument'
INADA Naoki56d8f572018-07-17 13:44:47 +09002383 with self.assertRaisesRegex(TypeError, msg):
Dong-hee Na445f1b32018-07-10 16:26:36 +09002384 f()
Łukasz Langa6f692512013-06-05 12:20:24 +02002385
Carl Meyerd658dea2018-08-28 01:11:56 -06002386
2387class CachedCostItem:
2388 _cost = 1
2389
2390 def __init__(self):
2391 self.lock = py_functools.RLock()
2392
2393 @py_functools.cached_property
2394 def cost(self):
2395 """The cost of the item."""
2396 with self.lock:
2397 self._cost += 1
2398 return self._cost
2399
2400
2401class OptionallyCachedCostItem:
2402 _cost = 1
2403
2404 def get_cost(self):
2405 """The cost of the item."""
2406 self._cost += 1
2407 return self._cost
2408
2409 cached_cost = py_functools.cached_property(get_cost)
2410
2411
2412class CachedCostItemWait:
2413
2414 def __init__(self, event):
2415 self._cost = 1
2416 self.lock = py_functools.RLock()
2417 self.event = event
2418
2419 @py_functools.cached_property
2420 def cost(self):
2421 self.event.wait(1)
2422 with self.lock:
2423 self._cost += 1
2424 return self._cost
2425
2426
2427class CachedCostItemWithSlots:
2428 __slots__ = ('_cost')
2429
2430 def __init__(self):
2431 self._cost = 1
2432
2433 @py_functools.cached_property
2434 def cost(self):
2435 raise RuntimeError('never called, slots not supported')
2436
2437
2438class TestCachedProperty(unittest.TestCase):
2439 def test_cached(self):
2440 item = CachedCostItem()
2441 self.assertEqual(item.cost, 2)
2442 self.assertEqual(item.cost, 2) # not 3
2443
2444 def test_cached_attribute_name_differs_from_func_name(self):
2445 item = OptionallyCachedCostItem()
2446 self.assertEqual(item.get_cost(), 2)
2447 self.assertEqual(item.cached_cost, 3)
2448 self.assertEqual(item.get_cost(), 4)
2449 self.assertEqual(item.cached_cost, 3)
2450
2451 def test_threaded(self):
2452 go = threading.Event()
2453 item = CachedCostItemWait(go)
2454
2455 num_threads = 3
2456
2457 orig_si = sys.getswitchinterval()
2458 sys.setswitchinterval(1e-6)
2459 try:
2460 threads = [
2461 threading.Thread(target=lambda: item.cost)
2462 for k in range(num_threads)
2463 ]
2464 with support.start_threads(threads):
2465 go.set()
2466 finally:
2467 sys.setswitchinterval(orig_si)
2468
2469 self.assertEqual(item.cost, 2)
2470
2471 def test_object_with_slots(self):
2472 item = CachedCostItemWithSlots()
2473 with self.assertRaisesRegex(
2474 TypeError,
2475 "No '__dict__' attribute on 'CachedCostItemWithSlots' instance to cache 'cost' property.",
2476 ):
2477 item.cost
2478
2479 def test_immutable_dict(self):
2480 class MyMeta(type):
2481 @py_functools.cached_property
2482 def prop(self):
2483 return True
2484
2485 class MyClass(metaclass=MyMeta):
2486 pass
2487
2488 with self.assertRaisesRegex(
2489 TypeError,
2490 "The '__dict__' attribute on 'MyMeta' instance does not support item assignment for caching 'prop' property.",
2491 ):
2492 MyClass.prop
2493
2494 def test_reuse_different_names(self):
2495 """Disallow this case because decorated function a would not be cached."""
2496 with self.assertRaises(RuntimeError) as ctx:
2497 class ReusedCachedProperty:
2498 @py_functools.cached_property
2499 def a(self):
2500 pass
2501
2502 b = a
2503
2504 self.assertEqual(
2505 str(ctx.exception.__context__),
2506 str(TypeError("Cannot assign the same cached_property to two different names ('a' and 'b')."))
2507 )
2508
2509 def test_reuse_same_name(self):
2510 """Reusing a cached_property on different classes under the same name is OK."""
2511 counter = 0
2512
2513 @py_functools.cached_property
2514 def _cp(_self):
2515 nonlocal counter
2516 counter += 1
2517 return counter
2518
2519 class A:
2520 cp = _cp
2521
2522 class B:
2523 cp = _cp
2524
2525 a = A()
2526 b = B()
2527
2528 self.assertEqual(a.cp, 1)
2529 self.assertEqual(b.cp, 2)
2530 self.assertEqual(a.cp, 1)
2531
2532 def test_set_name_not_called(self):
2533 cp = py_functools.cached_property(lambda s: None)
2534 class Foo:
2535 pass
2536
2537 Foo.cp = cp
2538
2539 with self.assertRaisesRegex(
2540 TypeError,
2541 "Cannot use cached_property instance without calling __set_name__ on it.",
2542 ):
2543 Foo().cp
2544
2545 def test_access_from_class(self):
2546 self.assertIsInstance(CachedCostItem.cost, py_functools.cached_property)
2547
2548 def test_doc(self):
2549 self.assertEqual(CachedCostItem.cost.__doc__, "The cost of the item.")
2550
2551
Raymond Hettinger9c323f82005-02-28 19:39:44 +00002552if __name__ == '__main__':
Zachary Ware38c707e2015-04-13 15:00:43 -05002553 unittest.main()