blob: 9ea6747188bd462da94cde37ed81d6dd4057693c [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettinger003be522011-05-03 11:01:32 -07002import collections
Serhiy Storchaka45120f22015-10-24 09:49:56 +03003import copy
Łukasz Langa6f692512013-06-05 12:20:24 +02004from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00005import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00006from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02007import sys
8from test import support
9import unittest
10from weakref import proxy
Nick Coghlan457fc9a2016-09-10 20:00:02 +100011import contextlib
Serhiy Storchaka46c56112015-05-24 21:53:49 +030012try:
13 import threading
14except ImportError:
15 threading = None
Raymond Hettinger9c323f82005-02-28 19:39:44 +000016
Antoine Pitroub5b37142012-11-13 21:35:40 +010017import functools
18
Antoine Pitroub5b37142012-11-13 21:35:40 +010019py_functools = support.import_fresh_module('functools', blocked=['_functools'])
20c_functools = support.import_fresh_module('functools', fresh=['_functools'])
21
Łukasz Langa6f692512013-06-05 12:20:24 +020022decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
23
Nick Coghlan457fc9a2016-09-10 20:00:02 +100024@contextlib.contextmanager
25def replaced_module(name, replacement):
26 original_module = sys.modules[name]
27 sys.modules[name] = replacement
28 try:
29 yield
30 finally:
31 sys.modules[name] = original_module
Łukasz Langa6f692512013-06-05 12:20:24 +020032
Raymond Hettinger9c323f82005-02-28 19:39:44 +000033def capture(*args, **kw):
34 """capture all positional and keyword arguments"""
35 return args, kw
36
Łukasz Langa6f692512013-06-05 12:20:24 +020037
Jack Diederiche0cbd692009-04-01 04:27:09 +000038def signature(part):
39 """ return the signature of a partial object """
40 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000041
Serhiy Storchaka38741282016-02-02 18:45:17 +020042class MyTuple(tuple):
43 pass
44
45class BadTuple(tuple):
46 def __add__(self, other):
47 return list(self) + list(other)
48
49class MyDict(dict):
50 pass
51
Łukasz Langa6f692512013-06-05 12:20:24 +020052
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020053class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000054
55 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010056 p = self.partial(capture, 1, 2, a=10, b=20)
57 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000058 self.assertEqual(p(3, 4, b=30, c=40),
59 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010060 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000061 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000062
63 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010064 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000065 # attributes should be readable
66 self.assertEqual(p.func, capture)
67 self.assertEqual(p.args, (1, 2))
68 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000069
70 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010071 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000072 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010073 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000074 except TypeError:
75 pass
76 else:
77 self.fail('First arg not checked for callability')
78
79 def test_protection_of_callers_dict_argument(self):
80 # a caller's dictionary should not be altered by partial
81 def func(a=10, b=20):
82 return a
83 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010084 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000085 self.assertEqual(p(**d), 3)
86 self.assertEqual(d, {'a':3})
87 p(b=7)
88 self.assertEqual(d, {'a':3})
89
90 def test_arg_combinations(self):
91 # exercise special code paths for zero args in either partial
92 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010093 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000094 self.assertEqual(p(), ((), {}))
95 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010096 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000097 self.assertEqual(p(), ((1,2), {}))
98 self.assertEqual(p(3,4), ((1,2,3,4), {}))
99
100 def test_kw_combinations(self):
101 # exercise special code paths for no keyword args in
102 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100103 p = self.partial(capture)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400104 self.assertEqual(p.keywords, {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000105 self.assertEqual(p(), ((), {}))
106 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100107 p = self.partial(capture, a=1)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400108 self.assertEqual(p.keywords, {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000109 self.assertEqual(p(), ((), {'a':1}))
110 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
111 # keyword args in the call override those in the partial object
112 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
113
114 def test_positional(self):
115 # make sure positional arguments are captured correctly
116 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100117 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000118 expected = args + ('x',)
119 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000120 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000121
122 def test_keyword(self):
123 # make sure keyword arguments are captured correctly
124 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100125 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000126 expected = {'a':a,'x':None}
127 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000128 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000129
130 def test_no_side_effects(self):
131 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100132 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000133 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000134 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000135 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000136 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000137
138 def test_error_propagation(self):
139 def f(x, y):
140 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100141 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
142 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
143 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
144 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000145
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000146 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100147 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000148 p = proxy(f)
149 self.assertEqual(f.func, p.func)
150 f = None
151 self.assertRaises(ReferenceError, getattr, p, 'func')
152
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000153 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000154 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100155 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000156 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100157 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000158 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000159
Alexander Belopolskye49af342015-03-01 15:08:17 -0500160 def test_nested_optimization(self):
161 partial = self.partial
Alexander Belopolskye49af342015-03-01 15:08:17 -0500162 inner = partial(signature, 'asdf')
163 nested = partial(inner, bar=True)
164 flat = partial(signature, 'asdf', bar=True)
165 self.assertEqual(signature(nested), signature(flat))
166
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300167 def test_nested_partial_with_attribute(self):
168 # see issue 25137
169 partial = self.partial
170
171 def foo(bar):
172 return bar
173
174 p = partial(foo, 'first')
175 p2 = partial(p, 'second')
176 p2.new_attr = 'spam'
177 self.assertEqual(p2.new_attr, 'spam')
178
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000179 def test_repr(self):
180 args = (object(), object())
181 args_repr = ', '.join(repr(a) for a in args)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200182 kwargs = {'a': object(), 'b': object()}
183 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
184 'b={b!r}, a={a!r}'.format_map(kwargs)]
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000185 if self.partial in (c_functools.partial, py_functools.partial):
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000186 name = 'functools.partial'
187 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100188 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000189
Antoine Pitroub5b37142012-11-13 21:35:40 +0100190 f = self.partial(capture)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000191 self.assertEqual(f'{name}({capture!r})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000192
Antoine Pitroub5b37142012-11-13 21:35:40 +0100193 f = self.partial(capture, *args)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000194 self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000195
Antoine Pitroub5b37142012-11-13 21:35:40 +0100196 f = self.partial(capture, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200197 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000198 [f'{name}({capture!r}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200199 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000200
Antoine Pitroub5b37142012-11-13 21:35:40 +0100201 f = self.partial(capture, *args, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200202 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000203 [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200204 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000205
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300206 def test_recursive_repr(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000207 if self.partial in (c_functools.partial, py_functools.partial):
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300208 name = 'functools.partial'
209 else:
210 name = self.partial.__name__
211
212 f = self.partial(capture)
213 f.__setstate__((f, (), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300214 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000215 self.assertEqual(repr(f), '%s(...)' % (name,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300216 finally:
217 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300218
219 f = self.partial(capture)
220 f.__setstate__((capture, (f,), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300221 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000222 self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300223 finally:
224 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300225
226 f = self.partial(capture)
227 f.__setstate__((capture, (), {'a': f}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300228 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000229 self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300230 finally:
231 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300232
Jack Diederiche0cbd692009-04-01 04:27:09 +0000233 def test_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000234 with self.AllowPickle():
235 f = self.partial(signature, ['asdf'], bar=[True])
236 f.attr = []
237 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
238 f_copy = pickle.loads(pickle.dumps(f, proto))
239 self.assertEqual(signature(f_copy), signature(f))
Serhiy Storchaka38741282016-02-02 18:45:17 +0200240
241 def test_copy(self):
242 f = self.partial(signature, ['asdf'], bar=[True])
243 f.attr = []
244 f_copy = copy.copy(f)
245 self.assertEqual(signature(f_copy), signature(f))
246 self.assertIs(f_copy.attr, f.attr)
247 self.assertIs(f_copy.args, f.args)
248 self.assertIs(f_copy.keywords, f.keywords)
249
250 def test_deepcopy(self):
251 f = self.partial(signature, ['asdf'], bar=[True])
252 f.attr = []
253 f_copy = copy.deepcopy(f)
254 self.assertEqual(signature(f_copy), signature(f))
255 self.assertIsNot(f_copy.attr, f.attr)
256 self.assertIsNot(f_copy.args, f.args)
257 self.assertIsNot(f_copy.args[0], f.args[0])
258 self.assertIsNot(f_copy.keywords, f.keywords)
259 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
260
261 def test_setstate(self):
262 f = self.partial(signature)
263 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000264
Serhiy Storchaka38741282016-02-02 18:45:17 +0200265 self.assertEqual(signature(f),
266 (capture, (1,), dict(a=10), dict(attr=[])))
267 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
268
269 f.__setstate__((capture, (1,), dict(a=10), None))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000270
Serhiy Storchaka38741282016-02-02 18:45:17 +0200271 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
272 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
273
274 f.__setstate__((capture, (1,), None, None))
275 #self.assertEqual(signature(f), (capture, (1,), {}, {}))
276 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
277 self.assertEqual(f(2), ((1, 2), {}))
278 self.assertEqual(f(), ((1,), {}))
279
280 f.__setstate__((capture, (), {}, None))
281 self.assertEqual(signature(f), (capture, (), {}, {}))
282 self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
283 self.assertEqual(f(2), ((2,), {}))
284 self.assertEqual(f(), ((), {}))
285
286 def test_setstate_errors(self):
287 f = self.partial(signature)
288 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
289 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
290 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
291 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
292 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
293 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
294 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
295
296 def test_setstate_subclasses(self):
297 f = self.partial(signature)
298 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
299 s = signature(f)
300 self.assertEqual(s, (capture, (1,), dict(a=10), {}))
301 self.assertIs(type(s[1]), tuple)
302 self.assertIs(type(s[2]), dict)
303 r = f()
304 self.assertEqual(r, ((1,), {'a': 10}))
305 self.assertIs(type(r[0]), tuple)
306 self.assertIs(type(r[1]), dict)
307
308 f.__setstate__((capture, BadTuple((1,)), {}, None))
309 s = signature(f)
310 self.assertEqual(s, (capture, (1,), {}, {}))
311 self.assertIs(type(s[1]), tuple)
312 r = f(2)
313 self.assertEqual(r, ((1, 2), {}))
314 self.assertIs(type(r[0]), tuple)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000315
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300316 def test_recursive_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000317 with self.AllowPickle():
318 f = self.partial(capture)
319 f.__setstate__((f, (), {}, {}))
320 try:
321 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
322 with self.assertRaises(RecursionError):
323 pickle.dumps(f, proto)
324 finally:
325 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300326
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000327 f = self.partial(capture)
328 f.__setstate__((capture, (f,), {}, {}))
329 try:
330 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
331 f_copy = pickle.loads(pickle.dumps(f, proto))
332 try:
333 self.assertIs(f_copy.args[0], f_copy)
334 finally:
335 f_copy.__setstate__((capture, (), {}, {}))
336 finally:
337 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300338
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000339 f = self.partial(capture)
340 f.__setstate__((capture, (), {'a': f}, {}))
341 try:
342 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
343 f_copy = pickle.loads(pickle.dumps(f, proto))
344 try:
345 self.assertIs(f_copy.keywords['a'], f_copy)
346 finally:
347 f_copy.__setstate__((capture, (), {}, {}))
348 finally:
349 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300350
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200351 # Issue 6083: Reference counting bug
352 def test_setstate_refcount(self):
353 class BadSequence:
354 def __len__(self):
355 return 4
356 def __getitem__(self, key):
357 if key == 0:
358 return max
359 elif key == 1:
360 return tuple(range(1000000))
361 elif key in (2, 3):
362 return {}
363 raise IndexError
364
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200365 f = self.partial(object)
Serhiy Storchaka38741282016-02-02 18:45:17 +0200366 self.assertRaises(TypeError, f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000367
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000368@unittest.skipUnless(c_functools, 'requires the C _functools module')
369class TestPartialC(TestPartial, unittest.TestCase):
370 if c_functools:
371 partial = c_functools.partial
372
373 class AllowPickle:
374 def __enter__(self):
375 return self
376 def __exit__(self, type, value, tb):
377 return False
378
379 def test_attributes_unwritable(self):
380 # attributes should not be writable
381 p = self.partial(capture, 1, 2, a=10, b=20)
382 self.assertRaises(AttributeError, setattr, p, 'func', map)
383 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
384 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
385
386 p = self.partial(hex)
387 try:
388 del p.__dict__
389 except TypeError:
390 pass
391 else:
392 self.fail('partial object allowed __dict__ to be deleted')
Łukasz Langa6f692512013-06-05 12:20:24 +0200393
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200394class TestPartialPy(TestPartial, unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000395 partial = py_functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000396
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000397 class AllowPickle:
398 def __init__(self):
399 self._cm = replaced_module("functools", py_functools)
400 def __enter__(self):
401 return self._cm.__enter__()
402 def __exit__(self, type, value, tb):
403 return self._cm.__exit__(type, value, tb)
Łukasz Langa6f692512013-06-05 12:20:24 +0200404
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200405if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000406 class CPartialSubclass(c_functools.partial):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200407 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100408
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000409class PyPartialSubclass(py_functools.partial):
410 pass
Łukasz Langa6f692512013-06-05 12:20:24 +0200411
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200412@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200413class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200414 if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000415 partial = CPartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000416
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300417 # partial subclasses are not optimized for nested calls
418 test_nested_optimization = None
419
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000420class TestPartialPySubclass(TestPartialPy):
421 partial = PyPartialSubclass
Łukasz Langa6f692512013-06-05 12:20:24 +0200422
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000423class TestPartialMethod(unittest.TestCase):
424
425 class A(object):
426 nothing = functools.partialmethod(capture)
427 positional = functools.partialmethod(capture, 1)
428 keywords = functools.partialmethod(capture, a=2)
429 both = functools.partialmethod(capture, 3, b=4)
430
431 nested = functools.partialmethod(positional, 5)
432
433 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
434
435 static = functools.partialmethod(staticmethod(capture), 8)
436 cls = functools.partialmethod(classmethod(capture), d=9)
437
438 a = A()
439
440 def test_arg_combinations(self):
441 self.assertEqual(self.a.nothing(), ((self.a,), {}))
442 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
443 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
444 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
445
446 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
447 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
448 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
449 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
450
451 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
452 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
453 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
454 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
455
456 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
457 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
458 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
459 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
460
461 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
462
463 def test_nested(self):
464 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
465 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
466 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
467 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
468
469 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
470
471 def test_over_partial(self):
472 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
473 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
474 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
475 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
476
477 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
478
479 def test_bound_method_introspection(self):
480 obj = self.a
481 self.assertIs(obj.both.__self__, obj)
482 self.assertIs(obj.nested.__self__, obj)
483 self.assertIs(obj.over_partial.__self__, obj)
484 self.assertIs(obj.cls.__self__, self.A)
485 self.assertIs(self.A.cls.__self__, self.A)
486
487 def test_unbound_method_retrieval(self):
488 obj = self.A
489 self.assertFalse(hasattr(obj.both, "__self__"))
490 self.assertFalse(hasattr(obj.nested, "__self__"))
491 self.assertFalse(hasattr(obj.over_partial, "__self__"))
492 self.assertFalse(hasattr(obj.static, "__self__"))
493 self.assertFalse(hasattr(self.a.static, "__self__"))
494
495 def test_descriptors(self):
496 for obj in [self.A, self.a]:
497 with self.subTest(obj=obj):
498 self.assertEqual(obj.static(), ((8,), {}))
499 self.assertEqual(obj.static(5), ((8, 5), {}))
500 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
501 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
502
503 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
504 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
505 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
506 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
507
508 def test_overriding_keywords(self):
509 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
510 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
511
512 def test_invalid_args(self):
513 with self.assertRaises(TypeError):
514 class B(object):
515 method = functools.partialmethod(None, 1)
516
517 def test_repr(self):
518 self.assertEqual(repr(vars(self.A)['both']),
519 'functools.partialmethod({}, 3, b=4)'.format(capture))
520
521 def test_abstract(self):
522 class Abstract(abc.ABCMeta):
523
524 @abc.abstractmethod
525 def add(self, x, y):
526 pass
527
528 add5 = functools.partialmethod(add, 5)
529
530 self.assertTrue(Abstract.add.__isabstractmethod__)
531 self.assertTrue(Abstract.add5.__isabstractmethod__)
532
533 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
534 self.assertFalse(getattr(func, '__isabstractmethod__', False))
535
536
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000537class TestUpdateWrapper(unittest.TestCase):
538
539 def check_wrapper(self, wrapper, wrapped,
540 assigned=functools.WRAPPER_ASSIGNMENTS,
541 updated=functools.WRAPPER_UPDATES):
542 # Check attributes were assigned
543 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000544 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000545 # Check attributes were updated
546 for name in updated:
547 wrapper_attr = getattr(wrapper, name)
548 wrapped_attr = getattr(wrapped, name)
549 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000550 if name == "__dict__" and key == "__wrapped__":
551 # __wrapped__ is overwritten by the update code
552 continue
553 self.assertIs(wrapped_attr[key], wrapper_attr[key])
554 # Check __wrapped__
555 self.assertIs(wrapper.__wrapped__, wrapped)
556
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000557
R. David Murray378c0cf2010-02-24 01:46:21 +0000558 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000559 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000560 """This is a test"""
561 pass
562 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000563 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000564 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000565 pass
566 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000567 return wrapper, f
568
569 def test_default_update(self):
570 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000571 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000572 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000573 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600574 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000575 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000576 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
577 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000578
R. David Murray378c0cf2010-02-24 01:46:21 +0000579 @unittest.skipIf(sys.flags.optimize >= 2,
580 "Docstrings are omitted with -O2 and above")
581 def test_default_update_doc(self):
582 wrapper, f = self._default_update()
583 self.assertEqual(wrapper.__doc__, 'This is a test')
584
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000585 def test_no_update(self):
586 def f():
587 """This is a test"""
588 pass
589 f.attr = 'This is also a test'
590 def wrapper():
591 pass
592 functools.update_wrapper(wrapper, f, (), ())
593 self.check_wrapper(wrapper, f, (), ())
594 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600595 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000596 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000597 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000598 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000599
600 def test_selective_update(self):
601 def f():
602 pass
603 f.attr = 'This is a different test'
604 f.dict_attr = dict(a=1, b=2, c=3)
605 def wrapper():
606 pass
607 wrapper.dict_attr = {}
608 assign = ('attr',)
609 update = ('dict_attr',)
610 functools.update_wrapper(wrapper, f, assign, update)
611 self.check_wrapper(wrapper, f, assign, update)
612 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600613 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000614 self.assertEqual(wrapper.__doc__, None)
615 self.assertEqual(wrapper.attr, 'This is a different test')
616 self.assertEqual(wrapper.dict_attr, f.dict_attr)
617
Nick Coghlan98876832010-08-17 06:17:18 +0000618 def test_missing_attributes(self):
619 def f():
620 pass
621 def wrapper():
622 pass
623 wrapper.dict_attr = {}
624 assign = ('attr',)
625 update = ('dict_attr',)
626 # Missing attributes on wrapped object are ignored
627 functools.update_wrapper(wrapper, f, assign, update)
628 self.assertNotIn('attr', wrapper.__dict__)
629 self.assertEqual(wrapper.dict_attr, {})
630 # Wrapper must have expected attributes for updating
631 del wrapper.dict_attr
632 with self.assertRaises(AttributeError):
633 functools.update_wrapper(wrapper, f, assign, update)
634 wrapper.dict_attr = 1
635 with self.assertRaises(AttributeError):
636 functools.update_wrapper(wrapper, f, assign, update)
637
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200638 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000639 @unittest.skipIf(sys.flags.optimize >= 2,
640 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000641 def test_builtin_update(self):
642 # Test for bug #1576241
643 def wrapper():
644 pass
645 functools.update_wrapper(wrapper, max)
646 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000647 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000648 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000649
Łukasz Langa6f692512013-06-05 12:20:24 +0200650
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000651class TestWraps(TestUpdateWrapper):
652
R. David Murray378c0cf2010-02-24 01:46:21 +0000653 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000654 def f():
655 """This is a test"""
656 pass
657 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000658 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000659 @functools.wraps(f)
660 def wrapper():
661 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600662 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000663
664 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600665 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000666 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000667 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600668 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000669 self.assertEqual(wrapper.attr, 'This is also a test')
670
Antoine Pitroub5b37142012-11-13 21:35:40 +0100671 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000672 "Docstrings are omitted with -O2 and above")
673 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600674 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000675 self.assertEqual(wrapper.__doc__, 'This is a test')
676
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000677 def test_no_update(self):
678 def f():
679 """This is a test"""
680 pass
681 f.attr = 'This is also a test'
682 @functools.wraps(f, (), ())
683 def wrapper():
684 pass
685 self.check_wrapper(wrapper, f, (), ())
686 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600687 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000688 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000689 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000690
691 def test_selective_update(self):
692 def f():
693 pass
694 f.attr = 'This is a different test'
695 f.dict_attr = dict(a=1, b=2, c=3)
696 def add_dict_attr(f):
697 f.dict_attr = {}
698 return f
699 assign = ('attr',)
700 update = ('dict_attr',)
701 @functools.wraps(f, assign, update)
702 @add_dict_attr
703 def wrapper():
704 pass
705 self.check_wrapper(wrapper, f, assign, update)
706 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600707 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000708 self.assertEqual(wrapper.__doc__, None)
709 self.assertEqual(wrapper.attr, 'This is a different test')
710 self.assertEqual(wrapper.dict_attr, f.dict_attr)
711
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000712@unittest.skipUnless(c_functools, 'requires the C _functools module')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000713class TestReduce(unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000714 if c_functools:
715 func = c_functools.reduce
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000716
717 def test_reduce(self):
718 class Squares:
719 def __init__(self, max):
720 self.max = max
721 self.sofar = []
722
723 def __len__(self):
724 return len(self.sofar)
725
726 def __getitem__(self, i):
727 if not 0 <= i < self.max: raise IndexError
728 n = len(self.sofar)
729 while n <= i:
730 self.sofar.append(n*n)
731 n += 1
732 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000733 def add(x, y):
734 return x + y
735 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000736 self.assertEqual(
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000737 self.func(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000738 ['a','c','d','w']
739 )
740 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
741 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000742 self.func(lambda x, y: x*y, range(2,21), 1),
743 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000744 )
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000745 self.assertEqual(self.func(add, Squares(10)), 285)
746 self.assertEqual(self.func(add, Squares(10), 0), 285)
747 self.assertEqual(self.func(add, Squares(0), 0), 0)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000748 self.assertRaises(TypeError, self.func)
749 self.assertRaises(TypeError, self.func, 42, 42)
750 self.assertRaises(TypeError, self.func, 42, 42, 42)
751 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
752 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
753 self.assertRaises(TypeError, self.func, 42, (42, 42))
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000754 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
755 self.assertRaises(TypeError, self.func, add, "")
756 self.assertRaises(TypeError, self.func, add, ())
757 self.assertRaises(TypeError, self.func, add, object())
758
759 class TestFailingIter:
760 def __iter__(self):
761 raise RuntimeError
762 self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
763
764 self.assertEqual(self.func(add, [], None), None)
765 self.assertEqual(self.func(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000766
767 class BadSeq:
768 def __getitem__(self, index):
769 raise ValueError
770 self.assertRaises(ValueError, self.func, 42, BadSeq())
771
772 # Test reduce()'s use of iterators.
773 def test_iterator_usage(self):
774 class SequenceClass:
775 def __init__(self, n):
776 self.n = n
777 def __getitem__(self, i):
778 if 0 <= i < self.n:
779 return i
780 else:
781 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000782
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000783 from operator import add
784 self.assertEqual(self.func(add, SequenceClass(5)), 10)
785 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
786 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
787 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
788 self.assertEqual(self.func(add, SequenceClass(1)), 0)
789 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
790
791 d = {"one": 1, "two": 2, "three": 3}
792 self.assertEqual(self.func(add, d), "".join(d.keys()))
793
Łukasz Langa6f692512013-06-05 12:20:24 +0200794
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200795class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700796
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000797 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700798 def cmp1(x, y):
799 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100800 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700801 self.assertEqual(key(3), key(3))
802 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100803 self.assertGreaterEqual(key(3), key(3))
804
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700805 def cmp2(x, y):
806 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100807 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700808 self.assertEqual(key(4.0), key('4'))
809 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100810 self.assertLessEqual(key(2), key('35'))
811 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700812
813 def test_cmp_to_key_arguments(self):
814 def cmp1(x, y):
815 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100816 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700817 self.assertEqual(key(obj=3), key(obj=3))
818 self.assertGreater(key(obj=3), key(obj=1))
819 with self.assertRaises((TypeError, AttributeError)):
820 key(3) > 1 # rhs is not a K object
821 with self.assertRaises((TypeError, AttributeError)):
822 1 < key(3) # lhs is not a K object
823 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100824 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700825 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200826 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100827 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700828 with self.assertRaises(TypeError):
829 key() # too few args
830 with self.assertRaises(TypeError):
831 key(None, None) # too many args
832
833 def test_bad_cmp(self):
834 def cmp1(x, y):
835 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100836 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700837 with self.assertRaises(ZeroDivisionError):
838 key(3) > key(1)
839
840 class BadCmp:
841 def __lt__(self, other):
842 raise ZeroDivisionError
843 def cmp1(x, y):
844 return BadCmp()
845 with self.assertRaises(ZeroDivisionError):
846 key(3) > key(1)
847
848 def test_obj_field(self):
849 def cmp1(x, y):
850 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100851 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700852 self.assertEqual(key(50).obj, 50)
853
854 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000855 def mycmp(x, y):
856 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100857 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000858 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000859
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700860 def test_sort_int_str(self):
861 def mycmp(x, y):
862 x, y = int(x), int(y)
863 return (x > y) - (x < y)
864 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100865 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700866 self.assertEqual([int(value) for value in values],
867 [0, 1, 1, 2, 3, 4, 5, 7, 10])
868
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000869 def test_hash(self):
870 def mycmp(x, y):
871 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100872 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000873 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700874 self.assertRaises(TypeError, hash, k)
Raymond Hettingere7a24302011-05-03 11:16:36 -0700875 self.assertNotIsInstance(k, collections.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000876
Łukasz Langa6f692512013-06-05 12:20:24 +0200877
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200878@unittest.skipUnless(c_functools, 'requires the C _functools module')
879class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
880 if c_functools:
881 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100882
Łukasz Langa6f692512013-06-05 12:20:24 +0200883
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200884class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100885 cmp_to_key = staticmethod(py_functools.cmp_to_key)
886
Łukasz Langa6f692512013-06-05 12:20:24 +0200887
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000888class TestTotalOrdering(unittest.TestCase):
889
890 def test_total_ordering_lt(self):
891 @functools.total_ordering
892 class A:
893 def __init__(self, value):
894 self.value = value
895 def __lt__(self, other):
896 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000897 def __eq__(self, other):
898 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000899 self.assertTrue(A(1) < A(2))
900 self.assertTrue(A(2) > A(1))
901 self.assertTrue(A(1) <= A(2))
902 self.assertTrue(A(2) >= A(1))
903 self.assertTrue(A(2) <= A(2))
904 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000905 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000906
907 def test_total_ordering_le(self):
908 @functools.total_ordering
909 class A:
910 def __init__(self, value):
911 self.value = value
912 def __le__(self, other):
913 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000914 def __eq__(self, other):
915 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000916 self.assertTrue(A(1) < A(2))
917 self.assertTrue(A(2) > A(1))
918 self.assertTrue(A(1) <= A(2))
919 self.assertTrue(A(2) >= A(1))
920 self.assertTrue(A(2) <= A(2))
921 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000922 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000923
924 def test_total_ordering_gt(self):
925 @functools.total_ordering
926 class A:
927 def __init__(self, value):
928 self.value = value
929 def __gt__(self, other):
930 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000931 def __eq__(self, other):
932 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000933 self.assertTrue(A(1) < A(2))
934 self.assertTrue(A(2) > A(1))
935 self.assertTrue(A(1) <= A(2))
936 self.assertTrue(A(2) >= A(1))
937 self.assertTrue(A(2) <= A(2))
938 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000939 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000940
941 def test_total_ordering_ge(self):
942 @functools.total_ordering
943 class A:
944 def __init__(self, value):
945 self.value = value
946 def __ge__(self, other):
947 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000948 def __eq__(self, other):
949 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000950 self.assertTrue(A(1) < A(2))
951 self.assertTrue(A(2) > A(1))
952 self.assertTrue(A(1) <= A(2))
953 self.assertTrue(A(2) >= A(1))
954 self.assertTrue(A(2) <= A(2))
955 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000956 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000957
958 def test_total_ordering_no_overwrite(self):
959 # new methods should not overwrite existing
960 @functools.total_ordering
961 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000962 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +0000963 self.assertTrue(A(1) < A(2))
964 self.assertTrue(A(2) > A(1))
965 self.assertTrue(A(1) <= A(2))
966 self.assertTrue(A(2) >= A(1))
967 self.assertTrue(A(2) <= A(2))
968 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000969
Benjamin Peterson42ebee32010-04-11 01:43:16 +0000970 def test_no_operations_defined(self):
971 with self.assertRaises(ValueError):
972 @functools.total_ordering
973 class A:
974 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000975
Nick Coghlanf05d9812013-10-02 00:02:03 +1000976 def test_type_error_when_not_implemented(self):
977 # bug 10042; ensure stack overflow does not occur
978 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000979 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +1000980 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000981 def __init__(self, value):
982 self.value = value
983 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000984 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000985 return self.value == other.value
986 return False
987 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000988 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000989 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +1000990 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000991
Nick Coghlanf05d9812013-10-02 00:02:03 +1000992 @functools.total_ordering
993 class ImplementsGreaterThan:
994 def __init__(self, value):
995 self.value = value
996 def __eq__(self, other):
997 if isinstance(other, ImplementsGreaterThan):
998 return self.value == other.value
999 return False
1000 def __gt__(self, other):
1001 if isinstance(other, ImplementsGreaterThan):
1002 return self.value > other.value
1003 return NotImplemented
1004
1005 @functools.total_ordering
1006 class ImplementsLessThanEqualTo:
1007 def __init__(self, value):
1008 self.value = value
1009 def __eq__(self, other):
1010 if isinstance(other, ImplementsLessThanEqualTo):
1011 return self.value == other.value
1012 return False
1013 def __le__(self, other):
1014 if isinstance(other, ImplementsLessThanEqualTo):
1015 return self.value <= other.value
1016 return NotImplemented
1017
1018 @functools.total_ordering
1019 class ImplementsGreaterThanEqualTo:
1020 def __init__(self, value):
1021 self.value = value
1022 def __eq__(self, other):
1023 if isinstance(other, ImplementsGreaterThanEqualTo):
1024 return self.value == other.value
1025 return False
1026 def __ge__(self, other):
1027 if isinstance(other, ImplementsGreaterThanEqualTo):
1028 return self.value >= other.value
1029 return NotImplemented
1030
1031 @functools.total_ordering
1032 class ComparatorNotImplemented:
1033 def __init__(self, value):
1034 self.value = value
1035 def __eq__(self, other):
1036 if isinstance(other, ComparatorNotImplemented):
1037 return self.value == other.value
1038 return False
1039 def __lt__(self, other):
1040 return NotImplemented
1041
1042 with self.subTest("LT < 1"), self.assertRaises(TypeError):
1043 ImplementsLessThan(-1) < 1
1044
1045 with self.subTest("LT < LE"), self.assertRaises(TypeError):
1046 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
1047
1048 with self.subTest("LT < GT"), self.assertRaises(TypeError):
1049 ImplementsLessThan(1) < ImplementsGreaterThan(1)
1050
1051 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
1052 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
1053
1054 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
1055 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
1056
1057 with self.subTest("GT > GE"), self.assertRaises(TypeError):
1058 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
1059
1060 with self.subTest("GT > LT"), self.assertRaises(TypeError):
1061 ImplementsGreaterThan(5) > ImplementsLessThan(5)
1062
1063 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
1064 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
1065
1066 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
1067 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
1068
1069 with self.subTest("GE when equal"):
1070 a = ComparatorNotImplemented(8)
1071 b = ComparatorNotImplemented(8)
1072 self.assertEqual(a, b)
1073 with self.assertRaises(TypeError):
1074 a >= b
1075
1076 with self.subTest("LE when equal"):
1077 a = ComparatorNotImplemented(9)
1078 b = ComparatorNotImplemented(9)
1079 self.assertEqual(a, b)
1080 with self.assertRaises(TypeError):
1081 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +02001082
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001083 def test_pickle(self):
Serhiy Storchaka92bb90a2016-09-22 11:39:25 +03001084 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001085 for name in '__lt__', '__gt__', '__le__', '__ge__':
1086 with self.subTest(method=name, proto=proto):
1087 method = getattr(Orderable_LT, name)
1088 method_copy = pickle.loads(pickle.dumps(method, proto))
1089 self.assertIs(method_copy, method)
1090
1091@functools.total_ordering
1092class Orderable_LT:
1093 def __init__(self, value):
1094 self.value = value
1095 def __lt__(self, other):
1096 return self.value < other.value
1097 def __eq__(self, other):
1098 return self.value == other.value
1099
1100
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001101class TestLRU:
Georg Brandl2e7346a2010-07-31 18:09:23 +00001102
1103 def test_lru(self):
1104 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001105 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001106 f = self.module.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001107 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001108 self.assertEqual(maxsize, 20)
1109 self.assertEqual(currsize, 0)
1110 self.assertEqual(hits, 0)
1111 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001112
1113 domain = range(5)
1114 for i in range(1000):
1115 x, y = choice(domain), choice(domain)
1116 actual = f(x, y)
1117 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +00001118 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001119 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001120 self.assertTrue(hits > misses)
1121 self.assertEqual(hits + misses, 1000)
1122 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001123
Raymond Hettinger02566ec2010-09-04 22:46:06 +00001124 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +00001125 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001126 self.assertEqual(hits, 0)
1127 self.assertEqual(misses, 0)
1128 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001129 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001130 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001131 self.assertEqual(hits, 0)
1132 self.assertEqual(misses, 1)
1133 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001134
Nick Coghlan98876832010-08-17 06:17:18 +00001135 # Test bypassing the cache
1136 self.assertIs(f.__wrapped__, orig)
1137 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001138 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001139 self.assertEqual(hits, 0)
1140 self.assertEqual(misses, 1)
1141 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +00001142
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001143 # test size zero (which means "never-cache")
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001144 @self.module.lru_cache(0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001145 def f():
1146 nonlocal f_cnt
1147 f_cnt += 1
1148 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001149 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001150 f_cnt = 0
1151 for i in range(5):
1152 self.assertEqual(f(), 20)
1153 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001154 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001155 self.assertEqual(hits, 0)
1156 self.assertEqual(misses, 5)
1157 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001158
1159 # test size one
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001160 @self.module.lru_cache(1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001161 def f():
1162 nonlocal f_cnt
1163 f_cnt += 1
1164 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001165 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001166 f_cnt = 0
1167 for i in range(5):
1168 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001169 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001170 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001171 self.assertEqual(hits, 4)
1172 self.assertEqual(misses, 1)
1173 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001174
Raymond Hettingerf3098282010-08-15 03:30:45 +00001175 # test size two
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001176 @self.module.lru_cache(2)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001177 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001178 nonlocal f_cnt
1179 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +00001180 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +00001181 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001182 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +00001183 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1184 # * * * *
1185 self.assertEqual(f(x), x*10)
1186 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001187 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001188 self.assertEqual(hits, 12)
1189 self.assertEqual(misses, 4)
1190 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001191
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001192 def test_lru_with_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001193 @self.module.lru_cache(maxsize=None)
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001194 def fib(n):
1195 if n < 2:
1196 return n
1197 return fib(n-1) + fib(n-2)
1198 self.assertEqual([fib(n) for n in range(16)],
1199 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1200 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001201 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001202 fib.cache_clear()
1203 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001204 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1205
1206 def test_lru_with_maxsize_negative(self):
1207 @self.module.lru_cache(maxsize=-10)
1208 def eq(n):
1209 return n
1210 for i in (0, 1):
1211 self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1212 self.assertEqual(eq.cache_info(),
1213 self.module._CacheInfo(hits=0, misses=300, maxsize=-10, currsize=1))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001214
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001215 def test_lru_with_exceptions(self):
1216 # Verify that user_function exceptions get passed through without
1217 # creating a hard-to-read chained exception.
1218 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +01001219 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001220 @self.module.lru_cache(maxsize)
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001221 def func(i):
1222 return 'abc'[i]
1223 self.assertEqual(func(0), 'a')
1224 with self.assertRaises(IndexError) as cm:
1225 func(15)
1226 self.assertIsNone(cm.exception.__context__)
1227 # Verify that the previous exception did not result in a cached entry
1228 with self.assertRaises(IndexError):
1229 func(15)
1230
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001231 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001232 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001233 @self.module.lru_cache(maxsize=maxsize, typed=True)
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001234 def square(x):
1235 return x * x
1236 self.assertEqual(square(3), 9)
1237 self.assertEqual(type(square(3)), type(9))
1238 self.assertEqual(square(3.0), 9.0)
1239 self.assertEqual(type(square(3.0)), type(9.0))
1240 self.assertEqual(square(x=3), 9)
1241 self.assertEqual(type(square(x=3)), type(9))
1242 self.assertEqual(square(x=3.0), 9.0)
1243 self.assertEqual(type(square(x=3.0)), type(9.0))
1244 self.assertEqual(square.cache_info().hits, 4)
1245 self.assertEqual(square.cache_info().misses, 4)
1246
Antoine Pitroub5b37142012-11-13 21:35:40 +01001247 def test_lru_with_keyword_args(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001248 @self.module.lru_cache()
Antoine Pitroub5b37142012-11-13 21:35:40 +01001249 def fib(n):
1250 if n < 2:
1251 return n
1252 return fib(n=n-1) + fib(n=n-2)
1253 self.assertEqual(
1254 [fib(n=number) for number in range(16)],
1255 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1256 )
1257 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001258 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001259 fib.cache_clear()
1260 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001261 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001262
1263 def test_lru_with_keyword_args_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001264 @self.module.lru_cache(maxsize=None)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001265 def fib(n):
1266 if n < 2:
1267 return n
1268 return fib(n=n-1) + fib(n=n-2)
1269 self.assertEqual([fib(n=number) for number in range(16)],
1270 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1271 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001272 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001273 fib.cache_clear()
1274 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001275 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1276
1277 def test_lru_cache_decoration(self):
1278 def f(zomg: 'zomg_annotation'):
1279 """f doc string"""
1280 return 42
1281 g = self.module.lru_cache()(f)
1282 for attr in self.module.WRAPPER_ASSIGNMENTS:
1283 self.assertEqual(getattr(g, attr), getattr(f, attr))
1284
1285 @unittest.skipUnless(threading, 'This test requires threading.')
1286 def test_lru_cache_threaded(self):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001287 n, m = 5, 11
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001288 def orig(x, y):
1289 return 3 * x + y
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001290 f = self.module.lru_cache(maxsize=n*m)(orig)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001291 hits, misses, maxsize, currsize = f.cache_info()
1292 self.assertEqual(currsize, 0)
1293
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001294 start = threading.Event()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001295 def full(k):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001296 start.wait(10)
1297 for _ in range(m):
Serhiy Storchaka391af752015-06-08 12:44:18 +03001298 self.assertEqual(f(k, 0), orig(k, 0))
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001299
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001300 def clear():
1301 start.wait(10)
1302 for _ in range(2*m):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001303 f.cache_clear()
1304
1305 orig_si = sys.getswitchinterval()
1306 sys.setswitchinterval(1e-6)
1307 try:
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001308 # create n threads in order to fill cache
Serhiy Storchaka391af752015-06-08 12:44:18 +03001309 threads = [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001310 for k in range(n)]
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001311 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001312 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001313
1314 hits, misses, maxsize, currsize = f.cache_info()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001315 if self.module is py_functools:
1316 # XXX: Why can be not equal?
1317 self.assertLessEqual(misses, n)
1318 self.assertLessEqual(hits, m*n - misses)
1319 else:
1320 self.assertEqual(misses, n)
1321 self.assertEqual(hits, m*n - misses)
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001322 self.assertEqual(currsize, n)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001323
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001324 # create n threads in order to fill cache and 1 to clear it
1325 threads = [threading.Thread(target=clear)]
Serhiy Storchaka391af752015-06-08 12:44:18 +03001326 threads += [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001327 for k in range(n)]
1328 start.clear()
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001329 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001330 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001331 finally:
1332 sys.setswitchinterval(orig_si)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001333
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001334 @unittest.skipUnless(threading, 'This test requires threading.')
1335 def test_lru_cache_threaded2(self):
1336 # Simultaneous call with the same arguments
1337 n, m = 5, 7
1338 start = threading.Barrier(n+1)
1339 pause = threading.Barrier(n+1)
1340 stop = threading.Barrier(n+1)
1341 @self.module.lru_cache(maxsize=m*n)
1342 def f(x):
1343 pause.wait(10)
1344 return 3 * x
1345 self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1346 def test():
1347 for i in range(m):
1348 start.wait(10)
1349 self.assertEqual(f(i), 3 * i)
1350 stop.wait(10)
1351 threads = [threading.Thread(target=test) for k in range(n)]
1352 with support.start_threads(threads):
1353 for i in range(m):
1354 start.wait(10)
1355 stop.reset()
1356 pause.wait(10)
1357 start.reset()
1358 stop.wait(10)
1359 pause.reset()
1360 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1361
Raymond Hettinger03923422013-03-04 02:52:50 -05001362 def test_need_for_rlock(self):
1363 # This will deadlock on an LRU cache that uses a regular lock
1364
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001365 @self.module.lru_cache(maxsize=10)
Raymond Hettinger03923422013-03-04 02:52:50 -05001366 def test_func(x):
1367 'Used to demonstrate a reentrant lru_cache call within a single thread'
1368 return x
1369
1370 class DoubleEq:
1371 'Demonstrate a reentrant lru_cache call within a single thread'
1372 def __init__(self, x):
1373 self.x = x
1374 def __hash__(self):
1375 return self.x
1376 def __eq__(self, other):
1377 if self.x == 2:
1378 test_func(DoubleEq(1))
1379 return self.x == other.x
1380
1381 test_func(DoubleEq(1)) # Load the cache
1382 test_func(DoubleEq(2)) # Load the cache
1383 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1384 DoubleEq(2)) # Verify the correct return value
1385
Raymond Hettinger4d588972014-08-12 12:44:52 -07001386 def test_early_detection_of_bad_call(self):
1387 # Issue #22184
1388 with self.assertRaises(TypeError):
1389 @functools.lru_cache
1390 def f():
1391 pass
1392
Serhiy Storchakae7070f02015-06-08 11:19:24 +03001393 def test_lru_method(self):
1394 class X(int):
1395 f_cnt = 0
1396 @self.module.lru_cache(2)
1397 def f(self, x):
1398 self.f_cnt += 1
1399 return x*10+self
1400 a = X(5)
1401 b = X(5)
1402 c = X(7)
1403 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1404
1405 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1406 self.assertEqual(a.f(x), x*10 + 5)
1407 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1408 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1409
1410 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1411 self.assertEqual(b.f(x), x*10 + 5)
1412 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1413 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1414
1415 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1416 self.assertEqual(c.f(x), x*10 + 7)
1417 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1418 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1419
1420 self.assertEqual(a.f.cache_info(), X.f.cache_info())
1421 self.assertEqual(b.f.cache_info(), X.f.cache_info())
1422 self.assertEqual(c.f.cache_info(), X.f.cache_info())
1423
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001424 def test_pickle(self):
1425 cls = self.__class__
1426 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1427 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1428 with self.subTest(proto=proto, func=f):
1429 f_copy = pickle.loads(pickle.dumps(f, proto))
1430 self.assertIs(f_copy, f)
1431
1432 def test_copy(self):
1433 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001434 def orig(x, y):
1435 return 3 * x + y
1436 part = self.module.partial(orig, 2)
1437 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1438 self.module.lru_cache(2)(part))
1439 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001440 with self.subTest(func=f):
1441 f_copy = copy.copy(f)
1442 self.assertIs(f_copy, f)
1443
1444 def test_deepcopy(self):
1445 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001446 def orig(x, y):
1447 return 3 * x + y
1448 part = self.module.partial(orig, 2)
1449 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1450 self.module.lru_cache(2)(part))
1451 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001452 with self.subTest(func=f):
1453 f_copy = copy.deepcopy(f)
1454 self.assertIs(f_copy, f)
1455
1456
1457@py_functools.lru_cache()
1458def py_cached_func(x, y):
1459 return 3 * x + y
1460
1461@c_functools.lru_cache()
1462def c_cached_func(x, y):
1463 return 3 * x + y
1464
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001465
1466class TestLRUPy(TestLRU, unittest.TestCase):
1467 module = py_functools
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001468 cached_func = py_cached_func,
1469
1470 @module.lru_cache()
1471 def cached_meth(self, x, y):
1472 return 3 * x + y
1473
1474 @staticmethod
1475 @module.lru_cache()
1476 def cached_staticmeth(x, y):
1477 return 3 * x + y
1478
1479
1480class TestLRUC(TestLRU, unittest.TestCase):
1481 module = c_functools
1482 cached_func = c_cached_func,
1483
1484 @module.lru_cache()
1485 def cached_meth(self, x, y):
1486 return 3 * x + y
1487
1488 @staticmethod
1489 @module.lru_cache()
1490 def cached_staticmeth(x, y):
1491 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001492
Raymond Hettinger03923422013-03-04 02:52:50 -05001493
Łukasz Langa6f692512013-06-05 12:20:24 +02001494class TestSingleDispatch(unittest.TestCase):
1495 def test_simple_overloads(self):
1496 @functools.singledispatch
1497 def g(obj):
1498 return "base"
1499 def g_int(i):
1500 return "integer"
1501 g.register(int, g_int)
1502 self.assertEqual(g("str"), "base")
1503 self.assertEqual(g(1), "integer")
1504 self.assertEqual(g([1,2,3]), "base")
1505
1506 def test_mro(self):
1507 @functools.singledispatch
1508 def g(obj):
1509 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001510 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001511 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001512 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001513 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001514 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001515 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001516 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001517 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001518 def g_A(a):
1519 return "A"
1520 def g_B(b):
1521 return "B"
1522 g.register(A, g_A)
1523 g.register(B, g_B)
1524 self.assertEqual(g(A()), "A")
1525 self.assertEqual(g(B()), "B")
1526 self.assertEqual(g(C()), "A")
1527 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001528
1529 def test_register_decorator(self):
1530 @functools.singledispatch
1531 def g(obj):
1532 return "base"
1533 @g.register(int)
1534 def g_int(i):
1535 return "int %s" % (i,)
1536 self.assertEqual(g(""), "base")
1537 self.assertEqual(g(12), "int 12")
1538 self.assertIs(g.dispatch(int), g_int)
1539 self.assertIs(g.dispatch(object), g.dispatch(str))
1540 # Note: in the assert above this is not g.
1541 # @singledispatch returns the wrapper.
1542
1543 def test_wrapping_attributes(self):
1544 @functools.singledispatch
1545 def g(obj):
1546 "Simple test"
1547 return "Test"
1548 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02001549 if sys.flags.optimize < 2:
1550 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02001551
1552 @unittest.skipUnless(decimal, 'requires _decimal')
1553 @support.cpython_only
1554 def test_c_classes(self):
1555 @functools.singledispatch
1556 def g(obj):
1557 return "base"
1558 @g.register(decimal.DecimalException)
1559 def _(obj):
1560 return obj.args
1561 subn = decimal.Subnormal("Exponent < Emin")
1562 rnd = decimal.Rounded("Number got rounded")
1563 self.assertEqual(g(subn), ("Exponent < Emin",))
1564 self.assertEqual(g(rnd), ("Number got rounded",))
1565 @g.register(decimal.Subnormal)
1566 def _(obj):
1567 return "Too small to care."
1568 self.assertEqual(g(subn), "Too small to care.")
1569 self.assertEqual(g(rnd), ("Number got rounded",))
1570
1571 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001572 # None of the examples in this test depend on haystack ordering.
Łukasz Langa6f692512013-06-05 12:20:24 +02001573 c = collections
1574 mro = functools._compose_mro
1575 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1576 for haystack in permutations(bases):
1577 m = mro(dict, haystack)
Guido van Rossumf0666942016-08-23 10:47:07 -07001578 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
1579 c.Collection, c.Sized, c.Iterable,
1580 c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001581 bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict]
1582 for haystack in permutations(bases):
1583 m = mro(c.ChainMap, haystack)
1584 self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001585 c.Collection, c.Sized, c.Iterable,
1586 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001587
1588 # If there's a generic function with implementations registered for
1589 # both Sized and Container, passing a defaultdict to it results in an
1590 # ambiguous dispatch which will cause a RuntimeError (see
1591 # test_mro_conflicts).
1592 bases = [c.Container, c.Sized, str]
1593 for haystack in permutations(bases):
1594 m = mro(c.defaultdict, [c.Sized, c.Container, str])
1595 self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container,
1596 object])
1597
1598 # MutableSequence below is registered directly on D. In other words, it
Martin Panter46f50722016-05-26 05:35:26 +00001599 # precedes MutableMapping which means single dispatch will always
Łukasz Langa3720c772013-07-01 16:00:38 +02001600 # choose MutableSequence here.
1601 class D(c.defaultdict):
1602 pass
1603 c.MutableSequence.register(D)
1604 bases = [c.MutableSequence, c.MutableMapping]
1605 for haystack in permutations(bases):
1606 m = mro(D, bases)
Guido van Rossumf0666942016-08-23 10:47:07 -07001607 self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
1608 c.defaultdict, dict, c.MutableMapping, c.Mapping,
1609 c.Collection, c.Sized, c.Iterable, c.Container,
Łukasz Langa3720c772013-07-01 16:00:38 +02001610 object])
1611
1612 # Container and Callable are registered on different base classes and
1613 # a generic function supporting both should always pick the Callable
1614 # implementation if a C instance is passed.
1615 class C(c.defaultdict):
1616 def __call__(self):
1617 pass
1618 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1619 for haystack in permutations(bases):
1620 m = mro(C, haystack)
1621 self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001622 c.Collection, c.Sized, c.Iterable,
1623 c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001624
1625 def test_register_abc(self):
1626 c = collections
1627 d = {"a": "b"}
1628 l = [1, 2, 3]
1629 s = {object(), None}
1630 f = frozenset(s)
1631 t = (1, 2, 3)
1632 @functools.singledispatch
1633 def g(obj):
1634 return "base"
1635 self.assertEqual(g(d), "base")
1636 self.assertEqual(g(l), "base")
1637 self.assertEqual(g(s), "base")
1638 self.assertEqual(g(f), "base")
1639 self.assertEqual(g(t), "base")
1640 g.register(c.Sized, lambda obj: "sized")
1641 self.assertEqual(g(d), "sized")
1642 self.assertEqual(g(l), "sized")
1643 self.assertEqual(g(s), "sized")
1644 self.assertEqual(g(f), "sized")
1645 self.assertEqual(g(t), "sized")
1646 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1647 self.assertEqual(g(d), "mutablemapping")
1648 self.assertEqual(g(l), "sized")
1649 self.assertEqual(g(s), "sized")
1650 self.assertEqual(g(f), "sized")
1651 self.assertEqual(g(t), "sized")
1652 g.register(c.ChainMap, lambda obj: "chainmap")
1653 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1654 self.assertEqual(g(l), "sized")
1655 self.assertEqual(g(s), "sized")
1656 self.assertEqual(g(f), "sized")
1657 self.assertEqual(g(t), "sized")
1658 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1659 self.assertEqual(g(d), "mutablemapping")
1660 self.assertEqual(g(l), "mutablesequence")
1661 self.assertEqual(g(s), "sized")
1662 self.assertEqual(g(f), "sized")
1663 self.assertEqual(g(t), "sized")
1664 g.register(c.MutableSet, lambda obj: "mutableset")
1665 self.assertEqual(g(d), "mutablemapping")
1666 self.assertEqual(g(l), "mutablesequence")
1667 self.assertEqual(g(s), "mutableset")
1668 self.assertEqual(g(f), "sized")
1669 self.assertEqual(g(t), "sized")
1670 g.register(c.Mapping, lambda obj: "mapping")
1671 self.assertEqual(g(d), "mutablemapping") # not specific enough
1672 self.assertEqual(g(l), "mutablesequence")
1673 self.assertEqual(g(s), "mutableset")
1674 self.assertEqual(g(f), "sized")
1675 self.assertEqual(g(t), "sized")
1676 g.register(c.Sequence, lambda obj: "sequence")
1677 self.assertEqual(g(d), "mutablemapping")
1678 self.assertEqual(g(l), "mutablesequence")
1679 self.assertEqual(g(s), "mutableset")
1680 self.assertEqual(g(f), "sized")
1681 self.assertEqual(g(t), "sequence")
1682 g.register(c.Set, lambda obj: "set")
1683 self.assertEqual(g(d), "mutablemapping")
1684 self.assertEqual(g(l), "mutablesequence")
1685 self.assertEqual(g(s), "mutableset")
1686 self.assertEqual(g(f), "set")
1687 self.assertEqual(g(t), "sequence")
1688 g.register(dict, lambda obj: "dict")
1689 self.assertEqual(g(d), "dict")
1690 self.assertEqual(g(l), "mutablesequence")
1691 self.assertEqual(g(s), "mutableset")
1692 self.assertEqual(g(f), "set")
1693 self.assertEqual(g(t), "sequence")
1694 g.register(list, lambda obj: "list")
1695 self.assertEqual(g(d), "dict")
1696 self.assertEqual(g(l), "list")
1697 self.assertEqual(g(s), "mutableset")
1698 self.assertEqual(g(f), "set")
1699 self.assertEqual(g(t), "sequence")
1700 g.register(set, lambda obj: "concrete-set")
1701 self.assertEqual(g(d), "dict")
1702 self.assertEqual(g(l), "list")
1703 self.assertEqual(g(s), "concrete-set")
1704 self.assertEqual(g(f), "set")
1705 self.assertEqual(g(t), "sequence")
1706 g.register(frozenset, lambda obj: "frozen-set")
1707 self.assertEqual(g(d), "dict")
1708 self.assertEqual(g(l), "list")
1709 self.assertEqual(g(s), "concrete-set")
1710 self.assertEqual(g(f), "frozen-set")
1711 self.assertEqual(g(t), "sequence")
1712 g.register(tuple, lambda obj: "tuple")
1713 self.assertEqual(g(d), "dict")
1714 self.assertEqual(g(l), "list")
1715 self.assertEqual(g(s), "concrete-set")
1716 self.assertEqual(g(f), "frozen-set")
1717 self.assertEqual(g(t), "tuple")
1718
Łukasz Langa3720c772013-07-01 16:00:38 +02001719 def test_c3_abc(self):
1720 c = collections
1721 mro = functools._c3_mro
1722 class A(object):
1723 pass
1724 class B(A):
1725 def __len__(self):
1726 return 0 # implies Sized
1727 @c.Container.register
1728 class C(object):
1729 pass
1730 class D(object):
1731 pass # unrelated
1732 class X(D, C, B):
1733 def __call__(self):
1734 pass # implies Callable
1735 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
1736 for abcs in permutations([c.Sized, c.Callable, c.Container]):
1737 self.assertEqual(mro(X, abcs=abcs), expected)
1738 # unrelated ABCs don't appear in the resulting MRO
1739 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
1740 self.assertEqual(mro(X, abcs=many_abcs), expected)
1741
Yury Selivanov77a8cd62015-08-18 14:20:00 -04001742 def test_false_meta(self):
1743 # see issue23572
1744 class MetaA(type):
1745 def __len__(self):
1746 return 0
1747 class A(metaclass=MetaA):
1748 pass
1749 class AA(A):
1750 pass
1751 @functools.singledispatch
1752 def fun(a):
1753 return 'base A'
1754 @fun.register(A)
1755 def _(a):
1756 return 'fun A'
1757 aa = AA()
1758 self.assertEqual(fun(aa), 'fun A')
1759
Łukasz Langa6f692512013-06-05 12:20:24 +02001760 def test_mro_conflicts(self):
1761 c = collections
Łukasz Langa6f692512013-06-05 12:20:24 +02001762 @functools.singledispatch
1763 def g(arg):
1764 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02001765 class O(c.Sized):
1766 def __len__(self):
1767 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001768 o = O()
1769 self.assertEqual(g(o), "base")
1770 g.register(c.Iterable, lambda arg: "iterable")
1771 g.register(c.Container, lambda arg: "container")
1772 g.register(c.Sized, lambda arg: "sized")
1773 g.register(c.Set, lambda arg: "set")
1774 self.assertEqual(g(o), "sized")
1775 c.Iterable.register(O)
1776 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
1777 c.Container.register(O)
1778 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02001779 c.Set.register(O)
1780 self.assertEqual(g(o), "set") # because c.Set is a subclass of
1781 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02001782 class P:
1783 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02001784 p = P()
1785 self.assertEqual(g(p), "base")
1786 c.Iterable.register(P)
1787 self.assertEqual(g(p), "iterable")
1788 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02001789 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02001790 g(p)
Benjamin Petersonab078e92016-07-13 21:13:29 -07001791 self.assertIn(
1792 str(re_one.exception),
1793 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1794 "or <class 'collections.abc.Iterable'>"),
1795 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
1796 "or <class 'collections.abc.Container'>")),
1797 )
Łukasz Langa6f692512013-06-05 12:20:24 +02001798 class Q(c.Sized):
1799 def __len__(self):
1800 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001801 q = Q()
1802 self.assertEqual(g(q), "sized")
1803 c.Iterable.register(Q)
1804 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
1805 c.Set.register(Q)
1806 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02001807 # c.Sized and c.Iterable
1808 @functools.singledispatch
1809 def h(arg):
1810 return "base"
1811 @h.register(c.Sized)
1812 def _(arg):
1813 return "sized"
1814 @h.register(c.Container)
1815 def _(arg):
1816 return "container"
1817 # Even though Sized and Container are explicit bases of MutableMapping,
1818 # this ABC is implicitly registered on defaultdict which makes all of
1819 # MutableMapping's bases implicit as well from defaultdict's
1820 # perspective.
1821 with self.assertRaises(RuntimeError) as re_two:
1822 h(c.defaultdict(lambda: 0))
Benjamin Petersonab078e92016-07-13 21:13:29 -07001823 self.assertIn(
1824 str(re_two.exception),
1825 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1826 "or <class 'collections.abc.Sized'>"),
1827 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1828 "or <class 'collections.abc.Container'>")),
1829 )
Łukasz Langa3720c772013-07-01 16:00:38 +02001830 class R(c.defaultdict):
1831 pass
1832 c.MutableSequence.register(R)
1833 @functools.singledispatch
1834 def i(arg):
1835 return "base"
1836 @i.register(c.MutableMapping)
1837 def _(arg):
1838 return "mapping"
1839 @i.register(c.MutableSequence)
1840 def _(arg):
1841 return "sequence"
1842 r = R()
1843 self.assertEqual(i(r), "sequence")
1844 class S:
1845 pass
1846 class T(S, c.Sized):
1847 def __len__(self):
1848 return 0
1849 t = T()
1850 self.assertEqual(h(t), "sized")
1851 c.Container.register(T)
1852 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
1853 class U:
1854 def __len__(self):
1855 return 0
1856 u = U()
1857 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
1858 # from the existence of __len__()
1859 c.Container.register(U)
1860 # There is no preference for registered versus inferred ABCs.
1861 with self.assertRaises(RuntimeError) as re_three:
1862 h(u)
Benjamin Petersonab078e92016-07-13 21:13:29 -07001863 self.assertIn(
1864 str(re_three.exception),
1865 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1866 "or <class 'collections.abc.Sized'>"),
1867 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1868 "or <class 'collections.abc.Container'>")),
1869 )
Łukasz Langa3720c772013-07-01 16:00:38 +02001870 class V(c.Sized, S):
1871 def __len__(self):
1872 return 0
1873 @functools.singledispatch
1874 def j(arg):
1875 return "base"
1876 @j.register(S)
1877 def _(arg):
1878 return "s"
1879 @j.register(c.Container)
1880 def _(arg):
1881 return "container"
1882 v = V()
1883 self.assertEqual(j(v), "s")
1884 c.Container.register(V)
1885 self.assertEqual(j(v), "container") # because it ends up right after
1886 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02001887
1888 def test_cache_invalidation(self):
1889 from collections import UserDict
1890 class TracingDict(UserDict):
1891 def __init__(self, *args, **kwargs):
1892 super(TracingDict, self).__init__(*args, **kwargs)
1893 self.set_ops = []
1894 self.get_ops = []
1895 def __getitem__(self, key):
1896 result = self.data[key]
1897 self.get_ops.append(key)
1898 return result
1899 def __setitem__(self, key, value):
1900 self.set_ops.append(key)
1901 self.data[key] = value
1902 def clear(self):
1903 self.data.clear()
1904 _orig_wkd = functools.WeakKeyDictionary
1905 td = TracingDict()
1906 functools.WeakKeyDictionary = lambda: td
1907 c = collections
1908 @functools.singledispatch
1909 def g(arg):
1910 return "base"
1911 d = {}
1912 l = []
1913 self.assertEqual(len(td), 0)
1914 self.assertEqual(g(d), "base")
1915 self.assertEqual(len(td), 1)
1916 self.assertEqual(td.get_ops, [])
1917 self.assertEqual(td.set_ops, [dict])
1918 self.assertEqual(td.data[dict], g.registry[object])
1919 self.assertEqual(g(l), "base")
1920 self.assertEqual(len(td), 2)
1921 self.assertEqual(td.get_ops, [])
1922 self.assertEqual(td.set_ops, [dict, list])
1923 self.assertEqual(td.data[dict], g.registry[object])
1924 self.assertEqual(td.data[list], g.registry[object])
1925 self.assertEqual(td.data[dict], td.data[list])
1926 self.assertEqual(g(l), "base")
1927 self.assertEqual(g(d), "base")
1928 self.assertEqual(td.get_ops, [list, dict])
1929 self.assertEqual(td.set_ops, [dict, list])
1930 g.register(list, lambda arg: "list")
1931 self.assertEqual(td.get_ops, [list, dict])
1932 self.assertEqual(len(td), 0)
1933 self.assertEqual(g(d), "base")
1934 self.assertEqual(len(td), 1)
1935 self.assertEqual(td.get_ops, [list, dict])
1936 self.assertEqual(td.set_ops, [dict, list, dict])
1937 self.assertEqual(td.data[dict],
1938 functools._find_impl(dict, g.registry))
1939 self.assertEqual(g(l), "list")
1940 self.assertEqual(len(td), 2)
1941 self.assertEqual(td.get_ops, [list, dict])
1942 self.assertEqual(td.set_ops, [dict, list, dict, list])
1943 self.assertEqual(td.data[list],
1944 functools._find_impl(list, g.registry))
1945 class X:
1946 pass
1947 c.MutableMapping.register(X) # Will not invalidate the cache,
1948 # not using ABCs yet.
1949 self.assertEqual(g(d), "base")
1950 self.assertEqual(g(l), "list")
1951 self.assertEqual(td.get_ops, [list, dict, dict, list])
1952 self.assertEqual(td.set_ops, [dict, list, dict, list])
1953 g.register(c.Sized, lambda arg: "sized")
1954 self.assertEqual(len(td), 0)
1955 self.assertEqual(g(d), "sized")
1956 self.assertEqual(len(td), 1)
1957 self.assertEqual(td.get_ops, [list, dict, dict, list])
1958 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
1959 self.assertEqual(g(l), "list")
1960 self.assertEqual(len(td), 2)
1961 self.assertEqual(td.get_ops, [list, dict, dict, list])
1962 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1963 self.assertEqual(g(l), "list")
1964 self.assertEqual(g(d), "sized")
1965 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
1966 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1967 g.dispatch(list)
1968 g.dispatch(dict)
1969 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
1970 list, dict])
1971 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1972 c.MutableSet.register(X) # Will invalidate the cache.
1973 self.assertEqual(len(td), 2) # Stale cache.
1974 self.assertEqual(g(l), "list")
1975 self.assertEqual(len(td), 1)
1976 g.register(c.MutableMapping, lambda arg: "mutablemapping")
1977 self.assertEqual(len(td), 0)
1978 self.assertEqual(g(d), "mutablemapping")
1979 self.assertEqual(len(td), 1)
1980 self.assertEqual(g(l), "list")
1981 self.assertEqual(len(td), 2)
1982 g.register(dict, lambda arg: "dict")
1983 self.assertEqual(g(d), "dict")
1984 self.assertEqual(g(l), "list")
1985 g._clear_cache()
1986 self.assertEqual(len(td), 0)
1987 functools.WeakKeyDictionary = _orig_wkd
1988
1989
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001990if __name__ == '__main__':
Zachary Ware38c707e2015-04-13 15:00:43 -05001991 unittest.main()