blob: a7625d6090398c2d2066af5e83f951763be07a7c [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08002import builtins
Raymond Hettinger003be522011-05-03 11:01:32 -07003import collections
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03004import collections.abc
Serhiy Storchaka45120f22015-10-24 09:49:56 +03005import copy
Łukasz Langa6f692512013-06-05 12:20:24 +02006from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00007import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00008from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02009import sys
10from test import support
Antoine Pitroua6a4dc82017-09-07 18:56:24 +020011import threading
Serhiy Storchaka67796522017-01-12 18:34:33 +020012import time
Łukasz Langae5697532017-12-11 13:56:31 -080013import typing
Łukasz Langa6f692512013-06-05 12:20:24 +020014import unittest
Raymond Hettingerd191ef22017-01-07 20:44:48 -080015import unittest.mock
Łukasz Langa6f692512013-06-05 12:20:24 +020016from weakref import proxy
Nick Coghlan457fc9a2016-09-10 20:00:02 +100017import contextlib
Raymond Hettinger9c323f82005-02-28 19:39:44 +000018
Antoine Pitroub5b37142012-11-13 21:35:40 +010019import functools
20
Antoine Pitroub5b37142012-11-13 21:35:40 +010021py_functools = support.import_fresh_module('functools', blocked=['_functools'])
22c_functools = support.import_fresh_module('functools', fresh=['_functools'])
23
Łukasz Langa6f692512013-06-05 12:20:24 +020024decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
25
Nick Coghlan457fc9a2016-09-10 20:00:02 +100026@contextlib.contextmanager
27def replaced_module(name, replacement):
28 original_module = sys.modules[name]
29 sys.modules[name] = replacement
30 try:
31 yield
32 finally:
33 sys.modules[name] = original_module
Łukasz Langa6f692512013-06-05 12:20:24 +020034
Raymond Hettinger9c323f82005-02-28 19:39:44 +000035def capture(*args, **kw):
36 """capture all positional and keyword arguments"""
37 return args, kw
38
Łukasz Langa6f692512013-06-05 12:20:24 +020039
Jack Diederiche0cbd692009-04-01 04:27:09 +000040def signature(part):
41 """ return the signature of a partial object """
42 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000043
Serhiy Storchaka38741282016-02-02 18:45:17 +020044class MyTuple(tuple):
45 pass
46
47class BadTuple(tuple):
48 def __add__(self, other):
49 return list(self) + list(other)
50
51class MyDict(dict):
52 pass
53
Łukasz Langa6f692512013-06-05 12:20:24 +020054
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020055class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000056
57 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010058 p = self.partial(capture, 1, 2, a=10, b=20)
59 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000060 self.assertEqual(p(3, 4, b=30, c=40),
61 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010062 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000063 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000064
65 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010066 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000067 # attributes should be readable
68 self.assertEqual(p.func, capture)
69 self.assertEqual(p.args, (1, 2))
70 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000071
72 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010073 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000074 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010075 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000076 except TypeError:
77 pass
78 else:
79 self.fail('First arg not checked for callability')
80
81 def test_protection_of_callers_dict_argument(self):
82 # a caller's dictionary should not be altered by partial
83 def func(a=10, b=20):
84 return a
85 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010086 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000087 self.assertEqual(p(**d), 3)
88 self.assertEqual(d, {'a':3})
89 p(b=7)
90 self.assertEqual(d, {'a':3})
91
Serhiy Storchaka9639e4a2017-02-20 14:04:30 +020092 def test_kwargs_copy(self):
93 # Issue #29532: Altering a kwarg dictionary passed to a constructor
94 # should not affect a partial object after creation
95 d = {'a': 3}
96 p = self.partial(capture, **d)
97 self.assertEqual(p(), ((), {'a': 3}))
98 d['a'] = 5
99 self.assertEqual(p(), ((), {'a': 3}))
100
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000101 def test_arg_combinations(self):
102 # exercise special code paths for zero args in either partial
103 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100104 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000105 self.assertEqual(p(), ((), {}))
106 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100107 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000108 self.assertEqual(p(), ((1,2), {}))
109 self.assertEqual(p(3,4), ((1,2,3,4), {}))
110
111 def test_kw_combinations(self):
112 # exercise special code paths for no keyword args in
113 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100114 p = self.partial(capture)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400115 self.assertEqual(p.keywords, {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000116 self.assertEqual(p(), ((), {}))
117 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100118 p = self.partial(capture, a=1)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400119 self.assertEqual(p.keywords, {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000120 self.assertEqual(p(), ((), {'a':1}))
121 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
122 # keyword args in the call override those in the partial object
123 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
124
125 def test_positional(self):
126 # make sure positional arguments are captured correctly
127 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100128 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000129 expected = args + ('x',)
130 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000131 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000132
133 def test_keyword(self):
134 # make sure keyword arguments are captured correctly
135 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100136 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000137 expected = {'a':a,'x':None}
138 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000139 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000140
141 def test_no_side_effects(self):
142 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100143 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000144 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000145 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000146 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000147 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000148
149 def test_error_propagation(self):
150 def f(x, y):
151 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100152 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
153 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
154 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
155 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000156
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000157 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100158 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000159 p = proxy(f)
160 self.assertEqual(f.func, p.func)
161 f = None
162 self.assertRaises(ReferenceError, getattr, p, 'func')
163
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000164 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000165 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100166 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000167 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100168 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000169 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000170
Alexander Belopolskye49af342015-03-01 15:08:17 -0500171 def test_nested_optimization(self):
172 partial = self.partial
Alexander Belopolskye49af342015-03-01 15:08:17 -0500173 inner = partial(signature, 'asdf')
174 nested = partial(inner, bar=True)
175 flat = partial(signature, 'asdf', bar=True)
176 self.assertEqual(signature(nested), signature(flat))
177
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300178 def test_nested_partial_with_attribute(self):
179 # see issue 25137
180 partial = self.partial
181
182 def foo(bar):
183 return bar
184
185 p = partial(foo, 'first')
186 p2 = partial(p, 'second')
187 p2.new_attr = 'spam'
188 self.assertEqual(p2.new_attr, 'spam')
189
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000190 def test_repr(self):
191 args = (object(), object())
192 args_repr = ', '.join(repr(a) for a in args)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200193 kwargs = {'a': object(), 'b': object()}
194 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
195 'b={b!r}, a={a!r}'.format_map(kwargs)]
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000196 if self.partial in (c_functools.partial, py_functools.partial):
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000197 name = 'functools.partial'
198 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100199 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000200
Antoine Pitroub5b37142012-11-13 21:35:40 +0100201 f = self.partial(capture)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000202 self.assertEqual(f'{name}({capture!r})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000203
Antoine Pitroub5b37142012-11-13 21:35:40 +0100204 f = self.partial(capture, *args)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000205 self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000206
Antoine Pitroub5b37142012-11-13 21:35:40 +0100207 f = self.partial(capture, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200208 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000209 [f'{name}({capture!r}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200210 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000211
Antoine Pitroub5b37142012-11-13 21:35:40 +0100212 f = self.partial(capture, *args, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200213 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000214 [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200215 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000216
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300217 def test_recursive_repr(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000218 if self.partial in (c_functools.partial, py_functools.partial):
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300219 name = 'functools.partial'
220 else:
221 name = self.partial.__name__
222
223 f = self.partial(capture)
224 f.__setstate__((f, (), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300225 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000226 self.assertEqual(repr(f), '%s(...)' % (name,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300227 finally:
228 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300229
230 f = self.partial(capture)
231 f.__setstate__((capture, (f,), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300232 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000233 self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300234 finally:
235 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300236
237 f = self.partial(capture)
238 f.__setstate__((capture, (), {'a': f}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300239 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000240 self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300241 finally:
242 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300243
Jack Diederiche0cbd692009-04-01 04:27:09 +0000244 def test_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000245 with self.AllowPickle():
246 f = self.partial(signature, ['asdf'], bar=[True])
247 f.attr = []
248 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
249 f_copy = pickle.loads(pickle.dumps(f, proto))
250 self.assertEqual(signature(f_copy), signature(f))
Serhiy Storchaka38741282016-02-02 18:45:17 +0200251
252 def test_copy(self):
253 f = self.partial(signature, ['asdf'], bar=[True])
254 f.attr = []
255 f_copy = copy.copy(f)
256 self.assertEqual(signature(f_copy), signature(f))
257 self.assertIs(f_copy.attr, f.attr)
258 self.assertIs(f_copy.args, f.args)
259 self.assertIs(f_copy.keywords, f.keywords)
260
261 def test_deepcopy(self):
262 f = self.partial(signature, ['asdf'], bar=[True])
263 f.attr = []
264 f_copy = copy.deepcopy(f)
265 self.assertEqual(signature(f_copy), signature(f))
266 self.assertIsNot(f_copy.attr, f.attr)
267 self.assertIsNot(f_copy.args, f.args)
268 self.assertIsNot(f_copy.args[0], f.args[0])
269 self.assertIsNot(f_copy.keywords, f.keywords)
270 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
271
272 def test_setstate(self):
273 f = self.partial(signature)
274 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000275
Serhiy Storchaka38741282016-02-02 18:45:17 +0200276 self.assertEqual(signature(f),
277 (capture, (1,), dict(a=10), dict(attr=[])))
278 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
279
280 f.__setstate__((capture, (1,), dict(a=10), None))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000281
Serhiy Storchaka38741282016-02-02 18:45:17 +0200282 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
283 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
284
285 f.__setstate__((capture, (1,), None, None))
286 #self.assertEqual(signature(f), (capture, (1,), {}, {}))
287 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
288 self.assertEqual(f(2), ((1, 2), {}))
289 self.assertEqual(f(), ((1,), {}))
290
291 f.__setstate__((capture, (), {}, None))
292 self.assertEqual(signature(f), (capture, (), {}, {}))
293 self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
294 self.assertEqual(f(2), ((2,), {}))
295 self.assertEqual(f(), ((), {}))
296
297 def test_setstate_errors(self):
298 f = self.partial(signature)
299 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
300 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
301 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
302 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
303 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
304 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
305 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
306
307 def test_setstate_subclasses(self):
308 f = self.partial(signature)
309 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
310 s = signature(f)
311 self.assertEqual(s, (capture, (1,), dict(a=10), {}))
312 self.assertIs(type(s[1]), tuple)
313 self.assertIs(type(s[2]), dict)
314 r = f()
315 self.assertEqual(r, ((1,), {'a': 10}))
316 self.assertIs(type(r[0]), tuple)
317 self.assertIs(type(r[1]), dict)
318
319 f.__setstate__((capture, BadTuple((1,)), {}, None))
320 s = signature(f)
321 self.assertEqual(s, (capture, (1,), {}, {}))
322 self.assertIs(type(s[1]), tuple)
323 r = f(2)
324 self.assertEqual(r, ((1, 2), {}))
325 self.assertIs(type(r[0]), tuple)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000326
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300327 def test_recursive_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000328 with self.AllowPickle():
329 f = self.partial(capture)
330 f.__setstate__((f, (), {}, {}))
331 try:
332 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
333 with self.assertRaises(RecursionError):
334 pickle.dumps(f, proto)
335 finally:
336 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300337
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000338 f = self.partial(capture)
339 f.__setstate__((capture, (f,), {}, {}))
340 try:
341 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
342 f_copy = pickle.loads(pickle.dumps(f, proto))
343 try:
344 self.assertIs(f_copy.args[0], f_copy)
345 finally:
346 f_copy.__setstate__((capture, (), {}, {}))
347 finally:
348 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300349
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000350 f = self.partial(capture)
351 f.__setstate__((capture, (), {'a': f}, {}))
352 try:
353 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
354 f_copy = pickle.loads(pickle.dumps(f, proto))
355 try:
356 self.assertIs(f_copy.keywords['a'], f_copy)
357 finally:
358 f_copy.__setstate__((capture, (), {}, {}))
359 finally:
360 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300361
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200362 # Issue 6083: Reference counting bug
363 def test_setstate_refcount(self):
364 class BadSequence:
365 def __len__(self):
366 return 4
367 def __getitem__(self, key):
368 if key == 0:
369 return max
370 elif key == 1:
371 return tuple(range(1000000))
372 elif key in (2, 3):
373 return {}
374 raise IndexError
375
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200376 f = self.partial(object)
Serhiy Storchaka38741282016-02-02 18:45:17 +0200377 self.assertRaises(TypeError, f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000378
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000379@unittest.skipUnless(c_functools, 'requires the C _functools module')
380class TestPartialC(TestPartial, unittest.TestCase):
381 if c_functools:
382 partial = c_functools.partial
383
384 class AllowPickle:
385 def __enter__(self):
386 return self
387 def __exit__(self, type, value, tb):
388 return False
389
390 def test_attributes_unwritable(self):
391 # attributes should not be writable
392 p = self.partial(capture, 1, 2, a=10, b=20)
393 self.assertRaises(AttributeError, setattr, p, 'func', map)
394 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
395 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
396
397 p = self.partial(hex)
398 try:
399 del p.__dict__
400 except TypeError:
401 pass
402 else:
403 self.fail('partial object allowed __dict__ to be deleted')
Łukasz Langa6f692512013-06-05 12:20:24 +0200404
Michael Seifert6c3d5272017-03-15 06:26:33 +0100405 def test_manually_adding_non_string_keyword(self):
406 p = self.partial(capture)
407 # Adding a non-string/unicode keyword to partial kwargs
408 p.keywords[1234] = 'value'
409 r = repr(p)
410 self.assertIn('1234', r)
411 self.assertIn("'value'", r)
412 with self.assertRaises(TypeError):
413 p()
414
415 def test_keystr_replaces_value(self):
416 p = self.partial(capture)
417
418 class MutatesYourDict(object):
419 def __str__(self):
420 p.keywords[self] = ['sth2']
421 return 'astr'
422
Mike53f7a7c2017-12-14 14:04:53 +0300423 # Replacing the value during key formatting should keep the original
Michael Seifert6c3d5272017-03-15 06:26:33 +0100424 # value alive (at least long enough).
425 p.keywords[MutatesYourDict()] = ['sth']
426 r = repr(p)
427 self.assertIn('astr', r)
428 self.assertIn("['sth']", r)
429
430
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200431class TestPartialPy(TestPartial, unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000432 partial = py_functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000433
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000434 class AllowPickle:
435 def __init__(self):
436 self._cm = replaced_module("functools", py_functools)
437 def __enter__(self):
438 return self._cm.__enter__()
439 def __exit__(self, type, value, tb):
440 return self._cm.__exit__(type, value, tb)
Łukasz Langa6f692512013-06-05 12:20:24 +0200441
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200442if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000443 class CPartialSubclass(c_functools.partial):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200444 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100445
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000446class PyPartialSubclass(py_functools.partial):
447 pass
Łukasz Langa6f692512013-06-05 12:20:24 +0200448
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200449@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200450class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200451 if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000452 partial = CPartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000453
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300454 # partial subclasses are not optimized for nested calls
455 test_nested_optimization = None
456
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000457class TestPartialPySubclass(TestPartialPy):
458 partial = PyPartialSubclass
Łukasz Langa6f692512013-06-05 12:20:24 +0200459
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000460class TestPartialMethod(unittest.TestCase):
461
462 class A(object):
463 nothing = functools.partialmethod(capture)
464 positional = functools.partialmethod(capture, 1)
465 keywords = functools.partialmethod(capture, a=2)
466 both = functools.partialmethod(capture, 3, b=4)
Serhiy Storchakaa37f3562019-04-01 10:59:24 +0300467 spec_keywords = functools.partialmethod(capture, self=1, func=2)
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000468
469 nested = functools.partialmethod(positional, 5)
470
471 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
472
473 static = functools.partialmethod(staticmethod(capture), 8)
474 cls = functools.partialmethod(classmethod(capture), d=9)
475
476 a = A()
477
478 def test_arg_combinations(self):
479 self.assertEqual(self.a.nothing(), ((self.a,), {}))
480 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
481 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
482 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
483
484 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
485 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
486 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
487 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
488
489 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
490 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
491 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
492 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
493
494 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
495 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
496 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
497 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
498
499 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
500
Serhiy Storchakaa37f3562019-04-01 10:59:24 +0300501 self.assertEqual(self.a.spec_keywords(), ((self.a,), {'self': 1, 'func': 2}))
502
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000503 def test_nested(self):
504 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
505 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
506 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
507 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
508
509 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
510
511 def test_over_partial(self):
512 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
513 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
514 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
515 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
516
517 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
518
519 def test_bound_method_introspection(self):
520 obj = self.a
521 self.assertIs(obj.both.__self__, obj)
522 self.assertIs(obj.nested.__self__, obj)
523 self.assertIs(obj.over_partial.__self__, obj)
524 self.assertIs(obj.cls.__self__, self.A)
525 self.assertIs(self.A.cls.__self__, self.A)
526
527 def test_unbound_method_retrieval(self):
528 obj = self.A
529 self.assertFalse(hasattr(obj.both, "__self__"))
530 self.assertFalse(hasattr(obj.nested, "__self__"))
531 self.assertFalse(hasattr(obj.over_partial, "__self__"))
532 self.assertFalse(hasattr(obj.static, "__self__"))
533 self.assertFalse(hasattr(self.a.static, "__self__"))
534
535 def test_descriptors(self):
536 for obj in [self.A, self.a]:
537 with self.subTest(obj=obj):
538 self.assertEqual(obj.static(), ((8,), {}))
539 self.assertEqual(obj.static(5), ((8, 5), {}))
540 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
541 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
542
543 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
544 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
545 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
546 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
547
548 def test_overriding_keywords(self):
549 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
550 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
551
552 def test_invalid_args(self):
553 with self.assertRaises(TypeError):
554 class B(object):
555 method = functools.partialmethod(None, 1)
Serhiy Storchakaa37f3562019-04-01 10:59:24 +0300556 with self.assertRaises(TypeError):
557 class B:
558 method = functools.partialmethod()
559 class B:
560 method = functools.partialmethod(func=capture, a=1)
561 b = B()
562 self.assertEqual(b.method(2, x=3), ((b, 2), {'a': 1, 'x': 3}))
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000563
564 def test_repr(self):
565 self.assertEqual(repr(vars(self.A)['both']),
566 'functools.partialmethod({}, 3, b=4)'.format(capture))
567
568 def test_abstract(self):
569 class Abstract(abc.ABCMeta):
570
571 @abc.abstractmethod
572 def add(self, x, y):
573 pass
574
575 add5 = functools.partialmethod(add, 5)
576
577 self.assertTrue(Abstract.add.__isabstractmethod__)
578 self.assertTrue(Abstract.add5.__isabstractmethod__)
579
580 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
581 self.assertFalse(getattr(func, '__isabstractmethod__', False))
582
583
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000584class TestUpdateWrapper(unittest.TestCase):
585
586 def check_wrapper(self, wrapper, wrapped,
587 assigned=functools.WRAPPER_ASSIGNMENTS,
588 updated=functools.WRAPPER_UPDATES):
589 # Check attributes were assigned
590 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000591 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000592 # Check attributes were updated
593 for name in updated:
594 wrapper_attr = getattr(wrapper, name)
595 wrapped_attr = getattr(wrapped, name)
596 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000597 if name == "__dict__" and key == "__wrapped__":
598 # __wrapped__ is overwritten by the update code
599 continue
600 self.assertIs(wrapped_attr[key], wrapper_attr[key])
601 # Check __wrapped__
602 self.assertIs(wrapper.__wrapped__, wrapped)
603
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000604
R. David Murray378c0cf2010-02-24 01:46:21 +0000605 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000606 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000607 """This is a test"""
608 pass
609 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000610 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000611 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000612 pass
613 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000614 return wrapper, f
615
616 def test_default_update(self):
617 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000618 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000619 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000620 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600621 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000622 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000623 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
624 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000625
R. David Murray378c0cf2010-02-24 01:46:21 +0000626 @unittest.skipIf(sys.flags.optimize >= 2,
627 "Docstrings are omitted with -O2 and above")
628 def test_default_update_doc(self):
629 wrapper, f = self._default_update()
630 self.assertEqual(wrapper.__doc__, 'This is a test')
631
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000632 def test_no_update(self):
633 def f():
634 """This is a test"""
635 pass
636 f.attr = 'This is also a test'
637 def wrapper():
638 pass
639 functools.update_wrapper(wrapper, f, (), ())
640 self.check_wrapper(wrapper, f, (), ())
641 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600642 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000643 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000644 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000645 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000646
647 def test_selective_update(self):
648 def f():
649 pass
650 f.attr = 'This is a different test'
651 f.dict_attr = dict(a=1, b=2, c=3)
652 def wrapper():
653 pass
654 wrapper.dict_attr = {}
655 assign = ('attr',)
656 update = ('dict_attr',)
657 functools.update_wrapper(wrapper, f, assign, update)
658 self.check_wrapper(wrapper, f, assign, update)
659 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600660 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000661 self.assertEqual(wrapper.__doc__, None)
662 self.assertEqual(wrapper.attr, 'This is a different test')
663 self.assertEqual(wrapper.dict_attr, f.dict_attr)
664
Nick Coghlan98876832010-08-17 06:17:18 +0000665 def test_missing_attributes(self):
666 def f():
667 pass
668 def wrapper():
669 pass
670 wrapper.dict_attr = {}
671 assign = ('attr',)
672 update = ('dict_attr',)
673 # Missing attributes on wrapped object are ignored
674 functools.update_wrapper(wrapper, f, assign, update)
675 self.assertNotIn('attr', wrapper.__dict__)
676 self.assertEqual(wrapper.dict_attr, {})
677 # Wrapper must have expected attributes for updating
678 del wrapper.dict_attr
679 with self.assertRaises(AttributeError):
680 functools.update_wrapper(wrapper, f, assign, update)
681 wrapper.dict_attr = 1
682 with self.assertRaises(AttributeError):
683 functools.update_wrapper(wrapper, f, assign, update)
684
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200685 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000686 @unittest.skipIf(sys.flags.optimize >= 2,
687 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000688 def test_builtin_update(self):
689 # Test for bug #1576241
690 def wrapper():
691 pass
692 functools.update_wrapper(wrapper, max)
693 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000694 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000695 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000696
Łukasz Langa6f692512013-06-05 12:20:24 +0200697
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000698class TestWraps(TestUpdateWrapper):
699
R. David Murray378c0cf2010-02-24 01:46:21 +0000700 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000701 def f():
702 """This is a test"""
703 pass
704 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000705 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000706 @functools.wraps(f)
707 def wrapper():
708 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600709 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000710
711 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600712 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000713 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000714 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600715 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000716 self.assertEqual(wrapper.attr, 'This is also a test')
717
Antoine Pitroub5b37142012-11-13 21:35:40 +0100718 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000719 "Docstrings are omitted with -O2 and above")
720 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600721 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000722 self.assertEqual(wrapper.__doc__, 'This is a test')
723
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000724 def test_no_update(self):
725 def f():
726 """This is a test"""
727 pass
728 f.attr = 'This is also a test'
729 @functools.wraps(f, (), ())
730 def wrapper():
731 pass
732 self.check_wrapper(wrapper, f, (), ())
733 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600734 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000735 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000736 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000737
738 def test_selective_update(self):
739 def f():
740 pass
741 f.attr = 'This is a different test'
742 f.dict_attr = dict(a=1, b=2, c=3)
743 def add_dict_attr(f):
744 f.dict_attr = {}
745 return f
746 assign = ('attr',)
747 update = ('dict_attr',)
748 @functools.wraps(f, assign, update)
749 @add_dict_attr
750 def wrapper():
751 pass
752 self.check_wrapper(wrapper, f, assign, update)
753 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600754 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000755 self.assertEqual(wrapper.__doc__, None)
756 self.assertEqual(wrapper.attr, 'This is a different test')
757 self.assertEqual(wrapper.dict_attr, f.dict_attr)
758
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000759@unittest.skipUnless(c_functools, 'requires the C _functools module')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000760class TestReduce(unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000761 if c_functools:
762 func = c_functools.reduce
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000763
764 def test_reduce(self):
765 class Squares:
766 def __init__(self, max):
767 self.max = max
768 self.sofar = []
769
770 def __len__(self):
771 return len(self.sofar)
772
773 def __getitem__(self, i):
774 if not 0 <= i < self.max: raise IndexError
775 n = len(self.sofar)
776 while n <= i:
777 self.sofar.append(n*n)
778 n += 1
779 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000780 def add(x, y):
781 return x + y
782 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000783 self.assertEqual(
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000784 self.func(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000785 ['a','c','d','w']
786 )
787 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
788 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000789 self.func(lambda x, y: x*y, range(2,21), 1),
790 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000791 )
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000792 self.assertEqual(self.func(add, Squares(10)), 285)
793 self.assertEqual(self.func(add, Squares(10), 0), 285)
794 self.assertEqual(self.func(add, Squares(0), 0), 0)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000795 self.assertRaises(TypeError, self.func)
796 self.assertRaises(TypeError, self.func, 42, 42)
797 self.assertRaises(TypeError, self.func, 42, 42, 42)
798 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
799 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
800 self.assertRaises(TypeError, self.func, 42, (42, 42))
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000801 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
802 self.assertRaises(TypeError, self.func, add, "")
803 self.assertRaises(TypeError, self.func, add, ())
804 self.assertRaises(TypeError, self.func, add, object())
805
806 class TestFailingIter:
807 def __iter__(self):
808 raise RuntimeError
809 self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
810
811 self.assertEqual(self.func(add, [], None), None)
812 self.assertEqual(self.func(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000813
814 class BadSeq:
815 def __getitem__(self, index):
816 raise ValueError
817 self.assertRaises(ValueError, self.func, 42, BadSeq())
818
819 # Test reduce()'s use of iterators.
820 def test_iterator_usage(self):
821 class SequenceClass:
822 def __init__(self, n):
823 self.n = n
824 def __getitem__(self, i):
825 if 0 <= i < self.n:
826 return i
827 else:
828 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000829
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000830 from operator import add
831 self.assertEqual(self.func(add, SequenceClass(5)), 10)
832 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
833 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
834 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
835 self.assertEqual(self.func(add, SequenceClass(1)), 0)
836 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
837
838 d = {"one": 1, "two": 2, "three": 3}
839 self.assertEqual(self.func(add, d), "".join(d.keys()))
840
Łukasz Langa6f692512013-06-05 12:20:24 +0200841
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200842class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700843
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000844 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700845 def cmp1(x, y):
846 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100847 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700848 self.assertEqual(key(3), key(3))
849 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100850 self.assertGreaterEqual(key(3), key(3))
851
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700852 def cmp2(x, y):
853 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100854 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700855 self.assertEqual(key(4.0), key('4'))
856 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100857 self.assertLessEqual(key(2), key('35'))
858 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700859
860 def test_cmp_to_key_arguments(self):
861 def cmp1(x, y):
862 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100863 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700864 self.assertEqual(key(obj=3), key(obj=3))
865 self.assertGreater(key(obj=3), key(obj=1))
866 with self.assertRaises((TypeError, AttributeError)):
867 key(3) > 1 # rhs is not a K object
868 with self.assertRaises((TypeError, AttributeError)):
869 1 < key(3) # lhs is not a K object
870 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100871 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700872 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200873 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100874 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700875 with self.assertRaises(TypeError):
876 key() # too few args
877 with self.assertRaises(TypeError):
878 key(None, None) # too many args
879
880 def test_bad_cmp(self):
881 def cmp1(x, y):
882 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100883 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700884 with self.assertRaises(ZeroDivisionError):
885 key(3) > key(1)
886
887 class BadCmp:
888 def __lt__(self, other):
889 raise ZeroDivisionError
890 def cmp1(x, y):
891 return BadCmp()
892 with self.assertRaises(ZeroDivisionError):
893 key(3) > key(1)
894
895 def test_obj_field(self):
896 def cmp1(x, y):
897 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100898 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700899 self.assertEqual(key(50).obj, 50)
900
901 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000902 def mycmp(x, y):
903 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100904 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000905 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000906
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700907 def test_sort_int_str(self):
908 def mycmp(x, y):
909 x, y = int(x), int(y)
910 return (x > y) - (x < y)
911 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100912 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700913 self.assertEqual([int(value) for value in values],
914 [0, 1, 1, 2, 3, 4, 5, 7, 10])
915
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000916 def test_hash(self):
917 def mycmp(x, y):
918 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100919 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000920 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700921 self.assertRaises(TypeError, hash, k)
Serhiy Storchaka2e576f52017-04-24 09:05:00 +0300922 self.assertNotIsInstance(k, collections.abc.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000923
Łukasz Langa6f692512013-06-05 12:20:24 +0200924
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200925@unittest.skipUnless(c_functools, 'requires the C _functools module')
926class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
927 if c_functools:
928 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100929
Łukasz Langa6f692512013-06-05 12:20:24 +0200930
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200931class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100932 cmp_to_key = staticmethod(py_functools.cmp_to_key)
933
Łukasz Langa6f692512013-06-05 12:20:24 +0200934
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000935class TestTotalOrdering(unittest.TestCase):
936
937 def test_total_ordering_lt(self):
938 @functools.total_ordering
939 class A:
940 def __init__(self, value):
941 self.value = value
942 def __lt__(self, other):
943 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000944 def __eq__(self, other):
945 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000946 self.assertTrue(A(1) < A(2))
947 self.assertTrue(A(2) > A(1))
948 self.assertTrue(A(1) <= A(2))
949 self.assertTrue(A(2) >= A(1))
950 self.assertTrue(A(2) <= A(2))
951 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000952 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000953
954 def test_total_ordering_le(self):
955 @functools.total_ordering
956 class A:
957 def __init__(self, value):
958 self.value = value
959 def __le__(self, other):
960 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000961 def __eq__(self, other):
962 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000963 self.assertTrue(A(1) < A(2))
964 self.assertTrue(A(2) > A(1))
965 self.assertTrue(A(1) <= A(2))
966 self.assertTrue(A(2) >= A(1))
967 self.assertTrue(A(2) <= A(2))
968 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000969 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000970
971 def test_total_ordering_gt(self):
972 @functools.total_ordering
973 class A:
974 def __init__(self, value):
975 self.value = value
976 def __gt__(self, other):
977 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000978 def __eq__(self, other):
979 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000980 self.assertTrue(A(1) < A(2))
981 self.assertTrue(A(2) > A(1))
982 self.assertTrue(A(1) <= A(2))
983 self.assertTrue(A(2) >= A(1))
984 self.assertTrue(A(2) <= A(2))
985 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000986 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000987
988 def test_total_ordering_ge(self):
989 @functools.total_ordering
990 class A:
991 def __init__(self, value):
992 self.value = value
993 def __ge__(self, other):
994 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000995 def __eq__(self, other):
996 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000997 self.assertTrue(A(1) < A(2))
998 self.assertTrue(A(2) > A(1))
999 self.assertTrue(A(1) <= A(2))
1000 self.assertTrue(A(2) >= A(1))
1001 self.assertTrue(A(2) <= A(2))
1002 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +10001003 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +00001004
1005 def test_total_ordering_no_overwrite(self):
1006 # new methods should not overwrite existing
1007 @functools.total_ordering
1008 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +00001009 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +00001010 self.assertTrue(A(1) < A(2))
1011 self.assertTrue(A(2) > A(1))
1012 self.assertTrue(A(1) <= A(2))
1013 self.assertTrue(A(2) >= A(1))
1014 self.assertTrue(A(2) <= A(2))
1015 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001016
Benjamin Peterson42ebee32010-04-11 01:43:16 +00001017 def test_no_operations_defined(self):
1018 with self.assertRaises(ValueError):
1019 @functools.total_ordering
1020 class A:
1021 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001022
Nick Coghlanf05d9812013-10-02 00:02:03 +10001023 def test_type_error_when_not_implemented(self):
1024 # bug 10042; ensure stack overflow does not occur
1025 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001026 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +10001027 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001028 def __init__(self, value):
1029 self.value = value
1030 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001031 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001032 return self.value == other.value
1033 return False
1034 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001035 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001036 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +10001037 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001038
Nick Coghlanf05d9812013-10-02 00:02:03 +10001039 @functools.total_ordering
1040 class ImplementsGreaterThan:
1041 def __init__(self, value):
1042 self.value = value
1043 def __eq__(self, other):
1044 if isinstance(other, ImplementsGreaterThan):
1045 return self.value == other.value
1046 return False
1047 def __gt__(self, other):
1048 if isinstance(other, ImplementsGreaterThan):
1049 return self.value > other.value
1050 return NotImplemented
1051
1052 @functools.total_ordering
1053 class ImplementsLessThanEqualTo:
1054 def __init__(self, value):
1055 self.value = value
1056 def __eq__(self, other):
1057 if isinstance(other, ImplementsLessThanEqualTo):
1058 return self.value == other.value
1059 return False
1060 def __le__(self, other):
1061 if isinstance(other, ImplementsLessThanEqualTo):
1062 return self.value <= other.value
1063 return NotImplemented
1064
1065 @functools.total_ordering
1066 class ImplementsGreaterThanEqualTo:
1067 def __init__(self, value):
1068 self.value = value
1069 def __eq__(self, other):
1070 if isinstance(other, ImplementsGreaterThanEqualTo):
1071 return self.value == other.value
1072 return False
1073 def __ge__(self, other):
1074 if isinstance(other, ImplementsGreaterThanEqualTo):
1075 return self.value >= other.value
1076 return NotImplemented
1077
1078 @functools.total_ordering
1079 class ComparatorNotImplemented:
1080 def __init__(self, value):
1081 self.value = value
1082 def __eq__(self, other):
1083 if isinstance(other, ComparatorNotImplemented):
1084 return self.value == other.value
1085 return False
1086 def __lt__(self, other):
1087 return NotImplemented
1088
1089 with self.subTest("LT < 1"), self.assertRaises(TypeError):
1090 ImplementsLessThan(-1) < 1
1091
1092 with self.subTest("LT < LE"), self.assertRaises(TypeError):
1093 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
1094
1095 with self.subTest("LT < GT"), self.assertRaises(TypeError):
1096 ImplementsLessThan(1) < ImplementsGreaterThan(1)
1097
1098 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
1099 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
1100
1101 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
1102 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
1103
1104 with self.subTest("GT > GE"), self.assertRaises(TypeError):
1105 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
1106
1107 with self.subTest("GT > LT"), self.assertRaises(TypeError):
1108 ImplementsGreaterThan(5) > ImplementsLessThan(5)
1109
1110 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
1111 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
1112
1113 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
1114 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
1115
1116 with self.subTest("GE when equal"):
1117 a = ComparatorNotImplemented(8)
1118 b = ComparatorNotImplemented(8)
1119 self.assertEqual(a, b)
1120 with self.assertRaises(TypeError):
1121 a >= b
1122
1123 with self.subTest("LE when equal"):
1124 a = ComparatorNotImplemented(9)
1125 b = ComparatorNotImplemented(9)
1126 self.assertEqual(a, b)
1127 with self.assertRaises(TypeError):
1128 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +02001129
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001130 def test_pickle(self):
Serhiy Storchaka92bb90a2016-09-22 11:39:25 +03001131 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001132 for name in '__lt__', '__gt__', '__le__', '__ge__':
1133 with self.subTest(method=name, proto=proto):
1134 method = getattr(Orderable_LT, name)
1135 method_copy = pickle.loads(pickle.dumps(method, proto))
1136 self.assertIs(method_copy, method)
1137
1138@functools.total_ordering
1139class Orderable_LT:
1140 def __init__(self, value):
1141 self.value = value
1142 def __lt__(self, other):
1143 return self.value < other.value
1144 def __eq__(self, other):
1145 return self.value == other.value
1146
1147
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001148class TestLRU:
Georg Brandl2e7346a2010-07-31 18:09:23 +00001149
1150 def test_lru(self):
1151 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001152 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001153 f = self.module.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001154 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001155 self.assertEqual(maxsize, 20)
1156 self.assertEqual(currsize, 0)
1157 self.assertEqual(hits, 0)
1158 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001159
1160 domain = range(5)
1161 for i in range(1000):
1162 x, y = choice(domain), choice(domain)
1163 actual = f(x, y)
1164 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +00001165 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001166 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001167 self.assertTrue(hits > misses)
1168 self.assertEqual(hits + misses, 1000)
1169 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001170
Raymond Hettinger02566ec2010-09-04 22:46:06 +00001171 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +00001172 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001173 self.assertEqual(hits, 0)
1174 self.assertEqual(misses, 0)
1175 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001176 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001177 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001178 self.assertEqual(hits, 0)
1179 self.assertEqual(misses, 1)
1180 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001181
Nick Coghlan98876832010-08-17 06:17:18 +00001182 # Test bypassing the cache
1183 self.assertIs(f.__wrapped__, orig)
1184 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001185 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001186 self.assertEqual(hits, 0)
1187 self.assertEqual(misses, 1)
1188 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +00001189
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001190 # test size zero (which means "never-cache")
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001191 @self.module.lru_cache(0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001192 def f():
1193 nonlocal f_cnt
1194 f_cnt += 1
1195 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001196 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001197 f_cnt = 0
1198 for i in range(5):
1199 self.assertEqual(f(), 20)
1200 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001201 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001202 self.assertEqual(hits, 0)
1203 self.assertEqual(misses, 5)
1204 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001205
1206 # test size one
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001207 @self.module.lru_cache(1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001208 def f():
1209 nonlocal f_cnt
1210 f_cnt += 1
1211 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001212 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001213 f_cnt = 0
1214 for i in range(5):
1215 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001216 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001217 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001218 self.assertEqual(hits, 4)
1219 self.assertEqual(misses, 1)
1220 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001221
Raymond Hettingerf3098282010-08-15 03:30:45 +00001222 # test size two
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001223 @self.module.lru_cache(2)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001224 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001225 nonlocal f_cnt
1226 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +00001227 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +00001228 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001229 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +00001230 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1231 # * * * *
1232 self.assertEqual(f(x), x*10)
1233 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001234 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001235 self.assertEqual(hits, 12)
1236 self.assertEqual(misses, 4)
1237 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001238
Miss Islington (bot)b2b023c2019-01-26 00:23:40 -08001239 def test_lru_bug_35780(self):
1240 # C version of the lru_cache was not checking to see if
1241 # the user function call has already modified the cache
1242 # (this arises in recursive calls and in multi-threading).
1243 # This cause the cache to have orphan links not referenced
1244 # by the cache dictionary.
1245
1246 once = True # Modified by f(x) below
1247
1248 @self.module.lru_cache(maxsize=10)
1249 def f(x):
1250 nonlocal once
1251 rv = f'.{x}.'
1252 if x == 20 and once:
1253 once = False
1254 rv = f(x)
1255 return rv
1256
1257 # Fill the cache
1258 for x in range(15):
1259 self.assertEqual(f(x), f'.{x}.')
1260 self.assertEqual(f.cache_info().currsize, 10)
1261
1262 # Make a recursive call and make sure the cache remains full
1263 self.assertEqual(f(20), '.20.')
1264 self.assertEqual(f.cache_info().currsize, 10)
1265
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001266 def test_lru_hash_only_once(self):
1267 # To protect against weird reentrancy bugs and to improve
1268 # efficiency when faced with slow __hash__ methods, the
1269 # LRU cache guarantees that it will only call __hash__
1270 # only once per use as an argument to the cached function.
1271
1272 @self.module.lru_cache(maxsize=1)
1273 def f(x, y):
1274 return x * 3 + y
1275
1276 # Simulate the integer 5
1277 mock_int = unittest.mock.Mock()
1278 mock_int.__mul__ = unittest.mock.Mock(return_value=15)
1279 mock_int.__hash__ = unittest.mock.Mock(return_value=999)
1280
1281 # Add to cache: One use as an argument gives one call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001282 self.assertEqual(f(mock_int, 1), 16)
1283 self.assertEqual(mock_int.__hash__.call_count, 1)
1284 self.assertEqual(f.cache_info(), (0, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001285
1286 # Cache hit: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001287 self.assertEqual(f(mock_int, 1), 16)
1288 self.assertEqual(mock_int.__hash__.call_count, 2)
1289 self.assertEqual(f.cache_info(), (1, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001290
Ville Skyttä49b27342017-08-03 09:00:59 +03001291 # Cache eviction: No use as an argument gives no additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001292 self.assertEqual(f(6, 2), 20)
1293 self.assertEqual(mock_int.__hash__.call_count, 2)
1294 self.assertEqual(f.cache_info(), (1, 2, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001295
1296 # Cache miss: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001297 self.assertEqual(f(mock_int, 1), 16)
1298 self.assertEqual(mock_int.__hash__.call_count, 3)
1299 self.assertEqual(f.cache_info(), (1, 3, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001300
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08001301 def test_lru_reentrancy_with_len(self):
1302 # Test to make sure the LRU cache code isn't thrown-off by
1303 # caching the built-in len() function. Since len() can be
1304 # cached, we shouldn't use it inside the lru code itself.
1305 old_len = builtins.len
1306 try:
1307 builtins.len = self.module.lru_cache(4)(len)
1308 for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
1309 self.assertEqual(len('abcdefghijklmn'[:i]), i)
1310 finally:
1311 builtins.len = old_len
1312
Raymond Hettinger605a4472017-01-09 07:50:19 -08001313 def test_lru_star_arg_handling(self):
1314 # Test regression that arose in ea064ff3c10f
1315 @functools.lru_cache()
1316 def f(*args):
1317 return args
1318
1319 self.assertEqual(f(1, 2), (1, 2))
1320 self.assertEqual(f((1, 2)), ((1, 2),))
1321
Yury Selivanov46a02db2016-11-09 18:55:45 -05001322 def test_lru_type_error(self):
1323 # Regression test for issue #28653.
1324 # lru_cache was leaking when one of the arguments
1325 # wasn't cacheable.
1326
1327 @functools.lru_cache(maxsize=None)
1328 def infinite_cache(o):
1329 pass
1330
1331 @functools.lru_cache(maxsize=10)
1332 def limited_cache(o):
1333 pass
1334
1335 with self.assertRaises(TypeError):
1336 infinite_cache([])
1337
1338 with self.assertRaises(TypeError):
1339 limited_cache([])
1340
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001341 def test_lru_with_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001342 @self.module.lru_cache(maxsize=None)
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001343 def fib(n):
1344 if n < 2:
1345 return n
1346 return fib(n-1) + fib(n-2)
1347 self.assertEqual([fib(n) for n in range(16)],
1348 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1349 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001350 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001351 fib.cache_clear()
1352 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001353 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1354
1355 def test_lru_with_maxsize_negative(self):
1356 @self.module.lru_cache(maxsize=-10)
1357 def eq(n):
1358 return n
1359 for i in (0, 1):
1360 self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1361 self.assertEqual(eq.cache_info(),
Miss Islington (bot)b2b023c2019-01-26 00:23:40 -08001362 self.module._CacheInfo(hits=0, misses=300, maxsize=0, currsize=0))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001363
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001364 def test_lru_with_exceptions(self):
1365 # Verify that user_function exceptions get passed through without
1366 # creating a hard-to-read chained exception.
1367 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +01001368 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001369 @self.module.lru_cache(maxsize)
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001370 def func(i):
1371 return 'abc'[i]
1372 self.assertEqual(func(0), 'a')
1373 with self.assertRaises(IndexError) as cm:
1374 func(15)
1375 self.assertIsNone(cm.exception.__context__)
1376 # Verify that the previous exception did not result in a cached entry
1377 with self.assertRaises(IndexError):
1378 func(15)
1379
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001380 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001381 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001382 @self.module.lru_cache(maxsize=maxsize, typed=True)
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001383 def square(x):
1384 return x * x
1385 self.assertEqual(square(3), 9)
1386 self.assertEqual(type(square(3)), type(9))
1387 self.assertEqual(square(3.0), 9.0)
1388 self.assertEqual(type(square(3.0)), type(9.0))
1389 self.assertEqual(square(x=3), 9)
1390 self.assertEqual(type(square(x=3)), type(9))
1391 self.assertEqual(square(x=3.0), 9.0)
1392 self.assertEqual(type(square(x=3.0)), type(9.0))
1393 self.assertEqual(square.cache_info().hits, 4)
1394 self.assertEqual(square.cache_info().misses, 4)
1395
Antoine Pitroub5b37142012-11-13 21:35:40 +01001396 def test_lru_with_keyword_args(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001397 @self.module.lru_cache()
Antoine Pitroub5b37142012-11-13 21:35:40 +01001398 def fib(n):
1399 if n < 2:
1400 return n
1401 return fib(n=n-1) + fib(n=n-2)
1402 self.assertEqual(
1403 [fib(n=number) for number in range(16)],
1404 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1405 )
1406 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001407 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001408 fib.cache_clear()
1409 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001410 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001411
1412 def test_lru_with_keyword_args_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001413 @self.module.lru_cache(maxsize=None)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001414 def fib(n):
1415 if n < 2:
1416 return n
1417 return fib(n=n-1) + fib(n=n-2)
1418 self.assertEqual([fib(n=number) for number in range(16)],
1419 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1420 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001421 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001422 fib.cache_clear()
1423 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001424 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1425
Raymond Hettinger4ee39142017-01-08 17:28:20 -08001426 def test_kwargs_order(self):
1427 # PEP 468: Preserving Keyword Argument Order
1428 @self.module.lru_cache(maxsize=10)
1429 def f(**kwargs):
1430 return list(kwargs.items())
1431 self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)])
1432 self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)])
1433 self.assertEqual(f.cache_info(),
1434 self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2))
1435
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001436 def test_lru_cache_decoration(self):
1437 def f(zomg: 'zomg_annotation'):
1438 """f doc string"""
1439 return 42
1440 g = self.module.lru_cache()(f)
1441 for attr in self.module.WRAPPER_ASSIGNMENTS:
1442 self.assertEqual(getattr(g, attr), getattr(f, attr))
1443
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001444 def test_lru_cache_threaded(self):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001445 n, m = 5, 11
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001446 def orig(x, y):
1447 return 3 * x + y
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001448 f = self.module.lru_cache(maxsize=n*m)(orig)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001449 hits, misses, maxsize, currsize = f.cache_info()
1450 self.assertEqual(currsize, 0)
1451
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001452 start = threading.Event()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001453 def full(k):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001454 start.wait(10)
1455 for _ in range(m):
Serhiy Storchaka391af752015-06-08 12:44:18 +03001456 self.assertEqual(f(k, 0), orig(k, 0))
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001457
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001458 def clear():
1459 start.wait(10)
1460 for _ in range(2*m):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001461 f.cache_clear()
1462
1463 orig_si = sys.getswitchinterval()
Xavier de Gaye7522ef42016-12-08 11:06:56 +01001464 support.setswitchinterval(1e-6)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001465 try:
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001466 # create n threads in order to fill cache
Serhiy Storchaka391af752015-06-08 12:44:18 +03001467 threads = [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001468 for k in range(n)]
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001469 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001470 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001471
1472 hits, misses, maxsize, currsize = f.cache_info()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001473 if self.module is py_functools:
1474 # XXX: Why can be not equal?
1475 self.assertLessEqual(misses, n)
1476 self.assertLessEqual(hits, m*n - misses)
1477 else:
1478 self.assertEqual(misses, n)
1479 self.assertEqual(hits, m*n - misses)
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001480 self.assertEqual(currsize, n)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001481
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001482 # create n threads in order to fill cache and 1 to clear it
1483 threads = [threading.Thread(target=clear)]
Serhiy Storchaka391af752015-06-08 12:44:18 +03001484 threads += [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001485 for k in range(n)]
1486 start.clear()
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001487 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001488 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001489 finally:
1490 sys.setswitchinterval(orig_si)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001491
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001492 def test_lru_cache_threaded2(self):
1493 # Simultaneous call with the same arguments
1494 n, m = 5, 7
1495 start = threading.Barrier(n+1)
1496 pause = threading.Barrier(n+1)
1497 stop = threading.Barrier(n+1)
1498 @self.module.lru_cache(maxsize=m*n)
1499 def f(x):
1500 pause.wait(10)
1501 return 3 * x
1502 self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1503 def test():
1504 for i in range(m):
1505 start.wait(10)
1506 self.assertEqual(f(i), 3 * i)
1507 stop.wait(10)
1508 threads = [threading.Thread(target=test) for k in range(n)]
1509 with support.start_threads(threads):
1510 for i in range(m):
1511 start.wait(10)
1512 stop.reset()
1513 pause.wait(10)
1514 start.reset()
1515 stop.wait(10)
1516 pause.reset()
1517 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1518
Serhiy Storchaka67796522017-01-12 18:34:33 +02001519 def test_lru_cache_threaded3(self):
1520 @self.module.lru_cache(maxsize=2)
1521 def f(x):
1522 time.sleep(.01)
1523 return 3 * x
1524 def test(i, x):
1525 with self.subTest(thread=i):
1526 self.assertEqual(f(x), 3 * x, i)
1527 threads = [threading.Thread(target=test, args=(i, v))
1528 for i, v in enumerate([1, 2, 2, 3, 2])]
1529 with support.start_threads(threads):
1530 pass
1531
Raymond Hettinger03923422013-03-04 02:52:50 -05001532 def test_need_for_rlock(self):
1533 # This will deadlock on an LRU cache that uses a regular lock
1534
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001535 @self.module.lru_cache(maxsize=10)
Raymond Hettinger03923422013-03-04 02:52:50 -05001536 def test_func(x):
1537 'Used to demonstrate a reentrant lru_cache call within a single thread'
1538 return x
1539
1540 class DoubleEq:
1541 'Demonstrate a reentrant lru_cache call within a single thread'
1542 def __init__(self, x):
1543 self.x = x
1544 def __hash__(self):
1545 return self.x
1546 def __eq__(self, other):
1547 if self.x == 2:
1548 test_func(DoubleEq(1))
1549 return self.x == other.x
1550
1551 test_func(DoubleEq(1)) # Load the cache
1552 test_func(DoubleEq(2)) # Load the cache
1553 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1554 DoubleEq(2)) # Verify the correct return value
1555
Raymond Hettinger4d588972014-08-12 12:44:52 -07001556 def test_early_detection_of_bad_call(self):
1557 # Issue #22184
1558 with self.assertRaises(TypeError):
1559 @functools.lru_cache
1560 def f():
1561 pass
1562
Serhiy Storchakae7070f02015-06-08 11:19:24 +03001563 def test_lru_method(self):
1564 class X(int):
1565 f_cnt = 0
1566 @self.module.lru_cache(2)
1567 def f(self, x):
1568 self.f_cnt += 1
1569 return x*10+self
1570 a = X(5)
1571 b = X(5)
1572 c = X(7)
1573 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1574
1575 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1576 self.assertEqual(a.f(x), x*10 + 5)
1577 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1578 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1579
1580 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1581 self.assertEqual(b.f(x), x*10 + 5)
1582 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1583 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1584
1585 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1586 self.assertEqual(c.f(x), x*10 + 7)
1587 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1588 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1589
1590 self.assertEqual(a.f.cache_info(), X.f.cache_info())
1591 self.assertEqual(b.f.cache_info(), X.f.cache_info())
1592 self.assertEqual(c.f.cache_info(), X.f.cache_info())
1593
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001594 def test_pickle(self):
1595 cls = self.__class__
1596 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1597 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1598 with self.subTest(proto=proto, func=f):
1599 f_copy = pickle.loads(pickle.dumps(f, proto))
1600 self.assertIs(f_copy, f)
1601
1602 def test_copy(self):
1603 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001604 def orig(x, y):
1605 return 3 * x + y
1606 part = self.module.partial(orig, 2)
1607 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1608 self.module.lru_cache(2)(part))
1609 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001610 with self.subTest(func=f):
1611 f_copy = copy.copy(f)
1612 self.assertIs(f_copy, f)
1613
1614 def test_deepcopy(self):
1615 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001616 def orig(x, y):
1617 return 3 * x + y
1618 part = self.module.partial(orig, 2)
1619 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1620 self.module.lru_cache(2)(part))
1621 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001622 with self.subTest(func=f):
1623 f_copy = copy.deepcopy(f)
1624 self.assertIs(f_copy, f)
1625
1626
1627@py_functools.lru_cache()
1628def py_cached_func(x, y):
1629 return 3 * x + y
1630
1631@c_functools.lru_cache()
1632def c_cached_func(x, y):
1633 return 3 * x + y
1634
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001635
1636class TestLRUPy(TestLRU, unittest.TestCase):
1637 module = py_functools
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001638 cached_func = py_cached_func,
1639
1640 @module.lru_cache()
1641 def cached_meth(self, x, y):
1642 return 3 * x + y
1643
1644 @staticmethod
1645 @module.lru_cache()
1646 def cached_staticmeth(x, y):
1647 return 3 * x + y
1648
1649
1650class TestLRUC(TestLRU, unittest.TestCase):
1651 module = c_functools
1652 cached_func = c_cached_func,
1653
1654 @module.lru_cache()
1655 def cached_meth(self, x, y):
1656 return 3 * x + y
1657
1658 @staticmethod
1659 @module.lru_cache()
1660 def cached_staticmeth(x, y):
1661 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001662
Raymond Hettinger03923422013-03-04 02:52:50 -05001663
Łukasz Langa6f692512013-06-05 12:20:24 +02001664class TestSingleDispatch(unittest.TestCase):
1665 def test_simple_overloads(self):
1666 @functools.singledispatch
1667 def g(obj):
1668 return "base"
1669 def g_int(i):
1670 return "integer"
1671 g.register(int, g_int)
1672 self.assertEqual(g("str"), "base")
1673 self.assertEqual(g(1), "integer")
1674 self.assertEqual(g([1,2,3]), "base")
1675
1676 def test_mro(self):
1677 @functools.singledispatch
1678 def g(obj):
1679 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001680 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001681 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001682 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001683 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001684 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001685 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001686 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001687 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001688 def g_A(a):
1689 return "A"
1690 def g_B(b):
1691 return "B"
1692 g.register(A, g_A)
1693 g.register(B, g_B)
1694 self.assertEqual(g(A()), "A")
1695 self.assertEqual(g(B()), "B")
1696 self.assertEqual(g(C()), "A")
1697 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001698
1699 def test_register_decorator(self):
1700 @functools.singledispatch
1701 def g(obj):
1702 return "base"
1703 @g.register(int)
1704 def g_int(i):
1705 return "int %s" % (i,)
1706 self.assertEqual(g(""), "base")
1707 self.assertEqual(g(12), "int 12")
1708 self.assertIs(g.dispatch(int), g_int)
1709 self.assertIs(g.dispatch(object), g.dispatch(str))
1710 # Note: in the assert above this is not g.
1711 # @singledispatch returns the wrapper.
1712
1713 def test_wrapping_attributes(self):
1714 @functools.singledispatch
1715 def g(obj):
1716 "Simple test"
1717 return "Test"
1718 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02001719 if sys.flags.optimize < 2:
1720 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02001721
1722 @unittest.skipUnless(decimal, 'requires _decimal')
1723 @support.cpython_only
1724 def test_c_classes(self):
1725 @functools.singledispatch
1726 def g(obj):
1727 return "base"
1728 @g.register(decimal.DecimalException)
1729 def _(obj):
1730 return obj.args
1731 subn = decimal.Subnormal("Exponent < Emin")
1732 rnd = decimal.Rounded("Number got rounded")
1733 self.assertEqual(g(subn), ("Exponent < Emin",))
1734 self.assertEqual(g(rnd), ("Number got rounded",))
1735 @g.register(decimal.Subnormal)
1736 def _(obj):
1737 return "Too small to care."
1738 self.assertEqual(g(subn), "Too small to care.")
1739 self.assertEqual(g(rnd), ("Number got rounded",))
1740
1741 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001742 # None of the examples in this test depend on haystack ordering.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001743 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001744 mro = functools._compose_mro
1745 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1746 for haystack in permutations(bases):
1747 m = mro(dict, haystack)
Guido van Rossumf0666942016-08-23 10:47:07 -07001748 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
1749 c.Collection, c.Sized, c.Iterable,
1750 c.Container, object])
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001751 bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict]
Łukasz Langa6f692512013-06-05 12:20:24 +02001752 for haystack in permutations(bases):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001753 m = mro(collections.ChainMap, haystack)
1754 self.assertEqual(m, [collections.ChainMap, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001755 c.Collection, c.Sized, c.Iterable,
1756 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001757
1758 # If there's a generic function with implementations registered for
1759 # both Sized and Container, passing a defaultdict to it results in an
1760 # ambiguous dispatch which will cause a RuntimeError (see
1761 # test_mro_conflicts).
1762 bases = [c.Container, c.Sized, str]
1763 for haystack in permutations(bases):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001764 m = mro(collections.defaultdict, [c.Sized, c.Container, str])
1765 self.assertEqual(m, [collections.defaultdict, dict, c.Sized,
1766 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001767
1768 # MutableSequence below is registered directly on D. In other words, it
Martin Panter46f50722016-05-26 05:35:26 +00001769 # precedes MutableMapping which means single dispatch will always
Łukasz Langa3720c772013-07-01 16:00:38 +02001770 # choose MutableSequence here.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001771 class D(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02001772 pass
1773 c.MutableSequence.register(D)
1774 bases = [c.MutableSequence, c.MutableMapping]
1775 for haystack in permutations(bases):
1776 m = mro(D, bases)
Guido van Rossumf0666942016-08-23 10:47:07 -07001777 self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001778 collections.defaultdict, dict, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001779 c.Collection, c.Sized, c.Iterable, c.Container,
Łukasz Langa3720c772013-07-01 16:00:38 +02001780 object])
1781
1782 # Container and Callable are registered on different base classes and
1783 # a generic function supporting both should always pick the Callable
1784 # implementation if a C instance is passed.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001785 class C(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02001786 def __call__(self):
1787 pass
1788 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1789 for haystack in permutations(bases):
1790 m = mro(C, haystack)
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001791 self.assertEqual(m, [C, c.Callable, collections.defaultdict, dict, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001792 c.Collection, c.Sized, c.Iterable,
1793 c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001794
1795 def test_register_abc(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001796 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001797 d = {"a": "b"}
1798 l = [1, 2, 3]
1799 s = {object(), None}
1800 f = frozenset(s)
1801 t = (1, 2, 3)
1802 @functools.singledispatch
1803 def g(obj):
1804 return "base"
1805 self.assertEqual(g(d), "base")
1806 self.assertEqual(g(l), "base")
1807 self.assertEqual(g(s), "base")
1808 self.assertEqual(g(f), "base")
1809 self.assertEqual(g(t), "base")
1810 g.register(c.Sized, lambda obj: "sized")
1811 self.assertEqual(g(d), "sized")
1812 self.assertEqual(g(l), "sized")
1813 self.assertEqual(g(s), "sized")
1814 self.assertEqual(g(f), "sized")
1815 self.assertEqual(g(t), "sized")
1816 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1817 self.assertEqual(g(d), "mutablemapping")
1818 self.assertEqual(g(l), "sized")
1819 self.assertEqual(g(s), "sized")
1820 self.assertEqual(g(f), "sized")
1821 self.assertEqual(g(t), "sized")
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001822 g.register(collections.ChainMap, lambda obj: "chainmap")
Łukasz Langa6f692512013-06-05 12:20:24 +02001823 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1824 self.assertEqual(g(l), "sized")
1825 self.assertEqual(g(s), "sized")
1826 self.assertEqual(g(f), "sized")
1827 self.assertEqual(g(t), "sized")
1828 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1829 self.assertEqual(g(d), "mutablemapping")
1830 self.assertEqual(g(l), "mutablesequence")
1831 self.assertEqual(g(s), "sized")
1832 self.assertEqual(g(f), "sized")
1833 self.assertEqual(g(t), "sized")
1834 g.register(c.MutableSet, lambda obj: "mutableset")
1835 self.assertEqual(g(d), "mutablemapping")
1836 self.assertEqual(g(l), "mutablesequence")
1837 self.assertEqual(g(s), "mutableset")
1838 self.assertEqual(g(f), "sized")
1839 self.assertEqual(g(t), "sized")
1840 g.register(c.Mapping, lambda obj: "mapping")
1841 self.assertEqual(g(d), "mutablemapping") # not specific enough
1842 self.assertEqual(g(l), "mutablesequence")
1843 self.assertEqual(g(s), "mutableset")
1844 self.assertEqual(g(f), "sized")
1845 self.assertEqual(g(t), "sized")
1846 g.register(c.Sequence, lambda obj: "sequence")
1847 self.assertEqual(g(d), "mutablemapping")
1848 self.assertEqual(g(l), "mutablesequence")
1849 self.assertEqual(g(s), "mutableset")
1850 self.assertEqual(g(f), "sized")
1851 self.assertEqual(g(t), "sequence")
1852 g.register(c.Set, lambda obj: "set")
1853 self.assertEqual(g(d), "mutablemapping")
1854 self.assertEqual(g(l), "mutablesequence")
1855 self.assertEqual(g(s), "mutableset")
1856 self.assertEqual(g(f), "set")
1857 self.assertEqual(g(t), "sequence")
1858 g.register(dict, lambda obj: "dict")
1859 self.assertEqual(g(d), "dict")
1860 self.assertEqual(g(l), "mutablesequence")
1861 self.assertEqual(g(s), "mutableset")
1862 self.assertEqual(g(f), "set")
1863 self.assertEqual(g(t), "sequence")
1864 g.register(list, lambda obj: "list")
1865 self.assertEqual(g(d), "dict")
1866 self.assertEqual(g(l), "list")
1867 self.assertEqual(g(s), "mutableset")
1868 self.assertEqual(g(f), "set")
1869 self.assertEqual(g(t), "sequence")
1870 g.register(set, lambda obj: "concrete-set")
1871 self.assertEqual(g(d), "dict")
1872 self.assertEqual(g(l), "list")
1873 self.assertEqual(g(s), "concrete-set")
1874 self.assertEqual(g(f), "set")
1875 self.assertEqual(g(t), "sequence")
1876 g.register(frozenset, lambda obj: "frozen-set")
1877 self.assertEqual(g(d), "dict")
1878 self.assertEqual(g(l), "list")
1879 self.assertEqual(g(s), "concrete-set")
1880 self.assertEqual(g(f), "frozen-set")
1881 self.assertEqual(g(t), "sequence")
1882 g.register(tuple, lambda obj: "tuple")
1883 self.assertEqual(g(d), "dict")
1884 self.assertEqual(g(l), "list")
1885 self.assertEqual(g(s), "concrete-set")
1886 self.assertEqual(g(f), "frozen-set")
1887 self.assertEqual(g(t), "tuple")
1888
Łukasz Langa3720c772013-07-01 16:00:38 +02001889 def test_c3_abc(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001890 c = collections.abc
Łukasz Langa3720c772013-07-01 16:00:38 +02001891 mro = functools._c3_mro
1892 class A(object):
1893 pass
1894 class B(A):
1895 def __len__(self):
1896 return 0 # implies Sized
1897 @c.Container.register
1898 class C(object):
1899 pass
1900 class D(object):
1901 pass # unrelated
1902 class X(D, C, B):
1903 def __call__(self):
1904 pass # implies Callable
1905 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
1906 for abcs in permutations([c.Sized, c.Callable, c.Container]):
1907 self.assertEqual(mro(X, abcs=abcs), expected)
1908 # unrelated ABCs don't appear in the resulting MRO
1909 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
1910 self.assertEqual(mro(X, abcs=many_abcs), expected)
1911
Yury Selivanov77a8cd62015-08-18 14:20:00 -04001912 def test_false_meta(self):
1913 # see issue23572
1914 class MetaA(type):
1915 def __len__(self):
1916 return 0
1917 class A(metaclass=MetaA):
1918 pass
1919 class AA(A):
1920 pass
1921 @functools.singledispatch
1922 def fun(a):
1923 return 'base A'
1924 @fun.register(A)
1925 def _(a):
1926 return 'fun A'
1927 aa = AA()
1928 self.assertEqual(fun(aa), 'fun A')
1929
Łukasz Langa6f692512013-06-05 12:20:24 +02001930 def test_mro_conflicts(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001931 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001932 @functools.singledispatch
1933 def g(arg):
1934 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02001935 class O(c.Sized):
1936 def __len__(self):
1937 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001938 o = O()
1939 self.assertEqual(g(o), "base")
1940 g.register(c.Iterable, lambda arg: "iterable")
1941 g.register(c.Container, lambda arg: "container")
1942 g.register(c.Sized, lambda arg: "sized")
1943 g.register(c.Set, lambda arg: "set")
1944 self.assertEqual(g(o), "sized")
1945 c.Iterable.register(O)
1946 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
1947 c.Container.register(O)
1948 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02001949 c.Set.register(O)
1950 self.assertEqual(g(o), "set") # because c.Set is a subclass of
1951 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02001952 class P:
1953 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02001954 p = P()
1955 self.assertEqual(g(p), "base")
1956 c.Iterable.register(P)
1957 self.assertEqual(g(p), "iterable")
1958 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02001959 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02001960 g(p)
Benjamin Petersonab078e92016-07-13 21:13:29 -07001961 self.assertIn(
1962 str(re_one.exception),
1963 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1964 "or <class 'collections.abc.Iterable'>"),
1965 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
1966 "or <class 'collections.abc.Container'>")),
1967 )
Łukasz Langa6f692512013-06-05 12:20:24 +02001968 class Q(c.Sized):
1969 def __len__(self):
1970 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001971 q = Q()
1972 self.assertEqual(g(q), "sized")
1973 c.Iterable.register(Q)
1974 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
1975 c.Set.register(Q)
1976 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02001977 # c.Sized and c.Iterable
1978 @functools.singledispatch
1979 def h(arg):
1980 return "base"
1981 @h.register(c.Sized)
1982 def _(arg):
1983 return "sized"
1984 @h.register(c.Container)
1985 def _(arg):
1986 return "container"
1987 # Even though Sized and Container are explicit bases of MutableMapping,
1988 # this ABC is implicitly registered on defaultdict which makes all of
1989 # MutableMapping's bases implicit as well from defaultdict's
1990 # perspective.
1991 with self.assertRaises(RuntimeError) as re_two:
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001992 h(collections.defaultdict(lambda: 0))
Benjamin Petersonab078e92016-07-13 21:13:29 -07001993 self.assertIn(
1994 str(re_two.exception),
1995 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1996 "or <class 'collections.abc.Sized'>"),
1997 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1998 "or <class 'collections.abc.Container'>")),
1999 )
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002000 class R(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02002001 pass
2002 c.MutableSequence.register(R)
2003 @functools.singledispatch
2004 def i(arg):
2005 return "base"
2006 @i.register(c.MutableMapping)
2007 def _(arg):
2008 return "mapping"
2009 @i.register(c.MutableSequence)
2010 def _(arg):
2011 return "sequence"
2012 r = R()
2013 self.assertEqual(i(r), "sequence")
2014 class S:
2015 pass
2016 class T(S, c.Sized):
2017 def __len__(self):
2018 return 0
2019 t = T()
2020 self.assertEqual(h(t), "sized")
2021 c.Container.register(T)
2022 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
2023 class U:
2024 def __len__(self):
2025 return 0
2026 u = U()
2027 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
2028 # from the existence of __len__()
2029 c.Container.register(U)
2030 # There is no preference for registered versus inferred ABCs.
2031 with self.assertRaises(RuntimeError) as re_three:
2032 h(u)
Benjamin Petersonab078e92016-07-13 21:13:29 -07002033 self.assertIn(
2034 str(re_three.exception),
2035 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2036 "or <class 'collections.abc.Sized'>"),
2037 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2038 "or <class 'collections.abc.Container'>")),
2039 )
Łukasz Langa3720c772013-07-01 16:00:38 +02002040 class V(c.Sized, S):
2041 def __len__(self):
2042 return 0
2043 @functools.singledispatch
2044 def j(arg):
2045 return "base"
2046 @j.register(S)
2047 def _(arg):
2048 return "s"
2049 @j.register(c.Container)
2050 def _(arg):
2051 return "container"
2052 v = V()
2053 self.assertEqual(j(v), "s")
2054 c.Container.register(V)
2055 self.assertEqual(j(v), "container") # because it ends up right after
2056 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02002057
2058 def test_cache_invalidation(self):
2059 from collections import UserDict
INADA Naoki9811e802017-09-30 16:13:02 +09002060 import weakref
2061
Łukasz Langa6f692512013-06-05 12:20:24 +02002062 class TracingDict(UserDict):
2063 def __init__(self, *args, **kwargs):
2064 super(TracingDict, self).__init__(*args, **kwargs)
2065 self.set_ops = []
2066 self.get_ops = []
2067 def __getitem__(self, key):
2068 result = self.data[key]
2069 self.get_ops.append(key)
2070 return result
2071 def __setitem__(self, key, value):
2072 self.set_ops.append(key)
2073 self.data[key] = value
2074 def clear(self):
2075 self.data.clear()
INADA Naoki9811e802017-09-30 16:13:02 +09002076
Łukasz Langa6f692512013-06-05 12:20:24 +02002077 td = TracingDict()
INADA Naoki9811e802017-09-30 16:13:02 +09002078 with support.swap_attr(weakref, "WeakKeyDictionary", lambda: td):
2079 c = collections.abc
2080 @functools.singledispatch
2081 def g(arg):
2082 return "base"
2083 d = {}
2084 l = []
2085 self.assertEqual(len(td), 0)
2086 self.assertEqual(g(d), "base")
2087 self.assertEqual(len(td), 1)
2088 self.assertEqual(td.get_ops, [])
2089 self.assertEqual(td.set_ops, [dict])
2090 self.assertEqual(td.data[dict], g.registry[object])
2091 self.assertEqual(g(l), "base")
2092 self.assertEqual(len(td), 2)
2093 self.assertEqual(td.get_ops, [])
2094 self.assertEqual(td.set_ops, [dict, list])
2095 self.assertEqual(td.data[dict], g.registry[object])
2096 self.assertEqual(td.data[list], g.registry[object])
2097 self.assertEqual(td.data[dict], td.data[list])
2098 self.assertEqual(g(l), "base")
2099 self.assertEqual(g(d), "base")
2100 self.assertEqual(td.get_ops, [list, dict])
2101 self.assertEqual(td.set_ops, [dict, list])
2102 g.register(list, lambda arg: "list")
2103 self.assertEqual(td.get_ops, [list, dict])
2104 self.assertEqual(len(td), 0)
2105 self.assertEqual(g(d), "base")
2106 self.assertEqual(len(td), 1)
2107 self.assertEqual(td.get_ops, [list, dict])
2108 self.assertEqual(td.set_ops, [dict, list, dict])
2109 self.assertEqual(td.data[dict],
2110 functools._find_impl(dict, g.registry))
2111 self.assertEqual(g(l), "list")
2112 self.assertEqual(len(td), 2)
2113 self.assertEqual(td.get_ops, [list, dict])
2114 self.assertEqual(td.set_ops, [dict, list, dict, list])
2115 self.assertEqual(td.data[list],
2116 functools._find_impl(list, g.registry))
2117 class X:
2118 pass
2119 c.MutableMapping.register(X) # Will not invalidate the cache,
2120 # not using ABCs yet.
2121 self.assertEqual(g(d), "base")
2122 self.assertEqual(g(l), "list")
2123 self.assertEqual(td.get_ops, [list, dict, dict, list])
2124 self.assertEqual(td.set_ops, [dict, list, dict, list])
2125 g.register(c.Sized, lambda arg: "sized")
2126 self.assertEqual(len(td), 0)
2127 self.assertEqual(g(d), "sized")
2128 self.assertEqual(len(td), 1)
2129 self.assertEqual(td.get_ops, [list, dict, dict, list])
2130 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
2131 self.assertEqual(g(l), "list")
2132 self.assertEqual(len(td), 2)
2133 self.assertEqual(td.get_ops, [list, dict, dict, list])
2134 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2135 self.assertEqual(g(l), "list")
2136 self.assertEqual(g(d), "sized")
2137 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
2138 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2139 g.dispatch(list)
2140 g.dispatch(dict)
2141 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
2142 list, dict])
2143 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2144 c.MutableSet.register(X) # Will invalidate the cache.
2145 self.assertEqual(len(td), 2) # Stale cache.
2146 self.assertEqual(g(l), "list")
2147 self.assertEqual(len(td), 1)
2148 g.register(c.MutableMapping, lambda arg: "mutablemapping")
2149 self.assertEqual(len(td), 0)
2150 self.assertEqual(g(d), "mutablemapping")
2151 self.assertEqual(len(td), 1)
2152 self.assertEqual(g(l), "list")
2153 self.assertEqual(len(td), 2)
2154 g.register(dict, lambda arg: "dict")
2155 self.assertEqual(g(d), "dict")
2156 self.assertEqual(g(l), "list")
2157 g._clear_cache()
2158 self.assertEqual(len(td), 0)
Łukasz Langa6f692512013-06-05 12:20:24 +02002159
Łukasz Langae5697532017-12-11 13:56:31 -08002160 def test_annotations(self):
2161 @functools.singledispatch
2162 def i(arg):
2163 return "base"
2164 @i.register
2165 def _(arg: collections.abc.Mapping):
2166 return "mapping"
2167 @i.register
2168 def _(arg: "collections.abc.Sequence"):
2169 return "sequence"
2170 self.assertEqual(i(None), "base")
2171 self.assertEqual(i({"a": 1}), "mapping")
2172 self.assertEqual(i([1, 2, 3]), "sequence")
2173 self.assertEqual(i((1, 2, 3)), "sequence")
2174 self.assertEqual(i("str"), "sequence")
2175
2176 # Registering classes as callables doesn't work with annotations,
2177 # you need to pass the type explicitly.
2178 @i.register(str)
2179 class _:
2180 def __init__(self, arg):
2181 self.arg = arg
2182
2183 def __eq__(self, other):
2184 return self.arg == other
2185 self.assertEqual(i("str"), "str")
2186
2187 def test_invalid_registrations(self):
2188 msg_prefix = "Invalid first argument to `register()`: "
2189 msg_suffix = (
2190 ". Use either `@register(some_class)` or plain `@register` on an "
2191 "annotated function."
2192 )
2193 @functools.singledispatch
2194 def i(arg):
2195 return "base"
2196 with self.assertRaises(TypeError) as exc:
2197 @i.register(42)
2198 def _(arg):
2199 return "I annotated with a non-type"
2200 self.assertTrue(str(exc.exception).startswith(msg_prefix + "42"))
2201 self.assertTrue(str(exc.exception).endswith(msg_suffix))
2202 with self.assertRaises(TypeError) as exc:
2203 @i.register
2204 def _(arg):
2205 return "I forgot to annotate"
2206 self.assertTrue(str(exc.exception).startswith(msg_prefix +
2207 "<function TestSingleDispatch.test_invalid_registrations.<locals>._"
2208 ))
2209 self.assertTrue(str(exc.exception).endswith(msg_suffix))
2210
2211 # FIXME: The following will only work after PEP 560 is implemented.
2212 return
2213
2214 with self.assertRaises(TypeError) as exc:
2215 @i.register
2216 def _(arg: typing.Iterable[str]):
2217 # At runtime, dispatching on generics is impossible.
2218 # When registering implementations with singledispatch, avoid
2219 # types from `typing`. Instead, annotate with regular types
2220 # or ABCs.
2221 return "I annotated with a generic collection"
2222 self.assertTrue(str(exc.exception).startswith(msg_prefix +
2223 "<function TestSingleDispatch.test_invalid_registrations.<locals>._"
2224 ))
2225 self.assertTrue(str(exc.exception).endswith(msg_suffix))
2226
Miss Islington (bot)df9f6332018-07-10 00:48:57 -07002227 def test_invalid_positional_argument(self):
2228 @functools.singledispatch
2229 def f(*args):
2230 pass
2231 msg = 'f requires at least 1 positional argument'
Miss Islington (bot)892df9d2018-07-16 22:18:56 -07002232 with self.assertRaisesRegex(TypeError, msg):
Miss Islington (bot)df9f6332018-07-10 00:48:57 -07002233 f()
Łukasz Langa6f692512013-06-05 12:20:24 +02002234
Raymond Hettinger9c323f82005-02-28 19:39:44 +00002235if __name__ == '__main__':
Zachary Ware38c707e2015-04-13 15:00:43 -05002236 unittest.main()