blob: ffbd0fcf2d80f2dc158054f84e018c1e068da28c [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08002import builtins
Raymond Hettinger003be522011-05-03 11:01:32 -07003import collections
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03004import collections.abc
Serhiy Storchaka45120f22015-10-24 09:49:56 +03005import copy
Łukasz Langa6f692512013-06-05 12:20:24 +02006from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00007import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00008from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02009import sys
10from test import support
Antoine Pitroua6a4dc82017-09-07 18:56:24 +020011import threading
Serhiy Storchaka67796522017-01-12 18:34:33 +020012import time
Łukasz Langae5697532017-12-11 13:56:31 -080013import typing
Łukasz Langa6f692512013-06-05 12:20:24 +020014import unittest
Raymond Hettingerd191ef22017-01-07 20:44:48 -080015import unittest.mock
Łukasz Langa6f692512013-06-05 12:20:24 +020016from weakref import proxy
Nick Coghlan457fc9a2016-09-10 20:00:02 +100017import contextlib
Raymond Hettinger9c323f82005-02-28 19:39:44 +000018
Antoine Pitroub5b37142012-11-13 21:35:40 +010019import functools
20
Antoine Pitroub5b37142012-11-13 21:35:40 +010021py_functools = support.import_fresh_module('functools', blocked=['_functools'])
22c_functools = support.import_fresh_module('functools', fresh=['_functools'])
23
Łukasz Langa6f692512013-06-05 12:20:24 +020024decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
25
Nick Coghlan457fc9a2016-09-10 20:00:02 +100026@contextlib.contextmanager
27def replaced_module(name, replacement):
28 original_module = sys.modules[name]
29 sys.modules[name] = replacement
30 try:
31 yield
32 finally:
33 sys.modules[name] = original_module
Łukasz Langa6f692512013-06-05 12:20:24 +020034
Raymond Hettinger9c323f82005-02-28 19:39:44 +000035def capture(*args, **kw):
36 """capture all positional and keyword arguments"""
37 return args, kw
38
Łukasz Langa6f692512013-06-05 12:20:24 +020039
Jack Diederiche0cbd692009-04-01 04:27:09 +000040def signature(part):
41 """ return the signature of a partial object """
42 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000043
Serhiy Storchaka38741282016-02-02 18:45:17 +020044class MyTuple(tuple):
45 pass
46
47class BadTuple(tuple):
48 def __add__(self, other):
49 return list(self) + list(other)
50
51class MyDict(dict):
52 pass
53
Łukasz Langa6f692512013-06-05 12:20:24 +020054
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020055class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000056
57 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010058 p = self.partial(capture, 1, 2, a=10, b=20)
59 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000060 self.assertEqual(p(3, 4, b=30, c=40),
61 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010062 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000063 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000064
65 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010066 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000067 # attributes should be readable
68 self.assertEqual(p.func, capture)
69 self.assertEqual(p.args, (1, 2))
70 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000071
72 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010073 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000074 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010075 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000076 except TypeError:
77 pass
78 else:
79 self.fail('First arg not checked for callability')
80
81 def test_protection_of_callers_dict_argument(self):
82 # a caller's dictionary should not be altered by partial
83 def func(a=10, b=20):
84 return a
85 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010086 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000087 self.assertEqual(p(**d), 3)
88 self.assertEqual(d, {'a':3})
89 p(b=7)
90 self.assertEqual(d, {'a':3})
91
Serhiy Storchaka9639e4a2017-02-20 14:04:30 +020092 def test_kwargs_copy(self):
93 # Issue #29532: Altering a kwarg dictionary passed to a constructor
94 # should not affect a partial object after creation
95 d = {'a': 3}
96 p = self.partial(capture, **d)
97 self.assertEqual(p(), ((), {'a': 3}))
98 d['a'] = 5
99 self.assertEqual(p(), ((), {'a': 3}))
100
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000101 def test_arg_combinations(self):
102 # exercise special code paths for zero args in either partial
103 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100104 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000105 self.assertEqual(p(), ((), {}))
106 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100107 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000108 self.assertEqual(p(), ((1,2), {}))
109 self.assertEqual(p(3,4), ((1,2,3,4), {}))
110
111 def test_kw_combinations(self):
112 # exercise special code paths for no keyword args in
113 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100114 p = self.partial(capture)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400115 self.assertEqual(p.keywords, {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000116 self.assertEqual(p(), ((), {}))
117 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100118 p = self.partial(capture, a=1)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400119 self.assertEqual(p.keywords, {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000120 self.assertEqual(p(), ((), {'a':1}))
121 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
122 # keyword args in the call override those in the partial object
123 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
124
125 def test_positional(self):
126 # make sure positional arguments are captured correctly
127 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100128 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000129 expected = args + ('x',)
130 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000131 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000132
133 def test_keyword(self):
134 # make sure keyword arguments are captured correctly
135 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100136 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000137 expected = {'a':a,'x':None}
138 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000139 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000140
141 def test_no_side_effects(self):
142 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100143 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000144 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000145 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000146 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000147 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000148
149 def test_error_propagation(self):
150 def f(x, y):
151 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100152 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
153 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
154 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
155 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000156
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000157 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100158 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000159 p = proxy(f)
160 self.assertEqual(f.func, p.func)
161 f = None
162 self.assertRaises(ReferenceError, getattr, p, 'func')
163
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000164 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000165 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100166 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000167 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100168 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000169 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000170
Alexander Belopolskye49af342015-03-01 15:08:17 -0500171 def test_nested_optimization(self):
172 partial = self.partial
Alexander Belopolskye49af342015-03-01 15:08:17 -0500173 inner = partial(signature, 'asdf')
174 nested = partial(inner, bar=True)
175 flat = partial(signature, 'asdf', bar=True)
176 self.assertEqual(signature(nested), signature(flat))
177
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300178 def test_nested_partial_with_attribute(self):
179 # see issue 25137
180 partial = self.partial
181
182 def foo(bar):
183 return bar
184
185 p = partial(foo, 'first')
186 p2 = partial(p, 'second')
187 p2.new_attr = 'spam'
188 self.assertEqual(p2.new_attr, 'spam')
189
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000190 def test_repr(self):
191 args = (object(), object())
192 args_repr = ', '.join(repr(a) for a in args)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200193 kwargs = {'a': object(), 'b': object()}
194 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
195 'b={b!r}, a={a!r}'.format_map(kwargs)]
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000196 if self.partial in (c_functools.partial, py_functools.partial):
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000197 name = 'functools.partial'
198 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100199 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000200
Antoine Pitroub5b37142012-11-13 21:35:40 +0100201 f = self.partial(capture)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000202 self.assertEqual(f'{name}({capture!r})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000203
Antoine Pitroub5b37142012-11-13 21:35:40 +0100204 f = self.partial(capture, *args)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000205 self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000206
Antoine Pitroub5b37142012-11-13 21:35:40 +0100207 f = self.partial(capture, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200208 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000209 [f'{name}({capture!r}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200210 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000211
Antoine Pitroub5b37142012-11-13 21:35:40 +0100212 f = self.partial(capture, *args, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200213 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000214 [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200215 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000216
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300217 def test_recursive_repr(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000218 if self.partial in (c_functools.partial, py_functools.partial):
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300219 name = 'functools.partial'
220 else:
221 name = self.partial.__name__
222
223 f = self.partial(capture)
224 f.__setstate__((f, (), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300225 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000226 self.assertEqual(repr(f), '%s(...)' % (name,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300227 finally:
228 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300229
230 f = self.partial(capture)
231 f.__setstate__((capture, (f,), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300232 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000233 self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300234 finally:
235 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300236
237 f = self.partial(capture)
238 f.__setstate__((capture, (), {'a': f}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300239 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000240 self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300241 finally:
242 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300243
Jack Diederiche0cbd692009-04-01 04:27:09 +0000244 def test_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000245 with self.AllowPickle():
246 f = self.partial(signature, ['asdf'], bar=[True])
247 f.attr = []
248 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
249 f_copy = pickle.loads(pickle.dumps(f, proto))
250 self.assertEqual(signature(f_copy), signature(f))
Serhiy Storchaka38741282016-02-02 18:45:17 +0200251
252 def test_copy(self):
253 f = self.partial(signature, ['asdf'], bar=[True])
254 f.attr = []
255 f_copy = copy.copy(f)
256 self.assertEqual(signature(f_copy), signature(f))
257 self.assertIs(f_copy.attr, f.attr)
258 self.assertIs(f_copy.args, f.args)
259 self.assertIs(f_copy.keywords, f.keywords)
260
261 def test_deepcopy(self):
262 f = self.partial(signature, ['asdf'], bar=[True])
263 f.attr = []
264 f_copy = copy.deepcopy(f)
265 self.assertEqual(signature(f_copy), signature(f))
266 self.assertIsNot(f_copy.attr, f.attr)
267 self.assertIsNot(f_copy.args, f.args)
268 self.assertIsNot(f_copy.args[0], f.args[0])
269 self.assertIsNot(f_copy.keywords, f.keywords)
270 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
271
272 def test_setstate(self):
273 f = self.partial(signature)
274 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000275
Serhiy Storchaka38741282016-02-02 18:45:17 +0200276 self.assertEqual(signature(f),
277 (capture, (1,), dict(a=10), dict(attr=[])))
278 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
279
280 f.__setstate__((capture, (1,), dict(a=10), None))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000281
Serhiy Storchaka38741282016-02-02 18:45:17 +0200282 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
283 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
284
285 f.__setstate__((capture, (1,), None, None))
286 #self.assertEqual(signature(f), (capture, (1,), {}, {}))
287 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
288 self.assertEqual(f(2), ((1, 2), {}))
289 self.assertEqual(f(), ((1,), {}))
290
291 f.__setstate__((capture, (), {}, None))
292 self.assertEqual(signature(f), (capture, (), {}, {}))
293 self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
294 self.assertEqual(f(2), ((2,), {}))
295 self.assertEqual(f(), ((), {}))
296
297 def test_setstate_errors(self):
298 f = self.partial(signature)
299 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
300 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
301 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
302 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
303 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
304 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
305 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
306
307 def test_setstate_subclasses(self):
308 f = self.partial(signature)
309 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
310 s = signature(f)
311 self.assertEqual(s, (capture, (1,), dict(a=10), {}))
312 self.assertIs(type(s[1]), tuple)
313 self.assertIs(type(s[2]), dict)
314 r = f()
315 self.assertEqual(r, ((1,), {'a': 10}))
316 self.assertIs(type(r[0]), tuple)
317 self.assertIs(type(r[1]), dict)
318
319 f.__setstate__((capture, BadTuple((1,)), {}, None))
320 s = signature(f)
321 self.assertEqual(s, (capture, (1,), {}, {}))
322 self.assertIs(type(s[1]), tuple)
323 r = f(2)
324 self.assertEqual(r, ((1, 2), {}))
325 self.assertIs(type(r[0]), tuple)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000326
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300327 def test_recursive_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000328 with self.AllowPickle():
329 f = self.partial(capture)
330 f.__setstate__((f, (), {}, {}))
331 try:
332 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
333 with self.assertRaises(RecursionError):
334 pickle.dumps(f, proto)
335 finally:
336 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300337
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000338 f = self.partial(capture)
339 f.__setstate__((capture, (f,), {}, {}))
340 try:
341 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
342 f_copy = pickle.loads(pickle.dumps(f, proto))
343 try:
344 self.assertIs(f_copy.args[0], f_copy)
345 finally:
346 f_copy.__setstate__((capture, (), {}, {}))
347 finally:
348 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300349
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000350 f = self.partial(capture)
351 f.__setstate__((capture, (), {'a': f}, {}))
352 try:
353 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
354 f_copy = pickle.loads(pickle.dumps(f, proto))
355 try:
356 self.assertIs(f_copy.keywords['a'], f_copy)
357 finally:
358 f_copy.__setstate__((capture, (), {}, {}))
359 finally:
360 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300361
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200362 # Issue 6083: Reference counting bug
363 def test_setstate_refcount(self):
364 class BadSequence:
365 def __len__(self):
366 return 4
367 def __getitem__(self, key):
368 if key == 0:
369 return max
370 elif key == 1:
371 return tuple(range(1000000))
372 elif key in (2, 3):
373 return {}
374 raise IndexError
375
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200376 f = self.partial(object)
Serhiy Storchaka38741282016-02-02 18:45:17 +0200377 self.assertRaises(TypeError, f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000378
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000379@unittest.skipUnless(c_functools, 'requires the C _functools module')
380class TestPartialC(TestPartial, unittest.TestCase):
381 if c_functools:
382 partial = c_functools.partial
383
384 class AllowPickle:
385 def __enter__(self):
386 return self
387 def __exit__(self, type, value, tb):
388 return False
389
390 def test_attributes_unwritable(self):
391 # attributes should not be writable
392 p = self.partial(capture, 1, 2, a=10, b=20)
393 self.assertRaises(AttributeError, setattr, p, 'func', map)
394 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
395 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
396
397 p = self.partial(hex)
398 try:
399 del p.__dict__
400 except TypeError:
401 pass
402 else:
403 self.fail('partial object allowed __dict__ to be deleted')
Łukasz Langa6f692512013-06-05 12:20:24 +0200404
Michael Seifert6c3d5272017-03-15 06:26:33 +0100405 def test_manually_adding_non_string_keyword(self):
406 p = self.partial(capture)
407 # Adding a non-string/unicode keyword to partial kwargs
408 p.keywords[1234] = 'value'
409 r = repr(p)
410 self.assertIn('1234', r)
411 self.assertIn("'value'", r)
412 with self.assertRaises(TypeError):
413 p()
414
415 def test_keystr_replaces_value(self):
416 p = self.partial(capture)
417
418 class MutatesYourDict(object):
419 def __str__(self):
420 p.keywords[self] = ['sth2']
421 return 'astr'
422
Mike53f7a7c2017-12-14 14:04:53 +0300423 # Replacing the value during key formatting should keep the original
Michael Seifert6c3d5272017-03-15 06:26:33 +0100424 # value alive (at least long enough).
425 p.keywords[MutatesYourDict()] = ['sth']
426 r = repr(p)
427 self.assertIn('astr', r)
428 self.assertIn("['sth']", r)
429
430
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200431class TestPartialPy(TestPartial, unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000432 partial = py_functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000433
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000434 class AllowPickle:
435 def __init__(self):
436 self._cm = replaced_module("functools", py_functools)
437 def __enter__(self):
438 return self._cm.__enter__()
439 def __exit__(self, type, value, tb):
440 return self._cm.__exit__(type, value, tb)
Łukasz Langa6f692512013-06-05 12:20:24 +0200441
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200442if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000443 class CPartialSubclass(c_functools.partial):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200444 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100445
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000446class PyPartialSubclass(py_functools.partial):
447 pass
Łukasz Langa6f692512013-06-05 12:20:24 +0200448
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200449@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200450class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200451 if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000452 partial = CPartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000453
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300454 # partial subclasses are not optimized for nested calls
455 test_nested_optimization = None
456
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000457class TestPartialPySubclass(TestPartialPy):
458 partial = PyPartialSubclass
Łukasz Langa6f692512013-06-05 12:20:24 +0200459
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000460class TestPartialMethod(unittest.TestCase):
461
462 class A(object):
463 nothing = functools.partialmethod(capture)
464 positional = functools.partialmethod(capture, 1)
465 keywords = functools.partialmethod(capture, a=2)
466 both = functools.partialmethod(capture, 3, b=4)
467
468 nested = functools.partialmethod(positional, 5)
469
470 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
471
472 static = functools.partialmethod(staticmethod(capture), 8)
473 cls = functools.partialmethod(classmethod(capture), d=9)
474
475 a = A()
476
477 def test_arg_combinations(self):
478 self.assertEqual(self.a.nothing(), ((self.a,), {}))
479 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
480 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
481 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
482
483 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
484 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
485 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
486 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
487
488 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
489 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
490 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
491 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
492
493 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
494 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
495 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
496 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
497
498 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
499
500 def test_nested(self):
501 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
502 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
503 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
504 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
505
506 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
507
508 def test_over_partial(self):
509 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
510 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
511 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
512 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
513
514 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
515
516 def test_bound_method_introspection(self):
517 obj = self.a
518 self.assertIs(obj.both.__self__, obj)
519 self.assertIs(obj.nested.__self__, obj)
520 self.assertIs(obj.over_partial.__self__, obj)
521 self.assertIs(obj.cls.__self__, self.A)
522 self.assertIs(self.A.cls.__self__, self.A)
523
524 def test_unbound_method_retrieval(self):
525 obj = self.A
526 self.assertFalse(hasattr(obj.both, "__self__"))
527 self.assertFalse(hasattr(obj.nested, "__self__"))
528 self.assertFalse(hasattr(obj.over_partial, "__self__"))
529 self.assertFalse(hasattr(obj.static, "__self__"))
530 self.assertFalse(hasattr(self.a.static, "__self__"))
531
532 def test_descriptors(self):
533 for obj in [self.A, self.a]:
534 with self.subTest(obj=obj):
535 self.assertEqual(obj.static(), ((8,), {}))
536 self.assertEqual(obj.static(5), ((8, 5), {}))
537 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
538 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
539
540 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
541 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
542 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
543 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
544
545 def test_overriding_keywords(self):
546 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
547 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
548
549 def test_invalid_args(self):
550 with self.assertRaises(TypeError):
551 class B(object):
552 method = functools.partialmethod(None, 1)
553
554 def test_repr(self):
555 self.assertEqual(repr(vars(self.A)['both']),
556 'functools.partialmethod({}, 3, b=4)'.format(capture))
557
558 def test_abstract(self):
559 class Abstract(abc.ABCMeta):
560
561 @abc.abstractmethod
562 def add(self, x, y):
563 pass
564
565 add5 = functools.partialmethod(add, 5)
566
567 self.assertTrue(Abstract.add.__isabstractmethod__)
568 self.assertTrue(Abstract.add5.__isabstractmethod__)
569
570 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
571 self.assertFalse(getattr(func, '__isabstractmethod__', False))
572
573
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000574class TestUpdateWrapper(unittest.TestCase):
575
576 def check_wrapper(self, wrapper, wrapped,
577 assigned=functools.WRAPPER_ASSIGNMENTS,
578 updated=functools.WRAPPER_UPDATES):
579 # Check attributes were assigned
580 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000581 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000582 # Check attributes were updated
583 for name in updated:
584 wrapper_attr = getattr(wrapper, name)
585 wrapped_attr = getattr(wrapped, name)
586 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000587 if name == "__dict__" and key == "__wrapped__":
588 # __wrapped__ is overwritten by the update code
589 continue
590 self.assertIs(wrapped_attr[key], wrapper_attr[key])
591 # Check __wrapped__
592 self.assertIs(wrapper.__wrapped__, wrapped)
593
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000594
R. David Murray378c0cf2010-02-24 01:46:21 +0000595 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000596 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000597 """This is a test"""
598 pass
599 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000600 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000601 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000602 pass
603 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000604 return wrapper, f
605
606 def test_default_update(self):
607 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000608 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000609 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000610 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600611 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000612 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000613 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
614 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000615
R. David Murray378c0cf2010-02-24 01:46:21 +0000616 @unittest.skipIf(sys.flags.optimize >= 2,
617 "Docstrings are omitted with -O2 and above")
618 def test_default_update_doc(self):
619 wrapper, f = self._default_update()
620 self.assertEqual(wrapper.__doc__, 'This is a test')
621
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000622 def test_no_update(self):
623 def f():
624 """This is a test"""
625 pass
626 f.attr = 'This is also a test'
627 def wrapper():
628 pass
629 functools.update_wrapper(wrapper, f, (), ())
630 self.check_wrapper(wrapper, f, (), ())
631 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600632 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000633 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000634 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000635 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000636
637 def test_selective_update(self):
638 def f():
639 pass
640 f.attr = 'This is a different test'
641 f.dict_attr = dict(a=1, b=2, c=3)
642 def wrapper():
643 pass
644 wrapper.dict_attr = {}
645 assign = ('attr',)
646 update = ('dict_attr',)
647 functools.update_wrapper(wrapper, f, assign, update)
648 self.check_wrapper(wrapper, f, assign, update)
649 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600650 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000651 self.assertEqual(wrapper.__doc__, None)
652 self.assertEqual(wrapper.attr, 'This is a different test')
653 self.assertEqual(wrapper.dict_attr, f.dict_attr)
654
Nick Coghlan98876832010-08-17 06:17:18 +0000655 def test_missing_attributes(self):
656 def f():
657 pass
658 def wrapper():
659 pass
660 wrapper.dict_attr = {}
661 assign = ('attr',)
662 update = ('dict_attr',)
663 # Missing attributes on wrapped object are ignored
664 functools.update_wrapper(wrapper, f, assign, update)
665 self.assertNotIn('attr', wrapper.__dict__)
666 self.assertEqual(wrapper.dict_attr, {})
667 # Wrapper must have expected attributes for updating
668 del wrapper.dict_attr
669 with self.assertRaises(AttributeError):
670 functools.update_wrapper(wrapper, f, assign, update)
671 wrapper.dict_attr = 1
672 with self.assertRaises(AttributeError):
673 functools.update_wrapper(wrapper, f, assign, update)
674
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200675 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000676 @unittest.skipIf(sys.flags.optimize >= 2,
677 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000678 def test_builtin_update(self):
679 # Test for bug #1576241
680 def wrapper():
681 pass
682 functools.update_wrapper(wrapper, max)
683 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000684 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000685 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000686
Łukasz Langa6f692512013-06-05 12:20:24 +0200687
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000688class TestWraps(TestUpdateWrapper):
689
R. David Murray378c0cf2010-02-24 01:46:21 +0000690 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000691 def f():
692 """This is a test"""
693 pass
694 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000695 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000696 @functools.wraps(f)
697 def wrapper():
698 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600699 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000700
701 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600702 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000703 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000704 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600705 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000706 self.assertEqual(wrapper.attr, 'This is also a test')
707
Antoine Pitroub5b37142012-11-13 21:35:40 +0100708 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000709 "Docstrings are omitted with -O2 and above")
710 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600711 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000712 self.assertEqual(wrapper.__doc__, 'This is a test')
713
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000714 def test_no_update(self):
715 def f():
716 """This is a test"""
717 pass
718 f.attr = 'This is also a test'
719 @functools.wraps(f, (), ())
720 def wrapper():
721 pass
722 self.check_wrapper(wrapper, f, (), ())
723 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600724 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000725 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000726 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000727
728 def test_selective_update(self):
729 def f():
730 pass
731 f.attr = 'This is a different test'
732 f.dict_attr = dict(a=1, b=2, c=3)
733 def add_dict_attr(f):
734 f.dict_attr = {}
735 return f
736 assign = ('attr',)
737 update = ('dict_attr',)
738 @functools.wraps(f, assign, update)
739 @add_dict_attr
740 def wrapper():
741 pass
742 self.check_wrapper(wrapper, f, assign, update)
743 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600744 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000745 self.assertEqual(wrapper.__doc__, None)
746 self.assertEqual(wrapper.attr, 'This is a different test')
747 self.assertEqual(wrapper.dict_attr, f.dict_attr)
748
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000749
madman-bobe25d5fc2018-10-25 15:02:10 +0100750class TestReduce:
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000751 def test_reduce(self):
752 class Squares:
753 def __init__(self, max):
754 self.max = max
755 self.sofar = []
756
757 def __len__(self):
758 return len(self.sofar)
759
760 def __getitem__(self, i):
761 if not 0 <= i < self.max: raise IndexError
762 n = len(self.sofar)
763 while n <= i:
764 self.sofar.append(n*n)
765 n += 1
766 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000767 def add(x, y):
768 return x + y
madman-bobe25d5fc2018-10-25 15:02:10 +0100769 self.assertEqual(self.reduce(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000770 self.assertEqual(
madman-bobe25d5fc2018-10-25 15:02:10 +0100771 self.reduce(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000772 ['a','c','d','w']
773 )
madman-bobe25d5fc2018-10-25 15:02:10 +0100774 self.assertEqual(self.reduce(lambda x, y: x*y, range(2,8), 1), 5040)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000775 self.assertEqual(
madman-bobe25d5fc2018-10-25 15:02:10 +0100776 self.reduce(lambda x, y: x*y, range(2,21), 1),
Guido van Rossume2a383d2007-01-15 16:59:06 +0000777 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000778 )
madman-bobe25d5fc2018-10-25 15:02:10 +0100779 self.assertEqual(self.reduce(add, Squares(10)), 285)
780 self.assertEqual(self.reduce(add, Squares(10), 0), 285)
781 self.assertEqual(self.reduce(add, Squares(0), 0), 0)
782 self.assertRaises(TypeError, self.reduce)
783 self.assertRaises(TypeError, self.reduce, 42, 42)
784 self.assertRaises(TypeError, self.reduce, 42, 42, 42)
785 self.assertEqual(self.reduce(42, "1"), "1") # func is never called with one item
786 self.assertEqual(self.reduce(42, "", "1"), "1") # func is never called with one item
787 self.assertRaises(TypeError, self.reduce, 42, (42, 42))
788 self.assertRaises(TypeError, self.reduce, add, []) # arg 2 must not be empty sequence with no initial value
789 self.assertRaises(TypeError, self.reduce, add, "")
790 self.assertRaises(TypeError, self.reduce, add, ())
791 self.assertRaises(TypeError, self.reduce, add, object())
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000792
793 class TestFailingIter:
794 def __iter__(self):
795 raise RuntimeError
madman-bobe25d5fc2018-10-25 15:02:10 +0100796 self.assertRaises(RuntimeError, self.reduce, add, TestFailingIter())
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000797
madman-bobe25d5fc2018-10-25 15:02:10 +0100798 self.assertEqual(self.reduce(add, [], None), None)
799 self.assertEqual(self.reduce(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000800
801 class BadSeq:
802 def __getitem__(self, index):
803 raise ValueError
madman-bobe25d5fc2018-10-25 15:02:10 +0100804 self.assertRaises(ValueError, self.reduce, 42, BadSeq())
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000805
806 # Test reduce()'s use of iterators.
807 def test_iterator_usage(self):
808 class SequenceClass:
809 def __init__(self, n):
810 self.n = n
811 def __getitem__(self, i):
812 if 0 <= i < self.n:
813 return i
814 else:
815 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000816
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000817 from operator import add
madman-bobe25d5fc2018-10-25 15:02:10 +0100818 self.assertEqual(self.reduce(add, SequenceClass(5)), 10)
819 self.assertEqual(self.reduce(add, SequenceClass(5), 42), 52)
820 self.assertRaises(TypeError, self.reduce, add, SequenceClass(0))
821 self.assertEqual(self.reduce(add, SequenceClass(0), 42), 42)
822 self.assertEqual(self.reduce(add, SequenceClass(1)), 0)
823 self.assertEqual(self.reduce(add, SequenceClass(1), 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000824
825 d = {"one": 1, "two": 2, "three": 3}
madman-bobe25d5fc2018-10-25 15:02:10 +0100826 self.assertEqual(self.reduce(add, d), "".join(d.keys()))
827
828
829@unittest.skipUnless(c_functools, 'requires the C _functools module')
830class TestReduceC(TestReduce, unittest.TestCase):
831 if c_functools:
832 reduce = c_functools.reduce
833
834
835class TestReducePy(TestReduce, unittest.TestCase):
836 reduce = staticmethod(py_functools.reduce)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000837
Łukasz Langa6f692512013-06-05 12:20:24 +0200838
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200839class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700840
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000841 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700842 def cmp1(x, y):
843 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100844 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700845 self.assertEqual(key(3), key(3))
846 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100847 self.assertGreaterEqual(key(3), key(3))
848
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700849 def cmp2(x, y):
850 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100851 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700852 self.assertEqual(key(4.0), key('4'))
853 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100854 self.assertLessEqual(key(2), key('35'))
855 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700856
857 def test_cmp_to_key_arguments(self):
858 def cmp1(x, y):
859 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100860 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700861 self.assertEqual(key(obj=3), key(obj=3))
862 self.assertGreater(key(obj=3), key(obj=1))
863 with self.assertRaises((TypeError, AttributeError)):
864 key(3) > 1 # rhs is not a K object
865 with self.assertRaises((TypeError, AttributeError)):
866 1 < key(3) # lhs is not a K object
867 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100868 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700869 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200870 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100871 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700872 with self.assertRaises(TypeError):
873 key() # too few args
874 with self.assertRaises(TypeError):
875 key(None, None) # too many args
876
877 def test_bad_cmp(self):
878 def cmp1(x, y):
879 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100880 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700881 with self.assertRaises(ZeroDivisionError):
882 key(3) > key(1)
883
884 class BadCmp:
885 def __lt__(self, other):
886 raise ZeroDivisionError
887 def cmp1(x, y):
888 return BadCmp()
889 with self.assertRaises(ZeroDivisionError):
890 key(3) > key(1)
891
892 def test_obj_field(self):
893 def cmp1(x, y):
894 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100895 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700896 self.assertEqual(key(50).obj, 50)
897
898 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000899 def mycmp(x, y):
900 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100901 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000902 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000903
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700904 def test_sort_int_str(self):
905 def mycmp(x, y):
906 x, y = int(x), int(y)
907 return (x > y) - (x < y)
908 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100909 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700910 self.assertEqual([int(value) for value in values],
911 [0, 1, 1, 2, 3, 4, 5, 7, 10])
912
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000913 def test_hash(self):
914 def mycmp(x, y):
915 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100916 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000917 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700918 self.assertRaises(TypeError, hash, k)
Serhiy Storchaka2e576f52017-04-24 09:05:00 +0300919 self.assertNotIsInstance(k, collections.abc.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000920
Łukasz Langa6f692512013-06-05 12:20:24 +0200921
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200922@unittest.skipUnless(c_functools, 'requires the C _functools module')
923class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
924 if c_functools:
925 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100926
Łukasz Langa6f692512013-06-05 12:20:24 +0200927
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200928class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100929 cmp_to_key = staticmethod(py_functools.cmp_to_key)
930
Łukasz Langa6f692512013-06-05 12:20:24 +0200931
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000932class TestTotalOrdering(unittest.TestCase):
933
934 def test_total_ordering_lt(self):
935 @functools.total_ordering
936 class A:
937 def __init__(self, value):
938 self.value = value
939 def __lt__(self, other):
940 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000941 def __eq__(self, other):
942 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000943 self.assertTrue(A(1) < A(2))
944 self.assertTrue(A(2) > A(1))
945 self.assertTrue(A(1) <= A(2))
946 self.assertTrue(A(2) >= A(1))
947 self.assertTrue(A(2) <= A(2))
948 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000949 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000950
951 def test_total_ordering_le(self):
952 @functools.total_ordering
953 class A:
954 def __init__(self, value):
955 self.value = value
956 def __le__(self, other):
957 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000958 def __eq__(self, other):
959 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000960 self.assertTrue(A(1) < A(2))
961 self.assertTrue(A(2) > A(1))
962 self.assertTrue(A(1) <= A(2))
963 self.assertTrue(A(2) >= A(1))
964 self.assertTrue(A(2) <= A(2))
965 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000966 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000967
968 def test_total_ordering_gt(self):
969 @functools.total_ordering
970 class A:
971 def __init__(self, value):
972 self.value = value
973 def __gt__(self, other):
974 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000975 def __eq__(self, other):
976 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000977 self.assertTrue(A(1) < A(2))
978 self.assertTrue(A(2) > A(1))
979 self.assertTrue(A(1) <= A(2))
980 self.assertTrue(A(2) >= A(1))
981 self.assertTrue(A(2) <= A(2))
982 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000983 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000984
985 def test_total_ordering_ge(self):
986 @functools.total_ordering
987 class A:
988 def __init__(self, value):
989 self.value = value
990 def __ge__(self, other):
991 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000992 def __eq__(self, other):
993 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000994 self.assertTrue(A(1) < A(2))
995 self.assertTrue(A(2) > A(1))
996 self.assertTrue(A(1) <= A(2))
997 self.assertTrue(A(2) >= A(1))
998 self.assertTrue(A(2) <= A(2))
999 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +10001000 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +00001001
1002 def test_total_ordering_no_overwrite(self):
1003 # new methods should not overwrite existing
1004 @functools.total_ordering
1005 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +00001006 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +00001007 self.assertTrue(A(1) < A(2))
1008 self.assertTrue(A(2) > A(1))
1009 self.assertTrue(A(1) <= A(2))
1010 self.assertTrue(A(2) >= A(1))
1011 self.assertTrue(A(2) <= A(2))
1012 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001013
Benjamin Peterson42ebee32010-04-11 01:43:16 +00001014 def test_no_operations_defined(self):
1015 with self.assertRaises(ValueError):
1016 @functools.total_ordering
1017 class A:
1018 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001019
Nick Coghlanf05d9812013-10-02 00:02:03 +10001020 def test_type_error_when_not_implemented(self):
1021 # bug 10042; ensure stack overflow does not occur
1022 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001023 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +10001024 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001025 def __init__(self, value):
1026 self.value = value
1027 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001028 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001029 return self.value == other.value
1030 return False
1031 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001032 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001033 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +10001034 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001035
Nick Coghlanf05d9812013-10-02 00:02:03 +10001036 @functools.total_ordering
1037 class ImplementsGreaterThan:
1038 def __init__(self, value):
1039 self.value = value
1040 def __eq__(self, other):
1041 if isinstance(other, ImplementsGreaterThan):
1042 return self.value == other.value
1043 return False
1044 def __gt__(self, other):
1045 if isinstance(other, ImplementsGreaterThan):
1046 return self.value > other.value
1047 return NotImplemented
1048
1049 @functools.total_ordering
1050 class ImplementsLessThanEqualTo:
1051 def __init__(self, value):
1052 self.value = value
1053 def __eq__(self, other):
1054 if isinstance(other, ImplementsLessThanEqualTo):
1055 return self.value == other.value
1056 return False
1057 def __le__(self, other):
1058 if isinstance(other, ImplementsLessThanEqualTo):
1059 return self.value <= other.value
1060 return NotImplemented
1061
1062 @functools.total_ordering
1063 class ImplementsGreaterThanEqualTo:
1064 def __init__(self, value):
1065 self.value = value
1066 def __eq__(self, other):
1067 if isinstance(other, ImplementsGreaterThanEqualTo):
1068 return self.value == other.value
1069 return False
1070 def __ge__(self, other):
1071 if isinstance(other, ImplementsGreaterThanEqualTo):
1072 return self.value >= other.value
1073 return NotImplemented
1074
1075 @functools.total_ordering
1076 class ComparatorNotImplemented:
1077 def __init__(self, value):
1078 self.value = value
1079 def __eq__(self, other):
1080 if isinstance(other, ComparatorNotImplemented):
1081 return self.value == other.value
1082 return False
1083 def __lt__(self, other):
1084 return NotImplemented
1085
1086 with self.subTest("LT < 1"), self.assertRaises(TypeError):
1087 ImplementsLessThan(-1) < 1
1088
1089 with self.subTest("LT < LE"), self.assertRaises(TypeError):
1090 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
1091
1092 with self.subTest("LT < GT"), self.assertRaises(TypeError):
1093 ImplementsLessThan(1) < ImplementsGreaterThan(1)
1094
1095 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
1096 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
1097
1098 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
1099 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
1100
1101 with self.subTest("GT > GE"), self.assertRaises(TypeError):
1102 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
1103
1104 with self.subTest("GT > LT"), self.assertRaises(TypeError):
1105 ImplementsGreaterThan(5) > ImplementsLessThan(5)
1106
1107 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
1108 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
1109
1110 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
1111 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
1112
1113 with self.subTest("GE when equal"):
1114 a = ComparatorNotImplemented(8)
1115 b = ComparatorNotImplemented(8)
1116 self.assertEqual(a, b)
1117 with self.assertRaises(TypeError):
1118 a >= b
1119
1120 with self.subTest("LE when equal"):
1121 a = ComparatorNotImplemented(9)
1122 b = ComparatorNotImplemented(9)
1123 self.assertEqual(a, b)
1124 with self.assertRaises(TypeError):
1125 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +02001126
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001127 def test_pickle(self):
Serhiy Storchaka92bb90a2016-09-22 11:39:25 +03001128 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001129 for name in '__lt__', '__gt__', '__le__', '__ge__':
1130 with self.subTest(method=name, proto=proto):
1131 method = getattr(Orderable_LT, name)
1132 method_copy = pickle.loads(pickle.dumps(method, proto))
1133 self.assertIs(method_copy, method)
1134
1135@functools.total_ordering
1136class Orderable_LT:
1137 def __init__(self, value):
1138 self.value = value
1139 def __lt__(self, other):
1140 return self.value < other.value
1141 def __eq__(self, other):
1142 return self.value == other.value
1143
1144
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001145class TestLRU:
Georg Brandl2e7346a2010-07-31 18:09:23 +00001146
1147 def test_lru(self):
1148 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001149 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001150 f = self.module.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001151 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001152 self.assertEqual(maxsize, 20)
1153 self.assertEqual(currsize, 0)
1154 self.assertEqual(hits, 0)
1155 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001156
1157 domain = range(5)
1158 for i in range(1000):
1159 x, y = choice(domain), choice(domain)
1160 actual = f(x, y)
1161 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +00001162 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001163 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001164 self.assertTrue(hits > misses)
1165 self.assertEqual(hits + misses, 1000)
1166 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001167
Raymond Hettinger02566ec2010-09-04 22:46:06 +00001168 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +00001169 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001170 self.assertEqual(hits, 0)
1171 self.assertEqual(misses, 0)
1172 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001173 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001174 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001175 self.assertEqual(hits, 0)
1176 self.assertEqual(misses, 1)
1177 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001178
Nick Coghlan98876832010-08-17 06:17:18 +00001179 # Test bypassing the cache
1180 self.assertIs(f.__wrapped__, orig)
1181 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001182 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001183 self.assertEqual(hits, 0)
1184 self.assertEqual(misses, 1)
1185 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +00001186
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001187 # test size zero (which means "never-cache")
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001188 @self.module.lru_cache(0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001189 def f():
1190 nonlocal f_cnt
1191 f_cnt += 1
1192 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001193 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001194 f_cnt = 0
1195 for i in range(5):
1196 self.assertEqual(f(), 20)
1197 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001198 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001199 self.assertEqual(hits, 0)
1200 self.assertEqual(misses, 5)
1201 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001202
1203 # test size one
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001204 @self.module.lru_cache(1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001205 def f():
1206 nonlocal f_cnt
1207 f_cnt += 1
1208 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001209 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001210 f_cnt = 0
1211 for i in range(5):
1212 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001213 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001214 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001215 self.assertEqual(hits, 4)
1216 self.assertEqual(misses, 1)
1217 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001218
Raymond Hettingerf3098282010-08-15 03:30:45 +00001219 # test size two
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001220 @self.module.lru_cache(2)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001221 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001222 nonlocal f_cnt
1223 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +00001224 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +00001225 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001226 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +00001227 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1228 # * * * *
1229 self.assertEqual(f(x), x*10)
1230 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001231 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001232 self.assertEqual(hits, 12)
1233 self.assertEqual(misses, 4)
1234 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001235
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001236 def test_lru_hash_only_once(self):
1237 # To protect against weird reentrancy bugs and to improve
1238 # efficiency when faced with slow __hash__ methods, the
1239 # LRU cache guarantees that it will only call __hash__
1240 # only once per use as an argument to the cached function.
1241
1242 @self.module.lru_cache(maxsize=1)
1243 def f(x, y):
1244 return x * 3 + y
1245
1246 # Simulate the integer 5
1247 mock_int = unittest.mock.Mock()
1248 mock_int.__mul__ = unittest.mock.Mock(return_value=15)
1249 mock_int.__hash__ = unittest.mock.Mock(return_value=999)
1250
1251 # Add to cache: One use as an argument gives one call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001252 self.assertEqual(f(mock_int, 1), 16)
1253 self.assertEqual(mock_int.__hash__.call_count, 1)
1254 self.assertEqual(f.cache_info(), (0, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001255
1256 # Cache hit: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001257 self.assertEqual(f(mock_int, 1), 16)
1258 self.assertEqual(mock_int.__hash__.call_count, 2)
1259 self.assertEqual(f.cache_info(), (1, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001260
Ville Skyttä49b27342017-08-03 09:00:59 +03001261 # Cache eviction: No use as an argument gives no additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001262 self.assertEqual(f(6, 2), 20)
1263 self.assertEqual(mock_int.__hash__.call_count, 2)
1264 self.assertEqual(f.cache_info(), (1, 2, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001265
1266 # Cache miss: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001267 self.assertEqual(f(mock_int, 1), 16)
1268 self.assertEqual(mock_int.__hash__.call_count, 3)
1269 self.assertEqual(f.cache_info(), (1, 3, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001270
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08001271 def test_lru_reentrancy_with_len(self):
1272 # Test to make sure the LRU cache code isn't thrown-off by
1273 # caching the built-in len() function. Since len() can be
1274 # cached, we shouldn't use it inside the lru code itself.
1275 old_len = builtins.len
1276 try:
1277 builtins.len = self.module.lru_cache(4)(len)
1278 for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
1279 self.assertEqual(len('abcdefghijklmn'[:i]), i)
1280 finally:
1281 builtins.len = old_len
1282
Raymond Hettinger605a4472017-01-09 07:50:19 -08001283 def test_lru_star_arg_handling(self):
1284 # Test regression that arose in ea064ff3c10f
1285 @functools.lru_cache()
1286 def f(*args):
1287 return args
1288
1289 self.assertEqual(f(1, 2), (1, 2))
1290 self.assertEqual(f((1, 2)), ((1, 2),))
1291
Yury Selivanov46a02db2016-11-09 18:55:45 -05001292 def test_lru_type_error(self):
1293 # Regression test for issue #28653.
1294 # lru_cache was leaking when one of the arguments
1295 # wasn't cacheable.
1296
1297 @functools.lru_cache(maxsize=None)
1298 def infinite_cache(o):
1299 pass
1300
1301 @functools.lru_cache(maxsize=10)
1302 def limited_cache(o):
1303 pass
1304
1305 with self.assertRaises(TypeError):
1306 infinite_cache([])
1307
1308 with self.assertRaises(TypeError):
1309 limited_cache([])
1310
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001311 def test_lru_with_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001312 @self.module.lru_cache(maxsize=None)
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001313 def fib(n):
1314 if n < 2:
1315 return n
1316 return fib(n-1) + fib(n-2)
1317 self.assertEqual([fib(n) for n in range(16)],
1318 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1319 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001320 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001321 fib.cache_clear()
1322 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001323 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1324
1325 def test_lru_with_maxsize_negative(self):
1326 @self.module.lru_cache(maxsize=-10)
1327 def eq(n):
1328 return n
1329 for i in (0, 1):
1330 self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1331 self.assertEqual(eq.cache_info(),
1332 self.module._CacheInfo(hits=0, misses=300, maxsize=-10, currsize=1))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001333
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001334 def test_lru_with_exceptions(self):
1335 # Verify that user_function exceptions get passed through without
1336 # creating a hard-to-read chained exception.
1337 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +01001338 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001339 @self.module.lru_cache(maxsize)
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001340 def func(i):
1341 return 'abc'[i]
1342 self.assertEqual(func(0), 'a')
1343 with self.assertRaises(IndexError) as cm:
1344 func(15)
1345 self.assertIsNone(cm.exception.__context__)
1346 # Verify that the previous exception did not result in a cached entry
1347 with self.assertRaises(IndexError):
1348 func(15)
1349
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001350 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001351 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001352 @self.module.lru_cache(maxsize=maxsize, typed=True)
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001353 def square(x):
1354 return x * x
1355 self.assertEqual(square(3), 9)
1356 self.assertEqual(type(square(3)), type(9))
1357 self.assertEqual(square(3.0), 9.0)
1358 self.assertEqual(type(square(3.0)), type(9.0))
1359 self.assertEqual(square(x=3), 9)
1360 self.assertEqual(type(square(x=3)), type(9))
1361 self.assertEqual(square(x=3.0), 9.0)
1362 self.assertEqual(type(square(x=3.0)), type(9.0))
1363 self.assertEqual(square.cache_info().hits, 4)
1364 self.assertEqual(square.cache_info().misses, 4)
1365
Antoine Pitroub5b37142012-11-13 21:35:40 +01001366 def test_lru_with_keyword_args(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001367 @self.module.lru_cache()
Antoine Pitroub5b37142012-11-13 21:35:40 +01001368 def fib(n):
1369 if n < 2:
1370 return n
1371 return fib(n=n-1) + fib(n=n-2)
1372 self.assertEqual(
1373 [fib(n=number) for number in range(16)],
1374 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1375 )
1376 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001377 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001378 fib.cache_clear()
1379 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001380 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001381
1382 def test_lru_with_keyword_args_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001383 @self.module.lru_cache(maxsize=None)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001384 def fib(n):
1385 if n < 2:
1386 return n
1387 return fib(n=n-1) + fib(n=n-2)
1388 self.assertEqual([fib(n=number) for number in range(16)],
1389 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1390 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001391 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001392 fib.cache_clear()
1393 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001394 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1395
Raymond Hettinger4ee39142017-01-08 17:28:20 -08001396 def test_kwargs_order(self):
1397 # PEP 468: Preserving Keyword Argument Order
1398 @self.module.lru_cache(maxsize=10)
1399 def f(**kwargs):
1400 return list(kwargs.items())
1401 self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)])
1402 self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)])
1403 self.assertEqual(f.cache_info(),
1404 self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2))
1405
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001406 def test_lru_cache_decoration(self):
1407 def f(zomg: 'zomg_annotation'):
1408 """f doc string"""
1409 return 42
1410 g = self.module.lru_cache()(f)
1411 for attr in self.module.WRAPPER_ASSIGNMENTS:
1412 self.assertEqual(getattr(g, attr), getattr(f, attr))
1413
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001414 def test_lru_cache_threaded(self):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001415 n, m = 5, 11
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001416 def orig(x, y):
1417 return 3 * x + y
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001418 f = self.module.lru_cache(maxsize=n*m)(orig)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001419 hits, misses, maxsize, currsize = f.cache_info()
1420 self.assertEqual(currsize, 0)
1421
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001422 start = threading.Event()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001423 def full(k):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001424 start.wait(10)
1425 for _ in range(m):
Serhiy Storchaka391af752015-06-08 12:44:18 +03001426 self.assertEqual(f(k, 0), orig(k, 0))
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001427
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001428 def clear():
1429 start.wait(10)
1430 for _ in range(2*m):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001431 f.cache_clear()
1432
1433 orig_si = sys.getswitchinterval()
Xavier de Gaye7522ef42016-12-08 11:06:56 +01001434 support.setswitchinterval(1e-6)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001435 try:
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001436 # create n threads in order to fill cache
Serhiy Storchaka391af752015-06-08 12:44:18 +03001437 threads = [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001438 for k in range(n)]
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001439 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001440 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001441
1442 hits, misses, maxsize, currsize = f.cache_info()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001443 if self.module is py_functools:
1444 # XXX: Why can be not equal?
1445 self.assertLessEqual(misses, n)
1446 self.assertLessEqual(hits, m*n - misses)
1447 else:
1448 self.assertEqual(misses, n)
1449 self.assertEqual(hits, m*n - misses)
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001450 self.assertEqual(currsize, n)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001451
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001452 # create n threads in order to fill cache and 1 to clear it
1453 threads = [threading.Thread(target=clear)]
Serhiy Storchaka391af752015-06-08 12:44:18 +03001454 threads += [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001455 for k in range(n)]
1456 start.clear()
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001457 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001458 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001459 finally:
1460 sys.setswitchinterval(orig_si)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001461
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001462 def test_lru_cache_threaded2(self):
1463 # Simultaneous call with the same arguments
1464 n, m = 5, 7
1465 start = threading.Barrier(n+1)
1466 pause = threading.Barrier(n+1)
1467 stop = threading.Barrier(n+1)
1468 @self.module.lru_cache(maxsize=m*n)
1469 def f(x):
1470 pause.wait(10)
1471 return 3 * x
1472 self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1473 def test():
1474 for i in range(m):
1475 start.wait(10)
1476 self.assertEqual(f(i), 3 * i)
1477 stop.wait(10)
1478 threads = [threading.Thread(target=test) for k in range(n)]
1479 with support.start_threads(threads):
1480 for i in range(m):
1481 start.wait(10)
1482 stop.reset()
1483 pause.wait(10)
1484 start.reset()
1485 stop.wait(10)
1486 pause.reset()
1487 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1488
Serhiy Storchaka67796522017-01-12 18:34:33 +02001489 def test_lru_cache_threaded3(self):
1490 @self.module.lru_cache(maxsize=2)
1491 def f(x):
1492 time.sleep(.01)
1493 return 3 * x
1494 def test(i, x):
1495 with self.subTest(thread=i):
1496 self.assertEqual(f(x), 3 * x, i)
1497 threads = [threading.Thread(target=test, args=(i, v))
1498 for i, v in enumerate([1, 2, 2, 3, 2])]
1499 with support.start_threads(threads):
1500 pass
1501
Raymond Hettinger03923422013-03-04 02:52:50 -05001502 def test_need_for_rlock(self):
1503 # This will deadlock on an LRU cache that uses a regular lock
1504
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001505 @self.module.lru_cache(maxsize=10)
Raymond Hettinger03923422013-03-04 02:52:50 -05001506 def test_func(x):
1507 'Used to demonstrate a reentrant lru_cache call within a single thread'
1508 return x
1509
1510 class DoubleEq:
1511 'Demonstrate a reentrant lru_cache call within a single thread'
1512 def __init__(self, x):
1513 self.x = x
1514 def __hash__(self):
1515 return self.x
1516 def __eq__(self, other):
1517 if self.x == 2:
1518 test_func(DoubleEq(1))
1519 return self.x == other.x
1520
1521 test_func(DoubleEq(1)) # Load the cache
1522 test_func(DoubleEq(2)) # Load the cache
1523 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1524 DoubleEq(2)) # Verify the correct return value
1525
Raymond Hettinger4d588972014-08-12 12:44:52 -07001526 def test_early_detection_of_bad_call(self):
1527 # Issue #22184
1528 with self.assertRaises(TypeError):
1529 @functools.lru_cache
1530 def f():
1531 pass
1532
Serhiy Storchakae7070f02015-06-08 11:19:24 +03001533 def test_lru_method(self):
1534 class X(int):
1535 f_cnt = 0
1536 @self.module.lru_cache(2)
1537 def f(self, x):
1538 self.f_cnt += 1
1539 return x*10+self
1540 a = X(5)
1541 b = X(5)
1542 c = X(7)
1543 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1544
1545 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1546 self.assertEqual(a.f(x), x*10 + 5)
1547 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1548 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1549
1550 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1551 self.assertEqual(b.f(x), x*10 + 5)
1552 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1553 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1554
1555 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1556 self.assertEqual(c.f(x), x*10 + 7)
1557 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1558 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1559
1560 self.assertEqual(a.f.cache_info(), X.f.cache_info())
1561 self.assertEqual(b.f.cache_info(), X.f.cache_info())
1562 self.assertEqual(c.f.cache_info(), X.f.cache_info())
1563
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001564 def test_pickle(self):
1565 cls = self.__class__
1566 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1567 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1568 with self.subTest(proto=proto, func=f):
1569 f_copy = pickle.loads(pickle.dumps(f, proto))
1570 self.assertIs(f_copy, f)
1571
1572 def test_copy(self):
1573 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001574 def orig(x, y):
1575 return 3 * x + y
1576 part = self.module.partial(orig, 2)
1577 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1578 self.module.lru_cache(2)(part))
1579 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001580 with self.subTest(func=f):
1581 f_copy = copy.copy(f)
1582 self.assertIs(f_copy, f)
1583
1584 def test_deepcopy(self):
1585 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001586 def orig(x, y):
1587 return 3 * x + y
1588 part = self.module.partial(orig, 2)
1589 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1590 self.module.lru_cache(2)(part))
1591 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001592 with self.subTest(func=f):
1593 f_copy = copy.deepcopy(f)
1594 self.assertIs(f_copy, f)
1595
1596
1597@py_functools.lru_cache()
1598def py_cached_func(x, y):
1599 return 3 * x + y
1600
1601@c_functools.lru_cache()
1602def c_cached_func(x, y):
1603 return 3 * x + y
1604
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001605
1606class TestLRUPy(TestLRU, unittest.TestCase):
1607 module = py_functools
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001608 cached_func = py_cached_func,
1609
1610 @module.lru_cache()
1611 def cached_meth(self, x, y):
1612 return 3 * x + y
1613
1614 @staticmethod
1615 @module.lru_cache()
1616 def cached_staticmeth(x, y):
1617 return 3 * x + y
1618
1619
1620class TestLRUC(TestLRU, unittest.TestCase):
1621 module = c_functools
1622 cached_func = c_cached_func,
1623
1624 @module.lru_cache()
1625 def cached_meth(self, x, y):
1626 return 3 * x + y
1627
1628 @staticmethod
1629 @module.lru_cache()
1630 def cached_staticmeth(x, y):
1631 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001632
Raymond Hettinger03923422013-03-04 02:52:50 -05001633
Łukasz Langa6f692512013-06-05 12:20:24 +02001634class TestSingleDispatch(unittest.TestCase):
1635 def test_simple_overloads(self):
1636 @functools.singledispatch
1637 def g(obj):
1638 return "base"
1639 def g_int(i):
1640 return "integer"
1641 g.register(int, g_int)
1642 self.assertEqual(g("str"), "base")
1643 self.assertEqual(g(1), "integer")
1644 self.assertEqual(g([1,2,3]), "base")
1645
1646 def test_mro(self):
1647 @functools.singledispatch
1648 def g(obj):
1649 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001650 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001651 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001652 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001653 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001654 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001655 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001656 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001657 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001658 def g_A(a):
1659 return "A"
1660 def g_B(b):
1661 return "B"
1662 g.register(A, g_A)
1663 g.register(B, g_B)
1664 self.assertEqual(g(A()), "A")
1665 self.assertEqual(g(B()), "B")
1666 self.assertEqual(g(C()), "A")
1667 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001668
1669 def test_register_decorator(self):
1670 @functools.singledispatch
1671 def g(obj):
1672 return "base"
1673 @g.register(int)
1674 def g_int(i):
1675 return "int %s" % (i,)
1676 self.assertEqual(g(""), "base")
1677 self.assertEqual(g(12), "int 12")
1678 self.assertIs(g.dispatch(int), g_int)
1679 self.assertIs(g.dispatch(object), g.dispatch(str))
1680 # Note: in the assert above this is not g.
1681 # @singledispatch returns the wrapper.
1682
1683 def test_wrapping_attributes(self):
1684 @functools.singledispatch
1685 def g(obj):
1686 "Simple test"
1687 return "Test"
1688 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02001689 if sys.flags.optimize < 2:
1690 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02001691
1692 @unittest.skipUnless(decimal, 'requires _decimal')
1693 @support.cpython_only
1694 def test_c_classes(self):
1695 @functools.singledispatch
1696 def g(obj):
1697 return "base"
1698 @g.register(decimal.DecimalException)
1699 def _(obj):
1700 return obj.args
1701 subn = decimal.Subnormal("Exponent < Emin")
1702 rnd = decimal.Rounded("Number got rounded")
1703 self.assertEqual(g(subn), ("Exponent < Emin",))
1704 self.assertEqual(g(rnd), ("Number got rounded",))
1705 @g.register(decimal.Subnormal)
1706 def _(obj):
1707 return "Too small to care."
1708 self.assertEqual(g(subn), "Too small to care.")
1709 self.assertEqual(g(rnd), ("Number got rounded",))
1710
1711 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001712 # None of the examples in this test depend on haystack ordering.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001713 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001714 mro = functools._compose_mro
1715 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1716 for haystack in permutations(bases):
1717 m = mro(dict, haystack)
Guido van Rossumf0666942016-08-23 10:47:07 -07001718 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
1719 c.Collection, c.Sized, c.Iterable,
1720 c.Container, object])
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001721 bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict]
Łukasz Langa6f692512013-06-05 12:20:24 +02001722 for haystack in permutations(bases):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001723 m = mro(collections.ChainMap, haystack)
1724 self.assertEqual(m, [collections.ChainMap, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001725 c.Collection, c.Sized, c.Iterable,
1726 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001727
1728 # If there's a generic function with implementations registered for
1729 # both Sized and Container, passing a defaultdict to it results in an
1730 # ambiguous dispatch which will cause a RuntimeError (see
1731 # test_mro_conflicts).
1732 bases = [c.Container, c.Sized, str]
1733 for haystack in permutations(bases):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001734 m = mro(collections.defaultdict, [c.Sized, c.Container, str])
1735 self.assertEqual(m, [collections.defaultdict, dict, c.Sized,
1736 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001737
1738 # MutableSequence below is registered directly on D. In other words, it
Martin Panter46f50722016-05-26 05:35:26 +00001739 # precedes MutableMapping which means single dispatch will always
Łukasz Langa3720c772013-07-01 16:00:38 +02001740 # choose MutableSequence here.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001741 class D(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02001742 pass
1743 c.MutableSequence.register(D)
1744 bases = [c.MutableSequence, c.MutableMapping]
1745 for haystack in permutations(bases):
1746 m = mro(D, bases)
Guido van Rossumf0666942016-08-23 10:47:07 -07001747 self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001748 collections.defaultdict, dict, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001749 c.Collection, c.Sized, c.Iterable, c.Container,
Łukasz Langa3720c772013-07-01 16:00:38 +02001750 object])
1751
1752 # Container and Callable are registered on different base classes and
1753 # a generic function supporting both should always pick the Callable
1754 # implementation if a C instance is passed.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001755 class C(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02001756 def __call__(self):
1757 pass
1758 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1759 for haystack in permutations(bases):
1760 m = mro(C, haystack)
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001761 self.assertEqual(m, [C, c.Callable, collections.defaultdict, dict, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001762 c.Collection, c.Sized, c.Iterable,
1763 c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001764
1765 def test_register_abc(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001766 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001767 d = {"a": "b"}
1768 l = [1, 2, 3]
1769 s = {object(), None}
1770 f = frozenset(s)
1771 t = (1, 2, 3)
1772 @functools.singledispatch
1773 def g(obj):
1774 return "base"
1775 self.assertEqual(g(d), "base")
1776 self.assertEqual(g(l), "base")
1777 self.assertEqual(g(s), "base")
1778 self.assertEqual(g(f), "base")
1779 self.assertEqual(g(t), "base")
1780 g.register(c.Sized, lambda obj: "sized")
1781 self.assertEqual(g(d), "sized")
1782 self.assertEqual(g(l), "sized")
1783 self.assertEqual(g(s), "sized")
1784 self.assertEqual(g(f), "sized")
1785 self.assertEqual(g(t), "sized")
1786 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1787 self.assertEqual(g(d), "mutablemapping")
1788 self.assertEqual(g(l), "sized")
1789 self.assertEqual(g(s), "sized")
1790 self.assertEqual(g(f), "sized")
1791 self.assertEqual(g(t), "sized")
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001792 g.register(collections.ChainMap, lambda obj: "chainmap")
Łukasz Langa6f692512013-06-05 12:20:24 +02001793 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1794 self.assertEqual(g(l), "sized")
1795 self.assertEqual(g(s), "sized")
1796 self.assertEqual(g(f), "sized")
1797 self.assertEqual(g(t), "sized")
1798 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1799 self.assertEqual(g(d), "mutablemapping")
1800 self.assertEqual(g(l), "mutablesequence")
1801 self.assertEqual(g(s), "sized")
1802 self.assertEqual(g(f), "sized")
1803 self.assertEqual(g(t), "sized")
1804 g.register(c.MutableSet, lambda obj: "mutableset")
1805 self.assertEqual(g(d), "mutablemapping")
1806 self.assertEqual(g(l), "mutablesequence")
1807 self.assertEqual(g(s), "mutableset")
1808 self.assertEqual(g(f), "sized")
1809 self.assertEqual(g(t), "sized")
1810 g.register(c.Mapping, lambda obj: "mapping")
1811 self.assertEqual(g(d), "mutablemapping") # not specific enough
1812 self.assertEqual(g(l), "mutablesequence")
1813 self.assertEqual(g(s), "mutableset")
1814 self.assertEqual(g(f), "sized")
1815 self.assertEqual(g(t), "sized")
1816 g.register(c.Sequence, lambda obj: "sequence")
1817 self.assertEqual(g(d), "mutablemapping")
1818 self.assertEqual(g(l), "mutablesequence")
1819 self.assertEqual(g(s), "mutableset")
1820 self.assertEqual(g(f), "sized")
1821 self.assertEqual(g(t), "sequence")
1822 g.register(c.Set, lambda obj: "set")
1823 self.assertEqual(g(d), "mutablemapping")
1824 self.assertEqual(g(l), "mutablesequence")
1825 self.assertEqual(g(s), "mutableset")
1826 self.assertEqual(g(f), "set")
1827 self.assertEqual(g(t), "sequence")
1828 g.register(dict, lambda obj: "dict")
1829 self.assertEqual(g(d), "dict")
1830 self.assertEqual(g(l), "mutablesequence")
1831 self.assertEqual(g(s), "mutableset")
1832 self.assertEqual(g(f), "set")
1833 self.assertEqual(g(t), "sequence")
1834 g.register(list, lambda obj: "list")
1835 self.assertEqual(g(d), "dict")
1836 self.assertEqual(g(l), "list")
1837 self.assertEqual(g(s), "mutableset")
1838 self.assertEqual(g(f), "set")
1839 self.assertEqual(g(t), "sequence")
1840 g.register(set, lambda obj: "concrete-set")
1841 self.assertEqual(g(d), "dict")
1842 self.assertEqual(g(l), "list")
1843 self.assertEqual(g(s), "concrete-set")
1844 self.assertEqual(g(f), "set")
1845 self.assertEqual(g(t), "sequence")
1846 g.register(frozenset, lambda obj: "frozen-set")
1847 self.assertEqual(g(d), "dict")
1848 self.assertEqual(g(l), "list")
1849 self.assertEqual(g(s), "concrete-set")
1850 self.assertEqual(g(f), "frozen-set")
1851 self.assertEqual(g(t), "sequence")
1852 g.register(tuple, lambda obj: "tuple")
1853 self.assertEqual(g(d), "dict")
1854 self.assertEqual(g(l), "list")
1855 self.assertEqual(g(s), "concrete-set")
1856 self.assertEqual(g(f), "frozen-set")
1857 self.assertEqual(g(t), "tuple")
1858
Łukasz Langa3720c772013-07-01 16:00:38 +02001859 def test_c3_abc(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001860 c = collections.abc
Łukasz Langa3720c772013-07-01 16:00:38 +02001861 mro = functools._c3_mro
1862 class A(object):
1863 pass
1864 class B(A):
1865 def __len__(self):
1866 return 0 # implies Sized
1867 @c.Container.register
1868 class C(object):
1869 pass
1870 class D(object):
1871 pass # unrelated
1872 class X(D, C, B):
1873 def __call__(self):
1874 pass # implies Callable
1875 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
1876 for abcs in permutations([c.Sized, c.Callable, c.Container]):
1877 self.assertEqual(mro(X, abcs=abcs), expected)
1878 # unrelated ABCs don't appear in the resulting MRO
1879 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
1880 self.assertEqual(mro(X, abcs=many_abcs), expected)
1881
Yury Selivanov77a8cd62015-08-18 14:20:00 -04001882 def test_false_meta(self):
1883 # see issue23572
1884 class MetaA(type):
1885 def __len__(self):
1886 return 0
1887 class A(metaclass=MetaA):
1888 pass
1889 class AA(A):
1890 pass
1891 @functools.singledispatch
1892 def fun(a):
1893 return 'base A'
1894 @fun.register(A)
1895 def _(a):
1896 return 'fun A'
1897 aa = AA()
1898 self.assertEqual(fun(aa), 'fun A')
1899
Łukasz Langa6f692512013-06-05 12:20:24 +02001900 def test_mro_conflicts(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001901 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001902 @functools.singledispatch
1903 def g(arg):
1904 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02001905 class O(c.Sized):
1906 def __len__(self):
1907 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001908 o = O()
1909 self.assertEqual(g(o), "base")
1910 g.register(c.Iterable, lambda arg: "iterable")
1911 g.register(c.Container, lambda arg: "container")
1912 g.register(c.Sized, lambda arg: "sized")
1913 g.register(c.Set, lambda arg: "set")
1914 self.assertEqual(g(o), "sized")
1915 c.Iterable.register(O)
1916 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
1917 c.Container.register(O)
1918 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02001919 c.Set.register(O)
1920 self.assertEqual(g(o), "set") # because c.Set is a subclass of
1921 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02001922 class P:
1923 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02001924 p = P()
1925 self.assertEqual(g(p), "base")
1926 c.Iterable.register(P)
1927 self.assertEqual(g(p), "iterable")
1928 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02001929 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02001930 g(p)
Benjamin Petersonab078e92016-07-13 21:13:29 -07001931 self.assertIn(
1932 str(re_one.exception),
1933 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1934 "or <class 'collections.abc.Iterable'>"),
1935 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
1936 "or <class 'collections.abc.Container'>")),
1937 )
Łukasz Langa6f692512013-06-05 12:20:24 +02001938 class Q(c.Sized):
1939 def __len__(self):
1940 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001941 q = Q()
1942 self.assertEqual(g(q), "sized")
1943 c.Iterable.register(Q)
1944 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
1945 c.Set.register(Q)
1946 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02001947 # c.Sized and c.Iterable
1948 @functools.singledispatch
1949 def h(arg):
1950 return "base"
1951 @h.register(c.Sized)
1952 def _(arg):
1953 return "sized"
1954 @h.register(c.Container)
1955 def _(arg):
1956 return "container"
1957 # Even though Sized and Container are explicit bases of MutableMapping,
1958 # this ABC is implicitly registered on defaultdict which makes all of
1959 # MutableMapping's bases implicit as well from defaultdict's
1960 # perspective.
1961 with self.assertRaises(RuntimeError) as re_two:
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001962 h(collections.defaultdict(lambda: 0))
Benjamin Petersonab078e92016-07-13 21:13:29 -07001963 self.assertIn(
1964 str(re_two.exception),
1965 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1966 "or <class 'collections.abc.Sized'>"),
1967 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1968 "or <class 'collections.abc.Container'>")),
1969 )
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001970 class R(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02001971 pass
1972 c.MutableSequence.register(R)
1973 @functools.singledispatch
1974 def i(arg):
1975 return "base"
1976 @i.register(c.MutableMapping)
1977 def _(arg):
1978 return "mapping"
1979 @i.register(c.MutableSequence)
1980 def _(arg):
1981 return "sequence"
1982 r = R()
1983 self.assertEqual(i(r), "sequence")
1984 class S:
1985 pass
1986 class T(S, c.Sized):
1987 def __len__(self):
1988 return 0
1989 t = T()
1990 self.assertEqual(h(t), "sized")
1991 c.Container.register(T)
1992 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
1993 class U:
1994 def __len__(self):
1995 return 0
1996 u = U()
1997 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
1998 # from the existence of __len__()
1999 c.Container.register(U)
2000 # There is no preference for registered versus inferred ABCs.
2001 with self.assertRaises(RuntimeError) as re_three:
2002 h(u)
Benjamin Petersonab078e92016-07-13 21:13:29 -07002003 self.assertIn(
2004 str(re_three.exception),
2005 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2006 "or <class 'collections.abc.Sized'>"),
2007 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2008 "or <class 'collections.abc.Container'>")),
2009 )
Łukasz Langa3720c772013-07-01 16:00:38 +02002010 class V(c.Sized, S):
2011 def __len__(self):
2012 return 0
2013 @functools.singledispatch
2014 def j(arg):
2015 return "base"
2016 @j.register(S)
2017 def _(arg):
2018 return "s"
2019 @j.register(c.Container)
2020 def _(arg):
2021 return "container"
2022 v = V()
2023 self.assertEqual(j(v), "s")
2024 c.Container.register(V)
2025 self.assertEqual(j(v), "container") # because it ends up right after
2026 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02002027
2028 def test_cache_invalidation(self):
2029 from collections import UserDict
INADA Naoki9811e802017-09-30 16:13:02 +09002030 import weakref
2031
Łukasz Langa6f692512013-06-05 12:20:24 +02002032 class TracingDict(UserDict):
2033 def __init__(self, *args, **kwargs):
2034 super(TracingDict, self).__init__(*args, **kwargs)
2035 self.set_ops = []
2036 self.get_ops = []
2037 def __getitem__(self, key):
2038 result = self.data[key]
2039 self.get_ops.append(key)
2040 return result
2041 def __setitem__(self, key, value):
2042 self.set_ops.append(key)
2043 self.data[key] = value
2044 def clear(self):
2045 self.data.clear()
INADA Naoki9811e802017-09-30 16:13:02 +09002046
Łukasz Langa6f692512013-06-05 12:20:24 +02002047 td = TracingDict()
INADA Naoki9811e802017-09-30 16:13:02 +09002048 with support.swap_attr(weakref, "WeakKeyDictionary", lambda: td):
2049 c = collections.abc
2050 @functools.singledispatch
2051 def g(arg):
2052 return "base"
2053 d = {}
2054 l = []
2055 self.assertEqual(len(td), 0)
2056 self.assertEqual(g(d), "base")
2057 self.assertEqual(len(td), 1)
2058 self.assertEqual(td.get_ops, [])
2059 self.assertEqual(td.set_ops, [dict])
2060 self.assertEqual(td.data[dict], g.registry[object])
2061 self.assertEqual(g(l), "base")
2062 self.assertEqual(len(td), 2)
2063 self.assertEqual(td.get_ops, [])
2064 self.assertEqual(td.set_ops, [dict, list])
2065 self.assertEqual(td.data[dict], g.registry[object])
2066 self.assertEqual(td.data[list], g.registry[object])
2067 self.assertEqual(td.data[dict], td.data[list])
2068 self.assertEqual(g(l), "base")
2069 self.assertEqual(g(d), "base")
2070 self.assertEqual(td.get_ops, [list, dict])
2071 self.assertEqual(td.set_ops, [dict, list])
2072 g.register(list, lambda arg: "list")
2073 self.assertEqual(td.get_ops, [list, dict])
2074 self.assertEqual(len(td), 0)
2075 self.assertEqual(g(d), "base")
2076 self.assertEqual(len(td), 1)
2077 self.assertEqual(td.get_ops, [list, dict])
2078 self.assertEqual(td.set_ops, [dict, list, dict])
2079 self.assertEqual(td.data[dict],
2080 functools._find_impl(dict, g.registry))
2081 self.assertEqual(g(l), "list")
2082 self.assertEqual(len(td), 2)
2083 self.assertEqual(td.get_ops, [list, dict])
2084 self.assertEqual(td.set_ops, [dict, list, dict, list])
2085 self.assertEqual(td.data[list],
2086 functools._find_impl(list, g.registry))
2087 class X:
2088 pass
2089 c.MutableMapping.register(X) # Will not invalidate the cache,
2090 # not using ABCs yet.
2091 self.assertEqual(g(d), "base")
2092 self.assertEqual(g(l), "list")
2093 self.assertEqual(td.get_ops, [list, dict, dict, list])
2094 self.assertEqual(td.set_ops, [dict, list, dict, list])
2095 g.register(c.Sized, lambda arg: "sized")
2096 self.assertEqual(len(td), 0)
2097 self.assertEqual(g(d), "sized")
2098 self.assertEqual(len(td), 1)
2099 self.assertEqual(td.get_ops, [list, dict, dict, list])
2100 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
2101 self.assertEqual(g(l), "list")
2102 self.assertEqual(len(td), 2)
2103 self.assertEqual(td.get_ops, [list, dict, dict, list])
2104 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2105 self.assertEqual(g(l), "list")
2106 self.assertEqual(g(d), "sized")
2107 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
2108 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2109 g.dispatch(list)
2110 g.dispatch(dict)
2111 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
2112 list, dict])
2113 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2114 c.MutableSet.register(X) # Will invalidate the cache.
2115 self.assertEqual(len(td), 2) # Stale cache.
2116 self.assertEqual(g(l), "list")
2117 self.assertEqual(len(td), 1)
2118 g.register(c.MutableMapping, lambda arg: "mutablemapping")
2119 self.assertEqual(len(td), 0)
2120 self.assertEqual(g(d), "mutablemapping")
2121 self.assertEqual(len(td), 1)
2122 self.assertEqual(g(l), "list")
2123 self.assertEqual(len(td), 2)
2124 g.register(dict, lambda arg: "dict")
2125 self.assertEqual(g(d), "dict")
2126 self.assertEqual(g(l), "list")
2127 g._clear_cache()
2128 self.assertEqual(len(td), 0)
Łukasz Langa6f692512013-06-05 12:20:24 +02002129
Łukasz Langae5697532017-12-11 13:56:31 -08002130 def test_annotations(self):
2131 @functools.singledispatch
2132 def i(arg):
2133 return "base"
2134 @i.register
2135 def _(arg: collections.abc.Mapping):
2136 return "mapping"
2137 @i.register
2138 def _(arg: "collections.abc.Sequence"):
2139 return "sequence"
2140 self.assertEqual(i(None), "base")
2141 self.assertEqual(i({"a": 1}), "mapping")
2142 self.assertEqual(i([1, 2, 3]), "sequence")
2143 self.assertEqual(i((1, 2, 3)), "sequence")
2144 self.assertEqual(i("str"), "sequence")
2145
2146 # Registering classes as callables doesn't work with annotations,
2147 # you need to pass the type explicitly.
2148 @i.register(str)
2149 class _:
2150 def __init__(self, arg):
2151 self.arg = arg
2152
2153 def __eq__(self, other):
2154 return self.arg == other
2155 self.assertEqual(i("str"), "str")
2156
Ethan Smithc6512752018-05-26 16:38:33 -04002157 def test_method_register(self):
2158 class A:
2159 @functools.singledispatchmethod
2160 def t(self, arg):
2161 self.arg = "base"
2162 @t.register(int)
2163 def _(self, arg):
2164 self.arg = "int"
2165 @t.register(str)
2166 def _(self, arg):
2167 self.arg = "str"
2168 a = A()
2169
2170 a.t(0)
2171 self.assertEqual(a.arg, "int")
2172 aa = A()
2173 self.assertFalse(hasattr(aa, 'arg'))
2174 a.t('')
2175 self.assertEqual(a.arg, "str")
2176 aa = A()
2177 self.assertFalse(hasattr(aa, 'arg'))
2178 a.t(0.0)
2179 self.assertEqual(a.arg, "base")
2180 aa = A()
2181 self.assertFalse(hasattr(aa, 'arg'))
2182
2183 def test_staticmethod_register(self):
2184 class A:
2185 @functools.singledispatchmethod
2186 @staticmethod
2187 def t(arg):
2188 return arg
2189 @t.register(int)
2190 @staticmethod
2191 def _(arg):
2192 return isinstance(arg, int)
2193 @t.register(str)
2194 @staticmethod
2195 def _(arg):
2196 return isinstance(arg, str)
2197 a = A()
2198
2199 self.assertTrue(A.t(0))
2200 self.assertTrue(A.t(''))
2201 self.assertEqual(A.t(0.0), 0.0)
2202
2203 def test_classmethod_register(self):
2204 class A:
2205 def __init__(self, arg):
2206 self.arg = arg
2207
2208 @functools.singledispatchmethod
2209 @classmethod
2210 def t(cls, arg):
2211 return cls("base")
2212 @t.register(int)
2213 @classmethod
2214 def _(cls, arg):
2215 return cls("int")
2216 @t.register(str)
2217 @classmethod
2218 def _(cls, arg):
2219 return cls("str")
2220
2221 self.assertEqual(A.t(0).arg, "int")
2222 self.assertEqual(A.t('').arg, "str")
2223 self.assertEqual(A.t(0.0).arg, "base")
2224
2225 def test_callable_register(self):
2226 class A:
2227 def __init__(self, arg):
2228 self.arg = arg
2229
2230 @functools.singledispatchmethod
2231 @classmethod
2232 def t(cls, arg):
2233 return cls("base")
2234
2235 @A.t.register(int)
2236 @classmethod
2237 def _(cls, arg):
2238 return cls("int")
2239 @A.t.register(str)
2240 @classmethod
2241 def _(cls, arg):
2242 return cls("str")
2243
2244 self.assertEqual(A.t(0).arg, "int")
2245 self.assertEqual(A.t('').arg, "str")
2246 self.assertEqual(A.t(0.0).arg, "base")
2247
2248 def test_abstractmethod_register(self):
2249 class Abstract(abc.ABCMeta):
2250
2251 @functools.singledispatchmethod
2252 @abc.abstractmethod
2253 def add(self, x, y):
2254 pass
2255
2256 self.assertTrue(Abstract.add.__isabstractmethod__)
2257
2258 def test_type_ann_register(self):
2259 class A:
2260 @functools.singledispatchmethod
2261 def t(self, arg):
2262 return "base"
2263 @t.register
2264 def _(self, arg: int):
2265 return "int"
2266 @t.register
2267 def _(self, arg: str):
2268 return "str"
2269 a = A()
2270
2271 self.assertEqual(a.t(0), "int")
2272 self.assertEqual(a.t(''), "str")
2273 self.assertEqual(a.t(0.0), "base")
2274
Łukasz Langae5697532017-12-11 13:56:31 -08002275 def test_invalid_registrations(self):
2276 msg_prefix = "Invalid first argument to `register()`: "
2277 msg_suffix = (
2278 ". Use either `@register(some_class)` or plain `@register` on an "
2279 "annotated function."
2280 )
2281 @functools.singledispatch
2282 def i(arg):
2283 return "base"
2284 with self.assertRaises(TypeError) as exc:
2285 @i.register(42)
2286 def _(arg):
2287 return "I annotated with a non-type"
2288 self.assertTrue(str(exc.exception).startswith(msg_prefix + "42"))
2289 self.assertTrue(str(exc.exception).endswith(msg_suffix))
2290 with self.assertRaises(TypeError) as exc:
2291 @i.register
2292 def _(arg):
2293 return "I forgot to annotate"
2294 self.assertTrue(str(exc.exception).startswith(msg_prefix +
2295 "<function TestSingleDispatch.test_invalid_registrations.<locals>._"
2296 ))
2297 self.assertTrue(str(exc.exception).endswith(msg_suffix))
2298
2299 # FIXME: The following will only work after PEP 560 is implemented.
2300 return
2301
2302 with self.assertRaises(TypeError) as exc:
2303 @i.register
2304 def _(arg: typing.Iterable[str]):
2305 # At runtime, dispatching on generics is impossible.
2306 # When registering implementations with singledispatch, avoid
2307 # types from `typing`. Instead, annotate with regular types
2308 # or ABCs.
2309 return "I annotated with a generic collection"
2310 self.assertTrue(str(exc.exception).startswith(msg_prefix +
2311 "<function TestSingleDispatch.test_invalid_registrations.<locals>._"
2312 ))
2313 self.assertTrue(str(exc.exception).endswith(msg_suffix))
2314
Dong-hee Na445f1b32018-07-10 16:26:36 +09002315 def test_invalid_positional_argument(self):
2316 @functools.singledispatch
2317 def f(*args):
2318 pass
2319 msg = 'f requires at least 1 positional argument'
INADA Naoki56d8f572018-07-17 13:44:47 +09002320 with self.assertRaisesRegex(TypeError, msg):
Dong-hee Na445f1b32018-07-10 16:26:36 +09002321 f()
Łukasz Langa6f692512013-06-05 12:20:24 +02002322
Carl Meyerd658dea2018-08-28 01:11:56 -06002323
2324class CachedCostItem:
2325 _cost = 1
2326
2327 def __init__(self):
2328 self.lock = py_functools.RLock()
2329
2330 @py_functools.cached_property
2331 def cost(self):
2332 """The cost of the item."""
2333 with self.lock:
2334 self._cost += 1
2335 return self._cost
2336
2337
2338class OptionallyCachedCostItem:
2339 _cost = 1
2340
2341 def get_cost(self):
2342 """The cost of the item."""
2343 self._cost += 1
2344 return self._cost
2345
2346 cached_cost = py_functools.cached_property(get_cost)
2347
2348
2349class CachedCostItemWait:
2350
2351 def __init__(self, event):
2352 self._cost = 1
2353 self.lock = py_functools.RLock()
2354 self.event = event
2355
2356 @py_functools.cached_property
2357 def cost(self):
2358 self.event.wait(1)
2359 with self.lock:
2360 self._cost += 1
2361 return self._cost
2362
2363
2364class CachedCostItemWithSlots:
2365 __slots__ = ('_cost')
2366
2367 def __init__(self):
2368 self._cost = 1
2369
2370 @py_functools.cached_property
2371 def cost(self):
2372 raise RuntimeError('never called, slots not supported')
2373
2374
2375class TestCachedProperty(unittest.TestCase):
2376 def test_cached(self):
2377 item = CachedCostItem()
2378 self.assertEqual(item.cost, 2)
2379 self.assertEqual(item.cost, 2) # not 3
2380
2381 def test_cached_attribute_name_differs_from_func_name(self):
2382 item = OptionallyCachedCostItem()
2383 self.assertEqual(item.get_cost(), 2)
2384 self.assertEqual(item.cached_cost, 3)
2385 self.assertEqual(item.get_cost(), 4)
2386 self.assertEqual(item.cached_cost, 3)
2387
2388 def test_threaded(self):
2389 go = threading.Event()
2390 item = CachedCostItemWait(go)
2391
2392 num_threads = 3
2393
2394 orig_si = sys.getswitchinterval()
2395 sys.setswitchinterval(1e-6)
2396 try:
2397 threads = [
2398 threading.Thread(target=lambda: item.cost)
2399 for k in range(num_threads)
2400 ]
2401 with support.start_threads(threads):
2402 go.set()
2403 finally:
2404 sys.setswitchinterval(orig_si)
2405
2406 self.assertEqual(item.cost, 2)
2407
2408 def test_object_with_slots(self):
2409 item = CachedCostItemWithSlots()
2410 with self.assertRaisesRegex(
2411 TypeError,
2412 "No '__dict__' attribute on 'CachedCostItemWithSlots' instance to cache 'cost' property.",
2413 ):
2414 item.cost
2415
2416 def test_immutable_dict(self):
2417 class MyMeta(type):
2418 @py_functools.cached_property
2419 def prop(self):
2420 return True
2421
2422 class MyClass(metaclass=MyMeta):
2423 pass
2424
2425 with self.assertRaisesRegex(
2426 TypeError,
2427 "The '__dict__' attribute on 'MyMeta' instance does not support item assignment for caching 'prop' property.",
2428 ):
2429 MyClass.prop
2430
2431 def test_reuse_different_names(self):
2432 """Disallow this case because decorated function a would not be cached."""
2433 with self.assertRaises(RuntimeError) as ctx:
2434 class ReusedCachedProperty:
2435 @py_functools.cached_property
2436 def a(self):
2437 pass
2438
2439 b = a
2440
2441 self.assertEqual(
2442 str(ctx.exception.__context__),
2443 str(TypeError("Cannot assign the same cached_property to two different names ('a' and 'b')."))
2444 )
2445
2446 def test_reuse_same_name(self):
2447 """Reusing a cached_property on different classes under the same name is OK."""
2448 counter = 0
2449
2450 @py_functools.cached_property
2451 def _cp(_self):
2452 nonlocal counter
2453 counter += 1
2454 return counter
2455
2456 class A:
2457 cp = _cp
2458
2459 class B:
2460 cp = _cp
2461
2462 a = A()
2463 b = B()
2464
2465 self.assertEqual(a.cp, 1)
2466 self.assertEqual(b.cp, 2)
2467 self.assertEqual(a.cp, 1)
2468
2469 def test_set_name_not_called(self):
2470 cp = py_functools.cached_property(lambda s: None)
2471 class Foo:
2472 pass
2473
2474 Foo.cp = cp
2475
2476 with self.assertRaisesRegex(
2477 TypeError,
2478 "Cannot use cached_property instance without calling __set_name__ on it.",
2479 ):
2480 Foo().cp
2481
2482 def test_access_from_class(self):
2483 self.assertIsInstance(CachedCostItem.cost, py_functools.cached_property)
2484
2485 def test_doc(self):
2486 self.assertEqual(CachedCostItem.cost.__doc__, "The cost of the item.")
2487
2488
Raymond Hettinger9c323f82005-02-28 19:39:44 +00002489if __name__ == '__main__':
Zachary Ware38c707e2015-04-13 15:00:43 -05002490 unittest.main()