blob: 72b7765853bc046f42ceddb7a092d41aab50ee3e [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08002import builtins
Raymond Hettinger003be522011-05-03 11:01:32 -07003import collections
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03004import collections.abc
Serhiy Storchaka45120f22015-10-24 09:49:56 +03005import copy
Pablo Galindo99e6c262020-01-23 15:29:52 +00006from itertools import permutations, chain
Jack Diederiche0cbd692009-04-01 04:27:09 +00007import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00008from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02009import sys
10from test import support
Antoine Pitroua6a4dc82017-09-07 18:56:24 +020011import threading
Serhiy Storchaka67796522017-01-12 18:34:33 +020012import time
Łukasz Langae5697532017-12-11 13:56:31 -080013import typing
Łukasz Langa6f692512013-06-05 12:20:24 +020014import unittest
Raymond Hettingerd191ef22017-01-07 20:44:48 -080015import unittest.mock
Pablo Galindo99e6c262020-01-23 15:29:52 +000016import os
Dennis Sweeney1253c3e2020-05-05 17:14:32 -040017import weakref
18import gc
Łukasz Langa6f692512013-06-05 12:20:24 +020019from weakref import proxy
Nick Coghlan457fc9a2016-09-10 20:00:02 +100020import contextlib
Raymond Hettinger9c323f82005-02-28 19:39:44 +000021
Hai Shie80697d2020-05-28 06:10:27 +080022from test.support import threading_helper
Pablo Galindo99e6c262020-01-23 15:29:52 +000023from test.support.script_helper import assert_python_ok
24
Antoine Pitroub5b37142012-11-13 21:35:40 +010025import functools
26
Antoine Pitroub5b37142012-11-13 21:35:40 +010027py_functools = support.import_fresh_module('functools', blocked=['_functools'])
28c_functools = support.import_fresh_module('functools', fresh=['_functools'])
29
Łukasz Langa6f692512013-06-05 12:20:24 +020030decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
31
Nick Coghlan457fc9a2016-09-10 20:00:02 +100032@contextlib.contextmanager
33def replaced_module(name, replacement):
34 original_module = sys.modules[name]
35 sys.modules[name] = replacement
36 try:
37 yield
38 finally:
39 sys.modules[name] = original_module
Łukasz Langa6f692512013-06-05 12:20:24 +020040
Raymond Hettinger9c323f82005-02-28 19:39:44 +000041def capture(*args, **kw):
42 """capture all positional and keyword arguments"""
43 return args, kw
44
Łukasz Langa6f692512013-06-05 12:20:24 +020045
Jack Diederiche0cbd692009-04-01 04:27:09 +000046def signature(part):
47 """ return the signature of a partial object """
48 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000049
Serhiy Storchaka38741282016-02-02 18:45:17 +020050class MyTuple(tuple):
51 pass
52
53class BadTuple(tuple):
54 def __add__(self, other):
55 return list(self) + list(other)
56
57class MyDict(dict):
58 pass
59
Łukasz Langa6f692512013-06-05 12:20:24 +020060
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020061class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000062
63 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010064 p = self.partial(capture, 1, 2, a=10, b=20)
65 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000066 self.assertEqual(p(3, 4, b=30, c=40),
67 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010068 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000069 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000070
71 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010072 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000073 # attributes should be readable
74 self.assertEqual(p.func, capture)
75 self.assertEqual(p.args, (1, 2))
76 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000077
78 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010079 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000080 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010081 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000082 except TypeError:
83 pass
84 else:
85 self.fail('First arg not checked for callability')
86
87 def test_protection_of_callers_dict_argument(self):
88 # a caller's dictionary should not be altered by partial
89 def func(a=10, b=20):
90 return a
91 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010092 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000093 self.assertEqual(p(**d), 3)
94 self.assertEqual(d, {'a':3})
95 p(b=7)
96 self.assertEqual(d, {'a':3})
97
Serhiy Storchaka9639e4a2017-02-20 14:04:30 +020098 def test_kwargs_copy(self):
99 # Issue #29532: Altering a kwarg dictionary passed to a constructor
100 # should not affect a partial object after creation
101 d = {'a': 3}
102 p = self.partial(capture, **d)
103 self.assertEqual(p(), ((), {'a': 3}))
104 d['a'] = 5
105 self.assertEqual(p(), ((), {'a': 3}))
106
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000107 def test_arg_combinations(self):
108 # exercise special code paths for zero args in either partial
109 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100110 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000111 self.assertEqual(p(), ((), {}))
112 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100113 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000114 self.assertEqual(p(), ((1,2), {}))
115 self.assertEqual(p(3,4), ((1,2,3,4), {}))
116
117 def test_kw_combinations(self):
118 # exercise special code paths for no keyword args in
119 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100120 p = self.partial(capture)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400121 self.assertEqual(p.keywords, {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000122 self.assertEqual(p(), ((), {}))
123 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100124 p = self.partial(capture, a=1)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400125 self.assertEqual(p.keywords, {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000126 self.assertEqual(p(), ((), {'a':1}))
127 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
128 # keyword args in the call override those in the partial object
129 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
130
131 def test_positional(self):
132 # make sure positional arguments are captured correctly
133 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100134 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000135 expected = args + ('x',)
136 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000137 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000138
139 def test_keyword(self):
140 # make sure keyword arguments are captured correctly
141 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100142 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000143 expected = {'a':a,'x':None}
144 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000145 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000146
147 def test_no_side_effects(self):
148 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100149 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000150 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000151 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000152 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000153 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000154
155 def test_error_propagation(self):
156 def f(x, y):
157 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100158 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
159 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
160 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
161 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000162
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000163 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100164 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000165 p = proxy(f)
166 self.assertEqual(f.func, p.func)
167 f = None
168 self.assertRaises(ReferenceError, getattr, p, 'func')
169
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000170 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000171 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100172 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000173 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100174 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000175 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000176
Alexander Belopolskye49af342015-03-01 15:08:17 -0500177 def test_nested_optimization(self):
178 partial = self.partial
Alexander Belopolskye49af342015-03-01 15:08:17 -0500179 inner = partial(signature, 'asdf')
180 nested = partial(inner, bar=True)
181 flat = partial(signature, 'asdf', bar=True)
182 self.assertEqual(signature(nested), signature(flat))
183
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300184 def test_nested_partial_with_attribute(self):
185 # see issue 25137
186 partial = self.partial
187
188 def foo(bar):
189 return bar
190
191 p = partial(foo, 'first')
192 p2 = partial(p, 'second')
193 p2.new_attr = 'spam'
194 self.assertEqual(p2.new_attr, 'spam')
195
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000196 def test_repr(self):
197 args = (object(), object())
198 args_repr = ', '.join(repr(a) for a in args)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200199 kwargs = {'a': object(), 'b': object()}
200 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
201 'b={b!r}, a={a!r}'.format_map(kwargs)]
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000202 if self.partial in (c_functools.partial, py_functools.partial):
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000203 name = 'functools.partial'
204 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100205 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000206
Antoine Pitroub5b37142012-11-13 21:35:40 +0100207 f = self.partial(capture)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000208 self.assertEqual(f'{name}({capture!r})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000209
Antoine Pitroub5b37142012-11-13 21:35:40 +0100210 f = self.partial(capture, *args)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000211 self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000212
Antoine Pitroub5b37142012-11-13 21:35:40 +0100213 f = self.partial(capture, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200214 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000215 [f'{name}({capture!r}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200216 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000217
Antoine Pitroub5b37142012-11-13 21:35:40 +0100218 f = self.partial(capture, *args, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200219 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000220 [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200221 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000222
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300223 def test_recursive_repr(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000224 if self.partial in (c_functools.partial, py_functools.partial):
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300225 name = 'functools.partial'
226 else:
227 name = self.partial.__name__
228
229 f = self.partial(capture)
230 f.__setstate__((f, (), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300231 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000232 self.assertEqual(repr(f), '%s(...)' % (name,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300233 finally:
234 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300235
236 f = self.partial(capture)
237 f.__setstate__((capture, (f,), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300238 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000239 self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300240 finally:
241 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300242
243 f = self.partial(capture)
244 f.__setstate__((capture, (), {'a': f}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300245 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000246 self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300247 finally:
248 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300249
Jack Diederiche0cbd692009-04-01 04:27:09 +0000250 def test_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000251 with self.AllowPickle():
252 f = self.partial(signature, ['asdf'], bar=[True])
253 f.attr = []
254 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
255 f_copy = pickle.loads(pickle.dumps(f, proto))
256 self.assertEqual(signature(f_copy), signature(f))
Serhiy Storchaka38741282016-02-02 18:45:17 +0200257
258 def test_copy(self):
259 f = self.partial(signature, ['asdf'], bar=[True])
260 f.attr = []
261 f_copy = copy.copy(f)
262 self.assertEqual(signature(f_copy), signature(f))
263 self.assertIs(f_copy.attr, f.attr)
264 self.assertIs(f_copy.args, f.args)
265 self.assertIs(f_copy.keywords, f.keywords)
266
267 def test_deepcopy(self):
268 f = self.partial(signature, ['asdf'], bar=[True])
269 f.attr = []
270 f_copy = copy.deepcopy(f)
271 self.assertEqual(signature(f_copy), signature(f))
272 self.assertIsNot(f_copy.attr, f.attr)
273 self.assertIsNot(f_copy.args, f.args)
274 self.assertIsNot(f_copy.args[0], f.args[0])
275 self.assertIsNot(f_copy.keywords, f.keywords)
276 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
277
278 def test_setstate(self):
279 f = self.partial(signature)
280 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000281
Serhiy Storchaka38741282016-02-02 18:45:17 +0200282 self.assertEqual(signature(f),
283 (capture, (1,), dict(a=10), dict(attr=[])))
284 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
285
286 f.__setstate__((capture, (1,), dict(a=10), None))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000287
Serhiy Storchaka38741282016-02-02 18:45:17 +0200288 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
289 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
290
291 f.__setstate__((capture, (1,), None, None))
292 #self.assertEqual(signature(f), (capture, (1,), {}, {}))
293 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
294 self.assertEqual(f(2), ((1, 2), {}))
295 self.assertEqual(f(), ((1,), {}))
296
297 f.__setstate__((capture, (), {}, None))
298 self.assertEqual(signature(f), (capture, (), {}, {}))
299 self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
300 self.assertEqual(f(2), ((2,), {}))
301 self.assertEqual(f(), ((), {}))
302
303 def test_setstate_errors(self):
304 f = self.partial(signature)
305 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
306 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
307 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
308 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
309 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
310 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
311 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
312
313 def test_setstate_subclasses(self):
314 f = self.partial(signature)
315 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
316 s = signature(f)
317 self.assertEqual(s, (capture, (1,), dict(a=10), {}))
318 self.assertIs(type(s[1]), tuple)
319 self.assertIs(type(s[2]), dict)
320 r = f()
321 self.assertEqual(r, ((1,), {'a': 10}))
322 self.assertIs(type(r[0]), tuple)
323 self.assertIs(type(r[1]), dict)
324
325 f.__setstate__((capture, BadTuple((1,)), {}, None))
326 s = signature(f)
327 self.assertEqual(s, (capture, (1,), {}, {}))
328 self.assertIs(type(s[1]), tuple)
329 r = f(2)
330 self.assertEqual(r, ((1, 2), {}))
331 self.assertIs(type(r[0]), tuple)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000332
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300333 def test_recursive_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000334 with self.AllowPickle():
335 f = self.partial(capture)
336 f.__setstate__((f, (), {}, {}))
337 try:
338 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
339 with self.assertRaises(RecursionError):
340 pickle.dumps(f, proto)
341 finally:
342 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300343
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000344 f = self.partial(capture)
345 f.__setstate__((capture, (f,), {}, {}))
346 try:
347 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
348 f_copy = pickle.loads(pickle.dumps(f, proto))
349 try:
350 self.assertIs(f_copy.args[0], f_copy)
351 finally:
352 f_copy.__setstate__((capture, (), {}, {}))
353 finally:
354 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300355
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000356 f = self.partial(capture)
357 f.__setstate__((capture, (), {'a': f}, {}))
358 try:
359 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
360 f_copy = pickle.loads(pickle.dumps(f, proto))
361 try:
362 self.assertIs(f_copy.keywords['a'], f_copy)
363 finally:
364 f_copy.__setstate__((capture, (), {}, {}))
365 finally:
366 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300367
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200368 # Issue 6083: Reference counting bug
369 def test_setstate_refcount(self):
370 class BadSequence:
371 def __len__(self):
372 return 4
373 def __getitem__(self, key):
374 if key == 0:
375 return max
376 elif key == 1:
377 return tuple(range(1000000))
378 elif key in (2, 3):
379 return {}
380 raise IndexError
381
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200382 f = self.partial(object)
Serhiy Storchaka38741282016-02-02 18:45:17 +0200383 self.assertRaises(TypeError, f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000384
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000385@unittest.skipUnless(c_functools, 'requires the C _functools module')
386class TestPartialC(TestPartial, unittest.TestCase):
387 if c_functools:
388 partial = c_functools.partial
389
390 class AllowPickle:
391 def __enter__(self):
392 return self
393 def __exit__(self, type, value, tb):
394 return False
395
396 def test_attributes_unwritable(self):
397 # attributes should not be writable
398 p = self.partial(capture, 1, 2, a=10, b=20)
399 self.assertRaises(AttributeError, setattr, p, 'func', map)
400 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
401 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
402
403 p = self.partial(hex)
404 try:
405 del p.__dict__
406 except TypeError:
407 pass
408 else:
409 self.fail('partial object allowed __dict__ to be deleted')
Łukasz Langa6f692512013-06-05 12:20:24 +0200410
Michael Seifert6c3d5272017-03-15 06:26:33 +0100411 def test_manually_adding_non_string_keyword(self):
412 p = self.partial(capture)
413 # Adding a non-string/unicode keyword to partial kwargs
414 p.keywords[1234] = 'value'
415 r = repr(p)
416 self.assertIn('1234', r)
417 self.assertIn("'value'", r)
418 with self.assertRaises(TypeError):
419 p()
420
421 def test_keystr_replaces_value(self):
422 p = self.partial(capture)
423
424 class MutatesYourDict(object):
425 def __str__(self):
426 p.keywords[self] = ['sth2']
427 return 'astr'
428
Mike53f7a7c2017-12-14 14:04:53 +0300429 # Replacing the value during key formatting should keep the original
Michael Seifert6c3d5272017-03-15 06:26:33 +0100430 # value alive (at least long enough).
431 p.keywords[MutatesYourDict()] = ['sth']
432 r = repr(p)
433 self.assertIn('astr', r)
434 self.assertIn("['sth']", r)
435
436
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200437class TestPartialPy(TestPartial, unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000438 partial = py_functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000439
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000440 class AllowPickle:
441 def __init__(self):
442 self._cm = replaced_module("functools", py_functools)
443 def __enter__(self):
444 return self._cm.__enter__()
445 def __exit__(self, type, value, tb):
446 return self._cm.__exit__(type, value, tb)
Łukasz Langa6f692512013-06-05 12:20:24 +0200447
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200448if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000449 class CPartialSubclass(c_functools.partial):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200450 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100451
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000452class PyPartialSubclass(py_functools.partial):
453 pass
Łukasz Langa6f692512013-06-05 12:20:24 +0200454
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200455@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200456class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200457 if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000458 partial = CPartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000459
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300460 # partial subclasses are not optimized for nested calls
461 test_nested_optimization = None
462
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000463class TestPartialPySubclass(TestPartialPy):
464 partial = PyPartialSubclass
Łukasz Langa6f692512013-06-05 12:20:24 +0200465
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000466class TestPartialMethod(unittest.TestCase):
467
468 class A(object):
469 nothing = functools.partialmethod(capture)
470 positional = functools.partialmethod(capture, 1)
471 keywords = functools.partialmethod(capture, a=2)
472 both = functools.partialmethod(capture, 3, b=4)
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300473 spec_keywords = functools.partialmethod(capture, self=1, func=2)
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000474
475 nested = functools.partialmethod(positional, 5)
476
477 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
478
479 static = functools.partialmethod(staticmethod(capture), 8)
480 cls = functools.partialmethod(classmethod(capture), d=9)
481
482 a = A()
483
484 def test_arg_combinations(self):
485 self.assertEqual(self.a.nothing(), ((self.a,), {}))
486 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
487 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
488 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
489
490 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
491 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
492 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
493 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
494
495 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
496 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
497 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
498 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
499
500 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
501 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
502 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
503 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
504
505 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
506
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300507 self.assertEqual(self.a.spec_keywords(), ((self.a,), {'self': 1, 'func': 2}))
508
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000509 def test_nested(self):
510 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
511 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
512 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
513 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
514
515 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
516
517 def test_over_partial(self):
518 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
519 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
520 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
521 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
522
523 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
524
525 def test_bound_method_introspection(self):
526 obj = self.a
527 self.assertIs(obj.both.__self__, obj)
528 self.assertIs(obj.nested.__self__, obj)
529 self.assertIs(obj.over_partial.__self__, obj)
530 self.assertIs(obj.cls.__self__, self.A)
531 self.assertIs(self.A.cls.__self__, self.A)
532
533 def test_unbound_method_retrieval(self):
534 obj = self.A
535 self.assertFalse(hasattr(obj.both, "__self__"))
536 self.assertFalse(hasattr(obj.nested, "__self__"))
537 self.assertFalse(hasattr(obj.over_partial, "__self__"))
538 self.assertFalse(hasattr(obj.static, "__self__"))
539 self.assertFalse(hasattr(self.a.static, "__self__"))
540
541 def test_descriptors(self):
542 for obj in [self.A, self.a]:
543 with self.subTest(obj=obj):
544 self.assertEqual(obj.static(), ((8,), {}))
545 self.assertEqual(obj.static(5), ((8, 5), {}))
546 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
547 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
548
549 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
550 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
551 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
552 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
553
554 def test_overriding_keywords(self):
555 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
556 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
557
558 def test_invalid_args(self):
559 with self.assertRaises(TypeError):
560 class B(object):
561 method = functools.partialmethod(None, 1)
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300562 with self.assertRaises(TypeError):
563 class B:
564 method = functools.partialmethod()
Serhiy Storchaka142566c2019-06-05 18:22:31 +0300565 with self.assertRaises(TypeError):
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300566 class B:
567 method = functools.partialmethod(func=capture, a=1)
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000568
569 def test_repr(self):
570 self.assertEqual(repr(vars(self.A)['both']),
571 'functools.partialmethod({}, 3, b=4)'.format(capture))
572
573 def test_abstract(self):
574 class Abstract(abc.ABCMeta):
575
576 @abc.abstractmethod
577 def add(self, x, y):
578 pass
579
580 add5 = functools.partialmethod(add, 5)
581
582 self.assertTrue(Abstract.add.__isabstractmethod__)
583 self.assertTrue(Abstract.add5.__isabstractmethod__)
584
585 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
586 self.assertFalse(getattr(func, '__isabstractmethod__', False))
587
Pablo Galindo8c77b8c2019-04-29 13:36:57 +0100588 def test_positional_only(self):
589 def f(a, b, /):
590 return a + b
591
592 p = functools.partial(f, 1)
593 self.assertEqual(p(2), f(1, 2))
594
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000595
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000596class TestUpdateWrapper(unittest.TestCase):
597
598 def check_wrapper(self, wrapper, wrapped,
599 assigned=functools.WRAPPER_ASSIGNMENTS,
600 updated=functools.WRAPPER_UPDATES):
601 # Check attributes were assigned
602 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000603 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000604 # Check attributes were updated
605 for name in updated:
606 wrapper_attr = getattr(wrapper, name)
607 wrapped_attr = getattr(wrapped, name)
608 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000609 if name == "__dict__" and key == "__wrapped__":
610 # __wrapped__ is overwritten by the update code
611 continue
612 self.assertIs(wrapped_attr[key], wrapper_attr[key])
613 # Check __wrapped__
614 self.assertIs(wrapper.__wrapped__, wrapped)
615
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000616
R. David Murray378c0cf2010-02-24 01:46:21 +0000617 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000618 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000619 """This is a test"""
620 pass
621 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000622 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000623 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000624 pass
625 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000626 return wrapper, f
627
628 def test_default_update(self):
629 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000630 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000631 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000632 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600633 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000634 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000635 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
636 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000637
R. David Murray378c0cf2010-02-24 01:46:21 +0000638 @unittest.skipIf(sys.flags.optimize >= 2,
639 "Docstrings are omitted with -O2 and above")
640 def test_default_update_doc(self):
641 wrapper, f = self._default_update()
642 self.assertEqual(wrapper.__doc__, 'This is a test')
643
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000644 def test_no_update(self):
645 def f():
646 """This is a test"""
647 pass
648 f.attr = 'This is also a test'
649 def wrapper():
650 pass
651 functools.update_wrapper(wrapper, f, (), ())
652 self.check_wrapper(wrapper, f, (), ())
653 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600654 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000655 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000656 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000657 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000658
659 def test_selective_update(self):
660 def f():
661 pass
662 f.attr = 'This is a different test'
663 f.dict_attr = dict(a=1, b=2, c=3)
664 def wrapper():
665 pass
666 wrapper.dict_attr = {}
667 assign = ('attr',)
668 update = ('dict_attr',)
669 functools.update_wrapper(wrapper, f, assign, update)
670 self.check_wrapper(wrapper, f, assign, update)
671 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600672 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000673 self.assertEqual(wrapper.__doc__, None)
674 self.assertEqual(wrapper.attr, 'This is a different test')
675 self.assertEqual(wrapper.dict_attr, f.dict_attr)
676
Nick Coghlan98876832010-08-17 06:17:18 +0000677 def test_missing_attributes(self):
678 def f():
679 pass
680 def wrapper():
681 pass
682 wrapper.dict_attr = {}
683 assign = ('attr',)
684 update = ('dict_attr',)
685 # Missing attributes on wrapped object are ignored
686 functools.update_wrapper(wrapper, f, assign, update)
687 self.assertNotIn('attr', wrapper.__dict__)
688 self.assertEqual(wrapper.dict_attr, {})
689 # Wrapper must have expected attributes for updating
690 del wrapper.dict_attr
691 with self.assertRaises(AttributeError):
692 functools.update_wrapper(wrapper, f, assign, update)
693 wrapper.dict_attr = 1
694 with self.assertRaises(AttributeError):
695 functools.update_wrapper(wrapper, f, assign, update)
696
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200697 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000698 @unittest.skipIf(sys.flags.optimize >= 2,
699 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000700 def test_builtin_update(self):
701 # Test for bug #1576241
702 def wrapper():
703 pass
704 functools.update_wrapper(wrapper, max)
705 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000706 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000707 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000708
Łukasz Langa6f692512013-06-05 12:20:24 +0200709
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000710class TestWraps(TestUpdateWrapper):
711
R. David Murray378c0cf2010-02-24 01:46:21 +0000712 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000713 def f():
714 """This is a test"""
715 pass
716 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000717 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000718 @functools.wraps(f)
719 def wrapper():
720 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600721 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000722
723 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600724 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000725 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000726 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600727 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000728 self.assertEqual(wrapper.attr, 'This is also a test')
729
Antoine Pitroub5b37142012-11-13 21:35:40 +0100730 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000731 "Docstrings are omitted with -O2 and above")
732 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600733 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000734 self.assertEqual(wrapper.__doc__, 'This is a test')
735
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000736 def test_no_update(self):
737 def f():
738 """This is a test"""
739 pass
740 f.attr = 'This is also a test'
741 @functools.wraps(f, (), ())
742 def wrapper():
743 pass
744 self.check_wrapper(wrapper, f, (), ())
745 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600746 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000747 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000748 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000749
750 def test_selective_update(self):
751 def f():
752 pass
753 f.attr = 'This is a different test'
754 f.dict_attr = dict(a=1, b=2, c=3)
755 def add_dict_attr(f):
756 f.dict_attr = {}
757 return f
758 assign = ('attr',)
759 update = ('dict_attr',)
760 @functools.wraps(f, assign, update)
761 @add_dict_attr
762 def wrapper():
763 pass
764 self.check_wrapper(wrapper, f, assign, update)
765 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600766 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000767 self.assertEqual(wrapper.__doc__, None)
768 self.assertEqual(wrapper.attr, 'This is a different test')
769 self.assertEqual(wrapper.dict_attr, f.dict_attr)
770
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000771
madman-bobe25d5fc2018-10-25 15:02:10 +0100772class TestReduce:
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000773 def test_reduce(self):
774 class Squares:
775 def __init__(self, max):
776 self.max = max
777 self.sofar = []
778
779 def __len__(self):
780 return len(self.sofar)
781
782 def __getitem__(self, i):
783 if not 0 <= i < self.max: raise IndexError
784 n = len(self.sofar)
785 while n <= i:
786 self.sofar.append(n*n)
787 n += 1
788 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000789 def add(x, y):
790 return x + y
madman-bobe25d5fc2018-10-25 15:02:10 +0100791 self.assertEqual(self.reduce(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000792 self.assertEqual(
madman-bobe25d5fc2018-10-25 15:02:10 +0100793 self.reduce(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000794 ['a','c','d','w']
795 )
madman-bobe25d5fc2018-10-25 15:02:10 +0100796 self.assertEqual(self.reduce(lambda x, y: x*y, range(2,8), 1), 5040)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000797 self.assertEqual(
madman-bobe25d5fc2018-10-25 15:02:10 +0100798 self.reduce(lambda x, y: x*y, range(2,21), 1),
Guido van Rossume2a383d2007-01-15 16:59:06 +0000799 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000800 )
madman-bobe25d5fc2018-10-25 15:02:10 +0100801 self.assertEqual(self.reduce(add, Squares(10)), 285)
802 self.assertEqual(self.reduce(add, Squares(10), 0), 285)
803 self.assertEqual(self.reduce(add, Squares(0), 0), 0)
804 self.assertRaises(TypeError, self.reduce)
805 self.assertRaises(TypeError, self.reduce, 42, 42)
806 self.assertRaises(TypeError, self.reduce, 42, 42, 42)
807 self.assertEqual(self.reduce(42, "1"), "1") # func is never called with one item
808 self.assertEqual(self.reduce(42, "", "1"), "1") # func is never called with one item
809 self.assertRaises(TypeError, self.reduce, 42, (42, 42))
810 self.assertRaises(TypeError, self.reduce, add, []) # arg 2 must not be empty sequence with no initial value
811 self.assertRaises(TypeError, self.reduce, add, "")
812 self.assertRaises(TypeError, self.reduce, add, ())
813 self.assertRaises(TypeError, self.reduce, add, object())
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000814
815 class TestFailingIter:
816 def __iter__(self):
817 raise RuntimeError
madman-bobe25d5fc2018-10-25 15:02:10 +0100818 self.assertRaises(RuntimeError, self.reduce, add, TestFailingIter())
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000819
madman-bobe25d5fc2018-10-25 15:02:10 +0100820 self.assertEqual(self.reduce(add, [], None), None)
821 self.assertEqual(self.reduce(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000822
823 class BadSeq:
824 def __getitem__(self, index):
825 raise ValueError
madman-bobe25d5fc2018-10-25 15:02:10 +0100826 self.assertRaises(ValueError, self.reduce, 42, BadSeq())
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000827
828 # Test reduce()'s use of iterators.
829 def test_iterator_usage(self):
830 class SequenceClass:
831 def __init__(self, n):
832 self.n = n
833 def __getitem__(self, i):
834 if 0 <= i < self.n:
835 return i
836 else:
837 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000838
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000839 from operator import add
madman-bobe25d5fc2018-10-25 15:02:10 +0100840 self.assertEqual(self.reduce(add, SequenceClass(5)), 10)
841 self.assertEqual(self.reduce(add, SequenceClass(5), 42), 52)
842 self.assertRaises(TypeError, self.reduce, add, SequenceClass(0))
843 self.assertEqual(self.reduce(add, SequenceClass(0), 42), 42)
844 self.assertEqual(self.reduce(add, SequenceClass(1)), 0)
845 self.assertEqual(self.reduce(add, SequenceClass(1), 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000846
847 d = {"one": 1, "two": 2, "three": 3}
madman-bobe25d5fc2018-10-25 15:02:10 +0100848 self.assertEqual(self.reduce(add, d), "".join(d.keys()))
849
850
851@unittest.skipUnless(c_functools, 'requires the C _functools module')
852class TestReduceC(TestReduce, unittest.TestCase):
853 if c_functools:
854 reduce = c_functools.reduce
855
856
857class TestReducePy(TestReduce, unittest.TestCase):
858 reduce = staticmethod(py_functools.reduce)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000859
Łukasz Langa6f692512013-06-05 12:20:24 +0200860
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200861class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700862
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000863 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700864 def cmp1(x, y):
865 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100866 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700867 self.assertEqual(key(3), key(3))
868 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100869 self.assertGreaterEqual(key(3), key(3))
870
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700871 def cmp2(x, y):
872 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100873 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700874 self.assertEqual(key(4.0), key('4'))
875 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100876 self.assertLessEqual(key(2), key('35'))
877 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700878
879 def test_cmp_to_key_arguments(self):
880 def cmp1(x, y):
881 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100882 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700883 self.assertEqual(key(obj=3), key(obj=3))
884 self.assertGreater(key(obj=3), key(obj=1))
885 with self.assertRaises((TypeError, AttributeError)):
886 key(3) > 1 # rhs is not a K object
887 with self.assertRaises((TypeError, AttributeError)):
888 1 < key(3) # lhs is not a K object
889 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100890 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700891 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200892 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100893 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700894 with self.assertRaises(TypeError):
895 key() # too few args
896 with self.assertRaises(TypeError):
897 key(None, None) # too many args
898
899 def test_bad_cmp(self):
900 def cmp1(x, y):
901 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100902 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700903 with self.assertRaises(ZeroDivisionError):
904 key(3) > key(1)
905
906 class BadCmp:
907 def __lt__(self, other):
908 raise ZeroDivisionError
909 def cmp1(x, y):
910 return BadCmp()
911 with self.assertRaises(ZeroDivisionError):
912 key(3) > key(1)
913
914 def test_obj_field(self):
915 def cmp1(x, y):
916 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100917 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700918 self.assertEqual(key(50).obj, 50)
919
920 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000921 def mycmp(x, y):
922 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100923 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000924 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000925
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700926 def test_sort_int_str(self):
927 def mycmp(x, y):
928 x, y = int(x), int(y)
929 return (x > y) - (x < y)
930 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100931 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700932 self.assertEqual([int(value) for value in values],
933 [0, 1, 1, 2, 3, 4, 5, 7, 10])
934
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000935 def test_hash(self):
936 def mycmp(x, y):
937 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100938 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000939 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700940 self.assertRaises(TypeError, hash, k)
Serhiy Storchaka2e576f52017-04-24 09:05:00 +0300941 self.assertNotIsInstance(k, collections.abc.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000942
Łukasz Langa6f692512013-06-05 12:20:24 +0200943
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200944@unittest.skipUnless(c_functools, 'requires the C _functools module')
945class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
946 if c_functools:
947 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100948
Łukasz Langa6f692512013-06-05 12:20:24 +0200949
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200950class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100951 cmp_to_key = staticmethod(py_functools.cmp_to_key)
952
Łukasz Langa6f692512013-06-05 12:20:24 +0200953
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000954class TestTotalOrdering(unittest.TestCase):
955
956 def test_total_ordering_lt(self):
957 @functools.total_ordering
958 class A:
959 def __init__(self, value):
960 self.value = value
961 def __lt__(self, other):
962 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000963 def __eq__(self, other):
964 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000965 self.assertTrue(A(1) < A(2))
966 self.assertTrue(A(2) > A(1))
967 self.assertTrue(A(1) <= A(2))
968 self.assertTrue(A(2) >= A(1))
969 self.assertTrue(A(2) <= A(2))
970 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000971 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000972
973 def test_total_ordering_le(self):
974 @functools.total_ordering
975 class A:
976 def __init__(self, value):
977 self.value = value
978 def __le__(self, other):
979 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000980 def __eq__(self, other):
981 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000982 self.assertTrue(A(1) < A(2))
983 self.assertTrue(A(2) > A(1))
984 self.assertTrue(A(1) <= A(2))
985 self.assertTrue(A(2) >= A(1))
986 self.assertTrue(A(2) <= A(2))
987 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000988 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000989
990 def test_total_ordering_gt(self):
991 @functools.total_ordering
992 class A:
993 def __init__(self, value):
994 self.value = value
995 def __gt__(self, other):
996 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000997 def __eq__(self, other):
998 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000999 self.assertTrue(A(1) < A(2))
1000 self.assertTrue(A(2) > A(1))
1001 self.assertTrue(A(1) <= A(2))
1002 self.assertTrue(A(2) >= A(1))
1003 self.assertTrue(A(2) <= A(2))
1004 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +10001005 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +00001006
1007 def test_total_ordering_ge(self):
1008 @functools.total_ordering
1009 class A:
1010 def __init__(self, value):
1011 self.value = value
1012 def __ge__(self, other):
1013 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001014 def __eq__(self, other):
1015 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +00001016 self.assertTrue(A(1) < A(2))
1017 self.assertTrue(A(2) > A(1))
1018 self.assertTrue(A(1) <= A(2))
1019 self.assertTrue(A(2) >= A(1))
1020 self.assertTrue(A(2) <= A(2))
1021 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +10001022 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +00001023
1024 def test_total_ordering_no_overwrite(self):
1025 # new methods should not overwrite existing
1026 @functools.total_ordering
1027 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +00001028 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +00001029 self.assertTrue(A(1) < A(2))
1030 self.assertTrue(A(2) > A(1))
1031 self.assertTrue(A(1) <= A(2))
1032 self.assertTrue(A(2) >= A(1))
1033 self.assertTrue(A(2) <= A(2))
1034 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001035
Benjamin Peterson42ebee32010-04-11 01:43:16 +00001036 def test_no_operations_defined(self):
1037 with self.assertRaises(ValueError):
1038 @functools.total_ordering
1039 class A:
1040 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001041
Nick Coghlanf05d9812013-10-02 00:02:03 +10001042 def test_type_error_when_not_implemented(self):
1043 # bug 10042; ensure stack overflow does not occur
1044 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001045 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +10001046 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001047 def __init__(self, value):
1048 self.value = value
1049 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001050 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001051 return self.value == other.value
1052 return False
1053 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001054 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001055 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +10001056 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001057
Nick Coghlanf05d9812013-10-02 00:02:03 +10001058 @functools.total_ordering
1059 class ImplementsGreaterThan:
1060 def __init__(self, value):
1061 self.value = value
1062 def __eq__(self, other):
1063 if isinstance(other, ImplementsGreaterThan):
1064 return self.value == other.value
1065 return False
1066 def __gt__(self, other):
1067 if isinstance(other, ImplementsGreaterThan):
1068 return self.value > other.value
1069 return NotImplemented
1070
1071 @functools.total_ordering
1072 class ImplementsLessThanEqualTo:
1073 def __init__(self, value):
1074 self.value = value
1075 def __eq__(self, other):
1076 if isinstance(other, ImplementsLessThanEqualTo):
1077 return self.value == other.value
1078 return False
1079 def __le__(self, other):
1080 if isinstance(other, ImplementsLessThanEqualTo):
1081 return self.value <= other.value
1082 return NotImplemented
1083
1084 @functools.total_ordering
1085 class ImplementsGreaterThanEqualTo:
1086 def __init__(self, value):
1087 self.value = value
1088 def __eq__(self, other):
1089 if isinstance(other, ImplementsGreaterThanEqualTo):
1090 return self.value == other.value
1091 return False
1092 def __ge__(self, other):
1093 if isinstance(other, ImplementsGreaterThanEqualTo):
1094 return self.value >= other.value
1095 return NotImplemented
1096
1097 @functools.total_ordering
1098 class ComparatorNotImplemented:
1099 def __init__(self, value):
1100 self.value = value
1101 def __eq__(self, other):
1102 if isinstance(other, ComparatorNotImplemented):
1103 return self.value == other.value
1104 return False
1105 def __lt__(self, other):
1106 return NotImplemented
1107
1108 with self.subTest("LT < 1"), self.assertRaises(TypeError):
1109 ImplementsLessThan(-1) < 1
1110
1111 with self.subTest("LT < LE"), self.assertRaises(TypeError):
1112 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
1113
1114 with self.subTest("LT < GT"), self.assertRaises(TypeError):
1115 ImplementsLessThan(1) < ImplementsGreaterThan(1)
1116
1117 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
1118 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
1119
1120 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
1121 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
1122
1123 with self.subTest("GT > GE"), self.assertRaises(TypeError):
1124 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
1125
1126 with self.subTest("GT > LT"), self.assertRaises(TypeError):
1127 ImplementsGreaterThan(5) > ImplementsLessThan(5)
1128
1129 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
1130 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
1131
1132 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
1133 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
1134
1135 with self.subTest("GE when equal"):
1136 a = ComparatorNotImplemented(8)
1137 b = ComparatorNotImplemented(8)
1138 self.assertEqual(a, b)
1139 with self.assertRaises(TypeError):
1140 a >= b
1141
1142 with self.subTest("LE when equal"):
1143 a = ComparatorNotImplemented(9)
1144 b = ComparatorNotImplemented(9)
1145 self.assertEqual(a, b)
1146 with self.assertRaises(TypeError):
1147 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +02001148
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001149 def test_pickle(self):
Serhiy Storchaka92bb90a2016-09-22 11:39:25 +03001150 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001151 for name in '__lt__', '__gt__', '__le__', '__ge__':
1152 with self.subTest(method=name, proto=proto):
1153 method = getattr(Orderable_LT, name)
1154 method_copy = pickle.loads(pickle.dumps(method, proto))
1155 self.assertIs(method_copy, method)
1156
1157@functools.total_ordering
1158class Orderable_LT:
1159 def __init__(self, value):
1160 self.value = value
1161 def __lt__(self, other):
1162 return self.value < other.value
1163 def __eq__(self, other):
1164 return self.value == other.value
1165
1166
Pablo Galindo99e6c262020-01-23 15:29:52 +00001167class TestTopologicalSort(unittest.TestCase):
1168
1169 def _test_graph(self, graph, expected):
1170
1171 def static_order_with_groups(ts):
1172 ts.prepare()
1173 while ts.is_active():
1174 nodes = ts.get_ready()
1175 for node in nodes:
1176 ts.done(node)
1177 yield nodes
1178
1179 ts = functools.TopologicalSorter(graph)
1180 self.assertEqual(list(static_order_with_groups(ts)), list(expected))
1181
1182 ts = functools.TopologicalSorter(graph)
1183 self.assertEqual(list(ts.static_order()), list(chain(*expected)))
1184
1185 def _assert_cycle(self, graph, cycle):
1186 ts = functools.TopologicalSorter()
1187 for node, dependson in graph.items():
1188 ts.add(node, *dependson)
1189 try:
1190 ts.prepare()
1191 except functools.CycleError as e:
1192 msg, seq = e.args
1193 self.assertIn(' '.join(map(str, cycle)),
1194 ' '.join(map(str, seq * 2)))
1195 else:
1196 raise
1197
1198 def test_simple_cases(self):
1199 self._test_graph(
1200 {2: {11},
1201 9: {11, 8},
1202 10: {11, 3},
1203 11: {7, 5},
1204 8: {7, 3}},
1205 [(3, 5, 7), (11, 8), (2, 10, 9)]
1206 )
1207
1208 self._test_graph({1: {}}, [(1,)])
1209
1210 self._test_graph({x: {x+1} for x in range(10)},
1211 [(x,) for x in range(10, -1, -1)])
1212
1213 self._test_graph({2: {3}, 3: {4}, 4: {5}, 5: {1},
1214 11: {12}, 12: {13}, 13: {14}, 14: {15}},
1215 [(1, 15), (5, 14), (4, 13), (3, 12), (2, 11)])
1216
1217 self._test_graph({
1218 0: [1, 2],
1219 1: [3],
1220 2: [5, 6],
1221 3: [4],
1222 4: [9],
1223 5: [3],
1224 6: [7],
1225 7: [8],
1226 8: [4],
1227 9: []
1228 },
1229 [(9,), (4,), (3, 8), (1, 5, 7), (6,), (2,), (0,)]
1230 )
1231
1232 self._test_graph({
1233 0: [1, 2],
1234 1: [],
1235 2: [3],
1236 3: []
1237 },
1238 [(1, 3), (2,), (0,)]
1239 )
1240
1241 self._test_graph({
1242 0: [1, 2],
1243 1: [],
1244 2: [3],
1245 3: [],
1246 4: [5],
1247 5: [6],
1248 6: []
1249 },
1250 [(1, 3, 6), (2, 5), (0, 4)]
1251 )
1252
1253 def test_no_dependencies(self):
1254 self._test_graph(
1255 {1: {2},
1256 3: {4},
1257 5: {6}},
1258 [(2, 4, 6), (1, 3, 5)]
1259 )
1260
1261 self._test_graph(
1262 {1: set(),
1263 3: set(),
1264 5: set()},
1265 [(1, 3, 5)]
1266 )
1267
1268 def test_the_node_multiple_times(self):
1269 # Test same node multiple times in dependencies
1270 self._test_graph({1: {2}, 3: {4}, 0: [2, 4, 4, 4, 4, 4]},
1271 [(2, 4), (1, 3, 0)])
1272
1273 # Test adding the same dependency multiple times
1274 ts = functools.TopologicalSorter()
1275 ts.add(1, 2)
1276 ts.add(1, 2)
1277 ts.add(1, 2)
1278 self.assertEqual([*ts.static_order()], [2, 1])
1279
1280 def test_graph_with_iterables(self):
1281 dependson = (2*x + 1 for x in range(5))
1282 ts = functools.TopologicalSorter({0: dependson})
1283 self.assertEqual(list(ts.static_order()), [1, 3, 5, 7, 9, 0])
1284
1285 def test_add_dependencies_for_same_node_incrementally(self):
1286 # Test same node multiple times
1287 ts = functools.TopologicalSorter()
1288 ts.add(1, 2)
1289 ts.add(1, 3)
1290 ts.add(1, 4)
1291 ts.add(1, 5)
1292
1293 ts2 = functools.TopologicalSorter({1: {2, 3, 4, 5}})
1294 self.assertEqual([*ts.static_order()], [*ts2.static_order()])
1295
1296 def test_empty(self):
1297 self._test_graph({}, [])
1298
1299 def test_cycle(self):
1300 # Self cycle
1301 self._assert_cycle({1: {1}}, [1, 1])
1302 # Simple cycle
1303 self._assert_cycle({1: {2}, 2: {1}}, [1, 2, 1])
1304 # Indirect cycle
1305 self._assert_cycle({1: {2}, 2: {3}, 3: {1}}, [1, 3, 2, 1])
1306 # not all elements involved in a cycle
1307 self._assert_cycle({1: {2}, 2: {3}, 3: {1}, 5: {4}, 4: {6}}, [1, 3, 2, 1])
1308 # Multiple cycles
1309 self._assert_cycle({1: {2}, 2: {1}, 3: {4}, 4: {5}, 6: {7}, 7: {6}},
1310 [1, 2, 1])
1311 # Cycle in the middle of the graph
1312 self._assert_cycle({1: {2}, 2: {3}, 3: {2, 4}, 4: {5}}, [3, 2])
1313
1314 def test_calls_before_prepare(self):
1315 ts = functools.TopologicalSorter()
1316
1317 with self.assertRaisesRegex(ValueError, r"prepare\(\) must be called first"):
1318 ts.get_ready()
1319 with self.assertRaisesRegex(ValueError, r"prepare\(\) must be called first"):
1320 ts.done(3)
1321 with self.assertRaisesRegex(ValueError, r"prepare\(\) must be called first"):
1322 ts.is_active()
1323
1324 def test_prepare_multiple_times(self):
1325 ts = functools.TopologicalSorter()
1326 ts.prepare()
1327 with self.assertRaisesRegex(ValueError, r"cannot prepare\(\) more than once"):
1328 ts.prepare()
1329
1330 def test_invalid_nodes_in_done(self):
1331 ts = functools.TopologicalSorter()
1332 ts.add(1, 2, 3, 4)
1333 ts.add(2, 3, 4)
1334 ts.prepare()
1335 ts.get_ready()
1336
1337 with self.assertRaisesRegex(ValueError, "node 2 was not passed out"):
1338 ts.done(2)
1339 with self.assertRaisesRegex(ValueError, r"node 24 was not added using add\(\)"):
1340 ts.done(24)
1341
1342 def test_done(self):
1343 ts = functools.TopologicalSorter()
1344 ts.add(1, 2, 3, 4)
1345 ts.add(2, 3)
1346 ts.prepare()
1347
1348 self.assertEqual(ts.get_ready(), (3, 4))
1349 # If we don't mark anything as done, get_ready() returns nothing
1350 self.assertEqual(ts.get_ready(), ())
1351 ts.done(3)
1352 # Now 2 becomes available as 3 is done
1353 self.assertEqual(ts.get_ready(), (2,))
1354 self.assertEqual(ts.get_ready(), ())
1355 ts.done(4)
1356 ts.done(2)
1357 # Only 1 is missing
1358 self.assertEqual(ts.get_ready(), (1,))
1359 self.assertEqual(ts.get_ready(), ())
1360 ts.done(1)
1361 self.assertEqual(ts.get_ready(), ())
1362 self.assertFalse(ts.is_active())
1363
1364 def test_is_active(self):
1365 ts = functools.TopologicalSorter()
1366 ts.add(1, 2)
1367 ts.prepare()
1368
1369 self.assertTrue(ts.is_active())
1370 self.assertEqual(ts.get_ready(), (2,))
1371 self.assertTrue(ts.is_active())
1372 ts.done(2)
1373 self.assertTrue(ts.is_active())
1374 self.assertEqual(ts.get_ready(), (1,))
1375 self.assertTrue(ts.is_active())
1376 ts.done(1)
1377 self.assertFalse(ts.is_active())
1378
1379 def test_not_hashable_nodes(self):
1380 ts = functools.TopologicalSorter()
1381 self.assertRaises(TypeError, ts.add, dict(), 1)
1382 self.assertRaises(TypeError, ts.add, 1, dict())
1383 self.assertRaises(TypeError, ts.add, dict(), dict())
1384
1385 def test_order_of_insertion_does_not_matter_between_groups(self):
1386 def get_groups(ts):
1387 ts.prepare()
1388 while ts.is_active():
1389 nodes = ts.get_ready()
1390 ts.done(*nodes)
1391 yield set(nodes)
1392
1393 ts = functools.TopologicalSorter()
1394 ts.add(3, 2, 1)
1395 ts.add(1, 0)
1396 ts.add(4, 5)
1397 ts.add(6, 7)
1398 ts.add(4, 7)
1399
1400 ts2 = functools.TopologicalSorter()
1401 ts2.add(1, 0)
1402 ts2.add(3, 2, 1)
1403 ts2.add(4, 7)
1404 ts2.add(6, 7)
1405 ts2.add(4, 5)
1406
1407 self.assertEqual(list(get_groups(ts)), list(get_groups(ts2)))
1408
1409 def test_static_order_does_not_change_with_the_hash_seed(self):
1410 def check_order_with_hash_seed(seed):
1411 code = """if 1:
1412 import functools
1413 ts = functools.TopologicalSorter()
1414 ts.add('blech', 'bluch', 'hola')
1415 ts.add('abcd', 'blech', 'bluch', 'a', 'b')
1416 ts.add('a', 'a string', 'something', 'b')
1417 ts.add('bluch', 'hola', 'abcde', 'a', 'b')
1418 print(list(ts.static_order()))
1419 """
1420 env = os.environ.copy()
1421 # signal to assert_python not to do a copy
1422 # of os.environ on its own
1423 env['__cleanenv'] = True
1424 env['PYTHONHASHSEED'] = str(seed)
1425 out = assert_python_ok('-c', code, **env)
1426 return out
1427
1428 run1 = check_order_with_hash_seed(1234)
1429 run2 = check_order_with_hash_seed(31415)
1430
1431 self.assertNotEqual(run1, "")
1432 self.assertNotEqual(run2, "")
1433 self.assertEqual(run1, run2)
1434
1435
Raymond Hettinger21cdb712020-05-11 17:00:53 -07001436class TestCache:
1437 # This tests that the pass-through is working as designed.
1438 # The underlying functionality is tested in TestLRU.
1439
1440 def test_cache(self):
1441 @self.module.cache
1442 def fib(n):
1443 if n < 2:
1444 return n
1445 return fib(n-1) + fib(n-2)
1446 self.assertEqual([fib(n) for n in range(16)],
1447 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1448 self.assertEqual(fib.cache_info(),
1449 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1450 fib.cache_clear()
1451 self.assertEqual(fib.cache_info(),
1452 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1453
1454
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001455class TestLRU:
Georg Brandl2e7346a2010-07-31 18:09:23 +00001456
1457 def test_lru(self):
1458 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001459 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001460 f = self.module.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001461 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001462 self.assertEqual(maxsize, 20)
1463 self.assertEqual(currsize, 0)
1464 self.assertEqual(hits, 0)
1465 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001466
1467 domain = range(5)
1468 for i in range(1000):
1469 x, y = choice(domain), choice(domain)
1470 actual = f(x, y)
1471 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +00001472 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001473 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001474 self.assertTrue(hits > misses)
1475 self.assertEqual(hits + misses, 1000)
1476 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001477
Raymond Hettinger02566ec2010-09-04 22:46:06 +00001478 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +00001479 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001480 self.assertEqual(hits, 0)
1481 self.assertEqual(misses, 0)
1482 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001483 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001484 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001485 self.assertEqual(hits, 0)
1486 self.assertEqual(misses, 1)
1487 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001488
Nick Coghlan98876832010-08-17 06:17:18 +00001489 # Test bypassing the cache
1490 self.assertIs(f.__wrapped__, orig)
1491 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001492 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001493 self.assertEqual(hits, 0)
1494 self.assertEqual(misses, 1)
1495 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +00001496
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001497 # test size zero (which means "never-cache")
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001498 @self.module.lru_cache(0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001499 def f():
1500 nonlocal f_cnt
1501 f_cnt += 1
1502 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001503 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001504 f_cnt = 0
1505 for i in range(5):
1506 self.assertEqual(f(), 20)
1507 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001508 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001509 self.assertEqual(hits, 0)
1510 self.assertEqual(misses, 5)
1511 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001512
1513 # test size one
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001514 @self.module.lru_cache(1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001515 def f():
1516 nonlocal f_cnt
1517 f_cnt += 1
1518 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001519 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001520 f_cnt = 0
1521 for i in range(5):
1522 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001523 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001524 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001525 self.assertEqual(hits, 4)
1526 self.assertEqual(misses, 1)
1527 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001528
Raymond Hettingerf3098282010-08-15 03:30:45 +00001529 # test size two
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001530 @self.module.lru_cache(2)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001531 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001532 nonlocal f_cnt
1533 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +00001534 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +00001535 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001536 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +00001537 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1538 # * * * *
1539 self.assertEqual(f(x), x*10)
1540 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001541 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001542 self.assertEqual(hits, 12)
1543 self.assertEqual(misses, 4)
1544 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001545
Raymond Hettingerb8218682019-05-26 11:27:35 -07001546 def test_lru_no_args(self):
1547 @self.module.lru_cache
1548 def square(x):
1549 return x ** 2
1550
1551 self.assertEqual(list(map(square, [10, 20, 10])),
1552 [100, 400, 100])
1553 self.assertEqual(square.cache_info().hits, 1)
1554 self.assertEqual(square.cache_info().misses, 2)
1555 self.assertEqual(square.cache_info().maxsize, 128)
1556 self.assertEqual(square.cache_info().currsize, 2)
1557
Raymond Hettingerd8080c02019-01-26 03:02:00 -05001558 def test_lru_bug_35780(self):
1559 # C version of the lru_cache was not checking to see if
1560 # the user function call has already modified the cache
1561 # (this arises in recursive calls and in multi-threading).
1562 # This cause the cache to have orphan links not referenced
1563 # by the cache dictionary.
1564
1565 once = True # Modified by f(x) below
1566
1567 @self.module.lru_cache(maxsize=10)
1568 def f(x):
1569 nonlocal once
1570 rv = f'.{x}.'
1571 if x == 20 and once:
1572 once = False
1573 rv = f(x)
1574 return rv
1575
1576 # Fill the cache
1577 for x in range(15):
1578 self.assertEqual(f(x), f'.{x}.')
1579 self.assertEqual(f.cache_info().currsize, 10)
1580
1581 # Make a recursive call and make sure the cache remains full
1582 self.assertEqual(f(20), '.20.')
1583 self.assertEqual(f.cache_info().currsize, 10)
1584
Raymond Hettinger14adbd42019-04-20 07:20:44 -10001585 def test_lru_bug_36650(self):
1586 # C version of lru_cache was treating a call with an empty **kwargs
1587 # dictionary as being distinct from a call with no keywords at all.
1588 # This did not result in an incorrect answer, but it did trigger
1589 # an unexpected cache miss.
1590
1591 @self.module.lru_cache()
1592 def f(x):
1593 pass
1594
1595 f(0)
1596 f(0, **{})
1597 self.assertEqual(f.cache_info().hits, 1)
1598
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001599 def test_lru_hash_only_once(self):
1600 # To protect against weird reentrancy bugs and to improve
1601 # efficiency when faced with slow __hash__ methods, the
1602 # LRU cache guarantees that it will only call __hash__
1603 # only once per use as an argument to the cached function.
1604
1605 @self.module.lru_cache(maxsize=1)
1606 def f(x, y):
1607 return x * 3 + y
1608
1609 # Simulate the integer 5
1610 mock_int = unittest.mock.Mock()
1611 mock_int.__mul__ = unittest.mock.Mock(return_value=15)
1612 mock_int.__hash__ = unittest.mock.Mock(return_value=999)
1613
1614 # Add to cache: One use as an argument gives one call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001615 self.assertEqual(f(mock_int, 1), 16)
1616 self.assertEqual(mock_int.__hash__.call_count, 1)
1617 self.assertEqual(f.cache_info(), (0, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001618
1619 # Cache hit: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001620 self.assertEqual(f(mock_int, 1), 16)
1621 self.assertEqual(mock_int.__hash__.call_count, 2)
1622 self.assertEqual(f.cache_info(), (1, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001623
Ville Skyttä49b27342017-08-03 09:00:59 +03001624 # Cache eviction: No use as an argument gives no additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001625 self.assertEqual(f(6, 2), 20)
1626 self.assertEqual(mock_int.__hash__.call_count, 2)
1627 self.assertEqual(f.cache_info(), (1, 2, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001628
1629 # Cache miss: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001630 self.assertEqual(f(mock_int, 1), 16)
1631 self.assertEqual(mock_int.__hash__.call_count, 3)
1632 self.assertEqual(f.cache_info(), (1, 3, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001633
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08001634 def test_lru_reentrancy_with_len(self):
1635 # Test to make sure the LRU cache code isn't thrown-off by
1636 # caching the built-in len() function. Since len() can be
1637 # cached, we shouldn't use it inside the lru code itself.
1638 old_len = builtins.len
1639 try:
1640 builtins.len = self.module.lru_cache(4)(len)
1641 for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
1642 self.assertEqual(len('abcdefghijklmn'[:i]), i)
1643 finally:
1644 builtins.len = old_len
1645
Raymond Hettinger605a4472017-01-09 07:50:19 -08001646 def test_lru_star_arg_handling(self):
1647 # Test regression that arose in ea064ff3c10f
1648 @functools.lru_cache()
1649 def f(*args):
1650 return args
1651
1652 self.assertEqual(f(1, 2), (1, 2))
1653 self.assertEqual(f((1, 2)), ((1, 2),))
1654
Yury Selivanov46a02db2016-11-09 18:55:45 -05001655 def test_lru_type_error(self):
1656 # Regression test for issue #28653.
1657 # lru_cache was leaking when one of the arguments
1658 # wasn't cacheable.
1659
1660 @functools.lru_cache(maxsize=None)
1661 def infinite_cache(o):
1662 pass
1663
1664 @functools.lru_cache(maxsize=10)
1665 def limited_cache(o):
1666 pass
1667
1668 with self.assertRaises(TypeError):
1669 infinite_cache([])
1670
1671 with self.assertRaises(TypeError):
1672 limited_cache([])
1673
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001674 def test_lru_with_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001675 @self.module.lru_cache(maxsize=None)
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001676 def fib(n):
1677 if n < 2:
1678 return n
1679 return fib(n-1) + fib(n-2)
1680 self.assertEqual([fib(n) for n in range(16)],
1681 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1682 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001683 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001684 fib.cache_clear()
1685 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001686 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1687
1688 def test_lru_with_maxsize_negative(self):
1689 @self.module.lru_cache(maxsize=-10)
1690 def eq(n):
1691 return n
1692 for i in (0, 1):
1693 self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1694 self.assertEqual(eq.cache_info(),
Raymond Hettingerd8080c02019-01-26 03:02:00 -05001695 self.module._CacheInfo(hits=0, misses=300, maxsize=0, currsize=0))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001696
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001697 def test_lru_with_exceptions(self):
1698 # Verify that user_function exceptions get passed through without
1699 # creating a hard-to-read chained exception.
1700 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +01001701 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001702 @self.module.lru_cache(maxsize)
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001703 def func(i):
1704 return 'abc'[i]
1705 self.assertEqual(func(0), 'a')
1706 with self.assertRaises(IndexError) as cm:
1707 func(15)
1708 self.assertIsNone(cm.exception.__context__)
1709 # Verify that the previous exception did not result in a cached entry
1710 with self.assertRaises(IndexError):
1711 func(15)
1712
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001713 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001714 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001715 @self.module.lru_cache(maxsize=maxsize, typed=True)
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001716 def square(x):
1717 return x * x
1718 self.assertEqual(square(3), 9)
1719 self.assertEqual(type(square(3)), type(9))
1720 self.assertEqual(square(3.0), 9.0)
1721 self.assertEqual(type(square(3.0)), type(9.0))
1722 self.assertEqual(square(x=3), 9)
1723 self.assertEqual(type(square(x=3)), type(9))
1724 self.assertEqual(square(x=3.0), 9.0)
1725 self.assertEqual(type(square(x=3.0)), type(9.0))
1726 self.assertEqual(square.cache_info().hits, 4)
1727 self.assertEqual(square.cache_info().misses, 4)
1728
Antoine Pitroub5b37142012-11-13 21:35:40 +01001729 def test_lru_with_keyword_args(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001730 @self.module.lru_cache()
Antoine Pitroub5b37142012-11-13 21:35:40 +01001731 def fib(n):
1732 if n < 2:
1733 return n
1734 return fib(n=n-1) + fib(n=n-2)
1735 self.assertEqual(
1736 [fib(n=number) for number in range(16)],
1737 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1738 )
1739 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001740 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001741 fib.cache_clear()
1742 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001743 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001744
1745 def test_lru_with_keyword_args_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001746 @self.module.lru_cache(maxsize=None)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001747 def fib(n):
1748 if n < 2:
1749 return n
1750 return fib(n=n-1) + fib(n=n-2)
1751 self.assertEqual([fib(n=number) for number in range(16)],
1752 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1753 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001754 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001755 fib.cache_clear()
1756 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001757 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1758
Raymond Hettinger4ee39142017-01-08 17:28:20 -08001759 def test_kwargs_order(self):
1760 # PEP 468: Preserving Keyword Argument Order
1761 @self.module.lru_cache(maxsize=10)
1762 def f(**kwargs):
1763 return list(kwargs.items())
1764 self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)])
1765 self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)])
1766 self.assertEqual(f.cache_info(),
1767 self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2))
1768
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001769 def test_lru_cache_decoration(self):
1770 def f(zomg: 'zomg_annotation'):
1771 """f doc string"""
1772 return 42
1773 g = self.module.lru_cache()(f)
1774 for attr in self.module.WRAPPER_ASSIGNMENTS:
1775 self.assertEqual(getattr(g, attr), getattr(f, attr))
1776
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001777 def test_lru_cache_threaded(self):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001778 n, m = 5, 11
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001779 def orig(x, y):
1780 return 3 * x + y
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001781 f = self.module.lru_cache(maxsize=n*m)(orig)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001782 hits, misses, maxsize, currsize = f.cache_info()
1783 self.assertEqual(currsize, 0)
1784
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001785 start = threading.Event()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001786 def full(k):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001787 start.wait(10)
1788 for _ in range(m):
Serhiy Storchaka391af752015-06-08 12:44:18 +03001789 self.assertEqual(f(k, 0), orig(k, 0))
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001790
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001791 def clear():
1792 start.wait(10)
1793 for _ in range(2*m):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001794 f.cache_clear()
1795
1796 orig_si = sys.getswitchinterval()
Xavier de Gaye7522ef42016-12-08 11:06:56 +01001797 support.setswitchinterval(1e-6)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001798 try:
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001799 # create n threads in order to fill cache
Serhiy Storchaka391af752015-06-08 12:44:18 +03001800 threads = [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001801 for k in range(n)]
Hai Shie80697d2020-05-28 06:10:27 +08001802 with threading_helper.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001803 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001804
1805 hits, misses, maxsize, currsize = f.cache_info()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001806 if self.module is py_functools:
1807 # XXX: Why can be not equal?
1808 self.assertLessEqual(misses, n)
1809 self.assertLessEqual(hits, m*n - misses)
1810 else:
1811 self.assertEqual(misses, n)
1812 self.assertEqual(hits, m*n - misses)
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001813 self.assertEqual(currsize, n)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001814
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001815 # create n threads in order to fill cache and 1 to clear it
1816 threads = [threading.Thread(target=clear)]
Serhiy Storchaka391af752015-06-08 12:44:18 +03001817 threads += [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001818 for k in range(n)]
1819 start.clear()
Hai Shie80697d2020-05-28 06:10:27 +08001820 with threading_helper.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001821 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001822 finally:
1823 sys.setswitchinterval(orig_si)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001824
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001825 def test_lru_cache_threaded2(self):
1826 # Simultaneous call with the same arguments
1827 n, m = 5, 7
1828 start = threading.Barrier(n+1)
1829 pause = threading.Barrier(n+1)
1830 stop = threading.Barrier(n+1)
1831 @self.module.lru_cache(maxsize=m*n)
1832 def f(x):
1833 pause.wait(10)
1834 return 3 * x
1835 self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1836 def test():
1837 for i in range(m):
1838 start.wait(10)
1839 self.assertEqual(f(i), 3 * i)
1840 stop.wait(10)
1841 threads = [threading.Thread(target=test) for k in range(n)]
Hai Shie80697d2020-05-28 06:10:27 +08001842 with threading_helper.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001843 for i in range(m):
1844 start.wait(10)
1845 stop.reset()
1846 pause.wait(10)
1847 start.reset()
1848 stop.wait(10)
1849 pause.reset()
1850 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1851
Serhiy Storchaka67796522017-01-12 18:34:33 +02001852 def test_lru_cache_threaded3(self):
1853 @self.module.lru_cache(maxsize=2)
1854 def f(x):
1855 time.sleep(.01)
1856 return 3 * x
1857 def test(i, x):
1858 with self.subTest(thread=i):
1859 self.assertEqual(f(x), 3 * x, i)
1860 threads = [threading.Thread(target=test, args=(i, v))
1861 for i, v in enumerate([1, 2, 2, 3, 2])]
Hai Shie80697d2020-05-28 06:10:27 +08001862 with threading_helper.start_threads(threads):
Serhiy Storchaka67796522017-01-12 18:34:33 +02001863 pass
1864
Raymond Hettinger03923422013-03-04 02:52:50 -05001865 def test_need_for_rlock(self):
1866 # This will deadlock on an LRU cache that uses a regular lock
1867
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001868 @self.module.lru_cache(maxsize=10)
Raymond Hettinger03923422013-03-04 02:52:50 -05001869 def test_func(x):
1870 'Used to demonstrate a reentrant lru_cache call within a single thread'
1871 return x
1872
1873 class DoubleEq:
1874 'Demonstrate a reentrant lru_cache call within a single thread'
1875 def __init__(self, x):
1876 self.x = x
1877 def __hash__(self):
1878 return self.x
1879 def __eq__(self, other):
1880 if self.x == 2:
1881 test_func(DoubleEq(1))
1882 return self.x == other.x
1883
1884 test_func(DoubleEq(1)) # Load the cache
1885 test_func(DoubleEq(2)) # Load the cache
1886 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1887 DoubleEq(2)) # Verify the correct return value
1888
Serhiy Storchakae7070f02015-06-08 11:19:24 +03001889 def test_lru_method(self):
1890 class X(int):
1891 f_cnt = 0
1892 @self.module.lru_cache(2)
1893 def f(self, x):
1894 self.f_cnt += 1
1895 return x*10+self
1896 a = X(5)
1897 b = X(5)
1898 c = X(7)
1899 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1900
1901 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1902 self.assertEqual(a.f(x), x*10 + 5)
1903 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1904 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1905
1906 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1907 self.assertEqual(b.f(x), x*10 + 5)
1908 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1909 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1910
1911 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1912 self.assertEqual(c.f(x), x*10 + 7)
1913 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1914 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1915
1916 self.assertEqual(a.f.cache_info(), X.f.cache_info())
1917 self.assertEqual(b.f.cache_info(), X.f.cache_info())
1918 self.assertEqual(c.f.cache_info(), X.f.cache_info())
1919
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001920 def test_pickle(self):
1921 cls = self.__class__
1922 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1923 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1924 with self.subTest(proto=proto, func=f):
1925 f_copy = pickle.loads(pickle.dumps(f, proto))
1926 self.assertIs(f_copy, f)
1927
1928 def test_copy(self):
1929 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001930 def orig(x, y):
1931 return 3 * x + y
1932 part = self.module.partial(orig, 2)
1933 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1934 self.module.lru_cache(2)(part))
1935 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001936 with self.subTest(func=f):
1937 f_copy = copy.copy(f)
1938 self.assertIs(f_copy, f)
1939
1940 def test_deepcopy(self):
1941 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001942 def orig(x, y):
1943 return 3 * x + y
1944 part = self.module.partial(orig, 2)
1945 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1946 self.module.lru_cache(2)(part))
1947 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001948 with self.subTest(func=f):
1949 f_copy = copy.deepcopy(f)
1950 self.assertIs(f_copy, f)
1951
Manjusaka051ff522019-11-12 15:30:18 +08001952 def test_lru_cache_parameters(self):
1953 @self.module.lru_cache(maxsize=2)
1954 def f():
1955 return 1
1956 self.assertEqual(f.cache_parameters(), {'maxsize': 2, "typed": False})
1957
1958 @self.module.lru_cache(maxsize=1000, typed=True)
1959 def f():
1960 return 1
1961 self.assertEqual(f.cache_parameters(), {'maxsize': 1000, "typed": True})
1962
Dennis Sweeney1253c3e2020-05-05 17:14:32 -04001963 def test_lru_cache_weakrefable(self):
1964 @self.module.lru_cache
1965 def test_function(x):
1966 return x
1967
1968 class A:
1969 @self.module.lru_cache
1970 def test_method(self, x):
1971 return (self, x)
1972
1973 @staticmethod
1974 @self.module.lru_cache
1975 def test_staticmethod(x):
1976 return (self, x)
1977
1978 refs = [weakref.ref(test_function),
1979 weakref.ref(A.test_method),
1980 weakref.ref(A.test_staticmethod)]
1981
1982 for ref in refs:
1983 self.assertIsNotNone(ref())
1984
1985 del A
1986 del test_function
1987 gc.collect()
1988
1989 for ref in refs:
1990 self.assertIsNone(ref())
1991
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001992
1993@py_functools.lru_cache()
1994def py_cached_func(x, y):
1995 return 3 * x + y
1996
1997@c_functools.lru_cache()
1998def c_cached_func(x, y):
1999 return 3 * x + y
2000
Serhiy Storchaka46c56112015-05-24 21:53:49 +03002001
2002class TestLRUPy(TestLRU, unittest.TestCase):
2003 module = py_functools
Serhiy Storchaka45120f22015-10-24 09:49:56 +03002004 cached_func = py_cached_func,
2005
2006 @module.lru_cache()
2007 def cached_meth(self, x, y):
2008 return 3 * x + y
2009
2010 @staticmethod
2011 @module.lru_cache()
2012 def cached_staticmeth(x, y):
2013 return 3 * x + y
2014
2015
2016class TestLRUC(TestLRU, unittest.TestCase):
2017 module = c_functools
2018 cached_func = c_cached_func,
2019
2020 @module.lru_cache()
2021 def cached_meth(self, x, y):
2022 return 3 * x + y
2023
2024 @staticmethod
2025 @module.lru_cache()
2026 def cached_staticmeth(x, y):
2027 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03002028
Raymond Hettinger03923422013-03-04 02:52:50 -05002029
Łukasz Langa6f692512013-06-05 12:20:24 +02002030class TestSingleDispatch(unittest.TestCase):
2031 def test_simple_overloads(self):
2032 @functools.singledispatch
2033 def g(obj):
2034 return "base"
2035 def g_int(i):
2036 return "integer"
2037 g.register(int, g_int)
2038 self.assertEqual(g("str"), "base")
2039 self.assertEqual(g(1), "integer")
2040 self.assertEqual(g([1,2,3]), "base")
2041
2042 def test_mro(self):
2043 @functools.singledispatch
2044 def g(obj):
2045 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02002046 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02002047 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02002048 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02002049 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02002050 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02002051 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02002052 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02002053 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02002054 def g_A(a):
2055 return "A"
2056 def g_B(b):
2057 return "B"
2058 g.register(A, g_A)
2059 g.register(B, g_B)
2060 self.assertEqual(g(A()), "A")
2061 self.assertEqual(g(B()), "B")
2062 self.assertEqual(g(C()), "A")
2063 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02002064
2065 def test_register_decorator(self):
2066 @functools.singledispatch
2067 def g(obj):
2068 return "base"
2069 @g.register(int)
2070 def g_int(i):
2071 return "int %s" % (i,)
2072 self.assertEqual(g(""), "base")
2073 self.assertEqual(g(12), "int 12")
2074 self.assertIs(g.dispatch(int), g_int)
2075 self.assertIs(g.dispatch(object), g.dispatch(str))
2076 # Note: in the assert above this is not g.
2077 # @singledispatch returns the wrapper.
2078
2079 def test_wrapping_attributes(self):
2080 @functools.singledispatch
2081 def g(obj):
2082 "Simple test"
2083 return "Test"
2084 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02002085 if sys.flags.optimize < 2:
2086 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02002087
2088 @unittest.skipUnless(decimal, 'requires _decimal')
2089 @support.cpython_only
2090 def test_c_classes(self):
2091 @functools.singledispatch
2092 def g(obj):
2093 return "base"
2094 @g.register(decimal.DecimalException)
2095 def _(obj):
2096 return obj.args
2097 subn = decimal.Subnormal("Exponent < Emin")
2098 rnd = decimal.Rounded("Number got rounded")
2099 self.assertEqual(g(subn), ("Exponent < Emin",))
2100 self.assertEqual(g(rnd), ("Number got rounded",))
2101 @g.register(decimal.Subnormal)
2102 def _(obj):
2103 return "Too small to care."
2104 self.assertEqual(g(subn), "Too small to care.")
2105 self.assertEqual(g(rnd), ("Number got rounded",))
2106
2107 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02002108 # None of the examples in this test depend on haystack ordering.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002109 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02002110 mro = functools._compose_mro
2111 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
2112 for haystack in permutations(bases):
2113 m = mro(dict, haystack)
Guido van Rossumf0666942016-08-23 10:47:07 -07002114 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
2115 c.Collection, c.Sized, c.Iterable,
2116 c.Container, object])
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002117 bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict]
Łukasz Langa6f692512013-06-05 12:20:24 +02002118 for haystack in permutations(bases):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002119 m = mro(collections.ChainMap, haystack)
2120 self.assertEqual(m, [collections.ChainMap, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07002121 c.Collection, c.Sized, c.Iterable,
2122 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02002123
2124 # If there's a generic function with implementations registered for
2125 # both Sized and Container, passing a defaultdict to it results in an
2126 # ambiguous dispatch which will cause a RuntimeError (see
2127 # test_mro_conflicts).
2128 bases = [c.Container, c.Sized, str]
2129 for haystack in permutations(bases):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002130 m = mro(collections.defaultdict, [c.Sized, c.Container, str])
2131 self.assertEqual(m, [collections.defaultdict, dict, c.Sized,
2132 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02002133
2134 # MutableSequence below is registered directly on D. In other words, it
Martin Panter46f50722016-05-26 05:35:26 +00002135 # precedes MutableMapping which means single dispatch will always
Łukasz Langa3720c772013-07-01 16:00:38 +02002136 # choose MutableSequence here.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002137 class D(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02002138 pass
2139 c.MutableSequence.register(D)
2140 bases = [c.MutableSequence, c.MutableMapping]
2141 for haystack in permutations(bases):
2142 m = mro(D, bases)
Guido van Rossumf0666942016-08-23 10:47:07 -07002143 self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002144 collections.defaultdict, dict, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07002145 c.Collection, c.Sized, c.Iterable, c.Container,
Łukasz Langa3720c772013-07-01 16:00:38 +02002146 object])
2147
2148 # Container and Callable are registered on different base classes and
2149 # a generic function supporting both should always pick the Callable
2150 # implementation if a C instance is passed.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002151 class C(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02002152 def __call__(self):
2153 pass
2154 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
2155 for haystack in permutations(bases):
2156 m = mro(C, haystack)
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002157 self.assertEqual(m, [C, c.Callable, collections.defaultdict, dict, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07002158 c.Collection, c.Sized, c.Iterable,
2159 c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02002160
2161 def test_register_abc(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002162 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02002163 d = {"a": "b"}
2164 l = [1, 2, 3]
2165 s = {object(), None}
2166 f = frozenset(s)
2167 t = (1, 2, 3)
2168 @functools.singledispatch
2169 def g(obj):
2170 return "base"
2171 self.assertEqual(g(d), "base")
2172 self.assertEqual(g(l), "base")
2173 self.assertEqual(g(s), "base")
2174 self.assertEqual(g(f), "base")
2175 self.assertEqual(g(t), "base")
2176 g.register(c.Sized, lambda obj: "sized")
2177 self.assertEqual(g(d), "sized")
2178 self.assertEqual(g(l), "sized")
2179 self.assertEqual(g(s), "sized")
2180 self.assertEqual(g(f), "sized")
2181 self.assertEqual(g(t), "sized")
2182 g.register(c.MutableMapping, lambda obj: "mutablemapping")
2183 self.assertEqual(g(d), "mutablemapping")
2184 self.assertEqual(g(l), "sized")
2185 self.assertEqual(g(s), "sized")
2186 self.assertEqual(g(f), "sized")
2187 self.assertEqual(g(t), "sized")
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002188 g.register(collections.ChainMap, lambda obj: "chainmap")
Łukasz Langa6f692512013-06-05 12:20:24 +02002189 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
2190 self.assertEqual(g(l), "sized")
2191 self.assertEqual(g(s), "sized")
2192 self.assertEqual(g(f), "sized")
2193 self.assertEqual(g(t), "sized")
2194 g.register(c.MutableSequence, lambda obj: "mutablesequence")
2195 self.assertEqual(g(d), "mutablemapping")
2196 self.assertEqual(g(l), "mutablesequence")
2197 self.assertEqual(g(s), "sized")
2198 self.assertEqual(g(f), "sized")
2199 self.assertEqual(g(t), "sized")
2200 g.register(c.MutableSet, lambda obj: "mutableset")
2201 self.assertEqual(g(d), "mutablemapping")
2202 self.assertEqual(g(l), "mutablesequence")
2203 self.assertEqual(g(s), "mutableset")
2204 self.assertEqual(g(f), "sized")
2205 self.assertEqual(g(t), "sized")
2206 g.register(c.Mapping, lambda obj: "mapping")
2207 self.assertEqual(g(d), "mutablemapping") # not specific enough
2208 self.assertEqual(g(l), "mutablesequence")
2209 self.assertEqual(g(s), "mutableset")
2210 self.assertEqual(g(f), "sized")
2211 self.assertEqual(g(t), "sized")
2212 g.register(c.Sequence, lambda obj: "sequence")
2213 self.assertEqual(g(d), "mutablemapping")
2214 self.assertEqual(g(l), "mutablesequence")
2215 self.assertEqual(g(s), "mutableset")
2216 self.assertEqual(g(f), "sized")
2217 self.assertEqual(g(t), "sequence")
2218 g.register(c.Set, lambda obj: "set")
2219 self.assertEqual(g(d), "mutablemapping")
2220 self.assertEqual(g(l), "mutablesequence")
2221 self.assertEqual(g(s), "mutableset")
2222 self.assertEqual(g(f), "set")
2223 self.assertEqual(g(t), "sequence")
2224 g.register(dict, lambda obj: "dict")
2225 self.assertEqual(g(d), "dict")
2226 self.assertEqual(g(l), "mutablesequence")
2227 self.assertEqual(g(s), "mutableset")
2228 self.assertEqual(g(f), "set")
2229 self.assertEqual(g(t), "sequence")
2230 g.register(list, lambda obj: "list")
2231 self.assertEqual(g(d), "dict")
2232 self.assertEqual(g(l), "list")
2233 self.assertEqual(g(s), "mutableset")
2234 self.assertEqual(g(f), "set")
2235 self.assertEqual(g(t), "sequence")
2236 g.register(set, lambda obj: "concrete-set")
2237 self.assertEqual(g(d), "dict")
2238 self.assertEqual(g(l), "list")
2239 self.assertEqual(g(s), "concrete-set")
2240 self.assertEqual(g(f), "set")
2241 self.assertEqual(g(t), "sequence")
2242 g.register(frozenset, lambda obj: "frozen-set")
2243 self.assertEqual(g(d), "dict")
2244 self.assertEqual(g(l), "list")
2245 self.assertEqual(g(s), "concrete-set")
2246 self.assertEqual(g(f), "frozen-set")
2247 self.assertEqual(g(t), "sequence")
2248 g.register(tuple, lambda obj: "tuple")
2249 self.assertEqual(g(d), "dict")
2250 self.assertEqual(g(l), "list")
2251 self.assertEqual(g(s), "concrete-set")
2252 self.assertEqual(g(f), "frozen-set")
2253 self.assertEqual(g(t), "tuple")
2254
Łukasz Langa3720c772013-07-01 16:00:38 +02002255 def test_c3_abc(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002256 c = collections.abc
Łukasz Langa3720c772013-07-01 16:00:38 +02002257 mro = functools._c3_mro
2258 class A(object):
2259 pass
2260 class B(A):
2261 def __len__(self):
2262 return 0 # implies Sized
2263 @c.Container.register
2264 class C(object):
2265 pass
2266 class D(object):
2267 pass # unrelated
2268 class X(D, C, B):
2269 def __call__(self):
2270 pass # implies Callable
2271 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
2272 for abcs in permutations([c.Sized, c.Callable, c.Container]):
2273 self.assertEqual(mro(X, abcs=abcs), expected)
2274 # unrelated ABCs don't appear in the resulting MRO
2275 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
2276 self.assertEqual(mro(X, abcs=many_abcs), expected)
2277
Yury Selivanov77a8cd62015-08-18 14:20:00 -04002278 def test_false_meta(self):
2279 # see issue23572
2280 class MetaA(type):
2281 def __len__(self):
2282 return 0
2283 class A(metaclass=MetaA):
2284 pass
2285 class AA(A):
2286 pass
2287 @functools.singledispatch
2288 def fun(a):
2289 return 'base A'
2290 @fun.register(A)
2291 def _(a):
2292 return 'fun A'
2293 aa = AA()
2294 self.assertEqual(fun(aa), 'fun A')
2295
Łukasz Langa6f692512013-06-05 12:20:24 +02002296 def test_mro_conflicts(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002297 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02002298 @functools.singledispatch
2299 def g(arg):
2300 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02002301 class O(c.Sized):
2302 def __len__(self):
2303 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02002304 o = O()
2305 self.assertEqual(g(o), "base")
2306 g.register(c.Iterable, lambda arg: "iterable")
2307 g.register(c.Container, lambda arg: "container")
2308 g.register(c.Sized, lambda arg: "sized")
2309 g.register(c.Set, lambda arg: "set")
2310 self.assertEqual(g(o), "sized")
2311 c.Iterable.register(O)
2312 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
2313 c.Container.register(O)
2314 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02002315 c.Set.register(O)
2316 self.assertEqual(g(o), "set") # because c.Set is a subclass of
2317 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02002318 class P:
2319 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02002320 p = P()
2321 self.assertEqual(g(p), "base")
2322 c.Iterable.register(P)
2323 self.assertEqual(g(p), "iterable")
2324 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02002325 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02002326 g(p)
Benjamin Petersonab078e92016-07-13 21:13:29 -07002327 self.assertIn(
2328 str(re_one.exception),
2329 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2330 "or <class 'collections.abc.Iterable'>"),
2331 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
2332 "or <class 'collections.abc.Container'>")),
2333 )
Łukasz Langa6f692512013-06-05 12:20:24 +02002334 class Q(c.Sized):
2335 def __len__(self):
2336 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02002337 q = Q()
2338 self.assertEqual(g(q), "sized")
2339 c.Iterable.register(Q)
2340 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
2341 c.Set.register(Q)
2342 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02002343 # c.Sized and c.Iterable
2344 @functools.singledispatch
2345 def h(arg):
2346 return "base"
2347 @h.register(c.Sized)
2348 def _(arg):
2349 return "sized"
2350 @h.register(c.Container)
2351 def _(arg):
2352 return "container"
2353 # Even though Sized and Container are explicit bases of MutableMapping,
2354 # this ABC is implicitly registered on defaultdict which makes all of
2355 # MutableMapping's bases implicit as well from defaultdict's
2356 # perspective.
2357 with self.assertRaises(RuntimeError) as re_two:
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002358 h(collections.defaultdict(lambda: 0))
Benjamin Petersonab078e92016-07-13 21:13:29 -07002359 self.assertIn(
2360 str(re_two.exception),
2361 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2362 "or <class 'collections.abc.Sized'>"),
2363 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2364 "or <class 'collections.abc.Container'>")),
2365 )
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002366 class R(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02002367 pass
2368 c.MutableSequence.register(R)
2369 @functools.singledispatch
2370 def i(arg):
2371 return "base"
2372 @i.register(c.MutableMapping)
2373 def _(arg):
2374 return "mapping"
2375 @i.register(c.MutableSequence)
2376 def _(arg):
2377 return "sequence"
2378 r = R()
2379 self.assertEqual(i(r), "sequence")
2380 class S:
2381 pass
2382 class T(S, c.Sized):
2383 def __len__(self):
2384 return 0
2385 t = T()
2386 self.assertEqual(h(t), "sized")
2387 c.Container.register(T)
2388 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
2389 class U:
2390 def __len__(self):
2391 return 0
2392 u = U()
2393 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
2394 # from the existence of __len__()
2395 c.Container.register(U)
2396 # There is no preference for registered versus inferred ABCs.
2397 with self.assertRaises(RuntimeError) as re_three:
2398 h(u)
Benjamin Petersonab078e92016-07-13 21:13:29 -07002399 self.assertIn(
2400 str(re_three.exception),
2401 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2402 "or <class 'collections.abc.Sized'>"),
2403 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2404 "or <class 'collections.abc.Container'>")),
2405 )
Łukasz Langa3720c772013-07-01 16:00:38 +02002406 class V(c.Sized, S):
2407 def __len__(self):
2408 return 0
2409 @functools.singledispatch
2410 def j(arg):
2411 return "base"
2412 @j.register(S)
2413 def _(arg):
2414 return "s"
2415 @j.register(c.Container)
2416 def _(arg):
2417 return "container"
2418 v = V()
2419 self.assertEqual(j(v), "s")
2420 c.Container.register(V)
2421 self.assertEqual(j(v), "container") # because it ends up right after
2422 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02002423
2424 def test_cache_invalidation(self):
2425 from collections import UserDict
INADA Naoki9811e802017-09-30 16:13:02 +09002426 import weakref
2427
Łukasz Langa6f692512013-06-05 12:20:24 +02002428 class TracingDict(UserDict):
2429 def __init__(self, *args, **kwargs):
2430 super(TracingDict, self).__init__(*args, **kwargs)
2431 self.set_ops = []
2432 self.get_ops = []
2433 def __getitem__(self, key):
2434 result = self.data[key]
2435 self.get_ops.append(key)
2436 return result
2437 def __setitem__(self, key, value):
2438 self.set_ops.append(key)
2439 self.data[key] = value
2440 def clear(self):
2441 self.data.clear()
INADA Naoki9811e802017-09-30 16:13:02 +09002442
Łukasz Langa6f692512013-06-05 12:20:24 +02002443 td = TracingDict()
INADA Naoki9811e802017-09-30 16:13:02 +09002444 with support.swap_attr(weakref, "WeakKeyDictionary", lambda: td):
2445 c = collections.abc
2446 @functools.singledispatch
2447 def g(arg):
2448 return "base"
2449 d = {}
2450 l = []
2451 self.assertEqual(len(td), 0)
2452 self.assertEqual(g(d), "base")
2453 self.assertEqual(len(td), 1)
2454 self.assertEqual(td.get_ops, [])
2455 self.assertEqual(td.set_ops, [dict])
2456 self.assertEqual(td.data[dict], g.registry[object])
2457 self.assertEqual(g(l), "base")
2458 self.assertEqual(len(td), 2)
2459 self.assertEqual(td.get_ops, [])
2460 self.assertEqual(td.set_ops, [dict, list])
2461 self.assertEqual(td.data[dict], g.registry[object])
2462 self.assertEqual(td.data[list], g.registry[object])
2463 self.assertEqual(td.data[dict], td.data[list])
2464 self.assertEqual(g(l), "base")
2465 self.assertEqual(g(d), "base")
2466 self.assertEqual(td.get_ops, [list, dict])
2467 self.assertEqual(td.set_ops, [dict, list])
2468 g.register(list, lambda arg: "list")
2469 self.assertEqual(td.get_ops, [list, dict])
2470 self.assertEqual(len(td), 0)
2471 self.assertEqual(g(d), "base")
2472 self.assertEqual(len(td), 1)
2473 self.assertEqual(td.get_ops, [list, dict])
2474 self.assertEqual(td.set_ops, [dict, list, dict])
2475 self.assertEqual(td.data[dict],
2476 functools._find_impl(dict, g.registry))
2477 self.assertEqual(g(l), "list")
2478 self.assertEqual(len(td), 2)
2479 self.assertEqual(td.get_ops, [list, dict])
2480 self.assertEqual(td.set_ops, [dict, list, dict, list])
2481 self.assertEqual(td.data[list],
2482 functools._find_impl(list, g.registry))
2483 class X:
2484 pass
2485 c.MutableMapping.register(X) # Will not invalidate the cache,
2486 # not using ABCs yet.
2487 self.assertEqual(g(d), "base")
2488 self.assertEqual(g(l), "list")
2489 self.assertEqual(td.get_ops, [list, dict, dict, list])
2490 self.assertEqual(td.set_ops, [dict, list, dict, list])
2491 g.register(c.Sized, lambda arg: "sized")
2492 self.assertEqual(len(td), 0)
2493 self.assertEqual(g(d), "sized")
2494 self.assertEqual(len(td), 1)
2495 self.assertEqual(td.get_ops, [list, dict, dict, list])
2496 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
2497 self.assertEqual(g(l), "list")
2498 self.assertEqual(len(td), 2)
2499 self.assertEqual(td.get_ops, [list, dict, dict, list])
2500 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2501 self.assertEqual(g(l), "list")
2502 self.assertEqual(g(d), "sized")
2503 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
2504 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2505 g.dispatch(list)
2506 g.dispatch(dict)
2507 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
2508 list, dict])
2509 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2510 c.MutableSet.register(X) # Will invalidate the cache.
2511 self.assertEqual(len(td), 2) # Stale cache.
2512 self.assertEqual(g(l), "list")
2513 self.assertEqual(len(td), 1)
2514 g.register(c.MutableMapping, lambda arg: "mutablemapping")
2515 self.assertEqual(len(td), 0)
2516 self.assertEqual(g(d), "mutablemapping")
2517 self.assertEqual(len(td), 1)
2518 self.assertEqual(g(l), "list")
2519 self.assertEqual(len(td), 2)
2520 g.register(dict, lambda arg: "dict")
2521 self.assertEqual(g(d), "dict")
2522 self.assertEqual(g(l), "list")
2523 g._clear_cache()
2524 self.assertEqual(len(td), 0)
Łukasz Langa6f692512013-06-05 12:20:24 +02002525
Łukasz Langae5697532017-12-11 13:56:31 -08002526 def test_annotations(self):
2527 @functools.singledispatch
2528 def i(arg):
2529 return "base"
2530 @i.register
2531 def _(arg: collections.abc.Mapping):
2532 return "mapping"
2533 @i.register
2534 def _(arg: "collections.abc.Sequence"):
2535 return "sequence"
2536 self.assertEqual(i(None), "base")
2537 self.assertEqual(i({"a": 1}), "mapping")
2538 self.assertEqual(i([1, 2, 3]), "sequence")
2539 self.assertEqual(i((1, 2, 3)), "sequence")
2540 self.assertEqual(i("str"), "sequence")
2541
2542 # Registering classes as callables doesn't work with annotations,
2543 # you need to pass the type explicitly.
2544 @i.register(str)
2545 class _:
2546 def __init__(self, arg):
2547 self.arg = arg
2548
2549 def __eq__(self, other):
2550 return self.arg == other
2551 self.assertEqual(i("str"), "str")
2552
Ethan Smithc6512752018-05-26 16:38:33 -04002553 def test_method_register(self):
2554 class A:
2555 @functools.singledispatchmethod
2556 def t(self, arg):
2557 self.arg = "base"
2558 @t.register(int)
2559 def _(self, arg):
2560 self.arg = "int"
2561 @t.register(str)
2562 def _(self, arg):
2563 self.arg = "str"
2564 a = A()
2565
2566 a.t(0)
2567 self.assertEqual(a.arg, "int")
2568 aa = A()
2569 self.assertFalse(hasattr(aa, 'arg'))
2570 a.t('')
2571 self.assertEqual(a.arg, "str")
2572 aa = A()
2573 self.assertFalse(hasattr(aa, 'arg'))
2574 a.t(0.0)
2575 self.assertEqual(a.arg, "base")
2576 aa = A()
2577 self.assertFalse(hasattr(aa, 'arg'))
2578
2579 def test_staticmethod_register(self):
2580 class A:
2581 @functools.singledispatchmethod
2582 @staticmethod
2583 def t(arg):
2584 return arg
2585 @t.register(int)
2586 @staticmethod
2587 def _(arg):
2588 return isinstance(arg, int)
2589 @t.register(str)
2590 @staticmethod
2591 def _(arg):
2592 return isinstance(arg, str)
2593 a = A()
2594
2595 self.assertTrue(A.t(0))
2596 self.assertTrue(A.t(''))
2597 self.assertEqual(A.t(0.0), 0.0)
2598
2599 def test_classmethod_register(self):
2600 class A:
2601 def __init__(self, arg):
2602 self.arg = arg
2603
2604 @functools.singledispatchmethod
2605 @classmethod
2606 def t(cls, arg):
2607 return cls("base")
2608 @t.register(int)
2609 @classmethod
2610 def _(cls, arg):
2611 return cls("int")
2612 @t.register(str)
2613 @classmethod
2614 def _(cls, arg):
2615 return cls("str")
2616
2617 self.assertEqual(A.t(0).arg, "int")
2618 self.assertEqual(A.t('').arg, "str")
2619 self.assertEqual(A.t(0.0).arg, "base")
2620
2621 def test_callable_register(self):
2622 class A:
2623 def __init__(self, arg):
2624 self.arg = arg
2625
2626 @functools.singledispatchmethod
2627 @classmethod
2628 def t(cls, arg):
2629 return cls("base")
2630
2631 @A.t.register(int)
2632 @classmethod
2633 def _(cls, arg):
2634 return cls("int")
2635 @A.t.register(str)
2636 @classmethod
2637 def _(cls, arg):
2638 return cls("str")
2639
2640 self.assertEqual(A.t(0).arg, "int")
2641 self.assertEqual(A.t('').arg, "str")
2642 self.assertEqual(A.t(0.0).arg, "base")
2643
2644 def test_abstractmethod_register(self):
2645 class Abstract(abc.ABCMeta):
2646
2647 @functools.singledispatchmethod
2648 @abc.abstractmethod
2649 def add(self, x, y):
2650 pass
2651
2652 self.assertTrue(Abstract.add.__isabstractmethod__)
2653
2654 def test_type_ann_register(self):
2655 class A:
2656 @functools.singledispatchmethod
2657 def t(self, arg):
2658 return "base"
2659 @t.register
2660 def _(self, arg: int):
2661 return "int"
2662 @t.register
2663 def _(self, arg: str):
2664 return "str"
2665 a = A()
2666
2667 self.assertEqual(a.t(0), "int")
2668 self.assertEqual(a.t(''), "str")
2669 self.assertEqual(a.t(0.0), "base")
2670
Łukasz Langae5697532017-12-11 13:56:31 -08002671 def test_invalid_registrations(self):
2672 msg_prefix = "Invalid first argument to `register()`: "
2673 msg_suffix = (
2674 ". Use either `@register(some_class)` or plain `@register` on an "
2675 "annotated function."
2676 )
2677 @functools.singledispatch
2678 def i(arg):
2679 return "base"
2680 with self.assertRaises(TypeError) as exc:
2681 @i.register(42)
2682 def _(arg):
2683 return "I annotated with a non-type"
2684 self.assertTrue(str(exc.exception).startswith(msg_prefix + "42"))
2685 self.assertTrue(str(exc.exception).endswith(msg_suffix))
2686 with self.assertRaises(TypeError) as exc:
2687 @i.register
2688 def _(arg):
2689 return "I forgot to annotate"
2690 self.assertTrue(str(exc.exception).startswith(msg_prefix +
2691 "<function TestSingleDispatch.test_invalid_registrations.<locals>._"
2692 ))
2693 self.assertTrue(str(exc.exception).endswith(msg_suffix))
2694
Łukasz Langae5697532017-12-11 13:56:31 -08002695 with self.assertRaises(TypeError) as exc:
2696 @i.register
2697 def _(arg: typing.Iterable[str]):
2698 # At runtime, dispatching on generics is impossible.
2699 # When registering implementations with singledispatch, avoid
2700 # types from `typing`. Instead, annotate with regular types
2701 # or ABCs.
2702 return "I annotated with a generic collection"
Lysandros Nikolaoud6738102019-05-20 00:11:21 +02002703 self.assertTrue(str(exc.exception).startswith(
2704 "Invalid annotation for 'arg'."
Łukasz Langae5697532017-12-11 13:56:31 -08002705 ))
Lysandros Nikolaoud6738102019-05-20 00:11:21 +02002706 self.assertTrue(str(exc.exception).endswith(
2707 'typing.Iterable[str] is not a class.'
2708 ))
Łukasz Langae5697532017-12-11 13:56:31 -08002709
Dong-hee Na445f1b32018-07-10 16:26:36 +09002710 def test_invalid_positional_argument(self):
2711 @functools.singledispatch
2712 def f(*args):
2713 pass
2714 msg = 'f requires at least 1 positional argument'
INADA Naoki56d8f572018-07-17 13:44:47 +09002715 with self.assertRaisesRegex(TypeError, msg):
Dong-hee Na445f1b32018-07-10 16:26:36 +09002716 f()
Łukasz Langa6f692512013-06-05 12:20:24 +02002717
Carl Meyerd658dea2018-08-28 01:11:56 -06002718
2719class CachedCostItem:
2720 _cost = 1
2721
2722 def __init__(self):
2723 self.lock = py_functools.RLock()
2724
2725 @py_functools.cached_property
2726 def cost(self):
2727 """The cost of the item."""
2728 with self.lock:
2729 self._cost += 1
2730 return self._cost
2731
2732
2733class OptionallyCachedCostItem:
2734 _cost = 1
2735
2736 def get_cost(self):
2737 """The cost of the item."""
2738 self._cost += 1
2739 return self._cost
2740
2741 cached_cost = py_functools.cached_property(get_cost)
2742
2743
2744class CachedCostItemWait:
2745
2746 def __init__(self, event):
2747 self._cost = 1
2748 self.lock = py_functools.RLock()
2749 self.event = event
2750
2751 @py_functools.cached_property
2752 def cost(self):
2753 self.event.wait(1)
2754 with self.lock:
2755 self._cost += 1
2756 return self._cost
2757
2758
2759class CachedCostItemWithSlots:
2760 __slots__ = ('_cost')
2761
2762 def __init__(self):
2763 self._cost = 1
2764
2765 @py_functools.cached_property
2766 def cost(self):
2767 raise RuntimeError('never called, slots not supported')
2768
2769
2770class TestCachedProperty(unittest.TestCase):
2771 def test_cached(self):
2772 item = CachedCostItem()
2773 self.assertEqual(item.cost, 2)
2774 self.assertEqual(item.cost, 2) # not 3
2775
2776 def test_cached_attribute_name_differs_from_func_name(self):
2777 item = OptionallyCachedCostItem()
2778 self.assertEqual(item.get_cost(), 2)
2779 self.assertEqual(item.cached_cost, 3)
2780 self.assertEqual(item.get_cost(), 4)
2781 self.assertEqual(item.cached_cost, 3)
2782
2783 def test_threaded(self):
2784 go = threading.Event()
2785 item = CachedCostItemWait(go)
2786
2787 num_threads = 3
2788
2789 orig_si = sys.getswitchinterval()
2790 sys.setswitchinterval(1e-6)
2791 try:
2792 threads = [
2793 threading.Thread(target=lambda: item.cost)
2794 for k in range(num_threads)
2795 ]
Hai Shie80697d2020-05-28 06:10:27 +08002796 with threading_helper.start_threads(threads):
Carl Meyerd658dea2018-08-28 01:11:56 -06002797 go.set()
2798 finally:
2799 sys.setswitchinterval(orig_si)
2800
2801 self.assertEqual(item.cost, 2)
2802
2803 def test_object_with_slots(self):
2804 item = CachedCostItemWithSlots()
2805 with self.assertRaisesRegex(
2806 TypeError,
2807 "No '__dict__' attribute on 'CachedCostItemWithSlots' instance to cache 'cost' property.",
2808 ):
2809 item.cost
2810
2811 def test_immutable_dict(self):
2812 class MyMeta(type):
2813 @py_functools.cached_property
2814 def prop(self):
2815 return True
2816
2817 class MyClass(metaclass=MyMeta):
2818 pass
2819
2820 with self.assertRaisesRegex(
2821 TypeError,
2822 "The '__dict__' attribute on 'MyMeta' instance does not support item assignment for caching 'prop' property.",
2823 ):
2824 MyClass.prop
2825
2826 def test_reuse_different_names(self):
2827 """Disallow this case because decorated function a would not be cached."""
2828 with self.assertRaises(RuntimeError) as ctx:
2829 class ReusedCachedProperty:
2830 @py_functools.cached_property
2831 def a(self):
2832 pass
2833
2834 b = a
2835
2836 self.assertEqual(
2837 str(ctx.exception.__context__),
2838 str(TypeError("Cannot assign the same cached_property to two different names ('a' and 'b')."))
2839 )
2840
2841 def test_reuse_same_name(self):
2842 """Reusing a cached_property on different classes under the same name is OK."""
2843 counter = 0
2844
2845 @py_functools.cached_property
2846 def _cp(_self):
2847 nonlocal counter
2848 counter += 1
2849 return counter
2850
2851 class A:
2852 cp = _cp
2853
2854 class B:
2855 cp = _cp
2856
2857 a = A()
2858 b = B()
2859
2860 self.assertEqual(a.cp, 1)
2861 self.assertEqual(b.cp, 2)
2862 self.assertEqual(a.cp, 1)
2863
2864 def test_set_name_not_called(self):
2865 cp = py_functools.cached_property(lambda s: None)
2866 class Foo:
2867 pass
2868
2869 Foo.cp = cp
2870
2871 with self.assertRaisesRegex(
2872 TypeError,
2873 "Cannot use cached_property instance without calling __set_name__ on it.",
2874 ):
2875 Foo().cp
2876
2877 def test_access_from_class(self):
2878 self.assertIsInstance(CachedCostItem.cost, py_functools.cached_property)
2879
2880 def test_doc(self):
2881 self.assertEqual(CachedCostItem.cost.__doc__, "The cost of the item.")
2882
2883
Raymond Hettinger9c323f82005-02-28 19:39:44 +00002884if __name__ == '__main__':
Zachary Ware38c707e2015-04-13 15:00:43 -05002885 unittest.main()