blob: c300270d49e5ec4f2600e5684840b2eaecac704e [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08002import builtins
Raymond Hettinger003be522011-05-03 11:01:32 -07003import collections
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03004import collections.abc
Serhiy Storchaka45120f22015-10-24 09:49:56 +03005import copy
Łukasz Langa6f692512013-06-05 12:20:24 +02006from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00007import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00008from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02009import sys
10from test import support
Antoine Pitroua6a4dc82017-09-07 18:56:24 +020011import threading
Serhiy Storchaka67796522017-01-12 18:34:33 +020012import time
Łukasz Langae5697532017-12-11 13:56:31 -080013import typing
Łukasz Langa6f692512013-06-05 12:20:24 +020014import unittest
Raymond Hettingerd191ef22017-01-07 20:44:48 -080015import unittest.mock
Łukasz Langa6f692512013-06-05 12:20:24 +020016from weakref import proxy
Nick Coghlan457fc9a2016-09-10 20:00:02 +100017import contextlib
Raymond Hettinger9c323f82005-02-28 19:39:44 +000018
Antoine Pitroub5b37142012-11-13 21:35:40 +010019import functools
20
Antoine Pitroub5b37142012-11-13 21:35:40 +010021py_functools = support.import_fresh_module('functools', blocked=['_functools'])
22c_functools = support.import_fresh_module('functools', fresh=['_functools'])
23
Łukasz Langa6f692512013-06-05 12:20:24 +020024decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
25
Nick Coghlan457fc9a2016-09-10 20:00:02 +100026@contextlib.contextmanager
27def replaced_module(name, replacement):
28 original_module = sys.modules[name]
29 sys.modules[name] = replacement
30 try:
31 yield
32 finally:
33 sys.modules[name] = original_module
Łukasz Langa6f692512013-06-05 12:20:24 +020034
Raymond Hettinger9c323f82005-02-28 19:39:44 +000035def capture(*args, **kw):
36 """capture all positional and keyword arguments"""
37 return args, kw
38
Łukasz Langa6f692512013-06-05 12:20:24 +020039
Jack Diederiche0cbd692009-04-01 04:27:09 +000040def signature(part):
41 """ return the signature of a partial object """
42 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000043
Serhiy Storchaka38741282016-02-02 18:45:17 +020044class MyTuple(tuple):
45 pass
46
47class BadTuple(tuple):
48 def __add__(self, other):
49 return list(self) + list(other)
50
51class MyDict(dict):
52 pass
53
Łukasz Langa6f692512013-06-05 12:20:24 +020054
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020055class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000056
57 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010058 p = self.partial(capture, 1, 2, a=10, b=20)
59 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000060 self.assertEqual(p(3, 4, b=30, c=40),
61 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010062 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000063 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000064
65 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010066 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000067 # attributes should be readable
68 self.assertEqual(p.func, capture)
69 self.assertEqual(p.args, (1, 2))
70 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000071
72 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010073 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000074 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010075 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000076 except TypeError:
77 pass
78 else:
79 self.fail('First arg not checked for callability')
80
81 def test_protection_of_callers_dict_argument(self):
82 # a caller's dictionary should not be altered by partial
83 def func(a=10, b=20):
84 return a
85 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010086 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000087 self.assertEqual(p(**d), 3)
88 self.assertEqual(d, {'a':3})
89 p(b=7)
90 self.assertEqual(d, {'a':3})
91
Serhiy Storchaka9639e4a2017-02-20 14:04:30 +020092 def test_kwargs_copy(self):
93 # Issue #29532: Altering a kwarg dictionary passed to a constructor
94 # should not affect a partial object after creation
95 d = {'a': 3}
96 p = self.partial(capture, **d)
97 self.assertEqual(p(), ((), {'a': 3}))
98 d['a'] = 5
99 self.assertEqual(p(), ((), {'a': 3}))
100
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000101 def test_arg_combinations(self):
102 # exercise special code paths for zero args in either partial
103 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100104 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000105 self.assertEqual(p(), ((), {}))
106 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100107 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000108 self.assertEqual(p(), ((1,2), {}))
109 self.assertEqual(p(3,4), ((1,2,3,4), {}))
110
111 def test_kw_combinations(self):
112 # exercise special code paths for no keyword args in
113 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100114 p = self.partial(capture)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400115 self.assertEqual(p.keywords, {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000116 self.assertEqual(p(), ((), {}))
117 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100118 p = self.partial(capture, a=1)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400119 self.assertEqual(p.keywords, {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000120 self.assertEqual(p(), ((), {'a':1}))
121 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
122 # keyword args in the call override those in the partial object
123 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
124
125 def test_positional(self):
126 # make sure positional arguments are captured correctly
127 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100128 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000129 expected = args + ('x',)
130 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000131 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000132
133 def test_keyword(self):
134 # make sure keyword arguments are captured correctly
135 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100136 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000137 expected = {'a':a,'x':None}
138 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000139 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000140
141 def test_no_side_effects(self):
142 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100143 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000144 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000145 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000146 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000147 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000148
149 def test_error_propagation(self):
150 def f(x, y):
151 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100152 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
153 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
154 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
155 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000156
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000157 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100158 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000159 p = proxy(f)
160 self.assertEqual(f.func, p.func)
161 f = None
162 self.assertRaises(ReferenceError, getattr, p, 'func')
163
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000164 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000165 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100166 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000167 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100168 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000169 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000170
Alexander Belopolskye49af342015-03-01 15:08:17 -0500171 def test_nested_optimization(self):
172 partial = self.partial
Alexander Belopolskye49af342015-03-01 15:08:17 -0500173 inner = partial(signature, 'asdf')
174 nested = partial(inner, bar=True)
175 flat = partial(signature, 'asdf', bar=True)
176 self.assertEqual(signature(nested), signature(flat))
177
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300178 def test_nested_partial_with_attribute(self):
179 # see issue 25137
180 partial = self.partial
181
182 def foo(bar):
183 return bar
184
185 p = partial(foo, 'first')
186 p2 = partial(p, 'second')
187 p2.new_attr = 'spam'
188 self.assertEqual(p2.new_attr, 'spam')
189
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000190 def test_repr(self):
191 args = (object(), object())
192 args_repr = ', '.join(repr(a) for a in args)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200193 kwargs = {'a': object(), 'b': object()}
194 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
195 'b={b!r}, a={a!r}'.format_map(kwargs)]
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000196 if self.partial in (c_functools.partial, py_functools.partial):
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000197 name = 'functools.partial'
198 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100199 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000200
Antoine Pitroub5b37142012-11-13 21:35:40 +0100201 f = self.partial(capture)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000202 self.assertEqual(f'{name}({capture!r})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000203
Antoine Pitroub5b37142012-11-13 21:35:40 +0100204 f = self.partial(capture, *args)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000205 self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000206
Antoine Pitroub5b37142012-11-13 21:35:40 +0100207 f = self.partial(capture, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200208 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000209 [f'{name}({capture!r}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200210 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000211
Antoine Pitroub5b37142012-11-13 21:35:40 +0100212 f = self.partial(capture, *args, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200213 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000214 [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200215 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000216
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300217 def test_recursive_repr(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000218 if self.partial in (c_functools.partial, py_functools.partial):
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300219 name = 'functools.partial'
220 else:
221 name = self.partial.__name__
222
223 f = self.partial(capture)
224 f.__setstate__((f, (), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300225 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000226 self.assertEqual(repr(f), '%s(...)' % (name,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300227 finally:
228 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300229
230 f = self.partial(capture)
231 f.__setstate__((capture, (f,), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300232 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000233 self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300234 finally:
235 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300236
237 f = self.partial(capture)
238 f.__setstate__((capture, (), {'a': f}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300239 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000240 self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300241 finally:
242 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300243
Jack Diederiche0cbd692009-04-01 04:27:09 +0000244 def test_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000245 with self.AllowPickle():
246 f = self.partial(signature, ['asdf'], bar=[True])
247 f.attr = []
248 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
249 f_copy = pickle.loads(pickle.dumps(f, proto))
250 self.assertEqual(signature(f_copy), signature(f))
Serhiy Storchaka38741282016-02-02 18:45:17 +0200251
252 def test_copy(self):
253 f = self.partial(signature, ['asdf'], bar=[True])
254 f.attr = []
255 f_copy = copy.copy(f)
256 self.assertEqual(signature(f_copy), signature(f))
257 self.assertIs(f_copy.attr, f.attr)
258 self.assertIs(f_copy.args, f.args)
259 self.assertIs(f_copy.keywords, f.keywords)
260
261 def test_deepcopy(self):
262 f = self.partial(signature, ['asdf'], bar=[True])
263 f.attr = []
264 f_copy = copy.deepcopy(f)
265 self.assertEqual(signature(f_copy), signature(f))
266 self.assertIsNot(f_copy.attr, f.attr)
267 self.assertIsNot(f_copy.args, f.args)
268 self.assertIsNot(f_copy.args[0], f.args[0])
269 self.assertIsNot(f_copy.keywords, f.keywords)
270 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
271
272 def test_setstate(self):
273 f = self.partial(signature)
274 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000275
Serhiy Storchaka38741282016-02-02 18:45:17 +0200276 self.assertEqual(signature(f),
277 (capture, (1,), dict(a=10), dict(attr=[])))
278 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
279
280 f.__setstate__((capture, (1,), dict(a=10), None))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000281
Serhiy Storchaka38741282016-02-02 18:45:17 +0200282 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
283 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
284
285 f.__setstate__((capture, (1,), None, None))
286 #self.assertEqual(signature(f), (capture, (1,), {}, {}))
287 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
288 self.assertEqual(f(2), ((1, 2), {}))
289 self.assertEqual(f(), ((1,), {}))
290
291 f.__setstate__((capture, (), {}, None))
292 self.assertEqual(signature(f), (capture, (), {}, {}))
293 self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
294 self.assertEqual(f(2), ((2,), {}))
295 self.assertEqual(f(), ((), {}))
296
297 def test_setstate_errors(self):
298 f = self.partial(signature)
299 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
300 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
301 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
302 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
303 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
304 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
305 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
306
307 def test_setstate_subclasses(self):
308 f = self.partial(signature)
309 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
310 s = signature(f)
311 self.assertEqual(s, (capture, (1,), dict(a=10), {}))
312 self.assertIs(type(s[1]), tuple)
313 self.assertIs(type(s[2]), dict)
314 r = f()
315 self.assertEqual(r, ((1,), {'a': 10}))
316 self.assertIs(type(r[0]), tuple)
317 self.assertIs(type(r[1]), dict)
318
319 f.__setstate__((capture, BadTuple((1,)), {}, None))
320 s = signature(f)
321 self.assertEqual(s, (capture, (1,), {}, {}))
322 self.assertIs(type(s[1]), tuple)
323 r = f(2)
324 self.assertEqual(r, ((1, 2), {}))
325 self.assertIs(type(r[0]), tuple)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000326
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300327 def test_recursive_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000328 with self.AllowPickle():
329 f = self.partial(capture)
330 f.__setstate__((f, (), {}, {}))
331 try:
332 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
333 with self.assertRaises(RecursionError):
334 pickle.dumps(f, proto)
335 finally:
336 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300337
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000338 f = self.partial(capture)
339 f.__setstate__((capture, (f,), {}, {}))
340 try:
341 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
342 f_copy = pickle.loads(pickle.dumps(f, proto))
343 try:
344 self.assertIs(f_copy.args[0], f_copy)
345 finally:
346 f_copy.__setstate__((capture, (), {}, {}))
347 finally:
348 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300349
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000350 f = self.partial(capture)
351 f.__setstate__((capture, (), {'a': f}, {}))
352 try:
353 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
354 f_copy = pickle.loads(pickle.dumps(f, proto))
355 try:
356 self.assertIs(f_copy.keywords['a'], f_copy)
357 finally:
358 f_copy.__setstate__((capture, (), {}, {}))
359 finally:
360 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300361
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200362 # Issue 6083: Reference counting bug
363 def test_setstate_refcount(self):
364 class BadSequence:
365 def __len__(self):
366 return 4
367 def __getitem__(self, key):
368 if key == 0:
369 return max
370 elif key == 1:
371 return tuple(range(1000000))
372 elif key in (2, 3):
373 return {}
374 raise IndexError
375
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200376 f = self.partial(object)
Serhiy Storchaka38741282016-02-02 18:45:17 +0200377 self.assertRaises(TypeError, f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000378
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000379@unittest.skipUnless(c_functools, 'requires the C _functools module')
380class TestPartialC(TestPartial, unittest.TestCase):
381 if c_functools:
382 partial = c_functools.partial
383
384 class AllowPickle:
385 def __enter__(self):
386 return self
387 def __exit__(self, type, value, tb):
388 return False
389
390 def test_attributes_unwritable(self):
391 # attributes should not be writable
392 p = self.partial(capture, 1, 2, a=10, b=20)
393 self.assertRaises(AttributeError, setattr, p, 'func', map)
394 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
395 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
396
397 p = self.partial(hex)
398 try:
399 del p.__dict__
400 except TypeError:
401 pass
402 else:
403 self.fail('partial object allowed __dict__ to be deleted')
Łukasz Langa6f692512013-06-05 12:20:24 +0200404
Michael Seifert6c3d5272017-03-15 06:26:33 +0100405 def test_manually_adding_non_string_keyword(self):
406 p = self.partial(capture)
407 # Adding a non-string/unicode keyword to partial kwargs
408 p.keywords[1234] = 'value'
409 r = repr(p)
410 self.assertIn('1234', r)
411 self.assertIn("'value'", r)
412 with self.assertRaises(TypeError):
413 p()
414
415 def test_keystr_replaces_value(self):
416 p = self.partial(capture)
417
418 class MutatesYourDict(object):
419 def __str__(self):
420 p.keywords[self] = ['sth2']
421 return 'astr'
422
Mike53f7a7c2017-12-14 14:04:53 +0300423 # Replacing the value during key formatting should keep the original
Michael Seifert6c3d5272017-03-15 06:26:33 +0100424 # value alive (at least long enough).
425 p.keywords[MutatesYourDict()] = ['sth']
426 r = repr(p)
427 self.assertIn('astr', r)
428 self.assertIn("['sth']", r)
429
430
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200431class TestPartialPy(TestPartial, unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000432 partial = py_functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000433
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000434 class AllowPickle:
435 def __init__(self):
436 self._cm = replaced_module("functools", py_functools)
437 def __enter__(self):
438 return self._cm.__enter__()
439 def __exit__(self, type, value, tb):
440 return self._cm.__exit__(type, value, tb)
Łukasz Langa6f692512013-06-05 12:20:24 +0200441
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200442if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000443 class CPartialSubclass(c_functools.partial):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200444 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100445
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000446class PyPartialSubclass(py_functools.partial):
447 pass
Łukasz Langa6f692512013-06-05 12:20:24 +0200448
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200449@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200450class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200451 if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000452 partial = CPartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000453
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300454 # partial subclasses are not optimized for nested calls
455 test_nested_optimization = None
456
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000457class TestPartialPySubclass(TestPartialPy):
458 partial = PyPartialSubclass
Łukasz Langa6f692512013-06-05 12:20:24 +0200459
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000460class TestPartialMethod(unittest.TestCase):
461
462 class A(object):
463 nothing = functools.partialmethod(capture)
464 positional = functools.partialmethod(capture, 1)
465 keywords = functools.partialmethod(capture, a=2)
466 both = functools.partialmethod(capture, 3, b=4)
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300467 spec_keywords = functools.partialmethod(capture, self=1, func=2)
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000468
469 nested = functools.partialmethod(positional, 5)
470
471 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
472
473 static = functools.partialmethod(staticmethod(capture), 8)
474 cls = functools.partialmethod(classmethod(capture), d=9)
475
476 a = A()
477
478 def test_arg_combinations(self):
479 self.assertEqual(self.a.nothing(), ((self.a,), {}))
480 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
481 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
482 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
483
484 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
485 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
486 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
487 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
488
489 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
490 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
491 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
492 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
493
494 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
495 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
496 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
497 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
498
499 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
500
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300501 self.assertEqual(self.a.spec_keywords(), ((self.a,), {'self': 1, 'func': 2}))
502
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000503 def test_nested(self):
504 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
505 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
506 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
507 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
508
509 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
510
511 def test_over_partial(self):
512 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
513 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
514 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
515 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
516
517 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
518
519 def test_bound_method_introspection(self):
520 obj = self.a
521 self.assertIs(obj.both.__self__, obj)
522 self.assertIs(obj.nested.__self__, obj)
523 self.assertIs(obj.over_partial.__self__, obj)
524 self.assertIs(obj.cls.__self__, self.A)
525 self.assertIs(self.A.cls.__self__, self.A)
526
527 def test_unbound_method_retrieval(self):
528 obj = self.A
529 self.assertFalse(hasattr(obj.both, "__self__"))
530 self.assertFalse(hasattr(obj.nested, "__self__"))
531 self.assertFalse(hasattr(obj.over_partial, "__self__"))
532 self.assertFalse(hasattr(obj.static, "__self__"))
533 self.assertFalse(hasattr(self.a.static, "__self__"))
534
535 def test_descriptors(self):
536 for obj in [self.A, self.a]:
537 with self.subTest(obj=obj):
538 self.assertEqual(obj.static(), ((8,), {}))
539 self.assertEqual(obj.static(5), ((8, 5), {}))
540 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
541 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
542
543 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
544 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
545 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
546 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
547
548 def test_overriding_keywords(self):
549 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
550 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
551
552 def test_invalid_args(self):
553 with self.assertRaises(TypeError):
554 class B(object):
555 method = functools.partialmethod(None, 1)
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300556 with self.assertRaises(TypeError):
557 class B:
558 method = functools.partialmethod()
Serhiy Storchaka142566c2019-06-05 18:22:31 +0300559 with self.assertRaises(TypeError):
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300560 class B:
561 method = functools.partialmethod(func=capture, a=1)
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000562
563 def test_repr(self):
564 self.assertEqual(repr(vars(self.A)['both']),
565 'functools.partialmethod({}, 3, b=4)'.format(capture))
566
567 def test_abstract(self):
568 class Abstract(abc.ABCMeta):
569
570 @abc.abstractmethod
571 def add(self, x, y):
572 pass
573
574 add5 = functools.partialmethod(add, 5)
575
576 self.assertTrue(Abstract.add.__isabstractmethod__)
577 self.assertTrue(Abstract.add5.__isabstractmethod__)
578
579 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
580 self.assertFalse(getattr(func, '__isabstractmethod__', False))
581
Pablo Galindo8c77b8c2019-04-29 13:36:57 +0100582 def test_positional_only(self):
583 def f(a, b, /):
584 return a + b
585
586 p = functools.partial(f, 1)
587 self.assertEqual(p(2), f(1, 2))
588
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000589
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000590class TestUpdateWrapper(unittest.TestCase):
591
592 def check_wrapper(self, wrapper, wrapped,
593 assigned=functools.WRAPPER_ASSIGNMENTS,
594 updated=functools.WRAPPER_UPDATES):
595 # Check attributes were assigned
596 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000597 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000598 # Check attributes were updated
599 for name in updated:
600 wrapper_attr = getattr(wrapper, name)
601 wrapped_attr = getattr(wrapped, name)
602 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000603 if name == "__dict__" and key == "__wrapped__":
604 # __wrapped__ is overwritten by the update code
605 continue
606 self.assertIs(wrapped_attr[key], wrapper_attr[key])
607 # Check __wrapped__
608 self.assertIs(wrapper.__wrapped__, wrapped)
609
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000610
R. David Murray378c0cf2010-02-24 01:46:21 +0000611 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000612 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000613 """This is a test"""
614 pass
615 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000616 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000617 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000618 pass
619 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000620 return wrapper, f
621
622 def test_default_update(self):
623 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000624 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000625 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000626 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600627 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000628 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000629 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
630 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000631
R. David Murray378c0cf2010-02-24 01:46:21 +0000632 @unittest.skipIf(sys.flags.optimize >= 2,
633 "Docstrings are omitted with -O2 and above")
634 def test_default_update_doc(self):
635 wrapper, f = self._default_update()
636 self.assertEqual(wrapper.__doc__, 'This is a test')
637
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000638 def test_no_update(self):
639 def f():
640 """This is a test"""
641 pass
642 f.attr = 'This is also a test'
643 def wrapper():
644 pass
645 functools.update_wrapper(wrapper, f, (), ())
646 self.check_wrapper(wrapper, f, (), ())
647 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600648 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000649 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000650 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000651 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000652
653 def test_selective_update(self):
654 def f():
655 pass
656 f.attr = 'This is a different test'
657 f.dict_attr = dict(a=1, b=2, c=3)
658 def wrapper():
659 pass
660 wrapper.dict_attr = {}
661 assign = ('attr',)
662 update = ('dict_attr',)
663 functools.update_wrapper(wrapper, f, assign, update)
664 self.check_wrapper(wrapper, f, assign, update)
665 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600666 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000667 self.assertEqual(wrapper.__doc__, None)
668 self.assertEqual(wrapper.attr, 'This is a different test')
669 self.assertEqual(wrapper.dict_attr, f.dict_attr)
670
Nick Coghlan98876832010-08-17 06:17:18 +0000671 def test_missing_attributes(self):
672 def f():
673 pass
674 def wrapper():
675 pass
676 wrapper.dict_attr = {}
677 assign = ('attr',)
678 update = ('dict_attr',)
679 # Missing attributes on wrapped object are ignored
680 functools.update_wrapper(wrapper, f, assign, update)
681 self.assertNotIn('attr', wrapper.__dict__)
682 self.assertEqual(wrapper.dict_attr, {})
683 # Wrapper must have expected attributes for updating
684 del wrapper.dict_attr
685 with self.assertRaises(AttributeError):
686 functools.update_wrapper(wrapper, f, assign, update)
687 wrapper.dict_attr = 1
688 with self.assertRaises(AttributeError):
689 functools.update_wrapper(wrapper, f, assign, update)
690
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200691 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000692 @unittest.skipIf(sys.flags.optimize >= 2,
693 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000694 def test_builtin_update(self):
695 # Test for bug #1576241
696 def wrapper():
697 pass
698 functools.update_wrapper(wrapper, max)
699 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000700 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000701 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000702
Łukasz Langa6f692512013-06-05 12:20:24 +0200703
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000704class TestWraps(TestUpdateWrapper):
705
R. David Murray378c0cf2010-02-24 01:46:21 +0000706 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000707 def f():
708 """This is a test"""
709 pass
710 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000711 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000712 @functools.wraps(f)
713 def wrapper():
714 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600715 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000716
717 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600718 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000719 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000720 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600721 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000722 self.assertEqual(wrapper.attr, 'This is also a test')
723
Antoine Pitroub5b37142012-11-13 21:35:40 +0100724 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000725 "Docstrings are omitted with -O2 and above")
726 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600727 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000728 self.assertEqual(wrapper.__doc__, 'This is a test')
729
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000730 def test_no_update(self):
731 def f():
732 """This is a test"""
733 pass
734 f.attr = 'This is also a test'
735 @functools.wraps(f, (), ())
736 def wrapper():
737 pass
738 self.check_wrapper(wrapper, f, (), ())
739 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600740 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000741 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000742 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000743
744 def test_selective_update(self):
745 def f():
746 pass
747 f.attr = 'This is a different test'
748 f.dict_attr = dict(a=1, b=2, c=3)
749 def add_dict_attr(f):
750 f.dict_attr = {}
751 return f
752 assign = ('attr',)
753 update = ('dict_attr',)
754 @functools.wraps(f, assign, update)
755 @add_dict_attr
756 def wrapper():
757 pass
758 self.check_wrapper(wrapper, f, assign, update)
759 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600760 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000761 self.assertEqual(wrapper.__doc__, None)
762 self.assertEqual(wrapper.attr, 'This is a different test')
763 self.assertEqual(wrapper.dict_attr, f.dict_attr)
764
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000765
madman-bobe25d5fc2018-10-25 15:02:10 +0100766class TestReduce:
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000767 def test_reduce(self):
768 class Squares:
769 def __init__(self, max):
770 self.max = max
771 self.sofar = []
772
773 def __len__(self):
774 return len(self.sofar)
775
776 def __getitem__(self, i):
777 if not 0 <= i < self.max: raise IndexError
778 n = len(self.sofar)
779 while n <= i:
780 self.sofar.append(n*n)
781 n += 1
782 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000783 def add(x, y):
784 return x + y
madman-bobe25d5fc2018-10-25 15:02:10 +0100785 self.assertEqual(self.reduce(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000786 self.assertEqual(
madman-bobe25d5fc2018-10-25 15:02:10 +0100787 self.reduce(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000788 ['a','c','d','w']
789 )
madman-bobe25d5fc2018-10-25 15:02:10 +0100790 self.assertEqual(self.reduce(lambda x, y: x*y, range(2,8), 1), 5040)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000791 self.assertEqual(
madman-bobe25d5fc2018-10-25 15:02:10 +0100792 self.reduce(lambda x, y: x*y, range(2,21), 1),
Guido van Rossume2a383d2007-01-15 16:59:06 +0000793 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000794 )
madman-bobe25d5fc2018-10-25 15:02:10 +0100795 self.assertEqual(self.reduce(add, Squares(10)), 285)
796 self.assertEqual(self.reduce(add, Squares(10), 0), 285)
797 self.assertEqual(self.reduce(add, Squares(0), 0), 0)
798 self.assertRaises(TypeError, self.reduce)
799 self.assertRaises(TypeError, self.reduce, 42, 42)
800 self.assertRaises(TypeError, self.reduce, 42, 42, 42)
801 self.assertEqual(self.reduce(42, "1"), "1") # func is never called with one item
802 self.assertEqual(self.reduce(42, "", "1"), "1") # func is never called with one item
803 self.assertRaises(TypeError, self.reduce, 42, (42, 42))
804 self.assertRaises(TypeError, self.reduce, add, []) # arg 2 must not be empty sequence with no initial value
805 self.assertRaises(TypeError, self.reduce, add, "")
806 self.assertRaises(TypeError, self.reduce, add, ())
807 self.assertRaises(TypeError, self.reduce, add, object())
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000808
809 class TestFailingIter:
810 def __iter__(self):
811 raise RuntimeError
madman-bobe25d5fc2018-10-25 15:02:10 +0100812 self.assertRaises(RuntimeError, self.reduce, add, TestFailingIter())
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000813
madman-bobe25d5fc2018-10-25 15:02:10 +0100814 self.assertEqual(self.reduce(add, [], None), None)
815 self.assertEqual(self.reduce(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000816
817 class BadSeq:
818 def __getitem__(self, index):
819 raise ValueError
madman-bobe25d5fc2018-10-25 15:02:10 +0100820 self.assertRaises(ValueError, self.reduce, 42, BadSeq())
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000821
822 # Test reduce()'s use of iterators.
823 def test_iterator_usage(self):
824 class SequenceClass:
825 def __init__(self, n):
826 self.n = n
827 def __getitem__(self, i):
828 if 0 <= i < self.n:
829 return i
830 else:
831 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000832
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000833 from operator import add
madman-bobe25d5fc2018-10-25 15:02:10 +0100834 self.assertEqual(self.reduce(add, SequenceClass(5)), 10)
835 self.assertEqual(self.reduce(add, SequenceClass(5), 42), 52)
836 self.assertRaises(TypeError, self.reduce, add, SequenceClass(0))
837 self.assertEqual(self.reduce(add, SequenceClass(0), 42), 42)
838 self.assertEqual(self.reduce(add, SequenceClass(1)), 0)
839 self.assertEqual(self.reduce(add, SequenceClass(1), 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000840
841 d = {"one": 1, "two": 2, "three": 3}
madman-bobe25d5fc2018-10-25 15:02:10 +0100842 self.assertEqual(self.reduce(add, d), "".join(d.keys()))
843
844
845@unittest.skipUnless(c_functools, 'requires the C _functools module')
846class TestReduceC(TestReduce, unittest.TestCase):
847 if c_functools:
848 reduce = c_functools.reduce
849
850
851class TestReducePy(TestReduce, unittest.TestCase):
852 reduce = staticmethod(py_functools.reduce)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000853
Łukasz Langa6f692512013-06-05 12:20:24 +0200854
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200855class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700856
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000857 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700858 def cmp1(x, y):
859 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100860 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700861 self.assertEqual(key(3), key(3))
862 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100863 self.assertGreaterEqual(key(3), key(3))
864
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700865 def cmp2(x, y):
866 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100867 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700868 self.assertEqual(key(4.0), key('4'))
869 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100870 self.assertLessEqual(key(2), key('35'))
871 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700872
873 def test_cmp_to_key_arguments(self):
874 def cmp1(x, y):
875 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100876 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700877 self.assertEqual(key(obj=3), key(obj=3))
878 self.assertGreater(key(obj=3), key(obj=1))
879 with self.assertRaises((TypeError, AttributeError)):
880 key(3) > 1 # rhs is not a K object
881 with self.assertRaises((TypeError, AttributeError)):
882 1 < key(3) # lhs is not a K object
883 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100884 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700885 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200886 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100887 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700888 with self.assertRaises(TypeError):
889 key() # too few args
890 with self.assertRaises(TypeError):
891 key(None, None) # too many args
892
893 def test_bad_cmp(self):
894 def cmp1(x, y):
895 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100896 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700897 with self.assertRaises(ZeroDivisionError):
898 key(3) > key(1)
899
900 class BadCmp:
901 def __lt__(self, other):
902 raise ZeroDivisionError
903 def cmp1(x, y):
904 return BadCmp()
905 with self.assertRaises(ZeroDivisionError):
906 key(3) > key(1)
907
908 def test_obj_field(self):
909 def cmp1(x, y):
910 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100911 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700912 self.assertEqual(key(50).obj, 50)
913
914 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000915 def mycmp(x, y):
916 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100917 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000918 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000919
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700920 def test_sort_int_str(self):
921 def mycmp(x, y):
922 x, y = int(x), int(y)
923 return (x > y) - (x < y)
924 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100925 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700926 self.assertEqual([int(value) for value in values],
927 [0, 1, 1, 2, 3, 4, 5, 7, 10])
928
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000929 def test_hash(self):
930 def mycmp(x, y):
931 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100932 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000933 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700934 self.assertRaises(TypeError, hash, k)
Serhiy Storchaka2e576f52017-04-24 09:05:00 +0300935 self.assertNotIsInstance(k, collections.abc.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000936
Łukasz Langa6f692512013-06-05 12:20:24 +0200937
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200938@unittest.skipUnless(c_functools, 'requires the C _functools module')
939class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
940 if c_functools:
941 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100942
Łukasz Langa6f692512013-06-05 12:20:24 +0200943
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200944class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100945 cmp_to_key = staticmethod(py_functools.cmp_to_key)
946
Łukasz Langa6f692512013-06-05 12:20:24 +0200947
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000948class TestTotalOrdering(unittest.TestCase):
949
950 def test_total_ordering_lt(self):
951 @functools.total_ordering
952 class A:
953 def __init__(self, value):
954 self.value = value
955 def __lt__(self, other):
956 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000957 def __eq__(self, other):
958 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000959 self.assertTrue(A(1) < A(2))
960 self.assertTrue(A(2) > A(1))
961 self.assertTrue(A(1) <= A(2))
962 self.assertTrue(A(2) >= A(1))
963 self.assertTrue(A(2) <= A(2))
964 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000965 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000966
967 def test_total_ordering_le(self):
968 @functools.total_ordering
969 class A:
970 def __init__(self, value):
971 self.value = value
972 def __le__(self, other):
973 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000974 def __eq__(self, other):
975 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000976 self.assertTrue(A(1) < A(2))
977 self.assertTrue(A(2) > A(1))
978 self.assertTrue(A(1) <= A(2))
979 self.assertTrue(A(2) >= A(1))
980 self.assertTrue(A(2) <= A(2))
981 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000982 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000983
984 def test_total_ordering_gt(self):
985 @functools.total_ordering
986 class A:
987 def __init__(self, value):
988 self.value = value
989 def __gt__(self, other):
990 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000991 def __eq__(self, other):
992 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000993 self.assertTrue(A(1) < A(2))
994 self.assertTrue(A(2) > A(1))
995 self.assertTrue(A(1) <= A(2))
996 self.assertTrue(A(2) >= A(1))
997 self.assertTrue(A(2) <= A(2))
998 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000999 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +00001000
1001 def test_total_ordering_ge(self):
1002 @functools.total_ordering
1003 class A:
1004 def __init__(self, value):
1005 self.value = value
1006 def __ge__(self, other):
1007 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001008 def __eq__(self, other):
1009 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +00001010 self.assertTrue(A(1) < A(2))
1011 self.assertTrue(A(2) > A(1))
1012 self.assertTrue(A(1) <= A(2))
1013 self.assertTrue(A(2) >= A(1))
1014 self.assertTrue(A(2) <= A(2))
1015 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +10001016 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +00001017
1018 def test_total_ordering_no_overwrite(self):
1019 # new methods should not overwrite existing
1020 @functools.total_ordering
1021 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +00001022 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +00001023 self.assertTrue(A(1) < A(2))
1024 self.assertTrue(A(2) > A(1))
1025 self.assertTrue(A(1) <= A(2))
1026 self.assertTrue(A(2) >= A(1))
1027 self.assertTrue(A(2) <= A(2))
1028 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001029
Benjamin Peterson42ebee32010-04-11 01:43:16 +00001030 def test_no_operations_defined(self):
1031 with self.assertRaises(ValueError):
1032 @functools.total_ordering
1033 class A:
1034 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001035
Nick Coghlanf05d9812013-10-02 00:02:03 +10001036 def test_type_error_when_not_implemented(self):
1037 # bug 10042; ensure stack overflow does not occur
1038 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001039 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +10001040 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001041 def __init__(self, value):
1042 self.value = value
1043 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001044 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001045 return self.value == other.value
1046 return False
1047 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001048 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001049 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +10001050 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001051
Nick Coghlanf05d9812013-10-02 00:02:03 +10001052 @functools.total_ordering
1053 class ImplementsGreaterThan:
1054 def __init__(self, value):
1055 self.value = value
1056 def __eq__(self, other):
1057 if isinstance(other, ImplementsGreaterThan):
1058 return self.value == other.value
1059 return False
1060 def __gt__(self, other):
1061 if isinstance(other, ImplementsGreaterThan):
1062 return self.value > other.value
1063 return NotImplemented
1064
1065 @functools.total_ordering
1066 class ImplementsLessThanEqualTo:
1067 def __init__(self, value):
1068 self.value = value
1069 def __eq__(self, other):
1070 if isinstance(other, ImplementsLessThanEqualTo):
1071 return self.value == other.value
1072 return False
1073 def __le__(self, other):
1074 if isinstance(other, ImplementsLessThanEqualTo):
1075 return self.value <= other.value
1076 return NotImplemented
1077
1078 @functools.total_ordering
1079 class ImplementsGreaterThanEqualTo:
1080 def __init__(self, value):
1081 self.value = value
1082 def __eq__(self, other):
1083 if isinstance(other, ImplementsGreaterThanEqualTo):
1084 return self.value == other.value
1085 return False
1086 def __ge__(self, other):
1087 if isinstance(other, ImplementsGreaterThanEqualTo):
1088 return self.value >= other.value
1089 return NotImplemented
1090
1091 @functools.total_ordering
1092 class ComparatorNotImplemented:
1093 def __init__(self, value):
1094 self.value = value
1095 def __eq__(self, other):
1096 if isinstance(other, ComparatorNotImplemented):
1097 return self.value == other.value
1098 return False
1099 def __lt__(self, other):
1100 return NotImplemented
1101
1102 with self.subTest("LT < 1"), self.assertRaises(TypeError):
1103 ImplementsLessThan(-1) < 1
1104
1105 with self.subTest("LT < LE"), self.assertRaises(TypeError):
1106 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
1107
1108 with self.subTest("LT < GT"), self.assertRaises(TypeError):
1109 ImplementsLessThan(1) < ImplementsGreaterThan(1)
1110
1111 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
1112 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
1113
1114 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
1115 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
1116
1117 with self.subTest("GT > GE"), self.assertRaises(TypeError):
1118 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
1119
1120 with self.subTest("GT > LT"), self.assertRaises(TypeError):
1121 ImplementsGreaterThan(5) > ImplementsLessThan(5)
1122
1123 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
1124 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
1125
1126 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
1127 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
1128
1129 with self.subTest("GE when equal"):
1130 a = ComparatorNotImplemented(8)
1131 b = ComparatorNotImplemented(8)
1132 self.assertEqual(a, b)
1133 with self.assertRaises(TypeError):
1134 a >= b
1135
1136 with self.subTest("LE when equal"):
1137 a = ComparatorNotImplemented(9)
1138 b = ComparatorNotImplemented(9)
1139 self.assertEqual(a, b)
1140 with self.assertRaises(TypeError):
1141 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +02001142
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001143 def test_pickle(self):
Serhiy Storchaka92bb90a2016-09-22 11:39:25 +03001144 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001145 for name in '__lt__', '__gt__', '__le__', '__ge__':
1146 with self.subTest(method=name, proto=proto):
1147 method = getattr(Orderable_LT, name)
1148 method_copy = pickle.loads(pickle.dumps(method, proto))
1149 self.assertIs(method_copy, method)
1150
1151@functools.total_ordering
1152class Orderable_LT:
1153 def __init__(self, value):
1154 self.value = value
1155 def __lt__(self, other):
1156 return self.value < other.value
1157 def __eq__(self, other):
1158 return self.value == other.value
1159
1160
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001161class TestLRU:
Georg Brandl2e7346a2010-07-31 18:09:23 +00001162
1163 def test_lru(self):
1164 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001165 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001166 f = self.module.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001167 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001168 self.assertEqual(maxsize, 20)
1169 self.assertEqual(currsize, 0)
1170 self.assertEqual(hits, 0)
1171 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001172
1173 domain = range(5)
1174 for i in range(1000):
1175 x, y = choice(domain), choice(domain)
1176 actual = f(x, y)
1177 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +00001178 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001179 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001180 self.assertTrue(hits > misses)
1181 self.assertEqual(hits + misses, 1000)
1182 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001183
Raymond Hettinger02566ec2010-09-04 22:46:06 +00001184 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +00001185 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001186 self.assertEqual(hits, 0)
1187 self.assertEqual(misses, 0)
1188 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001189 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001190 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001191 self.assertEqual(hits, 0)
1192 self.assertEqual(misses, 1)
1193 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001194
Nick Coghlan98876832010-08-17 06:17:18 +00001195 # Test bypassing the cache
1196 self.assertIs(f.__wrapped__, orig)
1197 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001198 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001199 self.assertEqual(hits, 0)
1200 self.assertEqual(misses, 1)
1201 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +00001202
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001203 # test size zero (which means "never-cache")
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001204 @self.module.lru_cache(0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001205 def f():
1206 nonlocal f_cnt
1207 f_cnt += 1
1208 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001209 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001210 f_cnt = 0
1211 for i in range(5):
1212 self.assertEqual(f(), 20)
1213 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001214 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001215 self.assertEqual(hits, 0)
1216 self.assertEqual(misses, 5)
1217 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001218
1219 # test size one
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001220 @self.module.lru_cache(1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001221 def f():
1222 nonlocal f_cnt
1223 f_cnt += 1
1224 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001225 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001226 f_cnt = 0
1227 for i in range(5):
1228 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001229 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001230 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001231 self.assertEqual(hits, 4)
1232 self.assertEqual(misses, 1)
1233 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001234
Raymond Hettingerf3098282010-08-15 03:30:45 +00001235 # test size two
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001236 @self.module.lru_cache(2)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001237 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001238 nonlocal f_cnt
1239 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +00001240 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +00001241 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001242 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +00001243 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1244 # * * * *
1245 self.assertEqual(f(x), x*10)
1246 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001247 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001248 self.assertEqual(hits, 12)
1249 self.assertEqual(misses, 4)
1250 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001251
Raymond Hettingerb8218682019-05-26 11:27:35 -07001252 def test_lru_no_args(self):
1253 @self.module.lru_cache
1254 def square(x):
1255 return x ** 2
1256
1257 self.assertEqual(list(map(square, [10, 20, 10])),
1258 [100, 400, 100])
1259 self.assertEqual(square.cache_info().hits, 1)
1260 self.assertEqual(square.cache_info().misses, 2)
1261 self.assertEqual(square.cache_info().maxsize, 128)
1262 self.assertEqual(square.cache_info().currsize, 2)
1263
Raymond Hettingerd8080c02019-01-26 03:02:00 -05001264 def test_lru_bug_35780(self):
1265 # C version of the lru_cache was not checking to see if
1266 # the user function call has already modified the cache
1267 # (this arises in recursive calls and in multi-threading).
1268 # This cause the cache to have orphan links not referenced
1269 # by the cache dictionary.
1270
1271 once = True # Modified by f(x) below
1272
1273 @self.module.lru_cache(maxsize=10)
1274 def f(x):
1275 nonlocal once
1276 rv = f'.{x}.'
1277 if x == 20 and once:
1278 once = False
1279 rv = f(x)
1280 return rv
1281
1282 # Fill the cache
1283 for x in range(15):
1284 self.assertEqual(f(x), f'.{x}.')
1285 self.assertEqual(f.cache_info().currsize, 10)
1286
1287 # Make a recursive call and make sure the cache remains full
1288 self.assertEqual(f(20), '.20.')
1289 self.assertEqual(f.cache_info().currsize, 10)
1290
Raymond Hettinger14adbd42019-04-20 07:20:44 -10001291 def test_lru_bug_36650(self):
1292 # C version of lru_cache was treating a call with an empty **kwargs
1293 # dictionary as being distinct from a call with no keywords at all.
1294 # This did not result in an incorrect answer, but it did trigger
1295 # an unexpected cache miss.
1296
1297 @self.module.lru_cache()
1298 def f(x):
1299 pass
1300
1301 f(0)
1302 f(0, **{})
1303 self.assertEqual(f.cache_info().hits, 1)
1304
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001305 def test_lru_hash_only_once(self):
1306 # To protect against weird reentrancy bugs and to improve
1307 # efficiency when faced with slow __hash__ methods, the
1308 # LRU cache guarantees that it will only call __hash__
1309 # only once per use as an argument to the cached function.
1310
1311 @self.module.lru_cache(maxsize=1)
1312 def f(x, y):
1313 return x * 3 + y
1314
1315 # Simulate the integer 5
1316 mock_int = unittest.mock.Mock()
1317 mock_int.__mul__ = unittest.mock.Mock(return_value=15)
1318 mock_int.__hash__ = unittest.mock.Mock(return_value=999)
1319
1320 # Add to cache: One use as an argument gives one call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001321 self.assertEqual(f(mock_int, 1), 16)
1322 self.assertEqual(mock_int.__hash__.call_count, 1)
1323 self.assertEqual(f.cache_info(), (0, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001324
1325 # Cache hit: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001326 self.assertEqual(f(mock_int, 1), 16)
1327 self.assertEqual(mock_int.__hash__.call_count, 2)
1328 self.assertEqual(f.cache_info(), (1, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001329
Ville Skyttä49b27342017-08-03 09:00:59 +03001330 # Cache eviction: No use as an argument gives no additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001331 self.assertEqual(f(6, 2), 20)
1332 self.assertEqual(mock_int.__hash__.call_count, 2)
1333 self.assertEqual(f.cache_info(), (1, 2, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001334
1335 # Cache miss: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001336 self.assertEqual(f(mock_int, 1), 16)
1337 self.assertEqual(mock_int.__hash__.call_count, 3)
1338 self.assertEqual(f.cache_info(), (1, 3, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001339
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08001340 def test_lru_reentrancy_with_len(self):
1341 # Test to make sure the LRU cache code isn't thrown-off by
1342 # caching the built-in len() function. Since len() can be
1343 # cached, we shouldn't use it inside the lru code itself.
1344 old_len = builtins.len
1345 try:
1346 builtins.len = self.module.lru_cache(4)(len)
1347 for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
1348 self.assertEqual(len('abcdefghijklmn'[:i]), i)
1349 finally:
1350 builtins.len = old_len
1351
Raymond Hettinger605a4472017-01-09 07:50:19 -08001352 def test_lru_star_arg_handling(self):
1353 # Test regression that arose in ea064ff3c10f
1354 @functools.lru_cache()
1355 def f(*args):
1356 return args
1357
1358 self.assertEqual(f(1, 2), (1, 2))
1359 self.assertEqual(f((1, 2)), ((1, 2),))
1360
Yury Selivanov46a02db2016-11-09 18:55:45 -05001361 def test_lru_type_error(self):
1362 # Regression test for issue #28653.
1363 # lru_cache was leaking when one of the arguments
1364 # wasn't cacheable.
1365
1366 @functools.lru_cache(maxsize=None)
1367 def infinite_cache(o):
1368 pass
1369
1370 @functools.lru_cache(maxsize=10)
1371 def limited_cache(o):
1372 pass
1373
1374 with self.assertRaises(TypeError):
1375 infinite_cache([])
1376
1377 with self.assertRaises(TypeError):
1378 limited_cache([])
1379
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001380 def test_lru_with_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001381 @self.module.lru_cache(maxsize=None)
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001382 def fib(n):
1383 if n < 2:
1384 return n
1385 return fib(n-1) + fib(n-2)
1386 self.assertEqual([fib(n) for n in range(16)],
1387 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1388 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001389 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001390 fib.cache_clear()
1391 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001392 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1393
1394 def test_lru_with_maxsize_negative(self):
1395 @self.module.lru_cache(maxsize=-10)
1396 def eq(n):
1397 return n
1398 for i in (0, 1):
1399 self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1400 self.assertEqual(eq.cache_info(),
Raymond Hettingerd8080c02019-01-26 03:02:00 -05001401 self.module._CacheInfo(hits=0, misses=300, maxsize=0, currsize=0))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001402
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001403 def test_lru_with_exceptions(self):
1404 # Verify that user_function exceptions get passed through without
1405 # creating a hard-to-read chained exception.
1406 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +01001407 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001408 @self.module.lru_cache(maxsize)
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001409 def func(i):
1410 return 'abc'[i]
1411 self.assertEqual(func(0), 'a')
1412 with self.assertRaises(IndexError) as cm:
1413 func(15)
1414 self.assertIsNone(cm.exception.__context__)
1415 # Verify that the previous exception did not result in a cached entry
1416 with self.assertRaises(IndexError):
1417 func(15)
1418
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001419 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001420 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001421 @self.module.lru_cache(maxsize=maxsize, typed=True)
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001422 def square(x):
1423 return x * x
1424 self.assertEqual(square(3), 9)
1425 self.assertEqual(type(square(3)), type(9))
1426 self.assertEqual(square(3.0), 9.0)
1427 self.assertEqual(type(square(3.0)), type(9.0))
1428 self.assertEqual(square(x=3), 9)
1429 self.assertEqual(type(square(x=3)), type(9))
1430 self.assertEqual(square(x=3.0), 9.0)
1431 self.assertEqual(type(square(x=3.0)), type(9.0))
1432 self.assertEqual(square.cache_info().hits, 4)
1433 self.assertEqual(square.cache_info().misses, 4)
1434
Antoine Pitroub5b37142012-11-13 21:35:40 +01001435 def test_lru_with_keyword_args(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001436 @self.module.lru_cache()
Antoine Pitroub5b37142012-11-13 21:35:40 +01001437 def fib(n):
1438 if n < 2:
1439 return n
1440 return fib(n=n-1) + fib(n=n-2)
1441 self.assertEqual(
1442 [fib(n=number) for number in range(16)],
1443 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1444 )
1445 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001446 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001447 fib.cache_clear()
1448 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001449 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001450
1451 def test_lru_with_keyword_args_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001452 @self.module.lru_cache(maxsize=None)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001453 def fib(n):
1454 if n < 2:
1455 return n
1456 return fib(n=n-1) + fib(n=n-2)
1457 self.assertEqual([fib(n=number) for number in range(16)],
1458 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1459 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001460 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001461 fib.cache_clear()
1462 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001463 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1464
Raymond Hettinger4ee39142017-01-08 17:28:20 -08001465 def test_kwargs_order(self):
1466 # PEP 468: Preserving Keyword Argument Order
1467 @self.module.lru_cache(maxsize=10)
1468 def f(**kwargs):
1469 return list(kwargs.items())
1470 self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)])
1471 self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)])
1472 self.assertEqual(f.cache_info(),
1473 self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2))
1474
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001475 def test_lru_cache_decoration(self):
1476 def f(zomg: 'zomg_annotation'):
1477 """f doc string"""
1478 return 42
1479 g = self.module.lru_cache()(f)
1480 for attr in self.module.WRAPPER_ASSIGNMENTS:
1481 self.assertEqual(getattr(g, attr), getattr(f, attr))
1482
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001483 def test_lru_cache_threaded(self):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001484 n, m = 5, 11
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001485 def orig(x, y):
1486 return 3 * x + y
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001487 f = self.module.lru_cache(maxsize=n*m)(orig)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001488 hits, misses, maxsize, currsize = f.cache_info()
1489 self.assertEqual(currsize, 0)
1490
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001491 start = threading.Event()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001492 def full(k):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001493 start.wait(10)
1494 for _ in range(m):
Serhiy Storchaka391af752015-06-08 12:44:18 +03001495 self.assertEqual(f(k, 0), orig(k, 0))
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001496
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001497 def clear():
1498 start.wait(10)
1499 for _ in range(2*m):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001500 f.cache_clear()
1501
1502 orig_si = sys.getswitchinterval()
Xavier de Gaye7522ef42016-12-08 11:06:56 +01001503 support.setswitchinterval(1e-6)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001504 try:
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001505 # create n threads in order to fill cache
Serhiy Storchaka391af752015-06-08 12:44:18 +03001506 threads = [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001507 for k in range(n)]
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001508 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001509 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001510
1511 hits, misses, maxsize, currsize = f.cache_info()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001512 if self.module is py_functools:
1513 # XXX: Why can be not equal?
1514 self.assertLessEqual(misses, n)
1515 self.assertLessEqual(hits, m*n - misses)
1516 else:
1517 self.assertEqual(misses, n)
1518 self.assertEqual(hits, m*n - misses)
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001519 self.assertEqual(currsize, n)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001520
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001521 # create n threads in order to fill cache and 1 to clear it
1522 threads = [threading.Thread(target=clear)]
Serhiy Storchaka391af752015-06-08 12:44:18 +03001523 threads += [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001524 for k in range(n)]
1525 start.clear()
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001526 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001527 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001528 finally:
1529 sys.setswitchinterval(orig_si)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001530
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001531 def test_lru_cache_threaded2(self):
1532 # Simultaneous call with the same arguments
1533 n, m = 5, 7
1534 start = threading.Barrier(n+1)
1535 pause = threading.Barrier(n+1)
1536 stop = threading.Barrier(n+1)
1537 @self.module.lru_cache(maxsize=m*n)
1538 def f(x):
1539 pause.wait(10)
1540 return 3 * x
1541 self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1542 def test():
1543 for i in range(m):
1544 start.wait(10)
1545 self.assertEqual(f(i), 3 * i)
1546 stop.wait(10)
1547 threads = [threading.Thread(target=test) for k in range(n)]
1548 with support.start_threads(threads):
1549 for i in range(m):
1550 start.wait(10)
1551 stop.reset()
1552 pause.wait(10)
1553 start.reset()
1554 stop.wait(10)
1555 pause.reset()
1556 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1557
Serhiy Storchaka67796522017-01-12 18:34:33 +02001558 def test_lru_cache_threaded3(self):
1559 @self.module.lru_cache(maxsize=2)
1560 def f(x):
1561 time.sleep(.01)
1562 return 3 * x
1563 def test(i, x):
1564 with self.subTest(thread=i):
1565 self.assertEqual(f(x), 3 * x, i)
1566 threads = [threading.Thread(target=test, args=(i, v))
1567 for i, v in enumerate([1, 2, 2, 3, 2])]
1568 with support.start_threads(threads):
1569 pass
1570
Raymond Hettinger03923422013-03-04 02:52:50 -05001571 def test_need_for_rlock(self):
1572 # This will deadlock on an LRU cache that uses a regular lock
1573
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001574 @self.module.lru_cache(maxsize=10)
Raymond Hettinger03923422013-03-04 02:52:50 -05001575 def test_func(x):
1576 'Used to demonstrate a reentrant lru_cache call within a single thread'
1577 return x
1578
1579 class DoubleEq:
1580 'Demonstrate a reentrant lru_cache call within a single thread'
1581 def __init__(self, x):
1582 self.x = x
1583 def __hash__(self):
1584 return self.x
1585 def __eq__(self, other):
1586 if self.x == 2:
1587 test_func(DoubleEq(1))
1588 return self.x == other.x
1589
1590 test_func(DoubleEq(1)) # Load the cache
1591 test_func(DoubleEq(2)) # Load the cache
1592 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1593 DoubleEq(2)) # Verify the correct return value
1594
Serhiy Storchakae7070f02015-06-08 11:19:24 +03001595 def test_lru_method(self):
1596 class X(int):
1597 f_cnt = 0
1598 @self.module.lru_cache(2)
1599 def f(self, x):
1600 self.f_cnt += 1
1601 return x*10+self
1602 a = X(5)
1603 b = X(5)
1604 c = X(7)
1605 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1606
1607 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1608 self.assertEqual(a.f(x), x*10 + 5)
1609 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1610 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1611
1612 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1613 self.assertEqual(b.f(x), x*10 + 5)
1614 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1615 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1616
1617 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1618 self.assertEqual(c.f(x), x*10 + 7)
1619 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1620 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1621
1622 self.assertEqual(a.f.cache_info(), X.f.cache_info())
1623 self.assertEqual(b.f.cache_info(), X.f.cache_info())
1624 self.assertEqual(c.f.cache_info(), X.f.cache_info())
1625
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001626 def test_pickle(self):
1627 cls = self.__class__
1628 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1629 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1630 with self.subTest(proto=proto, func=f):
1631 f_copy = pickle.loads(pickle.dumps(f, proto))
1632 self.assertIs(f_copy, f)
1633
1634 def test_copy(self):
1635 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001636 def orig(x, y):
1637 return 3 * x + y
1638 part = self.module.partial(orig, 2)
1639 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1640 self.module.lru_cache(2)(part))
1641 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001642 with self.subTest(func=f):
1643 f_copy = copy.copy(f)
1644 self.assertIs(f_copy, f)
1645
1646 def test_deepcopy(self):
1647 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001648 def orig(x, y):
1649 return 3 * x + y
1650 part = self.module.partial(orig, 2)
1651 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1652 self.module.lru_cache(2)(part))
1653 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001654 with self.subTest(func=f):
1655 f_copy = copy.deepcopy(f)
1656 self.assertIs(f_copy, f)
1657
1658
1659@py_functools.lru_cache()
1660def py_cached_func(x, y):
1661 return 3 * x + y
1662
1663@c_functools.lru_cache()
1664def c_cached_func(x, y):
1665 return 3 * x + y
1666
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001667
1668class TestLRUPy(TestLRU, unittest.TestCase):
1669 module = py_functools
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001670 cached_func = py_cached_func,
1671
1672 @module.lru_cache()
1673 def cached_meth(self, x, y):
1674 return 3 * x + y
1675
1676 @staticmethod
1677 @module.lru_cache()
1678 def cached_staticmeth(x, y):
1679 return 3 * x + y
1680
1681
1682class TestLRUC(TestLRU, unittest.TestCase):
1683 module = c_functools
1684 cached_func = c_cached_func,
1685
1686 @module.lru_cache()
1687 def cached_meth(self, x, y):
1688 return 3 * x + y
1689
1690 @staticmethod
1691 @module.lru_cache()
1692 def cached_staticmeth(x, y):
1693 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001694
Raymond Hettinger03923422013-03-04 02:52:50 -05001695
Łukasz Langa6f692512013-06-05 12:20:24 +02001696class TestSingleDispatch(unittest.TestCase):
1697 def test_simple_overloads(self):
1698 @functools.singledispatch
1699 def g(obj):
1700 return "base"
1701 def g_int(i):
1702 return "integer"
1703 g.register(int, g_int)
1704 self.assertEqual(g("str"), "base")
1705 self.assertEqual(g(1), "integer")
1706 self.assertEqual(g([1,2,3]), "base")
1707
1708 def test_mro(self):
1709 @functools.singledispatch
1710 def g(obj):
1711 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001712 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001713 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001714 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001715 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001716 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001717 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001718 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001719 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001720 def g_A(a):
1721 return "A"
1722 def g_B(b):
1723 return "B"
1724 g.register(A, g_A)
1725 g.register(B, g_B)
1726 self.assertEqual(g(A()), "A")
1727 self.assertEqual(g(B()), "B")
1728 self.assertEqual(g(C()), "A")
1729 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001730
1731 def test_register_decorator(self):
1732 @functools.singledispatch
1733 def g(obj):
1734 return "base"
1735 @g.register(int)
1736 def g_int(i):
1737 return "int %s" % (i,)
1738 self.assertEqual(g(""), "base")
1739 self.assertEqual(g(12), "int 12")
1740 self.assertIs(g.dispatch(int), g_int)
1741 self.assertIs(g.dispatch(object), g.dispatch(str))
1742 # Note: in the assert above this is not g.
1743 # @singledispatch returns the wrapper.
1744
1745 def test_wrapping_attributes(self):
1746 @functools.singledispatch
1747 def g(obj):
1748 "Simple test"
1749 return "Test"
1750 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02001751 if sys.flags.optimize < 2:
1752 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02001753
1754 @unittest.skipUnless(decimal, 'requires _decimal')
1755 @support.cpython_only
1756 def test_c_classes(self):
1757 @functools.singledispatch
1758 def g(obj):
1759 return "base"
1760 @g.register(decimal.DecimalException)
1761 def _(obj):
1762 return obj.args
1763 subn = decimal.Subnormal("Exponent < Emin")
1764 rnd = decimal.Rounded("Number got rounded")
1765 self.assertEqual(g(subn), ("Exponent < Emin",))
1766 self.assertEqual(g(rnd), ("Number got rounded",))
1767 @g.register(decimal.Subnormal)
1768 def _(obj):
1769 return "Too small to care."
1770 self.assertEqual(g(subn), "Too small to care.")
1771 self.assertEqual(g(rnd), ("Number got rounded",))
1772
1773 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001774 # None of the examples in this test depend on haystack ordering.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001775 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001776 mro = functools._compose_mro
1777 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1778 for haystack in permutations(bases):
1779 m = mro(dict, haystack)
Guido van Rossumf0666942016-08-23 10:47:07 -07001780 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
1781 c.Collection, c.Sized, c.Iterable,
1782 c.Container, object])
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001783 bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict]
Łukasz Langa6f692512013-06-05 12:20:24 +02001784 for haystack in permutations(bases):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001785 m = mro(collections.ChainMap, haystack)
1786 self.assertEqual(m, [collections.ChainMap, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001787 c.Collection, c.Sized, c.Iterable,
1788 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001789
1790 # If there's a generic function with implementations registered for
1791 # both Sized and Container, passing a defaultdict to it results in an
1792 # ambiguous dispatch which will cause a RuntimeError (see
1793 # test_mro_conflicts).
1794 bases = [c.Container, c.Sized, str]
1795 for haystack in permutations(bases):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001796 m = mro(collections.defaultdict, [c.Sized, c.Container, str])
1797 self.assertEqual(m, [collections.defaultdict, dict, c.Sized,
1798 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001799
1800 # MutableSequence below is registered directly on D. In other words, it
Martin Panter46f50722016-05-26 05:35:26 +00001801 # precedes MutableMapping which means single dispatch will always
Łukasz Langa3720c772013-07-01 16:00:38 +02001802 # choose MutableSequence here.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001803 class D(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02001804 pass
1805 c.MutableSequence.register(D)
1806 bases = [c.MutableSequence, c.MutableMapping]
1807 for haystack in permutations(bases):
1808 m = mro(D, bases)
Guido van Rossumf0666942016-08-23 10:47:07 -07001809 self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001810 collections.defaultdict, dict, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001811 c.Collection, c.Sized, c.Iterable, c.Container,
Łukasz Langa3720c772013-07-01 16:00:38 +02001812 object])
1813
1814 # Container and Callable are registered on different base classes and
1815 # a generic function supporting both should always pick the Callable
1816 # implementation if a C instance is passed.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001817 class C(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02001818 def __call__(self):
1819 pass
1820 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1821 for haystack in permutations(bases):
1822 m = mro(C, haystack)
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001823 self.assertEqual(m, [C, c.Callable, collections.defaultdict, dict, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001824 c.Collection, c.Sized, c.Iterable,
1825 c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001826
1827 def test_register_abc(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001828 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001829 d = {"a": "b"}
1830 l = [1, 2, 3]
1831 s = {object(), None}
1832 f = frozenset(s)
1833 t = (1, 2, 3)
1834 @functools.singledispatch
1835 def g(obj):
1836 return "base"
1837 self.assertEqual(g(d), "base")
1838 self.assertEqual(g(l), "base")
1839 self.assertEqual(g(s), "base")
1840 self.assertEqual(g(f), "base")
1841 self.assertEqual(g(t), "base")
1842 g.register(c.Sized, lambda obj: "sized")
1843 self.assertEqual(g(d), "sized")
1844 self.assertEqual(g(l), "sized")
1845 self.assertEqual(g(s), "sized")
1846 self.assertEqual(g(f), "sized")
1847 self.assertEqual(g(t), "sized")
1848 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1849 self.assertEqual(g(d), "mutablemapping")
1850 self.assertEqual(g(l), "sized")
1851 self.assertEqual(g(s), "sized")
1852 self.assertEqual(g(f), "sized")
1853 self.assertEqual(g(t), "sized")
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001854 g.register(collections.ChainMap, lambda obj: "chainmap")
Łukasz Langa6f692512013-06-05 12:20:24 +02001855 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1856 self.assertEqual(g(l), "sized")
1857 self.assertEqual(g(s), "sized")
1858 self.assertEqual(g(f), "sized")
1859 self.assertEqual(g(t), "sized")
1860 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1861 self.assertEqual(g(d), "mutablemapping")
1862 self.assertEqual(g(l), "mutablesequence")
1863 self.assertEqual(g(s), "sized")
1864 self.assertEqual(g(f), "sized")
1865 self.assertEqual(g(t), "sized")
1866 g.register(c.MutableSet, lambda obj: "mutableset")
1867 self.assertEqual(g(d), "mutablemapping")
1868 self.assertEqual(g(l), "mutablesequence")
1869 self.assertEqual(g(s), "mutableset")
1870 self.assertEqual(g(f), "sized")
1871 self.assertEqual(g(t), "sized")
1872 g.register(c.Mapping, lambda obj: "mapping")
1873 self.assertEqual(g(d), "mutablemapping") # not specific enough
1874 self.assertEqual(g(l), "mutablesequence")
1875 self.assertEqual(g(s), "mutableset")
1876 self.assertEqual(g(f), "sized")
1877 self.assertEqual(g(t), "sized")
1878 g.register(c.Sequence, lambda obj: "sequence")
1879 self.assertEqual(g(d), "mutablemapping")
1880 self.assertEqual(g(l), "mutablesequence")
1881 self.assertEqual(g(s), "mutableset")
1882 self.assertEqual(g(f), "sized")
1883 self.assertEqual(g(t), "sequence")
1884 g.register(c.Set, lambda obj: "set")
1885 self.assertEqual(g(d), "mutablemapping")
1886 self.assertEqual(g(l), "mutablesequence")
1887 self.assertEqual(g(s), "mutableset")
1888 self.assertEqual(g(f), "set")
1889 self.assertEqual(g(t), "sequence")
1890 g.register(dict, lambda obj: "dict")
1891 self.assertEqual(g(d), "dict")
1892 self.assertEqual(g(l), "mutablesequence")
1893 self.assertEqual(g(s), "mutableset")
1894 self.assertEqual(g(f), "set")
1895 self.assertEqual(g(t), "sequence")
1896 g.register(list, lambda obj: "list")
1897 self.assertEqual(g(d), "dict")
1898 self.assertEqual(g(l), "list")
1899 self.assertEqual(g(s), "mutableset")
1900 self.assertEqual(g(f), "set")
1901 self.assertEqual(g(t), "sequence")
1902 g.register(set, lambda obj: "concrete-set")
1903 self.assertEqual(g(d), "dict")
1904 self.assertEqual(g(l), "list")
1905 self.assertEqual(g(s), "concrete-set")
1906 self.assertEqual(g(f), "set")
1907 self.assertEqual(g(t), "sequence")
1908 g.register(frozenset, lambda obj: "frozen-set")
1909 self.assertEqual(g(d), "dict")
1910 self.assertEqual(g(l), "list")
1911 self.assertEqual(g(s), "concrete-set")
1912 self.assertEqual(g(f), "frozen-set")
1913 self.assertEqual(g(t), "sequence")
1914 g.register(tuple, lambda obj: "tuple")
1915 self.assertEqual(g(d), "dict")
1916 self.assertEqual(g(l), "list")
1917 self.assertEqual(g(s), "concrete-set")
1918 self.assertEqual(g(f), "frozen-set")
1919 self.assertEqual(g(t), "tuple")
1920
Łukasz Langa3720c772013-07-01 16:00:38 +02001921 def test_c3_abc(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001922 c = collections.abc
Łukasz Langa3720c772013-07-01 16:00:38 +02001923 mro = functools._c3_mro
1924 class A(object):
1925 pass
1926 class B(A):
1927 def __len__(self):
1928 return 0 # implies Sized
1929 @c.Container.register
1930 class C(object):
1931 pass
1932 class D(object):
1933 pass # unrelated
1934 class X(D, C, B):
1935 def __call__(self):
1936 pass # implies Callable
1937 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
1938 for abcs in permutations([c.Sized, c.Callable, c.Container]):
1939 self.assertEqual(mro(X, abcs=abcs), expected)
1940 # unrelated ABCs don't appear in the resulting MRO
1941 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
1942 self.assertEqual(mro(X, abcs=many_abcs), expected)
1943
Yury Selivanov77a8cd62015-08-18 14:20:00 -04001944 def test_false_meta(self):
1945 # see issue23572
1946 class MetaA(type):
1947 def __len__(self):
1948 return 0
1949 class A(metaclass=MetaA):
1950 pass
1951 class AA(A):
1952 pass
1953 @functools.singledispatch
1954 def fun(a):
1955 return 'base A'
1956 @fun.register(A)
1957 def _(a):
1958 return 'fun A'
1959 aa = AA()
1960 self.assertEqual(fun(aa), 'fun A')
1961
Łukasz Langa6f692512013-06-05 12:20:24 +02001962 def test_mro_conflicts(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001963 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001964 @functools.singledispatch
1965 def g(arg):
1966 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02001967 class O(c.Sized):
1968 def __len__(self):
1969 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001970 o = O()
1971 self.assertEqual(g(o), "base")
1972 g.register(c.Iterable, lambda arg: "iterable")
1973 g.register(c.Container, lambda arg: "container")
1974 g.register(c.Sized, lambda arg: "sized")
1975 g.register(c.Set, lambda arg: "set")
1976 self.assertEqual(g(o), "sized")
1977 c.Iterable.register(O)
1978 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
1979 c.Container.register(O)
1980 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02001981 c.Set.register(O)
1982 self.assertEqual(g(o), "set") # because c.Set is a subclass of
1983 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02001984 class P:
1985 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02001986 p = P()
1987 self.assertEqual(g(p), "base")
1988 c.Iterable.register(P)
1989 self.assertEqual(g(p), "iterable")
1990 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02001991 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02001992 g(p)
Benjamin Petersonab078e92016-07-13 21:13:29 -07001993 self.assertIn(
1994 str(re_one.exception),
1995 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1996 "or <class 'collections.abc.Iterable'>"),
1997 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
1998 "or <class 'collections.abc.Container'>")),
1999 )
Łukasz Langa6f692512013-06-05 12:20:24 +02002000 class Q(c.Sized):
2001 def __len__(self):
2002 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02002003 q = Q()
2004 self.assertEqual(g(q), "sized")
2005 c.Iterable.register(Q)
2006 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
2007 c.Set.register(Q)
2008 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02002009 # c.Sized and c.Iterable
2010 @functools.singledispatch
2011 def h(arg):
2012 return "base"
2013 @h.register(c.Sized)
2014 def _(arg):
2015 return "sized"
2016 @h.register(c.Container)
2017 def _(arg):
2018 return "container"
2019 # Even though Sized and Container are explicit bases of MutableMapping,
2020 # this ABC is implicitly registered on defaultdict which makes all of
2021 # MutableMapping's bases implicit as well from defaultdict's
2022 # perspective.
2023 with self.assertRaises(RuntimeError) as re_two:
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002024 h(collections.defaultdict(lambda: 0))
Benjamin Petersonab078e92016-07-13 21:13:29 -07002025 self.assertIn(
2026 str(re_two.exception),
2027 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2028 "or <class 'collections.abc.Sized'>"),
2029 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2030 "or <class 'collections.abc.Container'>")),
2031 )
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002032 class R(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02002033 pass
2034 c.MutableSequence.register(R)
2035 @functools.singledispatch
2036 def i(arg):
2037 return "base"
2038 @i.register(c.MutableMapping)
2039 def _(arg):
2040 return "mapping"
2041 @i.register(c.MutableSequence)
2042 def _(arg):
2043 return "sequence"
2044 r = R()
2045 self.assertEqual(i(r), "sequence")
2046 class S:
2047 pass
2048 class T(S, c.Sized):
2049 def __len__(self):
2050 return 0
2051 t = T()
2052 self.assertEqual(h(t), "sized")
2053 c.Container.register(T)
2054 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
2055 class U:
2056 def __len__(self):
2057 return 0
2058 u = U()
2059 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
2060 # from the existence of __len__()
2061 c.Container.register(U)
2062 # There is no preference for registered versus inferred ABCs.
2063 with self.assertRaises(RuntimeError) as re_three:
2064 h(u)
Benjamin Petersonab078e92016-07-13 21:13:29 -07002065 self.assertIn(
2066 str(re_three.exception),
2067 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2068 "or <class 'collections.abc.Sized'>"),
2069 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2070 "or <class 'collections.abc.Container'>")),
2071 )
Łukasz Langa3720c772013-07-01 16:00:38 +02002072 class V(c.Sized, S):
2073 def __len__(self):
2074 return 0
2075 @functools.singledispatch
2076 def j(arg):
2077 return "base"
2078 @j.register(S)
2079 def _(arg):
2080 return "s"
2081 @j.register(c.Container)
2082 def _(arg):
2083 return "container"
2084 v = V()
2085 self.assertEqual(j(v), "s")
2086 c.Container.register(V)
2087 self.assertEqual(j(v), "container") # because it ends up right after
2088 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02002089
2090 def test_cache_invalidation(self):
2091 from collections import UserDict
INADA Naoki9811e802017-09-30 16:13:02 +09002092 import weakref
2093
Łukasz Langa6f692512013-06-05 12:20:24 +02002094 class TracingDict(UserDict):
2095 def __init__(self, *args, **kwargs):
2096 super(TracingDict, self).__init__(*args, **kwargs)
2097 self.set_ops = []
2098 self.get_ops = []
2099 def __getitem__(self, key):
2100 result = self.data[key]
2101 self.get_ops.append(key)
2102 return result
2103 def __setitem__(self, key, value):
2104 self.set_ops.append(key)
2105 self.data[key] = value
2106 def clear(self):
2107 self.data.clear()
INADA Naoki9811e802017-09-30 16:13:02 +09002108
Łukasz Langa6f692512013-06-05 12:20:24 +02002109 td = TracingDict()
INADA Naoki9811e802017-09-30 16:13:02 +09002110 with support.swap_attr(weakref, "WeakKeyDictionary", lambda: td):
2111 c = collections.abc
2112 @functools.singledispatch
2113 def g(arg):
2114 return "base"
2115 d = {}
2116 l = []
2117 self.assertEqual(len(td), 0)
2118 self.assertEqual(g(d), "base")
2119 self.assertEqual(len(td), 1)
2120 self.assertEqual(td.get_ops, [])
2121 self.assertEqual(td.set_ops, [dict])
2122 self.assertEqual(td.data[dict], g.registry[object])
2123 self.assertEqual(g(l), "base")
2124 self.assertEqual(len(td), 2)
2125 self.assertEqual(td.get_ops, [])
2126 self.assertEqual(td.set_ops, [dict, list])
2127 self.assertEqual(td.data[dict], g.registry[object])
2128 self.assertEqual(td.data[list], g.registry[object])
2129 self.assertEqual(td.data[dict], td.data[list])
2130 self.assertEqual(g(l), "base")
2131 self.assertEqual(g(d), "base")
2132 self.assertEqual(td.get_ops, [list, dict])
2133 self.assertEqual(td.set_ops, [dict, list])
2134 g.register(list, lambda arg: "list")
2135 self.assertEqual(td.get_ops, [list, dict])
2136 self.assertEqual(len(td), 0)
2137 self.assertEqual(g(d), "base")
2138 self.assertEqual(len(td), 1)
2139 self.assertEqual(td.get_ops, [list, dict])
2140 self.assertEqual(td.set_ops, [dict, list, dict])
2141 self.assertEqual(td.data[dict],
2142 functools._find_impl(dict, g.registry))
2143 self.assertEqual(g(l), "list")
2144 self.assertEqual(len(td), 2)
2145 self.assertEqual(td.get_ops, [list, dict])
2146 self.assertEqual(td.set_ops, [dict, list, dict, list])
2147 self.assertEqual(td.data[list],
2148 functools._find_impl(list, g.registry))
2149 class X:
2150 pass
2151 c.MutableMapping.register(X) # Will not invalidate the cache,
2152 # not using ABCs yet.
2153 self.assertEqual(g(d), "base")
2154 self.assertEqual(g(l), "list")
2155 self.assertEqual(td.get_ops, [list, dict, dict, list])
2156 self.assertEqual(td.set_ops, [dict, list, dict, list])
2157 g.register(c.Sized, lambda arg: "sized")
2158 self.assertEqual(len(td), 0)
2159 self.assertEqual(g(d), "sized")
2160 self.assertEqual(len(td), 1)
2161 self.assertEqual(td.get_ops, [list, dict, dict, list])
2162 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
2163 self.assertEqual(g(l), "list")
2164 self.assertEqual(len(td), 2)
2165 self.assertEqual(td.get_ops, [list, dict, dict, list])
2166 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2167 self.assertEqual(g(l), "list")
2168 self.assertEqual(g(d), "sized")
2169 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
2170 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2171 g.dispatch(list)
2172 g.dispatch(dict)
2173 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
2174 list, dict])
2175 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2176 c.MutableSet.register(X) # Will invalidate the cache.
2177 self.assertEqual(len(td), 2) # Stale cache.
2178 self.assertEqual(g(l), "list")
2179 self.assertEqual(len(td), 1)
2180 g.register(c.MutableMapping, lambda arg: "mutablemapping")
2181 self.assertEqual(len(td), 0)
2182 self.assertEqual(g(d), "mutablemapping")
2183 self.assertEqual(len(td), 1)
2184 self.assertEqual(g(l), "list")
2185 self.assertEqual(len(td), 2)
2186 g.register(dict, lambda arg: "dict")
2187 self.assertEqual(g(d), "dict")
2188 self.assertEqual(g(l), "list")
2189 g._clear_cache()
2190 self.assertEqual(len(td), 0)
Łukasz Langa6f692512013-06-05 12:20:24 +02002191
Łukasz Langae5697532017-12-11 13:56:31 -08002192 def test_annotations(self):
2193 @functools.singledispatch
2194 def i(arg):
2195 return "base"
2196 @i.register
2197 def _(arg: collections.abc.Mapping):
2198 return "mapping"
2199 @i.register
2200 def _(arg: "collections.abc.Sequence"):
2201 return "sequence"
2202 self.assertEqual(i(None), "base")
2203 self.assertEqual(i({"a": 1}), "mapping")
2204 self.assertEqual(i([1, 2, 3]), "sequence")
2205 self.assertEqual(i((1, 2, 3)), "sequence")
2206 self.assertEqual(i("str"), "sequence")
2207
2208 # Registering classes as callables doesn't work with annotations,
2209 # you need to pass the type explicitly.
2210 @i.register(str)
2211 class _:
2212 def __init__(self, arg):
2213 self.arg = arg
2214
2215 def __eq__(self, other):
2216 return self.arg == other
2217 self.assertEqual(i("str"), "str")
2218
Ethan Smithc6512752018-05-26 16:38:33 -04002219 def test_method_register(self):
2220 class A:
2221 @functools.singledispatchmethod
2222 def t(self, arg):
2223 self.arg = "base"
2224 @t.register(int)
2225 def _(self, arg):
2226 self.arg = "int"
2227 @t.register(str)
2228 def _(self, arg):
2229 self.arg = "str"
2230 a = A()
2231
2232 a.t(0)
2233 self.assertEqual(a.arg, "int")
2234 aa = A()
2235 self.assertFalse(hasattr(aa, 'arg'))
2236 a.t('')
2237 self.assertEqual(a.arg, "str")
2238 aa = A()
2239 self.assertFalse(hasattr(aa, 'arg'))
2240 a.t(0.0)
2241 self.assertEqual(a.arg, "base")
2242 aa = A()
2243 self.assertFalse(hasattr(aa, 'arg'))
2244
2245 def test_staticmethod_register(self):
2246 class A:
2247 @functools.singledispatchmethod
2248 @staticmethod
2249 def t(arg):
2250 return arg
2251 @t.register(int)
2252 @staticmethod
2253 def _(arg):
2254 return isinstance(arg, int)
2255 @t.register(str)
2256 @staticmethod
2257 def _(arg):
2258 return isinstance(arg, str)
2259 a = A()
2260
2261 self.assertTrue(A.t(0))
2262 self.assertTrue(A.t(''))
2263 self.assertEqual(A.t(0.0), 0.0)
2264
2265 def test_classmethod_register(self):
2266 class A:
2267 def __init__(self, arg):
2268 self.arg = arg
2269
2270 @functools.singledispatchmethod
2271 @classmethod
2272 def t(cls, arg):
2273 return cls("base")
2274 @t.register(int)
2275 @classmethod
2276 def _(cls, arg):
2277 return cls("int")
2278 @t.register(str)
2279 @classmethod
2280 def _(cls, arg):
2281 return cls("str")
2282
2283 self.assertEqual(A.t(0).arg, "int")
2284 self.assertEqual(A.t('').arg, "str")
2285 self.assertEqual(A.t(0.0).arg, "base")
2286
2287 def test_callable_register(self):
2288 class A:
2289 def __init__(self, arg):
2290 self.arg = arg
2291
2292 @functools.singledispatchmethod
2293 @classmethod
2294 def t(cls, arg):
2295 return cls("base")
2296
2297 @A.t.register(int)
2298 @classmethod
2299 def _(cls, arg):
2300 return cls("int")
2301 @A.t.register(str)
2302 @classmethod
2303 def _(cls, arg):
2304 return cls("str")
2305
2306 self.assertEqual(A.t(0).arg, "int")
2307 self.assertEqual(A.t('').arg, "str")
2308 self.assertEqual(A.t(0.0).arg, "base")
2309
2310 def test_abstractmethod_register(self):
2311 class Abstract(abc.ABCMeta):
2312
2313 @functools.singledispatchmethod
2314 @abc.abstractmethod
2315 def add(self, x, y):
2316 pass
2317
2318 self.assertTrue(Abstract.add.__isabstractmethod__)
2319
2320 def test_type_ann_register(self):
2321 class A:
2322 @functools.singledispatchmethod
2323 def t(self, arg):
2324 return "base"
2325 @t.register
2326 def _(self, arg: int):
2327 return "int"
2328 @t.register
2329 def _(self, arg: str):
2330 return "str"
2331 a = A()
2332
2333 self.assertEqual(a.t(0), "int")
2334 self.assertEqual(a.t(''), "str")
2335 self.assertEqual(a.t(0.0), "base")
2336
Łukasz Langae5697532017-12-11 13:56:31 -08002337 def test_invalid_registrations(self):
2338 msg_prefix = "Invalid first argument to `register()`: "
2339 msg_suffix = (
2340 ". Use either `@register(some_class)` or plain `@register` on an "
2341 "annotated function."
2342 )
2343 @functools.singledispatch
2344 def i(arg):
2345 return "base"
2346 with self.assertRaises(TypeError) as exc:
2347 @i.register(42)
2348 def _(arg):
2349 return "I annotated with a non-type"
2350 self.assertTrue(str(exc.exception).startswith(msg_prefix + "42"))
2351 self.assertTrue(str(exc.exception).endswith(msg_suffix))
2352 with self.assertRaises(TypeError) as exc:
2353 @i.register
2354 def _(arg):
2355 return "I forgot to annotate"
2356 self.assertTrue(str(exc.exception).startswith(msg_prefix +
2357 "<function TestSingleDispatch.test_invalid_registrations.<locals>._"
2358 ))
2359 self.assertTrue(str(exc.exception).endswith(msg_suffix))
2360
Łukasz Langae5697532017-12-11 13:56:31 -08002361 with self.assertRaises(TypeError) as exc:
2362 @i.register
2363 def _(arg: typing.Iterable[str]):
2364 # At runtime, dispatching on generics is impossible.
2365 # When registering implementations with singledispatch, avoid
2366 # types from `typing`. Instead, annotate with regular types
2367 # or ABCs.
2368 return "I annotated with a generic collection"
Lysandros Nikolaoud6738102019-05-20 00:11:21 +02002369 self.assertTrue(str(exc.exception).startswith(
2370 "Invalid annotation for 'arg'."
Łukasz Langae5697532017-12-11 13:56:31 -08002371 ))
Lysandros Nikolaoud6738102019-05-20 00:11:21 +02002372 self.assertTrue(str(exc.exception).endswith(
2373 'typing.Iterable[str] is not a class.'
2374 ))
Łukasz Langae5697532017-12-11 13:56:31 -08002375
Dong-hee Na445f1b32018-07-10 16:26:36 +09002376 def test_invalid_positional_argument(self):
2377 @functools.singledispatch
2378 def f(*args):
2379 pass
2380 msg = 'f requires at least 1 positional argument'
INADA Naoki56d8f572018-07-17 13:44:47 +09002381 with self.assertRaisesRegex(TypeError, msg):
Dong-hee Na445f1b32018-07-10 16:26:36 +09002382 f()
Łukasz Langa6f692512013-06-05 12:20:24 +02002383
Carl Meyerd658dea2018-08-28 01:11:56 -06002384
2385class CachedCostItem:
2386 _cost = 1
2387
2388 def __init__(self):
2389 self.lock = py_functools.RLock()
2390
2391 @py_functools.cached_property
2392 def cost(self):
2393 """The cost of the item."""
2394 with self.lock:
2395 self._cost += 1
2396 return self._cost
2397
2398
2399class OptionallyCachedCostItem:
2400 _cost = 1
2401
2402 def get_cost(self):
2403 """The cost of the item."""
2404 self._cost += 1
2405 return self._cost
2406
2407 cached_cost = py_functools.cached_property(get_cost)
2408
2409
2410class CachedCostItemWait:
2411
2412 def __init__(self, event):
2413 self._cost = 1
2414 self.lock = py_functools.RLock()
2415 self.event = event
2416
2417 @py_functools.cached_property
2418 def cost(self):
2419 self.event.wait(1)
2420 with self.lock:
2421 self._cost += 1
2422 return self._cost
2423
2424
2425class CachedCostItemWithSlots:
2426 __slots__ = ('_cost')
2427
2428 def __init__(self):
2429 self._cost = 1
2430
2431 @py_functools.cached_property
2432 def cost(self):
2433 raise RuntimeError('never called, slots not supported')
2434
2435
2436class TestCachedProperty(unittest.TestCase):
2437 def test_cached(self):
2438 item = CachedCostItem()
2439 self.assertEqual(item.cost, 2)
2440 self.assertEqual(item.cost, 2) # not 3
2441
2442 def test_cached_attribute_name_differs_from_func_name(self):
2443 item = OptionallyCachedCostItem()
2444 self.assertEqual(item.get_cost(), 2)
2445 self.assertEqual(item.cached_cost, 3)
2446 self.assertEqual(item.get_cost(), 4)
2447 self.assertEqual(item.cached_cost, 3)
2448
2449 def test_threaded(self):
2450 go = threading.Event()
2451 item = CachedCostItemWait(go)
2452
2453 num_threads = 3
2454
2455 orig_si = sys.getswitchinterval()
2456 sys.setswitchinterval(1e-6)
2457 try:
2458 threads = [
2459 threading.Thread(target=lambda: item.cost)
2460 for k in range(num_threads)
2461 ]
2462 with support.start_threads(threads):
2463 go.set()
2464 finally:
2465 sys.setswitchinterval(orig_si)
2466
2467 self.assertEqual(item.cost, 2)
2468
2469 def test_object_with_slots(self):
2470 item = CachedCostItemWithSlots()
2471 with self.assertRaisesRegex(
2472 TypeError,
2473 "No '__dict__' attribute on 'CachedCostItemWithSlots' instance to cache 'cost' property.",
2474 ):
2475 item.cost
2476
2477 def test_immutable_dict(self):
2478 class MyMeta(type):
2479 @py_functools.cached_property
2480 def prop(self):
2481 return True
2482
2483 class MyClass(metaclass=MyMeta):
2484 pass
2485
2486 with self.assertRaisesRegex(
2487 TypeError,
2488 "The '__dict__' attribute on 'MyMeta' instance does not support item assignment for caching 'prop' property.",
2489 ):
2490 MyClass.prop
2491
2492 def test_reuse_different_names(self):
2493 """Disallow this case because decorated function a would not be cached."""
2494 with self.assertRaises(RuntimeError) as ctx:
2495 class ReusedCachedProperty:
2496 @py_functools.cached_property
2497 def a(self):
2498 pass
2499
2500 b = a
2501
2502 self.assertEqual(
2503 str(ctx.exception.__context__),
2504 str(TypeError("Cannot assign the same cached_property to two different names ('a' and 'b')."))
2505 )
2506
2507 def test_reuse_same_name(self):
2508 """Reusing a cached_property on different classes under the same name is OK."""
2509 counter = 0
2510
2511 @py_functools.cached_property
2512 def _cp(_self):
2513 nonlocal counter
2514 counter += 1
2515 return counter
2516
2517 class A:
2518 cp = _cp
2519
2520 class B:
2521 cp = _cp
2522
2523 a = A()
2524 b = B()
2525
2526 self.assertEqual(a.cp, 1)
2527 self.assertEqual(b.cp, 2)
2528 self.assertEqual(a.cp, 1)
2529
2530 def test_set_name_not_called(self):
2531 cp = py_functools.cached_property(lambda s: None)
2532 class Foo:
2533 pass
2534
2535 Foo.cp = cp
2536
2537 with self.assertRaisesRegex(
2538 TypeError,
2539 "Cannot use cached_property instance without calling __set_name__ on it.",
2540 ):
2541 Foo().cp
2542
2543 def test_access_from_class(self):
2544 self.assertIsInstance(CachedCostItem.cost, py_functools.cached_property)
2545
2546 def test_doc(self):
2547 self.assertEqual(CachedCostItem.cost.__doc__, "The cost of the item.")
2548
2549
Raymond Hettinger9c323f82005-02-28 19:39:44 +00002550if __name__ == '__main__':
Zachary Ware38c707e2015-04-13 15:00:43 -05002551 unittest.main()