blob: e122fe0b33340277591f7bdc88a6b08e2c531a59 [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08002import builtins
Raymond Hettinger003be522011-05-03 11:01:32 -07003import collections
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03004import collections.abc
Serhiy Storchaka45120f22015-10-24 09:49:56 +03005import copy
Pablo Galindo99e6c262020-01-23 15:29:52 +00006from itertools import permutations, chain
Jack Diederiche0cbd692009-04-01 04:27:09 +00007import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00008from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02009import sys
10from test import support
Antoine Pitroua6a4dc82017-09-07 18:56:24 +020011import threading
Serhiy Storchaka67796522017-01-12 18:34:33 +020012import time
Łukasz Langae5697532017-12-11 13:56:31 -080013import typing
Łukasz Langa6f692512013-06-05 12:20:24 +020014import unittest
Raymond Hettingerd191ef22017-01-07 20:44:48 -080015import unittest.mock
Pablo Galindo99e6c262020-01-23 15:29:52 +000016import os
Dennis Sweeney1253c3e2020-05-05 17:14:32 -040017import weakref
18import gc
Łukasz Langa6f692512013-06-05 12:20:24 +020019from weakref import proxy
Nick Coghlan457fc9a2016-09-10 20:00:02 +100020import contextlib
Raymond Hettinger9c323f82005-02-28 19:39:44 +000021
Pablo Galindo99e6c262020-01-23 15:29:52 +000022from test.support.script_helper import assert_python_ok
23
Antoine Pitroub5b37142012-11-13 21:35:40 +010024import functools
25
Antoine Pitroub5b37142012-11-13 21:35:40 +010026py_functools = support.import_fresh_module('functools', blocked=['_functools'])
27c_functools = support.import_fresh_module('functools', fresh=['_functools'])
28
Łukasz Langa6f692512013-06-05 12:20:24 +020029decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
30
Nick Coghlan457fc9a2016-09-10 20:00:02 +100031@contextlib.contextmanager
32def replaced_module(name, replacement):
33 original_module = sys.modules[name]
34 sys.modules[name] = replacement
35 try:
36 yield
37 finally:
38 sys.modules[name] = original_module
Łukasz Langa6f692512013-06-05 12:20:24 +020039
Raymond Hettinger9c323f82005-02-28 19:39:44 +000040def capture(*args, **kw):
41 """capture all positional and keyword arguments"""
42 return args, kw
43
Łukasz Langa6f692512013-06-05 12:20:24 +020044
Jack Diederiche0cbd692009-04-01 04:27:09 +000045def signature(part):
46 """ return the signature of a partial object """
47 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000048
Serhiy Storchaka38741282016-02-02 18:45:17 +020049class MyTuple(tuple):
50 pass
51
52class BadTuple(tuple):
53 def __add__(self, other):
54 return list(self) + list(other)
55
56class MyDict(dict):
57 pass
58
Łukasz Langa6f692512013-06-05 12:20:24 +020059
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020060class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000061
62 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010063 p = self.partial(capture, 1, 2, a=10, b=20)
64 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000065 self.assertEqual(p(3, 4, b=30, c=40),
66 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010067 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000068 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000069
70 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010071 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000072 # attributes should be readable
73 self.assertEqual(p.func, capture)
74 self.assertEqual(p.args, (1, 2))
75 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000076
77 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010078 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000079 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010080 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000081 except TypeError:
82 pass
83 else:
84 self.fail('First arg not checked for callability')
85
86 def test_protection_of_callers_dict_argument(self):
87 # a caller's dictionary should not be altered by partial
88 def func(a=10, b=20):
89 return a
90 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010091 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000092 self.assertEqual(p(**d), 3)
93 self.assertEqual(d, {'a':3})
94 p(b=7)
95 self.assertEqual(d, {'a':3})
96
Serhiy Storchaka9639e4a2017-02-20 14:04:30 +020097 def test_kwargs_copy(self):
98 # Issue #29532: Altering a kwarg dictionary passed to a constructor
99 # should not affect a partial object after creation
100 d = {'a': 3}
101 p = self.partial(capture, **d)
102 self.assertEqual(p(), ((), {'a': 3}))
103 d['a'] = 5
104 self.assertEqual(p(), ((), {'a': 3}))
105
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000106 def test_arg_combinations(self):
107 # exercise special code paths for zero args in either partial
108 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100109 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000110 self.assertEqual(p(), ((), {}))
111 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100112 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000113 self.assertEqual(p(), ((1,2), {}))
114 self.assertEqual(p(3,4), ((1,2,3,4), {}))
115
116 def test_kw_combinations(self):
117 # exercise special code paths for no keyword args in
118 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100119 p = self.partial(capture)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400120 self.assertEqual(p.keywords, {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000121 self.assertEqual(p(), ((), {}))
122 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100123 p = self.partial(capture, a=1)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400124 self.assertEqual(p.keywords, {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000125 self.assertEqual(p(), ((), {'a':1}))
126 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
127 # keyword args in the call override those in the partial object
128 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
129
130 def test_positional(self):
131 # make sure positional arguments are captured correctly
132 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100133 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000134 expected = args + ('x',)
135 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000136 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000137
138 def test_keyword(self):
139 # make sure keyword arguments are captured correctly
140 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100141 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000142 expected = {'a':a,'x':None}
143 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000144 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000145
146 def test_no_side_effects(self):
147 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100148 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000149 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000150 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000151 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000152 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000153
154 def test_error_propagation(self):
155 def f(x, y):
156 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100157 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
158 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
159 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
160 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000161
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000162 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100163 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000164 p = proxy(f)
165 self.assertEqual(f.func, p.func)
166 f = None
167 self.assertRaises(ReferenceError, getattr, p, 'func')
168
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000169 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000170 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100171 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000172 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100173 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000174 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000175
Alexander Belopolskye49af342015-03-01 15:08:17 -0500176 def test_nested_optimization(self):
177 partial = self.partial
Alexander Belopolskye49af342015-03-01 15:08:17 -0500178 inner = partial(signature, 'asdf')
179 nested = partial(inner, bar=True)
180 flat = partial(signature, 'asdf', bar=True)
181 self.assertEqual(signature(nested), signature(flat))
182
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300183 def test_nested_partial_with_attribute(self):
184 # see issue 25137
185 partial = self.partial
186
187 def foo(bar):
188 return bar
189
190 p = partial(foo, 'first')
191 p2 = partial(p, 'second')
192 p2.new_attr = 'spam'
193 self.assertEqual(p2.new_attr, 'spam')
194
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000195 def test_repr(self):
196 args = (object(), object())
197 args_repr = ', '.join(repr(a) for a in args)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200198 kwargs = {'a': object(), 'b': object()}
199 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
200 'b={b!r}, a={a!r}'.format_map(kwargs)]
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000201 if self.partial in (c_functools.partial, py_functools.partial):
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000202 name = 'functools.partial'
203 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100204 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000205
Antoine Pitroub5b37142012-11-13 21:35:40 +0100206 f = self.partial(capture)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000207 self.assertEqual(f'{name}({capture!r})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000208
Antoine Pitroub5b37142012-11-13 21:35:40 +0100209 f = self.partial(capture, *args)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000210 self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000211
Antoine Pitroub5b37142012-11-13 21:35:40 +0100212 f = self.partial(capture, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200213 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000214 [f'{name}({capture!r}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200215 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000216
Antoine Pitroub5b37142012-11-13 21:35:40 +0100217 f = self.partial(capture, *args, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200218 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000219 [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200220 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000221
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300222 def test_recursive_repr(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000223 if self.partial in (c_functools.partial, py_functools.partial):
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300224 name = 'functools.partial'
225 else:
226 name = self.partial.__name__
227
228 f = self.partial(capture)
229 f.__setstate__((f, (), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300230 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000231 self.assertEqual(repr(f), '%s(...)' % (name,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300232 finally:
233 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300234
235 f = self.partial(capture)
236 f.__setstate__((capture, (f,), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300237 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000238 self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300239 finally:
240 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300241
242 f = self.partial(capture)
243 f.__setstate__((capture, (), {'a': f}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300244 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000245 self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300246 finally:
247 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300248
Jack Diederiche0cbd692009-04-01 04:27:09 +0000249 def test_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000250 with self.AllowPickle():
251 f = self.partial(signature, ['asdf'], bar=[True])
252 f.attr = []
253 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
254 f_copy = pickle.loads(pickle.dumps(f, proto))
255 self.assertEqual(signature(f_copy), signature(f))
Serhiy Storchaka38741282016-02-02 18:45:17 +0200256
257 def test_copy(self):
258 f = self.partial(signature, ['asdf'], bar=[True])
259 f.attr = []
260 f_copy = copy.copy(f)
261 self.assertEqual(signature(f_copy), signature(f))
262 self.assertIs(f_copy.attr, f.attr)
263 self.assertIs(f_copy.args, f.args)
264 self.assertIs(f_copy.keywords, f.keywords)
265
266 def test_deepcopy(self):
267 f = self.partial(signature, ['asdf'], bar=[True])
268 f.attr = []
269 f_copy = copy.deepcopy(f)
270 self.assertEqual(signature(f_copy), signature(f))
271 self.assertIsNot(f_copy.attr, f.attr)
272 self.assertIsNot(f_copy.args, f.args)
273 self.assertIsNot(f_copy.args[0], f.args[0])
274 self.assertIsNot(f_copy.keywords, f.keywords)
275 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
276
277 def test_setstate(self):
278 f = self.partial(signature)
279 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000280
Serhiy Storchaka38741282016-02-02 18:45:17 +0200281 self.assertEqual(signature(f),
282 (capture, (1,), dict(a=10), dict(attr=[])))
283 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
284
285 f.__setstate__((capture, (1,), dict(a=10), None))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000286
Serhiy Storchaka38741282016-02-02 18:45:17 +0200287 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
288 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
289
290 f.__setstate__((capture, (1,), None, None))
291 #self.assertEqual(signature(f), (capture, (1,), {}, {}))
292 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
293 self.assertEqual(f(2), ((1, 2), {}))
294 self.assertEqual(f(), ((1,), {}))
295
296 f.__setstate__((capture, (), {}, None))
297 self.assertEqual(signature(f), (capture, (), {}, {}))
298 self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
299 self.assertEqual(f(2), ((2,), {}))
300 self.assertEqual(f(), ((), {}))
301
302 def test_setstate_errors(self):
303 f = self.partial(signature)
304 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
305 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
306 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
307 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
308 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
309 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
310 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
311
312 def test_setstate_subclasses(self):
313 f = self.partial(signature)
314 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
315 s = signature(f)
316 self.assertEqual(s, (capture, (1,), dict(a=10), {}))
317 self.assertIs(type(s[1]), tuple)
318 self.assertIs(type(s[2]), dict)
319 r = f()
320 self.assertEqual(r, ((1,), {'a': 10}))
321 self.assertIs(type(r[0]), tuple)
322 self.assertIs(type(r[1]), dict)
323
324 f.__setstate__((capture, BadTuple((1,)), {}, None))
325 s = signature(f)
326 self.assertEqual(s, (capture, (1,), {}, {}))
327 self.assertIs(type(s[1]), tuple)
328 r = f(2)
329 self.assertEqual(r, ((1, 2), {}))
330 self.assertIs(type(r[0]), tuple)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000331
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300332 def test_recursive_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000333 with self.AllowPickle():
334 f = self.partial(capture)
335 f.__setstate__((f, (), {}, {}))
336 try:
337 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
338 with self.assertRaises(RecursionError):
339 pickle.dumps(f, proto)
340 finally:
341 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300342
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000343 f = self.partial(capture)
344 f.__setstate__((capture, (f,), {}, {}))
345 try:
346 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
347 f_copy = pickle.loads(pickle.dumps(f, proto))
348 try:
349 self.assertIs(f_copy.args[0], f_copy)
350 finally:
351 f_copy.__setstate__((capture, (), {}, {}))
352 finally:
353 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300354
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000355 f = self.partial(capture)
356 f.__setstate__((capture, (), {'a': f}, {}))
357 try:
358 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
359 f_copy = pickle.loads(pickle.dumps(f, proto))
360 try:
361 self.assertIs(f_copy.keywords['a'], f_copy)
362 finally:
363 f_copy.__setstate__((capture, (), {}, {}))
364 finally:
365 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300366
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200367 # Issue 6083: Reference counting bug
368 def test_setstate_refcount(self):
369 class BadSequence:
370 def __len__(self):
371 return 4
372 def __getitem__(self, key):
373 if key == 0:
374 return max
375 elif key == 1:
376 return tuple(range(1000000))
377 elif key in (2, 3):
378 return {}
379 raise IndexError
380
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200381 f = self.partial(object)
Serhiy Storchaka38741282016-02-02 18:45:17 +0200382 self.assertRaises(TypeError, f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000383
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000384@unittest.skipUnless(c_functools, 'requires the C _functools module')
385class TestPartialC(TestPartial, unittest.TestCase):
386 if c_functools:
387 partial = c_functools.partial
388
389 class AllowPickle:
390 def __enter__(self):
391 return self
392 def __exit__(self, type, value, tb):
393 return False
394
395 def test_attributes_unwritable(self):
396 # attributes should not be writable
397 p = self.partial(capture, 1, 2, a=10, b=20)
398 self.assertRaises(AttributeError, setattr, p, 'func', map)
399 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
400 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
401
402 p = self.partial(hex)
403 try:
404 del p.__dict__
405 except TypeError:
406 pass
407 else:
408 self.fail('partial object allowed __dict__ to be deleted')
Łukasz Langa6f692512013-06-05 12:20:24 +0200409
Michael Seifert6c3d5272017-03-15 06:26:33 +0100410 def test_manually_adding_non_string_keyword(self):
411 p = self.partial(capture)
412 # Adding a non-string/unicode keyword to partial kwargs
413 p.keywords[1234] = 'value'
414 r = repr(p)
415 self.assertIn('1234', r)
416 self.assertIn("'value'", r)
417 with self.assertRaises(TypeError):
418 p()
419
420 def test_keystr_replaces_value(self):
421 p = self.partial(capture)
422
423 class MutatesYourDict(object):
424 def __str__(self):
425 p.keywords[self] = ['sth2']
426 return 'astr'
427
Mike53f7a7c2017-12-14 14:04:53 +0300428 # Replacing the value during key formatting should keep the original
Michael Seifert6c3d5272017-03-15 06:26:33 +0100429 # value alive (at least long enough).
430 p.keywords[MutatesYourDict()] = ['sth']
431 r = repr(p)
432 self.assertIn('astr', r)
433 self.assertIn("['sth']", r)
434
435
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200436class TestPartialPy(TestPartial, unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000437 partial = py_functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000438
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000439 class AllowPickle:
440 def __init__(self):
441 self._cm = replaced_module("functools", py_functools)
442 def __enter__(self):
443 return self._cm.__enter__()
444 def __exit__(self, type, value, tb):
445 return self._cm.__exit__(type, value, tb)
Łukasz Langa6f692512013-06-05 12:20:24 +0200446
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200447if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000448 class CPartialSubclass(c_functools.partial):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200449 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100450
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000451class PyPartialSubclass(py_functools.partial):
452 pass
Łukasz Langa6f692512013-06-05 12:20:24 +0200453
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200454@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200455class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200456 if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000457 partial = CPartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000458
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300459 # partial subclasses are not optimized for nested calls
460 test_nested_optimization = None
461
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000462class TestPartialPySubclass(TestPartialPy):
463 partial = PyPartialSubclass
Łukasz Langa6f692512013-06-05 12:20:24 +0200464
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000465class TestPartialMethod(unittest.TestCase):
466
467 class A(object):
468 nothing = functools.partialmethod(capture)
469 positional = functools.partialmethod(capture, 1)
470 keywords = functools.partialmethod(capture, a=2)
471 both = functools.partialmethod(capture, 3, b=4)
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300472 spec_keywords = functools.partialmethod(capture, self=1, func=2)
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000473
474 nested = functools.partialmethod(positional, 5)
475
476 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
477
478 static = functools.partialmethod(staticmethod(capture), 8)
479 cls = functools.partialmethod(classmethod(capture), d=9)
480
481 a = A()
482
483 def test_arg_combinations(self):
484 self.assertEqual(self.a.nothing(), ((self.a,), {}))
485 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
486 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
487 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
488
489 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
490 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
491 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
492 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
493
494 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
495 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
496 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
497 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
498
499 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
500 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
501 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
502 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
503
504 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
505
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300506 self.assertEqual(self.a.spec_keywords(), ((self.a,), {'self': 1, 'func': 2}))
507
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000508 def test_nested(self):
509 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
510 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
511 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
512 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
513
514 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
515
516 def test_over_partial(self):
517 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
518 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
519 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
520 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
521
522 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
523
524 def test_bound_method_introspection(self):
525 obj = self.a
526 self.assertIs(obj.both.__self__, obj)
527 self.assertIs(obj.nested.__self__, obj)
528 self.assertIs(obj.over_partial.__self__, obj)
529 self.assertIs(obj.cls.__self__, self.A)
530 self.assertIs(self.A.cls.__self__, self.A)
531
532 def test_unbound_method_retrieval(self):
533 obj = self.A
534 self.assertFalse(hasattr(obj.both, "__self__"))
535 self.assertFalse(hasattr(obj.nested, "__self__"))
536 self.assertFalse(hasattr(obj.over_partial, "__self__"))
537 self.assertFalse(hasattr(obj.static, "__self__"))
538 self.assertFalse(hasattr(self.a.static, "__self__"))
539
540 def test_descriptors(self):
541 for obj in [self.A, self.a]:
542 with self.subTest(obj=obj):
543 self.assertEqual(obj.static(), ((8,), {}))
544 self.assertEqual(obj.static(5), ((8, 5), {}))
545 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
546 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
547
548 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
549 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
550 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
551 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
552
553 def test_overriding_keywords(self):
554 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
555 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
556
557 def test_invalid_args(self):
558 with self.assertRaises(TypeError):
559 class B(object):
560 method = functools.partialmethod(None, 1)
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300561 with self.assertRaises(TypeError):
562 class B:
563 method = functools.partialmethod()
Serhiy Storchaka142566c2019-06-05 18:22:31 +0300564 with self.assertRaises(TypeError):
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300565 class B:
566 method = functools.partialmethod(func=capture, a=1)
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000567
568 def test_repr(self):
569 self.assertEqual(repr(vars(self.A)['both']),
570 'functools.partialmethod({}, 3, b=4)'.format(capture))
571
572 def test_abstract(self):
573 class Abstract(abc.ABCMeta):
574
575 @abc.abstractmethod
576 def add(self, x, y):
577 pass
578
579 add5 = functools.partialmethod(add, 5)
580
581 self.assertTrue(Abstract.add.__isabstractmethod__)
582 self.assertTrue(Abstract.add5.__isabstractmethod__)
583
584 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
585 self.assertFalse(getattr(func, '__isabstractmethod__', False))
586
Pablo Galindo8c77b8c2019-04-29 13:36:57 +0100587 def test_positional_only(self):
588 def f(a, b, /):
589 return a + b
590
591 p = functools.partial(f, 1)
592 self.assertEqual(p(2), f(1, 2))
593
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000594
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000595class TestUpdateWrapper(unittest.TestCase):
596
597 def check_wrapper(self, wrapper, wrapped,
598 assigned=functools.WRAPPER_ASSIGNMENTS,
599 updated=functools.WRAPPER_UPDATES):
600 # Check attributes were assigned
601 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000602 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000603 # Check attributes were updated
604 for name in updated:
605 wrapper_attr = getattr(wrapper, name)
606 wrapped_attr = getattr(wrapped, name)
607 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000608 if name == "__dict__" and key == "__wrapped__":
609 # __wrapped__ is overwritten by the update code
610 continue
611 self.assertIs(wrapped_attr[key], wrapper_attr[key])
612 # Check __wrapped__
613 self.assertIs(wrapper.__wrapped__, wrapped)
614
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000615
R. David Murray378c0cf2010-02-24 01:46:21 +0000616 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000617 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000618 """This is a test"""
619 pass
620 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000621 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000622 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000623 pass
624 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000625 return wrapper, f
626
627 def test_default_update(self):
628 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000629 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000630 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000631 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600632 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000633 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000634 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
635 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000636
R. David Murray378c0cf2010-02-24 01:46:21 +0000637 @unittest.skipIf(sys.flags.optimize >= 2,
638 "Docstrings are omitted with -O2 and above")
639 def test_default_update_doc(self):
640 wrapper, f = self._default_update()
641 self.assertEqual(wrapper.__doc__, 'This is a test')
642
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000643 def test_no_update(self):
644 def f():
645 """This is a test"""
646 pass
647 f.attr = 'This is also a test'
648 def wrapper():
649 pass
650 functools.update_wrapper(wrapper, f, (), ())
651 self.check_wrapper(wrapper, f, (), ())
652 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600653 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000654 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000655 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000656 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000657
658 def test_selective_update(self):
659 def f():
660 pass
661 f.attr = 'This is a different test'
662 f.dict_attr = dict(a=1, b=2, c=3)
663 def wrapper():
664 pass
665 wrapper.dict_attr = {}
666 assign = ('attr',)
667 update = ('dict_attr',)
668 functools.update_wrapper(wrapper, f, assign, update)
669 self.check_wrapper(wrapper, f, assign, update)
670 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600671 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000672 self.assertEqual(wrapper.__doc__, None)
673 self.assertEqual(wrapper.attr, 'This is a different test')
674 self.assertEqual(wrapper.dict_attr, f.dict_attr)
675
Nick Coghlan98876832010-08-17 06:17:18 +0000676 def test_missing_attributes(self):
677 def f():
678 pass
679 def wrapper():
680 pass
681 wrapper.dict_attr = {}
682 assign = ('attr',)
683 update = ('dict_attr',)
684 # Missing attributes on wrapped object are ignored
685 functools.update_wrapper(wrapper, f, assign, update)
686 self.assertNotIn('attr', wrapper.__dict__)
687 self.assertEqual(wrapper.dict_attr, {})
688 # Wrapper must have expected attributes for updating
689 del wrapper.dict_attr
690 with self.assertRaises(AttributeError):
691 functools.update_wrapper(wrapper, f, assign, update)
692 wrapper.dict_attr = 1
693 with self.assertRaises(AttributeError):
694 functools.update_wrapper(wrapper, f, assign, update)
695
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200696 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000697 @unittest.skipIf(sys.flags.optimize >= 2,
698 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000699 def test_builtin_update(self):
700 # Test for bug #1576241
701 def wrapper():
702 pass
703 functools.update_wrapper(wrapper, max)
704 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000705 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000706 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000707
Łukasz Langa6f692512013-06-05 12:20:24 +0200708
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000709class TestWraps(TestUpdateWrapper):
710
R. David Murray378c0cf2010-02-24 01:46:21 +0000711 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000712 def f():
713 """This is a test"""
714 pass
715 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000716 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000717 @functools.wraps(f)
718 def wrapper():
719 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600720 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000721
722 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600723 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000724 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000725 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600726 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000727 self.assertEqual(wrapper.attr, 'This is also a test')
728
Antoine Pitroub5b37142012-11-13 21:35:40 +0100729 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000730 "Docstrings are omitted with -O2 and above")
731 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600732 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000733 self.assertEqual(wrapper.__doc__, 'This is a test')
734
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000735 def test_no_update(self):
736 def f():
737 """This is a test"""
738 pass
739 f.attr = 'This is also a test'
740 @functools.wraps(f, (), ())
741 def wrapper():
742 pass
743 self.check_wrapper(wrapper, f, (), ())
744 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600745 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000746 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000747 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000748
749 def test_selective_update(self):
750 def f():
751 pass
752 f.attr = 'This is a different test'
753 f.dict_attr = dict(a=1, b=2, c=3)
754 def add_dict_attr(f):
755 f.dict_attr = {}
756 return f
757 assign = ('attr',)
758 update = ('dict_attr',)
759 @functools.wraps(f, assign, update)
760 @add_dict_attr
761 def wrapper():
762 pass
763 self.check_wrapper(wrapper, f, assign, update)
764 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600765 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000766 self.assertEqual(wrapper.__doc__, None)
767 self.assertEqual(wrapper.attr, 'This is a different test')
768 self.assertEqual(wrapper.dict_attr, f.dict_attr)
769
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000770
madman-bobe25d5fc2018-10-25 15:02:10 +0100771class TestReduce:
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000772 def test_reduce(self):
773 class Squares:
774 def __init__(self, max):
775 self.max = max
776 self.sofar = []
777
778 def __len__(self):
779 return len(self.sofar)
780
781 def __getitem__(self, i):
782 if not 0 <= i < self.max: raise IndexError
783 n = len(self.sofar)
784 while n <= i:
785 self.sofar.append(n*n)
786 n += 1
787 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000788 def add(x, y):
789 return x + y
madman-bobe25d5fc2018-10-25 15:02:10 +0100790 self.assertEqual(self.reduce(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000791 self.assertEqual(
madman-bobe25d5fc2018-10-25 15:02:10 +0100792 self.reduce(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000793 ['a','c','d','w']
794 )
madman-bobe25d5fc2018-10-25 15:02:10 +0100795 self.assertEqual(self.reduce(lambda x, y: x*y, range(2,8), 1), 5040)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000796 self.assertEqual(
madman-bobe25d5fc2018-10-25 15:02:10 +0100797 self.reduce(lambda x, y: x*y, range(2,21), 1),
Guido van Rossume2a383d2007-01-15 16:59:06 +0000798 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000799 )
madman-bobe25d5fc2018-10-25 15:02:10 +0100800 self.assertEqual(self.reduce(add, Squares(10)), 285)
801 self.assertEqual(self.reduce(add, Squares(10), 0), 285)
802 self.assertEqual(self.reduce(add, Squares(0), 0), 0)
803 self.assertRaises(TypeError, self.reduce)
804 self.assertRaises(TypeError, self.reduce, 42, 42)
805 self.assertRaises(TypeError, self.reduce, 42, 42, 42)
806 self.assertEqual(self.reduce(42, "1"), "1") # func is never called with one item
807 self.assertEqual(self.reduce(42, "", "1"), "1") # func is never called with one item
808 self.assertRaises(TypeError, self.reduce, 42, (42, 42))
809 self.assertRaises(TypeError, self.reduce, add, []) # arg 2 must not be empty sequence with no initial value
810 self.assertRaises(TypeError, self.reduce, add, "")
811 self.assertRaises(TypeError, self.reduce, add, ())
812 self.assertRaises(TypeError, self.reduce, add, object())
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000813
814 class TestFailingIter:
815 def __iter__(self):
816 raise RuntimeError
madman-bobe25d5fc2018-10-25 15:02:10 +0100817 self.assertRaises(RuntimeError, self.reduce, add, TestFailingIter())
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000818
madman-bobe25d5fc2018-10-25 15:02:10 +0100819 self.assertEqual(self.reduce(add, [], None), None)
820 self.assertEqual(self.reduce(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000821
822 class BadSeq:
823 def __getitem__(self, index):
824 raise ValueError
madman-bobe25d5fc2018-10-25 15:02:10 +0100825 self.assertRaises(ValueError, self.reduce, 42, BadSeq())
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000826
827 # Test reduce()'s use of iterators.
828 def test_iterator_usage(self):
829 class SequenceClass:
830 def __init__(self, n):
831 self.n = n
832 def __getitem__(self, i):
833 if 0 <= i < self.n:
834 return i
835 else:
836 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000837
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000838 from operator import add
madman-bobe25d5fc2018-10-25 15:02:10 +0100839 self.assertEqual(self.reduce(add, SequenceClass(5)), 10)
840 self.assertEqual(self.reduce(add, SequenceClass(5), 42), 52)
841 self.assertRaises(TypeError, self.reduce, add, SequenceClass(0))
842 self.assertEqual(self.reduce(add, SequenceClass(0), 42), 42)
843 self.assertEqual(self.reduce(add, SequenceClass(1)), 0)
844 self.assertEqual(self.reduce(add, SequenceClass(1), 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000845
846 d = {"one": 1, "two": 2, "three": 3}
madman-bobe25d5fc2018-10-25 15:02:10 +0100847 self.assertEqual(self.reduce(add, d), "".join(d.keys()))
848
849
850@unittest.skipUnless(c_functools, 'requires the C _functools module')
851class TestReduceC(TestReduce, unittest.TestCase):
852 if c_functools:
853 reduce = c_functools.reduce
854
855
856class TestReducePy(TestReduce, unittest.TestCase):
857 reduce = staticmethod(py_functools.reduce)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000858
Łukasz Langa6f692512013-06-05 12:20:24 +0200859
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200860class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700861
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000862 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700863 def cmp1(x, y):
864 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100865 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700866 self.assertEqual(key(3), key(3))
867 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100868 self.assertGreaterEqual(key(3), key(3))
869
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700870 def cmp2(x, y):
871 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100872 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700873 self.assertEqual(key(4.0), key('4'))
874 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100875 self.assertLessEqual(key(2), key('35'))
876 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700877
878 def test_cmp_to_key_arguments(self):
879 def cmp1(x, y):
880 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100881 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700882 self.assertEqual(key(obj=3), key(obj=3))
883 self.assertGreater(key(obj=3), key(obj=1))
884 with self.assertRaises((TypeError, AttributeError)):
885 key(3) > 1 # rhs is not a K object
886 with self.assertRaises((TypeError, AttributeError)):
887 1 < key(3) # lhs is not a K object
888 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100889 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700890 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200891 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100892 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700893 with self.assertRaises(TypeError):
894 key() # too few args
895 with self.assertRaises(TypeError):
896 key(None, None) # too many args
897
898 def test_bad_cmp(self):
899 def cmp1(x, y):
900 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100901 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700902 with self.assertRaises(ZeroDivisionError):
903 key(3) > key(1)
904
905 class BadCmp:
906 def __lt__(self, other):
907 raise ZeroDivisionError
908 def cmp1(x, y):
909 return BadCmp()
910 with self.assertRaises(ZeroDivisionError):
911 key(3) > key(1)
912
913 def test_obj_field(self):
914 def cmp1(x, y):
915 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100916 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700917 self.assertEqual(key(50).obj, 50)
918
919 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000920 def mycmp(x, y):
921 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100922 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000923 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000924
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700925 def test_sort_int_str(self):
926 def mycmp(x, y):
927 x, y = int(x), int(y)
928 return (x > y) - (x < y)
929 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100930 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700931 self.assertEqual([int(value) for value in values],
932 [0, 1, 1, 2, 3, 4, 5, 7, 10])
933
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000934 def test_hash(self):
935 def mycmp(x, y):
936 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100937 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000938 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700939 self.assertRaises(TypeError, hash, k)
Serhiy Storchaka2e576f52017-04-24 09:05:00 +0300940 self.assertNotIsInstance(k, collections.abc.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000941
Łukasz Langa6f692512013-06-05 12:20:24 +0200942
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200943@unittest.skipUnless(c_functools, 'requires the C _functools module')
944class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
945 if c_functools:
946 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100947
Łukasz Langa6f692512013-06-05 12:20:24 +0200948
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200949class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100950 cmp_to_key = staticmethod(py_functools.cmp_to_key)
951
Łukasz Langa6f692512013-06-05 12:20:24 +0200952
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000953class TestTotalOrdering(unittest.TestCase):
954
955 def test_total_ordering_lt(self):
956 @functools.total_ordering
957 class A:
958 def __init__(self, value):
959 self.value = value
960 def __lt__(self, other):
961 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000962 def __eq__(self, other):
963 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000964 self.assertTrue(A(1) < A(2))
965 self.assertTrue(A(2) > A(1))
966 self.assertTrue(A(1) <= A(2))
967 self.assertTrue(A(2) >= A(1))
968 self.assertTrue(A(2) <= A(2))
969 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000970 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000971
972 def test_total_ordering_le(self):
973 @functools.total_ordering
974 class A:
975 def __init__(self, value):
976 self.value = value
977 def __le__(self, other):
978 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000979 def __eq__(self, other):
980 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000981 self.assertTrue(A(1) < A(2))
982 self.assertTrue(A(2) > A(1))
983 self.assertTrue(A(1) <= A(2))
984 self.assertTrue(A(2) >= A(1))
985 self.assertTrue(A(2) <= A(2))
986 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000987 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000988
989 def test_total_ordering_gt(self):
990 @functools.total_ordering
991 class A:
992 def __init__(self, value):
993 self.value = value
994 def __gt__(self, other):
995 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000996 def __eq__(self, other):
997 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000998 self.assertTrue(A(1) < A(2))
999 self.assertTrue(A(2) > A(1))
1000 self.assertTrue(A(1) <= A(2))
1001 self.assertTrue(A(2) >= A(1))
1002 self.assertTrue(A(2) <= A(2))
1003 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +10001004 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +00001005
1006 def test_total_ordering_ge(self):
1007 @functools.total_ordering
1008 class A:
1009 def __init__(self, value):
1010 self.value = value
1011 def __ge__(self, other):
1012 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001013 def __eq__(self, other):
1014 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +00001015 self.assertTrue(A(1) < A(2))
1016 self.assertTrue(A(2) > A(1))
1017 self.assertTrue(A(1) <= A(2))
1018 self.assertTrue(A(2) >= A(1))
1019 self.assertTrue(A(2) <= A(2))
1020 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +10001021 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +00001022
1023 def test_total_ordering_no_overwrite(self):
1024 # new methods should not overwrite existing
1025 @functools.total_ordering
1026 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +00001027 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +00001028 self.assertTrue(A(1) < A(2))
1029 self.assertTrue(A(2) > A(1))
1030 self.assertTrue(A(1) <= A(2))
1031 self.assertTrue(A(2) >= A(1))
1032 self.assertTrue(A(2) <= A(2))
1033 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001034
Benjamin Peterson42ebee32010-04-11 01:43:16 +00001035 def test_no_operations_defined(self):
1036 with self.assertRaises(ValueError):
1037 @functools.total_ordering
1038 class A:
1039 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001040
Nick Coghlanf05d9812013-10-02 00:02:03 +10001041 def test_type_error_when_not_implemented(self):
1042 # bug 10042; ensure stack overflow does not occur
1043 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001044 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +10001045 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001046 def __init__(self, value):
1047 self.value = value
1048 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001049 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001050 return self.value == other.value
1051 return False
1052 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001053 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001054 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +10001055 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001056
Nick Coghlanf05d9812013-10-02 00:02:03 +10001057 @functools.total_ordering
1058 class ImplementsGreaterThan:
1059 def __init__(self, value):
1060 self.value = value
1061 def __eq__(self, other):
1062 if isinstance(other, ImplementsGreaterThan):
1063 return self.value == other.value
1064 return False
1065 def __gt__(self, other):
1066 if isinstance(other, ImplementsGreaterThan):
1067 return self.value > other.value
1068 return NotImplemented
1069
1070 @functools.total_ordering
1071 class ImplementsLessThanEqualTo:
1072 def __init__(self, value):
1073 self.value = value
1074 def __eq__(self, other):
1075 if isinstance(other, ImplementsLessThanEqualTo):
1076 return self.value == other.value
1077 return False
1078 def __le__(self, other):
1079 if isinstance(other, ImplementsLessThanEqualTo):
1080 return self.value <= other.value
1081 return NotImplemented
1082
1083 @functools.total_ordering
1084 class ImplementsGreaterThanEqualTo:
1085 def __init__(self, value):
1086 self.value = value
1087 def __eq__(self, other):
1088 if isinstance(other, ImplementsGreaterThanEqualTo):
1089 return self.value == other.value
1090 return False
1091 def __ge__(self, other):
1092 if isinstance(other, ImplementsGreaterThanEqualTo):
1093 return self.value >= other.value
1094 return NotImplemented
1095
1096 @functools.total_ordering
1097 class ComparatorNotImplemented:
1098 def __init__(self, value):
1099 self.value = value
1100 def __eq__(self, other):
1101 if isinstance(other, ComparatorNotImplemented):
1102 return self.value == other.value
1103 return False
1104 def __lt__(self, other):
1105 return NotImplemented
1106
1107 with self.subTest("LT < 1"), self.assertRaises(TypeError):
1108 ImplementsLessThan(-1) < 1
1109
1110 with self.subTest("LT < LE"), self.assertRaises(TypeError):
1111 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
1112
1113 with self.subTest("LT < GT"), self.assertRaises(TypeError):
1114 ImplementsLessThan(1) < ImplementsGreaterThan(1)
1115
1116 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
1117 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
1118
1119 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
1120 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
1121
1122 with self.subTest("GT > GE"), self.assertRaises(TypeError):
1123 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
1124
1125 with self.subTest("GT > LT"), self.assertRaises(TypeError):
1126 ImplementsGreaterThan(5) > ImplementsLessThan(5)
1127
1128 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
1129 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
1130
1131 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
1132 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
1133
1134 with self.subTest("GE when equal"):
1135 a = ComparatorNotImplemented(8)
1136 b = ComparatorNotImplemented(8)
1137 self.assertEqual(a, b)
1138 with self.assertRaises(TypeError):
1139 a >= b
1140
1141 with self.subTest("LE when equal"):
1142 a = ComparatorNotImplemented(9)
1143 b = ComparatorNotImplemented(9)
1144 self.assertEqual(a, b)
1145 with self.assertRaises(TypeError):
1146 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +02001147
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001148 def test_pickle(self):
Serhiy Storchaka92bb90a2016-09-22 11:39:25 +03001149 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001150 for name in '__lt__', '__gt__', '__le__', '__ge__':
1151 with self.subTest(method=name, proto=proto):
1152 method = getattr(Orderable_LT, name)
1153 method_copy = pickle.loads(pickle.dumps(method, proto))
1154 self.assertIs(method_copy, method)
1155
1156@functools.total_ordering
1157class Orderable_LT:
1158 def __init__(self, value):
1159 self.value = value
1160 def __lt__(self, other):
1161 return self.value < other.value
1162 def __eq__(self, other):
1163 return self.value == other.value
1164
1165
Pablo Galindo99e6c262020-01-23 15:29:52 +00001166class TestTopologicalSort(unittest.TestCase):
1167
1168 def _test_graph(self, graph, expected):
1169
1170 def static_order_with_groups(ts):
1171 ts.prepare()
1172 while ts.is_active():
1173 nodes = ts.get_ready()
1174 for node in nodes:
1175 ts.done(node)
1176 yield nodes
1177
1178 ts = functools.TopologicalSorter(graph)
1179 self.assertEqual(list(static_order_with_groups(ts)), list(expected))
1180
1181 ts = functools.TopologicalSorter(graph)
1182 self.assertEqual(list(ts.static_order()), list(chain(*expected)))
1183
1184 def _assert_cycle(self, graph, cycle):
1185 ts = functools.TopologicalSorter()
1186 for node, dependson in graph.items():
1187 ts.add(node, *dependson)
1188 try:
1189 ts.prepare()
1190 except functools.CycleError as e:
1191 msg, seq = e.args
1192 self.assertIn(' '.join(map(str, cycle)),
1193 ' '.join(map(str, seq * 2)))
1194 else:
1195 raise
1196
1197 def test_simple_cases(self):
1198 self._test_graph(
1199 {2: {11},
1200 9: {11, 8},
1201 10: {11, 3},
1202 11: {7, 5},
1203 8: {7, 3}},
1204 [(3, 5, 7), (11, 8), (2, 10, 9)]
1205 )
1206
1207 self._test_graph({1: {}}, [(1,)])
1208
1209 self._test_graph({x: {x+1} for x in range(10)},
1210 [(x,) for x in range(10, -1, -1)])
1211
1212 self._test_graph({2: {3}, 3: {4}, 4: {5}, 5: {1},
1213 11: {12}, 12: {13}, 13: {14}, 14: {15}},
1214 [(1, 15), (5, 14), (4, 13), (3, 12), (2, 11)])
1215
1216 self._test_graph({
1217 0: [1, 2],
1218 1: [3],
1219 2: [5, 6],
1220 3: [4],
1221 4: [9],
1222 5: [3],
1223 6: [7],
1224 7: [8],
1225 8: [4],
1226 9: []
1227 },
1228 [(9,), (4,), (3, 8), (1, 5, 7), (6,), (2,), (0,)]
1229 )
1230
1231 self._test_graph({
1232 0: [1, 2],
1233 1: [],
1234 2: [3],
1235 3: []
1236 },
1237 [(1, 3), (2,), (0,)]
1238 )
1239
1240 self._test_graph({
1241 0: [1, 2],
1242 1: [],
1243 2: [3],
1244 3: [],
1245 4: [5],
1246 5: [6],
1247 6: []
1248 },
1249 [(1, 3, 6), (2, 5), (0, 4)]
1250 )
1251
1252 def test_no_dependencies(self):
1253 self._test_graph(
1254 {1: {2},
1255 3: {4},
1256 5: {6}},
1257 [(2, 4, 6), (1, 3, 5)]
1258 )
1259
1260 self._test_graph(
1261 {1: set(),
1262 3: set(),
1263 5: set()},
1264 [(1, 3, 5)]
1265 )
1266
1267 def test_the_node_multiple_times(self):
1268 # Test same node multiple times in dependencies
1269 self._test_graph({1: {2}, 3: {4}, 0: [2, 4, 4, 4, 4, 4]},
1270 [(2, 4), (1, 3, 0)])
1271
1272 # Test adding the same dependency multiple times
1273 ts = functools.TopologicalSorter()
1274 ts.add(1, 2)
1275 ts.add(1, 2)
1276 ts.add(1, 2)
1277 self.assertEqual([*ts.static_order()], [2, 1])
1278
1279 def test_graph_with_iterables(self):
1280 dependson = (2*x + 1 for x in range(5))
1281 ts = functools.TopologicalSorter({0: dependson})
1282 self.assertEqual(list(ts.static_order()), [1, 3, 5, 7, 9, 0])
1283
1284 def test_add_dependencies_for_same_node_incrementally(self):
1285 # Test same node multiple times
1286 ts = functools.TopologicalSorter()
1287 ts.add(1, 2)
1288 ts.add(1, 3)
1289 ts.add(1, 4)
1290 ts.add(1, 5)
1291
1292 ts2 = functools.TopologicalSorter({1: {2, 3, 4, 5}})
1293 self.assertEqual([*ts.static_order()], [*ts2.static_order()])
1294
1295 def test_empty(self):
1296 self._test_graph({}, [])
1297
1298 def test_cycle(self):
1299 # Self cycle
1300 self._assert_cycle({1: {1}}, [1, 1])
1301 # Simple cycle
1302 self._assert_cycle({1: {2}, 2: {1}}, [1, 2, 1])
1303 # Indirect cycle
1304 self._assert_cycle({1: {2}, 2: {3}, 3: {1}}, [1, 3, 2, 1])
1305 # not all elements involved in a cycle
1306 self._assert_cycle({1: {2}, 2: {3}, 3: {1}, 5: {4}, 4: {6}}, [1, 3, 2, 1])
1307 # Multiple cycles
1308 self._assert_cycle({1: {2}, 2: {1}, 3: {4}, 4: {5}, 6: {7}, 7: {6}},
1309 [1, 2, 1])
1310 # Cycle in the middle of the graph
1311 self._assert_cycle({1: {2}, 2: {3}, 3: {2, 4}, 4: {5}}, [3, 2])
1312
1313 def test_calls_before_prepare(self):
1314 ts = functools.TopologicalSorter()
1315
1316 with self.assertRaisesRegex(ValueError, r"prepare\(\) must be called first"):
1317 ts.get_ready()
1318 with self.assertRaisesRegex(ValueError, r"prepare\(\) must be called first"):
1319 ts.done(3)
1320 with self.assertRaisesRegex(ValueError, r"prepare\(\) must be called first"):
1321 ts.is_active()
1322
1323 def test_prepare_multiple_times(self):
1324 ts = functools.TopologicalSorter()
1325 ts.prepare()
1326 with self.assertRaisesRegex(ValueError, r"cannot prepare\(\) more than once"):
1327 ts.prepare()
1328
1329 def test_invalid_nodes_in_done(self):
1330 ts = functools.TopologicalSorter()
1331 ts.add(1, 2, 3, 4)
1332 ts.add(2, 3, 4)
1333 ts.prepare()
1334 ts.get_ready()
1335
1336 with self.assertRaisesRegex(ValueError, "node 2 was not passed out"):
1337 ts.done(2)
1338 with self.assertRaisesRegex(ValueError, r"node 24 was not added using add\(\)"):
1339 ts.done(24)
1340
1341 def test_done(self):
1342 ts = functools.TopologicalSorter()
1343 ts.add(1, 2, 3, 4)
1344 ts.add(2, 3)
1345 ts.prepare()
1346
1347 self.assertEqual(ts.get_ready(), (3, 4))
1348 # If we don't mark anything as done, get_ready() returns nothing
1349 self.assertEqual(ts.get_ready(), ())
1350 ts.done(3)
1351 # Now 2 becomes available as 3 is done
1352 self.assertEqual(ts.get_ready(), (2,))
1353 self.assertEqual(ts.get_ready(), ())
1354 ts.done(4)
1355 ts.done(2)
1356 # Only 1 is missing
1357 self.assertEqual(ts.get_ready(), (1,))
1358 self.assertEqual(ts.get_ready(), ())
1359 ts.done(1)
1360 self.assertEqual(ts.get_ready(), ())
1361 self.assertFalse(ts.is_active())
1362
1363 def test_is_active(self):
1364 ts = functools.TopologicalSorter()
1365 ts.add(1, 2)
1366 ts.prepare()
1367
1368 self.assertTrue(ts.is_active())
1369 self.assertEqual(ts.get_ready(), (2,))
1370 self.assertTrue(ts.is_active())
1371 ts.done(2)
1372 self.assertTrue(ts.is_active())
1373 self.assertEqual(ts.get_ready(), (1,))
1374 self.assertTrue(ts.is_active())
1375 ts.done(1)
1376 self.assertFalse(ts.is_active())
1377
1378 def test_not_hashable_nodes(self):
1379 ts = functools.TopologicalSorter()
1380 self.assertRaises(TypeError, ts.add, dict(), 1)
1381 self.assertRaises(TypeError, ts.add, 1, dict())
1382 self.assertRaises(TypeError, ts.add, dict(), dict())
1383
1384 def test_order_of_insertion_does_not_matter_between_groups(self):
1385 def get_groups(ts):
1386 ts.prepare()
1387 while ts.is_active():
1388 nodes = ts.get_ready()
1389 ts.done(*nodes)
1390 yield set(nodes)
1391
1392 ts = functools.TopologicalSorter()
1393 ts.add(3, 2, 1)
1394 ts.add(1, 0)
1395 ts.add(4, 5)
1396 ts.add(6, 7)
1397 ts.add(4, 7)
1398
1399 ts2 = functools.TopologicalSorter()
1400 ts2.add(1, 0)
1401 ts2.add(3, 2, 1)
1402 ts2.add(4, 7)
1403 ts2.add(6, 7)
1404 ts2.add(4, 5)
1405
1406 self.assertEqual(list(get_groups(ts)), list(get_groups(ts2)))
1407
1408 def test_static_order_does_not_change_with_the_hash_seed(self):
1409 def check_order_with_hash_seed(seed):
1410 code = """if 1:
1411 import functools
1412 ts = functools.TopologicalSorter()
1413 ts.add('blech', 'bluch', 'hola')
1414 ts.add('abcd', 'blech', 'bluch', 'a', 'b')
1415 ts.add('a', 'a string', 'something', 'b')
1416 ts.add('bluch', 'hola', 'abcde', 'a', 'b')
1417 print(list(ts.static_order()))
1418 """
1419 env = os.environ.copy()
1420 # signal to assert_python not to do a copy
1421 # of os.environ on its own
1422 env['__cleanenv'] = True
1423 env['PYTHONHASHSEED'] = str(seed)
1424 out = assert_python_ok('-c', code, **env)
1425 return out
1426
1427 run1 = check_order_with_hash_seed(1234)
1428 run2 = check_order_with_hash_seed(31415)
1429
1430 self.assertNotEqual(run1, "")
1431 self.assertNotEqual(run2, "")
1432 self.assertEqual(run1, run2)
1433
1434
Raymond Hettinger21cdb712020-05-11 17:00:53 -07001435class TestCache:
1436 # This tests that the pass-through is working as designed.
1437 # The underlying functionality is tested in TestLRU.
1438
1439 def test_cache(self):
1440 @self.module.cache
1441 def fib(n):
1442 if n < 2:
1443 return n
1444 return fib(n-1) + fib(n-2)
1445 self.assertEqual([fib(n) for n in range(16)],
1446 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1447 self.assertEqual(fib.cache_info(),
1448 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1449 fib.cache_clear()
1450 self.assertEqual(fib.cache_info(),
1451 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1452
1453
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001454class TestLRU:
Georg Brandl2e7346a2010-07-31 18:09:23 +00001455
1456 def test_lru(self):
1457 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001458 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001459 f = self.module.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001460 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001461 self.assertEqual(maxsize, 20)
1462 self.assertEqual(currsize, 0)
1463 self.assertEqual(hits, 0)
1464 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001465
1466 domain = range(5)
1467 for i in range(1000):
1468 x, y = choice(domain), choice(domain)
1469 actual = f(x, y)
1470 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +00001471 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001472 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001473 self.assertTrue(hits > misses)
1474 self.assertEqual(hits + misses, 1000)
1475 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001476
Raymond Hettinger02566ec2010-09-04 22:46:06 +00001477 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +00001478 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001479 self.assertEqual(hits, 0)
1480 self.assertEqual(misses, 0)
1481 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001482 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001483 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001484 self.assertEqual(hits, 0)
1485 self.assertEqual(misses, 1)
1486 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001487
Nick Coghlan98876832010-08-17 06:17:18 +00001488 # Test bypassing the cache
1489 self.assertIs(f.__wrapped__, orig)
1490 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001491 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001492 self.assertEqual(hits, 0)
1493 self.assertEqual(misses, 1)
1494 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +00001495
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001496 # test size zero (which means "never-cache")
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001497 @self.module.lru_cache(0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001498 def f():
1499 nonlocal f_cnt
1500 f_cnt += 1
1501 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001502 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001503 f_cnt = 0
1504 for i in range(5):
1505 self.assertEqual(f(), 20)
1506 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001507 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001508 self.assertEqual(hits, 0)
1509 self.assertEqual(misses, 5)
1510 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001511
1512 # test size one
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001513 @self.module.lru_cache(1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001514 def f():
1515 nonlocal f_cnt
1516 f_cnt += 1
1517 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001518 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001519 f_cnt = 0
1520 for i in range(5):
1521 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001522 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001523 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001524 self.assertEqual(hits, 4)
1525 self.assertEqual(misses, 1)
1526 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001527
Raymond Hettingerf3098282010-08-15 03:30:45 +00001528 # test size two
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001529 @self.module.lru_cache(2)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001530 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001531 nonlocal f_cnt
1532 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +00001533 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +00001534 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001535 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +00001536 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1537 # * * * *
1538 self.assertEqual(f(x), x*10)
1539 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001540 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001541 self.assertEqual(hits, 12)
1542 self.assertEqual(misses, 4)
1543 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001544
Raymond Hettingerb8218682019-05-26 11:27:35 -07001545 def test_lru_no_args(self):
1546 @self.module.lru_cache
1547 def square(x):
1548 return x ** 2
1549
1550 self.assertEqual(list(map(square, [10, 20, 10])),
1551 [100, 400, 100])
1552 self.assertEqual(square.cache_info().hits, 1)
1553 self.assertEqual(square.cache_info().misses, 2)
1554 self.assertEqual(square.cache_info().maxsize, 128)
1555 self.assertEqual(square.cache_info().currsize, 2)
1556
Raymond Hettingerd8080c02019-01-26 03:02:00 -05001557 def test_lru_bug_35780(self):
1558 # C version of the lru_cache was not checking to see if
1559 # the user function call has already modified the cache
1560 # (this arises in recursive calls and in multi-threading).
1561 # This cause the cache to have orphan links not referenced
1562 # by the cache dictionary.
1563
1564 once = True # Modified by f(x) below
1565
1566 @self.module.lru_cache(maxsize=10)
1567 def f(x):
1568 nonlocal once
1569 rv = f'.{x}.'
1570 if x == 20 and once:
1571 once = False
1572 rv = f(x)
1573 return rv
1574
1575 # Fill the cache
1576 for x in range(15):
1577 self.assertEqual(f(x), f'.{x}.')
1578 self.assertEqual(f.cache_info().currsize, 10)
1579
1580 # Make a recursive call and make sure the cache remains full
1581 self.assertEqual(f(20), '.20.')
1582 self.assertEqual(f.cache_info().currsize, 10)
1583
Raymond Hettinger14adbd42019-04-20 07:20:44 -10001584 def test_lru_bug_36650(self):
1585 # C version of lru_cache was treating a call with an empty **kwargs
1586 # dictionary as being distinct from a call with no keywords at all.
1587 # This did not result in an incorrect answer, but it did trigger
1588 # an unexpected cache miss.
1589
1590 @self.module.lru_cache()
1591 def f(x):
1592 pass
1593
1594 f(0)
1595 f(0, **{})
1596 self.assertEqual(f.cache_info().hits, 1)
1597
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001598 def test_lru_hash_only_once(self):
1599 # To protect against weird reentrancy bugs and to improve
1600 # efficiency when faced with slow __hash__ methods, the
1601 # LRU cache guarantees that it will only call __hash__
1602 # only once per use as an argument to the cached function.
1603
1604 @self.module.lru_cache(maxsize=1)
1605 def f(x, y):
1606 return x * 3 + y
1607
1608 # Simulate the integer 5
1609 mock_int = unittest.mock.Mock()
1610 mock_int.__mul__ = unittest.mock.Mock(return_value=15)
1611 mock_int.__hash__ = unittest.mock.Mock(return_value=999)
1612
1613 # Add to cache: One use as an argument gives one call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001614 self.assertEqual(f(mock_int, 1), 16)
1615 self.assertEqual(mock_int.__hash__.call_count, 1)
1616 self.assertEqual(f.cache_info(), (0, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001617
1618 # Cache hit: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001619 self.assertEqual(f(mock_int, 1), 16)
1620 self.assertEqual(mock_int.__hash__.call_count, 2)
1621 self.assertEqual(f.cache_info(), (1, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001622
Ville Skyttä49b27342017-08-03 09:00:59 +03001623 # Cache eviction: No use as an argument gives no additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001624 self.assertEqual(f(6, 2), 20)
1625 self.assertEqual(mock_int.__hash__.call_count, 2)
1626 self.assertEqual(f.cache_info(), (1, 2, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001627
1628 # Cache miss: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001629 self.assertEqual(f(mock_int, 1), 16)
1630 self.assertEqual(mock_int.__hash__.call_count, 3)
1631 self.assertEqual(f.cache_info(), (1, 3, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001632
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08001633 def test_lru_reentrancy_with_len(self):
1634 # Test to make sure the LRU cache code isn't thrown-off by
1635 # caching the built-in len() function. Since len() can be
1636 # cached, we shouldn't use it inside the lru code itself.
1637 old_len = builtins.len
1638 try:
1639 builtins.len = self.module.lru_cache(4)(len)
1640 for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
1641 self.assertEqual(len('abcdefghijklmn'[:i]), i)
1642 finally:
1643 builtins.len = old_len
1644
Raymond Hettinger605a4472017-01-09 07:50:19 -08001645 def test_lru_star_arg_handling(self):
1646 # Test regression that arose in ea064ff3c10f
1647 @functools.lru_cache()
1648 def f(*args):
1649 return args
1650
1651 self.assertEqual(f(1, 2), (1, 2))
1652 self.assertEqual(f((1, 2)), ((1, 2),))
1653
Yury Selivanov46a02db2016-11-09 18:55:45 -05001654 def test_lru_type_error(self):
1655 # Regression test for issue #28653.
1656 # lru_cache was leaking when one of the arguments
1657 # wasn't cacheable.
1658
1659 @functools.lru_cache(maxsize=None)
1660 def infinite_cache(o):
1661 pass
1662
1663 @functools.lru_cache(maxsize=10)
1664 def limited_cache(o):
1665 pass
1666
1667 with self.assertRaises(TypeError):
1668 infinite_cache([])
1669
1670 with self.assertRaises(TypeError):
1671 limited_cache([])
1672
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001673 def test_lru_with_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001674 @self.module.lru_cache(maxsize=None)
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001675 def fib(n):
1676 if n < 2:
1677 return n
1678 return fib(n-1) + fib(n-2)
1679 self.assertEqual([fib(n) for n in range(16)],
1680 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1681 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001682 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001683 fib.cache_clear()
1684 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001685 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1686
1687 def test_lru_with_maxsize_negative(self):
1688 @self.module.lru_cache(maxsize=-10)
1689 def eq(n):
1690 return n
1691 for i in (0, 1):
1692 self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1693 self.assertEqual(eq.cache_info(),
Raymond Hettingerd8080c02019-01-26 03:02:00 -05001694 self.module._CacheInfo(hits=0, misses=300, maxsize=0, currsize=0))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001695
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001696 def test_lru_with_exceptions(self):
1697 # Verify that user_function exceptions get passed through without
1698 # creating a hard-to-read chained exception.
1699 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +01001700 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001701 @self.module.lru_cache(maxsize)
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001702 def func(i):
1703 return 'abc'[i]
1704 self.assertEqual(func(0), 'a')
1705 with self.assertRaises(IndexError) as cm:
1706 func(15)
1707 self.assertIsNone(cm.exception.__context__)
1708 # Verify that the previous exception did not result in a cached entry
1709 with self.assertRaises(IndexError):
1710 func(15)
1711
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001712 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001713 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001714 @self.module.lru_cache(maxsize=maxsize, typed=True)
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001715 def square(x):
1716 return x * x
1717 self.assertEqual(square(3), 9)
1718 self.assertEqual(type(square(3)), type(9))
1719 self.assertEqual(square(3.0), 9.0)
1720 self.assertEqual(type(square(3.0)), type(9.0))
1721 self.assertEqual(square(x=3), 9)
1722 self.assertEqual(type(square(x=3)), type(9))
1723 self.assertEqual(square(x=3.0), 9.0)
1724 self.assertEqual(type(square(x=3.0)), type(9.0))
1725 self.assertEqual(square.cache_info().hits, 4)
1726 self.assertEqual(square.cache_info().misses, 4)
1727
Antoine Pitroub5b37142012-11-13 21:35:40 +01001728 def test_lru_with_keyword_args(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001729 @self.module.lru_cache()
Antoine Pitroub5b37142012-11-13 21:35:40 +01001730 def fib(n):
1731 if n < 2:
1732 return n
1733 return fib(n=n-1) + fib(n=n-2)
1734 self.assertEqual(
1735 [fib(n=number) for number in range(16)],
1736 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1737 )
1738 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001739 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001740 fib.cache_clear()
1741 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001742 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001743
1744 def test_lru_with_keyword_args_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001745 @self.module.lru_cache(maxsize=None)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001746 def fib(n):
1747 if n < 2:
1748 return n
1749 return fib(n=n-1) + fib(n=n-2)
1750 self.assertEqual([fib(n=number) for number in range(16)],
1751 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1752 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001753 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001754 fib.cache_clear()
1755 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001756 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1757
Raymond Hettinger4ee39142017-01-08 17:28:20 -08001758 def test_kwargs_order(self):
1759 # PEP 468: Preserving Keyword Argument Order
1760 @self.module.lru_cache(maxsize=10)
1761 def f(**kwargs):
1762 return list(kwargs.items())
1763 self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)])
1764 self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)])
1765 self.assertEqual(f.cache_info(),
1766 self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2))
1767
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001768 def test_lru_cache_decoration(self):
1769 def f(zomg: 'zomg_annotation'):
1770 """f doc string"""
1771 return 42
1772 g = self.module.lru_cache()(f)
1773 for attr in self.module.WRAPPER_ASSIGNMENTS:
1774 self.assertEqual(getattr(g, attr), getattr(f, attr))
1775
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001776 def test_lru_cache_threaded(self):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001777 n, m = 5, 11
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001778 def orig(x, y):
1779 return 3 * x + y
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001780 f = self.module.lru_cache(maxsize=n*m)(orig)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001781 hits, misses, maxsize, currsize = f.cache_info()
1782 self.assertEqual(currsize, 0)
1783
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001784 start = threading.Event()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001785 def full(k):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001786 start.wait(10)
1787 for _ in range(m):
Serhiy Storchaka391af752015-06-08 12:44:18 +03001788 self.assertEqual(f(k, 0), orig(k, 0))
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001789
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001790 def clear():
1791 start.wait(10)
1792 for _ in range(2*m):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001793 f.cache_clear()
1794
1795 orig_si = sys.getswitchinterval()
Xavier de Gaye7522ef42016-12-08 11:06:56 +01001796 support.setswitchinterval(1e-6)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001797 try:
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001798 # create n threads in order to fill cache
Serhiy Storchaka391af752015-06-08 12:44:18 +03001799 threads = [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001800 for k in range(n)]
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001801 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001802 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001803
1804 hits, misses, maxsize, currsize = f.cache_info()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001805 if self.module is py_functools:
1806 # XXX: Why can be not equal?
1807 self.assertLessEqual(misses, n)
1808 self.assertLessEqual(hits, m*n - misses)
1809 else:
1810 self.assertEqual(misses, n)
1811 self.assertEqual(hits, m*n - misses)
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001812 self.assertEqual(currsize, n)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001813
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001814 # create n threads in order to fill cache and 1 to clear it
1815 threads = [threading.Thread(target=clear)]
Serhiy Storchaka391af752015-06-08 12:44:18 +03001816 threads += [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001817 for k in range(n)]
1818 start.clear()
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001819 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001820 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001821 finally:
1822 sys.setswitchinterval(orig_si)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001823
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001824 def test_lru_cache_threaded2(self):
1825 # Simultaneous call with the same arguments
1826 n, m = 5, 7
1827 start = threading.Barrier(n+1)
1828 pause = threading.Barrier(n+1)
1829 stop = threading.Barrier(n+1)
1830 @self.module.lru_cache(maxsize=m*n)
1831 def f(x):
1832 pause.wait(10)
1833 return 3 * x
1834 self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1835 def test():
1836 for i in range(m):
1837 start.wait(10)
1838 self.assertEqual(f(i), 3 * i)
1839 stop.wait(10)
1840 threads = [threading.Thread(target=test) for k in range(n)]
1841 with support.start_threads(threads):
1842 for i in range(m):
1843 start.wait(10)
1844 stop.reset()
1845 pause.wait(10)
1846 start.reset()
1847 stop.wait(10)
1848 pause.reset()
1849 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1850
Serhiy Storchaka67796522017-01-12 18:34:33 +02001851 def test_lru_cache_threaded3(self):
1852 @self.module.lru_cache(maxsize=2)
1853 def f(x):
1854 time.sleep(.01)
1855 return 3 * x
1856 def test(i, x):
1857 with self.subTest(thread=i):
1858 self.assertEqual(f(x), 3 * x, i)
1859 threads = [threading.Thread(target=test, args=(i, v))
1860 for i, v in enumerate([1, 2, 2, 3, 2])]
1861 with support.start_threads(threads):
1862 pass
1863
Raymond Hettinger03923422013-03-04 02:52:50 -05001864 def test_need_for_rlock(self):
1865 # This will deadlock on an LRU cache that uses a regular lock
1866
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001867 @self.module.lru_cache(maxsize=10)
Raymond Hettinger03923422013-03-04 02:52:50 -05001868 def test_func(x):
1869 'Used to demonstrate a reentrant lru_cache call within a single thread'
1870 return x
1871
1872 class DoubleEq:
1873 'Demonstrate a reentrant lru_cache call within a single thread'
1874 def __init__(self, x):
1875 self.x = x
1876 def __hash__(self):
1877 return self.x
1878 def __eq__(self, other):
1879 if self.x == 2:
1880 test_func(DoubleEq(1))
1881 return self.x == other.x
1882
1883 test_func(DoubleEq(1)) # Load the cache
1884 test_func(DoubleEq(2)) # Load the cache
1885 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1886 DoubleEq(2)) # Verify the correct return value
1887
Serhiy Storchakae7070f02015-06-08 11:19:24 +03001888 def test_lru_method(self):
1889 class X(int):
1890 f_cnt = 0
1891 @self.module.lru_cache(2)
1892 def f(self, x):
1893 self.f_cnt += 1
1894 return x*10+self
1895 a = X(5)
1896 b = X(5)
1897 c = X(7)
1898 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1899
1900 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1901 self.assertEqual(a.f(x), x*10 + 5)
1902 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1903 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1904
1905 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1906 self.assertEqual(b.f(x), x*10 + 5)
1907 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1908 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1909
1910 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1911 self.assertEqual(c.f(x), x*10 + 7)
1912 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1913 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1914
1915 self.assertEqual(a.f.cache_info(), X.f.cache_info())
1916 self.assertEqual(b.f.cache_info(), X.f.cache_info())
1917 self.assertEqual(c.f.cache_info(), X.f.cache_info())
1918
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001919 def test_pickle(self):
1920 cls = self.__class__
1921 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1922 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1923 with self.subTest(proto=proto, func=f):
1924 f_copy = pickle.loads(pickle.dumps(f, proto))
1925 self.assertIs(f_copy, f)
1926
1927 def test_copy(self):
1928 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001929 def orig(x, y):
1930 return 3 * x + y
1931 part = self.module.partial(orig, 2)
1932 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1933 self.module.lru_cache(2)(part))
1934 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001935 with self.subTest(func=f):
1936 f_copy = copy.copy(f)
1937 self.assertIs(f_copy, f)
1938
1939 def test_deepcopy(self):
1940 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001941 def orig(x, y):
1942 return 3 * x + y
1943 part = self.module.partial(orig, 2)
1944 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1945 self.module.lru_cache(2)(part))
1946 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001947 with self.subTest(func=f):
1948 f_copy = copy.deepcopy(f)
1949 self.assertIs(f_copy, f)
1950
Manjusaka051ff522019-11-12 15:30:18 +08001951 def test_lru_cache_parameters(self):
1952 @self.module.lru_cache(maxsize=2)
1953 def f():
1954 return 1
1955 self.assertEqual(f.cache_parameters(), {'maxsize': 2, "typed": False})
1956
1957 @self.module.lru_cache(maxsize=1000, typed=True)
1958 def f():
1959 return 1
1960 self.assertEqual(f.cache_parameters(), {'maxsize': 1000, "typed": True})
1961
Dennis Sweeney1253c3e2020-05-05 17:14:32 -04001962 def test_lru_cache_weakrefable(self):
1963 @self.module.lru_cache
1964 def test_function(x):
1965 return x
1966
1967 class A:
1968 @self.module.lru_cache
1969 def test_method(self, x):
1970 return (self, x)
1971
1972 @staticmethod
1973 @self.module.lru_cache
1974 def test_staticmethod(x):
1975 return (self, x)
1976
1977 refs = [weakref.ref(test_function),
1978 weakref.ref(A.test_method),
1979 weakref.ref(A.test_staticmethod)]
1980
1981 for ref in refs:
1982 self.assertIsNotNone(ref())
1983
1984 del A
1985 del test_function
1986 gc.collect()
1987
1988 for ref in refs:
1989 self.assertIsNone(ref())
1990
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001991
1992@py_functools.lru_cache()
1993def py_cached_func(x, y):
1994 return 3 * x + y
1995
1996@c_functools.lru_cache()
1997def c_cached_func(x, y):
1998 return 3 * x + y
1999
Serhiy Storchaka46c56112015-05-24 21:53:49 +03002000
2001class TestLRUPy(TestLRU, unittest.TestCase):
2002 module = py_functools
Serhiy Storchaka45120f22015-10-24 09:49:56 +03002003 cached_func = py_cached_func,
2004
2005 @module.lru_cache()
2006 def cached_meth(self, x, y):
2007 return 3 * x + y
2008
2009 @staticmethod
2010 @module.lru_cache()
2011 def cached_staticmeth(x, y):
2012 return 3 * x + y
2013
2014
2015class TestLRUC(TestLRU, unittest.TestCase):
2016 module = c_functools
2017 cached_func = c_cached_func,
2018
2019 @module.lru_cache()
2020 def cached_meth(self, x, y):
2021 return 3 * x + y
2022
2023 @staticmethod
2024 @module.lru_cache()
2025 def cached_staticmeth(x, y):
2026 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03002027
Raymond Hettinger03923422013-03-04 02:52:50 -05002028
Łukasz Langa6f692512013-06-05 12:20:24 +02002029class TestSingleDispatch(unittest.TestCase):
2030 def test_simple_overloads(self):
2031 @functools.singledispatch
2032 def g(obj):
2033 return "base"
2034 def g_int(i):
2035 return "integer"
2036 g.register(int, g_int)
2037 self.assertEqual(g("str"), "base")
2038 self.assertEqual(g(1), "integer")
2039 self.assertEqual(g([1,2,3]), "base")
2040
2041 def test_mro(self):
2042 @functools.singledispatch
2043 def g(obj):
2044 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02002045 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02002046 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02002047 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02002048 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02002049 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02002050 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02002051 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02002052 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02002053 def g_A(a):
2054 return "A"
2055 def g_B(b):
2056 return "B"
2057 g.register(A, g_A)
2058 g.register(B, g_B)
2059 self.assertEqual(g(A()), "A")
2060 self.assertEqual(g(B()), "B")
2061 self.assertEqual(g(C()), "A")
2062 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02002063
2064 def test_register_decorator(self):
2065 @functools.singledispatch
2066 def g(obj):
2067 return "base"
2068 @g.register(int)
2069 def g_int(i):
2070 return "int %s" % (i,)
2071 self.assertEqual(g(""), "base")
2072 self.assertEqual(g(12), "int 12")
2073 self.assertIs(g.dispatch(int), g_int)
2074 self.assertIs(g.dispatch(object), g.dispatch(str))
2075 # Note: in the assert above this is not g.
2076 # @singledispatch returns the wrapper.
2077
2078 def test_wrapping_attributes(self):
2079 @functools.singledispatch
2080 def g(obj):
2081 "Simple test"
2082 return "Test"
2083 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02002084 if sys.flags.optimize < 2:
2085 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02002086
2087 @unittest.skipUnless(decimal, 'requires _decimal')
2088 @support.cpython_only
2089 def test_c_classes(self):
2090 @functools.singledispatch
2091 def g(obj):
2092 return "base"
2093 @g.register(decimal.DecimalException)
2094 def _(obj):
2095 return obj.args
2096 subn = decimal.Subnormal("Exponent < Emin")
2097 rnd = decimal.Rounded("Number got rounded")
2098 self.assertEqual(g(subn), ("Exponent < Emin",))
2099 self.assertEqual(g(rnd), ("Number got rounded",))
2100 @g.register(decimal.Subnormal)
2101 def _(obj):
2102 return "Too small to care."
2103 self.assertEqual(g(subn), "Too small to care.")
2104 self.assertEqual(g(rnd), ("Number got rounded",))
2105
2106 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02002107 # None of the examples in this test depend on haystack ordering.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002108 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02002109 mro = functools._compose_mro
2110 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
2111 for haystack in permutations(bases):
2112 m = mro(dict, haystack)
Guido van Rossumf0666942016-08-23 10:47:07 -07002113 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
2114 c.Collection, c.Sized, c.Iterable,
2115 c.Container, object])
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002116 bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict]
Łukasz Langa6f692512013-06-05 12:20:24 +02002117 for haystack in permutations(bases):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002118 m = mro(collections.ChainMap, haystack)
2119 self.assertEqual(m, [collections.ChainMap, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07002120 c.Collection, c.Sized, c.Iterable,
2121 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02002122
2123 # If there's a generic function with implementations registered for
2124 # both Sized and Container, passing a defaultdict to it results in an
2125 # ambiguous dispatch which will cause a RuntimeError (see
2126 # test_mro_conflicts).
2127 bases = [c.Container, c.Sized, str]
2128 for haystack in permutations(bases):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002129 m = mro(collections.defaultdict, [c.Sized, c.Container, str])
2130 self.assertEqual(m, [collections.defaultdict, dict, c.Sized,
2131 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02002132
2133 # MutableSequence below is registered directly on D. In other words, it
Martin Panter46f50722016-05-26 05:35:26 +00002134 # precedes MutableMapping which means single dispatch will always
Łukasz Langa3720c772013-07-01 16:00:38 +02002135 # choose MutableSequence here.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002136 class D(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02002137 pass
2138 c.MutableSequence.register(D)
2139 bases = [c.MutableSequence, c.MutableMapping]
2140 for haystack in permutations(bases):
2141 m = mro(D, bases)
Guido van Rossumf0666942016-08-23 10:47:07 -07002142 self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002143 collections.defaultdict, dict, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07002144 c.Collection, c.Sized, c.Iterable, c.Container,
Łukasz Langa3720c772013-07-01 16:00:38 +02002145 object])
2146
2147 # Container and Callable are registered on different base classes and
2148 # a generic function supporting both should always pick the Callable
2149 # implementation if a C instance is passed.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002150 class C(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02002151 def __call__(self):
2152 pass
2153 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
2154 for haystack in permutations(bases):
2155 m = mro(C, haystack)
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002156 self.assertEqual(m, [C, c.Callable, collections.defaultdict, dict, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07002157 c.Collection, c.Sized, c.Iterable,
2158 c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02002159
2160 def test_register_abc(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002161 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02002162 d = {"a": "b"}
2163 l = [1, 2, 3]
2164 s = {object(), None}
2165 f = frozenset(s)
2166 t = (1, 2, 3)
2167 @functools.singledispatch
2168 def g(obj):
2169 return "base"
2170 self.assertEqual(g(d), "base")
2171 self.assertEqual(g(l), "base")
2172 self.assertEqual(g(s), "base")
2173 self.assertEqual(g(f), "base")
2174 self.assertEqual(g(t), "base")
2175 g.register(c.Sized, lambda obj: "sized")
2176 self.assertEqual(g(d), "sized")
2177 self.assertEqual(g(l), "sized")
2178 self.assertEqual(g(s), "sized")
2179 self.assertEqual(g(f), "sized")
2180 self.assertEqual(g(t), "sized")
2181 g.register(c.MutableMapping, lambda obj: "mutablemapping")
2182 self.assertEqual(g(d), "mutablemapping")
2183 self.assertEqual(g(l), "sized")
2184 self.assertEqual(g(s), "sized")
2185 self.assertEqual(g(f), "sized")
2186 self.assertEqual(g(t), "sized")
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002187 g.register(collections.ChainMap, lambda obj: "chainmap")
Łukasz Langa6f692512013-06-05 12:20:24 +02002188 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
2189 self.assertEqual(g(l), "sized")
2190 self.assertEqual(g(s), "sized")
2191 self.assertEqual(g(f), "sized")
2192 self.assertEqual(g(t), "sized")
2193 g.register(c.MutableSequence, lambda obj: "mutablesequence")
2194 self.assertEqual(g(d), "mutablemapping")
2195 self.assertEqual(g(l), "mutablesequence")
2196 self.assertEqual(g(s), "sized")
2197 self.assertEqual(g(f), "sized")
2198 self.assertEqual(g(t), "sized")
2199 g.register(c.MutableSet, lambda obj: "mutableset")
2200 self.assertEqual(g(d), "mutablemapping")
2201 self.assertEqual(g(l), "mutablesequence")
2202 self.assertEqual(g(s), "mutableset")
2203 self.assertEqual(g(f), "sized")
2204 self.assertEqual(g(t), "sized")
2205 g.register(c.Mapping, lambda obj: "mapping")
2206 self.assertEqual(g(d), "mutablemapping") # not specific enough
2207 self.assertEqual(g(l), "mutablesequence")
2208 self.assertEqual(g(s), "mutableset")
2209 self.assertEqual(g(f), "sized")
2210 self.assertEqual(g(t), "sized")
2211 g.register(c.Sequence, lambda obj: "sequence")
2212 self.assertEqual(g(d), "mutablemapping")
2213 self.assertEqual(g(l), "mutablesequence")
2214 self.assertEqual(g(s), "mutableset")
2215 self.assertEqual(g(f), "sized")
2216 self.assertEqual(g(t), "sequence")
2217 g.register(c.Set, lambda obj: "set")
2218 self.assertEqual(g(d), "mutablemapping")
2219 self.assertEqual(g(l), "mutablesequence")
2220 self.assertEqual(g(s), "mutableset")
2221 self.assertEqual(g(f), "set")
2222 self.assertEqual(g(t), "sequence")
2223 g.register(dict, lambda obj: "dict")
2224 self.assertEqual(g(d), "dict")
2225 self.assertEqual(g(l), "mutablesequence")
2226 self.assertEqual(g(s), "mutableset")
2227 self.assertEqual(g(f), "set")
2228 self.assertEqual(g(t), "sequence")
2229 g.register(list, lambda obj: "list")
2230 self.assertEqual(g(d), "dict")
2231 self.assertEqual(g(l), "list")
2232 self.assertEqual(g(s), "mutableset")
2233 self.assertEqual(g(f), "set")
2234 self.assertEqual(g(t), "sequence")
2235 g.register(set, lambda obj: "concrete-set")
2236 self.assertEqual(g(d), "dict")
2237 self.assertEqual(g(l), "list")
2238 self.assertEqual(g(s), "concrete-set")
2239 self.assertEqual(g(f), "set")
2240 self.assertEqual(g(t), "sequence")
2241 g.register(frozenset, lambda obj: "frozen-set")
2242 self.assertEqual(g(d), "dict")
2243 self.assertEqual(g(l), "list")
2244 self.assertEqual(g(s), "concrete-set")
2245 self.assertEqual(g(f), "frozen-set")
2246 self.assertEqual(g(t), "sequence")
2247 g.register(tuple, lambda obj: "tuple")
2248 self.assertEqual(g(d), "dict")
2249 self.assertEqual(g(l), "list")
2250 self.assertEqual(g(s), "concrete-set")
2251 self.assertEqual(g(f), "frozen-set")
2252 self.assertEqual(g(t), "tuple")
2253
Łukasz Langa3720c772013-07-01 16:00:38 +02002254 def test_c3_abc(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002255 c = collections.abc
Łukasz Langa3720c772013-07-01 16:00:38 +02002256 mro = functools._c3_mro
2257 class A(object):
2258 pass
2259 class B(A):
2260 def __len__(self):
2261 return 0 # implies Sized
2262 @c.Container.register
2263 class C(object):
2264 pass
2265 class D(object):
2266 pass # unrelated
2267 class X(D, C, B):
2268 def __call__(self):
2269 pass # implies Callable
2270 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
2271 for abcs in permutations([c.Sized, c.Callable, c.Container]):
2272 self.assertEqual(mro(X, abcs=abcs), expected)
2273 # unrelated ABCs don't appear in the resulting MRO
2274 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
2275 self.assertEqual(mro(X, abcs=many_abcs), expected)
2276
Yury Selivanov77a8cd62015-08-18 14:20:00 -04002277 def test_false_meta(self):
2278 # see issue23572
2279 class MetaA(type):
2280 def __len__(self):
2281 return 0
2282 class A(metaclass=MetaA):
2283 pass
2284 class AA(A):
2285 pass
2286 @functools.singledispatch
2287 def fun(a):
2288 return 'base A'
2289 @fun.register(A)
2290 def _(a):
2291 return 'fun A'
2292 aa = AA()
2293 self.assertEqual(fun(aa), 'fun A')
2294
Łukasz Langa6f692512013-06-05 12:20:24 +02002295 def test_mro_conflicts(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002296 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02002297 @functools.singledispatch
2298 def g(arg):
2299 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02002300 class O(c.Sized):
2301 def __len__(self):
2302 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02002303 o = O()
2304 self.assertEqual(g(o), "base")
2305 g.register(c.Iterable, lambda arg: "iterable")
2306 g.register(c.Container, lambda arg: "container")
2307 g.register(c.Sized, lambda arg: "sized")
2308 g.register(c.Set, lambda arg: "set")
2309 self.assertEqual(g(o), "sized")
2310 c.Iterable.register(O)
2311 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
2312 c.Container.register(O)
2313 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02002314 c.Set.register(O)
2315 self.assertEqual(g(o), "set") # because c.Set is a subclass of
2316 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02002317 class P:
2318 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02002319 p = P()
2320 self.assertEqual(g(p), "base")
2321 c.Iterable.register(P)
2322 self.assertEqual(g(p), "iterable")
2323 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02002324 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02002325 g(p)
Benjamin Petersonab078e92016-07-13 21:13:29 -07002326 self.assertIn(
2327 str(re_one.exception),
2328 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2329 "or <class 'collections.abc.Iterable'>"),
2330 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
2331 "or <class 'collections.abc.Container'>")),
2332 )
Łukasz Langa6f692512013-06-05 12:20:24 +02002333 class Q(c.Sized):
2334 def __len__(self):
2335 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02002336 q = Q()
2337 self.assertEqual(g(q), "sized")
2338 c.Iterable.register(Q)
2339 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
2340 c.Set.register(Q)
2341 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02002342 # c.Sized and c.Iterable
2343 @functools.singledispatch
2344 def h(arg):
2345 return "base"
2346 @h.register(c.Sized)
2347 def _(arg):
2348 return "sized"
2349 @h.register(c.Container)
2350 def _(arg):
2351 return "container"
2352 # Even though Sized and Container are explicit bases of MutableMapping,
2353 # this ABC is implicitly registered on defaultdict which makes all of
2354 # MutableMapping's bases implicit as well from defaultdict's
2355 # perspective.
2356 with self.assertRaises(RuntimeError) as re_two:
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002357 h(collections.defaultdict(lambda: 0))
Benjamin Petersonab078e92016-07-13 21:13:29 -07002358 self.assertIn(
2359 str(re_two.exception),
2360 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2361 "or <class 'collections.abc.Sized'>"),
2362 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2363 "or <class 'collections.abc.Container'>")),
2364 )
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002365 class R(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02002366 pass
2367 c.MutableSequence.register(R)
2368 @functools.singledispatch
2369 def i(arg):
2370 return "base"
2371 @i.register(c.MutableMapping)
2372 def _(arg):
2373 return "mapping"
2374 @i.register(c.MutableSequence)
2375 def _(arg):
2376 return "sequence"
2377 r = R()
2378 self.assertEqual(i(r), "sequence")
2379 class S:
2380 pass
2381 class T(S, c.Sized):
2382 def __len__(self):
2383 return 0
2384 t = T()
2385 self.assertEqual(h(t), "sized")
2386 c.Container.register(T)
2387 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
2388 class U:
2389 def __len__(self):
2390 return 0
2391 u = U()
2392 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
2393 # from the existence of __len__()
2394 c.Container.register(U)
2395 # There is no preference for registered versus inferred ABCs.
2396 with self.assertRaises(RuntimeError) as re_three:
2397 h(u)
Benjamin Petersonab078e92016-07-13 21:13:29 -07002398 self.assertIn(
2399 str(re_three.exception),
2400 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2401 "or <class 'collections.abc.Sized'>"),
2402 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2403 "or <class 'collections.abc.Container'>")),
2404 )
Łukasz Langa3720c772013-07-01 16:00:38 +02002405 class V(c.Sized, S):
2406 def __len__(self):
2407 return 0
2408 @functools.singledispatch
2409 def j(arg):
2410 return "base"
2411 @j.register(S)
2412 def _(arg):
2413 return "s"
2414 @j.register(c.Container)
2415 def _(arg):
2416 return "container"
2417 v = V()
2418 self.assertEqual(j(v), "s")
2419 c.Container.register(V)
2420 self.assertEqual(j(v), "container") # because it ends up right after
2421 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02002422
2423 def test_cache_invalidation(self):
2424 from collections import UserDict
INADA Naoki9811e802017-09-30 16:13:02 +09002425 import weakref
2426
Łukasz Langa6f692512013-06-05 12:20:24 +02002427 class TracingDict(UserDict):
2428 def __init__(self, *args, **kwargs):
2429 super(TracingDict, self).__init__(*args, **kwargs)
2430 self.set_ops = []
2431 self.get_ops = []
2432 def __getitem__(self, key):
2433 result = self.data[key]
2434 self.get_ops.append(key)
2435 return result
2436 def __setitem__(self, key, value):
2437 self.set_ops.append(key)
2438 self.data[key] = value
2439 def clear(self):
2440 self.data.clear()
INADA Naoki9811e802017-09-30 16:13:02 +09002441
Łukasz Langa6f692512013-06-05 12:20:24 +02002442 td = TracingDict()
INADA Naoki9811e802017-09-30 16:13:02 +09002443 with support.swap_attr(weakref, "WeakKeyDictionary", lambda: td):
2444 c = collections.abc
2445 @functools.singledispatch
2446 def g(arg):
2447 return "base"
2448 d = {}
2449 l = []
2450 self.assertEqual(len(td), 0)
2451 self.assertEqual(g(d), "base")
2452 self.assertEqual(len(td), 1)
2453 self.assertEqual(td.get_ops, [])
2454 self.assertEqual(td.set_ops, [dict])
2455 self.assertEqual(td.data[dict], g.registry[object])
2456 self.assertEqual(g(l), "base")
2457 self.assertEqual(len(td), 2)
2458 self.assertEqual(td.get_ops, [])
2459 self.assertEqual(td.set_ops, [dict, list])
2460 self.assertEqual(td.data[dict], g.registry[object])
2461 self.assertEqual(td.data[list], g.registry[object])
2462 self.assertEqual(td.data[dict], td.data[list])
2463 self.assertEqual(g(l), "base")
2464 self.assertEqual(g(d), "base")
2465 self.assertEqual(td.get_ops, [list, dict])
2466 self.assertEqual(td.set_ops, [dict, list])
2467 g.register(list, lambda arg: "list")
2468 self.assertEqual(td.get_ops, [list, dict])
2469 self.assertEqual(len(td), 0)
2470 self.assertEqual(g(d), "base")
2471 self.assertEqual(len(td), 1)
2472 self.assertEqual(td.get_ops, [list, dict])
2473 self.assertEqual(td.set_ops, [dict, list, dict])
2474 self.assertEqual(td.data[dict],
2475 functools._find_impl(dict, g.registry))
2476 self.assertEqual(g(l), "list")
2477 self.assertEqual(len(td), 2)
2478 self.assertEqual(td.get_ops, [list, dict])
2479 self.assertEqual(td.set_ops, [dict, list, dict, list])
2480 self.assertEqual(td.data[list],
2481 functools._find_impl(list, g.registry))
2482 class X:
2483 pass
2484 c.MutableMapping.register(X) # Will not invalidate the cache,
2485 # not using ABCs yet.
2486 self.assertEqual(g(d), "base")
2487 self.assertEqual(g(l), "list")
2488 self.assertEqual(td.get_ops, [list, dict, dict, list])
2489 self.assertEqual(td.set_ops, [dict, list, dict, list])
2490 g.register(c.Sized, lambda arg: "sized")
2491 self.assertEqual(len(td), 0)
2492 self.assertEqual(g(d), "sized")
2493 self.assertEqual(len(td), 1)
2494 self.assertEqual(td.get_ops, [list, dict, dict, list])
2495 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
2496 self.assertEqual(g(l), "list")
2497 self.assertEqual(len(td), 2)
2498 self.assertEqual(td.get_ops, [list, dict, dict, list])
2499 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2500 self.assertEqual(g(l), "list")
2501 self.assertEqual(g(d), "sized")
2502 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
2503 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2504 g.dispatch(list)
2505 g.dispatch(dict)
2506 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
2507 list, dict])
2508 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2509 c.MutableSet.register(X) # Will invalidate the cache.
2510 self.assertEqual(len(td), 2) # Stale cache.
2511 self.assertEqual(g(l), "list")
2512 self.assertEqual(len(td), 1)
2513 g.register(c.MutableMapping, lambda arg: "mutablemapping")
2514 self.assertEqual(len(td), 0)
2515 self.assertEqual(g(d), "mutablemapping")
2516 self.assertEqual(len(td), 1)
2517 self.assertEqual(g(l), "list")
2518 self.assertEqual(len(td), 2)
2519 g.register(dict, lambda arg: "dict")
2520 self.assertEqual(g(d), "dict")
2521 self.assertEqual(g(l), "list")
2522 g._clear_cache()
2523 self.assertEqual(len(td), 0)
Łukasz Langa6f692512013-06-05 12:20:24 +02002524
Łukasz Langae5697532017-12-11 13:56:31 -08002525 def test_annotations(self):
2526 @functools.singledispatch
2527 def i(arg):
2528 return "base"
2529 @i.register
2530 def _(arg: collections.abc.Mapping):
2531 return "mapping"
2532 @i.register
2533 def _(arg: "collections.abc.Sequence"):
2534 return "sequence"
2535 self.assertEqual(i(None), "base")
2536 self.assertEqual(i({"a": 1}), "mapping")
2537 self.assertEqual(i([1, 2, 3]), "sequence")
2538 self.assertEqual(i((1, 2, 3)), "sequence")
2539 self.assertEqual(i("str"), "sequence")
2540
2541 # Registering classes as callables doesn't work with annotations,
2542 # you need to pass the type explicitly.
2543 @i.register(str)
2544 class _:
2545 def __init__(self, arg):
2546 self.arg = arg
2547
2548 def __eq__(self, other):
2549 return self.arg == other
2550 self.assertEqual(i("str"), "str")
2551
Ethan Smithc6512752018-05-26 16:38:33 -04002552 def test_method_register(self):
2553 class A:
2554 @functools.singledispatchmethod
2555 def t(self, arg):
2556 self.arg = "base"
2557 @t.register(int)
2558 def _(self, arg):
2559 self.arg = "int"
2560 @t.register(str)
2561 def _(self, arg):
2562 self.arg = "str"
2563 a = A()
2564
2565 a.t(0)
2566 self.assertEqual(a.arg, "int")
2567 aa = A()
2568 self.assertFalse(hasattr(aa, 'arg'))
2569 a.t('')
2570 self.assertEqual(a.arg, "str")
2571 aa = A()
2572 self.assertFalse(hasattr(aa, 'arg'))
2573 a.t(0.0)
2574 self.assertEqual(a.arg, "base")
2575 aa = A()
2576 self.assertFalse(hasattr(aa, 'arg'))
2577
2578 def test_staticmethod_register(self):
2579 class A:
2580 @functools.singledispatchmethod
2581 @staticmethod
2582 def t(arg):
2583 return arg
2584 @t.register(int)
2585 @staticmethod
2586 def _(arg):
2587 return isinstance(arg, int)
2588 @t.register(str)
2589 @staticmethod
2590 def _(arg):
2591 return isinstance(arg, str)
2592 a = A()
2593
2594 self.assertTrue(A.t(0))
2595 self.assertTrue(A.t(''))
2596 self.assertEqual(A.t(0.0), 0.0)
2597
2598 def test_classmethod_register(self):
2599 class A:
2600 def __init__(self, arg):
2601 self.arg = arg
2602
2603 @functools.singledispatchmethod
2604 @classmethod
2605 def t(cls, arg):
2606 return cls("base")
2607 @t.register(int)
2608 @classmethod
2609 def _(cls, arg):
2610 return cls("int")
2611 @t.register(str)
2612 @classmethod
2613 def _(cls, arg):
2614 return cls("str")
2615
2616 self.assertEqual(A.t(0).arg, "int")
2617 self.assertEqual(A.t('').arg, "str")
2618 self.assertEqual(A.t(0.0).arg, "base")
2619
2620 def test_callable_register(self):
2621 class A:
2622 def __init__(self, arg):
2623 self.arg = arg
2624
2625 @functools.singledispatchmethod
2626 @classmethod
2627 def t(cls, arg):
2628 return cls("base")
2629
2630 @A.t.register(int)
2631 @classmethod
2632 def _(cls, arg):
2633 return cls("int")
2634 @A.t.register(str)
2635 @classmethod
2636 def _(cls, arg):
2637 return cls("str")
2638
2639 self.assertEqual(A.t(0).arg, "int")
2640 self.assertEqual(A.t('').arg, "str")
2641 self.assertEqual(A.t(0.0).arg, "base")
2642
2643 def test_abstractmethod_register(self):
2644 class Abstract(abc.ABCMeta):
2645
2646 @functools.singledispatchmethod
2647 @abc.abstractmethod
2648 def add(self, x, y):
2649 pass
2650
2651 self.assertTrue(Abstract.add.__isabstractmethod__)
2652
2653 def test_type_ann_register(self):
2654 class A:
2655 @functools.singledispatchmethod
2656 def t(self, arg):
2657 return "base"
2658 @t.register
2659 def _(self, arg: int):
2660 return "int"
2661 @t.register
2662 def _(self, arg: str):
2663 return "str"
2664 a = A()
2665
2666 self.assertEqual(a.t(0), "int")
2667 self.assertEqual(a.t(''), "str")
2668 self.assertEqual(a.t(0.0), "base")
2669
Łukasz Langae5697532017-12-11 13:56:31 -08002670 def test_invalid_registrations(self):
2671 msg_prefix = "Invalid first argument to `register()`: "
2672 msg_suffix = (
2673 ". Use either `@register(some_class)` or plain `@register` on an "
2674 "annotated function."
2675 )
2676 @functools.singledispatch
2677 def i(arg):
2678 return "base"
2679 with self.assertRaises(TypeError) as exc:
2680 @i.register(42)
2681 def _(arg):
2682 return "I annotated with a non-type"
2683 self.assertTrue(str(exc.exception).startswith(msg_prefix + "42"))
2684 self.assertTrue(str(exc.exception).endswith(msg_suffix))
2685 with self.assertRaises(TypeError) as exc:
2686 @i.register
2687 def _(arg):
2688 return "I forgot to annotate"
2689 self.assertTrue(str(exc.exception).startswith(msg_prefix +
2690 "<function TestSingleDispatch.test_invalid_registrations.<locals>._"
2691 ))
2692 self.assertTrue(str(exc.exception).endswith(msg_suffix))
2693
Łukasz Langae5697532017-12-11 13:56:31 -08002694 with self.assertRaises(TypeError) as exc:
2695 @i.register
2696 def _(arg: typing.Iterable[str]):
2697 # At runtime, dispatching on generics is impossible.
2698 # When registering implementations with singledispatch, avoid
2699 # types from `typing`. Instead, annotate with regular types
2700 # or ABCs.
2701 return "I annotated with a generic collection"
Lysandros Nikolaoud6738102019-05-20 00:11:21 +02002702 self.assertTrue(str(exc.exception).startswith(
2703 "Invalid annotation for 'arg'."
Łukasz Langae5697532017-12-11 13:56:31 -08002704 ))
Lysandros Nikolaoud6738102019-05-20 00:11:21 +02002705 self.assertTrue(str(exc.exception).endswith(
2706 'typing.Iterable[str] is not a class.'
2707 ))
Łukasz Langae5697532017-12-11 13:56:31 -08002708
Dong-hee Na445f1b32018-07-10 16:26:36 +09002709 def test_invalid_positional_argument(self):
2710 @functools.singledispatch
2711 def f(*args):
2712 pass
2713 msg = 'f requires at least 1 positional argument'
INADA Naoki56d8f572018-07-17 13:44:47 +09002714 with self.assertRaisesRegex(TypeError, msg):
Dong-hee Na445f1b32018-07-10 16:26:36 +09002715 f()
Łukasz Langa6f692512013-06-05 12:20:24 +02002716
Carl Meyerd658dea2018-08-28 01:11:56 -06002717
2718class CachedCostItem:
2719 _cost = 1
2720
2721 def __init__(self):
2722 self.lock = py_functools.RLock()
2723
2724 @py_functools.cached_property
2725 def cost(self):
2726 """The cost of the item."""
2727 with self.lock:
2728 self._cost += 1
2729 return self._cost
2730
2731
2732class OptionallyCachedCostItem:
2733 _cost = 1
2734
2735 def get_cost(self):
2736 """The cost of the item."""
2737 self._cost += 1
2738 return self._cost
2739
2740 cached_cost = py_functools.cached_property(get_cost)
2741
2742
2743class CachedCostItemWait:
2744
2745 def __init__(self, event):
2746 self._cost = 1
2747 self.lock = py_functools.RLock()
2748 self.event = event
2749
2750 @py_functools.cached_property
2751 def cost(self):
2752 self.event.wait(1)
2753 with self.lock:
2754 self._cost += 1
2755 return self._cost
2756
2757
2758class CachedCostItemWithSlots:
2759 __slots__ = ('_cost')
2760
2761 def __init__(self):
2762 self._cost = 1
2763
2764 @py_functools.cached_property
2765 def cost(self):
2766 raise RuntimeError('never called, slots not supported')
2767
2768
2769class TestCachedProperty(unittest.TestCase):
2770 def test_cached(self):
2771 item = CachedCostItem()
2772 self.assertEqual(item.cost, 2)
2773 self.assertEqual(item.cost, 2) # not 3
2774
2775 def test_cached_attribute_name_differs_from_func_name(self):
2776 item = OptionallyCachedCostItem()
2777 self.assertEqual(item.get_cost(), 2)
2778 self.assertEqual(item.cached_cost, 3)
2779 self.assertEqual(item.get_cost(), 4)
2780 self.assertEqual(item.cached_cost, 3)
2781
2782 def test_threaded(self):
2783 go = threading.Event()
2784 item = CachedCostItemWait(go)
2785
2786 num_threads = 3
2787
2788 orig_si = sys.getswitchinterval()
2789 sys.setswitchinterval(1e-6)
2790 try:
2791 threads = [
2792 threading.Thread(target=lambda: item.cost)
2793 for k in range(num_threads)
2794 ]
2795 with support.start_threads(threads):
2796 go.set()
2797 finally:
2798 sys.setswitchinterval(orig_si)
2799
2800 self.assertEqual(item.cost, 2)
2801
2802 def test_object_with_slots(self):
2803 item = CachedCostItemWithSlots()
2804 with self.assertRaisesRegex(
2805 TypeError,
2806 "No '__dict__' attribute on 'CachedCostItemWithSlots' instance to cache 'cost' property.",
2807 ):
2808 item.cost
2809
2810 def test_immutable_dict(self):
2811 class MyMeta(type):
2812 @py_functools.cached_property
2813 def prop(self):
2814 return True
2815
2816 class MyClass(metaclass=MyMeta):
2817 pass
2818
2819 with self.assertRaisesRegex(
2820 TypeError,
2821 "The '__dict__' attribute on 'MyMeta' instance does not support item assignment for caching 'prop' property.",
2822 ):
2823 MyClass.prop
2824
2825 def test_reuse_different_names(self):
2826 """Disallow this case because decorated function a would not be cached."""
2827 with self.assertRaises(RuntimeError) as ctx:
2828 class ReusedCachedProperty:
2829 @py_functools.cached_property
2830 def a(self):
2831 pass
2832
2833 b = a
2834
2835 self.assertEqual(
2836 str(ctx.exception.__context__),
2837 str(TypeError("Cannot assign the same cached_property to two different names ('a' and 'b')."))
2838 )
2839
2840 def test_reuse_same_name(self):
2841 """Reusing a cached_property on different classes under the same name is OK."""
2842 counter = 0
2843
2844 @py_functools.cached_property
2845 def _cp(_self):
2846 nonlocal counter
2847 counter += 1
2848 return counter
2849
2850 class A:
2851 cp = _cp
2852
2853 class B:
2854 cp = _cp
2855
2856 a = A()
2857 b = B()
2858
2859 self.assertEqual(a.cp, 1)
2860 self.assertEqual(b.cp, 2)
2861 self.assertEqual(a.cp, 1)
2862
2863 def test_set_name_not_called(self):
2864 cp = py_functools.cached_property(lambda s: None)
2865 class Foo:
2866 pass
2867
2868 Foo.cp = cp
2869
2870 with self.assertRaisesRegex(
2871 TypeError,
2872 "Cannot use cached_property instance without calling __set_name__ on it.",
2873 ):
2874 Foo().cp
2875
2876 def test_access_from_class(self):
2877 self.assertIsInstance(CachedCostItem.cost, py_functools.cached_property)
2878
2879 def test_doc(self):
2880 self.assertEqual(CachedCostItem.cost.__doc__, "The cost of the item.")
2881
2882
Raymond Hettinger9c323f82005-02-28 19:39:44 +00002883if __name__ == '__main__':
Zachary Ware38c707e2015-04-13 15:00:43 -05002884 unittest.main()