blob: b3893a15566fa67513c02c8c832ce7de6a7c74fe [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08002import builtins
Raymond Hettinger003be522011-05-03 11:01:32 -07003import collections
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03004import collections.abc
Serhiy Storchaka45120f22015-10-24 09:49:56 +03005import copy
Pablo Galindo99e6c262020-01-23 15:29:52 +00006from itertools import permutations, chain
Jack Diederiche0cbd692009-04-01 04:27:09 +00007import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00008from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02009import sys
10from test import support
Antoine Pitroua6a4dc82017-09-07 18:56:24 +020011import threading
Serhiy Storchaka67796522017-01-12 18:34:33 +020012import time
Łukasz Langae5697532017-12-11 13:56:31 -080013import typing
Łukasz Langa6f692512013-06-05 12:20:24 +020014import unittest
Raymond Hettingerd191ef22017-01-07 20:44:48 -080015import unittest.mock
Pablo Galindo99e6c262020-01-23 15:29:52 +000016import os
Dennis Sweeney1253c3e2020-05-05 17:14:32 -040017import weakref
18import gc
Łukasz Langa6f692512013-06-05 12:20:24 +020019from weakref import proxy
Nick Coghlan457fc9a2016-09-10 20:00:02 +100020import contextlib
Raymond Hettinger9c323f82005-02-28 19:39:44 +000021
Pablo Galindo99e6c262020-01-23 15:29:52 +000022from test.support.script_helper import assert_python_ok
23
Antoine Pitroub5b37142012-11-13 21:35:40 +010024import functools
25
Antoine Pitroub5b37142012-11-13 21:35:40 +010026py_functools = support.import_fresh_module('functools', blocked=['_functools'])
27c_functools = support.import_fresh_module('functools', fresh=['_functools'])
28
Łukasz Langa6f692512013-06-05 12:20:24 +020029decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
30
Nick Coghlan457fc9a2016-09-10 20:00:02 +100031@contextlib.contextmanager
32def replaced_module(name, replacement):
33 original_module = sys.modules[name]
34 sys.modules[name] = replacement
35 try:
36 yield
37 finally:
38 sys.modules[name] = original_module
Łukasz Langa6f692512013-06-05 12:20:24 +020039
Raymond Hettinger9c323f82005-02-28 19:39:44 +000040def capture(*args, **kw):
41 """capture all positional and keyword arguments"""
42 return args, kw
43
Łukasz Langa6f692512013-06-05 12:20:24 +020044
Jack Diederiche0cbd692009-04-01 04:27:09 +000045def signature(part):
46 """ return the signature of a partial object """
47 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000048
Serhiy Storchaka38741282016-02-02 18:45:17 +020049class MyTuple(tuple):
50 pass
51
52class BadTuple(tuple):
53 def __add__(self, other):
54 return list(self) + list(other)
55
56class MyDict(dict):
57 pass
58
Łukasz Langa6f692512013-06-05 12:20:24 +020059
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020060class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000061
62 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010063 p = self.partial(capture, 1, 2, a=10, b=20)
64 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000065 self.assertEqual(p(3, 4, b=30, c=40),
66 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010067 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000068 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000069
70 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010071 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000072 # attributes should be readable
73 self.assertEqual(p.func, capture)
74 self.assertEqual(p.args, (1, 2))
75 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000076
77 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010078 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000079 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010080 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000081 except TypeError:
82 pass
83 else:
84 self.fail('First arg not checked for callability')
85
86 def test_protection_of_callers_dict_argument(self):
87 # a caller's dictionary should not be altered by partial
88 def func(a=10, b=20):
89 return a
90 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010091 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000092 self.assertEqual(p(**d), 3)
93 self.assertEqual(d, {'a':3})
94 p(b=7)
95 self.assertEqual(d, {'a':3})
96
Serhiy Storchaka9639e4a2017-02-20 14:04:30 +020097 def test_kwargs_copy(self):
98 # Issue #29532: Altering a kwarg dictionary passed to a constructor
99 # should not affect a partial object after creation
100 d = {'a': 3}
101 p = self.partial(capture, **d)
102 self.assertEqual(p(), ((), {'a': 3}))
103 d['a'] = 5
104 self.assertEqual(p(), ((), {'a': 3}))
105
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000106 def test_arg_combinations(self):
107 # exercise special code paths for zero args in either partial
108 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100109 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000110 self.assertEqual(p(), ((), {}))
111 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100112 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000113 self.assertEqual(p(), ((1,2), {}))
114 self.assertEqual(p(3,4), ((1,2,3,4), {}))
115
116 def test_kw_combinations(self):
117 # exercise special code paths for no keyword args in
118 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100119 p = self.partial(capture)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400120 self.assertEqual(p.keywords, {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000121 self.assertEqual(p(), ((), {}))
122 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100123 p = self.partial(capture, a=1)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400124 self.assertEqual(p.keywords, {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000125 self.assertEqual(p(), ((), {'a':1}))
126 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
127 # keyword args in the call override those in the partial object
128 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
129
130 def test_positional(self):
131 # make sure positional arguments are captured correctly
132 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100133 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000134 expected = args + ('x',)
135 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000136 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000137
138 def test_keyword(self):
139 # make sure keyword arguments are captured correctly
140 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100141 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000142 expected = {'a':a,'x':None}
143 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000144 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000145
146 def test_no_side_effects(self):
147 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100148 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000149 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000150 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000151 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000152 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000153
154 def test_error_propagation(self):
155 def f(x, y):
156 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100157 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
158 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
159 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
160 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000161
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000162 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100163 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000164 p = proxy(f)
165 self.assertEqual(f.func, p.func)
166 f = None
167 self.assertRaises(ReferenceError, getattr, p, 'func')
168
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000169 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000170 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100171 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000172 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100173 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000174 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000175
Alexander Belopolskye49af342015-03-01 15:08:17 -0500176 def test_nested_optimization(self):
177 partial = self.partial
Alexander Belopolskye49af342015-03-01 15:08:17 -0500178 inner = partial(signature, 'asdf')
179 nested = partial(inner, bar=True)
180 flat = partial(signature, 'asdf', bar=True)
181 self.assertEqual(signature(nested), signature(flat))
182
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300183 def test_nested_partial_with_attribute(self):
184 # see issue 25137
185 partial = self.partial
186
187 def foo(bar):
188 return bar
189
190 p = partial(foo, 'first')
191 p2 = partial(p, 'second')
192 p2.new_attr = 'spam'
193 self.assertEqual(p2.new_attr, 'spam')
194
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000195 def test_repr(self):
196 args = (object(), object())
197 args_repr = ', '.join(repr(a) for a in args)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200198 kwargs = {'a': object(), 'b': object()}
199 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
200 'b={b!r}, a={a!r}'.format_map(kwargs)]
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000201 if self.partial in (c_functools.partial, py_functools.partial):
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000202 name = 'functools.partial'
203 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100204 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000205
Antoine Pitroub5b37142012-11-13 21:35:40 +0100206 f = self.partial(capture)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000207 self.assertEqual(f'{name}({capture!r})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000208
Antoine Pitroub5b37142012-11-13 21:35:40 +0100209 f = self.partial(capture, *args)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000210 self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000211
Antoine Pitroub5b37142012-11-13 21:35:40 +0100212 f = self.partial(capture, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200213 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000214 [f'{name}({capture!r}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200215 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000216
Antoine Pitroub5b37142012-11-13 21:35:40 +0100217 f = self.partial(capture, *args, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200218 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000219 [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200220 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000221
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300222 def test_recursive_repr(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000223 if self.partial in (c_functools.partial, py_functools.partial):
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300224 name = 'functools.partial'
225 else:
226 name = self.partial.__name__
227
228 f = self.partial(capture)
229 f.__setstate__((f, (), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300230 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000231 self.assertEqual(repr(f), '%s(...)' % (name,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300232 finally:
233 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300234
235 f = self.partial(capture)
236 f.__setstate__((capture, (f,), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300237 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000238 self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300239 finally:
240 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300241
242 f = self.partial(capture)
243 f.__setstate__((capture, (), {'a': f}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300244 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000245 self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300246 finally:
247 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300248
Jack Diederiche0cbd692009-04-01 04:27:09 +0000249 def test_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000250 with self.AllowPickle():
251 f = self.partial(signature, ['asdf'], bar=[True])
252 f.attr = []
253 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
254 f_copy = pickle.loads(pickle.dumps(f, proto))
255 self.assertEqual(signature(f_copy), signature(f))
Serhiy Storchaka38741282016-02-02 18:45:17 +0200256
257 def test_copy(self):
258 f = self.partial(signature, ['asdf'], bar=[True])
259 f.attr = []
260 f_copy = copy.copy(f)
261 self.assertEqual(signature(f_copy), signature(f))
262 self.assertIs(f_copy.attr, f.attr)
263 self.assertIs(f_copy.args, f.args)
264 self.assertIs(f_copy.keywords, f.keywords)
265
266 def test_deepcopy(self):
267 f = self.partial(signature, ['asdf'], bar=[True])
268 f.attr = []
269 f_copy = copy.deepcopy(f)
270 self.assertEqual(signature(f_copy), signature(f))
271 self.assertIsNot(f_copy.attr, f.attr)
272 self.assertIsNot(f_copy.args, f.args)
273 self.assertIsNot(f_copy.args[0], f.args[0])
274 self.assertIsNot(f_copy.keywords, f.keywords)
275 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
276
277 def test_setstate(self):
278 f = self.partial(signature)
279 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000280
Serhiy Storchaka38741282016-02-02 18:45:17 +0200281 self.assertEqual(signature(f),
282 (capture, (1,), dict(a=10), dict(attr=[])))
283 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
284
285 f.__setstate__((capture, (1,), dict(a=10), None))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000286
Serhiy Storchaka38741282016-02-02 18:45:17 +0200287 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
288 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
289
290 f.__setstate__((capture, (1,), None, None))
291 #self.assertEqual(signature(f), (capture, (1,), {}, {}))
292 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
293 self.assertEqual(f(2), ((1, 2), {}))
294 self.assertEqual(f(), ((1,), {}))
295
296 f.__setstate__((capture, (), {}, None))
297 self.assertEqual(signature(f), (capture, (), {}, {}))
298 self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
299 self.assertEqual(f(2), ((2,), {}))
300 self.assertEqual(f(), ((), {}))
301
302 def test_setstate_errors(self):
303 f = self.partial(signature)
304 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
305 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
306 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
307 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
308 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
309 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
310 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
311
312 def test_setstate_subclasses(self):
313 f = self.partial(signature)
314 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
315 s = signature(f)
316 self.assertEqual(s, (capture, (1,), dict(a=10), {}))
317 self.assertIs(type(s[1]), tuple)
318 self.assertIs(type(s[2]), dict)
319 r = f()
320 self.assertEqual(r, ((1,), {'a': 10}))
321 self.assertIs(type(r[0]), tuple)
322 self.assertIs(type(r[1]), dict)
323
324 f.__setstate__((capture, BadTuple((1,)), {}, None))
325 s = signature(f)
326 self.assertEqual(s, (capture, (1,), {}, {}))
327 self.assertIs(type(s[1]), tuple)
328 r = f(2)
329 self.assertEqual(r, ((1, 2), {}))
330 self.assertIs(type(r[0]), tuple)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000331
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300332 def test_recursive_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000333 with self.AllowPickle():
334 f = self.partial(capture)
335 f.__setstate__((f, (), {}, {}))
336 try:
337 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
338 with self.assertRaises(RecursionError):
339 pickle.dumps(f, proto)
340 finally:
341 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300342
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000343 f = self.partial(capture)
344 f.__setstate__((capture, (f,), {}, {}))
345 try:
346 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
347 f_copy = pickle.loads(pickle.dumps(f, proto))
348 try:
349 self.assertIs(f_copy.args[0], f_copy)
350 finally:
351 f_copy.__setstate__((capture, (), {}, {}))
352 finally:
353 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300354
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000355 f = self.partial(capture)
356 f.__setstate__((capture, (), {'a': f}, {}))
357 try:
358 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
359 f_copy = pickle.loads(pickle.dumps(f, proto))
360 try:
361 self.assertIs(f_copy.keywords['a'], f_copy)
362 finally:
363 f_copy.__setstate__((capture, (), {}, {}))
364 finally:
365 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300366
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200367 # Issue 6083: Reference counting bug
368 def test_setstate_refcount(self):
369 class BadSequence:
370 def __len__(self):
371 return 4
372 def __getitem__(self, key):
373 if key == 0:
374 return max
375 elif key == 1:
376 return tuple(range(1000000))
377 elif key in (2, 3):
378 return {}
379 raise IndexError
380
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200381 f = self.partial(object)
Serhiy Storchaka38741282016-02-02 18:45:17 +0200382 self.assertRaises(TypeError, f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000383
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000384@unittest.skipUnless(c_functools, 'requires the C _functools module')
385class TestPartialC(TestPartial, unittest.TestCase):
386 if c_functools:
387 partial = c_functools.partial
388
389 class AllowPickle:
390 def __enter__(self):
391 return self
392 def __exit__(self, type, value, tb):
393 return False
394
395 def test_attributes_unwritable(self):
396 # attributes should not be writable
397 p = self.partial(capture, 1, 2, a=10, b=20)
398 self.assertRaises(AttributeError, setattr, p, 'func', map)
399 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
400 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
401
402 p = self.partial(hex)
403 try:
404 del p.__dict__
405 except TypeError:
406 pass
407 else:
408 self.fail('partial object allowed __dict__ to be deleted')
Łukasz Langa6f692512013-06-05 12:20:24 +0200409
Michael Seifert6c3d5272017-03-15 06:26:33 +0100410 def test_manually_adding_non_string_keyword(self):
411 p = self.partial(capture)
412 # Adding a non-string/unicode keyword to partial kwargs
413 p.keywords[1234] = 'value'
414 r = repr(p)
415 self.assertIn('1234', r)
416 self.assertIn("'value'", r)
417 with self.assertRaises(TypeError):
418 p()
419
420 def test_keystr_replaces_value(self):
421 p = self.partial(capture)
422
423 class MutatesYourDict(object):
424 def __str__(self):
425 p.keywords[self] = ['sth2']
426 return 'astr'
427
Mike53f7a7c2017-12-14 14:04:53 +0300428 # Replacing the value during key formatting should keep the original
Michael Seifert6c3d5272017-03-15 06:26:33 +0100429 # value alive (at least long enough).
430 p.keywords[MutatesYourDict()] = ['sth']
431 r = repr(p)
432 self.assertIn('astr', r)
433 self.assertIn("['sth']", r)
434
435
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200436class TestPartialPy(TestPartial, unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000437 partial = py_functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000438
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000439 class AllowPickle:
440 def __init__(self):
441 self._cm = replaced_module("functools", py_functools)
442 def __enter__(self):
443 return self._cm.__enter__()
444 def __exit__(self, type, value, tb):
445 return self._cm.__exit__(type, value, tb)
Łukasz Langa6f692512013-06-05 12:20:24 +0200446
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200447if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000448 class CPartialSubclass(c_functools.partial):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200449 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100450
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000451class PyPartialSubclass(py_functools.partial):
452 pass
Łukasz Langa6f692512013-06-05 12:20:24 +0200453
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200454@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200455class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200456 if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000457 partial = CPartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000458
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300459 # partial subclasses are not optimized for nested calls
460 test_nested_optimization = None
461
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000462class TestPartialPySubclass(TestPartialPy):
463 partial = PyPartialSubclass
Łukasz Langa6f692512013-06-05 12:20:24 +0200464
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000465class TestPartialMethod(unittest.TestCase):
466
467 class A(object):
468 nothing = functools.partialmethod(capture)
469 positional = functools.partialmethod(capture, 1)
470 keywords = functools.partialmethod(capture, a=2)
471 both = functools.partialmethod(capture, 3, b=4)
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300472 spec_keywords = functools.partialmethod(capture, self=1, func=2)
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000473
474 nested = functools.partialmethod(positional, 5)
475
476 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
477
478 static = functools.partialmethod(staticmethod(capture), 8)
479 cls = functools.partialmethod(classmethod(capture), d=9)
480
481 a = A()
482
483 def test_arg_combinations(self):
484 self.assertEqual(self.a.nothing(), ((self.a,), {}))
485 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
486 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
487 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
488
489 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
490 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
491 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
492 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
493
494 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
495 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
496 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
497 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
498
499 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
500 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
501 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
502 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
503
504 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
505
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300506 self.assertEqual(self.a.spec_keywords(), ((self.a,), {'self': 1, 'func': 2}))
507
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000508 def test_nested(self):
509 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
510 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
511 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
512 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
513
514 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
515
516 def test_over_partial(self):
517 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
518 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
519 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
520 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
521
522 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
523
524 def test_bound_method_introspection(self):
525 obj = self.a
526 self.assertIs(obj.both.__self__, obj)
527 self.assertIs(obj.nested.__self__, obj)
528 self.assertIs(obj.over_partial.__self__, obj)
529 self.assertIs(obj.cls.__self__, self.A)
530 self.assertIs(self.A.cls.__self__, self.A)
531
532 def test_unbound_method_retrieval(self):
533 obj = self.A
534 self.assertFalse(hasattr(obj.both, "__self__"))
535 self.assertFalse(hasattr(obj.nested, "__self__"))
536 self.assertFalse(hasattr(obj.over_partial, "__self__"))
537 self.assertFalse(hasattr(obj.static, "__self__"))
538 self.assertFalse(hasattr(self.a.static, "__self__"))
539
540 def test_descriptors(self):
541 for obj in [self.A, self.a]:
542 with self.subTest(obj=obj):
543 self.assertEqual(obj.static(), ((8,), {}))
544 self.assertEqual(obj.static(5), ((8, 5), {}))
545 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
546 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
547
548 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
549 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
550 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
551 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
552
553 def test_overriding_keywords(self):
554 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
555 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
556
557 def test_invalid_args(self):
558 with self.assertRaises(TypeError):
559 class B(object):
560 method = functools.partialmethod(None, 1)
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300561 with self.assertRaises(TypeError):
562 class B:
563 method = functools.partialmethod()
Serhiy Storchaka142566c2019-06-05 18:22:31 +0300564 with self.assertRaises(TypeError):
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300565 class B:
566 method = functools.partialmethod(func=capture, a=1)
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000567
568 def test_repr(self):
569 self.assertEqual(repr(vars(self.A)['both']),
570 'functools.partialmethod({}, 3, b=4)'.format(capture))
571
572 def test_abstract(self):
573 class Abstract(abc.ABCMeta):
574
575 @abc.abstractmethod
576 def add(self, x, y):
577 pass
578
579 add5 = functools.partialmethod(add, 5)
580
581 self.assertTrue(Abstract.add.__isabstractmethod__)
582 self.assertTrue(Abstract.add5.__isabstractmethod__)
583
584 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
585 self.assertFalse(getattr(func, '__isabstractmethod__', False))
586
Pablo Galindo8c77b8c2019-04-29 13:36:57 +0100587 def test_positional_only(self):
588 def f(a, b, /):
589 return a + b
590
591 p = functools.partial(f, 1)
592 self.assertEqual(p(2), f(1, 2))
593
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000594
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000595class TestUpdateWrapper(unittest.TestCase):
596
597 def check_wrapper(self, wrapper, wrapped,
598 assigned=functools.WRAPPER_ASSIGNMENTS,
599 updated=functools.WRAPPER_UPDATES):
600 # Check attributes were assigned
601 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000602 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000603 # Check attributes were updated
604 for name in updated:
605 wrapper_attr = getattr(wrapper, name)
606 wrapped_attr = getattr(wrapped, name)
607 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000608 if name == "__dict__" and key == "__wrapped__":
609 # __wrapped__ is overwritten by the update code
610 continue
611 self.assertIs(wrapped_attr[key], wrapper_attr[key])
612 # Check __wrapped__
613 self.assertIs(wrapper.__wrapped__, wrapped)
614
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000615
R. David Murray378c0cf2010-02-24 01:46:21 +0000616 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000617 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000618 """This is a test"""
619 pass
620 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000621 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000622 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000623 pass
624 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000625 return wrapper, f
626
627 def test_default_update(self):
628 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000629 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000630 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000631 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600632 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000633 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000634 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
635 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000636
R. David Murray378c0cf2010-02-24 01:46:21 +0000637 @unittest.skipIf(sys.flags.optimize >= 2,
638 "Docstrings are omitted with -O2 and above")
639 def test_default_update_doc(self):
640 wrapper, f = self._default_update()
641 self.assertEqual(wrapper.__doc__, 'This is a test')
642
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000643 def test_no_update(self):
644 def f():
645 """This is a test"""
646 pass
647 f.attr = 'This is also a test'
648 def wrapper():
649 pass
650 functools.update_wrapper(wrapper, f, (), ())
651 self.check_wrapper(wrapper, f, (), ())
652 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600653 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000654 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000655 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000656 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000657
658 def test_selective_update(self):
659 def f():
660 pass
661 f.attr = 'This is a different test'
662 f.dict_attr = dict(a=1, b=2, c=3)
663 def wrapper():
664 pass
665 wrapper.dict_attr = {}
666 assign = ('attr',)
667 update = ('dict_attr',)
668 functools.update_wrapper(wrapper, f, assign, update)
669 self.check_wrapper(wrapper, f, assign, update)
670 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600671 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000672 self.assertEqual(wrapper.__doc__, None)
673 self.assertEqual(wrapper.attr, 'This is a different test')
674 self.assertEqual(wrapper.dict_attr, f.dict_attr)
675
Nick Coghlan98876832010-08-17 06:17:18 +0000676 def test_missing_attributes(self):
677 def f():
678 pass
679 def wrapper():
680 pass
681 wrapper.dict_attr = {}
682 assign = ('attr',)
683 update = ('dict_attr',)
684 # Missing attributes on wrapped object are ignored
685 functools.update_wrapper(wrapper, f, assign, update)
686 self.assertNotIn('attr', wrapper.__dict__)
687 self.assertEqual(wrapper.dict_attr, {})
688 # Wrapper must have expected attributes for updating
689 del wrapper.dict_attr
690 with self.assertRaises(AttributeError):
691 functools.update_wrapper(wrapper, f, assign, update)
692 wrapper.dict_attr = 1
693 with self.assertRaises(AttributeError):
694 functools.update_wrapper(wrapper, f, assign, update)
695
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200696 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000697 @unittest.skipIf(sys.flags.optimize >= 2,
698 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000699 def test_builtin_update(self):
700 # Test for bug #1576241
701 def wrapper():
702 pass
703 functools.update_wrapper(wrapper, max)
704 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000705 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000706 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000707
Łukasz Langa6f692512013-06-05 12:20:24 +0200708
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000709class TestWraps(TestUpdateWrapper):
710
R. David Murray378c0cf2010-02-24 01:46:21 +0000711 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000712 def f():
713 """This is a test"""
714 pass
715 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000716 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000717 @functools.wraps(f)
718 def wrapper():
719 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600720 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000721
722 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600723 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000724 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000725 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600726 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000727 self.assertEqual(wrapper.attr, 'This is also a test')
728
Antoine Pitroub5b37142012-11-13 21:35:40 +0100729 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000730 "Docstrings are omitted with -O2 and above")
731 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600732 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000733 self.assertEqual(wrapper.__doc__, 'This is a test')
734
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000735 def test_no_update(self):
736 def f():
737 """This is a test"""
738 pass
739 f.attr = 'This is also a test'
740 @functools.wraps(f, (), ())
741 def wrapper():
742 pass
743 self.check_wrapper(wrapper, f, (), ())
744 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600745 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000746 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000747 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000748
749 def test_selective_update(self):
750 def f():
751 pass
752 f.attr = 'This is a different test'
753 f.dict_attr = dict(a=1, b=2, c=3)
754 def add_dict_attr(f):
755 f.dict_attr = {}
756 return f
757 assign = ('attr',)
758 update = ('dict_attr',)
759 @functools.wraps(f, assign, update)
760 @add_dict_attr
761 def wrapper():
762 pass
763 self.check_wrapper(wrapper, f, assign, update)
764 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600765 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000766 self.assertEqual(wrapper.__doc__, None)
767 self.assertEqual(wrapper.attr, 'This is a different test')
768 self.assertEqual(wrapper.dict_attr, f.dict_attr)
769
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000770
madman-bobe25d5fc2018-10-25 15:02:10 +0100771class TestReduce:
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000772 def test_reduce(self):
773 class Squares:
774 def __init__(self, max):
775 self.max = max
776 self.sofar = []
777
778 def __len__(self):
779 return len(self.sofar)
780
781 def __getitem__(self, i):
782 if not 0 <= i < self.max: raise IndexError
783 n = len(self.sofar)
784 while n <= i:
785 self.sofar.append(n*n)
786 n += 1
787 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000788 def add(x, y):
789 return x + y
madman-bobe25d5fc2018-10-25 15:02:10 +0100790 self.assertEqual(self.reduce(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000791 self.assertEqual(
madman-bobe25d5fc2018-10-25 15:02:10 +0100792 self.reduce(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000793 ['a','c','d','w']
794 )
madman-bobe25d5fc2018-10-25 15:02:10 +0100795 self.assertEqual(self.reduce(lambda x, y: x*y, range(2,8), 1), 5040)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000796 self.assertEqual(
madman-bobe25d5fc2018-10-25 15:02:10 +0100797 self.reduce(lambda x, y: x*y, range(2,21), 1),
Guido van Rossume2a383d2007-01-15 16:59:06 +0000798 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000799 )
madman-bobe25d5fc2018-10-25 15:02:10 +0100800 self.assertEqual(self.reduce(add, Squares(10)), 285)
801 self.assertEqual(self.reduce(add, Squares(10), 0), 285)
802 self.assertEqual(self.reduce(add, Squares(0), 0), 0)
803 self.assertRaises(TypeError, self.reduce)
804 self.assertRaises(TypeError, self.reduce, 42, 42)
805 self.assertRaises(TypeError, self.reduce, 42, 42, 42)
806 self.assertEqual(self.reduce(42, "1"), "1") # func is never called with one item
807 self.assertEqual(self.reduce(42, "", "1"), "1") # func is never called with one item
808 self.assertRaises(TypeError, self.reduce, 42, (42, 42))
809 self.assertRaises(TypeError, self.reduce, add, []) # arg 2 must not be empty sequence with no initial value
810 self.assertRaises(TypeError, self.reduce, add, "")
811 self.assertRaises(TypeError, self.reduce, add, ())
812 self.assertRaises(TypeError, self.reduce, add, object())
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000813
814 class TestFailingIter:
815 def __iter__(self):
816 raise RuntimeError
madman-bobe25d5fc2018-10-25 15:02:10 +0100817 self.assertRaises(RuntimeError, self.reduce, add, TestFailingIter())
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000818
madman-bobe25d5fc2018-10-25 15:02:10 +0100819 self.assertEqual(self.reduce(add, [], None), None)
820 self.assertEqual(self.reduce(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000821
822 class BadSeq:
823 def __getitem__(self, index):
824 raise ValueError
madman-bobe25d5fc2018-10-25 15:02:10 +0100825 self.assertRaises(ValueError, self.reduce, 42, BadSeq())
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000826
827 # Test reduce()'s use of iterators.
828 def test_iterator_usage(self):
829 class SequenceClass:
830 def __init__(self, n):
831 self.n = n
832 def __getitem__(self, i):
833 if 0 <= i < self.n:
834 return i
835 else:
836 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000837
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000838 from operator import add
madman-bobe25d5fc2018-10-25 15:02:10 +0100839 self.assertEqual(self.reduce(add, SequenceClass(5)), 10)
840 self.assertEqual(self.reduce(add, SequenceClass(5), 42), 52)
841 self.assertRaises(TypeError, self.reduce, add, SequenceClass(0))
842 self.assertEqual(self.reduce(add, SequenceClass(0), 42), 42)
843 self.assertEqual(self.reduce(add, SequenceClass(1)), 0)
844 self.assertEqual(self.reduce(add, SequenceClass(1), 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000845
846 d = {"one": 1, "two": 2, "three": 3}
madman-bobe25d5fc2018-10-25 15:02:10 +0100847 self.assertEqual(self.reduce(add, d), "".join(d.keys()))
848
849
850@unittest.skipUnless(c_functools, 'requires the C _functools module')
851class TestReduceC(TestReduce, unittest.TestCase):
852 if c_functools:
853 reduce = c_functools.reduce
854
855
856class TestReducePy(TestReduce, unittest.TestCase):
857 reduce = staticmethod(py_functools.reduce)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000858
Łukasz Langa6f692512013-06-05 12:20:24 +0200859
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200860class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700861
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000862 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700863 def cmp1(x, y):
864 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100865 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700866 self.assertEqual(key(3), key(3))
867 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100868 self.assertGreaterEqual(key(3), key(3))
869
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700870 def cmp2(x, y):
871 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100872 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700873 self.assertEqual(key(4.0), key('4'))
874 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100875 self.assertLessEqual(key(2), key('35'))
876 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700877
878 def test_cmp_to_key_arguments(self):
879 def cmp1(x, y):
880 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100881 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700882 self.assertEqual(key(obj=3), key(obj=3))
883 self.assertGreater(key(obj=3), key(obj=1))
884 with self.assertRaises((TypeError, AttributeError)):
885 key(3) > 1 # rhs is not a K object
886 with self.assertRaises((TypeError, AttributeError)):
887 1 < key(3) # lhs is not a K object
888 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100889 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700890 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200891 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100892 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700893 with self.assertRaises(TypeError):
894 key() # too few args
895 with self.assertRaises(TypeError):
896 key(None, None) # too many args
897
898 def test_bad_cmp(self):
899 def cmp1(x, y):
900 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100901 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700902 with self.assertRaises(ZeroDivisionError):
903 key(3) > key(1)
904
905 class BadCmp:
906 def __lt__(self, other):
907 raise ZeroDivisionError
908 def cmp1(x, y):
909 return BadCmp()
910 with self.assertRaises(ZeroDivisionError):
911 key(3) > key(1)
912
913 def test_obj_field(self):
914 def cmp1(x, y):
915 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100916 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700917 self.assertEqual(key(50).obj, 50)
918
919 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000920 def mycmp(x, y):
921 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100922 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000923 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000924
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700925 def test_sort_int_str(self):
926 def mycmp(x, y):
927 x, y = int(x), int(y)
928 return (x > y) - (x < y)
929 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100930 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700931 self.assertEqual([int(value) for value in values],
932 [0, 1, 1, 2, 3, 4, 5, 7, 10])
933
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000934 def test_hash(self):
935 def mycmp(x, y):
936 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100937 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000938 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700939 self.assertRaises(TypeError, hash, k)
Serhiy Storchaka2e576f52017-04-24 09:05:00 +0300940 self.assertNotIsInstance(k, collections.abc.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000941
Łukasz Langa6f692512013-06-05 12:20:24 +0200942
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200943@unittest.skipUnless(c_functools, 'requires the C _functools module')
944class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
945 if c_functools:
946 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100947
Łukasz Langa6f692512013-06-05 12:20:24 +0200948
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200949class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100950 cmp_to_key = staticmethod(py_functools.cmp_to_key)
951
Łukasz Langa6f692512013-06-05 12:20:24 +0200952
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000953class TestTotalOrdering(unittest.TestCase):
954
955 def test_total_ordering_lt(self):
956 @functools.total_ordering
957 class A:
958 def __init__(self, value):
959 self.value = value
960 def __lt__(self, other):
961 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000962 def __eq__(self, other):
963 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000964 self.assertTrue(A(1) < A(2))
965 self.assertTrue(A(2) > A(1))
966 self.assertTrue(A(1) <= A(2))
967 self.assertTrue(A(2) >= A(1))
968 self.assertTrue(A(2) <= A(2))
969 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000970 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000971
972 def test_total_ordering_le(self):
973 @functools.total_ordering
974 class A:
975 def __init__(self, value):
976 self.value = value
977 def __le__(self, other):
978 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000979 def __eq__(self, other):
980 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000981 self.assertTrue(A(1) < A(2))
982 self.assertTrue(A(2) > A(1))
983 self.assertTrue(A(1) <= A(2))
984 self.assertTrue(A(2) >= A(1))
985 self.assertTrue(A(2) <= A(2))
986 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000987 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000988
989 def test_total_ordering_gt(self):
990 @functools.total_ordering
991 class A:
992 def __init__(self, value):
993 self.value = value
994 def __gt__(self, other):
995 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000996 def __eq__(self, other):
997 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000998 self.assertTrue(A(1) < A(2))
999 self.assertTrue(A(2) > A(1))
1000 self.assertTrue(A(1) <= A(2))
1001 self.assertTrue(A(2) >= A(1))
1002 self.assertTrue(A(2) <= A(2))
1003 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +10001004 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +00001005
1006 def test_total_ordering_ge(self):
1007 @functools.total_ordering
1008 class A:
1009 def __init__(self, value):
1010 self.value = value
1011 def __ge__(self, other):
1012 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001013 def __eq__(self, other):
1014 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +00001015 self.assertTrue(A(1) < A(2))
1016 self.assertTrue(A(2) > A(1))
1017 self.assertTrue(A(1) <= A(2))
1018 self.assertTrue(A(2) >= A(1))
1019 self.assertTrue(A(2) <= A(2))
1020 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +10001021 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +00001022
1023 def test_total_ordering_no_overwrite(self):
1024 # new methods should not overwrite existing
1025 @functools.total_ordering
1026 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +00001027 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +00001028 self.assertTrue(A(1) < A(2))
1029 self.assertTrue(A(2) > A(1))
1030 self.assertTrue(A(1) <= A(2))
1031 self.assertTrue(A(2) >= A(1))
1032 self.assertTrue(A(2) <= A(2))
1033 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001034
Benjamin Peterson42ebee32010-04-11 01:43:16 +00001035 def test_no_operations_defined(self):
1036 with self.assertRaises(ValueError):
1037 @functools.total_ordering
1038 class A:
1039 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001040
Nick Coghlanf05d9812013-10-02 00:02:03 +10001041 def test_type_error_when_not_implemented(self):
1042 # bug 10042; ensure stack overflow does not occur
1043 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001044 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +10001045 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001046 def __init__(self, value):
1047 self.value = value
1048 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001049 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001050 return self.value == other.value
1051 return False
1052 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001053 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001054 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +10001055 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001056
Nick Coghlanf05d9812013-10-02 00:02:03 +10001057 @functools.total_ordering
1058 class ImplementsGreaterThan:
1059 def __init__(self, value):
1060 self.value = value
1061 def __eq__(self, other):
1062 if isinstance(other, ImplementsGreaterThan):
1063 return self.value == other.value
1064 return False
1065 def __gt__(self, other):
1066 if isinstance(other, ImplementsGreaterThan):
1067 return self.value > other.value
1068 return NotImplemented
1069
1070 @functools.total_ordering
1071 class ImplementsLessThanEqualTo:
1072 def __init__(self, value):
1073 self.value = value
1074 def __eq__(self, other):
1075 if isinstance(other, ImplementsLessThanEqualTo):
1076 return self.value == other.value
1077 return False
1078 def __le__(self, other):
1079 if isinstance(other, ImplementsLessThanEqualTo):
1080 return self.value <= other.value
1081 return NotImplemented
1082
1083 @functools.total_ordering
1084 class ImplementsGreaterThanEqualTo:
1085 def __init__(self, value):
1086 self.value = value
1087 def __eq__(self, other):
1088 if isinstance(other, ImplementsGreaterThanEqualTo):
1089 return self.value == other.value
1090 return False
1091 def __ge__(self, other):
1092 if isinstance(other, ImplementsGreaterThanEqualTo):
1093 return self.value >= other.value
1094 return NotImplemented
1095
1096 @functools.total_ordering
1097 class ComparatorNotImplemented:
1098 def __init__(self, value):
1099 self.value = value
1100 def __eq__(self, other):
1101 if isinstance(other, ComparatorNotImplemented):
1102 return self.value == other.value
1103 return False
1104 def __lt__(self, other):
1105 return NotImplemented
1106
1107 with self.subTest("LT < 1"), self.assertRaises(TypeError):
1108 ImplementsLessThan(-1) < 1
1109
1110 with self.subTest("LT < LE"), self.assertRaises(TypeError):
1111 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
1112
1113 with self.subTest("LT < GT"), self.assertRaises(TypeError):
1114 ImplementsLessThan(1) < ImplementsGreaterThan(1)
1115
1116 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
1117 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
1118
1119 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
1120 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
1121
1122 with self.subTest("GT > GE"), self.assertRaises(TypeError):
1123 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
1124
1125 with self.subTest("GT > LT"), self.assertRaises(TypeError):
1126 ImplementsGreaterThan(5) > ImplementsLessThan(5)
1127
1128 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
1129 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
1130
1131 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
1132 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
1133
1134 with self.subTest("GE when equal"):
1135 a = ComparatorNotImplemented(8)
1136 b = ComparatorNotImplemented(8)
1137 self.assertEqual(a, b)
1138 with self.assertRaises(TypeError):
1139 a >= b
1140
1141 with self.subTest("LE when equal"):
1142 a = ComparatorNotImplemented(9)
1143 b = ComparatorNotImplemented(9)
1144 self.assertEqual(a, b)
1145 with self.assertRaises(TypeError):
1146 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +02001147
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001148 def test_pickle(self):
Serhiy Storchaka92bb90a2016-09-22 11:39:25 +03001149 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001150 for name in '__lt__', '__gt__', '__le__', '__ge__':
1151 with self.subTest(method=name, proto=proto):
1152 method = getattr(Orderable_LT, name)
1153 method_copy = pickle.loads(pickle.dumps(method, proto))
1154 self.assertIs(method_copy, method)
1155
1156@functools.total_ordering
1157class Orderable_LT:
1158 def __init__(self, value):
1159 self.value = value
1160 def __lt__(self, other):
1161 return self.value < other.value
1162 def __eq__(self, other):
1163 return self.value == other.value
1164
1165
Pablo Galindo99e6c262020-01-23 15:29:52 +00001166class TestTopologicalSort(unittest.TestCase):
1167
1168 def _test_graph(self, graph, expected):
1169
1170 def static_order_with_groups(ts):
1171 ts.prepare()
1172 while ts.is_active():
1173 nodes = ts.get_ready()
1174 for node in nodes:
1175 ts.done(node)
1176 yield nodes
1177
1178 ts = functools.TopologicalSorter(graph)
1179 self.assertEqual(list(static_order_with_groups(ts)), list(expected))
1180
1181 ts = functools.TopologicalSorter(graph)
1182 self.assertEqual(list(ts.static_order()), list(chain(*expected)))
1183
1184 def _assert_cycle(self, graph, cycle):
1185 ts = functools.TopologicalSorter()
1186 for node, dependson in graph.items():
1187 ts.add(node, *dependson)
1188 try:
1189 ts.prepare()
1190 except functools.CycleError as e:
1191 msg, seq = e.args
1192 self.assertIn(' '.join(map(str, cycle)),
1193 ' '.join(map(str, seq * 2)))
1194 else:
1195 raise
1196
1197 def test_simple_cases(self):
1198 self._test_graph(
1199 {2: {11},
1200 9: {11, 8},
1201 10: {11, 3},
1202 11: {7, 5},
1203 8: {7, 3}},
1204 [(3, 5, 7), (11, 8), (2, 10, 9)]
1205 )
1206
1207 self._test_graph({1: {}}, [(1,)])
1208
1209 self._test_graph({x: {x+1} for x in range(10)},
1210 [(x,) for x in range(10, -1, -1)])
1211
1212 self._test_graph({2: {3}, 3: {4}, 4: {5}, 5: {1},
1213 11: {12}, 12: {13}, 13: {14}, 14: {15}},
1214 [(1, 15), (5, 14), (4, 13), (3, 12), (2, 11)])
1215
1216 self._test_graph({
1217 0: [1, 2],
1218 1: [3],
1219 2: [5, 6],
1220 3: [4],
1221 4: [9],
1222 5: [3],
1223 6: [7],
1224 7: [8],
1225 8: [4],
1226 9: []
1227 },
1228 [(9,), (4,), (3, 8), (1, 5, 7), (6,), (2,), (0,)]
1229 )
1230
1231 self._test_graph({
1232 0: [1, 2],
1233 1: [],
1234 2: [3],
1235 3: []
1236 },
1237 [(1, 3), (2,), (0,)]
1238 )
1239
1240 self._test_graph({
1241 0: [1, 2],
1242 1: [],
1243 2: [3],
1244 3: [],
1245 4: [5],
1246 5: [6],
1247 6: []
1248 },
1249 [(1, 3, 6), (2, 5), (0, 4)]
1250 )
1251
1252 def test_no_dependencies(self):
1253 self._test_graph(
1254 {1: {2},
1255 3: {4},
1256 5: {6}},
1257 [(2, 4, 6), (1, 3, 5)]
1258 )
1259
1260 self._test_graph(
1261 {1: set(),
1262 3: set(),
1263 5: set()},
1264 [(1, 3, 5)]
1265 )
1266
1267 def test_the_node_multiple_times(self):
1268 # Test same node multiple times in dependencies
1269 self._test_graph({1: {2}, 3: {4}, 0: [2, 4, 4, 4, 4, 4]},
1270 [(2, 4), (1, 3, 0)])
1271
1272 # Test adding the same dependency multiple times
1273 ts = functools.TopologicalSorter()
1274 ts.add(1, 2)
1275 ts.add(1, 2)
1276 ts.add(1, 2)
1277 self.assertEqual([*ts.static_order()], [2, 1])
1278
1279 def test_graph_with_iterables(self):
1280 dependson = (2*x + 1 for x in range(5))
1281 ts = functools.TopologicalSorter({0: dependson})
1282 self.assertEqual(list(ts.static_order()), [1, 3, 5, 7, 9, 0])
1283
1284 def test_add_dependencies_for_same_node_incrementally(self):
1285 # Test same node multiple times
1286 ts = functools.TopologicalSorter()
1287 ts.add(1, 2)
1288 ts.add(1, 3)
1289 ts.add(1, 4)
1290 ts.add(1, 5)
1291
1292 ts2 = functools.TopologicalSorter({1: {2, 3, 4, 5}})
1293 self.assertEqual([*ts.static_order()], [*ts2.static_order()])
1294
1295 def test_empty(self):
1296 self._test_graph({}, [])
1297
1298 def test_cycle(self):
1299 # Self cycle
1300 self._assert_cycle({1: {1}}, [1, 1])
1301 # Simple cycle
1302 self._assert_cycle({1: {2}, 2: {1}}, [1, 2, 1])
1303 # Indirect cycle
1304 self._assert_cycle({1: {2}, 2: {3}, 3: {1}}, [1, 3, 2, 1])
1305 # not all elements involved in a cycle
1306 self._assert_cycle({1: {2}, 2: {3}, 3: {1}, 5: {4}, 4: {6}}, [1, 3, 2, 1])
1307 # Multiple cycles
1308 self._assert_cycle({1: {2}, 2: {1}, 3: {4}, 4: {5}, 6: {7}, 7: {6}},
1309 [1, 2, 1])
1310 # Cycle in the middle of the graph
1311 self._assert_cycle({1: {2}, 2: {3}, 3: {2, 4}, 4: {5}}, [3, 2])
1312
1313 def test_calls_before_prepare(self):
1314 ts = functools.TopologicalSorter()
1315
1316 with self.assertRaisesRegex(ValueError, r"prepare\(\) must be called first"):
1317 ts.get_ready()
1318 with self.assertRaisesRegex(ValueError, r"prepare\(\) must be called first"):
1319 ts.done(3)
1320 with self.assertRaisesRegex(ValueError, r"prepare\(\) must be called first"):
1321 ts.is_active()
1322
1323 def test_prepare_multiple_times(self):
1324 ts = functools.TopologicalSorter()
1325 ts.prepare()
1326 with self.assertRaisesRegex(ValueError, r"cannot prepare\(\) more than once"):
1327 ts.prepare()
1328
1329 def test_invalid_nodes_in_done(self):
1330 ts = functools.TopologicalSorter()
1331 ts.add(1, 2, 3, 4)
1332 ts.add(2, 3, 4)
1333 ts.prepare()
1334 ts.get_ready()
1335
1336 with self.assertRaisesRegex(ValueError, "node 2 was not passed out"):
1337 ts.done(2)
1338 with self.assertRaisesRegex(ValueError, r"node 24 was not added using add\(\)"):
1339 ts.done(24)
1340
1341 def test_done(self):
1342 ts = functools.TopologicalSorter()
1343 ts.add(1, 2, 3, 4)
1344 ts.add(2, 3)
1345 ts.prepare()
1346
1347 self.assertEqual(ts.get_ready(), (3, 4))
1348 # If we don't mark anything as done, get_ready() returns nothing
1349 self.assertEqual(ts.get_ready(), ())
1350 ts.done(3)
1351 # Now 2 becomes available as 3 is done
1352 self.assertEqual(ts.get_ready(), (2,))
1353 self.assertEqual(ts.get_ready(), ())
1354 ts.done(4)
1355 ts.done(2)
1356 # Only 1 is missing
1357 self.assertEqual(ts.get_ready(), (1,))
1358 self.assertEqual(ts.get_ready(), ())
1359 ts.done(1)
1360 self.assertEqual(ts.get_ready(), ())
1361 self.assertFalse(ts.is_active())
1362
1363 def test_is_active(self):
1364 ts = functools.TopologicalSorter()
1365 ts.add(1, 2)
1366 ts.prepare()
1367
1368 self.assertTrue(ts.is_active())
1369 self.assertEqual(ts.get_ready(), (2,))
1370 self.assertTrue(ts.is_active())
1371 ts.done(2)
1372 self.assertTrue(ts.is_active())
1373 self.assertEqual(ts.get_ready(), (1,))
1374 self.assertTrue(ts.is_active())
1375 ts.done(1)
1376 self.assertFalse(ts.is_active())
1377
1378 def test_not_hashable_nodes(self):
1379 ts = functools.TopologicalSorter()
1380 self.assertRaises(TypeError, ts.add, dict(), 1)
1381 self.assertRaises(TypeError, ts.add, 1, dict())
1382 self.assertRaises(TypeError, ts.add, dict(), dict())
1383
1384 def test_order_of_insertion_does_not_matter_between_groups(self):
1385 def get_groups(ts):
1386 ts.prepare()
1387 while ts.is_active():
1388 nodes = ts.get_ready()
1389 ts.done(*nodes)
1390 yield set(nodes)
1391
1392 ts = functools.TopologicalSorter()
1393 ts.add(3, 2, 1)
1394 ts.add(1, 0)
1395 ts.add(4, 5)
1396 ts.add(6, 7)
1397 ts.add(4, 7)
1398
1399 ts2 = functools.TopologicalSorter()
1400 ts2.add(1, 0)
1401 ts2.add(3, 2, 1)
1402 ts2.add(4, 7)
1403 ts2.add(6, 7)
1404 ts2.add(4, 5)
1405
1406 self.assertEqual(list(get_groups(ts)), list(get_groups(ts2)))
1407
1408 def test_static_order_does_not_change_with_the_hash_seed(self):
1409 def check_order_with_hash_seed(seed):
1410 code = """if 1:
1411 import functools
1412 ts = functools.TopologicalSorter()
1413 ts.add('blech', 'bluch', 'hola')
1414 ts.add('abcd', 'blech', 'bluch', 'a', 'b')
1415 ts.add('a', 'a string', 'something', 'b')
1416 ts.add('bluch', 'hola', 'abcde', 'a', 'b')
1417 print(list(ts.static_order()))
1418 """
1419 env = os.environ.copy()
1420 # signal to assert_python not to do a copy
1421 # of os.environ on its own
1422 env['__cleanenv'] = True
1423 env['PYTHONHASHSEED'] = str(seed)
1424 out = assert_python_ok('-c', code, **env)
1425 return out
1426
1427 run1 = check_order_with_hash_seed(1234)
1428 run2 = check_order_with_hash_seed(31415)
1429
1430 self.assertNotEqual(run1, "")
1431 self.assertNotEqual(run2, "")
1432 self.assertEqual(run1, run2)
1433
1434
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001435class TestLRU:
Georg Brandl2e7346a2010-07-31 18:09:23 +00001436
1437 def test_lru(self):
1438 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001439 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001440 f = self.module.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001441 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001442 self.assertEqual(maxsize, 20)
1443 self.assertEqual(currsize, 0)
1444 self.assertEqual(hits, 0)
1445 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001446
1447 domain = range(5)
1448 for i in range(1000):
1449 x, y = choice(domain), choice(domain)
1450 actual = f(x, y)
1451 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +00001452 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001453 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001454 self.assertTrue(hits > misses)
1455 self.assertEqual(hits + misses, 1000)
1456 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001457
Raymond Hettinger02566ec2010-09-04 22:46:06 +00001458 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +00001459 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001460 self.assertEqual(hits, 0)
1461 self.assertEqual(misses, 0)
1462 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001463 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001464 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001465 self.assertEqual(hits, 0)
1466 self.assertEqual(misses, 1)
1467 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001468
Nick Coghlan98876832010-08-17 06:17:18 +00001469 # Test bypassing the cache
1470 self.assertIs(f.__wrapped__, orig)
1471 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001472 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001473 self.assertEqual(hits, 0)
1474 self.assertEqual(misses, 1)
1475 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +00001476
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001477 # test size zero (which means "never-cache")
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001478 @self.module.lru_cache(0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001479 def f():
1480 nonlocal f_cnt
1481 f_cnt += 1
1482 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001483 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001484 f_cnt = 0
1485 for i in range(5):
1486 self.assertEqual(f(), 20)
1487 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001488 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001489 self.assertEqual(hits, 0)
1490 self.assertEqual(misses, 5)
1491 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001492
1493 # test size one
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001494 @self.module.lru_cache(1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001495 def f():
1496 nonlocal f_cnt
1497 f_cnt += 1
1498 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001499 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001500 f_cnt = 0
1501 for i in range(5):
1502 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001503 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001504 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001505 self.assertEqual(hits, 4)
1506 self.assertEqual(misses, 1)
1507 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001508
Raymond Hettingerf3098282010-08-15 03:30:45 +00001509 # test size two
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001510 @self.module.lru_cache(2)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001511 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001512 nonlocal f_cnt
1513 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +00001514 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +00001515 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001516 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +00001517 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1518 # * * * *
1519 self.assertEqual(f(x), x*10)
1520 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001521 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001522 self.assertEqual(hits, 12)
1523 self.assertEqual(misses, 4)
1524 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001525
Raymond Hettingerb8218682019-05-26 11:27:35 -07001526 def test_lru_no_args(self):
1527 @self.module.lru_cache
1528 def square(x):
1529 return x ** 2
1530
1531 self.assertEqual(list(map(square, [10, 20, 10])),
1532 [100, 400, 100])
1533 self.assertEqual(square.cache_info().hits, 1)
1534 self.assertEqual(square.cache_info().misses, 2)
1535 self.assertEqual(square.cache_info().maxsize, 128)
1536 self.assertEqual(square.cache_info().currsize, 2)
1537
Raymond Hettingerd8080c02019-01-26 03:02:00 -05001538 def test_lru_bug_35780(self):
1539 # C version of the lru_cache was not checking to see if
1540 # the user function call has already modified the cache
1541 # (this arises in recursive calls and in multi-threading).
1542 # This cause the cache to have orphan links not referenced
1543 # by the cache dictionary.
1544
1545 once = True # Modified by f(x) below
1546
1547 @self.module.lru_cache(maxsize=10)
1548 def f(x):
1549 nonlocal once
1550 rv = f'.{x}.'
1551 if x == 20 and once:
1552 once = False
1553 rv = f(x)
1554 return rv
1555
1556 # Fill the cache
1557 for x in range(15):
1558 self.assertEqual(f(x), f'.{x}.')
1559 self.assertEqual(f.cache_info().currsize, 10)
1560
1561 # Make a recursive call and make sure the cache remains full
1562 self.assertEqual(f(20), '.20.')
1563 self.assertEqual(f.cache_info().currsize, 10)
1564
Raymond Hettinger14adbd42019-04-20 07:20:44 -10001565 def test_lru_bug_36650(self):
1566 # C version of lru_cache was treating a call with an empty **kwargs
1567 # dictionary as being distinct from a call with no keywords at all.
1568 # This did not result in an incorrect answer, but it did trigger
1569 # an unexpected cache miss.
1570
1571 @self.module.lru_cache()
1572 def f(x):
1573 pass
1574
1575 f(0)
1576 f(0, **{})
1577 self.assertEqual(f.cache_info().hits, 1)
1578
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001579 def test_lru_hash_only_once(self):
1580 # To protect against weird reentrancy bugs and to improve
1581 # efficiency when faced with slow __hash__ methods, the
1582 # LRU cache guarantees that it will only call __hash__
1583 # only once per use as an argument to the cached function.
1584
1585 @self.module.lru_cache(maxsize=1)
1586 def f(x, y):
1587 return x * 3 + y
1588
1589 # Simulate the integer 5
1590 mock_int = unittest.mock.Mock()
1591 mock_int.__mul__ = unittest.mock.Mock(return_value=15)
1592 mock_int.__hash__ = unittest.mock.Mock(return_value=999)
1593
1594 # Add to cache: One use as an argument gives one call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001595 self.assertEqual(f(mock_int, 1), 16)
1596 self.assertEqual(mock_int.__hash__.call_count, 1)
1597 self.assertEqual(f.cache_info(), (0, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001598
1599 # Cache hit: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001600 self.assertEqual(f(mock_int, 1), 16)
1601 self.assertEqual(mock_int.__hash__.call_count, 2)
1602 self.assertEqual(f.cache_info(), (1, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001603
Ville Skyttä49b27342017-08-03 09:00:59 +03001604 # Cache eviction: No use as an argument gives no additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001605 self.assertEqual(f(6, 2), 20)
1606 self.assertEqual(mock_int.__hash__.call_count, 2)
1607 self.assertEqual(f.cache_info(), (1, 2, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001608
1609 # Cache miss: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001610 self.assertEqual(f(mock_int, 1), 16)
1611 self.assertEqual(mock_int.__hash__.call_count, 3)
1612 self.assertEqual(f.cache_info(), (1, 3, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001613
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08001614 def test_lru_reentrancy_with_len(self):
1615 # Test to make sure the LRU cache code isn't thrown-off by
1616 # caching the built-in len() function. Since len() can be
1617 # cached, we shouldn't use it inside the lru code itself.
1618 old_len = builtins.len
1619 try:
1620 builtins.len = self.module.lru_cache(4)(len)
1621 for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
1622 self.assertEqual(len('abcdefghijklmn'[:i]), i)
1623 finally:
1624 builtins.len = old_len
1625
Raymond Hettinger605a4472017-01-09 07:50:19 -08001626 def test_lru_star_arg_handling(self):
1627 # Test regression that arose in ea064ff3c10f
1628 @functools.lru_cache()
1629 def f(*args):
1630 return args
1631
1632 self.assertEqual(f(1, 2), (1, 2))
1633 self.assertEqual(f((1, 2)), ((1, 2),))
1634
Yury Selivanov46a02db2016-11-09 18:55:45 -05001635 def test_lru_type_error(self):
1636 # Regression test for issue #28653.
1637 # lru_cache was leaking when one of the arguments
1638 # wasn't cacheable.
1639
1640 @functools.lru_cache(maxsize=None)
1641 def infinite_cache(o):
1642 pass
1643
1644 @functools.lru_cache(maxsize=10)
1645 def limited_cache(o):
1646 pass
1647
1648 with self.assertRaises(TypeError):
1649 infinite_cache([])
1650
1651 with self.assertRaises(TypeError):
1652 limited_cache([])
1653
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001654 def test_lru_with_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001655 @self.module.lru_cache(maxsize=None)
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001656 def fib(n):
1657 if n < 2:
1658 return n
1659 return fib(n-1) + fib(n-2)
1660 self.assertEqual([fib(n) for n in range(16)],
1661 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1662 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001663 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001664 fib.cache_clear()
1665 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001666 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1667
1668 def test_lru_with_maxsize_negative(self):
1669 @self.module.lru_cache(maxsize=-10)
1670 def eq(n):
1671 return n
1672 for i in (0, 1):
1673 self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1674 self.assertEqual(eq.cache_info(),
Raymond Hettingerd8080c02019-01-26 03:02:00 -05001675 self.module._CacheInfo(hits=0, misses=300, maxsize=0, currsize=0))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001676
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001677 def test_lru_with_exceptions(self):
1678 # Verify that user_function exceptions get passed through without
1679 # creating a hard-to-read chained exception.
1680 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +01001681 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001682 @self.module.lru_cache(maxsize)
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001683 def func(i):
1684 return 'abc'[i]
1685 self.assertEqual(func(0), 'a')
1686 with self.assertRaises(IndexError) as cm:
1687 func(15)
1688 self.assertIsNone(cm.exception.__context__)
1689 # Verify that the previous exception did not result in a cached entry
1690 with self.assertRaises(IndexError):
1691 func(15)
1692
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001693 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001694 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001695 @self.module.lru_cache(maxsize=maxsize, typed=True)
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001696 def square(x):
1697 return x * x
1698 self.assertEqual(square(3), 9)
1699 self.assertEqual(type(square(3)), type(9))
1700 self.assertEqual(square(3.0), 9.0)
1701 self.assertEqual(type(square(3.0)), type(9.0))
1702 self.assertEqual(square(x=3), 9)
1703 self.assertEqual(type(square(x=3)), type(9))
1704 self.assertEqual(square(x=3.0), 9.0)
1705 self.assertEqual(type(square(x=3.0)), type(9.0))
1706 self.assertEqual(square.cache_info().hits, 4)
1707 self.assertEqual(square.cache_info().misses, 4)
1708
Antoine Pitroub5b37142012-11-13 21:35:40 +01001709 def test_lru_with_keyword_args(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001710 @self.module.lru_cache()
Antoine Pitroub5b37142012-11-13 21:35:40 +01001711 def fib(n):
1712 if n < 2:
1713 return n
1714 return fib(n=n-1) + fib(n=n-2)
1715 self.assertEqual(
1716 [fib(n=number) for number in range(16)],
1717 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1718 )
1719 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001720 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001721 fib.cache_clear()
1722 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001723 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001724
1725 def test_lru_with_keyword_args_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001726 @self.module.lru_cache(maxsize=None)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001727 def fib(n):
1728 if n < 2:
1729 return n
1730 return fib(n=n-1) + fib(n=n-2)
1731 self.assertEqual([fib(n=number) for number in range(16)],
1732 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1733 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001734 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001735 fib.cache_clear()
1736 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001737 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1738
Raymond Hettinger4ee39142017-01-08 17:28:20 -08001739 def test_kwargs_order(self):
1740 # PEP 468: Preserving Keyword Argument Order
1741 @self.module.lru_cache(maxsize=10)
1742 def f(**kwargs):
1743 return list(kwargs.items())
1744 self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)])
1745 self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)])
1746 self.assertEqual(f.cache_info(),
1747 self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2))
1748
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001749 def test_lru_cache_decoration(self):
1750 def f(zomg: 'zomg_annotation'):
1751 """f doc string"""
1752 return 42
1753 g = self.module.lru_cache()(f)
1754 for attr in self.module.WRAPPER_ASSIGNMENTS:
1755 self.assertEqual(getattr(g, attr), getattr(f, attr))
1756
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001757 def test_lru_cache_threaded(self):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001758 n, m = 5, 11
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001759 def orig(x, y):
1760 return 3 * x + y
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001761 f = self.module.lru_cache(maxsize=n*m)(orig)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001762 hits, misses, maxsize, currsize = f.cache_info()
1763 self.assertEqual(currsize, 0)
1764
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001765 start = threading.Event()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001766 def full(k):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001767 start.wait(10)
1768 for _ in range(m):
Serhiy Storchaka391af752015-06-08 12:44:18 +03001769 self.assertEqual(f(k, 0), orig(k, 0))
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001770
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001771 def clear():
1772 start.wait(10)
1773 for _ in range(2*m):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001774 f.cache_clear()
1775
1776 orig_si = sys.getswitchinterval()
Xavier de Gaye7522ef42016-12-08 11:06:56 +01001777 support.setswitchinterval(1e-6)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001778 try:
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001779 # create n threads in order to fill cache
Serhiy Storchaka391af752015-06-08 12:44:18 +03001780 threads = [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001781 for k in range(n)]
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001782 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001783 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001784
1785 hits, misses, maxsize, currsize = f.cache_info()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001786 if self.module is py_functools:
1787 # XXX: Why can be not equal?
1788 self.assertLessEqual(misses, n)
1789 self.assertLessEqual(hits, m*n - misses)
1790 else:
1791 self.assertEqual(misses, n)
1792 self.assertEqual(hits, m*n - misses)
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001793 self.assertEqual(currsize, n)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001794
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001795 # create n threads in order to fill cache and 1 to clear it
1796 threads = [threading.Thread(target=clear)]
Serhiy Storchaka391af752015-06-08 12:44:18 +03001797 threads += [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001798 for k in range(n)]
1799 start.clear()
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001800 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001801 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001802 finally:
1803 sys.setswitchinterval(orig_si)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001804
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001805 def test_lru_cache_threaded2(self):
1806 # Simultaneous call with the same arguments
1807 n, m = 5, 7
1808 start = threading.Barrier(n+1)
1809 pause = threading.Barrier(n+1)
1810 stop = threading.Barrier(n+1)
1811 @self.module.lru_cache(maxsize=m*n)
1812 def f(x):
1813 pause.wait(10)
1814 return 3 * x
1815 self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1816 def test():
1817 for i in range(m):
1818 start.wait(10)
1819 self.assertEqual(f(i), 3 * i)
1820 stop.wait(10)
1821 threads = [threading.Thread(target=test) for k in range(n)]
1822 with support.start_threads(threads):
1823 for i in range(m):
1824 start.wait(10)
1825 stop.reset()
1826 pause.wait(10)
1827 start.reset()
1828 stop.wait(10)
1829 pause.reset()
1830 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1831
Serhiy Storchaka67796522017-01-12 18:34:33 +02001832 def test_lru_cache_threaded3(self):
1833 @self.module.lru_cache(maxsize=2)
1834 def f(x):
1835 time.sleep(.01)
1836 return 3 * x
1837 def test(i, x):
1838 with self.subTest(thread=i):
1839 self.assertEqual(f(x), 3 * x, i)
1840 threads = [threading.Thread(target=test, args=(i, v))
1841 for i, v in enumerate([1, 2, 2, 3, 2])]
1842 with support.start_threads(threads):
1843 pass
1844
Raymond Hettinger03923422013-03-04 02:52:50 -05001845 def test_need_for_rlock(self):
1846 # This will deadlock on an LRU cache that uses a regular lock
1847
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001848 @self.module.lru_cache(maxsize=10)
Raymond Hettinger03923422013-03-04 02:52:50 -05001849 def test_func(x):
1850 'Used to demonstrate a reentrant lru_cache call within a single thread'
1851 return x
1852
1853 class DoubleEq:
1854 'Demonstrate a reentrant lru_cache call within a single thread'
1855 def __init__(self, x):
1856 self.x = x
1857 def __hash__(self):
1858 return self.x
1859 def __eq__(self, other):
1860 if self.x == 2:
1861 test_func(DoubleEq(1))
1862 return self.x == other.x
1863
1864 test_func(DoubleEq(1)) # Load the cache
1865 test_func(DoubleEq(2)) # Load the cache
1866 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1867 DoubleEq(2)) # Verify the correct return value
1868
Serhiy Storchakae7070f02015-06-08 11:19:24 +03001869 def test_lru_method(self):
1870 class X(int):
1871 f_cnt = 0
1872 @self.module.lru_cache(2)
1873 def f(self, x):
1874 self.f_cnt += 1
1875 return x*10+self
1876 a = X(5)
1877 b = X(5)
1878 c = X(7)
1879 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1880
1881 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1882 self.assertEqual(a.f(x), x*10 + 5)
1883 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1884 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1885
1886 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1887 self.assertEqual(b.f(x), x*10 + 5)
1888 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1889 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1890
1891 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1892 self.assertEqual(c.f(x), x*10 + 7)
1893 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1894 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1895
1896 self.assertEqual(a.f.cache_info(), X.f.cache_info())
1897 self.assertEqual(b.f.cache_info(), X.f.cache_info())
1898 self.assertEqual(c.f.cache_info(), X.f.cache_info())
1899
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001900 def test_pickle(self):
1901 cls = self.__class__
1902 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1903 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1904 with self.subTest(proto=proto, func=f):
1905 f_copy = pickle.loads(pickle.dumps(f, proto))
1906 self.assertIs(f_copy, f)
1907
1908 def test_copy(self):
1909 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001910 def orig(x, y):
1911 return 3 * x + y
1912 part = self.module.partial(orig, 2)
1913 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1914 self.module.lru_cache(2)(part))
1915 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001916 with self.subTest(func=f):
1917 f_copy = copy.copy(f)
1918 self.assertIs(f_copy, f)
1919
1920 def test_deepcopy(self):
1921 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001922 def orig(x, y):
1923 return 3 * x + y
1924 part = self.module.partial(orig, 2)
1925 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1926 self.module.lru_cache(2)(part))
1927 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001928 with self.subTest(func=f):
1929 f_copy = copy.deepcopy(f)
1930 self.assertIs(f_copy, f)
1931
Manjusaka051ff522019-11-12 15:30:18 +08001932 def test_lru_cache_parameters(self):
1933 @self.module.lru_cache(maxsize=2)
1934 def f():
1935 return 1
1936 self.assertEqual(f.cache_parameters(), {'maxsize': 2, "typed": False})
1937
1938 @self.module.lru_cache(maxsize=1000, typed=True)
1939 def f():
1940 return 1
1941 self.assertEqual(f.cache_parameters(), {'maxsize': 1000, "typed": True})
1942
Dennis Sweeney1253c3e2020-05-05 17:14:32 -04001943 def test_lru_cache_weakrefable(self):
1944 @self.module.lru_cache
1945 def test_function(x):
1946 return x
1947
1948 class A:
1949 @self.module.lru_cache
1950 def test_method(self, x):
1951 return (self, x)
1952
1953 @staticmethod
1954 @self.module.lru_cache
1955 def test_staticmethod(x):
1956 return (self, x)
1957
1958 refs = [weakref.ref(test_function),
1959 weakref.ref(A.test_method),
1960 weakref.ref(A.test_staticmethod)]
1961
1962 for ref in refs:
1963 self.assertIsNotNone(ref())
1964
1965 del A
1966 del test_function
1967 gc.collect()
1968
1969 for ref in refs:
1970 self.assertIsNone(ref())
1971
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001972
1973@py_functools.lru_cache()
1974def py_cached_func(x, y):
1975 return 3 * x + y
1976
1977@c_functools.lru_cache()
1978def c_cached_func(x, y):
1979 return 3 * x + y
1980
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001981
1982class TestLRUPy(TestLRU, unittest.TestCase):
1983 module = py_functools
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001984 cached_func = py_cached_func,
1985
1986 @module.lru_cache()
1987 def cached_meth(self, x, y):
1988 return 3 * x + y
1989
1990 @staticmethod
1991 @module.lru_cache()
1992 def cached_staticmeth(x, y):
1993 return 3 * x + y
1994
1995
1996class TestLRUC(TestLRU, unittest.TestCase):
1997 module = c_functools
1998 cached_func = c_cached_func,
1999
2000 @module.lru_cache()
2001 def cached_meth(self, x, y):
2002 return 3 * x + y
2003
2004 @staticmethod
2005 @module.lru_cache()
2006 def cached_staticmeth(x, y):
2007 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03002008
Raymond Hettinger03923422013-03-04 02:52:50 -05002009
Łukasz Langa6f692512013-06-05 12:20:24 +02002010class TestSingleDispatch(unittest.TestCase):
2011 def test_simple_overloads(self):
2012 @functools.singledispatch
2013 def g(obj):
2014 return "base"
2015 def g_int(i):
2016 return "integer"
2017 g.register(int, g_int)
2018 self.assertEqual(g("str"), "base")
2019 self.assertEqual(g(1), "integer")
2020 self.assertEqual(g([1,2,3]), "base")
2021
2022 def test_mro(self):
2023 @functools.singledispatch
2024 def g(obj):
2025 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02002026 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02002027 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02002028 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02002029 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02002030 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02002031 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02002032 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02002033 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02002034 def g_A(a):
2035 return "A"
2036 def g_B(b):
2037 return "B"
2038 g.register(A, g_A)
2039 g.register(B, g_B)
2040 self.assertEqual(g(A()), "A")
2041 self.assertEqual(g(B()), "B")
2042 self.assertEqual(g(C()), "A")
2043 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02002044
2045 def test_register_decorator(self):
2046 @functools.singledispatch
2047 def g(obj):
2048 return "base"
2049 @g.register(int)
2050 def g_int(i):
2051 return "int %s" % (i,)
2052 self.assertEqual(g(""), "base")
2053 self.assertEqual(g(12), "int 12")
2054 self.assertIs(g.dispatch(int), g_int)
2055 self.assertIs(g.dispatch(object), g.dispatch(str))
2056 # Note: in the assert above this is not g.
2057 # @singledispatch returns the wrapper.
2058
2059 def test_wrapping_attributes(self):
2060 @functools.singledispatch
2061 def g(obj):
2062 "Simple test"
2063 return "Test"
2064 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02002065 if sys.flags.optimize < 2:
2066 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02002067
2068 @unittest.skipUnless(decimal, 'requires _decimal')
2069 @support.cpython_only
2070 def test_c_classes(self):
2071 @functools.singledispatch
2072 def g(obj):
2073 return "base"
2074 @g.register(decimal.DecimalException)
2075 def _(obj):
2076 return obj.args
2077 subn = decimal.Subnormal("Exponent < Emin")
2078 rnd = decimal.Rounded("Number got rounded")
2079 self.assertEqual(g(subn), ("Exponent < Emin",))
2080 self.assertEqual(g(rnd), ("Number got rounded",))
2081 @g.register(decimal.Subnormal)
2082 def _(obj):
2083 return "Too small to care."
2084 self.assertEqual(g(subn), "Too small to care.")
2085 self.assertEqual(g(rnd), ("Number got rounded",))
2086
2087 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02002088 # None of the examples in this test depend on haystack ordering.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002089 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02002090 mro = functools._compose_mro
2091 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
2092 for haystack in permutations(bases):
2093 m = mro(dict, haystack)
Guido van Rossumf0666942016-08-23 10:47:07 -07002094 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
2095 c.Collection, c.Sized, c.Iterable,
2096 c.Container, object])
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002097 bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict]
Łukasz Langa6f692512013-06-05 12:20:24 +02002098 for haystack in permutations(bases):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002099 m = mro(collections.ChainMap, haystack)
2100 self.assertEqual(m, [collections.ChainMap, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07002101 c.Collection, c.Sized, c.Iterable,
2102 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02002103
2104 # If there's a generic function with implementations registered for
2105 # both Sized and Container, passing a defaultdict to it results in an
2106 # ambiguous dispatch which will cause a RuntimeError (see
2107 # test_mro_conflicts).
2108 bases = [c.Container, c.Sized, str]
2109 for haystack in permutations(bases):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002110 m = mro(collections.defaultdict, [c.Sized, c.Container, str])
2111 self.assertEqual(m, [collections.defaultdict, dict, c.Sized,
2112 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02002113
2114 # MutableSequence below is registered directly on D. In other words, it
Martin Panter46f50722016-05-26 05:35:26 +00002115 # precedes MutableMapping which means single dispatch will always
Łukasz Langa3720c772013-07-01 16:00:38 +02002116 # choose MutableSequence here.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002117 class D(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02002118 pass
2119 c.MutableSequence.register(D)
2120 bases = [c.MutableSequence, c.MutableMapping]
2121 for haystack in permutations(bases):
2122 m = mro(D, bases)
Guido van Rossumf0666942016-08-23 10:47:07 -07002123 self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002124 collections.defaultdict, dict, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07002125 c.Collection, c.Sized, c.Iterable, c.Container,
Łukasz Langa3720c772013-07-01 16:00:38 +02002126 object])
2127
2128 # Container and Callable are registered on different base classes and
2129 # a generic function supporting both should always pick the Callable
2130 # implementation if a C instance is passed.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002131 class C(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02002132 def __call__(self):
2133 pass
2134 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
2135 for haystack in permutations(bases):
2136 m = mro(C, haystack)
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002137 self.assertEqual(m, [C, c.Callable, collections.defaultdict, dict, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07002138 c.Collection, c.Sized, c.Iterable,
2139 c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02002140
2141 def test_register_abc(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002142 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02002143 d = {"a": "b"}
2144 l = [1, 2, 3]
2145 s = {object(), None}
2146 f = frozenset(s)
2147 t = (1, 2, 3)
2148 @functools.singledispatch
2149 def g(obj):
2150 return "base"
2151 self.assertEqual(g(d), "base")
2152 self.assertEqual(g(l), "base")
2153 self.assertEqual(g(s), "base")
2154 self.assertEqual(g(f), "base")
2155 self.assertEqual(g(t), "base")
2156 g.register(c.Sized, lambda obj: "sized")
2157 self.assertEqual(g(d), "sized")
2158 self.assertEqual(g(l), "sized")
2159 self.assertEqual(g(s), "sized")
2160 self.assertEqual(g(f), "sized")
2161 self.assertEqual(g(t), "sized")
2162 g.register(c.MutableMapping, lambda obj: "mutablemapping")
2163 self.assertEqual(g(d), "mutablemapping")
2164 self.assertEqual(g(l), "sized")
2165 self.assertEqual(g(s), "sized")
2166 self.assertEqual(g(f), "sized")
2167 self.assertEqual(g(t), "sized")
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002168 g.register(collections.ChainMap, lambda obj: "chainmap")
Łukasz Langa6f692512013-06-05 12:20:24 +02002169 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
2170 self.assertEqual(g(l), "sized")
2171 self.assertEqual(g(s), "sized")
2172 self.assertEqual(g(f), "sized")
2173 self.assertEqual(g(t), "sized")
2174 g.register(c.MutableSequence, lambda obj: "mutablesequence")
2175 self.assertEqual(g(d), "mutablemapping")
2176 self.assertEqual(g(l), "mutablesequence")
2177 self.assertEqual(g(s), "sized")
2178 self.assertEqual(g(f), "sized")
2179 self.assertEqual(g(t), "sized")
2180 g.register(c.MutableSet, lambda obj: "mutableset")
2181 self.assertEqual(g(d), "mutablemapping")
2182 self.assertEqual(g(l), "mutablesequence")
2183 self.assertEqual(g(s), "mutableset")
2184 self.assertEqual(g(f), "sized")
2185 self.assertEqual(g(t), "sized")
2186 g.register(c.Mapping, lambda obj: "mapping")
2187 self.assertEqual(g(d), "mutablemapping") # not specific enough
2188 self.assertEqual(g(l), "mutablesequence")
2189 self.assertEqual(g(s), "mutableset")
2190 self.assertEqual(g(f), "sized")
2191 self.assertEqual(g(t), "sized")
2192 g.register(c.Sequence, lambda obj: "sequence")
2193 self.assertEqual(g(d), "mutablemapping")
2194 self.assertEqual(g(l), "mutablesequence")
2195 self.assertEqual(g(s), "mutableset")
2196 self.assertEqual(g(f), "sized")
2197 self.assertEqual(g(t), "sequence")
2198 g.register(c.Set, lambda obj: "set")
2199 self.assertEqual(g(d), "mutablemapping")
2200 self.assertEqual(g(l), "mutablesequence")
2201 self.assertEqual(g(s), "mutableset")
2202 self.assertEqual(g(f), "set")
2203 self.assertEqual(g(t), "sequence")
2204 g.register(dict, lambda obj: "dict")
2205 self.assertEqual(g(d), "dict")
2206 self.assertEqual(g(l), "mutablesequence")
2207 self.assertEqual(g(s), "mutableset")
2208 self.assertEqual(g(f), "set")
2209 self.assertEqual(g(t), "sequence")
2210 g.register(list, lambda obj: "list")
2211 self.assertEqual(g(d), "dict")
2212 self.assertEqual(g(l), "list")
2213 self.assertEqual(g(s), "mutableset")
2214 self.assertEqual(g(f), "set")
2215 self.assertEqual(g(t), "sequence")
2216 g.register(set, lambda obj: "concrete-set")
2217 self.assertEqual(g(d), "dict")
2218 self.assertEqual(g(l), "list")
2219 self.assertEqual(g(s), "concrete-set")
2220 self.assertEqual(g(f), "set")
2221 self.assertEqual(g(t), "sequence")
2222 g.register(frozenset, lambda obj: "frozen-set")
2223 self.assertEqual(g(d), "dict")
2224 self.assertEqual(g(l), "list")
2225 self.assertEqual(g(s), "concrete-set")
2226 self.assertEqual(g(f), "frozen-set")
2227 self.assertEqual(g(t), "sequence")
2228 g.register(tuple, lambda obj: "tuple")
2229 self.assertEqual(g(d), "dict")
2230 self.assertEqual(g(l), "list")
2231 self.assertEqual(g(s), "concrete-set")
2232 self.assertEqual(g(f), "frozen-set")
2233 self.assertEqual(g(t), "tuple")
2234
Łukasz Langa3720c772013-07-01 16:00:38 +02002235 def test_c3_abc(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002236 c = collections.abc
Łukasz Langa3720c772013-07-01 16:00:38 +02002237 mro = functools._c3_mro
2238 class A(object):
2239 pass
2240 class B(A):
2241 def __len__(self):
2242 return 0 # implies Sized
2243 @c.Container.register
2244 class C(object):
2245 pass
2246 class D(object):
2247 pass # unrelated
2248 class X(D, C, B):
2249 def __call__(self):
2250 pass # implies Callable
2251 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
2252 for abcs in permutations([c.Sized, c.Callable, c.Container]):
2253 self.assertEqual(mro(X, abcs=abcs), expected)
2254 # unrelated ABCs don't appear in the resulting MRO
2255 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
2256 self.assertEqual(mro(X, abcs=many_abcs), expected)
2257
Yury Selivanov77a8cd62015-08-18 14:20:00 -04002258 def test_false_meta(self):
2259 # see issue23572
2260 class MetaA(type):
2261 def __len__(self):
2262 return 0
2263 class A(metaclass=MetaA):
2264 pass
2265 class AA(A):
2266 pass
2267 @functools.singledispatch
2268 def fun(a):
2269 return 'base A'
2270 @fun.register(A)
2271 def _(a):
2272 return 'fun A'
2273 aa = AA()
2274 self.assertEqual(fun(aa), 'fun A')
2275
Łukasz Langa6f692512013-06-05 12:20:24 +02002276 def test_mro_conflicts(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002277 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02002278 @functools.singledispatch
2279 def g(arg):
2280 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02002281 class O(c.Sized):
2282 def __len__(self):
2283 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02002284 o = O()
2285 self.assertEqual(g(o), "base")
2286 g.register(c.Iterable, lambda arg: "iterable")
2287 g.register(c.Container, lambda arg: "container")
2288 g.register(c.Sized, lambda arg: "sized")
2289 g.register(c.Set, lambda arg: "set")
2290 self.assertEqual(g(o), "sized")
2291 c.Iterable.register(O)
2292 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
2293 c.Container.register(O)
2294 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02002295 c.Set.register(O)
2296 self.assertEqual(g(o), "set") # because c.Set is a subclass of
2297 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02002298 class P:
2299 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02002300 p = P()
2301 self.assertEqual(g(p), "base")
2302 c.Iterable.register(P)
2303 self.assertEqual(g(p), "iterable")
2304 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02002305 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02002306 g(p)
Benjamin Petersonab078e92016-07-13 21:13:29 -07002307 self.assertIn(
2308 str(re_one.exception),
2309 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2310 "or <class 'collections.abc.Iterable'>"),
2311 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
2312 "or <class 'collections.abc.Container'>")),
2313 )
Łukasz Langa6f692512013-06-05 12:20:24 +02002314 class Q(c.Sized):
2315 def __len__(self):
2316 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02002317 q = Q()
2318 self.assertEqual(g(q), "sized")
2319 c.Iterable.register(Q)
2320 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
2321 c.Set.register(Q)
2322 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02002323 # c.Sized and c.Iterable
2324 @functools.singledispatch
2325 def h(arg):
2326 return "base"
2327 @h.register(c.Sized)
2328 def _(arg):
2329 return "sized"
2330 @h.register(c.Container)
2331 def _(arg):
2332 return "container"
2333 # Even though Sized and Container are explicit bases of MutableMapping,
2334 # this ABC is implicitly registered on defaultdict which makes all of
2335 # MutableMapping's bases implicit as well from defaultdict's
2336 # perspective.
2337 with self.assertRaises(RuntimeError) as re_two:
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002338 h(collections.defaultdict(lambda: 0))
Benjamin Petersonab078e92016-07-13 21:13:29 -07002339 self.assertIn(
2340 str(re_two.exception),
2341 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2342 "or <class 'collections.abc.Sized'>"),
2343 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2344 "or <class 'collections.abc.Container'>")),
2345 )
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002346 class R(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02002347 pass
2348 c.MutableSequence.register(R)
2349 @functools.singledispatch
2350 def i(arg):
2351 return "base"
2352 @i.register(c.MutableMapping)
2353 def _(arg):
2354 return "mapping"
2355 @i.register(c.MutableSequence)
2356 def _(arg):
2357 return "sequence"
2358 r = R()
2359 self.assertEqual(i(r), "sequence")
2360 class S:
2361 pass
2362 class T(S, c.Sized):
2363 def __len__(self):
2364 return 0
2365 t = T()
2366 self.assertEqual(h(t), "sized")
2367 c.Container.register(T)
2368 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
2369 class U:
2370 def __len__(self):
2371 return 0
2372 u = U()
2373 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
2374 # from the existence of __len__()
2375 c.Container.register(U)
2376 # There is no preference for registered versus inferred ABCs.
2377 with self.assertRaises(RuntimeError) as re_three:
2378 h(u)
Benjamin Petersonab078e92016-07-13 21:13:29 -07002379 self.assertIn(
2380 str(re_three.exception),
2381 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2382 "or <class 'collections.abc.Sized'>"),
2383 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2384 "or <class 'collections.abc.Container'>")),
2385 )
Łukasz Langa3720c772013-07-01 16:00:38 +02002386 class V(c.Sized, S):
2387 def __len__(self):
2388 return 0
2389 @functools.singledispatch
2390 def j(arg):
2391 return "base"
2392 @j.register(S)
2393 def _(arg):
2394 return "s"
2395 @j.register(c.Container)
2396 def _(arg):
2397 return "container"
2398 v = V()
2399 self.assertEqual(j(v), "s")
2400 c.Container.register(V)
2401 self.assertEqual(j(v), "container") # because it ends up right after
2402 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02002403
2404 def test_cache_invalidation(self):
2405 from collections import UserDict
INADA Naoki9811e802017-09-30 16:13:02 +09002406 import weakref
2407
Łukasz Langa6f692512013-06-05 12:20:24 +02002408 class TracingDict(UserDict):
2409 def __init__(self, *args, **kwargs):
2410 super(TracingDict, self).__init__(*args, **kwargs)
2411 self.set_ops = []
2412 self.get_ops = []
2413 def __getitem__(self, key):
2414 result = self.data[key]
2415 self.get_ops.append(key)
2416 return result
2417 def __setitem__(self, key, value):
2418 self.set_ops.append(key)
2419 self.data[key] = value
2420 def clear(self):
2421 self.data.clear()
INADA Naoki9811e802017-09-30 16:13:02 +09002422
Łukasz Langa6f692512013-06-05 12:20:24 +02002423 td = TracingDict()
INADA Naoki9811e802017-09-30 16:13:02 +09002424 with support.swap_attr(weakref, "WeakKeyDictionary", lambda: td):
2425 c = collections.abc
2426 @functools.singledispatch
2427 def g(arg):
2428 return "base"
2429 d = {}
2430 l = []
2431 self.assertEqual(len(td), 0)
2432 self.assertEqual(g(d), "base")
2433 self.assertEqual(len(td), 1)
2434 self.assertEqual(td.get_ops, [])
2435 self.assertEqual(td.set_ops, [dict])
2436 self.assertEqual(td.data[dict], g.registry[object])
2437 self.assertEqual(g(l), "base")
2438 self.assertEqual(len(td), 2)
2439 self.assertEqual(td.get_ops, [])
2440 self.assertEqual(td.set_ops, [dict, list])
2441 self.assertEqual(td.data[dict], g.registry[object])
2442 self.assertEqual(td.data[list], g.registry[object])
2443 self.assertEqual(td.data[dict], td.data[list])
2444 self.assertEqual(g(l), "base")
2445 self.assertEqual(g(d), "base")
2446 self.assertEqual(td.get_ops, [list, dict])
2447 self.assertEqual(td.set_ops, [dict, list])
2448 g.register(list, lambda arg: "list")
2449 self.assertEqual(td.get_ops, [list, dict])
2450 self.assertEqual(len(td), 0)
2451 self.assertEqual(g(d), "base")
2452 self.assertEqual(len(td), 1)
2453 self.assertEqual(td.get_ops, [list, dict])
2454 self.assertEqual(td.set_ops, [dict, list, dict])
2455 self.assertEqual(td.data[dict],
2456 functools._find_impl(dict, g.registry))
2457 self.assertEqual(g(l), "list")
2458 self.assertEqual(len(td), 2)
2459 self.assertEqual(td.get_ops, [list, dict])
2460 self.assertEqual(td.set_ops, [dict, list, dict, list])
2461 self.assertEqual(td.data[list],
2462 functools._find_impl(list, g.registry))
2463 class X:
2464 pass
2465 c.MutableMapping.register(X) # Will not invalidate the cache,
2466 # not using ABCs yet.
2467 self.assertEqual(g(d), "base")
2468 self.assertEqual(g(l), "list")
2469 self.assertEqual(td.get_ops, [list, dict, dict, list])
2470 self.assertEqual(td.set_ops, [dict, list, dict, list])
2471 g.register(c.Sized, lambda arg: "sized")
2472 self.assertEqual(len(td), 0)
2473 self.assertEqual(g(d), "sized")
2474 self.assertEqual(len(td), 1)
2475 self.assertEqual(td.get_ops, [list, dict, dict, list])
2476 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
2477 self.assertEqual(g(l), "list")
2478 self.assertEqual(len(td), 2)
2479 self.assertEqual(td.get_ops, [list, dict, dict, list])
2480 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2481 self.assertEqual(g(l), "list")
2482 self.assertEqual(g(d), "sized")
2483 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
2484 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2485 g.dispatch(list)
2486 g.dispatch(dict)
2487 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
2488 list, dict])
2489 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2490 c.MutableSet.register(X) # Will invalidate the cache.
2491 self.assertEqual(len(td), 2) # Stale cache.
2492 self.assertEqual(g(l), "list")
2493 self.assertEqual(len(td), 1)
2494 g.register(c.MutableMapping, lambda arg: "mutablemapping")
2495 self.assertEqual(len(td), 0)
2496 self.assertEqual(g(d), "mutablemapping")
2497 self.assertEqual(len(td), 1)
2498 self.assertEqual(g(l), "list")
2499 self.assertEqual(len(td), 2)
2500 g.register(dict, lambda arg: "dict")
2501 self.assertEqual(g(d), "dict")
2502 self.assertEqual(g(l), "list")
2503 g._clear_cache()
2504 self.assertEqual(len(td), 0)
Łukasz Langa6f692512013-06-05 12:20:24 +02002505
Łukasz Langae5697532017-12-11 13:56:31 -08002506 def test_annotations(self):
2507 @functools.singledispatch
2508 def i(arg):
2509 return "base"
2510 @i.register
2511 def _(arg: collections.abc.Mapping):
2512 return "mapping"
2513 @i.register
2514 def _(arg: "collections.abc.Sequence"):
2515 return "sequence"
2516 self.assertEqual(i(None), "base")
2517 self.assertEqual(i({"a": 1}), "mapping")
2518 self.assertEqual(i([1, 2, 3]), "sequence")
2519 self.assertEqual(i((1, 2, 3)), "sequence")
2520 self.assertEqual(i("str"), "sequence")
2521
2522 # Registering classes as callables doesn't work with annotations,
2523 # you need to pass the type explicitly.
2524 @i.register(str)
2525 class _:
2526 def __init__(self, arg):
2527 self.arg = arg
2528
2529 def __eq__(self, other):
2530 return self.arg == other
2531 self.assertEqual(i("str"), "str")
2532
Ethan Smithc6512752018-05-26 16:38:33 -04002533 def test_method_register(self):
2534 class A:
2535 @functools.singledispatchmethod
2536 def t(self, arg):
2537 self.arg = "base"
2538 @t.register(int)
2539 def _(self, arg):
2540 self.arg = "int"
2541 @t.register(str)
2542 def _(self, arg):
2543 self.arg = "str"
2544 a = A()
2545
2546 a.t(0)
2547 self.assertEqual(a.arg, "int")
2548 aa = A()
2549 self.assertFalse(hasattr(aa, 'arg'))
2550 a.t('')
2551 self.assertEqual(a.arg, "str")
2552 aa = A()
2553 self.assertFalse(hasattr(aa, 'arg'))
2554 a.t(0.0)
2555 self.assertEqual(a.arg, "base")
2556 aa = A()
2557 self.assertFalse(hasattr(aa, 'arg'))
2558
2559 def test_staticmethod_register(self):
2560 class A:
2561 @functools.singledispatchmethod
2562 @staticmethod
2563 def t(arg):
2564 return arg
2565 @t.register(int)
2566 @staticmethod
2567 def _(arg):
2568 return isinstance(arg, int)
2569 @t.register(str)
2570 @staticmethod
2571 def _(arg):
2572 return isinstance(arg, str)
2573 a = A()
2574
2575 self.assertTrue(A.t(0))
2576 self.assertTrue(A.t(''))
2577 self.assertEqual(A.t(0.0), 0.0)
2578
2579 def test_classmethod_register(self):
2580 class A:
2581 def __init__(self, arg):
2582 self.arg = arg
2583
2584 @functools.singledispatchmethod
2585 @classmethod
2586 def t(cls, arg):
2587 return cls("base")
2588 @t.register(int)
2589 @classmethod
2590 def _(cls, arg):
2591 return cls("int")
2592 @t.register(str)
2593 @classmethod
2594 def _(cls, arg):
2595 return cls("str")
2596
2597 self.assertEqual(A.t(0).arg, "int")
2598 self.assertEqual(A.t('').arg, "str")
2599 self.assertEqual(A.t(0.0).arg, "base")
2600
2601 def test_callable_register(self):
2602 class A:
2603 def __init__(self, arg):
2604 self.arg = arg
2605
2606 @functools.singledispatchmethod
2607 @classmethod
2608 def t(cls, arg):
2609 return cls("base")
2610
2611 @A.t.register(int)
2612 @classmethod
2613 def _(cls, arg):
2614 return cls("int")
2615 @A.t.register(str)
2616 @classmethod
2617 def _(cls, arg):
2618 return cls("str")
2619
2620 self.assertEqual(A.t(0).arg, "int")
2621 self.assertEqual(A.t('').arg, "str")
2622 self.assertEqual(A.t(0.0).arg, "base")
2623
2624 def test_abstractmethod_register(self):
2625 class Abstract(abc.ABCMeta):
2626
2627 @functools.singledispatchmethod
2628 @abc.abstractmethod
2629 def add(self, x, y):
2630 pass
2631
2632 self.assertTrue(Abstract.add.__isabstractmethod__)
2633
2634 def test_type_ann_register(self):
2635 class A:
2636 @functools.singledispatchmethod
2637 def t(self, arg):
2638 return "base"
2639 @t.register
2640 def _(self, arg: int):
2641 return "int"
2642 @t.register
2643 def _(self, arg: str):
2644 return "str"
2645 a = A()
2646
2647 self.assertEqual(a.t(0), "int")
2648 self.assertEqual(a.t(''), "str")
2649 self.assertEqual(a.t(0.0), "base")
2650
Łukasz Langae5697532017-12-11 13:56:31 -08002651 def test_invalid_registrations(self):
2652 msg_prefix = "Invalid first argument to `register()`: "
2653 msg_suffix = (
2654 ". Use either `@register(some_class)` or plain `@register` on an "
2655 "annotated function."
2656 )
2657 @functools.singledispatch
2658 def i(arg):
2659 return "base"
2660 with self.assertRaises(TypeError) as exc:
2661 @i.register(42)
2662 def _(arg):
2663 return "I annotated with a non-type"
2664 self.assertTrue(str(exc.exception).startswith(msg_prefix + "42"))
2665 self.assertTrue(str(exc.exception).endswith(msg_suffix))
2666 with self.assertRaises(TypeError) as exc:
2667 @i.register
2668 def _(arg):
2669 return "I forgot to annotate"
2670 self.assertTrue(str(exc.exception).startswith(msg_prefix +
2671 "<function TestSingleDispatch.test_invalid_registrations.<locals>._"
2672 ))
2673 self.assertTrue(str(exc.exception).endswith(msg_suffix))
2674
Łukasz Langae5697532017-12-11 13:56:31 -08002675 with self.assertRaises(TypeError) as exc:
2676 @i.register
2677 def _(arg: typing.Iterable[str]):
2678 # At runtime, dispatching on generics is impossible.
2679 # When registering implementations with singledispatch, avoid
2680 # types from `typing`. Instead, annotate with regular types
2681 # or ABCs.
2682 return "I annotated with a generic collection"
Lysandros Nikolaoud6738102019-05-20 00:11:21 +02002683 self.assertTrue(str(exc.exception).startswith(
2684 "Invalid annotation for 'arg'."
Łukasz Langae5697532017-12-11 13:56:31 -08002685 ))
Lysandros Nikolaoud6738102019-05-20 00:11:21 +02002686 self.assertTrue(str(exc.exception).endswith(
2687 'typing.Iterable[str] is not a class.'
2688 ))
Łukasz Langae5697532017-12-11 13:56:31 -08002689
Dong-hee Na445f1b32018-07-10 16:26:36 +09002690 def test_invalid_positional_argument(self):
2691 @functools.singledispatch
2692 def f(*args):
2693 pass
2694 msg = 'f requires at least 1 positional argument'
INADA Naoki56d8f572018-07-17 13:44:47 +09002695 with self.assertRaisesRegex(TypeError, msg):
Dong-hee Na445f1b32018-07-10 16:26:36 +09002696 f()
Łukasz Langa6f692512013-06-05 12:20:24 +02002697
Carl Meyerd658dea2018-08-28 01:11:56 -06002698
2699class CachedCostItem:
2700 _cost = 1
2701
2702 def __init__(self):
2703 self.lock = py_functools.RLock()
2704
2705 @py_functools.cached_property
2706 def cost(self):
2707 """The cost of the item."""
2708 with self.lock:
2709 self._cost += 1
2710 return self._cost
2711
2712
2713class OptionallyCachedCostItem:
2714 _cost = 1
2715
2716 def get_cost(self):
2717 """The cost of the item."""
2718 self._cost += 1
2719 return self._cost
2720
2721 cached_cost = py_functools.cached_property(get_cost)
2722
2723
2724class CachedCostItemWait:
2725
2726 def __init__(self, event):
2727 self._cost = 1
2728 self.lock = py_functools.RLock()
2729 self.event = event
2730
2731 @py_functools.cached_property
2732 def cost(self):
2733 self.event.wait(1)
2734 with self.lock:
2735 self._cost += 1
2736 return self._cost
2737
2738
2739class CachedCostItemWithSlots:
2740 __slots__ = ('_cost')
2741
2742 def __init__(self):
2743 self._cost = 1
2744
2745 @py_functools.cached_property
2746 def cost(self):
2747 raise RuntimeError('never called, slots not supported')
2748
2749
2750class TestCachedProperty(unittest.TestCase):
2751 def test_cached(self):
2752 item = CachedCostItem()
2753 self.assertEqual(item.cost, 2)
2754 self.assertEqual(item.cost, 2) # not 3
2755
2756 def test_cached_attribute_name_differs_from_func_name(self):
2757 item = OptionallyCachedCostItem()
2758 self.assertEqual(item.get_cost(), 2)
2759 self.assertEqual(item.cached_cost, 3)
2760 self.assertEqual(item.get_cost(), 4)
2761 self.assertEqual(item.cached_cost, 3)
2762
2763 def test_threaded(self):
2764 go = threading.Event()
2765 item = CachedCostItemWait(go)
2766
2767 num_threads = 3
2768
2769 orig_si = sys.getswitchinterval()
2770 sys.setswitchinterval(1e-6)
2771 try:
2772 threads = [
2773 threading.Thread(target=lambda: item.cost)
2774 for k in range(num_threads)
2775 ]
2776 with support.start_threads(threads):
2777 go.set()
2778 finally:
2779 sys.setswitchinterval(orig_si)
2780
2781 self.assertEqual(item.cost, 2)
2782
2783 def test_object_with_slots(self):
2784 item = CachedCostItemWithSlots()
2785 with self.assertRaisesRegex(
2786 TypeError,
2787 "No '__dict__' attribute on 'CachedCostItemWithSlots' instance to cache 'cost' property.",
2788 ):
2789 item.cost
2790
2791 def test_immutable_dict(self):
2792 class MyMeta(type):
2793 @py_functools.cached_property
2794 def prop(self):
2795 return True
2796
2797 class MyClass(metaclass=MyMeta):
2798 pass
2799
2800 with self.assertRaisesRegex(
2801 TypeError,
2802 "The '__dict__' attribute on 'MyMeta' instance does not support item assignment for caching 'prop' property.",
2803 ):
2804 MyClass.prop
2805
2806 def test_reuse_different_names(self):
2807 """Disallow this case because decorated function a would not be cached."""
2808 with self.assertRaises(RuntimeError) as ctx:
2809 class ReusedCachedProperty:
2810 @py_functools.cached_property
2811 def a(self):
2812 pass
2813
2814 b = a
2815
2816 self.assertEqual(
2817 str(ctx.exception.__context__),
2818 str(TypeError("Cannot assign the same cached_property to two different names ('a' and 'b')."))
2819 )
2820
2821 def test_reuse_same_name(self):
2822 """Reusing a cached_property on different classes under the same name is OK."""
2823 counter = 0
2824
2825 @py_functools.cached_property
2826 def _cp(_self):
2827 nonlocal counter
2828 counter += 1
2829 return counter
2830
2831 class A:
2832 cp = _cp
2833
2834 class B:
2835 cp = _cp
2836
2837 a = A()
2838 b = B()
2839
2840 self.assertEqual(a.cp, 1)
2841 self.assertEqual(b.cp, 2)
2842 self.assertEqual(a.cp, 1)
2843
2844 def test_set_name_not_called(self):
2845 cp = py_functools.cached_property(lambda s: None)
2846 class Foo:
2847 pass
2848
2849 Foo.cp = cp
2850
2851 with self.assertRaisesRegex(
2852 TypeError,
2853 "Cannot use cached_property instance without calling __set_name__ on it.",
2854 ):
2855 Foo().cp
2856
2857 def test_access_from_class(self):
2858 self.assertIsInstance(CachedCostItem.cost, py_functools.cached_property)
2859
2860 def test_doc(self):
2861 self.assertEqual(CachedCostItem.cost.__doc__, "The cost of the item.")
2862
2863
Raymond Hettinger9c323f82005-02-28 19:39:44 +00002864if __name__ == '__main__':
Zachary Ware38c707e2015-04-13 15:00:43 -05002865 unittest.main()