blob: f7a11666133d426f17704bac1b9b9bc3c5bc9382 [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08002import builtins
Raymond Hettinger003be522011-05-03 11:01:32 -07003import collections
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03004import collections.abc
Serhiy Storchaka45120f22015-10-24 09:49:56 +03005import copy
Łukasz Langa6f692512013-06-05 12:20:24 +02006from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00007import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00008from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02009import sys
10from test import support
Antoine Pitroua6a4dc82017-09-07 18:56:24 +020011import threading
Serhiy Storchaka67796522017-01-12 18:34:33 +020012import time
Łukasz Langa6f692512013-06-05 12:20:24 +020013import unittest
Raymond Hettingerd191ef22017-01-07 20:44:48 -080014import unittest.mock
Łukasz Langa6f692512013-06-05 12:20:24 +020015from weakref import proxy
Nick Coghlan457fc9a2016-09-10 20:00:02 +100016import contextlib
Raymond Hettinger9c323f82005-02-28 19:39:44 +000017
Antoine Pitroub5b37142012-11-13 21:35:40 +010018import functools
19
Antoine Pitroub5b37142012-11-13 21:35:40 +010020py_functools = support.import_fresh_module('functools', blocked=['_functools'])
21c_functools = support.import_fresh_module('functools', fresh=['_functools'])
22
Łukasz Langa6f692512013-06-05 12:20:24 +020023decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
24
Nick Coghlan457fc9a2016-09-10 20:00:02 +100025@contextlib.contextmanager
26def replaced_module(name, replacement):
27 original_module = sys.modules[name]
28 sys.modules[name] = replacement
29 try:
30 yield
31 finally:
32 sys.modules[name] = original_module
Łukasz Langa6f692512013-06-05 12:20:24 +020033
Raymond Hettinger9c323f82005-02-28 19:39:44 +000034def capture(*args, **kw):
35 """capture all positional and keyword arguments"""
36 return args, kw
37
Łukasz Langa6f692512013-06-05 12:20:24 +020038
Jack Diederiche0cbd692009-04-01 04:27:09 +000039def signature(part):
40 """ return the signature of a partial object """
41 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000042
Serhiy Storchaka38741282016-02-02 18:45:17 +020043class MyTuple(tuple):
44 pass
45
46class BadTuple(tuple):
47 def __add__(self, other):
48 return list(self) + list(other)
49
50class MyDict(dict):
51 pass
52
Łukasz Langa6f692512013-06-05 12:20:24 +020053
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020054class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000055
56 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010057 p = self.partial(capture, 1, 2, a=10, b=20)
58 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000059 self.assertEqual(p(3, 4, b=30, c=40),
60 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010061 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000062 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000063
64 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010065 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000066 # attributes should be readable
67 self.assertEqual(p.func, capture)
68 self.assertEqual(p.args, (1, 2))
69 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000070
71 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010072 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000073 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010074 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000075 except TypeError:
76 pass
77 else:
78 self.fail('First arg not checked for callability')
79
80 def test_protection_of_callers_dict_argument(self):
81 # a caller's dictionary should not be altered by partial
82 def func(a=10, b=20):
83 return a
84 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010085 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000086 self.assertEqual(p(**d), 3)
87 self.assertEqual(d, {'a':3})
88 p(b=7)
89 self.assertEqual(d, {'a':3})
90
Serhiy Storchaka9639e4a2017-02-20 14:04:30 +020091 def test_kwargs_copy(self):
92 # Issue #29532: Altering a kwarg dictionary passed to a constructor
93 # should not affect a partial object after creation
94 d = {'a': 3}
95 p = self.partial(capture, **d)
96 self.assertEqual(p(), ((), {'a': 3}))
97 d['a'] = 5
98 self.assertEqual(p(), ((), {'a': 3}))
99
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000100 def test_arg_combinations(self):
101 # exercise special code paths for zero args in either partial
102 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100103 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000104 self.assertEqual(p(), ((), {}))
105 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100106 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000107 self.assertEqual(p(), ((1,2), {}))
108 self.assertEqual(p(3,4), ((1,2,3,4), {}))
109
110 def test_kw_combinations(self):
111 # exercise special code paths for no keyword args in
112 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100113 p = self.partial(capture)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400114 self.assertEqual(p.keywords, {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000115 self.assertEqual(p(), ((), {}))
116 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100117 p = self.partial(capture, a=1)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400118 self.assertEqual(p.keywords, {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000119 self.assertEqual(p(), ((), {'a':1}))
120 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
121 # keyword args in the call override those in the partial object
122 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
123
124 def test_positional(self):
125 # make sure positional arguments are captured correctly
126 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100127 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000128 expected = args + ('x',)
129 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000130 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000131
132 def test_keyword(self):
133 # make sure keyword arguments are captured correctly
134 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100135 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000136 expected = {'a':a,'x':None}
137 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000138 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000139
140 def test_no_side_effects(self):
141 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100142 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000143 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000144 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000145 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000146 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000147
148 def test_error_propagation(self):
149 def f(x, y):
150 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100151 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
152 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
153 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
154 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000155
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000156 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100157 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000158 p = proxy(f)
159 self.assertEqual(f.func, p.func)
160 f = None
161 self.assertRaises(ReferenceError, getattr, p, 'func')
162
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000163 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000164 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100165 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000166 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100167 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000168 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000169
Alexander Belopolskye49af342015-03-01 15:08:17 -0500170 def test_nested_optimization(self):
171 partial = self.partial
Alexander Belopolskye49af342015-03-01 15:08:17 -0500172 inner = partial(signature, 'asdf')
173 nested = partial(inner, bar=True)
174 flat = partial(signature, 'asdf', bar=True)
175 self.assertEqual(signature(nested), signature(flat))
176
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300177 def test_nested_partial_with_attribute(self):
178 # see issue 25137
179 partial = self.partial
180
181 def foo(bar):
182 return bar
183
184 p = partial(foo, 'first')
185 p2 = partial(p, 'second')
186 p2.new_attr = 'spam'
187 self.assertEqual(p2.new_attr, 'spam')
188
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000189 def test_repr(self):
190 args = (object(), object())
191 args_repr = ', '.join(repr(a) for a in args)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200192 kwargs = {'a': object(), 'b': object()}
193 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
194 'b={b!r}, a={a!r}'.format_map(kwargs)]
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000195 if self.partial in (c_functools.partial, py_functools.partial):
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000196 name = 'functools.partial'
197 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100198 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000199
Antoine Pitroub5b37142012-11-13 21:35:40 +0100200 f = self.partial(capture)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000201 self.assertEqual(f'{name}({capture!r})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000202
Antoine Pitroub5b37142012-11-13 21:35:40 +0100203 f = self.partial(capture, *args)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000204 self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000205
Antoine Pitroub5b37142012-11-13 21:35:40 +0100206 f = self.partial(capture, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200207 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000208 [f'{name}({capture!r}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200209 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000210
Antoine Pitroub5b37142012-11-13 21:35:40 +0100211 f = self.partial(capture, *args, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200212 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000213 [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200214 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000215
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300216 def test_recursive_repr(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000217 if self.partial in (c_functools.partial, py_functools.partial):
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300218 name = 'functools.partial'
219 else:
220 name = self.partial.__name__
221
222 f = self.partial(capture)
223 f.__setstate__((f, (), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300224 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000225 self.assertEqual(repr(f), '%s(...)' % (name,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300226 finally:
227 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300228
229 f = self.partial(capture)
230 f.__setstate__((capture, (f,), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300231 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000232 self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300233 finally:
234 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300235
236 f = self.partial(capture)
237 f.__setstate__((capture, (), {'a': f}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300238 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000239 self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300240 finally:
241 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300242
Jack Diederiche0cbd692009-04-01 04:27:09 +0000243 def test_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000244 with self.AllowPickle():
245 f = self.partial(signature, ['asdf'], bar=[True])
246 f.attr = []
247 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
248 f_copy = pickle.loads(pickle.dumps(f, proto))
249 self.assertEqual(signature(f_copy), signature(f))
Serhiy Storchaka38741282016-02-02 18:45:17 +0200250
251 def test_copy(self):
252 f = self.partial(signature, ['asdf'], bar=[True])
253 f.attr = []
254 f_copy = copy.copy(f)
255 self.assertEqual(signature(f_copy), signature(f))
256 self.assertIs(f_copy.attr, f.attr)
257 self.assertIs(f_copy.args, f.args)
258 self.assertIs(f_copy.keywords, f.keywords)
259
260 def test_deepcopy(self):
261 f = self.partial(signature, ['asdf'], bar=[True])
262 f.attr = []
263 f_copy = copy.deepcopy(f)
264 self.assertEqual(signature(f_copy), signature(f))
265 self.assertIsNot(f_copy.attr, f.attr)
266 self.assertIsNot(f_copy.args, f.args)
267 self.assertIsNot(f_copy.args[0], f.args[0])
268 self.assertIsNot(f_copy.keywords, f.keywords)
269 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
270
271 def test_setstate(self):
272 f = self.partial(signature)
273 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000274
Serhiy Storchaka38741282016-02-02 18:45:17 +0200275 self.assertEqual(signature(f),
276 (capture, (1,), dict(a=10), dict(attr=[])))
277 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
278
279 f.__setstate__((capture, (1,), dict(a=10), None))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000280
Serhiy Storchaka38741282016-02-02 18:45:17 +0200281 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
282 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
283
284 f.__setstate__((capture, (1,), None, None))
285 #self.assertEqual(signature(f), (capture, (1,), {}, {}))
286 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
287 self.assertEqual(f(2), ((1, 2), {}))
288 self.assertEqual(f(), ((1,), {}))
289
290 f.__setstate__((capture, (), {}, None))
291 self.assertEqual(signature(f), (capture, (), {}, {}))
292 self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
293 self.assertEqual(f(2), ((2,), {}))
294 self.assertEqual(f(), ((), {}))
295
296 def test_setstate_errors(self):
297 f = self.partial(signature)
298 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
299 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
300 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
301 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
302 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
303 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
304 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
305
306 def test_setstate_subclasses(self):
307 f = self.partial(signature)
308 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
309 s = signature(f)
310 self.assertEqual(s, (capture, (1,), dict(a=10), {}))
311 self.assertIs(type(s[1]), tuple)
312 self.assertIs(type(s[2]), dict)
313 r = f()
314 self.assertEqual(r, ((1,), {'a': 10}))
315 self.assertIs(type(r[0]), tuple)
316 self.assertIs(type(r[1]), dict)
317
318 f.__setstate__((capture, BadTuple((1,)), {}, None))
319 s = signature(f)
320 self.assertEqual(s, (capture, (1,), {}, {}))
321 self.assertIs(type(s[1]), tuple)
322 r = f(2)
323 self.assertEqual(r, ((1, 2), {}))
324 self.assertIs(type(r[0]), tuple)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000325
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300326 def test_recursive_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000327 with self.AllowPickle():
328 f = self.partial(capture)
329 f.__setstate__((f, (), {}, {}))
330 try:
331 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
332 with self.assertRaises(RecursionError):
333 pickle.dumps(f, proto)
334 finally:
335 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300336
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000337 f = self.partial(capture)
338 f.__setstate__((capture, (f,), {}, {}))
339 try:
340 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
341 f_copy = pickle.loads(pickle.dumps(f, proto))
342 try:
343 self.assertIs(f_copy.args[0], f_copy)
344 finally:
345 f_copy.__setstate__((capture, (), {}, {}))
346 finally:
347 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300348
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000349 f = self.partial(capture)
350 f.__setstate__((capture, (), {'a': f}, {}))
351 try:
352 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
353 f_copy = pickle.loads(pickle.dumps(f, proto))
354 try:
355 self.assertIs(f_copy.keywords['a'], f_copy)
356 finally:
357 f_copy.__setstate__((capture, (), {}, {}))
358 finally:
359 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300360
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200361 # Issue 6083: Reference counting bug
362 def test_setstate_refcount(self):
363 class BadSequence:
364 def __len__(self):
365 return 4
366 def __getitem__(self, key):
367 if key == 0:
368 return max
369 elif key == 1:
370 return tuple(range(1000000))
371 elif key in (2, 3):
372 return {}
373 raise IndexError
374
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200375 f = self.partial(object)
Serhiy Storchaka38741282016-02-02 18:45:17 +0200376 self.assertRaises(TypeError, f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000377
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000378@unittest.skipUnless(c_functools, 'requires the C _functools module')
379class TestPartialC(TestPartial, unittest.TestCase):
380 if c_functools:
381 partial = c_functools.partial
382
383 class AllowPickle:
384 def __enter__(self):
385 return self
386 def __exit__(self, type, value, tb):
387 return False
388
389 def test_attributes_unwritable(self):
390 # attributes should not be writable
391 p = self.partial(capture, 1, 2, a=10, b=20)
392 self.assertRaises(AttributeError, setattr, p, 'func', map)
393 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
394 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
395
396 p = self.partial(hex)
397 try:
398 del p.__dict__
399 except TypeError:
400 pass
401 else:
402 self.fail('partial object allowed __dict__ to be deleted')
Łukasz Langa6f692512013-06-05 12:20:24 +0200403
Michael Seifert6c3d5272017-03-15 06:26:33 +0100404 def test_manually_adding_non_string_keyword(self):
405 p = self.partial(capture)
406 # Adding a non-string/unicode keyword to partial kwargs
407 p.keywords[1234] = 'value'
408 r = repr(p)
409 self.assertIn('1234', r)
410 self.assertIn("'value'", r)
411 with self.assertRaises(TypeError):
412 p()
413
414 def test_keystr_replaces_value(self):
415 p = self.partial(capture)
416
417 class MutatesYourDict(object):
418 def __str__(self):
419 p.keywords[self] = ['sth2']
420 return 'astr'
421
422 # Raplacing the value during key formatting should keep the original
423 # value alive (at least long enough).
424 p.keywords[MutatesYourDict()] = ['sth']
425 r = repr(p)
426 self.assertIn('astr', r)
427 self.assertIn("['sth']", r)
428
429
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200430class TestPartialPy(TestPartial, unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000431 partial = py_functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000432
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000433 class AllowPickle:
434 def __init__(self):
435 self._cm = replaced_module("functools", py_functools)
436 def __enter__(self):
437 return self._cm.__enter__()
438 def __exit__(self, type, value, tb):
439 return self._cm.__exit__(type, value, tb)
Łukasz Langa6f692512013-06-05 12:20:24 +0200440
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200441if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000442 class CPartialSubclass(c_functools.partial):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200443 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100444
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000445class PyPartialSubclass(py_functools.partial):
446 pass
Łukasz Langa6f692512013-06-05 12:20:24 +0200447
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200448@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200449class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200450 if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000451 partial = CPartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000452
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300453 # partial subclasses are not optimized for nested calls
454 test_nested_optimization = None
455
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000456class TestPartialPySubclass(TestPartialPy):
457 partial = PyPartialSubclass
Łukasz Langa6f692512013-06-05 12:20:24 +0200458
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000459class TestPartialMethod(unittest.TestCase):
460
461 class A(object):
462 nothing = functools.partialmethod(capture)
463 positional = functools.partialmethod(capture, 1)
464 keywords = functools.partialmethod(capture, a=2)
465 both = functools.partialmethod(capture, 3, b=4)
466
467 nested = functools.partialmethod(positional, 5)
468
469 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
470
471 static = functools.partialmethod(staticmethod(capture), 8)
472 cls = functools.partialmethod(classmethod(capture), d=9)
473
474 a = A()
475
476 def test_arg_combinations(self):
477 self.assertEqual(self.a.nothing(), ((self.a,), {}))
478 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
479 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
480 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
481
482 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
483 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
484 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
485 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
486
487 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
488 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
489 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
490 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
491
492 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
493 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
494 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
495 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
496
497 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
498
499 def test_nested(self):
500 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
501 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
502 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
503 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
504
505 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
506
507 def test_over_partial(self):
508 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
509 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
510 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
511 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
512
513 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
514
515 def test_bound_method_introspection(self):
516 obj = self.a
517 self.assertIs(obj.both.__self__, obj)
518 self.assertIs(obj.nested.__self__, obj)
519 self.assertIs(obj.over_partial.__self__, obj)
520 self.assertIs(obj.cls.__self__, self.A)
521 self.assertIs(self.A.cls.__self__, self.A)
522
523 def test_unbound_method_retrieval(self):
524 obj = self.A
525 self.assertFalse(hasattr(obj.both, "__self__"))
526 self.assertFalse(hasattr(obj.nested, "__self__"))
527 self.assertFalse(hasattr(obj.over_partial, "__self__"))
528 self.assertFalse(hasattr(obj.static, "__self__"))
529 self.assertFalse(hasattr(self.a.static, "__self__"))
530
531 def test_descriptors(self):
532 for obj in [self.A, self.a]:
533 with self.subTest(obj=obj):
534 self.assertEqual(obj.static(), ((8,), {}))
535 self.assertEqual(obj.static(5), ((8, 5), {}))
536 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
537 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
538
539 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
540 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
541 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
542 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
543
544 def test_overriding_keywords(self):
545 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
546 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
547
548 def test_invalid_args(self):
549 with self.assertRaises(TypeError):
550 class B(object):
551 method = functools.partialmethod(None, 1)
552
553 def test_repr(self):
554 self.assertEqual(repr(vars(self.A)['both']),
555 'functools.partialmethod({}, 3, b=4)'.format(capture))
556
557 def test_abstract(self):
558 class Abstract(abc.ABCMeta):
559
560 @abc.abstractmethod
561 def add(self, x, y):
562 pass
563
564 add5 = functools.partialmethod(add, 5)
565
566 self.assertTrue(Abstract.add.__isabstractmethod__)
567 self.assertTrue(Abstract.add5.__isabstractmethod__)
568
569 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
570 self.assertFalse(getattr(func, '__isabstractmethod__', False))
571
572
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000573class TestUpdateWrapper(unittest.TestCase):
574
575 def check_wrapper(self, wrapper, wrapped,
576 assigned=functools.WRAPPER_ASSIGNMENTS,
577 updated=functools.WRAPPER_UPDATES):
578 # Check attributes were assigned
579 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000580 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000581 # Check attributes were updated
582 for name in updated:
583 wrapper_attr = getattr(wrapper, name)
584 wrapped_attr = getattr(wrapped, name)
585 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000586 if name == "__dict__" and key == "__wrapped__":
587 # __wrapped__ is overwritten by the update code
588 continue
589 self.assertIs(wrapped_attr[key], wrapper_attr[key])
590 # Check __wrapped__
591 self.assertIs(wrapper.__wrapped__, wrapped)
592
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000593
R. David Murray378c0cf2010-02-24 01:46:21 +0000594 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000595 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000596 """This is a test"""
597 pass
598 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000599 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000600 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000601 pass
602 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000603 return wrapper, f
604
605 def test_default_update(self):
606 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000607 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000608 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000609 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600610 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000611 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000612 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
613 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000614
R. David Murray378c0cf2010-02-24 01:46:21 +0000615 @unittest.skipIf(sys.flags.optimize >= 2,
616 "Docstrings are omitted with -O2 and above")
617 def test_default_update_doc(self):
618 wrapper, f = self._default_update()
619 self.assertEqual(wrapper.__doc__, 'This is a test')
620
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000621 def test_no_update(self):
622 def f():
623 """This is a test"""
624 pass
625 f.attr = 'This is also a test'
626 def wrapper():
627 pass
628 functools.update_wrapper(wrapper, f, (), ())
629 self.check_wrapper(wrapper, f, (), ())
630 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600631 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000632 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000633 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000634 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000635
636 def test_selective_update(self):
637 def f():
638 pass
639 f.attr = 'This is a different test'
640 f.dict_attr = dict(a=1, b=2, c=3)
641 def wrapper():
642 pass
643 wrapper.dict_attr = {}
644 assign = ('attr',)
645 update = ('dict_attr',)
646 functools.update_wrapper(wrapper, f, assign, update)
647 self.check_wrapper(wrapper, f, assign, update)
648 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600649 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000650 self.assertEqual(wrapper.__doc__, None)
651 self.assertEqual(wrapper.attr, 'This is a different test')
652 self.assertEqual(wrapper.dict_attr, f.dict_attr)
653
Nick Coghlan98876832010-08-17 06:17:18 +0000654 def test_missing_attributes(self):
655 def f():
656 pass
657 def wrapper():
658 pass
659 wrapper.dict_attr = {}
660 assign = ('attr',)
661 update = ('dict_attr',)
662 # Missing attributes on wrapped object are ignored
663 functools.update_wrapper(wrapper, f, assign, update)
664 self.assertNotIn('attr', wrapper.__dict__)
665 self.assertEqual(wrapper.dict_attr, {})
666 # Wrapper must have expected attributes for updating
667 del wrapper.dict_attr
668 with self.assertRaises(AttributeError):
669 functools.update_wrapper(wrapper, f, assign, update)
670 wrapper.dict_attr = 1
671 with self.assertRaises(AttributeError):
672 functools.update_wrapper(wrapper, f, assign, update)
673
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200674 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000675 @unittest.skipIf(sys.flags.optimize >= 2,
676 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000677 def test_builtin_update(self):
678 # Test for bug #1576241
679 def wrapper():
680 pass
681 functools.update_wrapper(wrapper, max)
682 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000683 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000684 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000685
Łukasz Langa6f692512013-06-05 12:20:24 +0200686
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000687class TestWraps(TestUpdateWrapper):
688
R. David Murray378c0cf2010-02-24 01:46:21 +0000689 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000690 def f():
691 """This is a test"""
692 pass
693 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000694 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000695 @functools.wraps(f)
696 def wrapper():
697 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600698 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000699
700 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600701 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000702 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000703 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600704 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000705 self.assertEqual(wrapper.attr, 'This is also a test')
706
Antoine Pitroub5b37142012-11-13 21:35:40 +0100707 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000708 "Docstrings are omitted with -O2 and above")
709 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600710 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000711 self.assertEqual(wrapper.__doc__, 'This is a test')
712
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000713 def test_no_update(self):
714 def f():
715 """This is a test"""
716 pass
717 f.attr = 'This is also a test'
718 @functools.wraps(f, (), ())
719 def wrapper():
720 pass
721 self.check_wrapper(wrapper, f, (), ())
722 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600723 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000724 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000725 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000726
727 def test_selective_update(self):
728 def f():
729 pass
730 f.attr = 'This is a different test'
731 f.dict_attr = dict(a=1, b=2, c=3)
732 def add_dict_attr(f):
733 f.dict_attr = {}
734 return f
735 assign = ('attr',)
736 update = ('dict_attr',)
737 @functools.wraps(f, assign, update)
738 @add_dict_attr
739 def wrapper():
740 pass
741 self.check_wrapper(wrapper, f, assign, update)
742 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600743 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000744 self.assertEqual(wrapper.__doc__, None)
745 self.assertEqual(wrapper.attr, 'This is a different test')
746 self.assertEqual(wrapper.dict_attr, f.dict_attr)
747
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000748@unittest.skipUnless(c_functools, 'requires the C _functools module')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000749class TestReduce(unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000750 if c_functools:
751 func = c_functools.reduce
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000752
753 def test_reduce(self):
754 class Squares:
755 def __init__(self, max):
756 self.max = max
757 self.sofar = []
758
759 def __len__(self):
760 return len(self.sofar)
761
762 def __getitem__(self, i):
763 if not 0 <= i < self.max: raise IndexError
764 n = len(self.sofar)
765 while n <= i:
766 self.sofar.append(n*n)
767 n += 1
768 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000769 def add(x, y):
770 return x + y
771 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000772 self.assertEqual(
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000773 self.func(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000774 ['a','c','d','w']
775 )
776 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
777 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000778 self.func(lambda x, y: x*y, range(2,21), 1),
779 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000780 )
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000781 self.assertEqual(self.func(add, Squares(10)), 285)
782 self.assertEqual(self.func(add, Squares(10), 0), 285)
783 self.assertEqual(self.func(add, Squares(0), 0), 0)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000784 self.assertRaises(TypeError, self.func)
785 self.assertRaises(TypeError, self.func, 42, 42)
786 self.assertRaises(TypeError, self.func, 42, 42, 42)
787 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
788 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
789 self.assertRaises(TypeError, self.func, 42, (42, 42))
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000790 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
791 self.assertRaises(TypeError, self.func, add, "")
792 self.assertRaises(TypeError, self.func, add, ())
793 self.assertRaises(TypeError, self.func, add, object())
794
795 class TestFailingIter:
796 def __iter__(self):
797 raise RuntimeError
798 self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
799
800 self.assertEqual(self.func(add, [], None), None)
801 self.assertEqual(self.func(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000802
803 class BadSeq:
804 def __getitem__(self, index):
805 raise ValueError
806 self.assertRaises(ValueError, self.func, 42, BadSeq())
807
808 # Test reduce()'s use of iterators.
809 def test_iterator_usage(self):
810 class SequenceClass:
811 def __init__(self, n):
812 self.n = n
813 def __getitem__(self, i):
814 if 0 <= i < self.n:
815 return i
816 else:
817 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000818
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000819 from operator import add
820 self.assertEqual(self.func(add, SequenceClass(5)), 10)
821 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
822 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
823 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
824 self.assertEqual(self.func(add, SequenceClass(1)), 0)
825 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
826
827 d = {"one": 1, "two": 2, "three": 3}
828 self.assertEqual(self.func(add, d), "".join(d.keys()))
829
Łukasz Langa6f692512013-06-05 12:20:24 +0200830
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200831class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700832
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000833 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700834 def cmp1(x, y):
835 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100836 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700837 self.assertEqual(key(3), key(3))
838 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100839 self.assertGreaterEqual(key(3), key(3))
840
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700841 def cmp2(x, y):
842 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100843 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700844 self.assertEqual(key(4.0), key('4'))
845 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100846 self.assertLessEqual(key(2), key('35'))
847 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700848
849 def test_cmp_to_key_arguments(self):
850 def cmp1(x, y):
851 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100852 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700853 self.assertEqual(key(obj=3), key(obj=3))
854 self.assertGreater(key(obj=3), key(obj=1))
855 with self.assertRaises((TypeError, AttributeError)):
856 key(3) > 1 # rhs is not a K object
857 with self.assertRaises((TypeError, AttributeError)):
858 1 < key(3) # lhs is not a K object
859 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100860 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700861 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200862 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100863 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700864 with self.assertRaises(TypeError):
865 key() # too few args
866 with self.assertRaises(TypeError):
867 key(None, None) # too many args
868
869 def test_bad_cmp(self):
870 def cmp1(x, y):
871 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100872 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700873 with self.assertRaises(ZeroDivisionError):
874 key(3) > key(1)
875
876 class BadCmp:
877 def __lt__(self, other):
878 raise ZeroDivisionError
879 def cmp1(x, y):
880 return BadCmp()
881 with self.assertRaises(ZeroDivisionError):
882 key(3) > key(1)
883
884 def test_obj_field(self):
885 def cmp1(x, y):
886 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100887 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700888 self.assertEqual(key(50).obj, 50)
889
890 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000891 def mycmp(x, y):
892 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100893 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000894 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000895
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700896 def test_sort_int_str(self):
897 def mycmp(x, y):
898 x, y = int(x), int(y)
899 return (x > y) - (x < y)
900 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100901 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700902 self.assertEqual([int(value) for value in values],
903 [0, 1, 1, 2, 3, 4, 5, 7, 10])
904
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000905 def test_hash(self):
906 def mycmp(x, y):
907 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100908 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000909 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700910 self.assertRaises(TypeError, hash, k)
Serhiy Storchaka2e576f52017-04-24 09:05:00 +0300911 self.assertNotIsInstance(k, collections.abc.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000912
Łukasz Langa6f692512013-06-05 12:20:24 +0200913
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200914@unittest.skipUnless(c_functools, 'requires the C _functools module')
915class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
916 if c_functools:
917 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100918
Łukasz Langa6f692512013-06-05 12:20:24 +0200919
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200920class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100921 cmp_to_key = staticmethod(py_functools.cmp_to_key)
922
Łukasz Langa6f692512013-06-05 12:20:24 +0200923
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000924class TestTotalOrdering(unittest.TestCase):
925
926 def test_total_ordering_lt(self):
927 @functools.total_ordering
928 class A:
929 def __init__(self, value):
930 self.value = value
931 def __lt__(self, other):
932 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000933 def __eq__(self, other):
934 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000935 self.assertTrue(A(1) < A(2))
936 self.assertTrue(A(2) > A(1))
937 self.assertTrue(A(1) <= A(2))
938 self.assertTrue(A(2) >= A(1))
939 self.assertTrue(A(2) <= A(2))
940 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000941 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000942
943 def test_total_ordering_le(self):
944 @functools.total_ordering
945 class A:
946 def __init__(self, value):
947 self.value = value
948 def __le__(self, other):
949 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000950 def __eq__(self, other):
951 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000952 self.assertTrue(A(1) < A(2))
953 self.assertTrue(A(2) > A(1))
954 self.assertTrue(A(1) <= A(2))
955 self.assertTrue(A(2) >= A(1))
956 self.assertTrue(A(2) <= A(2))
957 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000958 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000959
960 def test_total_ordering_gt(self):
961 @functools.total_ordering
962 class A:
963 def __init__(self, value):
964 self.value = value
965 def __gt__(self, other):
966 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000967 def __eq__(self, other):
968 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000969 self.assertTrue(A(1) < A(2))
970 self.assertTrue(A(2) > A(1))
971 self.assertTrue(A(1) <= A(2))
972 self.assertTrue(A(2) >= A(1))
973 self.assertTrue(A(2) <= A(2))
974 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000975 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000976
977 def test_total_ordering_ge(self):
978 @functools.total_ordering
979 class A:
980 def __init__(self, value):
981 self.value = value
982 def __ge__(self, other):
983 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000984 def __eq__(self, other):
985 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000986 self.assertTrue(A(1) < A(2))
987 self.assertTrue(A(2) > A(1))
988 self.assertTrue(A(1) <= A(2))
989 self.assertTrue(A(2) >= A(1))
990 self.assertTrue(A(2) <= A(2))
991 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000992 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000993
994 def test_total_ordering_no_overwrite(self):
995 # new methods should not overwrite existing
996 @functools.total_ordering
997 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000998 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +0000999 self.assertTrue(A(1) < A(2))
1000 self.assertTrue(A(2) > A(1))
1001 self.assertTrue(A(1) <= A(2))
1002 self.assertTrue(A(2) >= A(1))
1003 self.assertTrue(A(2) <= A(2))
1004 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001005
Benjamin Peterson42ebee32010-04-11 01:43:16 +00001006 def test_no_operations_defined(self):
1007 with self.assertRaises(ValueError):
1008 @functools.total_ordering
1009 class A:
1010 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001011
Nick Coghlanf05d9812013-10-02 00:02:03 +10001012 def test_type_error_when_not_implemented(self):
1013 # bug 10042; ensure stack overflow does not occur
1014 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001015 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +10001016 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001017 def __init__(self, value):
1018 self.value = value
1019 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001020 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001021 return self.value == other.value
1022 return False
1023 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001024 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001025 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +10001026 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001027
Nick Coghlanf05d9812013-10-02 00:02:03 +10001028 @functools.total_ordering
1029 class ImplementsGreaterThan:
1030 def __init__(self, value):
1031 self.value = value
1032 def __eq__(self, other):
1033 if isinstance(other, ImplementsGreaterThan):
1034 return self.value == other.value
1035 return False
1036 def __gt__(self, other):
1037 if isinstance(other, ImplementsGreaterThan):
1038 return self.value > other.value
1039 return NotImplemented
1040
1041 @functools.total_ordering
1042 class ImplementsLessThanEqualTo:
1043 def __init__(self, value):
1044 self.value = value
1045 def __eq__(self, other):
1046 if isinstance(other, ImplementsLessThanEqualTo):
1047 return self.value == other.value
1048 return False
1049 def __le__(self, other):
1050 if isinstance(other, ImplementsLessThanEqualTo):
1051 return self.value <= other.value
1052 return NotImplemented
1053
1054 @functools.total_ordering
1055 class ImplementsGreaterThanEqualTo:
1056 def __init__(self, value):
1057 self.value = value
1058 def __eq__(self, other):
1059 if isinstance(other, ImplementsGreaterThanEqualTo):
1060 return self.value == other.value
1061 return False
1062 def __ge__(self, other):
1063 if isinstance(other, ImplementsGreaterThanEqualTo):
1064 return self.value >= other.value
1065 return NotImplemented
1066
1067 @functools.total_ordering
1068 class ComparatorNotImplemented:
1069 def __init__(self, value):
1070 self.value = value
1071 def __eq__(self, other):
1072 if isinstance(other, ComparatorNotImplemented):
1073 return self.value == other.value
1074 return False
1075 def __lt__(self, other):
1076 return NotImplemented
1077
1078 with self.subTest("LT < 1"), self.assertRaises(TypeError):
1079 ImplementsLessThan(-1) < 1
1080
1081 with self.subTest("LT < LE"), self.assertRaises(TypeError):
1082 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
1083
1084 with self.subTest("LT < GT"), self.assertRaises(TypeError):
1085 ImplementsLessThan(1) < ImplementsGreaterThan(1)
1086
1087 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
1088 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
1089
1090 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
1091 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
1092
1093 with self.subTest("GT > GE"), self.assertRaises(TypeError):
1094 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
1095
1096 with self.subTest("GT > LT"), self.assertRaises(TypeError):
1097 ImplementsGreaterThan(5) > ImplementsLessThan(5)
1098
1099 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
1100 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
1101
1102 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
1103 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
1104
1105 with self.subTest("GE when equal"):
1106 a = ComparatorNotImplemented(8)
1107 b = ComparatorNotImplemented(8)
1108 self.assertEqual(a, b)
1109 with self.assertRaises(TypeError):
1110 a >= b
1111
1112 with self.subTest("LE when equal"):
1113 a = ComparatorNotImplemented(9)
1114 b = ComparatorNotImplemented(9)
1115 self.assertEqual(a, b)
1116 with self.assertRaises(TypeError):
1117 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +02001118
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001119 def test_pickle(self):
Serhiy Storchaka92bb90a2016-09-22 11:39:25 +03001120 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001121 for name in '__lt__', '__gt__', '__le__', '__ge__':
1122 with self.subTest(method=name, proto=proto):
1123 method = getattr(Orderable_LT, name)
1124 method_copy = pickle.loads(pickle.dumps(method, proto))
1125 self.assertIs(method_copy, method)
1126
1127@functools.total_ordering
1128class Orderable_LT:
1129 def __init__(self, value):
1130 self.value = value
1131 def __lt__(self, other):
1132 return self.value < other.value
1133 def __eq__(self, other):
1134 return self.value == other.value
1135
1136
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001137class TestLRU:
Georg Brandl2e7346a2010-07-31 18:09:23 +00001138
1139 def test_lru(self):
1140 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001141 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001142 f = self.module.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001143 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001144 self.assertEqual(maxsize, 20)
1145 self.assertEqual(currsize, 0)
1146 self.assertEqual(hits, 0)
1147 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001148
1149 domain = range(5)
1150 for i in range(1000):
1151 x, y = choice(domain), choice(domain)
1152 actual = f(x, y)
1153 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +00001154 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001155 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001156 self.assertTrue(hits > misses)
1157 self.assertEqual(hits + misses, 1000)
1158 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001159
Raymond Hettinger02566ec2010-09-04 22:46:06 +00001160 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +00001161 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001162 self.assertEqual(hits, 0)
1163 self.assertEqual(misses, 0)
1164 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001165 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001166 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001167 self.assertEqual(hits, 0)
1168 self.assertEqual(misses, 1)
1169 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001170
Nick Coghlan98876832010-08-17 06:17:18 +00001171 # Test bypassing the cache
1172 self.assertIs(f.__wrapped__, orig)
1173 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001174 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001175 self.assertEqual(hits, 0)
1176 self.assertEqual(misses, 1)
1177 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +00001178
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001179 # test size zero (which means "never-cache")
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001180 @self.module.lru_cache(0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001181 def f():
1182 nonlocal f_cnt
1183 f_cnt += 1
1184 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001185 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001186 f_cnt = 0
1187 for i in range(5):
1188 self.assertEqual(f(), 20)
1189 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001190 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001191 self.assertEqual(hits, 0)
1192 self.assertEqual(misses, 5)
1193 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001194
1195 # test size one
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001196 @self.module.lru_cache(1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001197 def f():
1198 nonlocal f_cnt
1199 f_cnt += 1
1200 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001201 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001202 f_cnt = 0
1203 for i in range(5):
1204 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001205 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001206 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001207 self.assertEqual(hits, 4)
1208 self.assertEqual(misses, 1)
1209 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001210
Raymond Hettingerf3098282010-08-15 03:30:45 +00001211 # test size two
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001212 @self.module.lru_cache(2)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001213 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001214 nonlocal f_cnt
1215 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +00001216 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +00001217 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001218 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +00001219 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1220 # * * * *
1221 self.assertEqual(f(x), x*10)
1222 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001223 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001224 self.assertEqual(hits, 12)
1225 self.assertEqual(misses, 4)
1226 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001227
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001228 def test_lru_hash_only_once(self):
1229 # To protect against weird reentrancy bugs and to improve
1230 # efficiency when faced with slow __hash__ methods, the
1231 # LRU cache guarantees that it will only call __hash__
1232 # only once per use as an argument to the cached function.
1233
1234 @self.module.lru_cache(maxsize=1)
1235 def f(x, y):
1236 return x * 3 + y
1237
1238 # Simulate the integer 5
1239 mock_int = unittest.mock.Mock()
1240 mock_int.__mul__ = unittest.mock.Mock(return_value=15)
1241 mock_int.__hash__ = unittest.mock.Mock(return_value=999)
1242
1243 # Add to cache: One use as an argument gives one call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001244 self.assertEqual(f(mock_int, 1), 16)
1245 self.assertEqual(mock_int.__hash__.call_count, 1)
1246 self.assertEqual(f.cache_info(), (0, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001247
1248 # Cache hit: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001249 self.assertEqual(f(mock_int, 1), 16)
1250 self.assertEqual(mock_int.__hash__.call_count, 2)
1251 self.assertEqual(f.cache_info(), (1, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001252
Ville Skyttä49b27342017-08-03 09:00:59 +03001253 # Cache eviction: No use as an argument gives no additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001254 self.assertEqual(f(6, 2), 20)
1255 self.assertEqual(mock_int.__hash__.call_count, 2)
1256 self.assertEqual(f.cache_info(), (1, 2, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001257
1258 # Cache miss: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001259 self.assertEqual(f(mock_int, 1), 16)
1260 self.assertEqual(mock_int.__hash__.call_count, 3)
1261 self.assertEqual(f.cache_info(), (1, 3, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001262
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08001263 def test_lru_reentrancy_with_len(self):
1264 # Test to make sure the LRU cache code isn't thrown-off by
1265 # caching the built-in len() function. Since len() can be
1266 # cached, we shouldn't use it inside the lru code itself.
1267 old_len = builtins.len
1268 try:
1269 builtins.len = self.module.lru_cache(4)(len)
1270 for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
1271 self.assertEqual(len('abcdefghijklmn'[:i]), i)
1272 finally:
1273 builtins.len = old_len
1274
Raymond Hettinger605a4472017-01-09 07:50:19 -08001275 def test_lru_star_arg_handling(self):
1276 # Test regression that arose in ea064ff3c10f
1277 @functools.lru_cache()
1278 def f(*args):
1279 return args
1280
1281 self.assertEqual(f(1, 2), (1, 2))
1282 self.assertEqual(f((1, 2)), ((1, 2),))
1283
Yury Selivanov46a02db2016-11-09 18:55:45 -05001284 def test_lru_type_error(self):
1285 # Regression test for issue #28653.
1286 # lru_cache was leaking when one of the arguments
1287 # wasn't cacheable.
1288
1289 @functools.lru_cache(maxsize=None)
1290 def infinite_cache(o):
1291 pass
1292
1293 @functools.lru_cache(maxsize=10)
1294 def limited_cache(o):
1295 pass
1296
1297 with self.assertRaises(TypeError):
1298 infinite_cache([])
1299
1300 with self.assertRaises(TypeError):
1301 limited_cache([])
1302
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001303 def test_lru_with_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001304 @self.module.lru_cache(maxsize=None)
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001305 def fib(n):
1306 if n < 2:
1307 return n
1308 return fib(n-1) + fib(n-2)
1309 self.assertEqual([fib(n) for n in range(16)],
1310 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1311 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001312 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001313 fib.cache_clear()
1314 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001315 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1316
1317 def test_lru_with_maxsize_negative(self):
1318 @self.module.lru_cache(maxsize=-10)
1319 def eq(n):
1320 return n
1321 for i in (0, 1):
1322 self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1323 self.assertEqual(eq.cache_info(),
1324 self.module._CacheInfo(hits=0, misses=300, maxsize=-10, currsize=1))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001325
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001326 def test_lru_with_exceptions(self):
1327 # Verify that user_function exceptions get passed through without
1328 # creating a hard-to-read chained exception.
1329 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +01001330 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001331 @self.module.lru_cache(maxsize)
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001332 def func(i):
1333 return 'abc'[i]
1334 self.assertEqual(func(0), 'a')
1335 with self.assertRaises(IndexError) as cm:
1336 func(15)
1337 self.assertIsNone(cm.exception.__context__)
1338 # Verify that the previous exception did not result in a cached entry
1339 with self.assertRaises(IndexError):
1340 func(15)
1341
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001342 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001343 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001344 @self.module.lru_cache(maxsize=maxsize, typed=True)
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001345 def square(x):
1346 return x * x
1347 self.assertEqual(square(3), 9)
1348 self.assertEqual(type(square(3)), type(9))
1349 self.assertEqual(square(3.0), 9.0)
1350 self.assertEqual(type(square(3.0)), type(9.0))
1351 self.assertEqual(square(x=3), 9)
1352 self.assertEqual(type(square(x=3)), type(9))
1353 self.assertEqual(square(x=3.0), 9.0)
1354 self.assertEqual(type(square(x=3.0)), type(9.0))
1355 self.assertEqual(square.cache_info().hits, 4)
1356 self.assertEqual(square.cache_info().misses, 4)
1357
Antoine Pitroub5b37142012-11-13 21:35:40 +01001358 def test_lru_with_keyword_args(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001359 @self.module.lru_cache()
Antoine Pitroub5b37142012-11-13 21:35:40 +01001360 def fib(n):
1361 if n < 2:
1362 return n
1363 return fib(n=n-1) + fib(n=n-2)
1364 self.assertEqual(
1365 [fib(n=number) for number in range(16)],
1366 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1367 )
1368 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001369 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001370 fib.cache_clear()
1371 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001372 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001373
1374 def test_lru_with_keyword_args_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001375 @self.module.lru_cache(maxsize=None)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001376 def fib(n):
1377 if n < 2:
1378 return n
1379 return fib(n=n-1) + fib(n=n-2)
1380 self.assertEqual([fib(n=number) for number in range(16)],
1381 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1382 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001383 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001384 fib.cache_clear()
1385 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001386 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1387
Raymond Hettinger4ee39142017-01-08 17:28:20 -08001388 def test_kwargs_order(self):
1389 # PEP 468: Preserving Keyword Argument Order
1390 @self.module.lru_cache(maxsize=10)
1391 def f(**kwargs):
1392 return list(kwargs.items())
1393 self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)])
1394 self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)])
1395 self.assertEqual(f.cache_info(),
1396 self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2))
1397
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001398 def test_lru_cache_decoration(self):
1399 def f(zomg: 'zomg_annotation'):
1400 """f doc string"""
1401 return 42
1402 g = self.module.lru_cache()(f)
1403 for attr in self.module.WRAPPER_ASSIGNMENTS:
1404 self.assertEqual(getattr(g, attr), getattr(f, attr))
1405
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001406 def test_lru_cache_threaded(self):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001407 n, m = 5, 11
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001408 def orig(x, y):
1409 return 3 * x + y
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001410 f = self.module.lru_cache(maxsize=n*m)(orig)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001411 hits, misses, maxsize, currsize = f.cache_info()
1412 self.assertEqual(currsize, 0)
1413
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001414 start = threading.Event()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001415 def full(k):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001416 start.wait(10)
1417 for _ in range(m):
Serhiy Storchaka391af752015-06-08 12:44:18 +03001418 self.assertEqual(f(k, 0), orig(k, 0))
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001419
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001420 def clear():
1421 start.wait(10)
1422 for _ in range(2*m):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001423 f.cache_clear()
1424
1425 orig_si = sys.getswitchinterval()
Xavier de Gaye7522ef42016-12-08 11:06:56 +01001426 support.setswitchinterval(1e-6)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001427 try:
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001428 # create n threads in order to fill cache
Serhiy Storchaka391af752015-06-08 12:44:18 +03001429 threads = [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001430 for k in range(n)]
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001431 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001432 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001433
1434 hits, misses, maxsize, currsize = f.cache_info()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001435 if self.module is py_functools:
1436 # XXX: Why can be not equal?
1437 self.assertLessEqual(misses, n)
1438 self.assertLessEqual(hits, m*n - misses)
1439 else:
1440 self.assertEqual(misses, n)
1441 self.assertEqual(hits, m*n - misses)
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001442 self.assertEqual(currsize, n)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001443
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001444 # create n threads in order to fill cache and 1 to clear it
1445 threads = [threading.Thread(target=clear)]
Serhiy Storchaka391af752015-06-08 12:44:18 +03001446 threads += [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001447 for k in range(n)]
1448 start.clear()
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001449 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001450 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001451 finally:
1452 sys.setswitchinterval(orig_si)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001453
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001454 def test_lru_cache_threaded2(self):
1455 # Simultaneous call with the same arguments
1456 n, m = 5, 7
1457 start = threading.Barrier(n+1)
1458 pause = threading.Barrier(n+1)
1459 stop = threading.Barrier(n+1)
1460 @self.module.lru_cache(maxsize=m*n)
1461 def f(x):
1462 pause.wait(10)
1463 return 3 * x
1464 self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1465 def test():
1466 for i in range(m):
1467 start.wait(10)
1468 self.assertEqual(f(i), 3 * i)
1469 stop.wait(10)
1470 threads = [threading.Thread(target=test) for k in range(n)]
1471 with support.start_threads(threads):
1472 for i in range(m):
1473 start.wait(10)
1474 stop.reset()
1475 pause.wait(10)
1476 start.reset()
1477 stop.wait(10)
1478 pause.reset()
1479 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1480
Serhiy Storchaka67796522017-01-12 18:34:33 +02001481 def test_lru_cache_threaded3(self):
1482 @self.module.lru_cache(maxsize=2)
1483 def f(x):
1484 time.sleep(.01)
1485 return 3 * x
1486 def test(i, x):
1487 with self.subTest(thread=i):
1488 self.assertEqual(f(x), 3 * x, i)
1489 threads = [threading.Thread(target=test, args=(i, v))
1490 for i, v in enumerate([1, 2, 2, 3, 2])]
1491 with support.start_threads(threads):
1492 pass
1493
Raymond Hettinger03923422013-03-04 02:52:50 -05001494 def test_need_for_rlock(self):
1495 # This will deadlock on an LRU cache that uses a regular lock
1496
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001497 @self.module.lru_cache(maxsize=10)
Raymond Hettinger03923422013-03-04 02:52:50 -05001498 def test_func(x):
1499 'Used to demonstrate a reentrant lru_cache call within a single thread'
1500 return x
1501
1502 class DoubleEq:
1503 'Demonstrate a reentrant lru_cache call within a single thread'
1504 def __init__(self, x):
1505 self.x = x
1506 def __hash__(self):
1507 return self.x
1508 def __eq__(self, other):
1509 if self.x == 2:
1510 test_func(DoubleEq(1))
1511 return self.x == other.x
1512
1513 test_func(DoubleEq(1)) # Load the cache
1514 test_func(DoubleEq(2)) # Load the cache
1515 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1516 DoubleEq(2)) # Verify the correct return value
1517
Raymond Hettinger4d588972014-08-12 12:44:52 -07001518 def test_early_detection_of_bad_call(self):
1519 # Issue #22184
1520 with self.assertRaises(TypeError):
1521 @functools.lru_cache
1522 def f():
1523 pass
1524
Serhiy Storchakae7070f02015-06-08 11:19:24 +03001525 def test_lru_method(self):
1526 class X(int):
1527 f_cnt = 0
1528 @self.module.lru_cache(2)
1529 def f(self, x):
1530 self.f_cnt += 1
1531 return x*10+self
1532 a = X(5)
1533 b = X(5)
1534 c = X(7)
1535 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1536
1537 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1538 self.assertEqual(a.f(x), x*10 + 5)
1539 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1540 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1541
1542 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1543 self.assertEqual(b.f(x), x*10 + 5)
1544 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1545 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1546
1547 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1548 self.assertEqual(c.f(x), x*10 + 7)
1549 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1550 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1551
1552 self.assertEqual(a.f.cache_info(), X.f.cache_info())
1553 self.assertEqual(b.f.cache_info(), X.f.cache_info())
1554 self.assertEqual(c.f.cache_info(), X.f.cache_info())
1555
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001556 def test_pickle(self):
1557 cls = self.__class__
1558 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1559 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1560 with self.subTest(proto=proto, func=f):
1561 f_copy = pickle.loads(pickle.dumps(f, proto))
1562 self.assertIs(f_copy, f)
1563
1564 def test_copy(self):
1565 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001566 def orig(x, y):
1567 return 3 * x + y
1568 part = self.module.partial(orig, 2)
1569 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1570 self.module.lru_cache(2)(part))
1571 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001572 with self.subTest(func=f):
1573 f_copy = copy.copy(f)
1574 self.assertIs(f_copy, f)
1575
1576 def test_deepcopy(self):
1577 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001578 def orig(x, y):
1579 return 3 * x + y
1580 part = self.module.partial(orig, 2)
1581 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1582 self.module.lru_cache(2)(part))
1583 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001584 with self.subTest(func=f):
1585 f_copy = copy.deepcopy(f)
1586 self.assertIs(f_copy, f)
1587
1588
1589@py_functools.lru_cache()
1590def py_cached_func(x, y):
1591 return 3 * x + y
1592
1593@c_functools.lru_cache()
1594def c_cached_func(x, y):
1595 return 3 * x + y
1596
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001597
1598class TestLRUPy(TestLRU, unittest.TestCase):
1599 module = py_functools
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001600 cached_func = py_cached_func,
1601
1602 @module.lru_cache()
1603 def cached_meth(self, x, y):
1604 return 3 * x + y
1605
1606 @staticmethod
1607 @module.lru_cache()
1608 def cached_staticmeth(x, y):
1609 return 3 * x + y
1610
1611
1612class TestLRUC(TestLRU, unittest.TestCase):
1613 module = c_functools
1614 cached_func = c_cached_func,
1615
1616 @module.lru_cache()
1617 def cached_meth(self, x, y):
1618 return 3 * x + y
1619
1620 @staticmethod
1621 @module.lru_cache()
1622 def cached_staticmeth(x, y):
1623 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001624
Raymond Hettinger03923422013-03-04 02:52:50 -05001625
Łukasz Langa6f692512013-06-05 12:20:24 +02001626class TestSingleDispatch(unittest.TestCase):
1627 def test_simple_overloads(self):
1628 @functools.singledispatch
1629 def g(obj):
1630 return "base"
1631 def g_int(i):
1632 return "integer"
1633 g.register(int, g_int)
1634 self.assertEqual(g("str"), "base")
1635 self.assertEqual(g(1), "integer")
1636 self.assertEqual(g([1,2,3]), "base")
1637
1638 def test_mro(self):
1639 @functools.singledispatch
1640 def g(obj):
1641 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001642 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001643 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001644 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001645 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001646 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001647 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001648 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001649 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001650 def g_A(a):
1651 return "A"
1652 def g_B(b):
1653 return "B"
1654 g.register(A, g_A)
1655 g.register(B, g_B)
1656 self.assertEqual(g(A()), "A")
1657 self.assertEqual(g(B()), "B")
1658 self.assertEqual(g(C()), "A")
1659 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001660
1661 def test_register_decorator(self):
1662 @functools.singledispatch
1663 def g(obj):
1664 return "base"
1665 @g.register(int)
1666 def g_int(i):
1667 return "int %s" % (i,)
1668 self.assertEqual(g(""), "base")
1669 self.assertEqual(g(12), "int 12")
1670 self.assertIs(g.dispatch(int), g_int)
1671 self.assertIs(g.dispatch(object), g.dispatch(str))
1672 # Note: in the assert above this is not g.
1673 # @singledispatch returns the wrapper.
1674
1675 def test_wrapping_attributes(self):
1676 @functools.singledispatch
1677 def g(obj):
1678 "Simple test"
1679 return "Test"
1680 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02001681 if sys.flags.optimize < 2:
1682 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02001683
1684 @unittest.skipUnless(decimal, 'requires _decimal')
1685 @support.cpython_only
1686 def test_c_classes(self):
1687 @functools.singledispatch
1688 def g(obj):
1689 return "base"
1690 @g.register(decimal.DecimalException)
1691 def _(obj):
1692 return obj.args
1693 subn = decimal.Subnormal("Exponent < Emin")
1694 rnd = decimal.Rounded("Number got rounded")
1695 self.assertEqual(g(subn), ("Exponent < Emin",))
1696 self.assertEqual(g(rnd), ("Number got rounded",))
1697 @g.register(decimal.Subnormal)
1698 def _(obj):
1699 return "Too small to care."
1700 self.assertEqual(g(subn), "Too small to care.")
1701 self.assertEqual(g(rnd), ("Number got rounded",))
1702
1703 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001704 # None of the examples in this test depend on haystack ordering.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001705 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001706 mro = functools._compose_mro
1707 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1708 for haystack in permutations(bases):
1709 m = mro(dict, haystack)
Guido van Rossumf0666942016-08-23 10:47:07 -07001710 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
1711 c.Collection, c.Sized, c.Iterable,
1712 c.Container, object])
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001713 bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict]
Łukasz Langa6f692512013-06-05 12:20:24 +02001714 for haystack in permutations(bases):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001715 m = mro(collections.ChainMap, haystack)
1716 self.assertEqual(m, [collections.ChainMap, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001717 c.Collection, c.Sized, c.Iterable,
1718 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001719
1720 # If there's a generic function with implementations registered for
1721 # both Sized and Container, passing a defaultdict to it results in an
1722 # ambiguous dispatch which will cause a RuntimeError (see
1723 # test_mro_conflicts).
1724 bases = [c.Container, c.Sized, str]
1725 for haystack in permutations(bases):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001726 m = mro(collections.defaultdict, [c.Sized, c.Container, str])
1727 self.assertEqual(m, [collections.defaultdict, dict, c.Sized,
1728 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001729
1730 # MutableSequence below is registered directly on D. In other words, it
Martin Panter46f50722016-05-26 05:35:26 +00001731 # precedes MutableMapping which means single dispatch will always
Łukasz Langa3720c772013-07-01 16:00:38 +02001732 # choose MutableSequence here.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001733 class D(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02001734 pass
1735 c.MutableSequence.register(D)
1736 bases = [c.MutableSequence, c.MutableMapping]
1737 for haystack in permutations(bases):
1738 m = mro(D, bases)
Guido van Rossumf0666942016-08-23 10:47:07 -07001739 self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001740 collections.defaultdict, dict, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001741 c.Collection, c.Sized, c.Iterable, c.Container,
Łukasz Langa3720c772013-07-01 16:00:38 +02001742 object])
1743
1744 # Container and Callable are registered on different base classes and
1745 # a generic function supporting both should always pick the Callable
1746 # implementation if a C instance is passed.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001747 class C(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02001748 def __call__(self):
1749 pass
1750 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1751 for haystack in permutations(bases):
1752 m = mro(C, haystack)
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001753 self.assertEqual(m, [C, c.Callable, collections.defaultdict, dict, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001754 c.Collection, c.Sized, c.Iterable,
1755 c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001756
1757 def test_register_abc(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001758 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001759 d = {"a": "b"}
1760 l = [1, 2, 3]
1761 s = {object(), None}
1762 f = frozenset(s)
1763 t = (1, 2, 3)
1764 @functools.singledispatch
1765 def g(obj):
1766 return "base"
1767 self.assertEqual(g(d), "base")
1768 self.assertEqual(g(l), "base")
1769 self.assertEqual(g(s), "base")
1770 self.assertEqual(g(f), "base")
1771 self.assertEqual(g(t), "base")
1772 g.register(c.Sized, lambda obj: "sized")
1773 self.assertEqual(g(d), "sized")
1774 self.assertEqual(g(l), "sized")
1775 self.assertEqual(g(s), "sized")
1776 self.assertEqual(g(f), "sized")
1777 self.assertEqual(g(t), "sized")
1778 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1779 self.assertEqual(g(d), "mutablemapping")
1780 self.assertEqual(g(l), "sized")
1781 self.assertEqual(g(s), "sized")
1782 self.assertEqual(g(f), "sized")
1783 self.assertEqual(g(t), "sized")
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001784 g.register(collections.ChainMap, lambda obj: "chainmap")
Łukasz Langa6f692512013-06-05 12:20:24 +02001785 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1786 self.assertEqual(g(l), "sized")
1787 self.assertEqual(g(s), "sized")
1788 self.assertEqual(g(f), "sized")
1789 self.assertEqual(g(t), "sized")
1790 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1791 self.assertEqual(g(d), "mutablemapping")
1792 self.assertEqual(g(l), "mutablesequence")
1793 self.assertEqual(g(s), "sized")
1794 self.assertEqual(g(f), "sized")
1795 self.assertEqual(g(t), "sized")
1796 g.register(c.MutableSet, lambda obj: "mutableset")
1797 self.assertEqual(g(d), "mutablemapping")
1798 self.assertEqual(g(l), "mutablesequence")
1799 self.assertEqual(g(s), "mutableset")
1800 self.assertEqual(g(f), "sized")
1801 self.assertEqual(g(t), "sized")
1802 g.register(c.Mapping, lambda obj: "mapping")
1803 self.assertEqual(g(d), "mutablemapping") # not specific enough
1804 self.assertEqual(g(l), "mutablesequence")
1805 self.assertEqual(g(s), "mutableset")
1806 self.assertEqual(g(f), "sized")
1807 self.assertEqual(g(t), "sized")
1808 g.register(c.Sequence, lambda obj: "sequence")
1809 self.assertEqual(g(d), "mutablemapping")
1810 self.assertEqual(g(l), "mutablesequence")
1811 self.assertEqual(g(s), "mutableset")
1812 self.assertEqual(g(f), "sized")
1813 self.assertEqual(g(t), "sequence")
1814 g.register(c.Set, lambda obj: "set")
1815 self.assertEqual(g(d), "mutablemapping")
1816 self.assertEqual(g(l), "mutablesequence")
1817 self.assertEqual(g(s), "mutableset")
1818 self.assertEqual(g(f), "set")
1819 self.assertEqual(g(t), "sequence")
1820 g.register(dict, lambda obj: "dict")
1821 self.assertEqual(g(d), "dict")
1822 self.assertEqual(g(l), "mutablesequence")
1823 self.assertEqual(g(s), "mutableset")
1824 self.assertEqual(g(f), "set")
1825 self.assertEqual(g(t), "sequence")
1826 g.register(list, lambda obj: "list")
1827 self.assertEqual(g(d), "dict")
1828 self.assertEqual(g(l), "list")
1829 self.assertEqual(g(s), "mutableset")
1830 self.assertEqual(g(f), "set")
1831 self.assertEqual(g(t), "sequence")
1832 g.register(set, lambda obj: "concrete-set")
1833 self.assertEqual(g(d), "dict")
1834 self.assertEqual(g(l), "list")
1835 self.assertEqual(g(s), "concrete-set")
1836 self.assertEqual(g(f), "set")
1837 self.assertEqual(g(t), "sequence")
1838 g.register(frozenset, lambda obj: "frozen-set")
1839 self.assertEqual(g(d), "dict")
1840 self.assertEqual(g(l), "list")
1841 self.assertEqual(g(s), "concrete-set")
1842 self.assertEqual(g(f), "frozen-set")
1843 self.assertEqual(g(t), "sequence")
1844 g.register(tuple, lambda obj: "tuple")
1845 self.assertEqual(g(d), "dict")
1846 self.assertEqual(g(l), "list")
1847 self.assertEqual(g(s), "concrete-set")
1848 self.assertEqual(g(f), "frozen-set")
1849 self.assertEqual(g(t), "tuple")
1850
Łukasz Langa3720c772013-07-01 16:00:38 +02001851 def test_c3_abc(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001852 c = collections.abc
Łukasz Langa3720c772013-07-01 16:00:38 +02001853 mro = functools._c3_mro
1854 class A(object):
1855 pass
1856 class B(A):
1857 def __len__(self):
1858 return 0 # implies Sized
1859 @c.Container.register
1860 class C(object):
1861 pass
1862 class D(object):
1863 pass # unrelated
1864 class X(D, C, B):
1865 def __call__(self):
1866 pass # implies Callable
1867 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
1868 for abcs in permutations([c.Sized, c.Callable, c.Container]):
1869 self.assertEqual(mro(X, abcs=abcs), expected)
1870 # unrelated ABCs don't appear in the resulting MRO
1871 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
1872 self.assertEqual(mro(X, abcs=many_abcs), expected)
1873
Yury Selivanov77a8cd62015-08-18 14:20:00 -04001874 def test_false_meta(self):
1875 # see issue23572
1876 class MetaA(type):
1877 def __len__(self):
1878 return 0
1879 class A(metaclass=MetaA):
1880 pass
1881 class AA(A):
1882 pass
1883 @functools.singledispatch
1884 def fun(a):
1885 return 'base A'
1886 @fun.register(A)
1887 def _(a):
1888 return 'fun A'
1889 aa = AA()
1890 self.assertEqual(fun(aa), 'fun A')
1891
Łukasz Langa6f692512013-06-05 12:20:24 +02001892 def test_mro_conflicts(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001893 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001894 @functools.singledispatch
1895 def g(arg):
1896 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02001897 class O(c.Sized):
1898 def __len__(self):
1899 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001900 o = O()
1901 self.assertEqual(g(o), "base")
1902 g.register(c.Iterable, lambda arg: "iterable")
1903 g.register(c.Container, lambda arg: "container")
1904 g.register(c.Sized, lambda arg: "sized")
1905 g.register(c.Set, lambda arg: "set")
1906 self.assertEqual(g(o), "sized")
1907 c.Iterable.register(O)
1908 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
1909 c.Container.register(O)
1910 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02001911 c.Set.register(O)
1912 self.assertEqual(g(o), "set") # because c.Set is a subclass of
1913 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02001914 class P:
1915 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02001916 p = P()
1917 self.assertEqual(g(p), "base")
1918 c.Iterable.register(P)
1919 self.assertEqual(g(p), "iterable")
1920 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02001921 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02001922 g(p)
Benjamin Petersonab078e92016-07-13 21:13:29 -07001923 self.assertIn(
1924 str(re_one.exception),
1925 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1926 "or <class 'collections.abc.Iterable'>"),
1927 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
1928 "or <class 'collections.abc.Container'>")),
1929 )
Łukasz Langa6f692512013-06-05 12:20:24 +02001930 class Q(c.Sized):
1931 def __len__(self):
1932 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001933 q = Q()
1934 self.assertEqual(g(q), "sized")
1935 c.Iterable.register(Q)
1936 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
1937 c.Set.register(Q)
1938 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02001939 # c.Sized and c.Iterable
1940 @functools.singledispatch
1941 def h(arg):
1942 return "base"
1943 @h.register(c.Sized)
1944 def _(arg):
1945 return "sized"
1946 @h.register(c.Container)
1947 def _(arg):
1948 return "container"
1949 # Even though Sized and Container are explicit bases of MutableMapping,
1950 # this ABC is implicitly registered on defaultdict which makes all of
1951 # MutableMapping's bases implicit as well from defaultdict's
1952 # perspective.
1953 with self.assertRaises(RuntimeError) as re_two:
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001954 h(collections.defaultdict(lambda: 0))
Benjamin Petersonab078e92016-07-13 21:13:29 -07001955 self.assertIn(
1956 str(re_two.exception),
1957 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1958 "or <class 'collections.abc.Sized'>"),
1959 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1960 "or <class 'collections.abc.Container'>")),
1961 )
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001962 class R(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02001963 pass
1964 c.MutableSequence.register(R)
1965 @functools.singledispatch
1966 def i(arg):
1967 return "base"
1968 @i.register(c.MutableMapping)
1969 def _(arg):
1970 return "mapping"
1971 @i.register(c.MutableSequence)
1972 def _(arg):
1973 return "sequence"
1974 r = R()
1975 self.assertEqual(i(r), "sequence")
1976 class S:
1977 pass
1978 class T(S, c.Sized):
1979 def __len__(self):
1980 return 0
1981 t = T()
1982 self.assertEqual(h(t), "sized")
1983 c.Container.register(T)
1984 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
1985 class U:
1986 def __len__(self):
1987 return 0
1988 u = U()
1989 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
1990 # from the existence of __len__()
1991 c.Container.register(U)
1992 # There is no preference for registered versus inferred ABCs.
1993 with self.assertRaises(RuntimeError) as re_three:
1994 h(u)
Benjamin Petersonab078e92016-07-13 21:13:29 -07001995 self.assertIn(
1996 str(re_three.exception),
1997 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1998 "or <class 'collections.abc.Sized'>"),
1999 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2000 "or <class 'collections.abc.Container'>")),
2001 )
Łukasz Langa3720c772013-07-01 16:00:38 +02002002 class V(c.Sized, S):
2003 def __len__(self):
2004 return 0
2005 @functools.singledispatch
2006 def j(arg):
2007 return "base"
2008 @j.register(S)
2009 def _(arg):
2010 return "s"
2011 @j.register(c.Container)
2012 def _(arg):
2013 return "container"
2014 v = V()
2015 self.assertEqual(j(v), "s")
2016 c.Container.register(V)
2017 self.assertEqual(j(v), "container") # because it ends up right after
2018 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02002019
2020 def test_cache_invalidation(self):
2021 from collections import UserDict
2022 class TracingDict(UserDict):
2023 def __init__(self, *args, **kwargs):
2024 super(TracingDict, self).__init__(*args, **kwargs)
2025 self.set_ops = []
2026 self.get_ops = []
2027 def __getitem__(self, key):
2028 result = self.data[key]
2029 self.get_ops.append(key)
2030 return result
2031 def __setitem__(self, key, value):
2032 self.set_ops.append(key)
2033 self.data[key] = value
2034 def clear(self):
2035 self.data.clear()
2036 _orig_wkd = functools.WeakKeyDictionary
2037 td = TracingDict()
2038 functools.WeakKeyDictionary = lambda: td
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03002039 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02002040 @functools.singledispatch
2041 def g(arg):
2042 return "base"
2043 d = {}
2044 l = []
2045 self.assertEqual(len(td), 0)
2046 self.assertEqual(g(d), "base")
2047 self.assertEqual(len(td), 1)
2048 self.assertEqual(td.get_ops, [])
2049 self.assertEqual(td.set_ops, [dict])
2050 self.assertEqual(td.data[dict], g.registry[object])
2051 self.assertEqual(g(l), "base")
2052 self.assertEqual(len(td), 2)
2053 self.assertEqual(td.get_ops, [])
2054 self.assertEqual(td.set_ops, [dict, list])
2055 self.assertEqual(td.data[dict], g.registry[object])
2056 self.assertEqual(td.data[list], g.registry[object])
2057 self.assertEqual(td.data[dict], td.data[list])
2058 self.assertEqual(g(l), "base")
2059 self.assertEqual(g(d), "base")
2060 self.assertEqual(td.get_ops, [list, dict])
2061 self.assertEqual(td.set_ops, [dict, list])
2062 g.register(list, lambda arg: "list")
2063 self.assertEqual(td.get_ops, [list, dict])
2064 self.assertEqual(len(td), 0)
2065 self.assertEqual(g(d), "base")
2066 self.assertEqual(len(td), 1)
2067 self.assertEqual(td.get_ops, [list, dict])
2068 self.assertEqual(td.set_ops, [dict, list, dict])
2069 self.assertEqual(td.data[dict],
2070 functools._find_impl(dict, g.registry))
2071 self.assertEqual(g(l), "list")
2072 self.assertEqual(len(td), 2)
2073 self.assertEqual(td.get_ops, [list, dict])
2074 self.assertEqual(td.set_ops, [dict, list, dict, list])
2075 self.assertEqual(td.data[list],
2076 functools._find_impl(list, g.registry))
2077 class X:
2078 pass
2079 c.MutableMapping.register(X) # Will not invalidate the cache,
2080 # not using ABCs yet.
2081 self.assertEqual(g(d), "base")
2082 self.assertEqual(g(l), "list")
2083 self.assertEqual(td.get_ops, [list, dict, dict, list])
2084 self.assertEqual(td.set_ops, [dict, list, dict, list])
2085 g.register(c.Sized, lambda arg: "sized")
2086 self.assertEqual(len(td), 0)
2087 self.assertEqual(g(d), "sized")
2088 self.assertEqual(len(td), 1)
2089 self.assertEqual(td.get_ops, [list, dict, dict, list])
2090 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
2091 self.assertEqual(g(l), "list")
2092 self.assertEqual(len(td), 2)
2093 self.assertEqual(td.get_ops, [list, dict, dict, list])
2094 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2095 self.assertEqual(g(l), "list")
2096 self.assertEqual(g(d), "sized")
2097 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
2098 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2099 g.dispatch(list)
2100 g.dispatch(dict)
2101 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
2102 list, dict])
2103 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2104 c.MutableSet.register(X) # Will invalidate the cache.
2105 self.assertEqual(len(td), 2) # Stale cache.
2106 self.assertEqual(g(l), "list")
2107 self.assertEqual(len(td), 1)
2108 g.register(c.MutableMapping, lambda arg: "mutablemapping")
2109 self.assertEqual(len(td), 0)
2110 self.assertEqual(g(d), "mutablemapping")
2111 self.assertEqual(len(td), 1)
2112 self.assertEqual(g(l), "list")
2113 self.assertEqual(len(td), 2)
2114 g.register(dict, lambda arg: "dict")
2115 self.assertEqual(g(d), "dict")
2116 self.assertEqual(g(l), "list")
2117 g._clear_cache()
2118 self.assertEqual(len(td), 0)
2119 functools.WeakKeyDictionary = _orig_wkd
2120
2121
Raymond Hettinger9c323f82005-02-28 19:39:44 +00002122if __name__ == '__main__':
Zachary Ware38c707e2015-04-13 15:00:43 -05002123 unittest.main()