blob: 85c65d183260673b02d61c339b96c7499be8a420 [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08002import builtins
Raymond Hettinger003be522011-05-03 11:01:32 -07003import collections
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03004import collections.abc
Serhiy Storchaka45120f22015-10-24 09:49:56 +03005import copy
Łukasz Langa6f692512013-06-05 12:20:24 +02006from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00007import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00008from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02009import sys
10from test import support
Antoine Pitroua6a4dc82017-09-07 18:56:24 +020011import threading
Serhiy Storchaka67796522017-01-12 18:34:33 +020012import time
Łukasz Langae5697532017-12-11 13:56:31 -080013import typing
Łukasz Langa6f692512013-06-05 12:20:24 +020014import unittest
Raymond Hettingerd191ef22017-01-07 20:44:48 -080015import unittest.mock
Łukasz Langa6f692512013-06-05 12:20:24 +020016from weakref import proxy
Nick Coghlan457fc9a2016-09-10 20:00:02 +100017import contextlib
Raymond Hettinger9c323f82005-02-28 19:39:44 +000018
Antoine Pitroub5b37142012-11-13 21:35:40 +010019import functools
20
Antoine Pitroub5b37142012-11-13 21:35:40 +010021py_functools = support.import_fresh_module('functools', blocked=['_functools'])
22c_functools = support.import_fresh_module('functools', fresh=['_functools'])
23
Łukasz Langa6f692512013-06-05 12:20:24 +020024decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
25
Nick Coghlan457fc9a2016-09-10 20:00:02 +100026@contextlib.contextmanager
27def replaced_module(name, replacement):
28 original_module = sys.modules[name]
29 sys.modules[name] = replacement
30 try:
31 yield
32 finally:
33 sys.modules[name] = original_module
Łukasz Langa6f692512013-06-05 12:20:24 +020034
Raymond Hettinger9c323f82005-02-28 19:39:44 +000035def capture(*args, **kw):
36 """capture all positional and keyword arguments"""
37 return args, kw
38
Łukasz Langa6f692512013-06-05 12:20:24 +020039
Jack Diederiche0cbd692009-04-01 04:27:09 +000040def signature(part):
41 """ return the signature of a partial object """
42 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000043
Serhiy Storchaka38741282016-02-02 18:45:17 +020044class MyTuple(tuple):
45 pass
46
47class BadTuple(tuple):
48 def __add__(self, other):
49 return list(self) + list(other)
50
51class MyDict(dict):
52 pass
53
Łukasz Langa6f692512013-06-05 12:20:24 +020054
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020055class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000056
57 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010058 p = self.partial(capture, 1, 2, a=10, b=20)
59 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000060 self.assertEqual(p(3, 4, b=30, c=40),
61 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010062 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000063 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000064
65 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010066 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000067 # attributes should be readable
68 self.assertEqual(p.func, capture)
69 self.assertEqual(p.args, (1, 2))
70 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000071
72 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010073 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000074 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010075 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000076 except TypeError:
77 pass
78 else:
79 self.fail('First arg not checked for callability')
80
81 def test_protection_of_callers_dict_argument(self):
82 # a caller's dictionary should not be altered by partial
83 def func(a=10, b=20):
84 return a
85 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010086 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000087 self.assertEqual(p(**d), 3)
88 self.assertEqual(d, {'a':3})
89 p(b=7)
90 self.assertEqual(d, {'a':3})
91
Serhiy Storchaka9639e4a2017-02-20 14:04:30 +020092 def test_kwargs_copy(self):
93 # Issue #29532: Altering a kwarg dictionary passed to a constructor
94 # should not affect a partial object after creation
95 d = {'a': 3}
96 p = self.partial(capture, **d)
97 self.assertEqual(p(), ((), {'a': 3}))
98 d['a'] = 5
99 self.assertEqual(p(), ((), {'a': 3}))
100
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000101 def test_arg_combinations(self):
102 # exercise special code paths for zero args in either partial
103 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100104 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000105 self.assertEqual(p(), ((), {}))
106 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100107 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000108 self.assertEqual(p(), ((1,2), {}))
109 self.assertEqual(p(3,4), ((1,2,3,4), {}))
110
111 def test_kw_combinations(self):
112 # exercise special code paths for no keyword args in
113 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100114 p = self.partial(capture)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400115 self.assertEqual(p.keywords, {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000116 self.assertEqual(p(), ((), {}))
117 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100118 p = self.partial(capture, a=1)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400119 self.assertEqual(p.keywords, {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000120 self.assertEqual(p(), ((), {'a':1}))
121 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
122 # keyword args in the call override those in the partial object
123 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
124
125 def test_positional(self):
126 # make sure positional arguments are captured correctly
127 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100128 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000129 expected = args + ('x',)
130 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000131 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000132
133 def test_keyword(self):
134 # make sure keyword arguments are captured correctly
135 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100136 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000137 expected = {'a':a,'x':None}
138 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000139 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000140
141 def test_no_side_effects(self):
142 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100143 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000144 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000145 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000146 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000147 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000148
149 def test_error_propagation(self):
150 def f(x, y):
151 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100152 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
153 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
154 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
155 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000156
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000157 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100158 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000159 p = proxy(f)
160 self.assertEqual(f.func, p.func)
161 f = None
162 self.assertRaises(ReferenceError, getattr, p, 'func')
163
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000164 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000165 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100166 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000167 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100168 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000169 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000170
Alexander Belopolskye49af342015-03-01 15:08:17 -0500171 def test_nested_optimization(self):
172 partial = self.partial
Alexander Belopolskye49af342015-03-01 15:08:17 -0500173 inner = partial(signature, 'asdf')
174 nested = partial(inner, bar=True)
175 flat = partial(signature, 'asdf', bar=True)
176 self.assertEqual(signature(nested), signature(flat))
177
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300178 def test_nested_partial_with_attribute(self):
179 # see issue 25137
180 partial = self.partial
181
182 def foo(bar):
183 return bar
184
185 p = partial(foo, 'first')
186 p2 = partial(p, 'second')
187 p2.new_attr = 'spam'
188 self.assertEqual(p2.new_attr, 'spam')
189
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000190 def test_repr(self):
191 args = (object(), object())
192 args_repr = ', '.join(repr(a) for a in args)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200193 kwargs = {'a': object(), 'b': object()}
194 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
195 'b={b!r}, a={a!r}'.format_map(kwargs)]
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000196 if self.partial in (c_functools.partial, py_functools.partial):
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000197 name = 'functools.partial'
198 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100199 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000200
Antoine Pitroub5b37142012-11-13 21:35:40 +0100201 f = self.partial(capture)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000202 self.assertEqual(f'{name}({capture!r})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000203
Antoine Pitroub5b37142012-11-13 21:35:40 +0100204 f = self.partial(capture, *args)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000205 self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000206
Antoine Pitroub5b37142012-11-13 21:35:40 +0100207 f = self.partial(capture, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200208 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000209 [f'{name}({capture!r}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200210 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000211
Antoine Pitroub5b37142012-11-13 21:35:40 +0100212 f = self.partial(capture, *args, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200213 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000214 [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200215 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000216
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300217 def test_recursive_repr(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000218 if self.partial in (c_functools.partial, py_functools.partial):
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300219 name = 'functools.partial'
220 else:
221 name = self.partial.__name__
222
223 f = self.partial(capture)
224 f.__setstate__((f, (), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300225 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000226 self.assertEqual(repr(f), '%s(...)' % (name,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300227 finally:
228 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300229
230 f = self.partial(capture)
231 f.__setstate__((capture, (f,), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300232 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000233 self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300234 finally:
235 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300236
237 f = self.partial(capture)
238 f.__setstate__((capture, (), {'a': f}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300239 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000240 self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300241 finally:
242 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300243
Jack Diederiche0cbd692009-04-01 04:27:09 +0000244 def test_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000245 with self.AllowPickle():
246 f = self.partial(signature, ['asdf'], bar=[True])
247 f.attr = []
248 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
249 f_copy = pickle.loads(pickle.dumps(f, proto))
250 self.assertEqual(signature(f_copy), signature(f))
Serhiy Storchaka38741282016-02-02 18:45:17 +0200251
252 def test_copy(self):
253 f = self.partial(signature, ['asdf'], bar=[True])
254 f.attr = []
255 f_copy = copy.copy(f)
256 self.assertEqual(signature(f_copy), signature(f))
257 self.assertIs(f_copy.attr, f.attr)
258 self.assertIs(f_copy.args, f.args)
259 self.assertIs(f_copy.keywords, f.keywords)
260
261 def test_deepcopy(self):
262 f = self.partial(signature, ['asdf'], bar=[True])
263 f.attr = []
264 f_copy = copy.deepcopy(f)
265 self.assertEqual(signature(f_copy), signature(f))
266 self.assertIsNot(f_copy.attr, f.attr)
267 self.assertIsNot(f_copy.args, f.args)
268 self.assertIsNot(f_copy.args[0], f.args[0])
269 self.assertIsNot(f_copy.keywords, f.keywords)
270 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
271
272 def test_setstate(self):
273 f = self.partial(signature)
274 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000275
Serhiy Storchaka38741282016-02-02 18:45:17 +0200276 self.assertEqual(signature(f),
277 (capture, (1,), dict(a=10), dict(attr=[])))
278 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
279
280 f.__setstate__((capture, (1,), dict(a=10), None))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000281
Serhiy Storchaka38741282016-02-02 18:45:17 +0200282 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
283 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
284
285 f.__setstate__((capture, (1,), None, None))
286 #self.assertEqual(signature(f), (capture, (1,), {}, {}))
287 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
288 self.assertEqual(f(2), ((1, 2), {}))
289 self.assertEqual(f(), ((1,), {}))
290
291 f.__setstate__((capture, (), {}, None))
292 self.assertEqual(signature(f), (capture, (), {}, {}))
293 self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
294 self.assertEqual(f(2), ((2,), {}))
295 self.assertEqual(f(), ((), {}))
296
297 def test_setstate_errors(self):
298 f = self.partial(signature)
299 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
300 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
301 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
302 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
303 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
304 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
305 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
306
307 def test_setstate_subclasses(self):
308 f = self.partial(signature)
309 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
310 s = signature(f)
311 self.assertEqual(s, (capture, (1,), dict(a=10), {}))
312 self.assertIs(type(s[1]), tuple)
313 self.assertIs(type(s[2]), dict)
314 r = f()
315 self.assertEqual(r, ((1,), {'a': 10}))
316 self.assertIs(type(r[0]), tuple)
317 self.assertIs(type(r[1]), dict)
318
319 f.__setstate__((capture, BadTuple((1,)), {}, None))
320 s = signature(f)
321 self.assertEqual(s, (capture, (1,), {}, {}))
322 self.assertIs(type(s[1]), tuple)
323 r = f(2)
324 self.assertEqual(r, ((1, 2), {}))
325 self.assertIs(type(r[0]), tuple)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000326
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300327 def test_recursive_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000328 with self.AllowPickle():
329 f = self.partial(capture)
330 f.__setstate__((f, (), {}, {}))
331 try:
332 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
333 with self.assertRaises(RecursionError):
334 pickle.dumps(f, proto)
335 finally:
336 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300337
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000338 f = self.partial(capture)
339 f.__setstate__((capture, (f,), {}, {}))
340 try:
341 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
342 f_copy = pickle.loads(pickle.dumps(f, proto))
343 try:
344 self.assertIs(f_copy.args[0], f_copy)
345 finally:
346 f_copy.__setstate__((capture, (), {}, {}))
347 finally:
348 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300349
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000350 f = self.partial(capture)
351 f.__setstate__((capture, (), {'a': f}, {}))
352 try:
353 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
354 f_copy = pickle.loads(pickle.dumps(f, proto))
355 try:
356 self.assertIs(f_copy.keywords['a'], f_copy)
357 finally:
358 f_copy.__setstate__((capture, (), {}, {}))
359 finally:
360 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300361
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200362 # Issue 6083: Reference counting bug
363 def test_setstate_refcount(self):
364 class BadSequence:
365 def __len__(self):
366 return 4
367 def __getitem__(self, key):
368 if key == 0:
369 return max
370 elif key == 1:
371 return tuple(range(1000000))
372 elif key in (2, 3):
373 return {}
374 raise IndexError
375
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200376 f = self.partial(object)
Serhiy Storchaka38741282016-02-02 18:45:17 +0200377 self.assertRaises(TypeError, f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000378
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000379@unittest.skipUnless(c_functools, 'requires the C _functools module')
380class TestPartialC(TestPartial, unittest.TestCase):
381 if c_functools:
382 partial = c_functools.partial
383
384 class AllowPickle:
385 def __enter__(self):
386 return self
387 def __exit__(self, type, value, tb):
388 return False
389
390 def test_attributes_unwritable(self):
391 # attributes should not be writable
392 p = self.partial(capture, 1, 2, a=10, b=20)
393 self.assertRaises(AttributeError, setattr, p, 'func', map)
394 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
395 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
396
397 p = self.partial(hex)
398 try:
399 del p.__dict__
400 except TypeError:
401 pass
402 else:
403 self.fail('partial object allowed __dict__ to be deleted')
Łukasz Langa6f692512013-06-05 12:20:24 +0200404
Michael Seifert6c3d5272017-03-15 06:26:33 +0100405 def test_manually_adding_non_string_keyword(self):
406 p = self.partial(capture)
407 # Adding a non-string/unicode keyword to partial kwargs
408 p.keywords[1234] = 'value'
409 r = repr(p)
410 self.assertIn('1234', r)
411 self.assertIn("'value'", r)
412 with self.assertRaises(TypeError):
413 p()
414
415 def test_keystr_replaces_value(self):
416 p = self.partial(capture)
417
418 class MutatesYourDict(object):
419 def __str__(self):
420 p.keywords[self] = ['sth2']
421 return 'astr'
422
Mike53f7a7c2017-12-14 14:04:53 +0300423 # Replacing the value during key formatting should keep the original
Michael Seifert6c3d5272017-03-15 06:26:33 +0100424 # value alive (at least long enough).
425 p.keywords[MutatesYourDict()] = ['sth']
426 r = repr(p)
427 self.assertIn('astr', r)
428 self.assertIn("['sth']", r)
429
430
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200431class TestPartialPy(TestPartial, unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000432 partial = py_functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000433
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000434 class AllowPickle:
435 def __init__(self):
436 self._cm = replaced_module("functools", py_functools)
437 def __enter__(self):
438 return self._cm.__enter__()
439 def __exit__(self, type, value, tb):
440 return self._cm.__exit__(type, value, tb)
Łukasz Langa6f692512013-06-05 12:20:24 +0200441
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200442if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000443 class CPartialSubclass(c_functools.partial):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200444 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100445
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000446class PyPartialSubclass(py_functools.partial):
447 pass
Łukasz Langa6f692512013-06-05 12:20:24 +0200448
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200449@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200450class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200451 if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000452 partial = CPartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000453
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300454 # partial subclasses are not optimized for nested calls
455 test_nested_optimization = None
456
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000457class TestPartialPySubclass(TestPartialPy):
458 partial = PyPartialSubclass
Łukasz Langa6f692512013-06-05 12:20:24 +0200459
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000460class TestPartialMethod(unittest.TestCase):
461
462 class A(object):
463 nothing = functools.partialmethod(capture)
464 positional = functools.partialmethod(capture, 1)
465 keywords = functools.partialmethod(capture, a=2)
466 both = functools.partialmethod(capture, 3, b=4)
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300467 spec_keywords = functools.partialmethod(capture, self=1, func=2)
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000468
469 nested = functools.partialmethod(positional, 5)
470
471 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
472
473 static = functools.partialmethod(staticmethod(capture), 8)
474 cls = functools.partialmethod(classmethod(capture), d=9)
475
476 a = A()
477
478 def test_arg_combinations(self):
479 self.assertEqual(self.a.nothing(), ((self.a,), {}))
480 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
481 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
482 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
483
484 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
485 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
486 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
487 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
488
489 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
490 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
491 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
492 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
493
494 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
495 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
496 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
497 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
498
499 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
500
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300501 self.assertEqual(self.a.spec_keywords(), ((self.a,), {'self': 1, 'func': 2}))
502
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000503 def test_nested(self):
504 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
505 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
506 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
507 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
508
509 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
510
511 def test_over_partial(self):
512 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
513 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
514 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
515 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
516
517 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
518
519 def test_bound_method_introspection(self):
520 obj = self.a
521 self.assertIs(obj.both.__self__, obj)
522 self.assertIs(obj.nested.__self__, obj)
523 self.assertIs(obj.over_partial.__self__, obj)
524 self.assertIs(obj.cls.__self__, self.A)
525 self.assertIs(self.A.cls.__self__, self.A)
526
527 def test_unbound_method_retrieval(self):
528 obj = self.A
529 self.assertFalse(hasattr(obj.both, "__self__"))
530 self.assertFalse(hasattr(obj.nested, "__self__"))
531 self.assertFalse(hasattr(obj.over_partial, "__self__"))
532 self.assertFalse(hasattr(obj.static, "__self__"))
533 self.assertFalse(hasattr(self.a.static, "__self__"))
534
535 def test_descriptors(self):
536 for obj in [self.A, self.a]:
537 with self.subTest(obj=obj):
538 self.assertEqual(obj.static(), ((8,), {}))
539 self.assertEqual(obj.static(5), ((8, 5), {}))
540 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
541 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
542
543 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
544 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
545 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
546 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
547
548 def test_overriding_keywords(self):
549 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
550 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
551
552 def test_invalid_args(self):
553 with self.assertRaises(TypeError):
554 class B(object):
555 method = functools.partialmethod(None, 1)
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300556 with self.assertRaises(TypeError):
557 class B:
558 method = functools.partialmethod()
559 with self.assertWarns(DeprecationWarning):
560 class B:
561 method = functools.partialmethod(func=capture, a=1)
562 b = B()
563 self.assertEqual(b.method(2, x=3), ((b, 2), {'a': 1, 'x': 3}))
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000564
565 def test_repr(self):
566 self.assertEqual(repr(vars(self.A)['both']),
567 'functools.partialmethod({}, 3, b=4)'.format(capture))
568
569 def test_abstract(self):
570 class Abstract(abc.ABCMeta):
571
572 @abc.abstractmethod
573 def add(self, x, y):
574 pass
575
576 add5 = functools.partialmethod(add, 5)
577
578 self.assertTrue(Abstract.add.__isabstractmethod__)
579 self.assertTrue(Abstract.add5.__isabstractmethod__)
580
581 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
582 self.assertFalse(getattr(func, '__isabstractmethod__', False))
583
Pablo Galindo8c77b8c2019-04-29 13:36:57 +0100584 def test_positional_only(self):
585 def f(a, b, /):
586 return a + b
587
588 p = functools.partial(f, 1)
589 self.assertEqual(p(2), f(1, 2))
590
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000591
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000592class TestUpdateWrapper(unittest.TestCase):
593
594 def check_wrapper(self, wrapper, wrapped,
595 assigned=functools.WRAPPER_ASSIGNMENTS,
596 updated=functools.WRAPPER_UPDATES):
597 # Check attributes were assigned
598 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000599 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000600 # Check attributes were updated
601 for name in updated:
602 wrapper_attr = getattr(wrapper, name)
603 wrapped_attr = getattr(wrapped, name)
604 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000605 if name == "__dict__" and key == "__wrapped__":
606 # __wrapped__ is overwritten by the update code
607 continue
608 self.assertIs(wrapped_attr[key], wrapper_attr[key])
609 # Check __wrapped__
610 self.assertIs(wrapper.__wrapped__, wrapped)
611
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000612
R. David Murray378c0cf2010-02-24 01:46:21 +0000613 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000614 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000615 """This is a test"""
616 pass
617 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000618 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000619 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000620 pass
621 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000622 return wrapper, f
623
624 def test_default_update(self):
625 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000626 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000627 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000628 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600629 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000630 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000631 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
632 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000633
R. David Murray378c0cf2010-02-24 01:46:21 +0000634 @unittest.skipIf(sys.flags.optimize >= 2,
635 "Docstrings are omitted with -O2 and above")
636 def test_default_update_doc(self):
637 wrapper, f = self._default_update()
638 self.assertEqual(wrapper.__doc__, 'This is a test')
639
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000640 def test_no_update(self):
641 def f():
642 """This is a test"""
643 pass
644 f.attr = 'This is also a test'
645 def wrapper():
646 pass
647 functools.update_wrapper(wrapper, f, (), ())
648 self.check_wrapper(wrapper, f, (), ())
649 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600650 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000651 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000652 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000653 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000654
655 def test_selective_update(self):
656 def f():
657 pass
658 f.attr = 'This is a different test'
659 f.dict_attr = dict(a=1, b=2, c=3)
660 def wrapper():
661 pass
662 wrapper.dict_attr = {}
663 assign = ('attr',)
664 update = ('dict_attr',)
665 functools.update_wrapper(wrapper, f, assign, update)
666 self.check_wrapper(wrapper, f, assign, update)
667 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600668 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000669 self.assertEqual(wrapper.__doc__, None)
670 self.assertEqual(wrapper.attr, 'This is a different test')
671 self.assertEqual(wrapper.dict_attr, f.dict_attr)
672
Nick Coghlan98876832010-08-17 06:17:18 +0000673 def test_missing_attributes(self):
674 def f():
675 pass
676 def wrapper():
677 pass
678 wrapper.dict_attr = {}
679 assign = ('attr',)
680 update = ('dict_attr',)
681 # Missing attributes on wrapped object are ignored
682 functools.update_wrapper(wrapper, f, assign, update)
683 self.assertNotIn('attr', wrapper.__dict__)
684 self.assertEqual(wrapper.dict_attr, {})
685 # Wrapper must have expected attributes for updating
686 del wrapper.dict_attr
687 with self.assertRaises(AttributeError):
688 functools.update_wrapper(wrapper, f, assign, update)
689 wrapper.dict_attr = 1
690 with self.assertRaises(AttributeError):
691 functools.update_wrapper(wrapper, f, assign, update)
692
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200693 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000694 @unittest.skipIf(sys.flags.optimize >= 2,
695 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000696 def test_builtin_update(self):
697 # Test for bug #1576241
698 def wrapper():
699 pass
700 functools.update_wrapper(wrapper, max)
701 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000702 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000703 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000704
Łukasz Langa6f692512013-06-05 12:20:24 +0200705
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000706class TestWraps(TestUpdateWrapper):
707
R. David Murray378c0cf2010-02-24 01:46:21 +0000708 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000709 def f():
710 """This is a test"""
711 pass
712 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000713 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000714 @functools.wraps(f)
715 def wrapper():
716 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600717 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000718
719 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600720 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000721 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000722 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600723 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000724 self.assertEqual(wrapper.attr, 'This is also a test')
725
Antoine Pitroub5b37142012-11-13 21:35:40 +0100726 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000727 "Docstrings are omitted with -O2 and above")
728 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600729 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000730 self.assertEqual(wrapper.__doc__, 'This is a test')
731
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000732 def test_no_update(self):
733 def f():
734 """This is a test"""
735 pass
736 f.attr = 'This is also a test'
737 @functools.wraps(f, (), ())
738 def wrapper():
739 pass
740 self.check_wrapper(wrapper, f, (), ())
741 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600742 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000743 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000744 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000745
746 def test_selective_update(self):
747 def f():
748 pass
749 f.attr = 'This is a different test'
750 f.dict_attr = dict(a=1, b=2, c=3)
751 def add_dict_attr(f):
752 f.dict_attr = {}
753 return f
754 assign = ('attr',)
755 update = ('dict_attr',)
756 @functools.wraps(f, assign, update)
757 @add_dict_attr
758 def wrapper():
759 pass
760 self.check_wrapper(wrapper, f, assign, update)
761 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600762 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000763 self.assertEqual(wrapper.__doc__, None)
764 self.assertEqual(wrapper.attr, 'This is a different test')
765 self.assertEqual(wrapper.dict_attr, f.dict_attr)
766
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000767
madman-bobe25d5fc2018-10-25 15:02:10 +0100768class TestReduce:
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000769 def test_reduce(self):
770 class Squares:
771 def __init__(self, max):
772 self.max = max
773 self.sofar = []
774
775 def __len__(self):
776 return len(self.sofar)
777
778 def __getitem__(self, i):
779 if not 0 <= i < self.max: raise IndexError
780 n = len(self.sofar)
781 while n <= i:
782 self.sofar.append(n*n)
783 n += 1
784 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000785 def add(x, y):
786 return x + y
madman-bobe25d5fc2018-10-25 15:02:10 +0100787 self.assertEqual(self.reduce(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000788 self.assertEqual(
madman-bobe25d5fc2018-10-25 15:02:10 +0100789 self.reduce(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000790 ['a','c','d','w']
791 )
madman-bobe25d5fc2018-10-25 15:02:10 +0100792 self.assertEqual(self.reduce(lambda x, y: x*y, range(2,8), 1), 5040)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000793 self.assertEqual(
madman-bobe25d5fc2018-10-25 15:02:10 +0100794 self.reduce(lambda x, y: x*y, range(2,21), 1),
Guido van Rossume2a383d2007-01-15 16:59:06 +0000795 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000796 )
madman-bobe25d5fc2018-10-25 15:02:10 +0100797 self.assertEqual(self.reduce(add, Squares(10)), 285)
798 self.assertEqual(self.reduce(add, Squares(10), 0), 285)
799 self.assertEqual(self.reduce(add, Squares(0), 0), 0)
800 self.assertRaises(TypeError, self.reduce)
801 self.assertRaises(TypeError, self.reduce, 42, 42)
802 self.assertRaises(TypeError, self.reduce, 42, 42, 42)
803 self.assertEqual(self.reduce(42, "1"), "1") # func is never called with one item
804 self.assertEqual(self.reduce(42, "", "1"), "1") # func is never called with one item
805 self.assertRaises(TypeError, self.reduce, 42, (42, 42))
806 self.assertRaises(TypeError, self.reduce, add, []) # arg 2 must not be empty sequence with no initial value
807 self.assertRaises(TypeError, self.reduce, add, "")
808 self.assertRaises(TypeError, self.reduce, add, ())
809 self.assertRaises(TypeError, self.reduce, add, object())
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000810
811 class TestFailingIter:
812 def __iter__(self):
813 raise RuntimeError
madman-bobe25d5fc2018-10-25 15:02:10 +0100814 self.assertRaises(RuntimeError, self.reduce, add, TestFailingIter())
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000815
madman-bobe25d5fc2018-10-25 15:02:10 +0100816 self.assertEqual(self.reduce(add, [], None), None)
817 self.assertEqual(self.reduce(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000818
819 class BadSeq:
820 def __getitem__(self, index):
821 raise ValueError
madman-bobe25d5fc2018-10-25 15:02:10 +0100822 self.assertRaises(ValueError, self.reduce, 42, BadSeq())
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000823
824 # Test reduce()'s use of iterators.
825 def test_iterator_usage(self):
826 class SequenceClass:
827 def __init__(self, n):
828 self.n = n
829 def __getitem__(self, i):
830 if 0 <= i < self.n:
831 return i
832 else:
833 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000834
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000835 from operator import add
madman-bobe25d5fc2018-10-25 15:02:10 +0100836 self.assertEqual(self.reduce(add, SequenceClass(5)), 10)
837 self.assertEqual(self.reduce(add, SequenceClass(5), 42), 52)
838 self.assertRaises(TypeError, self.reduce, add, SequenceClass(0))
839 self.assertEqual(self.reduce(add, SequenceClass(0), 42), 42)
840 self.assertEqual(self.reduce(add, SequenceClass(1)), 0)
841 self.assertEqual(self.reduce(add, SequenceClass(1), 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000842
843 d = {"one": 1, "two": 2, "three": 3}
madman-bobe25d5fc2018-10-25 15:02:10 +0100844 self.assertEqual(self.reduce(add, d), "".join(d.keys()))
845
846
847@unittest.skipUnless(c_functools, 'requires the C _functools module')
848class TestReduceC(TestReduce, unittest.TestCase):
849 if c_functools:
850 reduce = c_functools.reduce
851
852
853class TestReducePy(TestReduce, unittest.TestCase):
854 reduce = staticmethod(py_functools.reduce)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000855
Łukasz Langa6f692512013-06-05 12:20:24 +0200856
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200857class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700858
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000859 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700860 def cmp1(x, y):
861 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100862 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700863 self.assertEqual(key(3), key(3))
864 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100865 self.assertGreaterEqual(key(3), key(3))
866
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700867 def cmp2(x, y):
868 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100869 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700870 self.assertEqual(key(4.0), key('4'))
871 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100872 self.assertLessEqual(key(2), key('35'))
873 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700874
875 def test_cmp_to_key_arguments(self):
876 def cmp1(x, y):
877 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100878 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700879 self.assertEqual(key(obj=3), key(obj=3))
880 self.assertGreater(key(obj=3), key(obj=1))
881 with self.assertRaises((TypeError, AttributeError)):
882 key(3) > 1 # rhs is not a K object
883 with self.assertRaises((TypeError, AttributeError)):
884 1 < key(3) # lhs is not a K object
885 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100886 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700887 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200888 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100889 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700890 with self.assertRaises(TypeError):
891 key() # too few args
892 with self.assertRaises(TypeError):
893 key(None, None) # too many args
894
895 def test_bad_cmp(self):
896 def cmp1(x, y):
897 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100898 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700899 with self.assertRaises(ZeroDivisionError):
900 key(3) > key(1)
901
902 class BadCmp:
903 def __lt__(self, other):
904 raise ZeroDivisionError
905 def cmp1(x, y):
906 return BadCmp()
907 with self.assertRaises(ZeroDivisionError):
908 key(3) > key(1)
909
910 def test_obj_field(self):
911 def cmp1(x, y):
912 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100913 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700914 self.assertEqual(key(50).obj, 50)
915
916 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000917 def mycmp(x, y):
918 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100919 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000920 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000921
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700922 def test_sort_int_str(self):
923 def mycmp(x, y):
924 x, y = int(x), int(y)
925 return (x > y) - (x < y)
926 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100927 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700928 self.assertEqual([int(value) for value in values],
929 [0, 1, 1, 2, 3, 4, 5, 7, 10])
930
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000931 def test_hash(self):
932 def mycmp(x, y):
933 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100934 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000935 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700936 self.assertRaises(TypeError, hash, k)
Serhiy Storchaka2e576f52017-04-24 09:05:00 +0300937 self.assertNotIsInstance(k, collections.abc.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000938
Łukasz Langa6f692512013-06-05 12:20:24 +0200939
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200940@unittest.skipUnless(c_functools, 'requires the C _functools module')
941class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
942 if c_functools:
943 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100944
Łukasz Langa6f692512013-06-05 12:20:24 +0200945
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200946class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100947 cmp_to_key = staticmethod(py_functools.cmp_to_key)
948
Łukasz Langa6f692512013-06-05 12:20:24 +0200949
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000950class TestTotalOrdering(unittest.TestCase):
951
952 def test_total_ordering_lt(self):
953 @functools.total_ordering
954 class A:
955 def __init__(self, value):
956 self.value = value
957 def __lt__(self, other):
958 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000959 def __eq__(self, other):
960 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000961 self.assertTrue(A(1) < A(2))
962 self.assertTrue(A(2) > A(1))
963 self.assertTrue(A(1) <= A(2))
964 self.assertTrue(A(2) >= A(1))
965 self.assertTrue(A(2) <= A(2))
966 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000967 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000968
969 def test_total_ordering_le(self):
970 @functools.total_ordering
971 class A:
972 def __init__(self, value):
973 self.value = value
974 def __le__(self, other):
975 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000976 def __eq__(self, other):
977 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000978 self.assertTrue(A(1) < A(2))
979 self.assertTrue(A(2) > A(1))
980 self.assertTrue(A(1) <= A(2))
981 self.assertTrue(A(2) >= A(1))
982 self.assertTrue(A(2) <= A(2))
983 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000984 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000985
986 def test_total_ordering_gt(self):
987 @functools.total_ordering
988 class A:
989 def __init__(self, value):
990 self.value = value
991 def __gt__(self, other):
992 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000993 def __eq__(self, other):
994 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000995 self.assertTrue(A(1) < A(2))
996 self.assertTrue(A(2) > A(1))
997 self.assertTrue(A(1) <= A(2))
998 self.assertTrue(A(2) >= A(1))
999 self.assertTrue(A(2) <= A(2))
1000 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +10001001 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +00001002
1003 def test_total_ordering_ge(self):
1004 @functools.total_ordering
1005 class A:
1006 def __init__(self, value):
1007 self.value = value
1008 def __ge__(self, other):
1009 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001010 def __eq__(self, other):
1011 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +00001012 self.assertTrue(A(1) < A(2))
1013 self.assertTrue(A(2) > A(1))
1014 self.assertTrue(A(1) <= A(2))
1015 self.assertTrue(A(2) >= A(1))
1016 self.assertTrue(A(2) <= A(2))
1017 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +10001018 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +00001019
1020 def test_total_ordering_no_overwrite(self):
1021 # new methods should not overwrite existing
1022 @functools.total_ordering
1023 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +00001024 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +00001025 self.assertTrue(A(1) < A(2))
1026 self.assertTrue(A(2) > A(1))
1027 self.assertTrue(A(1) <= A(2))
1028 self.assertTrue(A(2) >= A(1))
1029 self.assertTrue(A(2) <= A(2))
1030 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001031
Benjamin Peterson42ebee32010-04-11 01:43:16 +00001032 def test_no_operations_defined(self):
1033 with self.assertRaises(ValueError):
1034 @functools.total_ordering
1035 class A:
1036 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001037
Nick Coghlanf05d9812013-10-02 00:02:03 +10001038 def test_type_error_when_not_implemented(self):
1039 # bug 10042; ensure stack overflow does not occur
1040 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001041 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +10001042 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001043 def __init__(self, value):
1044 self.value = value
1045 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001046 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001047 return self.value == other.value
1048 return False
1049 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001050 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001051 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +10001052 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001053
Nick Coghlanf05d9812013-10-02 00:02:03 +10001054 @functools.total_ordering
1055 class ImplementsGreaterThan:
1056 def __init__(self, value):
1057 self.value = value
1058 def __eq__(self, other):
1059 if isinstance(other, ImplementsGreaterThan):
1060 return self.value == other.value
1061 return False
1062 def __gt__(self, other):
1063 if isinstance(other, ImplementsGreaterThan):
1064 return self.value > other.value
1065 return NotImplemented
1066
1067 @functools.total_ordering
1068 class ImplementsLessThanEqualTo:
1069 def __init__(self, value):
1070 self.value = value
1071 def __eq__(self, other):
1072 if isinstance(other, ImplementsLessThanEqualTo):
1073 return self.value == other.value
1074 return False
1075 def __le__(self, other):
1076 if isinstance(other, ImplementsLessThanEqualTo):
1077 return self.value <= other.value
1078 return NotImplemented
1079
1080 @functools.total_ordering
1081 class ImplementsGreaterThanEqualTo:
1082 def __init__(self, value):
1083 self.value = value
1084 def __eq__(self, other):
1085 if isinstance(other, ImplementsGreaterThanEqualTo):
1086 return self.value == other.value
1087 return False
1088 def __ge__(self, other):
1089 if isinstance(other, ImplementsGreaterThanEqualTo):
1090 return self.value >= other.value
1091 return NotImplemented
1092
1093 @functools.total_ordering
1094 class ComparatorNotImplemented:
1095 def __init__(self, value):
1096 self.value = value
1097 def __eq__(self, other):
1098 if isinstance(other, ComparatorNotImplemented):
1099 return self.value == other.value
1100 return False
1101 def __lt__(self, other):
1102 return NotImplemented
1103
1104 with self.subTest("LT < 1"), self.assertRaises(TypeError):
1105 ImplementsLessThan(-1) < 1
1106
1107 with self.subTest("LT < LE"), self.assertRaises(TypeError):
1108 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
1109
1110 with self.subTest("LT < GT"), self.assertRaises(TypeError):
1111 ImplementsLessThan(1) < ImplementsGreaterThan(1)
1112
1113 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
1114 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
1115
1116 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
1117 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
1118
1119 with self.subTest("GT > GE"), self.assertRaises(TypeError):
1120 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
1121
1122 with self.subTest("GT > LT"), self.assertRaises(TypeError):
1123 ImplementsGreaterThan(5) > ImplementsLessThan(5)
1124
1125 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
1126 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
1127
1128 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
1129 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
1130
1131 with self.subTest("GE when equal"):
1132 a = ComparatorNotImplemented(8)
1133 b = ComparatorNotImplemented(8)
1134 self.assertEqual(a, b)
1135 with self.assertRaises(TypeError):
1136 a >= b
1137
1138 with self.subTest("LE when equal"):
1139 a = ComparatorNotImplemented(9)
1140 b = ComparatorNotImplemented(9)
1141 self.assertEqual(a, b)
1142 with self.assertRaises(TypeError):
1143 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +02001144
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001145 def test_pickle(self):
Serhiy Storchaka92bb90a2016-09-22 11:39:25 +03001146 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001147 for name in '__lt__', '__gt__', '__le__', '__ge__':
1148 with self.subTest(method=name, proto=proto):
1149 method = getattr(Orderable_LT, name)
1150 method_copy = pickle.loads(pickle.dumps(method, proto))
1151 self.assertIs(method_copy, method)
1152
1153@functools.total_ordering
1154class Orderable_LT:
1155 def __init__(self, value):
1156 self.value = value
1157 def __lt__(self, other):
1158 return self.value < other.value
1159 def __eq__(self, other):
1160 return self.value == other.value
1161
1162
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001163class TestLRU:
Georg Brandl2e7346a2010-07-31 18:09:23 +00001164
1165 def test_lru(self):
1166 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001167 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001168 f = self.module.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001169 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001170 self.assertEqual(maxsize, 20)
1171 self.assertEqual(currsize, 0)
1172 self.assertEqual(hits, 0)
1173 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001174
1175 domain = range(5)
1176 for i in range(1000):
1177 x, y = choice(domain), choice(domain)
1178 actual = f(x, y)
1179 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +00001180 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001181 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001182 self.assertTrue(hits > misses)
1183 self.assertEqual(hits + misses, 1000)
1184 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001185
Raymond Hettinger02566ec2010-09-04 22:46:06 +00001186 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +00001187 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001188 self.assertEqual(hits, 0)
1189 self.assertEqual(misses, 0)
1190 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001191 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001192 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001193 self.assertEqual(hits, 0)
1194 self.assertEqual(misses, 1)
1195 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001196
Nick Coghlan98876832010-08-17 06:17:18 +00001197 # Test bypassing the cache
1198 self.assertIs(f.__wrapped__, orig)
1199 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001200 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001201 self.assertEqual(hits, 0)
1202 self.assertEqual(misses, 1)
1203 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +00001204
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001205 # test size zero (which means "never-cache")
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001206 @self.module.lru_cache(0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001207 def f():
1208 nonlocal f_cnt
1209 f_cnt += 1
1210 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001211 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001212 f_cnt = 0
1213 for i in range(5):
1214 self.assertEqual(f(), 20)
1215 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001216 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001217 self.assertEqual(hits, 0)
1218 self.assertEqual(misses, 5)
1219 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001220
1221 # test size one
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001222 @self.module.lru_cache(1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001223 def f():
1224 nonlocal f_cnt
1225 f_cnt += 1
1226 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001227 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001228 f_cnt = 0
1229 for i in range(5):
1230 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001231 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001232 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001233 self.assertEqual(hits, 4)
1234 self.assertEqual(misses, 1)
1235 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001236
Raymond Hettingerf3098282010-08-15 03:30:45 +00001237 # test size two
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001238 @self.module.lru_cache(2)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001239 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001240 nonlocal f_cnt
1241 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +00001242 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +00001243 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001244 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +00001245 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1246 # * * * *
1247 self.assertEqual(f(x), x*10)
1248 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001249 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001250 self.assertEqual(hits, 12)
1251 self.assertEqual(misses, 4)
1252 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001253
Raymond Hettingerd8080c02019-01-26 03:02:00 -05001254 def test_lru_bug_35780(self):
1255 # C version of the lru_cache was not checking to see if
1256 # the user function call has already modified the cache
1257 # (this arises in recursive calls and in multi-threading).
1258 # This cause the cache to have orphan links not referenced
1259 # by the cache dictionary.
1260
1261 once = True # Modified by f(x) below
1262
1263 @self.module.lru_cache(maxsize=10)
1264 def f(x):
1265 nonlocal once
1266 rv = f'.{x}.'
1267 if x == 20 and once:
1268 once = False
1269 rv = f(x)
1270 return rv
1271
1272 # Fill the cache
1273 for x in range(15):
1274 self.assertEqual(f(x), f'.{x}.')
1275 self.assertEqual(f.cache_info().currsize, 10)
1276
1277 # Make a recursive call and make sure the cache remains full
1278 self.assertEqual(f(20), '.20.')
1279 self.assertEqual(f.cache_info().currsize, 10)
1280
Raymond Hettinger14adbd42019-04-20 07:20:44 -10001281 def test_lru_bug_36650(self):
1282 # C version of lru_cache was treating a call with an empty **kwargs
1283 # dictionary as being distinct from a call with no keywords at all.
1284 # This did not result in an incorrect answer, but it did trigger
1285 # an unexpected cache miss.
1286
1287 @self.module.lru_cache()
1288 def f(x):
1289 pass
1290
1291 f(0)
1292 f(0, **{})
1293 self.assertEqual(f.cache_info().hits, 1)
1294
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001295 def test_lru_hash_only_once(self):
1296 # To protect against weird reentrancy bugs and to improve
1297 # efficiency when faced with slow __hash__ methods, the
1298 # LRU cache guarantees that it will only call __hash__
1299 # only once per use as an argument to the cached function.
1300
1301 @self.module.lru_cache(maxsize=1)
1302 def f(x, y):
1303 return x * 3 + y
1304
1305 # Simulate the integer 5
1306 mock_int = unittest.mock.Mock()
1307 mock_int.__mul__ = unittest.mock.Mock(return_value=15)
1308 mock_int.__hash__ = unittest.mock.Mock(return_value=999)
1309
1310 # Add to cache: One use as an argument gives one call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001311 self.assertEqual(f(mock_int, 1), 16)
1312 self.assertEqual(mock_int.__hash__.call_count, 1)
1313 self.assertEqual(f.cache_info(), (0, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001314
1315 # Cache hit: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001316 self.assertEqual(f(mock_int, 1), 16)
1317 self.assertEqual(mock_int.__hash__.call_count, 2)
1318 self.assertEqual(f.cache_info(), (1, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001319
Ville Skyttä49b27342017-08-03 09:00:59 +03001320 # Cache eviction: No use as an argument gives no additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001321 self.assertEqual(f(6, 2), 20)
1322 self.assertEqual(mock_int.__hash__.call_count, 2)
1323 self.assertEqual(f.cache_info(), (1, 2, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001324
1325 # Cache miss: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001326 self.assertEqual(f(mock_int, 1), 16)
1327 self.assertEqual(mock_int.__hash__.call_count, 3)
1328 self.assertEqual(f.cache_info(), (1, 3, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001329
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08001330 def test_lru_reentrancy_with_len(self):
1331 # Test to make sure the LRU cache code isn't thrown-off by
1332 # caching the built-in len() function. Since len() can be
1333 # cached, we shouldn't use it inside the lru code itself.
1334 old_len = builtins.len
1335 try:
1336 builtins.len = self.module.lru_cache(4)(len)
1337 for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
1338 self.assertEqual(len('abcdefghijklmn'[:i]), i)
1339 finally:
1340 builtins.len = old_len
1341
Raymond Hettinger605a4472017-01-09 07:50:19 -08001342 def test_lru_star_arg_handling(self):
1343 # Test regression that arose in ea064ff3c10f
1344 @functools.lru_cache()
1345 def f(*args):
1346 return args
1347
1348 self.assertEqual(f(1, 2), (1, 2))
1349 self.assertEqual(f((1, 2)), ((1, 2),))
1350
Yury Selivanov46a02db2016-11-09 18:55:45 -05001351 def test_lru_type_error(self):
1352 # Regression test for issue #28653.
1353 # lru_cache was leaking when one of the arguments
1354 # wasn't cacheable.
1355
1356 @functools.lru_cache(maxsize=None)
1357 def infinite_cache(o):
1358 pass
1359
1360 @functools.lru_cache(maxsize=10)
1361 def limited_cache(o):
1362 pass
1363
1364 with self.assertRaises(TypeError):
1365 infinite_cache([])
1366
1367 with self.assertRaises(TypeError):
1368 limited_cache([])
1369
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001370 def test_lru_with_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001371 @self.module.lru_cache(maxsize=None)
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001372 def fib(n):
1373 if n < 2:
1374 return n
1375 return fib(n-1) + fib(n-2)
1376 self.assertEqual([fib(n) for n in range(16)],
1377 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1378 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001379 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001380 fib.cache_clear()
1381 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001382 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1383
1384 def test_lru_with_maxsize_negative(self):
1385 @self.module.lru_cache(maxsize=-10)
1386 def eq(n):
1387 return n
1388 for i in (0, 1):
1389 self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1390 self.assertEqual(eq.cache_info(),
Raymond Hettingerd8080c02019-01-26 03:02:00 -05001391 self.module._CacheInfo(hits=0, misses=300, maxsize=0, currsize=0))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001392
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001393 def test_lru_with_exceptions(self):
1394 # Verify that user_function exceptions get passed through without
1395 # creating a hard-to-read chained exception.
1396 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +01001397 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001398 @self.module.lru_cache(maxsize)
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001399 def func(i):
1400 return 'abc'[i]
1401 self.assertEqual(func(0), 'a')
1402 with self.assertRaises(IndexError) as cm:
1403 func(15)
1404 self.assertIsNone(cm.exception.__context__)
1405 # Verify that the previous exception did not result in a cached entry
1406 with self.assertRaises(IndexError):
1407 func(15)
1408
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001409 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001410 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001411 @self.module.lru_cache(maxsize=maxsize, typed=True)
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001412 def square(x):
1413 return x * x
1414 self.assertEqual(square(3), 9)
1415 self.assertEqual(type(square(3)), type(9))
1416 self.assertEqual(square(3.0), 9.0)
1417 self.assertEqual(type(square(3.0)), type(9.0))
1418 self.assertEqual(square(x=3), 9)
1419 self.assertEqual(type(square(x=3)), type(9))
1420 self.assertEqual(square(x=3.0), 9.0)
1421 self.assertEqual(type(square(x=3.0)), type(9.0))
1422 self.assertEqual(square.cache_info().hits, 4)
1423 self.assertEqual(square.cache_info().misses, 4)
1424
Antoine Pitroub5b37142012-11-13 21:35:40 +01001425 def test_lru_with_keyword_args(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001426 @self.module.lru_cache()
Antoine Pitroub5b37142012-11-13 21:35:40 +01001427 def fib(n):
1428 if n < 2:
1429 return n
1430 return fib(n=n-1) + fib(n=n-2)
1431 self.assertEqual(
1432 [fib(n=number) for number in range(16)],
1433 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1434 )
1435 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001436 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001437 fib.cache_clear()
1438 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001439 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001440
1441 def test_lru_with_keyword_args_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001442 @self.module.lru_cache(maxsize=None)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001443 def fib(n):
1444 if n < 2:
1445 return n
1446 return fib(n=n-1) + fib(n=n-2)
1447 self.assertEqual([fib(n=number) for number in range(16)],
1448 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1449 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001450 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001451 fib.cache_clear()
1452 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001453 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1454
Raymond Hettinger4ee39142017-01-08 17:28:20 -08001455 def test_kwargs_order(self):
1456 # PEP 468: Preserving Keyword Argument Order
1457 @self.module.lru_cache(maxsize=10)
1458 def f(**kwargs):
1459 return list(kwargs.items())
1460 self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)])
1461 self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)])
1462 self.assertEqual(f.cache_info(),
1463 self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2))
1464
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001465 def test_lru_cache_decoration(self):
1466 def f(zomg: 'zomg_annotation'):
1467 """f doc string"""
1468 return 42
1469 g = self.module.lru_cache()(f)
1470 for attr in self.module.WRAPPER_ASSIGNMENTS:
1471 self.assertEqual(getattr(g, attr), getattr(f, attr))
1472
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001473 def test_lru_cache_threaded(self):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001474 n, m = 5, 11
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001475 def orig(x, y):
1476 return 3 * x + y
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001477 f = self.module.lru_cache(maxsize=n*m)(orig)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001478 hits, misses, maxsize, currsize = f.cache_info()
1479 self.assertEqual(currsize, 0)
1480
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001481 start = threading.Event()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001482 def full(k):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001483 start.wait(10)
1484 for _ in range(m):
Serhiy Storchaka391af752015-06-08 12:44:18 +03001485 self.assertEqual(f(k, 0), orig(k, 0))
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001486
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001487 def clear():
1488 start.wait(10)
1489 for _ in range(2*m):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001490 f.cache_clear()
1491
1492 orig_si = sys.getswitchinterval()
Xavier de Gaye7522ef42016-12-08 11:06:56 +01001493 support.setswitchinterval(1e-6)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001494 try:
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001495 # create n threads in order to fill cache
Serhiy Storchaka391af752015-06-08 12:44:18 +03001496 threads = [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001497 for k in range(n)]
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001498 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001499 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001500
1501 hits, misses, maxsize, currsize = f.cache_info()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001502 if self.module is py_functools:
1503 # XXX: Why can be not equal?
1504 self.assertLessEqual(misses, n)
1505 self.assertLessEqual(hits, m*n - misses)
1506 else:
1507 self.assertEqual(misses, n)
1508 self.assertEqual(hits, m*n - misses)
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001509 self.assertEqual(currsize, n)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001510
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001511 # create n threads in order to fill cache and 1 to clear it
1512 threads = [threading.Thread(target=clear)]
Serhiy Storchaka391af752015-06-08 12:44:18 +03001513 threads += [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001514 for k in range(n)]
1515 start.clear()
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001516 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001517 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001518 finally:
1519 sys.setswitchinterval(orig_si)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001520
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001521 def test_lru_cache_threaded2(self):
1522 # Simultaneous call with the same arguments
1523 n, m = 5, 7
1524 start = threading.Barrier(n+1)
1525 pause = threading.Barrier(n+1)
1526 stop = threading.Barrier(n+1)
1527 @self.module.lru_cache(maxsize=m*n)
1528 def f(x):
1529 pause.wait(10)
1530 return 3 * x
1531 self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1532 def test():
1533 for i in range(m):
1534 start.wait(10)
1535 self.assertEqual(f(i), 3 * i)
1536 stop.wait(10)
1537 threads = [threading.Thread(target=test) for k in range(n)]
1538 with support.start_threads(threads):
1539 for i in range(m):
1540 start.wait(10)
1541 stop.reset()
1542 pause.wait(10)
1543 start.reset()
1544 stop.wait(10)
1545 pause.reset()
1546 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1547
Serhiy Storchaka67796522017-01-12 18:34:33 +02001548 def test_lru_cache_threaded3(self):
1549 @self.module.lru_cache(maxsize=2)
1550 def f(x):
1551 time.sleep(.01)
1552 return 3 * x
1553 def test(i, x):
1554 with self.subTest(thread=i):
1555 self.assertEqual(f(x), 3 * x, i)
1556 threads = [threading.Thread(target=test, args=(i, v))
1557 for i, v in enumerate([1, 2, 2, 3, 2])]
1558 with support.start_threads(threads):
1559 pass
1560
Raymond Hettinger03923422013-03-04 02:52:50 -05001561 def test_need_for_rlock(self):
1562 # This will deadlock on an LRU cache that uses a regular lock
1563
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001564 @self.module.lru_cache(maxsize=10)
Raymond Hettinger03923422013-03-04 02:52:50 -05001565 def test_func(x):
1566 'Used to demonstrate a reentrant lru_cache call within a single thread'
1567 return x
1568
1569 class DoubleEq:
1570 'Demonstrate a reentrant lru_cache call within a single thread'
1571 def __init__(self, x):
1572 self.x = x
1573 def __hash__(self):
1574 return self.x
1575 def __eq__(self, other):
1576 if self.x == 2:
1577 test_func(DoubleEq(1))
1578 return self.x == other.x
1579
1580 test_func(DoubleEq(1)) # Load the cache
1581 test_func(DoubleEq(2)) # Load the cache
1582 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1583 DoubleEq(2)) # Verify the correct return value
1584
Raymond Hettinger4d588972014-08-12 12:44:52 -07001585 def test_early_detection_of_bad_call(self):
1586 # Issue #22184
1587 with self.assertRaises(TypeError):
1588 @functools.lru_cache
1589 def f():
1590 pass
1591
Serhiy Storchakae7070f02015-06-08 11:19:24 +03001592 def test_lru_method(self):
1593 class X(int):
1594 f_cnt = 0
1595 @self.module.lru_cache(2)
1596 def f(self, x):
1597 self.f_cnt += 1
1598 return x*10+self
1599 a = X(5)
1600 b = X(5)
1601 c = X(7)
1602 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1603
1604 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1605 self.assertEqual(a.f(x), x*10 + 5)
1606 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1607 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1608
1609 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1610 self.assertEqual(b.f(x), x*10 + 5)
1611 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1612 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1613
1614 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1615 self.assertEqual(c.f(x), x*10 + 7)
1616 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1617 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1618
1619 self.assertEqual(a.f.cache_info(), X.f.cache_info())
1620 self.assertEqual(b.f.cache_info(), X.f.cache_info())
1621 self.assertEqual(c.f.cache_info(), X.f.cache_info())
1622
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001623 def test_pickle(self):
1624 cls = self.__class__
1625 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1626 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1627 with self.subTest(proto=proto, func=f):
1628 f_copy = pickle.loads(pickle.dumps(f, proto))
1629 self.assertIs(f_copy, f)
1630
1631 def test_copy(self):
1632 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001633 def orig(x, y):
1634 return 3 * x + y
1635 part = self.module.partial(orig, 2)
1636 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1637 self.module.lru_cache(2)(part))
1638 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001639 with self.subTest(func=f):
1640 f_copy = copy.copy(f)
1641 self.assertIs(f_copy, f)
1642
1643 def test_deepcopy(self):
1644 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001645 def orig(x, y):
1646 return 3 * x + y
1647 part = self.module.partial(orig, 2)
1648 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1649 self.module.lru_cache(2)(part))
1650 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001651 with self.subTest(func=f):
1652 f_copy = copy.deepcopy(f)
1653 self.assertIs(f_copy, f)
1654
1655
1656@py_functools.lru_cache()
1657def py_cached_func(x, y):
1658 return 3 * x + y
1659
1660@c_functools.lru_cache()
1661def c_cached_func(x, y):
1662 return 3 * x + y
1663
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001664
1665class TestLRUPy(TestLRU, unittest.TestCase):
1666 module = py_functools
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001667 cached_func = py_cached_func,
1668
1669 @module.lru_cache()
1670 def cached_meth(self, x, y):
1671 return 3 * x + y
1672
1673 @staticmethod
1674 @module.lru_cache()
1675 def cached_staticmeth(x, y):
1676 return 3 * x + y
1677
1678
1679class TestLRUC(TestLRU, unittest.TestCase):
1680 module = c_functools
1681 cached_func = c_cached_func,
1682
1683 @module.lru_cache()
1684 def cached_meth(self, x, y):
1685 return 3 * x + y
1686
1687 @staticmethod
1688 @module.lru_cache()
1689 def cached_staticmeth(x, y):
1690 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001691
Raymond Hettinger03923422013-03-04 02:52:50 -05001692
Łukasz Langa6f692512013-06-05 12:20:24 +02001693class TestSingleDispatch(unittest.TestCase):
1694 def test_simple_overloads(self):
1695 @functools.singledispatch
1696 def g(obj):
1697 return "base"
1698 def g_int(i):
1699 return "integer"
1700 g.register(int, g_int)
1701 self.assertEqual(g("str"), "base")
1702 self.assertEqual(g(1), "integer")
1703 self.assertEqual(g([1,2,3]), "base")
1704
1705 def test_mro(self):
1706 @functools.singledispatch
1707 def g(obj):
1708 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001709 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001710 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001711 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001712 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001713 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001714 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001715 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001716 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001717 def g_A(a):
1718 return "A"
1719 def g_B(b):
1720 return "B"
1721 g.register(A, g_A)
1722 g.register(B, g_B)
1723 self.assertEqual(g(A()), "A")
1724 self.assertEqual(g(B()), "B")
1725 self.assertEqual(g(C()), "A")
1726 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001727
1728 def test_register_decorator(self):
1729 @functools.singledispatch
1730 def g(obj):
1731 return "base"
1732 @g.register(int)
1733 def g_int(i):
1734 return "int %s" % (i,)
1735 self.assertEqual(g(""), "base")
1736 self.assertEqual(g(12), "int 12")
1737 self.assertIs(g.dispatch(int), g_int)
1738 self.assertIs(g.dispatch(object), g.dispatch(str))
1739 # Note: in the assert above this is not g.
1740 # @singledispatch returns the wrapper.
1741
1742 def test_wrapping_attributes(self):
1743 @functools.singledispatch
1744 def g(obj):
1745 "Simple test"
1746 return "Test"
1747 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02001748 if sys.flags.optimize < 2:
1749 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02001750
1751 @unittest.skipUnless(decimal, 'requires _decimal')
1752 @support.cpython_only
1753 def test_c_classes(self):
1754 @functools.singledispatch
1755 def g(obj):
1756 return "base"
1757 @g.register(decimal.DecimalException)
1758 def _(obj):
1759 return obj.args
1760 subn = decimal.Subnormal("Exponent < Emin")
1761 rnd = decimal.Rounded("Number got rounded")
1762 self.assertEqual(g(subn), ("Exponent < Emin",))
1763 self.assertEqual(g(rnd), ("Number got rounded",))
1764 @g.register(decimal.Subnormal)
1765 def _(obj):
1766 return "Too small to care."
1767 self.assertEqual(g(subn), "Too small to care.")
1768 self.assertEqual(g(rnd), ("Number got rounded",))
1769
1770 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001771 # None of the examples in this test depend on haystack ordering.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001772 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001773 mro = functools._compose_mro
1774 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1775 for haystack in permutations(bases):
1776 m = mro(dict, haystack)
Guido van Rossumf0666942016-08-23 10:47:07 -07001777 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
1778 c.Collection, c.Sized, c.Iterable,
1779 c.Container, object])
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001780 bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict]
Łukasz Langa6f692512013-06-05 12:20:24 +02001781 for haystack in permutations(bases):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001782 m = mro(collections.ChainMap, haystack)
1783 self.assertEqual(m, [collections.ChainMap, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001784 c.Collection, c.Sized, c.Iterable,
1785 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001786
1787 # If there's a generic function with implementations registered for
1788 # both Sized and Container, passing a defaultdict to it results in an
1789 # ambiguous dispatch which will cause a RuntimeError (see
1790 # test_mro_conflicts).
1791 bases = [c.Container, c.Sized, str]
1792 for haystack in permutations(bases):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001793 m = mro(collections.defaultdict, [c.Sized, c.Container, str])
1794 self.assertEqual(m, [collections.defaultdict, dict, c.Sized,
1795 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001796
1797 # MutableSequence below is registered directly on D. In other words, it
Martin Panter46f50722016-05-26 05:35:26 +00001798 # precedes MutableMapping which means single dispatch will always
Łukasz Langa3720c772013-07-01 16:00:38 +02001799 # choose MutableSequence here.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001800 class D(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02001801 pass
1802 c.MutableSequence.register(D)
1803 bases = [c.MutableSequence, c.MutableMapping]
1804 for haystack in permutations(bases):
1805 m = mro(D, bases)
Guido van Rossumf0666942016-08-23 10:47:07 -07001806 self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001807 collections.defaultdict, dict, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001808 c.Collection, c.Sized, c.Iterable, c.Container,
Łukasz Langa3720c772013-07-01 16:00:38 +02001809 object])
1810
1811 # Container and Callable are registered on different base classes and
1812 # a generic function supporting both should always pick the Callable
1813 # implementation if a C instance is passed.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001814 class C(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02001815 def __call__(self):
1816 pass
1817 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1818 for haystack in permutations(bases):
1819 m = mro(C, haystack)
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001820 self.assertEqual(m, [C, c.Callable, collections.defaultdict, dict, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001821 c.Collection, c.Sized, c.Iterable,
1822 c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001823
1824 def test_register_abc(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001825 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001826 d = {"a": "b"}
1827 l = [1, 2, 3]
1828 s = {object(), None}
1829 f = frozenset(s)
1830 t = (1, 2, 3)
1831 @functools.singledispatch
1832 def g(obj):
1833 return "base"
1834 self.assertEqual(g(d), "base")
1835 self.assertEqual(g(l), "base")
1836 self.assertEqual(g(s), "base")
1837 self.assertEqual(g(f), "base")
1838 self.assertEqual(g(t), "base")
1839 g.register(c.Sized, lambda obj: "sized")
1840 self.assertEqual(g(d), "sized")
1841 self.assertEqual(g(l), "sized")
1842 self.assertEqual(g(s), "sized")
1843 self.assertEqual(g(f), "sized")
1844 self.assertEqual(g(t), "sized")
1845 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1846 self.assertEqual(g(d), "mutablemapping")
1847 self.assertEqual(g(l), "sized")
1848 self.assertEqual(g(s), "sized")
1849 self.assertEqual(g(f), "sized")
1850 self.assertEqual(g(t), "sized")
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001851 g.register(collections.ChainMap, lambda obj: "chainmap")
Łukasz Langa6f692512013-06-05 12:20:24 +02001852 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1853 self.assertEqual(g(l), "sized")
1854 self.assertEqual(g(s), "sized")
1855 self.assertEqual(g(f), "sized")
1856 self.assertEqual(g(t), "sized")
1857 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1858 self.assertEqual(g(d), "mutablemapping")
1859 self.assertEqual(g(l), "mutablesequence")
1860 self.assertEqual(g(s), "sized")
1861 self.assertEqual(g(f), "sized")
1862 self.assertEqual(g(t), "sized")
1863 g.register(c.MutableSet, lambda obj: "mutableset")
1864 self.assertEqual(g(d), "mutablemapping")
1865 self.assertEqual(g(l), "mutablesequence")
1866 self.assertEqual(g(s), "mutableset")
1867 self.assertEqual(g(f), "sized")
1868 self.assertEqual(g(t), "sized")
1869 g.register(c.Mapping, lambda obj: "mapping")
1870 self.assertEqual(g(d), "mutablemapping") # not specific enough
1871 self.assertEqual(g(l), "mutablesequence")
1872 self.assertEqual(g(s), "mutableset")
1873 self.assertEqual(g(f), "sized")
1874 self.assertEqual(g(t), "sized")
1875 g.register(c.Sequence, lambda obj: "sequence")
1876 self.assertEqual(g(d), "mutablemapping")
1877 self.assertEqual(g(l), "mutablesequence")
1878 self.assertEqual(g(s), "mutableset")
1879 self.assertEqual(g(f), "sized")
1880 self.assertEqual(g(t), "sequence")
1881 g.register(c.Set, lambda obj: "set")
1882 self.assertEqual(g(d), "mutablemapping")
1883 self.assertEqual(g(l), "mutablesequence")
1884 self.assertEqual(g(s), "mutableset")
1885 self.assertEqual(g(f), "set")
1886 self.assertEqual(g(t), "sequence")
1887 g.register(dict, lambda obj: "dict")
1888 self.assertEqual(g(d), "dict")
1889 self.assertEqual(g(l), "mutablesequence")
1890 self.assertEqual(g(s), "mutableset")
1891 self.assertEqual(g(f), "set")
1892 self.assertEqual(g(t), "sequence")
1893 g.register(list, lambda obj: "list")
1894 self.assertEqual(g(d), "dict")
1895 self.assertEqual(g(l), "list")
1896 self.assertEqual(g(s), "mutableset")
1897 self.assertEqual(g(f), "set")
1898 self.assertEqual(g(t), "sequence")
1899 g.register(set, lambda obj: "concrete-set")
1900 self.assertEqual(g(d), "dict")
1901 self.assertEqual(g(l), "list")
1902 self.assertEqual(g(s), "concrete-set")
1903 self.assertEqual(g(f), "set")
1904 self.assertEqual(g(t), "sequence")
1905 g.register(frozenset, lambda obj: "frozen-set")
1906 self.assertEqual(g(d), "dict")
1907 self.assertEqual(g(l), "list")
1908 self.assertEqual(g(s), "concrete-set")
1909 self.assertEqual(g(f), "frozen-set")
1910 self.assertEqual(g(t), "sequence")
1911 g.register(tuple, lambda obj: "tuple")
1912 self.assertEqual(g(d), "dict")
1913 self.assertEqual(g(l), "list")
1914 self.assertEqual(g(s), "concrete-set")
1915 self.assertEqual(g(f), "frozen-set")
1916 self.assertEqual(g(t), "tuple")
1917
Łukasz Langa3720c772013-07-01 16:00:38 +02001918 def test_c3_abc(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001919 c = collections.abc
Łukasz Langa3720c772013-07-01 16:00:38 +02001920 mro = functools._c3_mro
1921 class A(object):
1922 pass
1923 class B(A):
1924 def __len__(self):
1925 return 0 # implies Sized
1926 @c.Container.register
1927 class C(object):
1928 pass
1929 class D(object):
1930 pass # unrelated
1931 class X(D, C, B):
1932 def __call__(self):
1933 pass # implies Callable
1934 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
1935 for abcs in permutations([c.Sized, c.Callable, c.Container]):
1936 self.assertEqual(mro(X, abcs=abcs), expected)
1937 # unrelated ABCs don't appear in the resulting MRO
1938 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
1939 self.assertEqual(mro(X, abcs=many_abcs), expected)
1940
Yury Selivanov77a8cd62015-08-18 14:20:00 -04001941 def test_false_meta(self):
1942 # see issue23572
1943 class MetaA(type):
1944 def __len__(self):
1945 return 0
1946 class A(metaclass=MetaA):
1947 pass
1948 class AA(A):
1949 pass
1950 @functools.singledispatch
1951 def fun(a):
1952 return 'base A'
1953 @fun.register(A)
1954 def _(a):
1955 return 'fun A'
1956 aa = AA()
1957 self.assertEqual(fun(aa), 'fun A')
1958
Łukasz Langa6f692512013-06-05 12:20:24 +02001959 def test_mro_conflicts(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001960 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001961 @functools.singledispatch
1962 def g(arg):
1963 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02001964 class O(c.Sized):
1965 def __len__(self):
1966 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001967 o = O()
1968 self.assertEqual(g(o), "base")
1969 g.register(c.Iterable, lambda arg: "iterable")
1970 g.register(c.Container, lambda arg: "container")
1971 g.register(c.Sized, lambda arg: "sized")
1972 g.register(c.Set, lambda arg: "set")
1973 self.assertEqual(g(o), "sized")
1974 c.Iterable.register(O)
1975 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
1976 c.Container.register(O)
1977 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02001978 c.Set.register(O)
1979 self.assertEqual(g(o), "set") # because c.Set is a subclass of
1980 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02001981 class P:
1982 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02001983 p = P()
1984 self.assertEqual(g(p), "base")
1985 c.Iterable.register(P)
1986 self.assertEqual(g(p), "iterable")
1987 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02001988 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02001989 g(p)
Benjamin Petersonab078e92016-07-13 21:13:29 -07001990 self.assertIn(
1991 str(re_one.exception),
1992 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1993 "or <class 'collections.abc.Iterable'>"),
1994 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
1995 "or <class 'collections.abc.Container'>")),
1996 )
Łukasz Langa6f692512013-06-05 12:20:24 +02001997 class Q(c.Sized):
1998 def __len__(self):
1999 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02002000 q = Q()
2001 self.assertEqual(g(q), "sized")
2002 c.Iterable.register(Q)
2003 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
2004 c.Set.register(Q)
2005 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02002006 # c.Sized and c.Iterable
2007 @functools.singledispatch
2008 def h(arg):
2009 return "base"
2010 @h.register(c.Sized)
2011 def _(arg):
2012 return "sized"
2013 @h.register(c.Container)
2014 def _(arg):
2015 return "container"
2016 # Even though Sized and Container are explicit bases of MutableMapping,
2017 # this ABC is implicitly registered on defaultdict which makes all of
2018 # MutableMapping's bases implicit as well from defaultdict's
2019 # perspective.
2020 with self.assertRaises(RuntimeError) as re_two:
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002021 h(collections.defaultdict(lambda: 0))
Benjamin Petersonab078e92016-07-13 21:13:29 -07002022 self.assertIn(
2023 str(re_two.exception),
2024 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2025 "or <class 'collections.abc.Sized'>"),
2026 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2027 "or <class 'collections.abc.Container'>")),
2028 )
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002029 class R(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02002030 pass
2031 c.MutableSequence.register(R)
2032 @functools.singledispatch
2033 def i(arg):
2034 return "base"
2035 @i.register(c.MutableMapping)
2036 def _(arg):
2037 return "mapping"
2038 @i.register(c.MutableSequence)
2039 def _(arg):
2040 return "sequence"
2041 r = R()
2042 self.assertEqual(i(r), "sequence")
2043 class S:
2044 pass
2045 class T(S, c.Sized):
2046 def __len__(self):
2047 return 0
2048 t = T()
2049 self.assertEqual(h(t), "sized")
2050 c.Container.register(T)
2051 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
2052 class U:
2053 def __len__(self):
2054 return 0
2055 u = U()
2056 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
2057 # from the existence of __len__()
2058 c.Container.register(U)
2059 # There is no preference for registered versus inferred ABCs.
2060 with self.assertRaises(RuntimeError) as re_three:
2061 h(u)
Benjamin Petersonab078e92016-07-13 21:13:29 -07002062 self.assertIn(
2063 str(re_three.exception),
2064 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2065 "or <class 'collections.abc.Sized'>"),
2066 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2067 "or <class 'collections.abc.Container'>")),
2068 )
Łukasz Langa3720c772013-07-01 16:00:38 +02002069 class V(c.Sized, S):
2070 def __len__(self):
2071 return 0
2072 @functools.singledispatch
2073 def j(arg):
2074 return "base"
2075 @j.register(S)
2076 def _(arg):
2077 return "s"
2078 @j.register(c.Container)
2079 def _(arg):
2080 return "container"
2081 v = V()
2082 self.assertEqual(j(v), "s")
2083 c.Container.register(V)
2084 self.assertEqual(j(v), "container") # because it ends up right after
2085 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02002086
2087 def test_cache_invalidation(self):
2088 from collections import UserDict
INADA Naoki9811e802017-09-30 16:13:02 +09002089 import weakref
2090
Łukasz Langa6f692512013-06-05 12:20:24 +02002091 class TracingDict(UserDict):
2092 def __init__(self, *args, **kwargs):
2093 super(TracingDict, self).__init__(*args, **kwargs)
2094 self.set_ops = []
2095 self.get_ops = []
2096 def __getitem__(self, key):
2097 result = self.data[key]
2098 self.get_ops.append(key)
2099 return result
2100 def __setitem__(self, key, value):
2101 self.set_ops.append(key)
2102 self.data[key] = value
2103 def clear(self):
2104 self.data.clear()
INADA Naoki9811e802017-09-30 16:13:02 +09002105
Łukasz Langa6f692512013-06-05 12:20:24 +02002106 td = TracingDict()
INADA Naoki9811e802017-09-30 16:13:02 +09002107 with support.swap_attr(weakref, "WeakKeyDictionary", lambda: td):
2108 c = collections.abc
2109 @functools.singledispatch
2110 def g(arg):
2111 return "base"
2112 d = {}
2113 l = []
2114 self.assertEqual(len(td), 0)
2115 self.assertEqual(g(d), "base")
2116 self.assertEqual(len(td), 1)
2117 self.assertEqual(td.get_ops, [])
2118 self.assertEqual(td.set_ops, [dict])
2119 self.assertEqual(td.data[dict], g.registry[object])
2120 self.assertEqual(g(l), "base")
2121 self.assertEqual(len(td), 2)
2122 self.assertEqual(td.get_ops, [])
2123 self.assertEqual(td.set_ops, [dict, list])
2124 self.assertEqual(td.data[dict], g.registry[object])
2125 self.assertEqual(td.data[list], g.registry[object])
2126 self.assertEqual(td.data[dict], td.data[list])
2127 self.assertEqual(g(l), "base")
2128 self.assertEqual(g(d), "base")
2129 self.assertEqual(td.get_ops, [list, dict])
2130 self.assertEqual(td.set_ops, [dict, list])
2131 g.register(list, lambda arg: "list")
2132 self.assertEqual(td.get_ops, [list, dict])
2133 self.assertEqual(len(td), 0)
2134 self.assertEqual(g(d), "base")
2135 self.assertEqual(len(td), 1)
2136 self.assertEqual(td.get_ops, [list, dict])
2137 self.assertEqual(td.set_ops, [dict, list, dict])
2138 self.assertEqual(td.data[dict],
2139 functools._find_impl(dict, g.registry))
2140 self.assertEqual(g(l), "list")
2141 self.assertEqual(len(td), 2)
2142 self.assertEqual(td.get_ops, [list, dict])
2143 self.assertEqual(td.set_ops, [dict, list, dict, list])
2144 self.assertEqual(td.data[list],
2145 functools._find_impl(list, g.registry))
2146 class X:
2147 pass
2148 c.MutableMapping.register(X) # Will not invalidate the cache,
2149 # not using ABCs yet.
2150 self.assertEqual(g(d), "base")
2151 self.assertEqual(g(l), "list")
2152 self.assertEqual(td.get_ops, [list, dict, dict, list])
2153 self.assertEqual(td.set_ops, [dict, list, dict, list])
2154 g.register(c.Sized, lambda arg: "sized")
2155 self.assertEqual(len(td), 0)
2156 self.assertEqual(g(d), "sized")
2157 self.assertEqual(len(td), 1)
2158 self.assertEqual(td.get_ops, [list, dict, dict, list])
2159 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
2160 self.assertEqual(g(l), "list")
2161 self.assertEqual(len(td), 2)
2162 self.assertEqual(td.get_ops, [list, dict, dict, list])
2163 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2164 self.assertEqual(g(l), "list")
2165 self.assertEqual(g(d), "sized")
2166 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
2167 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2168 g.dispatch(list)
2169 g.dispatch(dict)
2170 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
2171 list, dict])
2172 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2173 c.MutableSet.register(X) # Will invalidate the cache.
2174 self.assertEqual(len(td), 2) # Stale cache.
2175 self.assertEqual(g(l), "list")
2176 self.assertEqual(len(td), 1)
2177 g.register(c.MutableMapping, lambda arg: "mutablemapping")
2178 self.assertEqual(len(td), 0)
2179 self.assertEqual(g(d), "mutablemapping")
2180 self.assertEqual(len(td), 1)
2181 self.assertEqual(g(l), "list")
2182 self.assertEqual(len(td), 2)
2183 g.register(dict, lambda arg: "dict")
2184 self.assertEqual(g(d), "dict")
2185 self.assertEqual(g(l), "list")
2186 g._clear_cache()
2187 self.assertEqual(len(td), 0)
Łukasz Langa6f692512013-06-05 12:20:24 +02002188
Łukasz Langae5697532017-12-11 13:56:31 -08002189 def test_annotations(self):
2190 @functools.singledispatch
2191 def i(arg):
2192 return "base"
2193 @i.register
2194 def _(arg: collections.abc.Mapping):
2195 return "mapping"
2196 @i.register
2197 def _(arg: "collections.abc.Sequence"):
2198 return "sequence"
2199 self.assertEqual(i(None), "base")
2200 self.assertEqual(i({"a": 1}), "mapping")
2201 self.assertEqual(i([1, 2, 3]), "sequence")
2202 self.assertEqual(i((1, 2, 3)), "sequence")
2203 self.assertEqual(i("str"), "sequence")
2204
2205 # Registering classes as callables doesn't work with annotations,
2206 # you need to pass the type explicitly.
2207 @i.register(str)
2208 class _:
2209 def __init__(self, arg):
2210 self.arg = arg
2211
2212 def __eq__(self, other):
2213 return self.arg == other
2214 self.assertEqual(i("str"), "str")
2215
Ethan Smithc6512752018-05-26 16:38:33 -04002216 def test_method_register(self):
2217 class A:
2218 @functools.singledispatchmethod
2219 def t(self, arg):
2220 self.arg = "base"
2221 @t.register(int)
2222 def _(self, arg):
2223 self.arg = "int"
2224 @t.register(str)
2225 def _(self, arg):
2226 self.arg = "str"
2227 a = A()
2228
2229 a.t(0)
2230 self.assertEqual(a.arg, "int")
2231 aa = A()
2232 self.assertFalse(hasattr(aa, 'arg'))
2233 a.t('')
2234 self.assertEqual(a.arg, "str")
2235 aa = A()
2236 self.assertFalse(hasattr(aa, 'arg'))
2237 a.t(0.0)
2238 self.assertEqual(a.arg, "base")
2239 aa = A()
2240 self.assertFalse(hasattr(aa, 'arg'))
2241
2242 def test_staticmethod_register(self):
2243 class A:
2244 @functools.singledispatchmethod
2245 @staticmethod
2246 def t(arg):
2247 return arg
2248 @t.register(int)
2249 @staticmethod
2250 def _(arg):
2251 return isinstance(arg, int)
2252 @t.register(str)
2253 @staticmethod
2254 def _(arg):
2255 return isinstance(arg, str)
2256 a = A()
2257
2258 self.assertTrue(A.t(0))
2259 self.assertTrue(A.t(''))
2260 self.assertEqual(A.t(0.0), 0.0)
2261
2262 def test_classmethod_register(self):
2263 class A:
2264 def __init__(self, arg):
2265 self.arg = arg
2266
2267 @functools.singledispatchmethod
2268 @classmethod
2269 def t(cls, arg):
2270 return cls("base")
2271 @t.register(int)
2272 @classmethod
2273 def _(cls, arg):
2274 return cls("int")
2275 @t.register(str)
2276 @classmethod
2277 def _(cls, arg):
2278 return cls("str")
2279
2280 self.assertEqual(A.t(0).arg, "int")
2281 self.assertEqual(A.t('').arg, "str")
2282 self.assertEqual(A.t(0.0).arg, "base")
2283
2284 def test_callable_register(self):
2285 class A:
2286 def __init__(self, arg):
2287 self.arg = arg
2288
2289 @functools.singledispatchmethod
2290 @classmethod
2291 def t(cls, arg):
2292 return cls("base")
2293
2294 @A.t.register(int)
2295 @classmethod
2296 def _(cls, arg):
2297 return cls("int")
2298 @A.t.register(str)
2299 @classmethod
2300 def _(cls, arg):
2301 return cls("str")
2302
2303 self.assertEqual(A.t(0).arg, "int")
2304 self.assertEqual(A.t('').arg, "str")
2305 self.assertEqual(A.t(0.0).arg, "base")
2306
2307 def test_abstractmethod_register(self):
2308 class Abstract(abc.ABCMeta):
2309
2310 @functools.singledispatchmethod
2311 @abc.abstractmethod
2312 def add(self, x, y):
2313 pass
2314
2315 self.assertTrue(Abstract.add.__isabstractmethod__)
2316
2317 def test_type_ann_register(self):
2318 class A:
2319 @functools.singledispatchmethod
2320 def t(self, arg):
2321 return "base"
2322 @t.register
2323 def _(self, arg: int):
2324 return "int"
2325 @t.register
2326 def _(self, arg: str):
2327 return "str"
2328 a = A()
2329
2330 self.assertEqual(a.t(0), "int")
2331 self.assertEqual(a.t(''), "str")
2332 self.assertEqual(a.t(0.0), "base")
2333
Łukasz Langae5697532017-12-11 13:56:31 -08002334 def test_invalid_registrations(self):
2335 msg_prefix = "Invalid first argument to `register()`: "
2336 msg_suffix = (
2337 ". Use either `@register(some_class)` or plain `@register` on an "
2338 "annotated function."
2339 )
2340 @functools.singledispatch
2341 def i(arg):
2342 return "base"
2343 with self.assertRaises(TypeError) as exc:
2344 @i.register(42)
2345 def _(arg):
2346 return "I annotated with a non-type"
2347 self.assertTrue(str(exc.exception).startswith(msg_prefix + "42"))
2348 self.assertTrue(str(exc.exception).endswith(msg_suffix))
2349 with self.assertRaises(TypeError) as exc:
2350 @i.register
2351 def _(arg):
2352 return "I forgot to annotate"
2353 self.assertTrue(str(exc.exception).startswith(msg_prefix +
2354 "<function TestSingleDispatch.test_invalid_registrations.<locals>._"
2355 ))
2356 self.assertTrue(str(exc.exception).endswith(msg_suffix))
2357
2358 # FIXME: The following will only work after PEP 560 is implemented.
2359 return
2360
2361 with self.assertRaises(TypeError) as exc:
2362 @i.register
2363 def _(arg: typing.Iterable[str]):
2364 # At runtime, dispatching on generics is impossible.
2365 # When registering implementations with singledispatch, avoid
2366 # types from `typing`. Instead, annotate with regular types
2367 # or ABCs.
2368 return "I annotated with a generic collection"
2369 self.assertTrue(str(exc.exception).startswith(msg_prefix +
2370 "<function TestSingleDispatch.test_invalid_registrations.<locals>._"
2371 ))
2372 self.assertTrue(str(exc.exception).endswith(msg_suffix))
2373
Dong-hee Na445f1b32018-07-10 16:26:36 +09002374 def test_invalid_positional_argument(self):
2375 @functools.singledispatch
2376 def f(*args):
2377 pass
2378 msg = 'f requires at least 1 positional argument'
INADA Naoki56d8f572018-07-17 13:44:47 +09002379 with self.assertRaisesRegex(TypeError, msg):
Dong-hee Na445f1b32018-07-10 16:26:36 +09002380 f()
Łukasz Langa6f692512013-06-05 12:20:24 +02002381
Carl Meyerd658dea2018-08-28 01:11:56 -06002382
2383class CachedCostItem:
2384 _cost = 1
2385
2386 def __init__(self):
2387 self.lock = py_functools.RLock()
2388
2389 @py_functools.cached_property
2390 def cost(self):
2391 """The cost of the item."""
2392 with self.lock:
2393 self._cost += 1
2394 return self._cost
2395
2396
2397class OptionallyCachedCostItem:
2398 _cost = 1
2399
2400 def get_cost(self):
2401 """The cost of the item."""
2402 self._cost += 1
2403 return self._cost
2404
2405 cached_cost = py_functools.cached_property(get_cost)
2406
2407
2408class CachedCostItemWait:
2409
2410 def __init__(self, event):
2411 self._cost = 1
2412 self.lock = py_functools.RLock()
2413 self.event = event
2414
2415 @py_functools.cached_property
2416 def cost(self):
2417 self.event.wait(1)
2418 with self.lock:
2419 self._cost += 1
2420 return self._cost
2421
2422
2423class CachedCostItemWithSlots:
2424 __slots__ = ('_cost')
2425
2426 def __init__(self):
2427 self._cost = 1
2428
2429 @py_functools.cached_property
2430 def cost(self):
2431 raise RuntimeError('never called, slots not supported')
2432
2433
2434class TestCachedProperty(unittest.TestCase):
2435 def test_cached(self):
2436 item = CachedCostItem()
2437 self.assertEqual(item.cost, 2)
2438 self.assertEqual(item.cost, 2) # not 3
2439
2440 def test_cached_attribute_name_differs_from_func_name(self):
2441 item = OptionallyCachedCostItem()
2442 self.assertEqual(item.get_cost(), 2)
2443 self.assertEqual(item.cached_cost, 3)
2444 self.assertEqual(item.get_cost(), 4)
2445 self.assertEqual(item.cached_cost, 3)
2446
2447 def test_threaded(self):
2448 go = threading.Event()
2449 item = CachedCostItemWait(go)
2450
2451 num_threads = 3
2452
2453 orig_si = sys.getswitchinterval()
2454 sys.setswitchinterval(1e-6)
2455 try:
2456 threads = [
2457 threading.Thread(target=lambda: item.cost)
2458 for k in range(num_threads)
2459 ]
2460 with support.start_threads(threads):
2461 go.set()
2462 finally:
2463 sys.setswitchinterval(orig_si)
2464
2465 self.assertEqual(item.cost, 2)
2466
2467 def test_object_with_slots(self):
2468 item = CachedCostItemWithSlots()
2469 with self.assertRaisesRegex(
2470 TypeError,
2471 "No '__dict__' attribute on 'CachedCostItemWithSlots' instance to cache 'cost' property.",
2472 ):
2473 item.cost
2474
2475 def test_immutable_dict(self):
2476 class MyMeta(type):
2477 @py_functools.cached_property
2478 def prop(self):
2479 return True
2480
2481 class MyClass(metaclass=MyMeta):
2482 pass
2483
2484 with self.assertRaisesRegex(
2485 TypeError,
2486 "The '__dict__' attribute on 'MyMeta' instance does not support item assignment for caching 'prop' property.",
2487 ):
2488 MyClass.prop
2489
2490 def test_reuse_different_names(self):
2491 """Disallow this case because decorated function a would not be cached."""
2492 with self.assertRaises(RuntimeError) as ctx:
2493 class ReusedCachedProperty:
2494 @py_functools.cached_property
2495 def a(self):
2496 pass
2497
2498 b = a
2499
2500 self.assertEqual(
2501 str(ctx.exception.__context__),
2502 str(TypeError("Cannot assign the same cached_property to two different names ('a' and 'b')."))
2503 )
2504
2505 def test_reuse_same_name(self):
2506 """Reusing a cached_property on different classes under the same name is OK."""
2507 counter = 0
2508
2509 @py_functools.cached_property
2510 def _cp(_self):
2511 nonlocal counter
2512 counter += 1
2513 return counter
2514
2515 class A:
2516 cp = _cp
2517
2518 class B:
2519 cp = _cp
2520
2521 a = A()
2522 b = B()
2523
2524 self.assertEqual(a.cp, 1)
2525 self.assertEqual(b.cp, 2)
2526 self.assertEqual(a.cp, 1)
2527
2528 def test_set_name_not_called(self):
2529 cp = py_functools.cached_property(lambda s: None)
2530 class Foo:
2531 pass
2532
2533 Foo.cp = cp
2534
2535 with self.assertRaisesRegex(
2536 TypeError,
2537 "Cannot use cached_property instance without calling __set_name__ on it.",
2538 ):
2539 Foo().cp
2540
2541 def test_access_from_class(self):
2542 self.assertIsInstance(CachedCostItem.cost, py_functools.cached_property)
2543
2544 def test_doc(self):
2545 self.assertEqual(CachedCostItem.cost.__doc__, "The cost of the item.")
2546
2547
Raymond Hettinger9c323f82005-02-28 19:39:44 +00002548if __name__ == '__main__':
Zachary Ware38c707e2015-04-13 15:00:43 -05002549 unittest.main()