blob: 63a9ade54806c5251550b8d4900b58288e12573d [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08002import builtins
Raymond Hettinger003be522011-05-03 11:01:32 -07003import collections
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03004import collections.abc
Serhiy Storchaka45120f22015-10-24 09:49:56 +03005import copy
Łukasz Langa6f692512013-06-05 12:20:24 +02006from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00007import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00008from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02009import sys
10from test import support
Antoine Pitroua6a4dc82017-09-07 18:56:24 +020011import threading
Serhiy Storchaka67796522017-01-12 18:34:33 +020012import time
Łukasz Langae5697532017-12-11 13:56:31 -080013import typing
Łukasz Langa6f692512013-06-05 12:20:24 +020014import unittest
Raymond Hettingerd191ef22017-01-07 20:44:48 -080015import unittest.mock
Łukasz Langa6f692512013-06-05 12:20:24 +020016from weakref import proxy
Nick Coghlan457fc9a2016-09-10 20:00:02 +100017import contextlib
Raymond Hettinger9c323f82005-02-28 19:39:44 +000018
Antoine Pitroub5b37142012-11-13 21:35:40 +010019import functools
20
Antoine Pitroub5b37142012-11-13 21:35:40 +010021py_functools = support.import_fresh_module('functools', blocked=['_functools'])
22c_functools = support.import_fresh_module('functools', fresh=['_functools'])
23
Łukasz Langa6f692512013-06-05 12:20:24 +020024decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
25
Nick Coghlan457fc9a2016-09-10 20:00:02 +100026@contextlib.contextmanager
27def replaced_module(name, replacement):
28 original_module = sys.modules[name]
29 sys.modules[name] = replacement
30 try:
31 yield
32 finally:
33 sys.modules[name] = original_module
Łukasz Langa6f692512013-06-05 12:20:24 +020034
Raymond Hettinger9c323f82005-02-28 19:39:44 +000035def capture(*args, **kw):
36 """capture all positional and keyword arguments"""
37 return args, kw
38
Łukasz Langa6f692512013-06-05 12:20:24 +020039
Jack Diederiche0cbd692009-04-01 04:27:09 +000040def signature(part):
41 """ return the signature of a partial object """
42 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000043
Serhiy Storchaka38741282016-02-02 18:45:17 +020044class MyTuple(tuple):
45 pass
46
47class BadTuple(tuple):
48 def __add__(self, other):
49 return list(self) + list(other)
50
51class MyDict(dict):
52 pass
53
Łukasz Langa6f692512013-06-05 12:20:24 +020054
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020055class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000056
57 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010058 p = self.partial(capture, 1, 2, a=10, b=20)
59 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000060 self.assertEqual(p(3, 4, b=30, c=40),
61 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010062 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000063 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000064
65 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010066 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000067 # attributes should be readable
68 self.assertEqual(p.func, capture)
69 self.assertEqual(p.args, (1, 2))
70 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000071
72 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010073 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000074 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010075 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000076 except TypeError:
77 pass
78 else:
79 self.fail('First arg not checked for callability')
80
81 def test_protection_of_callers_dict_argument(self):
82 # a caller's dictionary should not be altered by partial
83 def func(a=10, b=20):
84 return a
85 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010086 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000087 self.assertEqual(p(**d), 3)
88 self.assertEqual(d, {'a':3})
89 p(b=7)
90 self.assertEqual(d, {'a':3})
91
Serhiy Storchaka9639e4a2017-02-20 14:04:30 +020092 def test_kwargs_copy(self):
93 # Issue #29532: Altering a kwarg dictionary passed to a constructor
94 # should not affect a partial object after creation
95 d = {'a': 3}
96 p = self.partial(capture, **d)
97 self.assertEqual(p(), ((), {'a': 3}))
98 d['a'] = 5
99 self.assertEqual(p(), ((), {'a': 3}))
100
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000101 def test_arg_combinations(self):
102 # exercise special code paths for zero args in either partial
103 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100104 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000105 self.assertEqual(p(), ((), {}))
106 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100107 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000108 self.assertEqual(p(), ((1,2), {}))
109 self.assertEqual(p(3,4), ((1,2,3,4), {}))
110
111 def test_kw_combinations(self):
112 # exercise special code paths for no keyword args in
113 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +0100114 p = self.partial(capture)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400115 self.assertEqual(p.keywords, {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000116 self.assertEqual(p(), ((), {}))
117 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100118 p = self.partial(capture, a=1)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -0400119 self.assertEqual(p.keywords, {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000120 self.assertEqual(p(), ((), {'a':1}))
121 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
122 # keyword args in the call override those in the partial object
123 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
124
125 def test_positional(self):
126 # make sure positional arguments are captured correctly
127 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100128 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000129 expected = args + ('x',)
130 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000131 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000132
133 def test_keyword(self):
134 # make sure keyword arguments are captured correctly
135 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100136 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000137 expected = {'a':a,'x':None}
138 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000139 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000140
141 def test_no_side_effects(self):
142 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100143 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000144 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000145 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000146 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000147 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000148
149 def test_error_propagation(self):
150 def f(x, y):
151 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100152 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
153 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
154 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
155 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000156
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000157 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100158 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000159 p = proxy(f)
160 self.assertEqual(f.func, p.func)
161 f = None
162 self.assertRaises(ReferenceError, getattr, p, 'func')
163
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000164 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000165 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100166 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000167 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100168 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000169 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000170
Alexander Belopolskye49af342015-03-01 15:08:17 -0500171 def test_nested_optimization(self):
172 partial = self.partial
Alexander Belopolskye49af342015-03-01 15:08:17 -0500173 inner = partial(signature, 'asdf')
174 nested = partial(inner, bar=True)
175 flat = partial(signature, 'asdf', bar=True)
176 self.assertEqual(signature(nested), signature(flat))
177
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300178 def test_nested_partial_with_attribute(self):
179 # see issue 25137
180 partial = self.partial
181
182 def foo(bar):
183 return bar
184
185 p = partial(foo, 'first')
186 p2 = partial(p, 'second')
187 p2.new_attr = 'spam'
188 self.assertEqual(p2.new_attr, 'spam')
189
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000190 def test_repr(self):
191 args = (object(), object())
192 args_repr = ', '.join(repr(a) for a in args)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200193 kwargs = {'a': object(), 'b': object()}
194 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
195 'b={b!r}, a={a!r}'.format_map(kwargs)]
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000196 if self.partial in (c_functools.partial, py_functools.partial):
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000197 name = 'functools.partial'
198 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100199 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000200
Antoine Pitroub5b37142012-11-13 21:35:40 +0100201 f = self.partial(capture)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000202 self.assertEqual(f'{name}({capture!r})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000203
Antoine Pitroub5b37142012-11-13 21:35:40 +0100204 f = self.partial(capture, *args)
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000205 self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000206
Antoine Pitroub5b37142012-11-13 21:35:40 +0100207 f = self.partial(capture, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200208 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000209 [f'{name}({capture!r}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200210 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000211
Antoine Pitroub5b37142012-11-13 21:35:40 +0100212 f = self.partial(capture, *args, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200213 self.assertIn(repr(f),
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000214 [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200215 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000216
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300217 def test_recursive_repr(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000218 if self.partial in (c_functools.partial, py_functools.partial):
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300219 name = 'functools.partial'
220 else:
221 name = self.partial.__name__
222
223 f = self.partial(capture)
224 f.__setstate__((f, (), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300225 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000226 self.assertEqual(repr(f), '%s(...)' % (name,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300227 finally:
228 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300229
230 f = self.partial(capture)
231 f.__setstate__((capture, (f,), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300232 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000233 self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300234 finally:
235 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300236
237 f = self.partial(capture)
238 f.__setstate__((capture, (), {'a': f}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300239 try:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000240 self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300241 finally:
242 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300243
Jack Diederiche0cbd692009-04-01 04:27:09 +0000244 def test_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000245 with self.AllowPickle():
246 f = self.partial(signature, ['asdf'], bar=[True])
247 f.attr = []
248 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
249 f_copy = pickle.loads(pickle.dumps(f, proto))
250 self.assertEqual(signature(f_copy), signature(f))
Serhiy Storchaka38741282016-02-02 18:45:17 +0200251
252 def test_copy(self):
253 f = self.partial(signature, ['asdf'], bar=[True])
254 f.attr = []
255 f_copy = copy.copy(f)
256 self.assertEqual(signature(f_copy), signature(f))
257 self.assertIs(f_copy.attr, f.attr)
258 self.assertIs(f_copy.args, f.args)
259 self.assertIs(f_copy.keywords, f.keywords)
260
261 def test_deepcopy(self):
262 f = self.partial(signature, ['asdf'], bar=[True])
263 f.attr = []
264 f_copy = copy.deepcopy(f)
265 self.assertEqual(signature(f_copy), signature(f))
266 self.assertIsNot(f_copy.attr, f.attr)
267 self.assertIsNot(f_copy.args, f.args)
268 self.assertIsNot(f_copy.args[0], f.args[0])
269 self.assertIsNot(f_copy.keywords, f.keywords)
270 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
271
272 def test_setstate(self):
273 f = self.partial(signature)
274 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000275
Serhiy Storchaka38741282016-02-02 18:45:17 +0200276 self.assertEqual(signature(f),
277 (capture, (1,), dict(a=10), dict(attr=[])))
278 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
279
280 f.__setstate__((capture, (1,), dict(a=10), None))
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000281
Serhiy Storchaka38741282016-02-02 18:45:17 +0200282 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
283 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
284
285 f.__setstate__((capture, (1,), None, None))
286 #self.assertEqual(signature(f), (capture, (1,), {}, {}))
287 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
288 self.assertEqual(f(2), ((1, 2), {}))
289 self.assertEqual(f(), ((1,), {}))
290
291 f.__setstate__((capture, (), {}, None))
292 self.assertEqual(signature(f), (capture, (), {}, {}))
293 self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
294 self.assertEqual(f(2), ((2,), {}))
295 self.assertEqual(f(), ((), {}))
296
297 def test_setstate_errors(self):
298 f = self.partial(signature)
299 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
300 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
301 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
302 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
303 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
304 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
305 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
306
307 def test_setstate_subclasses(self):
308 f = self.partial(signature)
309 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
310 s = signature(f)
311 self.assertEqual(s, (capture, (1,), dict(a=10), {}))
312 self.assertIs(type(s[1]), tuple)
313 self.assertIs(type(s[2]), dict)
314 r = f()
315 self.assertEqual(r, ((1,), {'a': 10}))
316 self.assertIs(type(r[0]), tuple)
317 self.assertIs(type(r[1]), dict)
318
319 f.__setstate__((capture, BadTuple((1,)), {}, None))
320 s = signature(f)
321 self.assertEqual(s, (capture, (1,), {}, {}))
322 self.assertIs(type(s[1]), tuple)
323 r = f(2)
324 self.assertEqual(r, ((1, 2), {}))
325 self.assertIs(type(r[0]), tuple)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000326
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300327 def test_recursive_pickle(self):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000328 with self.AllowPickle():
329 f = self.partial(capture)
330 f.__setstate__((f, (), {}, {}))
331 try:
332 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
333 with self.assertRaises(RecursionError):
334 pickle.dumps(f, proto)
335 finally:
336 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300337
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000338 f = self.partial(capture)
339 f.__setstate__((capture, (f,), {}, {}))
340 try:
341 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
342 f_copy = pickle.loads(pickle.dumps(f, proto))
343 try:
344 self.assertIs(f_copy.args[0], f_copy)
345 finally:
346 f_copy.__setstate__((capture, (), {}, {}))
347 finally:
348 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300349
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000350 f = self.partial(capture)
351 f.__setstate__((capture, (), {'a': f}, {}))
352 try:
353 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
354 f_copy = pickle.loads(pickle.dumps(f, proto))
355 try:
356 self.assertIs(f_copy.keywords['a'], f_copy)
357 finally:
358 f_copy.__setstate__((capture, (), {}, {}))
359 finally:
360 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300361
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200362 # Issue 6083: Reference counting bug
363 def test_setstate_refcount(self):
364 class BadSequence:
365 def __len__(self):
366 return 4
367 def __getitem__(self, key):
368 if key == 0:
369 return max
370 elif key == 1:
371 return tuple(range(1000000))
372 elif key in (2, 3):
373 return {}
374 raise IndexError
375
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200376 f = self.partial(object)
Serhiy Storchaka38741282016-02-02 18:45:17 +0200377 self.assertRaises(TypeError, f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000378
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000379@unittest.skipUnless(c_functools, 'requires the C _functools module')
380class TestPartialC(TestPartial, unittest.TestCase):
381 if c_functools:
382 partial = c_functools.partial
383
384 class AllowPickle:
385 def __enter__(self):
386 return self
387 def __exit__(self, type, value, tb):
388 return False
389
390 def test_attributes_unwritable(self):
391 # attributes should not be writable
392 p = self.partial(capture, 1, 2, a=10, b=20)
393 self.assertRaises(AttributeError, setattr, p, 'func', map)
394 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
395 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
396
397 p = self.partial(hex)
398 try:
399 del p.__dict__
400 except TypeError:
401 pass
402 else:
403 self.fail('partial object allowed __dict__ to be deleted')
Łukasz Langa6f692512013-06-05 12:20:24 +0200404
Michael Seifert6c3d5272017-03-15 06:26:33 +0100405 def test_manually_adding_non_string_keyword(self):
406 p = self.partial(capture)
407 # Adding a non-string/unicode keyword to partial kwargs
408 p.keywords[1234] = 'value'
409 r = repr(p)
410 self.assertIn('1234', r)
411 self.assertIn("'value'", r)
412 with self.assertRaises(TypeError):
413 p()
414
415 def test_keystr_replaces_value(self):
416 p = self.partial(capture)
417
418 class MutatesYourDict(object):
419 def __str__(self):
420 p.keywords[self] = ['sth2']
421 return 'astr'
422
Mike53f7a7c2017-12-14 14:04:53 +0300423 # Replacing the value during key formatting should keep the original
Michael Seifert6c3d5272017-03-15 06:26:33 +0100424 # value alive (at least long enough).
425 p.keywords[MutatesYourDict()] = ['sth']
426 r = repr(p)
427 self.assertIn('astr', r)
428 self.assertIn("['sth']", r)
429
430
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200431class TestPartialPy(TestPartial, unittest.TestCase):
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000432 partial = py_functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000433
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000434 class AllowPickle:
435 def __init__(self):
436 self._cm = replaced_module("functools", py_functools)
437 def __enter__(self):
438 return self._cm.__enter__()
439 def __exit__(self, type, value, tb):
440 return self._cm.__exit__(type, value, tb)
Łukasz Langa6f692512013-06-05 12:20:24 +0200441
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200442if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000443 class CPartialSubclass(c_functools.partial):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200444 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100445
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000446class PyPartialSubclass(py_functools.partial):
447 pass
Łukasz Langa6f692512013-06-05 12:20:24 +0200448
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200449@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200450class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200451 if c_functools:
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000452 partial = CPartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000453
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300454 # partial subclasses are not optimized for nested calls
455 test_nested_optimization = None
456
Nick Coghlan457fc9a2016-09-10 20:00:02 +1000457class TestPartialPySubclass(TestPartialPy):
458 partial = PyPartialSubclass
Łukasz Langa6f692512013-06-05 12:20:24 +0200459
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000460class TestPartialMethod(unittest.TestCase):
461
462 class A(object):
463 nothing = functools.partialmethod(capture)
464 positional = functools.partialmethod(capture, 1)
465 keywords = functools.partialmethod(capture, a=2)
466 both = functools.partialmethod(capture, 3, b=4)
467
468 nested = functools.partialmethod(positional, 5)
469
470 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
471
472 static = functools.partialmethod(staticmethod(capture), 8)
473 cls = functools.partialmethod(classmethod(capture), d=9)
474
475 a = A()
476
477 def test_arg_combinations(self):
478 self.assertEqual(self.a.nothing(), ((self.a,), {}))
479 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
480 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
481 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
482
483 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
484 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
485 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
486 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
487
488 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
489 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
490 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
491 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
492
493 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
494 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
495 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
496 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
497
498 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
499
500 def test_nested(self):
501 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
502 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
503 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
504 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
505
506 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
507
508 def test_over_partial(self):
509 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
510 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
511 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
512 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
513
514 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
515
516 def test_bound_method_introspection(self):
517 obj = self.a
518 self.assertIs(obj.both.__self__, obj)
519 self.assertIs(obj.nested.__self__, obj)
520 self.assertIs(obj.over_partial.__self__, obj)
521 self.assertIs(obj.cls.__self__, self.A)
522 self.assertIs(self.A.cls.__self__, self.A)
523
524 def test_unbound_method_retrieval(self):
525 obj = self.A
526 self.assertFalse(hasattr(obj.both, "__self__"))
527 self.assertFalse(hasattr(obj.nested, "__self__"))
528 self.assertFalse(hasattr(obj.over_partial, "__self__"))
529 self.assertFalse(hasattr(obj.static, "__self__"))
530 self.assertFalse(hasattr(self.a.static, "__self__"))
531
532 def test_descriptors(self):
533 for obj in [self.A, self.a]:
534 with self.subTest(obj=obj):
535 self.assertEqual(obj.static(), ((8,), {}))
536 self.assertEqual(obj.static(5), ((8, 5), {}))
537 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
538 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
539
540 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
541 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
542 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
543 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
544
545 def test_overriding_keywords(self):
546 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
547 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
548
549 def test_invalid_args(self):
550 with self.assertRaises(TypeError):
551 class B(object):
552 method = functools.partialmethod(None, 1)
553
554 def test_repr(self):
555 self.assertEqual(repr(vars(self.A)['both']),
556 'functools.partialmethod({}, 3, b=4)'.format(capture))
557
558 def test_abstract(self):
559 class Abstract(abc.ABCMeta):
560
561 @abc.abstractmethod
562 def add(self, x, y):
563 pass
564
565 add5 = functools.partialmethod(add, 5)
566
567 self.assertTrue(Abstract.add.__isabstractmethod__)
568 self.assertTrue(Abstract.add5.__isabstractmethod__)
569
570 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
571 self.assertFalse(getattr(func, '__isabstractmethod__', False))
572
573
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000574class TestUpdateWrapper(unittest.TestCase):
575
576 def check_wrapper(self, wrapper, wrapped,
577 assigned=functools.WRAPPER_ASSIGNMENTS,
578 updated=functools.WRAPPER_UPDATES):
579 # Check attributes were assigned
580 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000581 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000582 # Check attributes were updated
583 for name in updated:
584 wrapper_attr = getattr(wrapper, name)
585 wrapped_attr = getattr(wrapped, name)
586 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000587 if name == "__dict__" and key == "__wrapped__":
588 # __wrapped__ is overwritten by the update code
589 continue
590 self.assertIs(wrapped_attr[key], wrapper_attr[key])
591 # Check __wrapped__
592 self.assertIs(wrapper.__wrapped__, wrapped)
593
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000594
R. David Murray378c0cf2010-02-24 01:46:21 +0000595 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000596 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000597 """This is a test"""
598 pass
599 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000600 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000601 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000602 pass
603 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000604 return wrapper, f
605
606 def test_default_update(self):
607 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000608 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000609 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000610 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600611 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000612 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000613 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
614 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000615
R. David Murray378c0cf2010-02-24 01:46:21 +0000616 @unittest.skipIf(sys.flags.optimize >= 2,
617 "Docstrings are omitted with -O2 and above")
618 def test_default_update_doc(self):
619 wrapper, f = self._default_update()
620 self.assertEqual(wrapper.__doc__, 'This is a test')
621
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000622 def test_no_update(self):
623 def f():
624 """This is a test"""
625 pass
626 f.attr = 'This is also a test'
627 def wrapper():
628 pass
629 functools.update_wrapper(wrapper, f, (), ())
630 self.check_wrapper(wrapper, f, (), ())
631 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600632 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000633 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000634 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000635 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000636
637 def test_selective_update(self):
638 def f():
639 pass
640 f.attr = 'This is a different test'
641 f.dict_attr = dict(a=1, b=2, c=3)
642 def wrapper():
643 pass
644 wrapper.dict_attr = {}
645 assign = ('attr',)
646 update = ('dict_attr',)
647 functools.update_wrapper(wrapper, f, assign, update)
648 self.check_wrapper(wrapper, f, assign, update)
649 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600650 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000651 self.assertEqual(wrapper.__doc__, None)
652 self.assertEqual(wrapper.attr, 'This is a different test')
653 self.assertEqual(wrapper.dict_attr, f.dict_attr)
654
Nick Coghlan98876832010-08-17 06:17:18 +0000655 def test_missing_attributes(self):
656 def f():
657 pass
658 def wrapper():
659 pass
660 wrapper.dict_attr = {}
661 assign = ('attr',)
662 update = ('dict_attr',)
663 # Missing attributes on wrapped object are ignored
664 functools.update_wrapper(wrapper, f, assign, update)
665 self.assertNotIn('attr', wrapper.__dict__)
666 self.assertEqual(wrapper.dict_attr, {})
667 # Wrapper must have expected attributes for updating
668 del wrapper.dict_attr
669 with self.assertRaises(AttributeError):
670 functools.update_wrapper(wrapper, f, assign, update)
671 wrapper.dict_attr = 1
672 with self.assertRaises(AttributeError):
673 functools.update_wrapper(wrapper, f, assign, update)
674
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200675 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000676 @unittest.skipIf(sys.flags.optimize >= 2,
677 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000678 def test_builtin_update(self):
679 # Test for bug #1576241
680 def wrapper():
681 pass
682 functools.update_wrapper(wrapper, max)
683 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000684 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000685 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000686
Łukasz Langa6f692512013-06-05 12:20:24 +0200687
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000688class TestWraps(TestUpdateWrapper):
689
R. David Murray378c0cf2010-02-24 01:46:21 +0000690 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000691 def f():
692 """This is a test"""
693 pass
694 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000695 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000696 @functools.wraps(f)
697 def wrapper():
698 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600699 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000700
701 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600702 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000703 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000704 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600705 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000706 self.assertEqual(wrapper.attr, 'This is also a test')
707
Antoine Pitroub5b37142012-11-13 21:35:40 +0100708 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000709 "Docstrings are omitted with -O2 and above")
710 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600711 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000712 self.assertEqual(wrapper.__doc__, 'This is a test')
713
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000714 def test_no_update(self):
715 def f():
716 """This is a test"""
717 pass
718 f.attr = 'This is also a test'
719 @functools.wraps(f, (), ())
720 def wrapper():
721 pass
722 self.check_wrapper(wrapper, f, (), ())
723 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600724 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000725 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000726 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000727
728 def test_selective_update(self):
729 def f():
730 pass
731 f.attr = 'This is a different test'
732 f.dict_attr = dict(a=1, b=2, c=3)
733 def add_dict_attr(f):
734 f.dict_attr = {}
735 return f
736 assign = ('attr',)
737 update = ('dict_attr',)
738 @functools.wraps(f, assign, update)
739 @add_dict_attr
740 def wrapper():
741 pass
742 self.check_wrapper(wrapper, f, assign, update)
743 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600744 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000745 self.assertEqual(wrapper.__doc__, None)
746 self.assertEqual(wrapper.attr, 'This is a different test')
747 self.assertEqual(wrapper.dict_attr, f.dict_attr)
748
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000749
madman-bobe25d5fc2018-10-25 15:02:10 +0100750class TestReduce:
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000751 def test_reduce(self):
752 class Squares:
753 def __init__(self, max):
754 self.max = max
755 self.sofar = []
756
757 def __len__(self):
758 return len(self.sofar)
759
760 def __getitem__(self, i):
761 if not 0 <= i < self.max: raise IndexError
762 n = len(self.sofar)
763 while n <= i:
764 self.sofar.append(n*n)
765 n += 1
766 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000767 def add(x, y):
768 return x + y
madman-bobe25d5fc2018-10-25 15:02:10 +0100769 self.assertEqual(self.reduce(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000770 self.assertEqual(
madman-bobe25d5fc2018-10-25 15:02:10 +0100771 self.reduce(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000772 ['a','c','d','w']
773 )
madman-bobe25d5fc2018-10-25 15:02:10 +0100774 self.assertEqual(self.reduce(lambda x, y: x*y, range(2,8), 1), 5040)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000775 self.assertEqual(
madman-bobe25d5fc2018-10-25 15:02:10 +0100776 self.reduce(lambda x, y: x*y, range(2,21), 1),
Guido van Rossume2a383d2007-01-15 16:59:06 +0000777 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000778 )
madman-bobe25d5fc2018-10-25 15:02:10 +0100779 self.assertEqual(self.reduce(add, Squares(10)), 285)
780 self.assertEqual(self.reduce(add, Squares(10), 0), 285)
781 self.assertEqual(self.reduce(add, Squares(0), 0), 0)
782 self.assertRaises(TypeError, self.reduce)
783 self.assertRaises(TypeError, self.reduce, 42, 42)
784 self.assertRaises(TypeError, self.reduce, 42, 42, 42)
785 self.assertEqual(self.reduce(42, "1"), "1") # func is never called with one item
786 self.assertEqual(self.reduce(42, "", "1"), "1") # func is never called with one item
787 self.assertRaises(TypeError, self.reduce, 42, (42, 42))
788 self.assertRaises(TypeError, self.reduce, add, []) # arg 2 must not be empty sequence with no initial value
789 self.assertRaises(TypeError, self.reduce, add, "")
790 self.assertRaises(TypeError, self.reduce, add, ())
791 self.assertRaises(TypeError, self.reduce, add, object())
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000792
793 class TestFailingIter:
794 def __iter__(self):
795 raise RuntimeError
madman-bobe25d5fc2018-10-25 15:02:10 +0100796 self.assertRaises(RuntimeError, self.reduce, add, TestFailingIter())
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000797
madman-bobe25d5fc2018-10-25 15:02:10 +0100798 self.assertEqual(self.reduce(add, [], None), None)
799 self.assertEqual(self.reduce(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000800
801 class BadSeq:
802 def __getitem__(self, index):
803 raise ValueError
madman-bobe25d5fc2018-10-25 15:02:10 +0100804 self.assertRaises(ValueError, self.reduce, 42, BadSeq())
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000805
806 # Test reduce()'s use of iterators.
807 def test_iterator_usage(self):
808 class SequenceClass:
809 def __init__(self, n):
810 self.n = n
811 def __getitem__(self, i):
812 if 0 <= i < self.n:
813 return i
814 else:
815 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000816
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000817 from operator import add
madman-bobe25d5fc2018-10-25 15:02:10 +0100818 self.assertEqual(self.reduce(add, SequenceClass(5)), 10)
819 self.assertEqual(self.reduce(add, SequenceClass(5), 42), 52)
820 self.assertRaises(TypeError, self.reduce, add, SequenceClass(0))
821 self.assertEqual(self.reduce(add, SequenceClass(0), 42), 42)
822 self.assertEqual(self.reduce(add, SequenceClass(1)), 0)
823 self.assertEqual(self.reduce(add, SequenceClass(1), 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000824
825 d = {"one": 1, "two": 2, "three": 3}
madman-bobe25d5fc2018-10-25 15:02:10 +0100826 self.assertEqual(self.reduce(add, d), "".join(d.keys()))
827
828
829@unittest.skipUnless(c_functools, 'requires the C _functools module')
830class TestReduceC(TestReduce, unittest.TestCase):
831 if c_functools:
832 reduce = c_functools.reduce
833
834
835class TestReducePy(TestReduce, unittest.TestCase):
836 reduce = staticmethod(py_functools.reduce)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000837
Łukasz Langa6f692512013-06-05 12:20:24 +0200838
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200839class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700840
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000841 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700842 def cmp1(x, y):
843 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100844 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700845 self.assertEqual(key(3), key(3))
846 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100847 self.assertGreaterEqual(key(3), key(3))
848
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700849 def cmp2(x, y):
850 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100851 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700852 self.assertEqual(key(4.0), key('4'))
853 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100854 self.assertLessEqual(key(2), key('35'))
855 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700856
857 def test_cmp_to_key_arguments(self):
858 def cmp1(x, y):
859 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100860 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700861 self.assertEqual(key(obj=3), key(obj=3))
862 self.assertGreater(key(obj=3), key(obj=1))
863 with self.assertRaises((TypeError, AttributeError)):
864 key(3) > 1 # rhs is not a K object
865 with self.assertRaises((TypeError, AttributeError)):
866 1 < key(3) # lhs is not a K object
867 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100868 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700869 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200870 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100871 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700872 with self.assertRaises(TypeError):
873 key() # too few args
874 with self.assertRaises(TypeError):
875 key(None, None) # too many args
876
877 def test_bad_cmp(self):
878 def cmp1(x, y):
879 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100880 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700881 with self.assertRaises(ZeroDivisionError):
882 key(3) > key(1)
883
884 class BadCmp:
885 def __lt__(self, other):
886 raise ZeroDivisionError
887 def cmp1(x, y):
888 return BadCmp()
889 with self.assertRaises(ZeroDivisionError):
890 key(3) > key(1)
891
892 def test_obj_field(self):
893 def cmp1(x, y):
894 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100895 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700896 self.assertEqual(key(50).obj, 50)
897
898 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000899 def mycmp(x, y):
900 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100901 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000902 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000903
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700904 def test_sort_int_str(self):
905 def mycmp(x, y):
906 x, y = int(x), int(y)
907 return (x > y) - (x < y)
908 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100909 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700910 self.assertEqual([int(value) for value in values],
911 [0, 1, 1, 2, 3, 4, 5, 7, 10])
912
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000913 def test_hash(self):
914 def mycmp(x, y):
915 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100916 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000917 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700918 self.assertRaises(TypeError, hash, k)
Serhiy Storchaka2e576f52017-04-24 09:05:00 +0300919 self.assertNotIsInstance(k, collections.abc.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000920
Łukasz Langa6f692512013-06-05 12:20:24 +0200921
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200922@unittest.skipUnless(c_functools, 'requires the C _functools module')
923class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
924 if c_functools:
925 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100926
Łukasz Langa6f692512013-06-05 12:20:24 +0200927
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200928class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100929 cmp_to_key = staticmethod(py_functools.cmp_to_key)
930
Łukasz Langa6f692512013-06-05 12:20:24 +0200931
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000932class TestTotalOrdering(unittest.TestCase):
933
934 def test_total_ordering_lt(self):
935 @functools.total_ordering
936 class A:
937 def __init__(self, value):
938 self.value = value
939 def __lt__(self, other):
940 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000941 def __eq__(self, other):
942 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000943 self.assertTrue(A(1) < A(2))
944 self.assertTrue(A(2) > A(1))
945 self.assertTrue(A(1) <= A(2))
946 self.assertTrue(A(2) >= A(1))
947 self.assertTrue(A(2) <= A(2))
948 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000949 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000950
951 def test_total_ordering_le(self):
952 @functools.total_ordering
953 class A:
954 def __init__(self, value):
955 self.value = value
956 def __le__(self, other):
957 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000958 def __eq__(self, other):
959 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000960 self.assertTrue(A(1) < A(2))
961 self.assertTrue(A(2) > A(1))
962 self.assertTrue(A(1) <= A(2))
963 self.assertTrue(A(2) >= A(1))
964 self.assertTrue(A(2) <= A(2))
965 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000966 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000967
968 def test_total_ordering_gt(self):
969 @functools.total_ordering
970 class A:
971 def __init__(self, value):
972 self.value = value
973 def __gt__(self, other):
974 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000975 def __eq__(self, other):
976 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000977 self.assertTrue(A(1) < A(2))
978 self.assertTrue(A(2) > A(1))
979 self.assertTrue(A(1) <= A(2))
980 self.assertTrue(A(2) >= A(1))
981 self.assertTrue(A(2) <= A(2))
982 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000983 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000984
985 def test_total_ordering_ge(self):
986 @functools.total_ordering
987 class A:
988 def __init__(self, value):
989 self.value = value
990 def __ge__(self, other):
991 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000992 def __eq__(self, other):
993 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000994 self.assertTrue(A(1) < A(2))
995 self.assertTrue(A(2) > A(1))
996 self.assertTrue(A(1) <= A(2))
997 self.assertTrue(A(2) >= A(1))
998 self.assertTrue(A(2) <= A(2))
999 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +10001000 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +00001001
1002 def test_total_ordering_no_overwrite(self):
1003 # new methods should not overwrite existing
1004 @functools.total_ordering
1005 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +00001006 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +00001007 self.assertTrue(A(1) < A(2))
1008 self.assertTrue(A(2) > A(1))
1009 self.assertTrue(A(1) <= A(2))
1010 self.assertTrue(A(2) >= A(1))
1011 self.assertTrue(A(2) <= A(2))
1012 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001013
Benjamin Peterson42ebee32010-04-11 01:43:16 +00001014 def test_no_operations_defined(self):
1015 with self.assertRaises(ValueError):
1016 @functools.total_ordering
1017 class A:
1018 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001019
Nick Coghlanf05d9812013-10-02 00:02:03 +10001020 def test_type_error_when_not_implemented(self):
1021 # bug 10042; ensure stack overflow does not occur
1022 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001023 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +10001024 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001025 def __init__(self, value):
1026 self.value = value
1027 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001028 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001029 return self.value == other.value
1030 return False
1031 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +10001032 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001033 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +10001034 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +00001035
Nick Coghlanf05d9812013-10-02 00:02:03 +10001036 @functools.total_ordering
1037 class ImplementsGreaterThan:
1038 def __init__(self, value):
1039 self.value = value
1040 def __eq__(self, other):
1041 if isinstance(other, ImplementsGreaterThan):
1042 return self.value == other.value
1043 return False
1044 def __gt__(self, other):
1045 if isinstance(other, ImplementsGreaterThan):
1046 return self.value > other.value
1047 return NotImplemented
1048
1049 @functools.total_ordering
1050 class ImplementsLessThanEqualTo:
1051 def __init__(self, value):
1052 self.value = value
1053 def __eq__(self, other):
1054 if isinstance(other, ImplementsLessThanEqualTo):
1055 return self.value == other.value
1056 return False
1057 def __le__(self, other):
1058 if isinstance(other, ImplementsLessThanEqualTo):
1059 return self.value <= other.value
1060 return NotImplemented
1061
1062 @functools.total_ordering
1063 class ImplementsGreaterThanEqualTo:
1064 def __init__(self, value):
1065 self.value = value
1066 def __eq__(self, other):
1067 if isinstance(other, ImplementsGreaterThanEqualTo):
1068 return self.value == other.value
1069 return False
1070 def __ge__(self, other):
1071 if isinstance(other, ImplementsGreaterThanEqualTo):
1072 return self.value >= other.value
1073 return NotImplemented
1074
1075 @functools.total_ordering
1076 class ComparatorNotImplemented:
1077 def __init__(self, value):
1078 self.value = value
1079 def __eq__(self, other):
1080 if isinstance(other, ComparatorNotImplemented):
1081 return self.value == other.value
1082 return False
1083 def __lt__(self, other):
1084 return NotImplemented
1085
1086 with self.subTest("LT < 1"), self.assertRaises(TypeError):
1087 ImplementsLessThan(-1) < 1
1088
1089 with self.subTest("LT < LE"), self.assertRaises(TypeError):
1090 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
1091
1092 with self.subTest("LT < GT"), self.assertRaises(TypeError):
1093 ImplementsLessThan(1) < ImplementsGreaterThan(1)
1094
1095 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
1096 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
1097
1098 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
1099 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
1100
1101 with self.subTest("GT > GE"), self.assertRaises(TypeError):
1102 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
1103
1104 with self.subTest("GT > LT"), self.assertRaises(TypeError):
1105 ImplementsGreaterThan(5) > ImplementsLessThan(5)
1106
1107 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
1108 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
1109
1110 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
1111 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
1112
1113 with self.subTest("GE when equal"):
1114 a = ComparatorNotImplemented(8)
1115 b = ComparatorNotImplemented(8)
1116 self.assertEqual(a, b)
1117 with self.assertRaises(TypeError):
1118 a >= b
1119
1120 with self.subTest("LE when equal"):
1121 a = ComparatorNotImplemented(9)
1122 b = ComparatorNotImplemented(9)
1123 self.assertEqual(a, b)
1124 with self.assertRaises(TypeError):
1125 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +02001126
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001127 def test_pickle(self):
Serhiy Storchaka92bb90a2016-09-22 11:39:25 +03001128 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001129 for name in '__lt__', '__gt__', '__le__', '__ge__':
1130 with self.subTest(method=name, proto=proto):
1131 method = getattr(Orderable_LT, name)
1132 method_copy = pickle.loads(pickle.dumps(method, proto))
1133 self.assertIs(method_copy, method)
1134
1135@functools.total_ordering
1136class Orderable_LT:
1137 def __init__(self, value):
1138 self.value = value
1139 def __lt__(self, other):
1140 return self.value < other.value
1141 def __eq__(self, other):
1142 return self.value == other.value
1143
1144
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001145class TestLRU:
Georg Brandl2e7346a2010-07-31 18:09:23 +00001146
1147 def test_lru(self):
1148 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001149 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001150 f = self.module.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001151 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001152 self.assertEqual(maxsize, 20)
1153 self.assertEqual(currsize, 0)
1154 self.assertEqual(hits, 0)
1155 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001156
1157 domain = range(5)
1158 for i in range(1000):
1159 x, y = choice(domain), choice(domain)
1160 actual = f(x, y)
1161 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +00001162 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001163 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001164 self.assertTrue(hits > misses)
1165 self.assertEqual(hits + misses, 1000)
1166 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001167
Raymond Hettinger02566ec2010-09-04 22:46:06 +00001168 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +00001169 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001170 self.assertEqual(hits, 0)
1171 self.assertEqual(misses, 0)
1172 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001173 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001174 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001175 self.assertEqual(hits, 0)
1176 self.assertEqual(misses, 1)
1177 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001178
Nick Coghlan98876832010-08-17 06:17:18 +00001179 # Test bypassing the cache
1180 self.assertIs(f.__wrapped__, orig)
1181 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001182 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001183 self.assertEqual(hits, 0)
1184 self.assertEqual(misses, 1)
1185 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +00001186
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001187 # test size zero (which means "never-cache")
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001188 @self.module.lru_cache(0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001189 def f():
1190 nonlocal f_cnt
1191 f_cnt += 1
1192 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001193 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001194 f_cnt = 0
1195 for i in range(5):
1196 self.assertEqual(f(), 20)
1197 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001198 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001199 self.assertEqual(hits, 0)
1200 self.assertEqual(misses, 5)
1201 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001202
1203 # test size one
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001204 @self.module.lru_cache(1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001205 def f():
1206 nonlocal f_cnt
1207 f_cnt += 1
1208 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001209 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001210 f_cnt = 0
1211 for i in range(5):
1212 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001213 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001214 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001215 self.assertEqual(hits, 4)
1216 self.assertEqual(misses, 1)
1217 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001218
Raymond Hettingerf3098282010-08-15 03:30:45 +00001219 # test size two
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001220 @self.module.lru_cache(2)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001221 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001222 nonlocal f_cnt
1223 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +00001224 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +00001225 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001226 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +00001227 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1228 # * * * *
1229 self.assertEqual(f(x), x*10)
1230 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001231 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001232 self.assertEqual(hits, 12)
1233 self.assertEqual(misses, 4)
1234 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001235
Raymond Hettingerd8080c02019-01-26 03:02:00 -05001236 def test_lru_bug_35780(self):
1237 # C version of the lru_cache was not checking to see if
1238 # the user function call has already modified the cache
1239 # (this arises in recursive calls and in multi-threading).
1240 # This cause the cache to have orphan links not referenced
1241 # by the cache dictionary.
1242
1243 once = True # Modified by f(x) below
1244
1245 @self.module.lru_cache(maxsize=10)
1246 def f(x):
1247 nonlocal once
1248 rv = f'.{x}.'
1249 if x == 20 and once:
1250 once = False
1251 rv = f(x)
1252 return rv
1253
1254 # Fill the cache
1255 for x in range(15):
1256 self.assertEqual(f(x), f'.{x}.')
1257 self.assertEqual(f.cache_info().currsize, 10)
1258
1259 # Make a recursive call and make sure the cache remains full
1260 self.assertEqual(f(20), '.20.')
1261 self.assertEqual(f.cache_info().currsize, 10)
1262
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001263 def test_lru_hash_only_once(self):
1264 # To protect against weird reentrancy bugs and to improve
1265 # efficiency when faced with slow __hash__ methods, the
1266 # LRU cache guarantees that it will only call __hash__
1267 # only once per use as an argument to the cached function.
1268
1269 @self.module.lru_cache(maxsize=1)
1270 def f(x, y):
1271 return x * 3 + y
1272
1273 # Simulate the integer 5
1274 mock_int = unittest.mock.Mock()
1275 mock_int.__mul__ = unittest.mock.Mock(return_value=15)
1276 mock_int.__hash__ = unittest.mock.Mock(return_value=999)
1277
1278 # Add to cache: One use as an argument gives one call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001279 self.assertEqual(f(mock_int, 1), 16)
1280 self.assertEqual(mock_int.__hash__.call_count, 1)
1281 self.assertEqual(f.cache_info(), (0, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001282
1283 # Cache hit: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001284 self.assertEqual(f(mock_int, 1), 16)
1285 self.assertEqual(mock_int.__hash__.call_count, 2)
1286 self.assertEqual(f.cache_info(), (1, 1, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001287
Ville Skyttä49b27342017-08-03 09:00:59 +03001288 # Cache eviction: No use as an argument gives no additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001289 self.assertEqual(f(6, 2), 20)
1290 self.assertEqual(mock_int.__hash__.call_count, 2)
1291 self.assertEqual(f.cache_info(), (1, 2, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001292
1293 # Cache miss: One use as an argument gives one additional call
Raymond Hettinger5eed36f2017-01-07 20:53:09 -08001294 self.assertEqual(f(mock_int, 1), 16)
1295 self.assertEqual(mock_int.__hash__.call_count, 3)
1296 self.assertEqual(f.cache_info(), (1, 3, 1, 1))
Raymond Hettingerd191ef22017-01-07 20:44:48 -08001297
Raymond Hettingeraf56e0e2016-12-16 13:57:40 -08001298 def test_lru_reentrancy_with_len(self):
1299 # Test to make sure the LRU cache code isn't thrown-off by
1300 # caching the built-in len() function. Since len() can be
1301 # cached, we shouldn't use it inside the lru code itself.
1302 old_len = builtins.len
1303 try:
1304 builtins.len = self.module.lru_cache(4)(len)
1305 for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
1306 self.assertEqual(len('abcdefghijklmn'[:i]), i)
1307 finally:
1308 builtins.len = old_len
1309
Raymond Hettinger605a4472017-01-09 07:50:19 -08001310 def test_lru_star_arg_handling(self):
1311 # Test regression that arose in ea064ff3c10f
1312 @functools.lru_cache()
1313 def f(*args):
1314 return args
1315
1316 self.assertEqual(f(1, 2), (1, 2))
1317 self.assertEqual(f((1, 2)), ((1, 2),))
1318
Yury Selivanov46a02db2016-11-09 18:55:45 -05001319 def test_lru_type_error(self):
1320 # Regression test for issue #28653.
1321 # lru_cache was leaking when one of the arguments
1322 # wasn't cacheable.
1323
1324 @functools.lru_cache(maxsize=None)
1325 def infinite_cache(o):
1326 pass
1327
1328 @functools.lru_cache(maxsize=10)
1329 def limited_cache(o):
1330 pass
1331
1332 with self.assertRaises(TypeError):
1333 infinite_cache([])
1334
1335 with self.assertRaises(TypeError):
1336 limited_cache([])
1337
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001338 def test_lru_with_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001339 @self.module.lru_cache(maxsize=None)
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001340 def fib(n):
1341 if n < 2:
1342 return n
1343 return fib(n-1) + fib(n-2)
1344 self.assertEqual([fib(n) for n in range(16)],
1345 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1346 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001347 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001348 fib.cache_clear()
1349 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001350 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1351
1352 def test_lru_with_maxsize_negative(self):
1353 @self.module.lru_cache(maxsize=-10)
1354 def eq(n):
1355 return n
1356 for i in (0, 1):
1357 self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1358 self.assertEqual(eq.cache_info(),
Raymond Hettingerd8080c02019-01-26 03:02:00 -05001359 self.module._CacheInfo(hits=0, misses=300, maxsize=0, currsize=0))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001360
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001361 def test_lru_with_exceptions(self):
1362 # Verify that user_function exceptions get passed through without
1363 # creating a hard-to-read chained exception.
1364 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +01001365 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001366 @self.module.lru_cache(maxsize)
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001367 def func(i):
1368 return 'abc'[i]
1369 self.assertEqual(func(0), 'a')
1370 with self.assertRaises(IndexError) as cm:
1371 func(15)
1372 self.assertIsNone(cm.exception.__context__)
1373 # Verify that the previous exception did not result in a cached entry
1374 with self.assertRaises(IndexError):
1375 func(15)
1376
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001377 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001378 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001379 @self.module.lru_cache(maxsize=maxsize, typed=True)
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001380 def square(x):
1381 return x * x
1382 self.assertEqual(square(3), 9)
1383 self.assertEqual(type(square(3)), type(9))
1384 self.assertEqual(square(3.0), 9.0)
1385 self.assertEqual(type(square(3.0)), type(9.0))
1386 self.assertEqual(square(x=3), 9)
1387 self.assertEqual(type(square(x=3)), type(9))
1388 self.assertEqual(square(x=3.0), 9.0)
1389 self.assertEqual(type(square(x=3.0)), type(9.0))
1390 self.assertEqual(square.cache_info().hits, 4)
1391 self.assertEqual(square.cache_info().misses, 4)
1392
Antoine Pitroub5b37142012-11-13 21:35:40 +01001393 def test_lru_with_keyword_args(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001394 @self.module.lru_cache()
Antoine Pitroub5b37142012-11-13 21:35:40 +01001395 def fib(n):
1396 if n < 2:
1397 return n
1398 return fib(n=n-1) + fib(n=n-2)
1399 self.assertEqual(
1400 [fib(n=number) for number in range(16)],
1401 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1402 )
1403 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001404 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001405 fib.cache_clear()
1406 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001407 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001408
1409 def test_lru_with_keyword_args_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001410 @self.module.lru_cache(maxsize=None)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001411 def fib(n):
1412 if n < 2:
1413 return n
1414 return fib(n=n-1) + fib(n=n-2)
1415 self.assertEqual([fib(n=number) for number in range(16)],
1416 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1417 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001418 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001419 fib.cache_clear()
1420 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001421 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1422
Raymond Hettinger4ee39142017-01-08 17:28:20 -08001423 def test_kwargs_order(self):
1424 # PEP 468: Preserving Keyword Argument Order
1425 @self.module.lru_cache(maxsize=10)
1426 def f(**kwargs):
1427 return list(kwargs.items())
1428 self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)])
1429 self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)])
1430 self.assertEqual(f.cache_info(),
1431 self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2))
1432
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001433 def test_lru_cache_decoration(self):
1434 def f(zomg: 'zomg_annotation'):
1435 """f doc string"""
1436 return 42
1437 g = self.module.lru_cache()(f)
1438 for attr in self.module.WRAPPER_ASSIGNMENTS:
1439 self.assertEqual(getattr(g, attr), getattr(f, attr))
1440
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001441 def test_lru_cache_threaded(self):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001442 n, m = 5, 11
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001443 def orig(x, y):
1444 return 3 * x + y
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001445 f = self.module.lru_cache(maxsize=n*m)(orig)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001446 hits, misses, maxsize, currsize = f.cache_info()
1447 self.assertEqual(currsize, 0)
1448
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001449 start = threading.Event()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001450 def full(k):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001451 start.wait(10)
1452 for _ in range(m):
Serhiy Storchaka391af752015-06-08 12:44:18 +03001453 self.assertEqual(f(k, 0), orig(k, 0))
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001454
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001455 def clear():
1456 start.wait(10)
1457 for _ in range(2*m):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001458 f.cache_clear()
1459
1460 orig_si = sys.getswitchinterval()
Xavier de Gaye7522ef42016-12-08 11:06:56 +01001461 support.setswitchinterval(1e-6)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001462 try:
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001463 # create n threads in order to fill cache
Serhiy Storchaka391af752015-06-08 12:44:18 +03001464 threads = [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001465 for k in range(n)]
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001466 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001467 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001468
1469 hits, misses, maxsize, currsize = f.cache_info()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001470 if self.module is py_functools:
1471 # XXX: Why can be not equal?
1472 self.assertLessEqual(misses, n)
1473 self.assertLessEqual(hits, m*n - misses)
1474 else:
1475 self.assertEqual(misses, n)
1476 self.assertEqual(hits, m*n - misses)
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001477 self.assertEqual(currsize, n)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001478
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001479 # create n threads in order to fill cache and 1 to clear it
1480 threads = [threading.Thread(target=clear)]
Serhiy Storchaka391af752015-06-08 12:44:18 +03001481 threads += [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001482 for k in range(n)]
1483 start.clear()
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001484 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001485 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001486 finally:
1487 sys.setswitchinterval(orig_si)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001488
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001489 def test_lru_cache_threaded2(self):
1490 # Simultaneous call with the same arguments
1491 n, m = 5, 7
1492 start = threading.Barrier(n+1)
1493 pause = threading.Barrier(n+1)
1494 stop = threading.Barrier(n+1)
1495 @self.module.lru_cache(maxsize=m*n)
1496 def f(x):
1497 pause.wait(10)
1498 return 3 * x
1499 self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1500 def test():
1501 for i in range(m):
1502 start.wait(10)
1503 self.assertEqual(f(i), 3 * i)
1504 stop.wait(10)
1505 threads = [threading.Thread(target=test) for k in range(n)]
1506 with support.start_threads(threads):
1507 for i in range(m):
1508 start.wait(10)
1509 stop.reset()
1510 pause.wait(10)
1511 start.reset()
1512 stop.wait(10)
1513 pause.reset()
1514 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1515
Serhiy Storchaka67796522017-01-12 18:34:33 +02001516 def test_lru_cache_threaded3(self):
1517 @self.module.lru_cache(maxsize=2)
1518 def f(x):
1519 time.sleep(.01)
1520 return 3 * x
1521 def test(i, x):
1522 with self.subTest(thread=i):
1523 self.assertEqual(f(x), 3 * x, i)
1524 threads = [threading.Thread(target=test, args=(i, v))
1525 for i, v in enumerate([1, 2, 2, 3, 2])]
1526 with support.start_threads(threads):
1527 pass
1528
Raymond Hettinger03923422013-03-04 02:52:50 -05001529 def test_need_for_rlock(self):
1530 # This will deadlock on an LRU cache that uses a regular lock
1531
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001532 @self.module.lru_cache(maxsize=10)
Raymond Hettinger03923422013-03-04 02:52:50 -05001533 def test_func(x):
1534 'Used to demonstrate a reentrant lru_cache call within a single thread'
1535 return x
1536
1537 class DoubleEq:
1538 'Demonstrate a reentrant lru_cache call within a single thread'
1539 def __init__(self, x):
1540 self.x = x
1541 def __hash__(self):
1542 return self.x
1543 def __eq__(self, other):
1544 if self.x == 2:
1545 test_func(DoubleEq(1))
1546 return self.x == other.x
1547
1548 test_func(DoubleEq(1)) # Load the cache
1549 test_func(DoubleEq(2)) # Load the cache
1550 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1551 DoubleEq(2)) # Verify the correct return value
1552
Raymond Hettinger4d588972014-08-12 12:44:52 -07001553 def test_early_detection_of_bad_call(self):
1554 # Issue #22184
1555 with self.assertRaises(TypeError):
1556 @functools.lru_cache
1557 def f():
1558 pass
1559
Serhiy Storchakae7070f02015-06-08 11:19:24 +03001560 def test_lru_method(self):
1561 class X(int):
1562 f_cnt = 0
1563 @self.module.lru_cache(2)
1564 def f(self, x):
1565 self.f_cnt += 1
1566 return x*10+self
1567 a = X(5)
1568 b = X(5)
1569 c = X(7)
1570 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1571
1572 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1573 self.assertEqual(a.f(x), x*10 + 5)
1574 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1575 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1576
1577 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1578 self.assertEqual(b.f(x), x*10 + 5)
1579 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1580 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1581
1582 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1583 self.assertEqual(c.f(x), x*10 + 7)
1584 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1585 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1586
1587 self.assertEqual(a.f.cache_info(), X.f.cache_info())
1588 self.assertEqual(b.f.cache_info(), X.f.cache_info())
1589 self.assertEqual(c.f.cache_info(), X.f.cache_info())
1590
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001591 def test_pickle(self):
1592 cls = self.__class__
1593 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1594 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1595 with self.subTest(proto=proto, func=f):
1596 f_copy = pickle.loads(pickle.dumps(f, proto))
1597 self.assertIs(f_copy, f)
1598
1599 def test_copy(self):
1600 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001601 def orig(x, y):
1602 return 3 * x + y
1603 part = self.module.partial(orig, 2)
1604 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1605 self.module.lru_cache(2)(part))
1606 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001607 with self.subTest(func=f):
1608 f_copy = copy.copy(f)
1609 self.assertIs(f_copy, f)
1610
1611 def test_deepcopy(self):
1612 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001613 def orig(x, y):
1614 return 3 * x + y
1615 part = self.module.partial(orig, 2)
1616 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1617 self.module.lru_cache(2)(part))
1618 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001619 with self.subTest(func=f):
1620 f_copy = copy.deepcopy(f)
1621 self.assertIs(f_copy, f)
1622
1623
1624@py_functools.lru_cache()
1625def py_cached_func(x, y):
1626 return 3 * x + y
1627
1628@c_functools.lru_cache()
1629def c_cached_func(x, y):
1630 return 3 * x + y
1631
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001632
1633class TestLRUPy(TestLRU, unittest.TestCase):
1634 module = py_functools
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001635 cached_func = py_cached_func,
1636
1637 @module.lru_cache()
1638 def cached_meth(self, x, y):
1639 return 3 * x + y
1640
1641 @staticmethod
1642 @module.lru_cache()
1643 def cached_staticmeth(x, y):
1644 return 3 * x + y
1645
1646
1647class TestLRUC(TestLRU, unittest.TestCase):
1648 module = c_functools
1649 cached_func = c_cached_func,
1650
1651 @module.lru_cache()
1652 def cached_meth(self, x, y):
1653 return 3 * x + y
1654
1655 @staticmethod
1656 @module.lru_cache()
1657 def cached_staticmeth(x, y):
1658 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001659
Raymond Hettinger03923422013-03-04 02:52:50 -05001660
Łukasz Langa6f692512013-06-05 12:20:24 +02001661class TestSingleDispatch(unittest.TestCase):
1662 def test_simple_overloads(self):
1663 @functools.singledispatch
1664 def g(obj):
1665 return "base"
1666 def g_int(i):
1667 return "integer"
1668 g.register(int, g_int)
1669 self.assertEqual(g("str"), "base")
1670 self.assertEqual(g(1), "integer")
1671 self.assertEqual(g([1,2,3]), "base")
1672
1673 def test_mro(self):
1674 @functools.singledispatch
1675 def g(obj):
1676 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001677 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001678 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001679 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001680 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001681 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001682 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001683 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001684 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001685 def g_A(a):
1686 return "A"
1687 def g_B(b):
1688 return "B"
1689 g.register(A, g_A)
1690 g.register(B, g_B)
1691 self.assertEqual(g(A()), "A")
1692 self.assertEqual(g(B()), "B")
1693 self.assertEqual(g(C()), "A")
1694 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001695
1696 def test_register_decorator(self):
1697 @functools.singledispatch
1698 def g(obj):
1699 return "base"
1700 @g.register(int)
1701 def g_int(i):
1702 return "int %s" % (i,)
1703 self.assertEqual(g(""), "base")
1704 self.assertEqual(g(12), "int 12")
1705 self.assertIs(g.dispatch(int), g_int)
1706 self.assertIs(g.dispatch(object), g.dispatch(str))
1707 # Note: in the assert above this is not g.
1708 # @singledispatch returns the wrapper.
1709
1710 def test_wrapping_attributes(self):
1711 @functools.singledispatch
1712 def g(obj):
1713 "Simple test"
1714 return "Test"
1715 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02001716 if sys.flags.optimize < 2:
1717 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02001718
1719 @unittest.skipUnless(decimal, 'requires _decimal')
1720 @support.cpython_only
1721 def test_c_classes(self):
1722 @functools.singledispatch
1723 def g(obj):
1724 return "base"
1725 @g.register(decimal.DecimalException)
1726 def _(obj):
1727 return obj.args
1728 subn = decimal.Subnormal("Exponent < Emin")
1729 rnd = decimal.Rounded("Number got rounded")
1730 self.assertEqual(g(subn), ("Exponent < Emin",))
1731 self.assertEqual(g(rnd), ("Number got rounded",))
1732 @g.register(decimal.Subnormal)
1733 def _(obj):
1734 return "Too small to care."
1735 self.assertEqual(g(subn), "Too small to care.")
1736 self.assertEqual(g(rnd), ("Number got rounded",))
1737
1738 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001739 # None of the examples in this test depend on haystack ordering.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001740 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001741 mro = functools._compose_mro
1742 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1743 for haystack in permutations(bases):
1744 m = mro(dict, haystack)
Guido van Rossumf0666942016-08-23 10:47:07 -07001745 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
1746 c.Collection, c.Sized, c.Iterable,
1747 c.Container, object])
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001748 bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict]
Łukasz Langa6f692512013-06-05 12:20:24 +02001749 for haystack in permutations(bases):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001750 m = mro(collections.ChainMap, haystack)
1751 self.assertEqual(m, [collections.ChainMap, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001752 c.Collection, c.Sized, c.Iterable,
1753 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001754
1755 # If there's a generic function with implementations registered for
1756 # both Sized and Container, passing a defaultdict to it results in an
1757 # ambiguous dispatch which will cause a RuntimeError (see
1758 # test_mro_conflicts).
1759 bases = [c.Container, c.Sized, str]
1760 for haystack in permutations(bases):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001761 m = mro(collections.defaultdict, [c.Sized, c.Container, str])
1762 self.assertEqual(m, [collections.defaultdict, dict, c.Sized,
1763 c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001764
1765 # MutableSequence below is registered directly on D. In other words, it
Martin Panter46f50722016-05-26 05:35:26 +00001766 # precedes MutableMapping which means single dispatch will always
Łukasz Langa3720c772013-07-01 16:00:38 +02001767 # choose MutableSequence here.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001768 class D(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02001769 pass
1770 c.MutableSequence.register(D)
1771 bases = [c.MutableSequence, c.MutableMapping]
1772 for haystack in permutations(bases):
1773 m = mro(D, bases)
Guido van Rossumf0666942016-08-23 10:47:07 -07001774 self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001775 collections.defaultdict, dict, c.MutableMapping, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001776 c.Collection, c.Sized, c.Iterable, c.Container,
Łukasz Langa3720c772013-07-01 16:00:38 +02001777 object])
1778
1779 # Container and Callable are registered on different base classes and
1780 # a generic function supporting both should always pick the Callable
1781 # implementation if a C instance is passed.
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001782 class C(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02001783 def __call__(self):
1784 pass
1785 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1786 for haystack in permutations(bases):
1787 m = mro(C, haystack)
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001788 self.assertEqual(m, [C, c.Callable, collections.defaultdict, dict, c.Mapping,
Guido van Rossumf0666942016-08-23 10:47:07 -07001789 c.Collection, c.Sized, c.Iterable,
1790 c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001791
1792 def test_register_abc(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001793 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001794 d = {"a": "b"}
1795 l = [1, 2, 3]
1796 s = {object(), None}
1797 f = frozenset(s)
1798 t = (1, 2, 3)
1799 @functools.singledispatch
1800 def g(obj):
1801 return "base"
1802 self.assertEqual(g(d), "base")
1803 self.assertEqual(g(l), "base")
1804 self.assertEqual(g(s), "base")
1805 self.assertEqual(g(f), "base")
1806 self.assertEqual(g(t), "base")
1807 g.register(c.Sized, lambda obj: "sized")
1808 self.assertEqual(g(d), "sized")
1809 self.assertEqual(g(l), "sized")
1810 self.assertEqual(g(s), "sized")
1811 self.assertEqual(g(f), "sized")
1812 self.assertEqual(g(t), "sized")
1813 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1814 self.assertEqual(g(d), "mutablemapping")
1815 self.assertEqual(g(l), "sized")
1816 self.assertEqual(g(s), "sized")
1817 self.assertEqual(g(f), "sized")
1818 self.assertEqual(g(t), "sized")
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001819 g.register(collections.ChainMap, lambda obj: "chainmap")
Łukasz Langa6f692512013-06-05 12:20:24 +02001820 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1821 self.assertEqual(g(l), "sized")
1822 self.assertEqual(g(s), "sized")
1823 self.assertEqual(g(f), "sized")
1824 self.assertEqual(g(t), "sized")
1825 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1826 self.assertEqual(g(d), "mutablemapping")
1827 self.assertEqual(g(l), "mutablesequence")
1828 self.assertEqual(g(s), "sized")
1829 self.assertEqual(g(f), "sized")
1830 self.assertEqual(g(t), "sized")
1831 g.register(c.MutableSet, lambda obj: "mutableset")
1832 self.assertEqual(g(d), "mutablemapping")
1833 self.assertEqual(g(l), "mutablesequence")
1834 self.assertEqual(g(s), "mutableset")
1835 self.assertEqual(g(f), "sized")
1836 self.assertEqual(g(t), "sized")
1837 g.register(c.Mapping, lambda obj: "mapping")
1838 self.assertEqual(g(d), "mutablemapping") # not specific enough
1839 self.assertEqual(g(l), "mutablesequence")
1840 self.assertEqual(g(s), "mutableset")
1841 self.assertEqual(g(f), "sized")
1842 self.assertEqual(g(t), "sized")
1843 g.register(c.Sequence, lambda obj: "sequence")
1844 self.assertEqual(g(d), "mutablemapping")
1845 self.assertEqual(g(l), "mutablesequence")
1846 self.assertEqual(g(s), "mutableset")
1847 self.assertEqual(g(f), "sized")
1848 self.assertEqual(g(t), "sequence")
1849 g.register(c.Set, lambda obj: "set")
1850 self.assertEqual(g(d), "mutablemapping")
1851 self.assertEqual(g(l), "mutablesequence")
1852 self.assertEqual(g(s), "mutableset")
1853 self.assertEqual(g(f), "set")
1854 self.assertEqual(g(t), "sequence")
1855 g.register(dict, lambda obj: "dict")
1856 self.assertEqual(g(d), "dict")
1857 self.assertEqual(g(l), "mutablesequence")
1858 self.assertEqual(g(s), "mutableset")
1859 self.assertEqual(g(f), "set")
1860 self.assertEqual(g(t), "sequence")
1861 g.register(list, lambda obj: "list")
1862 self.assertEqual(g(d), "dict")
1863 self.assertEqual(g(l), "list")
1864 self.assertEqual(g(s), "mutableset")
1865 self.assertEqual(g(f), "set")
1866 self.assertEqual(g(t), "sequence")
1867 g.register(set, lambda obj: "concrete-set")
1868 self.assertEqual(g(d), "dict")
1869 self.assertEqual(g(l), "list")
1870 self.assertEqual(g(s), "concrete-set")
1871 self.assertEqual(g(f), "set")
1872 self.assertEqual(g(t), "sequence")
1873 g.register(frozenset, lambda obj: "frozen-set")
1874 self.assertEqual(g(d), "dict")
1875 self.assertEqual(g(l), "list")
1876 self.assertEqual(g(s), "concrete-set")
1877 self.assertEqual(g(f), "frozen-set")
1878 self.assertEqual(g(t), "sequence")
1879 g.register(tuple, lambda obj: "tuple")
1880 self.assertEqual(g(d), "dict")
1881 self.assertEqual(g(l), "list")
1882 self.assertEqual(g(s), "concrete-set")
1883 self.assertEqual(g(f), "frozen-set")
1884 self.assertEqual(g(t), "tuple")
1885
Łukasz Langa3720c772013-07-01 16:00:38 +02001886 def test_c3_abc(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001887 c = collections.abc
Łukasz Langa3720c772013-07-01 16:00:38 +02001888 mro = functools._c3_mro
1889 class A(object):
1890 pass
1891 class B(A):
1892 def __len__(self):
1893 return 0 # implies Sized
1894 @c.Container.register
1895 class C(object):
1896 pass
1897 class D(object):
1898 pass # unrelated
1899 class X(D, C, B):
1900 def __call__(self):
1901 pass # implies Callable
1902 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
1903 for abcs in permutations([c.Sized, c.Callable, c.Container]):
1904 self.assertEqual(mro(X, abcs=abcs), expected)
1905 # unrelated ABCs don't appear in the resulting MRO
1906 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
1907 self.assertEqual(mro(X, abcs=many_abcs), expected)
1908
Yury Selivanov77a8cd62015-08-18 14:20:00 -04001909 def test_false_meta(self):
1910 # see issue23572
1911 class MetaA(type):
1912 def __len__(self):
1913 return 0
1914 class A(metaclass=MetaA):
1915 pass
1916 class AA(A):
1917 pass
1918 @functools.singledispatch
1919 def fun(a):
1920 return 'base A'
1921 @fun.register(A)
1922 def _(a):
1923 return 'fun A'
1924 aa = AA()
1925 self.assertEqual(fun(aa), 'fun A')
1926
Łukasz Langa6f692512013-06-05 12:20:24 +02001927 def test_mro_conflicts(self):
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001928 c = collections.abc
Łukasz Langa6f692512013-06-05 12:20:24 +02001929 @functools.singledispatch
1930 def g(arg):
1931 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02001932 class O(c.Sized):
1933 def __len__(self):
1934 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001935 o = O()
1936 self.assertEqual(g(o), "base")
1937 g.register(c.Iterable, lambda arg: "iterable")
1938 g.register(c.Container, lambda arg: "container")
1939 g.register(c.Sized, lambda arg: "sized")
1940 g.register(c.Set, lambda arg: "set")
1941 self.assertEqual(g(o), "sized")
1942 c.Iterable.register(O)
1943 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
1944 c.Container.register(O)
1945 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02001946 c.Set.register(O)
1947 self.assertEqual(g(o), "set") # because c.Set is a subclass of
1948 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02001949 class P:
1950 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02001951 p = P()
1952 self.assertEqual(g(p), "base")
1953 c.Iterable.register(P)
1954 self.assertEqual(g(p), "iterable")
1955 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02001956 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02001957 g(p)
Benjamin Petersonab078e92016-07-13 21:13:29 -07001958 self.assertIn(
1959 str(re_one.exception),
1960 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1961 "or <class 'collections.abc.Iterable'>"),
1962 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
1963 "or <class 'collections.abc.Container'>")),
1964 )
Łukasz Langa6f692512013-06-05 12:20:24 +02001965 class Q(c.Sized):
1966 def __len__(self):
1967 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001968 q = Q()
1969 self.assertEqual(g(q), "sized")
1970 c.Iterable.register(Q)
1971 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
1972 c.Set.register(Q)
1973 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02001974 # c.Sized and c.Iterable
1975 @functools.singledispatch
1976 def h(arg):
1977 return "base"
1978 @h.register(c.Sized)
1979 def _(arg):
1980 return "sized"
1981 @h.register(c.Container)
1982 def _(arg):
1983 return "container"
1984 # Even though Sized and Container are explicit bases of MutableMapping,
1985 # this ABC is implicitly registered on defaultdict which makes all of
1986 # MutableMapping's bases implicit as well from defaultdict's
1987 # perspective.
1988 with self.assertRaises(RuntimeError) as re_two:
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001989 h(collections.defaultdict(lambda: 0))
Benjamin Petersonab078e92016-07-13 21:13:29 -07001990 self.assertIn(
1991 str(re_two.exception),
1992 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1993 "or <class 'collections.abc.Sized'>"),
1994 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1995 "or <class 'collections.abc.Container'>")),
1996 )
Serhiy Storchaka2e576f52017-04-24 09:05:00 +03001997 class R(collections.defaultdict):
Łukasz Langa3720c772013-07-01 16:00:38 +02001998 pass
1999 c.MutableSequence.register(R)
2000 @functools.singledispatch
2001 def i(arg):
2002 return "base"
2003 @i.register(c.MutableMapping)
2004 def _(arg):
2005 return "mapping"
2006 @i.register(c.MutableSequence)
2007 def _(arg):
2008 return "sequence"
2009 r = R()
2010 self.assertEqual(i(r), "sequence")
2011 class S:
2012 pass
2013 class T(S, c.Sized):
2014 def __len__(self):
2015 return 0
2016 t = T()
2017 self.assertEqual(h(t), "sized")
2018 c.Container.register(T)
2019 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
2020 class U:
2021 def __len__(self):
2022 return 0
2023 u = U()
2024 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
2025 # from the existence of __len__()
2026 c.Container.register(U)
2027 # There is no preference for registered versus inferred ABCs.
2028 with self.assertRaises(RuntimeError) as re_three:
2029 h(u)
Benjamin Petersonab078e92016-07-13 21:13:29 -07002030 self.assertIn(
2031 str(re_three.exception),
2032 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
2033 "or <class 'collections.abc.Sized'>"),
2034 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
2035 "or <class 'collections.abc.Container'>")),
2036 )
Łukasz Langa3720c772013-07-01 16:00:38 +02002037 class V(c.Sized, S):
2038 def __len__(self):
2039 return 0
2040 @functools.singledispatch
2041 def j(arg):
2042 return "base"
2043 @j.register(S)
2044 def _(arg):
2045 return "s"
2046 @j.register(c.Container)
2047 def _(arg):
2048 return "container"
2049 v = V()
2050 self.assertEqual(j(v), "s")
2051 c.Container.register(V)
2052 self.assertEqual(j(v), "container") # because it ends up right after
2053 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02002054
2055 def test_cache_invalidation(self):
2056 from collections import UserDict
INADA Naoki9811e802017-09-30 16:13:02 +09002057 import weakref
2058
Łukasz Langa6f692512013-06-05 12:20:24 +02002059 class TracingDict(UserDict):
2060 def __init__(self, *args, **kwargs):
2061 super(TracingDict, self).__init__(*args, **kwargs)
2062 self.set_ops = []
2063 self.get_ops = []
2064 def __getitem__(self, key):
2065 result = self.data[key]
2066 self.get_ops.append(key)
2067 return result
2068 def __setitem__(self, key, value):
2069 self.set_ops.append(key)
2070 self.data[key] = value
2071 def clear(self):
2072 self.data.clear()
INADA Naoki9811e802017-09-30 16:13:02 +09002073
Łukasz Langa6f692512013-06-05 12:20:24 +02002074 td = TracingDict()
INADA Naoki9811e802017-09-30 16:13:02 +09002075 with support.swap_attr(weakref, "WeakKeyDictionary", lambda: td):
2076 c = collections.abc
2077 @functools.singledispatch
2078 def g(arg):
2079 return "base"
2080 d = {}
2081 l = []
2082 self.assertEqual(len(td), 0)
2083 self.assertEqual(g(d), "base")
2084 self.assertEqual(len(td), 1)
2085 self.assertEqual(td.get_ops, [])
2086 self.assertEqual(td.set_ops, [dict])
2087 self.assertEqual(td.data[dict], g.registry[object])
2088 self.assertEqual(g(l), "base")
2089 self.assertEqual(len(td), 2)
2090 self.assertEqual(td.get_ops, [])
2091 self.assertEqual(td.set_ops, [dict, list])
2092 self.assertEqual(td.data[dict], g.registry[object])
2093 self.assertEqual(td.data[list], g.registry[object])
2094 self.assertEqual(td.data[dict], td.data[list])
2095 self.assertEqual(g(l), "base")
2096 self.assertEqual(g(d), "base")
2097 self.assertEqual(td.get_ops, [list, dict])
2098 self.assertEqual(td.set_ops, [dict, list])
2099 g.register(list, lambda arg: "list")
2100 self.assertEqual(td.get_ops, [list, dict])
2101 self.assertEqual(len(td), 0)
2102 self.assertEqual(g(d), "base")
2103 self.assertEqual(len(td), 1)
2104 self.assertEqual(td.get_ops, [list, dict])
2105 self.assertEqual(td.set_ops, [dict, list, dict])
2106 self.assertEqual(td.data[dict],
2107 functools._find_impl(dict, g.registry))
2108 self.assertEqual(g(l), "list")
2109 self.assertEqual(len(td), 2)
2110 self.assertEqual(td.get_ops, [list, dict])
2111 self.assertEqual(td.set_ops, [dict, list, dict, list])
2112 self.assertEqual(td.data[list],
2113 functools._find_impl(list, g.registry))
2114 class X:
2115 pass
2116 c.MutableMapping.register(X) # Will not invalidate the cache,
2117 # not using ABCs yet.
2118 self.assertEqual(g(d), "base")
2119 self.assertEqual(g(l), "list")
2120 self.assertEqual(td.get_ops, [list, dict, dict, list])
2121 self.assertEqual(td.set_ops, [dict, list, dict, list])
2122 g.register(c.Sized, lambda arg: "sized")
2123 self.assertEqual(len(td), 0)
2124 self.assertEqual(g(d), "sized")
2125 self.assertEqual(len(td), 1)
2126 self.assertEqual(td.get_ops, [list, dict, dict, list])
2127 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
2128 self.assertEqual(g(l), "list")
2129 self.assertEqual(len(td), 2)
2130 self.assertEqual(td.get_ops, [list, dict, dict, list])
2131 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2132 self.assertEqual(g(l), "list")
2133 self.assertEqual(g(d), "sized")
2134 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
2135 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2136 g.dispatch(list)
2137 g.dispatch(dict)
2138 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
2139 list, dict])
2140 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
2141 c.MutableSet.register(X) # Will invalidate the cache.
2142 self.assertEqual(len(td), 2) # Stale cache.
2143 self.assertEqual(g(l), "list")
2144 self.assertEqual(len(td), 1)
2145 g.register(c.MutableMapping, lambda arg: "mutablemapping")
2146 self.assertEqual(len(td), 0)
2147 self.assertEqual(g(d), "mutablemapping")
2148 self.assertEqual(len(td), 1)
2149 self.assertEqual(g(l), "list")
2150 self.assertEqual(len(td), 2)
2151 g.register(dict, lambda arg: "dict")
2152 self.assertEqual(g(d), "dict")
2153 self.assertEqual(g(l), "list")
2154 g._clear_cache()
2155 self.assertEqual(len(td), 0)
Łukasz Langa6f692512013-06-05 12:20:24 +02002156
Łukasz Langae5697532017-12-11 13:56:31 -08002157 def test_annotations(self):
2158 @functools.singledispatch
2159 def i(arg):
2160 return "base"
2161 @i.register
2162 def _(arg: collections.abc.Mapping):
2163 return "mapping"
2164 @i.register
2165 def _(arg: "collections.abc.Sequence"):
2166 return "sequence"
2167 self.assertEqual(i(None), "base")
2168 self.assertEqual(i({"a": 1}), "mapping")
2169 self.assertEqual(i([1, 2, 3]), "sequence")
2170 self.assertEqual(i((1, 2, 3)), "sequence")
2171 self.assertEqual(i("str"), "sequence")
2172
2173 # Registering classes as callables doesn't work with annotations,
2174 # you need to pass the type explicitly.
2175 @i.register(str)
2176 class _:
2177 def __init__(self, arg):
2178 self.arg = arg
2179
2180 def __eq__(self, other):
2181 return self.arg == other
2182 self.assertEqual(i("str"), "str")
2183
Ethan Smithc6512752018-05-26 16:38:33 -04002184 def test_method_register(self):
2185 class A:
2186 @functools.singledispatchmethod
2187 def t(self, arg):
2188 self.arg = "base"
2189 @t.register(int)
2190 def _(self, arg):
2191 self.arg = "int"
2192 @t.register(str)
2193 def _(self, arg):
2194 self.arg = "str"
2195 a = A()
2196
2197 a.t(0)
2198 self.assertEqual(a.arg, "int")
2199 aa = A()
2200 self.assertFalse(hasattr(aa, 'arg'))
2201 a.t('')
2202 self.assertEqual(a.arg, "str")
2203 aa = A()
2204 self.assertFalse(hasattr(aa, 'arg'))
2205 a.t(0.0)
2206 self.assertEqual(a.arg, "base")
2207 aa = A()
2208 self.assertFalse(hasattr(aa, 'arg'))
2209
2210 def test_staticmethod_register(self):
2211 class A:
2212 @functools.singledispatchmethod
2213 @staticmethod
2214 def t(arg):
2215 return arg
2216 @t.register(int)
2217 @staticmethod
2218 def _(arg):
2219 return isinstance(arg, int)
2220 @t.register(str)
2221 @staticmethod
2222 def _(arg):
2223 return isinstance(arg, str)
2224 a = A()
2225
2226 self.assertTrue(A.t(0))
2227 self.assertTrue(A.t(''))
2228 self.assertEqual(A.t(0.0), 0.0)
2229
2230 def test_classmethod_register(self):
2231 class A:
2232 def __init__(self, arg):
2233 self.arg = arg
2234
2235 @functools.singledispatchmethod
2236 @classmethod
2237 def t(cls, arg):
2238 return cls("base")
2239 @t.register(int)
2240 @classmethod
2241 def _(cls, arg):
2242 return cls("int")
2243 @t.register(str)
2244 @classmethod
2245 def _(cls, arg):
2246 return cls("str")
2247
2248 self.assertEqual(A.t(0).arg, "int")
2249 self.assertEqual(A.t('').arg, "str")
2250 self.assertEqual(A.t(0.0).arg, "base")
2251
2252 def test_callable_register(self):
2253 class A:
2254 def __init__(self, arg):
2255 self.arg = arg
2256
2257 @functools.singledispatchmethod
2258 @classmethod
2259 def t(cls, arg):
2260 return cls("base")
2261
2262 @A.t.register(int)
2263 @classmethod
2264 def _(cls, arg):
2265 return cls("int")
2266 @A.t.register(str)
2267 @classmethod
2268 def _(cls, arg):
2269 return cls("str")
2270
2271 self.assertEqual(A.t(0).arg, "int")
2272 self.assertEqual(A.t('').arg, "str")
2273 self.assertEqual(A.t(0.0).arg, "base")
2274
2275 def test_abstractmethod_register(self):
2276 class Abstract(abc.ABCMeta):
2277
2278 @functools.singledispatchmethod
2279 @abc.abstractmethod
2280 def add(self, x, y):
2281 pass
2282
2283 self.assertTrue(Abstract.add.__isabstractmethod__)
2284
2285 def test_type_ann_register(self):
2286 class A:
2287 @functools.singledispatchmethod
2288 def t(self, arg):
2289 return "base"
2290 @t.register
2291 def _(self, arg: int):
2292 return "int"
2293 @t.register
2294 def _(self, arg: str):
2295 return "str"
2296 a = A()
2297
2298 self.assertEqual(a.t(0), "int")
2299 self.assertEqual(a.t(''), "str")
2300 self.assertEqual(a.t(0.0), "base")
2301
Łukasz Langae5697532017-12-11 13:56:31 -08002302 def test_invalid_registrations(self):
2303 msg_prefix = "Invalid first argument to `register()`: "
2304 msg_suffix = (
2305 ". Use either `@register(some_class)` or plain `@register` on an "
2306 "annotated function."
2307 )
2308 @functools.singledispatch
2309 def i(arg):
2310 return "base"
2311 with self.assertRaises(TypeError) as exc:
2312 @i.register(42)
2313 def _(arg):
2314 return "I annotated with a non-type"
2315 self.assertTrue(str(exc.exception).startswith(msg_prefix + "42"))
2316 self.assertTrue(str(exc.exception).endswith(msg_suffix))
2317 with self.assertRaises(TypeError) as exc:
2318 @i.register
2319 def _(arg):
2320 return "I forgot to annotate"
2321 self.assertTrue(str(exc.exception).startswith(msg_prefix +
2322 "<function TestSingleDispatch.test_invalid_registrations.<locals>._"
2323 ))
2324 self.assertTrue(str(exc.exception).endswith(msg_suffix))
2325
2326 # FIXME: The following will only work after PEP 560 is implemented.
2327 return
2328
2329 with self.assertRaises(TypeError) as exc:
2330 @i.register
2331 def _(arg: typing.Iterable[str]):
2332 # At runtime, dispatching on generics is impossible.
2333 # When registering implementations with singledispatch, avoid
2334 # types from `typing`. Instead, annotate with regular types
2335 # or ABCs.
2336 return "I annotated with a generic collection"
2337 self.assertTrue(str(exc.exception).startswith(msg_prefix +
2338 "<function TestSingleDispatch.test_invalid_registrations.<locals>._"
2339 ))
2340 self.assertTrue(str(exc.exception).endswith(msg_suffix))
2341
Dong-hee Na445f1b32018-07-10 16:26:36 +09002342 def test_invalid_positional_argument(self):
2343 @functools.singledispatch
2344 def f(*args):
2345 pass
2346 msg = 'f requires at least 1 positional argument'
INADA Naoki56d8f572018-07-17 13:44:47 +09002347 with self.assertRaisesRegex(TypeError, msg):
Dong-hee Na445f1b32018-07-10 16:26:36 +09002348 f()
Łukasz Langa6f692512013-06-05 12:20:24 +02002349
Carl Meyerd658dea2018-08-28 01:11:56 -06002350
2351class CachedCostItem:
2352 _cost = 1
2353
2354 def __init__(self):
2355 self.lock = py_functools.RLock()
2356
2357 @py_functools.cached_property
2358 def cost(self):
2359 """The cost of the item."""
2360 with self.lock:
2361 self._cost += 1
2362 return self._cost
2363
2364
2365class OptionallyCachedCostItem:
2366 _cost = 1
2367
2368 def get_cost(self):
2369 """The cost of the item."""
2370 self._cost += 1
2371 return self._cost
2372
2373 cached_cost = py_functools.cached_property(get_cost)
2374
2375
2376class CachedCostItemWait:
2377
2378 def __init__(self, event):
2379 self._cost = 1
2380 self.lock = py_functools.RLock()
2381 self.event = event
2382
2383 @py_functools.cached_property
2384 def cost(self):
2385 self.event.wait(1)
2386 with self.lock:
2387 self._cost += 1
2388 return self._cost
2389
2390
2391class CachedCostItemWithSlots:
2392 __slots__ = ('_cost')
2393
2394 def __init__(self):
2395 self._cost = 1
2396
2397 @py_functools.cached_property
2398 def cost(self):
2399 raise RuntimeError('never called, slots not supported')
2400
2401
2402class TestCachedProperty(unittest.TestCase):
2403 def test_cached(self):
2404 item = CachedCostItem()
2405 self.assertEqual(item.cost, 2)
2406 self.assertEqual(item.cost, 2) # not 3
2407
2408 def test_cached_attribute_name_differs_from_func_name(self):
2409 item = OptionallyCachedCostItem()
2410 self.assertEqual(item.get_cost(), 2)
2411 self.assertEqual(item.cached_cost, 3)
2412 self.assertEqual(item.get_cost(), 4)
2413 self.assertEqual(item.cached_cost, 3)
2414
2415 def test_threaded(self):
2416 go = threading.Event()
2417 item = CachedCostItemWait(go)
2418
2419 num_threads = 3
2420
2421 orig_si = sys.getswitchinterval()
2422 sys.setswitchinterval(1e-6)
2423 try:
2424 threads = [
2425 threading.Thread(target=lambda: item.cost)
2426 for k in range(num_threads)
2427 ]
2428 with support.start_threads(threads):
2429 go.set()
2430 finally:
2431 sys.setswitchinterval(orig_si)
2432
2433 self.assertEqual(item.cost, 2)
2434
2435 def test_object_with_slots(self):
2436 item = CachedCostItemWithSlots()
2437 with self.assertRaisesRegex(
2438 TypeError,
2439 "No '__dict__' attribute on 'CachedCostItemWithSlots' instance to cache 'cost' property.",
2440 ):
2441 item.cost
2442
2443 def test_immutable_dict(self):
2444 class MyMeta(type):
2445 @py_functools.cached_property
2446 def prop(self):
2447 return True
2448
2449 class MyClass(metaclass=MyMeta):
2450 pass
2451
2452 with self.assertRaisesRegex(
2453 TypeError,
2454 "The '__dict__' attribute on 'MyMeta' instance does not support item assignment for caching 'prop' property.",
2455 ):
2456 MyClass.prop
2457
2458 def test_reuse_different_names(self):
2459 """Disallow this case because decorated function a would not be cached."""
2460 with self.assertRaises(RuntimeError) as ctx:
2461 class ReusedCachedProperty:
2462 @py_functools.cached_property
2463 def a(self):
2464 pass
2465
2466 b = a
2467
2468 self.assertEqual(
2469 str(ctx.exception.__context__),
2470 str(TypeError("Cannot assign the same cached_property to two different names ('a' and 'b')."))
2471 )
2472
2473 def test_reuse_same_name(self):
2474 """Reusing a cached_property on different classes under the same name is OK."""
2475 counter = 0
2476
2477 @py_functools.cached_property
2478 def _cp(_self):
2479 nonlocal counter
2480 counter += 1
2481 return counter
2482
2483 class A:
2484 cp = _cp
2485
2486 class B:
2487 cp = _cp
2488
2489 a = A()
2490 b = B()
2491
2492 self.assertEqual(a.cp, 1)
2493 self.assertEqual(b.cp, 2)
2494 self.assertEqual(a.cp, 1)
2495
2496 def test_set_name_not_called(self):
2497 cp = py_functools.cached_property(lambda s: None)
2498 class Foo:
2499 pass
2500
2501 Foo.cp = cp
2502
2503 with self.assertRaisesRegex(
2504 TypeError,
2505 "Cannot use cached_property instance without calling __set_name__ on it.",
2506 ):
2507 Foo().cp
2508
2509 def test_access_from_class(self):
2510 self.assertIsInstance(CachedCostItem.cost, py_functools.cached_property)
2511
2512 def test_doc(self):
2513 self.assertEqual(CachedCostItem.cost.__doc__, "The cost of the item.")
2514
2515
Raymond Hettinger9c323f82005-02-28 19:39:44 +00002516if __name__ == '__main__':
Zachary Ware38c707e2015-04-13 15:00:43 -05002517 unittest.main()