blob: 29ea49362262dd0d4d30172abd094a0cab930f93 [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08002import builtins
Raymond Hettinger003be522011-05-03 11:01:32 -07003import collections
Serhiy Storchaka45120f22015-10-24 09:49:56 +03004import copy
Łukasz Langa6f692512013-06-05 12:20:24 +02005from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00006import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00007from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02008import sys
9from test import support
Serhiy Storchaka67796522017-01-12 18:34:33 +020010import time
Łukasz Langa6f692512013-06-05 12:20:24 +020011import unittest
Raymond Hettingerd191ef22017-01-07 20:44:48 -080012import unittest.mock
Łukasz Langa6f692512013-06-05 12:20:24 +020013from weakref import proxy
Nick Coghlan457fc9a2016-09-10 20:00:02 +100014import contextlib
Serhiy Storchaka46c56112015-05-24 21:53:49 +030015try:
16 import threading
17except ImportError:
18 threading = None
Raymond Hettinger9c323f82005-02-28 19:39:44 +000019
Antoine Pitroub5b37142012-11-13 21:35:40 +010020import functools
21
Antoine Pitroub5b37142012-11-13 21:35:40 +010022py_functools = support.import_fresh_module('functools', blocked=['_functools'])
23c_functools = support.import_fresh_module('functools', fresh=['_functools'])
24
Łukasz Langa6f692512013-06-05 12:20:24 +020025decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
26
Nick Coghlan457fc9a2016-09-10 20:00:02 +100027@contextlib.contextmanager
28def replaced_module(name, replacement):
29 original_module = sys.modules[name]
30 sys.modules[name] = replacement
31 try:
32 yield
33 finally:
34 sys.modules[name] = original_module
Łukasz Langa6f692512013-06-05 12:20:24 +020035
Raymond Hettinger9c323f82005-02-28 19:39:44 +000036def capture(*args, **kw):
37 """capture all positional and keyword arguments"""
38 return args, kw
39
Łukasz Langa6f692512013-06-05 12:20:24 +020040
Jack Diederiche0cbd692009-04-01 04:27:09 +000041def signature(part):
42 """ return the signature of a partial object """
43 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000044
Serhiy Storchaka38741282016-02-02 18:45:17 +020045class MyTuple(tuple):
46 pass
47
48class BadTuple(tuple):
49 def __add__(self, other):
50 return list(self) + list(other)
51
52class MyDict(dict):
53 pass
54
Łukasz Langa6f692512013-06-05 12:20:24 +020055
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020056class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000057
58 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010059 p = self.partial(capture, 1, 2, a=10, b=20)
60 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000061 self.assertEqual(p(3, 4, b=30, c=40),
62 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010063 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000064 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000065
66 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010067 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000068 # attributes should be readable
69 self.assertEqual(p.func, capture)
70 self.assertEqual(p.args, (1, 2))
71 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000072
73 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010074 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000075 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010076 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000077 except TypeError:
78 pass
79 else:
80 self.fail('First arg not checked for callability')
81
82 def test_protection_of_callers_dict_argument(self):
83 # a caller's dictionary should not be altered by partial
84 def func(a=10, b=20):
85 return a
86 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010087 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000088 self.assertEqual(p(**d), 3)
89 self.assertEqual(d, {'a':3})
90 p(b=7)
91 self.assertEqual(d, {'a':3})
92
Serhiy Storchaka9639e4a2017-02-20 14:04:30 +020093 def test_kwargs_copy(self):
94 # Issue #29532: Altering a kwarg dictionary passed to a constructor
95 # should not affect a partial object after creation
96 d = {'a': 3}
97 p = self.partial(capture, **d)
98 self.assertEqual(p(), ((), {'a': 3}))
99 d['a'] = 5
100 self.assertEqual(p(), ((), {'a': 3}))
101
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000102 def test_arg_combinations(self):
103 # exercise special code paths for zero args in either partial
104 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100105 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000106 self.assertEqual(p(), ((), {}))
107 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100108 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000109 self.assertEqual(p(), ((1,2), {}))
110 self.assertEqual(p(3,4), ((1,2,3,4), {}))
111
112 def test_kw_combinations(self):
113 # exercise special code paths for no keyword args in
114 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100115 p = self.partial(capture)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400116 self.assertEqual(p.keywords, {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000117 self.assertEqual(p(), ((), {}))
118 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100119 p = self.partial(capture, a=1)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400120 self.assertEqual(p.keywords, {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000121 self.assertEqual(p(), ((), {'a':1}))
122 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
123 # keyword args in the call override those in the partial object
124 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
125
126 def test_positional(self):
127 # make sure positional arguments are captured correctly
128 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100129 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000130 expected = args + ('x',)
131 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000132 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000133
134 def test_keyword(self):
135 # make sure keyword arguments are captured correctly
136 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100137 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000138 expected = {'a':a,'x':None}
139 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000140 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000141
142 def test_no_side_effects(self):
143 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100144 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000145 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000146 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000147 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000148 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000149
150 def test_error_propagation(self):
151 def f(x, y):
152 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100153 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
154 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
155 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
156 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000157
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000158 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100159 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000160 p = proxy(f)
161 self.assertEqual(f.func, p.func)
162 f = None
163 self.assertRaises(ReferenceError, getattr, p, 'func')
164
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000165 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000166 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100167 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000168 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100169 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000170 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000171
Alexander Belopolskye49af342015-03-01 15:08:17 -0500172 def test_nested_optimization(self):
173 partial = self.partial
Alexander Belopolskye49af342015-03-01 15:08:17 -0500174 inner = partial(signature, 'asdf')
175 nested = partial(inner, bar=True)
176 flat = partial(signature, 'asdf', bar=True)
177 self.assertEqual(signature(nested), signature(flat))
178
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300179 def test_nested_partial_with_attribute(self):
180 # see issue 25137
181 partial = self.partial
182
183 def foo(bar):
184 return bar
185
186 p = partial(foo, 'first')
187 p2 = partial(p, 'second')
188 p2.new_attr = 'spam'
189 self.assertEqual(p2.new_attr, 'spam')
190
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000191 def test_repr(self):
192 args = (object(), object())
193 args_repr = ', '.join(repr(a) for a in args)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200194 kwargs = {'a': object(), 'b': object()}
195 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
196 'b={b!r}, a={a!r}'.format_map(kwargs)]
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000197 if self.partial in (c_functools.partial, py_functools.partial):
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000198 name = 'functools.partial'
199 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100200 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000201
Antoine Pitroub5b37142012-11-13 21:35:40 +0100202 f = self.partial(capture)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000203 self.assertEqual(f'{name}({capture!r})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000204
Antoine Pitroub5b37142012-11-13 21:35:40 +0100205 f = self.partial(capture, *args)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000206 self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000207
Antoine Pitroub5b37142012-11-13 21:35:40 +0100208 f = self.partial(capture, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200209 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000210 [f'{name}({capture!r}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200211 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000212
Antoine Pitroub5b37142012-11-13 21:35:40 +0100213 f = self.partial(capture, *args, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200214 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000215 [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200216 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000217
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300218 def test_recursive_repr(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000219 if self.partial in (c_functools.partial, py_functools.partial):
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300220 name = 'functools.partial'
221 else:
222 name = self.partial.__name__
223
224 f = self.partial(capture)
225 f.__setstate__((f, (), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300226 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000227 self.assertEqual(repr(f), '%s(...)' % (name,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300228 finally:
229 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300230
231 f = self.partial(capture)
232 f.__setstate__((capture, (f,), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300233 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000234 self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300235 finally:
236 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300237
238 f = self.partial(capture)
239 f.__setstate__((capture, (), {'a': f}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300240 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000241 self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300242 finally:
243 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300244
Jack Diederiche0cbd692009-04-01 04:27:09 +0000245 def test_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000246 with self.AllowPickle():
247 f = self.partial(signature, ['asdf'], bar=[True])
248 f.attr = []
249 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
250 f_copy = pickle.loads(pickle.dumps(f, proto))
251 self.assertEqual(signature(f_copy), signature(f))
Serhiy Storchaka38741282016-02-02 18:45:17 +0200252
253 def test_copy(self):
254 f = self.partial(signature, ['asdf'], bar=[True])
255 f.attr = []
256 f_copy = copy.copy(f)
257 self.assertEqual(signature(f_copy), signature(f))
258 self.assertIs(f_copy.attr, f.attr)
259 self.assertIs(f_copy.args, f.args)
260 self.assertIs(f_copy.keywords, f.keywords)
261
262 def test_deepcopy(self):
263 f = self.partial(signature, ['asdf'], bar=[True])
264 f.attr = []
265 f_copy = copy.deepcopy(f)
266 self.assertEqual(signature(f_copy), signature(f))
267 self.assertIsNot(f_copy.attr, f.attr)
268 self.assertIsNot(f_copy.args, f.args)
269 self.assertIsNot(f_copy.args[0], f.args[0])
270 self.assertIsNot(f_copy.keywords, f.keywords)
271 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
272
273 def test_setstate(self):
274 f = self.partial(signature)
275 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000276
Serhiy Storchaka38741282016-02-02 18:45:17 +0200277 self.assertEqual(signature(f),
278 (capture, (1,), dict(a=10), dict(attr=[])))
279 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
280
281 f.__setstate__((capture, (1,), dict(a=10), None))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000282
Serhiy Storchaka38741282016-02-02 18:45:17 +0200283 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
284 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
285
286 f.__setstate__((capture, (1,), None, None))
287 #self.assertEqual(signature(f), (capture, (1,), {}, {}))
288 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
289 self.assertEqual(f(2), ((1, 2), {}))
290 self.assertEqual(f(), ((1,), {}))
291
292 f.__setstate__((capture, (), {}, None))
293 self.assertEqual(signature(f), (capture, (), {}, {}))
294 self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
295 self.assertEqual(f(2), ((2,), {}))
296 self.assertEqual(f(), ((), {}))
297
298 def test_setstate_errors(self):
299 f = self.partial(signature)
300 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
301 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
302 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
303 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
304 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
305 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
306 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
307
308 def test_setstate_subclasses(self):
309 f = self.partial(signature)
310 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
311 s = signature(f)
312 self.assertEqual(s, (capture, (1,), dict(a=10), {}))
313 self.assertIs(type(s[1]), tuple)
314 self.assertIs(type(s[2]), dict)
315 r = f()
316 self.assertEqual(r, ((1,), {'a': 10}))
317 self.assertIs(type(r[0]), tuple)
318 self.assertIs(type(r[1]), dict)
319
320 f.__setstate__((capture, BadTuple((1,)), {}, None))
321 s = signature(f)
322 self.assertEqual(s, (capture, (1,), {}, {}))
323 self.assertIs(type(s[1]), tuple)
324 r = f(2)
325 self.assertEqual(r, ((1, 2), {}))
326 self.assertIs(type(r[0]), tuple)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000327
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300328 def test_recursive_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000329 with self.AllowPickle():
330 f = self.partial(capture)
331 f.__setstate__((f, (), {}, {}))
332 try:
333 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
334 with self.assertRaises(RecursionError):
335 pickle.dumps(f, proto)
336 finally:
337 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300338
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000339 f = self.partial(capture)
340 f.__setstate__((capture, (f,), {}, {}))
341 try:
342 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
343 f_copy = pickle.loads(pickle.dumps(f, proto))
344 try:
345 self.assertIs(f_copy.args[0], f_copy)
346 finally:
347 f_copy.__setstate__((capture, (), {}, {}))
348 finally:
349 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300350
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000351 f = self.partial(capture)
352 f.__setstate__((capture, (), {'a': f}, {}))
353 try:
354 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
355 f_copy = pickle.loads(pickle.dumps(f, proto))
356 try:
357 self.assertIs(f_copy.keywords['a'], f_copy)
358 finally:
359 f_copy.__setstate__((capture, (), {}, {}))
360 finally:
361 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300362
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200363 # Issue 6083: Reference counting bug
364 def test_setstate_refcount(self):
365 class BadSequence:
366 def __len__(self):
367 return 4
368 def __getitem__(self, key):
369 if key == 0:
370 return max
371 elif key == 1:
372 return tuple(range(1000000))
373 elif key in (2, 3):
374 return {}
375 raise IndexError
376
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200377 f = self.partial(object)
Serhiy Storchaka38741282016-02-02 18:45:17 +0200378 self.assertRaises(TypeError, f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000379
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000380@unittest.skipUnless(c_functools, 'requires the C _functools module')
381class TestPartialC(TestPartial, unittest.TestCase):
382 if c_functools:
383 partial = c_functools.partial
384
385 class AllowPickle:
386 def __enter__(self):
387 return self
388 def __exit__(self, type, value, tb):
389 return False
390
391 def test_attributes_unwritable(self):
392 # attributes should not be writable
393 p = self.partial(capture, 1, 2, a=10, b=20)
394 self.assertRaises(AttributeError, setattr, p, 'func', map)
395 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
396 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
397
398 p = self.partial(hex)
399 try:
400 del p.__dict__
401 except TypeError:
402 pass
403 else:
404 self.fail('partial object allowed __dict__ to be deleted')
Łukasz Langa6f692512013-06-05 12:20:24 +0200405
Michael Seifert6c3d5272017-03-15 06:26:33 +0100406 def test_manually_adding_non_string_keyword(self):
407 p = self.partial(capture)
408 # Adding a non-string/unicode keyword to partial kwargs
409 p.keywords[1234] = 'value'
410 r = repr(p)
411 self.assertIn('1234', r)
412 self.assertIn("'value'", r)
413 with self.assertRaises(TypeError):
414 p()
415
416 def test_keystr_replaces_value(self):
417 p = self.partial(capture)
418
419 class MutatesYourDict(object):
420 def __str__(self):
421 p.keywords[self] = ['sth2']
422 return 'astr'
423
424 # Raplacing the value during key formatting should keep the original
425 # value alive (at least long enough).
426 p.keywords[MutatesYourDict()] = ['sth']
427 r = repr(p)
428 self.assertIn('astr', r)
429 self.assertIn("['sth']", r)
430
431
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200432class TestPartialPy(TestPartial, unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000433 partial = py_functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000434
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000435 class AllowPickle:
436 def __init__(self):
437 self._cm = replaced_module("functools", py_functools)
438 def __enter__(self):
439 return self._cm.__enter__()
440 def __exit__(self, type, value, tb):
441 return self._cm.__exit__(type, value, tb)
Łukasz Langa6f692512013-06-05 12:20:24 +0200442
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200443if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000444 class CPartialSubclass(c_functools.partial):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200445 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100446
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000447class PyPartialSubclass(py_functools.partial):
448 pass
Łukasz Langa6f692512013-06-05 12:20:24 +0200449
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200450@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200451class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200452 if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000453 partial = CPartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000454
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300455 # partial subclasses are not optimized for nested calls
456 test_nested_optimization = None
457
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000458class TestPartialPySubclass(TestPartialPy):
459 partial = PyPartialSubclass
Łukasz Langa6f692512013-06-05 12:20:24 +0200460
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000461class TestPartialMethod(unittest.TestCase):
462
463 class A(object):
464 nothing = functools.partialmethod(capture)
465 positional = functools.partialmethod(capture, 1)
466 keywords = functools.partialmethod(capture, a=2)
467 both = functools.partialmethod(capture, 3, b=4)
468
469 nested = functools.partialmethod(positional, 5)
470
471 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
472
473 static = functools.partialmethod(staticmethod(capture), 8)
474 cls = functools.partialmethod(classmethod(capture), d=9)
475
476 a = A()
477
478 def test_arg_combinations(self):
479 self.assertEqual(self.a.nothing(), ((self.a,), {}))
480 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
481 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
482 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
483
484 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
485 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
486 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
487 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
488
489 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
490 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
491 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
492 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
493
494 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
495 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
496 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
497 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
498
499 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
500
501 def test_nested(self):
502 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
503 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
504 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
505 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
506
507 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
508
509 def test_over_partial(self):
510 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
511 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
512 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
513 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
514
515 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
516
517 def test_bound_method_introspection(self):
518 obj = self.a
519 self.assertIs(obj.both.__self__, obj)
520 self.assertIs(obj.nested.__self__, obj)
521 self.assertIs(obj.over_partial.__self__, obj)
522 self.assertIs(obj.cls.__self__, self.A)
523 self.assertIs(self.A.cls.__self__, self.A)
524
525 def test_unbound_method_retrieval(self):
526 obj = self.A
527 self.assertFalse(hasattr(obj.both, "__self__"))
528 self.assertFalse(hasattr(obj.nested, "__self__"))
529 self.assertFalse(hasattr(obj.over_partial, "__self__"))
530 self.assertFalse(hasattr(obj.static, "__self__"))
531 self.assertFalse(hasattr(self.a.static, "__self__"))
532
533 def test_descriptors(self):
534 for obj in [self.A, self.a]:
535 with self.subTest(obj=obj):
536 self.assertEqual(obj.static(), ((8,), {}))
537 self.assertEqual(obj.static(5), ((8, 5), {}))
538 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
539 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
540
541 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
542 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
543 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
544 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
545
546 def test_overriding_keywords(self):
547 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
548 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
549
550 def test_invalid_args(self):
551 with self.assertRaises(TypeError):
552 class B(object):
553 method = functools.partialmethod(None, 1)
554
555 def test_repr(self):
556 self.assertEqual(repr(vars(self.A)['both']),
557 'functools.partialmethod({}, 3, b=4)'.format(capture))
558
559 def test_abstract(self):
560 class Abstract(abc.ABCMeta):
561
562 @abc.abstractmethod
563 def add(self, x, y):
564 pass
565
566 add5 = functools.partialmethod(add, 5)
567
568 self.assertTrue(Abstract.add.__isabstractmethod__)
569 self.assertTrue(Abstract.add5.__isabstractmethod__)
570
571 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
572 self.assertFalse(getattr(func, '__isabstractmethod__', False))
573
574
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000575class TestUpdateWrapper(unittest.TestCase):
576
577 def check_wrapper(self, wrapper, wrapped,
578 assigned=functools.WRAPPER_ASSIGNMENTS,
579 updated=functools.WRAPPER_UPDATES):
580 # Check attributes were assigned
581 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000582 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000583 # Check attributes were updated
584 for name in updated:
585 wrapper_attr = getattr(wrapper, name)
586 wrapped_attr = getattr(wrapped, name)
587 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000588 if name == "__dict__" and key == "__wrapped__":
589 # __wrapped__ is overwritten by the update code
590 continue
591 self.assertIs(wrapped_attr[key], wrapper_attr[key])
592 # Check __wrapped__
593 self.assertIs(wrapper.__wrapped__, wrapped)
594
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000595
R. David Murray378c0cf2010-02-24 01:46:21 +0000596 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000597 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000598 """This is a test"""
599 pass
600 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000601 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000602 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000603 pass
604 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000605 return wrapper, f
606
607 def test_default_update(self):
608 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000609 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000610 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000611 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600612 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000613 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000614 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
615 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000616
R. David Murray378c0cf2010-02-24 01:46:21 +0000617 @unittest.skipIf(sys.flags.optimize >= 2,
618 "Docstrings are omitted with -O2 and above")
619 def test_default_update_doc(self):
620 wrapper, f = self._default_update()
621 self.assertEqual(wrapper.__doc__, 'This is a test')
622
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000623 def test_no_update(self):
624 def f():
625 """This is a test"""
626 pass
627 f.attr = 'This is also a test'
628 def wrapper():
629 pass
630 functools.update_wrapper(wrapper, f, (), ())
631 self.check_wrapper(wrapper, f, (), ())
632 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600633 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000634 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000635 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000636 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000637
638 def test_selective_update(self):
639 def f():
640 pass
641 f.attr = 'This is a different test'
642 f.dict_attr = dict(a=1, b=2, c=3)
643 def wrapper():
644 pass
645 wrapper.dict_attr = {}
646 assign = ('attr',)
647 update = ('dict_attr',)
648 functools.update_wrapper(wrapper, f, assign, update)
649 self.check_wrapper(wrapper, f, assign, update)
650 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600651 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000652 self.assertEqual(wrapper.__doc__, None)
653 self.assertEqual(wrapper.attr, 'This is a different test')
654 self.assertEqual(wrapper.dict_attr, f.dict_attr)
655
Nick Coghlan98876832010-08-17 06:17:18 +0000656 def test_missing_attributes(self):
657 def f():
658 pass
659 def wrapper():
660 pass
661 wrapper.dict_attr = {}
662 assign = ('attr',)
663 update = ('dict_attr',)
664 # Missing attributes on wrapped object are ignored
665 functools.update_wrapper(wrapper, f, assign, update)
666 self.assertNotIn('attr', wrapper.__dict__)
667 self.assertEqual(wrapper.dict_attr, {})
668 # Wrapper must have expected attributes for updating
669 del wrapper.dict_attr
670 with self.assertRaises(AttributeError):
671 functools.update_wrapper(wrapper, f, assign, update)
672 wrapper.dict_attr = 1
673 with self.assertRaises(AttributeError):
674 functools.update_wrapper(wrapper, f, assign, update)
675
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200676 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000677 @unittest.skipIf(sys.flags.optimize >= 2,
678 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000679 def test_builtin_update(self):
680 # Test for bug #1576241
681 def wrapper():
682 pass
683 functools.update_wrapper(wrapper, max)
684 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000685 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000686 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000687
Łukasz Langa6f692512013-06-05 12:20:24 +0200688
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000689class TestWraps(TestUpdateWrapper):
690
R. David Murray378c0cf2010-02-24 01:46:21 +0000691 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000692 def f():
693 """This is a test"""
694 pass
695 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000696 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000697 @functools.wraps(f)
698 def wrapper():
699 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600700 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000701
702 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600703 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000704 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000705 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600706 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000707 self.assertEqual(wrapper.attr, 'This is also a test')
708
Antoine Pitroub5b37142012-11-13 21:35:40 +0100709 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000710 "Docstrings are omitted with -O2 and above")
711 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600712 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000713 self.assertEqual(wrapper.__doc__, 'This is a test')
714
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000715 def test_no_update(self):
716 def f():
717 """This is a test"""
718 pass
719 f.attr = 'This is also a test'
720 @functools.wraps(f, (), ())
721 def wrapper():
722 pass
723 self.check_wrapper(wrapper, f, (), ())
724 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600725 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000726 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000727 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000728
729 def test_selective_update(self):
730 def f():
731 pass
732 f.attr = 'This is a different test'
733 f.dict_attr = dict(a=1, b=2, c=3)
734 def add_dict_attr(f):
735 f.dict_attr = {}
736 return f
737 assign = ('attr',)
738 update = ('dict_attr',)
739 @functools.wraps(f, assign, update)
740 @add_dict_attr
741 def wrapper():
742 pass
743 self.check_wrapper(wrapper, f, assign, update)
744 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600745 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000746 self.assertEqual(wrapper.__doc__, None)
747 self.assertEqual(wrapper.attr, 'This is a different test')
748 self.assertEqual(wrapper.dict_attr, f.dict_attr)
749
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000750@unittest.skipUnless(c_functools, 'requires the C _functools module')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000751class TestReduce(unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000752 if c_functools:
753 func = c_functools.reduce
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000754
755 def test_reduce(self):
756 class Squares:
757 def __init__(self, max):
758 self.max = max
759 self.sofar = []
760
761 def __len__(self):
762 return len(self.sofar)
763
764 def __getitem__(self, i):
765 if not 0 <= i < self.max: raise IndexError
766 n = len(self.sofar)
767 while n <= i:
768 self.sofar.append(n*n)
769 n += 1
770 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000771 def add(x, y):
772 return x + y
773 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000774 self.assertEqual(
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000775 self.func(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000776 ['a','c','d','w']
777 )
778 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
779 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000780 self.func(lambda x, y: x*y, range(2,21), 1),
781 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000782 )
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000783 self.assertEqual(self.func(add, Squares(10)), 285)
784 self.assertEqual(self.func(add, Squares(10), 0), 285)
785 self.assertEqual(self.func(add, Squares(0), 0), 0)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000786 self.assertRaises(TypeError, self.func)
787 self.assertRaises(TypeError, self.func, 42, 42)
788 self.assertRaises(TypeError, self.func, 42, 42, 42)
789 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
790 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
791 self.assertRaises(TypeError, self.func, 42, (42, 42))
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000792 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
793 self.assertRaises(TypeError, self.func, add, "")
794 self.assertRaises(TypeError, self.func, add, ())
795 self.assertRaises(TypeError, self.func, add, object())
796
797 class TestFailingIter:
798 def __iter__(self):
799 raise RuntimeError
800 self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
801
802 self.assertEqual(self.func(add, [], None), None)
803 self.assertEqual(self.func(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000804
805 class BadSeq:
806 def __getitem__(self, index):
807 raise ValueError
808 self.assertRaises(ValueError, self.func, 42, BadSeq())
809
810 # Test reduce()'s use of iterators.
811 def test_iterator_usage(self):
812 class SequenceClass:
813 def __init__(self, n):
814 self.n = n
815 def __getitem__(self, i):
816 if 0 <= i < self.n:
817 return i
818 else:
819 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000820
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000821 from operator import add
822 self.assertEqual(self.func(add, SequenceClass(5)), 10)
823 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
824 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
825 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
826 self.assertEqual(self.func(add, SequenceClass(1)), 0)
827 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
828
829 d = {"one": 1, "two": 2, "three": 3}
830 self.assertEqual(self.func(add, d), "".join(d.keys()))
831
Łukasz Langa6f692512013-06-05 12:20:24 +0200832
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200833class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700834
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000835 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700836 def cmp1(x, y):
837 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100838 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700839 self.assertEqual(key(3), key(3))
840 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100841 self.assertGreaterEqual(key(3), key(3))
842
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700843 def cmp2(x, y):
844 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100845 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700846 self.assertEqual(key(4.0), key('4'))
847 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100848 self.assertLessEqual(key(2), key('35'))
849 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700850
851 def test_cmp_to_key_arguments(self):
852 def cmp1(x, y):
853 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100854 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700855 self.assertEqual(key(obj=3), key(obj=3))
856 self.assertGreater(key(obj=3), key(obj=1))
857 with self.assertRaises((TypeError, AttributeError)):
858 key(3) > 1 # rhs is not a K object
859 with self.assertRaises((TypeError, AttributeError)):
860 1 < key(3) # lhs is not a K object
861 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100862 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700863 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200864 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100865 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700866 with self.assertRaises(TypeError):
867 key() # too few args
868 with self.assertRaises(TypeError):
869 key(None, None) # too many args
870
871 def test_bad_cmp(self):
872 def cmp1(x, y):
873 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100874 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700875 with self.assertRaises(ZeroDivisionError):
876 key(3) > key(1)
877
878 class BadCmp:
879 def __lt__(self, other):
880 raise ZeroDivisionError
881 def cmp1(x, y):
882 return BadCmp()
883 with self.assertRaises(ZeroDivisionError):
884 key(3) > key(1)
885
886 def test_obj_field(self):
887 def cmp1(x, y):
888 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100889 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700890 self.assertEqual(key(50).obj, 50)
891
892 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000893 def mycmp(x, y):
894 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100895 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000896 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000897
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700898 def test_sort_int_str(self):
899 def mycmp(x, y):
900 x, y = int(x), int(y)
901 return (x > y) - (x < y)
902 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100903 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700904 self.assertEqual([int(value) for value in values],
905 [0, 1, 1, 2, 3, 4, 5, 7, 10])
906
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000907 def test_hash(self):
908 def mycmp(x, y):
909 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100910 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000911 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700912 self.assertRaises(TypeError, hash, k)
Raymond Hettingere7a24302011-05-03 11:16:36 -0700913 self.assertNotIsInstance(k, collections.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000914
Łukasz Langa6f692512013-06-05 12:20:24 +0200915
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200916@unittest.skipUnless(c_functools, 'requires the C _functools module')
917class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
918 if c_functools:
919 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100920
Łukasz Langa6f692512013-06-05 12:20:24 +0200921
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200922class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100923 cmp_to_key = staticmethod(py_functools.cmp_to_key)
924
Łukasz Langa6f692512013-06-05 12:20:24 +0200925
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000926class TestTotalOrdering(unittest.TestCase):
927
928 def test_total_ordering_lt(self):
929 @functools.total_ordering
930 class A:
931 def __init__(self, value):
932 self.value = value
933 def __lt__(self, other):
934 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000935 def __eq__(self, other):
936 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000937 self.assertTrue(A(1) < A(2))
938 self.assertTrue(A(2) > A(1))
939 self.assertTrue(A(1) <= A(2))
940 self.assertTrue(A(2) >= A(1))
941 self.assertTrue(A(2) <= A(2))
942 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000943 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000944
945 def test_total_ordering_le(self):
946 @functools.total_ordering
947 class A:
948 def __init__(self, value):
949 self.value = value
950 def __le__(self, other):
951 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000952 def __eq__(self, other):
953 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000954 self.assertTrue(A(1) < A(2))
955 self.assertTrue(A(2) > A(1))
956 self.assertTrue(A(1) <= A(2))
957 self.assertTrue(A(2) >= A(1))
958 self.assertTrue(A(2) <= A(2))
959 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000960 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000961
962 def test_total_ordering_gt(self):
963 @functools.total_ordering
964 class A:
965 def __init__(self, value):
966 self.value = value
967 def __gt__(self, other):
968 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000969 def __eq__(self, other):
970 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000971 self.assertTrue(A(1) < A(2))
972 self.assertTrue(A(2) > A(1))
973 self.assertTrue(A(1) <= A(2))
974 self.assertTrue(A(2) >= A(1))
975 self.assertTrue(A(2) <= A(2))
976 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000977 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000978
979 def test_total_ordering_ge(self):
980 @functools.total_ordering
981 class A:
982 def __init__(self, value):
983 self.value = value
984 def __ge__(self, other):
985 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000986 def __eq__(self, other):
987 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000988 self.assertTrue(A(1) < A(2))
989 self.assertTrue(A(2) > A(1))
990 self.assertTrue(A(1) <= A(2))
991 self.assertTrue(A(2) >= A(1))
992 self.assertTrue(A(2) <= A(2))
993 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000994 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000995
996 def test_total_ordering_no_overwrite(self):
997 # new methods should not overwrite existing
998 @functools.total_ordering
999 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +00001000 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +00001001 self.assertTrue(A(1) < A(2))
1002 self.assertTrue(A(2) > A(1))
1003 self.assertTrue(A(1) <= A(2))
1004 self.assertTrue(A(2) >= A(1))
1005 self.assertTrue(A(2) <= A(2))
1006 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001007
Benjamin Peterson42ebee32010-04-11 01:43:16 +00001008 def test_no_operations_defined(self):
1009 with self.assertRaises(ValueError):
1010 @functools.total_ordering
1011 class A:
1012 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001013
Nick Coghlanf05d9812013-10-02 00:02:03 +10001014 def test_type_error_when_not_implemented(self):
1015 # bug 10042; ensure stack overflow does not occur
1016 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001017 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +10001018 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001019 def __init__(self, value):
1020 self.value = value
1021 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001022 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001023 return self.value == other.value
1024 return False
1025 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001026 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001027 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +10001028 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001029
Nick Coghlanf05d9812013-10-02 00:02:03 +10001030 @functools.total_ordering
1031 class ImplementsGreaterThan:
1032 def __init__(self, value):
1033 self.value = value
1034 def __eq__(self, other):
1035 if isinstance(other, ImplementsGreaterThan):
1036 return self.value == other.value
1037 return False
1038 def __gt__(self, other):
1039 if isinstance(other, ImplementsGreaterThan):
1040 return self.value > other.value
1041 return NotImplemented
1042
1043 @functools.total_ordering
1044 class ImplementsLessThanEqualTo:
1045 def __init__(self, value):
1046 self.value = value
1047 def __eq__(self, other):
1048 if isinstance(other, ImplementsLessThanEqualTo):
1049 return self.value == other.value
1050 return False
1051 def __le__(self, other):
1052 if isinstance(other, ImplementsLessThanEqualTo):
1053 return self.value <= other.value
1054 return NotImplemented
1055
1056 @functools.total_ordering
1057 class ImplementsGreaterThanEqualTo:
1058 def __init__(self, value):
1059 self.value = value
1060 def __eq__(self, other):
1061 if isinstance(other, ImplementsGreaterThanEqualTo):
1062 return self.value == other.value
1063 return False
1064 def __ge__(self, other):
1065 if isinstance(other, ImplementsGreaterThanEqualTo):
1066 return self.value >= other.value
1067 return NotImplemented
1068
1069 @functools.total_ordering
1070 class ComparatorNotImplemented:
1071 def __init__(self, value):
1072 self.value = value
1073 def __eq__(self, other):
1074 if isinstance(other, ComparatorNotImplemented):
1075 return self.value == other.value
1076 return False
1077 def __lt__(self, other):
1078 return NotImplemented
1079
1080 with self.subTest("LT < 1"), self.assertRaises(TypeError):
1081 ImplementsLessThan(-1) < 1
1082
1083 with self.subTest("LT < LE"), self.assertRaises(TypeError):
1084 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
1085
1086 with self.subTest("LT < GT"), self.assertRaises(TypeError):
1087 ImplementsLessThan(1) < ImplementsGreaterThan(1)
1088
1089 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
1090 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
1091
1092 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
1093 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
1094
1095 with self.subTest("GT > GE"), self.assertRaises(TypeError):
1096 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
1097
1098 with self.subTest("GT > LT"), self.assertRaises(TypeError):
1099 ImplementsGreaterThan(5) > ImplementsLessThan(5)
1100
1101 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
1102 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
1103
1104 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
1105 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
1106
1107 with self.subTest("GE when equal"):
1108 a = ComparatorNotImplemented(8)
1109 b = ComparatorNotImplemented(8)
1110 self.assertEqual(a, b)
1111 with self.assertRaises(TypeError):
1112 a >= b
1113
1114 with self.subTest("LE when equal"):
1115 a = ComparatorNotImplemented(9)
1116 b = ComparatorNotImplemented(9)
1117 self.assertEqual(a, b)
1118 with self.assertRaises(TypeError):
1119 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +02001120
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001121 def test_pickle(self):
Serhiy Storchaka92bb90a2016-09-22 11:39:25 +03001122 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001123 for name in '__lt__', '__gt__', '__le__', '__ge__':
1124 with self.subTest(method=name, proto=proto):
1125 method = getattr(Orderable_LT, name)
1126 method_copy = pickle.loads(pickle.dumps(method, proto))
1127 self.assertIs(method_copy, method)
1128
1129@functools.total_ordering
1130class Orderable_LT:
1131 def __init__(self, value):
1132 self.value = value
1133 def __lt__(self, other):
1134 return self.value < other.value
1135 def __eq__(self, other):
1136 return self.value == other.value
1137
1138
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001139class TestLRU:
Georg Brandl2e7346a2010-07-31 18:09:23 +00001140
1141 def test_lru(self):
1142 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001143 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001144 f = self.module.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001145 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001146 self.assertEqual(maxsize, 20)
1147 self.assertEqual(currsize, 0)
1148 self.assertEqual(hits, 0)
1149 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001150
1151 domain = range(5)
1152 for i in range(1000):
1153 x, y = choice(domain), choice(domain)
1154 actual = f(x, y)
1155 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +00001156 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001157 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001158 self.assertTrue(hits > misses)
1159 self.assertEqual(hits + misses, 1000)
1160 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001161
Raymond Hettinger02566ec2010-09-04 22:46:06 +00001162 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +00001163 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001164 self.assertEqual(hits, 0)
1165 self.assertEqual(misses, 0)
1166 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001167 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001168 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001169 self.assertEqual(hits, 0)
1170 self.assertEqual(misses, 1)
1171 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001172
Nick Coghlan98876832010-08-17 06:17:18 +00001173 # Test bypassing the cache
1174 self.assertIs(f.__wrapped__, orig)
1175 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001176 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001177 self.assertEqual(hits, 0)
1178 self.assertEqual(misses, 1)
1179 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +00001180
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001181 # test size zero (which means "never-cache")
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001182 @self.module.lru_cache(0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001183 def f():
1184 nonlocal f_cnt
1185 f_cnt += 1
1186 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001187 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001188 f_cnt = 0
1189 for i in range(5):
1190 self.assertEqual(f(), 20)
1191 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001192 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001193 self.assertEqual(hits, 0)
1194 self.assertEqual(misses, 5)
1195 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001196
1197 # test size one
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001198 @self.module.lru_cache(1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001199 def f():
1200 nonlocal f_cnt
1201 f_cnt += 1
1202 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001203 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001204 f_cnt = 0
1205 for i in range(5):
1206 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001207 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001208 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001209 self.assertEqual(hits, 4)
1210 self.assertEqual(misses, 1)
1211 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001212
Raymond Hettingerf3098282010-08-15 03:30:45 +00001213 # test size two
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001214 @self.module.lru_cache(2)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001215 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001216 nonlocal f_cnt
1217 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +00001218 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +00001219 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001220 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +00001221 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1222 # * * * *
1223 self.assertEqual(f(x), x*10)
1224 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001225 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001226 self.assertEqual(hits, 12)
1227 self.assertEqual(misses, 4)
1228 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001229
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001230 def test_lru_hash_only_once(self):
1231 # To protect against weird reentrancy bugs and to improve
1232 # efficiency when faced with slow __hash__ methods, the
1233 # LRU cache guarantees that it will only call __hash__
1234 # only once per use as an argument to the cached function.
1235
1236 @self.module.lru_cache(maxsize=1)
1237 def f(x, y):
1238 return x * 3 + y
1239
1240 # Simulate the integer 5
1241 mock_int = unittest.mock.Mock()
1242 mock_int.__mul__ = unittest.mock.Mock(return_value=15)
1243 mock_int.__hash__ = unittest.mock.Mock(return_value=999)
1244
1245 # Add to cache: One use as an argument gives one call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001246 self.assertEqual(f(mock_int, 1), 16)
1247 self.assertEqual(mock_int.__hash__.call_count, 1)
1248 self.assertEqual(f.cache_info(), (0, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001249
1250 # Cache hit: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001251 self.assertEqual(f(mock_int, 1), 16)
1252 self.assertEqual(mock_int.__hash__.call_count, 2)
1253 self.assertEqual(f.cache_info(), (1, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001254
1255 # Cache eviction: No use as an argument gives no additonal call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001256 self.assertEqual(f(6, 2), 20)
1257 self.assertEqual(mock_int.__hash__.call_count, 2)
1258 self.assertEqual(f.cache_info(), (1, 2, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001259
1260 # Cache miss: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001261 self.assertEqual(f(mock_int, 1), 16)
1262 self.assertEqual(mock_int.__hash__.call_count, 3)
1263 self.assertEqual(f.cache_info(), (1, 3, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001264
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08001265 def test_lru_reentrancy_with_len(self):
1266 # Test to make sure the LRU cache code isn't thrown-off by
1267 # caching the built-in len() function. Since len() can be
1268 # cached, we shouldn't use it inside the lru code itself.
1269 old_len = builtins.len
1270 try:
1271 builtins.len = self.module.lru_cache(4)(len)
1272 for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
1273 self.assertEqual(len('abcdefghijklmn'[:i]), i)
1274 finally:
1275 builtins.len = old_len
1276
Raymond Hettinger605a4472017-01-09 07:50:19 -08001277 def test_lru_star_arg_handling(self):
1278 # Test regression that arose in ea064ff3c10f
1279 @functools.lru_cache()
1280 def f(*args):
1281 return args
1282
1283 self.assertEqual(f(1, 2), (1, 2))
1284 self.assertEqual(f((1, 2)), ((1, 2),))
1285
Yury Selivanov46a02db2016-11-09 18:55:45 -05001286 def test_lru_type_error(self):
1287 # Regression test for issue #28653.
1288 # lru_cache was leaking when one of the arguments
1289 # wasn't cacheable.
1290
1291 @functools.lru_cache(maxsize=None)
1292 def infinite_cache(o):
1293 pass
1294
1295 @functools.lru_cache(maxsize=10)
1296 def limited_cache(o):
1297 pass
1298
1299 with self.assertRaises(TypeError):
1300 infinite_cache([])
1301
1302 with self.assertRaises(TypeError):
1303 limited_cache([])
1304
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001305 def test_lru_with_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001306 @self.module.lru_cache(maxsize=None)
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001307 def fib(n):
1308 if n < 2:
1309 return n
1310 return fib(n-1) + fib(n-2)
1311 self.assertEqual([fib(n) for n in range(16)],
1312 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1313 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001314 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001315 fib.cache_clear()
1316 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001317 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1318
1319 def test_lru_with_maxsize_negative(self):
1320 @self.module.lru_cache(maxsize=-10)
1321 def eq(n):
1322 return n
1323 for i in (0, 1):
1324 self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1325 self.assertEqual(eq.cache_info(),
1326 self.module._CacheInfo(hits=0, misses=300, maxsize=-10, currsize=1))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001327
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001328 def test_lru_with_exceptions(self):
1329 # Verify that user_function exceptions get passed through without
1330 # creating a hard-to-read chained exception.
1331 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +01001332 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001333 @self.module.lru_cache(maxsize)
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001334 def func(i):
1335 return 'abc'[i]
1336 self.assertEqual(func(0), 'a')
1337 with self.assertRaises(IndexError) as cm:
1338 func(15)
1339 self.assertIsNone(cm.exception.__context__)
1340 # Verify that the previous exception did not result in a cached entry
1341 with self.assertRaises(IndexError):
1342 func(15)
1343
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001344 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001345 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001346 @self.module.lru_cache(maxsize=maxsize, typed=True)
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001347 def square(x):
1348 return x * x
1349 self.assertEqual(square(3), 9)
1350 self.assertEqual(type(square(3)), type(9))
1351 self.assertEqual(square(3.0), 9.0)
1352 self.assertEqual(type(square(3.0)), type(9.0))
1353 self.assertEqual(square(x=3), 9)
1354 self.assertEqual(type(square(x=3)), type(9))
1355 self.assertEqual(square(x=3.0), 9.0)
1356 self.assertEqual(type(square(x=3.0)), type(9.0))
1357 self.assertEqual(square.cache_info().hits, 4)
1358 self.assertEqual(square.cache_info().misses, 4)
1359
Antoine Pitroub5b37142012-11-13 21:35:40 +01001360 def test_lru_with_keyword_args(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001361 @self.module.lru_cache()
Antoine Pitroub5b37142012-11-13 21:35:40 +01001362 def fib(n):
1363 if n < 2:
1364 return n
1365 return fib(n=n-1) + fib(n=n-2)
1366 self.assertEqual(
1367 [fib(n=number) for number in range(16)],
1368 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1369 )
1370 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001371 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001372 fib.cache_clear()
1373 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001374 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001375
1376 def test_lru_with_keyword_args_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001377 @self.module.lru_cache(maxsize=None)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001378 def fib(n):
1379 if n < 2:
1380 return n
1381 return fib(n=n-1) + fib(n=n-2)
1382 self.assertEqual([fib(n=number) for number in range(16)],
1383 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1384 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001385 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001386 fib.cache_clear()
1387 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001388 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1389
Raymond Hettinger4ee39142017-01-08 17:28:20 -08001390 def test_kwargs_order(self):
1391 # PEP 468: Preserving Keyword Argument Order
1392 @self.module.lru_cache(maxsize=10)
1393 def f(**kwargs):
1394 return list(kwargs.items())
1395 self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)])
1396 self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)])
1397 self.assertEqual(f.cache_info(),
1398 self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2))
1399
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001400 def test_lru_cache_decoration(self):
1401 def f(zomg: 'zomg_annotation'):
1402 """f doc string"""
1403 return 42
1404 g = self.module.lru_cache()(f)
1405 for attr in self.module.WRAPPER_ASSIGNMENTS:
1406 self.assertEqual(getattr(g, attr), getattr(f, attr))
1407
1408 @unittest.skipUnless(threading, 'This test requires threading.')
1409 def test_lru_cache_threaded(self):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001410 n, m = 5, 11
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001411 def orig(x, y):
1412 return 3 * x + y
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001413 f = self.module.lru_cache(maxsize=n*m)(orig)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001414 hits, misses, maxsize, currsize = f.cache_info()
1415 self.assertEqual(currsize, 0)
1416
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001417 start = threading.Event()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001418 def full(k):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001419 start.wait(10)
1420 for _ in range(m):
Serhiy Storchaka391af752015-06-08 12:44:18 +03001421 self.assertEqual(f(k, 0), orig(k, 0))
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001422
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001423 def clear():
1424 start.wait(10)
1425 for _ in range(2*m):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001426 f.cache_clear()
1427
1428 orig_si = sys.getswitchinterval()
Xavier de Gaye7522ef42016-12-08 11:06:56 +01001429 support.setswitchinterval(1e-6)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001430 try:
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001431 # create n threads in order to fill cache
Serhiy Storchaka391af752015-06-08 12:44:18 +03001432 threads = [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001433 for k in range(n)]
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001434 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001435 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001436
1437 hits, misses, maxsize, currsize = f.cache_info()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001438 if self.module is py_functools:
1439 # XXX: Why can be not equal?
1440 self.assertLessEqual(misses, n)
1441 self.assertLessEqual(hits, m*n - misses)
1442 else:
1443 self.assertEqual(misses, n)
1444 self.assertEqual(hits, m*n - misses)
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001445 self.assertEqual(currsize, n)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001446
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001447 # create n threads in order to fill cache and 1 to clear it
1448 threads = [threading.Thread(target=clear)]
Serhiy Storchaka391af752015-06-08 12:44:18 +03001449 threads += [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001450 for k in range(n)]
1451 start.clear()
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001452 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001453 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001454 finally:
1455 sys.setswitchinterval(orig_si)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001456
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001457 @unittest.skipUnless(threading, 'This test requires threading.')
1458 def test_lru_cache_threaded2(self):
1459 # Simultaneous call with the same arguments
1460 n, m = 5, 7
1461 start = threading.Barrier(n+1)
1462 pause = threading.Barrier(n+1)
1463 stop = threading.Barrier(n+1)
1464 @self.module.lru_cache(maxsize=m*n)
1465 def f(x):
1466 pause.wait(10)
1467 return 3 * x
1468 self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1469 def test():
1470 for i in range(m):
1471 start.wait(10)
1472 self.assertEqual(f(i), 3 * i)
1473 stop.wait(10)
1474 threads = [threading.Thread(target=test) for k in range(n)]
1475 with support.start_threads(threads):
1476 for i in range(m):
1477 start.wait(10)
1478 stop.reset()
1479 pause.wait(10)
1480 start.reset()
1481 stop.wait(10)
1482 pause.reset()
1483 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1484
Serhiy Storchaka67796522017-01-12 18:34:33 +02001485 @unittest.skipUnless(threading, 'This test requires threading.')
1486 def test_lru_cache_threaded3(self):
1487 @self.module.lru_cache(maxsize=2)
1488 def f(x):
1489 time.sleep(.01)
1490 return 3 * x
1491 def test(i, x):
1492 with self.subTest(thread=i):
1493 self.assertEqual(f(x), 3 * x, i)
1494 threads = [threading.Thread(target=test, args=(i, v))
1495 for i, v in enumerate([1, 2, 2, 3, 2])]
1496 with support.start_threads(threads):
1497 pass
1498
Raymond Hettinger03923422013-03-04 02:52:50 -05001499 def test_need_for_rlock(self):
1500 # This will deadlock on an LRU cache that uses a regular lock
1501
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001502 @self.module.lru_cache(maxsize=10)
Raymond Hettinger03923422013-03-04 02:52:50 -05001503 def test_func(x):
1504 'Used to demonstrate a reentrant lru_cache call within a single thread'
1505 return x
1506
1507 class DoubleEq:
1508 'Demonstrate a reentrant lru_cache call within a single thread'
1509 def __init__(self, x):
1510 self.x = x
1511 def __hash__(self):
1512 return self.x
1513 def __eq__(self, other):
1514 if self.x == 2:
1515 test_func(DoubleEq(1))
1516 return self.x == other.x
1517
1518 test_func(DoubleEq(1)) # Load the cache
1519 test_func(DoubleEq(2)) # Load the cache
1520 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1521 DoubleEq(2)) # Verify the correct return value
1522
Raymond Hettinger4d588972014-08-12 12:44:52 -07001523 def test_early_detection_of_bad_call(self):
1524 # Issue #22184
1525 with self.assertRaises(TypeError):
1526 @functools.lru_cache
1527 def f():
1528 pass
1529
Serhiy Storchakae7070f02015-06-08 11:19:24 +03001530 def test_lru_method(self):
1531 class X(int):
1532 f_cnt = 0
1533 @self.module.lru_cache(2)
1534 def f(self, x):
1535 self.f_cnt += 1
1536 return x*10+self
1537 a = X(5)
1538 b = X(5)
1539 c = X(7)
1540 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1541
1542 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1543 self.assertEqual(a.f(x), x*10 + 5)
1544 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1545 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1546
1547 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1548 self.assertEqual(b.f(x), x*10 + 5)
1549 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1550 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1551
1552 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1553 self.assertEqual(c.f(x), x*10 + 7)
1554 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1555 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1556
1557 self.assertEqual(a.f.cache_info(), X.f.cache_info())
1558 self.assertEqual(b.f.cache_info(), X.f.cache_info())
1559 self.assertEqual(c.f.cache_info(), X.f.cache_info())
1560
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001561 def test_pickle(self):
1562 cls = self.__class__
1563 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1564 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1565 with self.subTest(proto=proto, func=f):
1566 f_copy = pickle.loads(pickle.dumps(f, proto))
1567 self.assertIs(f_copy, f)
1568
1569 def test_copy(self):
1570 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001571 def orig(x, y):
1572 return 3 * x + y
1573 part = self.module.partial(orig, 2)
1574 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1575 self.module.lru_cache(2)(part))
1576 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001577 with self.subTest(func=f):
1578 f_copy = copy.copy(f)
1579 self.assertIs(f_copy, f)
1580
1581 def test_deepcopy(self):
1582 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001583 def orig(x, y):
1584 return 3 * x + y
1585 part = self.module.partial(orig, 2)
1586 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1587 self.module.lru_cache(2)(part))
1588 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001589 with self.subTest(func=f):
1590 f_copy = copy.deepcopy(f)
1591 self.assertIs(f_copy, f)
1592
1593
1594@py_functools.lru_cache()
1595def py_cached_func(x, y):
1596 return 3 * x + y
1597
1598@c_functools.lru_cache()
1599def c_cached_func(x, y):
1600 return 3 * x + y
1601
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001602
1603class TestLRUPy(TestLRU, unittest.TestCase):
1604 module = py_functools
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001605 cached_func = py_cached_func,
1606
1607 @module.lru_cache()
1608 def cached_meth(self, x, y):
1609 return 3 * x + y
1610
1611 @staticmethod
1612 @module.lru_cache()
1613 def cached_staticmeth(x, y):
1614 return 3 * x + y
1615
1616
1617class TestLRUC(TestLRU, unittest.TestCase):
1618 module = c_functools
1619 cached_func = c_cached_func,
1620
1621 @module.lru_cache()
1622 def cached_meth(self, x, y):
1623 return 3 * x + y
1624
1625 @staticmethod
1626 @module.lru_cache()
1627 def cached_staticmeth(x, y):
1628 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001629
Raymond Hettinger03923422013-03-04 02:52:50 -05001630
Łukasz Langa6f692512013-06-05 12:20:24 +02001631class TestSingleDispatch(unittest.TestCase):
1632 def test_simple_overloads(self):
1633 @functools.singledispatch
1634 def g(obj):
1635 return "base"
1636 def g_int(i):
1637 return "integer"
1638 g.register(int, g_int)
1639 self.assertEqual(g("str"), "base")
1640 self.assertEqual(g(1), "integer")
1641 self.assertEqual(g([1,2,3]), "base")
1642
1643 def test_mro(self):
1644 @functools.singledispatch
1645 def g(obj):
1646 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001647 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001648 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001649 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001650 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001651 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001652 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001653 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001654 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001655 def g_A(a):
1656 return "A"
1657 def g_B(b):
1658 return "B"
1659 g.register(A, g_A)
1660 g.register(B, g_B)
1661 self.assertEqual(g(A()), "A")
1662 self.assertEqual(g(B()), "B")
1663 self.assertEqual(g(C()), "A")
1664 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001665
1666 def test_register_decorator(self):
1667 @functools.singledispatch
1668 def g(obj):
1669 return "base"
1670 @g.register(int)
1671 def g_int(i):
1672 return "int %s" % (i,)
1673 self.assertEqual(g(""), "base")
1674 self.assertEqual(g(12), "int 12")
1675 self.assertIs(g.dispatch(int), g_int)
1676 self.assertIs(g.dispatch(object), g.dispatch(str))
1677 # Note: in the assert above this is not g.
1678 # @singledispatch returns the wrapper.
1679
1680 def test_wrapping_attributes(self):
1681 @functools.singledispatch
1682 def g(obj):
1683 "Simple test"
1684 return "Test"
1685 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02001686 if sys.flags.optimize < 2:
1687 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02001688
1689 @unittest.skipUnless(decimal, 'requires _decimal')
1690 @support.cpython_only
1691 def test_c_classes(self):
1692 @functools.singledispatch
1693 def g(obj):
1694 return "base"
1695 @g.register(decimal.DecimalException)
1696 def _(obj):
1697 return obj.args
1698 subn = decimal.Subnormal("Exponent < Emin")
1699 rnd = decimal.Rounded("Number got rounded")
1700 self.assertEqual(g(subn), ("Exponent < Emin",))
1701 self.assertEqual(g(rnd), ("Number got rounded",))
1702 @g.register(decimal.Subnormal)
1703 def _(obj):
1704 return "Too small to care."
1705 self.assertEqual(g(subn), "Too small to care.")
1706 self.assertEqual(g(rnd), ("Number got rounded",))
1707
1708 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001709 # None of the examples in this test depend on haystack ordering.
Łukasz Langa6f692512013-06-05 12:20:24 +02001710 c = collections
1711 mro = functools._compose_mro
1712 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1713 for haystack in permutations(bases):
1714 m = mro(dict, haystack)
Guido van Rossumf0666942016-08-23 10:47:07 -07001715 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
1716 c.Collection, c.Sized, c.Iterable,
1717 c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001718 bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict]
1719 for haystack in permutations(bases):
1720 m = mro(c.ChainMap, haystack)
1721 self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001722 c.Collection, c.Sized, c.Iterable,
1723 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001724
1725 # If there's a generic function with implementations registered for
1726 # both Sized and Container, passing a defaultdict to it results in an
1727 # ambiguous dispatch which will cause a RuntimeError (see
1728 # test_mro_conflicts).
1729 bases = [c.Container, c.Sized, str]
1730 for haystack in permutations(bases):
1731 m = mro(c.defaultdict, [c.Sized, c.Container, str])
1732 self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container,
1733 object])
1734
1735 # MutableSequence below is registered directly on D. In other words, it
Martin Panter46f50722016-05-26 05:35:26 +00001736 # precedes MutableMapping which means single dispatch will always
Łukasz Langa3720c772013-07-01 16:00:38 +02001737 # choose MutableSequence here.
1738 class D(c.defaultdict):
1739 pass
1740 c.MutableSequence.register(D)
1741 bases = [c.MutableSequence, c.MutableMapping]
1742 for haystack in permutations(bases):
1743 m = mro(D, bases)
Guido van Rossumf0666942016-08-23 10:47:07 -07001744 self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
1745 c.defaultdict, dict, c.MutableMapping, c.Mapping,
1746 c.Collection, c.Sized, c.Iterable, c.Container,
Łukasz Langa3720c772013-07-01 16:00:38 +02001747 object])
1748
1749 # Container and Callable are registered on different base classes and
1750 # a generic function supporting both should always pick the Callable
1751 # implementation if a C instance is passed.
1752 class C(c.defaultdict):
1753 def __call__(self):
1754 pass
1755 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1756 for haystack in permutations(bases):
1757 m = mro(C, haystack)
1758 self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001759 c.Collection, c.Sized, c.Iterable,
1760 c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001761
1762 def test_register_abc(self):
1763 c = collections
1764 d = {"a": "b"}
1765 l = [1, 2, 3]
1766 s = {object(), None}
1767 f = frozenset(s)
1768 t = (1, 2, 3)
1769 @functools.singledispatch
1770 def g(obj):
1771 return "base"
1772 self.assertEqual(g(d), "base")
1773 self.assertEqual(g(l), "base")
1774 self.assertEqual(g(s), "base")
1775 self.assertEqual(g(f), "base")
1776 self.assertEqual(g(t), "base")
1777 g.register(c.Sized, lambda obj: "sized")
1778 self.assertEqual(g(d), "sized")
1779 self.assertEqual(g(l), "sized")
1780 self.assertEqual(g(s), "sized")
1781 self.assertEqual(g(f), "sized")
1782 self.assertEqual(g(t), "sized")
1783 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1784 self.assertEqual(g(d), "mutablemapping")
1785 self.assertEqual(g(l), "sized")
1786 self.assertEqual(g(s), "sized")
1787 self.assertEqual(g(f), "sized")
1788 self.assertEqual(g(t), "sized")
1789 g.register(c.ChainMap, lambda obj: "chainmap")
1790 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1791 self.assertEqual(g(l), "sized")
1792 self.assertEqual(g(s), "sized")
1793 self.assertEqual(g(f), "sized")
1794 self.assertEqual(g(t), "sized")
1795 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1796 self.assertEqual(g(d), "mutablemapping")
1797 self.assertEqual(g(l), "mutablesequence")
1798 self.assertEqual(g(s), "sized")
1799 self.assertEqual(g(f), "sized")
1800 self.assertEqual(g(t), "sized")
1801 g.register(c.MutableSet, lambda obj: "mutableset")
1802 self.assertEqual(g(d), "mutablemapping")
1803 self.assertEqual(g(l), "mutablesequence")
1804 self.assertEqual(g(s), "mutableset")
1805 self.assertEqual(g(f), "sized")
1806 self.assertEqual(g(t), "sized")
1807 g.register(c.Mapping, lambda obj: "mapping")
1808 self.assertEqual(g(d), "mutablemapping") # not specific enough
1809 self.assertEqual(g(l), "mutablesequence")
1810 self.assertEqual(g(s), "mutableset")
1811 self.assertEqual(g(f), "sized")
1812 self.assertEqual(g(t), "sized")
1813 g.register(c.Sequence, lambda obj: "sequence")
1814 self.assertEqual(g(d), "mutablemapping")
1815 self.assertEqual(g(l), "mutablesequence")
1816 self.assertEqual(g(s), "mutableset")
1817 self.assertEqual(g(f), "sized")
1818 self.assertEqual(g(t), "sequence")
1819 g.register(c.Set, lambda obj: "set")
1820 self.assertEqual(g(d), "mutablemapping")
1821 self.assertEqual(g(l), "mutablesequence")
1822 self.assertEqual(g(s), "mutableset")
1823 self.assertEqual(g(f), "set")
1824 self.assertEqual(g(t), "sequence")
1825 g.register(dict, lambda obj: "dict")
1826 self.assertEqual(g(d), "dict")
1827 self.assertEqual(g(l), "mutablesequence")
1828 self.assertEqual(g(s), "mutableset")
1829 self.assertEqual(g(f), "set")
1830 self.assertEqual(g(t), "sequence")
1831 g.register(list, lambda obj: "list")
1832 self.assertEqual(g(d), "dict")
1833 self.assertEqual(g(l), "list")
1834 self.assertEqual(g(s), "mutableset")
1835 self.assertEqual(g(f), "set")
1836 self.assertEqual(g(t), "sequence")
1837 g.register(set, lambda obj: "concrete-set")
1838 self.assertEqual(g(d), "dict")
1839 self.assertEqual(g(l), "list")
1840 self.assertEqual(g(s), "concrete-set")
1841 self.assertEqual(g(f), "set")
1842 self.assertEqual(g(t), "sequence")
1843 g.register(frozenset, lambda obj: "frozen-set")
1844 self.assertEqual(g(d), "dict")
1845 self.assertEqual(g(l), "list")
1846 self.assertEqual(g(s), "concrete-set")
1847 self.assertEqual(g(f), "frozen-set")
1848 self.assertEqual(g(t), "sequence")
1849 g.register(tuple, lambda obj: "tuple")
1850 self.assertEqual(g(d), "dict")
1851 self.assertEqual(g(l), "list")
1852 self.assertEqual(g(s), "concrete-set")
1853 self.assertEqual(g(f), "frozen-set")
1854 self.assertEqual(g(t), "tuple")
1855
Łukasz Langa3720c772013-07-01 16:00:38 +02001856 def test_c3_abc(self):
1857 c = collections
1858 mro = functools._c3_mro
1859 class A(object):
1860 pass
1861 class B(A):
1862 def __len__(self):
1863 return 0 # implies Sized
1864 @c.Container.register
1865 class C(object):
1866 pass
1867 class D(object):
1868 pass # unrelated
1869 class X(D, C, B):
1870 def __call__(self):
1871 pass # implies Callable
1872 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
1873 for abcs in permutations([c.Sized, c.Callable, c.Container]):
1874 self.assertEqual(mro(X, abcs=abcs), expected)
1875 # unrelated ABCs don't appear in the resulting MRO
1876 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
1877 self.assertEqual(mro(X, abcs=many_abcs), expected)
1878
Yury Selivanov77a8cd62015-08-18 14:20:00 -04001879 def test_false_meta(self):
1880 # see issue23572
1881 class MetaA(type):
1882 def __len__(self):
1883 return 0
1884 class A(metaclass=MetaA):
1885 pass
1886 class AA(A):
1887 pass
1888 @functools.singledispatch
1889 def fun(a):
1890 return 'base A'
1891 @fun.register(A)
1892 def _(a):
1893 return 'fun A'
1894 aa = AA()
1895 self.assertEqual(fun(aa), 'fun A')
1896
Łukasz Langa6f692512013-06-05 12:20:24 +02001897 def test_mro_conflicts(self):
1898 c = collections
Łukasz Langa6f692512013-06-05 12:20:24 +02001899 @functools.singledispatch
1900 def g(arg):
1901 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02001902 class O(c.Sized):
1903 def __len__(self):
1904 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001905 o = O()
1906 self.assertEqual(g(o), "base")
1907 g.register(c.Iterable, lambda arg: "iterable")
1908 g.register(c.Container, lambda arg: "container")
1909 g.register(c.Sized, lambda arg: "sized")
1910 g.register(c.Set, lambda arg: "set")
1911 self.assertEqual(g(o), "sized")
1912 c.Iterable.register(O)
1913 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
1914 c.Container.register(O)
1915 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02001916 c.Set.register(O)
1917 self.assertEqual(g(o), "set") # because c.Set is a subclass of
1918 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02001919 class P:
1920 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02001921 p = P()
1922 self.assertEqual(g(p), "base")
1923 c.Iterable.register(P)
1924 self.assertEqual(g(p), "iterable")
1925 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02001926 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02001927 g(p)
Benjamin Petersonab078e92016-07-13 21:13:29 -07001928 self.assertIn(
1929 str(re_one.exception),
1930 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1931 "or <class 'collections.abc.Iterable'>"),
1932 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
1933 "or <class 'collections.abc.Container'>")),
1934 )
Łukasz Langa6f692512013-06-05 12:20:24 +02001935 class Q(c.Sized):
1936 def __len__(self):
1937 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001938 q = Q()
1939 self.assertEqual(g(q), "sized")
1940 c.Iterable.register(Q)
1941 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
1942 c.Set.register(Q)
1943 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02001944 # c.Sized and c.Iterable
1945 @functools.singledispatch
1946 def h(arg):
1947 return "base"
1948 @h.register(c.Sized)
1949 def _(arg):
1950 return "sized"
1951 @h.register(c.Container)
1952 def _(arg):
1953 return "container"
1954 # Even though Sized and Container are explicit bases of MutableMapping,
1955 # this ABC is implicitly registered on defaultdict which makes all of
1956 # MutableMapping's bases implicit as well from defaultdict's
1957 # perspective.
1958 with self.assertRaises(RuntimeError) as re_two:
1959 h(c.defaultdict(lambda: 0))
Benjamin Petersonab078e92016-07-13 21:13:29 -07001960 self.assertIn(
1961 str(re_two.exception),
1962 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1963 "or <class 'collections.abc.Sized'>"),
1964 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1965 "or <class 'collections.abc.Container'>")),
1966 )
Łukasz Langa3720c772013-07-01 16:00:38 +02001967 class R(c.defaultdict):
1968 pass
1969 c.MutableSequence.register(R)
1970 @functools.singledispatch
1971 def i(arg):
1972 return "base"
1973 @i.register(c.MutableMapping)
1974 def _(arg):
1975 return "mapping"
1976 @i.register(c.MutableSequence)
1977 def _(arg):
1978 return "sequence"
1979 r = R()
1980 self.assertEqual(i(r), "sequence")
1981 class S:
1982 pass
1983 class T(S, c.Sized):
1984 def __len__(self):
1985 return 0
1986 t = T()
1987 self.assertEqual(h(t), "sized")
1988 c.Container.register(T)
1989 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
1990 class U:
1991 def __len__(self):
1992 return 0
1993 u = U()
1994 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
1995 # from the existence of __len__()
1996 c.Container.register(U)
1997 # There is no preference for registered versus inferred ABCs.
1998 with self.assertRaises(RuntimeError) as re_three:
1999 h(u)
Benjamin Petersonab078e92016-07-13 21:13:29 -07002000 self.assertIn(
2001 str(re_three.exception),
2002 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2003 "or <class 'collections.abc.Sized'>"),
2004 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2005 "or <class 'collections.abc.Container'>")),
2006 )
Łukasz Langa3720c772013-07-01 16:00:38 +02002007 class V(c.Sized, S):
2008 def __len__(self):
2009 return 0
2010 @functools.singledispatch
2011 def j(arg):
2012 return "base"
2013 @j.register(S)
2014 def _(arg):
2015 return "s"
2016 @j.register(c.Container)
2017 def _(arg):
2018 return "container"
2019 v = V()
2020 self.assertEqual(j(v), "s")
2021 c.Container.register(V)
2022 self.assertEqual(j(v), "container") # because it ends up right after
2023 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02002024
2025 def test_cache_invalidation(self):
2026 from collections import UserDict
2027 class TracingDict(UserDict):
2028 def __init__(self, *args, **kwargs):
2029 super(TracingDict, self).__init__(*args, **kwargs)
2030 self.set_ops = []
2031 self.get_ops = []
2032 def __getitem__(self, key):
2033 result = self.data[key]
2034 self.get_ops.append(key)
2035 return result
2036 def __setitem__(self, key, value):
2037 self.set_ops.append(key)
2038 self.data[key] = value
2039 def clear(self):
2040 self.data.clear()
2041 _orig_wkd = functools.WeakKeyDictionary
2042 td = TracingDict()
2043 functools.WeakKeyDictionary = lambda: td
2044 c = collections
2045 @functools.singledispatch
2046 def g(arg):
2047 return "base"
2048 d = {}
2049 l = []
2050 self.assertEqual(len(td), 0)
2051 self.assertEqual(g(d), "base")
2052 self.assertEqual(len(td), 1)
2053 self.assertEqual(td.get_ops, [])
2054 self.assertEqual(td.set_ops, [dict])
2055 self.assertEqual(td.data[dict], g.registry[object])
2056 self.assertEqual(g(l), "base")
2057 self.assertEqual(len(td), 2)
2058 self.assertEqual(td.get_ops, [])
2059 self.assertEqual(td.set_ops, [dict, list])
2060 self.assertEqual(td.data[dict], g.registry[object])
2061 self.assertEqual(td.data[list], g.registry[object])
2062 self.assertEqual(td.data[dict], td.data[list])
2063 self.assertEqual(g(l), "base")
2064 self.assertEqual(g(d), "base")
2065 self.assertEqual(td.get_ops, [list, dict])
2066 self.assertEqual(td.set_ops, [dict, list])
2067 g.register(list, lambda arg: "list")
2068 self.assertEqual(td.get_ops, [list, dict])
2069 self.assertEqual(len(td), 0)
2070 self.assertEqual(g(d), "base")
2071 self.assertEqual(len(td), 1)
2072 self.assertEqual(td.get_ops, [list, dict])
2073 self.assertEqual(td.set_ops, [dict, list, dict])
2074 self.assertEqual(td.data[dict],
2075 functools._find_impl(dict, g.registry))
2076 self.assertEqual(g(l), "list")
2077 self.assertEqual(len(td), 2)
2078 self.assertEqual(td.get_ops, [list, dict])
2079 self.assertEqual(td.set_ops, [dict, list, dict, list])
2080 self.assertEqual(td.data[list],
2081 functools._find_impl(list, g.registry))
2082 class X:
2083 pass
2084 c.MutableMapping.register(X) # Will not invalidate the cache,
2085 # not using ABCs yet.
2086 self.assertEqual(g(d), "base")
2087 self.assertEqual(g(l), "list")
2088 self.assertEqual(td.get_ops, [list, dict, dict, list])
2089 self.assertEqual(td.set_ops, [dict, list, dict, list])
2090 g.register(c.Sized, lambda arg: "sized")
2091 self.assertEqual(len(td), 0)
2092 self.assertEqual(g(d), "sized")
2093 self.assertEqual(len(td), 1)
2094 self.assertEqual(td.get_ops, [list, dict, dict, list])
2095 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
2096 self.assertEqual(g(l), "list")
2097 self.assertEqual(len(td), 2)
2098 self.assertEqual(td.get_ops, [list, dict, dict, list])
2099 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2100 self.assertEqual(g(l), "list")
2101 self.assertEqual(g(d), "sized")
2102 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
2103 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2104 g.dispatch(list)
2105 g.dispatch(dict)
2106 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
2107 list, dict])
2108 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2109 c.MutableSet.register(X) # Will invalidate the cache.
2110 self.assertEqual(len(td), 2) # Stale cache.
2111 self.assertEqual(g(l), "list")
2112 self.assertEqual(len(td), 1)
2113 g.register(c.MutableMapping, lambda arg: "mutablemapping")
2114 self.assertEqual(len(td), 0)
2115 self.assertEqual(g(d), "mutablemapping")
2116 self.assertEqual(len(td), 1)
2117 self.assertEqual(g(l), "list")
2118 self.assertEqual(len(td), 2)
2119 g.register(dict, lambda arg: "dict")
2120 self.assertEqual(g(d), "dict")
2121 self.assertEqual(g(l), "list")
2122 g._clear_cache()
2123 self.assertEqual(len(td), 0)
2124 functools.WeakKeyDictionary = _orig_wkd
2125
2126
Raymond Hettinger9c323f82005-02-28 19:39:44 +00002127if __name__ == '__main__':
Zachary Ware38c707e2015-04-13 15:00:43 -05002128 unittest.main()