blob: 6a3bf649633f8a7ef3fb3c519ab9c70695ad2d4d [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettinger003be522011-05-03 11:01:32 -07002import collections
Serhiy Storchaka45120f22015-10-24 09:49:56 +03003import copy
Łukasz Langa6f692512013-06-05 12:20:24 +02004from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00005import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00006from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02007import sys
8from test import support
9import unittest
10from weakref import proxy
Serhiy Storchaka46c56112015-05-24 21:53:49 +030011try:
12 import threading
13except ImportError:
14 threading = None
Raymond Hettinger9c323f82005-02-28 19:39:44 +000015
Antoine Pitroub5b37142012-11-13 21:35:40 +010016import functools
17
Antoine Pitroub5b37142012-11-13 21:35:40 +010018py_functools = support.import_fresh_module('functools', blocked=['_functools'])
19c_functools = support.import_fresh_module('functools', fresh=['_functools'])
20
Łukasz Langa6f692512013-06-05 12:20:24 +020021decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
22
23
Raymond Hettinger9c323f82005-02-28 19:39:44 +000024def capture(*args, **kw):
25 """capture all positional and keyword arguments"""
26 return args, kw
27
Łukasz Langa6f692512013-06-05 12:20:24 +020028
Jack Diederiche0cbd692009-04-01 04:27:09 +000029def signature(part):
30 """ return the signature of a partial object """
31 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000032
Serhiy Storchaka38741282016-02-02 18:45:17 +020033class MyTuple(tuple):
34 pass
35
36class BadTuple(tuple):
37 def __add__(self, other):
38 return list(self) + list(other)
39
40class MyDict(dict):
41 pass
42
Łukasz Langa6f692512013-06-05 12:20:24 +020043
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020044class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000045
46 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010047 p = self.partial(capture, 1, 2, a=10, b=20)
48 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000049 self.assertEqual(p(3, 4, b=30, c=40),
50 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010051 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000052 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000053
54 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010055 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000056 # attributes should be readable
57 self.assertEqual(p.func, capture)
58 self.assertEqual(p.args, (1, 2))
59 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000060
61 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010062 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000063 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010064 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000065 except TypeError:
66 pass
67 else:
68 self.fail('First arg not checked for callability')
69
70 def test_protection_of_callers_dict_argument(self):
71 # a caller's dictionary should not be altered by partial
72 def func(a=10, b=20):
73 return a
74 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010075 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000076 self.assertEqual(p(**d), 3)
77 self.assertEqual(d, {'a':3})
78 p(b=7)
79 self.assertEqual(d, {'a':3})
80
81 def test_arg_combinations(self):
82 # exercise special code paths for zero args in either partial
83 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010084 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000085 self.assertEqual(p(), ((), {}))
86 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010087 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000088 self.assertEqual(p(), ((1,2), {}))
89 self.assertEqual(p(3,4), ((1,2,3,4), {}))
90
91 def test_kw_combinations(self):
92 # exercise special code paths for no keyword args in
93 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010094 p = self.partial(capture)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -040095 self.assertEqual(p.keywords, {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +000096 self.assertEqual(p(), ((), {}))
97 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010098 p = self.partial(capture, a=1)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -040099 self.assertEqual(p.keywords, {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000100 self.assertEqual(p(), ((), {'a':1}))
101 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
102 # keyword args in the call override those in the partial object
103 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
104
105 def test_positional(self):
106 # make sure positional arguments are captured correctly
107 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100108 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000109 expected = args + ('x',)
110 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000111 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000112
113 def test_keyword(self):
114 # make sure keyword arguments are captured correctly
115 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100116 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000117 expected = {'a':a,'x':None}
118 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000119 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000120
121 def test_no_side_effects(self):
122 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100123 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000124 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000125 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000126 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000127 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000128
129 def test_error_propagation(self):
130 def f(x, y):
131 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100132 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
133 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
134 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
135 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000136
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000137 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100138 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000139 p = proxy(f)
140 self.assertEqual(f.func, p.func)
141 f = None
142 self.assertRaises(ReferenceError, getattr, p, 'func')
143
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000144 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000145 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100146 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000147 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100148 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000149 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000150
Alexander Belopolskye49af342015-03-01 15:08:17 -0500151 def test_nested_optimization(self):
152 partial = self.partial
Alexander Belopolskye49af342015-03-01 15:08:17 -0500153 inner = partial(signature, 'asdf')
154 nested = partial(inner, bar=True)
155 flat = partial(signature, 'asdf', bar=True)
156 self.assertEqual(signature(nested), signature(flat))
157
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300158 def test_nested_partial_with_attribute(self):
159 # see issue 25137
160 partial = self.partial
161
162 def foo(bar):
163 return bar
164
165 p = partial(foo, 'first')
166 p2 = partial(p, 'second')
167 p2.new_attr = 'spam'
168 self.assertEqual(p2.new_attr, 'spam')
169
Łukasz Langa6f692512013-06-05 12:20:24 +0200170
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200171@unittest.skipUnless(c_functools, 'requires the C _functools module')
172class TestPartialC(TestPartial, unittest.TestCase):
173 if c_functools:
174 partial = c_functools.partial
175
Zachary Ware101d9e72013-12-08 00:44:27 -0600176 def test_attributes_unwritable(self):
177 # attributes should not be writable
178 p = self.partial(capture, 1, 2, a=10, b=20)
179 self.assertRaises(AttributeError, setattr, p, 'func', map)
180 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
181 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
182
183 p = self.partial(hex)
184 try:
185 del p.__dict__
186 except TypeError:
187 pass
188 else:
189 self.fail('partial object allowed __dict__ to be deleted')
190
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000191 def test_repr(self):
192 args = (object(), object())
193 args_repr = ', '.join(repr(a) for a in args)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200194 kwargs = {'a': object(), 'b': object()}
195 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
196 'b={b!r}, a={a!r}'.format_map(kwargs)]
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200197 if self.partial is c_functools.partial:
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000198 name = 'functools.partial'
199 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100200 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000201
Antoine Pitroub5b37142012-11-13 21:35:40 +0100202 f = self.partial(capture)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000203 self.assertEqual('{}({!r})'.format(name, capture),
204 repr(f))
205
Antoine Pitroub5b37142012-11-13 21:35:40 +0100206 f = self.partial(capture, *args)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000207 self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr),
208 repr(f))
209
Antoine Pitroub5b37142012-11-13 21:35:40 +0100210 f = self.partial(capture, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200211 self.assertIn(repr(f),
212 ['{}({!r}, {})'.format(name, capture, kwargs_repr)
213 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000214
Antoine Pitroub5b37142012-11-13 21:35:40 +0100215 f = self.partial(capture, *args, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200216 self.assertIn(repr(f),
217 ['{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr)
218 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000219
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300220 def test_recursive_repr(self):
221 if self.partial is c_functools.partial:
222 name = 'functools.partial'
223 else:
224 name = self.partial.__name__
225
226 f = self.partial(capture)
227 f.__setstate__((f, (), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300228 try:
229 self.assertEqual(repr(f), '%s(%s(...))' % (name, name))
230 finally:
231 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300232
233 f = self.partial(capture)
234 f.__setstate__((capture, (f,), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300235 try:
236 self.assertEqual(repr(f), '%s(%r, %s(...))' % (name, capture, name))
237 finally:
238 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300239
240 f = self.partial(capture)
241 f.__setstate__((capture, (), {'a': f}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300242 try:
243 self.assertEqual(repr(f), '%s(%r, a=%s(...))' % (name, capture, name))
244 finally:
245 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300246
Jack Diederiche0cbd692009-04-01 04:27:09 +0000247 def test_pickle(self):
Serhiy Storchaka38741282016-02-02 18:45:17 +0200248 f = self.partial(signature, ['asdf'], bar=[True])
249 f.attr = []
Serhiy Storchakabad12572014-12-15 14:03:42 +0200250 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
251 f_copy = pickle.loads(pickle.dumps(f, proto))
Serhiy Storchaka38741282016-02-02 18:45:17 +0200252 self.assertEqual(signature(f_copy), signature(f))
253
254 def test_copy(self):
255 f = self.partial(signature, ['asdf'], bar=[True])
256 f.attr = []
257 f_copy = copy.copy(f)
258 self.assertEqual(signature(f_copy), signature(f))
259 self.assertIs(f_copy.attr, f.attr)
260 self.assertIs(f_copy.args, f.args)
261 self.assertIs(f_copy.keywords, f.keywords)
262
263 def test_deepcopy(self):
264 f = self.partial(signature, ['asdf'], bar=[True])
265 f.attr = []
266 f_copy = copy.deepcopy(f)
267 self.assertEqual(signature(f_copy), signature(f))
268 self.assertIsNot(f_copy.attr, f.attr)
269 self.assertIsNot(f_copy.args, f.args)
270 self.assertIsNot(f_copy.args[0], f.args[0])
271 self.assertIsNot(f_copy.keywords, f.keywords)
272 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
273
274 def test_setstate(self):
275 f = self.partial(signature)
276 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
277 self.assertEqual(signature(f),
278 (capture, (1,), dict(a=10), dict(attr=[])))
279 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
280
281 f.__setstate__((capture, (1,), dict(a=10), None))
282 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
283 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
284
285 f.__setstate__((capture, (1,), None, None))
286 #self.assertEqual(signature(f), (capture, (1,), {}, {}))
287 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
288 self.assertEqual(f(2), ((1, 2), {}))
289 self.assertEqual(f(), ((1,), {}))
290
291 f.__setstate__((capture, (), {}, None))
292 self.assertEqual(signature(f), (capture, (), {}, {}))
293 self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
294 self.assertEqual(f(2), ((2,), {}))
295 self.assertEqual(f(), ((), {}))
296
297 def test_setstate_errors(self):
298 f = self.partial(signature)
299 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
300 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
301 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
302 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
303 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
304 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
305 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
306
307 def test_setstate_subclasses(self):
308 f = self.partial(signature)
309 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
310 s = signature(f)
311 self.assertEqual(s, (capture, (1,), dict(a=10), {}))
312 self.assertIs(type(s[1]), tuple)
313 self.assertIs(type(s[2]), dict)
314 r = f()
315 self.assertEqual(r, ((1,), {'a': 10}))
316 self.assertIs(type(r[0]), tuple)
317 self.assertIs(type(r[1]), dict)
318
319 f.__setstate__((capture, BadTuple((1,)), {}, None))
320 s = signature(f)
321 self.assertEqual(s, (capture, (1,), {}, {}))
322 self.assertIs(type(s[1]), tuple)
323 r = f(2)
324 self.assertEqual(r, ((1, 2), {}))
325 self.assertIs(type(r[0]), tuple)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000326
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300327 def test_recursive_pickle(self):
328 f = self.partial(capture)
329 f.__setstate__((f, (), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300330 try:
331 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
332 with self.assertRaises(RecursionError):
333 pickle.dumps(f, proto)
334 finally:
335 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300336
337 f = self.partial(capture)
338 f.__setstate__((capture, (f,), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300339 try:
340 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
341 f_copy = pickle.loads(pickle.dumps(f, proto))
342 try:
343 self.assertIs(f_copy.args[0], f_copy)
344 finally:
345 f_copy.__setstate__((capture, (), {}, {}))
346 finally:
347 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300348
349 f = self.partial(capture)
350 f.__setstate__((capture, (), {'a': f}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300351 try:
352 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
353 f_copy = pickle.loads(pickle.dumps(f, proto))
354 try:
355 self.assertIs(f_copy.keywords['a'], f_copy)
356 finally:
357 f_copy.__setstate__((capture, (), {}, {}))
358 finally:
359 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300360
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200361 # Issue 6083: Reference counting bug
362 def test_setstate_refcount(self):
363 class BadSequence:
364 def __len__(self):
365 return 4
366 def __getitem__(self, key):
367 if key == 0:
368 return max
369 elif key == 1:
370 return tuple(range(1000000))
371 elif key in (2, 3):
372 return {}
373 raise IndexError
374
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200375 f = self.partial(object)
Serhiy Storchaka38741282016-02-02 18:45:17 +0200376 self.assertRaises(TypeError, f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000377
Łukasz Langa6f692512013-06-05 12:20:24 +0200378
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200379class TestPartialPy(TestPartial, unittest.TestCase):
380 partial = staticmethod(py_functools.partial)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000381
Łukasz Langa6f692512013-06-05 12:20:24 +0200382
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200383if c_functools:
384 class PartialSubclass(c_functools.partial):
385 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100386
Łukasz Langa6f692512013-06-05 12:20:24 +0200387
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200388@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200389class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200390 if c_functools:
391 partial = PartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000392
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300393 # partial subclasses are not optimized for nested calls
394 test_nested_optimization = None
395
Łukasz Langa6f692512013-06-05 12:20:24 +0200396
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000397class TestPartialMethod(unittest.TestCase):
398
399 class A(object):
400 nothing = functools.partialmethod(capture)
401 positional = functools.partialmethod(capture, 1)
402 keywords = functools.partialmethod(capture, a=2)
403 both = functools.partialmethod(capture, 3, b=4)
404
405 nested = functools.partialmethod(positional, 5)
406
407 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
408
409 static = functools.partialmethod(staticmethod(capture), 8)
410 cls = functools.partialmethod(classmethod(capture), d=9)
411
412 a = A()
413
414 def test_arg_combinations(self):
415 self.assertEqual(self.a.nothing(), ((self.a,), {}))
416 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
417 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
418 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
419
420 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
421 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
422 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
423 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
424
425 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
426 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
427 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
428 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
429
430 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
431 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
432 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
433 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
434
435 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
436
437 def test_nested(self):
438 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
439 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
440 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
441 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
442
443 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
444
445 def test_over_partial(self):
446 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
447 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
448 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
449 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
450
451 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
452
453 def test_bound_method_introspection(self):
454 obj = self.a
455 self.assertIs(obj.both.__self__, obj)
456 self.assertIs(obj.nested.__self__, obj)
457 self.assertIs(obj.over_partial.__self__, obj)
458 self.assertIs(obj.cls.__self__, self.A)
459 self.assertIs(self.A.cls.__self__, self.A)
460
461 def test_unbound_method_retrieval(self):
462 obj = self.A
463 self.assertFalse(hasattr(obj.both, "__self__"))
464 self.assertFalse(hasattr(obj.nested, "__self__"))
465 self.assertFalse(hasattr(obj.over_partial, "__self__"))
466 self.assertFalse(hasattr(obj.static, "__self__"))
467 self.assertFalse(hasattr(self.a.static, "__self__"))
468
469 def test_descriptors(self):
470 for obj in [self.A, self.a]:
471 with self.subTest(obj=obj):
472 self.assertEqual(obj.static(), ((8,), {}))
473 self.assertEqual(obj.static(5), ((8, 5), {}))
474 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
475 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
476
477 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
478 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
479 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
480 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
481
482 def test_overriding_keywords(self):
483 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
484 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
485
486 def test_invalid_args(self):
487 with self.assertRaises(TypeError):
488 class B(object):
489 method = functools.partialmethod(None, 1)
490
491 def test_repr(self):
492 self.assertEqual(repr(vars(self.A)['both']),
493 'functools.partialmethod({}, 3, b=4)'.format(capture))
494
495 def test_abstract(self):
496 class Abstract(abc.ABCMeta):
497
498 @abc.abstractmethod
499 def add(self, x, y):
500 pass
501
502 add5 = functools.partialmethod(add, 5)
503
504 self.assertTrue(Abstract.add.__isabstractmethod__)
505 self.assertTrue(Abstract.add5.__isabstractmethod__)
506
507 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
508 self.assertFalse(getattr(func, '__isabstractmethod__', False))
509
510
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000511class TestUpdateWrapper(unittest.TestCase):
512
513 def check_wrapper(self, wrapper, wrapped,
514 assigned=functools.WRAPPER_ASSIGNMENTS,
515 updated=functools.WRAPPER_UPDATES):
516 # Check attributes were assigned
517 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000518 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000519 # Check attributes were updated
520 for name in updated:
521 wrapper_attr = getattr(wrapper, name)
522 wrapped_attr = getattr(wrapped, name)
523 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000524 if name == "__dict__" and key == "__wrapped__":
525 # __wrapped__ is overwritten by the update code
526 continue
527 self.assertIs(wrapped_attr[key], wrapper_attr[key])
528 # Check __wrapped__
529 self.assertIs(wrapper.__wrapped__, wrapped)
530
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000531
R. David Murray378c0cf2010-02-24 01:46:21 +0000532 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000533 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000534 """This is a test"""
535 pass
536 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000537 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000538 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000539 pass
540 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000541 return wrapper, f
542
543 def test_default_update(self):
544 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000545 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000546 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000547 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600548 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000549 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000550 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
551 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000552
R. David Murray378c0cf2010-02-24 01:46:21 +0000553 @unittest.skipIf(sys.flags.optimize >= 2,
554 "Docstrings are omitted with -O2 and above")
555 def test_default_update_doc(self):
556 wrapper, f = self._default_update()
557 self.assertEqual(wrapper.__doc__, 'This is a test')
558
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000559 def test_no_update(self):
560 def f():
561 """This is a test"""
562 pass
563 f.attr = 'This is also a test'
564 def wrapper():
565 pass
566 functools.update_wrapper(wrapper, f, (), ())
567 self.check_wrapper(wrapper, f, (), ())
568 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600569 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000570 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000571 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000572 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000573
574 def test_selective_update(self):
575 def f():
576 pass
577 f.attr = 'This is a different test'
578 f.dict_attr = dict(a=1, b=2, c=3)
579 def wrapper():
580 pass
581 wrapper.dict_attr = {}
582 assign = ('attr',)
583 update = ('dict_attr',)
584 functools.update_wrapper(wrapper, f, assign, update)
585 self.check_wrapper(wrapper, f, assign, update)
586 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600587 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000588 self.assertEqual(wrapper.__doc__, None)
589 self.assertEqual(wrapper.attr, 'This is a different test')
590 self.assertEqual(wrapper.dict_attr, f.dict_attr)
591
Nick Coghlan98876832010-08-17 06:17:18 +0000592 def test_missing_attributes(self):
593 def f():
594 pass
595 def wrapper():
596 pass
597 wrapper.dict_attr = {}
598 assign = ('attr',)
599 update = ('dict_attr',)
600 # Missing attributes on wrapped object are ignored
601 functools.update_wrapper(wrapper, f, assign, update)
602 self.assertNotIn('attr', wrapper.__dict__)
603 self.assertEqual(wrapper.dict_attr, {})
604 # Wrapper must have expected attributes for updating
605 del wrapper.dict_attr
606 with self.assertRaises(AttributeError):
607 functools.update_wrapper(wrapper, f, assign, update)
608 wrapper.dict_attr = 1
609 with self.assertRaises(AttributeError):
610 functools.update_wrapper(wrapper, f, assign, update)
611
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200612 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000613 @unittest.skipIf(sys.flags.optimize >= 2,
614 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000615 def test_builtin_update(self):
616 # Test for bug #1576241
617 def wrapper():
618 pass
619 functools.update_wrapper(wrapper, max)
620 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000621 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000622 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000623
Łukasz Langa6f692512013-06-05 12:20:24 +0200624
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000625class TestWraps(TestUpdateWrapper):
626
R. David Murray378c0cf2010-02-24 01:46:21 +0000627 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000628 def f():
629 """This is a test"""
630 pass
631 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000632 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000633 @functools.wraps(f)
634 def wrapper():
635 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600636 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000637
638 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600639 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000640 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000641 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600642 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000643 self.assertEqual(wrapper.attr, 'This is also a test')
644
Antoine Pitroub5b37142012-11-13 21:35:40 +0100645 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000646 "Docstrings are omitted with -O2 and above")
647 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600648 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000649 self.assertEqual(wrapper.__doc__, 'This is a test')
650
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000651 def test_no_update(self):
652 def f():
653 """This is a test"""
654 pass
655 f.attr = 'This is also a test'
656 @functools.wraps(f, (), ())
657 def wrapper():
658 pass
659 self.check_wrapper(wrapper, f, (), ())
660 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600661 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000662 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000663 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000664
665 def test_selective_update(self):
666 def f():
667 pass
668 f.attr = 'This is a different test'
669 f.dict_attr = dict(a=1, b=2, c=3)
670 def add_dict_attr(f):
671 f.dict_attr = {}
672 return f
673 assign = ('attr',)
674 update = ('dict_attr',)
675 @functools.wraps(f, assign, update)
676 @add_dict_attr
677 def wrapper():
678 pass
679 self.check_wrapper(wrapper, f, assign, update)
680 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600681 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000682 self.assertEqual(wrapper.__doc__, None)
683 self.assertEqual(wrapper.attr, 'This is a different test')
684 self.assertEqual(wrapper.dict_attr, f.dict_attr)
685
Łukasz Langa6f692512013-06-05 12:20:24 +0200686
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000687class TestReduce(unittest.TestCase):
688 func = functools.reduce
689
690 def test_reduce(self):
691 class Squares:
692 def __init__(self, max):
693 self.max = max
694 self.sofar = []
695
696 def __len__(self):
697 return len(self.sofar)
698
699 def __getitem__(self, i):
700 if not 0 <= i < self.max: raise IndexError
701 n = len(self.sofar)
702 while n <= i:
703 self.sofar.append(n*n)
704 n += 1
705 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000706 def add(x, y):
707 return x + y
708 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000709 self.assertEqual(
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000710 self.func(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000711 ['a','c','d','w']
712 )
713 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
714 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000715 self.func(lambda x, y: x*y, range(2,21), 1),
716 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000717 )
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000718 self.assertEqual(self.func(add, Squares(10)), 285)
719 self.assertEqual(self.func(add, Squares(10), 0), 285)
720 self.assertEqual(self.func(add, Squares(0), 0), 0)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000721 self.assertRaises(TypeError, self.func)
722 self.assertRaises(TypeError, self.func, 42, 42)
723 self.assertRaises(TypeError, self.func, 42, 42, 42)
724 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
725 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
726 self.assertRaises(TypeError, self.func, 42, (42, 42))
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000727 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
728 self.assertRaises(TypeError, self.func, add, "")
729 self.assertRaises(TypeError, self.func, add, ())
730 self.assertRaises(TypeError, self.func, add, object())
731
732 class TestFailingIter:
733 def __iter__(self):
734 raise RuntimeError
735 self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
736
737 self.assertEqual(self.func(add, [], None), None)
738 self.assertEqual(self.func(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000739
740 class BadSeq:
741 def __getitem__(self, index):
742 raise ValueError
743 self.assertRaises(ValueError, self.func, 42, BadSeq())
744
745 # Test reduce()'s use of iterators.
746 def test_iterator_usage(self):
747 class SequenceClass:
748 def __init__(self, n):
749 self.n = n
750 def __getitem__(self, i):
751 if 0 <= i < self.n:
752 return i
753 else:
754 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000755
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000756 from operator import add
757 self.assertEqual(self.func(add, SequenceClass(5)), 10)
758 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
759 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
760 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
761 self.assertEqual(self.func(add, SequenceClass(1)), 0)
762 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
763
764 d = {"one": 1, "two": 2, "three": 3}
765 self.assertEqual(self.func(add, d), "".join(d.keys()))
766
Łukasz Langa6f692512013-06-05 12:20:24 +0200767
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200768class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700769
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000770 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700771 def cmp1(x, y):
772 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100773 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700774 self.assertEqual(key(3), key(3))
775 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100776 self.assertGreaterEqual(key(3), key(3))
777
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700778 def cmp2(x, y):
779 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100780 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700781 self.assertEqual(key(4.0), key('4'))
782 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100783 self.assertLessEqual(key(2), key('35'))
784 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700785
786 def test_cmp_to_key_arguments(self):
787 def cmp1(x, y):
788 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100789 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700790 self.assertEqual(key(obj=3), key(obj=3))
791 self.assertGreater(key(obj=3), key(obj=1))
792 with self.assertRaises((TypeError, AttributeError)):
793 key(3) > 1 # rhs is not a K object
794 with self.assertRaises((TypeError, AttributeError)):
795 1 < key(3) # lhs is not a K object
796 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100797 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700798 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200799 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100800 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700801 with self.assertRaises(TypeError):
802 key() # too few args
803 with self.assertRaises(TypeError):
804 key(None, None) # too many args
805
806 def test_bad_cmp(self):
807 def cmp1(x, y):
808 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100809 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700810 with self.assertRaises(ZeroDivisionError):
811 key(3) > key(1)
812
813 class BadCmp:
814 def __lt__(self, other):
815 raise ZeroDivisionError
816 def cmp1(x, y):
817 return BadCmp()
818 with self.assertRaises(ZeroDivisionError):
819 key(3) > key(1)
820
821 def test_obj_field(self):
822 def cmp1(x, y):
823 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100824 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700825 self.assertEqual(key(50).obj, 50)
826
827 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000828 def mycmp(x, y):
829 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100830 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000831 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000832
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700833 def test_sort_int_str(self):
834 def mycmp(x, y):
835 x, y = int(x), int(y)
836 return (x > y) - (x < y)
837 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100838 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700839 self.assertEqual([int(value) for value in values],
840 [0, 1, 1, 2, 3, 4, 5, 7, 10])
841
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000842 def test_hash(self):
843 def mycmp(x, y):
844 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100845 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000846 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700847 self.assertRaises(TypeError, hash, k)
Raymond Hettingere7a24302011-05-03 11:16:36 -0700848 self.assertNotIsInstance(k, collections.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000849
Łukasz Langa6f692512013-06-05 12:20:24 +0200850
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200851@unittest.skipUnless(c_functools, 'requires the C _functools module')
852class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
853 if c_functools:
854 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100855
Łukasz Langa6f692512013-06-05 12:20:24 +0200856
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200857class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100858 cmp_to_key = staticmethod(py_functools.cmp_to_key)
859
Łukasz Langa6f692512013-06-05 12:20:24 +0200860
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000861class TestTotalOrdering(unittest.TestCase):
862
863 def test_total_ordering_lt(self):
864 @functools.total_ordering
865 class A:
866 def __init__(self, value):
867 self.value = value
868 def __lt__(self, other):
869 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000870 def __eq__(self, other):
871 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000872 self.assertTrue(A(1) < A(2))
873 self.assertTrue(A(2) > A(1))
874 self.assertTrue(A(1) <= A(2))
875 self.assertTrue(A(2) >= A(1))
876 self.assertTrue(A(2) <= A(2))
877 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000878 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000879
880 def test_total_ordering_le(self):
881 @functools.total_ordering
882 class A:
883 def __init__(self, value):
884 self.value = value
885 def __le__(self, other):
886 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000887 def __eq__(self, other):
888 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000889 self.assertTrue(A(1) < A(2))
890 self.assertTrue(A(2) > A(1))
891 self.assertTrue(A(1) <= A(2))
892 self.assertTrue(A(2) >= A(1))
893 self.assertTrue(A(2) <= A(2))
894 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000895 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000896
897 def test_total_ordering_gt(self):
898 @functools.total_ordering
899 class A:
900 def __init__(self, value):
901 self.value = value
902 def __gt__(self, other):
903 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000904 def __eq__(self, other):
905 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000906 self.assertTrue(A(1) < A(2))
907 self.assertTrue(A(2) > A(1))
908 self.assertTrue(A(1) <= A(2))
909 self.assertTrue(A(2) >= A(1))
910 self.assertTrue(A(2) <= A(2))
911 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000912 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000913
914 def test_total_ordering_ge(self):
915 @functools.total_ordering
916 class A:
917 def __init__(self, value):
918 self.value = value
919 def __ge__(self, other):
920 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000921 def __eq__(self, other):
922 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000923 self.assertTrue(A(1) < A(2))
924 self.assertTrue(A(2) > A(1))
925 self.assertTrue(A(1) <= A(2))
926 self.assertTrue(A(2) >= A(1))
927 self.assertTrue(A(2) <= A(2))
928 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000929 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000930
931 def test_total_ordering_no_overwrite(self):
932 # new methods should not overwrite existing
933 @functools.total_ordering
934 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000935 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +0000936 self.assertTrue(A(1) < A(2))
937 self.assertTrue(A(2) > A(1))
938 self.assertTrue(A(1) <= A(2))
939 self.assertTrue(A(2) >= A(1))
940 self.assertTrue(A(2) <= A(2))
941 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000942
Benjamin Peterson42ebee32010-04-11 01:43:16 +0000943 def test_no_operations_defined(self):
944 with self.assertRaises(ValueError):
945 @functools.total_ordering
946 class A:
947 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000948
Nick Coghlanf05d9812013-10-02 00:02:03 +1000949 def test_type_error_when_not_implemented(self):
950 # bug 10042; ensure stack overflow does not occur
951 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000952 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +1000953 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000954 def __init__(self, value):
955 self.value = value
956 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000957 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000958 return self.value == other.value
959 return False
960 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000961 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000962 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +1000963 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000964
Nick Coghlanf05d9812013-10-02 00:02:03 +1000965 @functools.total_ordering
966 class ImplementsGreaterThan:
967 def __init__(self, value):
968 self.value = value
969 def __eq__(self, other):
970 if isinstance(other, ImplementsGreaterThan):
971 return self.value == other.value
972 return False
973 def __gt__(self, other):
974 if isinstance(other, ImplementsGreaterThan):
975 return self.value > other.value
976 return NotImplemented
977
978 @functools.total_ordering
979 class ImplementsLessThanEqualTo:
980 def __init__(self, value):
981 self.value = value
982 def __eq__(self, other):
983 if isinstance(other, ImplementsLessThanEqualTo):
984 return self.value == other.value
985 return False
986 def __le__(self, other):
987 if isinstance(other, ImplementsLessThanEqualTo):
988 return self.value <= other.value
989 return NotImplemented
990
991 @functools.total_ordering
992 class ImplementsGreaterThanEqualTo:
993 def __init__(self, value):
994 self.value = value
995 def __eq__(self, other):
996 if isinstance(other, ImplementsGreaterThanEqualTo):
997 return self.value == other.value
998 return False
999 def __ge__(self, other):
1000 if isinstance(other, ImplementsGreaterThanEqualTo):
1001 return self.value >= other.value
1002 return NotImplemented
1003
1004 @functools.total_ordering
1005 class ComparatorNotImplemented:
1006 def __init__(self, value):
1007 self.value = value
1008 def __eq__(self, other):
1009 if isinstance(other, ComparatorNotImplemented):
1010 return self.value == other.value
1011 return False
1012 def __lt__(self, other):
1013 return NotImplemented
1014
1015 with self.subTest("LT < 1"), self.assertRaises(TypeError):
1016 ImplementsLessThan(-1) < 1
1017
1018 with self.subTest("LT < LE"), self.assertRaises(TypeError):
1019 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
1020
1021 with self.subTest("LT < GT"), self.assertRaises(TypeError):
1022 ImplementsLessThan(1) < ImplementsGreaterThan(1)
1023
1024 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
1025 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
1026
1027 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
1028 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
1029
1030 with self.subTest("GT > GE"), self.assertRaises(TypeError):
1031 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
1032
1033 with self.subTest("GT > LT"), self.assertRaises(TypeError):
1034 ImplementsGreaterThan(5) > ImplementsLessThan(5)
1035
1036 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
1037 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
1038
1039 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
1040 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
1041
1042 with self.subTest("GE when equal"):
1043 a = ComparatorNotImplemented(8)
1044 b = ComparatorNotImplemented(8)
1045 self.assertEqual(a, b)
1046 with self.assertRaises(TypeError):
1047 a >= b
1048
1049 with self.subTest("LE when equal"):
1050 a = ComparatorNotImplemented(9)
1051 b = ComparatorNotImplemented(9)
1052 self.assertEqual(a, b)
1053 with self.assertRaises(TypeError):
1054 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +02001055
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001056 def test_pickle(self):
Serhiy Storchaka92bb90a2016-09-22 11:39:25 +03001057 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001058 for name in '__lt__', '__gt__', '__le__', '__ge__':
1059 with self.subTest(method=name, proto=proto):
1060 method = getattr(Orderable_LT, name)
1061 method_copy = pickle.loads(pickle.dumps(method, proto))
1062 self.assertIs(method_copy, method)
1063
1064@functools.total_ordering
1065class Orderable_LT:
1066 def __init__(self, value):
1067 self.value = value
1068 def __lt__(self, other):
1069 return self.value < other.value
1070 def __eq__(self, other):
1071 return self.value == other.value
1072
1073
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001074class TestLRU:
Georg Brandl2e7346a2010-07-31 18:09:23 +00001075
1076 def test_lru(self):
1077 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001078 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001079 f = self.module.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001080 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001081 self.assertEqual(maxsize, 20)
1082 self.assertEqual(currsize, 0)
1083 self.assertEqual(hits, 0)
1084 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001085
1086 domain = range(5)
1087 for i in range(1000):
1088 x, y = choice(domain), choice(domain)
1089 actual = f(x, y)
1090 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +00001091 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001092 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001093 self.assertTrue(hits > misses)
1094 self.assertEqual(hits + misses, 1000)
1095 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001096
Raymond Hettinger02566ec2010-09-04 22:46:06 +00001097 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +00001098 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001099 self.assertEqual(hits, 0)
1100 self.assertEqual(misses, 0)
1101 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001102 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001103 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001104 self.assertEqual(hits, 0)
1105 self.assertEqual(misses, 1)
1106 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001107
Nick Coghlan98876832010-08-17 06:17:18 +00001108 # Test bypassing the cache
1109 self.assertIs(f.__wrapped__, orig)
1110 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001111 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001112 self.assertEqual(hits, 0)
1113 self.assertEqual(misses, 1)
1114 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +00001115
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001116 # test size zero (which means "never-cache")
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001117 @self.module.lru_cache(0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001118 def f():
1119 nonlocal f_cnt
1120 f_cnt += 1
1121 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001122 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001123 f_cnt = 0
1124 for i in range(5):
1125 self.assertEqual(f(), 20)
1126 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001127 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001128 self.assertEqual(hits, 0)
1129 self.assertEqual(misses, 5)
1130 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001131
1132 # test size one
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001133 @self.module.lru_cache(1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001134 def f():
1135 nonlocal f_cnt
1136 f_cnt += 1
1137 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001138 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001139 f_cnt = 0
1140 for i in range(5):
1141 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001142 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001143 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001144 self.assertEqual(hits, 4)
1145 self.assertEqual(misses, 1)
1146 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001147
Raymond Hettingerf3098282010-08-15 03:30:45 +00001148 # test size two
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001149 @self.module.lru_cache(2)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001150 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001151 nonlocal f_cnt
1152 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +00001153 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +00001154 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001155 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +00001156 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1157 # * * * *
1158 self.assertEqual(f(x), x*10)
1159 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001160 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001161 self.assertEqual(hits, 12)
1162 self.assertEqual(misses, 4)
1163 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001164
Yury Selivanov46a02db2016-11-09 18:55:45 -05001165 def test_lru_type_error(self):
1166 # Regression test for issue #28653.
1167 # lru_cache was leaking when one of the arguments
1168 # wasn't cacheable.
1169
1170 @functools.lru_cache(maxsize=None)
1171 def infinite_cache(o):
1172 pass
1173
1174 @functools.lru_cache(maxsize=10)
1175 def limited_cache(o):
1176 pass
1177
1178 with self.assertRaises(TypeError):
1179 infinite_cache([])
1180
1181 with self.assertRaises(TypeError):
1182 limited_cache([])
1183
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001184 def test_lru_with_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001185 @self.module.lru_cache(maxsize=None)
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001186 def fib(n):
1187 if n < 2:
1188 return n
1189 return fib(n-1) + fib(n-2)
1190 self.assertEqual([fib(n) for n in range(16)],
1191 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1192 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001193 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001194 fib.cache_clear()
1195 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001196 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1197
1198 def test_lru_with_maxsize_negative(self):
1199 @self.module.lru_cache(maxsize=-10)
1200 def eq(n):
1201 return n
1202 for i in (0, 1):
1203 self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1204 self.assertEqual(eq.cache_info(),
1205 self.module._CacheInfo(hits=0, misses=300, maxsize=-10, currsize=1))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001206
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001207 def test_lru_with_exceptions(self):
1208 # Verify that user_function exceptions get passed through without
1209 # creating a hard-to-read chained exception.
1210 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +01001211 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001212 @self.module.lru_cache(maxsize)
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001213 def func(i):
1214 return 'abc'[i]
1215 self.assertEqual(func(0), 'a')
1216 with self.assertRaises(IndexError) as cm:
1217 func(15)
1218 self.assertIsNone(cm.exception.__context__)
1219 # Verify that the previous exception did not result in a cached entry
1220 with self.assertRaises(IndexError):
1221 func(15)
1222
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001223 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001224 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001225 @self.module.lru_cache(maxsize=maxsize, typed=True)
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001226 def square(x):
1227 return x * x
1228 self.assertEqual(square(3), 9)
1229 self.assertEqual(type(square(3)), type(9))
1230 self.assertEqual(square(3.0), 9.0)
1231 self.assertEqual(type(square(3.0)), type(9.0))
1232 self.assertEqual(square(x=3), 9)
1233 self.assertEqual(type(square(x=3)), type(9))
1234 self.assertEqual(square(x=3.0), 9.0)
1235 self.assertEqual(type(square(x=3.0)), type(9.0))
1236 self.assertEqual(square.cache_info().hits, 4)
1237 self.assertEqual(square.cache_info().misses, 4)
1238
Antoine Pitroub5b37142012-11-13 21:35:40 +01001239 def test_lru_with_keyword_args(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001240 @self.module.lru_cache()
Antoine Pitroub5b37142012-11-13 21:35:40 +01001241 def fib(n):
1242 if n < 2:
1243 return n
1244 return fib(n=n-1) + fib(n=n-2)
1245 self.assertEqual(
1246 [fib(n=number) for number in range(16)],
1247 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1248 )
1249 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001250 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001251 fib.cache_clear()
1252 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001253 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001254
1255 def test_lru_with_keyword_args_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001256 @self.module.lru_cache(maxsize=None)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001257 def fib(n):
1258 if n < 2:
1259 return n
1260 return fib(n=n-1) + fib(n=n-2)
1261 self.assertEqual([fib(n=number) for number in range(16)],
1262 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1263 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001264 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001265 fib.cache_clear()
1266 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001267 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1268
1269 def test_lru_cache_decoration(self):
1270 def f(zomg: 'zomg_annotation'):
1271 """f doc string"""
1272 return 42
1273 g = self.module.lru_cache()(f)
1274 for attr in self.module.WRAPPER_ASSIGNMENTS:
1275 self.assertEqual(getattr(g, attr), getattr(f, attr))
1276
1277 @unittest.skipUnless(threading, 'This test requires threading.')
1278 def test_lru_cache_threaded(self):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001279 n, m = 5, 11
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001280 def orig(x, y):
1281 return 3 * x + y
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001282 f = self.module.lru_cache(maxsize=n*m)(orig)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001283 hits, misses, maxsize, currsize = f.cache_info()
1284 self.assertEqual(currsize, 0)
1285
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001286 start = threading.Event()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001287 def full(k):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001288 start.wait(10)
1289 for _ in range(m):
Serhiy Storchaka391af752015-06-08 12:44:18 +03001290 self.assertEqual(f(k, 0), orig(k, 0))
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001291
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001292 def clear():
1293 start.wait(10)
1294 for _ in range(2*m):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001295 f.cache_clear()
1296
1297 orig_si = sys.getswitchinterval()
1298 sys.setswitchinterval(1e-6)
1299 try:
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001300 # create n threads in order to fill cache
Serhiy Storchaka391af752015-06-08 12:44:18 +03001301 threads = [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001302 for k in range(n)]
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001303 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001304 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001305
1306 hits, misses, maxsize, currsize = f.cache_info()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001307 if self.module is py_functools:
1308 # XXX: Why can be not equal?
1309 self.assertLessEqual(misses, n)
1310 self.assertLessEqual(hits, m*n - misses)
1311 else:
1312 self.assertEqual(misses, n)
1313 self.assertEqual(hits, m*n - misses)
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001314 self.assertEqual(currsize, n)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001315
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001316 # create n threads in order to fill cache and 1 to clear it
1317 threads = [threading.Thread(target=clear)]
Serhiy Storchaka391af752015-06-08 12:44:18 +03001318 threads += [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001319 for k in range(n)]
1320 start.clear()
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001321 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001322 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001323 finally:
1324 sys.setswitchinterval(orig_si)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001325
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001326 @unittest.skipUnless(threading, 'This test requires threading.')
1327 def test_lru_cache_threaded2(self):
1328 # Simultaneous call with the same arguments
1329 n, m = 5, 7
1330 start = threading.Barrier(n+1)
1331 pause = threading.Barrier(n+1)
1332 stop = threading.Barrier(n+1)
1333 @self.module.lru_cache(maxsize=m*n)
1334 def f(x):
1335 pause.wait(10)
1336 return 3 * x
1337 self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1338 def test():
1339 for i in range(m):
1340 start.wait(10)
1341 self.assertEqual(f(i), 3 * i)
1342 stop.wait(10)
1343 threads = [threading.Thread(target=test) for k in range(n)]
1344 with support.start_threads(threads):
1345 for i in range(m):
1346 start.wait(10)
1347 stop.reset()
1348 pause.wait(10)
1349 start.reset()
1350 stop.wait(10)
1351 pause.reset()
1352 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1353
Raymond Hettinger03923422013-03-04 02:52:50 -05001354 def test_need_for_rlock(self):
1355 # This will deadlock on an LRU cache that uses a regular lock
1356
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001357 @self.module.lru_cache(maxsize=10)
Raymond Hettinger03923422013-03-04 02:52:50 -05001358 def test_func(x):
1359 'Used to demonstrate a reentrant lru_cache call within a single thread'
1360 return x
1361
1362 class DoubleEq:
1363 'Demonstrate a reentrant lru_cache call within a single thread'
1364 def __init__(self, x):
1365 self.x = x
1366 def __hash__(self):
1367 return self.x
1368 def __eq__(self, other):
1369 if self.x == 2:
1370 test_func(DoubleEq(1))
1371 return self.x == other.x
1372
1373 test_func(DoubleEq(1)) # Load the cache
1374 test_func(DoubleEq(2)) # Load the cache
1375 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1376 DoubleEq(2)) # Verify the correct return value
1377
Raymond Hettinger4d588972014-08-12 12:44:52 -07001378 def test_early_detection_of_bad_call(self):
1379 # Issue #22184
1380 with self.assertRaises(TypeError):
1381 @functools.lru_cache
1382 def f():
1383 pass
1384
Serhiy Storchakae7070f02015-06-08 11:19:24 +03001385 def test_lru_method(self):
1386 class X(int):
1387 f_cnt = 0
1388 @self.module.lru_cache(2)
1389 def f(self, x):
1390 self.f_cnt += 1
1391 return x*10+self
1392 a = X(5)
1393 b = X(5)
1394 c = X(7)
1395 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1396
1397 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1398 self.assertEqual(a.f(x), x*10 + 5)
1399 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1400 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1401
1402 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1403 self.assertEqual(b.f(x), x*10 + 5)
1404 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1405 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1406
1407 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1408 self.assertEqual(c.f(x), x*10 + 7)
1409 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1410 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1411
1412 self.assertEqual(a.f.cache_info(), X.f.cache_info())
1413 self.assertEqual(b.f.cache_info(), X.f.cache_info())
1414 self.assertEqual(c.f.cache_info(), X.f.cache_info())
1415
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001416 def test_pickle(self):
1417 cls = self.__class__
1418 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1419 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1420 with self.subTest(proto=proto, func=f):
1421 f_copy = pickle.loads(pickle.dumps(f, proto))
1422 self.assertIs(f_copy, f)
1423
1424 def test_copy(self):
1425 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001426 def orig(x, y):
1427 return 3 * x + y
1428 part = self.module.partial(orig, 2)
1429 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1430 self.module.lru_cache(2)(part))
1431 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001432 with self.subTest(func=f):
1433 f_copy = copy.copy(f)
1434 self.assertIs(f_copy, f)
1435
1436 def test_deepcopy(self):
1437 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001438 def orig(x, y):
1439 return 3 * x + y
1440 part = self.module.partial(orig, 2)
1441 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1442 self.module.lru_cache(2)(part))
1443 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001444 with self.subTest(func=f):
1445 f_copy = copy.deepcopy(f)
1446 self.assertIs(f_copy, f)
1447
1448
1449@py_functools.lru_cache()
1450def py_cached_func(x, y):
1451 return 3 * x + y
1452
1453@c_functools.lru_cache()
1454def c_cached_func(x, y):
1455 return 3 * x + y
1456
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001457
1458class TestLRUPy(TestLRU, unittest.TestCase):
1459 module = py_functools
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001460 cached_func = py_cached_func,
1461
1462 @module.lru_cache()
1463 def cached_meth(self, x, y):
1464 return 3 * x + y
1465
1466 @staticmethod
1467 @module.lru_cache()
1468 def cached_staticmeth(x, y):
1469 return 3 * x + y
1470
1471
1472class TestLRUC(TestLRU, unittest.TestCase):
1473 module = c_functools
1474 cached_func = c_cached_func,
1475
1476 @module.lru_cache()
1477 def cached_meth(self, x, y):
1478 return 3 * x + y
1479
1480 @staticmethod
1481 @module.lru_cache()
1482 def cached_staticmeth(x, y):
1483 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001484
Raymond Hettinger03923422013-03-04 02:52:50 -05001485
Łukasz Langa6f692512013-06-05 12:20:24 +02001486class TestSingleDispatch(unittest.TestCase):
1487 def test_simple_overloads(self):
1488 @functools.singledispatch
1489 def g(obj):
1490 return "base"
1491 def g_int(i):
1492 return "integer"
1493 g.register(int, g_int)
1494 self.assertEqual(g("str"), "base")
1495 self.assertEqual(g(1), "integer")
1496 self.assertEqual(g([1,2,3]), "base")
1497
1498 def test_mro(self):
1499 @functools.singledispatch
1500 def g(obj):
1501 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001502 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001503 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001504 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001505 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001506 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001507 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001508 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001509 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001510 def g_A(a):
1511 return "A"
1512 def g_B(b):
1513 return "B"
1514 g.register(A, g_A)
1515 g.register(B, g_B)
1516 self.assertEqual(g(A()), "A")
1517 self.assertEqual(g(B()), "B")
1518 self.assertEqual(g(C()), "A")
1519 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001520
1521 def test_register_decorator(self):
1522 @functools.singledispatch
1523 def g(obj):
1524 return "base"
1525 @g.register(int)
1526 def g_int(i):
1527 return "int %s" % (i,)
1528 self.assertEqual(g(""), "base")
1529 self.assertEqual(g(12), "int 12")
1530 self.assertIs(g.dispatch(int), g_int)
1531 self.assertIs(g.dispatch(object), g.dispatch(str))
1532 # Note: in the assert above this is not g.
1533 # @singledispatch returns the wrapper.
1534
1535 def test_wrapping_attributes(self):
1536 @functools.singledispatch
1537 def g(obj):
1538 "Simple test"
1539 return "Test"
1540 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02001541 if sys.flags.optimize < 2:
1542 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02001543
1544 @unittest.skipUnless(decimal, 'requires _decimal')
1545 @support.cpython_only
1546 def test_c_classes(self):
1547 @functools.singledispatch
1548 def g(obj):
1549 return "base"
1550 @g.register(decimal.DecimalException)
1551 def _(obj):
1552 return obj.args
1553 subn = decimal.Subnormal("Exponent < Emin")
1554 rnd = decimal.Rounded("Number got rounded")
1555 self.assertEqual(g(subn), ("Exponent < Emin",))
1556 self.assertEqual(g(rnd), ("Number got rounded",))
1557 @g.register(decimal.Subnormal)
1558 def _(obj):
1559 return "Too small to care."
1560 self.assertEqual(g(subn), "Too small to care.")
1561 self.assertEqual(g(rnd), ("Number got rounded",))
1562
1563 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001564 # None of the examples in this test depend on haystack ordering.
Łukasz Langa6f692512013-06-05 12:20:24 +02001565 c = collections
1566 mro = functools._compose_mro
1567 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1568 for haystack in permutations(bases):
1569 m = mro(dict, haystack)
Łukasz Langa3720c772013-07-01 16:00:38 +02001570 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, c.Sized,
1571 c.Iterable, c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001572 bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict]
1573 for haystack in permutations(bases):
1574 m = mro(c.ChainMap, haystack)
1575 self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping,
1576 c.Sized, c.Iterable, c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001577
1578 # If there's a generic function with implementations registered for
1579 # both Sized and Container, passing a defaultdict to it results in an
1580 # ambiguous dispatch which will cause a RuntimeError (see
1581 # test_mro_conflicts).
1582 bases = [c.Container, c.Sized, str]
1583 for haystack in permutations(bases):
1584 m = mro(c.defaultdict, [c.Sized, c.Container, str])
1585 self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container,
1586 object])
1587
1588 # MutableSequence below is registered directly on D. In other words, it
Martin Panter46f50722016-05-26 05:35:26 +00001589 # precedes MutableMapping which means single dispatch will always
Łukasz Langa3720c772013-07-01 16:00:38 +02001590 # choose MutableSequence here.
1591 class D(c.defaultdict):
1592 pass
1593 c.MutableSequence.register(D)
1594 bases = [c.MutableSequence, c.MutableMapping]
1595 for haystack in permutations(bases):
1596 m = mro(D, bases)
1597 self.assertEqual(m, [D, c.MutableSequence, c.Sequence,
1598 c.defaultdict, dict, c.MutableMapping,
1599 c.Mapping, c.Sized, c.Iterable, c.Container,
1600 object])
1601
1602 # Container and Callable are registered on different base classes and
1603 # a generic function supporting both should always pick the Callable
1604 # implementation if a C instance is passed.
1605 class C(c.defaultdict):
1606 def __call__(self):
1607 pass
1608 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1609 for haystack in permutations(bases):
1610 m = mro(C, haystack)
1611 self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping,
1612 c.Sized, c.Iterable, c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001613
1614 def test_register_abc(self):
1615 c = collections
1616 d = {"a": "b"}
1617 l = [1, 2, 3]
1618 s = {object(), None}
1619 f = frozenset(s)
1620 t = (1, 2, 3)
1621 @functools.singledispatch
1622 def g(obj):
1623 return "base"
1624 self.assertEqual(g(d), "base")
1625 self.assertEqual(g(l), "base")
1626 self.assertEqual(g(s), "base")
1627 self.assertEqual(g(f), "base")
1628 self.assertEqual(g(t), "base")
1629 g.register(c.Sized, lambda obj: "sized")
1630 self.assertEqual(g(d), "sized")
1631 self.assertEqual(g(l), "sized")
1632 self.assertEqual(g(s), "sized")
1633 self.assertEqual(g(f), "sized")
1634 self.assertEqual(g(t), "sized")
1635 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1636 self.assertEqual(g(d), "mutablemapping")
1637 self.assertEqual(g(l), "sized")
1638 self.assertEqual(g(s), "sized")
1639 self.assertEqual(g(f), "sized")
1640 self.assertEqual(g(t), "sized")
1641 g.register(c.ChainMap, lambda obj: "chainmap")
1642 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1643 self.assertEqual(g(l), "sized")
1644 self.assertEqual(g(s), "sized")
1645 self.assertEqual(g(f), "sized")
1646 self.assertEqual(g(t), "sized")
1647 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1648 self.assertEqual(g(d), "mutablemapping")
1649 self.assertEqual(g(l), "mutablesequence")
1650 self.assertEqual(g(s), "sized")
1651 self.assertEqual(g(f), "sized")
1652 self.assertEqual(g(t), "sized")
1653 g.register(c.MutableSet, lambda obj: "mutableset")
1654 self.assertEqual(g(d), "mutablemapping")
1655 self.assertEqual(g(l), "mutablesequence")
1656 self.assertEqual(g(s), "mutableset")
1657 self.assertEqual(g(f), "sized")
1658 self.assertEqual(g(t), "sized")
1659 g.register(c.Mapping, lambda obj: "mapping")
1660 self.assertEqual(g(d), "mutablemapping") # not specific enough
1661 self.assertEqual(g(l), "mutablesequence")
1662 self.assertEqual(g(s), "mutableset")
1663 self.assertEqual(g(f), "sized")
1664 self.assertEqual(g(t), "sized")
1665 g.register(c.Sequence, lambda obj: "sequence")
1666 self.assertEqual(g(d), "mutablemapping")
1667 self.assertEqual(g(l), "mutablesequence")
1668 self.assertEqual(g(s), "mutableset")
1669 self.assertEqual(g(f), "sized")
1670 self.assertEqual(g(t), "sequence")
1671 g.register(c.Set, lambda obj: "set")
1672 self.assertEqual(g(d), "mutablemapping")
1673 self.assertEqual(g(l), "mutablesequence")
1674 self.assertEqual(g(s), "mutableset")
1675 self.assertEqual(g(f), "set")
1676 self.assertEqual(g(t), "sequence")
1677 g.register(dict, lambda obj: "dict")
1678 self.assertEqual(g(d), "dict")
1679 self.assertEqual(g(l), "mutablesequence")
1680 self.assertEqual(g(s), "mutableset")
1681 self.assertEqual(g(f), "set")
1682 self.assertEqual(g(t), "sequence")
1683 g.register(list, lambda obj: "list")
1684 self.assertEqual(g(d), "dict")
1685 self.assertEqual(g(l), "list")
1686 self.assertEqual(g(s), "mutableset")
1687 self.assertEqual(g(f), "set")
1688 self.assertEqual(g(t), "sequence")
1689 g.register(set, lambda obj: "concrete-set")
1690 self.assertEqual(g(d), "dict")
1691 self.assertEqual(g(l), "list")
1692 self.assertEqual(g(s), "concrete-set")
1693 self.assertEqual(g(f), "set")
1694 self.assertEqual(g(t), "sequence")
1695 g.register(frozenset, lambda obj: "frozen-set")
1696 self.assertEqual(g(d), "dict")
1697 self.assertEqual(g(l), "list")
1698 self.assertEqual(g(s), "concrete-set")
1699 self.assertEqual(g(f), "frozen-set")
1700 self.assertEqual(g(t), "sequence")
1701 g.register(tuple, lambda obj: "tuple")
1702 self.assertEqual(g(d), "dict")
1703 self.assertEqual(g(l), "list")
1704 self.assertEqual(g(s), "concrete-set")
1705 self.assertEqual(g(f), "frozen-set")
1706 self.assertEqual(g(t), "tuple")
1707
Łukasz Langa3720c772013-07-01 16:00:38 +02001708 def test_c3_abc(self):
1709 c = collections
1710 mro = functools._c3_mro
1711 class A(object):
1712 pass
1713 class B(A):
1714 def __len__(self):
1715 return 0 # implies Sized
1716 @c.Container.register
1717 class C(object):
1718 pass
1719 class D(object):
1720 pass # unrelated
1721 class X(D, C, B):
1722 def __call__(self):
1723 pass # implies Callable
1724 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
1725 for abcs in permutations([c.Sized, c.Callable, c.Container]):
1726 self.assertEqual(mro(X, abcs=abcs), expected)
1727 # unrelated ABCs don't appear in the resulting MRO
1728 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
1729 self.assertEqual(mro(X, abcs=many_abcs), expected)
1730
Yury Selivanov77a8cd62015-08-18 14:20:00 -04001731 def test_false_meta(self):
1732 # see issue23572
1733 class MetaA(type):
1734 def __len__(self):
1735 return 0
1736 class A(metaclass=MetaA):
1737 pass
1738 class AA(A):
1739 pass
1740 @functools.singledispatch
1741 def fun(a):
1742 return 'base A'
1743 @fun.register(A)
1744 def _(a):
1745 return 'fun A'
1746 aa = AA()
1747 self.assertEqual(fun(aa), 'fun A')
1748
Łukasz Langa6f692512013-06-05 12:20:24 +02001749 def test_mro_conflicts(self):
1750 c = collections
Łukasz Langa6f692512013-06-05 12:20:24 +02001751 @functools.singledispatch
1752 def g(arg):
1753 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02001754 class O(c.Sized):
1755 def __len__(self):
1756 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001757 o = O()
1758 self.assertEqual(g(o), "base")
1759 g.register(c.Iterable, lambda arg: "iterable")
1760 g.register(c.Container, lambda arg: "container")
1761 g.register(c.Sized, lambda arg: "sized")
1762 g.register(c.Set, lambda arg: "set")
1763 self.assertEqual(g(o), "sized")
1764 c.Iterable.register(O)
1765 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
1766 c.Container.register(O)
1767 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02001768 c.Set.register(O)
1769 self.assertEqual(g(o), "set") # because c.Set is a subclass of
1770 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02001771 class P:
1772 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02001773 p = P()
1774 self.assertEqual(g(p), "base")
1775 c.Iterable.register(P)
1776 self.assertEqual(g(p), "iterable")
1777 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02001778 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02001779 g(p)
Łukasz Langa3720c772013-07-01 16:00:38 +02001780 self.assertIn(
1781 str(re_one.exception),
1782 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1783 "or <class 'collections.abc.Iterable'>"),
1784 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
1785 "or <class 'collections.abc.Container'>")),
1786 )
Łukasz Langa6f692512013-06-05 12:20:24 +02001787 class Q(c.Sized):
1788 def __len__(self):
1789 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001790 q = Q()
1791 self.assertEqual(g(q), "sized")
1792 c.Iterable.register(Q)
1793 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
1794 c.Set.register(Q)
1795 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02001796 # c.Sized and c.Iterable
1797 @functools.singledispatch
1798 def h(arg):
1799 return "base"
1800 @h.register(c.Sized)
1801 def _(arg):
1802 return "sized"
1803 @h.register(c.Container)
1804 def _(arg):
1805 return "container"
1806 # Even though Sized and Container are explicit bases of MutableMapping,
1807 # this ABC is implicitly registered on defaultdict which makes all of
1808 # MutableMapping's bases implicit as well from defaultdict's
1809 # perspective.
1810 with self.assertRaises(RuntimeError) as re_two:
1811 h(c.defaultdict(lambda: 0))
1812 self.assertIn(
1813 str(re_two.exception),
1814 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1815 "or <class 'collections.abc.Sized'>"),
1816 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1817 "or <class 'collections.abc.Container'>")),
1818 )
1819 class R(c.defaultdict):
1820 pass
1821 c.MutableSequence.register(R)
1822 @functools.singledispatch
1823 def i(arg):
1824 return "base"
1825 @i.register(c.MutableMapping)
1826 def _(arg):
1827 return "mapping"
1828 @i.register(c.MutableSequence)
1829 def _(arg):
1830 return "sequence"
1831 r = R()
1832 self.assertEqual(i(r), "sequence")
1833 class S:
1834 pass
1835 class T(S, c.Sized):
1836 def __len__(self):
1837 return 0
1838 t = T()
1839 self.assertEqual(h(t), "sized")
1840 c.Container.register(T)
1841 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
1842 class U:
1843 def __len__(self):
1844 return 0
1845 u = U()
1846 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
1847 # from the existence of __len__()
1848 c.Container.register(U)
1849 # There is no preference for registered versus inferred ABCs.
1850 with self.assertRaises(RuntimeError) as re_three:
1851 h(u)
1852 self.assertIn(
1853 str(re_three.exception),
1854 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1855 "or <class 'collections.abc.Sized'>"),
1856 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1857 "or <class 'collections.abc.Container'>")),
1858 )
1859 class V(c.Sized, S):
1860 def __len__(self):
1861 return 0
1862 @functools.singledispatch
1863 def j(arg):
1864 return "base"
1865 @j.register(S)
1866 def _(arg):
1867 return "s"
1868 @j.register(c.Container)
1869 def _(arg):
1870 return "container"
1871 v = V()
1872 self.assertEqual(j(v), "s")
1873 c.Container.register(V)
1874 self.assertEqual(j(v), "container") # because it ends up right after
1875 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02001876
1877 def test_cache_invalidation(self):
1878 from collections import UserDict
1879 class TracingDict(UserDict):
1880 def __init__(self, *args, **kwargs):
1881 super(TracingDict, self).__init__(*args, **kwargs)
1882 self.set_ops = []
1883 self.get_ops = []
1884 def __getitem__(self, key):
1885 result = self.data[key]
1886 self.get_ops.append(key)
1887 return result
1888 def __setitem__(self, key, value):
1889 self.set_ops.append(key)
1890 self.data[key] = value
1891 def clear(self):
1892 self.data.clear()
1893 _orig_wkd = functools.WeakKeyDictionary
1894 td = TracingDict()
1895 functools.WeakKeyDictionary = lambda: td
1896 c = collections
1897 @functools.singledispatch
1898 def g(arg):
1899 return "base"
1900 d = {}
1901 l = []
1902 self.assertEqual(len(td), 0)
1903 self.assertEqual(g(d), "base")
1904 self.assertEqual(len(td), 1)
1905 self.assertEqual(td.get_ops, [])
1906 self.assertEqual(td.set_ops, [dict])
1907 self.assertEqual(td.data[dict], g.registry[object])
1908 self.assertEqual(g(l), "base")
1909 self.assertEqual(len(td), 2)
1910 self.assertEqual(td.get_ops, [])
1911 self.assertEqual(td.set_ops, [dict, list])
1912 self.assertEqual(td.data[dict], g.registry[object])
1913 self.assertEqual(td.data[list], g.registry[object])
1914 self.assertEqual(td.data[dict], td.data[list])
1915 self.assertEqual(g(l), "base")
1916 self.assertEqual(g(d), "base")
1917 self.assertEqual(td.get_ops, [list, dict])
1918 self.assertEqual(td.set_ops, [dict, list])
1919 g.register(list, lambda arg: "list")
1920 self.assertEqual(td.get_ops, [list, dict])
1921 self.assertEqual(len(td), 0)
1922 self.assertEqual(g(d), "base")
1923 self.assertEqual(len(td), 1)
1924 self.assertEqual(td.get_ops, [list, dict])
1925 self.assertEqual(td.set_ops, [dict, list, dict])
1926 self.assertEqual(td.data[dict],
1927 functools._find_impl(dict, g.registry))
1928 self.assertEqual(g(l), "list")
1929 self.assertEqual(len(td), 2)
1930 self.assertEqual(td.get_ops, [list, dict])
1931 self.assertEqual(td.set_ops, [dict, list, dict, list])
1932 self.assertEqual(td.data[list],
1933 functools._find_impl(list, g.registry))
1934 class X:
1935 pass
1936 c.MutableMapping.register(X) # Will not invalidate the cache,
1937 # not using ABCs yet.
1938 self.assertEqual(g(d), "base")
1939 self.assertEqual(g(l), "list")
1940 self.assertEqual(td.get_ops, [list, dict, dict, list])
1941 self.assertEqual(td.set_ops, [dict, list, dict, list])
1942 g.register(c.Sized, lambda arg: "sized")
1943 self.assertEqual(len(td), 0)
1944 self.assertEqual(g(d), "sized")
1945 self.assertEqual(len(td), 1)
1946 self.assertEqual(td.get_ops, [list, dict, dict, list])
1947 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
1948 self.assertEqual(g(l), "list")
1949 self.assertEqual(len(td), 2)
1950 self.assertEqual(td.get_ops, [list, dict, dict, list])
1951 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1952 self.assertEqual(g(l), "list")
1953 self.assertEqual(g(d), "sized")
1954 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
1955 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1956 g.dispatch(list)
1957 g.dispatch(dict)
1958 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
1959 list, dict])
1960 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1961 c.MutableSet.register(X) # Will invalidate the cache.
1962 self.assertEqual(len(td), 2) # Stale cache.
1963 self.assertEqual(g(l), "list")
1964 self.assertEqual(len(td), 1)
1965 g.register(c.MutableMapping, lambda arg: "mutablemapping")
1966 self.assertEqual(len(td), 0)
1967 self.assertEqual(g(d), "mutablemapping")
1968 self.assertEqual(len(td), 1)
1969 self.assertEqual(g(l), "list")
1970 self.assertEqual(len(td), 2)
1971 g.register(dict, lambda arg: "dict")
1972 self.assertEqual(g(d), "dict")
1973 self.assertEqual(g(l), "list")
1974 g._clear_cache()
1975 self.assertEqual(len(td), 0)
1976 functools.WeakKeyDictionary = _orig_wkd
1977
1978
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001979if __name__ == '__main__':
Zachary Ware38c707e2015-04-13 15:00:43 -05001980 unittest.main()