blob: a91c6348e709d73df9c81f47b059fa6d82962cbb [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08002import builtins
Raymond Hettinger003be522011-05-03 11:01:32 -07003import collections
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03004import collections.abc
Serhiy Storchaka45120f22015-10-24 09:49:56 +03005import copy
Łukasz Langa6f692512013-06-05 12:20:24 +02006from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00007import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00008from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02009import sys
10from test import support
Antoine Pitroua6a4dc82017-09-07 18:56:24 +020011import threading
Serhiy Storchaka67796522017-01-12 18:34:33 +020012import time
Łukasz Langae5697532017-12-11 13:56:31 -080013import typing
Łukasz Langa6f692512013-06-05 12:20:24 +020014import unittest
Raymond Hettingerd191ef22017-01-07 20:44:48 -080015import unittest.mock
Łukasz Langa6f692512013-06-05 12:20:24 +020016from weakref import proxy
Nick Coghlan457fc9a2016-09-10 20:00:02 +100017import contextlib
Raymond Hettinger9c323f82005-02-28 19:39:44 +000018
Antoine Pitroub5b37142012-11-13 21:35:40 +010019import functools
20
Antoine Pitroub5b37142012-11-13 21:35:40 +010021py_functools = support.import_fresh_module('functools', blocked=['_functools'])
22c_functools = support.import_fresh_module('functools', fresh=['_functools'])
23
Łukasz Langa6f692512013-06-05 12:20:24 +020024decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
25
Nick Coghlan457fc9a2016-09-10 20:00:02 +100026@contextlib.contextmanager
27def replaced_module(name, replacement):
28 original_module = sys.modules[name]
29 sys.modules[name] = replacement
30 try:
31 yield
32 finally:
33 sys.modules[name] = original_module
Łukasz Langa6f692512013-06-05 12:20:24 +020034
Raymond Hettinger9c323f82005-02-28 19:39:44 +000035def capture(*args, **kw):
36 """capture all positional and keyword arguments"""
37 return args, kw
38
Łukasz Langa6f692512013-06-05 12:20:24 +020039
Jack Diederiche0cbd692009-04-01 04:27:09 +000040def signature(part):
41 """ return the signature of a partial object """
42 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000043
Serhiy Storchaka38741282016-02-02 18:45:17 +020044class MyTuple(tuple):
45 pass
46
47class BadTuple(tuple):
48 def __add__(self, other):
49 return list(self) + list(other)
50
51class MyDict(dict):
52 pass
53
Łukasz Langa6f692512013-06-05 12:20:24 +020054
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020055class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000056
57 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010058 p = self.partial(capture, 1, 2, a=10, b=20)
59 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000060 self.assertEqual(p(3, 4, b=30, c=40),
61 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010062 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000063 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000064
65 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010066 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000067 # attributes should be readable
68 self.assertEqual(p.func, capture)
69 self.assertEqual(p.args, (1, 2))
70 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000071
72 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010073 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000074 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010075 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000076 except TypeError:
77 pass
78 else:
79 self.fail('First arg not checked for callability')
80
81 def test_protection_of_callers_dict_argument(self):
82 # a caller's dictionary should not be altered by partial
83 def func(a=10, b=20):
84 return a
85 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010086 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000087 self.assertEqual(p(**d), 3)
88 self.assertEqual(d, {'a':3})
89 p(b=7)
90 self.assertEqual(d, {'a':3})
91
Serhiy Storchaka9639e4a2017-02-20 14:04:30 +020092 def test_kwargs_copy(self):
93 # Issue #29532: Altering a kwarg dictionary passed to a constructor
94 # should not affect a partial object after creation
95 d = {'a': 3}
96 p = self.partial(capture, **d)
97 self.assertEqual(p(), ((), {'a': 3}))
98 d['a'] = 5
99 self.assertEqual(p(), ((), {'a': 3}))
100
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000101 def test_arg_combinations(self):
102 # exercise special code paths for zero args in either partial
103 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100104 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000105 self.assertEqual(p(), ((), {}))
106 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100107 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000108 self.assertEqual(p(), ((1,2), {}))
109 self.assertEqual(p(3,4), ((1,2,3,4), {}))
110
111 def test_kw_combinations(self):
112 # exercise special code paths for no keyword args in
113 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100114 p = self.partial(capture)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400115 self.assertEqual(p.keywords, {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000116 self.assertEqual(p(), ((), {}))
117 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100118 p = self.partial(capture, a=1)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400119 self.assertEqual(p.keywords, {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000120 self.assertEqual(p(), ((), {'a':1}))
121 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
122 # keyword args in the call override those in the partial object
123 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
124
125 def test_positional(self):
126 # make sure positional arguments are captured correctly
127 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100128 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000129 expected = args + ('x',)
130 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000131 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000132
133 def test_keyword(self):
134 # make sure keyword arguments are captured correctly
135 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100136 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000137 expected = {'a':a,'x':None}
138 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000139 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000140
141 def test_no_side_effects(self):
142 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100143 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000144 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000145 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000146 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000147 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000148
149 def test_error_propagation(self):
150 def f(x, y):
151 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100152 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
153 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
154 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
155 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000156
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000157 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100158 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000159 p = proxy(f)
160 self.assertEqual(f.func, p.func)
161 f = None
162 self.assertRaises(ReferenceError, getattr, p, 'func')
163
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000164 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000165 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100166 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000167 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100168 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000169 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000170
Alexander Belopolskye49af342015-03-01 15:08:17 -0500171 def test_nested_optimization(self):
172 partial = self.partial
Alexander Belopolskye49af342015-03-01 15:08:17 -0500173 inner = partial(signature, 'asdf')
174 nested = partial(inner, bar=True)
175 flat = partial(signature, 'asdf', bar=True)
176 self.assertEqual(signature(nested), signature(flat))
177
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300178 def test_nested_partial_with_attribute(self):
179 # see issue 25137
180 partial = self.partial
181
182 def foo(bar):
183 return bar
184
185 p = partial(foo, 'first')
186 p2 = partial(p, 'second')
187 p2.new_attr = 'spam'
188 self.assertEqual(p2.new_attr, 'spam')
189
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000190 def test_repr(self):
191 args = (object(), object())
192 args_repr = ', '.join(repr(a) for a in args)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200193 kwargs = {'a': object(), 'b': object()}
194 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
195 'b={b!r}, a={a!r}'.format_map(kwargs)]
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000196 if self.partial in (c_functools.partial, py_functools.partial):
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000197 name = 'functools.partial'
198 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100199 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000200
Antoine Pitroub5b37142012-11-13 21:35:40 +0100201 f = self.partial(capture)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000202 self.assertEqual(f'{name}({capture!r})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000203
Antoine Pitroub5b37142012-11-13 21:35:40 +0100204 f = self.partial(capture, *args)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000205 self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000206
Antoine Pitroub5b37142012-11-13 21:35:40 +0100207 f = self.partial(capture, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200208 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000209 [f'{name}({capture!r}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200210 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000211
Antoine Pitroub5b37142012-11-13 21:35:40 +0100212 f = self.partial(capture, *args, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200213 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000214 [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200215 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000216
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300217 def test_recursive_repr(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000218 if self.partial in (c_functools.partial, py_functools.partial):
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300219 name = 'functools.partial'
220 else:
221 name = self.partial.__name__
222
223 f = self.partial(capture)
224 f.__setstate__((f, (), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300225 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000226 self.assertEqual(repr(f), '%s(...)' % (name,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300227 finally:
228 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300229
230 f = self.partial(capture)
231 f.__setstate__((capture, (f,), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300232 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000233 self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300234 finally:
235 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300236
237 f = self.partial(capture)
238 f.__setstate__((capture, (), {'a': f}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300239 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000240 self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300241 finally:
242 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300243
Jack Diederiche0cbd692009-04-01 04:27:09 +0000244 def test_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000245 with self.AllowPickle():
246 f = self.partial(signature, ['asdf'], bar=[True])
247 f.attr = []
248 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
249 f_copy = pickle.loads(pickle.dumps(f, proto))
250 self.assertEqual(signature(f_copy), signature(f))
Serhiy Storchaka38741282016-02-02 18:45:17 +0200251
252 def test_copy(self):
253 f = self.partial(signature, ['asdf'], bar=[True])
254 f.attr = []
255 f_copy = copy.copy(f)
256 self.assertEqual(signature(f_copy), signature(f))
257 self.assertIs(f_copy.attr, f.attr)
258 self.assertIs(f_copy.args, f.args)
259 self.assertIs(f_copy.keywords, f.keywords)
260
261 def test_deepcopy(self):
262 f = self.partial(signature, ['asdf'], bar=[True])
263 f.attr = []
264 f_copy = copy.deepcopy(f)
265 self.assertEqual(signature(f_copy), signature(f))
266 self.assertIsNot(f_copy.attr, f.attr)
267 self.assertIsNot(f_copy.args, f.args)
268 self.assertIsNot(f_copy.args[0], f.args[0])
269 self.assertIsNot(f_copy.keywords, f.keywords)
270 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
271
272 def test_setstate(self):
273 f = self.partial(signature)
274 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000275
Serhiy Storchaka38741282016-02-02 18:45:17 +0200276 self.assertEqual(signature(f),
277 (capture, (1,), dict(a=10), dict(attr=[])))
278 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
279
280 f.__setstate__((capture, (1,), dict(a=10), None))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000281
Serhiy Storchaka38741282016-02-02 18:45:17 +0200282 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
283 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
284
285 f.__setstate__((capture, (1,), None, None))
286 #self.assertEqual(signature(f), (capture, (1,), {}, {}))
287 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
288 self.assertEqual(f(2), ((1, 2), {}))
289 self.assertEqual(f(), ((1,), {}))
290
291 f.__setstate__((capture, (), {}, None))
292 self.assertEqual(signature(f), (capture, (), {}, {}))
293 self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
294 self.assertEqual(f(2), ((2,), {}))
295 self.assertEqual(f(), ((), {}))
296
297 def test_setstate_errors(self):
298 f = self.partial(signature)
299 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
300 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
301 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
302 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
303 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
304 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
305 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
306
307 def test_setstate_subclasses(self):
308 f = self.partial(signature)
309 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
310 s = signature(f)
311 self.assertEqual(s, (capture, (1,), dict(a=10), {}))
312 self.assertIs(type(s[1]), tuple)
313 self.assertIs(type(s[2]), dict)
314 r = f()
315 self.assertEqual(r, ((1,), {'a': 10}))
316 self.assertIs(type(r[0]), tuple)
317 self.assertIs(type(r[1]), dict)
318
319 f.__setstate__((capture, BadTuple((1,)), {}, None))
320 s = signature(f)
321 self.assertEqual(s, (capture, (1,), {}, {}))
322 self.assertIs(type(s[1]), tuple)
323 r = f(2)
324 self.assertEqual(r, ((1, 2), {}))
325 self.assertIs(type(r[0]), tuple)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000326
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300327 def test_recursive_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000328 with self.AllowPickle():
329 f = self.partial(capture)
330 f.__setstate__((f, (), {}, {}))
331 try:
332 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
333 with self.assertRaises(RecursionError):
334 pickle.dumps(f, proto)
335 finally:
336 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300337
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000338 f = self.partial(capture)
339 f.__setstate__((capture, (f,), {}, {}))
340 try:
341 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
342 f_copy = pickle.loads(pickle.dumps(f, proto))
343 try:
344 self.assertIs(f_copy.args[0], f_copy)
345 finally:
346 f_copy.__setstate__((capture, (), {}, {}))
347 finally:
348 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300349
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000350 f = self.partial(capture)
351 f.__setstate__((capture, (), {'a': f}, {}))
352 try:
353 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
354 f_copy = pickle.loads(pickle.dumps(f, proto))
355 try:
356 self.assertIs(f_copy.keywords['a'], f_copy)
357 finally:
358 f_copy.__setstate__((capture, (), {}, {}))
359 finally:
360 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300361
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200362 # Issue 6083: Reference counting bug
363 def test_setstate_refcount(self):
364 class BadSequence:
365 def __len__(self):
366 return 4
367 def __getitem__(self, key):
368 if key == 0:
369 return max
370 elif key == 1:
371 return tuple(range(1000000))
372 elif key in (2, 3):
373 return {}
374 raise IndexError
375
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200376 f = self.partial(object)
Serhiy Storchaka38741282016-02-02 18:45:17 +0200377 self.assertRaises(TypeError, f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000378
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000379@unittest.skipUnless(c_functools, 'requires the C _functools module')
380class TestPartialC(TestPartial, unittest.TestCase):
381 if c_functools:
382 partial = c_functools.partial
383
384 class AllowPickle:
385 def __enter__(self):
386 return self
387 def __exit__(self, type, value, tb):
388 return False
389
390 def test_attributes_unwritable(self):
391 # attributes should not be writable
392 p = self.partial(capture, 1, 2, a=10, b=20)
393 self.assertRaises(AttributeError, setattr, p, 'func', map)
394 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
395 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
396
397 p = self.partial(hex)
398 try:
399 del p.__dict__
400 except TypeError:
401 pass
402 else:
403 self.fail('partial object allowed __dict__ to be deleted')
Łukasz Langa6f692512013-06-05 12:20:24 +0200404
Michael Seifert6c3d5272017-03-15 06:26:33 +0100405 def test_manually_adding_non_string_keyword(self):
406 p = self.partial(capture)
407 # Adding a non-string/unicode keyword to partial kwargs
408 p.keywords[1234] = 'value'
409 r = repr(p)
410 self.assertIn('1234', r)
411 self.assertIn("'value'", r)
412 with self.assertRaises(TypeError):
413 p()
414
415 def test_keystr_replaces_value(self):
416 p = self.partial(capture)
417
418 class MutatesYourDict(object):
419 def __str__(self):
420 p.keywords[self] = ['sth2']
421 return 'astr'
422
Mike53f7a7c2017-12-14 14:04:53 +0300423 # Replacing the value during key formatting should keep the original
Michael Seifert6c3d5272017-03-15 06:26:33 +0100424 # value alive (at least long enough).
425 p.keywords[MutatesYourDict()] = ['sth']
426 r = repr(p)
427 self.assertIn('astr', r)
428 self.assertIn("['sth']", r)
429
430
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200431class TestPartialPy(TestPartial, unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000432 partial = py_functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000433
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000434 class AllowPickle:
435 def __init__(self):
436 self._cm = replaced_module("functools", py_functools)
437 def __enter__(self):
438 return self._cm.__enter__()
439 def __exit__(self, type, value, tb):
440 return self._cm.__exit__(type, value, tb)
Łukasz Langa6f692512013-06-05 12:20:24 +0200441
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200442if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000443 class CPartialSubclass(c_functools.partial):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200444 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100445
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000446class PyPartialSubclass(py_functools.partial):
447 pass
Łukasz Langa6f692512013-06-05 12:20:24 +0200448
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200449@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200450class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200451 if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000452 partial = CPartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000453
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300454 # partial subclasses are not optimized for nested calls
455 test_nested_optimization = None
456
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000457class TestPartialPySubclass(TestPartialPy):
458 partial = PyPartialSubclass
Łukasz Langa6f692512013-06-05 12:20:24 +0200459
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000460class TestPartialMethod(unittest.TestCase):
461
462 class A(object):
463 nothing = functools.partialmethod(capture)
464 positional = functools.partialmethod(capture, 1)
465 keywords = functools.partialmethod(capture, a=2)
466 both = functools.partialmethod(capture, 3, b=4)
467
468 nested = functools.partialmethod(positional, 5)
469
470 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
471
472 static = functools.partialmethod(staticmethod(capture), 8)
473 cls = functools.partialmethod(classmethod(capture), d=9)
474
475 a = A()
476
477 def test_arg_combinations(self):
478 self.assertEqual(self.a.nothing(), ((self.a,), {}))
479 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
480 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
481 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
482
483 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
484 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
485 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
486 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
487
488 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
489 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
490 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
491 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
492
493 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
494 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
495 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
496 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
497
498 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
499
500 def test_nested(self):
501 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
502 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
503 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
504 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
505
506 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
507
508 def test_over_partial(self):
509 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
510 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
511 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
512 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
513
514 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
515
516 def test_bound_method_introspection(self):
517 obj = self.a
518 self.assertIs(obj.both.__self__, obj)
519 self.assertIs(obj.nested.__self__, obj)
520 self.assertIs(obj.over_partial.__self__, obj)
521 self.assertIs(obj.cls.__self__, self.A)
522 self.assertIs(self.A.cls.__self__, self.A)
523
524 def test_unbound_method_retrieval(self):
525 obj = self.A
526 self.assertFalse(hasattr(obj.both, "__self__"))
527 self.assertFalse(hasattr(obj.nested, "__self__"))
528 self.assertFalse(hasattr(obj.over_partial, "__self__"))
529 self.assertFalse(hasattr(obj.static, "__self__"))
530 self.assertFalse(hasattr(self.a.static, "__self__"))
531
532 def test_descriptors(self):
533 for obj in [self.A, self.a]:
534 with self.subTest(obj=obj):
535 self.assertEqual(obj.static(), ((8,), {}))
536 self.assertEqual(obj.static(5), ((8, 5), {}))
537 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
538 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
539
540 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
541 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
542 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
543 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
544
545 def test_overriding_keywords(self):
546 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
547 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
548
549 def test_invalid_args(self):
550 with self.assertRaises(TypeError):
551 class B(object):
552 method = functools.partialmethod(None, 1)
553
554 def test_repr(self):
555 self.assertEqual(repr(vars(self.A)['both']),
556 'functools.partialmethod({}, 3, b=4)'.format(capture))
557
558 def test_abstract(self):
559 class Abstract(abc.ABCMeta):
560
561 @abc.abstractmethod
562 def add(self, x, y):
563 pass
564
565 add5 = functools.partialmethod(add, 5)
566
567 self.assertTrue(Abstract.add.__isabstractmethod__)
568 self.assertTrue(Abstract.add5.__isabstractmethod__)
569
570 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
571 self.assertFalse(getattr(func, '__isabstractmethod__', False))
572
573
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000574class TestUpdateWrapper(unittest.TestCase):
575
576 def check_wrapper(self, wrapper, wrapped,
577 assigned=functools.WRAPPER_ASSIGNMENTS,
578 updated=functools.WRAPPER_UPDATES):
579 # Check attributes were assigned
580 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000581 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000582 # Check attributes were updated
583 for name in updated:
584 wrapper_attr = getattr(wrapper, name)
585 wrapped_attr = getattr(wrapped, name)
586 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000587 if name == "__dict__" and key == "__wrapped__":
588 # __wrapped__ is overwritten by the update code
589 continue
590 self.assertIs(wrapped_attr[key], wrapper_attr[key])
591 # Check __wrapped__
592 self.assertIs(wrapper.__wrapped__, wrapped)
593
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000594
R. David Murray378c0cf2010-02-24 01:46:21 +0000595 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000596 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000597 """This is a test"""
598 pass
599 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000600 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000601 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000602 pass
603 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000604 return wrapper, f
605
606 def test_default_update(self):
607 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000608 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000609 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000610 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600611 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000612 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000613 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
614 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000615
R. David Murray378c0cf2010-02-24 01:46:21 +0000616 @unittest.skipIf(sys.flags.optimize >= 2,
617 "Docstrings are omitted with -O2 and above")
618 def test_default_update_doc(self):
619 wrapper, f = self._default_update()
620 self.assertEqual(wrapper.__doc__, 'This is a test')
621
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000622 def test_no_update(self):
623 def f():
624 """This is a test"""
625 pass
626 f.attr = 'This is also a test'
627 def wrapper():
628 pass
629 functools.update_wrapper(wrapper, f, (), ())
630 self.check_wrapper(wrapper, f, (), ())
631 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600632 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000633 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000634 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000635 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000636
637 def test_selective_update(self):
638 def f():
639 pass
640 f.attr = 'This is a different test'
641 f.dict_attr = dict(a=1, b=2, c=3)
642 def wrapper():
643 pass
644 wrapper.dict_attr = {}
645 assign = ('attr',)
646 update = ('dict_attr',)
647 functools.update_wrapper(wrapper, f, assign, update)
648 self.check_wrapper(wrapper, f, assign, update)
649 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600650 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000651 self.assertEqual(wrapper.__doc__, None)
652 self.assertEqual(wrapper.attr, 'This is a different test')
653 self.assertEqual(wrapper.dict_attr, f.dict_attr)
654
Nick Coghlan98876832010-08-17 06:17:18 +0000655 def test_missing_attributes(self):
656 def f():
657 pass
658 def wrapper():
659 pass
660 wrapper.dict_attr = {}
661 assign = ('attr',)
662 update = ('dict_attr',)
663 # Missing attributes on wrapped object are ignored
664 functools.update_wrapper(wrapper, f, assign, update)
665 self.assertNotIn('attr', wrapper.__dict__)
666 self.assertEqual(wrapper.dict_attr, {})
667 # Wrapper must have expected attributes for updating
668 del wrapper.dict_attr
669 with self.assertRaises(AttributeError):
670 functools.update_wrapper(wrapper, f, assign, update)
671 wrapper.dict_attr = 1
672 with self.assertRaises(AttributeError):
673 functools.update_wrapper(wrapper, f, assign, update)
674
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200675 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000676 @unittest.skipIf(sys.flags.optimize >= 2,
677 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000678 def test_builtin_update(self):
679 # Test for bug #1576241
680 def wrapper():
681 pass
682 functools.update_wrapper(wrapper, max)
683 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000684 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000685 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000686
Łukasz Langa6f692512013-06-05 12:20:24 +0200687
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000688class TestWraps(TestUpdateWrapper):
689
R. David Murray378c0cf2010-02-24 01:46:21 +0000690 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000691 def f():
692 """This is a test"""
693 pass
694 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000695 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000696 @functools.wraps(f)
697 def wrapper():
698 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600699 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000700
701 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600702 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000703 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000704 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600705 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000706 self.assertEqual(wrapper.attr, 'This is also a test')
707
Antoine Pitroub5b37142012-11-13 21:35:40 +0100708 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000709 "Docstrings are omitted with -O2 and above")
710 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600711 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000712 self.assertEqual(wrapper.__doc__, 'This is a test')
713
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000714 def test_no_update(self):
715 def f():
716 """This is a test"""
717 pass
718 f.attr = 'This is also a test'
719 @functools.wraps(f, (), ())
720 def wrapper():
721 pass
722 self.check_wrapper(wrapper, f, (), ())
723 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600724 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000725 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000726 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000727
728 def test_selective_update(self):
729 def f():
730 pass
731 f.attr = 'This is a different test'
732 f.dict_attr = dict(a=1, b=2, c=3)
733 def add_dict_attr(f):
734 f.dict_attr = {}
735 return f
736 assign = ('attr',)
737 update = ('dict_attr',)
738 @functools.wraps(f, assign, update)
739 @add_dict_attr
740 def wrapper():
741 pass
742 self.check_wrapper(wrapper, f, assign, update)
743 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600744 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000745 self.assertEqual(wrapper.__doc__, None)
746 self.assertEqual(wrapper.attr, 'This is a different test')
747 self.assertEqual(wrapper.dict_attr, f.dict_attr)
748
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000749@unittest.skipUnless(c_functools, 'requires the C _functools module')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000750class TestReduce(unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000751 if c_functools:
752 func = c_functools.reduce
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000753
754 def test_reduce(self):
755 class Squares:
756 def __init__(self, max):
757 self.max = max
758 self.sofar = []
759
760 def __len__(self):
761 return len(self.sofar)
762
763 def __getitem__(self, i):
764 if not 0 <= i < self.max: raise IndexError
765 n = len(self.sofar)
766 while n <= i:
767 self.sofar.append(n*n)
768 n += 1
769 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000770 def add(x, y):
771 return x + y
772 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000773 self.assertEqual(
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000774 self.func(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000775 ['a','c','d','w']
776 )
777 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
778 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000779 self.func(lambda x, y: x*y, range(2,21), 1),
780 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000781 )
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000782 self.assertEqual(self.func(add, Squares(10)), 285)
783 self.assertEqual(self.func(add, Squares(10), 0), 285)
784 self.assertEqual(self.func(add, Squares(0), 0), 0)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000785 self.assertRaises(TypeError, self.func)
786 self.assertRaises(TypeError, self.func, 42, 42)
787 self.assertRaises(TypeError, self.func, 42, 42, 42)
788 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
789 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
790 self.assertRaises(TypeError, self.func, 42, (42, 42))
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000791 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
792 self.assertRaises(TypeError, self.func, add, "")
793 self.assertRaises(TypeError, self.func, add, ())
794 self.assertRaises(TypeError, self.func, add, object())
795
796 class TestFailingIter:
797 def __iter__(self):
798 raise RuntimeError
799 self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
800
801 self.assertEqual(self.func(add, [], None), None)
802 self.assertEqual(self.func(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000803
804 class BadSeq:
805 def __getitem__(self, index):
806 raise ValueError
807 self.assertRaises(ValueError, self.func, 42, BadSeq())
808
809 # Test reduce()'s use of iterators.
810 def test_iterator_usage(self):
811 class SequenceClass:
812 def __init__(self, n):
813 self.n = n
814 def __getitem__(self, i):
815 if 0 <= i < self.n:
816 return i
817 else:
818 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000819
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000820 from operator import add
821 self.assertEqual(self.func(add, SequenceClass(5)), 10)
822 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
823 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
824 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
825 self.assertEqual(self.func(add, SequenceClass(1)), 0)
826 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
827
828 d = {"one": 1, "two": 2, "three": 3}
829 self.assertEqual(self.func(add, d), "".join(d.keys()))
830
Łukasz Langa6f692512013-06-05 12:20:24 +0200831
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200832class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700833
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000834 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700835 def cmp1(x, y):
836 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100837 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700838 self.assertEqual(key(3), key(3))
839 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100840 self.assertGreaterEqual(key(3), key(3))
841
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700842 def cmp2(x, y):
843 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100844 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700845 self.assertEqual(key(4.0), key('4'))
846 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100847 self.assertLessEqual(key(2), key('35'))
848 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700849
850 def test_cmp_to_key_arguments(self):
851 def cmp1(x, y):
852 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100853 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700854 self.assertEqual(key(obj=3), key(obj=3))
855 self.assertGreater(key(obj=3), key(obj=1))
856 with self.assertRaises((TypeError, AttributeError)):
857 key(3) > 1 # rhs is not a K object
858 with self.assertRaises((TypeError, AttributeError)):
859 1 < key(3) # lhs is not a K object
860 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100861 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700862 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200863 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100864 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700865 with self.assertRaises(TypeError):
866 key() # too few args
867 with self.assertRaises(TypeError):
868 key(None, None) # too many args
869
870 def test_bad_cmp(self):
871 def cmp1(x, y):
872 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100873 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700874 with self.assertRaises(ZeroDivisionError):
875 key(3) > key(1)
876
877 class BadCmp:
878 def __lt__(self, other):
879 raise ZeroDivisionError
880 def cmp1(x, y):
881 return BadCmp()
882 with self.assertRaises(ZeroDivisionError):
883 key(3) > key(1)
884
885 def test_obj_field(self):
886 def cmp1(x, y):
887 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100888 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700889 self.assertEqual(key(50).obj, 50)
890
891 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000892 def mycmp(x, y):
893 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100894 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000895 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000896
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700897 def test_sort_int_str(self):
898 def mycmp(x, y):
899 x, y = int(x), int(y)
900 return (x > y) - (x < y)
901 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100902 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700903 self.assertEqual([int(value) for value in values],
904 [0, 1, 1, 2, 3, 4, 5, 7, 10])
905
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000906 def test_hash(self):
907 def mycmp(x, y):
908 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100909 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000910 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700911 self.assertRaises(TypeError, hash, k)
Serhiy Storchaka2e576f52017-04-24 09:05:00 +0300912 self.assertNotIsInstance(k, collections.abc.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000913
Łukasz Langa6f692512013-06-05 12:20:24 +0200914
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200915@unittest.skipUnless(c_functools, 'requires the C _functools module')
916class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
917 if c_functools:
918 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100919
Łukasz Langa6f692512013-06-05 12:20:24 +0200920
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200921class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100922 cmp_to_key = staticmethod(py_functools.cmp_to_key)
923
Łukasz Langa6f692512013-06-05 12:20:24 +0200924
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000925class TestTotalOrdering(unittest.TestCase):
926
927 def test_total_ordering_lt(self):
928 @functools.total_ordering
929 class A:
930 def __init__(self, value):
931 self.value = value
932 def __lt__(self, other):
933 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000934 def __eq__(self, other):
935 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000936 self.assertTrue(A(1) < A(2))
937 self.assertTrue(A(2) > A(1))
938 self.assertTrue(A(1) <= A(2))
939 self.assertTrue(A(2) >= A(1))
940 self.assertTrue(A(2) <= A(2))
941 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000942 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000943
944 def test_total_ordering_le(self):
945 @functools.total_ordering
946 class A:
947 def __init__(self, value):
948 self.value = value
949 def __le__(self, other):
950 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000951 def __eq__(self, other):
952 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000953 self.assertTrue(A(1) < A(2))
954 self.assertTrue(A(2) > A(1))
955 self.assertTrue(A(1) <= A(2))
956 self.assertTrue(A(2) >= A(1))
957 self.assertTrue(A(2) <= A(2))
958 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000959 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000960
961 def test_total_ordering_gt(self):
962 @functools.total_ordering
963 class A:
964 def __init__(self, value):
965 self.value = value
966 def __gt__(self, other):
967 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000968 def __eq__(self, other):
969 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000970 self.assertTrue(A(1) < A(2))
971 self.assertTrue(A(2) > A(1))
972 self.assertTrue(A(1) <= A(2))
973 self.assertTrue(A(2) >= A(1))
974 self.assertTrue(A(2) <= A(2))
975 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000976 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000977
978 def test_total_ordering_ge(self):
979 @functools.total_ordering
980 class A:
981 def __init__(self, value):
982 self.value = value
983 def __ge__(self, other):
984 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000985 def __eq__(self, other):
986 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000987 self.assertTrue(A(1) < A(2))
988 self.assertTrue(A(2) > A(1))
989 self.assertTrue(A(1) <= A(2))
990 self.assertTrue(A(2) >= A(1))
991 self.assertTrue(A(2) <= A(2))
992 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000993 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000994
995 def test_total_ordering_no_overwrite(self):
996 # new methods should not overwrite existing
997 @functools.total_ordering
998 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000999 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +00001000 self.assertTrue(A(1) < A(2))
1001 self.assertTrue(A(2) > A(1))
1002 self.assertTrue(A(1) <= A(2))
1003 self.assertTrue(A(2) >= A(1))
1004 self.assertTrue(A(2) <= A(2))
1005 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001006
Benjamin Peterson42ebee32010-04-11 01:43:16 +00001007 def test_no_operations_defined(self):
1008 with self.assertRaises(ValueError):
1009 @functools.total_ordering
1010 class A:
1011 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001012
Nick Coghlanf05d9812013-10-02 00:02:03 +10001013 def test_type_error_when_not_implemented(self):
1014 # bug 10042; ensure stack overflow does not occur
1015 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001016 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +10001017 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001018 def __init__(self, value):
1019 self.value = value
1020 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001021 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001022 return self.value == other.value
1023 return False
1024 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001025 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001026 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +10001027 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001028
Nick Coghlanf05d9812013-10-02 00:02:03 +10001029 @functools.total_ordering
1030 class ImplementsGreaterThan:
1031 def __init__(self, value):
1032 self.value = value
1033 def __eq__(self, other):
1034 if isinstance(other, ImplementsGreaterThan):
1035 return self.value == other.value
1036 return False
1037 def __gt__(self, other):
1038 if isinstance(other, ImplementsGreaterThan):
1039 return self.value > other.value
1040 return NotImplemented
1041
1042 @functools.total_ordering
1043 class ImplementsLessThanEqualTo:
1044 def __init__(self, value):
1045 self.value = value
1046 def __eq__(self, other):
1047 if isinstance(other, ImplementsLessThanEqualTo):
1048 return self.value == other.value
1049 return False
1050 def __le__(self, other):
1051 if isinstance(other, ImplementsLessThanEqualTo):
1052 return self.value <= other.value
1053 return NotImplemented
1054
1055 @functools.total_ordering
1056 class ImplementsGreaterThanEqualTo:
1057 def __init__(self, value):
1058 self.value = value
1059 def __eq__(self, other):
1060 if isinstance(other, ImplementsGreaterThanEqualTo):
1061 return self.value == other.value
1062 return False
1063 def __ge__(self, other):
1064 if isinstance(other, ImplementsGreaterThanEqualTo):
1065 return self.value >= other.value
1066 return NotImplemented
1067
1068 @functools.total_ordering
1069 class ComparatorNotImplemented:
1070 def __init__(self, value):
1071 self.value = value
1072 def __eq__(self, other):
1073 if isinstance(other, ComparatorNotImplemented):
1074 return self.value == other.value
1075 return False
1076 def __lt__(self, other):
1077 return NotImplemented
1078
1079 with self.subTest("LT < 1"), self.assertRaises(TypeError):
1080 ImplementsLessThan(-1) < 1
1081
1082 with self.subTest("LT < LE"), self.assertRaises(TypeError):
1083 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
1084
1085 with self.subTest("LT < GT"), self.assertRaises(TypeError):
1086 ImplementsLessThan(1) < ImplementsGreaterThan(1)
1087
1088 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
1089 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
1090
1091 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
1092 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
1093
1094 with self.subTest("GT > GE"), self.assertRaises(TypeError):
1095 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
1096
1097 with self.subTest("GT > LT"), self.assertRaises(TypeError):
1098 ImplementsGreaterThan(5) > ImplementsLessThan(5)
1099
1100 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
1101 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
1102
1103 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
1104 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
1105
1106 with self.subTest("GE when equal"):
1107 a = ComparatorNotImplemented(8)
1108 b = ComparatorNotImplemented(8)
1109 self.assertEqual(a, b)
1110 with self.assertRaises(TypeError):
1111 a >= b
1112
1113 with self.subTest("LE when equal"):
1114 a = ComparatorNotImplemented(9)
1115 b = ComparatorNotImplemented(9)
1116 self.assertEqual(a, b)
1117 with self.assertRaises(TypeError):
1118 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +02001119
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001120 def test_pickle(self):
Serhiy Storchaka92bb90a2016-09-22 11:39:25 +03001121 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001122 for name in '__lt__', '__gt__', '__le__', '__ge__':
1123 with self.subTest(method=name, proto=proto):
1124 method = getattr(Orderable_LT, name)
1125 method_copy = pickle.loads(pickle.dumps(method, proto))
1126 self.assertIs(method_copy, method)
1127
1128@functools.total_ordering
1129class Orderable_LT:
1130 def __init__(self, value):
1131 self.value = value
1132 def __lt__(self, other):
1133 return self.value < other.value
1134 def __eq__(self, other):
1135 return self.value == other.value
1136
1137
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001138class TestLRU:
Georg Brandl2e7346a2010-07-31 18:09:23 +00001139
1140 def test_lru(self):
1141 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001142 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001143 f = self.module.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001144 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001145 self.assertEqual(maxsize, 20)
1146 self.assertEqual(currsize, 0)
1147 self.assertEqual(hits, 0)
1148 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001149
1150 domain = range(5)
1151 for i in range(1000):
1152 x, y = choice(domain), choice(domain)
1153 actual = f(x, y)
1154 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +00001155 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001156 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001157 self.assertTrue(hits > misses)
1158 self.assertEqual(hits + misses, 1000)
1159 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001160
Raymond Hettinger02566ec2010-09-04 22:46:06 +00001161 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +00001162 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001163 self.assertEqual(hits, 0)
1164 self.assertEqual(misses, 0)
1165 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001166 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001167 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001168 self.assertEqual(hits, 0)
1169 self.assertEqual(misses, 1)
1170 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001171
Nick Coghlan98876832010-08-17 06:17:18 +00001172 # Test bypassing the cache
1173 self.assertIs(f.__wrapped__, orig)
1174 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001175 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001176 self.assertEqual(hits, 0)
1177 self.assertEqual(misses, 1)
1178 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +00001179
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001180 # test size zero (which means "never-cache")
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001181 @self.module.lru_cache(0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001182 def f():
1183 nonlocal f_cnt
1184 f_cnt += 1
1185 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001186 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001187 f_cnt = 0
1188 for i in range(5):
1189 self.assertEqual(f(), 20)
1190 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001191 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001192 self.assertEqual(hits, 0)
1193 self.assertEqual(misses, 5)
1194 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001195
1196 # test size one
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001197 @self.module.lru_cache(1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001198 def f():
1199 nonlocal f_cnt
1200 f_cnt += 1
1201 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001202 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001203 f_cnt = 0
1204 for i in range(5):
1205 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001206 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001207 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001208 self.assertEqual(hits, 4)
1209 self.assertEqual(misses, 1)
1210 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001211
Raymond Hettingerf3098282010-08-15 03:30:45 +00001212 # test size two
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001213 @self.module.lru_cache(2)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001214 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001215 nonlocal f_cnt
1216 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +00001217 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +00001218 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001219 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +00001220 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1221 # * * * *
1222 self.assertEqual(f(x), x*10)
1223 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001224 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001225 self.assertEqual(hits, 12)
1226 self.assertEqual(misses, 4)
1227 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001228
Miss Islington (bot)b2b023c2019-01-26 00:23:40 -08001229 def test_lru_bug_35780(self):
1230 # C version of the lru_cache was not checking to see if
1231 # the user function call has already modified the cache
1232 # (this arises in recursive calls and in multi-threading).
1233 # This cause the cache to have orphan links not referenced
1234 # by the cache dictionary.
1235
1236 once = True # Modified by f(x) below
1237
1238 @self.module.lru_cache(maxsize=10)
1239 def f(x):
1240 nonlocal once
1241 rv = f'.{x}.'
1242 if x == 20 and once:
1243 once = False
1244 rv = f(x)
1245 return rv
1246
1247 # Fill the cache
1248 for x in range(15):
1249 self.assertEqual(f(x), f'.{x}.')
1250 self.assertEqual(f.cache_info().currsize, 10)
1251
1252 # Make a recursive call and make sure the cache remains full
1253 self.assertEqual(f(20), '.20.')
1254 self.assertEqual(f.cache_info().currsize, 10)
1255
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001256 def test_lru_hash_only_once(self):
1257 # To protect against weird reentrancy bugs and to improve
1258 # efficiency when faced with slow __hash__ methods, the
1259 # LRU cache guarantees that it will only call __hash__
1260 # only once per use as an argument to the cached function.
1261
1262 @self.module.lru_cache(maxsize=1)
1263 def f(x, y):
1264 return x * 3 + y
1265
1266 # Simulate the integer 5
1267 mock_int = unittest.mock.Mock()
1268 mock_int.__mul__ = unittest.mock.Mock(return_value=15)
1269 mock_int.__hash__ = unittest.mock.Mock(return_value=999)
1270
1271 # Add to cache: One use as an argument gives one call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001272 self.assertEqual(f(mock_int, 1), 16)
1273 self.assertEqual(mock_int.__hash__.call_count, 1)
1274 self.assertEqual(f.cache_info(), (0, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001275
1276 # Cache hit: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001277 self.assertEqual(f(mock_int, 1), 16)
1278 self.assertEqual(mock_int.__hash__.call_count, 2)
1279 self.assertEqual(f.cache_info(), (1, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001280
Ville Skyttä49b27342017-08-03 09:00:59 +03001281 # Cache eviction: No use as an argument gives no additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001282 self.assertEqual(f(6, 2), 20)
1283 self.assertEqual(mock_int.__hash__.call_count, 2)
1284 self.assertEqual(f.cache_info(), (1, 2, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001285
1286 # Cache miss: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001287 self.assertEqual(f(mock_int, 1), 16)
1288 self.assertEqual(mock_int.__hash__.call_count, 3)
1289 self.assertEqual(f.cache_info(), (1, 3, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001290
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08001291 def test_lru_reentrancy_with_len(self):
1292 # Test to make sure the LRU cache code isn't thrown-off by
1293 # caching the built-in len() function. Since len() can be
1294 # cached, we shouldn't use it inside the lru code itself.
1295 old_len = builtins.len
1296 try:
1297 builtins.len = self.module.lru_cache(4)(len)
1298 for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
1299 self.assertEqual(len('abcdefghijklmn'[:i]), i)
1300 finally:
1301 builtins.len = old_len
1302
Raymond Hettinger605a4472017-01-09 07:50:19 -08001303 def test_lru_star_arg_handling(self):
1304 # Test regression that arose in ea064ff3c10f
1305 @functools.lru_cache()
1306 def f(*args):
1307 return args
1308
1309 self.assertEqual(f(1, 2), (1, 2))
1310 self.assertEqual(f((1, 2)), ((1, 2),))
1311
Yury Selivanov46a02db2016-11-09 18:55:45 -05001312 def test_lru_type_error(self):
1313 # Regression test for issue #28653.
1314 # lru_cache was leaking when one of the arguments
1315 # wasn't cacheable.
1316
1317 @functools.lru_cache(maxsize=None)
1318 def infinite_cache(o):
1319 pass
1320
1321 @functools.lru_cache(maxsize=10)
1322 def limited_cache(o):
1323 pass
1324
1325 with self.assertRaises(TypeError):
1326 infinite_cache([])
1327
1328 with self.assertRaises(TypeError):
1329 limited_cache([])
1330
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001331 def test_lru_with_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001332 @self.module.lru_cache(maxsize=None)
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001333 def fib(n):
1334 if n < 2:
1335 return n
1336 return fib(n-1) + fib(n-2)
1337 self.assertEqual([fib(n) for n in range(16)],
1338 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1339 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001340 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001341 fib.cache_clear()
1342 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001343 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1344
1345 def test_lru_with_maxsize_negative(self):
1346 @self.module.lru_cache(maxsize=-10)
1347 def eq(n):
1348 return n
1349 for i in (0, 1):
1350 self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1351 self.assertEqual(eq.cache_info(),
Miss Islington (bot)b2b023c2019-01-26 00:23:40 -08001352 self.module._CacheInfo(hits=0, misses=300, maxsize=0, currsize=0))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001353
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001354 def test_lru_with_exceptions(self):
1355 # Verify that user_function exceptions get passed through without
1356 # creating a hard-to-read chained exception.
1357 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +01001358 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001359 @self.module.lru_cache(maxsize)
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001360 def func(i):
1361 return 'abc'[i]
1362 self.assertEqual(func(0), 'a')
1363 with self.assertRaises(IndexError) as cm:
1364 func(15)
1365 self.assertIsNone(cm.exception.__context__)
1366 # Verify that the previous exception did not result in a cached entry
1367 with self.assertRaises(IndexError):
1368 func(15)
1369
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001370 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001371 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001372 @self.module.lru_cache(maxsize=maxsize, typed=True)
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001373 def square(x):
1374 return x * x
1375 self.assertEqual(square(3), 9)
1376 self.assertEqual(type(square(3)), type(9))
1377 self.assertEqual(square(3.0), 9.0)
1378 self.assertEqual(type(square(3.0)), type(9.0))
1379 self.assertEqual(square(x=3), 9)
1380 self.assertEqual(type(square(x=3)), type(9))
1381 self.assertEqual(square(x=3.0), 9.0)
1382 self.assertEqual(type(square(x=3.0)), type(9.0))
1383 self.assertEqual(square.cache_info().hits, 4)
1384 self.assertEqual(square.cache_info().misses, 4)
1385
Antoine Pitroub5b37142012-11-13 21:35:40 +01001386 def test_lru_with_keyword_args(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001387 @self.module.lru_cache()
Antoine Pitroub5b37142012-11-13 21:35:40 +01001388 def fib(n):
1389 if n < 2:
1390 return n
1391 return fib(n=n-1) + fib(n=n-2)
1392 self.assertEqual(
1393 [fib(n=number) for number in range(16)],
1394 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1395 )
1396 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001397 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001398 fib.cache_clear()
1399 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001400 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001401
1402 def test_lru_with_keyword_args_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001403 @self.module.lru_cache(maxsize=None)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001404 def fib(n):
1405 if n < 2:
1406 return n
1407 return fib(n=n-1) + fib(n=n-2)
1408 self.assertEqual([fib(n=number) for number in range(16)],
1409 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1410 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001411 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001412 fib.cache_clear()
1413 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001414 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1415
Raymond Hettinger4ee39142017-01-08 17:28:20 -08001416 def test_kwargs_order(self):
1417 # PEP 468: Preserving Keyword Argument Order
1418 @self.module.lru_cache(maxsize=10)
1419 def f(**kwargs):
1420 return list(kwargs.items())
1421 self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)])
1422 self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)])
1423 self.assertEqual(f.cache_info(),
1424 self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2))
1425
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001426 def test_lru_cache_decoration(self):
1427 def f(zomg: 'zomg_annotation'):
1428 """f doc string"""
1429 return 42
1430 g = self.module.lru_cache()(f)
1431 for attr in self.module.WRAPPER_ASSIGNMENTS:
1432 self.assertEqual(getattr(g, attr), getattr(f, attr))
1433
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001434 def test_lru_cache_threaded(self):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001435 n, m = 5, 11
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001436 def orig(x, y):
1437 return 3 * x + y
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001438 f = self.module.lru_cache(maxsize=n*m)(orig)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001439 hits, misses, maxsize, currsize = f.cache_info()
1440 self.assertEqual(currsize, 0)
1441
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001442 start = threading.Event()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001443 def full(k):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001444 start.wait(10)
1445 for _ in range(m):
Serhiy Storchaka391af752015-06-08 12:44:18 +03001446 self.assertEqual(f(k, 0), orig(k, 0))
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001447
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001448 def clear():
1449 start.wait(10)
1450 for _ in range(2*m):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001451 f.cache_clear()
1452
1453 orig_si = sys.getswitchinterval()
Xavier de Gaye7522ef42016-12-08 11:06:56 +01001454 support.setswitchinterval(1e-6)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001455 try:
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001456 # create n threads in order to fill cache
Serhiy Storchaka391af752015-06-08 12:44:18 +03001457 threads = [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001458 for k in range(n)]
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001459 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001460 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001461
1462 hits, misses, maxsize, currsize = f.cache_info()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001463 if self.module is py_functools:
1464 # XXX: Why can be not equal?
1465 self.assertLessEqual(misses, n)
1466 self.assertLessEqual(hits, m*n - misses)
1467 else:
1468 self.assertEqual(misses, n)
1469 self.assertEqual(hits, m*n - misses)
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001470 self.assertEqual(currsize, n)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001471
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001472 # create n threads in order to fill cache and 1 to clear it
1473 threads = [threading.Thread(target=clear)]
Serhiy Storchaka391af752015-06-08 12:44:18 +03001474 threads += [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001475 for k in range(n)]
1476 start.clear()
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001477 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001478 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001479 finally:
1480 sys.setswitchinterval(orig_si)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001481
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001482 def test_lru_cache_threaded2(self):
1483 # Simultaneous call with the same arguments
1484 n, m = 5, 7
1485 start = threading.Barrier(n+1)
1486 pause = threading.Barrier(n+1)
1487 stop = threading.Barrier(n+1)
1488 @self.module.lru_cache(maxsize=m*n)
1489 def f(x):
1490 pause.wait(10)
1491 return 3 * x
1492 self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1493 def test():
1494 for i in range(m):
1495 start.wait(10)
1496 self.assertEqual(f(i), 3 * i)
1497 stop.wait(10)
1498 threads = [threading.Thread(target=test) for k in range(n)]
1499 with support.start_threads(threads):
1500 for i in range(m):
1501 start.wait(10)
1502 stop.reset()
1503 pause.wait(10)
1504 start.reset()
1505 stop.wait(10)
1506 pause.reset()
1507 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1508
Serhiy Storchaka67796522017-01-12 18:34:33 +02001509 def test_lru_cache_threaded3(self):
1510 @self.module.lru_cache(maxsize=2)
1511 def f(x):
1512 time.sleep(.01)
1513 return 3 * x
1514 def test(i, x):
1515 with self.subTest(thread=i):
1516 self.assertEqual(f(x), 3 * x, i)
1517 threads = [threading.Thread(target=test, args=(i, v))
1518 for i, v in enumerate([1, 2, 2, 3, 2])]
1519 with support.start_threads(threads):
1520 pass
1521
Raymond Hettinger03923422013-03-04 02:52:50 -05001522 def test_need_for_rlock(self):
1523 # This will deadlock on an LRU cache that uses a regular lock
1524
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001525 @self.module.lru_cache(maxsize=10)
Raymond Hettinger03923422013-03-04 02:52:50 -05001526 def test_func(x):
1527 'Used to demonstrate a reentrant lru_cache call within a single thread'
1528 return x
1529
1530 class DoubleEq:
1531 'Demonstrate a reentrant lru_cache call within a single thread'
1532 def __init__(self, x):
1533 self.x = x
1534 def __hash__(self):
1535 return self.x
1536 def __eq__(self, other):
1537 if self.x == 2:
1538 test_func(DoubleEq(1))
1539 return self.x == other.x
1540
1541 test_func(DoubleEq(1)) # Load the cache
1542 test_func(DoubleEq(2)) # Load the cache
1543 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1544 DoubleEq(2)) # Verify the correct return value
1545
Raymond Hettinger4d588972014-08-12 12:44:52 -07001546 def test_early_detection_of_bad_call(self):
1547 # Issue #22184
1548 with self.assertRaises(TypeError):
1549 @functools.lru_cache
1550 def f():
1551 pass
1552
Serhiy Storchakae7070f02015-06-08 11:19:24 +03001553 def test_lru_method(self):
1554 class X(int):
1555 f_cnt = 0
1556 @self.module.lru_cache(2)
1557 def f(self, x):
1558 self.f_cnt += 1
1559 return x*10+self
1560 a = X(5)
1561 b = X(5)
1562 c = X(7)
1563 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1564
1565 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1566 self.assertEqual(a.f(x), x*10 + 5)
1567 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1568 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1569
1570 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1571 self.assertEqual(b.f(x), x*10 + 5)
1572 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1573 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1574
1575 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1576 self.assertEqual(c.f(x), x*10 + 7)
1577 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1578 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1579
1580 self.assertEqual(a.f.cache_info(), X.f.cache_info())
1581 self.assertEqual(b.f.cache_info(), X.f.cache_info())
1582 self.assertEqual(c.f.cache_info(), X.f.cache_info())
1583
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001584 def test_pickle(self):
1585 cls = self.__class__
1586 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1587 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1588 with self.subTest(proto=proto, func=f):
1589 f_copy = pickle.loads(pickle.dumps(f, proto))
1590 self.assertIs(f_copy, f)
1591
1592 def test_copy(self):
1593 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001594 def orig(x, y):
1595 return 3 * x + y
1596 part = self.module.partial(orig, 2)
1597 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1598 self.module.lru_cache(2)(part))
1599 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001600 with self.subTest(func=f):
1601 f_copy = copy.copy(f)
1602 self.assertIs(f_copy, f)
1603
1604 def test_deepcopy(self):
1605 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001606 def orig(x, y):
1607 return 3 * x + y
1608 part = self.module.partial(orig, 2)
1609 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1610 self.module.lru_cache(2)(part))
1611 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001612 with self.subTest(func=f):
1613 f_copy = copy.deepcopy(f)
1614 self.assertIs(f_copy, f)
1615
1616
1617@py_functools.lru_cache()
1618def py_cached_func(x, y):
1619 return 3 * x + y
1620
1621@c_functools.lru_cache()
1622def c_cached_func(x, y):
1623 return 3 * x + y
1624
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001625
1626class TestLRUPy(TestLRU, unittest.TestCase):
1627 module = py_functools
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001628 cached_func = py_cached_func,
1629
1630 @module.lru_cache()
1631 def cached_meth(self, x, y):
1632 return 3 * x + y
1633
1634 @staticmethod
1635 @module.lru_cache()
1636 def cached_staticmeth(x, y):
1637 return 3 * x + y
1638
1639
1640class TestLRUC(TestLRU, unittest.TestCase):
1641 module = c_functools
1642 cached_func = c_cached_func,
1643
1644 @module.lru_cache()
1645 def cached_meth(self, x, y):
1646 return 3 * x + y
1647
1648 @staticmethod
1649 @module.lru_cache()
1650 def cached_staticmeth(x, y):
1651 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001652
Raymond Hettinger03923422013-03-04 02:52:50 -05001653
Łukasz Langa6f692512013-06-05 12:20:24 +02001654class TestSingleDispatch(unittest.TestCase):
1655 def test_simple_overloads(self):
1656 @functools.singledispatch
1657 def g(obj):
1658 return "base"
1659 def g_int(i):
1660 return "integer"
1661 g.register(int, g_int)
1662 self.assertEqual(g("str"), "base")
1663 self.assertEqual(g(1), "integer")
1664 self.assertEqual(g([1,2,3]), "base")
1665
1666 def test_mro(self):
1667 @functools.singledispatch
1668 def g(obj):
1669 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001670 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001671 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001672 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001673 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001674 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001675 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001676 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001677 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001678 def g_A(a):
1679 return "A"
1680 def g_B(b):
1681 return "B"
1682 g.register(A, g_A)
1683 g.register(B, g_B)
1684 self.assertEqual(g(A()), "A")
1685 self.assertEqual(g(B()), "B")
1686 self.assertEqual(g(C()), "A")
1687 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001688
1689 def test_register_decorator(self):
1690 @functools.singledispatch
1691 def g(obj):
1692 return "base"
1693 @g.register(int)
1694 def g_int(i):
1695 return "int %s" % (i,)
1696 self.assertEqual(g(""), "base")
1697 self.assertEqual(g(12), "int 12")
1698 self.assertIs(g.dispatch(int), g_int)
1699 self.assertIs(g.dispatch(object), g.dispatch(str))
1700 # Note: in the assert above this is not g.
1701 # @singledispatch returns the wrapper.
1702
1703 def test_wrapping_attributes(self):
1704 @functools.singledispatch
1705 def g(obj):
1706 "Simple test"
1707 return "Test"
1708 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02001709 if sys.flags.optimize < 2:
1710 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02001711
1712 @unittest.skipUnless(decimal, 'requires _decimal')
1713 @support.cpython_only
1714 def test_c_classes(self):
1715 @functools.singledispatch
1716 def g(obj):
1717 return "base"
1718 @g.register(decimal.DecimalException)
1719 def _(obj):
1720 return obj.args
1721 subn = decimal.Subnormal("Exponent < Emin")
1722 rnd = decimal.Rounded("Number got rounded")
1723 self.assertEqual(g(subn), ("Exponent < Emin",))
1724 self.assertEqual(g(rnd), ("Number got rounded",))
1725 @g.register(decimal.Subnormal)
1726 def _(obj):
1727 return "Too small to care."
1728 self.assertEqual(g(subn), "Too small to care.")
1729 self.assertEqual(g(rnd), ("Number got rounded",))
1730
1731 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001732 # None of the examples in this test depend on haystack ordering.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001733 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001734 mro = functools._compose_mro
1735 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1736 for haystack in permutations(bases):
1737 m = mro(dict, haystack)
Guido van Rossumf0666942016-08-23 10:47:07 -07001738 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
1739 c.Collection, c.Sized, c.Iterable,
1740 c.Container, object])
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001741 bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict]
Łukasz Langa6f692512013-06-05 12:20:24 +02001742 for haystack in permutations(bases):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001743 m = mro(collections.ChainMap, haystack)
1744 self.assertEqual(m, [collections.ChainMap, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001745 c.Collection, c.Sized, c.Iterable,
1746 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001747
1748 # If there's a generic function with implementations registered for
1749 # both Sized and Container, passing a defaultdict to it results in an
1750 # ambiguous dispatch which will cause a RuntimeError (see
1751 # test_mro_conflicts).
1752 bases = [c.Container, c.Sized, str]
1753 for haystack in permutations(bases):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001754 m = mro(collections.defaultdict, [c.Sized, c.Container, str])
1755 self.assertEqual(m, [collections.defaultdict, dict, c.Sized,
1756 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001757
1758 # MutableSequence below is registered directly on D. In other words, it
Martin Panter46f50722016-05-26 05:35:26 +00001759 # precedes MutableMapping which means single dispatch will always
Łukasz Langa3720c772013-07-01 16:00:38 +02001760 # choose MutableSequence here.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001761 class D(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02001762 pass
1763 c.MutableSequence.register(D)
1764 bases = [c.MutableSequence, c.MutableMapping]
1765 for haystack in permutations(bases):
1766 m = mro(D, bases)
Guido van Rossumf0666942016-08-23 10:47:07 -07001767 self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001768 collections.defaultdict, dict, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001769 c.Collection, c.Sized, c.Iterable, c.Container,
Łukasz Langa3720c772013-07-01 16:00:38 +02001770 object])
1771
1772 # Container and Callable are registered on different base classes and
1773 # a generic function supporting both should always pick the Callable
1774 # implementation if a C instance is passed.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001775 class C(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02001776 def __call__(self):
1777 pass
1778 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1779 for haystack in permutations(bases):
1780 m = mro(C, haystack)
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001781 self.assertEqual(m, [C, c.Callable, collections.defaultdict, dict, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001782 c.Collection, c.Sized, c.Iterable,
1783 c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001784
1785 def test_register_abc(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001786 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001787 d = {"a": "b"}
1788 l = [1, 2, 3]
1789 s = {object(), None}
1790 f = frozenset(s)
1791 t = (1, 2, 3)
1792 @functools.singledispatch
1793 def g(obj):
1794 return "base"
1795 self.assertEqual(g(d), "base")
1796 self.assertEqual(g(l), "base")
1797 self.assertEqual(g(s), "base")
1798 self.assertEqual(g(f), "base")
1799 self.assertEqual(g(t), "base")
1800 g.register(c.Sized, lambda obj: "sized")
1801 self.assertEqual(g(d), "sized")
1802 self.assertEqual(g(l), "sized")
1803 self.assertEqual(g(s), "sized")
1804 self.assertEqual(g(f), "sized")
1805 self.assertEqual(g(t), "sized")
1806 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1807 self.assertEqual(g(d), "mutablemapping")
1808 self.assertEqual(g(l), "sized")
1809 self.assertEqual(g(s), "sized")
1810 self.assertEqual(g(f), "sized")
1811 self.assertEqual(g(t), "sized")
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001812 g.register(collections.ChainMap, lambda obj: "chainmap")
Łukasz Langa6f692512013-06-05 12:20:24 +02001813 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1814 self.assertEqual(g(l), "sized")
1815 self.assertEqual(g(s), "sized")
1816 self.assertEqual(g(f), "sized")
1817 self.assertEqual(g(t), "sized")
1818 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1819 self.assertEqual(g(d), "mutablemapping")
1820 self.assertEqual(g(l), "mutablesequence")
1821 self.assertEqual(g(s), "sized")
1822 self.assertEqual(g(f), "sized")
1823 self.assertEqual(g(t), "sized")
1824 g.register(c.MutableSet, lambda obj: "mutableset")
1825 self.assertEqual(g(d), "mutablemapping")
1826 self.assertEqual(g(l), "mutablesequence")
1827 self.assertEqual(g(s), "mutableset")
1828 self.assertEqual(g(f), "sized")
1829 self.assertEqual(g(t), "sized")
1830 g.register(c.Mapping, lambda obj: "mapping")
1831 self.assertEqual(g(d), "mutablemapping") # not specific enough
1832 self.assertEqual(g(l), "mutablesequence")
1833 self.assertEqual(g(s), "mutableset")
1834 self.assertEqual(g(f), "sized")
1835 self.assertEqual(g(t), "sized")
1836 g.register(c.Sequence, lambda obj: "sequence")
1837 self.assertEqual(g(d), "mutablemapping")
1838 self.assertEqual(g(l), "mutablesequence")
1839 self.assertEqual(g(s), "mutableset")
1840 self.assertEqual(g(f), "sized")
1841 self.assertEqual(g(t), "sequence")
1842 g.register(c.Set, lambda obj: "set")
1843 self.assertEqual(g(d), "mutablemapping")
1844 self.assertEqual(g(l), "mutablesequence")
1845 self.assertEqual(g(s), "mutableset")
1846 self.assertEqual(g(f), "set")
1847 self.assertEqual(g(t), "sequence")
1848 g.register(dict, lambda obj: "dict")
1849 self.assertEqual(g(d), "dict")
1850 self.assertEqual(g(l), "mutablesequence")
1851 self.assertEqual(g(s), "mutableset")
1852 self.assertEqual(g(f), "set")
1853 self.assertEqual(g(t), "sequence")
1854 g.register(list, lambda obj: "list")
1855 self.assertEqual(g(d), "dict")
1856 self.assertEqual(g(l), "list")
1857 self.assertEqual(g(s), "mutableset")
1858 self.assertEqual(g(f), "set")
1859 self.assertEqual(g(t), "sequence")
1860 g.register(set, lambda obj: "concrete-set")
1861 self.assertEqual(g(d), "dict")
1862 self.assertEqual(g(l), "list")
1863 self.assertEqual(g(s), "concrete-set")
1864 self.assertEqual(g(f), "set")
1865 self.assertEqual(g(t), "sequence")
1866 g.register(frozenset, lambda obj: "frozen-set")
1867 self.assertEqual(g(d), "dict")
1868 self.assertEqual(g(l), "list")
1869 self.assertEqual(g(s), "concrete-set")
1870 self.assertEqual(g(f), "frozen-set")
1871 self.assertEqual(g(t), "sequence")
1872 g.register(tuple, lambda obj: "tuple")
1873 self.assertEqual(g(d), "dict")
1874 self.assertEqual(g(l), "list")
1875 self.assertEqual(g(s), "concrete-set")
1876 self.assertEqual(g(f), "frozen-set")
1877 self.assertEqual(g(t), "tuple")
1878
Łukasz Langa3720c772013-07-01 16:00:38 +02001879 def test_c3_abc(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001880 c = collections.abc
Łukasz Langa3720c772013-07-01 16:00:38 +02001881 mro = functools._c3_mro
1882 class A(object):
1883 pass
1884 class B(A):
1885 def __len__(self):
1886 return 0 # implies Sized
1887 @c.Container.register
1888 class C(object):
1889 pass
1890 class D(object):
1891 pass # unrelated
1892 class X(D, C, B):
1893 def __call__(self):
1894 pass # implies Callable
1895 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
1896 for abcs in permutations([c.Sized, c.Callable, c.Container]):
1897 self.assertEqual(mro(X, abcs=abcs), expected)
1898 # unrelated ABCs don't appear in the resulting MRO
1899 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
1900 self.assertEqual(mro(X, abcs=many_abcs), expected)
1901
Yury Selivanov77a8cd62015-08-18 14:20:00 -04001902 def test_false_meta(self):
1903 # see issue23572
1904 class MetaA(type):
1905 def __len__(self):
1906 return 0
1907 class A(metaclass=MetaA):
1908 pass
1909 class AA(A):
1910 pass
1911 @functools.singledispatch
1912 def fun(a):
1913 return 'base A'
1914 @fun.register(A)
1915 def _(a):
1916 return 'fun A'
1917 aa = AA()
1918 self.assertEqual(fun(aa), 'fun A')
1919
Łukasz Langa6f692512013-06-05 12:20:24 +02001920 def test_mro_conflicts(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001921 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001922 @functools.singledispatch
1923 def g(arg):
1924 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02001925 class O(c.Sized):
1926 def __len__(self):
1927 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001928 o = O()
1929 self.assertEqual(g(o), "base")
1930 g.register(c.Iterable, lambda arg: "iterable")
1931 g.register(c.Container, lambda arg: "container")
1932 g.register(c.Sized, lambda arg: "sized")
1933 g.register(c.Set, lambda arg: "set")
1934 self.assertEqual(g(o), "sized")
1935 c.Iterable.register(O)
1936 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
1937 c.Container.register(O)
1938 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02001939 c.Set.register(O)
1940 self.assertEqual(g(o), "set") # because c.Set is a subclass of
1941 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02001942 class P:
1943 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02001944 p = P()
1945 self.assertEqual(g(p), "base")
1946 c.Iterable.register(P)
1947 self.assertEqual(g(p), "iterable")
1948 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02001949 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02001950 g(p)
Benjamin Petersonab078e92016-07-13 21:13:29 -07001951 self.assertIn(
1952 str(re_one.exception),
1953 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1954 "or <class 'collections.abc.Iterable'>"),
1955 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
1956 "or <class 'collections.abc.Container'>")),
1957 )
Łukasz Langa6f692512013-06-05 12:20:24 +02001958 class Q(c.Sized):
1959 def __len__(self):
1960 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001961 q = Q()
1962 self.assertEqual(g(q), "sized")
1963 c.Iterable.register(Q)
1964 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
1965 c.Set.register(Q)
1966 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02001967 # c.Sized and c.Iterable
1968 @functools.singledispatch
1969 def h(arg):
1970 return "base"
1971 @h.register(c.Sized)
1972 def _(arg):
1973 return "sized"
1974 @h.register(c.Container)
1975 def _(arg):
1976 return "container"
1977 # Even though Sized and Container are explicit bases of MutableMapping,
1978 # this ABC is implicitly registered on defaultdict which makes all of
1979 # MutableMapping's bases implicit as well from defaultdict's
1980 # perspective.
1981 with self.assertRaises(RuntimeError) as re_two:
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001982 h(collections.defaultdict(lambda: 0))
Benjamin Petersonab078e92016-07-13 21:13:29 -07001983 self.assertIn(
1984 str(re_two.exception),
1985 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1986 "or <class 'collections.abc.Sized'>"),
1987 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1988 "or <class 'collections.abc.Container'>")),
1989 )
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001990 class R(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02001991 pass
1992 c.MutableSequence.register(R)
1993 @functools.singledispatch
1994 def i(arg):
1995 return "base"
1996 @i.register(c.MutableMapping)
1997 def _(arg):
1998 return "mapping"
1999 @i.register(c.MutableSequence)
2000 def _(arg):
2001 return "sequence"
2002 r = R()
2003 self.assertEqual(i(r), "sequence")
2004 class S:
2005 pass
2006 class T(S, c.Sized):
2007 def __len__(self):
2008 return 0
2009 t = T()
2010 self.assertEqual(h(t), "sized")
2011 c.Container.register(T)
2012 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
2013 class U:
2014 def __len__(self):
2015 return 0
2016 u = U()
2017 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
2018 # from the existence of __len__()
2019 c.Container.register(U)
2020 # There is no preference for registered versus inferred ABCs.
2021 with self.assertRaises(RuntimeError) as re_three:
2022 h(u)
Benjamin Petersonab078e92016-07-13 21:13:29 -07002023 self.assertIn(
2024 str(re_three.exception),
2025 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2026 "or <class 'collections.abc.Sized'>"),
2027 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2028 "or <class 'collections.abc.Container'>")),
2029 )
Łukasz Langa3720c772013-07-01 16:00:38 +02002030 class V(c.Sized, S):
2031 def __len__(self):
2032 return 0
2033 @functools.singledispatch
2034 def j(arg):
2035 return "base"
2036 @j.register(S)
2037 def _(arg):
2038 return "s"
2039 @j.register(c.Container)
2040 def _(arg):
2041 return "container"
2042 v = V()
2043 self.assertEqual(j(v), "s")
2044 c.Container.register(V)
2045 self.assertEqual(j(v), "container") # because it ends up right after
2046 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02002047
2048 def test_cache_invalidation(self):
2049 from collections import UserDict
INADA Naoki9811e802017-09-30 16:13:02 +09002050 import weakref
2051
Łukasz Langa6f692512013-06-05 12:20:24 +02002052 class TracingDict(UserDict):
2053 def __init__(self, *args, **kwargs):
2054 super(TracingDict, self).__init__(*args, **kwargs)
2055 self.set_ops = []
2056 self.get_ops = []
2057 def __getitem__(self, key):
2058 result = self.data[key]
2059 self.get_ops.append(key)
2060 return result
2061 def __setitem__(self, key, value):
2062 self.set_ops.append(key)
2063 self.data[key] = value
2064 def clear(self):
2065 self.data.clear()
INADA Naoki9811e802017-09-30 16:13:02 +09002066
Łukasz Langa6f692512013-06-05 12:20:24 +02002067 td = TracingDict()
INADA Naoki9811e802017-09-30 16:13:02 +09002068 with support.swap_attr(weakref, "WeakKeyDictionary", lambda: td):
2069 c = collections.abc
2070 @functools.singledispatch
2071 def g(arg):
2072 return "base"
2073 d = {}
2074 l = []
2075 self.assertEqual(len(td), 0)
2076 self.assertEqual(g(d), "base")
2077 self.assertEqual(len(td), 1)
2078 self.assertEqual(td.get_ops, [])
2079 self.assertEqual(td.set_ops, [dict])
2080 self.assertEqual(td.data[dict], g.registry[object])
2081 self.assertEqual(g(l), "base")
2082 self.assertEqual(len(td), 2)
2083 self.assertEqual(td.get_ops, [])
2084 self.assertEqual(td.set_ops, [dict, list])
2085 self.assertEqual(td.data[dict], g.registry[object])
2086 self.assertEqual(td.data[list], g.registry[object])
2087 self.assertEqual(td.data[dict], td.data[list])
2088 self.assertEqual(g(l), "base")
2089 self.assertEqual(g(d), "base")
2090 self.assertEqual(td.get_ops, [list, dict])
2091 self.assertEqual(td.set_ops, [dict, list])
2092 g.register(list, lambda arg: "list")
2093 self.assertEqual(td.get_ops, [list, dict])
2094 self.assertEqual(len(td), 0)
2095 self.assertEqual(g(d), "base")
2096 self.assertEqual(len(td), 1)
2097 self.assertEqual(td.get_ops, [list, dict])
2098 self.assertEqual(td.set_ops, [dict, list, dict])
2099 self.assertEqual(td.data[dict],
2100 functools._find_impl(dict, g.registry))
2101 self.assertEqual(g(l), "list")
2102 self.assertEqual(len(td), 2)
2103 self.assertEqual(td.get_ops, [list, dict])
2104 self.assertEqual(td.set_ops, [dict, list, dict, list])
2105 self.assertEqual(td.data[list],
2106 functools._find_impl(list, g.registry))
2107 class X:
2108 pass
2109 c.MutableMapping.register(X) # Will not invalidate the cache,
2110 # not using ABCs yet.
2111 self.assertEqual(g(d), "base")
2112 self.assertEqual(g(l), "list")
2113 self.assertEqual(td.get_ops, [list, dict, dict, list])
2114 self.assertEqual(td.set_ops, [dict, list, dict, list])
2115 g.register(c.Sized, lambda arg: "sized")
2116 self.assertEqual(len(td), 0)
2117 self.assertEqual(g(d), "sized")
2118 self.assertEqual(len(td), 1)
2119 self.assertEqual(td.get_ops, [list, dict, dict, list])
2120 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
2121 self.assertEqual(g(l), "list")
2122 self.assertEqual(len(td), 2)
2123 self.assertEqual(td.get_ops, [list, dict, dict, list])
2124 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2125 self.assertEqual(g(l), "list")
2126 self.assertEqual(g(d), "sized")
2127 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
2128 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2129 g.dispatch(list)
2130 g.dispatch(dict)
2131 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
2132 list, dict])
2133 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2134 c.MutableSet.register(X) # Will invalidate the cache.
2135 self.assertEqual(len(td), 2) # Stale cache.
2136 self.assertEqual(g(l), "list")
2137 self.assertEqual(len(td), 1)
2138 g.register(c.MutableMapping, lambda arg: "mutablemapping")
2139 self.assertEqual(len(td), 0)
2140 self.assertEqual(g(d), "mutablemapping")
2141 self.assertEqual(len(td), 1)
2142 self.assertEqual(g(l), "list")
2143 self.assertEqual(len(td), 2)
2144 g.register(dict, lambda arg: "dict")
2145 self.assertEqual(g(d), "dict")
2146 self.assertEqual(g(l), "list")
2147 g._clear_cache()
2148 self.assertEqual(len(td), 0)
Łukasz Langa6f692512013-06-05 12:20:24 +02002149
Łukasz Langae5697532017-12-11 13:56:31 -08002150 def test_annotations(self):
2151 @functools.singledispatch
2152 def i(arg):
2153 return "base"
2154 @i.register
2155 def _(arg: collections.abc.Mapping):
2156 return "mapping"
2157 @i.register
2158 def _(arg: "collections.abc.Sequence"):
2159 return "sequence"
2160 self.assertEqual(i(None), "base")
2161 self.assertEqual(i({"a": 1}), "mapping")
2162 self.assertEqual(i([1, 2, 3]), "sequence")
2163 self.assertEqual(i((1, 2, 3)), "sequence")
2164 self.assertEqual(i("str"), "sequence")
2165
2166 # Registering classes as callables doesn't work with annotations,
2167 # you need to pass the type explicitly.
2168 @i.register(str)
2169 class _:
2170 def __init__(self, arg):
2171 self.arg = arg
2172
2173 def __eq__(self, other):
2174 return self.arg == other
2175 self.assertEqual(i("str"), "str")
2176
2177 def test_invalid_registrations(self):
2178 msg_prefix = "Invalid first argument to `register()`: "
2179 msg_suffix = (
2180 ". Use either `@register(some_class)` or plain `@register` on an "
2181 "annotated function."
2182 )
2183 @functools.singledispatch
2184 def i(arg):
2185 return "base"
2186 with self.assertRaises(TypeError) as exc:
2187 @i.register(42)
2188 def _(arg):
2189 return "I annotated with a non-type"
2190 self.assertTrue(str(exc.exception).startswith(msg_prefix + "42"))
2191 self.assertTrue(str(exc.exception).endswith(msg_suffix))
2192 with self.assertRaises(TypeError) as exc:
2193 @i.register
2194 def _(arg):
2195 return "I forgot to annotate"
2196 self.assertTrue(str(exc.exception).startswith(msg_prefix +
2197 "<function TestSingleDispatch.test_invalid_registrations.<locals>._"
2198 ))
2199 self.assertTrue(str(exc.exception).endswith(msg_suffix))
2200
2201 # FIXME: The following will only work after PEP 560 is implemented.
2202 return
2203
2204 with self.assertRaises(TypeError) as exc:
2205 @i.register
2206 def _(arg: typing.Iterable[str]):
2207 # At runtime, dispatching on generics is impossible.
2208 # When registering implementations with singledispatch, avoid
2209 # types from `typing`. Instead, annotate with regular types
2210 # or ABCs.
2211 return "I annotated with a generic collection"
2212 self.assertTrue(str(exc.exception).startswith(msg_prefix +
2213 "<function TestSingleDispatch.test_invalid_registrations.<locals>._"
2214 ))
2215 self.assertTrue(str(exc.exception).endswith(msg_suffix))
2216
Miss Islington (bot)df9f6332018-07-10 00:48:57 -07002217 def test_invalid_positional_argument(self):
2218 @functools.singledispatch
2219 def f(*args):
2220 pass
2221 msg = 'f requires at least 1 positional argument'
Miss Islington (bot)892df9d2018-07-16 22:18:56 -07002222 with self.assertRaisesRegex(TypeError, msg):
Miss Islington (bot)df9f6332018-07-10 00:48:57 -07002223 f()
Łukasz Langa6f692512013-06-05 12:20:24 +02002224
Raymond Hettinger9c323f82005-02-28 19:39:44 +00002225if __name__ == '__main__':
Zachary Ware38c707e2015-04-13 15:00:43 -05002226 unittest.main()