blob: 31930fc763a672b4a52a5563bd657534941286bd [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettinger003be522011-05-03 11:01:32 -07002import collections
Serhiy Storchaka45120f22015-10-24 09:49:56 +03003import copy
Łukasz Langa6f692512013-06-05 12:20:24 +02004from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00005import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00006from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02007import sys
8from test import support
9import unittest
10from weakref import proxy
Serhiy Storchaka46c56112015-05-24 21:53:49 +030011try:
12 import threading
13except ImportError:
14 threading = None
Raymond Hettinger9c323f82005-02-28 19:39:44 +000015
Antoine Pitroub5b37142012-11-13 21:35:40 +010016import functools
17
Antoine Pitroub5b37142012-11-13 21:35:40 +010018py_functools = support.import_fresh_module('functools', blocked=['_functools'])
19c_functools = support.import_fresh_module('functools', fresh=['_functools'])
20
Łukasz Langa6f692512013-06-05 12:20:24 +020021decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
22
23
Raymond Hettinger9c323f82005-02-28 19:39:44 +000024def capture(*args, **kw):
25 """capture all positional and keyword arguments"""
26 return args, kw
27
Łukasz Langa6f692512013-06-05 12:20:24 +020028
Jack Diederiche0cbd692009-04-01 04:27:09 +000029def signature(part):
30 """ return the signature of a partial object """
31 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000032
Serhiy Storchaka38741282016-02-02 18:45:17 +020033class MyTuple(tuple):
34 pass
35
36class BadTuple(tuple):
37 def __add__(self, other):
38 return list(self) + list(other)
39
40class MyDict(dict):
41 pass
42
Łukasz Langa6f692512013-06-05 12:20:24 +020043
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020044class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000045
46 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010047 p = self.partial(capture, 1, 2, a=10, b=20)
48 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000049 self.assertEqual(p(3, 4, b=30, c=40),
50 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010051 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000052 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000053
54 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010055 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000056 # attributes should be readable
57 self.assertEqual(p.func, capture)
58 self.assertEqual(p.args, (1, 2))
59 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000060
61 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010062 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000063 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010064 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000065 except TypeError:
66 pass
67 else:
68 self.fail('First arg not checked for callability')
69
70 def test_protection_of_callers_dict_argument(self):
71 # a caller's dictionary should not be altered by partial
72 def func(a=10, b=20):
73 return a
74 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010075 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000076 self.assertEqual(p(**d), 3)
77 self.assertEqual(d, {'a':3})
78 p(b=7)
79 self.assertEqual(d, {'a':3})
80
81 def test_arg_combinations(self):
82 # exercise special code paths for zero args in either partial
83 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010084 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000085 self.assertEqual(p(), ((), {}))
86 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010087 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000088 self.assertEqual(p(), ((1,2), {}))
89 self.assertEqual(p(3,4), ((1,2,3,4), {}))
90
91 def test_kw_combinations(self):
92 # exercise special code paths for no keyword args in
93 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010094 p = self.partial(capture)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -040095 self.assertEqual(p.keywords, {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +000096 self.assertEqual(p(), ((), {}))
97 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010098 p = self.partial(capture, a=1)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -040099 self.assertEqual(p.keywords, {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000100 self.assertEqual(p(), ((), {'a':1}))
101 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
102 # keyword args in the call override those in the partial object
103 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
104
105 def test_positional(self):
106 # make sure positional arguments are captured correctly
107 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100108 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000109 expected = args + ('x',)
110 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000111 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000112
113 def test_keyword(self):
114 # make sure keyword arguments are captured correctly
115 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100116 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000117 expected = {'a':a,'x':None}
118 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000119 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000120
121 def test_no_side_effects(self):
122 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100123 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000124 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000125 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000126 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000127 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000128
129 def test_error_propagation(self):
130 def f(x, y):
131 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100132 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
133 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
134 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
135 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000136
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000137 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100138 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000139 p = proxy(f)
140 self.assertEqual(f.func, p.func)
141 f = None
142 self.assertRaises(ReferenceError, getattr, p, 'func')
143
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000144 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000145 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100146 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000147 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100148 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000149 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000150
Alexander Belopolskye49af342015-03-01 15:08:17 -0500151 def test_nested_optimization(self):
152 partial = self.partial
Alexander Belopolskye49af342015-03-01 15:08:17 -0500153 inner = partial(signature, 'asdf')
154 nested = partial(inner, bar=True)
155 flat = partial(signature, 'asdf', bar=True)
156 self.assertEqual(signature(nested), signature(flat))
157
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300158 def test_nested_partial_with_attribute(self):
159 # see issue 25137
160 partial = self.partial
161
162 def foo(bar):
163 return bar
164
165 p = partial(foo, 'first')
166 p2 = partial(p, 'second')
167 p2.new_attr = 'spam'
168 self.assertEqual(p2.new_attr, 'spam')
169
Łukasz Langa6f692512013-06-05 12:20:24 +0200170
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200171@unittest.skipUnless(c_functools, 'requires the C _functools module')
172class TestPartialC(TestPartial, unittest.TestCase):
173 if c_functools:
174 partial = c_functools.partial
175
Zachary Ware101d9e72013-12-08 00:44:27 -0600176 def test_attributes_unwritable(self):
177 # attributes should not be writable
178 p = self.partial(capture, 1, 2, a=10, b=20)
179 self.assertRaises(AttributeError, setattr, p, 'func', map)
180 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
181 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
182
183 p = self.partial(hex)
184 try:
185 del p.__dict__
186 except TypeError:
187 pass
188 else:
189 self.fail('partial object allowed __dict__ to be deleted')
190
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000191 def test_repr(self):
192 args = (object(), object())
193 args_repr = ', '.join(repr(a) for a in args)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200194 kwargs = {'a': object(), 'b': object()}
195 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
196 'b={b!r}, a={a!r}'.format_map(kwargs)]
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200197 if self.partial is c_functools.partial:
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000198 name = 'functools.partial'
199 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100200 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000201
Antoine Pitroub5b37142012-11-13 21:35:40 +0100202 f = self.partial(capture)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000203 self.assertEqual('{}({!r})'.format(name, capture),
204 repr(f))
205
Antoine Pitroub5b37142012-11-13 21:35:40 +0100206 f = self.partial(capture, *args)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000207 self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr),
208 repr(f))
209
Antoine Pitroub5b37142012-11-13 21:35:40 +0100210 f = self.partial(capture, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200211 self.assertIn(repr(f),
212 ['{}({!r}, {})'.format(name, capture, kwargs_repr)
213 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000214
Antoine Pitroub5b37142012-11-13 21:35:40 +0100215 f = self.partial(capture, *args, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200216 self.assertIn(repr(f),
217 ['{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr)
218 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000219
Jack Diederiche0cbd692009-04-01 04:27:09 +0000220 def test_pickle(self):
Serhiy Storchaka38741282016-02-02 18:45:17 +0200221 f = self.partial(signature, ['asdf'], bar=[True])
222 f.attr = []
Serhiy Storchakabad12572014-12-15 14:03:42 +0200223 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
224 f_copy = pickle.loads(pickle.dumps(f, proto))
Serhiy Storchaka38741282016-02-02 18:45:17 +0200225 self.assertEqual(signature(f_copy), signature(f))
226
227 def test_copy(self):
228 f = self.partial(signature, ['asdf'], bar=[True])
229 f.attr = []
230 f_copy = copy.copy(f)
231 self.assertEqual(signature(f_copy), signature(f))
232 self.assertIs(f_copy.attr, f.attr)
233 self.assertIs(f_copy.args, f.args)
234 self.assertIs(f_copy.keywords, f.keywords)
235
236 def test_deepcopy(self):
237 f = self.partial(signature, ['asdf'], bar=[True])
238 f.attr = []
239 f_copy = copy.deepcopy(f)
240 self.assertEqual(signature(f_copy), signature(f))
241 self.assertIsNot(f_copy.attr, f.attr)
242 self.assertIsNot(f_copy.args, f.args)
243 self.assertIsNot(f_copy.args[0], f.args[0])
244 self.assertIsNot(f_copy.keywords, f.keywords)
245 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
246
247 def test_setstate(self):
248 f = self.partial(signature)
249 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
250 self.assertEqual(signature(f),
251 (capture, (1,), dict(a=10), dict(attr=[])))
252 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
253
254 f.__setstate__((capture, (1,), dict(a=10), None))
255 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
256 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
257
258 f.__setstate__((capture, (1,), None, None))
259 #self.assertEqual(signature(f), (capture, (1,), {}, {}))
260 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
261 self.assertEqual(f(2), ((1, 2), {}))
262 self.assertEqual(f(), ((1,), {}))
263
264 f.__setstate__((capture, (), {}, None))
265 self.assertEqual(signature(f), (capture, (), {}, {}))
266 self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
267 self.assertEqual(f(2), ((2,), {}))
268 self.assertEqual(f(), ((), {}))
269
270 def test_setstate_errors(self):
271 f = self.partial(signature)
272 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
273 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
274 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
275 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
276 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
277 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
278 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
279
280 def test_setstate_subclasses(self):
281 f = self.partial(signature)
282 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
283 s = signature(f)
284 self.assertEqual(s, (capture, (1,), dict(a=10), {}))
285 self.assertIs(type(s[1]), tuple)
286 self.assertIs(type(s[2]), dict)
287 r = f()
288 self.assertEqual(r, ((1,), {'a': 10}))
289 self.assertIs(type(r[0]), tuple)
290 self.assertIs(type(r[1]), dict)
291
292 f.__setstate__((capture, BadTuple((1,)), {}, None))
293 s = signature(f)
294 self.assertEqual(s, (capture, (1,), {}, {}))
295 self.assertIs(type(s[1]), tuple)
296 r = f(2)
297 self.assertEqual(r, ((1, 2), {}))
298 self.assertIs(type(r[0]), tuple)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000299
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200300 # Issue 6083: Reference counting bug
301 def test_setstate_refcount(self):
302 class BadSequence:
303 def __len__(self):
304 return 4
305 def __getitem__(self, key):
306 if key == 0:
307 return max
308 elif key == 1:
309 return tuple(range(1000000))
310 elif key in (2, 3):
311 return {}
312 raise IndexError
313
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200314 f = self.partial(object)
Serhiy Storchaka38741282016-02-02 18:45:17 +0200315 self.assertRaises(TypeError, f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000316
Łukasz Langa6f692512013-06-05 12:20:24 +0200317
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200318class TestPartialPy(TestPartial, unittest.TestCase):
319 partial = staticmethod(py_functools.partial)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000320
Łukasz Langa6f692512013-06-05 12:20:24 +0200321
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200322if c_functools:
323 class PartialSubclass(c_functools.partial):
324 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100325
Łukasz Langa6f692512013-06-05 12:20:24 +0200326
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200327@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200328class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200329 if c_functools:
330 partial = PartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000331
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300332 # partial subclasses are not optimized for nested calls
333 test_nested_optimization = None
334
Łukasz Langa6f692512013-06-05 12:20:24 +0200335
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000336class TestPartialMethod(unittest.TestCase):
337
338 class A(object):
339 nothing = functools.partialmethod(capture)
340 positional = functools.partialmethod(capture, 1)
341 keywords = functools.partialmethod(capture, a=2)
342 both = functools.partialmethod(capture, 3, b=4)
343
344 nested = functools.partialmethod(positional, 5)
345
346 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
347
348 static = functools.partialmethod(staticmethod(capture), 8)
349 cls = functools.partialmethod(classmethod(capture), d=9)
350
351 a = A()
352
353 def test_arg_combinations(self):
354 self.assertEqual(self.a.nothing(), ((self.a,), {}))
355 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
356 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
357 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
358
359 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
360 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
361 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
362 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
363
364 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
365 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
366 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
367 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
368
369 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
370 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
371 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
372 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
373
374 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
375
376 def test_nested(self):
377 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
378 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
379 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
380 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
381
382 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
383
384 def test_over_partial(self):
385 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
386 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
387 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
388 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
389
390 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
391
392 def test_bound_method_introspection(self):
393 obj = self.a
394 self.assertIs(obj.both.__self__, obj)
395 self.assertIs(obj.nested.__self__, obj)
396 self.assertIs(obj.over_partial.__self__, obj)
397 self.assertIs(obj.cls.__self__, self.A)
398 self.assertIs(self.A.cls.__self__, self.A)
399
400 def test_unbound_method_retrieval(self):
401 obj = self.A
402 self.assertFalse(hasattr(obj.both, "__self__"))
403 self.assertFalse(hasattr(obj.nested, "__self__"))
404 self.assertFalse(hasattr(obj.over_partial, "__self__"))
405 self.assertFalse(hasattr(obj.static, "__self__"))
406 self.assertFalse(hasattr(self.a.static, "__self__"))
407
408 def test_descriptors(self):
409 for obj in [self.A, self.a]:
410 with self.subTest(obj=obj):
411 self.assertEqual(obj.static(), ((8,), {}))
412 self.assertEqual(obj.static(5), ((8, 5), {}))
413 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
414 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
415
416 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
417 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
418 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
419 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
420
421 def test_overriding_keywords(self):
422 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
423 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
424
425 def test_invalid_args(self):
426 with self.assertRaises(TypeError):
427 class B(object):
428 method = functools.partialmethod(None, 1)
429
430 def test_repr(self):
431 self.assertEqual(repr(vars(self.A)['both']),
432 'functools.partialmethod({}, 3, b=4)'.format(capture))
433
434 def test_abstract(self):
435 class Abstract(abc.ABCMeta):
436
437 @abc.abstractmethod
438 def add(self, x, y):
439 pass
440
441 add5 = functools.partialmethod(add, 5)
442
443 self.assertTrue(Abstract.add.__isabstractmethod__)
444 self.assertTrue(Abstract.add5.__isabstractmethod__)
445
446 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
447 self.assertFalse(getattr(func, '__isabstractmethod__', False))
448
449
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000450class TestUpdateWrapper(unittest.TestCase):
451
452 def check_wrapper(self, wrapper, wrapped,
453 assigned=functools.WRAPPER_ASSIGNMENTS,
454 updated=functools.WRAPPER_UPDATES):
455 # Check attributes were assigned
456 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000457 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000458 # Check attributes were updated
459 for name in updated:
460 wrapper_attr = getattr(wrapper, name)
461 wrapped_attr = getattr(wrapped, name)
462 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000463 if name == "__dict__" and key == "__wrapped__":
464 # __wrapped__ is overwritten by the update code
465 continue
466 self.assertIs(wrapped_attr[key], wrapper_attr[key])
467 # Check __wrapped__
468 self.assertIs(wrapper.__wrapped__, wrapped)
469
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000470
R. David Murray378c0cf2010-02-24 01:46:21 +0000471 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000472 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000473 """This is a test"""
474 pass
475 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000476 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000477 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000478 pass
479 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000480 return wrapper, f
481
482 def test_default_update(self):
483 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000484 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000485 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000486 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600487 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000488 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000489 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
490 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000491
R. David Murray378c0cf2010-02-24 01:46:21 +0000492 @unittest.skipIf(sys.flags.optimize >= 2,
493 "Docstrings are omitted with -O2 and above")
494 def test_default_update_doc(self):
495 wrapper, f = self._default_update()
496 self.assertEqual(wrapper.__doc__, 'This is a test')
497
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000498 def test_no_update(self):
499 def f():
500 """This is a test"""
501 pass
502 f.attr = 'This is also a test'
503 def wrapper():
504 pass
505 functools.update_wrapper(wrapper, f, (), ())
506 self.check_wrapper(wrapper, f, (), ())
507 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600508 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000509 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000510 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000511 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000512
513 def test_selective_update(self):
514 def f():
515 pass
516 f.attr = 'This is a different test'
517 f.dict_attr = dict(a=1, b=2, c=3)
518 def wrapper():
519 pass
520 wrapper.dict_attr = {}
521 assign = ('attr',)
522 update = ('dict_attr',)
523 functools.update_wrapper(wrapper, f, assign, update)
524 self.check_wrapper(wrapper, f, assign, update)
525 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600526 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000527 self.assertEqual(wrapper.__doc__, None)
528 self.assertEqual(wrapper.attr, 'This is a different test')
529 self.assertEqual(wrapper.dict_attr, f.dict_attr)
530
Nick Coghlan98876832010-08-17 06:17:18 +0000531 def test_missing_attributes(self):
532 def f():
533 pass
534 def wrapper():
535 pass
536 wrapper.dict_attr = {}
537 assign = ('attr',)
538 update = ('dict_attr',)
539 # Missing attributes on wrapped object are ignored
540 functools.update_wrapper(wrapper, f, assign, update)
541 self.assertNotIn('attr', wrapper.__dict__)
542 self.assertEqual(wrapper.dict_attr, {})
543 # Wrapper must have expected attributes for updating
544 del wrapper.dict_attr
545 with self.assertRaises(AttributeError):
546 functools.update_wrapper(wrapper, f, assign, update)
547 wrapper.dict_attr = 1
548 with self.assertRaises(AttributeError):
549 functools.update_wrapper(wrapper, f, assign, update)
550
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200551 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000552 @unittest.skipIf(sys.flags.optimize >= 2,
553 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000554 def test_builtin_update(self):
555 # Test for bug #1576241
556 def wrapper():
557 pass
558 functools.update_wrapper(wrapper, max)
559 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000560 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000561 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000562
Łukasz Langa6f692512013-06-05 12:20:24 +0200563
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000564class TestWraps(TestUpdateWrapper):
565
R. David Murray378c0cf2010-02-24 01:46:21 +0000566 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000567 def f():
568 """This is a test"""
569 pass
570 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000571 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000572 @functools.wraps(f)
573 def wrapper():
574 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600575 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000576
577 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600578 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000579 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000580 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600581 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000582 self.assertEqual(wrapper.attr, 'This is also a test')
583
Antoine Pitroub5b37142012-11-13 21:35:40 +0100584 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000585 "Docstrings are omitted with -O2 and above")
586 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600587 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000588 self.assertEqual(wrapper.__doc__, 'This is a test')
589
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000590 def test_no_update(self):
591 def f():
592 """This is a test"""
593 pass
594 f.attr = 'This is also a test'
595 @functools.wraps(f, (), ())
596 def wrapper():
597 pass
598 self.check_wrapper(wrapper, f, (), ())
599 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600600 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000601 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000602 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000603
604 def test_selective_update(self):
605 def f():
606 pass
607 f.attr = 'This is a different test'
608 f.dict_attr = dict(a=1, b=2, c=3)
609 def add_dict_attr(f):
610 f.dict_attr = {}
611 return f
612 assign = ('attr',)
613 update = ('dict_attr',)
614 @functools.wraps(f, assign, update)
615 @add_dict_attr
616 def wrapper():
617 pass
618 self.check_wrapper(wrapper, f, assign, update)
619 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600620 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000621 self.assertEqual(wrapper.__doc__, None)
622 self.assertEqual(wrapper.attr, 'This is a different test')
623 self.assertEqual(wrapper.dict_attr, f.dict_attr)
624
Łukasz Langa6f692512013-06-05 12:20:24 +0200625
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000626class TestReduce(unittest.TestCase):
627 func = functools.reduce
628
629 def test_reduce(self):
630 class Squares:
631 def __init__(self, max):
632 self.max = max
633 self.sofar = []
634
635 def __len__(self):
636 return len(self.sofar)
637
638 def __getitem__(self, i):
639 if not 0 <= i < self.max: raise IndexError
640 n = len(self.sofar)
641 while n <= i:
642 self.sofar.append(n*n)
643 n += 1
644 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000645 def add(x, y):
646 return x + y
647 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000648 self.assertEqual(
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000649 self.func(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000650 ['a','c','d','w']
651 )
652 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
653 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000654 self.func(lambda x, y: x*y, range(2,21), 1),
655 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000656 )
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000657 self.assertEqual(self.func(add, Squares(10)), 285)
658 self.assertEqual(self.func(add, Squares(10), 0), 285)
659 self.assertEqual(self.func(add, Squares(0), 0), 0)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000660 self.assertRaises(TypeError, self.func)
661 self.assertRaises(TypeError, self.func, 42, 42)
662 self.assertRaises(TypeError, self.func, 42, 42, 42)
663 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
664 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
665 self.assertRaises(TypeError, self.func, 42, (42, 42))
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000666 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
667 self.assertRaises(TypeError, self.func, add, "")
668 self.assertRaises(TypeError, self.func, add, ())
669 self.assertRaises(TypeError, self.func, add, object())
670
671 class TestFailingIter:
672 def __iter__(self):
673 raise RuntimeError
674 self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
675
676 self.assertEqual(self.func(add, [], None), None)
677 self.assertEqual(self.func(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000678
679 class BadSeq:
680 def __getitem__(self, index):
681 raise ValueError
682 self.assertRaises(ValueError, self.func, 42, BadSeq())
683
684 # Test reduce()'s use of iterators.
685 def test_iterator_usage(self):
686 class SequenceClass:
687 def __init__(self, n):
688 self.n = n
689 def __getitem__(self, i):
690 if 0 <= i < self.n:
691 return i
692 else:
693 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000694
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000695 from operator import add
696 self.assertEqual(self.func(add, SequenceClass(5)), 10)
697 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
698 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
699 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
700 self.assertEqual(self.func(add, SequenceClass(1)), 0)
701 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
702
703 d = {"one": 1, "two": 2, "three": 3}
704 self.assertEqual(self.func(add, d), "".join(d.keys()))
705
Łukasz Langa6f692512013-06-05 12:20:24 +0200706
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200707class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700708
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000709 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700710 def cmp1(x, y):
711 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100712 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700713 self.assertEqual(key(3), key(3))
714 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100715 self.assertGreaterEqual(key(3), key(3))
716
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700717 def cmp2(x, y):
718 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100719 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700720 self.assertEqual(key(4.0), key('4'))
721 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100722 self.assertLessEqual(key(2), key('35'))
723 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700724
725 def test_cmp_to_key_arguments(self):
726 def cmp1(x, y):
727 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100728 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700729 self.assertEqual(key(obj=3), key(obj=3))
730 self.assertGreater(key(obj=3), key(obj=1))
731 with self.assertRaises((TypeError, AttributeError)):
732 key(3) > 1 # rhs is not a K object
733 with self.assertRaises((TypeError, AttributeError)):
734 1 < key(3) # lhs is not a K object
735 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100736 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700737 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200738 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100739 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700740 with self.assertRaises(TypeError):
741 key() # too few args
742 with self.assertRaises(TypeError):
743 key(None, None) # too many args
744
745 def test_bad_cmp(self):
746 def cmp1(x, y):
747 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100748 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700749 with self.assertRaises(ZeroDivisionError):
750 key(3) > key(1)
751
752 class BadCmp:
753 def __lt__(self, other):
754 raise ZeroDivisionError
755 def cmp1(x, y):
756 return BadCmp()
757 with self.assertRaises(ZeroDivisionError):
758 key(3) > key(1)
759
760 def test_obj_field(self):
761 def cmp1(x, y):
762 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100763 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700764 self.assertEqual(key(50).obj, 50)
765
766 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000767 def mycmp(x, y):
768 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100769 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000770 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000771
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700772 def test_sort_int_str(self):
773 def mycmp(x, y):
774 x, y = int(x), int(y)
775 return (x > y) - (x < y)
776 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100777 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700778 self.assertEqual([int(value) for value in values],
779 [0, 1, 1, 2, 3, 4, 5, 7, 10])
780
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000781 def test_hash(self):
782 def mycmp(x, y):
783 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100784 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000785 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700786 self.assertRaises(TypeError, hash, k)
Raymond Hettingere7a24302011-05-03 11:16:36 -0700787 self.assertNotIsInstance(k, collections.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000788
Łukasz Langa6f692512013-06-05 12:20:24 +0200789
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200790@unittest.skipUnless(c_functools, 'requires the C _functools module')
791class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
792 if c_functools:
793 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100794
Łukasz Langa6f692512013-06-05 12:20:24 +0200795
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200796class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100797 cmp_to_key = staticmethod(py_functools.cmp_to_key)
798
Łukasz Langa6f692512013-06-05 12:20:24 +0200799
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000800class TestTotalOrdering(unittest.TestCase):
801
802 def test_total_ordering_lt(self):
803 @functools.total_ordering
804 class A:
805 def __init__(self, value):
806 self.value = value
807 def __lt__(self, other):
808 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000809 def __eq__(self, other):
810 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000811 self.assertTrue(A(1) < A(2))
812 self.assertTrue(A(2) > A(1))
813 self.assertTrue(A(1) <= A(2))
814 self.assertTrue(A(2) >= A(1))
815 self.assertTrue(A(2) <= A(2))
816 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000817 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000818
819 def test_total_ordering_le(self):
820 @functools.total_ordering
821 class A:
822 def __init__(self, value):
823 self.value = value
824 def __le__(self, other):
825 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000826 def __eq__(self, other):
827 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000828 self.assertTrue(A(1) < A(2))
829 self.assertTrue(A(2) > A(1))
830 self.assertTrue(A(1) <= A(2))
831 self.assertTrue(A(2) >= A(1))
832 self.assertTrue(A(2) <= A(2))
833 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000834 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000835
836 def test_total_ordering_gt(self):
837 @functools.total_ordering
838 class A:
839 def __init__(self, value):
840 self.value = value
841 def __gt__(self, other):
842 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000843 def __eq__(self, other):
844 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000845 self.assertTrue(A(1) < A(2))
846 self.assertTrue(A(2) > A(1))
847 self.assertTrue(A(1) <= A(2))
848 self.assertTrue(A(2) >= A(1))
849 self.assertTrue(A(2) <= A(2))
850 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000851 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000852
853 def test_total_ordering_ge(self):
854 @functools.total_ordering
855 class A:
856 def __init__(self, value):
857 self.value = value
858 def __ge__(self, other):
859 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000860 def __eq__(self, other):
861 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000862 self.assertTrue(A(1) < A(2))
863 self.assertTrue(A(2) > A(1))
864 self.assertTrue(A(1) <= A(2))
865 self.assertTrue(A(2) >= A(1))
866 self.assertTrue(A(2) <= A(2))
867 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000868 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000869
870 def test_total_ordering_no_overwrite(self):
871 # new methods should not overwrite existing
872 @functools.total_ordering
873 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000874 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +0000875 self.assertTrue(A(1) < A(2))
876 self.assertTrue(A(2) > A(1))
877 self.assertTrue(A(1) <= A(2))
878 self.assertTrue(A(2) >= A(1))
879 self.assertTrue(A(2) <= A(2))
880 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000881
Benjamin Peterson42ebee32010-04-11 01:43:16 +0000882 def test_no_operations_defined(self):
883 with self.assertRaises(ValueError):
884 @functools.total_ordering
885 class A:
886 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000887
Nick Coghlanf05d9812013-10-02 00:02:03 +1000888 def test_type_error_when_not_implemented(self):
889 # bug 10042; ensure stack overflow does not occur
890 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000891 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +1000892 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000893 def __init__(self, value):
894 self.value = value
895 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000896 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000897 return self.value == other.value
898 return False
899 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000900 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000901 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +1000902 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000903
Nick Coghlanf05d9812013-10-02 00:02:03 +1000904 @functools.total_ordering
905 class ImplementsGreaterThan:
906 def __init__(self, value):
907 self.value = value
908 def __eq__(self, other):
909 if isinstance(other, ImplementsGreaterThan):
910 return self.value == other.value
911 return False
912 def __gt__(self, other):
913 if isinstance(other, ImplementsGreaterThan):
914 return self.value > other.value
915 return NotImplemented
916
917 @functools.total_ordering
918 class ImplementsLessThanEqualTo:
919 def __init__(self, value):
920 self.value = value
921 def __eq__(self, other):
922 if isinstance(other, ImplementsLessThanEqualTo):
923 return self.value == other.value
924 return False
925 def __le__(self, other):
926 if isinstance(other, ImplementsLessThanEqualTo):
927 return self.value <= other.value
928 return NotImplemented
929
930 @functools.total_ordering
931 class ImplementsGreaterThanEqualTo:
932 def __init__(self, value):
933 self.value = value
934 def __eq__(self, other):
935 if isinstance(other, ImplementsGreaterThanEqualTo):
936 return self.value == other.value
937 return False
938 def __ge__(self, other):
939 if isinstance(other, ImplementsGreaterThanEqualTo):
940 return self.value >= other.value
941 return NotImplemented
942
943 @functools.total_ordering
944 class ComparatorNotImplemented:
945 def __init__(self, value):
946 self.value = value
947 def __eq__(self, other):
948 if isinstance(other, ComparatorNotImplemented):
949 return self.value == other.value
950 return False
951 def __lt__(self, other):
952 return NotImplemented
953
954 with self.subTest("LT < 1"), self.assertRaises(TypeError):
955 ImplementsLessThan(-1) < 1
956
957 with self.subTest("LT < LE"), self.assertRaises(TypeError):
958 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
959
960 with self.subTest("LT < GT"), self.assertRaises(TypeError):
961 ImplementsLessThan(1) < ImplementsGreaterThan(1)
962
963 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
964 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
965
966 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
967 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
968
969 with self.subTest("GT > GE"), self.assertRaises(TypeError):
970 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
971
972 with self.subTest("GT > LT"), self.assertRaises(TypeError):
973 ImplementsGreaterThan(5) > ImplementsLessThan(5)
974
975 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
976 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
977
978 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
979 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
980
981 with self.subTest("GE when equal"):
982 a = ComparatorNotImplemented(8)
983 b = ComparatorNotImplemented(8)
984 self.assertEqual(a, b)
985 with self.assertRaises(TypeError):
986 a >= b
987
988 with self.subTest("LE when equal"):
989 a = ComparatorNotImplemented(9)
990 b = ComparatorNotImplemented(9)
991 self.assertEqual(a, b)
992 with self.assertRaises(TypeError):
993 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +0200994
Serhiy Storchaka697a5262015-01-01 15:23:12 +0200995 def test_pickle(self):
996 for proto in range(4, pickle.HIGHEST_PROTOCOL + 1):
997 for name in '__lt__', '__gt__', '__le__', '__ge__':
998 with self.subTest(method=name, proto=proto):
999 method = getattr(Orderable_LT, name)
1000 method_copy = pickle.loads(pickle.dumps(method, proto))
1001 self.assertIs(method_copy, method)
1002
1003@functools.total_ordering
1004class Orderable_LT:
1005 def __init__(self, value):
1006 self.value = value
1007 def __lt__(self, other):
1008 return self.value < other.value
1009 def __eq__(self, other):
1010 return self.value == other.value
1011
1012
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001013class TestLRU:
Georg Brandl2e7346a2010-07-31 18:09:23 +00001014
1015 def test_lru(self):
1016 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001017 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001018 f = self.module.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001019 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001020 self.assertEqual(maxsize, 20)
1021 self.assertEqual(currsize, 0)
1022 self.assertEqual(hits, 0)
1023 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001024
1025 domain = range(5)
1026 for i in range(1000):
1027 x, y = choice(domain), choice(domain)
1028 actual = f(x, y)
1029 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +00001030 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001031 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001032 self.assertTrue(hits > misses)
1033 self.assertEqual(hits + misses, 1000)
1034 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001035
Raymond Hettinger02566ec2010-09-04 22:46:06 +00001036 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +00001037 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001038 self.assertEqual(hits, 0)
1039 self.assertEqual(misses, 0)
1040 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001041 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001042 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001043 self.assertEqual(hits, 0)
1044 self.assertEqual(misses, 1)
1045 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001046
Nick Coghlan98876832010-08-17 06:17:18 +00001047 # Test bypassing the cache
1048 self.assertIs(f.__wrapped__, orig)
1049 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001050 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001051 self.assertEqual(hits, 0)
1052 self.assertEqual(misses, 1)
1053 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +00001054
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001055 # test size zero (which means "never-cache")
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001056 @self.module.lru_cache(0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001057 def f():
1058 nonlocal f_cnt
1059 f_cnt += 1
1060 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001061 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001062 f_cnt = 0
1063 for i in range(5):
1064 self.assertEqual(f(), 20)
1065 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001066 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001067 self.assertEqual(hits, 0)
1068 self.assertEqual(misses, 5)
1069 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001070
1071 # test size one
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001072 @self.module.lru_cache(1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001073 def f():
1074 nonlocal f_cnt
1075 f_cnt += 1
1076 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001077 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001078 f_cnt = 0
1079 for i in range(5):
1080 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001081 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001082 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001083 self.assertEqual(hits, 4)
1084 self.assertEqual(misses, 1)
1085 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001086
Raymond Hettingerf3098282010-08-15 03:30:45 +00001087 # test size two
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001088 @self.module.lru_cache(2)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001089 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001090 nonlocal f_cnt
1091 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +00001092 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +00001093 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001094 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +00001095 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1096 # * * * *
1097 self.assertEqual(f(x), x*10)
1098 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001099 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001100 self.assertEqual(hits, 12)
1101 self.assertEqual(misses, 4)
1102 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001103
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001104 def test_lru_with_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001105 @self.module.lru_cache(maxsize=None)
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001106 def fib(n):
1107 if n < 2:
1108 return n
1109 return fib(n-1) + fib(n-2)
1110 self.assertEqual([fib(n) for n in range(16)],
1111 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1112 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001113 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001114 fib.cache_clear()
1115 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001116 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1117
1118 def test_lru_with_maxsize_negative(self):
1119 @self.module.lru_cache(maxsize=-10)
1120 def eq(n):
1121 return n
1122 for i in (0, 1):
1123 self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1124 self.assertEqual(eq.cache_info(),
1125 self.module._CacheInfo(hits=0, misses=300, maxsize=-10, currsize=1))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001126
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001127 def test_lru_with_exceptions(self):
1128 # Verify that user_function exceptions get passed through without
1129 # creating a hard-to-read chained exception.
1130 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +01001131 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001132 @self.module.lru_cache(maxsize)
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001133 def func(i):
1134 return 'abc'[i]
1135 self.assertEqual(func(0), 'a')
1136 with self.assertRaises(IndexError) as cm:
1137 func(15)
1138 self.assertIsNone(cm.exception.__context__)
1139 # Verify that the previous exception did not result in a cached entry
1140 with self.assertRaises(IndexError):
1141 func(15)
1142
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001143 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001144 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001145 @self.module.lru_cache(maxsize=maxsize, typed=True)
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001146 def square(x):
1147 return x * x
1148 self.assertEqual(square(3), 9)
1149 self.assertEqual(type(square(3)), type(9))
1150 self.assertEqual(square(3.0), 9.0)
1151 self.assertEqual(type(square(3.0)), type(9.0))
1152 self.assertEqual(square(x=3), 9)
1153 self.assertEqual(type(square(x=3)), type(9))
1154 self.assertEqual(square(x=3.0), 9.0)
1155 self.assertEqual(type(square(x=3.0)), type(9.0))
1156 self.assertEqual(square.cache_info().hits, 4)
1157 self.assertEqual(square.cache_info().misses, 4)
1158
Antoine Pitroub5b37142012-11-13 21:35:40 +01001159 def test_lru_with_keyword_args(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001160 @self.module.lru_cache()
Antoine Pitroub5b37142012-11-13 21:35:40 +01001161 def fib(n):
1162 if n < 2:
1163 return n
1164 return fib(n=n-1) + fib(n=n-2)
1165 self.assertEqual(
1166 [fib(n=number) for number in range(16)],
1167 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1168 )
1169 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001170 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001171 fib.cache_clear()
1172 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001173 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001174
1175 def test_lru_with_keyword_args_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001176 @self.module.lru_cache(maxsize=None)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001177 def fib(n):
1178 if n < 2:
1179 return n
1180 return fib(n=n-1) + fib(n=n-2)
1181 self.assertEqual([fib(n=number) for number in range(16)],
1182 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1183 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001184 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001185 fib.cache_clear()
1186 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001187 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1188
1189 def test_lru_cache_decoration(self):
1190 def f(zomg: 'zomg_annotation'):
1191 """f doc string"""
1192 return 42
1193 g = self.module.lru_cache()(f)
1194 for attr in self.module.WRAPPER_ASSIGNMENTS:
1195 self.assertEqual(getattr(g, attr), getattr(f, attr))
1196
1197 @unittest.skipUnless(threading, 'This test requires threading.')
1198 def test_lru_cache_threaded(self):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001199 n, m = 5, 11
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001200 def orig(x, y):
1201 return 3 * x + y
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001202 f = self.module.lru_cache(maxsize=n*m)(orig)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001203 hits, misses, maxsize, currsize = f.cache_info()
1204 self.assertEqual(currsize, 0)
1205
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001206 start = threading.Event()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001207 def full(k):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001208 start.wait(10)
1209 for _ in range(m):
Serhiy Storchaka391af752015-06-08 12:44:18 +03001210 self.assertEqual(f(k, 0), orig(k, 0))
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001211
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001212 def clear():
1213 start.wait(10)
1214 for _ in range(2*m):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001215 f.cache_clear()
1216
1217 orig_si = sys.getswitchinterval()
1218 sys.setswitchinterval(1e-6)
1219 try:
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001220 # create n threads in order to fill cache
Serhiy Storchaka391af752015-06-08 12:44:18 +03001221 threads = [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001222 for k in range(n)]
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001223 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001224 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001225
1226 hits, misses, maxsize, currsize = f.cache_info()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001227 if self.module is py_functools:
1228 # XXX: Why can be not equal?
1229 self.assertLessEqual(misses, n)
1230 self.assertLessEqual(hits, m*n - misses)
1231 else:
1232 self.assertEqual(misses, n)
1233 self.assertEqual(hits, m*n - misses)
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001234 self.assertEqual(currsize, n)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001235
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001236 # create n threads in order to fill cache and 1 to clear it
1237 threads = [threading.Thread(target=clear)]
Serhiy Storchaka391af752015-06-08 12:44:18 +03001238 threads += [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001239 for k in range(n)]
1240 start.clear()
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001241 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001242 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001243 finally:
1244 sys.setswitchinterval(orig_si)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001245
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001246 @unittest.skipUnless(threading, 'This test requires threading.')
1247 def test_lru_cache_threaded2(self):
1248 # Simultaneous call with the same arguments
1249 n, m = 5, 7
1250 start = threading.Barrier(n+1)
1251 pause = threading.Barrier(n+1)
1252 stop = threading.Barrier(n+1)
1253 @self.module.lru_cache(maxsize=m*n)
1254 def f(x):
1255 pause.wait(10)
1256 return 3 * x
1257 self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1258 def test():
1259 for i in range(m):
1260 start.wait(10)
1261 self.assertEqual(f(i), 3 * i)
1262 stop.wait(10)
1263 threads = [threading.Thread(target=test) for k in range(n)]
1264 with support.start_threads(threads):
1265 for i in range(m):
1266 start.wait(10)
1267 stop.reset()
1268 pause.wait(10)
1269 start.reset()
1270 stop.wait(10)
1271 pause.reset()
1272 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1273
Raymond Hettinger03923422013-03-04 02:52:50 -05001274 def test_need_for_rlock(self):
1275 # This will deadlock on an LRU cache that uses a regular lock
1276
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001277 @self.module.lru_cache(maxsize=10)
Raymond Hettinger03923422013-03-04 02:52:50 -05001278 def test_func(x):
1279 'Used to demonstrate a reentrant lru_cache call within a single thread'
1280 return x
1281
1282 class DoubleEq:
1283 'Demonstrate a reentrant lru_cache call within a single thread'
1284 def __init__(self, x):
1285 self.x = x
1286 def __hash__(self):
1287 return self.x
1288 def __eq__(self, other):
1289 if self.x == 2:
1290 test_func(DoubleEq(1))
1291 return self.x == other.x
1292
1293 test_func(DoubleEq(1)) # Load the cache
1294 test_func(DoubleEq(2)) # Load the cache
1295 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1296 DoubleEq(2)) # Verify the correct return value
1297
Raymond Hettinger4d588972014-08-12 12:44:52 -07001298 def test_early_detection_of_bad_call(self):
1299 # Issue #22184
1300 with self.assertRaises(TypeError):
1301 @functools.lru_cache
1302 def f():
1303 pass
1304
Serhiy Storchakae7070f02015-06-08 11:19:24 +03001305 def test_lru_method(self):
1306 class X(int):
1307 f_cnt = 0
1308 @self.module.lru_cache(2)
1309 def f(self, x):
1310 self.f_cnt += 1
1311 return x*10+self
1312 a = X(5)
1313 b = X(5)
1314 c = X(7)
1315 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1316
1317 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1318 self.assertEqual(a.f(x), x*10 + 5)
1319 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1320 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1321
1322 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1323 self.assertEqual(b.f(x), x*10 + 5)
1324 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1325 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1326
1327 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1328 self.assertEqual(c.f(x), x*10 + 7)
1329 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1330 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1331
1332 self.assertEqual(a.f.cache_info(), X.f.cache_info())
1333 self.assertEqual(b.f.cache_info(), X.f.cache_info())
1334 self.assertEqual(c.f.cache_info(), X.f.cache_info())
1335
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001336 def test_pickle(self):
1337 cls = self.__class__
1338 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1339 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1340 with self.subTest(proto=proto, func=f):
1341 f_copy = pickle.loads(pickle.dumps(f, proto))
1342 self.assertIs(f_copy, f)
1343
1344 def test_copy(self):
1345 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001346 def orig(x, y):
1347 return 3 * x + y
1348 part = self.module.partial(orig, 2)
1349 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1350 self.module.lru_cache(2)(part))
1351 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001352 with self.subTest(func=f):
1353 f_copy = copy.copy(f)
1354 self.assertIs(f_copy, f)
1355
1356 def test_deepcopy(self):
1357 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001358 def orig(x, y):
1359 return 3 * x + y
1360 part = self.module.partial(orig, 2)
1361 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1362 self.module.lru_cache(2)(part))
1363 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001364 with self.subTest(func=f):
1365 f_copy = copy.deepcopy(f)
1366 self.assertIs(f_copy, f)
1367
1368
1369@py_functools.lru_cache()
1370def py_cached_func(x, y):
1371 return 3 * x + y
1372
1373@c_functools.lru_cache()
1374def c_cached_func(x, y):
1375 return 3 * x + y
1376
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001377
1378class TestLRUPy(TestLRU, unittest.TestCase):
1379 module = py_functools
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001380 cached_func = py_cached_func,
1381
1382 @module.lru_cache()
1383 def cached_meth(self, x, y):
1384 return 3 * x + y
1385
1386 @staticmethod
1387 @module.lru_cache()
1388 def cached_staticmeth(x, y):
1389 return 3 * x + y
1390
1391
1392class TestLRUC(TestLRU, unittest.TestCase):
1393 module = c_functools
1394 cached_func = c_cached_func,
1395
1396 @module.lru_cache()
1397 def cached_meth(self, x, y):
1398 return 3 * x + y
1399
1400 @staticmethod
1401 @module.lru_cache()
1402 def cached_staticmeth(x, y):
1403 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001404
Raymond Hettinger03923422013-03-04 02:52:50 -05001405
Łukasz Langa6f692512013-06-05 12:20:24 +02001406class TestSingleDispatch(unittest.TestCase):
1407 def test_simple_overloads(self):
1408 @functools.singledispatch
1409 def g(obj):
1410 return "base"
1411 def g_int(i):
1412 return "integer"
1413 g.register(int, g_int)
1414 self.assertEqual(g("str"), "base")
1415 self.assertEqual(g(1), "integer")
1416 self.assertEqual(g([1,2,3]), "base")
1417
1418 def test_mro(self):
1419 @functools.singledispatch
1420 def g(obj):
1421 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001422 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001423 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001424 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001425 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001426 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001427 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001428 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001429 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001430 def g_A(a):
1431 return "A"
1432 def g_B(b):
1433 return "B"
1434 g.register(A, g_A)
1435 g.register(B, g_B)
1436 self.assertEqual(g(A()), "A")
1437 self.assertEqual(g(B()), "B")
1438 self.assertEqual(g(C()), "A")
1439 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001440
1441 def test_register_decorator(self):
1442 @functools.singledispatch
1443 def g(obj):
1444 return "base"
1445 @g.register(int)
1446 def g_int(i):
1447 return "int %s" % (i,)
1448 self.assertEqual(g(""), "base")
1449 self.assertEqual(g(12), "int 12")
1450 self.assertIs(g.dispatch(int), g_int)
1451 self.assertIs(g.dispatch(object), g.dispatch(str))
1452 # Note: in the assert above this is not g.
1453 # @singledispatch returns the wrapper.
1454
1455 def test_wrapping_attributes(self):
1456 @functools.singledispatch
1457 def g(obj):
1458 "Simple test"
1459 return "Test"
1460 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02001461 if sys.flags.optimize < 2:
1462 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02001463
1464 @unittest.skipUnless(decimal, 'requires _decimal')
1465 @support.cpython_only
1466 def test_c_classes(self):
1467 @functools.singledispatch
1468 def g(obj):
1469 return "base"
1470 @g.register(decimal.DecimalException)
1471 def _(obj):
1472 return obj.args
1473 subn = decimal.Subnormal("Exponent < Emin")
1474 rnd = decimal.Rounded("Number got rounded")
1475 self.assertEqual(g(subn), ("Exponent < Emin",))
1476 self.assertEqual(g(rnd), ("Number got rounded",))
1477 @g.register(decimal.Subnormal)
1478 def _(obj):
1479 return "Too small to care."
1480 self.assertEqual(g(subn), "Too small to care.")
1481 self.assertEqual(g(rnd), ("Number got rounded",))
1482
1483 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001484 # None of the examples in this test depend on haystack ordering.
Łukasz Langa6f692512013-06-05 12:20:24 +02001485 c = collections
1486 mro = functools._compose_mro
1487 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1488 for haystack in permutations(bases):
1489 m = mro(dict, haystack)
Łukasz Langa3720c772013-07-01 16:00:38 +02001490 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, c.Sized,
1491 c.Iterable, c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001492 bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict]
1493 for haystack in permutations(bases):
1494 m = mro(c.ChainMap, haystack)
1495 self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping,
1496 c.Sized, c.Iterable, c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001497
1498 # If there's a generic function with implementations registered for
1499 # both Sized and Container, passing a defaultdict to it results in an
1500 # ambiguous dispatch which will cause a RuntimeError (see
1501 # test_mro_conflicts).
1502 bases = [c.Container, c.Sized, str]
1503 for haystack in permutations(bases):
1504 m = mro(c.defaultdict, [c.Sized, c.Container, str])
1505 self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container,
1506 object])
1507
1508 # MutableSequence below is registered directly on D. In other words, it
1509 # preceeds MutableMapping which means single dispatch will always
1510 # choose MutableSequence here.
1511 class D(c.defaultdict):
1512 pass
1513 c.MutableSequence.register(D)
1514 bases = [c.MutableSequence, c.MutableMapping]
1515 for haystack in permutations(bases):
1516 m = mro(D, bases)
1517 self.assertEqual(m, [D, c.MutableSequence, c.Sequence,
1518 c.defaultdict, dict, c.MutableMapping,
1519 c.Mapping, c.Sized, c.Iterable, c.Container,
1520 object])
1521
1522 # Container and Callable are registered on different base classes and
1523 # a generic function supporting both should always pick the Callable
1524 # implementation if a C instance is passed.
1525 class C(c.defaultdict):
1526 def __call__(self):
1527 pass
1528 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1529 for haystack in permutations(bases):
1530 m = mro(C, haystack)
1531 self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping,
1532 c.Sized, c.Iterable, c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001533
1534 def test_register_abc(self):
1535 c = collections
1536 d = {"a": "b"}
1537 l = [1, 2, 3]
1538 s = {object(), None}
1539 f = frozenset(s)
1540 t = (1, 2, 3)
1541 @functools.singledispatch
1542 def g(obj):
1543 return "base"
1544 self.assertEqual(g(d), "base")
1545 self.assertEqual(g(l), "base")
1546 self.assertEqual(g(s), "base")
1547 self.assertEqual(g(f), "base")
1548 self.assertEqual(g(t), "base")
1549 g.register(c.Sized, lambda obj: "sized")
1550 self.assertEqual(g(d), "sized")
1551 self.assertEqual(g(l), "sized")
1552 self.assertEqual(g(s), "sized")
1553 self.assertEqual(g(f), "sized")
1554 self.assertEqual(g(t), "sized")
1555 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1556 self.assertEqual(g(d), "mutablemapping")
1557 self.assertEqual(g(l), "sized")
1558 self.assertEqual(g(s), "sized")
1559 self.assertEqual(g(f), "sized")
1560 self.assertEqual(g(t), "sized")
1561 g.register(c.ChainMap, lambda obj: "chainmap")
1562 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1563 self.assertEqual(g(l), "sized")
1564 self.assertEqual(g(s), "sized")
1565 self.assertEqual(g(f), "sized")
1566 self.assertEqual(g(t), "sized")
1567 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1568 self.assertEqual(g(d), "mutablemapping")
1569 self.assertEqual(g(l), "mutablesequence")
1570 self.assertEqual(g(s), "sized")
1571 self.assertEqual(g(f), "sized")
1572 self.assertEqual(g(t), "sized")
1573 g.register(c.MutableSet, lambda obj: "mutableset")
1574 self.assertEqual(g(d), "mutablemapping")
1575 self.assertEqual(g(l), "mutablesequence")
1576 self.assertEqual(g(s), "mutableset")
1577 self.assertEqual(g(f), "sized")
1578 self.assertEqual(g(t), "sized")
1579 g.register(c.Mapping, lambda obj: "mapping")
1580 self.assertEqual(g(d), "mutablemapping") # not specific enough
1581 self.assertEqual(g(l), "mutablesequence")
1582 self.assertEqual(g(s), "mutableset")
1583 self.assertEqual(g(f), "sized")
1584 self.assertEqual(g(t), "sized")
1585 g.register(c.Sequence, lambda obj: "sequence")
1586 self.assertEqual(g(d), "mutablemapping")
1587 self.assertEqual(g(l), "mutablesequence")
1588 self.assertEqual(g(s), "mutableset")
1589 self.assertEqual(g(f), "sized")
1590 self.assertEqual(g(t), "sequence")
1591 g.register(c.Set, lambda obj: "set")
1592 self.assertEqual(g(d), "mutablemapping")
1593 self.assertEqual(g(l), "mutablesequence")
1594 self.assertEqual(g(s), "mutableset")
1595 self.assertEqual(g(f), "set")
1596 self.assertEqual(g(t), "sequence")
1597 g.register(dict, lambda obj: "dict")
1598 self.assertEqual(g(d), "dict")
1599 self.assertEqual(g(l), "mutablesequence")
1600 self.assertEqual(g(s), "mutableset")
1601 self.assertEqual(g(f), "set")
1602 self.assertEqual(g(t), "sequence")
1603 g.register(list, lambda obj: "list")
1604 self.assertEqual(g(d), "dict")
1605 self.assertEqual(g(l), "list")
1606 self.assertEqual(g(s), "mutableset")
1607 self.assertEqual(g(f), "set")
1608 self.assertEqual(g(t), "sequence")
1609 g.register(set, lambda obj: "concrete-set")
1610 self.assertEqual(g(d), "dict")
1611 self.assertEqual(g(l), "list")
1612 self.assertEqual(g(s), "concrete-set")
1613 self.assertEqual(g(f), "set")
1614 self.assertEqual(g(t), "sequence")
1615 g.register(frozenset, lambda obj: "frozen-set")
1616 self.assertEqual(g(d), "dict")
1617 self.assertEqual(g(l), "list")
1618 self.assertEqual(g(s), "concrete-set")
1619 self.assertEqual(g(f), "frozen-set")
1620 self.assertEqual(g(t), "sequence")
1621 g.register(tuple, lambda obj: "tuple")
1622 self.assertEqual(g(d), "dict")
1623 self.assertEqual(g(l), "list")
1624 self.assertEqual(g(s), "concrete-set")
1625 self.assertEqual(g(f), "frozen-set")
1626 self.assertEqual(g(t), "tuple")
1627
Łukasz Langa3720c772013-07-01 16:00:38 +02001628 def test_c3_abc(self):
1629 c = collections
1630 mro = functools._c3_mro
1631 class A(object):
1632 pass
1633 class B(A):
1634 def __len__(self):
1635 return 0 # implies Sized
1636 @c.Container.register
1637 class C(object):
1638 pass
1639 class D(object):
1640 pass # unrelated
1641 class X(D, C, B):
1642 def __call__(self):
1643 pass # implies Callable
1644 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
1645 for abcs in permutations([c.Sized, c.Callable, c.Container]):
1646 self.assertEqual(mro(X, abcs=abcs), expected)
1647 # unrelated ABCs don't appear in the resulting MRO
1648 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
1649 self.assertEqual(mro(X, abcs=many_abcs), expected)
1650
Yury Selivanov77a8cd62015-08-18 14:20:00 -04001651 def test_false_meta(self):
1652 # see issue23572
1653 class MetaA(type):
1654 def __len__(self):
1655 return 0
1656 class A(metaclass=MetaA):
1657 pass
1658 class AA(A):
1659 pass
1660 @functools.singledispatch
1661 def fun(a):
1662 return 'base A'
1663 @fun.register(A)
1664 def _(a):
1665 return 'fun A'
1666 aa = AA()
1667 self.assertEqual(fun(aa), 'fun A')
1668
Łukasz Langa6f692512013-06-05 12:20:24 +02001669 def test_mro_conflicts(self):
1670 c = collections
Łukasz Langa6f692512013-06-05 12:20:24 +02001671 @functools.singledispatch
1672 def g(arg):
1673 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02001674 class O(c.Sized):
1675 def __len__(self):
1676 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001677 o = O()
1678 self.assertEqual(g(o), "base")
1679 g.register(c.Iterable, lambda arg: "iterable")
1680 g.register(c.Container, lambda arg: "container")
1681 g.register(c.Sized, lambda arg: "sized")
1682 g.register(c.Set, lambda arg: "set")
1683 self.assertEqual(g(o), "sized")
1684 c.Iterable.register(O)
1685 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
1686 c.Container.register(O)
1687 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02001688 c.Set.register(O)
1689 self.assertEqual(g(o), "set") # because c.Set is a subclass of
1690 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02001691 class P:
1692 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02001693 p = P()
1694 self.assertEqual(g(p), "base")
1695 c.Iterable.register(P)
1696 self.assertEqual(g(p), "iterable")
1697 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02001698 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02001699 g(p)
Łukasz Langa3720c772013-07-01 16:00:38 +02001700 self.assertIn(
1701 str(re_one.exception),
1702 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1703 "or <class 'collections.abc.Iterable'>"),
1704 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
1705 "or <class 'collections.abc.Container'>")),
1706 )
Łukasz Langa6f692512013-06-05 12:20:24 +02001707 class Q(c.Sized):
1708 def __len__(self):
1709 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001710 q = Q()
1711 self.assertEqual(g(q), "sized")
1712 c.Iterable.register(Q)
1713 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
1714 c.Set.register(Q)
1715 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02001716 # c.Sized and c.Iterable
1717 @functools.singledispatch
1718 def h(arg):
1719 return "base"
1720 @h.register(c.Sized)
1721 def _(arg):
1722 return "sized"
1723 @h.register(c.Container)
1724 def _(arg):
1725 return "container"
1726 # Even though Sized and Container are explicit bases of MutableMapping,
1727 # this ABC is implicitly registered on defaultdict which makes all of
1728 # MutableMapping's bases implicit as well from defaultdict's
1729 # perspective.
1730 with self.assertRaises(RuntimeError) as re_two:
1731 h(c.defaultdict(lambda: 0))
1732 self.assertIn(
1733 str(re_two.exception),
1734 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1735 "or <class 'collections.abc.Sized'>"),
1736 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1737 "or <class 'collections.abc.Container'>")),
1738 )
1739 class R(c.defaultdict):
1740 pass
1741 c.MutableSequence.register(R)
1742 @functools.singledispatch
1743 def i(arg):
1744 return "base"
1745 @i.register(c.MutableMapping)
1746 def _(arg):
1747 return "mapping"
1748 @i.register(c.MutableSequence)
1749 def _(arg):
1750 return "sequence"
1751 r = R()
1752 self.assertEqual(i(r), "sequence")
1753 class S:
1754 pass
1755 class T(S, c.Sized):
1756 def __len__(self):
1757 return 0
1758 t = T()
1759 self.assertEqual(h(t), "sized")
1760 c.Container.register(T)
1761 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
1762 class U:
1763 def __len__(self):
1764 return 0
1765 u = U()
1766 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
1767 # from the existence of __len__()
1768 c.Container.register(U)
1769 # There is no preference for registered versus inferred ABCs.
1770 with self.assertRaises(RuntimeError) as re_three:
1771 h(u)
1772 self.assertIn(
1773 str(re_three.exception),
1774 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1775 "or <class 'collections.abc.Sized'>"),
1776 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1777 "or <class 'collections.abc.Container'>")),
1778 )
1779 class V(c.Sized, S):
1780 def __len__(self):
1781 return 0
1782 @functools.singledispatch
1783 def j(arg):
1784 return "base"
1785 @j.register(S)
1786 def _(arg):
1787 return "s"
1788 @j.register(c.Container)
1789 def _(arg):
1790 return "container"
1791 v = V()
1792 self.assertEqual(j(v), "s")
1793 c.Container.register(V)
1794 self.assertEqual(j(v), "container") # because it ends up right after
1795 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02001796
1797 def test_cache_invalidation(self):
1798 from collections import UserDict
1799 class TracingDict(UserDict):
1800 def __init__(self, *args, **kwargs):
1801 super(TracingDict, self).__init__(*args, **kwargs)
1802 self.set_ops = []
1803 self.get_ops = []
1804 def __getitem__(self, key):
1805 result = self.data[key]
1806 self.get_ops.append(key)
1807 return result
1808 def __setitem__(self, key, value):
1809 self.set_ops.append(key)
1810 self.data[key] = value
1811 def clear(self):
1812 self.data.clear()
1813 _orig_wkd = functools.WeakKeyDictionary
1814 td = TracingDict()
1815 functools.WeakKeyDictionary = lambda: td
1816 c = collections
1817 @functools.singledispatch
1818 def g(arg):
1819 return "base"
1820 d = {}
1821 l = []
1822 self.assertEqual(len(td), 0)
1823 self.assertEqual(g(d), "base")
1824 self.assertEqual(len(td), 1)
1825 self.assertEqual(td.get_ops, [])
1826 self.assertEqual(td.set_ops, [dict])
1827 self.assertEqual(td.data[dict], g.registry[object])
1828 self.assertEqual(g(l), "base")
1829 self.assertEqual(len(td), 2)
1830 self.assertEqual(td.get_ops, [])
1831 self.assertEqual(td.set_ops, [dict, list])
1832 self.assertEqual(td.data[dict], g.registry[object])
1833 self.assertEqual(td.data[list], g.registry[object])
1834 self.assertEqual(td.data[dict], td.data[list])
1835 self.assertEqual(g(l), "base")
1836 self.assertEqual(g(d), "base")
1837 self.assertEqual(td.get_ops, [list, dict])
1838 self.assertEqual(td.set_ops, [dict, list])
1839 g.register(list, lambda arg: "list")
1840 self.assertEqual(td.get_ops, [list, dict])
1841 self.assertEqual(len(td), 0)
1842 self.assertEqual(g(d), "base")
1843 self.assertEqual(len(td), 1)
1844 self.assertEqual(td.get_ops, [list, dict])
1845 self.assertEqual(td.set_ops, [dict, list, dict])
1846 self.assertEqual(td.data[dict],
1847 functools._find_impl(dict, g.registry))
1848 self.assertEqual(g(l), "list")
1849 self.assertEqual(len(td), 2)
1850 self.assertEqual(td.get_ops, [list, dict])
1851 self.assertEqual(td.set_ops, [dict, list, dict, list])
1852 self.assertEqual(td.data[list],
1853 functools._find_impl(list, g.registry))
1854 class X:
1855 pass
1856 c.MutableMapping.register(X) # Will not invalidate the cache,
1857 # not using ABCs yet.
1858 self.assertEqual(g(d), "base")
1859 self.assertEqual(g(l), "list")
1860 self.assertEqual(td.get_ops, [list, dict, dict, list])
1861 self.assertEqual(td.set_ops, [dict, list, dict, list])
1862 g.register(c.Sized, lambda arg: "sized")
1863 self.assertEqual(len(td), 0)
1864 self.assertEqual(g(d), "sized")
1865 self.assertEqual(len(td), 1)
1866 self.assertEqual(td.get_ops, [list, dict, dict, list])
1867 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
1868 self.assertEqual(g(l), "list")
1869 self.assertEqual(len(td), 2)
1870 self.assertEqual(td.get_ops, [list, dict, dict, list])
1871 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1872 self.assertEqual(g(l), "list")
1873 self.assertEqual(g(d), "sized")
1874 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
1875 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1876 g.dispatch(list)
1877 g.dispatch(dict)
1878 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
1879 list, dict])
1880 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1881 c.MutableSet.register(X) # Will invalidate the cache.
1882 self.assertEqual(len(td), 2) # Stale cache.
1883 self.assertEqual(g(l), "list")
1884 self.assertEqual(len(td), 1)
1885 g.register(c.MutableMapping, lambda arg: "mutablemapping")
1886 self.assertEqual(len(td), 0)
1887 self.assertEqual(g(d), "mutablemapping")
1888 self.assertEqual(len(td), 1)
1889 self.assertEqual(g(l), "list")
1890 self.assertEqual(len(td), 2)
1891 g.register(dict, lambda arg: "dict")
1892 self.assertEqual(g(d), "dict")
1893 self.assertEqual(g(l), "list")
1894 g._clear_cache()
1895 self.assertEqual(len(td), 0)
1896 functools.WeakKeyDictionary = _orig_wkd
1897
1898
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001899if __name__ == '__main__':
Zachary Ware38c707e2015-04-13 15:00:43 -05001900 unittest.main()