blob: 3acfb92be4a18ae531c8fa28a56a77fd1e87dc52 [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08002import builtins
Raymond Hettinger003be522011-05-03 11:01:32 -07003import collections
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03004import collections.abc
Serhiy Storchaka45120f22015-10-24 09:49:56 +03005import copy
Łukasz Langa6f692512013-06-05 12:20:24 +02006from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00007import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00008from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02009import sys
10from test import support
Serhiy Storchaka67796522017-01-12 18:34:33 +020011import time
Łukasz Langa6f692512013-06-05 12:20:24 +020012import unittest
Raymond Hettingerd191ef22017-01-07 20:44:48 -080013import unittest.mock
Łukasz Langa6f692512013-06-05 12:20:24 +020014from weakref import proxy
Nick Coghlan457fc9a2016-09-10 20:00:02 +100015import contextlib
Serhiy Storchaka46c56112015-05-24 21:53:49 +030016try:
17 import threading
18except ImportError:
19 threading = None
Raymond Hettinger9c323f82005-02-28 19:39:44 +000020
Antoine Pitroub5b37142012-11-13 21:35:40 +010021import functools
22
Antoine Pitroub5b37142012-11-13 21:35:40 +010023py_functools = support.import_fresh_module('functools', blocked=['_functools'])
24c_functools = support.import_fresh_module('functools', fresh=['_functools'])
25
Łukasz Langa6f692512013-06-05 12:20:24 +020026decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
27
Nick Coghlan457fc9a2016-09-10 20:00:02 +100028@contextlib.contextmanager
29def replaced_module(name, replacement):
30 original_module = sys.modules[name]
31 sys.modules[name] = replacement
32 try:
33 yield
34 finally:
35 sys.modules[name] = original_module
Łukasz Langa6f692512013-06-05 12:20:24 +020036
Raymond Hettinger9c323f82005-02-28 19:39:44 +000037def capture(*args, **kw):
38 """capture all positional and keyword arguments"""
39 return args, kw
40
Łukasz Langa6f692512013-06-05 12:20:24 +020041
Jack Diederiche0cbd692009-04-01 04:27:09 +000042def signature(part):
43 """ return the signature of a partial object """
44 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000045
Serhiy Storchaka38741282016-02-02 18:45:17 +020046class MyTuple(tuple):
47 pass
48
49class BadTuple(tuple):
50 def __add__(self, other):
51 return list(self) + list(other)
52
53class MyDict(dict):
54 pass
55
Łukasz Langa6f692512013-06-05 12:20:24 +020056
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020057class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000058
59 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010060 p = self.partial(capture, 1, 2, a=10, b=20)
61 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000062 self.assertEqual(p(3, 4, b=30, c=40),
63 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010064 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000065 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000066
67 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010068 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000069 # attributes should be readable
70 self.assertEqual(p.func, capture)
71 self.assertEqual(p.args, (1, 2))
72 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000073
74 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010075 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000076 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010077 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000078 except TypeError:
79 pass
80 else:
81 self.fail('First arg not checked for callability')
82
83 def test_protection_of_callers_dict_argument(self):
84 # a caller's dictionary should not be altered by partial
85 def func(a=10, b=20):
86 return a
87 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010088 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000089 self.assertEqual(p(**d), 3)
90 self.assertEqual(d, {'a':3})
91 p(b=7)
92 self.assertEqual(d, {'a':3})
93
Serhiy Storchaka9639e4a2017-02-20 14:04:30 +020094 def test_kwargs_copy(self):
95 # Issue #29532: Altering a kwarg dictionary passed to a constructor
96 # should not affect a partial object after creation
97 d = {'a': 3}
98 p = self.partial(capture, **d)
99 self.assertEqual(p(), ((), {'a': 3}))
100 d['a'] = 5
101 self.assertEqual(p(), ((), {'a': 3}))
102
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000103 def test_arg_combinations(self):
104 # exercise special code paths for zero args in either partial
105 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100106 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000107 self.assertEqual(p(), ((), {}))
108 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100109 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000110 self.assertEqual(p(), ((1,2), {}))
111 self.assertEqual(p(3,4), ((1,2,3,4), {}))
112
113 def test_kw_combinations(self):
114 # exercise special code paths for no keyword args in
115 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100116 p = self.partial(capture)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400117 self.assertEqual(p.keywords, {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000118 self.assertEqual(p(), ((), {}))
119 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100120 p = self.partial(capture, a=1)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400121 self.assertEqual(p.keywords, {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000122 self.assertEqual(p(), ((), {'a':1}))
123 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
124 # keyword args in the call override those in the partial object
125 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
126
127 def test_positional(self):
128 # make sure positional arguments are captured correctly
129 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100130 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000131 expected = args + ('x',)
132 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000133 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000134
135 def test_keyword(self):
136 # make sure keyword arguments are captured correctly
137 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100138 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000139 expected = {'a':a,'x':None}
140 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000141 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000142
143 def test_no_side_effects(self):
144 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100145 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000146 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000147 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000148 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000149 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000150
151 def test_error_propagation(self):
152 def f(x, y):
153 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100154 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
155 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
156 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
157 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000158
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000159 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100160 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000161 p = proxy(f)
162 self.assertEqual(f.func, p.func)
163 f = None
164 self.assertRaises(ReferenceError, getattr, p, 'func')
165
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000166 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000167 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100168 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000169 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100170 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000171 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000172
Alexander Belopolskye49af342015-03-01 15:08:17 -0500173 def test_nested_optimization(self):
174 partial = self.partial
Alexander Belopolskye49af342015-03-01 15:08:17 -0500175 inner = partial(signature, 'asdf')
176 nested = partial(inner, bar=True)
177 flat = partial(signature, 'asdf', bar=True)
178 self.assertEqual(signature(nested), signature(flat))
179
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300180 def test_nested_partial_with_attribute(self):
181 # see issue 25137
182 partial = self.partial
183
184 def foo(bar):
185 return bar
186
187 p = partial(foo, 'first')
188 p2 = partial(p, 'second')
189 p2.new_attr = 'spam'
190 self.assertEqual(p2.new_attr, 'spam')
191
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000192 def test_repr(self):
193 args = (object(), object())
194 args_repr = ', '.join(repr(a) for a in args)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200195 kwargs = {'a': object(), 'b': object()}
196 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
197 'b={b!r}, a={a!r}'.format_map(kwargs)]
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000198 if self.partial in (c_functools.partial, py_functools.partial):
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000199 name = 'functools.partial'
200 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100201 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000202
Antoine Pitroub5b37142012-11-13 21:35:40 +0100203 f = self.partial(capture)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000204 self.assertEqual(f'{name}({capture!r})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000205
Antoine Pitroub5b37142012-11-13 21:35:40 +0100206 f = self.partial(capture, *args)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000207 self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000208
Antoine Pitroub5b37142012-11-13 21:35:40 +0100209 f = self.partial(capture, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200210 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000211 [f'{name}({capture!r}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200212 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000213
Antoine Pitroub5b37142012-11-13 21:35:40 +0100214 f = self.partial(capture, *args, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200215 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000216 [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200217 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000218
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300219 def test_recursive_repr(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000220 if self.partial in (c_functools.partial, py_functools.partial):
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300221 name = 'functools.partial'
222 else:
223 name = self.partial.__name__
224
225 f = self.partial(capture)
226 f.__setstate__((f, (), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300227 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000228 self.assertEqual(repr(f), '%s(...)' % (name,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300229 finally:
230 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300231
232 f = self.partial(capture)
233 f.__setstate__((capture, (f,), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300234 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000235 self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300236 finally:
237 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300238
239 f = self.partial(capture)
240 f.__setstate__((capture, (), {'a': f}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300241 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000242 self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300243 finally:
244 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300245
Jack Diederiche0cbd692009-04-01 04:27:09 +0000246 def test_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000247 with self.AllowPickle():
248 f = self.partial(signature, ['asdf'], bar=[True])
249 f.attr = []
250 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
251 f_copy = pickle.loads(pickle.dumps(f, proto))
252 self.assertEqual(signature(f_copy), signature(f))
Serhiy Storchaka38741282016-02-02 18:45:17 +0200253
254 def test_copy(self):
255 f = self.partial(signature, ['asdf'], bar=[True])
256 f.attr = []
257 f_copy = copy.copy(f)
258 self.assertEqual(signature(f_copy), signature(f))
259 self.assertIs(f_copy.attr, f.attr)
260 self.assertIs(f_copy.args, f.args)
261 self.assertIs(f_copy.keywords, f.keywords)
262
263 def test_deepcopy(self):
264 f = self.partial(signature, ['asdf'], bar=[True])
265 f.attr = []
266 f_copy = copy.deepcopy(f)
267 self.assertEqual(signature(f_copy), signature(f))
268 self.assertIsNot(f_copy.attr, f.attr)
269 self.assertIsNot(f_copy.args, f.args)
270 self.assertIsNot(f_copy.args[0], f.args[0])
271 self.assertIsNot(f_copy.keywords, f.keywords)
272 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
273
274 def test_setstate(self):
275 f = self.partial(signature)
276 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000277
Serhiy Storchaka38741282016-02-02 18:45:17 +0200278 self.assertEqual(signature(f),
279 (capture, (1,), dict(a=10), dict(attr=[])))
280 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
281
282 f.__setstate__((capture, (1,), dict(a=10), None))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000283
Serhiy Storchaka38741282016-02-02 18:45:17 +0200284 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
285 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
286
287 f.__setstate__((capture, (1,), None, None))
288 #self.assertEqual(signature(f), (capture, (1,), {}, {}))
289 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
290 self.assertEqual(f(2), ((1, 2), {}))
291 self.assertEqual(f(), ((1,), {}))
292
293 f.__setstate__((capture, (), {}, None))
294 self.assertEqual(signature(f), (capture, (), {}, {}))
295 self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
296 self.assertEqual(f(2), ((2,), {}))
297 self.assertEqual(f(), ((), {}))
298
299 def test_setstate_errors(self):
300 f = self.partial(signature)
301 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
302 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
303 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
304 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
305 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
306 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
307 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
308
309 def test_setstate_subclasses(self):
310 f = self.partial(signature)
311 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
312 s = signature(f)
313 self.assertEqual(s, (capture, (1,), dict(a=10), {}))
314 self.assertIs(type(s[1]), tuple)
315 self.assertIs(type(s[2]), dict)
316 r = f()
317 self.assertEqual(r, ((1,), {'a': 10}))
318 self.assertIs(type(r[0]), tuple)
319 self.assertIs(type(r[1]), dict)
320
321 f.__setstate__((capture, BadTuple((1,)), {}, None))
322 s = signature(f)
323 self.assertEqual(s, (capture, (1,), {}, {}))
324 self.assertIs(type(s[1]), tuple)
325 r = f(2)
326 self.assertEqual(r, ((1, 2), {}))
327 self.assertIs(type(r[0]), tuple)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000328
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300329 def test_recursive_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000330 with self.AllowPickle():
331 f = self.partial(capture)
332 f.__setstate__((f, (), {}, {}))
333 try:
334 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
335 with self.assertRaises(RecursionError):
336 pickle.dumps(f, proto)
337 finally:
338 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300339
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000340 f = self.partial(capture)
341 f.__setstate__((capture, (f,), {}, {}))
342 try:
343 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
344 f_copy = pickle.loads(pickle.dumps(f, proto))
345 try:
346 self.assertIs(f_copy.args[0], f_copy)
347 finally:
348 f_copy.__setstate__((capture, (), {}, {}))
349 finally:
350 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300351
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000352 f = self.partial(capture)
353 f.__setstate__((capture, (), {'a': f}, {}))
354 try:
355 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
356 f_copy = pickle.loads(pickle.dumps(f, proto))
357 try:
358 self.assertIs(f_copy.keywords['a'], f_copy)
359 finally:
360 f_copy.__setstate__((capture, (), {}, {}))
361 finally:
362 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300363
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200364 # Issue 6083: Reference counting bug
365 def test_setstate_refcount(self):
366 class BadSequence:
367 def __len__(self):
368 return 4
369 def __getitem__(self, key):
370 if key == 0:
371 return max
372 elif key == 1:
373 return tuple(range(1000000))
374 elif key in (2, 3):
375 return {}
376 raise IndexError
377
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200378 f = self.partial(object)
Serhiy Storchaka38741282016-02-02 18:45:17 +0200379 self.assertRaises(TypeError, f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000380
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000381@unittest.skipUnless(c_functools, 'requires the C _functools module')
382class TestPartialC(TestPartial, unittest.TestCase):
383 if c_functools:
384 partial = c_functools.partial
385
386 class AllowPickle:
387 def __enter__(self):
388 return self
389 def __exit__(self, type, value, tb):
390 return False
391
392 def test_attributes_unwritable(self):
393 # attributes should not be writable
394 p = self.partial(capture, 1, 2, a=10, b=20)
395 self.assertRaises(AttributeError, setattr, p, 'func', map)
396 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
397 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
398
399 p = self.partial(hex)
400 try:
401 del p.__dict__
402 except TypeError:
403 pass
404 else:
405 self.fail('partial object allowed __dict__ to be deleted')
Łukasz Langa6f692512013-06-05 12:20:24 +0200406
Michael Seifert6c3d5272017-03-15 06:26:33 +0100407 def test_manually_adding_non_string_keyword(self):
408 p = self.partial(capture)
409 # Adding a non-string/unicode keyword to partial kwargs
410 p.keywords[1234] = 'value'
411 r = repr(p)
412 self.assertIn('1234', r)
413 self.assertIn("'value'", r)
414 with self.assertRaises(TypeError):
415 p()
416
417 def test_keystr_replaces_value(self):
418 p = self.partial(capture)
419
420 class MutatesYourDict(object):
421 def __str__(self):
422 p.keywords[self] = ['sth2']
423 return 'astr'
424
425 # Raplacing the value during key formatting should keep the original
426 # value alive (at least long enough).
427 p.keywords[MutatesYourDict()] = ['sth']
428 r = repr(p)
429 self.assertIn('astr', r)
430 self.assertIn("['sth']", r)
431
432
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200433class TestPartialPy(TestPartial, unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000434 partial = py_functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000435
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000436 class AllowPickle:
437 def __init__(self):
438 self._cm = replaced_module("functools", py_functools)
439 def __enter__(self):
440 return self._cm.__enter__()
441 def __exit__(self, type, value, tb):
442 return self._cm.__exit__(type, value, tb)
Łukasz Langa6f692512013-06-05 12:20:24 +0200443
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200444if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000445 class CPartialSubclass(c_functools.partial):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200446 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100447
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000448class PyPartialSubclass(py_functools.partial):
449 pass
Łukasz Langa6f692512013-06-05 12:20:24 +0200450
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200451@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200452class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200453 if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000454 partial = CPartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000455
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300456 # partial subclasses are not optimized for nested calls
457 test_nested_optimization = None
458
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000459class TestPartialPySubclass(TestPartialPy):
460 partial = PyPartialSubclass
Łukasz Langa6f692512013-06-05 12:20:24 +0200461
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000462class TestPartialMethod(unittest.TestCase):
463
464 class A(object):
465 nothing = functools.partialmethod(capture)
466 positional = functools.partialmethod(capture, 1)
467 keywords = functools.partialmethod(capture, a=2)
468 both = functools.partialmethod(capture, 3, b=4)
469
470 nested = functools.partialmethod(positional, 5)
471
472 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
473
474 static = functools.partialmethod(staticmethod(capture), 8)
475 cls = functools.partialmethod(classmethod(capture), d=9)
476
477 a = A()
478
479 def test_arg_combinations(self):
480 self.assertEqual(self.a.nothing(), ((self.a,), {}))
481 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
482 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
483 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
484
485 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
486 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
487 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
488 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
489
490 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
491 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
492 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
493 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
494
495 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
496 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
497 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
498 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
499
500 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
501
502 def test_nested(self):
503 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
504 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
505 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
506 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
507
508 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
509
510 def test_over_partial(self):
511 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
512 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
513 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
514 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
515
516 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
517
518 def test_bound_method_introspection(self):
519 obj = self.a
520 self.assertIs(obj.both.__self__, obj)
521 self.assertIs(obj.nested.__self__, obj)
522 self.assertIs(obj.over_partial.__self__, obj)
523 self.assertIs(obj.cls.__self__, self.A)
524 self.assertIs(self.A.cls.__self__, self.A)
525
526 def test_unbound_method_retrieval(self):
527 obj = self.A
528 self.assertFalse(hasattr(obj.both, "__self__"))
529 self.assertFalse(hasattr(obj.nested, "__self__"))
530 self.assertFalse(hasattr(obj.over_partial, "__self__"))
531 self.assertFalse(hasattr(obj.static, "__self__"))
532 self.assertFalse(hasattr(self.a.static, "__self__"))
533
534 def test_descriptors(self):
535 for obj in [self.A, self.a]:
536 with self.subTest(obj=obj):
537 self.assertEqual(obj.static(), ((8,), {}))
538 self.assertEqual(obj.static(5), ((8, 5), {}))
539 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
540 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
541
542 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
543 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
544 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
545 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
546
547 def test_overriding_keywords(self):
548 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
549 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
550
551 def test_invalid_args(self):
552 with self.assertRaises(TypeError):
553 class B(object):
554 method = functools.partialmethod(None, 1)
555
556 def test_repr(self):
557 self.assertEqual(repr(vars(self.A)['both']),
558 'functools.partialmethod({}, 3, b=4)'.format(capture))
559
560 def test_abstract(self):
561 class Abstract(abc.ABCMeta):
562
563 @abc.abstractmethod
564 def add(self, x, y):
565 pass
566
567 add5 = functools.partialmethod(add, 5)
568
569 self.assertTrue(Abstract.add.__isabstractmethod__)
570 self.assertTrue(Abstract.add5.__isabstractmethod__)
571
572 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
573 self.assertFalse(getattr(func, '__isabstractmethod__', False))
574
575
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000576class TestUpdateWrapper(unittest.TestCase):
577
578 def check_wrapper(self, wrapper, wrapped,
579 assigned=functools.WRAPPER_ASSIGNMENTS,
580 updated=functools.WRAPPER_UPDATES):
581 # Check attributes were assigned
582 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000583 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000584 # Check attributes were updated
585 for name in updated:
586 wrapper_attr = getattr(wrapper, name)
587 wrapped_attr = getattr(wrapped, name)
588 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000589 if name == "__dict__" and key == "__wrapped__":
590 # __wrapped__ is overwritten by the update code
591 continue
592 self.assertIs(wrapped_attr[key], wrapper_attr[key])
593 # Check __wrapped__
594 self.assertIs(wrapper.__wrapped__, wrapped)
595
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000596
R. David Murray378c0cf2010-02-24 01:46:21 +0000597 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000598 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000599 """This is a test"""
600 pass
601 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000602 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000603 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000604 pass
605 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000606 return wrapper, f
607
608 def test_default_update(self):
609 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000610 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000611 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000612 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600613 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000614 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000615 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
616 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000617
R. David Murray378c0cf2010-02-24 01:46:21 +0000618 @unittest.skipIf(sys.flags.optimize >= 2,
619 "Docstrings are omitted with -O2 and above")
620 def test_default_update_doc(self):
621 wrapper, f = self._default_update()
622 self.assertEqual(wrapper.__doc__, 'This is a test')
623
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000624 def test_no_update(self):
625 def f():
626 """This is a test"""
627 pass
628 f.attr = 'This is also a test'
629 def wrapper():
630 pass
631 functools.update_wrapper(wrapper, f, (), ())
632 self.check_wrapper(wrapper, f, (), ())
633 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600634 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000635 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000636 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000637 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000638
639 def test_selective_update(self):
640 def f():
641 pass
642 f.attr = 'This is a different test'
643 f.dict_attr = dict(a=1, b=2, c=3)
644 def wrapper():
645 pass
646 wrapper.dict_attr = {}
647 assign = ('attr',)
648 update = ('dict_attr',)
649 functools.update_wrapper(wrapper, f, assign, update)
650 self.check_wrapper(wrapper, f, assign, update)
651 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600652 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000653 self.assertEqual(wrapper.__doc__, None)
654 self.assertEqual(wrapper.attr, 'This is a different test')
655 self.assertEqual(wrapper.dict_attr, f.dict_attr)
656
Nick Coghlan98876832010-08-17 06:17:18 +0000657 def test_missing_attributes(self):
658 def f():
659 pass
660 def wrapper():
661 pass
662 wrapper.dict_attr = {}
663 assign = ('attr',)
664 update = ('dict_attr',)
665 # Missing attributes on wrapped object are ignored
666 functools.update_wrapper(wrapper, f, assign, update)
667 self.assertNotIn('attr', wrapper.__dict__)
668 self.assertEqual(wrapper.dict_attr, {})
669 # Wrapper must have expected attributes for updating
670 del wrapper.dict_attr
671 with self.assertRaises(AttributeError):
672 functools.update_wrapper(wrapper, f, assign, update)
673 wrapper.dict_attr = 1
674 with self.assertRaises(AttributeError):
675 functools.update_wrapper(wrapper, f, assign, update)
676
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200677 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000678 @unittest.skipIf(sys.flags.optimize >= 2,
679 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000680 def test_builtin_update(self):
681 # Test for bug #1576241
682 def wrapper():
683 pass
684 functools.update_wrapper(wrapper, max)
685 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000686 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000687 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000688
Łukasz Langa6f692512013-06-05 12:20:24 +0200689
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000690class TestWraps(TestUpdateWrapper):
691
R. David Murray378c0cf2010-02-24 01:46:21 +0000692 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000693 def f():
694 """This is a test"""
695 pass
696 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000697 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000698 @functools.wraps(f)
699 def wrapper():
700 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600701 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000702
703 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600704 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000705 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000706 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600707 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000708 self.assertEqual(wrapper.attr, 'This is also a test')
709
Antoine Pitroub5b37142012-11-13 21:35:40 +0100710 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000711 "Docstrings are omitted with -O2 and above")
712 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600713 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000714 self.assertEqual(wrapper.__doc__, 'This is a test')
715
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000716 def test_no_update(self):
717 def f():
718 """This is a test"""
719 pass
720 f.attr = 'This is also a test'
721 @functools.wraps(f, (), ())
722 def wrapper():
723 pass
724 self.check_wrapper(wrapper, f, (), ())
725 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600726 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000727 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000728 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000729
730 def test_selective_update(self):
731 def f():
732 pass
733 f.attr = 'This is a different test'
734 f.dict_attr = dict(a=1, b=2, c=3)
735 def add_dict_attr(f):
736 f.dict_attr = {}
737 return f
738 assign = ('attr',)
739 update = ('dict_attr',)
740 @functools.wraps(f, assign, update)
741 @add_dict_attr
742 def wrapper():
743 pass
744 self.check_wrapper(wrapper, f, assign, update)
745 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600746 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000747 self.assertEqual(wrapper.__doc__, None)
748 self.assertEqual(wrapper.attr, 'This is a different test')
749 self.assertEqual(wrapper.dict_attr, f.dict_attr)
750
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000751@unittest.skipUnless(c_functools, 'requires the C _functools module')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000752class TestReduce(unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000753 if c_functools:
754 func = c_functools.reduce
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000755
756 def test_reduce(self):
757 class Squares:
758 def __init__(self, max):
759 self.max = max
760 self.sofar = []
761
762 def __len__(self):
763 return len(self.sofar)
764
765 def __getitem__(self, i):
766 if not 0 <= i < self.max: raise IndexError
767 n = len(self.sofar)
768 while n <= i:
769 self.sofar.append(n*n)
770 n += 1
771 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000772 def add(x, y):
773 return x + y
774 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000775 self.assertEqual(
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000776 self.func(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000777 ['a','c','d','w']
778 )
779 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
780 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000781 self.func(lambda x, y: x*y, range(2,21), 1),
782 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000783 )
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000784 self.assertEqual(self.func(add, Squares(10)), 285)
785 self.assertEqual(self.func(add, Squares(10), 0), 285)
786 self.assertEqual(self.func(add, Squares(0), 0), 0)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000787 self.assertRaises(TypeError, self.func)
788 self.assertRaises(TypeError, self.func, 42, 42)
789 self.assertRaises(TypeError, self.func, 42, 42, 42)
790 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
791 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
792 self.assertRaises(TypeError, self.func, 42, (42, 42))
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000793 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
794 self.assertRaises(TypeError, self.func, add, "")
795 self.assertRaises(TypeError, self.func, add, ())
796 self.assertRaises(TypeError, self.func, add, object())
797
798 class TestFailingIter:
799 def __iter__(self):
800 raise RuntimeError
801 self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
802
803 self.assertEqual(self.func(add, [], None), None)
804 self.assertEqual(self.func(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000805
806 class BadSeq:
807 def __getitem__(self, index):
808 raise ValueError
809 self.assertRaises(ValueError, self.func, 42, BadSeq())
810
811 # Test reduce()'s use of iterators.
812 def test_iterator_usage(self):
813 class SequenceClass:
814 def __init__(self, n):
815 self.n = n
816 def __getitem__(self, i):
817 if 0 <= i < self.n:
818 return i
819 else:
820 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000821
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000822 from operator import add
823 self.assertEqual(self.func(add, SequenceClass(5)), 10)
824 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
825 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
826 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
827 self.assertEqual(self.func(add, SequenceClass(1)), 0)
828 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
829
830 d = {"one": 1, "two": 2, "three": 3}
831 self.assertEqual(self.func(add, d), "".join(d.keys()))
832
Łukasz Langa6f692512013-06-05 12:20:24 +0200833
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200834class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700835
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000836 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700837 def cmp1(x, y):
838 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100839 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700840 self.assertEqual(key(3), key(3))
841 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100842 self.assertGreaterEqual(key(3), key(3))
843
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700844 def cmp2(x, y):
845 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100846 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700847 self.assertEqual(key(4.0), key('4'))
848 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100849 self.assertLessEqual(key(2), key('35'))
850 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700851
852 def test_cmp_to_key_arguments(self):
853 def cmp1(x, y):
854 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100855 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700856 self.assertEqual(key(obj=3), key(obj=3))
857 self.assertGreater(key(obj=3), key(obj=1))
858 with self.assertRaises((TypeError, AttributeError)):
859 key(3) > 1 # rhs is not a K object
860 with self.assertRaises((TypeError, AttributeError)):
861 1 < key(3) # lhs is not a K object
862 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100863 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700864 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200865 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100866 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700867 with self.assertRaises(TypeError):
868 key() # too few args
869 with self.assertRaises(TypeError):
870 key(None, None) # too many args
871
872 def test_bad_cmp(self):
873 def cmp1(x, y):
874 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100875 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700876 with self.assertRaises(ZeroDivisionError):
877 key(3) > key(1)
878
879 class BadCmp:
880 def __lt__(self, other):
881 raise ZeroDivisionError
882 def cmp1(x, y):
883 return BadCmp()
884 with self.assertRaises(ZeroDivisionError):
885 key(3) > key(1)
886
887 def test_obj_field(self):
888 def cmp1(x, y):
889 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100890 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700891 self.assertEqual(key(50).obj, 50)
892
893 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000894 def mycmp(x, y):
895 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100896 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000897 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000898
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700899 def test_sort_int_str(self):
900 def mycmp(x, y):
901 x, y = int(x), int(y)
902 return (x > y) - (x < y)
903 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100904 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700905 self.assertEqual([int(value) for value in values],
906 [0, 1, 1, 2, 3, 4, 5, 7, 10])
907
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000908 def test_hash(self):
909 def mycmp(x, y):
910 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100911 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000912 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700913 self.assertRaises(TypeError, hash, k)
Serhiy Storchaka2e576f52017-04-24 09:05:00 +0300914 self.assertNotIsInstance(k, collections.abc.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000915
Łukasz Langa6f692512013-06-05 12:20:24 +0200916
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200917@unittest.skipUnless(c_functools, 'requires the C _functools module')
918class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
919 if c_functools:
920 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100921
Łukasz Langa6f692512013-06-05 12:20:24 +0200922
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200923class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100924 cmp_to_key = staticmethod(py_functools.cmp_to_key)
925
Łukasz Langa6f692512013-06-05 12:20:24 +0200926
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000927class TestTotalOrdering(unittest.TestCase):
928
929 def test_total_ordering_lt(self):
930 @functools.total_ordering
931 class A:
932 def __init__(self, value):
933 self.value = value
934 def __lt__(self, other):
935 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000936 def __eq__(self, other):
937 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000938 self.assertTrue(A(1) < A(2))
939 self.assertTrue(A(2) > A(1))
940 self.assertTrue(A(1) <= A(2))
941 self.assertTrue(A(2) >= A(1))
942 self.assertTrue(A(2) <= A(2))
943 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000944 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000945
946 def test_total_ordering_le(self):
947 @functools.total_ordering
948 class A:
949 def __init__(self, value):
950 self.value = value
951 def __le__(self, other):
952 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000953 def __eq__(self, other):
954 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000955 self.assertTrue(A(1) < A(2))
956 self.assertTrue(A(2) > A(1))
957 self.assertTrue(A(1) <= A(2))
958 self.assertTrue(A(2) >= A(1))
959 self.assertTrue(A(2) <= A(2))
960 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000961 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000962
963 def test_total_ordering_gt(self):
964 @functools.total_ordering
965 class A:
966 def __init__(self, value):
967 self.value = value
968 def __gt__(self, other):
969 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000970 def __eq__(self, other):
971 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000972 self.assertTrue(A(1) < A(2))
973 self.assertTrue(A(2) > A(1))
974 self.assertTrue(A(1) <= A(2))
975 self.assertTrue(A(2) >= A(1))
976 self.assertTrue(A(2) <= A(2))
977 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000978 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000979
980 def test_total_ordering_ge(self):
981 @functools.total_ordering
982 class A:
983 def __init__(self, value):
984 self.value = value
985 def __ge__(self, other):
986 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000987 def __eq__(self, other):
988 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000989 self.assertTrue(A(1) < A(2))
990 self.assertTrue(A(2) > A(1))
991 self.assertTrue(A(1) <= A(2))
992 self.assertTrue(A(2) >= A(1))
993 self.assertTrue(A(2) <= A(2))
994 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000995 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000996
997 def test_total_ordering_no_overwrite(self):
998 # new methods should not overwrite existing
999 @functools.total_ordering
1000 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +00001001 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +00001002 self.assertTrue(A(1) < A(2))
1003 self.assertTrue(A(2) > A(1))
1004 self.assertTrue(A(1) <= A(2))
1005 self.assertTrue(A(2) >= A(1))
1006 self.assertTrue(A(2) <= A(2))
1007 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001008
Benjamin Peterson42ebee32010-04-11 01:43:16 +00001009 def test_no_operations_defined(self):
1010 with self.assertRaises(ValueError):
1011 @functools.total_ordering
1012 class A:
1013 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001014
Nick Coghlanf05d9812013-10-02 00:02:03 +10001015 def test_type_error_when_not_implemented(self):
1016 # bug 10042; ensure stack overflow does not occur
1017 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001018 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +10001019 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001020 def __init__(self, value):
1021 self.value = value
1022 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001023 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001024 return self.value == other.value
1025 return False
1026 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001027 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001028 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +10001029 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001030
Nick Coghlanf05d9812013-10-02 00:02:03 +10001031 @functools.total_ordering
1032 class ImplementsGreaterThan:
1033 def __init__(self, value):
1034 self.value = value
1035 def __eq__(self, other):
1036 if isinstance(other, ImplementsGreaterThan):
1037 return self.value == other.value
1038 return False
1039 def __gt__(self, other):
1040 if isinstance(other, ImplementsGreaterThan):
1041 return self.value > other.value
1042 return NotImplemented
1043
1044 @functools.total_ordering
1045 class ImplementsLessThanEqualTo:
1046 def __init__(self, value):
1047 self.value = value
1048 def __eq__(self, other):
1049 if isinstance(other, ImplementsLessThanEqualTo):
1050 return self.value == other.value
1051 return False
1052 def __le__(self, other):
1053 if isinstance(other, ImplementsLessThanEqualTo):
1054 return self.value <= other.value
1055 return NotImplemented
1056
1057 @functools.total_ordering
1058 class ImplementsGreaterThanEqualTo:
1059 def __init__(self, value):
1060 self.value = value
1061 def __eq__(self, other):
1062 if isinstance(other, ImplementsGreaterThanEqualTo):
1063 return self.value == other.value
1064 return False
1065 def __ge__(self, other):
1066 if isinstance(other, ImplementsGreaterThanEqualTo):
1067 return self.value >= other.value
1068 return NotImplemented
1069
1070 @functools.total_ordering
1071 class ComparatorNotImplemented:
1072 def __init__(self, value):
1073 self.value = value
1074 def __eq__(self, other):
1075 if isinstance(other, ComparatorNotImplemented):
1076 return self.value == other.value
1077 return False
1078 def __lt__(self, other):
1079 return NotImplemented
1080
1081 with self.subTest("LT < 1"), self.assertRaises(TypeError):
1082 ImplementsLessThan(-1) < 1
1083
1084 with self.subTest("LT < LE"), self.assertRaises(TypeError):
1085 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
1086
1087 with self.subTest("LT < GT"), self.assertRaises(TypeError):
1088 ImplementsLessThan(1) < ImplementsGreaterThan(1)
1089
1090 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
1091 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
1092
1093 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
1094 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
1095
1096 with self.subTest("GT > GE"), self.assertRaises(TypeError):
1097 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
1098
1099 with self.subTest("GT > LT"), self.assertRaises(TypeError):
1100 ImplementsGreaterThan(5) > ImplementsLessThan(5)
1101
1102 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
1103 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
1104
1105 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
1106 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
1107
1108 with self.subTest("GE when equal"):
1109 a = ComparatorNotImplemented(8)
1110 b = ComparatorNotImplemented(8)
1111 self.assertEqual(a, b)
1112 with self.assertRaises(TypeError):
1113 a >= b
1114
1115 with self.subTest("LE when equal"):
1116 a = ComparatorNotImplemented(9)
1117 b = ComparatorNotImplemented(9)
1118 self.assertEqual(a, b)
1119 with self.assertRaises(TypeError):
1120 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +02001121
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001122 def test_pickle(self):
Serhiy Storchaka92bb90a2016-09-22 11:39:25 +03001123 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001124 for name in '__lt__', '__gt__', '__le__', '__ge__':
1125 with self.subTest(method=name, proto=proto):
1126 method = getattr(Orderable_LT, name)
1127 method_copy = pickle.loads(pickle.dumps(method, proto))
1128 self.assertIs(method_copy, method)
1129
1130@functools.total_ordering
1131class Orderable_LT:
1132 def __init__(self, value):
1133 self.value = value
1134 def __lt__(self, other):
1135 return self.value < other.value
1136 def __eq__(self, other):
1137 return self.value == other.value
1138
1139
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001140class TestLRU:
Georg Brandl2e7346a2010-07-31 18:09:23 +00001141
1142 def test_lru(self):
1143 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001144 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001145 f = self.module.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001146 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001147 self.assertEqual(maxsize, 20)
1148 self.assertEqual(currsize, 0)
1149 self.assertEqual(hits, 0)
1150 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001151
1152 domain = range(5)
1153 for i in range(1000):
1154 x, y = choice(domain), choice(domain)
1155 actual = f(x, y)
1156 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +00001157 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001158 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001159 self.assertTrue(hits > misses)
1160 self.assertEqual(hits + misses, 1000)
1161 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001162
Raymond Hettinger02566ec2010-09-04 22:46:06 +00001163 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +00001164 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001165 self.assertEqual(hits, 0)
1166 self.assertEqual(misses, 0)
1167 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001168 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001169 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001170 self.assertEqual(hits, 0)
1171 self.assertEqual(misses, 1)
1172 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001173
Nick Coghlan98876832010-08-17 06:17:18 +00001174 # Test bypassing the cache
1175 self.assertIs(f.__wrapped__, orig)
1176 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001177 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001178 self.assertEqual(hits, 0)
1179 self.assertEqual(misses, 1)
1180 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +00001181
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001182 # test size zero (which means "never-cache")
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001183 @self.module.lru_cache(0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001184 def f():
1185 nonlocal f_cnt
1186 f_cnt += 1
1187 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001188 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001189 f_cnt = 0
1190 for i in range(5):
1191 self.assertEqual(f(), 20)
1192 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001193 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001194 self.assertEqual(hits, 0)
1195 self.assertEqual(misses, 5)
1196 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001197
1198 # test size one
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001199 @self.module.lru_cache(1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001200 def f():
1201 nonlocal f_cnt
1202 f_cnt += 1
1203 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001204 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001205 f_cnt = 0
1206 for i in range(5):
1207 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001208 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001209 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001210 self.assertEqual(hits, 4)
1211 self.assertEqual(misses, 1)
1212 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001213
Raymond Hettingerf3098282010-08-15 03:30:45 +00001214 # test size two
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001215 @self.module.lru_cache(2)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001216 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001217 nonlocal f_cnt
1218 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +00001219 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +00001220 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001221 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +00001222 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1223 # * * * *
1224 self.assertEqual(f(x), x*10)
1225 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001226 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001227 self.assertEqual(hits, 12)
1228 self.assertEqual(misses, 4)
1229 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001230
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001231 def test_lru_hash_only_once(self):
1232 # To protect against weird reentrancy bugs and to improve
1233 # efficiency when faced with slow __hash__ methods, the
1234 # LRU cache guarantees that it will only call __hash__
1235 # only once per use as an argument to the cached function.
1236
1237 @self.module.lru_cache(maxsize=1)
1238 def f(x, y):
1239 return x * 3 + y
1240
1241 # Simulate the integer 5
1242 mock_int = unittest.mock.Mock()
1243 mock_int.__mul__ = unittest.mock.Mock(return_value=15)
1244 mock_int.__hash__ = unittest.mock.Mock(return_value=999)
1245
1246 # Add to cache: One use as an argument gives one call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001247 self.assertEqual(f(mock_int, 1), 16)
1248 self.assertEqual(mock_int.__hash__.call_count, 1)
1249 self.assertEqual(f.cache_info(), (0, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001250
1251 # Cache hit: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001252 self.assertEqual(f(mock_int, 1), 16)
1253 self.assertEqual(mock_int.__hash__.call_count, 2)
1254 self.assertEqual(f.cache_info(), (1, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001255
Ville Skyttä49b27342017-08-03 09:00:59 +03001256 # Cache eviction: No use as an argument gives no additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001257 self.assertEqual(f(6, 2), 20)
1258 self.assertEqual(mock_int.__hash__.call_count, 2)
1259 self.assertEqual(f.cache_info(), (1, 2, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001260
1261 # Cache miss: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001262 self.assertEqual(f(mock_int, 1), 16)
1263 self.assertEqual(mock_int.__hash__.call_count, 3)
1264 self.assertEqual(f.cache_info(), (1, 3, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001265
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08001266 def test_lru_reentrancy_with_len(self):
1267 # Test to make sure the LRU cache code isn't thrown-off by
1268 # caching the built-in len() function. Since len() can be
1269 # cached, we shouldn't use it inside the lru code itself.
1270 old_len = builtins.len
1271 try:
1272 builtins.len = self.module.lru_cache(4)(len)
1273 for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
1274 self.assertEqual(len('abcdefghijklmn'[:i]), i)
1275 finally:
1276 builtins.len = old_len
1277
Raymond Hettinger605a4472017-01-09 07:50:19 -08001278 def test_lru_star_arg_handling(self):
1279 # Test regression that arose in ea064ff3c10f
1280 @functools.lru_cache()
1281 def f(*args):
1282 return args
1283
1284 self.assertEqual(f(1, 2), (1, 2))
1285 self.assertEqual(f((1, 2)), ((1, 2),))
1286
Yury Selivanov46a02db2016-11-09 18:55:45 -05001287 def test_lru_type_error(self):
1288 # Regression test for issue #28653.
1289 # lru_cache was leaking when one of the arguments
1290 # wasn't cacheable.
1291
1292 @functools.lru_cache(maxsize=None)
1293 def infinite_cache(o):
1294 pass
1295
1296 @functools.lru_cache(maxsize=10)
1297 def limited_cache(o):
1298 pass
1299
1300 with self.assertRaises(TypeError):
1301 infinite_cache([])
1302
1303 with self.assertRaises(TypeError):
1304 limited_cache([])
1305
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001306 def test_lru_with_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001307 @self.module.lru_cache(maxsize=None)
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001308 def fib(n):
1309 if n < 2:
1310 return n
1311 return fib(n-1) + fib(n-2)
1312 self.assertEqual([fib(n) for n in range(16)],
1313 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1314 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001315 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001316 fib.cache_clear()
1317 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001318 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1319
1320 def test_lru_with_maxsize_negative(self):
1321 @self.module.lru_cache(maxsize=-10)
1322 def eq(n):
1323 return n
1324 for i in (0, 1):
1325 self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1326 self.assertEqual(eq.cache_info(),
1327 self.module._CacheInfo(hits=0, misses=300, maxsize=-10, currsize=1))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001328
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001329 def test_lru_with_exceptions(self):
1330 # Verify that user_function exceptions get passed through without
1331 # creating a hard-to-read chained exception.
1332 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +01001333 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001334 @self.module.lru_cache(maxsize)
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001335 def func(i):
1336 return 'abc'[i]
1337 self.assertEqual(func(0), 'a')
1338 with self.assertRaises(IndexError) as cm:
1339 func(15)
1340 self.assertIsNone(cm.exception.__context__)
1341 # Verify that the previous exception did not result in a cached entry
1342 with self.assertRaises(IndexError):
1343 func(15)
1344
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001345 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001346 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001347 @self.module.lru_cache(maxsize=maxsize, typed=True)
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001348 def square(x):
1349 return x * x
1350 self.assertEqual(square(3), 9)
1351 self.assertEqual(type(square(3)), type(9))
1352 self.assertEqual(square(3.0), 9.0)
1353 self.assertEqual(type(square(3.0)), type(9.0))
1354 self.assertEqual(square(x=3), 9)
1355 self.assertEqual(type(square(x=3)), type(9))
1356 self.assertEqual(square(x=3.0), 9.0)
1357 self.assertEqual(type(square(x=3.0)), type(9.0))
1358 self.assertEqual(square.cache_info().hits, 4)
1359 self.assertEqual(square.cache_info().misses, 4)
1360
Antoine Pitroub5b37142012-11-13 21:35:40 +01001361 def test_lru_with_keyword_args(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001362 @self.module.lru_cache()
Antoine Pitroub5b37142012-11-13 21:35:40 +01001363 def fib(n):
1364 if n < 2:
1365 return n
1366 return fib(n=n-1) + fib(n=n-2)
1367 self.assertEqual(
1368 [fib(n=number) for number in range(16)],
1369 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1370 )
1371 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001372 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001373 fib.cache_clear()
1374 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001375 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001376
1377 def test_lru_with_keyword_args_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001378 @self.module.lru_cache(maxsize=None)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001379 def fib(n):
1380 if n < 2:
1381 return n
1382 return fib(n=n-1) + fib(n=n-2)
1383 self.assertEqual([fib(n=number) for number in range(16)],
1384 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1385 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001386 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001387 fib.cache_clear()
1388 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001389 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1390
Raymond Hettinger4ee39142017-01-08 17:28:20 -08001391 def test_kwargs_order(self):
1392 # PEP 468: Preserving Keyword Argument Order
1393 @self.module.lru_cache(maxsize=10)
1394 def f(**kwargs):
1395 return list(kwargs.items())
1396 self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)])
1397 self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)])
1398 self.assertEqual(f.cache_info(),
1399 self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2))
1400
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001401 def test_lru_cache_decoration(self):
1402 def f(zomg: 'zomg_annotation'):
1403 """f doc string"""
1404 return 42
1405 g = self.module.lru_cache()(f)
1406 for attr in self.module.WRAPPER_ASSIGNMENTS:
1407 self.assertEqual(getattr(g, attr), getattr(f, attr))
1408
1409 @unittest.skipUnless(threading, 'This test requires threading.')
1410 def test_lru_cache_threaded(self):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001411 n, m = 5, 11
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001412 def orig(x, y):
1413 return 3 * x + y
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001414 f = self.module.lru_cache(maxsize=n*m)(orig)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001415 hits, misses, maxsize, currsize = f.cache_info()
1416 self.assertEqual(currsize, 0)
1417
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001418 start = threading.Event()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001419 def full(k):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001420 start.wait(10)
1421 for _ in range(m):
Serhiy Storchaka391af752015-06-08 12:44:18 +03001422 self.assertEqual(f(k, 0), orig(k, 0))
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001423
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001424 def clear():
1425 start.wait(10)
1426 for _ in range(2*m):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001427 f.cache_clear()
1428
1429 orig_si = sys.getswitchinterval()
Xavier de Gaye7522ef42016-12-08 11:06:56 +01001430 support.setswitchinterval(1e-6)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001431 try:
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001432 # create n threads in order to fill cache
Serhiy Storchaka391af752015-06-08 12:44:18 +03001433 threads = [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001434 for k in range(n)]
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001435 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001436 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001437
1438 hits, misses, maxsize, currsize = f.cache_info()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001439 if self.module is py_functools:
1440 # XXX: Why can be not equal?
1441 self.assertLessEqual(misses, n)
1442 self.assertLessEqual(hits, m*n - misses)
1443 else:
1444 self.assertEqual(misses, n)
1445 self.assertEqual(hits, m*n - misses)
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001446 self.assertEqual(currsize, n)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001447
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001448 # create n threads in order to fill cache and 1 to clear it
1449 threads = [threading.Thread(target=clear)]
Serhiy Storchaka391af752015-06-08 12:44:18 +03001450 threads += [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001451 for k in range(n)]
1452 start.clear()
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001453 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001454 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001455 finally:
1456 sys.setswitchinterval(orig_si)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001457
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001458 @unittest.skipUnless(threading, 'This test requires threading.')
1459 def test_lru_cache_threaded2(self):
1460 # Simultaneous call with the same arguments
1461 n, m = 5, 7
1462 start = threading.Barrier(n+1)
1463 pause = threading.Barrier(n+1)
1464 stop = threading.Barrier(n+1)
1465 @self.module.lru_cache(maxsize=m*n)
1466 def f(x):
1467 pause.wait(10)
1468 return 3 * x
1469 self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1470 def test():
1471 for i in range(m):
1472 start.wait(10)
1473 self.assertEqual(f(i), 3 * i)
1474 stop.wait(10)
1475 threads = [threading.Thread(target=test) for k in range(n)]
1476 with support.start_threads(threads):
1477 for i in range(m):
1478 start.wait(10)
1479 stop.reset()
1480 pause.wait(10)
1481 start.reset()
1482 stop.wait(10)
1483 pause.reset()
1484 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1485
Serhiy Storchaka67796522017-01-12 18:34:33 +02001486 @unittest.skipUnless(threading, 'This test requires threading.')
1487 def test_lru_cache_threaded3(self):
1488 @self.module.lru_cache(maxsize=2)
1489 def f(x):
1490 time.sleep(.01)
1491 return 3 * x
1492 def test(i, x):
1493 with self.subTest(thread=i):
1494 self.assertEqual(f(x), 3 * x, i)
1495 threads = [threading.Thread(target=test, args=(i, v))
1496 for i, v in enumerate([1, 2, 2, 3, 2])]
1497 with support.start_threads(threads):
1498 pass
1499
Raymond Hettinger03923422013-03-04 02:52:50 -05001500 def test_need_for_rlock(self):
1501 # This will deadlock on an LRU cache that uses a regular lock
1502
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001503 @self.module.lru_cache(maxsize=10)
Raymond Hettinger03923422013-03-04 02:52:50 -05001504 def test_func(x):
1505 'Used to demonstrate a reentrant lru_cache call within a single thread'
1506 return x
1507
1508 class DoubleEq:
1509 'Demonstrate a reentrant lru_cache call within a single thread'
1510 def __init__(self, x):
1511 self.x = x
1512 def __hash__(self):
1513 return self.x
1514 def __eq__(self, other):
1515 if self.x == 2:
1516 test_func(DoubleEq(1))
1517 return self.x == other.x
1518
1519 test_func(DoubleEq(1)) # Load the cache
1520 test_func(DoubleEq(2)) # Load the cache
1521 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1522 DoubleEq(2)) # Verify the correct return value
1523
Raymond Hettinger4d588972014-08-12 12:44:52 -07001524 def test_early_detection_of_bad_call(self):
1525 # Issue #22184
1526 with self.assertRaises(TypeError):
1527 @functools.lru_cache
1528 def f():
1529 pass
1530
Serhiy Storchakae7070f02015-06-08 11:19:24 +03001531 def test_lru_method(self):
1532 class X(int):
1533 f_cnt = 0
1534 @self.module.lru_cache(2)
1535 def f(self, x):
1536 self.f_cnt += 1
1537 return x*10+self
1538 a = X(5)
1539 b = X(5)
1540 c = X(7)
1541 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1542
1543 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1544 self.assertEqual(a.f(x), x*10 + 5)
1545 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1546 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1547
1548 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1549 self.assertEqual(b.f(x), x*10 + 5)
1550 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1551 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1552
1553 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1554 self.assertEqual(c.f(x), x*10 + 7)
1555 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1556 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1557
1558 self.assertEqual(a.f.cache_info(), X.f.cache_info())
1559 self.assertEqual(b.f.cache_info(), X.f.cache_info())
1560 self.assertEqual(c.f.cache_info(), X.f.cache_info())
1561
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001562 def test_pickle(self):
1563 cls = self.__class__
1564 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1565 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1566 with self.subTest(proto=proto, func=f):
1567 f_copy = pickle.loads(pickle.dumps(f, proto))
1568 self.assertIs(f_copy, f)
1569
1570 def test_copy(self):
1571 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001572 def orig(x, y):
1573 return 3 * x + y
1574 part = self.module.partial(orig, 2)
1575 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1576 self.module.lru_cache(2)(part))
1577 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001578 with self.subTest(func=f):
1579 f_copy = copy.copy(f)
1580 self.assertIs(f_copy, f)
1581
1582 def test_deepcopy(self):
1583 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001584 def orig(x, y):
1585 return 3 * x + y
1586 part = self.module.partial(orig, 2)
1587 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1588 self.module.lru_cache(2)(part))
1589 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001590 with self.subTest(func=f):
1591 f_copy = copy.deepcopy(f)
1592 self.assertIs(f_copy, f)
1593
1594
1595@py_functools.lru_cache()
1596def py_cached_func(x, y):
1597 return 3 * x + y
1598
1599@c_functools.lru_cache()
1600def c_cached_func(x, y):
1601 return 3 * x + y
1602
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001603
1604class TestLRUPy(TestLRU, unittest.TestCase):
1605 module = py_functools
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001606 cached_func = py_cached_func,
1607
1608 @module.lru_cache()
1609 def cached_meth(self, x, y):
1610 return 3 * x + y
1611
1612 @staticmethod
1613 @module.lru_cache()
1614 def cached_staticmeth(x, y):
1615 return 3 * x + y
1616
1617
1618class TestLRUC(TestLRU, unittest.TestCase):
1619 module = c_functools
1620 cached_func = c_cached_func,
1621
1622 @module.lru_cache()
1623 def cached_meth(self, x, y):
1624 return 3 * x + y
1625
1626 @staticmethod
1627 @module.lru_cache()
1628 def cached_staticmeth(x, y):
1629 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001630
Raymond Hettinger03923422013-03-04 02:52:50 -05001631
Łukasz Langa6f692512013-06-05 12:20:24 +02001632class TestSingleDispatch(unittest.TestCase):
1633 def test_simple_overloads(self):
1634 @functools.singledispatch
1635 def g(obj):
1636 return "base"
1637 def g_int(i):
1638 return "integer"
1639 g.register(int, g_int)
1640 self.assertEqual(g("str"), "base")
1641 self.assertEqual(g(1), "integer")
1642 self.assertEqual(g([1,2,3]), "base")
1643
1644 def test_mro(self):
1645 @functools.singledispatch
1646 def g(obj):
1647 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001648 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001649 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001650 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001651 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001652 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001653 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001654 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001655 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001656 def g_A(a):
1657 return "A"
1658 def g_B(b):
1659 return "B"
1660 g.register(A, g_A)
1661 g.register(B, g_B)
1662 self.assertEqual(g(A()), "A")
1663 self.assertEqual(g(B()), "B")
1664 self.assertEqual(g(C()), "A")
1665 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001666
1667 def test_register_decorator(self):
1668 @functools.singledispatch
1669 def g(obj):
1670 return "base"
1671 @g.register(int)
1672 def g_int(i):
1673 return "int %s" % (i,)
1674 self.assertEqual(g(""), "base")
1675 self.assertEqual(g(12), "int 12")
1676 self.assertIs(g.dispatch(int), g_int)
1677 self.assertIs(g.dispatch(object), g.dispatch(str))
1678 # Note: in the assert above this is not g.
1679 # @singledispatch returns the wrapper.
1680
1681 def test_wrapping_attributes(self):
1682 @functools.singledispatch
1683 def g(obj):
1684 "Simple test"
1685 return "Test"
1686 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02001687 if sys.flags.optimize < 2:
1688 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02001689
1690 @unittest.skipUnless(decimal, 'requires _decimal')
1691 @support.cpython_only
1692 def test_c_classes(self):
1693 @functools.singledispatch
1694 def g(obj):
1695 return "base"
1696 @g.register(decimal.DecimalException)
1697 def _(obj):
1698 return obj.args
1699 subn = decimal.Subnormal("Exponent < Emin")
1700 rnd = decimal.Rounded("Number got rounded")
1701 self.assertEqual(g(subn), ("Exponent < Emin",))
1702 self.assertEqual(g(rnd), ("Number got rounded",))
1703 @g.register(decimal.Subnormal)
1704 def _(obj):
1705 return "Too small to care."
1706 self.assertEqual(g(subn), "Too small to care.")
1707 self.assertEqual(g(rnd), ("Number got rounded",))
1708
1709 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001710 # None of the examples in this test depend on haystack ordering.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001711 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001712 mro = functools._compose_mro
1713 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1714 for haystack in permutations(bases):
1715 m = mro(dict, haystack)
Guido van Rossumf0666942016-08-23 10:47:07 -07001716 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
1717 c.Collection, c.Sized, c.Iterable,
1718 c.Container, object])
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001719 bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict]
Łukasz Langa6f692512013-06-05 12:20:24 +02001720 for haystack in permutations(bases):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001721 m = mro(collections.ChainMap, haystack)
1722 self.assertEqual(m, [collections.ChainMap, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001723 c.Collection, c.Sized, c.Iterable,
1724 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001725
1726 # If there's a generic function with implementations registered for
1727 # both Sized and Container, passing a defaultdict to it results in an
1728 # ambiguous dispatch which will cause a RuntimeError (see
1729 # test_mro_conflicts).
1730 bases = [c.Container, c.Sized, str]
1731 for haystack in permutations(bases):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001732 m = mro(collections.defaultdict, [c.Sized, c.Container, str])
1733 self.assertEqual(m, [collections.defaultdict, dict, c.Sized,
1734 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001735
1736 # MutableSequence below is registered directly on D. In other words, it
Martin Panter46f50722016-05-26 05:35:26 +00001737 # precedes MutableMapping which means single dispatch will always
Łukasz Langa3720c772013-07-01 16:00:38 +02001738 # choose MutableSequence here.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001739 class D(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02001740 pass
1741 c.MutableSequence.register(D)
1742 bases = [c.MutableSequence, c.MutableMapping]
1743 for haystack in permutations(bases):
1744 m = mro(D, bases)
Guido van Rossumf0666942016-08-23 10:47:07 -07001745 self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001746 collections.defaultdict, dict, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001747 c.Collection, c.Sized, c.Iterable, c.Container,
Łukasz Langa3720c772013-07-01 16:00:38 +02001748 object])
1749
1750 # Container and Callable are registered on different base classes and
1751 # a generic function supporting both should always pick the Callable
1752 # implementation if a C instance is passed.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001753 class C(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02001754 def __call__(self):
1755 pass
1756 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1757 for haystack in permutations(bases):
1758 m = mro(C, haystack)
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001759 self.assertEqual(m, [C, c.Callable, collections.defaultdict, dict, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001760 c.Collection, c.Sized, c.Iterable,
1761 c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001762
1763 def test_register_abc(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001764 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001765 d = {"a": "b"}
1766 l = [1, 2, 3]
1767 s = {object(), None}
1768 f = frozenset(s)
1769 t = (1, 2, 3)
1770 @functools.singledispatch
1771 def g(obj):
1772 return "base"
1773 self.assertEqual(g(d), "base")
1774 self.assertEqual(g(l), "base")
1775 self.assertEqual(g(s), "base")
1776 self.assertEqual(g(f), "base")
1777 self.assertEqual(g(t), "base")
1778 g.register(c.Sized, lambda obj: "sized")
1779 self.assertEqual(g(d), "sized")
1780 self.assertEqual(g(l), "sized")
1781 self.assertEqual(g(s), "sized")
1782 self.assertEqual(g(f), "sized")
1783 self.assertEqual(g(t), "sized")
1784 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1785 self.assertEqual(g(d), "mutablemapping")
1786 self.assertEqual(g(l), "sized")
1787 self.assertEqual(g(s), "sized")
1788 self.assertEqual(g(f), "sized")
1789 self.assertEqual(g(t), "sized")
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001790 g.register(collections.ChainMap, lambda obj: "chainmap")
Łukasz Langa6f692512013-06-05 12:20:24 +02001791 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1792 self.assertEqual(g(l), "sized")
1793 self.assertEqual(g(s), "sized")
1794 self.assertEqual(g(f), "sized")
1795 self.assertEqual(g(t), "sized")
1796 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1797 self.assertEqual(g(d), "mutablemapping")
1798 self.assertEqual(g(l), "mutablesequence")
1799 self.assertEqual(g(s), "sized")
1800 self.assertEqual(g(f), "sized")
1801 self.assertEqual(g(t), "sized")
1802 g.register(c.MutableSet, lambda obj: "mutableset")
1803 self.assertEqual(g(d), "mutablemapping")
1804 self.assertEqual(g(l), "mutablesequence")
1805 self.assertEqual(g(s), "mutableset")
1806 self.assertEqual(g(f), "sized")
1807 self.assertEqual(g(t), "sized")
1808 g.register(c.Mapping, lambda obj: "mapping")
1809 self.assertEqual(g(d), "mutablemapping") # not specific enough
1810 self.assertEqual(g(l), "mutablesequence")
1811 self.assertEqual(g(s), "mutableset")
1812 self.assertEqual(g(f), "sized")
1813 self.assertEqual(g(t), "sized")
1814 g.register(c.Sequence, lambda obj: "sequence")
1815 self.assertEqual(g(d), "mutablemapping")
1816 self.assertEqual(g(l), "mutablesequence")
1817 self.assertEqual(g(s), "mutableset")
1818 self.assertEqual(g(f), "sized")
1819 self.assertEqual(g(t), "sequence")
1820 g.register(c.Set, lambda obj: "set")
1821 self.assertEqual(g(d), "mutablemapping")
1822 self.assertEqual(g(l), "mutablesequence")
1823 self.assertEqual(g(s), "mutableset")
1824 self.assertEqual(g(f), "set")
1825 self.assertEqual(g(t), "sequence")
1826 g.register(dict, lambda obj: "dict")
1827 self.assertEqual(g(d), "dict")
1828 self.assertEqual(g(l), "mutablesequence")
1829 self.assertEqual(g(s), "mutableset")
1830 self.assertEqual(g(f), "set")
1831 self.assertEqual(g(t), "sequence")
1832 g.register(list, lambda obj: "list")
1833 self.assertEqual(g(d), "dict")
1834 self.assertEqual(g(l), "list")
1835 self.assertEqual(g(s), "mutableset")
1836 self.assertEqual(g(f), "set")
1837 self.assertEqual(g(t), "sequence")
1838 g.register(set, lambda obj: "concrete-set")
1839 self.assertEqual(g(d), "dict")
1840 self.assertEqual(g(l), "list")
1841 self.assertEqual(g(s), "concrete-set")
1842 self.assertEqual(g(f), "set")
1843 self.assertEqual(g(t), "sequence")
1844 g.register(frozenset, lambda obj: "frozen-set")
1845 self.assertEqual(g(d), "dict")
1846 self.assertEqual(g(l), "list")
1847 self.assertEqual(g(s), "concrete-set")
1848 self.assertEqual(g(f), "frozen-set")
1849 self.assertEqual(g(t), "sequence")
1850 g.register(tuple, lambda obj: "tuple")
1851 self.assertEqual(g(d), "dict")
1852 self.assertEqual(g(l), "list")
1853 self.assertEqual(g(s), "concrete-set")
1854 self.assertEqual(g(f), "frozen-set")
1855 self.assertEqual(g(t), "tuple")
1856
Łukasz Langa3720c772013-07-01 16:00:38 +02001857 def test_c3_abc(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001858 c = collections.abc
Łukasz Langa3720c772013-07-01 16:00:38 +02001859 mro = functools._c3_mro
1860 class A(object):
1861 pass
1862 class B(A):
1863 def __len__(self):
1864 return 0 # implies Sized
1865 @c.Container.register
1866 class C(object):
1867 pass
1868 class D(object):
1869 pass # unrelated
1870 class X(D, C, B):
1871 def __call__(self):
1872 pass # implies Callable
1873 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
1874 for abcs in permutations([c.Sized, c.Callable, c.Container]):
1875 self.assertEqual(mro(X, abcs=abcs), expected)
1876 # unrelated ABCs don't appear in the resulting MRO
1877 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
1878 self.assertEqual(mro(X, abcs=many_abcs), expected)
1879
Yury Selivanov77a8cd62015-08-18 14:20:00 -04001880 def test_false_meta(self):
1881 # see issue23572
1882 class MetaA(type):
1883 def __len__(self):
1884 return 0
1885 class A(metaclass=MetaA):
1886 pass
1887 class AA(A):
1888 pass
1889 @functools.singledispatch
1890 def fun(a):
1891 return 'base A'
1892 @fun.register(A)
1893 def _(a):
1894 return 'fun A'
1895 aa = AA()
1896 self.assertEqual(fun(aa), 'fun A')
1897
Łukasz Langa6f692512013-06-05 12:20:24 +02001898 def test_mro_conflicts(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001899 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001900 @functools.singledispatch
1901 def g(arg):
1902 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02001903 class O(c.Sized):
1904 def __len__(self):
1905 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001906 o = O()
1907 self.assertEqual(g(o), "base")
1908 g.register(c.Iterable, lambda arg: "iterable")
1909 g.register(c.Container, lambda arg: "container")
1910 g.register(c.Sized, lambda arg: "sized")
1911 g.register(c.Set, lambda arg: "set")
1912 self.assertEqual(g(o), "sized")
1913 c.Iterable.register(O)
1914 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
1915 c.Container.register(O)
1916 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02001917 c.Set.register(O)
1918 self.assertEqual(g(o), "set") # because c.Set is a subclass of
1919 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02001920 class P:
1921 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02001922 p = P()
1923 self.assertEqual(g(p), "base")
1924 c.Iterable.register(P)
1925 self.assertEqual(g(p), "iterable")
1926 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02001927 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02001928 g(p)
Benjamin Petersonab078e92016-07-13 21:13:29 -07001929 self.assertIn(
1930 str(re_one.exception),
1931 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1932 "or <class 'collections.abc.Iterable'>"),
1933 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
1934 "or <class 'collections.abc.Container'>")),
1935 )
Łukasz Langa6f692512013-06-05 12:20:24 +02001936 class Q(c.Sized):
1937 def __len__(self):
1938 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001939 q = Q()
1940 self.assertEqual(g(q), "sized")
1941 c.Iterable.register(Q)
1942 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
1943 c.Set.register(Q)
1944 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02001945 # c.Sized and c.Iterable
1946 @functools.singledispatch
1947 def h(arg):
1948 return "base"
1949 @h.register(c.Sized)
1950 def _(arg):
1951 return "sized"
1952 @h.register(c.Container)
1953 def _(arg):
1954 return "container"
1955 # Even though Sized and Container are explicit bases of MutableMapping,
1956 # this ABC is implicitly registered on defaultdict which makes all of
1957 # MutableMapping's bases implicit as well from defaultdict's
1958 # perspective.
1959 with self.assertRaises(RuntimeError) as re_two:
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001960 h(collections.defaultdict(lambda: 0))
Benjamin Petersonab078e92016-07-13 21:13:29 -07001961 self.assertIn(
1962 str(re_two.exception),
1963 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1964 "or <class 'collections.abc.Sized'>"),
1965 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1966 "or <class 'collections.abc.Container'>")),
1967 )
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001968 class R(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02001969 pass
1970 c.MutableSequence.register(R)
1971 @functools.singledispatch
1972 def i(arg):
1973 return "base"
1974 @i.register(c.MutableMapping)
1975 def _(arg):
1976 return "mapping"
1977 @i.register(c.MutableSequence)
1978 def _(arg):
1979 return "sequence"
1980 r = R()
1981 self.assertEqual(i(r), "sequence")
1982 class S:
1983 pass
1984 class T(S, c.Sized):
1985 def __len__(self):
1986 return 0
1987 t = T()
1988 self.assertEqual(h(t), "sized")
1989 c.Container.register(T)
1990 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
1991 class U:
1992 def __len__(self):
1993 return 0
1994 u = U()
1995 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
1996 # from the existence of __len__()
1997 c.Container.register(U)
1998 # There is no preference for registered versus inferred ABCs.
1999 with self.assertRaises(RuntimeError) as re_three:
2000 h(u)
Benjamin Petersonab078e92016-07-13 21:13:29 -07002001 self.assertIn(
2002 str(re_three.exception),
2003 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2004 "or <class 'collections.abc.Sized'>"),
2005 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2006 "or <class 'collections.abc.Container'>")),
2007 )
Łukasz Langa3720c772013-07-01 16:00:38 +02002008 class V(c.Sized, S):
2009 def __len__(self):
2010 return 0
2011 @functools.singledispatch
2012 def j(arg):
2013 return "base"
2014 @j.register(S)
2015 def _(arg):
2016 return "s"
2017 @j.register(c.Container)
2018 def _(arg):
2019 return "container"
2020 v = V()
2021 self.assertEqual(j(v), "s")
2022 c.Container.register(V)
2023 self.assertEqual(j(v), "container") # because it ends up right after
2024 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02002025
2026 def test_cache_invalidation(self):
2027 from collections import UserDict
2028 class TracingDict(UserDict):
2029 def __init__(self, *args, **kwargs):
2030 super(TracingDict, self).__init__(*args, **kwargs)
2031 self.set_ops = []
2032 self.get_ops = []
2033 def __getitem__(self, key):
2034 result = self.data[key]
2035 self.get_ops.append(key)
2036 return result
2037 def __setitem__(self, key, value):
2038 self.set_ops.append(key)
2039 self.data[key] = value
2040 def clear(self):
2041 self.data.clear()
2042 _orig_wkd = functools.WeakKeyDictionary
2043 td = TracingDict()
2044 functools.WeakKeyDictionary = lambda: td
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002045 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02002046 @functools.singledispatch
2047 def g(arg):
2048 return "base"
2049 d = {}
2050 l = []
2051 self.assertEqual(len(td), 0)
2052 self.assertEqual(g(d), "base")
2053 self.assertEqual(len(td), 1)
2054 self.assertEqual(td.get_ops, [])
2055 self.assertEqual(td.set_ops, [dict])
2056 self.assertEqual(td.data[dict], g.registry[object])
2057 self.assertEqual(g(l), "base")
2058 self.assertEqual(len(td), 2)
2059 self.assertEqual(td.get_ops, [])
2060 self.assertEqual(td.set_ops, [dict, list])
2061 self.assertEqual(td.data[dict], g.registry[object])
2062 self.assertEqual(td.data[list], g.registry[object])
2063 self.assertEqual(td.data[dict], td.data[list])
2064 self.assertEqual(g(l), "base")
2065 self.assertEqual(g(d), "base")
2066 self.assertEqual(td.get_ops, [list, dict])
2067 self.assertEqual(td.set_ops, [dict, list])
2068 g.register(list, lambda arg: "list")
2069 self.assertEqual(td.get_ops, [list, dict])
2070 self.assertEqual(len(td), 0)
2071 self.assertEqual(g(d), "base")
2072 self.assertEqual(len(td), 1)
2073 self.assertEqual(td.get_ops, [list, dict])
2074 self.assertEqual(td.set_ops, [dict, list, dict])
2075 self.assertEqual(td.data[dict],
2076 functools._find_impl(dict, g.registry))
2077 self.assertEqual(g(l), "list")
2078 self.assertEqual(len(td), 2)
2079 self.assertEqual(td.get_ops, [list, dict])
2080 self.assertEqual(td.set_ops, [dict, list, dict, list])
2081 self.assertEqual(td.data[list],
2082 functools._find_impl(list, g.registry))
2083 class X:
2084 pass
2085 c.MutableMapping.register(X) # Will not invalidate the cache,
2086 # not using ABCs yet.
2087 self.assertEqual(g(d), "base")
2088 self.assertEqual(g(l), "list")
2089 self.assertEqual(td.get_ops, [list, dict, dict, list])
2090 self.assertEqual(td.set_ops, [dict, list, dict, list])
2091 g.register(c.Sized, lambda arg: "sized")
2092 self.assertEqual(len(td), 0)
2093 self.assertEqual(g(d), "sized")
2094 self.assertEqual(len(td), 1)
2095 self.assertEqual(td.get_ops, [list, dict, dict, list])
2096 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
2097 self.assertEqual(g(l), "list")
2098 self.assertEqual(len(td), 2)
2099 self.assertEqual(td.get_ops, [list, dict, dict, list])
2100 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2101 self.assertEqual(g(l), "list")
2102 self.assertEqual(g(d), "sized")
2103 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
2104 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2105 g.dispatch(list)
2106 g.dispatch(dict)
2107 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
2108 list, dict])
2109 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2110 c.MutableSet.register(X) # Will invalidate the cache.
2111 self.assertEqual(len(td), 2) # Stale cache.
2112 self.assertEqual(g(l), "list")
2113 self.assertEqual(len(td), 1)
2114 g.register(c.MutableMapping, lambda arg: "mutablemapping")
2115 self.assertEqual(len(td), 0)
2116 self.assertEqual(g(d), "mutablemapping")
2117 self.assertEqual(len(td), 1)
2118 self.assertEqual(g(l), "list")
2119 self.assertEqual(len(td), 2)
2120 g.register(dict, lambda arg: "dict")
2121 self.assertEqual(g(d), "dict")
2122 self.assertEqual(g(l), "list")
2123 g._clear_cache()
2124 self.assertEqual(len(td), 0)
2125 functools.WeakKeyDictionary = _orig_wkd
2126
2127
Raymond Hettinger9c323f82005-02-28 19:39:44 +00002128if __name__ == '__main__':
Zachary Ware38c707e2015-04-13 15:00:43 -05002129 unittest.main()