blob: ab51a3534d53aa9d47538ea3b537cf2afc1fed7b [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettinger003be522011-05-03 11:01:32 -07002import collections
Serhiy Storchaka45120f22015-10-24 09:49:56 +03003import copy
Łukasz Langa6f692512013-06-05 12:20:24 +02004from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00005import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00006from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02007import sys
8from test import support
9import unittest
10from weakref import proxy
Serhiy Storchaka46c56112015-05-24 21:53:49 +030011try:
12 import threading
13except ImportError:
14 threading = None
Raymond Hettinger9c323f82005-02-28 19:39:44 +000015
Antoine Pitroub5b37142012-11-13 21:35:40 +010016import functools
17
Antoine Pitroub5b37142012-11-13 21:35:40 +010018py_functools = support.import_fresh_module('functools', blocked=['_functools'])
19c_functools = support.import_fresh_module('functools', fresh=['_functools'])
20
Łukasz Langa6f692512013-06-05 12:20:24 +020021decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
22
23
Raymond Hettinger9c323f82005-02-28 19:39:44 +000024def capture(*args, **kw):
25 """capture all positional and keyword arguments"""
26 return args, kw
27
Łukasz Langa6f692512013-06-05 12:20:24 +020028
Jack Diederiche0cbd692009-04-01 04:27:09 +000029def signature(part):
30 """ return the signature of a partial object """
31 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000032
Serhiy Storchaka38741282016-02-02 18:45:17 +020033class MyTuple(tuple):
34 pass
35
36class BadTuple(tuple):
37 def __add__(self, other):
38 return list(self) + list(other)
39
40class MyDict(dict):
41 pass
42
Łukasz Langa6f692512013-06-05 12:20:24 +020043
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020044class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000045
46 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010047 p = self.partial(capture, 1, 2, a=10, b=20)
48 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000049 self.assertEqual(p(3, 4, b=30, c=40),
50 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010051 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000052 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000053
54 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010055 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000056 # attributes should be readable
57 self.assertEqual(p.func, capture)
58 self.assertEqual(p.args, (1, 2))
59 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000060
61 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010062 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000063 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010064 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000065 except TypeError:
66 pass
67 else:
68 self.fail('First arg not checked for callability')
69
70 def test_protection_of_callers_dict_argument(self):
71 # a caller's dictionary should not be altered by partial
72 def func(a=10, b=20):
73 return a
74 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010075 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000076 self.assertEqual(p(**d), 3)
77 self.assertEqual(d, {'a':3})
78 p(b=7)
79 self.assertEqual(d, {'a':3})
80
81 def test_arg_combinations(self):
82 # exercise special code paths for zero args in either partial
83 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010084 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000085 self.assertEqual(p(), ((), {}))
86 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010087 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000088 self.assertEqual(p(), ((1,2), {}))
89 self.assertEqual(p(3,4), ((1,2,3,4), {}))
90
91 def test_kw_combinations(self):
92 # exercise special code paths for no keyword args in
93 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010094 p = self.partial(capture)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -040095 self.assertEqual(p.keywords, {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +000096 self.assertEqual(p(), ((), {}))
97 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010098 p = self.partial(capture, a=1)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -040099 self.assertEqual(p.keywords, {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000100 self.assertEqual(p(), ((), {'a':1}))
101 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
102 # keyword args in the call override those in the partial object
103 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
104
105 def test_positional(self):
106 # make sure positional arguments are captured correctly
107 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100108 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000109 expected = args + ('x',)
110 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000111 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000112
113 def test_keyword(self):
114 # make sure keyword arguments are captured correctly
115 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100116 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000117 expected = {'a':a,'x':None}
118 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000119 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000120
121 def test_no_side_effects(self):
122 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100123 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000124 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000125 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000126 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000127 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000128
129 def test_error_propagation(self):
130 def f(x, y):
131 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100132 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
133 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
134 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
135 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000136
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000137 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100138 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000139 p = proxy(f)
140 self.assertEqual(f.func, p.func)
141 f = None
142 self.assertRaises(ReferenceError, getattr, p, 'func')
143
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000144 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000145 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100146 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000147 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100148 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000149 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000150
Alexander Belopolskye49af342015-03-01 15:08:17 -0500151 def test_nested_optimization(self):
152 partial = self.partial
Alexander Belopolskye49af342015-03-01 15:08:17 -0500153 inner = partial(signature, 'asdf')
154 nested = partial(inner, bar=True)
155 flat = partial(signature, 'asdf', bar=True)
156 self.assertEqual(signature(nested), signature(flat))
157
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300158 def test_nested_partial_with_attribute(self):
159 # see issue 25137
160 partial = self.partial
161
162 def foo(bar):
163 return bar
164
165 p = partial(foo, 'first')
166 p2 = partial(p, 'second')
167 p2.new_attr = 'spam'
168 self.assertEqual(p2.new_attr, 'spam')
169
Łukasz Langa6f692512013-06-05 12:20:24 +0200170
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200171@unittest.skipUnless(c_functools, 'requires the C _functools module')
172class TestPartialC(TestPartial, unittest.TestCase):
173 if c_functools:
174 partial = c_functools.partial
175
Zachary Ware101d9e72013-12-08 00:44:27 -0600176 def test_attributes_unwritable(self):
177 # attributes should not be writable
178 p = self.partial(capture, 1, 2, a=10, b=20)
179 self.assertRaises(AttributeError, setattr, p, 'func', map)
180 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
181 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
182
183 p = self.partial(hex)
184 try:
185 del p.__dict__
186 except TypeError:
187 pass
188 else:
189 self.fail('partial object allowed __dict__ to be deleted')
190
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000191 def test_repr(self):
192 args = (object(), object())
193 args_repr = ', '.join(repr(a) for a in args)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200194 kwargs = {'a': object(), 'b': object()}
195 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
196 'b={b!r}, a={a!r}'.format_map(kwargs)]
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200197 if self.partial is c_functools.partial:
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000198 name = 'functools.partial'
199 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100200 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000201
Antoine Pitroub5b37142012-11-13 21:35:40 +0100202 f = self.partial(capture)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000203 self.assertEqual('{}({!r})'.format(name, capture),
204 repr(f))
205
Antoine Pitroub5b37142012-11-13 21:35:40 +0100206 f = self.partial(capture, *args)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000207 self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr),
208 repr(f))
209
Antoine Pitroub5b37142012-11-13 21:35:40 +0100210 f = self.partial(capture, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200211 self.assertIn(repr(f),
212 ['{}({!r}, {})'.format(name, capture, kwargs_repr)
213 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000214
Antoine Pitroub5b37142012-11-13 21:35:40 +0100215 f = self.partial(capture, *args, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200216 self.assertIn(repr(f),
217 ['{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr)
218 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000219
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300220 def test_recursive_repr(self):
221 if self.partial is c_functools.partial:
222 name = 'functools.partial'
223 else:
224 name = self.partial.__name__
225
226 f = self.partial(capture)
227 f.__setstate__((f, (), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300228 try:
229 self.assertEqual(repr(f), '%s(%s(...))' % (name, name))
230 finally:
231 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300232
233 f = self.partial(capture)
234 f.__setstate__((capture, (f,), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300235 try:
236 self.assertEqual(repr(f), '%s(%r, %s(...))' % (name, capture, name))
237 finally:
238 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300239
240 f = self.partial(capture)
241 f.__setstate__((capture, (), {'a': f}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300242 try:
243 self.assertEqual(repr(f), '%s(%r, a=%s(...))' % (name, capture, name))
244 finally:
245 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300246
Jack Diederiche0cbd692009-04-01 04:27:09 +0000247 def test_pickle(self):
Serhiy Storchaka38741282016-02-02 18:45:17 +0200248 f = self.partial(signature, ['asdf'], bar=[True])
249 f.attr = []
Serhiy Storchakabad12572014-12-15 14:03:42 +0200250 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
251 f_copy = pickle.loads(pickle.dumps(f, proto))
Serhiy Storchaka38741282016-02-02 18:45:17 +0200252 self.assertEqual(signature(f_copy), signature(f))
253
254 def test_copy(self):
255 f = self.partial(signature, ['asdf'], bar=[True])
256 f.attr = []
257 f_copy = copy.copy(f)
258 self.assertEqual(signature(f_copy), signature(f))
259 self.assertIs(f_copy.attr, f.attr)
260 self.assertIs(f_copy.args, f.args)
261 self.assertIs(f_copy.keywords, f.keywords)
262
263 def test_deepcopy(self):
264 f = self.partial(signature, ['asdf'], bar=[True])
265 f.attr = []
266 f_copy = copy.deepcopy(f)
267 self.assertEqual(signature(f_copy), signature(f))
268 self.assertIsNot(f_copy.attr, f.attr)
269 self.assertIsNot(f_copy.args, f.args)
270 self.assertIsNot(f_copy.args[0], f.args[0])
271 self.assertIsNot(f_copy.keywords, f.keywords)
272 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
273
274 def test_setstate(self):
275 f = self.partial(signature)
276 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
277 self.assertEqual(signature(f),
278 (capture, (1,), dict(a=10), dict(attr=[])))
279 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
280
281 f.__setstate__((capture, (1,), dict(a=10), None))
282 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
283 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
284
285 f.__setstate__((capture, (1,), None, None))
286 #self.assertEqual(signature(f), (capture, (1,), {}, {}))
287 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
288 self.assertEqual(f(2), ((1, 2), {}))
289 self.assertEqual(f(), ((1,), {}))
290
291 f.__setstate__((capture, (), {}, None))
292 self.assertEqual(signature(f), (capture, (), {}, {}))
293 self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
294 self.assertEqual(f(2), ((2,), {}))
295 self.assertEqual(f(), ((), {}))
296
297 def test_setstate_errors(self):
298 f = self.partial(signature)
299 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
300 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
301 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
302 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
303 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
304 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
305 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
306
307 def test_setstate_subclasses(self):
308 f = self.partial(signature)
309 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
310 s = signature(f)
311 self.assertEqual(s, (capture, (1,), dict(a=10), {}))
312 self.assertIs(type(s[1]), tuple)
313 self.assertIs(type(s[2]), dict)
314 r = f()
315 self.assertEqual(r, ((1,), {'a': 10}))
316 self.assertIs(type(r[0]), tuple)
317 self.assertIs(type(r[1]), dict)
318
319 f.__setstate__((capture, BadTuple((1,)), {}, None))
320 s = signature(f)
321 self.assertEqual(s, (capture, (1,), {}, {}))
322 self.assertIs(type(s[1]), tuple)
323 r = f(2)
324 self.assertEqual(r, ((1, 2), {}))
325 self.assertIs(type(r[0]), tuple)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000326
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300327 def test_recursive_pickle(self):
328 f = self.partial(capture)
329 f.__setstate__((f, (), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300330 try:
331 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
332 with self.assertRaises(RecursionError):
333 pickle.dumps(f, proto)
334 finally:
335 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300336
337 f = self.partial(capture)
338 f.__setstate__((capture, (f,), {}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300339 try:
340 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
341 f_copy = pickle.loads(pickle.dumps(f, proto))
342 try:
343 self.assertIs(f_copy.args[0], f_copy)
344 finally:
345 f_copy.__setstate__((capture, (), {}, {}))
346 finally:
347 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300348
349 f = self.partial(capture)
350 f.__setstate__((capture, (), {'a': f}, {}))
Serhiy Storchaka46fe29d2016-06-12 15:45:14 +0300351 try:
352 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
353 f_copy = pickle.loads(pickle.dumps(f, proto))
354 try:
355 self.assertIs(f_copy.keywords['a'], f_copy)
356 finally:
357 f_copy.__setstate__((capture, (), {}, {}))
358 finally:
359 f.__setstate__((capture, (), {}, {}))
Serhiy Storchaka179f9602016-06-12 11:44:06 +0300360
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200361 # Issue 6083: Reference counting bug
362 def test_setstate_refcount(self):
363 class BadSequence:
364 def __len__(self):
365 return 4
366 def __getitem__(self, key):
367 if key == 0:
368 return max
369 elif key == 1:
370 return tuple(range(1000000))
371 elif key in (2, 3):
372 return {}
373 raise IndexError
374
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200375 f = self.partial(object)
Serhiy Storchaka38741282016-02-02 18:45:17 +0200376 self.assertRaises(TypeError, f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000377
Łukasz Langa6f692512013-06-05 12:20:24 +0200378
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200379class TestPartialPy(TestPartial, unittest.TestCase):
380 partial = staticmethod(py_functools.partial)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000381
Łukasz Langa6f692512013-06-05 12:20:24 +0200382
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200383if c_functools:
384 class PartialSubclass(c_functools.partial):
385 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100386
Łukasz Langa6f692512013-06-05 12:20:24 +0200387
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200388@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200389class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200390 if c_functools:
391 partial = PartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000392
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300393 # partial subclasses are not optimized for nested calls
394 test_nested_optimization = None
395
Łukasz Langa6f692512013-06-05 12:20:24 +0200396
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000397class TestPartialMethod(unittest.TestCase):
398
399 class A(object):
400 nothing = functools.partialmethod(capture)
401 positional = functools.partialmethod(capture, 1)
402 keywords = functools.partialmethod(capture, a=2)
403 both = functools.partialmethod(capture, 3, b=4)
404
405 nested = functools.partialmethod(positional, 5)
406
407 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
408
409 static = functools.partialmethod(staticmethod(capture), 8)
410 cls = functools.partialmethod(classmethod(capture), d=9)
411
412 a = A()
413
414 def test_arg_combinations(self):
415 self.assertEqual(self.a.nothing(), ((self.a,), {}))
416 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
417 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
418 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
419
420 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
421 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
422 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
423 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
424
425 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
426 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
427 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
428 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
429
430 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
431 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
432 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
433 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
434
435 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
436
437 def test_nested(self):
438 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
439 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
440 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
441 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
442
443 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
444
445 def test_over_partial(self):
446 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
447 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
448 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
449 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
450
451 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
452
453 def test_bound_method_introspection(self):
454 obj = self.a
455 self.assertIs(obj.both.__self__, obj)
456 self.assertIs(obj.nested.__self__, obj)
457 self.assertIs(obj.over_partial.__self__, obj)
458 self.assertIs(obj.cls.__self__, self.A)
459 self.assertIs(self.A.cls.__self__, self.A)
460
461 def test_unbound_method_retrieval(self):
462 obj = self.A
463 self.assertFalse(hasattr(obj.both, "__self__"))
464 self.assertFalse(hasattr(obj.nested, "__self__"))
465 self.assertFalse(hasattr(obj.over_partial, "__self__"))
466 self.assertFalse(hasattr(obj.static, "__self__"))
467 self.assertFalse(hasattr(self.a.static, "__self__"))
468
469 def test_descriptors(self):
470 for obj in [self.A, self.a]:
471 with self.subTest(obj=obj):
472 self.assertEqual(obj.static(), ((8,), {}))
473 self.assertEqual(obj.static(5), ((8, 5), {}))
474 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
475 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
476
477 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
478 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
479 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
480 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
481
482 def test_overriding_keywords(self):
483 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
484 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
485
486 def test_invalid_args(self):
487 with self.assertRaises(TypeError):
488 class B(object):
489 method = functools.partialmethod(None, 1)
490
491 def test_repr(self):
492 self.assertEqual(repr(vars(self.A)['both']),
493 'functools.partialmethod({}, 3, b=4)'.format(capture))
494
495 def test_abstract(self):
496 class Abstract(abc.ABCMeta):
497
498 @abc.abstractmethod
499 def add(self, x, y):
500 pass
501
502 add5 = functools.partialmethod(add, 5)
503
504 self.assertTrue(Abstract.add.__isabstractmethod__)
505 self.assertTrue(Abstract.add5.__isabstractmethod__)
506
507 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
508 self.assertFalse(getattr(func, '__isabstractmethod__', False))
509
510
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000511class TestUpdateWrapper(unittest.TestCase):
512
513 def check_wrapper(self, wrapper, wrapped,
514 assigned=functools.WRAPPER_ASSIGNMENTS,
515 updated=functools.WRAPPER_UPDATES):
516 # Check attributes were assigned
517 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000518 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000519 # Check attributes were updated
520 for name in updated:
521 wrapper_attr = getattr(wrapper, name)
522 wrapped_attr = getattr(wrapped, name)
523 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000524 if name == "__dict__" and key == "__wrapped__":
525 # __wrapped__ is overwritten by the update code
526 continue
527 self.assertIs(wrapped_attr[key], wrapper_attr[key])
528 # Check __wrapped__
529 self.assertIs(wrapper.__wrapped__, wrapped)
530
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000531
R. David Murray378c0cf2010-02-24 01:46:21 +0000532 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000533 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000534 """This is a test"""
535 pass
536 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000537 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000538 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000539 pass
540 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000541 return wrapper, f
542
543 def test_default_update(self):
544 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000545 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000546 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000547 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600548 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000549 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000550 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
551 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000552
R. David Murray378c0cf2010-02-24 01:46:21 +0000553 @unittest.skipIf(sys.flags.optimize >= 2,
554 "Docstrings are omitted with -O2 and above")
555 def test_default_update_doc(self):
556 wrapper, f = self._default_update()
557 self.assertEqual(wrapper.__doc__, 'This is a test')
558
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000559 def test_no_update(self):
560 def f():
561 """This is a test"""
562 pass
563 f.attr = 'This is also a test'
564 def wrapper():
565 pass
566 functools.update_wrapper(wrapper, f, (), ())
567 self.check_wrapper(wrapper, f, (), ())
568 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600569 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000570 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000571 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000572 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000573
574 def test_selective_update(self):
575 def f():
576 pass
577 f.attr = 'This is a different test'
578 f.dict_attr = dict(a=1, b=2, c=3)
579 def wrapper():
580 pass
581 wrapper.dict_attr = {}
582 assign = ('attr',)
583 update = ('dict_attr',)
584 functools.update_wrapper(wrapper, f, assign, update)
585 self.check_wrapper(wrapper, f, assign, update)
586 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600587 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000588 self.assertEqual(wrapper.__doc__, None)
589 self.assertEqual(wrapper.attr, 'This is a different test')
590 self.assertEqual(wrapper.dict_attr, f.dict_attr)
591
Nick Coghlan98876832010-08-17 06:17:18 +0000592 def test_missing_attributes(self):
593 def f():
594 pass
595 def wrapper():
596 pass
597 wrapper.dict_attr = {}
598 assign = ('attr',)
599 update = ('dict_attr',)
600 # Missing attributes on wrapped object are ignored
601 functools.update_wrapper(wrapper, f, assign, update)
602 self.assertNotIn('attr', wrapper.__dict__)
603 self.assertEqual(wrapper.dict_attr, {})
604 # Wrapper must have expected attributes for updating
605 del wrapper.dict_attr
606 with self.assertRaises(AttributeError):
607 functools.update_wrapper(wrapper, f, assign, update)
608 wrapper.dict_attr = 1
609 with self.assertRaises(AttributeError):
610 functools.update_wrapper(wrapper, f, assign, update)
611
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200612 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000613 @unittest.skipIf(sys.flags.optimize >= 2,
614 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000615 def test_builtin_update(self):
616 # Test for bug #1576241
617 def wrapper():
618 pass
619 functools.update_wrapper(wrapper, max)
620 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000621 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000622 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000623
Łukasz Langa6f692512013-06-05 12:20:24 +0200624
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000625class TestWraps(TestUpdateWrapper):
626
R. David Murray378c0cf2010-02-24 01:46:21 +0000627 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000628 def f():
629 """This is a test"""
630 pass
631 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000632 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000633 @functools.wraps(f)
634 def wrapper():
635 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600636 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000637
638 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600639 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000640 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000641 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600642 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000643 self.assertEqual(wrapper.attr, 'This is also a test')
644
Antoine Pitroub5b37142012-11-13 21:35:40 +0100645 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000646 "Docstrings are omitted with -O2 and above")
647 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600648 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000649 self.assertEqual(wrapper.__doc__, 'This is a test')
650
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000651 def test_no_update(self):
652 def f():
653 """This is a test"""
654 pass
655 f.attr = 'This is also a test'
656 @functools.wraps(f, (), ())
657 def wrapper():
658 pass
659 self.check_wrapper(wrapper, f, (), ())
660 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600661 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000662 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000663 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000664
665 def test_selective_update(self):
666 def f():
667 pass
668 f.attr = 'This is a different test'
669 f.dict_attr = dict(a=1, b=2, c=3)
670 def add_dict_attr(f):
671 f.dict_attr = {}
672 return f
673 assign = ('attr',)
674 update = ('dict_attr',)
675 @functools.wraps(f, assign, update)
676 @add_dict_attr
677 def wrapper():
678 pass
679 self.check_wrapper(wrapper, f, assign, update)
680 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600681 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000682 self.assertEqual(wrapper.__doc__, None)
683 self.assertEqual(wrapper.attr, 'This is a different test')
684 self.assertEqual(wrapper.dict_attr, f.dict_attr)
685
Łukasz Langa6f692512013-06-05 12:20:24 +0200686
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000687class TestReduce(unittest.TestCase):
688 func = functools.reduce
689
690 def test_reduce(self):
691 class Squares:
692 def __init__(self, max):
693 self.max = max
694 self.sofar = []
695
696 def __len__(self):
697 return len(self.sofar)
698
699 def __getitem__(self, i):
700 if not 0 <= i < self.max: raise IndexError
701 n = len(self.sofar)
702 while n <= i:
703 self.sofar.append(n*n)
704 n += 1
705 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000706 def add(x, y):
707 return x + y
708 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000709 self.assertEqual(
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000710 self.func(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000711 ['a','c','d','w']
712 )
713 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
714 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000715 self.func(lambda x, y: x*y, range(2,21), 1),
716 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000717 )
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000718 self.assertEqual(self.func(add, Squares(10)), 285)
719 self.assertEqual(self.func(add, Squares(10), 0), 285)
720 self.assertEqual(self.func(add, Squares(0), 0), 0)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000721 self.assertRaises(TypeError, self.func)
722 self.assertRaises(TypeError, self.func, 42, 42)
723 self.assertRaises(TypeError, self.func, 42, 42, 42)
724 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
725 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
726 self.assertRaises(TypeError, self.func, 42, (42, 42))
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000727 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
728 self.assertRaises(TypeError, self.func, add, "")
729 self.assertRaises(TypeError, self.func, add, ())
730 self.assertRaises(TypeError, self.func, add, object())
731
732 class TestFailingIter:
733 def __iter__(self):
734 raise RuntimeError
735 self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
736
737 self.assertEqual(self.func(add, [], None), None)
738 self.assertEqual(self.func(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000739
740 class BadSeq:
741 def __getitem__(self, index):
742 raise ValueError
743 self.assertRaises(ValueError, self.func, 42, BadSeq())
744
745 # Test reduce()'s use of iterators.
746 def test_iterator_usage(self):
747 class SequenceClass:
748 def __init__(self, n):
749 self.n = n
750 def __getitem__(self, i):
751 if 0 <= i < self.n:
752 return i
753 else:
754 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000755
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000756 from operator import add
757 self.assertEqual(self.func(add, SequenceClass(5)), 10)
758 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
759 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
760 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
761 self.assertEqual(self.func(add, SequenceClass(1)), 0)
762 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
763
764 d = {"one": 1, "two": 2, "three": 3}
765 self.assertEqual(self.func(add, d), "".join(d.keys()))
766
Łukasz Langa6f692512013-06-05 12:20:24 +0200767
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200768class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700769
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000770 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700771 def cmp1(x, y):
772 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100773 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700774 self.assertEqual(key(3), key(3))
775 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100776 self.assertGreaterEqual(key(3), key(3))
777
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700778 def cmp2(x, y):
779 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100780 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700781 self.assertEqual(key(4.0), key('4'))
782 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100783 self.assertLessEqual(key(2), key('35'))
784 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700785
786 def test_cmp_to_key_arguments(self):
787 def cmp1(x, y):
788 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100789 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700790 self.assertEqual(key(obj=3), key(obj=3))
791 self.assertGreater(key(obj=3), key(obj=1))
792 with self.assertRaises((TypeError, AttributeError)):
793 key(3) > 1 # rhs is not a K object
794 with self.assertRaises((TypeError, AttributeError)):
795 1 < key(3) # lhs is not a K object
796 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100797 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700798 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200799 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100800 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700801 with self.assertRaises(TypeError):
802 key() # too few args
803 with self.assertRaises(TypeError):
804 key(None, None) # too many args
805
806 def test_bad_cmp(self):
807 def cmp1(x, y):
808 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100809 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700810 with self.assertRaises(ZeroDivisionError):
811 key(3) > key(1)
812
813 class BadCmp:
814 def __lt__(self, other):
815 raise ZeroDivisionError
816 def cmp1(x, y):
817 return BadCmp()
818 with self.assertRaises(ZeroDivisionError):
819 key(3) > key(1)
820
821 def test_obj_field(self):
822 def cmp1(x, y):
823 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100824 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700825 self.assertEqual(key(50).obj, 50)
826
827 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000828 def mycmp(x, y):
829 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100830 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000831 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000832
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700833 def test_sort_int_str(self):
834 def mycmp(x, y):
835 x, y = int(x), int(y)
836 return (x > y) - (x < y)
837 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100838 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700839 self.assertEqual([int(value) for value in values],
840 [0, 1, 1, 2, 3, 4, 5, 7, 10])
841
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000842 def test_hash(self):
843 def mycmp(x, y):
844 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100845 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000846 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700847 self.assertRaises(TypeError, hash, k)
Raymond Hettingere7a24302011-05-03 11:16:36 -0700848 self.assertNotIsInstance(k, collections.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000849
Łukasz Langa6f692512013-06-05 12:20:24 +0200850
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200851@unittest.skipUnless(c_functools, 'requires the C _functools module')
852class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
853 if c_functools:
854 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100855
Łukasz Langa6f692512013-06-05 12:20:24 +0200856
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200857class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100858 cmp_to_key = staticmethod(py_functools.cmp_to_key)
859
Łukasz Langa6f692512013-06-05 12:20:24 +0200860
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000861class TestTotalOrdering(unittest.TestCase):
862
863 def test_total_ordering_lt(self):
864 @functools.total_ordering
865 class A:
866 def __init__(self, value):
867 self.value = value
868 def __lt__(self, other):
869 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000870 def __eq__(self, other):
871 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000872 self.assertTrue(A(1) < A(2))
873 self.assertTrue(A(2) > A(1))
874 self.assertTrue(A(1) <= A(2))
875 self.assertTrue(A(2) >= A(1))
876 self.assertTrue(A(2) <= A(2))
877 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000878 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000879
880 def test_total_ordering_le(self):
881 @functools.total_ordering
882 class A:
883 def __init__(self, value):
884 self.value = value
885 def __le__(self, other):
886 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000887 def __eq__(self, other):
888 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000889 self.assertTrue(A(1) < A(2))
890 self.assertTrue(A(2) > A(1))
891 self.assertTrue(A(1) <= A(2))
892 self.assertTrue(A(2) >= A(1))
893 self.assertTrue(A(2) <= A(2))
894 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000895 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000896
897 def test_total_ordering_gt(self):
898 @functools.total_ordering
899 class A:
900 def __init__(self, value):
901 self.value = value
902 def __gt__(self, other):
903 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000904 def __eq__(self, other):
905 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000906 self.assertTrue(A(1) < A(2))
907 self.assertTrue(A(2) > A(1))
908 self.assertTrue(A(1) <= A(2))
909 self.assertTrue(A(2) >= A(1))
910 self.assertTrue(A(2) <= A(2))
911 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000912 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000913
914 def test_total_ordering_ge(self):
915 @functools.total_ordering
916 class A:
917 def __init__(self, value):
918 self.value = value
919 def __ge__(self, other):
920 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000921 def __eq__(self, other):
922 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000923 self.assertTrue(A(1) < A(2))
924 self.assertTrue(A(2) > A(1))
925 self.assertTrue(A(1) <= A(2))
926 self.assertTrue(A(2) >= A(1))
927 self.assertTrue(A(2) <= A(2))
928 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000929 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000930
931 def test_total_ordering_no_overwrite(self):
932 # new methods should not overwrite existing
933 @functools.total_ordering
934 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000935 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +0000936 self.assertTrue(A(1) < A(2))
937 self.assertTrue(A(2) > A(1))
938 self.assertTrue(A(1) <= A(2))
939 self.assertTrue(A(2) >= A(1))
940 self.assertTrue(A(2) <= A(2))
941 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000942
Benjamin Peterson42ebee32010-04-11 01:43:16 +0000943 def test_no_operations_defined(self):
944 with self.assertRaises(ValueError):
945 @functools.total_ordering
946 class A:
947 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000948
Nick Coghlanf05d9812013-10-02 00:02:03 +1000949 def test_type_error_when_not_implemented(self):
950 # bug 10042; ensure stack overflow does not occur
951 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000952 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +1000953 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000954 def __init__(self, value):
955 self.value = value
956 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000957 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000958 return self.value == other.value
959 return False
960 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000961 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000962 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +1000963 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000964
Nick Coghlanf05d9812013-10-02 00:02:03 +1000965 @functools.total_ordering
966 class ImplementsGreaterThan:
967 def __init__(self, value):
968 self.value = value
969 def __eq__(self, other):
970 if isinstance(other, ImplementsGreaterThan):
971 return self.value == other.value
972 return False
973 def __gt__(self, other):
974 if isinstance(other, ImplementsGreaterThan):
975 return self.value > other.value
976 return NotImplemented
977
978 @functools.total_ordering
979 class ImplementsLessThanEqualTo:
980 def __init__(self, value):
981 self.value = value
982 def __eq__(self, other):
983 if isinstance(other, ImplementsLessThanEqualTo):
984 return self.value == other.value
985 return False
986 def __le__(self, other):
987 if isinstance(other, ImplementsLessThanEqualTo):
988 return self.value <= other.value
989 return NotImplemented
990
991 @functools.total_ordering
992 class ImplementsGreaterThanEqualTo:
993 def __init__(self, value):
994 self.value = value
995 def __eq__(self, other):
996 if isinstance(other, ImplementsGreaterThanEqualTo):
997 return self.value == other.value
998 return False
999 def __ge__(self, other):
1000 if isinstance(other, ImplementsGreaterThanEqualTo):
1001 return self.value >= other.value
1002 return NotImplemented
1003
1004 @functools.total_ordering
1005 class ComparatorNotImplemented:
1006 def __init__(self, value):
1007 self.value = value
1008 def __eq__(self, other):
1009 if isinstance(other, ComparatorNotImplemented):
1010 return self.value == other.value
1011 return False
1012 def __lt__(self, other):
1013 return NotImplemented
1014
1015 with self.subTest("LT < 1"), self.assertRaises(TypeError):
1016 ImplementsLessThan(-1) < 1
1017
1018 with self.subTest("LT < LE"), self.assertRaises(TypeError):
1019 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
1020
1021 with self.subTest("LT < GT"), self.assertRaises(TypeError):
1022 ImplementsLessThan(1) < ImplementsGreaterThan(1)
1023
1024 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
1025 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
1026
1027 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
1028 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
1029
1030 with self.subTest("GT > GE"), self.assertRaises(TypeError):
1031 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
1032
1033 with self.subTest("GT > LT"), self.assertRaises(TypeError):
1034 ImplementsGreaterThan(5) > ImplementsLessThan(5)
1035
1036 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
1037 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
1038
1039 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
1040 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
1041
1042 with self.subTest("GE when equal"):
1043 a = ComparatorNotImplemented(8)
1044 b = ComparatorNotImplemented(8)
1045 self.assertEqual(a, b)
1046 with self.assertRaises(TypeError):
1047 a >= b
1048
1049 with self.subTest("LE when equal"):
1050 a = ComparatorNotImplemented(9)
1051 b = ComparatorNotImplemented(9)
1052 self.assertEqual(a, b)
1053 with self.assertRaises(TypeError):
1054 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +02001055
Serhiy Storchaka697a5262015-01-01 15:23:12 +02001056 def test_pickle(self):
1057 for proto in range(4, pickle.HIGHEST_PROTOCOL + 1):
1058 for name in '__lt__', '__gt__', '__le__', '__ge__':
1059 with self.subTest(method=name, proto=proto):
1060 method = getattr(Orderable_LT, name)
1061 method_copy = pickle.loads(pickle.dumps(method, proto))
1062 self.assertIs(method_copy, method)
1063
1064@functools.total_ordering
1065class Orderable_LT:
1066 def __init__(self, value):
1067 self.value = value
1068 def __lt__(self, other):
1069 return self.value < other.value
1070 def __eq__(self, other):
1071 return self.value == other.value
1072
1073
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001074class TestLRU:
Georg Brandl2e7346a2010-07-31 18:09:23 +00001075
1076 def test_lru(self):
1077 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001078 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001079 f = self.module.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001080 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001081 self.assertEqual(maxsize, 20)
1082 self.assertEqual(currsize, 0)
1083 self.assertEqual(hits, 0)
1084 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001085
1086 domain = range(5)
1087 for i in range(1000):
1088 x, y = choice(domain), choice(domain)
1089 actual = f(x, y)
1090 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +00001091 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001092 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001093 self.assertTrue(hits > misses)
1094 self.assertEqual(hits + misses, 1000)
1095 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001096
Raymond Hettinger02566ec2010-09-04 22:46:06 +00001097 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +00001098 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001099 self.assertEqual(hits, 0)
1100 self.assertEqual(misses, 0)
1101 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001102 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001103 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001104 self.assertEqual(hits, 0)
1105 self.assertEqual(misses, 1)
1106 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +00001107
Nick Coghlan98876832010-08-17 06:17:18 +00001108 # Test bypassing the cache
1109 self.assertIs(f.__wrapped__, orig)
1110 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001111 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001112 self.assertEqual(hits, 0)
1113 self.assertEqual(misses, 1)
1114 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +00001115
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001116 # test size zero (which means "never-cache")
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001117 @self.module.lru_cache(0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001118 def f():
1119 nonlocal f_cnt
1120 f_cnt += 1
1121 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001122 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001123 f_cnt = 0
1124 for i in range(5):
1125 self.assertEqual(f(), 20)
1126 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001127 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001128 self.assertEqual(hits, 0)
1129 self.assertEqual(misses, 5)
1130 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001131
1132 # test size one
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001133 @self.module.lru_cache(1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001134 def f():
1135 nonlocal f_cnt
1136 f_cnt += 1
1137 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +00001138 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001139 f_cnt = 0
1140 for i in range(5):
1141 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001142 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001143 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001144 self.assertEqual(hits, 4)
1145 self.assertEqual(misses, 1)
1146 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001147
Raymond Hettingerf3098282010-08-15 03:30:45 +00001148 # test size two
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001149 @self.module.lru_cache(2)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001150 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001151 nonlocal f_cnt
1152 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +00001153 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +00001154 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001155 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +00001156 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1157 # * * * *
1158 self.assertEqual(f(x), x*10)
1159 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001160 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001161 self.assertEqual(hits, 12)
1162 self.assertEqual(misses, 4)
1163 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001164
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001165 def test_lru_with_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001166 @self.module.lru_cache(maxsize=None)
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001167 def fib(n):
1168 if n < 2:
1169 return n
1170 return fib(n-1) + fib(n-2)
1171 self.assertEqual([fib(n) for n in range(16)],
1172 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1173 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001174 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001175 fib.cache_clear()
1176 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001177 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1178
1179 def test_lru_with_maxsize_negative(self):
1180 @self.module.lru_cache(maxsize=-10)
1181 def eq(n):
1182 return n
1183 for i in (0, 1):
1184 self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1185 self.assertEqual(eq.cache_info(),
1186 self.module._CacheInfo(hits=0, misses=300, maxsize=-10, currsize=1))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001187
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001188 def test_lru_with_exceptions(self):
1189 # Verify that user_function exceptions get passed through without
1190 # creating a hard-to-read chained exception.
1191 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +01001192 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001193 @self.module.lru_cache(maxsize)
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001194 def func(i):
1195 return 'abc'[i]
1196 self.assertEqual(func(0), 'a')
1197 with self.assertRaises(IndexError) as cm:
1198 func(15)
1199 self.assertIsNone(cm.exception.__context__)
1200 # Verify that the previous exception did not result in a cached entry
1201 with self.assertRaises(IndexError):
1202 func(15)
1203
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001204 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001205 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001206 @self.module.lru_cache(maxsize=maxsize, typed=True)
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001207 def square(x):
1208 return x * x
1209 self.assertEqual(square(3), 9)
1210 self.assertEqual(type(square(3)), type(9))
1211 self.assertEqual(square(3.0), 9.0)
1212 self.assertEqual(type(square(3.0)), type(9.0))
1213 self.assertEqual(square(x=3), 9)
1214 self.assertEqual(type(square(x=3)), type(9))
1215 self.assertEqual(square(x=3.0), 9.0)
1216 self.assertEqual(type(square(x=3.0)), type(9.0))
1217 self.assertEqual(square.cache_info().hits, 4)
1218 self.assertEqual(square.cache_info().misses, 4)
1219
Antoine Pitroub5b37142012-11-13 21:35:40 +01001220 def test_lru_with_keyword_args(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001221 @self.module.lru_cache()
Antoine Pitroub5b37142012-11-13 21:35:40 +01001222 def fib(n):
1223 if n < 2:
1224 return n
1225 return fib(n=n-1) + fib(n=n-2)
1226 self.assertEqual(
1227 [fib(n=number) for number in range(16)],
1228 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1229 )
1230 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001231 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001232 fib.cache_clear()
1233 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001234 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001235
1236 def test_lru_with_keyword_args_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001237 @self.module.lru_cache(maxsize=None)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001238 def fib(n):
1239 if n < 2:
1240 return n
1241 return fib(n=n-1) + fib(n=n-2)
1242 self.assertEqual([fib(n=number) for number in range(16)],
1243 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1244 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001245 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001246 fib.cache_clear()
1247 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001248 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1249
1250 def test_lru_cache_decoration(self):
1251 def f(zomg: 'zomg_annotation'):
1252 """f doc string"""
1253 return 42
1254 g = self.module.lru_cache()(f)
1255 for attr in self.module.WRAPPER_ASSIGNMENTS:
1256 self.assertEqual(getattr(g, attr), getattr(f, attr))
1257
1258 @unittest.skipUnless(threading, 'This test requires threading.')
1259 def test_lru_cache_threaded(self):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001260 n, m = 5, 11
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001261 def orig(x, y):
1262 return 3 * x + y
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001263 f = self.module.lru_cache(maxsize=n*m)(orig)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001264 hits, misses, maxsize, currsize = f.cache_info()
1265 self.assertEqual(currsize, 0)
1266
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001267 start = threading.Event()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001268 def full(k):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001269 start.wait(10)
1270 for _ in range(m):
Serhiy Storchaka391af752015-06-08 12:44:18 +03001271 self.assertEqual(f(k, 0), orig(k, 0))
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001272
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001273 def clear():
1274 start.wait(10)
1275 for _ in range(2*m):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001276 f.cache_clear()
1277
1278 orig_si = sys.getswitchinterval()
1279 sys.setswitchinterval(1e-6)
1280 try:
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001281 # create n threads in order to fill cache
Serhiy Storchaka391af752015-06-08 12:44:18 +03001282 threads = [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001283 for k in range(n)]
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001284 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001285 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001286
1287 hits, misses, maxsize, currsize = f.cache_info()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001288 if self.module is py_functools:
1289 # XXX: Why can be not equal?
1290 self.assertLessEqual(misses, n)
1291 self.assertLessEqual(hits, m*n - misses)
1292 else:
1293 self.assertEqual(misses, n)
1294 self.assertEqual(hits, m*n - misses)
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001295 self.assertEqual(currsize, n)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001296
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001297 # create n threads in order to fill cache and 1 to clear it
1298 threads = [threading.Thread(target=clear)]
Serhiy Storchaka391af752015-06-08 12:44:18 +03001299 threads += [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001300 for k in range(n)]
1301 start.clear()
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001302 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001303 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001304 finally:
1305 sys.setswitchinterval(orig_si)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001306
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001307 @unittest.skipUnless(threading, 'This test requires threading.')
1308 def test_lru_cache_threaded2(self):
1309 # Simultaneous call with the same arguments
1310 n, m = 5, 7
1311 start = threading.Barrier(n+1)
1312 pause = threading.Barrier(n+1)
1313 stop = threading.Barrier(n+1)
1314 @self.module.lru_cache(maxsize=m*n)
1315 def f(x):
1316 pause.wait(10)
1317 return 3 * x
1318 self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1319 def test():
1320 for i in range(m):
1321 start.wait(10)
1322 self.assertEqual(f(i), 3 * i)
1323 stop.wait(10)
1324 threads = [threading.Thread(target=test) for k in range(n)]
1325 with support.start_threads(threads):
1326 for i in range(m):
1327 start.wait(10)
1328 stop.reset()
1329 pause.wait(10)
1330 start.reset()
1331 stop.wait(10)
1332 pause.reset()
1333 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1334
Raymond Hettinger03923422013-03-04 02:52:50 -05001335 def test_need_for_rlock(self):
1336 # This will deadlock on an LRU cache that uses a regular lock
1337
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001338 @self.module.lru_cache(maxsize=10)
Raymond Hettinger03923422013-03-04 02:52:50 -05001339 def test_func(x):
1340 'Used to demonstrate a reentrant lru_cache call within a single thread'
1341 return x
1342
1343 class DoubleEq:
1344 'Demonstrate a reentrant lru_cache call within a single thread'
1345 def __init__(self, x):
1346 self.x = x
1347 def __hash__(self):
1348 return self.x
1349 def __eq__(self, other):
1350 if self.x == 2:
1351 test_func(DoubleEq(1))
1352 return self.x == other.x
1353
1354 test_func(DoubleEq(1)) # Load the cache
1355 test_func(DoubleEq(2)) # Load the cache
1356 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1357 DoubleEq(2)) # Verify the correct return value
1358
Raymond Hettinger4d588972014-08-12 12:44:52 -07001359 def test_early_detection_of_bad_call(self):
1360 # Issue #22184
1361 with self.assertRaises(TypeError):
1362 @functools.lru_cache
1363 def f():
1364 pass
1365
Serhiy Storchakae7070f02015-06-08 11:19:24 +03001366 def test_lru_method(self):
1367 class X(int):
1368 f_cnt = 0
1369 @self.module.lru_cache(2)
1370 def f(self, x):
1371 self.f_cnt += 1
1372 return x*10+self
1373 a = X(5)
1374 b = X(5)
1375 c = X(7)
1376 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1377
1378 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1379 self.assertEqual(a.f(x), x*10 + 5)
1380 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1381 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1382
1383 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1384 self.assertEqual(b.f(x), x*10 + 5)
1385 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1386 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1387
1388 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1389 self.assertEqual(c.f(x), x*10 + 7)
1390 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1391 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1392
1393 self.assertEqual(a.f.cache_info(), X.f.cache_info())
1394 self.assertEqual(b.f.cache_info(), X.f.cache_info())
1395 self.assertEqual(c.f.cache_info(), X.f.cache_info())
1396
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001397 def test_pickle(self):
1398 cls = self.__class__
1399 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1400 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1401 with self.subTest(proto=proto, func=f):
1402 f_copy = pickle.loads(pickle.dumps(f, proto))
1403 self.assertIs(f_copy, f)
1404
1405 def test_copy(self):
1406 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001407 def orig(x, y):
1408 return 3 * x + y
1409 part = self.module.partial(orig, 2)
1410 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1411 self.module.lru_cache(2)(part))
1412 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001413 with self.subTest(func=f):
1414 f_copy = copy.copy(f)
1415 self.assertIs(f_copy, f)
1416
1417 def test_deepcopy(self):
1418 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001419 def orig(x, y):
1420 return 3 * x + y
1421 part = self.module.partial(orig, 2)
1422 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1423 self.module.lru_cache(2)(part))
1424 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001425 with self.subTest(func=f):
1426 f_copy = copy.deepcopy(f)
1427 self.assertIs(f_copy, f)
1428
1429
1430@py_functools.lru_cache()
1431def py_cached_func(x, y):
1432 return 3 * x + y
1433
1434@c_functools.lru_cache()
1435def c_cached_func(x, y):
1436 return 3 * x + y
1437
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001438
1439class TestLRUPy(TestLRU, unittest.TestCase):
1440 module = py_functools
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001441 cached_func = py_cached_func,
1442
1443 @module.lru_cache()
1444 def cached_meth(self, x, y):
1445 return 3 * x + y
1446
1447 @staticmethod
1448 @module.lru_cache()
1449 def cached_staticmeth(x, y):
1450 return 3 * x + y
1451
1452
1453class TestLRUC(TestLRU, unittest.TestCase):
1454 module = c_functools
1455 cached_func = c_cached_func,
1456
1457 @module.lru_cache()
1458 def cached_meth(self, x, y):
1459 return 3 * x + y
1460
1461 @staticmethod
1462 @module.lru_cache()
1463 def cached_staticmeth(x, y):
1464 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001465
Raymond Hettinger03923422013-03-04 02:52:50 -05001466
Łukasz Langa6f692512013-06-05 12:20:24 +02001467class TestSingleDispatch(unittest.TestCase):
1468 def test_simple_overloads(self):
1469 @functools.singledispatch
1470 def g(obj):
1471 return "base"
1472 def g_int(i):
1473 return "integer"
1474 g.register(int, g_int)
1475 self.assertEqual(g("str"), "base")
1476 self.assertEqual(g(1), "integer")
1477 self.assertEqual(g([1,2,3]), "base")
1478
1479 def test_mro(self):
1480 @functools.singledispatch
1481 def g(obj):
1482 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001483 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001484 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001485 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001486 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001487 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001488 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001489 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001490 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001491 def g_A(a):
1492 return "A"
1493 def g_B(b):
1494 return "B"
1495 g.register(A, g_A)
1496 g.register(B, g_B)
1497 self.assertEqual(g(A()), "A")
1498 self.assertEqual(g(B()), "B")
1499 self.assertEqual(g(C()), "A")
1500 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001501
1502 def test_register_decorator(self):
1503 @functools.singledispatch
1504 def g(obj):
1505 return "base"
1506 @g.register(int)
1507 def g_int(i):
1508 return "int %s" % (i,)
1509 self.assertEqual(g(""), "base")
1510 self.assertEqual(g(12), "int 12")
1511 self.assertIs(g.dispatch(int), g_int)
1512 self.assertIs(g.dispatch(object), g.dispatch(str))
1513 # Note: in the assert above this is not g.
1514 # @singledispatch returns the wrapper.
1515
1516 def test_wrapping_attributes(self):
1517 @functools.singledispatch
1518 def g(obj):
1519 "Simple test"
1520 return "Test"
1521 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02001522 if sys.flags.optimize < 2:
1523 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02001524
1525 @unittest.skipUnless(decimal, 'requires _decimal')
1526 @support.cpython_only
1527 def test_c_classes(self):
1528 @functools.singledispatch
1529 def g(obj):
1530 return "base"
1531 @g.register(decimal.DecimalException)
1532 def _(obj):
1533 return obj.args
1534 subn = decimal.Subnormal("Exponent < Emin")
1535 rnd = decimal.Rounded("Number got rounded")
1536 self.assertEqual(g(subn), ("Exponent < Emin",))
1537 self.assertEqual(g(rnd), ("Number got rounded",))
1538 @g.register(decimal.Subnormal)
1539 def _(obj):
1540 return "Too small to care."
1541 self.assertEqual(g(subn), "Too small to care.")
1542 self.assertEqual(g(rnd), ("Number got rounded",))
1543
1544 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001545 # None of the examples in this test depend on haystack ordering.
Łukasz Langa6f692512013-06-05 12:20:24 +02001546 c = collections
1547 mro = functools._compose_mro
1548 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1549 for haystack in permutations(bases):
1550 m = mro(dict, haystack)
Łukasz Langa3720c772013-07-01 16:00:38 +02001551 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, c.Sized,
1552 c.Iterable, c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001553 bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict]
1554 for haystack in permutations(bases):
1555 m = mro(c.ChainMap, haystack)
1556 self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping,
1557 c.Sized, c.Iterable, c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001558
1559 # If there's a generic function with implementations registered for
1560 # both Sized and Container, passing a defaultdict to it results in an
1561 # ambiguous dispatch which will cause a RuntimeError (see
1562 # test_mro_conflicts).
1563 bases = [c.Container, c.Sized, str]
1564 for haystack in permutations(bases):
1565 m = mro(c.defaultdict, [c.Sized, c.Container, str])
1566 self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container,
1567 object])
1568
1569 # MutableSequence below is registered directly on D. In other words, it
Martin Panter46f50722016-05-26 05:35:26 +00001570 # precedes MutableMapping which means single dispatch will always
Łukasz Langa3720c772013-07-01 16:00:38 +02001571 # choose MutableSequence here.
1572 class D(c.defaultdict):
1573 pass
1574 c.MutableSequence.register(D)
1575 bases = [c.MutableSequence, c.MutableMapping]
1576 for haystack in permutations(bases):
1577 m = mro(D, bases)
1578 self.assertEqual(m, [D, c.MutableSequence, c.Sequence,
1579 c.defaultdict, dict, c.MutableMapping,
1580 c.Mapping, c.Sized, c.Iterable, c.Container,
1581 object])
1582
1583 # Container and Callable are registered on different base classes and
1584 # a generic function supporting both should always pick the Callable
1585 # implementation if a C instance is passed.
1586 class C(c.defaultdict):
1587 def __call__(self):
1588 pass
1589 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1590 for haystack in permutations(bases):
1591 m = mro(C, haystack)
1592 self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping,
1593 c.Sized, c.Iterable, c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001594
1595 def test_register_abc(self):
1596 c = collections
1597 d = {"a": "b"}
1598 l = [1, 2, 3]
1599 s = {object(), None}
1600 f = frozenset(s)
1601 t = (1, 2, 3)
1602 @functools.singledispatch
1603 def g(obj):
1604 return "base"
1605 self.assertEqual(g(d), "base")
1606 self.assertEqual(g(l), "base")
1607 self.assertEqual(g(s), "base")
1608 self.assertEqual(g(f), "base")
1609 self.assertEqual(g(t), "base")
1610 g.register(c.Sized, lambda obj: "sized")
1611 self.assertEqual(g(d), "sized")
1612 self.assertEqual(g(l), "sized")
1613 self.assertEqual(g(s), "sized")
1614 self.assertEqual(g(f), "sized")
1615 self.assertEqual(g(t), "sized")
1616 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1617 self.assertEqual(g(d), "mutablemapping")
1618 self.assertEqual(g(l), "sized")
1619 self.assertEqual(g(s), "sized")
1620 self.assertEqual(g(f), "sized")
1621 self.assertEqual(g(t), "sized")
1622 g.register(c.ChainMap, lambda obj: "chainmap")
1623 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1624 self.assertEqual(g(l), "sized")
1625 self.assertEqual(g(s), "sized")
1626 self.assertEqual(g(f), "sized")
1627 self.assertEqual(g(t), "sized")
1628 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1629 self.assertEqual(g(d), "mutablemapping")
1630 self.assertEqual(g(l), "mutablesequence")
1631 self.assertEqual(g(s), "sized")
1632 self.assertEqual(g(f), "sized")
1633 self.assertEqual(g(t), "sized")
1634 g.register(c.MutableSet, lambda obj: "mutableset")
1635 self.assertEqual(g(d), "mutablemapping")
1636 self.assertEqual(g(l), "mutablesequence")
1637 self.assertEqual(g(s), "mutableset")
1638 self.assertEqual(g(f), "sized")
1639 self.assertEqual(g(t), "sized")
1640 g.register(c.Mapping, lambda obj: "mapping")
1641 self.assertEqual(g(d), "mutablemapping") # not specific enough
1642 self.assertEqual(g(l), "mutablesequence")
1643 self.assertEqual(g(s), "mutableset")
1644 self.assertEqual(g(f), "sized")
1645 self.assertEqual(g(t), "sized")
1646 g.register(c.Sequence, lambda obj: "sequence")
1647 self.assertEqual(g(d), "mutablemapping")
1648 self.assertEqual(g(l), "mutablesequence")
1649 self.assertEqual(g(s), "mutableset")
1650 self.assertEqual(g(f), "sized")
1651 self.assertEqual(g(t), "sequence")
1652 g.register(c.Set, lambda obj: "set")
1653 self.assertEqual(g(d), "mutablemapping")
1654 self.assertEqual(g(l), "mutablesequence")
1655 self.assertEqual(g(s), "mutableset")
1656 self.assertEqual(g(f), "set")
1657 self.assertEqual(g(t), "sequence")
1658 g.register(dict, lambda obj: "dict")
1659 self.assertEqual(g(d), "dict")
1660 self.assertEqual(g(l), "mutablesequence")
1661 self.assertEqual(g(s), "mutableset")
1662 self.assertEqual(g(f), "set")
1663 self.assertEqual(g(t), "sequence")
1664 g.register(list, lambda obj: "list")
1665 self.assertEqual(g(d), "dict")
1666 self.assertEqual(g(l), "list")
1667 self.assertEqual(g(s), "mutableset")
1668 self.assertEqual(g(f), "set")
1669 self.assertEqual(g(t), "sequence")
1670 g.register(set, lambda obj: "concrete-set")
1671 self.assertEqual(g(d), "dict")
1672 self.assertEqual(g(l), "list")
1673 self.assertEqual(g(s), "concrete-set")
1674 self.assertEqual(g(f), "set")
1675 self.assertEqual(g(t), "sequence")
1676 g.register(frozenset, lambda obj: "frozen-set")
1677 self.assertEqual(g(d), "dict")
1678 self.assertEqual(g(l), "list")
1679 self.assertEqual(g(s), "concrete-set")
1680 self.assertEqual(g(f), "frozen-set")
1681 self.assertEqual(g(t), "sequence")
1682 g.register(tuple, lambda obj: "tuple")
1683 self.assertEqual(g(d), "dict")
1684 self.assertEqual(g(l), "list")
1685 self.assertEqual(g(s), "concrete-set")
1686 self.assertEqual(g(f), "frozen-set")
1687 self.assertEqual(g(t), "tuple")
1688
Łukasz Langa3720c772013-07-01 16:00:38 +02001689 def test_c3_abc(self):
1690 c = collections
1691 mro = functools._c3_mro
1692 class A(object):
1693 pass
1694 class B(A):
1695 def __len__(self):
1696 return 0 # implies Sized
1697 @c.Container.register
1698 class C(object):
1699 pass
1700 class D(object):
1701 pass # unrelated
1702 class X(D, C, B):
1703 def __call__(self):
1704 pass # implies Callable
1705 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
1706 for abcs in permutations([c.Sized, c.Callable, c.Container]):
1707 self.assertEqual(mro(X, abcs=abcs), expected)
1708 # unrelated ABCs don't appear in the resulting MRO
1709 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
1710 self.assertEqual(mro(X, abcs=many_abcs), expected)
1711
Yury Selivanov77a8cd62015-08-18 14:20:00 -04001712 def test_false_meta(self):
1713 # see issue23572
1714 class MetaA(type):
1715 def __len__(self):
1716 return 0
1717 class A(metaclass=MetaA):
1718 pass
1719 class AA(A):
1720 pass
1721 @functools.singledispatch
1722 def fun(a):
1723 return 'base A'
1724 @fun.register(A)
1725 def _(a):
1726 return 'fun A'
1727 aa = AA()
1728 self.assertEqual(fun(aa), 'fun A')
1729
Łukasz Langa6f692512013-06-05 12:20:24 +02001730 def test_mro_conflicts(self):
1731 c = collections
Łukasz Langa6f692512013-06-05 12:20:24 +02001732 @functools.singledispatch
1733 def g(arg):
1734 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02001735 class O(c.Sized):
1736 def __len__(self):
1737 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001738 o = O()
1739 self.assertEqual(g(o), "base")
1740 g.register(c.Iterable, lambda arg: "iterable")
1741 g.register(c.Container, lambda arg: "container")
1742 g.register(c.Sized, lambda arg: "sized")
1743 g.register(c.Set, lambda arg: "set")
1744 self.assertEqual(g(o), "sized")
1745 c.Iterable.register(O)
1746 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
1747 c.Container.register(O)
1748 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02001749 c.Set.register(O)
1750 self.assertEqual(g(o), "set") # because c.Set is a subclass of
1751 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02001752 class P:
1753 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02001754 p = P()
1755 self.assertEqual(g(p), "base")
1756 c.Iterable.register(P)
1757 self.assertEqual(g(p), "iterable")
1758 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02001759 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02001760 g(p)
Łukasz Langa3720c772013-07-01 16:00:38 +02001761 self.assertIn(
1762 str(re_one.exception),
1763 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1764 "or <class 'collections.abc.Iterable'>"),
1765 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
1766 "or <class 'collections.abc.Container'>")),
1767 )
Łukasz Langa6f692512013-06-05 12:20:24 +02001768 class Q(c.Sized):
1769 def __len__(self):
1770 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001771 q = Q()
1772 self.assertEqual(g(q), "sized")
1773 c.Iterable.register(Q)
1774 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
1775 c.Set.register(Q)
1776 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02001777 # c.Sized and c.Iterable
1778 @functools.singledispatch
1779 def h(arg):
1780 return "base"
1781 @h.register(c.Sized)
1782 def _(arg):
1783 return "sized"
1784 @h.register(c.Container)
1785 def _(arg):
1786 return "container"
1787 # Even though Sized and Container are explicit bases of MutableMapping,
1788 # this ABC is implicitly registered on defaultdict which makes all of
1789 # MutableMapping's bases implicit as well from defaultdict's
1790 # perspective.
1791 with self.assertRaises(RuntimeError) as re_two:
1792 h(c.defaultdict(lambda: 0))
1793 self.assertIn(
1794 str(re_two.exception),
1795 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1796 "or <class 'collections.abc.Sized'>"),
1797 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1798 "or <class 'collections.abc.Container'>")),
1799 )
1800 class R(c.defaultdict):
1801 pass
1802 c.MutableSequence.register(R)
1803 @functools.singledispatch
1804 def i(arg):
1805 return "base"
1806 @i.register(c.MutableMapping)
1807 def _(arg):
1808 return "mapping"
1809 @i.register(c.MutableSequence)
1810 def _(arg):
1811 return "sequence"
1812 r = R()
1813 self.assertEqual(i(r), "sequence")
1814 class S:
1815 pass
1816 class T(S, c.Sized):
1817 def __len__(self):
1818 return 0
1819 t = T()
1820 self.assertEqual(h(t), "sized")
1821 c.Container.register(T)
1822 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
1823 class U:
1824 def __len__(self):
1825 return 0
1826 u = U()
1827 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
1828 # from the existence of __len__()
1829 c.Container.register(U)
1830 # There is no preference for registered versus inferred ABCs.
1831 with self.assertRaises(RuntimeError) as re_three:
1832 h(u)
1833 self.assertIn(
1834 str(re_three.exception),
1835 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1836 "or <class 'collections.abc.Sized'>"),
1837 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1838 "or <class 'collections.abc.Container'>")),
1839 )
1840 class V(c.Sized, S):
1841 def __len__(self):
1842 return 0
1843 @functools.singledispatch
1844 def j(arg):
1845 return "base"
1846 @j.register(S)
1847 def _(arg):
1848 return "s"
1849 @j.register(c.Container)
1850 def _(arg):
1851 return "container"
1852 v = V()
1853 self.assertEqual(j(v), "s")
1854 c.Container.register(V)
1855 self.assertEqual(j(v), "container") # because it ends up right after
1856 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02001857
1858 def test_cache_invalidation(self):
1859 from collections import UserDict
1860 class TracingDict(UserDict):
1861 def __init__(self, *args, **kwargs):
1862 super(TracingDict, self).__init__(*args, **kwargs)
1863 self.set_ops = []
1864 self.get_ops = []
1865 def __getitem__(self, key):
1866 result = self.data[key]
1867 self.get_ops.append(key)
1868 return result
1869 def __setitem__(self, key, value):
1870 self.set_ops.append(key)
1871 self.data[key] = value
1872 def clear(self):
1873 self.data.clear()
1874 _orig_wkd = functools.WeakKeyDictionary
1875 td = TracingDict()
1876 functools.WeakKeyDictionary = lambda: td
1877 c = collections
1878 @functools.singledispatch
1879 def g(arg):
1880 return "base"
1881 d = {}
1882 l = []
1883 self.assertEqual(len(td), 0)
1884 self.assertEqual(g(d), "base")
1885 self.assertEqual(len(td), 1)
1886 self.assertEqual(td.get_ops, [])
1887 self.assertEqual(td.set_ops, [dict])
1888 self.assertEqual(td.data[dict], g.registry[object])
1889 self.assertEqual(g(l), "base")
1890 self.assertEqual(len(td), 2)
1891 self.assertEqual(td.get_ops, [])
1892 self.assertEqual(td.set_ops, [dict, list])
1893 self.assertEqual(td.data[dict], g.registry[object])
1894 self.assertEqual(td.data[list], g.registry[object])
1895 self.assertEqual(td.data[dict], td.data[list])
1896 self.assertEqual(g(l), "base")
1897 self.assertEqual(g(d), "base")
1898 self.assertEqual(td.get_ops, [list, dict])
1899 self.assertEqual(td.set_ops, [dict, list])
1900 g.register(list, lambda arg: "list")
1901 self.assertEqual(td.get_ops, [list, dict])
1902 self.assertEqual(len(td), 0)
1903 self.assertEqual(g(d), "base")
1904 self.assertEqual(len(td), 1)
1905 self.assertEqual(td.get_ops, [list, dict])
1906 self.assertEqual(td.set_ops, [dict, list, dict])
1907 self.assertEqual(td.data[dict],
1908 functools._find_impl(dict, g.registry))
1909 self.assertEqual(g(l), "list")
1910 self.assertEqual(len(td), 2)
1911 self.assertEqual(td.get_ops, [list, dict])
1912 self.assertEqual(td.set_ops, [dict, list, dict, list])
1913 self.assertEqual(td.data[list],
1914 functools._find_impl(list, g.registry))
1915 class X:
1916 pass
1917 c.MutableMapping.register(X) # Will not invalidate the cache,
1918 # not using ABCs yet.
1919 self.assertEqual(g(d), "base")
1920 self.assertEqual(g(l), "list")
1921 self.assertEqual(td.get_ops, [list, dict, dict, list])
1922 self.assertEqual(td.set_ops, [dict, list, dict, list])
1923 g.register(c.Sized, lambda arg: "sized")
1924 self.assertEqual(len(td), 0)
1925 self.assertEqual(g(d), "sized")
1926 self.assertEqual(len(td), 1)
1927 self.assertEqual(td.get_ops, [list, dict, dict, list])
1928 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
1929 self.assertEqual(g(l), "list")
1930 self.assertEqual(len(td), 2)
1931 self.assertEqual(td.get_ops, [list, dict, dict, list])
1932 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1933 self.assertEqual(g(l), "list")
1934 self.assertEqual(g(d), "sized")
1935 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
1936 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1937 g.dispatch(list)
1938 g.dispatch(dict)
1939 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
1940 list, dict])
1941 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1942 c.MutableSet.register(X) # Will invalidate the cache.
1943 self.assertEqual(len(td), 2) # Stale cache.
1944 self.assertEqual(g(l), "list")
1945 self.assertEqual(len(td), 1)
1946 g.register(c.MutableMapping, lambda arg: "mutablemapping")
1947 self.assertEqual(len(td), 0)
1948 self.assertEqual(g(d), "mutablemapping")
1949 self.assertEqual(len(td), 1)
1950 self.assertEqual(g(l), "list")
1951 self.assertEqual(len(td), 2)
1952 g.register(dict, lambda arg: "dict")
1953 self.assertEqual(g(d), "dict")
1954 self.assertEqual(g(l), "list")
1955 g._clear_cache()
1956 self.assertEqual(len(td), 0)
1957 functools.WeakKeyDictionary = _orig_wkd
1958
1959
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001960if __name__ == '__main__':
Zachary Ware38c707e2015-04-13 15:00:43 -05001961 unittest.main()