blob: 612ca17a60bc8cd90cbea0f32f059151a3c5e600 [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08002import builtins
Raymond Hettinger003be522011-05-03 11:01:32 -07003import collections
Serhiy Storchaka45120f22015-10-24 09:49:56 +03004import copy
Łukasz Langa6f692512013-06-05 12:20:24 +02005from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00006import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00007from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02008import sys
9from test import support
Serhiy Storchaka67796522017-01-12 18:34:33 +020010import time
Łukasz Langa6f692512013-06-05 12:20:24 +020011import unittest
Raymond Hettingerd191ef22017-01-07 20:44:48 -080012import unittest.mock
Łukasz Langa6f692512013-06-05 12:20:24 +020013from weakref import proxy
Nick Coghlan457fc9a2016-09-10 20:00:02 +100014import contextlib
Serhiy Storchaka46c56112015-05-24 21:53:49 +030015try:
16 import threading
17except ImportError:
18 threading = None
Raymond Hettinger9c323f82005-02-28 19:39:44 +000019
Antoine Pitroub5b37142012-11-13 21:35:40 +010020import functools
21
Antoine Pitroub5b37142012-11-13 21:35:40 +010022py_functools = support.import_fresh_module('functools', blocked=['_functools'])
23c_functools = support.import_fresh_module('functools', fresh=['_functools'])
24
Łukasz Langa6f692512013-06-05 12:20:24 +020025decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
26
Nick Coghlan457fc9a2016-09-10 20:00:02 +100027@contextlib.contextmanager
28def replaced_module(name, replacement):
29 original_module = sys.modules[name]
30 sys.modules[name] = replacement
31 try:
32 yield
33 finally:
34 sys.modules[name] = original_module
Łukasz Langa6f692512013-06-05 12:20:24 +020035
Raymond Hettinger9c323f82005-02-28 19:39:44 +000036def capture(*args, **kw):
37 """capture all positional and keyword arguments"""
38 return args, kw
39
Łukasz Langa6f692512013-06-05 12:20:24 +020040
Jack Diederiche0cbd692009-04-01 04:27:09 +000041def signature(part):
42 """ return the signature of a partial object """
43 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000044
Serhiy Storchaka38741282016-02-02 18:45:17 +020045class MyTuple(tuple):
46 pass
47
48class BadTuple(tuple):
49 def __add__(self, other):
50 return list(self) + list(other)
51
52class MyDict(dict):
53 pass
54
Łukasz Langa6f692512013-06-05 12:20:24 +020055
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020056class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000057
58 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010059 p = self.partial(capture, 1, 2, a=10, b=20)
60 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000061 self.assertEqual(p(3, 4, b=30, c=40),
62 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010063 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000064 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000065
66 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010067 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000068 # attributes should be readable
69 self.assertEqual(p.func, capture)
70 self.assertEqual(p.args, (1, 2))
71 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000072
73 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010074 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000075 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010076 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000077 except TypeError:
78 pass
79 else:
80 self.fail('First arg not checked for callability')
81
82 def test_protection_of_callers_dict_argument(self):
83 # a caller's dictionary should not be altered by partial
84 def func(a=10, b=20):
85 return a
86 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010087 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000088 self.assertEqual(p(**d), 3)
89 self.assertEqual(d, {'a':3})
90 p(b=7)
91 self.assertEqual(d, {'a':3})
92
Serhiy Storchaka9639e4a2017-02-20 14:04:30 +020093 def test_kwargs_copy(self):
94 # Issue #29532: Altering a kwarg dictionary passed to a constructor
95 # should not affect a partial object after creation
96 d = {'a': 3}
97 p = self.partial(capture, **d)
98 self.assertEqual(p(), ((), {'a': 3}))
99 d['a'] = 5
100 self.assertEqual(p(), ((), {'a': 3}))
101
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000102 def test_arg_combinations(self):
103 # exercise special code paths for zero args in either partial
104 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100105 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000106 self.assertEqual(p(), ((), {}))
107 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100108 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000109 self.assertEqual(p(), ((1,2), {}))
110 self.assertEqual(p(3,4), ((1,2,3,4), {}))
111
112 def test_kw_combinations(self):
113 # exercise special code paths for no keyword args in
114 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100115 p = self.partial(capture)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400116 self.assertEqual(p.keywords, {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000117 self.assertEqual(p(), ((), {}))
118 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100119 p = self.partial(capture, a=1)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400120 self.assertEqual(p.keywords, {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000121 self.assertEqual(p(), ((), {'a':1}))
122 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
123 # keyword args in the call override those in the partial object
124 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
125
126 def test_positional(self):
127 # make sure positional arguments are captured correctly
128 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100129 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000130 expected = args + ('x',)
131 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000132 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000133
134 def test_keyword(self):
135 # make sure keyword arguments are captured correctly
136 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100137 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000138 expected = {'a':a,'x':None}
139 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000140 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000141
142 def test_no_side_effects(self):
143 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100144 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000145 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000146 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000147 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000148 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000149
150 def test_error_propagation(self):
151 def f(x, y):
152 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100153 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
154 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
155 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
156 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000157
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000158 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100159 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000160 p = proxy(f)
161 self.assertEqual(f.func, p.func)
162 f = None
163 self.assertRaises(ReferenceError, getattr, p, 'func')
164
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000165 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000166 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100167 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000168 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100169 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000170 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000171
Alexander Belopolskye49af342015-03-01 15:08:17 -0500172 def test_nested_optimization(self):
173 partial = self.partial
Alexander Belopolskye49af342015-03-01 15:08:17 -0500174 inner = partial(signature, 'asdf')
175 nested = partial(inner, bar=True)
176 flat = partial(signature, 'asdf', bar=True)
177 self.assertEqual(signature(nested), signature(flat))
178
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300179 def test_nested_partial_with_attribute(self):
180 # see issue 25137
181 partial = self.partial
182
183 def foo(bar):
184 return bar
185
186 p = partial(foo, 'first')
187 p2 = partial(p, 'second')
188 p2.new_attr = 'spam'
189 self.assertEqual(p2.new_attr, 'spam')
190
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000191 def test_repr(self):
192 args = (object(), object())
193 args_repr = ', '.join(repr(a) for a in args)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200194 kwargs = {'a': object(), 'b': object()}
195 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
196 'b={b!r}, a={a!r}'.format_map(kwargs)]
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000197 if self.partial in (c_functools.partial, py_functools.partial):
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000198 name = 'functools.partial'
199 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100200 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000201
Antoine Pitroub5b37142012-11-13 21:35:40 +0100202 f = self.partial(capture)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000203 self.assertEqual(f'{name}({capture!r})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000204
Antoine Pitroub5b37142012-11-13 21:35:40 +0100205 f = self.partial(capture, *args)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000206 self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000207
Antoine Pitroub5b37142012-11-13 21:35:40 +0100208 f = self.partial(capture, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200209 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000210 [f'{name}({capture!r}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200211 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000212
Antoine Pitroub5b37142012-11-13 21:35:40 +0100213 f = self.partial(capture, *args, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200214 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000215 [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200216 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000217
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300218 def test_recursive_repr(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000219 if self.partial in (c_functools.partial, py_functools.partial):
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300220 name = 'functools.partial'
221 else:
222 name = self.partial.__name__
223
224 f = self.partial(capture)
225 f.__setstate__((f, (), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300226 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000227 self.assertEqual(repr(f), '%s(...)' % (name,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300228 finally:
229 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300230
231 f = self.partial(capture)
232 f.__setstate__((capture, (f,), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300233 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000234 self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300235 finally:
236 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300237
238 f = self.partial(capture)
239 f.__setstate__((capture, (), {'a': f}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300240 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000241 self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300242 finally:
243 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300244
Jack Diederiche0cbd692009-04-01 04:27:09 +0000245 def test_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000246 with self.AllowPickle():
247 f = self.partial(signature, ['asdf'], bar=[True])
248 f.attr = []
249 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
250 f_copy = pickle.loads(pickle.dumps(f, proto))
251 self.assertEqual(signature(f_copy), signature(f))
Serhiy Storchaka38741282016-02-02 18:45:17 +0200252
253 def test_copy(self):
254 f = self.partial(signature, ['asdf'], bar=[True])
255 f.attr = []
256 f_copy = copy.copy(f)
257 self.assertEqual(signature(f_copy), signature(f))
258 self.assertIs(f_copy.attr, f.attr)
259 self.assertIs(f_copy.args, f.args)
260 self.assertIs(f_copy.keywords, f.keywords)
261
262 def test_deepcopy(self):
263 f = self.partial(signature, ['asdf'], bar=[True])
264 f.attr = []
265 f_copy = copy.deepcopy(f)
266 self.assertEqual(signature(f_copy), signature(f))
267 self.assertIsNot(f_copy.attr, f.attr)
268 self.assertIsNot(f_copy.args, f.args)
269 self.assertIsNot(f_copy.args[0], f.args[0])
270 self.assertIsNot(f_copy.keywords, f.keywords)
271 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
272
273 def test_setstate(self):
274 f = self.partial(signature)
275 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000276
Serhiy Storchaka38741282016-02-02 18:45:17 +0200277 self.assertEqual(signature(f),
278 (capture, (1,), dict(a=10), dict(attr=[])))
279 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
280
281 f.__setstate__((capture, (1,), dict(a=10), None))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000282
Serhiy Storchaka38741282016-02-02 18:45:17 +0200283 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
284 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
285
286 f.__setstate__((capture, (1,), None, None))
287 #self.assertEqual(signature(f), (capture, (1,), {}, {}))
288 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
289 self.assertEqual(f(2), ((1, 2), {}))
290 self.assertEqual(f(), ((1,), {}))
291
292 f.__setstate__((capture, (), {}, None))
293 self.assertEqual(signature(f), (capture, (), {}, {}))
294 self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
295 self.assertEqual(f(2), ((2,), {}))
296 self.assertEqual(f(), ((), {}))
297
298 def test_setstate_errors(self):
299 f = self.partial(signature)
300 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
301 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
302 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
303 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
304 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
305 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
306 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
307
308 def test_setstate_subclasses(self):
309 f = self.partial(signature)
310 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
311 s = signature(f)
312 self.assertEqual(s, (capture, (1,), dict(a=10), {}))
313 self.assertIs(type(s[1]), tuple)
314 self.assertIs(type(s[2]), dict)
315 r = f()
316 self.assertEqual(r, ((1,), {'a': 10}))
317 self.assertIs(type(r[0]), tuple)
318 self.assertIs(type(r[1]), dict)
319
320 f.__setstate__((capture, BadTuple((1,)), {}, None))
321 s = signature(f)
322 self.assertEqual(s, (capture, (1,), {}, {}))
323 self.assertIs(type(s[1]), tuple)
324 r = f(2)
325 self.assertEqual(r, ((1, 2), {}))
326 self.assertIs(type(r[0]), tuple)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000327
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300328 def test_recursive_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000329 with self.AllowPickle():
330 f = self.partial(capture)
331 f.__setstate__((f, (), {}, {}))
332 try:
333 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
334 with self.assertRaises(RecursionError):
335 pickle.dumps(f, proto)
336 finally:
337 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300338
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000339 f = self.partial(capture)
340 f.__setstate__((capture, (f,), {}, {}))
341 try:
342 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
343 f_copy = pickle.loads(pickle.dumps(f, proto))
344 try:
345 self.assertIs(f_copy.args[0], f_copy)
346 finally:
347 f_copy.__setstate__((capture, (), {}, {}))
348 finally:
349 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300350
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000351 f = self.partial(capture)
352 f.__setstate__((capture, (), {'a': f}, {}))
353 try:
354 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
355 f_copy = pickle.loads(pickle.dumps(f, proto))
356 try:
357 self.assertIs(f_copy.keywords['a'], f_copy)
358 finally:
359 f_copy.__setstate__((capture, (), {}, {}))
360 finally:
361 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300362
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200363 # Issue 6083: Reference counting bug
364 def test_setstate_refcount(self):
365 class BadSequence:
366 def __len__(self):
367 return 4
368 def __getitem__(self, key):
369 if key == 0:
370 return max
371 elif key == 1:
372 return tuple(range(1000000))
373 elif key in (2, 3):
374 return {}
375 raise IndexError
376
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200377 f = self.partial(object)
Serhiy Storchaka38741282016-02-02 18:45:17 +0200378 self.assertRaises(TypeError, f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000379
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000380@unittest.skipUnless(c_functools, 'requires the C _functools module')
381class TestPartialC(TestPartial, unittest.TestCase):
382 if c_functools:
383 partial = c_functools.partial
384
385 class AllowPickle:
386 def __enter__(self):
387 return self
388 def __exit__(self, type, value, tb):
389 return False
390
391 def test_attributes_unwritable(self):
392 # attributes should not be writable
393 p = self.partial(capture, 1, 2, a=10, b=20)
394 self.assertRaises(AttributeError, setattr, p, 'func', map)
395 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
396 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
397
398 p = self.partial(hex)
399 try:
400 del p.__dict__
401 except TypeError:
402 pass
403 else:
404 self.fail('partial object allowed __dict__ to be deleted')
Łukasz Langa6f692512013-06-05 12:20:24 +0200405
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200406class TestPartialPy(TestPartial, unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000407 partial = py_functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000408
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000409 class AllowPickle:
410 def __init__(self):
411 self._cm = replaced_module("functools", py_functools)
412 def __enter__(self):
413 return self._cm.__enter__()
414 def __exit__(self, type, value, tb):
415 return self._cm.__exit__(type, value, tb)
Łukasz Langa6f692512013-06-05 12:20:24 +0200416
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200417if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000418 class CPartialSubclass(c_functools.partial):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200419 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100420
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000421class PyPartialSubclass(py_functools.partial):
422 pass
Łukasz Langa6f692512013-06-05 12:20:24 +0200423
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200424@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200425class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200426 if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000427 partial = CPartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000428
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300429 # partial subclasses are not optimized for nested calls
430 test_nested_optimization = None
431
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000432class TestPartialPySubclass(TestPartialPy):
433 partial = PyPartialSubclass
Łukasz Langa6f692512013-06-05 12:20:24 +0200434
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000435class TestPartialMethod(unittest.TestCase):
436
437 class A(object):
438 nothing = functools.partialmethod(capture)
439 positional = functools.partialmethod(capture, 1)
440 keywords = functools.partialmethod(capture, a=2)
441 both = functools.partialmethod(capture, 3, b=4)
442
443 nested = functools.partialmethod(positional, 5)
444
445 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
446
447 static = functools.partialmethod(staticmethod(capture), 8)
448 cls = functools.partialmethod(classmethod(capture), d=9)
449
450 a = A()
451
452 def test_arg_combinations(self):
453 self.assertEqual(self.a.nothing(), ((self.a,), {}))
454 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
455 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
456 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
457
458 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
459 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
460 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
461 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
462
463 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
464 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
465 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
466 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
467
468 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
469 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
470 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
471 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
472
473 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
474
475 def test_nested(self):
476 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
477 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
478 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
479 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
480
481 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
482
483 def test_over_partial(self):
484 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
485 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
486 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
487 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
488
489 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
490
491 def test_bound_method_introspection(self):
492 obj = self.a
493 self.assertIs(obj.both.__self__, obj)
494 self.assertIs(obj.nested.__self__, obj)
495 self.assertIs(obj.over_partial.__self__, obj)
496 self.assertIs(obj.cls.__self__, self.A)
497 self.assertIs(self.A.cls.__self__, self.A)
498
499 def test_unbound_method_retrieval(self):
500 obj = self.A
501 self.assertFalse(hasattr(obj.both, "__self__"))
502 self.assertFalse(hasattr(obj.nested, "__self__"))
503 self.assertFalse(hasattr(obj.over_partial, "__self__"))
504 self.assertFalse(hasattr(obj.static, "__self__"))
505 self.assertFalse(hasattr(self.a.static, "__self__"))
506
507 def test_descriptors(self):
508 for obj in [self.A, self.a]:
509 with self.subTest(obj=obj):
510 self.assertEqual(obj.static(), ((8,), {}))
511 self.assertEqual(obj.static(5), ((8, 5), {}))
512 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
513 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
514
515 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
516 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
517 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
518 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
519
520 def test_overriding_keywords(self):
521 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
522 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
523
524 def test_invalid_args(self):
525 with self.assertRaises(TypeError):
526 class B(object):
527 method = functools.partialmethod(None, 1)
528
529 def test_repr(self):
530 self.assertEqual(repr(vars(self.A)['both']),
531 'functools.partialmethod({}, 3, b=4)'.format(capture))
532
533 def test_abstract(self):
534 class Abstract(abc.ABCMeta):
535
536 @abc.abstractmethod
537 def add(self, x, y):
538 pass
539
540 add5 = functools.partialmethod(add, 5)
541
542 self.assertTrue(Abstract.add.__isabstractmethod__)
543 self.assertTrue(Abstract.add5.__isabstractmethod__)
544
545 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
546 self.assertFalse(getattr(func, '__isabstractmethod__', False))
547
548
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000549class TestUpdateWrapper(unittest.TestCase):
550
551 def check_wrapper(self, wrapper, wrapped,
552 assigned=functools.WRAPPER_ASSIGNMENTS,
553 updated=functools.WRAPPER_UPDATES):
554 # Check attributes were assigned
555 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000556 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000557 # Check attributes were updated
558 for name in updated:
559 wrapper_attr = getattr(wrapper, name)
560 wrapped_attr = getattr(wrapped, name)
561 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000562 if name == "__dict__" and key == "__wrapped__":
563 # __wrapped__ is overwritten by the update code
564 continue
565 self.assertIs(wrapped_attr[key], wrapper_attr[key])
566 # Check __wrapped__
567 self.assertIs(wrapper.__wrapped__, wrapped)
568
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000569
R. David Murray378c0cf2010-02-24 01:46:21 +0000570 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000571 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000572 """This is a test"""
573 pass
574 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000575 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000576 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000577 pass
578 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000579 return wrapper, f
580
581 def test_default_update(self):
582 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000583 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000584 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000585 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600586 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000587 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000588 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
589 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000590
R. David Murray378c0cf2010-02-24 01:46:21 +0000591 @unittest.skipIf(sys.flags.optimize >= 2,
592 "Docstrings are omitted with -O2 and above")
593 def test_default_update_doc(self):
594 wrapper, f = self._default_update()
595 self.assertEqual(wrapper.__doc__, 'This is a test')
596
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000597 def test_no_update(self):
598 def f():
599 """This is a test"""
600 pass
601 f.attr = 'This is also a test'
602 def wrapper():
603 pass
604 functools.update_wrapper(wrapper, f, (), ())
605 self.check_wrapper(wrapper, f, (), ())
606 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600607 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000608 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000609 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000610 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000611
612 def test_selective_update(self):
613 def f():
614 pass
615 f.attr = 'This is a different test'
616 f.dict_attr = dict(a=1, b=2, c=3)
617 def wrapper():
618 pass
619 wrapper.dict_attr = {}
620 assign = ('attr',)
621 update = ('dict_attr',)
622 functools.update_wrapper(wrapper, f, assign, update)
623 self.check_wrapper(wrapper, f, assign, update)
624 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600625 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000626 self.assertEqual(wrapper.__doc__, None)
627 self.assertEqual(wrapper.attr, 'This is a different test')
628 self.assertEqual(wrapper.dict_attr, f.dict_attr)
629
Nick Coghlan98876832010-08-17 06:17:18 +0000630 def test_missing_attributes(self):
631 def f():
632 pass
633 def wrapper():
634 pass
635 wrapper.dict_attr = {}
636 assign = ('attr',)
637 update = ('dict_attr',)
638 # Missing attributes on wrapped object are ignored
639 functools.update_wrapper(wrapper, f, assign, update)
640 self.assertNotIn('attr', wrapper.__dict__)
641 self.assertEqual(wrapper.dict_attr, {})
642 # Wrapper must have expected attributes for updating
643 del wrapper.dict_attr
644 with self.assertRaises(AttributeError):
645 functools.update_wrapper(wrapper, f, assign, update)
646 wrapper.dict_attr = 1
647 with self.assertRaises(AttributeError):
648 functools.update_wrapper(wrapper, f, assign, update)
649
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200650 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000651 @unittest.skipIf(sys.flags.optimize >= 2,
652 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000653 def test_builtin_update(self):
654 # Test for bug #1576241
655 def wrapper():
656 pass
657 functools.update_wrapper(wrapper, max)
658 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000659 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000660 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000661
Łukasz Langa6f692512013-06-05 12:20:24 +0200662
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000663class TestWraps(TestUpdateWrapper):
664
R. David Murray378c0cf2010-02-24 01:46:21 +0000665 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000666 def f():
667 """This is a test"""
668 pass
669 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000670 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000671 @functools.wraps(f)
672 def wrapper():
673 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600674 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000675
676 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600677 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000678 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000679 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600680 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000681 self.assertEqual(wrapper.attr, 'This is also a test')
682
Antoine Pitroub5b37142012-11-13 21:35:40 +0100683 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000684 "Docstrings are omitted with -O2 and above")
685 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600686 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000687 self.assertEqual(wrapper.__doc__, 'This is a test')
688
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000689 def test_no_update(self):
690 def f():
691 """This is a test"""
692 pass
693 f.attr = 'This is also a test'
694 @functools.wraps(f, (), ())
695 def wrapper():
696 pass
697 self.check_wrapper(wrapper, f, (), ())
698 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600699 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000700 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000701 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000702
703 def test_selective_update(self):
704 def f():
705 pass
706 f.attr = 'This is a different test'
707 f.dict_attr = dict(a=1, b=2, c=3)
708 def add_dict_attr(f):
709 f.dict_attr = {}
710 return f
711 assign = ('attr',)
712 update = ('dict_attr',)
713 @functools.wraps(f, assign, update)
714 @add_dict_attr
715 def wrapper():
716 pass
717 self.check_wrapper(wrapper, f, assign, update)
718 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600719 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000720 self.assertEqual(wrapper.__doc__, None)
721 self.assertEqual(wrapper.attr, 'This is a different test')
722 self.assertEqual(wrapper.dict_attr, f.dict_attr)
723
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000724@unittest.skipUnless(c_functools, 'requires the C _functools module')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000725class TestReduce(unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000726 if c_functools:
727 func = c_functools.reduce
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000728
729 def test_reduce(self):
730 class Squares:
731 def __init__(self, max):
732 self.max = max
733 self.sofar = []
734
735 def __len__(self):
736 return len(self.sofar)
737
738 def __getitem__(self, i):
739 if not 0 <= i < self.max: raise IndexError
740 n = len(self.sofar)
741 while n <= i:
742 self.sofar.append(n*n)
743 n += 1
744 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000745 def add(x, y):
746 return x + y
747 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000748 self.assertEqual(
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000749 self.func(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000750 ['a','c','d','w']
751 )
752 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
753 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000754 self.func(lambda x, y: x*y, range(2,21), 1),
755 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000756 )
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000757 self.assertEqual(self.func(add, Squares(10)), 285)
758 self.assertEqual(self.func(add, Squares(10), 0), 285)
759 self.assertEqual(self.func(add, Squares(0), 0), 0)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000760 self.assertRaises(TypeError, self.func)
761 self.assertRaises(TypeError, self.func, 42, 42)
762 self.assertRaises(TypeError, self.func, 42, 42, 42)
763 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
764 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
765 self.assertRaises(TypeError, self.func, 42, (42, 42))
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000766 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
767 self.assertRaises(TypeError, self.func, add, "")
768 self.assertRaises(TypeError, self.func, add, ())
769 self.assertRaises(TypeError, self.func, add, object())
770
771 class TestFailingIter:
772 def __iter__(self):
773 raise RuntimeError
774 self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
775
776 self.assertEqual(self.func(add, [], None), None)
777 self.assertEqual(self.func(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000778
779 class BadSeq:
780 def __getitem__(self, index):
781 raise ValueError
782 self.assertRaises(ValueError, self.func, 42, BadSeq())
783
784 # Test reduce()'s use of iterators.
785 def test_iterator_usage(self):
786 class SequenceClass:
787 def __init__(self, n):
788 self.n = n
789 def __getitem__(self, i):
790 if 0 <= i < self.n:
791 return i
792 else:
793 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000794
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000795 from operator import add
796 self.assertEqual(self.func(add, SequenceClass(5)), 10)
797 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
798 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
799 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
800 self.assertEqual(self.func(add, SequenceClass(1)), 0)
801 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
802
803 d = {"one": 1, "two": 2, "three": 3}
804 self.assertEqual(self.func(add, d), "".join(d.keys()))
805
Łukasz Langa6f692512013-06-05 12:20:24 +0200806
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200807class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700808
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000809 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700810 def cmp1(x, y):
811 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100812 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700813 self.assertEqual(key(3), key(3))
814 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100815 self.assertGreaterEqual(key(3), key(3))
816
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700817 def cmp2(x, y):
818 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100819 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700820 self.assertEqual(key(4.0), key('4'))
821 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100822 self.assertLessEqual(key(2), key('35'))
823 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700824
825 def test_cmp_to_key_arguments(self):
826 def cmp1(x, y):
827 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100828 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700829 self.assertEqual(key(obj=3), key(obj=3))
830 self.assertGreater(key(obj=3), key(obj=1))
831 with self.assertRaises((TypeError, AttributeError)):
832 key(3) > 1 # rhs is not a K object
833 with self.assertRaises((TypeError, AttributeError)):
834 1 < key(3) # lhs is not a K object
835 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100836 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700837 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200838 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100839 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700840 with self.assertRaises(TypeError):
841 key() # too few args
842 with self.assertRaises(TypeError):
843 key(None, None) # too many args
844
845 def test_bad_cmp(self):
846 def cmp1(x, y):
847 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100848 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700849 with self.assertRaises(ZeroDivisionError):
850 key(3) > key(1)
851
852 class BadCmp:
853 def __lt__(self, other):
854 raise ZeroDivisionError
855 def cmp1(x, y):
856 return BadCmp()
857 with self.assertRaises(ZeroDivisionError):
858 key(3) > key(1)
859
860 def test_obj_field(self):
861 def cmp1(x, y):
862 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100863 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700864 self.assertEqual(key(50).obj, 50)
865
866 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000867 def mycmp(x, y):
868 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100869 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000870 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000871
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700872 def test_sort_int_str(self):
873 def mycmp(x, y):
874 x, y = int(x), int(y)
875 return (x > y) - (x < y)
876 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100877 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700878 self.assertEqual([int(value) for value in values],
879 [0, 1, 1, 2, 3, 4, 5, 7, 10])
880
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000881 def test_hash(self):
882 def mycmp(x, y):
883 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100884 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000885 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700886 self.assertRaises(TypeError, hash, k)
Raymond Hettingere7a24302011-05-03 11:16:36 -0700887 self.assertNotIsInstance(k, collections.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000888
Łukasz Langa6f692512013-06-05 12:20:24 +0200889
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200890@unittest.skipUnless(c_functools, 'requires the C _functools module')
891class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
892 if c_functools:
893 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100894
Łukasz Langa6f692512013-06-05 12:20:24 +0200895
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200896class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100897 cmp_to_key = staticmethod(py_functools.cmp_to_key)
898
Łukasz Langa6f692512013-06-05 12:20:24 +0200899
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000900class TestTotalOrdering(unittest.TestCase):
901
902 def test_total_ordering_lt(self):
903 @functools.total_ordering
904 class A:
905 def __init__(self, value):
906 self.value = value
907 def __lt__(self, other):
908 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000909 def __eq__(self, other):
910 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000911 self.assertTrue(A(1) < A(2))
912 self.assertTrue(A(2) > A(1))
913 self.assertTrue(A(1) <= A(2))
914 self.assertTrue(A(2) >= A(1))
915 self.assertTrue(A(2) <= A(2))
916 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000917 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000918
919 def test_total_ordering_le(self):
920 @functools.total_ordering
921 class A:
922 def __init__(self, value):
923 self.value = value
924 def __le__(self, other):
925 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000926 def __eq__(self, other):
927 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000928 self.assertTrue(A(1) < A(2))
929 self.assertTrue(A(2) > A(1))
930 self.assertTrue(A(1) <= A(2))
931 self.assertTrue(A(2) >= A(1))
932 self.assertTrue(A(2) <= A(2))
933 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000934 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000935
936 def test_total_ordering_gt(self):
937 @functools.total_ordering
938 class A:
939 def __init__(self, value):
940 self.value = value
941 def __gt__(self, other):
942 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000943 def __eq__(self, other):
944 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000945 self.assertTrue(A(1) < A(2))
946 self.assertTrue(A(2) > A(1))
947 self.assertTrue(A(1) <= A(2))
948 self.assertTrue(A(2) >= A(1))
949 self.assertTrue(A(2) <= A(2))
950 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000951 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000952
953 def test_total_ordering_ge(self):
954 @functools.total_ordering
955 class A:
956 def __init__(self, value):
957 self.value = value
958 def __ge__(self, other):
959 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000960 def __eq__(self, other):
961 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000962 self.assertTrue(A(1) < A(2))
963 self.assertTrue(A(2) > A(1))
964 self.assertTrue(A(1) <= A(2))
965 self.assertTrue(A(2) >= A(1))
966 self.assertTrue(A(2) <= A(2))
967 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000968 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000969
970 def test_total_ordering_no_overwrite(self):
971 # new methods should not overwrite existing
972 @functools.total_ordering
973 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000974 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +0000975 self.assertTrue(A(1) < A(2))
976 self.assertTrue(A(2) > A(1))
977 self.assertTrue(A(1) <= A(2))
978 self.assertTrue(A(2) >= A(1))
979 self.assertTrue(A(2) <= A(2))
980 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000981
Benjamin Peterson42ebee32010-04-11 01:43:16 +0000982 def test_no_operations_defined(self):
983 with self.assertRaises(ValueError):
984 @functools.total_ordering
985 class A:
986 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000987
Nick Coghlanf05d9812013-10-02 00:02:03 +1000988 def test_type_error_when_not_implemented(self):
989 # bug 10042; ensure stack overflow does not occur
990 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000991 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +1000992 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000993 def __init__(self, value):
994 self.value = value
995 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000996 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000997 return self.value == other.value
998 return False
999 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001000 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001001 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +10001002 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001003
Nick Coghlanf05d9812013-10-02 00:02:03 +10001004 @functools.total_ordering
1005 class ImplementsGreaterThan:
1006 def __init__(self, value):
1007 self.value = value
1008 def __eq__(self, other):
1009 if isinstance(other, ImplementsGreaterThan):
1010 return self.value == other.value
1011 return False
1012 def __gt__(self, other):
1013 if isinstance(other, ImplementsGreaterThan):
1014 return self.value > other.value
1015 return NotImplemented
1016
1017 @functools.total_ordering
1018 class ImplementsLessThanEqualTo:
1019 def __init__(self, value):
1020 self.value = value
1021 def __eq__(self, other):
1022 if isinstance(other, ImplementsLessThanEqualTo):
1023 return self.value == other.value
1024 return False
1025 def __le__(self, other):
1026 if isinstance(other, ImplementsLessThanEqualTo):
1027 return self.value <= other.value
1028 return NotImplemented
1029
1030 @functools.total_ordering
1031 class ImplementsGreaterThanEqualTo:
1032 def __init__(self, value):
1033 self.value = value
1034 def __eq__(self, other):
1035 if isinstance(other, ImplementsGreaterThanEqualTo):
1036 return self.value == other.value
1037 return False
1038 def __ge__(self, other):
1039 if isinstance(other, ImplementsGreaterThanEqualTo):
1040 return self.value >= other.value
1041 return NotImplemented
1042
1043 @functools.total_ordering
1044 class ComparatorNotImplemented:
1045 def __init__(self, value):
1046 self.value = value
1047 def __eq__(self, other):
1048 if isinstance(other, ComparatorNotImplemented):
1049 return self.value == other.value
1050 return False
1051 def __lt__(self, other):
1052 return NotImplemented
1053
1054 with self.subTest("LT < 1"), self.assertRaises(TypeError):
1055 ImplementsLessThan(-1) < 1
1056
1057 with self.subTest("LT < LE"), self.assertRaises(TypeError):
1058 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
1059
1060 with self.subTest("LT < GT"), self.assertRaises(TypeError):
1061 ImplementsLessThan(1) < ImplementsGreaterThan(1)
1062
1063 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
1064 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
1065
1066 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
1067 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
1068
1069 with self.subTest("GT > GE"), self.assertRaises(TypeError):
1070 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
1071
1072 with self.subTest("GT > LT"), self.assertRaises(TypeError):
1073 ImplementsGreaterThan(5) > ImplementsLessThan(5)
1074
1075 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
1076 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
1077
1078 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
1079 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
1080
1081 with self.subTest("GE when equal"):
1082 a = ComparatorNotImplemented(8)
1083 b = ComparatorNotImplemented(8)
1084 self.assertEqual(a, b)
1085 with self.assertRaises(TypeError):
1086 a >= b
1087
1088 with self.subTest("LE when equal"):
1089 a = ComparatorNotImplemented(9)
1090 b = ComparatorNotImplemented(9)
1091 self.assertEqual(a, b)
1092 with self.assertRaises(TypeError):
1093 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +02001094
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001095 def test_pickle(self):
Serhiy Storchaka92bb90a2016-09-22 11:39:25 +03001096 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001097 for name in '__lt__', '__gt__', '__le__', '__ge__':
1098 with self.subTest(method=name, proto=proto):
1099 method = getattr(Orderable_LT, name)
1100 method_copy = pickle.loads(pickle.dumps(method, proto))
1101 self.assertIs(method_copy, method)
1102
1103@functools.total_ordering
1104class Orderable_LT:
1105 def __init__(self, value):
1106 self.value = value
1107 def __lt__(self, other):
1108 return self.value < other.value
1109 def __eq__(self, other):
1110 return self.value == other.value
1111
1112
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001113class TestLRU:
Georg Brandl2e7346a2010-07-31 18:09:23 +00001114
1115 def test_lru(self):
1116 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001117 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001118 f = self.module.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001119 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001120 self.assertEqual(maxsize, 20)
1121 self.assertEqual(currsize, 0)
1122 self.assertEqual(hits, 0)
1123 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001124
1125 domain = range(5)
1126 for i in range(1000):
1127 x, y = choice(domain), choice(domain)
1128 actual = f(x, y)
1129 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +00001130 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001131 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001132 self.assertTrue(hits > misses)
1133 self.assertEqual(hits + misses, 1000)
1134 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001135
Raymond Hettinger02566ec2010-09-04 22:46:06 +00001136 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +00001137 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001138 self.assertEqual(hits, 0)
1139 self.assertEqual(misses, 0)
1140 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001141 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001142 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001143 self.assertEqual(hits, 0)
1144 self.assertEqual(misses, 1)
1145 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001146
Nick Coghlan98876832010-08-17 06:17:18 +00001147 # Test bypassing the cache
1148 self.assertIs(f.__wrapped__, orig)
1149 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001150 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001151 self.assertEqual(hits, 0)
1152 self.assertEqual(misses, 1)
1153 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +00001154
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001155 # test size zero (which means "never-cache")
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001156 @self.module.lru_cache(0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001157 def f():
1158 nonlocal f_cnt
1159 f_cnt += 1
1160 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001161 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001162 f_cnt = 0
1163 for i in range(5):
1164 self.assertEqual(f(), 20)
1165 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001166 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001167 self.assertEqual(hits, 0)
1168 self.assertEqual(misses, 5)
1169 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001170
1171 # test size one
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001172 @self.module.lru_cache(1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001173 def f():
1174 nonlocal f_cnt
1175 f_cnt += 1
1176 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001177 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001178 f_cnt = 0
1179 for i in range(5):
1180 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001181 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001182 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001183 self.assertEqual(hits, 4)
1184 self.assertEqual(misses, 1)
1185 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001186
Raymond Hettingerf3098282010-08-15 03:30:45 +00001187 # test size two
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001188 @self.module.lru_cache(2)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001189 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001190 nonlocal f_cnt
1191 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +00001192 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +00001193 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001194 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +00001195 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1196 # * * * *
1197 self.assertEqual(f(x), x*10)
1198 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001199 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001200 self.assertEqual(hits, 12)
1201 self.assertEqual(misses, 4)
1202 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001203
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001204 def test_lru_hash_only_once(self):
1205 # To protect against weird reentrancy bugs and to improve
1206 # efficiency when faced with slow __hash__ methods, the
1207 # LRU cache guarantees that it will only call __hash__
1208 # only once per use as an argument to the cached function.
1209
1210 @self.module.lru_cache(maxsize=1)
1211 def f(x, y):
1212 return x * 3 + y
1213
1214 # Simulate the integer 5
1215 mock_int = unittest.mock.Mock()
1216 mock_int.__mul__ = unittest.mock.Mock(return_value=15)
1217 mock_int.__hash__ = unittest.mock.Mock(return_value=999)
1218
1219 # Add to cache: One use as an argument gives one call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001220 self.assertEqual(f(mock_int, 1), 16)
1221 self.assertEqual(mock_int.__hash__.call_count, 1)
1222 self.assertEqual(f.cache_info(), (0, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001223
1224 # Cache hit: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001225 self.assertEqual(f(mock_int, 1), 16)
1226 self.assertEqual(mock_int.__hash__.call_count, 2)
1227 self.assertEqual(f.cache_info(), (1, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001228
1229 # Cache eviction: No use as an argument gives no additonal call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001230 self.assertEqual(f(6, 2), 20)
1231 self.assertEqual(mock_int.__hash__.call_count, 2)
1232 self.assertEqual(f.cache_info(), (1, 2, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001233
1234 # Cache miss: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001235 self.assertEqual(f(mock_int, 1), 16)
1236 self.assertEqual(mock_int.__hash__.call_count, 3)
1237 self.assertEqual(f.cache_info(), (1, 3, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001238
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08001239 def test_lru_reentrancy_with_len(self):
1240 # Test to make sure the LRU cache code isn't thrown-off by
1241 # caching the built-in len() function. Since len() can be
1242 # cached, we shouldn't use it inside the lru code itself.
1243 old_len = builtins.len
1244 try:
1245 builtins.len = self.module.lru_cache(4)(len)
1246 for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
1247 self.assertEqual(len('abcdefghijklmn'[:i]), i)
1248 finally:
1249 builtins.len = old_len
1250
Raymond Hettinger605a4472017-01-09 07:50:19 -08001251 def test_lru_star_arg_handling(self):
1252 # Test regression that arose in ea064ff3c10f
1253 @functools.lru_cache()
1254 def f(*args):
1255 return args
1256
1257 self.assertEqual(f(1, 2), (1, 2))
1258 self.assertEqual(f((1, 2)), ((1, 2),))
1259
Yury Selivanov46a02db2016-11-09 18:55:45 -05001260 def test_lru_type_error(self):
1261 # Regression test for issue #28653.
1262 # lru_cache was leaking when one of the arguments
1263 # wasn't cacheable.
1264
1265 @functools.lru_cache(maxsize=None)
1266 def infinite_cache(o):
1267 pass
1268
1269 @functools.lru_cache(maxsize=10)
1270 def limited_cache(o):
1271 pass
1272
1273 with self.assertRaises(TypeError):
1274 infinite_cache([])
1275
1276 with self.assertRaises(TypeError):
1277 limited_cache([])
1278
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001279 def test_lru_with_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001280 @self.module.lru_cache(maxsize=None)
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001281 def fib(n):
1282 if n < 2:
1283 return n
1284 return fib(n-1) + fib(n-2)
1285 self.assertEqual([fib(n) for n in range(16)],
1286 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1287 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001288 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001289 fib.cache_clear()
1290 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001291 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1292
1293 def test_lru_with_maxsize_negative(self):
1294 @self.module.lru_cache(maxsize=-10)
1295 def eq(n):
1296 return n
1297 for i in (0, 1):
1298 self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1299 self.assertEqual(eq.cache_info(),
1300 self.module._CacheInfo(hits=0, misses=300, maxsize=-10, currsize=1))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001301
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001302 def test_lru_with_exceptions(self):
1303 # Verify that user_function exceptions get passed through without
1304 # creating a hard-to-read chained exception.
1305 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +01001306 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001307 @self.module.lru_cache(maxsize)
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001308 def func(i):
1309 return 'abc'[i]
1310 self.assertEqual(func(0), 'a')
1311 with self.assertRaises(IndexError) as cm:
1312 func(15)
1313 self.assertIsNone(cm.exception.__context__)
1314 # Verify that the previous exception did not result in a cached entry
1315 with self.assertRaises(IndexError):
1316 func(15)
1317
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001318 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001319 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001320 @self.module.lru_cache(maxsize=maxsize, typed=True)
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001321 def square(x):
1322 return x * x
1323 self.assertEqual(square(3), 9)
1324 self.assertEqual(type(square(3)), type(9))
1325 self.assertEqual(square(3.0), 9.0)
1326 self.assertEqual(type(square(3.0)), type(9.0))
1327 self.assertEqual(square(x=3), 9)
1328 self.assertEqual(type(square(x=3)), type(9))
1329 self.assertEqual(square(x=3.0), 9.0)
1330 self.assertEqual(type(square(x=3.0)), type(9.0))
1331 self.assertEqual(square.cache_info().hits, 4)
1332 self.assertEqual(square.cache_info().misses, 4)
1333
Antoine Pitroub5b37142012-11-13 21:35:40 +01001334 def test_lru_with_keyword_args(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001335 @self.module.lru_cache()
Antoine Pitroub5b37142012-11-13 21:35:40 +01001336 def fib(n):
1337 if n < 2:
1338 return n
1339 return fib(n=n-1) + fib(n=n-2)
1340 self.assertEqual(
1341 [fib(n=number) for number in range(16)],
1342 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1343 )
1344 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001345 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001346 fib.cache_clear()
1347 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001348 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001349
1350 def test_lru_with_keyword_args_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001351 @self.module.lru_cache(maxsize=None)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001352 def fib(n):
1353 if n < 2:
1354 return n
1355 return fib(n=n-1) + fib(n=n-2)
1356 self.assertEqual([fib(n=number) for number in range(16)],
1357 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1358 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001359 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001360 fib.cache_clear()
1361 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001362 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1363
Raymond Hettinger4ee39142017-01-08 17:28:20 -08001364 def test_kwargs_order(self):
1365 # PEP 468: Preserving Keyword Argument Order
1366 @self.module.lru_cache(maxsize=10)
1367 def f(**kwargs):
1368 return list(kwargs.items())
1369 self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)])
1370 self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)])
1371 self.assertEqual(f.cache_info(),
1372 self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2))
1373
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001374 def test_lru_cache_decoration(self):
1375 def f(zomg: 'zomg_annotation'):
1376 """f doc string"""
1377 return 42
1378 g = self.module.lru_cache()(f)
1379 for attr in self.module.WRAPPER_ASSIGNMENTS:
1380 self.assertEqual(getattr(g, attr), getattr(f, attr))
1381
1382 @unittest.skipUnless(threading, 'This test requires threading.')
1383 def test_lru_cache_threaded(self):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001384 n, m = 5, 11
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001385 def orig(x, y):
1386 return 3 * x + y
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001387 f = self.module.lru_cache(maxsize=n*m)(orig)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001388 hits, misses, maxsize, currsize = f.cache_info()
1389 self.assertEqual(currsize, 0)
1390
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001391 start = threading.Event()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001392 def full(k):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001393 start.wait(10)
1394 for _ in range(m):
Serhiy Storchaka391af752015-06-08 12:44:18 +03001395 self.assertEqual(f(k, 0), orig(k, 0))
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001396
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001397 def clear():
1398 start.wait(10)
1399 for _ in range(2*m):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001400 f.cache_clear()
1401
1402 orig_si = sys.getswitchinterval()
Xavier de Gaye7522ef42016-12-08 11:06:56 +01001403 support.setswitchinterval(1e-6)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001404 try:
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001405 # create n threads in order to fill cache
Serhiy Storchaka391af752015-06-08 12:44:18 +03001406 threads = [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001407 for k in range(n)]
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001408 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001409 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001410
1411 hits, misses, maxsize, currsize = f.cache_info()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001412 if self.module is py_functools:
1413 # XXX: Why can be not equal?
1414 self.assertLessEqual(misses, n)
1415 self.assertLessEqual(hits, m*n - misses)
1416 else:
1417 self.assertEqual(misses, n)
1418 self.assertEqual(hits, m*n - misses)
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001419 self.assertEqual(currsize, n)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001420
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001421 # create n threads in order to fill cache and 1 to clear it
1422 threads = [threading.Thread(target=clear)]
Serhiy Storchaka391af752015-06-08 12:44:18 +03001423 threads += [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001424 for k in range(n)]
1425 start.clear()
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001426 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001427 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001428 finally:
1429 sys.setswitchinterval(orig_si)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001430
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001431 @unittest.skipUnless(threading, 'This test requires threading.')
1432 def test_lru_cache_threaded2(self):
1433 # Simultaneous call with the same arguments
1434 n, m = 5, 7
1435 start = threading.Barrier(n+1)
1436 pause = threading.Barrier(n+1)
1437 stop = threading.Barrier(n+1)
1438 @self.module.lru_cache(maxsize=m*n)
1439 def f(x):
1440 pause.wait(10)
1441 return 3 * x
1442 self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1443 def test():
1444 for i in range(m):
1445 start.wait(10)
1446 self.assertEqual(f(i), 3 * i)
1447 stop.wait(10)
1448 threads = [threading.Thread(target=test) for k in range(n)]
1449 with support.start_threads(threads):
1450 for i in range(m):
1451 start.wait(10)
1452 stop.reset()
1453 pause.wait(10)
1454 start.reset()
1455 stop.wait(10)
1456 pause.reset()
1457 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1458
Serhiy Storchaka67796522017-01-12 18:34:33 +02001459 @unittest.skipUnless(threading, 'This test requires threading.')
1460 def test_lru_cache_threaded3(self):
1461 @self.module.lru_cache(maxsize=2)
1462 def f(x):
1463 time.sleep(.01)
1464 return 3 * x
1465 def test(i, x):
1466 with self.subTest(thread=i):
1467 self.assertEqual(f(x), 3 * x, i)
1468 threads = [threading.Thread(target=test, args=(i, v))
1469 for i, v in enumerate([1, 2, 2, 3, 2])]
1470 with support.start_threads(threads):
1471 pass
1472
Raymond Hettinger03923422013-03-04 02:52:50 -05001473 def test_need_for_rlock(self):
1474 # This will deadlock on an LRU cache that uses a regular lock
1475
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001476 @self.module.lru_cache(maxsize=10)
Raymond Hettinger03923422013-03-04 02:52:50 -05001477 def test_func(x):
1478 'Used to demonstrate a reentrant lru_cache call within a single thread'
1479 return x
1480
1481 class DoubleEq:
1482 'Demonstrate a reentrant lru_cache call within a single thread'
1483 def __init__(self, x):
1484 self.x = x
1485 def __hash__(self):
1486 return self.x
1487 def __eq__(self, other):
1488 if self.x == 2:
1489 test_func(DoubleEq(1))
1490 return self.x == other.x
1491
1492 test_func(DoubleEq(1)) # Load the cache
1493 test_func(DoubleEq(2)) # Load the cache
1494 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1495 DoubleEq(2)) # Verify the correct return value
1496
Raymond Hettinger4d588972014-08-12 12:44:52 -07001497 def test_early_detection_of_bad_call(self):
1498 # Issue #22184
1499 with self.assertRaises(TypeError):
1500 @functools.lru_cache
1501 def f():
1502 pass
1503
Serhiy Storchakae7070f02015-06-08 11:19:24 +03001504 def test_lru_method(self):
1505 class X(int):
1506 f_cnt = 0
1507 @self.module.lru_cache(2)
1508 def f(self, x):
1509 self.f_cnt += 1
1510 return x*10+self
1511 a = X(5)
1512 b = X(5)
1513 c = X(7)
1514 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1515
1516 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1517 self.assertEqual(a.f(x), x*10 + 5)
1518 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1519 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1520
1521 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1522 self.assertEqual(b.f(x), x*10 + 5)
1523 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1524 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1525
1526 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1527 self.assertEqual(c.f(x), x*10 + 7)
1528 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1529 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1530
1531 self.assertEqual(a.f.cache_info(), X.f.cache_info())
1532 self.assertEqual(b.f.cache_info(), X.f.cache_info())
1533 self.assertEqual(c.f.cache_info(), X.f.cache_info())
1534
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001535 def test_pickle(self):
1536 cls = self.__class__
1537 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1538 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1539 with self.subTest(proto=proto, func=f):
1540 f_copy = pickle.loads(pickle.dumps(f, proto))
1541 self.assertIs(f_copy, f)
1542
1543 def test_copy(self):
1544 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001545 def orig(x, y):
1546 return 3 * x + y
1547 part = self.module.partial(orig, 2)
1548 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1549 self.module.lru_cache(2)(part))
1550 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001551 with self.subTest(func=f):
1552 f_copy = copy.copy(f)
1553 self.assertIs(f_copy, f)
1554
1555 def test_deepcopy(self):
1556 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001557 def orig(x, y):
1558 return 3 * x + y
1559 part = self.module.partial(orig, 2)
1560 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1561 self.module.lru_cache(2)(part))
1562 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001563 with self.subTest(func=f):
1564 f_copy = copy.deepcopy(f)
1565 self.assertIs(f_copy, f)
1566
1567
1568@py_functools.lru_cache()
1569def py_cached_func(x, y):
1570 return 3 * x + y
1571
1572@c_functools.lru_cache()
1573def c_cached_func(x, y):
1574 return 3 * x + y
1575
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001576
1577class TestLRUPy(TestLRU, unittest.TestCase):
1578 module = py_functools
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001579 cached_func = py_cached_func,
1580
1581 @module.lru_cache()
1582 def cached_meth(self, x, y):
1583 return 3 * x + y
1584
1585 @staticmethod
1586 @module.lru_cache()
1587 def cached_staticmeth(x, y):
1588 return 3 * x + y
1589
1590
1591class TestLRUC(TestLRU, unittest.TestCase):
1592 module = c_functools
1593 cached_func = c_cached_func,
1594
1595 @module.lru_cache()
1596 def cached_meth(self, x, y):
1597 return 3 * x + y
1598
1599 @staticmethod
1600 @module.lru_cache()
1601 def cached_staticmeth(x, y):
1602 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001603
Raymond Hettinger03923422013-03-04 02:52:50 -05001604
Łukasz Langa6f692512013-06-05 12:20:24 +02001605class TestSingleDispatch(unittest.TestCase):
1606 def test_simple_overloads(self):
1607 @functools.singledispatch
1608 def g(obj):
1609 return "base"
1610 def g_int(i):
1611 return "integer"
1612 g.register(int, g_int)
1613 self.assertEqual(g("str"), "base")
1614 self.assertEqual(g(1), "integer")
1615 self.assertEqual(g([1,2,3]), "base")
1616
1617 def test_mro(self):
1618 @functools.singledispatch
1619 def g(obj):
1620 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001621 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001622 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001623 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001624 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001625 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001626 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001627 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001628 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001629 def g_A(a):
1630 return "A"
1631 def g_B(b):
1632 return "B"
1633 g.register(A, g_A)
1634 g.register(B, g_B)
1635 self.assertEqual(g(A()), "A")
1636 self.assertEqual(g(B()), "B")
1637 self.assertEqual(g(C()), "A")
1638 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001639
1640 def test_register_decorator(self):
1641 @functools.singledispatch
1642 def g(obj):
1643 return "base"
1644 @g.register(int)
1645 def g_int(i):
1646 return "int %s" % (i,)
1647 self.assertEqual(g(""), "base")
1648 self.assertEqual(g(12), "int 12")
1649 self.assertIs(g.dispatch(int), g_int)
1650 self.assertIs(g.dispatch(object), g.dispatch(str))
1651 # Note: in the assert above this is not g.
1652 # @singledispatch returns the wrapper.
1653
1654 def test_wrapping_attributes(self):
1655 @functools.singledispatch
1656 def g(obj):
1657 "Simple test"
1658 return "Test"
1659 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02001660 if sys.flags.optimize < 2:
1661 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02001662
1663 @unittest.skipUnless(decimal, 'requires _decimal')
1664 @support.cpython_only
1665 def test_c_classes(self):
1666 @functools.singledispatch
1667 def g(obj):
1668 return "base"
1669 @g.register(decimal.DecimalException)
1670 def _(obj):
1671 return obj.args
1672 subn = decimal.Subnormal("Exponent < Emin")
1673 rnd = decimal.Rounded("Number got rounded")
1674 self.assertEqual(g(subn), ("Exponent < Emin",))
1675 self.assertEqual(g(rnd), ("Number got rounded",))
1676 @g.register(decimal.Subnormal)
1677 def _(obj):
1678 return "Too small to care."
1679 self.assertEqual(g(subn), "Too small to care.")
1680 self.assertEqual(g(rnd), ("Number got rounded",))
1681
1682 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001683 # None of the examples in this test depend on haystack ordering.
Łukasz Langa6f692512013-06-05 12:20:24 +02001684 c = collections
1685 mro = functools._compose_mro
1686 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1687 for haystack in permutations(bases):
1688 m = mro(dict, haystack)
Guido van Rossumf0666942016-08-23 10:47:07 -07001689 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
1690 c.Collection, c.Sized, c.Iterable,
1691 c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001692 bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict]
1693 for haystack in permutations(bases):
1694 m = mro(c.ChainMap, haystack)
1695 self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001696 c.Collection, c.Sized, c.Iterable,
1697 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001698
1699 # If there's a generic function with implementations registered for
1700 # both Sized and Container, passing a defaultdict to it results in an
1701 # ambiguous dispatch which will cause a RuntimeError (see
1702 # test_mro_conflicts).
1703 bases = [c.Container, c.Sized, str]
1704 for haystack in permutations(bases):
1705 m = mro(c.defaultdict, [c.Sized, c.Container, str])
1706 self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container,
1707 object])
1708
1709 # MutableSequence below is registered directly on D. In other words, it
Martin Panter46f50722016-05-26 05:35:26 +00001710 # precedes MutableMapping which means single dispatch will always
Łukasz Langa3720c772013-07-01 16:00:38 +02001711 # choose MutableSequence here.
1712 class D(c.defaultdict):
1713 pass
1714 c.MutableSequence.register(D)
1715 bases = [c.MutableSequence, c.MutableMapping]
1716 for haystack in permutations(bases):
1717 m = mro(D, bases)
Guido van Rossumf0666942016-08-23 10:47:07 -07001718 self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
1719 c.defaultdict, dict, c.MutableMapping, c.Mapping,
1720 c.Collection, c.Sized, c.Iterable, c.Container,
Łukasz Langa3720c772013-07-01 16:00:38 +02001721 object])
1722
1723 # Container and Callable are registered on different base classes and
1724 # a generic function supporting both should always pick the Callable
1725 # implementation if a C instance is passed.
1726 class C(c.defaultdict):
1727 def __call__(self):
1728 pass
1729 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1730 for haystack in permutations(bases):
1731 m = mro(C, haystack)
1732 self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001733 c.Collection, c.Sized, c.Iterable,
1734 c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001735
1736 def test_register_abc(self):
1737 c = collections
1738 d = {"a": "b"}
1739 l = [1, 2, 3]
1740 s = {object(), None}
1741 f = frozenset(s)
1742 t = (1, 2, 3)
1743 @functools.singledispatch
1744 def g(obj):
1745 return "base"
1746 self.assertEqual(g(d), "base")
1747 self.assertEqual(g(l), "base")
1748 self.assertEqual(g(s), "base")
1749 self.assertEqual(g(f), "base")
1750 self.assertEqual(g(t), "base")
1751 g.register(c.Sized, lambda obj: "sized")
1752 self.assertEqual(g(d), "sized")
1753 self.assertEqual(g(l), "sized")
1754 self.assertEqual(g(s), "sized")
1755 self.assertEqual(g(f), "sized")
1756 self.assertEqual(g(t), "sized")
1757 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1758 self.assertEqual(g(d), "mutablemapping")
1759 self.assertEqual(g(l), "sized")
1760 self.assertEqual(g(s), "sized")
1761 self.assertEqual(g(f), "sized")
1762 self.assertEqual(g(t), "sized")
1763 g.register(c.ChainMap, lambda obj: "chainmap")
1764 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1765 self.assertEqual(g(l), "sized")
1766 self.assertEqual(g(s), "sized")
1767 self.assertEqual(g(f), "sized")
1768 self.assertEqual(g(t), "sized")
1769 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1770 self.assertEqual(g(d), "mutablemapping")
1771 self.assertEqual(g(l), "mutablesequence")
1772 self.assertEqual(g(s), "sized")
1773 self.assertEqual(g(f), "sized")
1774 self.assertEqual(g(t), "sized")
1775 g.register(c.MutableSet, lambda obj: "mutableset")
1776 self.assertEqual(g(d), "mutablemapping")
1777 self.assertEqual(g(l), "mutablesequence")
1778 self.assertEqual(g(s), "mutableset")
1779 self.assertEqual(g(f), "sized")
1780 self.assertEqual(g(t), "sized")
1781 g.register(c.Mapping, lambda obj: "mapping")
1782 self.assertEqual(g(d), "mutablemapping") # not specific enough
1783 self.assertEqual(g(l), "mutablesequence")
1784 self.assertEqual(g(s), "mutableset")
1785 self.assertEqual(g(f), "sized")
1786 self.assertEqual(g(t), "sized")
1787 g.register(c.Sequence, lambda obj: "sequence")
1788 self.assertEqual(g(d), "mutablemapping")
1789 self.assertEqual(g(l), "mutablesequence")
1790 self.assertEqual(g(s), "mutableset")
1791 self.assertEqual(g(f), "sized")
1792 self.assertEqual(g(t), "sequence")
1793 g.register(c.Set, lambda obj: "set")
1794 self.assertEqual(g(d), "mutablemapping")
1795 self.assertEqual(g(l), "mutablesequence")
1796 self.assertEqual(g(s), "mutableset")
1797 self.assertEqual(g(f), "set")
1798 self.assertEqual(g(t), "sequence")
1799 g.register(dict, lambda obj: "dict")
1800 self.assertEqual(g(d), "dict")
1801 self.assertEqual(g(l), "mutablesequence")
1802 self.assertEqual(g(s), "mutableset")
1803 self.assertEqual(g(f), "set")
1804 self.assertEqual(g(t), "sequence")
1805 g.register(list, lambda obj: "list")
1806 self.assertEqual(g(d), "dict")
1807 self.assertEqual(g(l), "list")
1808 self.assertEqual(g(s), "mutableset")
1809 self.assertEqual(g(f), "set")
1810 self.assertEqual(g(t), "sequence")
1811 g.register(set, lambda obj: "concrete-set")
1812 self.assertEqual(g(d), "dict")
1813 self.assertEqual(g(l), "list")
1814 self.assertEqual(g(s), "concrete-set")
1815 self.assertEqual(g(f), "set")
1816 self.assertEqual(g(t), "sequence")
1817 g.register(frozenset, lambda obj: "frozen-set")
1818 self.assertEqual(g(d), "dict")
1819 self.assertEqual(g(l), "list")
1820 self.assertEqual(g(s), "concrete-set")
1821 self.assertEqual(g(f), "frozen-set")
1822 self.assertEqual(g(t), "sequence")
1823 g.register(tuple, lambda obj: "tuple")
1824 self.assertEqual(g(d), "dict")
1825 self.assertEqual(g(l), "list")
1826 self.assertEqual(g(s), "concrete-set")
1827 self.assertEqual(g(f), "frozen-set")
1828 self.assertEqual(g(t), "tuple")
1829
Łukasz Langa3720c772013-07-01 16:00:38 +02001830 def test_c3_abc(self):
1831 c = collections
1832 mro = functools._c3_mro
1833 class A(object):
1834 pass
1835 class B(A):
1836 def __len__(self):
1837 return 0 # implies Sized
1838 @c.Container.register
1839 class C(object):
1840 pass
1841 class D(object):
1842 pass # unrelated
1843 class X(D, C, B):
1844 def __call__(self):
1845 pass # implies Callable
1846 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
1847 for abcs in permutations([c.Sized, c.Callable, c.Container]):
1848 self.assertEqual(mro(X, abcs=abcs), expected)
1849 # unrelated ABCs don't appear in the resulting MRO
1850 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
1851 self.assertEqual(mro(X, abcs=many_abcs), expected)
1852
Yury Selivanov77a8cd62015-08-18 14:20:00 -04001853 def test_false_meta(self):
1854 # see issue23572
1855 class MetaA(type):
1856 def __len__(self):
1857 return 0
1858 class A(metaclass=MetaA):
1859 pass
1860 class AA(A):
1861 pass
1862 @functools.singledispatch
1863 def fun(a):
1864 return 'base A'
1865 @fun.register(A)
1866 def _(a):
1867 return 'fun A'
1868 aa = AA()
1869 self.assertEqual(fun(aa), 'fun A')
1870
Łukasz Langa6f692512013-06-05 12:20:24 +02001871 def test_mro_conflicts(self):
1872 c = collections
Łukasz Langa6f692512013-06-05 12:20:24 +02001873 @functools.singledispatch
1874 def g(arg):
1875 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02001876 class O(c.Sized):
1877 def __len__(self):
1878 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001879 o = O()
1880 self.assertEqual(g(o), "base")
1881 g.register(c.Iterable, lambda arg: "iterable")
1882 g.register(c.Container, lambda arg: "container")
1883 g.register(c.Sized, lambda arg: "sized")
1884 g.register(c.Set, lambda arg: "set")
1885 self.assertEqual(g(o), "sized")
1886 c.Iterable.register(O)
1887 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
1888 c.Container.register(O)
1889 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02001890 c.Set.register(O)
1891 self.assertEqual(g(o), "set") # because c.Set is a subclass of
1892 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02001893 class P:
1894 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02001895 p = P()
1896 self.assertEqual(g(p), "base")
1897 c.Iterable.register(P)
1898 self.assertEqual(g(p), "iterable")
1899 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02001900 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02001901 g(p)
Benjamin Petersonab078e92016-07-13 21:13:29 -07001902 self.assertIn(
1903 str(re_one.exception),
1904 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1905 "or <class 'collections.abc.Iterable'>"),
1906 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
1907 "or <class 'collections.abc.Container'>")),
1908 )
Łukasz Langa6f692512013-06-05 12:20:24 +02001909 class Q(c.Sized):
1910 def __len__(self):
1911 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001912 q = Q()
1913 self.assertEqual(g(q), "sized")
1914 c.Iterable.register(Q)
1915 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
1916 c.Set.register(Q)
1917 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02001918 # c.Sized and c.Iterable
1919 @functools.singledispatch
1920 def h(arg):
1921 return "base"
1922 @h.register(c.Sized)
1923 def _(arg):
1924 return "sized"
1925 @h.register(c.Container)
1926 def _(arg):
1927 return "container"
1928 # Even though Sized and Container are explicit bases of MutableMapping,
1929 # this ABC is implicitly registered on defaultdict which makes all of
1930 # MutableMapping's bases implicit as well from defaultdict's
1931 # perspective.
1932 with self.assertRaises(RuntimeError) as re_two:
1933 h(c.defaultdict(lambda: 0))
Benjamin Petersonab078e92016-07-13 21:13:29 -07001934 self.assertIn(
1935 str(re_two.exception),
1936 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1937 "or <class 'collections.abc.Sized'>"),
1938 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1939 "or <class 'collections.abc.Container'>")),
1940 )
Łukasz Langa3720c772013-07-01 16:00:38 +02001941 class R(c.defaultdict):
1942 pass
1943 c.MutableSequence.register(R)
1944 @functools.singledispatch
1945 def i(arg):
1946 return "base"
1947 @i.register(c.MutableMapping)
1948 def _(arg):
1949 return "mapping"
1950 @i.register(c.MutableSequence)
1951 def _(arg):
1952 return "sequence"
1953 r = R()
1954 self.assertEqual(i(r), "sequence")
1955 class S:
1956 pass
1957 class T(S, c.Sized):
1958 def __len__(self):
1959 return 0
1960 t = T()
1961 self.assertEqual(h(t), "sized")
1962 c.Container.register(T)
1963 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
1964 class U:
1965 def __len__(self):
1966 return 0
1967 u = U()
1968 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
1969 # from the existence of __len__()
1970 c.Container.register(U)
1971 # There is no preference for registered versus inferred ABCs.
1972 with self.assertRaises(RuntimeError) as re_three:
1973 h(u)
Benjamin Petersonab078e92016-07-13 21:13:29 -07001974 self.assertIn(
1975 str(re_three.exception),
1976 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1977 "or <class 'collections.abc.Sized'>"),
1978 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1979 "or <class 'collections.abc.Container'>")),
1980 )
Łukasz Langa3720c772013-07-01 16:00:38 +02001981 class V(c.Sized, S):
1982 def __len__(self):
1983 return 0
1984 @functools.singledispatch
1985 def j(arg):
1986 return "base"
1987 @j.register(S)
1988 def _(arg):
1989 return "s"
1990 @j.register(c.Container)
1991 def _(arg):
1992 return "container"
1993 v = V()
1994 self.assertEqual(j(v), "s")
1995 c.Container.register(V)
1996 self.assertEqual(j(v), "container") # because it ends up right after
1997 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02001998
1999 def test_cache_invalidation(self):
2000 from collections import UserDict
2001 class TracingDict(UserDict):
2002 def __init__(self, *args, **kwargs):
2003 super(TracingDict, self).__init__(*args, **kwargs)
2004 self.set_ops = []
2005 self.get_ops = []
2006 def __getitem__(self, key):
2007 result = self.data[key]
2008 self.get_ops.append(key)
2009 return result
2010 def __setitem__(self, key, value):
2011 self.set_ops.append(key)
2012 self.data[key] = value
2013 def clear(self):
2014 self.data.clear()
2015 _orig_wkd = functools.WeakKeyDictionary
2016 td = TracingDict()
2017 functools.WeakKeyDictionary = lambda: td
2018 c = collections
2019 @functools.singledispatch
2020 def g(arg):
2021 return "base"
2022 d = {}
2023 l = []
2024 self.assertEqual(len(td), 0)
2025 self.assertEqual(g(d), "base")
2026 self.assertEqual(len(td), 1)
2027 self.assertEqual(td.get_ops, [])
2028 self.assertEqual(td.set_ops, [dict])
2029 self.assertEqual(td.data[dict], g.registry[object])
2030 self.assertEqual(g(l), "base")
2031 self.assertEqual(len(td), 2)
2032 self.assertEqual(td.get_ops, [])
2033 self.assertEqual(td.set_ops, [dict, list])
2034 self.assertEqual(td.data[dict], g.registry[object])
2035 self.assertEqual(td.data[list], g.registry[object])
2036 self.assertEqual(td.data[dict], td.data[list])
2037 self.assertEqual(g(l), "base")
2038 self.assertEqual(g(d), "base")
2039 self.assertEqual(td.get_ops, [list, dict])
2040 self.assertEqual(td.set_ops, [dict, list])
2041 g.register(list, lambda arg: "list")
2042 self.assertEqual(td.get_ops, [list, dict])
2043 self.assertEqual(len(td), 0)
2044 self.assertEqual(g(d), "base")
2045 self.assertEqual(len(td), 1)
2046 self.assertEqual(td.get_ops, [list, dict])
2047 self.assertEqual(td.set_ops, [dict, list, dict])
2048 self.assertEqual(td.data[dict],
2049 functools._find_impl(dict, g.registry))
2050 self.assertEqual(g(l), "list")
2051 self.assertEqual(len(td), 2)
2052 self.assertEqual(td.get_ops, [list, dict])
2053 self.assertEqual(td.set_ops, [dict, list, dict, list])
2054 self.assertEqual(td.data[list],
2055 functools._find_impl(list, g.registry))
2056 class X:
2057 pass
2058 c.MutableMapping.register(X) # Will not invalidate the cache,
2059 # not using ABCs yet.
2060 self.assertEqual(g(d), "base")
2061 self.assertEqual(g(l), "list")
2062 self.assertEqual(td.get_ops, [list, dict, dict, list])
2063 self.assertEqual(td.set_ops, [dict, list, dict, list])
2064 g.register(c.Sized, lambda arg: "sized")
2065 self.assertEqual(len(td), 0)
2066 self.assertEqual(g(d), "sized")
2067 self.assertEqual(len(td), 1)
2068 self.assertEqual(td.get_ops, [list, dict, dict, list])
2069 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
2070 self.assertEqual(g(l), "list")
2071 self.assertEqual(len(td), 2)
2072 self.assertEqual(td.get_ops, [list, dict, dict, list])
2073 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2074 self.assertEqual(g(l), "list")
2075 self.assertEqual(g(d), "sized")
2076 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
2077 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2078 g.dispatch(list)
2079 g.dispatch(dict)
2080 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
2081 list, dict])
2082 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2083 c.MutableSet.register(X) # Will invalidate the cache.
2084 self.assertEqual(len(td), 2) # Stale cache.
2085 self.assertEqual(g(l), "list")
2086 self.assertEqual(len(td), 1)
2087 g.register(c.MutableMapping, lambda arg: "mutablemapping")
2088 self.assertEqual(len(td), 0)
2089 self.assertEqual(g(d), "mutablemapping")
2090 self.assertEqual(len(td), 1)
2091 self.assertEqual(g(l), "list")
2092 self.assertEqual(len(td), 2)
2093 g.register(dict, lambda arg: "dict")
2094 self.assertEqual(g(d), "dict")
2095 self.assertEqual(g(l), "list")
2096 g._clear_cache()
2097 self.assertEqual(len(td), 0)
2098 functools.WeakKeyDictionary = _orig_wkd
2099
2100
Raymond Hettinger9c323f82005-02-28 19:39:44 +00002101if __name__ == '__main__':
Zachary Ware38c707e2015-04-13 15:00:43 -05002102 unittest.main()