blob: 824549b80342ed8f72483208a4baa892c8b6c715 [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08002import builtins
Raymond Hettinger003be522011-05-03 11:01:32 -07003import collections
Serhiy Storchaka45120f22015-10-24 09:49:56 +03004import copy
Łukasz Langa6f692512013-06-05 12:20:24 +02005from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00006import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00007from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02008import sys
9from test import support
Serhiy Storchaka67796522017-01-12 18:34:33 +020010import time
Łukasz Langa6f692512013-06-05 12:20:24 +020011import unittest
12from weakref import proxy
Nick Coghlan457fc9a2016-09-10 20:00:02 +100013import contextlib
Serhiy Storchaka46c56112015-05-24 21:53:49 +030014try:
15 import threading
16except ImportError:
17 threading = None
Raymond Hettinger9c323f82005-02-28 19:39:44 +000018
Antoine Pitroub5b37142012-11-13 21:35:40 +010019import functools
20
Antoine Pitroub5b37142012-11-13 21:35:40 +010021py_functools = support.import_fresh_module('functools', blocked=['_functools'])
22c_functools = support.import_fresh_module('functools', fresh=['_functools'])
23
Łukasz Langa6f692512013-06-05 12:20:24 +020024decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
25
Nick Coghlan457fc9a2016-09-10 20:00:02 +100026@contextlib.contextmanager
27def replaced_module(name, replacement):
28 original_module = sys.modules[name]
29 sys.modules[name] = replacement
30 try:
31 yield
32 finally:
33 sys.modules[name] = original_module
Łukasz Langa6f692512013-06-05 12:20:24 +020034
Raymond Hettinger9c323f82005-02-28 19:39:44 +000035def capture(*args, **kw):
36 """capture all positional and keyword arguments"""
37 return args, kw
38
Łukasz Langa6f692512013-06-05 12:20:24 +020039
Jack Diederiche0cbd692009-04-01 04:27:09 +000040def signature(part):
41 """ return the signature of a partial object """
42 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000043
Serhiy Storchaka38741282016-02-02 18:45:17 +020044class MyTuple(tuple):
45 pass
46
47class BadTuple(tuple):
48 def __add__(self, other):
49 return list(self) + list(other)
50
51class MyDict(dict):
52 pass
53
Łukasz Langa6f692512013-06-05 12:20:24 +020054
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020055class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000056
57 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010058 p = self.partial(capture, 1, 2, a=10, b=20)
59 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000060 self.assertEqual(p(3, 4, b=30, c=40),
61 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010062 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000063 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000064
65 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010066 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000067 # attributes should be readable
68 self.assertEqual(p.func, capture)
69 self.assertEqual(p.args, (1, 2))
70 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000071
72 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010073 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000074 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010075 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000076 except TypeError:
77 pass
78 else:
79 self.fail('First arg not checked for callability')
80
81 def test_protection_of_callers_dict_argument(self):
82 # a caller's dictionary should not be altered by partial
83 def func(a=10, b=20):
84 return a
85 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010086 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000087 self.assertEqual(p(**d), 3)
88 self.assertEqual(d, {'a':3})
89 p(b=7)
90 self.assertEqual(d, {'a':3})
91
92 def test_arg_combinations(self):
93 # exercise special code paths for zero args in either partial
94 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010095 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000096 self.assertEqual(p(), ((), {}))
97 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010098 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000099 self.assertEqual(p(), ((1,2), {}))
100 self.assertEqual(p(3,4), ((1,2,3,4), {}))
101
102 def test_kw_combinations(self):
103 # exercise special code paths for no keyword args in
104 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100105 p = self.partial(capture)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400106 self.assertEqual(p.keywords, {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000107 self.assertEqual(p(), ((), {}))
108 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100109 p = self.partial(capture, a=1)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400110 self.assertEqual(p.keywords, {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000111 self.assertEqual(p(), ((), {'a':1}))
112 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
113 # keyword args in the call override those in the partial object
114 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
115
116 def test_positional(self):
117 # make sure positional arguments are captured correctly
118 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100119 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000120 expected = args + ('x',)
121 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000122 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000123
124 def test_keyword(self):
125 # make sure keyword arguments are captured correctly
126 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100127 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000128 expected = {'a':a,'x':None}
129 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000130 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000131
132 def test_no_side_effects(self):
133 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100134 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000135 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000136 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000137 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000138 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000139
140 def test_error_propagation(self):
141 def f(x, y):
142 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100143 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
144 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
145 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
146 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000147
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000148 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100149 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000150 p = proxy(f)
151 self.assertEqual(f.func, p.func)
152 f = None
153 self.assertRaises(ReferenceError, getattr, p, 'func')
154
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000155 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000156 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100157 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000158 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100159 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000160 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000161
Alexander Belopolskye49af342015-03-01 15:08:17 -0500162 def test_nested_optimization(self):
163 partial = self.partial
Alexander Belopolskye49af342015-03-01 15:08:17 -0500164 inner = partial(signature, 'asdf')
165 nested = partial(inner, bar=True)
166 flat = partial(signature, 'asdf', bar=True)
167 self.assertEqual(signature(nested), signature(flat))
168
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300169 def test_nested_partial_with_attribute(self):
170 # see issue 25137
171 partial = self.partial
172
173 def foo(bar):
174 return bar
175
176 p = partial(foo, 'first')
177 p2 = partial(p, 'second')
178 p2.new_attr = 'spam'
179 self.assertEqual(p2.new_attr, 'spam')
180
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000181 def test_repr(self):
182 args = (object(), object())
183 args_repr = ', '.join(repr(a) for a in args)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200184 kwargs = {'a': object(), 'b': object()}
185 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
186 'b={b!r}, a={a!r}'.format_map(kwargs)]
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000187 if self.partial in (c_functools.partial, py_functools.partial):
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000188 name = 'functools.partial'
189 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100190 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000191
Antoine Pitroub5b37142012-11-13 21:35:40 +0100192 f = self.partial(capture)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000193 self.assertEqual(f'{name}({capture!r})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000194
Antoine Pitroub5b37142012-11-13 21:35:40 +0100195 f = self.partial(capture, *args)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000196 self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000197
Antoine Pitroub5b37142012-11-13 21:35:40 +0100198 f = self.partial(capture, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200199 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000200 [f'{name}({capture!r}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200201 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000202
Antoine Pitroub5b37142012-11-13 21:35:40 +0100203 f = self.partial(capture, *args, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200204 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000205 [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200206 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000207
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300208 def test_recursive_repr(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000209 if self.partial in (c_functools.partial, py_functools.partial):
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300210 name = 'functools.partial'
211 else:
212 name = self.partial.__name__
213
214 f = self.partial(capture)
215 f.__setstate__((f, (), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300216 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000217 self.assertEqual(repr(f), '%s(...)' % (name,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300218 finally:
219 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300220
221 f = self.partial(capture)
222 f.__setstate__((capture, (f,), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300223 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000224 self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300225 finally:
226 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300227
228 f = self.partial(capture)
229 f.__setstate__((capture, (), {'a': f}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300230 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000231 self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300232 finally:
233 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300234
Jack Diederiche0cbd692009-04-01 04:27:09 +0000235 def test_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000236 with self.AllowPickle():
237 f = self.partial(signature, ['asdf'], bar=[True])
238 f.attr = []
239 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
240 f_copy = pickle.loads(pickle.dumps(f, proto))
241 self.assertEqual(signature(f_copy), signature(f))
Serhiy Storchaka38741282016-02-02 18:45:17 +0200242
243 def test_copy(self):
244 f = self.partial(signature, ['asdf'], bar=[True])
245 f.attr = []
246 f_copy = copy.copy(f)
247 self.assertEqual(signature(f_copy), signature(f))
248 self.assertIs(f_copy.attr, f.attr)
249 self.assertIs(f_copy.args, f.args)
250 self.assertIs(f_copy.keywords, f.keywords)
251
252 def test_deepcopy(self):
253 f = self.partial(signature, ['asdf'], bar=[True])
254 f.attr = []
255 f_copy = copy.deepcopy(f)
256 self.assertEqual(signature(f_copy), signature(f))
257 self.assertIsNot(f_copy.attr, f.attr)
258 self.assertIsNot(f_copy.args, f.args)
259 self.assertIsNot(f_copy.args[0], f.args[0])
260 self.assertIsNot(f_copy.keywords, f.keywords)
261 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
262
263 def test_setstate(self):
264 f = self.partial(signature)
265 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000266
Serhiy Storchaka38741282016-02-02 18:45:17 +0200267 self.assertEqual(signature(f),
268 (capture, (1,), dict(a=10), dict(attr=[])))
269 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
270
271 f.__setstate__((capture, (1,), dict(a=10), None))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000272
Serhiy Storchaka38741282016-02-02 18:45:17 +0200273 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
274 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
275
276 f.__setstate__((capture, (1,), None, None))
277 #self.assertEqual(signature(f), (capture, (1,), {}, {}))
278 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
279 self.assertEqual(f(2), ((1, 2), {}))
280 self.assertEqual(f(), ((1,), {}))
281
282 f.__setstate__((capture, (), {}, None))
283 self.assertEqual(signature(f), (capture, (), {}, {}))
284 self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
285 self.assertEqual(f(2), ((2,), {}))
286 self.assertEqual(f(), ((), {}))
287
288 def test_setstate_errors(self):
289 f = self.partial(signature)
290 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
291 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
292 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
293 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
294 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
295 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
296 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
297
298 def test_setstate_subclasses(self):
299 f = self.partial(signature)
300 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
301 s = signature(f)
302 self.assertEqual(s, (capture, (1,), dict(a=10), {}))
303 self.assertIs(type(s[1]), tuple)
304 self.assertIs(type(s[2]), dict)
305 r = f()
306 self.assertEqual(r, ((1,), {'a': 10}))
307 self.assertIs(type(r[0]), tuple)
308 self.assertIs(type(r[1]), dict)
309
310 f.__setstate__((capture, BadTuple((1,)), {}, None))
311 s = signature(f)
312 self.assertEqual(s, (capture, (1,), {}, {}))
313 self.assertIs(type(s[1]), tuple)
314 r = f(2)
315 self.assertEqual(r, ((1, 2), {}))
316 self.assertIs(type(r[0]), tuple)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000317
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300318 def test_recursive_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000319 with self.AllowPickle():
320 f = self.partial(capture)
321 f.__setstate__((f, (), {}, {}))
322 try:
323 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
324 with self.assertRaises(RecursionError):
325 pickle.dumps(f, proto)
326 finally:
327 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300328
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000329 f = self.partial(capture)
330 f.__setstate__((capture, (f,), {}, {}))
331 try:
332 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
333 f_copy = pickle.loads(pickle.dumps(f, proto))
334 try:
335 self.assertIs(f_copy.args[0], f_copy)
336 finally:
337 f_copy.__setstate__((capture, (), {}, {}))
338 finally:
339 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300340
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000341 f = self.partial(capture)
342 f.__setstate__((capture, (), {'a': f}, {}))
343 try:
344 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
345 f_copy = pickle.loads(pickle.dumps(f, proto))
346 try:
347 self.assertIs(f_copy.keywords['a'], f_copy)
348 finally:
349 f_copy.__setstate__((capture, (), {}, {}))
350 finally:
351 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300352
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200353 # Issue 6083: Reference counting bug
354 def test_setstate_refcount(self):
355 class BadSequence:
356 def __len__(self):
357 return 4
358 def __getitem__(self, key):
359 if key == 0:
360 return max
361 elif key == 1:
362 return tuple(range(1000000))
363 elif key in (2, 3):
364 return {}
365 raise IndexError
366
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200367 f = self.partial(object)
Serhiy Storchaka38741282016-02-02 18:45:17 +0200368 self.assertRaises(TypeError, f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000369
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000370@unittest.skipUnless(c_functools, 'requires the C _functools module')
371class TestPartialC(TestPartial, unittest.TestCase):
372 if c_functools:
373 partial = c_functools.partial
374
375 class AllowPickle:
376 def __enter__(self):
377 return self
378 def __exit__(self, type, value, tb):
379 return False
380
381 def test_attributes_unwritable(self):
382 # attributes should not be writable
383 p = self.partial(capture, 1, 2, a=10, b=20)
384 self.assertRaises(AttributeError, setattr, p, 'func', map)
385 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
386 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
387
388 p = self.partial(hex)
389 try:
390 del p.__dict__
391 except TypeError:
392 pass
393 else:
394 self.fail('partial object allowed __dict__ to be deleted')
Łukasz Langa6f692512013-06-05 12:20:24 +0200395
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200396class TestPartialPy(TestPartial, unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000397 partial = py_functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000398
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000399 class AllowPickle:
400 def __init__(self):
401 self._cm = replaced_module("functools", py_functools)
402 def __enter__(self):
403 return self._cm.__enter__()
404 def __exit__(self, type, value, tb):
405 return self._cm.__exit__(type, value, tb)
Łukasz Langa6f692512013-06-05 12:20:24 +0200406
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200407if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000408 class CPartialSubclass(c_functools.partial):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200409 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100410
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000411class PyPartialSubclass(py_functools.partial):
412 pass
Łukasz Langa6f692512013-06-05 12:20:24 +0200413
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200414@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200415class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200416 if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000417 partial = CPartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000418
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300419 # partial subclasses are not optimized for nested calls
420 test_nested_optimization = None
421
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000422class TestPartialPySubclass(TestPartialPy):
423 partial = PyPartialSubclass
Łukasz Langa6f692512013-06-05 12:20:24 +0200424
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000425class TestPartialMethod(unittest.TestCase):
426
427 class A(object):
428 nothing = functools.partialmethod(capture)
429 positional = functools.partialmethod(capture, 1)
430 keywords = functools.partialmethod(capture, a=2)
431 both = functools.partialmethod(capture, 3, b=4)
432
433 nested = functools.partialmethod(positional, 5)
434
435 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
436
437 static = functools.partialmethod(staticmethod(capture), 8)
438 cls = functools.partialmethod(classmethod(capture), d=9)
439
440 a = A()
441
442 def test_arg_combinations(self):
443 self.assertEqual(self.a.nothing(), ((self.a,), {}))
444 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
445 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
446 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
447
448 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
449 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
450 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
451 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
452
453 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
454 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
455 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
456 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
457
458 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
459 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
460 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
461 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
462
463 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
464
465 def test_nested(self):
466 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
467 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
468 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
469 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
470
471 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
472
473 def test_over_partial(self):
474 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
475 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
476 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
477 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
478
479 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
480
481 def test_bound_method_introspection(self):
482 obj = self.a
483 self.assertIs(obj.both.__self__, obj)
484 self.assertIs(obj.nested.__self__, obj)
485 self.assertIs(obj.over_partial.__self__, obj)
486 self.assertIs(obj.cls.__self__, self.A)
487 self.assertIs(self.A.cls.__self__, self.A)
488
489 def test_unbound_method_retrieval(self):
490 obj = self.A
491 self.assertFalse(hasattr(obj.both, "__self__"))
492 self.assertFalse(hasattr(obj.nested, "__self__"))
493 self.assertFalse(hasattr(obj.over_partial, "__self__"))
494 self.assertFalse(hasattr(obj.static, "__self__"))
495 self.assertFalse(hasattr(self.a.static, "__self__"))
496
497 def test_descriptors(self):
498 for obj in [self.A, self.a]:
499 with self.subTest(obj=obj):
500 self.assertEqual(obj.static(), ((8,), {}))
501 self.assertEqual(obj.static(5), ((8, 5), {}))
502 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
503 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
504
505 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
506 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
507 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
508 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
509
510 def test_overriding_keywords(self):
511 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
512 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
513
514 def test_invalid_args(self):
515 with self.assertRaises(TypeError):
516 class B(object):
517 method = functools.partialmethod(None, 1)
518
519 def test_repr(self):
520 self.assertEqual(repr(vars(self.A)['both']),
521 'functools.partialmethod({}, 3, b=4)'.format(capture))
522
523 def test_abstract(self):
524 class Abstract(abc.ABCMeta):
525
526 @abc.abstractmethod
527 def add(self, x, y):
528 pass
529
530 add5 = functools.partialmethod(add, 5)
531
532 self.assertTrue(Abstract.add.__isabstractmethod__)
533 self.assertTrue(Abstract.add5.__isabstractmethod__)
534
535 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
536 self.assertFalse(getattr(func, '__isabstractmethod__', False))
537
538
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000539class TestUpdateWrapper(unittest.TestCase):
540
541 def check_wrapper(self, wrapper, wrapped,
542 assigned=functools.WRAPPER_ASSIGNMENTS,
543 updated=functools.WRAPPER_UPDATES):
544 # Check attributes were assigned
545 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000546 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000547 # Check attributes were updated
548 for name in updated:
549 wrapper_attr = getattr(wrapper, name)
550 wrapped_attr = getattr(wrapped, name)
551 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000552 if name == "__dict__" and key == "__wrapped__":
553 # __wrapped__ is overwritten by the update code
554 continue
555 self.assertIs(wrapped_attr[key], wrapper_attr[key])
556 # Check __wrapped__
557 self.assertIs(wrapper.__wrapped__, wrapped)
558
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000559
R. David Murray378c0cf2010-02-24 01:46:21 +0000560 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000561 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000562 """This is a test"""
563 pass
564 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000565 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000566 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000567 pass
568 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000569 return wrapper, f
570
571 def test_default_update(self):
572 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000573 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000574 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000575 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600576 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000577 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000578 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
579 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000580
R. David Murray378c0cf2010-02-24 01:46:21 +0000581 @unittest.skipIf(sys.flags.optimize >= 2,
582 "Docstrings are omitted with -O2 and above")
583 def test_default_update_doc(self):
584 wrapper, f = self._default_update()
585 self.assertEqual(wrapper.__doc__, 'This is a test')
586
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000587 def test_no_update(self):
588 def f():
589 """This is a test"""
590 pass
591 f.attr = 'This is also a test'
592 def wrapper():
593 pass
594 functools.update_wrapper(wrapper, f, (), ())
595 self.check_wrapper(wrapper, f, (), ())
596 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600597 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000598 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000599 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000600 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000601
602 def test_selective_update(self):
603 def f():
604 pass
605 f.attr = 'This is a different test'
606 f.dict_attr = dict(a=1, b=2, c=3)
607 def wrapper():
608 pass
609 wrapper.dict_attr = {}
610 assign = ('attr',)
611 update = ('dict_attr',)
612 functools.update_wrapper(wrapper, f, assign, update)
613 self.check_wrapper(wrapper, f, assign, update)
614 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600615 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000616 self.assertEqual(wrapper.__doc__, None)
617 self.assertEqual(wrapper.attr, 'This is a different test')
618 self.assertEqual(wrapper.dict_attr, f.dict_attr)
619
Nick Coghlan98876832010-08-17 06:17:18 +0000620 def test_missing_attributes(self):
621 def f():
622 pass
623 def wrapper():
624 pass
625 wrapper.dict_attr = {}
626 assign = ('attr',)
627 update = ('dict_attr',)
628 # Missing attributes on wrapped object are ignored
629 functools.update_wrapper(wrapper, f, assign, update)
630 self.assertNotIn('attr', wrapper.__dict__)
631 self.assertEqual(wrapper.dict_attr, {})
632 # Wrapper must have expected attributes for updating
633 del wrapper.dict_attr
634 with self.assertRaises(AttributeError):
635 functools.update_wrapper(wrapper, f, assign, update)
636 wrapper.dict_attr = 1
637 with self.assertRaises(AttributeError):
638 functools.update_wrapper(wrapper, f, assign, update)
639
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200640 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000641 @unittest.skipIf(sys.flags.optimize >= 2,
642 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000643 def test_builtin_update(self):
644 # Test for bug #1576241
645 def wrapper():
646 pass
647 functools.update_wrapper(wrapper, max)
648 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000649 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000650 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000651
Łukasz Langa6f692512013-06-05 12:20:24 +0200652
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000653class TestWraps(TestUpdateWrapper):
654
R. David Murray378c0cf2010-02-24 01:46:21 +0000655 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000656 def f():
657 """This is a test"""
658 pass
659 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000660 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000661 @functools.wraps(f)
662 def wrapper():
663 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600664 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000665
666 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600667 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000668 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000669 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600670 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000671 self.assertEqual(wrapper.attr, 'This is also a test')
672
Antoine Pitroub5b37142012-11-13 21:35:40 +0100673 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000674 "Docstrings are omitted with -O2 and above")
675 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600676 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000677 self.assertEqual(wrapper.__doc__, 'This is a test')
678
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000679 def test_no_update(self):
680 def f():
681 """This is a test"""
682 pass
683 f.attr = 'This is also a test'
684 @functools.wraps(f, (), ())
685 def wrapper():
686 pass
687 self.check_wrapper(wrapper, f, (), ())
688 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600689 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000690 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000691 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000692
693 def test_selective_update(self):
694 def f():
695 pass
696 f.attr = 'This is a different test'
697 f.dict_attr = dict(a=1, b=2, c=3)
698 def add_dict_attr(f):
699 f.dict_attr = {}
700 return f
701 assign = ('attr',)
702 update = ('dict_attr',)
703 @functools.wraps(f, assign, update)
704 @add_dict_attr
705 def wrapper():
706 pass
707 self.check_wrapper(wrapper, f, assign, update)
708 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600709 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000710 self.assertEqual(wrapper.__doc__, None)
711 self.assertEqual(wrapper.attr, 'This is a different test')
712 self.assertEqual(wrapper.dict_attr, f.dict_attr)
713
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000714@unittest.skipUnless(c_functools, 'requires the C _functools module')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000715class TestReduce(unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000716 if c_functools:
717 func = c_functools.reduce
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000718
719 def test_reduce(self):
720 class Squares:
721 def __init__(self, max):
722 self.max = max
723 self.sofar = []
724
725 def __len__(self):
726 return len(self.sofar)
727
728 def __getitem__(self, i):
729 if not 0 <= i < self.max: raise IndexError
730 n = len(self.sofar)
731 while n <= i:
732 self.sofar.append(n*n)
733 n += 1
734 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000735 def add(x, y):
736 return x + y
737 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000738 self.assertEqual(
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000739 self.func(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000740 ['a','c','d','w']
741 )
742 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
743 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000744 self.func(lambda x, y: x*y, range(2,21), 1),
745 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000746 )
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000747 self.assertEqual(self.func(add, Squares(10)), 285)
748 self.assertEqual(self.func(add, Squares(10), 0), 285)
749 self.assertEqual(self.func(add, Squares(0), 0), 0)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000750 self.assertRaises(TypeError, self.func)
751 self.assertRaises(TypeError, self.func, 42, 42)
752 self.assertRaises(TypeError, self.func, 42, 42, 42)
753 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
754 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
755 self.assertRaises(TypeError, self.func, 42, (42, 42))
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000756 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
757 self.assertRaises(TypeError, self.func, add, "")
758 self.assertRaises(TypeError, self.func, add, ())
759 self.assertRaises(TypeError, self.func, add, object())
760
761 class TestFailingIter:
762 def __iter__(self):
763 raise RuntimeError
764 self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
765
766 self.assertEqual(self.func(add, [], None), None)
767 self.assertEqual(self.func(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000768
769 class BadSeq:
770 def __getitem__(self, index):
771 raise ValueError
772 self.assertRaises(ValueError, self.func, 42, BadSeq())
773
774 # Test reduce()'s use of iterators.
775 def test_iterator_usage(self):
776 class SequenceClass:
777 def __init__(self, n):
778 self.n = n
779 def __getitem__(self, i):
780 if 0 <= i < self.n:
781 return i
782 else:
783 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000784
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000785 from operator import add
786 self.assertEqual(self.func(add, SequenceClass(5)), 10)
787 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
788 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
789 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
790 self.assertEqual(self.func(add, SequenceClass(1)), 0)
791 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
792
793 d = {"one": 1, "two": 2, "three": 3}
794 self.assertEqual(self.func(add, d), "".join(d.keys()))
795
Łukasz Langa6f692512013-06-05 12:20:24 +0200796
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200797class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700798
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000799 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700800 def cmp1(x, y):
801 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100802 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700803 self.assertEqual(key(3), key(3))
804 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100805 self.assertGreaterEqual(key(3), key(3))
806
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700807 def cmp2(x, y):
808 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100809 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700810 self.assertEqual(key(4.0), key('4'))
811 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100812 self.assertLessEqual(key(2), key('35'))
813 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700814
815 def test_cmp_to_key_arguments(self):
816 def cmp1(x, y):
817 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100818 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700819 self.assertEqual(key(obj=3), key(obj=3))
820 self.assertGreater(key(obj=3), key(obj=1))
821 with self.assertRaises((TypeError, AttributeError)):
822 key(3) > 1 # rhs is not a K object
823 with self.assertRaises((TypeError, AttributeError)):
824 1 < key(3) # lhs is not a K object
825 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100826 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700827 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200828 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100829 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700830 with self.assertRaises(TypeError):
831 key() # too few args
832 with self.assertRaises(TypeError):
833 key(None, None) # too many args
834
835 def test_bad_cmp(self):
836 def cmp1(x, y):
837 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100838 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700839 with self.assertRaises(ZeroDivisionError):
840 key(3) > key(1)
841
842 class BadCmp:
843 def __lt__(self, other):
844 raise ZeroDivisionError
845 def cmp1(x, y):
846 return BadCmp()
847 with self.assertRaises(ZeroDivisionError):
848 key(3) > key(1)
849
850 def test_obj_field(self):
851 def cmp1(x, y):
852 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100853 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700854 self.assertEqual(key(50).obj, 50)
855
856 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000857 def mycmp(x, y):
858 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100859 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000860 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000861
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700862 def test_sort_int_str(self):
863 def mycmp(x, y):
864 x, y = int(x), int(y)
865 return (x > y) - (x < y)
866 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100867 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700868 self.assertEqual([int(value) for value in values],
869 [0, 1, 1, 2, 3, 4, 5, 7, 10])
870
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000871 def test_hash(self):
872 def mycmp(x, y):
873 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100874 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000875 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700876 self.assertRaises(TypeError, hash, k)
Raymond Hettingere7a24302011-05-03 11:16:36 -0700877 self.assertNotIsInstance(k, collections.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000878
Łukasz Langa6f692512013-06-05 12:20:24 +0200879
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200880@unittest.skipUnless(c_functools, 'requires the C _functools module')
881class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
882 if c_functools:
883 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100884
Łukasz Langa6f692512013-06-05 12:20:24 +0200885
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200886class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100887 cmp_to_key = staticmethod(py_functools.cmp_to_key)
888
Łukasz Langa6f692512013-06-05 12:20:24 +0200889
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000890class TestTotalOrdering(unittest.TestCase):
891
892 def test_total_ordering_lt(self):
893 @functools.total_ordering
894 class A:
895 def __init__(self, value):
896 self.value = value
897 def __lt__(self, other):
898 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000899 def __eq__(self, other):
900 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000901 self.assertTrue(A(1) < A(2))
902 self.assertTrue(A(2) > A(1))
903 self.assertTrue(A(1) <= A(2))
904 self.assertTrue(A(2) >= A(1))
905 self.assertTrue(A(2) <= A(2))
906 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000907 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000908
909 def test_total_ordering_le(self):
910 @functools.total_ordering
911 class A:
912 def __init__(self, value):
913 self.value = value
914 def __le__(self, other):
915 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000916 def __eq__(self, other):
917 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000918 self.assertTrue(A(1) < A(2))
919 self.assertTrue(A(2) > A(1))
920 self.assertTrue(A(1) <= A(2))
921 self.assertTrue(A(2) >= A(1))
922 self.assertTrue(A(2) <= A(2))
923 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000924 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000925
926 def test_total_ordering_gt(self):
927 @functools.total_ordering
928 class A:
929 def __init__(self, value):
930 self.value = value
931 def __gt__(self, other):
932 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000933 def __eq__(self, other):
934 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000935 self.assertTrue(A(1) < A(2))
936 self.assertTrue(A(2) > A(1))
937 self.assertTrue(A(1) <= A(2))
938 self.assertTrue(A(2) >= A(1))
939 self.assertTrue(A(2) <= A(2))
940 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000941 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000942
943 def test_total_ordering_ge(self):
944 @functools.total_ordering
945 class A:
946 def __init__(self, value):
947 self.value = value
948 def __ge__(self, other):
949 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000950 def __eq__(self, other):
951 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000952 self.assertTrue(A(1) < A(2))
953 self.assertTrue(A(2) > A(1))
954 self.assertTrue(A(1) <= A(2))
955 self.assertTrue(A(2) >= A(1))
956 self.assertTrue(A(2) <= A(2))
957 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000958 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000959
960 def test_total_ordering_no_overwrite(self):
961 # new methods should not overwrite existing
962 @functools.total_ordering
963 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000964 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +0000965 self.assertTrue(A(1) < A(2))
966 self.assertTrue(A(2) > A(1))
967 self.assertTrue(A(1) <= A(2))
968 self.assertTrue(A(2) >= A(1))
969 self.assertTrue(A(2) <= A(2))
970 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000971
Benjamin Peterson42ebee32010-04-11 01:43:16 +0000972 def test_no_operations_defined(self):
973 with self.assertRaises(ValueError):
974 @functools.total_ordering
975 class A:
976 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000977
Nick Coghlanf05d9812013-10-02 00:02:03 +1000978 def test_type_error_when_not_implemented(self):
979 # bug 10042; ensure stack overflow does not occur
980 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000981 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +1000982 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000983 def __init__(self, value):
984 self.value = value
985 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000986 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000987 return self.value == other.value
988 return False
989 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000990 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000991 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +1000992 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000993
Nick Coghlanf05d9812013-10-02 00:02:03 +1000994 @functools.total_ordering
995 class ImplementsGreaterThan:
996 def __init__(self, value):
997 self.value = value
998 def __eq__(self, other):
999 if isinstance(other, ImplementsGreaterThan):
1000 return self.value == other.value
1001 return False
1002 def __gt__(self, other):
1003 if isinstance(other, ImplementsGreaterThan):
1004 return self.value > other.value
1005 return NotImplemented
1006
1007 @functools.total_ordering
1008 class ImplementsLessThanEqualTo:
1009 def __init__(self, value):
1010 self.value = value
1011 def __eq__(self, other):
1012 if isinstance(other, ImplementsLessThanEqualTo):
1013 return self.value == other.value
1014 return False
1015 def __le__(self, other):
1016 if isinstance(other, ImplementsLessThanEqualTo):
1017 return self.value <= other.value
1018 return NotImplemented
1019
1020 @functools.total_ordering
1021 class ImplementsGreaterThanEqualTo:
1022 def __init__(self, value):
1023 self.value = value
1024 def __eq__(self, other):
1025 if isinstance(other, ImplementsGreaterThanEqualTo):
1026 return self.value == other.value
1027 return False
1028 def __ge__(self, other):
1029 if isinstance(other, ImplementsGreaterThanEqualTo):
1030 return self.value >= other.value
1031 return NotImplemented
1032
1033 @functools.total_ordering
1034 class ComparatorNotImplemented:
1035 def __init__(self, value):
1036 self.value = value
1037 def __eq__(self, other):
1038 if isinstance(other, ComparatorNotImplemented):
1039 return self.value == other.value
1040 return False
1041 def __lt__(self, other):
1042 return NotImplemented
1043
1044 with self.subTest("LT < 1"), self.assertRaises(TypeError):
1045 ImplementsLessThan(-1) < 1
1046
1047 with self.subTest("LT < LE"), self.assertRaises(TypeError):
1048 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
1049
1050 with self.subTest("LT < GT"), self.assertRaises(TypeError):
1051 ImplementsLessThan(1) < ImplementsGreaterThan(1)
1052
1053 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
1054 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
1055
1056 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
1057 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
1058
1059 with self.subTest("GT > GE"), self.assertRaises(TypeError):
1060 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
1061
1062 with self.subTest("GT > LT"), self.assertRaises(TypeError):
1063 ImplementsGreaterThan(5) > ImplementsLessThan(5)
1064
1065 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
1066 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
1067
1068 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
1069 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
1070
1071 with self.subTest("GE when equal"):
1072 a = ComparatorNotImplemented(8)
1073 b = ComparatorNotImplemented(8)
1074 self.assertEqual(a, b)
1075 with self.assertRaises(TypeError):
1076 a >= b
1077
1078 with self.subTest("LE when equal"):
1079 a = ComparatorNotImplemented(9)
1080 b = ComparatorNotImplemented(9)
1081 self.assertEqual(a, b)
1082 with self.assertRaises(TypeError):
1083 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +02001084
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001085 def test_pickle(self):
Serhiy Storchaka92bb90a2016-09-22 11:39:25 +03001086 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001087 for name in '__lt__', '__gt__', '__le__', '__ge__':
1088 with self.subTest(method=name, proto=proto):
1089 method = getattr(Orderable_LT, name)
1090 method_copy = pickle.loads(pickle.dumps(method, proto))
1091 self.assertIs(method_copy, method)
1092
1093@functools.total_ordering
1094class Orderable_LT:
1095 def __init__(self, value):
1096 self.value = value
1097 def __lt__(self, other):
1098 return self.value < other.value
1099 def __eq__(self, other):
1100 return self.value == other.value
1101
1102
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001103class TestLRU:
Georg Brandl2e7346a2010-07-31 18:09:23 +00001104
1105 def test_lru(self):
1106 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001107 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001108 f = self.module.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001109 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001110 self.assertEqual(maxsize, 20)
1111 self.assertEqual(currsize, 0)
1112 self.assertEqual(hits, 0)
1113 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001114
1115 domain = range(5)
1116 for i in range(1000):
1117 x, y = choice(domain), choice(domain)
1118 actual = f(x, y)
1119 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +00001120 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001121 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001122 self.assertTrue(hits > misses)
1123 self.assertEqual(hits + misses, 1000)
1124 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001125
Raymond Hettinger02566ec2010-09-04 22:46:06 +00001126 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +00001127 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001128 self.assertEqual(hits, 0)
1129 self.assertEqual(misses, 0)
1130 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001131 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001132 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001133 self.assertEqual(hits, 0)
1134 self.assertEqual(misses, 1)
1135 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001136
Nick Coghlan98876832010-08-17 06:17:18 +00001137 # Test bypassing the cache
1138 self.assertIs(f.__wrapped__, orig)
1139 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001140 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001141 self.assertEqual(hits, 0)
1142 self.assertEqual(misses, 1)
1143 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +00001144
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001145 # test size zero (which means "never-cache")
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001146 @self.module.lru_cache(0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001147 def f():
1148 nonlocal f_cnt
1149 f_cnt += 1
1150 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001151 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001152 f_cnt = 0
1153 for i in range(5):
1154 self.assertEqual(f(), 20)
1155 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001156 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001157 self.assertEqual(hits, 0)
1158 self.assertEqual(misses, 5)
1159 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001160
1161 # test size one
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001162 @self.module.lru_cache(1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001163 def f():
1164 nonlocal f_cnt
1165 f_cnt += 1
1166 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001167 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001168 f_cnt = 0
1169 for i in range(5):
1170 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001171 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001172 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001173 self.assertEqual(hits, 4)
1174 self.assertEqual(misses, 1)
1175 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001176
Raymond Hettingerf3098282010-08-15 03:30:45 +00001177 # test size two
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001178 @self.module.lru_cache(2)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001179 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001180 nonlocal f_cnt
1181 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +00001182 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +00001183 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001184 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +00001185 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1186 # * * * *
1187 self.assertEqual(f(x), x*10)
1188 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001189 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001190 self.assertEqual(hits, 12)
1191 self.assertEqual(misses, 4)
1192 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001193
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08001194 def test_lru_reentrancy_with_len(self):
1195 # Test to make sure the LRU cache code isn't thrown-off by
1196 # caching the built-in len() function. Since len() can be
1197 # cached, we shouldn't use it inside the lru code itself.
1198 old_len = builtins.len
1199 try:
1200 builtins.len = self.module.lru_cache(4)(len)
1201 for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
1202 self.assertEqual(len('abcdefghijklmn'[:i]), i)
1203 finally:
1204 builtins.len = old_len
1205
Yury Selivanov46a02db2016-11-09 18:55:45 -05001206 def test_lru_type_error(self):
1207 # Regression test for issue #28653.
1208 # lru_cache was leaking when one of the arguments
1209 # wasn't cacheable.
1210
1211 @functools.lru_cache(maxsize=None)
1212 def infinite_cache(o):
1213 pass
1214
1215 @functools.lru_cache(maxsize=10)
1216 def limited_cache(o):
1217 pass
1218
1219 with self.assertRaises(TypeError):
1220 infinite_cache([])
1221
1222 with self.assertRaises(TypeError):
1223 limited_cache([])
1224
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001225 def test_lru_with_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001226 @self.module.lru_cache(maxsize=None)
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001227 def fib(n):
1228 if n < 2:
1229 return n
1230 return fib(n-1) + fib(n-2)
1231 self.assertEqual([fib(n) for n in range(16)],
1232 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1233 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001234 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001235 fib.cache_clear()
1236 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001237 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1238
1239 def test_lru_with_maxsize_negative(self):
1240 @self.module.lru_cache(maxsize=-10)
1241 def eq(n):
1242 return n
1243 for i in (0, 1):
1244 self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1245 self.assertEqual(eq.cache_info(),
1246 self.module._CacheInfo(hits=0, misses=300, maxsize=-10, currsize=1))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001247
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001248 def test_lru_with_exceptions(self):
1249 # Verify that user_function exceptions get passed through without
1250 # creating a hard-to-read chained exception.
1251 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +01001252 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001253 @self.module.lru_cache(maxsize)
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001254 def func(i):
1255 return 'abc'[i]
1256 self.assertEqual(func(0), 'a')
1257 with self.assertRaises(IndexError) as cm:
1258 func(15)
1259 self.assertIsNone(cm.exception.__context__)
1260 # Verify that the previous exception did not result in a cached entry
1261 with self.assertRaises(IndexError):
1262 func(15)
1263
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001264 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001265 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001266 @self.module.lru_cache(maxsize=maxsize, typed=True)
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001267 def square(x):
1268 return x * x
1269 self.assertEqual(square(3), 9)
1270 self.assertEqual(type(square(3)), type(9))
1271 self.assertEqual(square(3.0), 9.0)
1272 self.assertEqual(type(square(3.0)), type(9.0))
1273 self.assertEqual(square(x=3), 9)
1274 self.assertEqual(type(square(x=3)), type(9))
1275 self.assertEqual(square(x=3.0), 9.0)
1276 self.assertEqual(type(square(x=3.0)), type(9.0))
1277 self.assertEqual(square.cache_info().hits, 4)
1278 self.assertEqual(square.cache_info().misses, 4)
1279
Antoine Pitroub5b37142012-11-13 21:35:40 +01001280 def test_lru_with_keyword_args(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001281 @self.module.lru_cache()
Antoine Pitroub5b37142012-11-13 21:35:40 +01001282 def fib(n):
1283 if n < 2:
1284 return n
1285 return fib(n=n-1) + fib(n=n-2)
1286 self.assertEqual(
1287 [fib(n=number) for number in range(16)],
1288 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1289 )
1290 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001291 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001292 fib.cache_clear()
1293 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001294 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001295
1296 def test_lru_with_keyword_args_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001297 @self.module.lru_cache(maxsize=None)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001298 def fib(n):
1299 if n < 2:
1300 return n
1301 return fib(n=n-1) + fib(n=n-2)
1302 self.assertEqual([fib(n=number) for number in range(16)],
1303 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1304 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001305 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001306 fib.cache_clear()
1307 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001308 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1309
Raymond Hettinger4ee39142017-01-08 17:28:20 -08001310 def test_kwargs_order(self):
1311 # PEP 468: Preserving Keyword Argument Order
1312 @self.module.lru_cache(maxsize=10)
1313 def f(**kwargs):
1314 return list(kwargs.items())
1315 self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)])
1316 self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)])
1317 self.assertEqual(f.cache_info(),
1318 self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2))
1319
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001320 def test_lru_cache_decoration(self):
1321 def f(zomg: 'zomg_annotation'):
1322 """f doc string"""
1323 return 42
1324 g = self.module.lru_cache()(f)
1325 for attr in self.module.WRAPPER_ASSIGNMENTS:
1326 self.assertEqual(getattr(g, attr), getattr(f, attr))
1327
1328 @unittest.skipUnless(threading, 'This test requires threading.')
1329 def test_lru_cache_threaded(self):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001330 n, m = 5, 11
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001331 def orig(x, y):
1332 return 3 * x + y
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001333 f = self.module.lru_cache(maxsize=n*m)(orig)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001334 hits, misses, maxsize, currsize = f.cache_info()
1335 self.assertEqual(currsize, 0)
1336
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001337 start = threading.Event()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001338 def full(k):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001339 start.wait(10)
1340 for _ in range(m):
Serhiy Storchaka391af752015-06-08 12:44:18 +03001341 self.assertEqual(f(k, 0), orig(k, 0))
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001342
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001343 def clear():
1344 start.wait(10)
1345 for _ in range(2*m):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001346 f.cache_clear()
1347
1348 orig_si = sys.getswitchinterval()
Xavier de Gaye7522ef42016-12-08 11:06:56 +01001349 support.setswitchinterval(1e-6)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001350 try:
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001351 # create n threads in order to fill cache
Serhiy Storchaka391af752015-06-08 12:44:18 +03001352 threads = [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001353 for k in range(n)]
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001354 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001355 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001356
1357 hits, misses, maxsize, currsize = f.cache_info()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001358 if self.module is py_functools:
1359 # XXX: Why can be not equal?
1360 self.assertLessEqual(misses, n)
1361 self.assertLessEqual(hits, m*n - misses)
1362 else:
1363 self.assertEqual(misses, n)
1364 self.assertEqual(hits, m*n - misses)
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001365 self.assertEqual(currsize, n)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001366
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001367 # create n threads in order to fill cache and 1 to clear it
1368 threads = [threading.Thread(target=clear)]
Serhiy Storchaka391af752015-06-08 12:44:18 +03001369 threads += [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001370 for k in range(n)]
1371 start.clear()
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001372 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001373 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001374 finally:
1375 sys.setswitchinterval(orig_si)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001376
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001377 @unittest.skipUnless(threading, 'This test requires threading.')
1378 def test_lru_cache_threaded2(self):
1379 # Simultaneous call with the same arguments
1380 n, m = 5, 7
1381 start = threading.Barrier(n+1)
1382 pause = threading.Barrier(n+1)
1383 stop = threading.Barrier(n+1)
1384 @self.module.lru_cache(maxsize=m*n)
1385 def f(x):
1386 pause.wait(10)
1387 return 3 * x
1388 self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1389 def test():
1390 for i in range(m):
1391 start.wait(10)
1392 self.assertEqual(f(i), 3 * i)
1393 stop.wait(10)
1394 threads = [threading.Thread(target=test) for k in range(n)]
1395 with support.start_threads(threads):
1396 for i in range(m):
1397 start.wait(10)
1398 stop.reset()
1399 pause.wait(10)
1400 start.reset()
1401 stop.wait(10)
1402 pause.reset()
1403 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1404
Serhiy Storchaka67796522017-01-12 18:34:33 +02001405 @unittest.skipUnless(threading, 'This test requires threading.')
1406 def test_lru_cache_threaded3(self):
1407 @self.module.lru_cache(maxsize=2)
1408 def f(x):
1409 time.sleep(.01)
1410 return 3 * x
1411 def test(i, x):
1412 with self.subTest(thread=i):
1413 self.assertEqual(f(x), 3 * x, i)
1414 threads = [threading.Thread(target=test, args=(i, v))
1415 for i, v in enumerate([1, 2, 2, 3, 2])]
1416 with support.start_threads(threads):
1417 pass
1418
Raymond Hettinger03923422013-03-04 02:52:50 -05001419 def test_need_for_rlock(self):
1420 # This will deadlock on an LRU cache that uses a regular lock
1421
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001422 @self.module.lru_cache(maxsize=10)
Raymond Hettinger03923422013-03-04 02:52:50 -05001423 def test_func(x):
1424 'Used to demonstrate a reentrant lru_cache call within a single thread'
1425 return x
1426
1427 class DoubleEq:
1428 'Demonstrate a reentrant lru_cache call within a single thread'
1429 def __init__(self, x):
1430 self.x = x
1431 def __hash__(self):
1432 return self.x
1433 def __eq__(self, other):
1434 if self.x == 2:
1435 test_func(DoubleEq(1))
1436 return self.x == other.x
1437
1438 test_func(DoubleEq(1)) # Load the cache
1439 test_func(DoubleEq(2)) # Load the cache
1440 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1441 DoubleEq(2)) # Verify the correct return value
1442
Raymond Hettinger4d588972014-08-12 12:44:52 -07001443 def test_early_detection_of_bad_call(self):
1444 # Issue #22184
1445 with self.assertRaises(TypeError):
1446 @functools.lru_cache
1447 def f():
1448 pass
1449
Serhiy Storchakae7070f02015-06-08 11:19:24 +03001450 def test_lru_method(self):
1451 class X(int):
1452 f_cnt = 0
1453 @self.module.lru_cache(2)
1454 def f(self, x):
1455 self.f_cnt += 1
1456 return x*10+self
1457 a = X(5)
1458 b = X(5)
1459 c = X(7)
1460 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1461
1462 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1463 self.assertEqual(a.f(x), x*10 + 5)
1464 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1465 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1466
1467 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1468 self.assertEqual(b.f(x), x*10 + 5)
1469 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1470 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1471
1472 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1473 self.assertEqual(c.f(x), x*10 + 7)
1474 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1475 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1476
1477 self.assertEqual(a.f.cache_info(), X.f.cache_info())
1478 self.assertEqual(b.f.cache_info(), X.f.cache_info())
1479 self.assertEqual(c.f.cache_info(), X.f.cache_info())
1480
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001481 def test_pickle(self):
1482 cls = self.__class__
1483 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1484 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1485 with self.subTest(proto=proto, func=f):
1486 f_copy = pickle.loads(pickle.dumps(f, proto))
1487 self.assertIs(f_copy, f)
1488
1489 def test_copy(self):
1490 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001491 def orig(x, y):
1492 return 3 * x + y
1493 part = self.module.partial(orig, 2)
1494 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1495 self.module.lru_cache(2)(part))
1496 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001497 with self.subTest(func=f):
1498 f_copy = copy.copy(f)
1499 self.assertIs(f_copy, f)
1500
1501 def test_deepcopy(self):
1502 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001503 def orig(x, y):
1504 return 3 * x + y
1505 part = self.module.partial(orig, 2)
1506 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1507 self.module.lru_cache(2)(part))
1508 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001509 with self.subTest(func=f):
1510 f_copy = copy.deepcopy(f)
1511 self.assertIs(f_copy, f)
1512
1513
1514@py_functools.lru_cache()
1515def py_cached_func(x, y):
1516 return 3 * x + y
1517
1518@c_functools.lru_cache()
1519def c_cached_func(x, y):
1520 return 3 * x + y
1521
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001522
1523class TestLRUPy(TestLRU, unittest.TestCase):
1524 module = py_functools
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001525 cached_func = py_cached_func,
1526
1527 @module.lru_cache()
1528 def cached_meth(self, x, y):
1529 return 3 * x + y
1530
1531 @staticmethod
1532 @module.lru_cache()
1533 def cached_staticmeth(x, y):
1534 return 3 * x + y
1535
1536
1537class TestLRUC(TestLRU, unittest.TestCase):
1538 module = c_functools
1539 cached_func = c_cached_func,
1540
1541 @module.lru_cache()
1542 def cached_meth(self, x, y):
1543 return 3 * x + y
1544
1545 @staticmethod
1546 @module.lru_cache()
1547 def cached_staticmeth(x, y):
1548 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001549
Raymond Hettinger03923422013-03-04 02:52:50 -05001550
Łukasz Langa6f692512013-06-05 12:20:24 +02001551class TestSingleDispatch(unittest.TestCase):
1552 def test_simple_overloads(self):
1553 @functools.singledispatch
1554 def g(obj):
1555 return "base"
1556 def g_int(i):
1557 return "integer"
1558 g.register(int, g_int)
1559 self.assertEqual(g("str"), "base")
1560 self.assertEqual(g(1), "integer")
1561 self.assertEqual(g([1,2,3]), "base")
1562
1563 def test_mro(self):
1564 @functools.singledispatch
1565 def g(obj):
1566 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001567 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001568 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001569 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001570 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001571 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001572 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001573 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001574 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001575 def g_A(a):
1576 return "A"
1577 def g_B(b):
1578 return "B"
1579 g.register(A, g_A)
1580 g.register(B, g_B)
1581 self.assertEqual(g(A()), "A")
1582 self.assertEqual(g(B()), "B")
1583 self.assertEqual(g(C()), "A")
1584 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001585
1586 def test_register_decorator(self):
1587 @functools.singledispatch
1588 def g(obj):
1589 return "base"
1590 @g.register(int)
1591 def g_int(i):
1592 return "int %s" % (i,)
1593 self.assertEqual(g(""), "base")
1594 self.assertEqual(g(12), "int 12")
1595 self.assertIs(g.dispatch(int), g_int)
1596 self.assertIs(g.dispatch(object), g.dispatch(str))
1597 # Note: in the assert above this is not g.
1598 # @singledispatch returns the wrapper.
1599
1600 def test_wrapping_attributes(self):
1601 @functools.singledispatch
1602 def g(obj):
1603 "Simple test"
1604 return "Test"
1605 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02001606 if sys.flags.optimize < 2:
1607 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02001608
1609 @unittest.skipUnless(decimal, 'requires _decimal')
1610 @support.cpython_only
1611 def test_c_classes(self):
1612 @functools.singledispatch
1613 def g(obj):
1614 return "base"
1615 @g.register(decimal.DecimalException)
1616 def _(obj):
1617 return obj.args
1618 subn = decimal.Subnormal("Exponent < Emin")
1619 rnd = decimal.Rounded("Number got rounded")
1620 self.assertEqual(g(subn), ("Exponent < Emin",))
1621 self.assertEqual(g(rnd), ("Number got rounded",))
1622 @g.register(decimal.Subnormal)
1623 def _(obj):
1624 return "Too small to care."
1625 self.assertEqual(g(subn), "Too small to care.")
1626 self.assertEqual(g(rnd), ("Number got rounded",))
1627
1628 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001629 # None of the examples in this test depend on haystack ordering.
Łukasz Langa6f692512013-06-05 12:20:24 +02001630 c = collections
1631 mro = functools._compose_mro
1632 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1633 for haystack in permutations(bases):
1634 m = mro(dict, haystack)
Guido van Rossumf0666942016-08-23 10:47:07 -07001635 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
1636 c.Collection, c.Sized, c.Iterable,
1637 c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001638 bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict]
1639 for haystack in permutations(bases):
1640 m = mro(c.ChainMap, haystack)
1641 self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001642 c.Collection, c.Sized, c.Iterable,
1643 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001644
1645 # If there's a generic function with implementations registered for
1646 # both Sized and Container, passing a defaultdict to it results in an
1647 # ambiguous dispatch which will cause a RuntimeError (see
1648 # test_mro_conflicts).
1649 bases = [c.Container, c.Sized, str]
1650 for haystack in permutations(bases):
1651 m = mro(c.defaultdict, [c.Sized, c.Container, str])
1652 self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container,
1653 object])
1654
1655 # MutableSequence below is registered directly on D. In other words, it
Martin Panter46f50722016-05-26 05:35:26 +00001656 # precedes MutableMapping which means single dispatch will always
Łukasz Langa3720c772013-07-01 16:00:38 +02001657 # choose MutableSequence here.
1658 class D(c.defaultdict):
1659 pass
1660 c.MutableSequence.register(D)
1661 bases = [c.MutableSequence, c.MutableMapping]
1662 for haystack in permutations(bases):
1663 m = mro(D, bases)
Guido van Rossumf0666942016-08-23 10:47:07 -07001664 self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
1665 c.defaultdict, dict, c.MutableMapping, c.Mapping,
1666 c.Collection, c.Sized, c.Iterable, c.Container,
Łukasz Langa3720c772013-07-01 16:00:38 +02001667 object])
1668
1669 # Container and Callable are registered on different base classes and
1670 # a generic function supporting both should always pick the Callable
1671 # implementation if a C instance is passed.
1672 class C(c.defaultdict):
1673 def __call__(self):
1674 pass
1675 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1676 for haystack in permutations(bases):
1677 m = mro(C, haystack)
1678 self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001679 c.Collection, c.Sized, c.Iterable,
1680 c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001681
1682 def test_register_abc(self):
1683 c = collections
1684 d = {"a": "b"}
1685 l = [1, 2, 3]
1686 s = {object(), None}
1687 f = frozenset(s)
1688 t = (1, 2, 3)
1689 @functools.singledispatch
1690 def g(obj):
1691 return "base"
1692 self.assertEqual(g(d), "base")
1693 self.assertEqual(g(l), "base")
1694 self.assertEqual(g(s), "base")
1695 self.assertEqual(g(f), "base")
1696 self.assertEqual(g(t), "base")
1697 g.register(c.Sized, lambda obj: "sized")
1698 self.assertEqual(g(d), "sized")
1699 self.assertEqual(g(l), "sized")
1700 self.assertEqual(g(s), "sized")
1701 self.assertEqual(g(f), "sized")
1702 self.assertEqual(g(t), "sized")
1703 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1704 self.assertEqual(g(d), "mutablemapping")
1705 self.assertEqual(g(l), "sized")
1706 self.assertEqual(g(s), "sized")
1707 self.assertEqual(g(f), "sized")
1708 self.assertEqual(g(t), "sized")
1709 g.register(c.ChainMap, lambda obj: "chainmap")
1710 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1711 self.assertEqual(g(l), "sized")
1712 self.assertEqual(g(s), "sized")
1713 self.assertEqual(g(f), "sized")
1714 self.assertEqual(g(t), "sized")
1715 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1716 self.assertEqual(g(d), "mutablemapping")
1717 self.assertEqual(g(l), "mutablesequence")
1718 self.assertEqual(g(s), "sized")
1719 self.assertEqual(g(f), "sized")
1720 self.assertEqual(g(t), "sized")
1721 g.register(c.MutableSet, lambda obj: "mutableset")
1722 self.assertEqual(g(d), "mutablemapping")
1723 self.assertEqual(g(l), "mutablesequence")
1724 self.assertEqual(g(s), "mutableset")
1725 self.assertEqual(g(f), "sized")
1726 self.assertEqual(g(t), "sized")
1727 g.register(c.Mapping, lambda obj: "mapping")
1728 self.assertEqual(g(d), "mutablemapping") # not specific enough
1729 self.assertEqual(g(l), "mutablesequence")
1730 self.assertEqual(g(s), "mutableset")
1731 self.assertEqual(g(f), "sized")
1732 self.assertEqual(g(t), "sized")
1733 g.register(c.Sequence, lambda obj: "sequence")
1734 self.assertEqual(g(d), "mutablemapping")
1735 self.assertEqual(g(l), "mutablesequence")
1736 self.assertEqual(g(s), "mutableset")
1737 self.assertEqual(g(f), "sized")
1738 self.assertEqual(g(t), "sequence")
1739 g.register(c.Set, lambda obj: "set")
1740 self.assertEqual(g(d), "mutablemapping")
1741 self.assertEqual(g(l), "mutablesequence")
1742 self.assertEqual(g(s), "mutableset")
1743 self.assertEqual(g(f), "set")
1744 self.assertEqual(g(t), "sequence")
1745 g.register(dict, lambda obj: "dict")
1746 self.assertEqual(g(d), "dict")
1747 self.assertEqual(g(l), "mutablesequence")
1748 self.assertEqual(g(s), "mutableset")
1749 self.assertEqual(g(f), "set")
1750 self.assertEqual(g(t), "sequence")
1751 g.register(list, lambda obj: "list")
1752 self.assertEqual(g(d), "dict")
1753 self.assertEqual(g(l), "list")
1754 self.assertEqual(g(s), "mutableset")
1755 self.assertEqual(g(f), "set")
1756 self.assertEqual(g(t), "sequence")
1757 g.register(set, lambda obj: "concrete-set")
1758 self.assertEqual(g(d), "dict")
1759 self.assertEqual(g(l), "list")
1760 self.assertEqual(g(s), "concrete-set")
1761 self.assertEqual(g(f), "set")
1762 self.assertEqual(g(t), "sequence")
1763 g.register(frozenset, lambda obj: "frozen-set")
1764 self.assertEqual(g(d), "dict")
1765 self.assertEqual(g(l), "list")
1766 self.assertEqual(g(s), "concrete-set")
1767 self.assertEqual(g(f), "frozen-set")
1768 self.assertEqual(g(t), "sequence")
1769 g.register(tuple, lambda obj: "tuple")
1770 self.assertEqual(g(d), "dict")
1771 self.assertEqual(g(l), "list")
1772 self.assertEqual(g(s), "concrete-set")
1773 self.assertEqual(g(f), "frozen-set")
1774 self.assertEqual(g(t), "tuple")
1775
Łukasz Langa3720c772013-07-01 16:00:38 +02001776 def test_c3_abc(self):
1777 c = collections
1778 mro = functools._c3_mro
1779 class A(object):
1780 pass
1781 class B(A):
1782 def __len__(self):
1783 return 0 # implies Sized
1784 @c.Container.register
1785 class C(object):
1786 pass
1787 class D(object):
1788 pass # unrelated
1789 class X(D, C, B):
1790 def __call__(self):
1791 pass # implies Callable
1792 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
1793 for abcs in permutations([c.Sized, c.Callable, c.Container]):
1794 self.assertEqual(mro(X, abcs=abcs), expected)
1795 # unrelated ABCs don't appear in the resulting MRO
1796 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
1797 self.assertEqual(mro(X, abcs=many_abcs), expected)
1798
Yury Selivanov77a8cd62015-08-18 14:20:00 -04001799 def test_false_meta(self):
1800 # see issue23572
1801 class MetaA(type):
1802 def __len__(self):
1803 return 0
1804 class A(metaclass=MetaA):
1805 pass
1806 class AA(A):
1807 pass
1808 @functools.singledispatch
1809 def fun(a):
1810 return 'base A'
1811 @fun.register(A)
1812 def _(a):
1813 return 'fun A'
1814 aa = AA()
1815 self.assertEqual(fun(aa), 'fun A')
1816
Łukasz Langa6f692512013-06-05 12:20:24 +02001817 def test_mro_conflicts(self):
1818 c = collections
Łukasz Langa6f692512013-06-05 12:20:24 +02001819 @functools.singledispatch
1820 def g(arg):
1821 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02001822 class O(c.Sized):
1823 def __len__(self):
1824 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001825 o = O()
1826 self.assertEqual(g(o), "base")
1827 g.register(c.Iterable, lambda arg: "iterable")
1828 g.register(c.Container, lambda arg: "container")
1829 g.register(c.Sized, lambda arg: "sized")
1830 g.register(c.Set, lambda arg: "set")
1831 self.assertEqual(g(o), "sized")
1832 c.Iterable.register(O)
1833 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
1834 c.Container.register(O)
1835 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02001836 c.Set.register(O)
1837 self.assertEqual(g(o), "set") # because c.Set is a subclass of
1838 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02001839 class P:
1840 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02001841 p = P()
1842 self.assertEqual(g(p), "base")
1843 c.Iterable.register(P)
1844 self.assertEqual(g(p), "iterable")
1845 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02001846 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02001847 g(p)
Benjamin Petersonab078e92016-07-13 21:13:29 -07001848 self.assertIn(
1849 str(re_one.exception),
1850 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1851 "or <class 'collections.abc.Iterable'>"),
1852 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
1853 "or <class 'collections.abc.Container'>")),
1854 )
Łukasz Langa6f692512013-06-05 12:20:24 +02001855 class Q(c.Sized):
1856 def __len__(self):
1857 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001858 q = Q()
1859 self.assertEqual(g(q), "sized")
1860 c.Iterable.register(Q)
1861 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
1862 c.Set.register(Q)
1863 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02001864 # c.Sized and c.Iterable
1865 @functools.singledispatch
1866 def h(arg):
1867 return "base"
1868 @h.register(c.Sized)
1869 def _(arg):
1870 return "sized"
1871 @h.register(c.Container)
1872 def _(arg):
1873 return "container"
1874 # Even though Sized and Container are explicit bases of MutableMapping,
1875 # this ABC is implicitly registered on defaultdict which makes all of
1876 # MutableMapping's bases implicit as well from defaultdict's
1877 # perspective.
1878 with self.assertRaises(RuntimeError) as re_two:
1879 h(c.defaultdict(lambda: 0))
Benjamin Petersonab078e92016-07-13 21:13:29 -07001880 self.assertIn(
1881 str(re_two.exception),
1882 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1883 "or <class 'collections.abc.Sized'>"),
1884 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1885 "or <class 'collections.abc.Container'>")),
1886 )
Łukasz Langa3720c772013-07-01 16:00:38 +02001887 class R(c.defaultdict):
1888 pass
1889 c.MutableSequence.register(R)
1890 @functools.singledispatch
1891 def i(arg):
1892 return "base"
1893 @i.register(c.MutableMapping)
1894 def _(arg):
1895 return "mapping"
1896 @i.register(c.MutableSequence)
1897 def _(arg):
1898 return "sequence"
1899 r = R()
1900 self.assertEqual(i(r), "sequence")
1901 class S:
1902 pass
1903 class T(S, c.Sized):
1904 def __len__(self):
1905 return 0
1906 t = T()
1907 self.assertEqual(h(t), "sized")
1908 c.Container.register(T)
1909 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
1910 class U:
1911 def __len__(self):
1912 return 0
1913 u = U()
1914 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
1915 # from the existence of __len__()
1916 c.Container.register(U)
1917 # There is no preference for registered versus inferred ABCs.
1918 with self.assertRaises(RuntimeError) as re_three:
1919 h(u)
Benjamin Petersonab078e92016-07-13 21:13:29 -07001920 self.assertIn(
1921 str(re_three.exception),
1922 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1923 "or <class 'collections.abc.Sized'>"),
1924 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1925 "or <class 'collections.abc.Container'>")),
1926 )
Łukasz Langa3720c772013-07-01 16:00:38 +02001927 class V(c.Sized, S):
1928 def __len__(self):
1929 return 0
1930 @functools.singledispatch
1931 def j(arg):
1932 return "base"
1933 @j.register(S)
1934 def _(arg):
1935 return "s"
1936 @j.register(c.Container)
1937 def _(arg):
1938 return "container"
1939 v = V()
1940 self.assertEqual(j(v), "s")
1941 c.Container.register(V)
1942 self.assertEqual(j(v), "container") # because it ends up right after
1943 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02001944
1945 def test_cache_invalidation(self):
1946 from collections import UserDict
1947 class TracingDict(UserDict):
1948 def __init__(self, *args, **kwargs):
1949 super(TracingDict, self).__init__(*args, **kwargs)
1950 self.set_ops = []
1951 self.get_ops = []
1952 def __getitem__(self, key):
1953 result = self.data[key]
1954 self.get_ops.append(key)
1955 return result
1956 def __setitem__(self, key, value):
1957 self.set_ops.append(key)
1958 self.data[key] = value
1959 def clear(self):
1960 self.data.clear()
1961 _orig_wkd = functools.WeakKeyDictionary
1962 td = TracingDict()
1963 functools.WeakKeyDictionary = lambda: td
1964 c = collections
1965 @functools.singledispatch
1966 def g(arg):
1967 return "base"
1968 d = {}
1969 l = []
1970 self.assertEqual(len(td), 0)
1971 self.assertEqual(g(d), "base")
1972 self.assertEqual(len(td), 1)
1973 self.assertEqual(td.get_ops, [])
1974 self.assertEqual(td.set_ops, [dict])
1975 self.assertEqual(td.data[dict], g.registry[object])
1976 self.assertEqual(g(l), "base")
1977 self.assertEqual(len(td), 2)
1978 self.assertEqual(td.get_ops, [])
1979 self.assertEqual(td.set_ops, [dict, list])
1980 self.assertEqual(td.data[dict], g.registry[object])
1981 self.assertEqual(td.data[list], g.registry[object])
1982 self.assertEqual(td.data[dict], td.data[list])
1983 self.assertEqual(g(l), "base")
1984 self.assertEqual(g(d), "base")
1985 self.assertEqual(td.get_ops, [list, dict])
1986 self.assertEqual(td.set_ops, [dict, list])
1987 g.register(list, lambda arg: "list")
1988 self.assertEqual(td.get_ops, [list, dict])
1989 self.assertEqual(len(td), 0)
1990 self.assertEqual(g(d), "base")
1991 self.assertEqual(len(td), 1)
1992 self.assertEqual(td.get_ops, [list, dict])
1993 self.assertEqual(td.set_ops, [dict, list, dict])
1994 self.assertEqual(td.data[dict],
1995 functools._find_impl(dict, g.registry))
1996 self.assertEqual(g(l), "list")
1997 self.assertEqual(len(td), 2)
1998 self.assertEqual(td.get_ops, [list, dict])
1999 self.assertEqual(td.set_ops, [dict, list, dict, list])
2000 self.assertEqual(td.data[list],
2001 functools._find_impl(list, g.registry))
2002 class X:
2003 pass
2004 c.MutableMapping.register(X) # Will not invalidate the cache,
2005 # not using ABCs yet.
2006 self.assertEqual(g(d), "base")
2007 self.assertEqual(g(l), "list")
2008 self.assertEqual(td.get_ops, [list, dict, dict, list])
2009 self.assertEqual(td.set_ops, [dict, list, dict, list])
2010 g.register(c.Sized, lambda arg: "sized")
2011 self.assertEqual(len(td), 0)
2012 self.assertEqual(g(d), "sized")
2013 self.assertEqual(len(td), 1)
2014 self.assertEqual(td.get_ops, [list, dict, dict, list])
2015 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
2016 self.assertEqual(g(l), "list")
2017 self.assertEqual(len(td), 2)
2018 self.assertEqual(td.get_ops, [list, dict, dict, list])
2019 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2020 self.assertEqual(g(l), "list")
2021 self.assertEqual(g(d), "sized")
2022 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
2023 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2024 g.dispatch(list)
2025 g.dispatch(dict)
2026 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
2027 list, dict])
2028 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2029 c.MutableSet.register(X) # Will invalidate the cache.
2030 self.assertEqual(len(td), 2) # Stale cache.
2031 self.assertEqual(g(l), "list")
2032 self.assertEqual(len(td), 1)
2033 g.register(c.MutableMapping, lambda arg: "mutablemapping")
2034 self.assertEqual(len(td), 0)
2035 self.assertEqual(g(d), "mutablemapping")
2036 self.assertEqual(len(td), 1)
2037 self.assertEqual(g(l), "list")
2038 self.assertEqual(len(td), 2)
2039 g.register(dict, lambda arg: "dict")
2040 self.assertEqual(g(d), "dict")
2041 self.assertEqual(g(l), "list")
2042 g._clear_cache()
2043 self.assertEqual(len(td), 0)
2044 functools.WeakKeyDictionary = _orig_wkd
2045
2046
Raymond Hettinger9c323f82005-02-28 19:39:44 +00002047if __name__ == '__main__':
Zachary Ware38c707e2015-04-13 15:00:43 -05002048 unittest.main()