blob: b7d648d0b15a553840c5aaf29cf42f1d1bb6d1be [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08002import builtins
Raymond Hettinger003be522011-05-03 11:01:32 -07003import collections
Serhiy Storchaka45120f22015-10-24 09:49:56 +03004import copy
Łukasz Langa6f692512013-06-05 12:20:24 +02005from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00006import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00007from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02008import sys
9from test import support
Serhiy Storchaka67796522017-01-12 18:34:33 +020010import time
Łukasz Langa6f692512013-06-05 12:20:24 +020011import unittest
12from weakref import proxy
Nick Coghlan457fc9a2016-09-10 20:00:02 +100013import contextlib
Serhiy Storchaka46c56112015-05-24 21:53:49 +030014try:
15 import threading
16except ImportError:
17 threading = None
Raymond Hettinger9c323f82005-02-28 19:39:44 +000018
Antoine Pitroub5b37142012-11-13 21:35:40 +010019import functools
20
Antoine Pitroub5b37142012-11-13 21:35:40 +010021py_functools = support.import_fresh_module('functools', blocked=['_functools'])
22c_functools = support.import_fresh_module('functools', fresh=['_functools'])
23
Łukasz Langa6f692512013-06-05 12:20:24 +020024decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
25
Nick Coghlan457fc9a2016-09-10 20:00:02 +100026@contextlib.contextmanager
27def replaced_module(name, replacement):
28 original_module = sys.modules[name]
29 sys.modules[name] = replacement
30 try:
31 yield
32 finally:
33 sys.modules[name] = original_module
Łukasz Langa6f692512013-06-05 12:20:24 +020034
Raymond Hettinger9c323f82005-02-28 19:39:44 +000035def capture(*args, **kw):
36 """capture all positional and keyword arguments"""
37 return args, kw
38
Łukasz Langa6f692512013-06-05 12:20:24 +020039
Jack Diederiche0cbd692009-04-01 04:27:09 +000040def signature(part):
41 """ return the signature of a partial object """
42 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000043
Serhiy Storchaka38741282016-02-02 18:45:17 +020044class MyTuple(tuple):
45 pass
46
47class BadTuple(tuple):
48 def __add__(self, other):
49 return list(self) + list(other)
50
51class MyDict(dict):
52 pass
53
Łukasz Langa6f692512013-06-05 12:20:24 +020054
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020055class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000056
57 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010058 p = self.partial(capture, 1, 2, a=10, b=20)
59 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000060 self.assertEqual(p(3, 4, b=30, c=40),
61 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010062 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000063 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000064
65 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010066 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000067 # attributes should be readable
68 self.assertEqual(p.func, capture)
69 self.assertEqual(p.args, (1, 2))
70 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000071
72 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010073 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000074 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010075 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000076 except TypeError:
77 pass
78 else:
79 self.fail('First arg not checked for callability')
80
81 def test_protection_of_callers_dict_argument(self):
82 # a caller's dictionary should not be altered by partial
83 def func(a=10, b=20):
84 return a
85 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010086 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000087 self.assertEqual(p(**d), 3)
88 self.assertEqual(d, {'a':3})
89 p(b=7)
90 self.assertEqual(d, {'a':3})
91
Serhiy Storchakae48fd932017-02-21 18:18:27 +020092 def test_kwargs_copy(self):
93 # Issue #29532: Altering a kwarg dictionary passed to a constructor
94 # should not affect a partial object after creation
95 d = {'a': 3}
96 p = self.partial(capture, **d)
97 self.assertEqual(p(), ((), {'a': 3}))
98 d['a'] = 5
99 self.assertEqual(p(), ((), {'a': 3}))
100
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000101 def test_arg_combinations(self):
102 # exercise special code paths for zero args in either partial
103 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100104 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000105 self.assertEqual(p(), ((), {}))
106 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100107 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000108 self.assertEqual(p(), ((1,2), {}))
109 self.assertEqual(p(3,4), ((1,2,3,4), {}))
110
111 def test_kw_combinations(self):
112 # exercise special code paths for no keyword args in
113 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100114 p = self.partial(capture)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400115 self.assertEqual(p.keywords, {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000116 self.assertEqual(p(), ((), {}))
117 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100118 p = self.partial(capture, a=1)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400119 self.assertEqual(p.keywords, {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000120 self.assertEqual(p(), ((), {'a':1}))
121 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
122 # keyword args in the call override those in the partial object
123 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
124
125 def test_positional(self):
126 # make sure positional arguments are captured correctly
127 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100128 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000129 expected = args + ('x',)
130 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000131 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000132
133 def test_keyword(self):
134 # make sure keyword arguments are captured correctly
135 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100136 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000137 expected = {'a':a,'x':None}
138 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000139 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000140
141 def test_no_side_effects(self):
142 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100143 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000144 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000145 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000146 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000147 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000148
149 def test_error_propagation(self):
150 def f(x, y):
151 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100152 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
153 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
154 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
155 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000156
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000157 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100158 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000159 p = proxy(f)
160 self.assertEqual(f.func, p.func)
161 f = None
162 self.assertRaises(ReferenceError, getattr, p, 'func')
163
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000164 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000165 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100166 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000167 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100168 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000169 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000170
Alexander Belopolskye49af342015-03-01 15:08:17 -0500171 def test_nested_optimization(self):
172 partial = self.partial
Alexander Belopolskye49af342015-03-01 15:08:17 -0500173 inner = partial(signature, 'asdf')
174 nested = partial(inner, bar=True)
175 flat = partial(signature, 'asdf', bar=True)
176 self.assertEqual(signature(nested), signature(flat))
177
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300178 def test_nested_partial_with_attribute(self):
179 # see issue 25137
180 partial = self.partial
181
182 def foo(bar):
183 return bar
184
185 p = partial(foo, 'first')
186 p2 = partial(p, 'second')
187 p2.new_attr = 'spam'
188 self.assertEqual(p2.new_attr, 'spam')
189
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000190 def test_repr(self):
191 args = (object(), object())
192 args_repr = ', '.join(repr(a) for a in args)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200193 kwargs = {'a': object(), 'b': object()}
194 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
195 'b={b!r}, a={a!r}'.format_map(kwargs)]
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000196 if self.partial in (c_functools.partial, py_functools.partial):
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000197 name = 'functools.partial'
198 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100199 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000200
Antoine Pitroub5b37142012-11-13 21:35:40 +0100201 f = self.partial(capture)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000202 self.assertEqual(f'{name}({capture!r})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000203
Antoine Pitroub5b37142012-11-13 21:35:40 +0100204 f = self.partial(capture, *args)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000205 self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000206
Antoine Pitroub5b37142012-11-13 21:35:40 +0100207 f = self.partial(capture, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200208 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000209 [f'{name}({capture!r}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200210 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000211
Antoine Pitroub5b37142012-11-13 21:35:40 +0100212 f = self.partial(capture, *args, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200213 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000214 [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200215 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000216
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300217 def test_recursive_repr(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000218 if self.partial in (c_functools.partial, py_functools.partial):
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300219 name = 'functools.partial'
220 else:
221 name = self.partial.__name__
222
223 f = self.partial(capture)
224 f.__setstate__((f, (), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300225 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000226 self.assertEqual(repr(f), '%s(...)' % (name,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300227 finally:
228 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300229
230 f = self.partial(capture)
231 f.__setstate__((capture, (f,), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300232 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000233 self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300234 finally:
235 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300236
237 f = self.partial(capture)
238 f.__setstate__((capture, (), {'a': f}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300239 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000240 self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300241 finally:
242 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300243
Jack Diederiche0cbd692009-04-01 04:27:09 +0000244 def test_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000245 with self.AllowPickle():
246 f = self.partial(signature, ['asdf'], bar=[True])
247 f.attr = []
248 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
249 f_copy = pickle.loads(pickle.dumps(f, proto))
250 self.assertEqual(signature(f_copy), signature(f))
Serhiy Storchaka38741282016-02-02 18:45:17 +0200251
252 def test_copy(self):
253 f = self.partial(signature, ['asdf'], bar=[True])
254 f.attr = []
255 f_copy = copy.copy(f)
256 self.assertEqual(signature(f_copy), signature(f))
257 self.assertIs(f_copy.attr, f.attr)
258 self.assertIs(f_copy.args, f.args)
259 self.assertIs(f_copy.keywords, f.keywords)
260
261 def test_deepcopy(self):
262 f = self.partial(signature, ['asdf'], bar=[True])
263 f.attr = []
264 f_copy = copy.deepcopy(f)
265 self.assertEqual(signature(f_copy), signature(f))
266 self.assertIsNot(f_copy.attr, f.attr)
267 self.assertIsNot(f_copy.args, f.args)
268 self.assertIsNot(f_copy.args[0], f.args[0])
269 self.assertIsNot(f_copy.keywords, f.keywords)
270 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
271
272 def test_setstate(self):
273 f = self.partial(signature)
274 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000275
Serhiy Storchaka38741282016-02-02 18:45:17 +0200276 self.assertEqual(signature(f),
277 (capture, (1,), dict(a=10), dict(attr=[])))
278 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
279
280 f.__setstate__((capture, (1,), dict(a=10), None))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000281
Serhiy Storchaka38741282016-02-02 18:45:17 +0200282 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
283 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
284
285 f.__setstate__((capture, (1,), None, None))
286 #self.assertEqual(signature(f), (capture, (1,), {}, {}))
287 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
288 self.assertEqual(f(2), ((1, 2), {}))
289 self.assertEqual(f(), ((1,), {}))
290
291 f.__setstate__((capture, (), {}, None))
292 self.assertEqual(signature(f), (capture, (), {}, {}))
293 self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
294 self.assertEqual(f(2), ((2,), {}))
295 self.assertEqual(f(), ((), {}))
296
297 def test_setstate_errors(self):
298 f = self.partial(signature)
299 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
300 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
301 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
302 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
303 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
304 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
305 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
306
307 def test_setstate_subclasses(self):
308 f = self.partial(signature)
309 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
310 s = signature(f)
311 self.assertEqual(s, (capture, (1,), dict(a=10), {}))
312 self.assertIs(type(s[1]), tuple)
313 self.assertIs(type(s[2]), dict)
314 r = f()
315 self.assertEqual(r, ((1,), {'a': 10}))
316 self.assertIs(type(r[0]), tuple)
317 self.assertIs(type(r[1]), dict)
318
319 f.__setstate__((capture, BadTuple((1,)), {}, None))
320 s = signature(f)
321 self.assertEqual(s, (capture, (1,), {}, {}))
322 self.assertIs(type(s[1]), tuple)
323 r = f(2)
324 self.assertEqual(r, ((1, 2), {}))
325 self.assertIs(type(r[0]), tuple)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000326
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300327 def test_recursive_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000328 with self.AllowPickle():
329 f = self.partial(capture)
330 f.__setstate__((f, (), {}, {}))
331 try:
332 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
333 with self.assertRaises(RecursionError):
334 pickle.dumps(f, proto)
335 finally:
336 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300337
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000338 f = self.partial(capture)
339 f.__setstate__((capture, (f,), {}, {}))
340 try:
341 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
342 f_copy = pickle.loads(pickle.dumps(f, proto))
343 try:
344 self.assertIs(f_copy.args[0], f_copy)
345 finally:
346 f_copy.__setstate__((capture, (), {}, {}))
347 finally:
348 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300349
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000350 f = self.partial(capture)
351 f.__setstate__((capture, (), {'a': f}, {}))
352 try:
353 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
354 f_copy = pickle.loads(pickle.dumps(f, proto))
355 try:
356 self.assertIs(f_copy.keywords['a'], f_copy)
357 finally:
358 f_copy.__setstate__((capture, (), {}, {}))
359 finally:
360 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300361
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200362 # Issue 6083: Reference counting bug
363 def test_setstate_refcount(self):
364 class BadSequence:
365 def __len__(self):
366 return 4
367 def __getitem__(self, key):
368 if key == 0:
369 return max
370 elif key == 1:
371 return tuple(range(1000000))
372 elif key in (2, 3):
373 return {}
374 raise IndexError
375
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200376 f = self.partial(object)
Serhiy Storchaka38741282016-02-02 18:45:17 +0200377 self.assertRaises(TypeError, f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000378
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000379@unittest.skipUnless(c_functools, 'requires the C _functools module')
380class TestPartialC(TestPartial, unittest.TestCase):
381 if c_functools:
382 partial = c_functools.partial
383
384 class AllowPickle:
385 def __enter__(self):
386 return self
387 def __exit__(self, type, value, tb):
388 return False
389
390 def test_attributes_unwritable(self):
391 # attributes should not be writable
392 p = self.partial(capture, 1, 2, a=10, b=20)
393 self.assertRaises(AttributeError, setattr, p, 'func', map)
394 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
395 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
396
397 p = self.partial(hex)
398 try:
399 del p.__dict__
400 except TypeError:
401 pass
402 else:
403 self.fail('partial object allowed __dict__ to be deleted')
Łukasz Langa6f692512013-06-05 12:20:24 +0200404
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200405class TestPartialPy(TestPartial, unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000406 partial = py_functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000407
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000408 class AllowPickle:
409 def __init__(self):
410 self._cm = replaced_module("functools", py_functools)
411 def __enter__(self):
412 return self._cm.__enter__()
413 def __exit__(self, type, value, tb):
414 return self._cm.__exit__(type, value, tb)
Łukasz Langa6f692512013-06-05 12:20:24 +0200415
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200416if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000417 class CPartialSubclass(c_functools.partial):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200418 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100419
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000420class PyPartialSubclass(py_functools.partial):
421 pass
Łukasz Langa6f692512013-06-05 12:20:24 +0200422
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200423@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200424class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200425 if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000426 partial = CPartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000427
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300428 # partial subclasses are not optimized for nested calls
429 test_nested_optimization = None
430
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000431class TestPartialPySubclass(TestPartialPy):
432 partial = PyPartialSubclass
Łukasz Langa6f692512013-06-05 12:20:24 +0200433
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000434class TestPartialMethod(unittest.TestCase):
435
436 class A(object):
437 nothing = functools.partialmethod(capture)
438 positional = functools.partialmethod(capture, 1)
439 keywords = functools.partialmethod(capture, a=2)
440 both = functools.partialmethod(capture, 3, b=4)
441
442 nested = functools.partialmethod(positional, 5)
443
444 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
445
446 static = functools.partialmethod(staticmethod(capture), 8)
447 cls = functools.partialmethod(classmethod(capture), d=9)
448
449 a = A()
450
451 def test_arg_combinations(self):
452 self.assertEqual(self.a.nothing(), ((self.a,), {}))
453 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
454 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
455 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
456
457 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
458 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
459 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
460 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
461
462 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
463 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
464 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
465 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
466
467 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
468 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
469 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
470 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
471
472 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
473
474 def test_nested(self):
475 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
476 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
477 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
478 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
479
480 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
481
482 def test_over_partial(self):
483 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
484 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
485 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
486 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
487
488 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
489
490 def test_bound_method_introspection(self):
491 obj = self.a
492 self.assertIs(obj.both.__self__, obj)
493 self.assertIs(obj.nested.__self__, obj)
494 self.assertIs(obj.over_partial.__self__, obj)
495 self.assertIs(obj.cls.__self__, self.A)
496 self.assertIs(self.A.cls.__self__, self.A)
497
498 def test_unbound_method_retrieval(self):
499 obj = self.A
500 self.assertFalse(hasattr(obj.both, "__self__"))
501 self.assertFalse(hasattr(obj.nested, "__self__"))
502 self.assertFalse(hasattr(obj.over_partial, "__self__"))
503 self.assertFalse(hasattr(obj.static, "__self__"))
504 self.assertFalse(hasattr(self.a.static, "__self__"))
505
506 def test_descriptors(self):
507 for obj in [self.A, self.a]:
508 with self.subTest(obj=obj):
509 self.assertEqual(obj.static(), ((8,), {}))
510 self.assertEqual(obj.static(5), ((8, 5), {}))
511 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
512 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
513
514 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
515 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
516 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
517 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
518
519 def test_overriding_keywords(self):
520 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
521 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
522
523 def test_invalid_args(self):
524 with self.assertRaises(TypeError):
525 class B(object):
526 method = functools.partialmethod(None, 1)
527
528 def test_repr(self):
529 self.assertEqual(repr(vars(self.A)['both']),
530 'functools.partialmethod({}, 3, b=4)'.format(capture))
531
532 def test_abstract(self):
533 class Abstract(abc.ABCMeta):
534
535 @abc.abstractmethod
536 def add(self, x, y):
537 pass
538
539 add5 = functools.partialmethod(add, 5)
540
541 self.assertTrue(Abstract.add.__isabstractmethod__)
542 self.assertTrue(Abstract.add5.__isabstractmethod__)
543
544 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
545 self.assertFalse(getattr(func, '__isabstractmethod__', False))
546
547
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000548class TestUpdateWrapper(unittest.TestCase):
549
550 def check_wrapper(self, wrapper, wrapped,
551 assigned=functools.WRAPPER_ASSIGNMENTS,
552 updated=functools.WRAPPER_UPDATES):
553 # Check attributes were assigned
554 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000555 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000556 # Check attributes were updated
557 for name in updated:
558 wrapper_attr = getattr(wrapper, name)
559 wrapped_attr = getattr(wrapped, name)
560 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000561 if name == "__dict__" and key == "__wrapped__":
562 # __wrapped__ is overwritten by the update code
563 continue
564 self.assertIs(wrapped_attr[key], wrapper_attr[key])
565 # Check __wrapped__
566 self.assertIs(wrapper.__wrapped__, wrapped)
567
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000568
R. David Murray378c0cf2010-02-24 01:46:21 +0000569 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000570 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000571 """This is a test"""
572 pass
573 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000574 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000575 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000576 pass
577 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000578 return wrapper, f
579
580 def test_default_update(self):
581 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000582 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000583 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000584 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600585 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000586 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000587 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
588 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000589
R. David Murray378c0cf2010-02-24 01:46:21 +0000590 @unittest.skipIf(sys.flags.optimize >= 2,
591 "Docstrings are omitted with -O2 and above")
592 def test_default_update_doc(self):
593 wrapper, f = self._default_update()
594 self.assertEqual(wrapper.__doc__, 'This is a test')
595
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000596 def test_no_update(self):
597 def f():
598 """This is a test"""
599 pass
600 f.attr = 'This is also a test'
601 def wrapper():
602 pass
603 functools.update_wrapper(wrapper, f, (), ())
604 self.check_wrapper(wrapper, f, (), ())
605 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600606 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000607 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000608 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000609 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000610
611 def test_selective_update(self):
612 def f():
613 pass
614 f.attr = 'This is a different test'
615 f.dict_attr = dict(a=1, b=2, c=3)
616 def wrapper():
617 pass
618 wrapper.dict_attr = {}
619 assign = ('attr',)
620 update = ('dict_attr',)
621 functools.update_wrapper(wrapper, f, assign, update)
622 self.check_wrapper(wrapper, f, assign, update)
623 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600624 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000625 self.assertEqual(wrapper.__doc__, None)
626 self.assertEqual(wrapper.attr, 'This is a different test')
627 self.assertEqual(wrapper.dict_attr, f.dict_attr)
628
Nick Coghlan98876832010-08-17 06:17:18 +0000629 def test_missing_attributes(self):
630 def f():
631 pass
632 def wrapper():
633 pass
634 wrapper.dict_attr = {}
635 assign = ('attr',)
636 update = ('dict_attr',)
637 # Missing attributes on wrapped object are ignored
638 functools.update_wrapper(wrapper, f, assign, update)
639 self.assertNotIn('attr', wrapper.__dict__)
640 self.assertEqual(wrapper.dict_attr, {})
641 # Wrapper must have expected attributes for updating
642 del wrapper.dict_attr
643 with self.assertRaises(AttributeError):
644 functools.update_wrapper(wrapper, f, assign, update)
645 wrapper.dict_attr = 1
646 with self.assertRaises(AttributeError):
647 functools.update_wrapper(wrapper, f, assign, update)
648
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200649 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000650 @unittest.skipIf(sys.flags.optimize >= 2,
651 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000652 def test_builtin_update(self):
653 # Test for bug #1576241
654 def wrapper():
655 pass
656 functools.update_wrapper(wrapper, max)
657 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000658 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000659 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000660
Łukasz Langa6f692512013-06-05 12:20:24 +0200661
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000662class TestWraps(TestUpdateWrapper):
663
R. David Murray378c0cf2010-02-24 01:46:21 +0000664 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000665 def f():
666 """This is a test"""
667 pass
668 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000669 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000670 @functools.wraps(f)
671 def wrapper():
672 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600673 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000674
675 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600676 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000677 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000678 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600679 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000680 self.assertEqual(wrapper.attr, 'This is also a test')
681
Antoine Pitroub5b37142012-11-13 21:35:40 +0100682 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000683 "Docstrings are omitted with -O2 and above")
684 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600685 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000686 self.assertEqual(wrapper.__doc__, 'This is a test')
687
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000688 def test_no_update(self):
689 def f():
690 """This is a test"""
691 pass
692 f.attr = 'This is also a test'
693 @functools.wraps(f, (), ())
694 def wrapper():
695 pass
696 self.check_wrapper(wrapper, f, (), ())
697 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600698 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000699 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000700 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000701
702 def test_selective_update(self):
703 def f():
704 pass
705 f.attr = 'This is a different test'
706 f.dict_attr = dict(a=1, b=2, c=3)
707 def add_dict_attr(f):
708 f.dict_attr = {}
709 return f
710 assign = ('attr',)
711 update = ('dict_attr',)
712 @functools.wraps(f, assign, update)
713 @add_dict_attr
714 def wrapper():
715 pass
716 self.check_wrapper(wrapper, f, assign, update)
717 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600718 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000719 self.assertEqual(wrapper.__doc__, None)
720 self.assertEqual(wrapper.attr, 'This is a different test')
721 self.assertEqual(wrapper.dict_attr, f.dict_attr)
722
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000723@unittest.skipUnless(c_functools, 'requires the C _functools module')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000724class TestReduce(unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000725 if c_functools:
726 func = c_functools.reduce
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000727
728 def test_reduce(self):
729 class Squares:
730 def __init__(self, max):
731 self.max = max
732 self.sofar = []
733
734 def __len__(self):
735 return len(self.sofar)
736
737 def __getitem__(self, i):
738 if not 0 <= i < self.max: raise IndexError
739 n = len(self.sofar)
740 while n <= i:
741 self.sofar.append(n*n)
742 n += 1
743 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000744 def add(x, y):
745 return x + y
746 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000747 self.assertEqual(
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000748 self.func(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000749 ['a','c','d','w']
750 )
751 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
752 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000753 self.func(lambda x, y: x*y, range(2,21), 1),
754 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000755 )
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000756 self.assertEqual(self.func(add, Squares(10)), 285)
757 self.assertEqual(self.func(add, Squares(10), 0), 285)
758 self.assertEqual(self.func(add, Squares(0), 0), 0)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000759 self.assertRaises(TypeError, self.func)
760 self.assertRaises(TypeError, self.func, 42, 42)
761 self.assertRaises(TypeError, self.func, 42, 42, 42)
762 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
763 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
764 self.assertRaises(TypeError, self.func, 42, (42, 42))
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000765 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
766 self.assertRaises(TypeError, self.func, add, "")
767 self.assertRaises(TypeError, self.func, add, ())
768 self.assertRaises(TypeError, self.func, add, object())
769
770 class TestFailingIter:
771 def __iter__(self):
772 raise RuntimeError
773 self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
774
775 self.assertEqual(self.func(add, [], None), None)
776 self.assertEqual(self.func(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000777
778 class BadSeq:
779 def __getitem__(self, index):
780 raise ValueError
781 self.assertRaises(ValueError, self.func, 42, BadSeq())
782
783 # Test reduce()'s use of iterators.
784 def test_iterator_usage(self):
785 class SequenceClass:
786 def __init__(self, n):
787 self.n = n
788 def __getitem__(self, i):
789 if 0 <= i < self.n:
790 return i
791 else:
792 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000793
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000794 from operator import add
795 self.assertEqual(self.func(add, SequenceClass(5)), 10)
796 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
797 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
798 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
799 self.assertEqual(self.func(add, SequenceClass(1)), 0)
800 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
801
802 d = {"one": 1, "two": 2, "three": 3}
803 self.assertEqual(self.func(add, d), "".join(d.keys()))
804
Łukasz Langa6f692512013-06-05 12:20:24 +0200805
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200806class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700807
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000808 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700809 def cmp1(x, y):
810 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100811 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700812 self.assertEqual(key(3), key(3))
813 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100814 self.assertGreaterEqual(key(3), key(3))
815
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700816 def cmp2(x, y):
817 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100818 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700819 self.assertEqual(key(4.0), key('4'))
820 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100821 self.assertLessEqual(key(2), key('35'))
822 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700823
824 def test_cmp_to_key_arguments(self):
825 def cmp1(x, y):
826 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100827 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700828 self.assertEqual(key(obj=3), key(obj=3))
829 self.assertGreater(key(obj=3), key(obj=1))
830 with self.assertRaises((TypeError, AttributeError)):
831 key(3) > 1 # rhs is not a K object
832 with self.assertRaises((TypeError, AttributeError)):
833 1 < key(3) # lhs is not a K object
834 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100835 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700836 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200837 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100838 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700839 with self.assertRaises(TypeError):
840 key() # too few args
841 with self.assertRaises(TypeError):
842 key(None, None) # too many args
843
844 def test_bad_cmp(self):
845 def cmp1(x, y):
846 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100847 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700848 with self.assertRaises(ZeroDivisionError):
849 key(3) > key(1)
850
851 class BadCmp:
852 def __lt__(self, other):
853 raise ZeroDivisionError
854 def cmp1(x, y):
855 return BadCmp()
856 with self.assertRaises(ZeroDivisionError):
857 key(3) > key(1)
858
859 def test_obj_field(self):
860 def cmp1(x, y):
861 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100862 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700863 self.assertEqual(key(50).obj, 50)
864
865 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000866 def mycmp(x, y):
867 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100868 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000869 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000870
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700871 def test_sort_int_str(self):
872 def mycmp(x, y):
873 x, y = int(x), int(y)
874 return (x > y) - (x < y)
875 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100876 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700877 self.assertEqual([int(value) for value in values],
878 [0, 1, 1, 2, 3, 4, 5, 7, 10])
879
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000880 def test_hash(self):
881 def mycmp(x, y):
882 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100883 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000884 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700885 self.assertRaises(TypeError, hash, k)
Raymond Hettingere7a24302011-05-03 11:16:36 -0700886 self.assertNotIsInstance(k, collections.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000887
Łukasz Langa6f692512013-06-05 12:20:24 +0200888
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200889@unittest.skipUnless(c_functools, 'requires the C _functools module')
890class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
891 if c_functools:
892 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100893
Łukasz Langa6f692512013-06-05 12:20:24 +0200894
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200895class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100896 cmp_to_key = staticmethod(py_functools.cmp_to_key)
897
Łukasz Langa6f692512013-06-05 12:20:24 +0200898
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000899class TestTotalOrdering(unittest.TestCase):
900
901 def test_total_ordering_lt(self):
902 @functools.total_ordering
903 class A:
904 def __init__(self, value):
905 self.value = value
906 def __lt__(self, other):
907 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000908 def __eq__(self, other):
909 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000910 self.assertTrue(A(1) < A(2))
911 self.assertTrue(A(2) > A(1))
912 self.assertTrue(A(1) <= A(2))
913 self.assertTrue(A(2) >= A(1))
914 self.assertTrue(A(2) <= A(2))
915 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000916 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000917
918 def test_total_ordering_le(self):
919 @functools.total_ordering
920 class A:
921 def __init__(self, value):
922 self.value = value
923 def __le__(self, other):
924 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000925 def __eq__(self, other):
926 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000927 self.assertTrue(A(1) < A(2))
928 self.assertTrue(A(2) > A(1))
929 self.assertTrue(A(1) <= A(2))
930 self.assertTrue(A(2) >= A(1))
931 self.assertTrue(A(2) <= A(2))
932 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000933 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000934
935 def test_total_ordering_gt(self):
936 @functools.total_ordering
937 class A:
938 def __init__(self, value):
939 self.value = value
940 def __gt__(self, other):
941 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000942 def __eq__(self, other):
943 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000944 self.assertTrue(A(1) < A(2))
945 self.assertTrue(A(2) > A(1))
946 self.assertTrue(A(1) <= A(2))
947 self.assertTrue(A(2) >= A(1))
948 self.assertTrue(A(2) <= A(2))
949 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000950 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000951
952 def test_total_ordering_ge(self):
953 @functools.total_ordering
954 class A:
955 def __init__(self, value):
956 self.value = value
957 def __ge__(self, other):
958 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000959 def __eq__(self, other):
960 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000961 self.assertTrue(A(1) < A(2))
962 self.assertTrue(A(2) > A(1))
963 self.assertTrue(A(1) <= A(2))
964 self.assertTrue(A(2) >= A(1))
965 self.assertTrue(A(2) <= A(2))
966 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000967 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000968
969 def test_total_ordering_no_overwrite(self):
970 # new methods should not overwrite existing
971 @functools.total_ordering
972 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000973 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +0000974 self.assertTrue(A(1) < A(2))
975 self.assertTrue(A(2) > A(1))
976 self.assertTrue(A(1) <= A(2))
977 self.assertTrue(A(2) >= A(1))
978 self.assertTrue(A(2) <= A(2))
979 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000980
Benjamin Peterson42ebee32010-04-11 01:43:16 +0000981 def test_no_operations_defined(self):
982 with self.assertRaises(ValueError):
983 @functools.total_ordering
984 class A:
985 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000986
Nick Coghlanf05d9812013-10-02 00:02:03 +1000987 def test_type_error_when_not_implemented(self):
988 # bug 10042; ensure stack overflow does not occur
989 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000990 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +1000991 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000992 def __init__(self, value):
993 self.value = value
994 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000995 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000996 return self.value == other.value
997 return False
998 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000999 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001000 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +10001001 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001002
Nick Coghlanf05d9812013-10-02 00:02:03 +10001003 @functools.total_ordering
1004 class ImplementsGreaterThan:
1005 def __init__(self, value):
1006 self.value = value
1007 def __eq__(self, other):
1008 if isinstance(other, ImplementsGreaterThan):
1009 return self.value == other.value
1010 return False
1011 def __gt__(self, other):
1012 if isinstance(other, ImplementsGreaterThan):
1013 return self.value > other.value
1014 return NotImplemented
1015
1016 @functools.total_ordering
1017 class ImplementsLessThanEqualTo:
1018 def __init__(self, value):
1019 self.value = value
1020 def __eq__(self, other):
1021 if isinstance(other, ImplementsLessThanEqualTo):
1022 return self.value == other.value
1023 return False
1024 def __le__(self, other):
1025 if isinstance(other, ImplementsLessThanEqualTo):
1026 return self.value <= other.value
1027 return NotImplemented
1028
1029 @functools.total_ordering
1030 class ImplementsGreaterThanEqualTo:
1031 def __init__(self, value):
1032 self.value = value
1033 def __eq__(self, other):
1034 if isinstance(other, ImplementsGreaterThanEqualTo):
1035 return self.value == other.value
1036 return False
1037 def __ge__(self, other):
1038 if isinstance(other, ImplementsGreaterThanEqualTo):
1039 return self.value >= other.value
1040 return NotImplemented
1041
1042 @functools.total_ordering
1043 class ComparatorNotImplemented:
1044 def __init__(self, value):
1045 self.value = value
1046 def __eq__(self, other):
1047 if isinstance(other, ComparatorNotImplemented):
1048 return self.value == other.value
1049 return False
1050 def __lt__(self, other):
1051 return NotImplemented
1052
1053 with self.subTest("LT < 1"), self.assertRaises(TypeError):
1054 ImplementsLessThan(-1) < 1
1055
1056 with self.subTest("LT < LE"), self.assertRaises(TypeError):
1057 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
1058
1059 with self.subTest("LT < GT"), self.assertRaises(TypeError):
1060 ImplementsLessThan(1) < ImplementsGreaterThan(1)
1061
1062 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
1063 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
1064
1065 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
1066 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
1067
1068 with self.subTest("GT > GE"), self.assertRaises(TypeError):
1069 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
1070
1071 with self.subTest("GT > LT"), self.assertRaises(TypeError):
1072 ImplementsGreaterThan(5) > ImplementsLessThan(5)
1073
1074 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
1075 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
1076
1077 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
1078 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
1079
1080 with self.subTest("GE when equal"):
1081 a = ComparatorNotImplemented(8)
1082 b = ComparatorNotImplemented(8)
1083 self.assertEqual(a, b)
1084 with self.assertRaises(TypeError):
1085 a >= b
1086
1087 with self.subTest("LE when equal"):
1088 a = ComparatorNotImplemented(9)
1089 b = ComparatorNotImplemented(9)
1090 self.assertEqual(a, b)
1091 with self.assertRaises(TypeError):
1092 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +02001093
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001094 def test_pickle(self):
Serhiy Storchaka92bb90a2016-09-22 11:39:25 +03001095 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001096 for name in '__lt__', '__gt__', '__le__', '__ge__':
1097 with self.subTest(method=name, proto=proto):
1098 method = getattr(Orderable_LT, name)
1099 method_copy = pickle.loads(pickle.dumps(method, proto))
1100 self.assertIs(method_copy, method)
1101
1102@functools.total_ordering
1103class Orderable_LT:
1104 def __init__(self, value):
1105 self.value = value
1106 def __lt__(self, other):
1107 return self.value < other.value
1108 def __eq__(self, other):
1109 return self.value == other.value
1110
1111
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001112class TestLRU:
Georg Brandl2e7346a2010-07-31 18:09:23 +00001113
1114 def test_lru(self):
1115 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001116 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001117 f = self.module.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001118 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001119 self.assertEqual(maxsize, 20)
1120 self.assertEqual(currsize, 0)
1121 self.assertEqual(hits, 0)
1122 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001123
1124 domain = range(5)
1125 for i in range(1000):
1126 x, y = choice(domain), choice(domain)
1127 actual = f(x, y)
1128 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +00001129 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001130 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001131 self.assertTrue(hits > misses)
1132 self.assertEqual(hits + misses, 1000)
1133 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001134
Raymond Hettinger02566ec2010-09-04 22:46:06 +00001135 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +00001136 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001137 self.assertEqual(hits, 0)
1138 self.assertEqual(misses, 0)
1139 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001140 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001141 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001142 self.assertEqual(hits, 0)
1143 self.assertEqual(misses, 1)
1144 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001145
Nick Coghlan98876832010-08-17 06:17:18 +00001146 # Test bypassing the cache
1147 self.assertIs(f.__wrapped__, orig)
1148 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001149 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001150 self.assertEqual(hits, 0)
1151 self.assertEqual(misses, 1)
1152 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +00001153
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001154 # test size zero (which means "never-cache")
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001155 @self.module.lru_cache(0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001156 def f():
1157 nonlocal f_cnt
1158 f_cnt += 1
1159 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001160 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001161 f_cnt = 0
1162 for i in range(5):
1163 self.assertEqual(f(), 20)
1164 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001165 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001166 self.assertEqual(hits, 0)
1167 self.assertEqual(misses, 5)
1168 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001169
1170 # test size one
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001171 @self.module.lru_cache(1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001172 def f():
1173 nonlocal f_cnt
1174 f_cnt += 1
1175 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001176 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001177 f_cnt = 0
1178 for i in range(5):
1179 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001180 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001181 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001182 self.assertEqual(hits, 4)
1183 self.assertEqual(misses, 1)
1184 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001185
Raymond Hettingerf3098282010-08-15 03:30:45 +00001186 # test size two
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001187 @self.module.lru_cache(2)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001188 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001189 nonlocal f_cnt
1190 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +00001191 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +00001192 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001193 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +00001194 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1195 # * * * *
1196 self.assertEqual(f(x), x*10)
1197 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001198 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001199 self.assertEqual(hits, 12)
1200 self.assertEqual(misses, 4)
1201 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001202
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08001203 def test_lru_reentrancy_with_len(self):
1204 # Test to make sure the LRU cache code isn't thrown-off by
1205 # caching the built-in len() function. Since len() can be
1206 # cached, we shouldn't use it inside the lru code itself.
1207 old_len = builtins.len
1208 try:
1209 builtins.len = self.module.lru_cache(4)(len)
1210 for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
1211 self.assertEqual(len('abcdefghijklmn'[:i]), i)
1212 finally:
1213 builtins.len = old_len
1214
Yury Selivanov46a02db2016-11-09 18:55:45 -05001215 def test_lru_type_error(self):
1216 # Regression test for issue #28653.
1217 # lru_cache was leaking when one of the arguments
1218 # wasn't cacheable.
1219
1220 @functools.lru_cache(maxsize=None)
1221 def infinite_cache(o):
1222 pass
1223
1224 @functools.lru_cache(maxsize=10)
1225 def limited_cache(o):
1226 pass
1227
1228 with self.assertRaises(TypeError):
1229 infinite_cache([])
1230
1231 with self.assertRaises(TypeError):
1232 limited_cache([])
1233
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001234 def test_lru_with_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001235 @self.module.lru_cache(maxsize=None)
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001236 def fib(n):
1237 if n < 2:
1238 return n
1239 return fib(n-1) + fib(n-2)
1240 self.assertEqual([fib(n) for n in range(16)],
1241 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1242 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001243 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001244 fib.cache_clear()
1245 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001246 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1247
1248 def test_lru_with_maxsize_negative(self):
1249 @self.module.lru_cache(maxsize=-10)
1250 def eq(n):
1251 return n
1252 for i in (0, 1):
1253 self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1254 self.assertEqual(eq.cache_info(),
1255 self.module._CacheInfo(hits=0, misses=300, maxsize=-10, currsize=1))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001256
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001257 def test_lru_with_exceptions(self):
1258 # Verify that user_function exceptions get passed through without
1259 # creating a hard-to-read chained exception.
1260 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +01001261 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001262 @self.module.lru_cache(maxsize)
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001263 def func(i):
1264 return 'abc'[i]
1265 self.assertEqual(func(0), 'a')
1266 with self.assertRaises(IndexError) as cm:
1267 func(15)
1268 self.assertIsNone(cm.exception.__context__)
1269 # Verify that the previous exception did not result in a cached entry
1270 with self.assertRaises(IndexError):
1271 func(15)
1272
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001273 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001274 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001275 @self.module.lru_cache(maxsize=maxsize, typed=True)
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001276 def square(x):
1277 return x * x
1278 self.assertEqual(square(3), 9)
1279 self.assertEqual(type(square(3)), type(9))
1280 self.assertEqual(square(3.0), 9.0)
1281 self.assertEqual(type(square(3.0)), type(9.0))
1282 self.assertEqual(square(x=3), 9)
1283 self.assertEqual(type(square(x=3)), type(9))
1284 self.assertEqual(square(x=3.0), 9.0)
1285 self.assertEqual(type(square(x=3.0)), type(9.0))
1286 self.assertEqual(square.cache_info().hits, 4)
1287 self.assertEqual(square.cache_info().misses, 4)
1288
Antoine Pitroub5b37142012-11-13 21:35:40 +01001289 def test_lru_with_keyword_args(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001290 @self.module.lru_cache()
Antoine Pitroub5b37142012-11-13 21:35:40 +01001291 def fib(n):
1292 if n < 2:
1293 return n
1294 return fib(n=n-1) + fib(n=n-2)
1295 self.assertEqual(
1296 [fib(n=number) for number in range(16)],
1297 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1298 )
1299 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001300 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001301 fib.cache_clear()
1302 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001303 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001304
1305 def test_lru_with_keyword_args_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001306 @self.module.lru_cache(maxsize=None)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001307 def fib(n):
1308 if n < 2:
1309 return n
1310 return fib(n=n-1) + fib(n=n-2)
1311 self.assertEqual([fib(n=number) for number in range(16)],
1312 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1313 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001314 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001315 fib.cache_clear()
1316 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001317 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1318
Raymond Hettinger4ee39142017-01-08 17:28:20 -08001319 def test_kwargs_order(self):
1320 # PEP 468: Preserving Keyword Argument Order
1321 @self.module.lru_cache(maxsize=10)
1322 def f(**kwargs):
1323 return list(kwargs.items())
1324 self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)])
1325 self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)])
1326 self.assertEqual(f.cache_info(),
1327 self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2))
1328
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001329 def test_lru_cache_decoration(self):
1330 def f(zomg: 'zomg_annotation'):
1331 """f doc string"""
1332 return 42
1333 g = self.module.lru_cache()(f)
1334 for attr in self.module.WRAPPER_ASSIGNMENTS:
1335 self.assertEqual(getattr(g, attr), getattr(f, attr))
1336
1337 @unittest.skipUnless(threading, 'This test requires threading.')
1338 def test_lru_cache_threaded(self):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001339 n, m = 5, 11
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001340 def orig(x, y):
1341 return 3 * x + y
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001342 f = self.module.lru_cache(maxsize=n*m)(orig)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001343 hits, misses, maxsize, currsize = f.cache_info()
1344 self.assertEqual(currsize, 0)
1345
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001346 start = threading.Event()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001347 def full(k):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001348 start.wait(10)
1349 for _ in range(m):
Serhiy Storchaka391af752015-06-08 12:44:18 +03001350 self.assertEqual(f(k, 0), orig(k, 0))
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001351
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001352 def clear():
1353 start.wait(10)
1354 for _ in range(2*m):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001355 f.cache_clear()
1356
1357 orig_si = sys.getswitchinterval()
Xavier de Gaye7522ef42016-12-08 11:06:56 +01001358 support.setswitchinterval(1e-6)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001359 try:
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001360 # create n threads in order to fill cache
Serhiy Storchaka391af752015-06-08 12:44:18 +03001361 threads = [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001362 for k in range(n)]
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001363 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001364 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001365
1366 hits, misses, maxsize, currsize = f.cache_info()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001367 if self.module is py_functools:
1368 # XXX: Why can be not equal?
1369 self.assertLessEqual(misses, n)
1370 self.assertLessEqual(hits, m*n - misses)
1371 else:
1372 self.assertEqual(misses, n)
1373 self.assertEqual(hits, m*n - misses)
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001374 self.assertEqual(currsize, n)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001375
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001376 # create n threads in order to fill cache and 1 to clear it
1377 threads = [threading.Thread(target=clear)]
Serhiy Storchaka391af752015-06-08 12:44:18 +03001378 threads += [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001379 for k in range(n)]
1380 start.clear()
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001381 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001382 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001383 finally:
1384 sys.setswitchinterval(orig_si)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001385
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001386 @unittest.skipUnless(threading, 'This test requires threading.')
1387 def test_lru_cache_threaded2(self):
1388 # Simultaneous call with the same arguments
1389 n, m = 5, 7
1390 start = threading.Barrier(n+1)
1391 pause = threading.Barrier(n+1)
1392 stop = threading.Barrier(n+1)
1393 @self.module.lru_cache(maxsize=m*n)
1394 def f(x):
1395 pause.wait(10)
1396 return 3 * x
1397 self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1398 def test():
1399 for i in range(m):
1400 start.wait(10)
1401 self.assertEqual(f(i), 3 * i)
1402 stop.wait(10)
1403 threads = [threading.Thread(target=test) for k in range(n)]
1404 with support.start_threads(threads):
1405 for i in range(m):
1406 start.wait(10)
1407 stop.reset()
1408 pause.wait(10)
1409 start.reset()
1410 stop.wait(10)
1411 pause.reset()
1412 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1413
Serhiy Storchaka67796522017-01-12 18:34:33 +02001414 @unittest.skipUnless(threading, 'This test requires threading.')
1415 def test_lru_cache_threaded3(self):
1416 @self.module.lru_cache(maxsize=2)
1417 def f(x):
1418 time.sleep(.01)
1419 return 3 * x
1420 def test(i, x):
1421 with self.subTest(thread=i):
1422 self.assertEqual(f(x), 3 * x, i)
1423 threads = [threading.Thread(target=test, args=(i, v))
1424 for i, v in enumerate([1, 2, 2, 3, 2])]
1425 with support.start_threads(threads):
1426 pass
1427
Raymond Hettinger03923422013-03-04 02:52:50 -05001428 def test_need_for_rlock(self):
1429 # This will deadlock on an LRU cache that uses a regular lock
1430
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001431 @self.module.lru_cache(maxsize=10)
Raymond Hettinger03923422013-03-04 02:52:50 -05001432 def test_func(x):
1433 'Used to demonstrate a reentrant lru_cache call within a single thread'
1434 return x
1435
1436 class DoubleEq:
1437 'Demonstrate a reentrant lru_cache call within a single thread'
1438 def __init__(self, x):
1439 self.x = x
1440 def __hash__(self):
1441 return self.x
1442 def __eq__(self, other):
1443 if self.x == 2:
1444 test_func(DoubleEq(1))
1445 return self.x == other.x
1446
1447 test_func(DoubleEq(1)) # Load the cache
1448 test_func(DoubleEq(2)) # Load the cache
1449 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1450 DoubleEq(2)) # Verify the correct return value
1451
Raymond Hettinger4d588972014-08-12 12:44:52 -07001452 def test_early_detection_of_bad_call(self):
1453 # Issue #22184
1454 with self.assertRaises(TypeError):
1455 @functools.lru_cache
1456 def f():
1457 pass
1458
Serhiy Storchakae7070f02015-06-08 11:19:24 +03001459 def test_lru_method(self):
1460 class X(int):
1461 f_cnt = 0
1462 @self.module.lru_cache(2)
1463 def f(self, x):
1464 self.f_cnt += 1
1465 return x*10+self
1466 a = X(5)
1467 b = X(5)
1468 c = X(7)
1469 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1470
1471 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1472 self.assertEqual(a.f(x), x*10 + 5)
1473 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1474 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1475
1476 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1477 self.assertEqual(b.f(x), x*10 + 5)
1478 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1479 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1480
1481 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1482 self.assertEqual(c.f(x), x*10 + 7)
1483 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1484 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1485
1486 self.assertEqual(a.f.cache_info(), X.f.cache_info())
1487 self.assertEqual(b.f.cache_info(), X.f.cache_info())
1488 self.assertEqual(c.f.cache_info(), X.f.cache_info())
1489
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001490 def test_pickle(self):
1491 cls = self.__class__
1492 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1493 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1494 with self.subTest(proto=proto, func=f):
1495 f_copy = pickle.loads(pickle.dumps(f, proto))
1496 self.assertIs(f_copy, f)
1497
1498 def test_copy(self):
1499 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001500 def orig(x, y):
1501 return 3 * x + y
1502 part = self.module.partial(orig, 2)
1503 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1504 self.module.lru_cache(2)(part))
1505 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001506 with self.subTest(func=f):
1507 f_copy = copy.copy(f)
1508 self.assertIs(f_copy, f)
1509
1510 def test_deepcopy(self):
1511 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001512 def orig(x, y):
1513 return 3 * x + y
1514 part = self.module.partial(orig, 2)
1515 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1516 self.module.lru_cache(2)(part))
1517 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001518 with self.subTest(func=f):
1519 f_copy = copy.deepcopy(f)
1520 self.assertIs(f_copy, f)
1521
1522
1523@py_functools.lru_cache()
1524def py_cached_func(x, y):
1525 return 3 * x + y
1526
1527@c_functools.lru_cache()
1528def c_cached_func(x, y):
1529 return 3 * x + y
1530
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001531
1532class TestLRUPy(TestLRU, unittest.TestCase):
1533 module = py_functools
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001534 cached_func = py_cached_func,
1535
1536 @module.lru_cache()
1537 def cached_meth(self, x, y):
1538 return 3 * x + y
1539
1540 @staticmethod
1541 @module.lru_cache()
1542 def cached_staticmeth(x, y):
1543 return 3 * x + y
1544
1545
1546class TestLRUC(TestLRU, unittest.TestCase):
1547 module = c_functools
1548 cached_func = c_cached_func,
1549
1550 @module.lru_cache()
1551 def cached_meth(self, x, y):
1552 return 3 * x + y
1553
1554 @staticmethod
1555 @module.lru_cache()
1556 def cached_staticmeth(x, y):
1557 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001558
Raymond Hettinger03923422013-03-04 02:52:50 -05001559
Łukasz Langa6f692512013-06-05 12:20:24 +02001560class TestSingleDispatch(unittest.TestCase):
1561 def test_simple_overloads(self):
1562 @functools.singledispatch
1563 def g(obj):
1564 return "base"
1565 def g_int(i):
1566 return "integer"
1567 g.register(int, g_int)
1568 self.assertEqual(g("str"), "base")
1569 self.assertEqual(g(1), "integer")
1570 self.assertEqual(g([1,2,3]), "base")
1571
1572 def test_mro(self):
1573 @functools.singledispatch
1574 def g(obj):
1575 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001576 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001577 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001578 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001579 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001580 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001581 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001582 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001583 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001584 def g_A(a):
1585 return "A"
1586 def g_B(b):
1587 return "B"
1588 g.register(A, g_A)
1589 g.register(B, g_B)
1590 self.assertEqual(g(A()), "A")
1591 self.assertEqual(g(B()), "B")
1592 self.assertEqual(g(C()), "A")
1593 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001594
1595 def test_register_decorator(self):
1596 @functools.singledispatch
1597 def g(obj):
1598 return "base"
1599 @g.register(int)
1600 def g_int(i):
1601 return "int %s" % (i,)
1602 self.assertEqual(g(""), "base")
1603 self.assertEqual(g(12), "int 12")
1604 self.assertIs(g.dispatch(int), g_int)
1605 self.assertIs(g.dispatch(object), g.dispatch(str))
1606 # Note: in the assert above this is not g.
1607 # @singledispatch returns the wrapper.
1608
1609 def test_wrapping_attributes(self):
1610 @functools.singledispatch
1611 def g(obj):
1612 "Simple test"
1613 return "Test"
1614 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02001615 if sys.flags.optimize < 2:
1616 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02001617
1618 @unittest.skipUnless(decimal, 'requires _decimal')
1619 @support.cpython_only
1620 def test_c_classes(self):
1621 @functools.singledispatch
1622 def g(obj):
1623 return "base"
1624 @g.register(decimal.DecimalException)
1625 def _(obj):
1626 return obj.args
1627 subn = decimal.Subnormal("Exponent < Emin")
1628 rnd = decimal.Rounded("Number got rounded")
1629 self.assertEqual(g(subn), ("Exponent < Emin",))
1630 self.assertEqual(g(rnd), ("Number got rounded",))
1631 @g.register(decimal.Subnormal)
1632 def _(obj):
1633 return "Too small to care."
1634 self.assertEqual(g(subn), "Too small to care.")
1635 self.assertEqual(g(rnd), ("Number got rounded",))
1636
1637 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001638 # None of the examples in this test depend on haystack ordering.
Łukasz Langa6f692512013-06-05 12:20:24 +02001639 c = collections
1640 mro = functools._compose_mro
1641 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1642 for haystack in permutations(bases):
1643 m = mro(dict, haystack)
Guido van Rossumf0666942016-08-23 10:47:07 -07001644 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
1645 c.Collection, c.Sized, c.Iterable,
1646 c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001647 bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict]
1648 for haystack in permutations(bases):
1649 m = mro(c.ChainMap, haystack)
1650 self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001651 c.Collection, c.Sized, c.Iterable,
1652 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001653
1654 # If there's a generic function with implementations registered for
1655 # both Sized and Container, passing a defaultdict to it results in an
1656 # ambiguous dispatch which will cause a RuntimeError (see
1657 # test_mro_conflicts).
1658 bases = [c.Container, c.Sized, str]
1659 for haystack in permutations(bases):
1660 m = mro(c.defaultdict, [c.Sized, c.Container, str])
1661 self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container,
1662 object])
1663
1664 # MutableSequence below is registered directly on D. In other words, it
Martin Panter46f50722016-05-26 05:35:26 +00001665 # precedes MutableMapping which means single dispatch will always
Łukasz Langa3720c772013-07-01 16:00:38 +02001666 # choose MutableSequence here.
1667 class D(c.defaultdict):
1668 pass
1669 c.MutableSequence.register(D)
1670 bases = [c.MutableSequence, c.MutableMapping]
1671 for haystack in permutations(bases):
1672 m = mro(D, bases)
Guido van Rossumf0666942016-08-23 10:47:07 -07001673 self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
1674 c.defaultdict, dict, c.MutableMapping, c.Mapping,
1675 c.Collection, c.Sized, c.Iterable, c.Container,
Łukasz Langa3720c772013-07-01 16:00:38 +02001676 object])
1677
1678 # Container and Callable are registered on different base classes and
1679 # a generic function supporting both should always pick the Callable
1680 # implementation if a C instance is passed.
1681 class C(c.defaultdict):
1682 def __call__(self):
1683 pass
1684 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1685 for haystack in permutations(bases):
1686 m = mro(C, haystack)
1687 self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001688 c.Collection, c.Sized, c.Iterable,
1689 c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001690
1691 def test_register_abc(self):
1692 c = collections
1693 d = {"a": "b"}
1694 l = [1, 2, 3]
1695 s = {object(), None}
1696 f = frozenset(s)
1697 t = (1, 2, 3)
1698 @functools.singledispatch
1699 def g(obj):
1700 return "base"
1701 self.assertEqual(g(d), "base")
1702 self.assertEqual(g(l), "base")
1703 self.assertEqual(g(s), "base")
1704 self.assertEqual(g(f), "base")
1705 self.assertEqual(g(t), "base")
1706 g.register(c.Sized, lambda obj: "sized")
1707 self.assertEqual(g(d), "sized")
1708 self.assertEqual(g(l), "sized")
1709 self.assertEqual(g(s), "sized")
1710 self.assertEqual(g(f), "sized")
1711 self.assertEqual(g(t), "sized")
1712 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1713 self.assertEqual(g(d), "mutablemapping")
1714 self.assertEqual(g(l), "sized")
1715 self.assertEqual(g(s), "sized")
1716 self.assertEqual(g(f), "sized")
1717 self.assertEqual(g(t), "sized")
1718 g.register(c.ChainMap, lambda obj: "chainmap")
1719 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1720 self.assertEqual(g(l), "sized")
1721 self.assertEqual(g(s), "sized")
1722 self.assertEqual(g(f), "sized")
1723 self.assertEqual(g(t), "sized")
1724 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1725 self.assertEqual(g(d), "mutablemapping")
1726 self.assertEqual(g(l), "mutablesequence")
1727 self.assertEqual(g(s), "sized")
1728 self.assertEqual(g(f), "sized")
1729 self.assertEqual(g(t), "sized")
1730 g.register(c.MutableSet, lambda obj: "mutableset")
1731 self.assertEqual(g(d), "mutablemapping")
1732 self.assertEqual(g(l), "mutablesequence")
1733 self.assertEqual(g(s), "mutableset")
1734 self.assertEqual(g(f), "sized")
1735 self.assertEqual(g(t), "sized")
1736 g.register(c.Mapping, lambda obj: "mapping")
1737 self.assertEqual(g(d), "mutablemapping") # not specific enough
1738 self.assertEqual(g(l), "mutablesequence")
1739 self.assertEqual(g(s), "mutableset")
1740 self.assertEqual(g(f), "sized")
1741 self.assertEqual(g(t), "sized")
1742 g.register(c.Sequence, lambda obj: "sequence")
1743 self.assertEqual(g(d), "mutablemapping")
1744 self.assertEqual(g(l), "mutablesequence")
1745 self.assertEqual(g(s), "mutableset")
1746 self.assertEqual(g(f), "sized")
1747 self.assertEqual(g(t), "sequence")
1748 g.register(c.Set, lambda obj: "set")
1749 self.assertEqual(g(d), "mutablemapping")
1750 self.assertEqual(g(l), "mutablesequence")
1751 self.assertEqual(g(s), "mutableset")
1752 self.assertEqual(g(f), "set")
1753 self.assertEqual(g(t), "sequence")
1754 g.register(dict, lambda obj: "dict")
1755 self.assertEqual(g(d), "dict")
1756 self.assertEqual(g(l), "mutablesequence")
1757 self.assertEqual(g(s), "mutableset")
1758 self.assertEqual(g(f), "set")
1759 self.assertEqual(g(t), "sequence")
1760 g.register(list, lambda obj: "list")
1761 self.assertEqual(g(d), "dict")
1762 self.assertEqual(g(l), "list")
1763 self.assertEqual(g(s), "mutableset")
1764 self.assertEqual(g(f), "set")
1765 self.assertEqual(g(t), "sequence")
1766 g.register(set, lambda obj: "concrete-set")
1767 self.assertEqual(g(d), "dict")
1768 self.assertEqual(g(l), "list")
1769 self.assertEqual(g(s), "concrete-set")
1770 self.assertEqual(g(f), "set")
1771 self.assertEqual(g(t), "sequence")
1772 g.register(frozenset, lambda obj: "frozen-set")
1773 self.assertEqual(g(d), "dict")
1774 self.assertEqual(g(l), "list")
1775 self.assertEqual(g(s), "concrete-set")
1776 self.assertEqual(g(f), "frozen-set")
1777 self.assertEqual(g(t), "sequence")
1778 g.register(tuple, lambda obj: "tuple")
1779 self.assertEqual(g(d), "dict")
1780 self.assertEqual(g(l), "list")
1781 self.assertEqual(g(s), "concrete-set")
1782 self.assertEqual(g(f), "frozen-set")
1783 self.assertEqual(g(t), "tuple")
1784
Łukasz Langa3720c772013-07-01 16:00:38 +02001785 def test_c3_abc(self):
1786 c = collections
1787 mro = functools._c3_mro
1788 class A(object):
1789 pass
1790 class B(A):
1791 def __len__(self):
1792 return 0 # implies Sized
1793 @c.Container.register
1794 class C(object):
1795 pass
1796 class D(object):
1797 pass # unrelated
1798 class X(D, C, B):
1799 def __call__(self):
1800 pass # implies Callable
1801 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
1802 for abcs in permutations([c.Sized, c.Callable, c.Container]):
1803 self.assertEqual(mro(X, abcs=abcs), expected)
1804 # unrelated ABCs don't appear in the resulting MRO
1805 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
1806 self.assertEqual(mro(X, abcs=many_abcs), expected)
1807
Yury Selivanov77a8cd62015-08-18 14:20:00 -04001808 def test_false_meta(self):
1809 # see issue23572
1810 class MetaA(type):
1811 def __len__(self):
1812 return 0
1813 class A(metaclass=MetaA):
1814 pass
1815 class AA(A):
1816 pass
1817 @functools.singledispatch
1818 def fun(a):
1819 return 'base A'
1820 @fun.register(A)
1821 def _(a):
1822 return 'fun A'
1823 aa = AA()
1824 self.assertEqual(fun(aa), 'fun A')
1825
Łukasz Langa6f692512013-06-05 12:20:24 +02001826 def test_mro_conflicts(self):
1827 c = collections
Łukasz Langa6f692512013-06-05 12:20:24 +02001828 @functools.singledispatch
1829 def g(arg):
1830 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02001831 class O(c.Sized):
1832 def __len__(self):
1833 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001834 o = O()
1835 self.assertEqual(g(o), "base")
1836 g.register(c.Iterable, lambda arg: "iterable")
1837 g.register(c.Container, lambda arg: "container")
1838 g.register(c.Sized, lambda arg: "sized")
1839 g.register(c.Set, lambda arg: "set")
1840 self.assertEqual(g(o), "sized")
1841 c.Iterable.register(O)
1842 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
1843 c.Container.register(O)
1844 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02001845 c.Set.register(O)
1846 self.assertEqual(g(o), "set") # because c.Set is a subclass of
1847 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02001848 class P:
1849 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02001850 p = P()
1851 self.assertEqual(g(p), "base")
1852 c.Iterable.register(P)
1853 self.assertEqual(g(p), "iterable")
1854 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02001855 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02001856 g(p)
Benjamin Petersonab078e92016-07-13 21:13:29 -07001857 self.assertIn(
1858 str(re_one.exception),
1859 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1860 "or <class 'collections.abc.Iterable'>"),
1861 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
1862 "or <class 'collections.abc.Container'>")),
1863 )
Łukasz Langa6f692512013-06-05 12:20:24 +02001864 class Q(c.Sized):
1865 def __len__(self):
1866 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001867 q = Q()
1868 self.assertEqual(g(q), "sized")
1869 c.Iterable.register(Q)
1870 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
1871 c.Set.register(Q)
1872 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02001873 # c.Sized and c.Iterable
1874 @functools.singledispatch
1875 def h(arg):
1876 return "base"
1877 @h.register(c.Sized)
1878 def _(arg):
1879 return "sized"
1880 @h.register(c.Container)
1881 def _(arg):
1882 return "container"
1883 # Even though Sized and Container are explicit bases of MutableMapping,
1884 # this ABC is implicitly registered on defaultdict which makes all of
1885 # MutableMapping's bases implicit as well from defaultdict's
1886 # perspective.
1887 with self.assertRaises(RuntimeError) as re_two:
1888 h(c.defaultdict(lambda: 0))
Benjamin Petersonab078e92016-07-13 21:13:29 -07001889 self.assertIn(
1890 str(re_two.exception),
1891 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1892 "or <class 'collections.abc.Sized'>"),
1893 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1894 "or <class 'collections.abc.Container'>")),
1895 )
Łukasz Langa3720c772013-07-01 16:00:38 +02001896 class R(c.defaultdict):
1897 pass
1898 c.MutableSequence.register(R)
1899 @functools.singledispatch
1900 def i(arg):
1901 return "base"
1902 @i.register(c.MutableMapping)
1903 def _(arg):
1904 return "mapping"
1905 @i.register(c.MutableSequence)
1906 def _(arg):
1907 return "sequence"
1908 r = R()
1909 self.assertEqual(i(r), "sequence")
1910 class S:
1911 pass
1912 class T(S, c.Sized):
1913 def __len__(self):
1914 return 0
1915 t = T()
1916 self.assertEqual(h(t), "sized")
1917 c.Container.register(T)
1918 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
1919 class U:
1920 def __len__(self):
1921 return 0
1922 u = U()
1923 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
1924 # from the existence of __len__()
1925 c.Container.register(U)
1926 # There is no preference for registered versus inferred ABCs.
1927 with self.assertRaises(RuntimeError) as re_three:
1928 h(u)
Benjamin Petersonab078e92016-07-13 21:13:29 -07001929 self.assertIn(
1930 str(re_three.exception),
1931 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1932 "or <class 'collections.abc.Sized'>"),
1933 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1934 "or <class 'collections.abc.Container'>")),
1935 )
Łukasz Langa3720c772013-07-01 16:00:38 +02001936 class V(c.Sized, S):
1937 def __len__(self):
1938 return 0
1939 @functools.singledispatch
1940 def j(arg):
1941 return "base"
1942 @j.register(S)
1943 def _(arg):
1944 return "s"
1945 @j.register(c.Container)
1946 def _(arg):
1947 return "container"
1948 v = V()
1949 self.assertEqual(j(v), "s")
1950 c.Container.register(V)
1951 self.assertEqual(j(v), "container") # because it ends up right after
1952 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02001953
1954 def test_cache_invalidation(self):
1955 from collections import UserDict
1956 class TracingDict(UserDict):
1957 def __init__(self, *args, **kwargs):
1958 super(TracingDict, self).__init__(*args, **kwargs)
1959 self.set_ops = []
1960 self.get_ops = []
1961 def __getitem__(self, key):
1962 result = self.data[key]
1963 self.get_ops.append(key)
1964 return result
1965 def __setitem__(self, key, value):
1966 self.set_ops.append(key)
1967 self.data[key] = value
1968 def clear(self):
1969 self.data.clear()
1970 _orig_wkd = functools.WeakKeyDictionary
1971 td = TracingDict()
1972 functools.WeakKeyDictionary = lambda: td
1973 c = collections
1974 @functools.singledispatch
1975 def g(arg):
1976 return "base"
1977 d = {}
1978 l = []
1979 self.assertEqual(len(td), 0)
1980 self.assertEqual(g(d), "base")
1981 self.assertEqual(len(td), 1)
1982 self.assertEqual(td.get_ops, [])
1983 self.assertEqual(td.set_ops, [dict])
1984 self.assertEqual(td.data[dict], g.registry[object])
1985 self.assertEqual(g(l), "base")
1986 self.assertEqual(len(td), 2)
1987 self.assertEqual(td.get_ops, [])
1988 self.assertEqual(td.set_ops, [dict, list])
1989 self.assertEqual(td.data[dict], g.registry[object])
1990 self.assertEqual(td.data[list], g.registry[object])
1991 self.assertEqual(td.data[dict], td.data[list])
1992 self.assertEqual(g(l), "base")
1993 self.assertEqual(g(d), "base")
1994 self.assertEqual(td.get_ops, [list, dict])
1995 self.assertEqual(td.set_ops, [dict, list])
1996 g.register(list, lambda arg: "list")
1997 self.assertEqual(td.get_ops, [list, dict])
1998 self.assertEqual(len(td), 0)
1999 self.assertEqual(g(d), "base")
2000 self.assertEqual(len(td), 1)
2001 self.assertEqual(td.get_ops, [list, dict])
2002 self.assertEqual(td.set_ops, [dict, list, dict])
2003 self.assertEqual(td.data[dict],
2004 functools._find_impl(dict, g.registry))
2005 self.assertEqual(g(l), "list")
2006 self.assertEqual(len(td), 2)
2007 self.assertEqual(td.get_ops, [list, dict])
2008 self.assertEqual(td.set_ops, [dict, list, dict, list])
2009 self.assertEqual(td.data[list],
2010 functools._find_impl(list, g.registry))
2011 class X:
2012 pass
2013 c.MutableMapping.register(X) # Will not invalidate the cache,
2014 # not using ABCs yet.
2015 self.assertEqual(g(d), "base")
2016 self.assertEqual(g(l), "list")
2017 self.assertEqual(td.get_ops, [list, dict, dict, list])
2018 self.assertEqual(td.set_ops, [dict, list, dict, list])
2019 g.register(c.Sized, lambda arg: "sized")
2020 self.assertEqual(len(td), 0)
2021 self.assertEqual(g(d), "sized")
2022 self.assertEqual(len(td), 1)
2023 self.assertEqual(td.get_ops, [list, dict, dict, list])
2024 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
2025 self.assertEqual(g(l), "list")
2026 self.assertEqual(len(td), 2)
2027 self.assertEqual(td.get_ops, [list, dict, dict, list])
2028 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2029 self.assertEqual(g(l), "list")
2030 self.assertEqual(g(d), "sized")
2031 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
2032 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2033 g.dispatch(list)
2034 g.dispatch(dict)
2035 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
2036 list, dict])
2037 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2038 c.MutableSet.register(X) # Will invalidate the cache.
2039 self.assertEqual(len(td), 2) # Stale cache.
2040 self.assertEqual(g(l), "list")
2041 self.assertEqual(len(td), 1)
2042 g.register(c.MutableMapping, lambda arg: "mutablemapping")
2043 self.assertEqual(len(td), 0)
2044 self.assertEqual(g(d), "mutablemapping")
2045 self.assertEqual(len(td), 1)
2046 self.assertEqual(g(l), "list")
2047 self.assertEqual(len(td), 2)
2048 g.register(dict, lambda arg: "dict")
2049 self.assertEqual(g(d), "dict")
2050 self.assertEqual(g(l), "list")
2051 g._clear_cache()
2052 self.assertEqual(len(td), 0)
2053 functools.WeakKeyDictionary = _orig_wkd
2054
2055
Raymond Hettinger9c323f82005-02-28 19:39:44 +00002056if __name__ == '__main__':
Zachary Ware38c707e2015-04-13 15:00:43 -05002057 unittest.main()