blob: 3a40861594693f18d0d2ab8b2e711ae24ec723f8 [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08002import builtins
Raymond Hettinger003be522011-05-03 11:01:32 -07003import collections
Serhiy Storchaka45120f22015-10-24 09:49:56 +03004import copy
Łukasz Langa6f692512013-06-05 12:20:24 +02005from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00006import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00007from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02008import sys
9from test import support
10import unittest
11from weakref import proxy
Nick Coghlan457fc9a2016-09-10 20:00:02 +100012import contextlib
Serhiy Storchaka46c56112015-05-24 21:53:49 +030013try:
14 import threading
15except ImportError:
16 threading = None
Raymond Hettinger9c323f82005-02-28 19:39:44 +000017
Antoine Pitroub5b37142012-11-13 21:35:40 +010018import functools
19
Antoine Pitroub5b37142012-11-13 21:35:40 +010020py_functools = support.import_fresh_module('functools', blocked=['_functools'])
21c_functools = support.import_fresh_module('functools', fresh=['_functools'])
22
Łukasz Langa6f692512013-06-05 12:20:24 +020023decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
24
Nick Coghlan457fc9a2016-09-10 20:00:02 +100025@contextlib.contextmanager
26def replaced_module(name, replacement):
27 original_module = sys.modules[name]
28 sys.modules[name] = replacement
29 try:
30 yield
31 finally:
32 sys.modules[name] = original_module
Łukasz Langa6f692512013-06-05 12:20:24 +020033
Raymond Hettinger9c323f82005-02-28 19:39:44 +000034def capture(*args, **kw):
35 """capture all positional and keyword arguments"""
36 return args, kw
37
Łukasz Langa6f692512013-06-05 12:20:24 +020038
Jack Diederiche0cbd692009-04-01 04:27:09 +000039def signature(part):
40 """ return the signature of a partial object """
41 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000042
Serhiy Storchaka38741282016-02-02 18:45:17 +020043class MyTuple(tuple):
44 pass
45
46class BadTuple(tuple):
47 def __add__(self, other):
48 return list(self) + list(other)
49
50class MyDict(dict):
51 pass
52
Łukasz Langa6f692512013-06-05 12:20:24 +020053
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020054class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000055
56 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010057 p = self.partial(capture, 1, 2, a=10, b=20)
58 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000059 self.assertEqual(p(3, 4, b=30, c=40),
60 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010061 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000062 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000063
64 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010065 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000066 # attributes should be readable
67 self.assertEqual(p.func, capture)
68 self.assertEqual(p.args, (1, 2))
69 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000070
71 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010072 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000073 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010074 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000075 except TypeError:
76 pass
77 else:
78 self.fail('First arg not checked for callability')
79
80 def test_protection_of_callers_dict_argument(self):
81 # a caller's dictionary should not be altered by partial
82 def func(a=10, b=20):
83 return a
84 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010085 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000086 self.assertEqual(p(**d), 3)
87 self.assertEqual(d, {'a':3})
88 p(b=7)
89 self.assertEqual(d, {'a':3})
90
91 def test_arg_combinations(self):
92 # exercise special code paths for zero args in either partial
93 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010094 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000095 self.assertEqual(p(), ((), {}))
96 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010097 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000098 self.assertEqual(p(), ((1,2), {}))
99 self.assertEqual(p(3,4), ((1,2,3,4), {}))
100
101 def test_kw_combinations(self):
102 # exercise special code paths for no keyword args in
103 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100104 p = self.partial(capture)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400105 self.assertEqual(p.keywords, {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000106 self.assertEqual(p(), ((), {}))
107 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100108 p = self.partial(capture, a=1)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400109 self.assertEqual(p.keywords, {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000110 self.assertEqual(p(), ((), {'a':1}))
111 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
112 # keyword args in the call override those in the partial object
113 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
114
115 def test_positional(self):
116 # make sure positional arguments are captured correctly
117 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100118 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000119 expected = args + ('x',)
120 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000121 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000122
123 def test_keyword(self):
124 # make sure keyword arguments are captured correctly
125 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100126 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000127 expected = {'a':a,'x':None}
128 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000129 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000130
131 def test_no_side_effects(self):
132 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100133 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000134 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000135 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000136 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000137 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000138
139 def test_error_propagation(self):
140 def f(x, y):
141 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100142 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
143 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
144 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
145 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000146
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000147 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100148 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000149 p = proxy(f)
150 self.assertEqual(f.func, p.func)
151 f = None
152 self.assertRaises(ReferenceError, getattr, p, 'func')
153
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000154 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000155 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100156 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000157 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100158 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000159 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000160
Alexander Belopolskye49af342015-03-01 15:08:17 -0500161 def test_nested_optimization(self):
162 partial = self.partial
Alexander Belopolskye49af342015-03-01 15:08:17 -0500163 inner = partial(signature, 'asdf')
164 nested = partial(inner, bar=True)
165 flat = partial(signature, 'asdf', bar=True)
166 self.assertEqual(signature(nested), signature(flat))
167
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300168 def test_nested_partial_with_attribute(self):
169 # see issue 25137
170 partial = self.partial
171
172 def foo(bar):
173 return bar
174
175 p = partial(foo, 'first')
176 p2 = partial(p, 'second')
177 p2.new_attr = 'spam'
178 self.assertEqual(p2.new_attr, 'spam')
179
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000180 def test_repr(self):
181 args = (object(), object())
182 args_repr = ', '.join(repr(a) for a in args)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200183 kwargs = {'a': object(), 'b': object()}
184 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
185 'b={b!r}, a={a!r}'.format_map(kwargs)]
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000186 if self.partial in (c_functools.partial, py_functools.partial):
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000187 name = 'functools.partial'
188 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100189 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000190
Antoine Pitroub5b37142012-11-13 21:35:40 +0100191 f = self.partial(capture)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000192 self.assertEqual(f'{name}({capture!r})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000193
Antoine Pitroub5b37142012-11-13 21:35:40 +0100194 f = self.partial(capture, *args)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000195 self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000196
Antoine Pitroub5b37142012-11-13 21:35:40 +0100197 f = self.partial(capture, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200198 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000199 [f'{name}({capture!r}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200200 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000201
Antoine Pitroub5b37142012-11-13 21:35:40 +0100202 f = self.partial(capture, *args, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200203 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000204 [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200205 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000206
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300207 def test_recursive_repr(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000208 if self.partial in (c_functools.partial, py_functools.partial):
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300209 name = 'functools.partial'
210 else:
211 name = self.partial.__name__
212
213 f = self.partial(capture)
214 f.__setstate__((f, (), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300215 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000216 self.assertEqual(repr(f), '%s(...)' % (name,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300217 finally:
218 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300219
220 f = self.partial(capture)
221 f.__setstate__((capture, (f,), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300222 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000223 self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300224 finally:
225 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300226
227 f = self.partial(capture)
228 f.__setstate__((capture, (), {'a': f}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300229 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000230 self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300231 finally:
232 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300233
Jack Diederiche0cbd692009-04-01 04:27:09 +0000234 def test_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000235 with self.AllowPickle():
236 f = self.partial(signature, ['asdf'], bar=[True])
237 f.attr = []
238 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
239 f_copy = pickle.loads(pickle.dumps(f, proto))
240 self.assertEqual(signature(f_copy), signature(f))
Serhiy Storchaka38741282016-02-02 18:45:17 +0200241
242 def test_copy(self):
243 f = self.partial(signature, ['asdf'], bar=[True])
244 f.attr = []
245 f_copy = copy.copy(f)
246 self.assertEqual(signature(f_copy), signature(f))
247 self.assertIs(f_copy.attr, f.attr)
248 self.assertIs(f_copy.args, f.args)
249 self.assertIs(f_copy.keywords, f.keywords)
250
251 def test_deepcopy(self):
252 f = self.partial(signature, ['asdf'], bar=[True])
253 f.attr = []
254 f_copy = copy.deepcopy(f)
255 self.assertEqual(signature(f_copy), signature(f))
256 self.assertIsNot(f_copy.attr, f.attr)
257 self.assertIsNot(f_copy.args, f.args)
258 self.assertIsNot(f_copy.args[0], f.args[0])
259 self.assertIsNot(f_copy.keywords, f.keywords)
260 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
261
262 def test_setstate(self):
263 f = self.partial(signature)
264 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000265
Serhiy Storchaka38741282016-02-02 18:45:17 +0200266 self.assertEqual(signature(f),
267 (capture, (1,), dict(a=10), dict(attr=[])))
268 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
269
270 f.__setstate__((capture, (1,), dict(a=10), None))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000271
Serhiy Storchaka38741282016-02-02 18:45:17 +0200272 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
273 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
274
275 f.__setstate__((capture, (1,), None, None))
276 #self.assertEqual(signature(f), (capture, (1,), {}, {}))
277 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
278 self.assertEqual(f(2), ((1, 2), {}))
279 self.assertEqual(f(), ((1,), {}))
280
281 f.__setstate__((capture, (), {}, None))
282 self.assertEqual(signature(f), (capture, (), {}, {}))
283 self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
284 self.assertEqual(f(2), ((2,), {}))
285 self.assertEqual(f(), ((), {}))
286
287 def test_setstate_errors(self):
288 f = self.partial(signature)
289 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
290 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
291 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
292 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
293 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
294 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
295 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
296
297 def test_setstate_subclasses(self):
298 f = self.partial(signature)
299 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
300 s = signature(f)
301 self.assertEqual(s, (capture, (1,), dict(a=10), {}))
302 self.assertIs(type(s[1]), tuple)
303 self.assertIs(type(s[2]), dict)
304 r = f()
305 self.assertEqual(r, ((1,), {'a': 10}))
306 self.assertIs(type(r[0]), tuple)
307 self.assertIs(type(r[1]), dict)
308
309 f.__setstate__((capture, BadTuple((1,)), {}, None))
310 s = signature(f)
311 self.assertEqual(s, (capture, (1,), {}, {}))
312 self.assertIs(type(s[1]), tuple)
313 r = f(2)
314 self.assertEqual(r, ((1, 2), {}))
315 self.assertIs(type(r[0]), tuple)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000316
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300317 def test_recursive_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000318 with self.AllowPickle():
319 f = self.partial(capture)
320 f.__setstate__((f, (), {}, {}))
321 try:
322 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
323 with self.assertRaises(RecursionError):
324 pickle.dumps(f, proto)
325 finally:
326 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300327
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000328 f = self.partial(capture)
329 f.__setstate__((capture, (f,), {}, {}))
330 try:
331 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
332 f_copy = pickle.loads(pickle.dumps(f, proto))
333 try:
334 self.assertIs(f_copy.args[0], f_copy)
335 finally:
336 f_copy.__setstate__((capture, (), {}, {}))
337 finally:
338 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300339
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000340 f = self.partial(capture)
341 f.__setstate__((capture, (), {'a': f}, {}))
342 try:
343 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
344 f_copy = pickle.loads(pickle.dumps(f, proto))
345 try:
346 self.assertIs(f_copy.keywords['a'], f_copy)
347 finally:
348 f_copy.__setstate__((capture, (), {}, {}))
349 finally:
350 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300351
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200352 # Issue 6083: Reference counting bug
353 def test_setstate_refcount(self):
354 class BadSequence:
355 def __len__(self):
356 return 4
357 def __getitem__(self, key):
358 if key == 0:
359 return max
360 elif key == 1:
361 return tuple(range(1000000))
362 elif key in (2, 3):
363 return {}
364 raise IndexError
365
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200366 f = self.partial(object)
Serhiy Storchaka38741282016-02-02 18:45:17 +0200367 self.assertRaises(TypeError, f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000368
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000369@unittest.skipUnless(c_functools, 'requires the C _functools module')
370class TestPartialC(TestPartial, unittest.TestCase):
371 if c_functools:
372 partial = c_functools.partial
373
374 class AllowPickle:
375 def __enter__(self):
376 return self
377 def __exit__(self, type, value, tb):
378 return False
379
380 def test_attributes_unwritable(self):
381 # attributes should not be writable
382 p = self.partial(capture, 1, 2, a=10, b=20)
383 self.assertRaises(AttributeError, setattr, p, 'func', map)
384 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
385 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
386
387 p = self.partial(hex)
388 try:
389 del p.__dict__
390 except TypeError:
391 pass
392 else:
393 self.fail('partial object allowed __dict__ to be deleted')
Łukasz Langa6f692512013-06-05 12:20:24 +0200394
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200395class TestPartialPy(TestPartial, unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000396 partial = py_functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000397
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000398 class AllowPickle:
399 def __init__(self):
400 self._cm = replaced_module("functools", py_functools)
401 def __enter__(self):
402 return self._cm.__enter__()
403 def __exit__(self, type, value, tb):
404 return self._cm.__exit__(type, value, tb)
Łukasz Langa6f692512013-06-05 12:20:24 +0200405
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200406if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000407 class CPartialSubclass(c_functools.partial):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200408 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100409
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000410class PyPartialSubclass(py_functools.partial):
411 pass
Łukasz Langa6f692512013-06-05 12:20:24 +0200412
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200413@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200414class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200415 if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000416 partial = CPartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000417
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300418 # partial subclasses are not optimized for nested calls
419 test_nested_optimization = None
420
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000421class TestPartialPySubclass(TestPartialPy):
422 partial = PyPartialSubclass
Łukasz Langa6f692512013-06-05 12:20:24 +0200423
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000424class TestPartialMethod(unittest.TestCase):
425
426 class A(object):
427 nothing = functools.partialmethod(capture)
428 positional = functools.partialmethod(capture, 1)
429 keywords = functools.partialmethod(capture, a=2)
430 both = functools.partialmethod(capture, 3, b=4)
431
432 nested = functools.partialmethod(positional, 5)
433
434 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
435
436 static = functools.partialmethod(staticmethod(capture), 8)
437 cls = functools.partialmethod(classmethod(capture), d=9)
438
439 a = A()
440
441 def test_arg_combinations(self):
442 self.assertEqual(self.a.nothing(), ((self.a,), {}))
443 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
444 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
445 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
446
447 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
448 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
449 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
450 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
451
452 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
453 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
454 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
455 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
456
457 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
458 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
459 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
460 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
461
462 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
463
464 def test_nested(self):
465 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
466 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
467 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
468 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
469
470 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
471
472 def test_over_partial(self):
473 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
474 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
475 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
476 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
477
478 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
479
480 def test_bound_method_introspection(self):
481 obj = self.a
482 self.assertIs(obj.both.__self__, obj)
483 self.assertIs(obj.nested.__self__, obj)
484 self.assertIs(obj.over_partial.__self__, obj)
485 self.assertIs(obj.cls.__self__, self.A)
486 self.assertIs(self.A.cls.__self__, self.A)
487
488 def test_unbound_method_retrieval(self):
489 obj = self.A
490 self.assertFalse(hasattr(obj.both, "__self__"))
491 self.assertFalse(hasattr(obj.nested, "__self__"))
492 self.assertFalse(hasattr(obj.over_partial, "__self__"))
493 self.assertFalse(hasattr(obj.static, "__self__"))
494 self.assertFalse(hasattr(self.a.static, "__self__"))
495
496 def test_descriptors(self):
497 for obj in [self.A, self.a]:
498 with self.subTest(obj=obj):
499 self.assertEqual(obj.static(), ((8,), {}))
500 self.assertEqual(obj.static(5), ((8, 5), {}))
501 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
502 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
503
504 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
505 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
506 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
507 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
508
509 def test_overriding_keywords(self):
510 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
511 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
512
513 def test_invalid_args(self):
514 with self.assertRaises(TypeError):
515 class B(object):
516 method = functools.partialmethod(None, 1)
517
518 def test_repr(self):
519 self.assertEqual(repr(vars(self.A)['both']),
520 'functools.partialmethod({}, 3, b=4)'.format(capture))
521
522 def test_abstract(self):
523 class Abstract(abc.ABCMeta):
524
525 @abc.abstractmethod
526 def add(self, x, y):
527 pass
528
529 add5 = functools.partialmethod(add, 5)
530
531 self.assertTrue(Abstract.add.__isabstractmethod__)
532 self.assertTrue(Abstract.add5.__isabstractmethod__)
533
534 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
535 self.assertFalse(getattr(func, '__isabstractmethod__', False))
536
537
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000538class TestUpdateWrapper(unittest.TestCase):
539
540 def check_wrapper(self, wrapper, wrapped,
541 assigned=functools.WRAPPER_ASSIGNMENTS,
542 updated=functools.WRAPPER_UPDATES):
543 # Check attributes were assigned
544 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000545 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000546 # Check attributes were updated
547 for name in updated:
548 wrapper_attr = getattr(wrapper, name)
549 wrapped_attr = getattr(wrapped, name)
550 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000551 if name == "__dict__" and key == "__wrapped__":
552 # __wrapped__ is overwritten by the update code
553 continue
554 self.assertIs(wrapped_attr[key], wrapper_attr[key])
555 # Check __wrapped__
556 self.assertIs(wrapper.__wrapped__, wrapped)
557
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000558
R. David Murray378c0cf2010-02-24 01:46:21 +0000559 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000560 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000561 """This is a test"""
562 pass
563 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000564 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000565 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000566 pass
567 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000568 return wrapper, f
569
570 def test_default_update(self):
571 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000572 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000573 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000574 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600575 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000576 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000577 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
578 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000579
R. David Murray378c0cf2010-02-24 01:46:21 +0000580 @unittest.skipIf(sys.flags.optimize >= 2,
581 "Docstrings are omitted with -O2 and above")
582 def test_default_update_doc(self):
583 wrapper, f = self._default_update()
584 self.assertEqual(wrapper.__doc__, 'This is a test')
585
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000586 def test_no_update(self):
587 def f():
588 """This is a test"""
589 pass
590 f.attr = 'This is also a test'
591 def wrapper():
592 pass
593 functools.update_wrapper(wrapper, f, (), ())
594 self.check_wrapper(wrapper, f, (), ())
595 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600596 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000597 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000598 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000599 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000600
601 def test_selective_update(self):
602 def f():
603 pass
604 f.attr = 'This is a different test'
605 f.dict_attr = dict(a=1, b=2, c=3)
606 def wrapper():
607 pass
608 wrapper.dict_attr = {}
609 assign = ('attr',)
610 update = ('dict_attr',)
611 functools.update_wrapper(wrapper, f, assign, update)
612 self.check_wrapper(wrapper, f, assign, update)
613 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600614 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000615 self.assertEqual(wrapper.__doc__, None)
616 self.assertEqual(wrapper.attr, 'This is a different test')
617 self.assertEqual(wrapper.dict_attr, f.dict_attr)
618
Nick Coghlan98876832010-08-17 06:17:18 +0000619 def test_missing_attributes(self):
620 def f():
621 pass
622 def wrapper():
623 pass
624 wrapper.dict_attr = {}
625 assign = ('attr',)
626 update = ('dict_attr',)
627 # Missing attributes on wrapped object are ignored
628 functools.update_wrapper(wrapper, f, assign, update)
629 self.assertNotIn('attr', wrapper.__dict__)
630 self.assertEqual(wrapper.dict_attr, {})
631 # Wrapper must have expected attributes for updating
632 del wrapper.dict_attr
633 with self.assertRaises(AttributeError):
634 functools.update_wrapper(wrapper, f, assign, update)
635 wrapper.dict_attr = 1
636 with self.assertRaises(AttributeError):
637 functools.update_wrapper(wrapper, f, assign, update)
638
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200639 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000640 @unittest.skipIf(sys.flags.optimize >= 2,
641 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000642 def test_builtin_update(self):
643 # Test for bug #1576241
644 def wrapper():
645 pass
646 functools.update_wrapper(wrapper, max)
647 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000648 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000649 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000650
Łukasz Langa6f692512013-06-05 12:20:24 +0200651
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000652class TestWraps(TestUpdateWrapper):
653
R. David Murray378c0cf2010-02-24 01:46:21 +0000654 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000655 def f():
656 """This is a test"""
657 pass
658 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000659 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000660 @functools.wraps(f)
661 def wrapper():
662 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600663 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000664
665 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600666 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000667 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000668 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600669 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000670 self.assertEqual(wrapper.attr, 'This is also a test')
671
Antoine Pitroub5b37142012-11-13 21:35:40 +0100672 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000673 "Docstrings are omitted with -O2 and above")
674 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600675 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000676 self.assertEqual(wrapper.__doc__, 'This is a test')
677
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000678 def test_no_update(self):
679 def f():
680 """This is a test"""
681 pass
682 f.attr = 'This is also a test'
683 @functools.wraps(f, (), ())
684 def wrapper():
685 pass
686 self.check_wrapper(wrapper, f, (), ())
687 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600688 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000689 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000690 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000691
692 def test_selective_update(self):
693 def f():
694 pass
695 f.attr = 'This is a different test'
696 f.dict_attr = dict(a=1, b=2, c=3)
697 def add_dict_attr(f):
698 f.dict_attr = {}
699 return f
700 assign = ('attr',)
701 update = ('dict_attr',)
702 @functools.wraps(f, assign, update)
703 @add_dict_attr
704 def wrapper():
705 pass
706 self.check_wrapper(wrapper, f, assign, update)
707 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600708 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000709 self.assertEqual(wrapper.__doc__, None)
710 self.assertEqual(wrapper.attr, 'This is a different test')
711 self.assertEqual(wrapper.dict_attr, f.dict_attr)
712
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000713@unittest.skipUnless(c_functools, 'requires the C _functools module')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000714class TestReduce(unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000715 if c_functools:
716 func = c_functools.reduce
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000717
718 def test_reduce(self):
719 class Squares:
720 def __init__(self, max):
721 self.max = max
722 self.sofar = []
723
724 def __len__(self):
725 return len(self.sofar)
726
727 def __getitem__(self, i):
728 if not 0 <= i < self.max: raise IndexError
729 n = len(self.sofar)
730 while n <= i:
731 self.sofar.append(n*n)
732 n += 1
733 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000734 def add(x, y):
735 return x + y
736 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000737 self.assertEqual(
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000738 self.func(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000739 ['a','c','d','w']
740 )
741 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
742 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000743 self.func(lambda x, y: x*y, range(2,21), 1),
744 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000745 )
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000746 self.assertEqual(self.func(add, Squares(10)), 285)
747 self.assertEqual(self.func(add, Squares(10), 0), 285)
748 self.assertEqual(self.func(add, Squares(0), 0), 0)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000749 self.assertRaises(TypeError, self.func)
750 self.assertRaises(TypeError, self.func, 42, 42)
751 self.assertRaises(TypeError, self.func, 42, 42, 42)
752 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
753 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
754 self.assertRaises(TypeError, self.func, 42, (42, 42))
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000755 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
756 self.assertRaises(TypeError, self.func, add, "")
757 self.assertRaises(TypeError, self.func, add, ())
758 self.assertRaises(TypeError, self.func, add, object())
759
760 class TestFailingIter:
761 def __iter__(self):
762 raise RuntimeError
763 self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
764
765 self.assertEqual(self.func(add, [], None), None)
766 self.assertEqual(self.func(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000767
768 class BadSeq:
769 def __getitem__(self, index):
770 raise ValueError
771 self.assertRaises(ValueError, self.func, 42, BadSeq())
772
773 # Test reduce()'s use of iterators.
774 def test_iterator_usage(self):
775 class SequenceClass:
776 def __init__(self, n):
777 self.n = n
778 def __getitem__(self, i):
779 if 0 <= i < self.n:
780 return i
781 else:
782 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000783
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000784 from operator import add
785 self.assertEqual(self.func(add, SequenceClass(5)), 10)
786 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
787 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
788 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
789 self.assertEqual(self.func(add, SequenceClass(1)), 0)
790 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
791
792 d = {"one": 1, "two": 2, "three": 3}
793 self.assertEqual(self.func(add, d), "".join(d.keys()))
794
Łukasz Langa6f692512013-06-05 12:20:24 +0200795
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200796class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700797
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000798 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700799 def cmp1(x, y):
800 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100801 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700802 self.assertEqual(key(3), key(3))
803 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100804 self.assertGreaterEqual(key(3), key(3))
805
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700806 def cmp2(x, y):
807 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100808 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700809 self.assertEqual(key(4.0), key('4'))
810 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100811 self.assertLessEqual(key(2), key('35'))
812 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700813
814 def test_cmp_to_key_arguments(self):
815 def cmp1(x, y):
816 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100817 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700818 self.assertEqual(key(obj=3), key(obj=3))
819 self.assertGreater(key(obj=3), key(obj=1))
820 with self.assertRaises((TypeError, AttributeError)):
821 key(3) > 1 # rhs is not a K object
822 with self.assertRaises((TypeError, AttributeError)):
823 1 < key(3) # lhs is not a K object
824 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100825 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700826 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200827 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100828 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700829 with self.assertRaises(TypeError):
830 key() # too few args
831 with self.assertRaises(TypeError):
832 key(None, None) # too many args
833
834 def test_bad_cmp(self):
835 def cmp1(x, y):
836 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100837 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700838 with self.assertRaises(ZeroDivisionError):
839 key(3) > key(1)
840
841 class BadCmp:
842 def __lt__(self, other):
843 raise ZeroDivisionError
844 def cmp1(x, y):
845 return BadCmp()
846 with self.assertRaises(ZeroDivisionError):
847 key(3) > key(1)
848
849 def test_obj_field(self):
850 def cmp1(x, y):
851 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100852 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700853 self.assertEqual(key(50).obj, 50)
854
855 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000856 def mycmp(x, y):
857 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100858 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000859 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000860
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700861 def test_sort_int_str(self):
862 def mycmp(x, y):
863 x, y = int(x), int(y)
864 return (x > y) - (x < y)
865 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100866 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700867 self.assertEqual([int(value) for value in values],
868 [0, 1, 1, 2, 3, 4, 5, 7, 10])
869
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000870 def test_hash(self):
871 def mycmp(x, y):
872 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100873 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000874 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700875 self.assertRaises(TypeError, hash, k)
Raymond Hettingere7a24302011-05-03 11:16:36 -0700876 self.assertNotIsInstance(k, collections.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000877
Łukasz Langa6f692512013-06-05 12:20:24 +0200878
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200879@unittest.skipUnless(c_functools, 'requires the C _functools module')
880class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
881 if c_functools:
882 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100883
Łukasz Langa6f692512013-06-05 12:20:24 +0200884
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200885class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100886 cmp_to_key = staticmethod(py_functools.cmp_to_key)
887
Łukasz Langa6f692512013-06-05 12:20:24 +0200888
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000889class TestTotalOrdering(unittest.TestCase):
890
891 def test_total_ordering_lt(self):
892 @functools.total_ordering
893 class A:
894 def __init__(self, value):
895 self.value = value
896 def __lt__(self, other):
897 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000898 def __eq__(self, other):
899 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000900 self.assertTrue(A(1) < A(2))
901 self.assertTrue(A(2) > A(1))
902 self.assertTrue(A(1) <= A(2))
903 self.assertTrue(A(2) >= A(1))
904 self.assertTrue(A(2) <= A(2))
905 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000906 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000907
908 def test_total_ordering_le(self):
909 @functools.total_ordering
910 class A:
911 def __init__(self, value):
912 self.value = value
913 def __le__(self, other):
914 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000915 def __eq__(self, other):
916 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000917 self.assertTrue(A(1) < A(2))
918 self.assertTrue(A(2) > A(1))
919 self.assertTrue(A(1) <= A(2))
920 self.assertTrue(A(2) >= A(1))
921 self.assertTrue(A(2) <= A(2))
922 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000923 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000924
925 def test_total_ordering_gt(self):
926 @functools.total_ordering
927 class A:
928 def __init__(self, value):
929 self.value = value
930 def __gt__(self, other):
931 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000932 def __eq__(self, other):
933 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000934 self.assertTrue(A(1) < A(2))
935 self.assertTrue(A(2) > A(1))
936 self.assertTrue(A(1) <= A(2))
937 self.assertTrue(A(2) >= A(1))
938 self.assertTrue(A(2) <= A(2))
939 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000940 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000941
942 def test_total_ordering_ge(self):
943 @functools.total_ordering
944 class A:
945 def __init__(self, value):
946 self.value = value
947 def __ge__(self, other):
948 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000949 def __eq__(self, other):
950 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000951 self.assertTrue(A(1) < A(2))
952 self.assertTrue(A(2) > A(1))
953 self.assertTrue(A(1) <= A(2))
954 self.assertTrue(A(2) >= A(1))
955 self.assertTrue(A(2) <= A(2))
956 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000957 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000958
959 def test_total_ordering_no_overwrite(self):
960 # new methods should not overwrite existing
961 @functools.total_ordering
962 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000963 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +0000964 self.assertTrue(A(1) < A(2))
965 self.assertTrue(A(2) > A(1))
966 self.assertTrue(A(1) <= A(2))
967 self.assertTrue(A(2) >= A(1))
968 self.assertTrue(A(2) <= A(2))
969 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000970
Benjamin Peterson42ebee32010-04-11 01:43:16 +0000971 def test_no_operations_defined(self):
972 with self.assertRaises(ValueError):
973 @functools.total_ordering
974 class A:
975 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000976
Nick Coghlanf05d9812013-10-02 00:02:03 +1000977 def test_type_error_when_not_implemented(self):
978 # bug 10042; ensure stack overflow does not occur
979 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000980 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +1000981 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000982 def __init__(self, value):
983 self.value = value
984 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000985 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000986 return self.value == other.value
987 return False
988 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000989 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000990 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +1000991 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000992
Nick Coghlanf05d9812013-10-02 00:02:03 +1000993 @functools.total_ordering
994 class ImplementsGreaterThan:
995 def __init__(self, value):
996 self.value = value
997 def __eq__(self, other):
998 if isinstance(other, ImplementsGreaterThan):
999 return self.value == other.value
1000 return False
1001 def __gt__(self, other):
1002 if isinstance(other, ImplementsGreaterThan):
1003 return self.value > other.value
1004 return NotImplemented
1005
1006 @functools.total_ordering
1007 class ImplementsLessThanEqualTo:
1008 def __init__(self, value):
1009 self.value = value
1010 def __eq__(self, other):
1011 if isinstance(other, ImplementsLessThanEqualTo):
1012 return self.value == other.value
1013 return False
1014 def __le__(self, other):
1015 if isinstance(other, ImplementsLessThanEqualTo):
1016 return self.value <= other.value
1017 return NotImplemented
1018
1019 @functools.total_ordering
1020 class ImplementsGreaterThanEqualTo:
1021 def __init__(self, value):
1022 self.value = value
1023 def __eq__(self, other):
1024 if isinstance(other, ImplementsGreaterThanEqualTo):
1025 return self.value == other.value
1026 return False
1027 def __ge__(self, other):
1028 if isinstance(other, ImplementsGreaterThanEqualTo):
1029 return self.value >= other.value
1030 return NotImplemented
1031
1032 @functools.total_ordering
1033 class ComparatorNotImplemented:
1034 def __init__(self, value):
1035 self.value = value
1036 def __eq__(self, other):
1037 if isinstance(other, ComparatorNotImplemented):
1038 return self.value == other.value
1039 return False
1040 def __lt__(self, other):
1041 return NotImplemented
1042
1043 with self.subTest("LT < 1"), self.assertRaises(TypeError):
1044 ImplementsLessThan(-1) < 1
1045
1046 with self.subTest("LT < LE"), self.assertRaises(TypeError):
1047 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
1048
1049 with self.subTest("LT < GT"), self.assertRaises(TypeError):
1050 ImplementsLessThan(1) < ImplementsGreaterThan(1)
1051
1052 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
1053 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
1054
1055 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
1056 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
1057
1058 with self.subTest("GT > GE"), self.assertRaises(TypeError):
1059 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
1060
1061 with self.subTest("GT > LT"), self.assertRaises(TypeError):
1062 ImplementsGreaterThan(5) > ImplementsLessThan(5)
1063
1064 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
1065 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
1066
1067 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
1068 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
1069
1070 with self.subTest("GE when equal"):
1071 a = ComparatorNotImplemented(8)
1072 b = ComparatorNotImplemented(8)
1073 self.assertEqual(a, b)
1074 with self.assertRaises(TypeError):
1075 a >= b
1076
1077 with self.subTest("LE when equal"):
1078 a = ComparatorNotImplemented(9)
1079 b = ComparatorNotImplemented(9)
1080 self.assertEqual(a, b)
1081 with self.assertRaises(TypeError):
1082 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +02001083
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001084 def test_pickle(self):
Serhiy Storchaka92bb90a2016-09-22 11:39:25 +03001085 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001086 for name in '__lt__', '__gt__', '__le__', '__ge__':
1087 with self.subTest(method=name, proto=proto):
1088 method = getattr(Orderable_LT, name)
1089 method_copy = pickle.loads(pickle.dumps(method, proto))
1090 self.assertIs(method_copy, method)
1091
1092@functools.total_ordering
1093class Orderable_LT:
1094 def __init__(self, value):
1095 self.value = value
1096 def __lt__(self, other):
1097 return self.value < other.value
1098 def __eq__(self, other):
1099 return self.value == other.value
1100
1101
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001102class TestLRU:
Georg Brandl2e7346a2010-07-31 18:09:23 +00001103
1104 def test_lru(self):
1105 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001106 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001107 f = self.module.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001108 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001109 self.assertEqual(maxsize, 20)
1110 self.assertEqual(currsize, 0)
1111 self.assertEqual(hits, 0)
1112 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001113
1114 domain = range(5)
1115 for i in range(1000):
1116 x, y = choice(domain), choice(domain)
1117 actual = f(x, y)
1118 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +00001119 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001120 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001121 self.assertTrue(hits > misses)
1122 self.assertEqual(hits + misses, 1000)
1123 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001124
Raymond Hettinger02566ec2010-09-04 22:46:06 +00001125 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +00001126 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001127 self.assertEqual(hits, 0)
1128 self.assertEqual(misses, 0)
1129 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001130 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001131 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001132 self.assertEqual(hits, 0)
1133 self.assertEqual(misses, 1)
1134 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001135
Nick Coghlan98876832010-08-17 06:17:18 +00001136 # Test bypassing the cache
1137 self.assertIs(f.__wrapped__, orig)
1138 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001139 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001140 self.assertEqual(hits, 0)
1141 self.assertEqual(misses, 1)
1142 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +00001143
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001144 # test size zero (which means "never-cache")
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001145 @self.module.lru_cache(0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001146 def f():
1147 nonlocal f_cnt
1148 f_cnt += 1
1149 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001150 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001151 f_cnt = 0
1152 for i in range(5):
1153 self.assertEqual(f(), 20)
1154 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001155 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001156 self.assertEqual(hits, 0)
1157 self.assertEqual(misses, 5)
1158 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001159
1160 # test size one
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001161 @self.module.lru_cache(1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001162 def f():
1163 nonlocal f_cnt
1164 f_cnt += 1
1165 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001166 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001167 f_cnt = 0
1168 for i in range(5):
1169 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001170 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001171 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001172 self.assertEqual(hits, 4)
1173 self.assertEqual(misses, 1)
1174 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001175
Raymond Hettingerf3098282010-08-15 03:30:45 +00001176 # test size two
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001177 @self.module.lru_cache(2)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001178 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001179 nonlocal f_cnt
1180 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +00001181 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +00001182 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001183 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +00001184 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1185 # * * * *
1186 self.assertEqual(f(x), x*10)
1187 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001188 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001189 self.assertEqual(hits, 12)
1190 self.assertEqual(misses, 4)
1191 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001192
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08001193 def test_lru_reentrancy_with_len(self):
1194 # Test to make sure the LRU cache code isn't thrown-off by
1195 # caching the built-in len() function. Since len() can be
1196 # cached, we shouldn't use it inside the lru code itself.
1197 old_len = builtins.len
1198 try:
1199 builtins.len = self.module.lru_cache(4)(len)
1200 for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
1201 self.assertEqual(len('abcdefghijklmn'[:i]), i)
1202 finally:
1203 builtins.len = old_len
1204
Yury Selivanov46a02db2016-11-09 18:55:45 -05001205 def test_lru_type_error(self):
1206 # Regression test for issue #28653.
1207 # lru_cache was leaking when one of the arguments
1208 # wasn't cacheable.
1209
1210 @functools.lru_cache(maxsize=None)
1211 def infinite_cache(o):
1212 pass
1213
1214 @functools.lru_cache(maxsize=10)
1215 def limited_cache(o):
1216 pass
1217
1218 with self.assertRaises(TypeError):
1219 infinite_cache([])
1220
1221 with self.assertRaises(TypeError):
1222 limited_cache([])
1223
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001224 def test_lru_with_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001225 @self.module.lru_cache(maxsize=None)
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001226 def fib(n):
1227 if n < 2:
1228 return n
1229 return fib(n-1) + fib(n-2)
1230 self.assertEqual([fib(n) for n in range(16)],
1231 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1232 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001233 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001234 fib.cache_clear()
1235 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001236 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1237
1238 def test_lru_with_maxsize_negative(self):
1239 @self.module.lru_cache(maxsize=-10)
1240 def eq(n):
1241 return n
1242 for i in (0, 1):
1243 self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1244 self.assertEqual(eq.cache_info(),
1245 self.module._CacheInfo(hits=0, misses=300, maxsize=-10, currsize=1))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001246
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001247 def test_lru_with_exceptions(self):
1248 # Verify that user_function exceptions get passed through without
1249 # creating a hard-to-read chained exception.
1250 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +01001251 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001252 @self.module.lru_cache(maxsize)
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001253 def func(i):
1254 return 'abc'[i]
1255 self.assertEqual(func(0), 'a')
1256 with self.assertRaises(IndexError) as cm:
1257 func(15)
1258 self.assertIsNone(cm.exception.__context__)
1259 # Verify that the previous exception did not result in a cached entry
1260 with self.assertRaises(IndexError):
1261 func(15)
1262
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001263 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001264 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001265 @self.module.lru_cache(maxsize=maxsize, typed=True)
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001266 def square(x):
1267 return x * x
1268 self.assertEqual(square(3), 9)
1269 self.assertEqual(type(square(3)), type(9))
1270 self.assertEqual(square(3.0), 9.0)
1271 self.assertEqual(type(square(3.0)), type(9.0))
1272 self.assertEqual(square(x=3), 9)
1273 self.assertEqual(type(square(x=3)), type(9))
1274 self.assertEqual(square(x=3.0), 9.0)
1275 self.assertEqual(type(square(x=3.0)), type(9.0))
1276 self.assertEqual(square.cache_info().hits, 4)
1277 self.assertEqual(square.cache_info().misses, 4)
1278
Antoine Pitroub5b37142012-11-13 21:35:40 +01001279 def test_lru_with_keyword_args(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001280 @self.module.lru_cache()
Antoine Pitroub5b37142012-11-13 21:35:40 +01001281 def fib(n):
1282 if n < 2:
1283 return n
1284 return fib(n=n-1) + fib(n=n-2)
1285 self.assertEqual(
1286 [fib(n=number) for number in range(16)],
1287 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1288 )
1289 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001290 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001291 fib.cache_clear()
1292 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001293 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001294
1295 def test_lru_with_keyword_args_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001296 @self.module.lru_cache(maxsize=None)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001297 def fib(n):
1298 if n < 2:
1299 return n
1300 return fib(n=n-1) + fib(n=n-2)
1301 self.assertEqual([fib(n=number) for number in range(16)],
1302 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1303 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001304 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001305 fib.cache_clear()
1306 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001307 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1308
1309 def test_lru_cache_decoration(self):
1310 def f(zomg: 'zomg_annotation'):
1311 """f doc string"""
1312 return 42
1313 g = self.module.lru_cache()(f)
1314 for attr in self.module.WRAPPER_ASSIGNMENTS:
1315 self.assertEqual(getattr(g, attr), getattr(f, attr))
1316
1317 @unittest.skipUnless(threading, 'This test requires threading.')
1318 def test_lru_cache_threaded(self):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001319 n, m = 5, 11
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001320 def orig(x, y):
1321 return 3 * x + y
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001322 f = self.module.lru_cache(maxsize=n*m)(orig)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001323 hits, misses, maxsize, currsize = f.cache_info()
1324 self.assertEqual(currsize, 0)
1325
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001326 start = threading.Event()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001327 def full(k):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001328 start.wait(10)
1329 for _ in range(m):
Serhiy Storchaka391af752015-06-08 12:44:18 +03001330 self.assertEqual(f(k, 0), orig(k, 0))
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001331
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001332 def clear():
1333 start.wait(10)
1334 for _ in range(2*m):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001335 f.cache_clear()
1336
1337 orig_si = sys.getswitchinterval()
Xavier de Gaye7522ef42016-12-08 11:06:56 +01001338 support.setswitchinterval(1e-6)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001339 try:
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001340 # create n threads in order to fill cache
Serhiy Storchaka391af752015-06-08 12:44:18 +03001341 threads = [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001342 for k in range(n)]
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001343 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001344 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001345
1346 hits, misses, maxsize, currsize = f.cache_info()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001347 if self.module is py_functools:
1348 # XXX: Why can be not equal?
1349 self.assertLessEqual(misses, n)
1350 self.assertLessEqual(hits, m*n - misses)
1351 else:
1352 self.assertEqual(misses, n)
1353 self.assertEqual(hits, m*n - misses)
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001354 self.assertEqual(currsize, n)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001355
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001356 # create n threads in order to fill cache and 1 to clear it
1357 threads = [threading.Thread(target=clear)]
Serhiy Storchaka391af752015-06-08 12:44:18 +03001358 threads += [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001359 for k in range(n)]
1360 start.clear()
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001361 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001362 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001363 finally:
1364 sys.setswitchinterval(orig_si)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001365
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001366 @unittest.skipUnless(threading, 'This test requires threading.')
1367 def test_lru_cache_threaded2(self):
1368 # Simultaneous call with the same arguments
1369 n, m = 5, 7
1370 start = threading.Barrier(n+1)
1371 pause = threading.Barrier(n+1)
1372 stop = threading.Barrier(n+1)
1373 @self.module.lru_cache(maxsize=m*n)
1374 def f(x):
1375 pause.wait(10)
1376 return 3 * x
1377 self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1378 def test():
1379 for i in range(m):
1380 start.wait(10)
1381 self.assertEqual(f(i), 3 * i)
1382 stop.wait(10)
1383 threads = [threading.Thread(target=test) for k in range(n)]
1384 with support.start_threads(threads):
1385 for i in range(m):
1386 start.wait(10)
1387 stop.reset()
1388 pause.wait(10)
1389 start.reset()
1390 stop.wait(10)
1391 pause.reset()
1392 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1393
Raymond Hettinger03923422013-03-04 02:52:50 -05001394 def test_need_for_rlock(self):
1395 # This will deadlock on an LRU cache that uses a regular lock
1396
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001397 @self.module.lru_cache(maxsize=10)
Raymond Hettinger03923422013-03-04 02:52:50 -05001398 def test_func(x):
1399 'Used to demonstrate a reentrant lru_cache call within a single thread'
1400 return x
1401
1402 class DoubleEq:
1403 'Demonstrate a reentrant lru_cache call within a single thread'
1404 def __init__(self, x):
1405 self.x = x
1406 def __hash__(self):
1407 return self.x
1408 def __eq__(self, other):
1409 if self.x == 2:
1410 test_func(DoubleEq(1))
1411 return self.x == other.x
1412
1413 test_func(DoubleEq(1)) # Load the cache
1414 test_func(DoubleEq(2)) # Load the cache
1415 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1416 DoubleEq(2)) # Verify the correct return value
1417
Raymond Hettinger4d588972014-08-12 12:44:52 -07001418 def test_early_detection_of_bad_call(self):
1419 # Issue #22184
1420 with self.assertRaises(TypeError):
1421 @functools.lru_cache
1422 def f():
1423 pass
1424
Serhiy Storchakae7070f02015-06-08 11:19:24 +03001425 def test_lru_method(self):
1426 class X(int):
1427 f_cnt = 0
1428 @self.module.lru_cache(2)
1429 def f(self, x):
1430 self.f_cnt += 1
1431 return x*10+self
1432 a = X(5)
1433 b = X(5)
1434 c = X(7)
1435 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1436
1437 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1438 self.assertEqual(a.f(x), x*10 + 5)
1439 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1440 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1441
1442 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1443 self.assertEqual(b.f(x), x*10 + 5)
1444 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1445 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1446
1447 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1448 self.assertEqual(c.f(x), x*10 + 7)
1449 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1450 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1451
1452 self.assertEqual(a.f.cache_info(), X.f.cache_info())
1453 self.assertEqual(b.f.cache_info(), X.f.cache_info())
1454 self.assertEqual(c.f.cache_info(), X.f.cache_info())
1455
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001456 def test_pickle(self):
1457 cls = self.__class__
1458 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1459 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1460 with self.subTest(proto=proto, func=f):
1461 f_copy = pickle.loads(pickle.dumps(f, proto))
1462 self.assertIs(f_copy, f)
1463
1464 def test_copy(self):
1465 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001466 def orig(x, y):
1467 return 3 * x + y
1468 part = self.module.partial(orig, 2)
1469 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1470 self.module.lru_cache(2)(part))
1471 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001472 with self.subTest(func=f):
1473 f_copy = copy.copy(f)
1474 self.assertIs(f_copy, f)
1475
1476 def test_deepcopy(self):
1477 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001478 def orig(x, y):
1479 return 3 * x + y
1480 part = self.module.partial(orig, 2)
1481 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1482 self.module.lru_cache(2)(part))
1483 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001484 with self.subTest(func=f):
1485 f_copy = copy.deepcopy(f)
1486 self.assertIs(f_copy, f)
1487
1488
1489@py_functools.lru_cache()
1490def py_cached_func(x, y):
1491 return 3 * x + y
1492
1493@c_functools.lru_cache()
1494def c_cached_func(x, y):
1495 return 3 * x + y
1496
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001497
1498class TestLRUPy(TestLRU, unittest.TestCase):
1499 module = py_functools
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001500 cached_func = py_cached_func,
1501
1502 @module.lru_cache()
1503 def cached_meth(self, x, y):
1504 return 3 * x + y
1505
1506 @staticmethod
1507 @module.lru_cache()
1508 def cached_staticmeth(x, y):
1509 return 3 * x + y
1510
1511
1512class TestLRUC(TestLRU, unittest.TestCase):
1513 module = c_functools
1514 cached_func = c_cached_func,
1515
1516 @module.lru_cache()
1517 def cached_meth(self, x, y):
1518 return 3 * x + y
1519
1520 @staticmethod
1521 @module.lru_cache()
1522 def cached_staticmeth(x, y):
1523 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001524
Raymond Hettinger03923422013-03-04 02:52:50 -05001525
Łukasz Langa6f692512013-06-05 12:20:24 +02001526class TestSingleDispatch(unittest.TestCase):
1527 def test_simple_overloads(self):
1528 @functools.singledispatch
1529 def g(obj):
1530 return "base"
1531 def g_int(i):
1532 return "integer"
1533 g.register(int, g_int)
1534 self.assertEqual(g("str"), "base")
1535 self.assertEqual(g(1), "integer")
1536 self.assertEqual(g([1,2,3]), "base")
1537
1538 def test_mro(self):
1539 @functools.singledispatch
1540 def g(obj):
1541 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001542 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001543 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001544 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001545 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001546 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001547 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001548 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001549 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001550 def g_A(a):
1551 return "A"
1552 def g_B(b):
1553 return "B"
1554 g.register(A, g_A)
1555 g.register(B, g_B)
1556 self.assertEqual(g(A()), "A")
1557 self.assertEqual(g(B()), "B")
1558 self.assertEqual(g(C()), "A")
1559 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001560
1561 def test_register_decorator(self):
1562 @functools.singledispatch
1563 def g(obj):
1564 return "base"
1565 @g.register(int)
1566 def g_int(i):
1567 return "int %s" % (i,)
1568 self.assertEqual(g(""), "base")
1569 self.assertEqual(g(12), "int 12")
1570 self.assertIs(g.dispatch(int), g_int)
1571 self.assertIs(g.dispatch(object), g.dispatch(str))
1572 # Note: in the assert above this is not g.
1573 # @singledispatch returns the wrapper.
1574
1575 def test_wrapping_attributes(self):
1576 @functools.singledispatch
1577 def g(obj):
1578 "Simple test"
1579 return "Test"
1580 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02001581 if sys.flags.optimize < 2:
1582 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02001583
1584 @unittest.skipUnless(decimal, 'requires _decimal')
1585 @support.cpython_only
1586 def test_c_classes(self):
1587 @functools.singledispatch
1588 def g(obj):
1589 return "base"
1590 @g.register(decimal.DecimalException)
1591 def _(obj):
1592 return obj.args
1593 subn = decimal.Subnormal("Exponent < Emin")
1594 rnd = decimal.Rounded("Number got rounded")
1595 self.assertEqual(g(subn), ("Exponent < Emin",))
1596 self.assertEqual(g(rnd), ("Number got rounded",))
1597 @g.register(decimal.Subnormal)
1598 def _(obj):
1599 return "Too small to care."
1600 self.assertEqual(g(subn), "Too small to care.")
1601 self.assertEqual(g(rnd), ("Number got rounded",))
1602
1603 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001604 # None of the examples in this test depend on haystack ordering.
Łukasz Langa6f692512013-06-05 12:20:24 +02001605 c = collections
1606 mro = functools._compose_mro
1607 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1608 for haystack in permutations(bases):
1609 m = mro(dict, haystack)
Guido van Rossumf0666942016-08-23 10:47:07 -07001610 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
1611 c.Collection, c.Sized, c.Iterable,
1612 c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001613 bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict]
1614 for haystack in permutations(bases):
1615 m = mro(c.ChainMap, haystack)
1616 self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001617 c.Collection, c.Sized, c.Iterable,
1618 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001619
1620 # If there's a generic function with implementations registered for
1621 # both Sized and Container, passing a defaultdict to it results in an
1622 # ambiguous dispatch which will cause a RuntimeError (see
1623 # test_mro_conflicts).
1624 bases = [c.Container, c.Sized, str]
1625 for haystack in permutations(bases):
1626 m = mro(c.defaultdict, [c.Sized, c.Container, str])
1627 self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container,
1628 object])
1629
1630 # MutableSequence below is registered directly on D. In other words, it
Martin Panter46f50722016-05-26 05:35:26 +00001631 # precedes MutableMapping which means single dispatch will always
Łukasz Langa3720c772013-07-01 16:00:38 +02001632 # choose MutableSequence here.
1633 class D(c.defaultdict):
1634 pass
1635 c.MutableSequence.register(D)
1636 bases = [c.MutableSequence, c.MutableMapping]
1637 for haystack in permutations(bases):
1638 m = mro(D, bases)
Guido van Rossumf0666942016-08-23 10:47:07 -07001639 self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
1640 c.defaultdict, dict, c.MutableMapping, c.Mapping,
1641 c.Collection, c.Sized, c.Iterable, c.Container,
Łukasz Langa3720c772013-07-01 16:00:38 +02001642 object])
1643
1644 # Container and Callable are registered on different base classes and
1645 # a generic function supporting both should always pick the Callable
1646 # implementation if a C instance is passed.
1647 class C(c.defaultdict):
1648 def __call__(self):
1649 pass
1650 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1651 for haystack in permutations(bases):
1652 m = mro(C, haystack)
1653 self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001654 c.Collection, c.Sized, c.Iterable,
1655 c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001656
1657 def test_register_abc(self):
1658 c = collections
1659 d = {"a": "b"}
1660 l = [1, 2, 3]
1661 s = {object(), None}
1662 f = frozenset(s)
1663 t = (1, 2, 3)
1664 @functools.singledispatch
1665 def g(obj):
1666 return "base"
1667 self.assertEqual(g(d), "base")
1668 self.assertEqual(g(l), "base")
1669 self.assertEqual(g(s), "base")
1670 self.assertEqual(g(f), "base")
1671 self.assertEqual(g(t), "base")
1672 g.register(c.Sized, lambda obj: "sized")
1673 self.assertEqual(g(d), "sized")
1674 self.assertEqual(g(l), "sized")
1675 self.assertEqual(g(s), "sized")
1676 self.assertEqual(g(f), "sized")
1677 self.assertEqual(g(t), "sized")
1678 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1679 self.assertEqual(g(d), "mutablemapping")
1680 self.assertEqual(g(l), "sized")
1681 self.assertEqual(g(s), "sized")
1682 self.assertEqual(g(f), "sized")
1683 self.assertEqual(g(t), "sized")
1684 g.register(c.ChainMap, lambda obj: "chainmap")
1685 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1686 self.assertEqual(g(l), "sized")
1687 self.assertEqual(g(s), "sized")
1688 self.assertEqual(g(f), "sized")
1689 self.assertEqual(g(t), "sized")
1690 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1691 self.assertEqual(g(d), "mutablemapping")
1692 self.assertEqual(g(l), "mutablesequence")
1693 self.assertEqual(g(s), "sized")
1694 self.assertEqual(g(f), "sized")
1695 self.assertEqual(g(t), "sized")
1696 g.register(c.MutableSet, lambda obj: "mutableset")
1697 self.assertEqual(g(d), "mutablemapping")
1698 self.assertEqual(g(l), "mutablesequence")
1699 self.assertEqual(g(s), "mutableset")
1700 self.assertEqual(g(f), "sized")
1701 self.assertEqual(g(t), "sized")
1702 g.register(c.Mapping, lambda obj: "mapping")
1703 self.assertEqual(g(d), "mutablemapping") # not specific enough
1704 self.assertEqual(g(l), "mutablesequence")
1705 self.assertEqual(g(s), "mutableset")
1706 self.assertEqual(g(f), "sized")
1707 self.assertEqual(g(t), "sized")
1708 g.register(c.Sequence, lambda obj: "sequence")
1709 self.assertEqual(g(d), "mutablemapping")
1710 self.assertEqual(g(l), "mutablesequence")
1711 self.assertEqual(g(s), "mutableset")
1712 self.assertEqual(g(f), "sized")
1713 self.assertEqual(g(t), "sequence")
1714 g.register(c.Set, lambda obj: "set")
1715 self.assertEqual(g(d), "mutablemapping")
1716 self.assertEqual(g(l), "mutablesequence")
1717 self.assertEqual(g(s), "mutableset")
1718 self.assertEqual(g(f), "set")
1719 self.assertEqual(g(t), "sequence")
1720 g.register(dict, lambda obj: "dict")
1721 self.assertEqual(g(d), "dict")
1722 self.assertEqual(g(l), "mutablesequence")
1723 self.assertEqual(g(s), "mutableset")
1724 self.assertEqual(g(f), "set")
1725 self.assertEqual(g(t), "sequence")
1726 g.register(list, lambda obj: "list")
1727 self.assertEqual(g(d), "dict")
1728 self.assertEqual(g(l), "list")
1729 self.assertEqual(g(s), "mutableset")
1730 self.assertEqual(g(f), "set")
1731 self.assertEqual(g(t), "sequence")
1732 g.register(set, lambda obj: "concrete-set")
1733 self.assertEqual(g(d), "dict")
1734 self.assertEqual(g(l), "list")
1735 self.assertEqual(g(s), "concrete-set")
1736 self.assertEqual(g(f), "set")
1737 self.assertEqual(g(t), "sequence")
1738 g.register(frozenset, lambda obj: "frozen-set")
1739 self.assertEqual(g(d), "dict")
1740 self.assertEqual(g(l), "list")
1741 self.assertEqual(g(s), "concrete-set")
1742 self.assertEqual(g(f), "frozen-set")
1743 self.assertEqual(g(t), "sequence")
1744 g.register(tuple, lambda obj: "tuple")
1745 self.assertEqual(g(d), "dict")
1746 self.assertEqual(g(l), "list")
1747 self.assertEqual(g(s), "concrete-set")
1748 self.assertEqual(g(f), "frozen-set")
1749 self.assertEqual(g(t), "tuple")
1750
Łukasz Langa3720c772013-07-01 16:00:38 +02001751 def test_c3_abc(self):
1752 c = collections
1753 mro = functools._c3_mro
1754 class A(object):
1755 pass
1756 class B(A):
1757 def __len__(self):
1758 return 0 # implies Sized
1759 @c.Container.register
1760 class C(object):
1761 pass
1762 class D(object):
1763 pass # unrelated
1764 class X(D, C, B):
1765 def __call__(self):
1766 pass # implies Callable
1767 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
1768 for abcs in permutations([c.Sized, c.Callable, c.Container]):
1769 self.assertEqual(mro(X, abcs=abcs), expected)
1770 # unrelated ABCs don't appear in the resulting MRO
1771 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
1772 self.assertEqual(mro(X, abcs=many_abcs), expected)
1773
Yury Selivanov77a8cd62015-08-18 14:20:00 -04001774 def test_false_meta(self):
1775 # see issue23572
1776 class MetaA(type):
1777 def __len__(self):
1778 return 0
1779 class A(metaclass=MetaA):
1780 pass
1781 class AA(A):
1782 pass
1783 @functools.singledispatch
1784 def fun(a):
1785 return 'base A'
1786 @fun.register(A)
1787 def _(a):
1788 return 'fun A'
1789 aa = AA()
1790 self.assertEqual(fun(aa), 'fun A')
1791
Łukasz Langa6f692512013-06-05 12:20:24 +02001792 def test_mro_conflicts(self):
1793 c = collections
Łukasz Langa6f692512013-06-05 12:20:24 +02001794 @functools.singledispatch
1795 def g(arg):
1796 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02001797 class O(c.Sized):
1798 def __len__(self):
1799 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001800 o = O()
1801 self.assertEqual(g(o), "base")
1802 g.register(c.Iterable, lambda arg: "iterable")
1803 g.register(c.Container, lambda arg: "container")
1804 g.register(c.Sized, lambda arg: "sized")
1805 g.register(c.Set, lambda arg: "set")
1806 self.assertEqual(g(o), "sized")
1807 c.Iterable.register(O)
1808 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
1809 c.Container.register(O)
1810 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02001811 c.Set.register(O)
1812 self.assertEqual(g(o), "set") # because c.Set is a subclass of
1813 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02001814 class P:
1815 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02001816 p = P()
1817 self.assertEqual(g(p), "base")
1818 c.Iterable.register(P)
1819 self.assertEqual(g(p), "iterable")
1820 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02001821 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02001822 g(p)
Benjamin Petersonab078e92016-07-13 21:13:29 -07001823 self.assertIn(
1824 str(re_one.exception),
1825 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1826 "or <class 'collections.abc.Iterable'>"),
1827 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
1828 "or <class 'collections.abc.Container'>")),
1829 )
Łukasz Langa6f692512013-06-05 12:20:24 +02001830 class Q(c.Sized):
1831 def __len__(self):
1832 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001833 q = Q()
1834 self.assertEqual(g(q), "sized")
1835 c.Iterable.register(Q)
1836 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
1837 c.Set.register(Q)
1838 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02001839 # c.Sized and c.Iterable
1840 @functools.singledispatch
1841 def h(arg):
1842 return "base"
1843 @h.register(c.Sized)
1844 def _(arg):
1845 return "sized"
1846 @h.register(c.Container)
1847 def _(arg):
1848 return "container"
1849 # Even though Sized and Container are explicit bases of MutableMapping,
1850 # this ABC is implicitly registered on defaultdict which makes all of
1851 # MutableMapping's bases implicit as well from defaultdict's
1852 # perspective.
1853 with self.assertRaises(RuntimeError) as re_two:
1854 h(c.defaultdict(lambda: 0))
Benjamin Petersonab078e92016-07-13 21:13:29 -07001855 self.assertIn(
1856 str(re_two.exception),
1857 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1858 "or <class 'collections.abc.Sized'>"),
1859 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1860 "or <class 'collections.abc.Container'>")),
1861 )
Łukasz Langa3720c772013-07-01 16:00:38 +02001862 class R(c.defaultdict):
1863 pass
1864 c.MutableSequence.register(R)
1865 @functools.singledispatch
1866 def i(arg):
1867 return "base"
1868 @i.register(c.MutableMapping)
1869 def _(arg):
1870 return "mapping"
1871 @i.register(c.MutableSequence)
1872 def _(arg):
1873 return "sequence"
1874 r = R()
1875 self.assertEqual(i(r), "sequence")
1876 class S:
1877 pass
1878 class T(S, c.Sized):
1879 def __len__(self):
1880 return 0
1881 t = T()
1882 self.assertEqual(h(t), "sized")
1883 c.Container.register(T)
1884 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
1885 class U:
1886 def __len__(self):
1887 return 0
1888 u = U()
1889 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
1890 # from the existence of __len__()
1891 c.Container.register(U)
1892 # There is no preference for registered versus inferred ABCs.
1893 with self.assertRaises(RuntimeError) as re_three:
1894 h(u)
Benjamin Petersonab078e92016-07-13 21:13:29 -07001895 self.assertIn(
1896 str(re_three.exception),
1897 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1898 "or <class 'collections.abc.Sized'>"),
1899 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1900 "or <class 'collections.abc.Container'>")),
1901 )
Łukasz Langa3720c772013-07-01 16:00:38 +02001902 class V(c.Sized, S):
1903 def __len__(self):
1904 return 0
1905 @functools.singledispatch
1906 def j(arg):
1907 return "base"
1908 @j.register(S)
1909 def _(arg):
1910 return "s"
1911 @j.register(c.Container)
1912 def _(arg):
1913 return "container"
1914 v = V()
1915 self.assertEqual(j(v), "s")
1916 c.Container.register(V)
1917 self.assertEqual(j(v), "container") # because it ends up right after
1918 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02001919
1920 def test_cache_invalidation(self):
1921 from collections import UserDict
1922 class TracingDict(UserDict):
1923 def __init__(self, *args, **kwargs):
1924 super(TracingDict, self).__init__(*args, **kwargs)
1925 self.set_ops = []
1926 self.get_ops = []
1927 def __getitem__(self, key):
1928 result = self.data[key]
1929 self.get_ops.append(key)
1930 return result
1931 def __setitem__(self, key, value):
1932 self.set_ops.append(key)
1933 self.data[key] = value
1934 def clear(self):
1935 self.data.clear()
1936 _orig_wkd = functools.WeakKeyDictionary
1937 td = TracingDict()
1938 functools.WeakKeyDictionary = lambda: td
1939 c = collections
1940 @functools.singledispatch
1941 def g(arg):
1942 return "base"
1943 d = {}
1944 l = []
1945 self.assertEqual(len(td), 0)
1946 self.assertEqual(g(d), "base")
1947 self.assertEqual(len(td), 1)
1948 self.assertEqual(td.get_ops, [])
1949 self.assertEqual(td.set_ops, [dict])
1950 self.assertEqual(td.data[dict], g.registry[object])
1951 self.assertEqual(g(l), "base")
1952 self.assertEqual(len(td), 2)
1953 self.assertEqual(td.get_ops, [])
1954 self.assertEqual(td.set_ops, [dict, list])
1955 self.assertEqual(td.data[dict], g.registry[object])
1956 self.assertEqual(td.data[list], g.registry[object])
1957 self.assertEqual(td.data[dict], td.data[list])
1958 self.assertEqual(g(l), "base")
1959 self.assertEqual(g(d), "base")
1960 self.assertEqual(td.get_ops, [list, dict])
1961 self.assertEqual(td.set_ops, [dict, list])
1962 g.register(list, lambda arg: "list")
1963 self.assertEqual(td.get_ops, [list, dict])
1964 self.assertEqual(len(td), 0)
1965 self.assertEqual(g(d), "base")
1966 self.assertEqual(len(td), 1)
1967 self.assertEqual(td.get_ops, [list, dict])
1968 self.assertEqual(td.set_ops, [dict, list, dict])
1969 self.assertEqual(td.data[dict],
1970 functools._find_impl(dict, g.registry))
1971 self.assertEqual(g(l), "list")
1972 self.assertEqual(len(td), 2)
1973 self.assertEqual(td.get_ops, [list, dict])
1974 self.assertEqual(td.set_ops, [dict, list, dict, list])
1975 self.assertEqual(td.data[list],
1976 functools._find_impl(list, g.registry))
1977 class X:
1978 pass
1979 c.MutableMapping.register(X) # Will not invalidate the cache,
1980 # not using ABCs yet.
1981 self.assertEqual(g(d), "base")
1982 self.assertEqual(g(l), "list")
1983 self.assertEqual(td.get_ops, [list, dict, dict, list])
1984 self.assertEqual(td.set_ops, [dict, list, dict, list])
1985 g.register(c.Sized, lambda arg: "sized")
1986 self.assertEqual(len(td), 0)
1987 self.assertEqual(g(d), "sized")
1988 self.assertEqual(len(td), 1)
1989 self.assertEqual(td.get_ops, [list, dict, dict, list])
1990 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
1991 self.assertEqual(g(l), "list")
1992 self.assertEqual(len(td), 2)
1993 self.assertEqual(td.get_ops, [list, dict, dict, list])
1994 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1995 self.assertEqual(g(l), "list")
1996 self.assertEqual(g(d), "sized")
1997 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
1998 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1999 g.dispatch(list)
2000 g.dispatch(dict)
2001 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
2002 list, dict])
2003 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2004 c.MutableSet.register(X) # Will invalidate the cache.
2005 self.assertEqual(len(td), 2) # Stale cache.
2006 self.assertEqual(g(l), "list")
2007 self.assertEqual(len(td), 1)
2008 g.register(c.MutableMapping, lambda arg: "mutablemapping")
2009 self.assertEqual(len(td), 0)
2010 self.assertEqual(g(d), "mutablemapping")
2011 self.assertEqual(len(td), 1)
2012 self.assertEqual(g(l), "list")
2013 self.assertEqual(len(td), 2)
2014 g.register(dict, lambda arg: "dict")
2015 self.assertEqual(g(d), "dict")
2016 self.assertEqual(g(l), "list")
2017 g._clear_cache()
2018 self.assertEqual(len(td), 0)
2019 functools.WeakKeyDictionary = _orig_wkd
2020
2021
Raymond Hettinger9c323f82005-02-28 19:39:44 +00002022if __name__ == '__main__':
Zachary Ware38c707e2015-04-13 15:00:43 -05002023 unittest.main()