blob: cd4664cec0874567dead66d8fd8b900deace242d [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08002import builtins
Raymond Hettinger003be522011-05-03 11:01:32 -07003import collections
Serhiy Storchaka45120f22015-10-24 09:49:56 +03004import copy
Łukasz Langa6f692512013-06-05 12:20:24 +02005from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00006import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00007from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02008import sys
9from test import support
Serhiy Storchaka67796522017-01-12 18:34:33 +020010import time
Łukasz Langa6f692512013-06-05 12:20:24 +020011import unittest
12from weakref import proxy
Nick Coghlan457fc9a2016-09-10 20:00:02 +100013import contextlib
Serhiy Storchaka46c56112015-05-24 21:53:49 +030014try:
15 import threading
16except ImportError:
17 threading = None
Raymond Hettinger9c323f82005-02-28 19:39:44 +000018
Antoine Pitroub5b37142012-11-13 21:35:40 +010019import functools
20
Antoine Pitroub5b37142012-11-13 21:35:40 +010021py_functools = support.import_fresh_module('functools', blocked=['_functools'])
22c_functools = support.import_fresh_module('functools', fresh=['_functools'])
23
Łukasz Langa6f692512013-06-05 12:20:24 +020024decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
25
Nick Coghlan457fc9a2016-09-10 20:00:02 +100026@contextlib.contextmanager
27def replaced_module(name, replacement):
28 original_module = sys.modules[name]
29 sys.modules[name] = replacement
30 try:
31 yield
32 finally:
33 sys.modules[name] = original_module
Łukasz Langa6f692512013-06-05 12:20:24 +020034
Raymond Hettinger9c323f82005-02-28 19:39:44 +000035def capture(*args, **kw):
36 """capture all positional and keyword arguments"""
37 return args, kw
38
Łukasz Langa6f692512013-06-05 12:20:24 +020039
Jack Diederiche0cbd692009-04-01 04:27:09 +000040def signature(part):
41 """ return the signature of a partial object """
42 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000043
Serhiy Storchaka38741282016-02-02 18:45:17 +020044class MyTuple(tuple):
45 pass
46
47class BadTuple(tuple):
48 def __add__(self, other):
49 return list(self) + list(other)
50
51class MyDict(dict):
52 pass
53
Łukasz Langa6f692512013-06-05 12:20:24 +020054
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020055class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000056
57 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010058 p = self.partial(capture, 1, 2, a=10, b=20)
59 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000060 self.assertEqual(p(3, 4, b=30, c=40),
61 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010062 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000063 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000064
65 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010066 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000067 # attributes should be readable
68 self.assertEqual(p.func, capture)
69 self.assertEqual(p.args, (1, 2))
70 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000071
72 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010073 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000074 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010075 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000076 except TypeError:
77 pass
78 else:
79 self.fail('First arg not checked for callability')
80
81 def test_protection_of_callers_dict_argument(self):
82 # a caller's dictionary should not be altered by partial
83 def func(a=10, b=20):
84 return a
85 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010086 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000087 self.assertEqual(p(**d), 3)
88 self.assertEqual(d, {'a':3})
89 p(b=7)
90 self.assertEqual(d, {'a':3})
91
Serhiy Storchakae48fd932017-02-21 18:18:27 +020092 def test_kwargs_copy(self):
93 # Issue #29532: Altering a kwarg dictionary passed to a constructor
94 # should not affect a partial object after creation
95 d = {'a': 3}
96 p = self.partial(capture, **d)
97 self.assertEqual(p(), ((), {'a': 3}))
98 d['a'] = 5
99 self.assertEqual(p(), ((), {'a': 3}))
100
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000101 def test_arg_combinations(self):
102 # exercise special code paths for zero args in either partial
103 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100104 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000105 self.assertEqual(p(), ((), {}))
106 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100107 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000108 self.assertEqual(p(), ((1,2), {}))
109 self.assertEqual(p(3,4), ((1,2,3,4), {}))
110
111 def test_kw_combinations(self):
112 # exercise special code paths for no keyword args in
113 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100114 p = self.partial(capture)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400115 self.assertEqual(p.keywords, {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000116 self.assertEqual(p(), ((), {}))
117 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100118 p = self.partial(capture, a=1)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400119 self.assertEqual(p.keywords, {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000120 self.assertEqual(p(), ((), {'a':1}))
121 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
122 # keyword args in the call override those in the partial object
123 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
124
125 def test_positional(self):
126 # make sure positional arguments are captured correctly
127 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100128 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000129 expected = args + ('x',)
130 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000131 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000132
133 def test_keyword(self):
134 # make sure keyword arguments are captured correctly
135 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100136 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000137 expected = {'a':a,'x':None}
138 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000139 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000140
141 def test_no_side_effects(self):
142 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100143 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000144 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000145 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000146 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000147 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000148
149 def test_error_propagation(self):
150 def f(x, y):
151 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100152 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
153 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
154 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
155 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000156
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000157 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100158 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000159 p = proxy(f)
160 self.assertEqual(f.func, p.func)
161 f = None
162 self.assertRaises(ReferenceError, getattr, p, 'func')
163
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000164 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000165 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100166 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000167 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100168 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000169 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000170
Alexander Belopolskye49af342015-03-01 15:08:17 -0500171 def test_nested_optimization(self):
172 partial = self.partial
Alexander Belopolskye49af342015-03-01 15:08:17 -0500173 inner = partial(signature, 'asdf')
174 nested = partial(inner, bar=True)
175 flat = partial(signature, 'asdf', bar=True)
176 self.assertEqual(signature(nested), signature(flat))
177
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300178 def test_nested_partial_with_attribute(self):
179 # see issue 25137
180 partial = self.partial
181
182 def foo(bar):
183 return bar
184
185 p = partial(foo, 'first')
186 p2 = partial(p, 'second')
187 p2.new_attr = 'spam'
188 self.assertEqual(p2.new_attr, 'spam')
189
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000190 def test_repr(self):
191 args = (object(), object())
192 args_repr = ', '.join(repr(a) for a in args)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200193 kwargs = {'a': object(), 'b': object()}
194 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
195 'b={b!r}, a={a!r}'.format_map(kwargs)]
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000196 if self.partial in (c_functools.partial, py_functools.partial):
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000197 name = 'functools.partial'
198 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100199 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000200
Antoine Pitroub5b37142012-11-13 21:35:40 +0100201 f = self.partial(capture)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000202 self.assertEqual(f'{name}({capture!r})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000203
Antoine Pitroub5b37142012-11-13 21:35:40 +0100204 f = self.partial(capture, *args)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000205 self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000206
Antoine Pitroub5b37142012-11-13 21:35:40 +0100207 f = self.partial(capture, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200208 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000209 [f'{name}({capture!r}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200210 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000211
Antoine Pitroub5b37142012-11-13 21:35:40 +0100212 f = self.partial(capture, *args, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200213 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000214 [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200215 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000216
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300217 def test_recursive_repr(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000218 if self.partial in (c_functools.partial, py_functools.partial):
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300219 name = 'functools.partial'
220 else:
221 name = self.partial.__name__
222
223 f = self.partial(capture)
224 f.__setstate__((f, (), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300225 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000226 self.assertEqual(repr(f), '%s(...)' % (name,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300227 finally:
228 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300229
230 f = self.partial(capture)
231 f.__setstate__((capture, (f,), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300232 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000233 self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300234 finally:
235 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300236
237 f = self.partial(capture)
238 f.__setstate__((capture, (), {'a': f}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300239 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000240 self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300241 finally:
242 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300243
Jack Diederiche0cbd692009-04-01 04:27:09 +0000244 def test_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000245 with self.AllowPickle():
246 f = self.partial(signature, ['asdf'], bar=[True])
247 f.attr = []
248 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
249 f_copy = pickle.loads(pickle.dumps(f, proto))
250 self.assertEqual(signature(f_copy), signature(f))
Serhiy Storchaka38741282016-02-02 18:45:17 +0200251
252 def test_copy(self):
253 f = self.partial(signature, ['asdf'], bar=[True])
254 f.attr = []
255 f_copy = copy.copy(f)
256 self.assertEqual(signature(f_copy), signature(f))
257 self.assertIs(f_copy.attr, f.attr)
258 self.assertIs(f_copy.args, f.args)
259 self.assertIs(f_copy.keywords, f.keywords)
260
261 def test_deepcopy(self):
262 f = self.partial(signature, ['asdf'], bar=[True])
263 f.attr = []
264 f_copy = copy.deepcopy(f)
265 self.assertEqual(signature(f_copy), signature(f))
266 self.assertIsNot(f_copy.attr, f.attr)
267 self.assertIsNot(f_copy.args, f.args)
268 self.assertIsNot(f_copy.args[0], f.args[0])
269 self.assertIsNot(f_copy.keywords, f.keywords)
270 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
271
272 def test_setstate(self):
273 f = self.partial(signature)
274 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000275
Serhiy Storchaka38741282016-02-02 18:45:17 +0200276 self.assertEqual(signature(f),
277 (capture, (1,), dict(a=10), dict(attr=[])))
278 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
279
280 f.__setstate__((capture, (1,), dict(a=10), None))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000281
Serhiy Storchaka38741282016-02-02 18:45:17 +0200282 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
283 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
284
285 f.__setstate__((capture, (1,), None, None))
286 #self.assertEqual(signature(f), (capture, (1,), {}, {}))
287 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
288 self.assertEqual(f(2), ((1, 2), {}))
289 self.assertEqual(f(), ((1,), {}))
290
291 f.__setstate__((capture, (), {}, None))
292 self.assertEqual(signature(f), (capture, (), {}, {}))
293 self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
294 self.assertEqual(f(2), ((2,), {}))
295 self.assertEqual(f(), ((), {}))
296
297 def test_setstate_errors(self):
298 f = self.partial(signature)
299 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
300 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
301 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
302 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
303 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
304 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
305 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
306
307 def test_setstate_subclasses(self):
308 f = self.partial(signature)
309 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
310 s = signature(f)
311 self.assertEqual(s, (capture, (1,), dict(a=10), {}))
312 self.assertIs(type(s[1]), tuple)
313 self.assertIs(type(s[2]), dict)
314 r = f()
315 self.assertEqual(r, ((1,), {'a': 10}))
316 self.assertIs(type(r[0]), tuple)
317 self.assertIs(type(r[1]), dict)
318
319 f.__setstate__((capture, BadTuple((1,)), {}, None))
320 s = signature(f)
321 self.assertEqual(s, (capture, (1,), {}, {}))
322 self.assertIs(type(s[1]), tuple)
323 r = f(2)
324 self.assertEqual(r, ((1, 2), {}))
325 self.assertIs(type(r[0]), tuple)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000326
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300327 def test_recursive_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000328 with self.AllowPickle():
329 f = self.partial(capture)
330 f.__setstate__((f, (), {}, {}))
331 try:
332 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
333 with self.assertRaises(RecursionError):
334 pickle.dumps(f, proto)
335 finally:
336 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300337
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000338 f = self.partial(capture)
339 f.__setstate__((capture, (f,), {}, {}))
340 try:
341 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
342 f_copy = pickle.loads(pickle.dumps(f, proto))
343 try:
344 self.assertIs(f_copy.args[0], f_copy)
345 finally:
346 f_copy.__setstate__((capture, (), {}, {}))
347 finally:
348 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300349
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000350 f = self.partial(capture)
351 f.__setstate__((capture, (), {'a': f}, {}))
352 try:
353 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
354 f_copy = pickle.loads(pickle.dumps(f, proto))
355 try:
356 self.assertIs(f_copy.keywords['a'], f_copy)
357 finally:
358 f_copy.__setstate__((capture, (), {}, {}))
359 finally:
360 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300361
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200362 # Issue 6083: Reference counting bug
363 def test_setstate_refcount(self):
364 class BadSequence:
365 def __len__(self):
366 return 4
367 def __getitem__(self, key):
368 if key == 0:
369 return max
370 elif key == 1:
371 return tuple(range(1000000))
372 elif key in (2, 3):
373 return {}
374 raise IndexError
375
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200376 f = self.partial(object)
Serhiy Storchaka38741282016-02-02 18:45:17 +0200377 self.assertRaises(TypeError, f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000378
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000379@unittest.skipUnless(c_functools, 'requires the C _functools module')
380class TestPartialC(TestPartial, unittest.TestCase):
381 if c_functools:
382 partial = c_functools.partial
383
384 class AllowPickle:
385 def __enter__(self):
386 return self
387 def __exit__(self, type, value, tb):
388 return False
389
390 def test_attributes_unwritable(self):
391 # attributes should not be writable
392 p = self.partial(capture, 1, 2, a=10, b=20)
393 self.assertRaises(AttributeError, setattr, p, 'func', map)
394 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
395 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
396
397 p = self.partial(hex)
398 try:
399 del p.__dict__
400 except TypeError:
401 pass
402 else:
403 self.fail('partial object allowed __dict__ to be deleted')
Łukasz Langa6f692512013-06-05 12:20:24 +0200404
Michael Seifert53b26672017-03-15 08:42:02 +0100405 def test_manually_adding_non_string_keyword(self):
406 p = self.partial(capture)
407 # Adding a non-string/unicode keyword to partial kwargs
408 p.keywords[1234] = 'value'
409 r = repr(p)
410 self.assertIn('1234', r)
411 self.assertIn("'value'", r)
412 with self.assertRaises(TypeError):
413 p()
414
415 def test_keystr_replaces_value(self):
416 p = self.partial(capture)
417
418 class MutatesYourDict(object):
419 def __str__(self):
420 p.keywords[self] = ['sth2']
421 return 'astr'
422
423 # Raplacing the value during key formatting should keep the original
424 # value alive (at least long enough).
425 p.keywords[MutatesYourDict()] = ['sth']
426 r = repr(p)
427 self.assertIn('astr', r)
428 self.assertIn("['sth']", r)
429
430
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200431class TestPartialPy(TestPartial, unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000432 partial = py_functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000433
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000434 class AllowPickle:
435 def __init__(self):
436 self._cm = replaced_module("functools", py_functools)
437 def __enter__(self):
438 return self._cm.__enter__()
439 def __exit__(self, type, value, tb):
440 return self._cm.__exit__(type, value, tb)
Łukasz Langa6f692512013-06-05 12:20:24 +0200441
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200442if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000443 class CPartialSubclass(c_functools.partial):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200444 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100445
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000446class PyPartialSubclass(py_functools.partial):
447 pass
Łukasz Langa6f692512013-06-05 12:20:24 +0200448
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200449@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200450class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200451 if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000452 partial = CPartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000453
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300454 # partial subclasses are not optimized for nested calls
455 test_nested_optimization = None
456
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000457class TestPartialPySubclass(TestPartialPy):
458 partial = PyPartialSubclass
Łukasz Langa6f692512013-06-05 12:20:24 +0200459
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000460class TestPartialMethod(unittest.TestCase):
461
462 class A(object):
463 nothing = functools.partialmethod(capture)
464 positional = functools.partialmethod(capture, 1)
465 keywords = functools.partialmethod(capture, a=2)
466 both = functools.partialmethod(capture, 3, b=4)
467
468 nested = functools.partialmethod(positional, 5)
469
470 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
471
472 static = functools.partialmethod(staticmethod(capture), 8)
473 cls = functools.partialmethod(classmethod(capture), d=9)
474
475 a = A()
476
477 def test_arg_combinations(self):
478 self.assertEqual(self.a.nothing(), ((self.a,), {}))
479 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
480 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
481 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
482
483 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
484 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
485 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
486 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
487
488 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
489 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
490 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
491 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
492
493 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
494 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
495 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
496 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
497
498 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
499
500 def test_nested(self):
501 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
502 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
503 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
504 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
505
506 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
507
508 def test_over_partial(self):
509 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
510 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
511 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
512 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
513
514 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
515
516 def test_bound_method_introspection(self):
517 obj = self.a
518 self.assertIs(obj.both.__self__, obj)
519 self.assertIs(obj.nested.__self__, obj)
520 self.assertIs(obj.over_partial.__self__, obj)
521 self.assertIs(obj.cls.__self__, self.A)
522 self.assertIs(self.A.cls.__self__, self.A)
523
524 def test_unbound_method_retrieval(self):
525 obj = self.A
526 self.assertFalse(hasattr(obj.both, "__self__"))
527 self.assertFalse(hasattr(obj.nested, "__self__"))
528 self.assertFalse(hasattr(obj.over_partial, "__self__"))
529 self.assertFalse(hasattr(obj.static, "__self__"))
530 self.assertFalse(hasattr(self.a.static, "__self__"))
531
532 def test_descriptors(self):
533 for obj in [self.A, self.a]:
534 with self.subTest(obj=obj):
535 self.assertEqual(obj.static(), ((8,), {}))
536 self.assertEqual(obj.static(5), ((8, 5), {}))
537 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
538 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
539
540 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
541 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
542 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
543 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
544
545 def test_overriding_keywords(self):
546 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
547 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
548
549 def test_invalid_args(self):
550 with self.assertRaises(TypeError):
551 class B(object):
552 method = functools.partialmethod(None, 1)
553
554 def test_repr(self):
555 self.assertEqual(repr(vars(self.A)['both']),
556 'functools.partialmethod({}, 3, b=4)'.format(capture))
557
558 def test_abstract(self):
559 class Abstract(abc.ABCMeta):
560
561 @abc.abstractmethod
562 def add(self, x, y):
563 pass
564
565 add5 = functools.partialmethod(add, 5)
566
567 self.assertTrue(Abstract.add.__isabstractmethod__)
568 self.assertTrue(Abstract.add5.__isabstractmethod__)
569
570 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
571 self.assertFalse(getattr(func, '__isabstractmethod__', False))
572
573
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000574class TestUpdateWrapper(unittest.TestCase):
575
576 def check_wrapper(self, wrapper, wrapped,
577 assigned=functools.WRAPPER_ASSIGNMENTS,
578 updated=functools.WRAPPER_UPDATES):
579 # Check attributes were assigned
580 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000581 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000582 # Check attributes were updated
583 for name in updated:
584 wrapper_attr = getattr(wrapper, name)
585 wrapped_attr = getattr(wrapped, name)
586 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000587 if name == "__dict__" and key == "__wrapped__":
588 # __wrapped__ is overwritten by the update code
589 continue
590 self.assertIs(wrapped_attr[key], wrapper_attr[key])
591 # Check __wrapped__
592 self.assertIs(wrapper.__wrapped__, wrapped)
593
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000594
R. David Murray378c0cf2010-02-24 01:46:21 +0000595 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000596 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000597 """This is a test"""
598 pass
599 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000600 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000601 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000602 pass
603 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000604 return wrapper, f
605
606 def test_default_update(self):
607 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000608 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000609 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000610 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600611 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000612 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000613 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
614 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000615
R. David Murray378c0cf2010-02-24 01:46:21 +0000616 @unittest.skipIf(sys.flags.optimize >= 2,
617 "Docstrings are omitted with -O2 and above")
618 def test_default_update_doc(self):
619 wrapper, f = self._default_update()
620 self.assertEqual(wrapper.__doc__, 'This is a test')
621
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000622 def test_no_update(self):
623 def f():
624 """This is a test"""
625 pass
626 f.attr = 'This is also a test'
627 def wrapper():
628 pass
629 functools.update_wrapper(wrapper, f, (), ())
630 self.check_wrapper(wrapper, f, (), ())
631 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600632 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000633 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000634 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000635 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000636
637 def test_selective_update(self):
638 def f():
639 pass
640 f.attr = 'This is a different test'
641 f.dict_attr = dict(a=1, b=2, c=3)
642 def wrapper():
643 pass
644 wrapper.dict_attr = {}
645 assign = ('attr',)
646 update = ('dict_attr',)
647 functools.update_wrapper(wrapper, f, assign, update)
648 self.check_wrapper(wrapper, f, assign, update)
649 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600650 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000651 self.assertEqual(wrapper.__doc__, None)
652 self.assertEqual(wrapper.attr, 'This is a different test')
653 self.assertEqual(wrapper.dict_attr, f.dict_attr)
654
Nick Coghlan98876832010-08-17 06:17:18 +0000655 def test_missing_attributes(self):
656 def f():
657 pass
658 def wrapper():
659 pass
660 wrapper.dict_attr = {}
661 assign = ('attr',)
662 update = ('dict_attr',)
663 # Missing attributes on wrapped object are ignored
664 functools.update_wrapper(wrapper, f, assign, update)
665 self.assertNotIn('attr', wrapper.__dict__)
666 self.assertEqual(wrapper.dict_attr, {})
667 # Wrapper must have expected attributes for updating
668 del wrapper.dict_attr
669 with self.assertRaises(AttributeError):
670 functools.update_wrapper(wrapper, f, assign, update)
671 wrapper.dict_attr = 1
672 with self.assertRaises(AttributeError):
673 functools.update_wrapper(wrapper, f, assign, update)
674
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200675 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000676 @unittest.skipIf(sys.flags.optimize >= 2,
677 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000678 def test_builtin_update(self):
679 # Test for bug #1576241
680 def wrapper():
681 pass
682 functools.update_wrapper(wrapper, max)
683 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000684 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000685 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000686
Łukasz Langa6f692512013-06-05 12:20:24 +0200687
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000688class TestWraps(TestUpdateWrapper):
689
R. David Murray378c0cf2010-02-24 01:46:21 +0000690 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000691 def f():
692 """This is a test"""
693 pass
694 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000695 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000696 @functools.wraps(f)
697 def wrapper():
698 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600699 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000700
701 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600702 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000703 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000704 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600705 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000706 self.assertEqual(wrapper.attr, 'This is also a test')
707
Antoine Pitroub5b37142012-11-13 21:35:40 +0100708 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000709 "Docstrings are omitted with -O2 and above")
710 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600711 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000712 self.assertEqual(wrapper.__doc__, 'This is a test')
713
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000714 def test_no_update(self):
715 def f():
716 """This is a test"""
717 pass
718 f.attr = 'This is also a test'
719 @functools.wraps(f, (), ())
720 def wrapper():
721 pass
722 self.check_wrapper(wrapper, f, (), ())
723 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600724 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000725 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000726 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000727
728 def test_selective_update(self):
729 def f():
730 pass
731 f.attr = 'This is a different test'
732 f.dict_attr = dict(a=1, b=2, c=3)
733 def add_dict_attr(f):
734 f.dict_attr = {}
735 return f
736 assign = ('attr',)
737 update = ('dict_attr',)
738 @functools.wraps(f, assign, update)
739 @add_dict_attr
740 def wrapper():
741 pass
742 self.check_wrapper(wrapper, f, assign, update)
743 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600744 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000745 self.assertEqual(wrapper.__doc__, None)
746 self.assertEqual(wrapper.attr, 'This is a different test')
747 self.assertEqual(wrapper.dict_attr, f.dict_attr)
748
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000749@unittest.skipUnless(c_functools, 'requires the C _functools module')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000750class TestReduce(unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000751 if c_functools:
752 func = c_functools.reduce
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000753
754 def test_reduce(self):
755 class Squares:
756 def __init__(self, max):
757 self.max = max
758 self.sofar = []
759
760 def __len__(self):
761 return len(self.sofar)
762
763 def __getitem__(self, i):
764 if not 0 <= i < self.max: raise IndexError
765 n = len(self.sofar)
766 while n <= i:
767 self.sofar.append(n*n)
768 n += 1
769 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000770 def add(x, y):
771 return x + y
772 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000773 self.assertEqual(
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000774 self.func(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000775 ['a','c','d','w']
776 )
777 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
778 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000779 self.func(lambda x, y: x*y, range(2,21), 1),
780 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000781 )
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000782 self.assertEqual(self.func(add, Squares(10)), 285)
783 self.assertEqual(self.func(add, Squares(10), 0), 285)
784 self.assertEqual(self.func(add, Squares(0), 0), 0)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000785 self.assertRaises(TypeError, self.func)
786 self.assertRaises(TypeError, self.func, 42, 42)
787 self.assertRaises(TypeError, self.func, 42, 42, 42)
788 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
789 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
790 self.assertRaises(TypeError, self.func, 42, (42, 42))
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000791 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
792 self.assertRaises(TypeError, self.func, add, "")
793 self.assertRaises(TypeError, self.func, add, ())
794 self.assertRaises(TypeError, self.func, add, object())
795
796 class TestFailingIter:
797 def __iter__(self):
798 raise RuntimeError
799 self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
800
801 self.assertEqual(self.func(add, [], None), None)
802 self.assertEqual(self.func(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000803
804 class BadSeq:
805 def __getitem__(self, index):
806 raise ValueError
807 self.assertRaises(ValueError, self.func, 42, BadSeq())
808
809 # Test reduce()'s use of iterators.
810 def test_iterator_usage(self):
811 class SequenceClass:
812 def __init__(self, n):
813 self.n = n
814 def __getitem__(self, i):
815 if 0 <= i < self.n:
816 return i
817 else:
818 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000819
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000820 from operator import add
821 self.assertEqual(self.func(add, SequenceClass(5)), 10)
822 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
823 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
824 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
825 self.assertEqual(self.func(add, SequenceClass(1)), 0)
826 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
827
828 d = {"one": 1, "two": 2, "three": 3}
829 self.assertEqual(self.func(add, d), "".join(d.keys()))
830
Łukasz Langa6f692512013-06-05 12:20:24 +0200831
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200832class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700833
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000834 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700835 def cmp1(x, y):
836 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100837 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700838 self.assertEqual(key(3), key(3))
839 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100840 self.assertGreaterEqual(key(3), key(3))
841
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700842 def cmp2(x, y):
843 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100844 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700845 self.assertEqual(key(4.0), key('4'))
846 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100847 self.assertLessEqual(key(2), key('35'))
848 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700849
850 def test_cmp_to_key_arguments(self):
851 def cmp1(x, y):
852 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100853 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700854 self.assertEqual(key(obj=3), key(obj=3))
855 self.assertGreater(key(obj=3), key(obj=1))
856 with self.assertRaises((TypeError, AttributeError)):
857 key(3) > 1 # rhs is not a K object
858 with self.assertRaises((TypeError, AttributeError)):
859 1 < key(3) # lhs is not a K object
860 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100861 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700862 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200863 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100864 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700865 with self.assertRaises(TypeError):
866 key() # too few args
867 with self.assertRaises(TypeError):
868 key(None, None) # too many args
869
870 def test_bad_cmp(self):
871 def cmp1(x, y):
872 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100873 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700874 with self.assertRaises(ZeroDivisionError):
875 key(3) > key(1)
876
877 class BadCmp:
878 def __lt__(self, other):
879 raise ZeroDivisionError
880 def cmp1(x, y):
881 return BadCmp()
882 with self.assertRaises(ZeroDivisionError):
883 key(3) > key(1)
884
885 def test_obj_field(self):
886 def cmp1(x, y):
887 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100888 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700889 self.assertEqual(key(50).obj, 50)
890
891 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000892 def mycmp(x, y):
893 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100894 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000895 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000896
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700897 def test_sort_int_str(self):
898 def mycmp(x, y):
899 x, y = int(x), int(y)
900 return (x > y) - (x < y)
901 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100902 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700903 self.assertEqual([int(value) for value in values],
904 [0, 1, 1, 2, 3, 4, 5, 7, 10])
905
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000906 def test_hash(self):
907 def mycmp(x, y):
908 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100909 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000910 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700911 self.assertRaises(TypeError, hash, k)
Raymond Hettingere7a24302011-05-03 11:16:36 -0700912 self.assertNotIsInstance(k, collections.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000913
Łukasz Langa6f692512013-06-05 12:20:24 +0200914
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200915@unittest.skipUnless(c_functools, 'requires the C _functools module')
916class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
917 if c_functools:
918 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100919
Łukasz Langa6f692512013-06-05 12:20:24 +0200920
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200921class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100922 cmp_to_key = staticmethod(py_functools.cmp_to_key)
923
Łukasz Langa6f692512013-06-05 12:20:24 +0200924
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000925class TestTotalOrdering(unittest.TestCase):
926
927 def test_total_ordering_lt(self):
928 @functools.total_ordering
929 class A:
930 def __init__(self, value):
931 self.value = value
932 def __lt__(self, other):
933 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000934 def __eq__(self, other):
935 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000936 self.assertTrue(A(1) < A(2))
937 self.assertTrue(A(2) > A(1))
938 self.assertTrue(A(1) <= A(2))
939 self.assertTrue(A(2) >= A(1))
940 self.assertTrue(A(2) <= A(2))
941 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000942 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000943
944 def test_total_ordering_le(self):
945 @functools.total_ordering
946 class A:
947 def __init__(self, value):
948 self.value = value
949 def __le__(self, other):
950 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000951 def __eq__(self, other):
952 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000953 self.assertTrue(A(1) < A(2))
954 self.assertTrue(A(2) > A(1))
955 self.assertTrue(A(1) <= A(2))
956 self.assertTrue(A(2) >= A(1))
957 self.assertTrue(A(2) <= A(2))
958 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000959 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000960
961 def test_total_ordering_gt(self):
962 @functools.total_ordering
963 class A:
964 def __init__(self, value):
965 self.value = value
966 def __gt__(self, other):
967 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000968 def __eq__(self, other):
969 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000970 self.assertTrue(A(1) < A(2))
971 self.assertTrue(A(2) > A(1))
972 self.assertTrue(A(1) <= A(2))
973 self.assertTrue(A(2) >= A(1))
974 self.assertTrue(A(2) <= A(2))
975 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000976 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000977
978 def test_total_ordering_ge(self):
979 @functools.total_ordering
980 class A:
981 def __init__(self, value):
982 self.value = value
983 def __ge__(self, other):
984 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000985 def __eq__(self, other):
986 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000987 self.assertTrue(A(1) < A(2))
988 self.assertTrue(A(2) > A(1))
989 self.assertTrue(A(1) <= A(2))
990 self.assertTrue(A(2) >= A(1))
991 self.assertTrue(A(2) <= A(2))
992 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000993 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000994
995 def test_total_ordering_no_overwrite(self):
996 # new methods should not overwrite existing
997 @functools.total_ordering
998 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000999 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +00001000 self.assertTrue(A(1) < A(2))
1001 self.assertTrue(A(2) > A(1))
1002 self.assertTrue(A(1) <= A(2))
1003 self.assertTrue(A(2) >= A(1))
1004 self.assertTrue(A(2) <= A(2))
1005 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001006
Benjamin Peterson42ebee32010-04-11 01:43:16 +00001007 def test_no_operations_defined(self):
1008 with self.assertRaises(ValueError):
1009 @functools.total_ordering
1010 class A:
1011 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001012
Nick Coghlanf05d9812013-10-02 00:02:03 +10001013 def test_type_error_when_not_implemented(self):
1014 # bug 10042; ensure stack overflow does not occur
1015 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001016 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +10001017 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001018 def __init__(self, value):
1019 self.value = value
1020 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001021 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001022 return self.value == other.value
1023 return False
1024 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001025 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001026 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +10001027 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001028
Nick Coghlanf05d9812013-10-02 00:02:03 +10001029 @functools.total_ordering
1030 class ImplementsGreaterThan:
1031 def __init__(self, value):
1032 self.value = value
1033 def __eq__(self, other):
1034 if isinstance(other, ImplementsGreaterThan):
1035 return self.value == other.value
1036 return False
1037 def __gt__(self, other):
1038 if isinstance(other, ImplementsGreaterThan):
1039 return self.value > other.value
1040 return NotImplemented
1041
1042 @functools.total_ordering
1043 class ImplementsLessThanEqualTo:
1044 def __init__(self, value):
1045 self.value = value
1046 def __eq__(self, other):
1047 if isinstance(other, ImplementsLessThanEqualTo):
1048 return self.value == other.value
1049 return False
1050 def __le__(self, other):
1051 if isinstance(other, ImplementsLessThanEqualTo):
1052 return self.value <= other.value
1053 return NotImplemented
1054
1055 @functools.total_ordering
1056 class ImplementsGreaterThanEqualTo:
1057 def __init__(self, value):
1058 self.value = value
1059 def __eq__(self, other):
1060 if isinstance(other, ImplementsGreaterThanEqualTo):
1061 return self.value == other.value
1062 return False
1063 def __ge__(self, other):
1064 if isinstance(other, ImplementsGreaterThanEqualTo):
1065 return self.value >= other.value
1066 return NotImplemented
1067
1068 @functools.total_ordering
1069 class ComparatorNotImplemented:
1070 def __init__(self, value):
1071 self.value = value
1072 def __eq__(self, other):
1073 if isinstance(other, ComparatorNotImplemented):
1074 return self.value == other.value
1075 return False
1076 def __lt__(self, other):
1077 return NotImplemented
1078
1079 with self.subTest("LT < 1"), self.assertRaises(TypeError):
1080 ImplementsLessThan(-1) < 1
1081
1082 with self.subTest("LT < LE"), self.assertRaises(TypeError):
1083 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
1084
1085 with self.subTest("LT < GT"), self.assertRaises(TypeError):
1086 ImplementsLessThan(1) < ImplementsGreaterThan(1)
1087
1088 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
1089 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
1090
1091 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
1092 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
1093
1094 with self.subTest("GT > GE"), self.assertRaises(TypeError):
1095 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
1096
1097 with self.subTest("GT > LT"), self.assertRaises(TypeError):
1098 ImplementsGreaterThan(5) > ImplementsLessThan(5)
1099
1100 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
1101 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
1102
1103 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
1104 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
1105
1106 with self.subTest("GE when equal"):
1107 a = ComparatorNotImplemented(8)
1108 b = ComparatorNotImplemented(8)
1109 self.assertEqual(a, b)
1110 with self.assertRaises(TypeError):
1111 a >= b
1112
1113 with self.subTest("LE when equal"):
1114 a = ComparatorNotImplemented(9)
1115 b = ComparatorNotImplemented(9)
1116 self.assertEqual(a, b)
1117 with self.assertRaises(TypeError):
1118 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +02001119
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001120 def test_pickle(self):
Serhiy Storchaka92bb90a2016-09-22 11:39:25 +03001121 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001122 for name in '__lt__', '__gt__', '__le__', '__ge__':
1123 with self.subTest(method=name, proto=proto):
1124 method = getattr(Orderable_LT, name)
1125 method_copy = pickle.loads(pickle.dumps(method, proto))
1126 self.assertIs(method_copy, method)
1127
1128@functools.total_ordering
1129class Orderable_LT:
1130 def __init__(self, value):
1131 self.value = value
1132 def __lt__(self, other):
1133 return self.value < other.value
1134 def __eq__(self, other):
1135 return self.value == other.value
1136
1137
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001138class TestLRU:
Georg Brandl2e7346a2010-07-31 18:09:23 +00001139
1140 def test_lru(self):
1141 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001142 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001143 f = self.module.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001144 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001145 self.assertEqual(maxsize, 20)
1146 self.assertEqual(currsize, 0)
1147 self.assertEqual(hits, 0)
1148 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001149
1150 domain = range(5)
1151 for i in range(1000):
1152 x, y = choice(domain), choice(domain)
1153 actual = f(x, y)
1154 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +00001155 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001156 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001157 self.assertTrue(hits > misses)
1158 self.assertEqual(hits + misses, 1000)
1159 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001160
Raymond Hettinger02566ec2010-09-04 22:46:06 +00001161 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +00001162 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001163 self.assertEqual(hits, 0)
1164 self.assertEqual(misses, 0)
1165 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001166 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001167 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001168 self.assertEqual(hits, 0)
1169 self.assertEqual(misses, 1)
1170 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001171
Nick Coghlan98876832010-08-17 06:17:18 +00001172 # Test bypassing the cache
1173 self.assertIs(f.__wrapped__, orig)
1174 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001175 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001176 self.assertEqual(hits, 0)
1177 self.assertEqual(misses, 1)
1178 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +00001179
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001180 # test size zero (which means "never-cache")
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001181 @self.module.lru_cache(0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001182 def f():
1183 nonlocal f_cnt
1184 f_cnt += 1
1185 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001186 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001187 f_cnt = 0
1188 for i in range(5):
1189 self.assertEqual(f(), 20)
1190 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001191 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001192 self.assertEqual(hits, 0)
1193 self.assertEqual(misses, 5)
1194 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001195
1196 # test size one
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001197 @self.module.lru_cache(1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001198 def f():
1199 nonlocal f_cnt
1200 f_cnt += 1
1201 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001202 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001203 f_cnt = 0
1204 for i in range(5):
1205 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001206 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001207 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001208 self.assertEqual(hits, 4)
1209 self.assertEqual(misses, 1)
1210 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001211
Raymond Hettingerf3098282010-08-15 03:30:45 +00001212 # test size two
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001213 @self.module.lru_cache(2)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001214 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001215 nonlocal f_cnt
1216 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +00001217 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +00001218 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001219 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +00001220 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1221 # * * * *
1222 self.assertEqual(f(x), x*10)
1223 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001224 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001225 self.assertEqual(hits, 12)
1226 self.assertEqual(misses, 4)
1227 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001228
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08001229 def test_lru_reentrancy_with_len(self):
1230 # Test to make sure the LRU cache code isn't thrown-off by
1231 # caching the built-in len() function. Since len() can be
1232 # cached, we shouldn't use it inside the lru code itself.
1233 old_len = builtins.len
1234 try:
1235 builtins.len = self.module.lru_cache(4)(len)
1236 for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
1237 self.assertEqual(len('abcdefghijklmn'[:i]), i)
1238 finally:
1239 builtins.len = old_len
1240
Yury Selivanov46a02db2016-11-09 18:55:45 -05001241 def test_lru_type_error(self):
1242 # Regression test for issue #28653.
1243 # lru_cache was leaking when one of the arguments
1244 # wasn't cacheable.
1245
1246 @functools.lru_cache(maxsize=None)
1247 def infinite_cache(o):
1248 pass
1249
1250 @functools.lru_cache(maxsize=10)
1251 def limited_cache(o):
1252 pass
1253
1254 with self.assertRaises(TypeError):
1255 infinite_cache([])
1256
1257 with self.assertRaises(TypeError):
1258 limited_cache([])
1259
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001260 def test_lru_with_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001261 @self.module.lru_cache(maxsize=None)
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001262 def fib(n):
1263 if n < 2:
1264 return n
1265 return fib(n-1) + fib(n-2)
1266 self.assertEqual([fib(n) for n in range(16)],
1267 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1268 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001269 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001270 fib.cache_clear()
1271 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001272 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1273
1274 def test_lru_with_maxsize_negative(self):
1275 @self.module.lru_cache(maxsize=-10)
1276 def eq(n):
1277 return n
1278 for i in (0, 1):
1279 self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1280 self.assertEqual(eq.cache_info(),
1281 self.module._CacheInfo(hits=0, misses=300, maxsize=-10, currsize=1))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001282
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001283 def test_lru_with_exceptions(self):
1284 # Verify that user_function exceptions get passed through without
1285 # creating a hard-to-read chained exception.
1286 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +01001287 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001288 @self.module.lru_cache(maxsize)
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001289 def func(i):
1290 return 'abc'[i]
1291 self.assertEqual(func(0), 'a')
1292 with self.assertRaises(IndexError) as cm:
1293 func(15)
1294 self.assertIsNone(cm.exception.__context__)
1295 # Verify that the previous exception did not result in a cached entry
1296 with self.assertRaises(IndexError):
1297 func(15)
1298
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001299 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001300 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001301 @self.module.lru_cache(maxsize=maxsize, typed=True)
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001302 def square(x):
1303 return x * x
1304 self.assertEqual(square(3), 9)
1305 self.assertEqual(type(square(3)), type(9))
1306 self.assertEqual(square(3.0), 9.0)
1307 self.assertEqual(type(square(3.0)), type(9.0))
1308 self.assertEqual(square(x=3), 9)
1309 self.assertEqual(type(square(x=3)), type(9))
1310 self.assertEqual(square(x=3.0), 9.0)
1311 self.assertEqual(type(square(x=3.0)), type(9.0))
1312 self.assertEqual(square.cache_info().hits, 4)
1313 self.assertEqual(square.cache_info().misses, 4)
1314
Antoine Pitroub5b37142012-11-13 21:35:40 +01001315 def test_lru_with_keyword_args(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001316 @self.module.lru_cache()
Antoine Pitroub5b37142012-11-13 21:35:40 +01001317 def fib(n):
1318 if n < 2:
1319 return n
1320 return fib(n=n-1) + fib(n=n-2)
1321 self.assertEqual(
1322 [fib(n=number) for number in range(16)],
1323 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1324 )
1325 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001326 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001327 fib.cache_clear()
1328 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001329 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001330
1331 def test_lru_with_keyword_args_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001332 @self.module.lru_cache(maxsize=None)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001333 def fib(n):
1334 if n < 2:
1335 return n
1336 return fib(n=n-1) + fib(n=n-2)
1337 self.assertEqual([fib(n=number) for number in range(16)],
1338 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1339 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001340 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001341 fib.cache_clear()
1342 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001343 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1344
Raymond Hettinger4ee39142017-01-08 17:28:20 -08001345 def test_kwargs_order(self):
1346 # PEP 468: Preserving Keyword Argument Order
1347 @self.module.lru_cache(maxsize=10)
1348 def f(**kwargs):
1349 return list(kwargs.items())
1350 self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)])
1351 self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)])
1352 self.assertEqual(f.cache_info(),
1353 self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2))
1354
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001355 def test_lru_cache_decoration(self):
1356 def f(zomg: 'zomg_annotation'):
1357 """f doc string"""
1358 return 42
1359 g = self.module.lru_cache()(f)
1360 for attr in self.module.WRAPPER_ASSIGNMENTS:
1361 self.assertEqual(getattr(g, attr), getattr(f, attr))
1362
1363 @unittest.skipUnless(threading, 'This test requires threading.')
1364 def test_lru_cache_threaded(self):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001365 n, m = 5, 11
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001366 def orig(x, y):
1367 return 3 * x + y
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001368 f = self.module.lru_cache(maxsize=n*m)(orig)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001369 hits, misses, maxsize, currsize = f.cache_info()
1370 self.assertEqual(currsize, 0)
1371
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001372 start = threading.Event()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001373 def full(k):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001374 start.wait(10)
1375 for _ in range(m):
Serhiy Storchaka391af752015-06-08 12:44:18 +03001376 self.assertEqual(f(k, 0), orig(k, 0))
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001377
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001378 def clear():
1379 start.wait(10)
1380 for _ in range(2*m):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001381 f.cache_clear()
1382
1383 orig_si = sys.getswitchinterval()
Xavier de Gaye7522ef42016-12-08 11:06:56 +01001384 support.setswitchinterval(1e-6)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001385 try:
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001386 # create n threads in order to fill cache
Serhiy Storchaka391af752015-06-08 12:44:18 +03001387 threads = [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001388 for k in range(n)]
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001389 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001390 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001391
1392 hits, misses, maxsize, currsize = f.cache_info()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001393 if self.module is py_functools:
1394 # XXX: Why can be not equal?
1395 self.assertLessEqual(misses, n)
1396 self.assertLessEqual(hits, m*n - misses)
1397 else:
1398 self.assertEqual(misses, n)
1399 self.assertEqual(hits, m*n - misses)
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001400 self.assertEqual(currsize, n)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001401
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001402 # create n threads in order to fill cache and 1 to clear it
1403 threads = [threading.Thread(target=clear)]
Serhiy Storchaka391af752015-06-08 12:44:18 +03001404 threads += [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001405 for k in range(n)]
1406 start.clear()
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001407 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001408 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001409 finally:
1410 sys.setswitchinterval(orig_si)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001411
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001412 @unittest.skipUnless(threading, 'This test requires threading.')
1413 def test_lru_cache_threaded2(self):
1414 # Simultaneous call with the same arguments
1415 n, m = 5, 7
1416 start = threading.Barrier(n+1)
1417 pause = threading.Barrier(n+1)
1418 stop = threading.Barrier(n+1)
1419 @self.module.lru_cache(maxsize=m*n)
1420 def f(x):
1421 pause.wait(10)
1422 return 3 * x
1423 self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1424 def test():
1425 for i in range(m):
1426 start.wait(10)
1427 self.assertEqual(f(i), 3 * i)
1428 stop.wait(10)
1429 threads = [threading.Thread(target=test) for k in range(n)]
1430 with support.start_threads(threads):
1431 for i in range(m):
1432 start.wait(10)
1433 stop.reset()
1434 pause.wait(10)
1435 start.reset()
1436 stop.wait(10)
1437 pause.reset()
1438 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1439
Serhiy Storchaka67796522017-01-12 18:34:33 +02001440 @unittest.skipUnless(threading, 'This test requires threading.')
1441 def test_lru_cache_threaded3(self):
1442 @self.module.lru_cache(maxsize=2)
1443 def f(x):
1444 time.sleep(.01)
1445 return 3 * x
1446 def test(i, x):
1447 with self.subTest(thread=i):
1448 self.assertEqual(f(x), 3 * x, i)
1449 threads = [threading.Thread(target=test, args=(i, v))
1450 for i, v in enumerate([1, 2, 2, 3, 2])]
1451 with support.start_threads(threads):
1452 pass
1453
Raymond Hettinger03923422013-03-04 02:52:50 -05001454 def test_need_for_rlock(self):
1455 # This will deadlock on an LRU cache that uses a regular lock
1456
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001457 @self.module.lru_cache(maxsize=10)
Raymond Hettinger03923422013-03-04 02:52:50 -05001458 def test_func(x):
1459 'Used to demonstrate a reentrant lru_cache call within a single thread'
1460 return x
1461
1462 class DoubleEq:
1463 'Demonstrate a reentrant lru_cache call within a single thread'
1464 def __init__(self, x):
1465 self.x = x
1466 def __hash__(self):
1467 return self.x
1468 def __eq__(self, other):
1469 if self.x == 2:
1470 test_func(DoubleEq(1))
1471 return self.x == other.x
1472
1473 test_func(DoubleEq(1)) # Load the cache
1474 test_func(DoubleEq(2)) # Load the cache
1475 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1476 DoubleEq(2)) # Verify the correct return value
1477
Raymond Hettinger4d588972014-08-12 12:44:52 -07001478 def test_early_detection_of_bad_call(self):
1479 # Issue #22184
1480 with self.assertRaises(TypeError):
1481 @functools.lru_cache
1482 def f():
1483 pass
1484
Serhiy Storchakae7070f02015-06-08 11:19:24 +03001485 def test_lru_method(self):
1486 class X(int):
1487 f_cnt = 0
1488 @self.module.lru_cache(2)
1489 def f(self, x):
1490 self.f_cnt += 1
1491 return x*10+self
1492 a = X(5)
1493 b = X(5)
1494 c = X(7)
1495 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1496
1497 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1498 self.assertEqual(a.f(x), x*10 + 5)
1499 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1500 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1501
1502 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1503 self.assertEqual(b.f(x), x*10 + 5)
1504 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1505 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1506
1507 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1508 self.assertEqual(c.f(x), x*10 + 7)
1509 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1510 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1511
1512 self.assertEqual(a.f.cache_info(), X.f.cache_info())
1513 self.assertEqual(b.f.cache_info(), X.f.cache_info())
1514 self.assertEqual(c.f.cache_info(), X.f.cache_info())
1515
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001516 def test_pickle(self):
1517 cls = self.__class__
1518 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1519 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1520 with self.subTest(proto=proto, func=f):
1521 f_copy = pickle.loads(pickle.dumps(f, proto))
1522 self.assertIs(f_copy, f)
1523
1524 def test_copy(self):
1525 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001526 def orig(x, y):
1527 return 3 * x + y
1528 part = self.module.partial(orig, 2)
1529 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1530 self.module.lru_cache(2)(part))
1531 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001532 with self.subTest(func=f):
1533 f_copy = copy.copy(f)
1534 self.assertIs(f_copy, f)
1535
1536 def test_deepcopy(self):
1537 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001538 def orig(x, y):
1539 return 3 * x + y
1540 part = self.module.partial(orig, 2)
1541 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1542 self.module.lru_cache(2)(part))
1543 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001544 with self.subTest(func=f):
1545 f_copy = copy.deepcopy(f)
1546 self.assertIs(f_copy, f)
1547
1548
1549@py_functools.lru_cache()
1550def py_cached_func(x, y):
1551 return 3 * x + y
1552
1553@c_functools.lru_cache()
1554def c_cached_func(x, y):
1555 return 3 * x + y
1556
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001557
1558class TestLRUPy(TestLRU, unittest.TestCase):
1559 module = py_functools
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001560 cached_func = py_cached_func,
1561
1562 @module.lru_cache()
1563 def cached_meth(self, x, y):
1564 return 3 * x + y
1565
1566 @staticmethod
1567 @module.lru_cache()
1568 def cached_staticmeth(x, y):
1569 return 3 * x + y
1570
1571
1572class TestLRUC(TestLRU, unittest.TestCase):
1573 module = c_functools
1574 cached_func = c_cached_func,
1575
1576 @module.lru_cache()
1577 def cached_meth(self, x, y):
1578 return 3 * x + y
1579
1580 @staticmethod
1581 @module.lru_cache()
1582 def cached_staticmeth(x, y):
1583 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001584
Raymond Hettinger03923422013-03-04 02:52:50 -05001585
Łukasz Langa6f692512013-06-05 12:20:24 +02001586class TestSingleDispatch(unittest.TestCase):
1587 def test_simple_overloads(self):
1588 @functools.singledispatch
1589 def g(obj):
1590 return "base"
1591 def g_int(i):
1592 return "integer"
1593 g.register(int, g_int)
1594 self.assertEqual(g("str"), "base")
1595 self.assertEqual(g(1), "integer")
1596 self.assertEqual(g([1,2,3]), "base")
1597
1598 def test_mro(self):
1599 @functools.singledispatch
1600 def g(obj):
1601 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001602 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001603 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001604 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001605 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001606 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001607 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001608 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001609 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001610 def g_A(a):
1611 return "A"
1612 def g_B(b):
1613 return "B"
1614 g.register(A, g_A)
1615 g.register(B, g_B)
1616 self.assertEqual(g(A()), "A")
1617 self.assertEqual(g(B()), "B")
1618 self.assertEqual(g(C()), "A")
1619 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001620
1621 def test_register_decorator(self):
1622 @functools.singledispatch
1623 def g(obj):
1624 return "base"
1625 @g.register(int)
1626 def g_int(i):
1627 return "int %s" % (i,)
1628 self.assertEqual(g(""), "base")
1629 self.assertEqual(g(12), "int 12")
1630 self.assertIs(g.dispatch(int), g_int)
1631 self.assertIs(g.dispatch(object), g.dispatch(str))
1632 # Note: in the assert above this is not g.
1633 # @singledispatch returns the wrapper.
1634
1635 def test_wrapping_attributes(self):
1636 @functools.singledispatch
1637 def g(obj):
1638 "Simple test"
1639 return "Test"
1640 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02001641 if sys.flags.optimize < 2:
1642 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02001643
1644 @unittest.skipUnless(decimal, 'requires _decimal')
1645 @support.cpython_only
1646 def test_c_classes(self):
1647 @functools.singledispatch
1648 def g(obj):
1649 return "base"
1650 @g.register(decimal.DecimalException)
1651 def _(obj):
1652 return obj.args
1653 subn = decimal.Subnormal("Exponent < Emin")
1654 rnd = decimal.Rounded("Number got rounded")
1655 self.assertEqual(g(subn), ("Exponent < Emin",))
1656 self.assertEqual(g(rnd), ("Number got rounded",))
1657 @g.register(decimal.Subnormal)
1658 def _(obj):
1659 return "Too small to care."
1660 self.assertEqual(g(subn), "Too small to care.")
1661 self.assertEqual(g(rnd), ("Number got rounded",))
1662
1663 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001664 # None of the examples in this test depend on haystack ordering.
Łukasz Langa6f692512013-06-05 12:20:24 +02001665 c = collections
1666 mro = functools._compose_mro
1667 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1668 for haystack in permutations(bases):
1669 m = mro(dict, haystack)
Guido van Rossumf0666942016-08-23 10:47:07 -07001670 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
1671 c.Collection, c.Sized, c.Iterable,
1672 c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001673 bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict]
1674 for haystack in permutations(bases):
1675 m = mro(c.ChainMap, haystack)
1676 self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001677 c.Collection, c.Sized, c.Iterable,
1678 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001679
1680 # If there's a generic function with implementations registered for
1681 # both Sized and Container, passing a defaultdict to it results in an
1682 # ambiguous dispatch which will cause a RuntimeError (see
1683 # test_mro_conflicts).
1684 bases = [c.Container, c.Sized, str]
1685 for haystack in permutations(bases):
1686 m = mro(c.defaultdict, [c.Sized, c.Container, str])
1687 self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container,
1688 object])
1689
1690 # MutableSequence below is registered directly on D. In other words, it
Martin Panter46f50722016-05-26 05:35:26 +00001691 # precedes MutableMapping which means single dispatch will always
Łukasz Langa3720c772013-07-01 16:00:38 +02001692 # choose MutableSequence here.
1693 class D(c.defaultdict):
1694 pass
1695 c.MutableSequence.register(D)
1696 bases = [c.MutableSequence, c.MutableMapping]
1697 for haystack in permutations(bases):
1698 m = mro(D, bases)
Guido van Rossumf0666942016-08-23 10:47:07 -07001699 self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
1700 c.defaultdict, dict, c.MutableMapping, c.Mapping,
1701 c.Collection, c.Sized, c.Iterable, c.Container,
Łukasz Langa3720c772013-07-01 16:00:38 +02001702 object])
1703
1704 # Container and Callable are registered on different base classes and
1705 # a generic function supporting both should always pick the Callable
1706 # implementation if a C instance is passed.
1707 class C(c.defaultdict):
1708 def __call__(self):
1709 pass
1710 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1711 for haystack in permutations(bases):
1712 m = mro(C, haystack)
1713 self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001714 c.Collection, c.Sized, c.Iterable,
1715 c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001716
1717 def test_register_abc(self):
1718 c = collections
1719 d = {"a": "b"}
1720 l = [1, 2, 3]
1721 s = {object(), None}
1722 f = frozenset(s)
1723 t = (1, 2, 3)
1724 @functools.singledispatch
1725 def g(obj):
1726 return "base"
1727 self.assertEqual(g(d), "base")
1728 self.assertEqual(g(l), "base")
1729 self.assertEqual(g(s), "base")
1730 self.assertEqual(g(f), "base")
1731 self.assertEqual(g(t), "base")
1732 g.register(c.Sized, lambda obj: "sized")
1733 self.assertEqual(g(d), "sized")
1734 self.assertEqual(g(l), "sized")
1735 self.assertEqual(g(s), "sized")
1736 self.assertEqual(g(f), "sized")
1737 self.assertEqual(g(t), "sized")
1738 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1739 self.assertEqual(g(d), "mutablemapping")
1740 self.assertEqual(g(l), "sized")
1741 self.assertEqual(g(s), "sized")
1742 self.assertEqual(g(f), "sized")
1743 self.assertEqual(g(t), "sized")
1744 g.register(c.ChainMap, lambda obj: "chainmap")
1745 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1746 self.assertEqual(g(l), "sized")
1747 self.assertEqual(g(s), "sized")
1748 self.assertEqual(g(f), "sized")
1749 self.assertEqual(g(t), "sized")
1750 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1751 self.assertEqual(g(d), "mutablemapping")
1752 self.assertEqual(g(l), "mutablesequence")
1753 self.assertEqual(g(s), "sized")
1754 self.assertEqual(g(f), "sized")
1755 self.assertEqual(g(t), "sized")
1756 g.register(c.MutableSet, lambda obj: "mutableset")
1757 self.assertEqual(g(d), "mutablemapping")
1758 self.assertEqual(g(l), "mutablesequence")
1759 self.assertEqual(g(s), "mutableset")
1760 self.assertEqual(g(f), "sized")
1761 self.assertEqual(g(t), "sized")
1762 g.register(c.Mapping, lambda obj: "mapping")
1763 self.assertEqual(g(d), "mutablemapping") # not specific enough
1764 self.assertEqual(g(l), "mutablesequence")
1765 self.assertEqual(g(s), "mutableset")
1766 self.assertEqual(g(f), "sized")
1767 self.assertEqual(g(t), "sized")
1768 g.register(c.Sequence, lambda obj: "sequence")
1769 self.assertEqual(g(d), "mutablemapping")
1770 self.assertEqual(g(l), "mutablesequence")
1771 self.assertEqual(g(s), "mutableset")
1772 self.assertEqual(g(f), "sized")
1773 self.assertEqual(g(t), "sequence")
1774 g.register(c.Set, lambda obj: "set")
1775 self.assertEqual(g(d), "mutablemapping")
1776 self.assertEqual(g(l), "mutablesequence")
1777 self.assertEqual(g(s), "mutableset")
1778 self.assertEqual(g(f), "set")
1779 self.assertEqual(g(t), "sequence")
1780 g.register(dict, lambda obj: "dict")
1781 self.assertEqual(g(d), "dict")
1782 self.assertEqual(g(l), "mutablesequence")
1783 self.assertEqual(g(s), "mutableset")
1784 self.assertEqual(g(f), "set")
1785 self.assertEqual(g(t), "sequence")
1786 g.register(list, lambda obj: "list")
1787 self.assertEqual(g(d), "dict")
1788 self.assertEqual(g(l), "list")
1789 self.assertEqual(g(s), "mutableset")
1790 self.assertEqual(g(f), "set")
1791 self.assertEqual(g(t), "sequence")
1792 g.register(set, lambda obj: "concrete-set")
1793 self.assertEqual(g(d), "dict")
1794 self.assertEqual(g(l), "list")
1795 self.assertEqual(g(s), "concrete-set")
1796 self.assertEqual(g(f), "set")
1797 self.assertEqual(g(t), "sequence")
1798 g.register(frozenset, lambda obj: "frozen-set")
1799 self.assertEqual(g(d), "dict")
1800 self.assertEqual(g(l), "list")
1801 self.assertEqual(g(s), "concrete-set")
1802 self.assertEqual(g(f), "frozen-set")
1803 self.assertEqual(g(t), "sequence")
1804 g.register(tuple, lambda obj: "tuple")
1805 self.assertEqual(g(d), "dict")
1806 self.assertEqual(g(l), "list")
1807 self.assertEqual(g(s), "concrete-set")
1808 self.assertEqual(g(f), "frozen-set")
1809 self.assertEqual(g(t), "tuple")
1810
Łukasz Langa3720c772013-07-01 16:00:38 +02001811 def test_c3_abc(self):
1812 c = collections
1813 mro = functools._c3_mro
1814 class A(object):
1815 pass
1816 class B(A):
1817 def __len__(self):
1818 return 0 # implies Sized
1819 @c.Container.register
1820 class C(object):
1821 pass
1822 class D(object):
1823 pass # unrelated
1824 class X(D, C, B):
1825 def __call__(self):
1826 pass # implies Callable
1827 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
1828 for abcs in permutations([c.Sized, c.Callable, c.Container]):
1829 self.assertEqual(mro(X, abcs=abcs), expected)
1830 # unrelated ABCs don't appear in the resulting MRO
1831 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
1832 self.assertEqual(mro(X, abcs=many_abcs), expected)
1833
Yury Selivanov77a8cd62015-08-18 14:20:00 -04001834 def test_false_meta(self):
1835 # see issue23572
1836 class MetaA(type):
1837 def __len__(self):
1838 return 0
1839 class A(metaclass=MetaA):
1840 pass
1841 class AA(A):
1842 pass
1843 @functools.singledispatch
1844 def fun(a):
1845 return 'base A'
1846 @fun.register(A)
1847 def _(a):
1848 return 'fun A'
1849 aa = AA()
1850 self.assertEqual(fun(aa), 'fun A')
1851
Łukasz Langa6f692512013-06-05 12:20:24 +02001852 def test_mro_conflicts(self):
1853 c = collections
Łukasz Langa6f692512013-06-05 12:20:24 +02001854 @functools.singledispatch
1855 def g(arg):
1856 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02001857 class O(c.Sized):
1858 def __len__(self):
1859 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001860 o = O()
1861 self.assertEqual(g(o), "base")
1862 g.register(c.Iterable, lambda arg: "iterable")
1863 g.register(c.Container, lambda arg: "container")
1864 g.register(c.Sized, lambda arg: "sized")
1865 g.register(c.Set, lambda arg: "set")
1866 self.assertEqual(g(o), "sized")
1867 c.Iterable.register(O)
1868 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
1869 c.Container.register(O)
1870 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02001871 c.Set.register(O)
1872 self.assertEqual(g(o), "set") # because c.Set is a subclass of
1873 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02001874 class P:
1875 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02001876 p = P()
1877 self.assertEqual(g(p), "base")
1878 c.Iterable.register(P)
1879 self.assertEqual(g(p), "iterable")
1880 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02001881 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02001882 g(p)
Benjamin Petersonab078e92016-07-13 21:13:29 -07001883 self.assertIn(
1884 str(re_one.exception),
1885 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1886 "or <class 'collections.abc.Iterable'>"),
1887 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
1888 "or <class 'collections.abc.Container'>")),
1889 )
Łukasz Langa6f692512013-06-05 12:20:24 +02001890 class Q(c.Sized):
1891 def __len__(self):
1892 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001893 q = Q()
1894 self.assertEqual(g(q), "sized")
1895 c.Iterable.register(Q)
1896 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
1897 c.Set.register(Q)
1898 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02001899 # c.Sized and c.Iterable
1900 @functools.singledispatch
1901 def h(arg):
1902 return "base"
1903 @h.register(c.Sized)
1904 def _(arg):
1905 return "sized"
1906 @h.register(c.Container)
1907 def _(arg):
1908 return "container"
1909 # Even though Sized and Container are explicit bases of MutableMapping,
1910 # this ABC is implicitly registered on defaultdict which makes all of
1911 # MutableMapping's bases implicit as well from defaultdict's
1912 # perspective.
1913 with self.assertRaises(RuntimeError) as re_two:
1914 h(c.defaultdict(lambda: 0))
Benjamin Petersonab078e92016-07-13 21:13:29 -07001915 self.assertIn(
1916 str(re_two.exception),
1917 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1918 "or <class 'collections.abc.Sized'>"),
1919 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1920 "or <class 'collections.abc.Container'>")),
1921 )
Łukasz Langa3720c772013-07-01 16:00:38 +02001922 class R(c.defaultdict):
1923 pass
1924 c.MutableSequence.register(R)
1925 @functools.singledispatch
1926 def i(arg):
1927 return "base"
1928 @i.register(c.MutableMapping)
1929 def _(arg):
1930 return "mapping"
1931 @i.register(c.MutableSequence)
1932 def _(arg):
1933 return "sequence"
1934 r = R()
1935 self.assertEqual(i(r), "sequence")
1936 class S:
1937 pass
1938 class T(S, c.Sized):
1939 def __len__(self):
1940 return 0
1941 t = T()
1942 self.assertEqual(h(t), "sized")
1943 c.Container.register(T)
1944 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
1945 class U:
1946 def __len__(self):
1947 return 0
1948 u = U()
1949 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
1950 # from the existence of __len__()
1951 c.Container.register(U)
1952 # There is no preference for registered versus inferred ABCs.
1953 with self.assertRaises(RuntimeError) as re_three:
1954 h(u)
Benjamin Petersonab078e92016-07-13 21:13:29 -07001955 self.assertIn(
1956 str(re_three.exception),
1957 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1958 "or <class 'collections.abc.Sized'>"),
1959 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1960 "or <class 'collections.abc.Container'>")),
1961 )
Łukasz Langa3720c772013-07-01 16:00:38 +02001962 class V(c.Sized, S):
1963 def __len__(self):
1964 return 0
1965 @functools.singledispatch
1966 def j(arg):
1967 return "base"
1968 @j.register(S)
1969 def _(arg):
1970 return "s"
1971 @j.register(c.Container)
1972 def _(arg):
1973 return "container"
1974 v = V()
1975 self.assertEqual(j(v), "s")
1976 c.Container.register(V)
1977 self.assertEqual(j(v), "container") # because it ends up right after
1978 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02001979
1980 def test_cache_invalidation(self):
1981 from collections import UserDict
1982 class TracingDict(UserDict):
1983 def __init__(self, *args, **kwargs):
1984 super(TracingDict, self).__init__(*args, **kwargs)
1985 self.set_ops = []
1986 self.get_ops = []
1987 def __getitem__(self, key):
1988 result = self.data[key]
1989 self.get_ops.append(key)
1990 return result
1991 def __setitem__(self, key, value):
1992 self.set_ops.append(key)
1993 self.data[key] = value
1994 def clear(self):
1995 self.data.clear()
1996 _orig_wkd = functools.WeakKeyDictionary
1997 td = TracingDict()
1998 functools.WeakKeyDictionary = lambda: td
1999 c = collections
2000 @functools.singledispatch
2001 def g(arg):
2002 return "base"
2003 d = {}
2004 l = []
2005 self.assertEqual(len(td), 0)
2006 self.assertEqual(g(d), "base")
2007 self.assertEqual(len(td), 1)
2008 self.assertEqual(td.get_ops, [])
2009 self.assertEqual(td.set_ops, [dict])
2010 self.assertEqual(td.data[dict], g.registry[object])
2011 self.assertEqual(g(l), "base")
2012 self.assertEqual(len(td), 2)
2013 self.assertEqual(td.get_ops, [])
2014 self.assertEqual(td.set_ops, [dict, list])
2015 self.assertEqual(td.data[dict], g.registry[object])
2016 self.assertEqual(td.data[list], g.registry[object])
2017 self.assertEqual(td.data[dict], td.data[list])
2018 self.assertEqual(g(l), "base")
2019 self.assertEqual(g(d), "base")
2020 self.assertEqual(td.get_ops, [list, dict])
2021 self.assertEqual(td.set_ops, [dict, list])
2022 g.register(list, lambda arg: "list")
2023 self.assertEqual(td.get_ops, [list, dict])
2024 self.assertEqual(len(td), 0)
2025 self.assertEqual(g(d), "base")
2026 self.assertEqual(len(td), 1)
2027 self.assertEqual(td.get_ops, [list, dict])
2028 self.assertEqual(td.set_ops, [dict, list, dict])
2029 self.assertEqual(td.data[dict],
2030 functools._find_impl(dict, g.registry))
2031 self.assertEqual(g(l), "list")
2032 self.assertEqual(len(td), 2)
2033 self.assertEqual(td.get_ops, [list, dict])
2034 self.assertEqual(td.set_ops, [dict, list, dict, list])
2035 self.assertEqual(td.data[list],
2036 functools._find_impl(list, g.registry))
2037 class X:
2038 pass
2039 c.MutableMapping.register(X) # Will not invalidate the cache,
2040 # not using ABCs yet.
2041 self.assertEqual(g(d), "base")
2042 self.assertEqual(g(l), "list")
2043 self.assertEqual(td.get_ops, [list, dict, dict, list])
2044 self.assertEqual(td.set_ops, [dict, list, dict, list])
2045 g.register(c.Sized, lambda arg: "sized")
2046 self.assertEqual(len(td), 0)
2047 self.assertEqual(g(d), "sized")
2048 self.assertEqual(len(td), 1)
2049 self.assertEqual(td.get_ops, [list, dict, dict, list])
2050 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
2051 self.assertEqual(g(l), "list")
2052 self.assertEqual(len(td), 2)
2053 self.assertEqual(td.get_ops, [list, dict, dict, list])
2054 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2055 self.assertEqual(g(l), "list")
2056 self.assertEqual(g(d), "sized")
2057 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
2058 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2059 g.dispatch(list)
2060 g.dispatch(dict)
2061 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
2062 list, dict])
2063 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2064 c.MutableSet.register(X) # Will invalidate the cache.
2065 self.assertEqual(len(td), 2) # Stale cache.
2066 self.assertEqual(g(l), "list")
2067 self.assertEqual(len(td), 1)
2068 g.register(c.MutableMapping, lambda arg: "mutablemapping")
2069 self.assertEqual(len(td), 0)
2070 self.assertEqual(g(d), "mutablemapping")
2071 self.assertEqual(len(td), 1)
2072 self.assertEqual(g(l), "list")
2073 self.assertEqual(len(td), 2)
2074 g.register(dict, lambda arg: "dict")
2075 self.assertEqual(g(d), "dict")
2076 self.assertEqual(g(l), "list")
2077 g._clear_cache()
2078 self.assertEqual(len(td), 0)
2079 functools.WeakKeyDictionary = _orig_wkd
2080
2081
Raymond Hettinger9c323f82005-02-28 19:39:44 +00002082if __name__ == '__main__':
Zachary Ware38c707e2015-04-13 15:00:43 -05002083 unittest.main()