blob: 75427dfad3420caf6bfc79cc53c0ab4495fbe6c2 [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettinger003be522011-05-03 11:01:32 -07002import collections
Serhiy Storchaka45120f22015-10-24 09:49:56 +03003import copy
Łukasz Langa6f692512013-06-05 12:20:24 +02004from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00005import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00006from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02007import sys
8from test import support
9import unittest
10from weakref import proxy
Nick Coghlan457fc9a2016-09-10 20:00:02 +100011import contextlib
Serhiy Storchaka46c56112015-05-24 21:53:49 +030012try:
13 import threading
14except ImportError:
15 threading = None
Raymond Hettinger9c323f82005-02-28 19:39:44 +000016
Antoine Pitroub5b37142012-11-13 21:35:40 +010017import functools
18
Antoine Pitroub5b37142012-11-13 21:35:40 +010019py_functools = support.import_fresh_module('functools', blocked=['_functools'])
20c_functools = support.import_fresh_module('functools', fresh=['_functools'])
21
Łukasz Langa6f692512013-06-05 12:20:24 +020022decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
23
Nick Coghlan457fc9a2016-09-10 20:00:02 +100024@contextlib.contextmanager
25def replaced_module(name, replacement):
26 original_module = sys.modules[name]
27 sys.modules[name] = replacement
28 try:
29 yield
30 finally:
31 sys.modules[name] = original_module
Łukasz Langa6f692512013-06-05 12:20:24 +020032
Raymond Hettinger9c323f82005-02-28 19:39:44 +000033def capture(*args, **kw):
34 """capture all positional and keyword arguments"""
35 return args, kw
36
Łukasz Langa6f692512013-06-05 12:20:24 +020037
Jack Diederiche0cbd692009-04-01 04:27:09 +000038def signature(part):
39 """ return the signature of a partial object """
40 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000041
Serhiy Storchaka38741282016-02-02 18:45:17 +020042class MyTuple(tuple):
43 pass
44
45class BadTuple(tuple):
46 def __add__(self, other):
47 return list(self) + list(other)
48
49class MyDict(dict):
50 pass
51
Łukasz Langa6f692512013-06-05 12:20:24 +020052
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020053class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000054
55 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010056 p = self.partial(capture, 1, 2, a=10, b=20)
57 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000058 self.assertEqual(p(3, 4, b=30, c=40),
59 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010060 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000061 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000062
63 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010064 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000065 # attributes should be readable
66 self.assertEqual(p.func, capture)
67 self.assertEqual(p.args, (1, 2))
68 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000069
70 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010071 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000072 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010073 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000074 except TypeError:
75 pass
76 else:
77 self.fail('First arg not checked for callability')
78
79 def test_protection_of_callers_dict_argument(self):
80 # a caller's dictionary should not be altered by partial
81 def func(a=10, b=20):
82 return a
83 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010084 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000085 self.assertEqual(p(**d), 3)
86 self.assertEqual(d, {'a':3})
87 p(b=7)
88 self.assertEqual(d, {'a':3})
89
90 def test_arg_combinations(self):
91 # exercise special code paths for zero args in either partial
92 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010093 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000094 self.assertEqual(p(), ((), {}))
95 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010096 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000097 self.assertEqual(p(), ((1,2), {}))
98 self.assertEqual(p(3,4), ((1,2,3,4), {}))
99
100 def test_kw_combinations(self):
101 # exercise special code paths for no keyword args in
102 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100103 p = self.partial(capture)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400104 self.assertEqual(p.keywords, {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000105 self.assertEqual(p(), ((), {}))
106 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100107 p = self.partial(capture, a=1)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400108 self.assertEqual(p.keywords, {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000109 self.assertEqual(p(), ((), {'a':1}))
110 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
111 # keyword args in the call override those in the partial object
112 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
113
114 def test_positional(self):
115 # make sure positional arguments are captured correctly
116 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100117 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000118 expected = args + ('x',)
119 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000120 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000121
122 def test_keyword(self):
123 # make sure keyword arguments are captured correctly
124 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100125 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000126 expected = {'a':a,'x':None}
127 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000128 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000129
130 def test_no_side_effects(self):
131 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100132 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000133 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000134 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000135 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000136 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000137
138 def test_error_propagation(self):
139 def f(x, y):
140 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100141 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
142 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
143 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
144 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000145
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000146 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100147 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000148 p = proxy(f)
149 self.assertEqual(f.func, p.func)
150 f = None
151 self.assertRaises(ReferenceError, getattr, p, 'func')
152
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000153 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000154 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100155 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000156 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100157 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000158 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000159
Alexander Belopolskye49af342015-03-01 15:08:17 -0500160 def test_nested_optimization(self):
161 partial = self.partial
Alexander Belopolskye49af342015-03-01 15:08:17 -0500162 inner = partial(signature, 'asdf')
163 nested = partial(inner, bar=True)
164 flat = partial(signature, 'asdf', bar=True)
165 self.assertEqual(signature(nested), signature(flat))
166
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300167 def test_nested_partial_with_attribute(self):
168 # see issue 25137
169 partial = self.partial
170
171 def foo(bar):
172 return bar
173
174 p = partial(foo, 'first')
175 p2 = partial(p, 'second')
176 p2.new_attr = 'spam'
177 self.assertEqual(p2.new_attr, 'spam')
178
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000179 def test_repr(self):
180 args = (object(), object())
181 args_repr = ', '.join(repr(a) for a in args)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200182 kwargs = {'a': object(), 'b': object()}
183 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
184 'b={b!r}, a={a!r}'.format_map(kwargs)]
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000185 if self.partial in (c_functools.partial, py_functools.partial):
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000186 name = 'functools.partial'
187 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100188 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000189
Antoine Pitroub5b37142012-11-13 21:35:40 +0100190 f = self.partial(capture)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000191 self.assertEqual(f'{name}({capture!r})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000192
Antoine Pitroub5b37142012-11-13 21:35:40 +0100193 f = self.partial(capture, *args)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000194 self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000195
Antoine Pitroub5b37142012-11-13 21:35:40 +0100196 f = self.partial(capture, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200197 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000198 [f'{name}({capture!r}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200199 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000200
Antoine Pitroub5b37142012-11-13 21:35:40 +0100201 f = self.partial(capture, *args, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200202 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000203 [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200204 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000205
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300206 def test_recursive_repr(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000207 if self.partial in (c_functools.partial, py_functools.partial):
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300208 name = 'functools.partial'
209 else:
210 name = self.partial.__name__
211
212 f = self.partial(capture)
213 f.__setstate__((f, (), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300214 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000215 self.assertEqual(repr(f), '%s(...)' % (name,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300216 finally:
217 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300218
219 f = self.partial(capture)
220 f.__setstate__((capture, (f,), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300221 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000222 self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300223 finally:
224 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300225
226 f = self.partial(capture)
227 f.__setstate__((capture, (), {'a': f}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300228 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000229 self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300230 finally:
231 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300232
Jack Diederiche0cbd692009-04-01 04:27:09 +0000233 def test_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000234 with self.AllowPickle():
235 f = self.partial(signature, ['asdf'], bar=[True])
236 f.attr = []
237 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
238 f_copy = pickle.loads(pickle.dumps(f, proto))
239 self.assertEqual(signature(f_copy), signature(f))
Serhiy Storchaka38741282016-02-02 18:45:17 +0200240
241 def test_copy(self):
242 f = self.partial(signature, ['asdf'], bar=[True])
243 f.attr = []
244 f_copy = copy.copy(f)
245 self.assertEqual(signature(f_copy), signature(f))
246 self.assertIs(f_copy.attr, f.attr)
247 self.assertIs(f_copy.args, f.args)
248 self.assertIs(f_copy.keywords, f.keywords)
249
250 def test_deepcopy(self):
251 f = self.partial(signature, ['asdf'], bar=[True])
252 f.attr = []
253 f_copy = copy.deepcopy(f)
254 self.assertEqual(signature(f_copy), signature(f))
255 self.assertIsNot(f_copy.attr, f.attr)
256 self.assertIsNot(f_copy.args, f.args)
257 self.assertIsNot(f_copy.args[0], f.args[0])
258 self.assertIsNot(f_copy.keywords, f.keywords)
259 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
260
261 def test_setstate(self):
262 f = self.partial(signature)
263 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000264
Serhiy Storchaka38741282016-02-02 18:45:17 +0200265 self.assertEqual(signature(f),
266 (capture, (1,), dict(a=10), dict(attr=[])))
267 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
268
269 f.__setstate__((capture, (1,), dict(a=10), None))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000270
Serhiy Storchaka38741282016-02-02 18:45:17 +0200271 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
272 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
273
274 f.__setstate__((capture, (1,), None, None))
275 #self.assertEqual(signature(f), (capture, (1,), {}, {}))
276 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
277 self.assertEqual(f(2), ((1, 2), {}))
278 self.assertEqual(f(), ((1,), {}))
279
280 f.__setstate__((capture, (), {}, None))
281 self.assertEqual(signature(f), (capture, (), {}, {}))
282 self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
283 self.assertEqual(f(2), ((2,), {}))
284 self.assertEqual(f(), ((), {}))
285
286 def test_setstate_errors(self):
287 f = self.partial(signature)
288 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
289 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
290 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
291 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
292 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
293 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
294 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
295
296 def test_setstate_subclasses(self):
297 f = self.partial(signature)
298 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
299 s = signature(f)
300 self.assertEqual(s, (capture, (1,), dict(a=10), {}))
301 self.assertIs(type(s[1]), tuple)
302 self.assertIs(type(s[2]), dict)
303 r = f()
304 self.assertEqual(r, ((1,), {'a': 10}))
305 self.assertIs(type(r[0]), tuple)
306 self.assertIs(type(r[1]), dict)
307
308 f.__setstate__((capture, BadTuple((1,)), {}, None))
309 s = signature(f)
310 self.assertEqual(s, (capture, (1,), {}, {}))
311 self.assertIs(type(s[1]), tuple)
312 r = f(2)
313 self.assertEqual(r, ((1, 2), {}))
314 self.assertIs(type(r[0]), tuple)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000315
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300316 def test_recursive_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000317 with self.AllowPickle():
318 f = self.partial(capture)
319 f.__setstate__((f, (), {}, {}))
320 try:
321 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
322 with self.assertRaises(RecursionError):
323 pickle.dumps(f, proto)
324 finally:
325 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300326
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000327 f = self.partial(capture)
328 f.__setstate__((capture, (f,), {}, {}))
329 try:
330 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
331 f_copy = pickle.loads(pickle.dumps(f, proto))
332 try:
333 self.assertIs(f_copy.args[0], f_copy)
334 finally:
335 f_copy.__setstate__((capture, (), {}, {}))
336 finally:
337 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300338
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000339 f = self.partial(capture)
340 f.__setstate__((capture, (), {'a': f}, {}))
341 try:
342 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
343 f_copy = pickle.loads(pickle.dumps(f, proto))
344 try:
345 self.assertIs(f_copy.keywords['a'], f_copy)
346 finally:
347 f_copy.__setstate__((capture, (), {}, {}))
348 finally:
349 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300350
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200351 # Issue 6083: Reference counting bug
352 def test_setstate_refcount(self):
353 class BadSequence:
354 def __len__(self):
355 return 4
356 def __getitem__(self, key):
357 if key == 0:
358 return max
359 elif key == 1:
360 return tuple(range(1000000))
361 elif key in (2, 3):
362 return {}
363 raise IndexError
364
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200365 f = self.partial(object)
Serhiy Storchaka38741282016-02-02 18:45:17 +0200366 self.assertRaises(TypeError, f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000367
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000368@unittest.skipUnless(c_functools, 'requires the C _functools module')
369class TestPartialC(TestPartial, unittest.TestCase):
370 if c_functools:
371 partial = c_functools.partial
372
373 class AllowPickle:
374 def __enter__(self):
375 return self
376 def __exit__(self, type, value, tb):
377 return False
378
379 def test_attributes_unwritable(self):
380 # attributes should not be writable
381 p = self.partial(capture, 1, 2, a=10, b=20)
382 self.assertRaises(AttributeError, setattr, p, 'func', map)
383 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
384 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
385
386 p = self.partial(hex)
387 try:
388 del p.__dict__
389 except TypeError:
390 pass
391 else:
392 self.fail('partial object allowed __dict__ to be deleted')
Łukasz Langa6f692512013-06-05 12:20:24 +0200393
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200394class TestPartialPy(TestPartial, unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000395 partial = py_functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000396
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000397 class AllowPickle:
398 def __init__(self):
399 self._cm = replaced_module("functools", py_functools)
400 def __enter__(self):
401 return self._cm.__enter__()
402 def __exit__(self, type, value, tb):
403 return self._cm.__exit__(type, value, tb)
Łukasz Langa6f692512013-06-05 12:20:24 +0200404
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200405if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000406 class CPartialSubclass(c_functools.partial):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200407 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100408
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000409class PyPartialSubclass(py_functools.partial):
410 pass
Łukasz Langa6f692512013-06-05 12:20:24 +0200411
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200412@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200413class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200414 if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000415 partial = CPartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000416
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300417 # partial subclasses are not optimized for nested calls
418 test_nested_optimization = None
419
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000420class TestPartialPySubclass(TestPartialPy):
421 partial = PyPartialSubclass
Łukasz Langa6f692512013-06-05 12:20:24 +0200422
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000423class TestPartialMethod(unittest.TestCase):
424
425 class A(object):
426 nothing = functools.partialmethod(capture)
427 positional = functools.partialmethod(capture, 1)
428 keywords = functools.partialmethod(capture, a=2)
429 both = functools.partialmethod(capture, 3, b=4)
430
431 nested = functools.partialmethod(positional, 5)
432
433 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
434
435 static = functools.partialmethod(staticmethod(capture), 8)
436 cls = functools.partialmethod(classmethod(capture), d=9)
437
438 a = A()
439
440 def test_arg_combinations(self):
441 self.assertEqual(self.a.nothing(), ((self.a,), {}))
442 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
443 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
444 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
445
446 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
447 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
448 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
449 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
450
451 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
452 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
453 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
454 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
455
456 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
457 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
458 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
459 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
460
461 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
462
463 def test_nested(self):
464 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
465 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
466 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
467 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
468
469 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
470
471 def test_over_partial(self):
472 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
473 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
474 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
475 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
476
477 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
478
479 def test_bound_method_introspection(self):
480 obj = self.a
481 self.assertIs(obj.both.__self__, obj)
482 self.assertIs(obj.nested.__self__, obj)
483 self.assertIs(obj.over_partial.__self__, obj)
484 self.assertIs(obj.cls.__self__, self.A)
485 self.assertIs(self.A.cls.__self__, self.A)
486
487 def test_unbound_method_retrieval(self):
488 obj = self.A
489 self.assertFalse(hasattr(obj.both, "__self__"))
490 self.assertFalse(hasattr(obj.nested, "__self__"))
491 self.assertFalse(hasattr(obj.over_partial, "__self__"))
492 self.assertFalse(hasattr(obj.static, "__self__"))
493 self.assertFalse(hasattr(self.a.static, "__self__"))
494
495 def test_descriptors(self):
496 for obj in [self.A, self.a]:
497 with self.subTest(obj=obj):
498 self.assertEqual(obj.static(), ((8,), {}))
499 self.assertEqual(obj.static(5), ((8, 5), {}))
500 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
501 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
502
503 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
504 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
505 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
506 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
507
508 def test_overriding_keywords(self):
509 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
510 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
511
512 def test_invalid_args(self):
513 with self.assertRaises(TypeError):
514 class B(object):
515 method = functools.partialmethod(None, 1)
516
517 def test_repr(self):
518 self.assertEqual(repr(vars(self.A)['both']),
519 'functools.partialmethod({}, 3, b=4)'.format(capture))
520
521 def test_abstract(self):
522 class Abstract(abc.ABCMeta):
523
524 @abc.abstractmethod
525 def add(self, x, y):
526 pass
527
528 add5 = functools.partialmethod(add, 5)
529
530 self.assertTrue(Abstract.add.__isabstractmethod__)
531 self.assertTrue(Abstract.add5.__isabstractmethod__)
532
533 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
534 self.assertFalse(getattr(func, '__isabstractmethod__', False))
535
536
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000537class TestUpdateWrapper(unittest.TestCase):
538
539 def check_wrapper(self, wrapper, wrapped,
540 assigned=functools.WRAPPER_ASSIGNMENTS,
541 updated=functools.WRAPPER_UPDATES):
542 # Check attributes were assigned
543 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000544 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000545 # Check attributes were updated
546 for name in updated:
547 wrapper_attr = getattr(wrapper, name)
548 wrapped_attr = getattr(wrapped, name)
549 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000550 if name == "__dict__" and key == "__wrapped__":
551 # __wrapped__ is overwritten by the update code
552 continue
553 self.assertIs(wrapped_attr[key], wrapper_attr[key])
554 # Check __wrapped__
555 self.assertIs(wrapper.__wrapped__, wrapped)
556
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000557
R. David Murray378c0cf2010-02-24 01:46:21 +0000558 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000559 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000560 """This is a test"""
561 pass
562 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000563 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000564 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000565 pass
566 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000567 return wrapper, f
568
569 def test_default_update(self):
570 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000571 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000572 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000573 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600574 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000575 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000576 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
577 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000578
R. David Murray378c0cf2010-02-24 01:46:21 +0000579 @unittest.skipIf(sys.flags.optimize >= 2,
580 "Docstrings are omitted with -O2 and above")
581 def test_default_update_doc(self):
582 wrapper, f = self._default_update()
583 self.assertEqual(wrapper.__doc__, 'This is a test')
584
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000585 def test_no_update(self):
586 def f():
587 """This is a test"""
588 pass
589 f.attr = 'This is also a test'
590 def wrapper():
591 pass
592 functools.update_wrapper(wrapper, f, (), ())
593 self.check_wrapper(wrapper, f, (), ())
594 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600595 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000596 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000597 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000598 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000599
600 def test_selective_update(self):
601 def f():
602 pass
603 f.attr = 'This is a different test'
604 f.dict_attr = dict(a=1, b=2, c=3)
605 def wrapper():
606 pass
607 wrapper.dict_attr = {}
608 assign = ('attr',)
609 update = ('dict_attr',)
610 functools.update_wrapper(wrapper, f, assign, update)
611 self.check_wrapper(wrapper, f, assign, update)
612 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600613 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000614 self.assertEqual(wrapper.__doc__, None)
615 self.assertEqual(wrapper.attr, 'This is a different test')
616 self.assertEqual(wrapper.dict_attr, f.dict_attr)
617
Nick Coghlan98876832010-08-17 06:17:18 +0000618 def test_missing_attributes(self):
619 def f():
620 pass
621 def wrapper():
622 pass
623 wrapper.dict_attr = {}
624 assign = ('attr',)
625 update = ('dict_attr',)
626 # Missing attributes on wrapped object are ignored
627 functools.update_wrapper(wrapper, f, assign, update)
628 self.assertNotIn('attr', wrapper.__dict__)
629 self.assertEqual(wrapper.dict_attr, {})
630 # Wrapper must have expected attributes for updating
631 del wrapper.dict_attr
632 with self.assertRaises(AttributeError):
633 functools.update_wrapper(wrapper, f, assign, update)
634 wrapper.dict_attr = 1
635 with self.assertRaises(AttributeError):
636 functools.update_wrapper(wrapper, f, assign, update)
637
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200638 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000639 @unittest.skipIf(sys.flags.optimize >= 2,
640 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000641 def test_builtin_update(self):
642 # Test for bug #1576241
643 def wrapper():
644 pass
645 functools.update_wrapper(wrapper, max)
646 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000647 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000648 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000649
Łukasz Langa6f692512013-06-05 12:20:24 +0200650
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000651class TestWraps(TestUpdateWrapper):
652
R. David Murray378c0cf2010-02-24 01:46:21 +0000653 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000654 def f():
655 """This is a test"""
656 pass
657 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000658 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000659 @functools.wraps(f)
660 def wrapper():
661 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600662 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000663
664 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600665 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000666 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000667 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600668 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000669 self.assertEqual(wrapper.attr, 'This is also a test')
670
Antoine Pitroub5b37142012-11-13 21:35:40 +0100671 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000672 "Docstrings are omitted with -O2 and above")
673 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600674 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000675 self.assertEqual(wrapper.__doc__, 'This is a test')
676
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000677 def test_no_update(self):
678 def f():
679 """This is a test"""
680 pass
681 f.attr = 'This is also a test'
682 @functools.wraps(f, (), ())
683 def wrapper():
684 pass
685 self.check_wrapper(wrapper, f, (), ())
686 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600687 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000688 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000689 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000690
691 def test_selective_update(self):
692 def f():
693 pass
694 f.attr = 'This is a different test'
695 f.dict_attr = dict(a=1, b=2, c=3)
696 def add_dict_attr(f):
697 f.dict_attr = {}
698 return f
699 assign = ('attr',)
700 update = ('dict_attr',)
701 @functools.wraps(f, assign, update)
702 @add_dict_attr
703 def wrapper():
704 pass
705 self.check_wrapper(wrapper, f, assign, update)
706 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600707 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000708 self.assertEqual(wrapper.__doc__, None)
709 self.assertEqual(wrapper.attr, 'This is a different test')
710 self.assertEqual(wrapper.dict_attr, f.dict_attr)
711
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000712@unittest.skipUnless(c_functools, 'requires the C _functools module')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000713class TestReduce(unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000714 if c_functools:
715 func = c_functools.reduce
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000716
717 def test_reduce(self):
718 class Squares:
719 def __init__(self, max):
720 self.max = max
721 self.sofar = []
722
723 def __len__(self):
724 return len(self.sofar)
725
726 def __getitem__(self, i):
727 if not 0 <= i < self.max: raise IndexError
728 n = len(self.sofar)
729 while n <= i:
730 self.sofar.append(n*n)
731 n += 1
732 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000733 def add(x, y):
734 return x + y
735 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000736 self.assertEqual(
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000737 self.func(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000738 ['a','c','d','w']
739 )
740 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
741 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000742 self.func(lambda x, y: x*y, range(2,21), 1),
743 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000744 )
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000745 self.assertEqual(self.func(add, Squares(10)), 285)
746 self.assertEqual(self.func(add, Squares(10), 0), 285)
747 self.assertEqual(self.func(add, Squares(0), 0), 0)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000748 self.assertRaises(TypeError, self.func)
749 self.assertRaises(TypeError, self.func, 42, 42)
750 self.assertRaises(TypeError, self.func, 42, 42, 42)
751 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
752 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
753 self.assertRaises(TypeError, self.func, 42, (42, 42))
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000754 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
755 self.assertRaises(TypeError, self.func, add, "")
756 self.assertRaises(TypeError, self.func, add, ())
757 self.assertRaises(TypeError, self.func, add, object())
758
759 class TestFailingIter:
760 def __iter__(self):
761 raise RuntimeError
762 self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
763
764 self.assertEqual(self.func(add, [], None), None)
765 self.assertEqual(self.func(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000766
767 class BadSeq:
768 def __getitem__(self, index):
769 raise ValueError
770 self.assertRaises(ValueError, self.func, 42, BadSeq())
771
772 # Test reduce()'s use of iterators.
773 def test_iterator_usage(self):
774 class SequenceClass:
775 def __init__(self, n):
776 self.n = n
777 def __getitem__(self, i):
778 if 0 <= i < self.n:
779 return i
780 else:
781 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000782
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000783 from operator import add
784 self.assertEqual(self.func(add, SequenceClass(5)), 10)
785 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
786 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
787 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
788 self.assertEqual(self.func(add, SequenceClass(1)), 0)
789 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
790
791 d = {"one": 1, "two": 2, "three": 3}
792 self.assertEqual(self.func(add, d), "".join(d.keys()))
793
Łukasz Langa6f692512013-06-05 12:20:24 +0200794
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200795class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700796
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000797 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700798 def cmp1(x, y):
799 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100800 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700801 self.assertEqual(key(3), key(3))
802 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100803 self.assertGreaterEqual(key(3), key(3))
804
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700805 def cmp2(x, y):
806 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100807 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700808 self.assertEqual(key(4.0), key('4'))
809 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100810 self.assertLessEqual(key(2), key('35'))
811 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700812
813 def test_cmp_to_key_arguments(self):
814 def cmp1(x, y):
815 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100816 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700817 self.assertEqual(key(obj=3), key(obj=3))
818 self.assertGreater(key(obj=3), key(obj=1))
819 with self.assertRaises((TypeError, AttributeError)):
820 key(3) > 1 # rhs is not a K object
821 with self.assertRaises((TypeError, AttributeError)):
822 1 < key(3) # lhs is not a K object
823 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100824 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700825 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200826 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100827 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700828 with self.assertRaises(TypeError):
829 key() # too few args
830 with self.assertRaises(TypeError):
831 key(None, None) # too many args
832
833 def test_bad_cmp(self):
834 def cmp1(x, y):
835 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100836 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700837 with self.assertRaises(ZeroDivisionError):
838 key(3) > key(1)
839
840 class BadCmp:
841 def __lt__(self, other):
842 raise ZeroDivisionError
843 def cmp1(x, y):
844 return BadCmp()
845 with self.assertRaises(ZeroDivisionError):
846 key(3) > key(1)
847
848 def test_obj_field(self):
849 def cmp1(x, y):
850 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100851 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700852 self.assertEqual(key(50).obj, 50)
853
854 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000855 def mycmp(x, y):
856 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100857 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000858 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000859
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700860 def test_sort_int_str(self):
861 def mycmp(x, y):
862 x, y = int(x), int(y)
863 return (x > y) - (x < y)
864 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100865 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700866 self.assertEqual([int(value) for value in values],
867 [0, 1, 1, 2, 3, 4, 5, 7, 10])
868
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000869 def test_hash(self):
870 def mycmp(x, y):
871 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100872 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000873 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700874 self.assertRaises(TypeError, hash, k)
Raymond Hettingere7a24302011-05-03 11:16:36 -0700875 self.assertNotIsInstance(k, collections.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000876
Łukasz Langa6f692512013-06-05 12:20:24 +0200877
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200878@unittest.skipUnless(c_functools, 'requires the C _functools module')
879class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
880 if c_functools:
881 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100882
Łukasz Langa6f692512013-06-05 12:20:24 +0200883
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200884class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100885 cmp_to_key = staticmethod(py_functools.cmp_to_key)
886
Łukasz Langa6f692512013-06-05 12:20:24 +0200887
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000888class TestTotalOrdering(unittest.TestCase):
889
890 def test_total_ordering_lt(self):
891 @functools.total_ordering
892 class A:
893 def __init__(self, value):
894 self.value = value
895 def __lt__(self, other):
896 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000897 def __eq__(self, other):
898 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000899 self.assertTrue(A(1) < A(2))
900 self.assertTrue(A(2) > A(1))
901 self.assertTrue(A(1) <= A(2))
902 self.assertTrue(A(2) >= A(1))
903 self.assertTrue(A(2) <= A(2))
904 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000905 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000906
907 def test_total_ordering_le(self):
908 @functools.total_ordering
909 class A:
910 def __init__(self, value):
911 self.value = value
912 def __le__(self, other):
913 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000914 def __eq__(self, other):
915 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000916 self.assertTrue(A(1) < A(2))
917 self.assertTrue(A(2) > A(1))
918 self.assertTrue(A(1) <= A(2))
919 self.assertTrue(A(2) >= A(1))
920 self.assertTrue(A(2) <= A(2))
921 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000922 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000923
924 def test_total_ordering_gt(self):
925 @functools.total_ordering
926 class A:
927 def __init__(self, value):
928 self.value = value
929 def __gt__(self, other):
930 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000931 def __eq__(self, other):
932 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000933 self.assertTrue(A(1) < A(2))
934 self.assertTrue(A(2) > A(1))
935 self.assertTrue(A(1) <= A(2))
936 self.assertTrue(A(2) >= A(1))
937 self.assertTrue(A(2) <= A(2))
938 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000939 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000940
941 def test_total_ordering_ge(self):
942 @functools.total_ordering
943 class A:
944 def __init__(self, value):
945 self.value = value
946 def __ge__(self, other):
947 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000948 def __eq__(self, other):
949 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000950 self.assertTrue(A(1) < A(2))
951 self.assertTrue(A(2) > A(1))
952 self.assertTrue(A(1) <= A(2))
953 self.assertTrue(A(2) >= A(1))
954 self.assertTrue(A(2) <= A(2))
955 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000956 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000957
958 def test_total_ordering_no_overwrite(self):
959 # new methods should not overwrite existing
960 @functools.total_ordering
961 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000962 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +0000963 self.assertTrue(A(1) < A(2))
964 self.assertTrue(A(2) > A(1))
965 self.assertTrue(A(1) <= A(2))
966 self.assertTrue(A(2) >= A(1))
967 self.assertTrue(A(2) <= A(2))
968 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000969
Benjamin Peterson42ebee32010-04-11 01:43:16 +0000970 def test_no_operations_defined(self):
971 with self.assertRaises(ValueError):
972 @functools.total_ordering
973 class A:
974 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000975
Nick Coghlanf05d9812013-10-02 00:02:03 +1000976 def test_type_error_when_not_implemented(self):
977 # bug 10042; ensure stack overflow does not occur
978 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000979 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +1000980 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000981 def __init__(self, value):
982 self.value = value
983 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000984 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000985 return self.value == other.value
986 return False
987 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000988 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000989 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +1000990 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000991
Nick Coghlanf05d9812013-10-02 00:02:03 +1000992 @functools.total_ordering
993 class ImplementsGreaterThan:
994 def __init__(self, value):
995 self.value = value
996 def __eq__(self, other):
997 if isinstance(other, ImplementsGreaterThan):
998 return self.value == other.value
999 return False
1000 def __gt__(self, other):
1001 if isinstance(other, ImplementsGreaterThan):
1002 return self.value > other.value
1003 return NotImplemented
1004
1005 @functools.total_ordering
1006 class ImplementsLessThanEqualTo:
1007 def __init__(self, value):
1008 self.value = value
1009 def __eq__(self, other):
1010 if isinstance(other, ImplementsLessThanEqualTo):
1011 return self.value == other.value
1012 return False
1013 def __le__(self, other):
1014 if isinstance(other, ImplementsLessThanEqualTo):
1015 return self.value <= other.value
1016 return NotImplemented
1017
1018 @functools.total_ordering
1019 class ImplementsGreaterThanEqualTo:
1020 def __init__(self, value):
1021 self.value = value
1022 def __eq__(self, other):
1023 if isinstance(other, ImplementsGreaterThanEqualTo):
1024 return self.value == other.value
1025 return False
1026 def __ge__(self, other):
1027 if isinstance(other, ImplementsGreaterThanEqualTo):
1028 return self.value >= other.value
1029 return NotImplemented
1030
1031 @functools.total_ordering
1032 class ComparatorNotImplemented:
1033 def __init__(self, value):
1034 self.value = value
1035 def __eq__(self, other):
1036 if isinstance(other, ComparatorNotImplemented):
1037 return self.value == other.value
1038 return False
1039 def __lt__(self, other):
1040 return NotImplemented
1041
1042 with self.subTest("LT < 1"), self.assertRaises(TypeError):
1043 ImplementsLessThan(-1) < 1
1044
1045 with self.subTest("LT < LE"), self.assertRaises(TypeError):
1046 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
1047
1048 with self.subTest("LT < GT"), self.assertRaises(TypeError):
1049 ImplementsLessThan(1) < ImplementsGreaterThan(1)
1050
1051 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
1052 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
1053
1054 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
1055 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
1056
1057 with self.subTest("GT > GE"), self.assertRaises(TypeError):
1058 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
1059
1060 with self.subTest("GT > LT"), self.assertRaises(TypeError):
1061 ImplementsGreaterThan(5) > ImplementsLessThan(5)
1062
1063 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
1064 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
1065
1066 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
1067 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
1068
1069 with self.subTest("GE when equal"):
1070 a = ComparatorNotImplemented(8)
1071 b = ComparatorNotImplemented(8)
1072 self.assertEqual(a, b)
1073 with self.assertRaises(TypeError):
1074 a >= b
1075
1076 with self.subTest("LE when equal"):
1077 a = ComparatorNotImplemented(9)
1078 b = ComparatorNotImplemented(9)
1079 self.assertEqual(a, b)
1080 with self.assertRaises(TypeError):
1081 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +02001082
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001083 def test_pickle(self):
Serhiy Storchaka92bb90a2016-09-22 11:39:25 +03001084 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001085 for name in '__lt__', '__gt__', '__le__', '__ge__':
1086 with self.subTest(method=name, proto=proto):
1087 method = getattr(Orderable_LT, name)
1088 method_copy = pickle.loads(pickle.dumps(method, proto))
1089 self.assertIs(method_copy, method)
1090
1091@functools.total_ordering
1092class Orderable_LT:
1093 def __init__(self, value):
1094 self.value = value
1095 def __lt__(self, other):
1096 return self.value < other.value
1097 def __eq__(self, other):
1098 return self.value == other.value
1099
1100
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001101class TestLRU:
Georg Brandl2e7346a2010-07-31 18:09:23 +00001102
1103 def test_lru(self):
1104 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001105 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001106 f = self.module.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001107 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001108 self.assertEqual(maxsize, 20)
1109 self.assertEqual(currsize, 0)
1110 self.assertEqual(hits, 0)
1111 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001112
1113 domain = range(5)
1114 for i in range(1000):
1115 x, y = choice(domain), choice(domain)
1116 actual = f(x, y)
1117 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +00001118 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001119 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001120 self.assertTrue(hits > misses)
1121 self.assertEqual(hits + misses, 1000)
1122 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001123
Raymond Hettinger02566ec2010-09-04 22:46:06 +00001124 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +00001125 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001126 self.assertEqual(hits, 0)
1127 self.assertEqual(misses, 0)
1128 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001129 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001130 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001131 self.assertEqual(hits, 0)
1132 self.assertEqual(misses, 1)
1133 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001134
Nick Coghlan98876832010-08-17 06:17:18 +00001135 # Test bypassing the cache
1136 self.assertIs(f.__wrapped__, orig)
1137 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001138 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001139 self.assertEqual(hits, 0)
1140 self.assertEqual(misses, 1)
1141 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +00001142
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001143 # test size zero (which means "never-cache")
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001144 @self.module.lru_cache(0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001145 def f():
1146 nonlocal f_cnt
1147 f_cnt += 1
1148 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001149 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001150 f_cnt = 0
1151 for i in range(5):
1152 self.assertEqual(f(), 20)
1153 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001154 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001155 self.assertEqual(hits, 0)
1156 self.assertEqual(misses, 5)
1157 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001158
1159 # test size one
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001160 @self.module.lru_cache(1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001161 def f():
1162 nonlocal f_cnt
1163 f_cnt += 1
1164 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001165 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001166 f_cnt = 0
1167 for i in range(5):
1168 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001169 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001170 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001171 self.assertEqual(hits, 4)
1172 self.assertEqual(misses, 1)
1173 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001174
Raymond Hettingerf3098282010-08-15 03:30:45 +00001175 # test size two
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001176 @self.module.lru_cache(2)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001177 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001178 nonlocal f_cnt
1179 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +00001180 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +00001181 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001182 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +00001183 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1184 # * * * *
1185 self.assertEqual(f(x), x*10)
1186 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001187 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001188 self.assertEqual(hits, 12)
1189 self.assertEqual(misses, 4)
1190 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001191
Yury Selivanov46a02db2016-11-09 18:55:45 -05001192 def test_lru_type_error(self):
1193 # Regression test for issue #28653.
1194 # lru_cache was leaking when one of the arguments
1195 # wasn't cacheable.
1196
1197 @functools.lru_cache(maxsize=None)
1198 def infinite_cache(o):
1199 pass
1200
1201 @functools.lru_cache(maxsize=10)
1202 def limited_cache(o):
1203 pass
1204
1205 with self.assertRaises(TypeError):
1206 infinite_cache([])
1207
1208 with self.assertRaises(TypeError):
1209 limited_cache([])
1210
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001211 def test_lru_with_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001212 @self.module.lru_cache(maxsize=None)
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001213 def fib(n):
1214 if n < 2:
1215 return n
1216 return fib(n-1) + fib(n-2)
1217 self.assertEqual([fib(n) for n in range(16)],
1218 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1219 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001220 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001221 fib.cache_clear()
1222 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001223 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1224
1225 def test_lru_with_maxsize_negative(self):
1226 @self.module.lru_cache(maxsize=-10)
1227 def eq(n):
1228 return n
1229 for i in (0, 1):
1230 self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1231 self.assertEqual(eq.cache_info(),
1232 self.module._CacheInfo(hits=0, misses=300, maxsize=-10, currsize=1))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001233
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001234 def test_lru_with_exceptions(self):
1235 # Verify that user_function exceptions get passed through without
1236 # creating a hard-to-read chained exception.
1237 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +01001238 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001239 @self.module.lru_cache(maxsize)
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001240 def func(i):
1241 return 'abc'[i]
1242 self.assertEqual(func(0), 'a')
1243 with self.assertRaises(IndexError) as cm:
1244 func(15)
1245 self.assertIsNone(cm.exception.__context__)
1246 # Verify that the previous exception did not result in a cached entry
1247 with self.assertRaises(IndexError):
1248 func(15)
1249
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001250 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001251 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001252 @self.module.lru_cache(maxsize=maxsize, typed=True)
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001253 def square(x):
1254 return x * x
1255 self.assertEqual(square(3), 9)
1256 self.assertEqual(type(square(3)), type(9))
1257 self.assertEqual(square(3.0), 9.0)
1258 self.assertEqual(type(square(3.0)), type(9.0))
1259 self.assertEqual(square(x=3), 9)
1260 self.assertEqual(type(square(x=3)), type(9))
1261 self.assertEqual(square(x=3.0), 9.0)
1262 self.assertEqual(type(square(x=3.0)), type(9.0))
1263 self.assertEqual(square.cache_info().hits, 4)
1264 self.assertEqual(square.cache_info().misses, 4)
1265
Antoine Pitroub5b37142012-11-13 21:35:40 +01001266 def test_lru_with_keyword_args(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001267 @self.module.lru_cache()
Antoine Pitroub5b37142012-11-13 21:35:40 +01001268 def fib(n):
1269 if n < 2:
1270 return n
1271 return fib(n=n-1) + fib(n=n-2)
1272 self.assertEqual(
1273 [fib(n=number) for number in range(16)],
1274 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1275 )
1276 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001277 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001278 fib.cache_clear()
1279 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001280 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001281
1282 def test_lru_with_keyword_args_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001283 @self.module.lru_cache(maxsize=None)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001284 def fib(n):
1285 if n < 2:
1286 return n
1287 return fib(n=n-1) + fib(n=n-2)
1288 self.assertEqual([fib(n=number) for number in range(16)],
1289 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1290 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001291 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001292 fib.cache_clear()
1293 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001294 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1295
1296 def test_lru_cache_decoration(self):
1297 def f(zomg: 'zomg_annotation'):
1298 """f doc string"""
1299 return 42
1300 g = self.module.lru_cache()(f)
1301 for attr in self.module.WRAPPER_ASSIGNMENTS:
1302 self.assertEqual(getattr(g, attr), getattr(f, attr))
1303
1304 @unittest.skipUnless(threading, 'This test requires threading.')
1305 def test_lru_cache_threaded(self):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001306 n, m = 5, 11
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001307 def orig(x, y):
1308 return 3 * x + y
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001309 f = self.module.lru_cache(maxsize=n*m)(orig)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001310 hits, misses, maxsize, currsize = f.cache_info()
1311 self.assertEqual(currsize, 0)
1312
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001313 start = threading.Event()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001314 def full(k):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001315 start.wait(10)
1316 for _ in range(m):
Serhiy Storchaka391af752015-06-08 12:44:18 +03001317 self.assertEqual(f(k, 0), orig(k, 0))
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001318
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001319 def clear():
1320 start.wait(10)
1321 for _ in range(2*m):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001322 f.cache_clear()
1323
1324 orig_si = sys.getswitchinterval()
1325 sys.setswitchinterval(1e-6)
1326 try:
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001327 # create n threads in order to fill cache
Serhiy Storchaka391af752015-06-08 12:44:18 +03001328 threads = [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001329 for k in range(n)]
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001330 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001331 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001332
1333 hits, misses, maxsize, currsize = f.cache_info()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001334 if self.module is py_functools:
1335 # XXX: Why can be not equal?
1336 self.assertLessEqual(misses, n)
1337 self.assertLessEqual(hits, m*n - misses)
1338 else:
1339 self.assertEqual(misses, n)
1340 self.assertEqual(hits, m*n - misses)
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001341 self.assertEqual(currsize, n)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001342
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001343 # create n threads in order to fill cache and 1 to clear it
1344 threads = [threading.Thread(target=clear)]
Serhiy Storchaka391af752015-06-08 12:44:18 +03001345 threads += [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001346 for k in range(n)]
1347 start.clear()
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001348 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001349 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001350 finally:
1351 sys.setswitchinterval(orig_si)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001352
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001353 @unittest.skipUnless(threading, 'This test requires threading.')
1354 def test_lru_cache_threaded2(self):
1355 # Simultaneous call with the same arguments
1356 n, m = 5, 7
1357 start = threading.Barrier(n+1)
1358 pause = threading.Barrier(n+1)
1359 stop = threading.Barrier(n+1)
1360 @self.module.lru_cache(maxsize=m*n)
1361 def f(x):
1362 pause.wait(10)
1363 return 3 * x
1364 self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1365 def test():
1366 for i in range(m):
1367 start.wait(10)
1368 self.assertEqual(f(i), 3 * i)
1369 stop.wait(10)
1370 threads = [threading.Thread(target=test) for k in range(n)]
1371 with support.start_threads(threads):
1372 for i in range(m):
1373 start.wait(10)
1374 stop.reset()
1375 pause.wait(10)
1376 start.reset()
1377 stop.wait(10)
1378 pause.reset()
1379 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1380
Raymond Hettinger03923422013-03-04 02:52:50 -05001381 def test_need_for_rlock(self):
1382 # This will deadlock on an LRU cache that uses a regular lock
1383
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001384 @self.module.lru_cache(maxsize=10)
Raymond Hettinger03923422013-03-04 02:52:50 -05001385 def test_func(x):
1386 'Used to demonstrate a reentrant lru_cache call within a single thread'
1387 return x
1388
1389 class DoubleEq:
1390 'Demonstrate a reentrant lru_cache call within a single thread'
1391 def __init__(self, x):
1392 self.x = x
1393 def __hash__(self):
1394 return self.x
1395 def __eq__(self, other):
1396 if self.x == 2:
1397 test_func(DoubleEq(1))
1398 return self.x == other.x
1399
1400 test_func(DoubleEq(1)) # Load the cache
1401 test_func(DoubleEq(2)) # Load the cache
1402 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1403 DoubleEq(2)) # Verify the correct return value
1404
Raymond Hettinger4d588972014-08-12 12:44:52 -07001405 def test_early_detection_of_bad_call(self):
1406 # Issue #22184
1407 with self.assertRaises(TypeError):
1408 @functools.lru_cache
1409 def f():
1410 pass
1411
Serhiy Storchakae7070f02015-06-08 11:19:24 +03001412 def test_lru_method(self):
1413 class X(int):
1414 f_cnt = 0
1415 @self.module.lru_cache(2)
1416 def f(self, x):
1417 self.f_cnt += 1
1418 return x*10+self
1419 a = X(5)
1420 b = X(5)
1421 c = X(7)
1422 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1423
1424 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1425 self.assertEqual(a.f(x), x*10 + 5)
1426 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1427 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1428
1429 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1430 self.assertEqual(b.f(x), x*10 + 5)
1431 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1432 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1433
1434 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1435 self.assertEqual(c.f(x), x*10 + 7)
1436 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1437 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1438
1439 self.assertEqual(a.f.cache_info(), X.f.cache_info())
1440 self.assertEqual(b.f.cache_info(), X.f.cache_info())
1441 self.assertEqual(c.f.cache_info(), X.f.cache_info())
1442
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001443 def test_pickle(self):
1444 cls = self.__class__
1445 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1446 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1447 with self.subTest(proto=proto, func=f):
1448 f_copy = pickle.loads(pickle.dumps(f, proto))
1449 self.assertIs(f_copy, f)
1450
1451 def test_copy(self):
1452 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001453 def orig(x, y):
1454 return 3 * x + y
1455 part = self.module.partial(orig, 2)
1456 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1457 self.module.lru_cache(2)(part))
1458 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001459 with self.subTest(func=f):
1460 f_copy = copy.copy(f)
1461 self.assertIs(f_copy, f)
1462
1463 def test_deepcopy(self):
1464 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001465 def orig(x, y):
1466 return 3 * x + y
1467 part = self.module.partial(orig, 2)
1468 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1469 self.module.lru_cache(2)(part))
1470 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001471 with self.subTest(func=f):
1472 f_copy = copy.deepcopy(f)
1473 self.assertIs(f_copy, f)
1474
1475
1476@py_functools.lru_cache()
1477def py_cached_func(x, y):
1478 return 3 * x + y
1479
1480@c_functools.lru_cache()
1481def c_cached_func(x, y):
1482 return 3 * x + y
1483
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001484
1485class TestLRUPy(TestLRU, unittest.TestCase):
1486 module = py_functools
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001487 cached_func = py_cached_func,
1488
1489 @module.lru_cache()
1490 def cached_meth(self, x, y):
1491 return 3 * x + y
1492
1493 @staticmethod
1494 @module.lru_cache()
1495 def cached_staticmeth(x, y):
1496 return 3 * x + y
1497
1498
1499class TestLRUC(TestLRU, unittest.TestCase):
1500 module = c_functools
1501 cached_func = c_cached_func,
1502
1503 @module.lru_cache()
1504 def cached_meth(self, x, y):
1505 return 3 * x + y
1506
1507 @staticmethod
1508 @module.lru_cache()
1509 def cached_staticmeth(x, y):
1510 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001511
Raymond Hettinger03923422013-03-04 02:52:50 -05001512
Łukasz Langa6f692512013-06-05 12:20:24 +02001513class TestSingleDispatch(unittest.TestCase):
1514 def test_simple_overloads(self):
1515 @functools.singledispatch
1516 def g(obj):
1517 return "base"
1518 def g_int(i):
1519 return "integer"
1520 g.register(int, g_int)
1521 self.assertEqual(g("str"), "base")
1522 self.assertEqual(g(1), "integer")
1523 self.assertEqual(g([1,2,3]), "base")
1524
1525 def test_mro(self):
1526 @functools.singledispatch
1527 def g(obj):
1528 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001529 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001530 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001531 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001532 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001533 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001534 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001535 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001536 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001537 def g_A(a):
1538 return "A"
1539 def g_B(b):
1540 return "B"
1541 g.register(A, g_A)
1542 g.register(B, g_B)
1543 self.assertEqual(g(A()), "A")
1544 self.assertEqual(g(B()), "B")
1545 self.assertEqual(g(C()), "A")
1546 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001547
1548 def test_register_decorator(self):
1549 @functools.singledispatch
1550 def g(obj):
1551 return "base"
1552 @g.register(int)
1553 def g_int(i):
1554 return "int %s" % (i,)
1555 self.assertEqual(g(""), "base")
1556 self.assertEqual(g(12), "int 12")
1557 self.assertIs(g.dispatch(int), g_int)
1558 self.assertIs(g.dispatch(object), g.dispatch(str))
1559 # Note: in the assert above this is not g.
1560 # @singledispatch returns the wrapper.
1561
1562 def test_wrapping_attributes(self):
1563 @functools.singledispatch
1564 def g(obj):
1565 "Simple test"
1566 return "Test"
1567 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02001568 if sys.flags.optimize < 2:
1569 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02001570
1571 @unittest.skipUnless(decimal, 'requires _decimal')
1572 @support.cpython_only
1573 def test_c_classes(self):
1574 @functools.singledispatch
1575 def g(obj):
1576 return "base"
1577 @g.register(decimal.DecimalException)
1578 def _(obj):
1579 return obj.args
1580 subn = decimal.Subnormal("Exponent < Emin")
1581 rnd = decimal.Rounded("Number got rounded")
1582 self.assertEqual(g(subn), ("Exponent < Emin",))
1583 self.assertEqual(g(rnd), ("Number got rounded",))
1584 @g.register(decimal.Subnormal)
1585 def _(obj):
1586 return "Too small to care."
1587 self.assertEqual(g(subn), "Too small to care.")
1588 self.assertEqual(g(rnd), ("Number got rounded",))
1589
1590 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001591 # None of the examples in this test depend on haystack ordering.
Łukasz Langa6f692512013-06-05 12:20:24 +02001592 c = collections
1593 mro = functools._compose_mro
1594 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1595 for haystack in permutations(bases):
1596 m = mro(dict, haystack)
Guido van Rossumf0666942016-08-23 10:47:07 -07001597 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
1598 c.Collection, c.Sized, c.Iterable,
1599 c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001600 bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict]
1601 for haystack in permutations(bases):
1602 m = mro(c.ChainMap, haystack)
1603 self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001604 c.Collection, c.Sized, c.Iterable,
1605 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001606
1607 # If there's a generic function with implementations registered for
1608 # both Sized and Container, passing a defaultdict to it results in an
1609 # ambiguous dispatch which will cause a RuntimeError (see
1610 # test_mro_conflicts).
1611 bases = [c.Container, c.Sized, str]
1612 for haystack in permutations(bases):
1613 m = mro(c.defaultdict, [c.Sized, c.Container, str])
1614 self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container,
1615 object])
1616
1617 # MutableSequence below is registered directly on D. In other words, it
Martin Panter46f50722016-05-26 05:35:26 +00001618 # precedes MutableMapping which means single dispatch will always
Łukasz Langa3720c772013-07-01 16:00:38 +02001619 # choose MutableSequence here.
1620 class D(c.defaultdict):
1621 pass
1622 c.MutableSequence.register(D)
1623 bases = [c.MutableSequence, c.MutableMapping]
1624 for haystack in permutations(bases):
1625 m = mro(D, bases)
Guido van Rossumf0666942016-08-23 10:47:07 -07001626 self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
1627 c.defaultdict, dict, c.MutableMapping, c.Mapping,
1628 c.Collection, c.Sized, c.Iterable, c.Container,
Łukasz Langa3720c772013-07-01 16:00:38 +02001629 object])
1630
1631 # Container and Callable are registered on different base classes and
1632 # a generic function supporting both should always pick the Callable
1633 # implementation if a C instance is passed.
1634 class C(c.defaultdict):
1635 def __call__(self):
1636 pass
1637 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1638 for haystack in permutations(bases):
1639 m = mro(C, haystack)
1640 self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001641 c.Collection, c.Sized, c.Iterable,
1642 c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001643
1644 def test_register_abc(self):
1645 c = collections
1646 d = {"a": "b"}
1647 l = [1, 2, 3]
1648 s = {object(), None}
1649 f = frozenset(s)
1650 t = (1, 2, 3)
1651 @functools.singledispatch
1652 def g(obj):
1653 return "base"
1654 self.assertEqual(g(d), "base")
1655 self.assertEqual(g(l), "base")
1656 self.assertEqual(g(s), "base")
1657 self.assertEqual(g(f), "base")
1658 self.assertEqual(g(t), "base")
1659 g.register(c.Sized, lambda obj: "sized")
1660 self.assertEqual(g(d), "sized")
1661 self.assertEqual(g(l), "sized")
1662 self.assertEqual(g(s), "sized")
1663 self.assertEqual(g(f), "sized")
1664 self.assertEqual(g(t), "sized")
1665 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1666 self.assertEqual(g(d), "mutablemapping")
1667 self.assertEqual(g(l), "sized")
1668 self.assertEqual(g(s), "sized")
1669 self.assertEqual(g(f), "sized")
1670 self.assertEqual(g(t), "sized")
1671 g.register(c.ChainMap, lambda obj: "chainmap")
1672 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1673 self.assertEqual(g(l), "sized")
1674 self.assertEqual(g(s), "sized")
1675 self.assertEqual(g(f), "sized")
1676 self.assertEqual(g(t), "sized")
1677 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1678 self.assertEqual(g(d), "mutablemapping")
1679 self.assertEqual(g(l), "mutablesequence")
1680 self.assertEqual(g(s), "sized")
1681 self.assertEqual(g(f), "sized")
1682 self.assertEqual(g(t), "sized")
1683 g.register(c.MutableSet, lambda obj: "mutableset")
1684 self.assertEqual(g(d), "mutablemapping")
1685 self.assertEqual(g(l), "mutablesequence")
1686 self.assertEqual(g(s), "mutableset")
1687 self.assertEqual(g(f), "sized")
1688 self.assertEqual(g(t), "sized")
1689 g.register(c.Mapping, lambda obj: "mapping")
1690 self.assertEqual(g(d), "mutablemapping") # not specific enough
1691 self.assertEqual(g(l), "mutablesequence")
1692 self.assertEqual(g(s), "mutableset")
1693 self.assertEqual(g(f), "sized")
1694 self.assertEqual(g(t), "sized")
1695 g.register(c.Sequence, lambda obj: "sequence")
1696 self.assertEqual(g(d), "mutablemapping")
1697 self.assertEqual(g(l), "mutablesequence")
1698 self.assertEqual(g(s), "mutableset")
1699 self.assertEqual(g(f), "sized")
1700 self.assertEqual(g(t), "sequence")
1701 g.register(c.Set, lambda obj: "set")
1702 self.assertEqual(g(d), "mutablemapping")
1703 self.assertEqual(g(l), "mutablesequence")
1704 self.assertEqual(g(s), "mutableset")
1705 self.assertEqual(g(f), "set")
1706 self.assertEqual(g(t), "sequence")
1707 g.register(dict, lambda obj: "dict")
1708 self.assertEqual(g(d), "dict")
1709 self.assertEqual(g(l), "mutablesequence")
1710 self.assertEqual(g(s), "mutableset")
1711 self.assertEqual(g(f), "set")
1712 self.assertEqual(g(t), "sequence")
1713 g.register(list, lambda obj: "list")
1714 self.assertEqual(g(d), "dict")
1715 self.assertEqual(g(l), "list")
1716 self.assertEqual(g(s), "mutableset")
1717 self.assertEqual(g(f), "set")
1718 self.assertEqual(g(t), "sequence")
1719 g.register(set, lambda obj: "concrete-set")
1720 self.assertEqual(g(d), "dict")
1721 self.assertEqual(g(l), "list")
1722 self.assertEqual(g(s), "concrete-set")
1723 self.assertEqual(g(f), "set")
1724 self.assertEqual(g(t), "sequence")
1725 g.register(frozenset, lambda obj: "frozen-set")
1726 self.assertEqual(g(d), "dict")
1727 self.assertEqual(g(l), "list")
1728 self.assertEqual(g(s), "concrete-set")
1729 self.assertEqual(g(f), "frozen-set")
1730 self.assertEqual(g(t), "sequence")
1731 g.register(tuple, lambda obj: "tuple")
1732 self.assertEqual(g(d), "dict")
1733 self.assertEqual(g(l), "list")
1734 self.assertEqual(g(s), "concrete-set")
1735 self.assertEqual(g(f), "frozen-set")
1736 self.assertEqual(g(t), "tuple")
1737
Łukasz Langa3720c772013-07-01 16:00:38 +02001738 def test_c3_abc(self):
1739 c = collections
1740 mro = functools._c3_mro
1741 class A(object):
1742 pass
1743 class B(A):
1744 def __len__(self):
1745 return 0 # implies Sized
1746 @c.Container.register
1747 class C(object):
1748 pass
1749 class D(object):
1750 pass # unrelated
1751 class X(D, C, B):
1752 def __call__(self):
1753 pass # implies Callable
1754 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
1755 for abcs in permutations([c.Sized, c.Callable, c.Container]):
1756 self.assertEqual(mro(X, abcs=abcs), expected)
1757 # unrelated ABCs don't appear in the resulting MRO
1758 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
1759 self.assertEqual(mro(X, abcs=many_abcs), expected)
1760
Yury Selivanov77a8cd62015-08-18 14:20:00 -04001761 def test_false_meta(self):
1762 # see issue23572
1763 class MetaA(type):
1764 def __len__(self):
1765 return 0
1766 class A(metaclass=MetaA):
1767 pass
1768 class AA(A):
1769 pass
1770 @functools.singledispatch
1771 def fun(a):
1772 return 'base A'
1773 @fun.register(A)
1774 def _(a):
1775 return 'fun A'
1776 aa = AA()
1777 self.assertEqual(fun(aa), 'fun A')
1778
Łukasz Langa6f692512013-06-05 12:20:24 +02001779 def test_mro_conflicts(self):
1780 c = collections
Łukasz Langa6f692512013-06-05 12:20:24 +02001781 @functools.singledispatch
1782 def g(arg):
1783 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02001784 class O(c.Sized):
1785 def __len__(self):
1786 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001787 o = O()
1788 self.assertEqual(g(o), "base")
1789 g.register(c.Iterable, lambda arg: "iterable")
1790 g.register(c.Container, lambda arg: "container")
1791 g.register(c.Sized, lambda arg: "sized")
1792 g.register(c.Set, lambda arg: "set")
1793 self.assertEqual(g(o), "sized")
1794 c.Iterable.register(O)
1795 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
1796 c.Container.register(O)
1797 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02001798 c.Set.register(O)
1799 self.assertEqual(g(o), "set") # because c.Set is a subclass of
1800 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02001801 class P:
1802 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02001803 p = P()
1804 self.assertEqual(g(p), "base")
1805 c.Iterable.register(P)
1806 self.assertEqual(g(p), "iterable")
1807 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02001808 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02001809 g(p)
Benjamin Petersonab078e92016-07-13 21:13:29 -07001810 self.assertIn(
1811 str(re_one.exception),
1812 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1813 "or <class 'collections.abc.Iterable'>"),
1814 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
1815 "or <class 'collections.abc.Container'>")),
1816 )
Łukasz Langa6f692512013-06-05 12:20:24 +02001817 class Q(c.Sized):
1818 def __len__(self):
1819 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001820 q = Q()
1821 self.assertEqual(g(q), "sized")
1822 c.Iterable.register(Q)
1823 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
1824 c.Set.register(Q)
1825 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02001826 # c.Sized and c.Iterable
1827 @functools.singledispatch
1828 def h(arg):
1829 return "base"
1830 @h.register(c.Sized)
1831 def _(arg):
1832 return "sized"
1833 @h.register(c.Container)
1834 def _(arg):
1835 return "container"
1836 # Even though Sized and Container are explicit bases of MutableMapping,
1837 # this ABC is implicitly registered on defaultdict which makes all of
1838 # MutableMapping's bases implicit as well from defaultdict's
1839 # perspective.
1840 with self.assertRaises(RuntimeError) as re_two:
1841 h(c.defaultdict(lambda: 0))
Benjamin Petersonab078e92016-07-13 21:13:29 -07001842 self.assertIn(
1843 str(re_two.exception),
1844 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1845 "or <class 'collections.abc.Sized'>"),
1846 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1847 "or <class 'collections.abc.Container'>")),
1848 )
Łukasz Langa3720c772013-07-01 16:00:38 +02001849 class R(c.defaultdict):
1850 pass
1851 c.MutableSequence.register(R)
1852 @functools.singledispatch
1853 def i(arg):
1854 return "base"
1855 @i.register(c.MutableMapping)
1856 def _(arg):
1857 return "mapping"
1858 @i.register(c.MutableSequence)
1859 def _(arg):
1860 return "sequence"
1861 r = R()
1862 self.assertEqual(i(r), "sequence")
1863 class S:
1864 pass
1865 class T(S, c.Sized):
1866 def __len__(self):
1867 return 0
1868 t = T()
1869 self.assertEqual(h(t), "sized")
1870 c.Container.register(T)
1871 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
1872 class U:
1873 def __len__(self):
1874 return 0
1875 u = U()
1876 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
1877 # from the existence of __len__()
1878 c.Container.register(U)
1879 # There is no preference for registered versus inferred ABCs.
1880 with self.assertRaises(RuntimeError) as re_three:
1881 h(u)
Benjamin Petersonab078e92016-07-13 21:13:29 -07001882 self.assertIn(
1883 str(re_three.exception),
1884 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1885 "or <class 'collections.abc.Sized'>"),
1886 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1887 "or <class 'collections.abc.Container'>")),
1888 )
Łukasz Langa3720c772013-07-01 16:00:38 +02001889 class V(c.Sized, S):
1890 def __len__(self):
1891 return 0
1892 @functools.singledispatch
1893 def j(arg):
1894 return "base"
1895 @j.register(S)
1896 def _(arg):
1897 return "s"
1898 @j.register(c.Container)
1899 def _(arg):
1900 return "container"
1901 v = V()
1902 self.assertEqual(j(v), "s")
1903 c.Container.register(V)
1904 self.assertEqual(j(v), "container") # because it ends up right after
1905 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02001906
1907 def test_cache_invalidation(self):
1908 from collections import UserDict
1909 class TracingDict(UserDict):
1910 def __init__(self, *args, **kwargs):
1911 super(TracingDict, self).__init__(*args, **kwargs)
1912 self.set_ops = []
1913 self.get_ops = []
1914 def __getitem__(self, key):
1915 result = self.data[key]
1916 self.get_ops.append(key)
1917 return result
1918 def __setitem__(self, key, value):
1919 self.set_ops.append(key)
1920 self.data[key] = value
1921 def clear(self):
1922 self.data.clear()
1923 _orig_wkd = functools.WeakKeyDictionary
1924 td = TracingDict()
1925 functools.WeakKeyDictionary = lambda: td
1926 c = collections
1927 @functools.singledispatch
1928 def g(arg):
1929 return "base"
1930 d = {}
1931 l = []
1932 self.assertEqual(len(td), 0)
1933 self.assertEqual(g(d), "base")
1934 self.assertEqual(len(td), 1)
1935 self.assertEqual(td.get_ops, [])
1936 self.assertEqual(td.set_ops, [dict])
1937 self.assertEqual(td.data[dict], g.registry[object])
1938 self.assertEqual(g(l), "base")
1939 self.assertEqual(len(td), 2)
1940 self.assertEqual(td.get_ops, [])
1941 self.assertEqual(td.set_ops, [dict, list])
1942 self.assertEqual(td.data[dict], g.registry[object])
1943 self.assertEqual(td.data[list], g.registry[object])
1944 self.assertEqual(td.data[dict], td.data[list])
1945 self.assertEqual(g(l), "base")
1946 self.assertEqual(g(d), "base")
1947 self.assertEqual(td.get_ops, [list, dict])
1948 self.assertEqual(td.set_ops, [dict, list])
1949 g.register(list, lambda arg: "list")
1950 self.assertEqual(td.get_ops, [list, dict])
1951 self.assertEqual(len(td), 0)
1952 self.assertEqual(g(d), "base")
1953 self.assertEqual(len(td), 1)
1954 self.assertEqual(td.get_ops, [list, dict])
1955 self.assertEqual(td.set_ops, [dict, list, dict])
1956 self.assertEqual(td.data[dict],
1957 functools._find_impl(dict, g.registry))
1958 self.assertEqual(g(l), "list")
1959 self.assertEqual(len(td), 2)
1960 self.assertEqual(td.get_ops, [list, dict])
1961 self.assertEqual(td.set_ops, [dict, list, dict, list])
1962 self.assertEqual(td.data[list],
1963 functools._find_impl(list, g.registry))
1964 class X:
1965 pass
1966 c.MutableMapping.register(X) # Will not invalidate the cache,
1967 # not using ABCs yet.
1968 self.assertEqual(g(d), "base")
1969 self.assertEqual(g(l), "list")
1970 self.assertEqual(td.get_ops, [list, dict, dict, list])
1971 self.assertEqual(td.set_ops, [dict, list, dict, list])
1972 g.register(c.Sized, lambda arg: "sized")
1973 self.assertEqual(len(td), 0)
1974 self.assertEqual(g(d), "sized")
1975 self.assertEqual(len(td), 1)
1976 self.assertEqual(td.get_ops, [list, dict, dict, list])
1977 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
1978 self.assertEqual(g(l), "list")
1979 self.assertEqual(len(td), 2)
1980 self.assertEqual(td.get_ops, [list, dict, dict, list])
1981 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1982 self.assertEqual(g(l), "list")
1983 self.assertEqual(g(d), "sized")
1984 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
1985 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1986 g.dispatch(list)
1987 g.dispatch(dict)
1988 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
1989 list, dict])
1990 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1991 c.MutableSet.register(X) # Will invalidate the cache.
1992 self.assertEqual(len(td), 2) # Stale cache.
1993 self.assertEqual(g(l), "list")
1994 self.assertEqual(len(td), 1)
1995 g.register(c.MutableMapping, lambda arg: "mutablemapping")
1996 self.assertEqual(len(td), 0)
1997 self.assertEqual(g(d), "mutablemapping")
1998 self.assertEqual(len(td), 1)
1999 self.assertEqual(g(l), "list")
2000 self.assertEqual(len(td), 2)
2001 g.register(dict, lambda arg: "dict")
2002 self.assertEqual(g(d), "dict")
2003 self.assertEqual(g(l), "list")
2004 g._clear_cache()
2005 self.assertEqual(len(td), 0)
2006 functools.WeakKeyDictionary = _orig_wkd
2007
2008
Raymond Hettinger9c323f82005-02-28 19:39:44 +00002009if __name__ == '__main__':
Zachary Ware38c707e2015-04-13 15:00:43 -05002010 unittest.main()