blob: cf0b95d73c755ba969c43fec1a438b3b3b6afce9 [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettinger003be522011-05-03 11:01:32 -07002import collections
Serhiy Storchaka45120f22015-10-24 09:49:56 +03003import copy
Łukasz Langa6f692512013-06-05 12:20:24 +02004from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00005import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00006from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02007import sys
8from test import support
9import unittest
10from weakref import proxy
Serhiy Storchaka46c56112015-05-24 21:53:49 +030011try:
12 import threading
13except ImportError:
14 threading = None
Raymond Hettinger9c323f82005-02-28 19:39:44 +000015
Antoine Pitroub5b37142012-11-13 21:35:40 +010016import functools
17
Antoine Pitroub5b37142012-11-13 21:35:40 +010018py_functools = support.import_fresh_module('functools', blocked=['_functools'])
19c_functools = support.import_fresh_module('functools', fresh=['_functools'])
20
Łukasz Langa6f692512013-06-05 12:20:24 +020021decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
22
23
Raymond Hettinger9c323f82005-02-28 19:39:44 +000024def capture(*args, **kw):
25 """capture all positional and keyword arguments"""
26 return args, kw
27
Łukasz Langa6f692512013-06-05 12:20:24 +020028
Jack Diederiche0cbd692009-04-01 04:27:09 +000029def signature(part):
30 """ return the signature of a partial object """
31 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000032
Łukasz Langa6f692512013-06-05 12:20:24 +020033
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020034class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000035
36 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010037 p = self.partial(capture, 1, 2, a=10, b=20)
38 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000039 self.assertEqual(p(3, 4, b=30, c=40),
40 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010041 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000042 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000043
44 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010045 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000046 # attributes should be readable
47 self.assertEqual(p.func, capture)
48 self.assertEqual(p.args, (1, 2))
49 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000050
51 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010052 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000053 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010054 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000055 except TypeError:
56 pass
57 else:
58 self.fail('First arg not checked for callability')
59
60 def test_protection_of_callers_dict_argument(self):
61 # a caller's dictionary should not be altered by partial
62 def func(a=10, b=20):
63 return a
64 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010065 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000066 self.assertEqual(p(**d), 3)
67 self.assertEqual(d, {'a':3})
68 p(b=7)
69 self.assertEqual(d, {'a':3})
70
71 def test_arg_combinations(self):
72 # exercise special code paths for zero args in either partial
73 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010074 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000075 self.assertEqual(p(), ((), {}))
76 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010077 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000078 self.assertEqual(p(), ((1,2), {}))
79 self.assertEqual(p(3,4), ((1,2,3,4), {}))
80
81 def test_kw_combinations(self):
82 # exercise special code paths for no keyword args in
83 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010084 p = self.partial(capture)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -040085 self.assertEqual(p.keywords, {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +000086 self.assertEqual(p(), ((), {}))
87 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010088 p = self.partial(capture, a=1)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -040089 self.assertEqual(p.keywords, {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +000090 self.assertEqual(p(), ((), {'a':1}))
91 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
92 # keyword args in the call override those in the partial object
93 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
94
95 def test_positional(self):
96 # make sure positional arguments are captured correctly
97 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +010098 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000099 expected = args + ('x',)
100 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000101 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000102
103 def test_keyword(self):
104 # make sure keyword arguments are captured correctly
105 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100106 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000107 expected = {'a':a,'x':None}
108 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000109 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000110
111 def test_no_side_effects(self):
112 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100113 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000114 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000115 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000116 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000117 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000118
119 def test_error_propagation(self):
120 def f(x, y):
121 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100122 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
123 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
124 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
125 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000126
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000127 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100128 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000129 p = proxy(f)
130 self.assertEqual(f.func, p.func)
131 f = None
132 self.assertRaises(ReferenceError, getattr, p, 'func')
133
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000134 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000135 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100136 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000137 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100138 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000139 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000140
Alexander Belopolskye49af342015-03-01 15:08:17 -0500141 def test_nested_optimization(self):
142 partial = self.partial
Alexander Belopolskye49af342015-03-01 15:08:17 -0500143 inner = partial(signature, 'asdf')
144 nested = partial(inner, bar=True)
145 flat = partial(signature, 'asdf', bar=True)
146 self.assertEqual(signature(nested), signature(flat))
147
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300148 def test_nested_partial_with_attribute(self):
149 # see issue 25137
150 partial = self.partial
151
152 def foo(bar):
153 return bar
154
155 p = partial(foo, 'first')
156 p2 = partial(p, 'second')
157 p2.new_attr = 'spam'
158 self.assertEqual(p2.new_attr, 'spam')
159
Łukasz Langa6f692512013-06-05 12:20:24 +0200160
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200161@unittest.skipUnless(c_functools, 'requires the C _functools module')
162class TestPartialC(TestPartial, unittest.TestCase):
163 if c_functools:
164 partial = c_functools.partial
165
Zachary Ware101d9e72013-12-08 00:44:27 -0600166 def test_attributes_unwritable(self):
167 # attributes should not be writable
168 p = self.partial(capture, 1, 2, a=10, b=20)
169 self.assertRaises(AttributeError, setattr, p, 'func', map)
170 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
171 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
172
173 p = self.partial(hex)
174 try:
175 del p.__dict__
176 except TypeError:
177 pass
178 else:
179 self.fail('partial object allowed __dict__ to be deleted')
180
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000181 def test_repr(self):
182 args = (object(), object())
183 args_repr = ', '.join(repr(a) for a in args)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200184 kwargs = {'a': object(), 'b': object()}
185 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
186 'b={b!r}, a={a!r}'.format_map(kwargs)]
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200187 if self.partial is c_functools.partial:
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000188 name = 'functools.partial'
189 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100190 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000191
Antoine Pitroub5b37142012-11-13 21:35:40 +0100192 f = self.partial(capture)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000193 self.assertEqual('{}({!r})'.format(name, capture),
194 repr(f))
195
Antoine Pitroub5b37142012-11-13 21:35:40 +0100196 f = self.partial(capture, *args)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000197 self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr),
198 repr(f))
199
Antoine Pitroub5b37142012-11-13 21:35:40 +0100200 f = self.partial(capture, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200201 self.assertIn(repr(f),
202 ['{}({!r}, {})'.format(name, capture, kwargs_repr)
203 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000204
Antoine Pitroub5b37142012-11-13 21:35:40 +0100205 f = self.partial(capture, *args, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200206 self.assertIn(repr(f),
207 ['{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr)
208 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000209
Jack Diederiche0cbd692009-04-01 04:27:09 +0000210 def test_pickle(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100211 f = self.partial(signature, 'asdf', bar=True)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000212 f.add_something_to__dict__ = True
Serhiy Storchakabad12572014-12-15 14:03:42 +0200213 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
214 f_copy = pickle.loads(pickle.dumps(f, proto))
215 self.assertEqual(signature(f), signature(f_copy))
Jack Diederiche0cbd692009-04-01 04:27:09 +0000216
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200217 # Issue 6083: Reference counting bug
218 def test_setstate_refcount(self):
219 class BadSequence:
220 def __len__(self):
221 return 4
222 def __getitem__(self, key):
223 if key == 0:
224 return max
225 elif key == 1:
226 return tuple(range(1000000))
227 elif key in (2, 3):
228 return {}
229 raise IndexError
230
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200231 f = self.partial(object)
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200232 self.assertRaisesRegex(SystemError,
233 "new style getargs format but argument is not a tuple",
234 f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000235
Łukasz Langa6f692512013-06-05 12:20:24 +0200236
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200237class TestPartialPy(TestPartial, unittest.TestCase):
238 partial = staticmethod(py_functools.partial)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000239
Łukasz Langa6f692512013-06-05 12:20:24 +0200240
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200241if c_functools:
242 class PartialSubclass(c_functools.partial):
243 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100244
Łukasz Langa6f692512013-06-05 12:20:24 +0200245
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200246@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200247class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200248 if c_functools:
249 partial = PartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000250
Berker Peksag9b93c6b2015-09-22 13:08:16 +0300251 # partial subclasses are not optimized for nested calls
252 test_nested_optimization = None
253
Łukasz Langa6f692512013-06-05 12:20:24 +0200254
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000255class TestPartialMethod(unittest.TestCase):
256
257 class A(object):
258 nothing = functools.partialmethod(capture)
259 positional = functools.partialmethod(capture, 1)
260 keywords = functools.partialmethod(capture, a=2)
261 both = functools.partialmethod(capture, 3, b=4)
262
263 nested = functools.partialmethod(positional, 5)
264
265 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
266
267 static = functools.partialmethod(staticmethod(capture), 8)
268 cls = functools.partialmethod(classmethod(capture), d=9)
269
270 a = A()
271
272 def test_arg_combinations(self):
273 self.assertEqual(self.a.nothing(), ((self.a,), {}))
274 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
275 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
276 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
277
278 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
279 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
280 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
281 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
282
283 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
284 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
285 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
286 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
287
288 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
289 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
290 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
291 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
292
293 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
294
295 def test_nested(self):
296 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
297 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
298 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
299 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
300
301 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
302
303 def test_over_partial(self):
304 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
305 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
306 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
307 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
308
309 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
310
311 def test_bound_method_introspection(self):
312 obj = self.a
313 self.assertIs(obj.both.__self__, obj)
314 self.assertIs(obj.nested.__self__, obj)
315 self.assertIs(obj.over_partial.__self__, obj)
316 self.assertIs(obj.cls.__self__, self.A)
317 self.assertIs(self.A.cls.__self__, self.A)
318
319 def test_unbound_method_retrieval(self):
320 obj = self.A
321 self.assertFalse(hasattr(obj.both, "__self__"))
322 self.assertFalse(hasattr(obj.nested, "__self__"))
323 self.assertFalse(hasattr(obj.over_partial, "__self__"))
324 self.assertFalse(hasattr(obj.static, "__self__"))
325 self.assertFalse(hasattr(self.a.static, "__self__"))
326
327 def test_descriptors(self):
328 for obj in [self.A, self.a]:
329 with self.subTest(obj=obj):
330 self.assertEqual(obj.static(), ((8,), {}))
331 self.assertEqual(obj.static(5), ((8, 5), {}))
332 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
333 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
334
335 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
336 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
337 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
338 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
339
340 def test_overriding_keywords(self):
341 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
342 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
343
344 def test_invalid_args(self):
345 with self.assertRaises(TypeError):
346 class B(object):
347 method = functools.partialmethod(None, 1)
348
349 def test_repr(self):
350 self.assertEqual(repr(vars(self.A)['both']),
351 'functools.partialmethod({}, 3, b=4)'.format(capture))
352
353 def test_abstract(self):
354 class Abstract(abc.ABCMeta):
355
356 @abc.abstractmethod
357 def add(self, x, y):
358 pass
359
360 add5 = functools.partialmethod(add, 5)
361
362 self.assertTrue(Abstract.add.__isabstractmethod__)
363 self.assertTrue(Abstract.add5.__isabstractmethod__)
364
365 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
366 self.assertFalse(getattr(func, '__isabstractmethod__', False))
367
368
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000369class TestUpdateWrapper(unittest.TestCase):
370
371 def check_wrapper(self, wrapper, wrapped,
372 assigned=functools.WRAPPER_ASSIGNMENTS,
373 updated=functools.WRAPPER_UPDATES):
374 # Check attributes were assigned
375 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000376 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000377 # Check attributes were updated
378 for name in updated:
379 wrapper_attr = getattr(wrapper, name)
380 wrapped_attr = getattr(wrapped, name)
381 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000382 if name == "__dict__" and key == "__wrapped__":
383 # __wrapped__ is overwritten by the update code
384 continue
385 self.assertIs(wrapped_attr[key], wrapper_attr[key])
386 # Check __wrapped__
387 self.assertIs(wrapper.__wrapped__, wrapped)
388
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000389
R. David Murray378c0cf2010-02-24 01:46:21 +0000390 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000391 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000392 """This is a test"""
393 pass
394 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000395 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000396 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000397 pass
398 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000399 return wrapper, f
400
401 def test_default_update(self):
402 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000403 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000404 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000405 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600406 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000407 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000408 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
409 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000410
R. David Murray378c0cf2010-02-24 01:46:21 +0000411 @unittest.skipIf(sys.flags.optimize >= 2,
412 "Docstrings are omitted with -O2 and above")
413 def test_default_update_doc(self):
414 wrapper, f = self._default_update()
415 self.assertEqual(wrapper.__doc__, 'This is a test')
416
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000417 def test_no_update(self):
418 def f():
419 """This is a test"""
420 pass
421 f.attr = 'This is also a test'
422 def wrapper():
423 pass
424 functools.update_wrapper(wrapper, f, (), ())
425 self.check_wrapper(wrapper, f, (), ())
426 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600427 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000428 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000429 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000430 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000431
432 def test_selective_update(self):
433 def f():
434 pass
435 f.attr = 'This is a different test'
436 f.dict_attr = dict(a=1, b=2, c=3)
437 def wrapper():
438 pass
439 wrapper.dict_attr = {}
440 assign = ('attr',)
441 update = ('dict_attr',)
442 functools.update_wrapper(wrapper, f, assign, update)
443 self.check_wrapper(wrapper, f, assign, update)
444 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600445 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000446 self.assertEqual(wrapper.__doc__, None)
447 self.assertEqual(wrapper.attr, 'This is a different test')
448 self.assertEqual(wrapper.dict_attr, f.dict_attr)
449
Nick Coghlan98876832010-08-17 06:17:18 +0000450 def test_missing_attributes(self):
451 def f():
452 pass
453 def wrapper():
454 pass
455 wrapper.dict_attr = {}
456 assign = ('attr',)
457 update = ('dict_attr',)
458 # Missing attributes on wrapped object are ignored
459 functools.update_wrapper(wrapper, f, assign, update)
460 self.assertNotIn('attr', wrapper.__dict__)
461 self.assertEqual(wrapper.dict_attr, {})
462 # Wrapper must have expected attributes for updating
463 del wrapper.dict_attr
464 with self.assertRaises(AttributeError):
465 functools.update_wrapper(wrapper, f, assign, update)
466 wrapper.dict_attr = 1
467 with self.assertRaises(AttributeError):
468 functools.update_wrapper(wrapper, f, assign, update)
469
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200470 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000471 @unittest.skipIf(sys.flags.optimize >= 2,
472 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000473 def test_builtin_update(self):
474 # Test for bug #1576241
475 def wrapper():
476 pass
477 functools.update_wrapper(wrapper, max)
478 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000479 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000480 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000481
Łukasz Langa6f692512013-06-05 12:20:24 +0200482
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000483class TestWraps(TestUpdateWrapper):
484
R. David Murray378c0cf2010-02-24 01:46:21 +0000485 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000486 def f():
487 """This is a test"""
488 pass
489 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000490 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000491 @functools.wraps(f)
492 def wrapper():
493 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600494 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000495
496 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600497 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000498 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000499 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600500 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000501 self.assertEqual(wrapper.attr, 'This is also a test')
502
Antoine Pitroub5b37142012-11-13 21:35:40 +0100503 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000504 "Docstrings are omitted with -O2 and above")
505 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600506 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000507 self.assertEqual(wrapper.__doc__, 'This is a test')
508
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000509 def test_no_update(self):
510 def f():
511 """This is a test"""
512 pass
513 f.attr = 'This is also a test'
514 @functools.wraps(f, (), ())
515 def wrapper():
516 pass
517 self.check_wrapper(wrapper, f, (), ())
518 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600519 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000520 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000521 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000522
523 def test_selective_update(self):
524 def f():
525 pass
526 f.attr = 'This is a different test'
527 f.dict_attr = dict(a=1, b=2, c=3)
528 def add_dict_attr(f):
529 f.dict_attr = {}
530 return f
531 assign = ('attr',)
532 update = ('dict_attr',)
533 @functools.wraps(f, assign, update)
534 @add_dict_attr
535 def wrapper():
536 pass
537 self.check_wrapper(wrapper, f, assign, update)
538 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600539 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000540 self.assertEqual(wrapper.__doc__, None)
541 self.assertEqual(wrapper.attr, 'This is a different test')
542 self.assertEqual(wrapper.dict_attr, f.dict_attr)
543
Łukasz Langa6f692512013-06-05 12:20:24 +0200544
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000545class TestReduce(unittest.TestCase):
546 func = functools.reduce
547
548 def test_reduce(self):
549 class Squares:
550 def __init__(self, max):
551 self.max = max
552 self.sofar = []
553
554 def __len__(self):
555 return len(self.sofar)
556
557 def __getitem__(self, i):
558 if not 0 <= i < self.max: raise IndexError
559 n = len(self.sofar)
560 while n <= i:
561 self.sofar.append(n*n)
562 n += 1
563 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000564 def add(x, y):
565 return x + y
566 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000567 self.assertEqual(
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000568 self.func(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000569 ['a','c','d','w']
570 )
571 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
572 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000573 self.func(lambda x, y: x*y, range(2,21), 1),
574 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000575 )
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000576 self.assertEqual(self.func(add, Squares(10)), 285)
577 self.assertEqual(self.func(add, Squares(10), 0), 285)
578 self.assertEqual(self.func(add, Squares(0), 0), 0)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000579 self.assertRaises(TypeError, self.func)
580 self.assertRaises(TypeError, self.func, 42, 42)
581 self.assertRaises(TypeError, self.func, 42, 42, 42)
582 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
583 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
584 self.assertRaises(TypeError, self.func, 42, (42, 42))
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000585 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
586 self.assertRaises(TypeError, self.func, add, "")
587 self.assertRaises(TypeError, self.func, add, ())
588 self.assertRaises(TypeError, self.func, add, object())
589
590 class TestFailingIter:
591 def __iter__(self):
592 raise RuntimeError
593 self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
594
595 self.assertEqual(self.func(add, [], None), None)
596 self.assertEqual(self.func(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000597
598 class BadSeq:
599 def __getitem__(self, index):
600 raise ValueError
601 self.assertRaises(ValueError, self.func, 42, BadSeq())
602
603 # Test reduce()'s use of iterators.
604 def test_iterator_usage(self):
605 class SequenceClass:
606 def __init__(self, n):
607 self.n = n
608 def __getitem__(self, i):
609 if 0 <= i < self.n:
610 return i
611 else:
612 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000613
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000614 from operator import add
615 self.assertEqual(self.func(add, SequenceClass(5)), 10)
616 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
617 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
618 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
619 self.assertEqual(self.func(add, SequenceClass(1)), 0)
620 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
621
622 d = {"one": 1, "two": 2, "three": 3}
623 self.assertEqual(self.func(add, d), "".join(d.keys()))
624
Łukasz Langa6f692512013-06-05 12:20:24 +0200625
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200626class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700627
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000628 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700629 def cmp1(x, y):
630 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100631 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700632 self.assertEqual(key(3), key(3))
633 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100634 self.assertGreaterEqual(key(3), key(3))
635
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700636 def cmp2(x, y):
637 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100638 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700639 self.assertEqual(key(4.0), key('4'))
640 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100641 self.assertLessEqual(key(2), key('35'))
642 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700643
644 def test_cmp_to_key_arguments(self):
645 def cmp1(x, y):
646 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100647 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700648 self.assertEqual(key(obj=3), key(obj=3))
649 self.assertGreater(key(obj=3), key(obj=1))
650 with self.assertRaises((TypeError, AttributeError)):
651 key(3) > 1 # rhs is not a K object
652 with self.assertRaises((TypeError, AttributeError)):
653 1 < key(3) # lhs is not a K object
654 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100655 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700656 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200657 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100658 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700659 with self.assertRaises(TypeError):
660 key() # too few args
661 with self.assertRaises(TypeError):
662 key(None, None) # too many args
663
664 def test_bad_cmp(self):
665 def cmp1(x, y):
666 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100667 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700668 with self.assertRaises(ZeroDivisionError):
669 key(3) > key(1)
670
671 class BadCmp:
672 def __lt__(self, other):
673 raise ZeroDivisionError
674 def cmp1(x, y):
675 return BadCmp()
676 with self.assertRaises(ZeroDivisionError):
677 key(3) > key(1)
678
679 def test_obj_field(self):
680 def cmp1(x, y):
681 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100682 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700683 self.assertEqual(key(50).obj, 50)
684
685 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000686 def mycmp(x, y):
687 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100688 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000689 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000690
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700691 def test_sort_int_str(self):
692 def mycmp(x, y):
693 x, y = int(x), int(y)
694 return (x > y) - (x < y)
695 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100696 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700697 self.assertEqual([int(value) for value in values],
698 [0, 1, 1, 2, 3, 4, 5, 7, 10])
699
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000700 def test_hash(self):
701 def mycmp(x, y):
702 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100703 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000704 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700705 self.assertRaises(TypeError, hash, k)
Raymond Hettingere7a24302011-05-03 11:16:36 -0700706 self.assertNotIsInstance(k, collections.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000707
Łukasz Langa6f692512013-06-05 12:20:24 +0200708
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200709@unittest.skipUnless(c_functools, 'requires the C _functools module')
710class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
711 if c_functools:
712 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100713
Łukasz Langa6f692512013-06-05 12:20:24 +0200714
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200715class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100716 cmp_to_key = staticmethod(py_functools.cmp_to_key)
717
Łukasz Langa6f692512013-06-05 12:20:24 +0200718
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000719class TestTotalOrdering(unittest.TestCase):
720
721 def test_total_ordering_lt(self):
722 @functools.total_ordering
723 class A:
724 def __init__(self, value):
725 self.value = value
726 def __lt__(self, other):
727 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000728 def __eq__(self, other):
729 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000730 self.assertTrue(A(1) < A(2))
731 self.assertTrue(A(2) > A(1))
732 self.assertTrue(A(1) <= A(2))
733 self.assertTrue(A(2) >= A(1))
734 self.assertTrue(A(2) <= A(2))
735 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000736 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000737
738 def test_total_ordering_le(self):
739 @functools.total_ordering
740 class A:
741 def __init__(self, value):
742 self.value = value
743 def __le__(self, other):
744 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000745 def __eq__(self, other):
746 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000747 self.assertTrue(A(1) < A(2))
748 self.assertTrue(A(2) > A(1))
749 self.assertTrue(A(1) <= A(2))
750 self.assertTrue(A(2) >= A(1))
751 self.assertTrue(A(2) <= A(2))
752 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000753 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000754
755 def test_total_ordering_gt(self):
756 @functools.total_ordering
757 class A:
758 def __init__(self, value):
759 self.value = value
760 def __gt__(self, other):
761 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000762 def __eq__(self, other):
763 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000764 self.assertTrue(A(1) < A(2))
765 self.assertTrue(A(2) > A(1))
766 self.assertTrue(A(1) <= A(2))
767 self.assertTrue(A(2) >= A(1))
768 self.assertTrue(A(2) <= A(2))
769 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000770 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000771
772 def test_total_ordering_ge(self):
773 @functools.total_ordering
774 class A:
775 def __init__(self, value):
776 self.value = value
777 def __ge__(self, other):
778 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000779 def __eq__(self, other):
780 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000781 self.assertTrue(A(1) < A(2))
782 self.assertTrue(A(2) > A(1))
783 self.assertTrue(A(1) <= A(2))
784 self.assertTrue(A(2) >= A(1))
785 self.assertTrue(A(2) <= A(2))
786 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000787 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000788
789 def test_total_ordering_no_overwrite(self):
790 # new methods should not overwrite existing
791 @functools.total_ordering
792 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000793 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +0000794 self.assertTrue(A(1) < A(2))
795 self.assertTrue(A(2) > A(1))
796 self.assertTrue(A(1) <= A(2))
797 self.assertTrue(A(2) >= A(1))
798 self.assertTrue(A(2) <= A(2))
799 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000800
Benjamin Peterson42ebee32010-04-11 01:43:16 +0000801 def test_no_operations_defined(self):
802 with self.assertRaises(ValueError):
803 @functools.total_ordering
804 class A:
805 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000806
Nick Coghlanf05d9812013-10-02 00:02:03 +1000807 def test_type_error_when_not_implemented(self):
808 # bug 10042; ensure stack overflow does not occur
809 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000810 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +1000811 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000812 def __init__(self, value):
813 self.value = value
814 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000815 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000816 return self.value == other.value
817 return False
818 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000819 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000820 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +1000821 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000822
Nick Coghlanf05d9812013-10-02 00:02:03 +1000823 @functools.total_ordering
824 class ImplementsGreaterThan:
825 def __init__(self, value):
826 self.value = value
827 def __eq__(self, other):
828 if isinstance(other, ImplementsGreaterThan):
829 return self.value == other.value
830 return False
831 def __gt__(self, other):
832 if isinstance(other, ImplementsGreaterThan):
833 return self.value > other.value
834 return NotImplemented
835
836 @functools.total_ordering
837 class ImplementsLessThanEqualTo:
838 def __init__(self, value):
839 self.value = value
840 def __eq__(self, other):
841 if isinstance(other, ImplementsLessThanEqualTo):
842 return self.value == other.value
843 return False
844 def __le__(self, other):
845 if isinstance(other, ImplementsLessThanEqualTo):
846 return self.value <= other.value
847 return NotImplemented
848
849 @functools.total_ordering
850 class ImplementsGreaterThanEqualTo:
851 def __init__(self, value):
852 self.value = value
853 def __eq__(self, other):
854 if isinstance(other, ImplementsGreaterThanEqualTo):
855 return self.value == other.value
856 return False
857 def __ge__(self, other):
858 if isinstance(other, ImplementsGreaterThanEqualTo):
859 return self.value >= other.value
860 return NotImplemented
861
862 @functools.total_ordering
863 class ComparatorNotImplemented:
864 def __init__(self, value):
865 self.value = value
866 def __eq__(self, other):
867 if isinstance(other, ComparatorNotImplemented):
868 return self.value == other.value
869 return False
870 def __lt__(self, other):
871 return NotImplemented
872
873 with self.subTest("LT < 1"), self.assertRaises(TypeError):
874 ImplementsLessThan(-1) < 1
875
876 with self.subTest("LT < LE"), self.assertRaises(TypeError):
877 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
878
879 with self.subTest("LT < GT"), self.assertRaises(TypeError):
880 ImplementsLessThan(1) < ImplementsGreaterThan(1)
881
882 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
883 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
884
885 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
886 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
887
888 with self.subTest("GT > GE"), self.assertRaises(TypeError):
889 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
890
891 with self.subTest("GT > LT"), self.assertRaises(TypeError):
892 ImplementsGreaterThan(5) > ImplementsLessThan(5)
893
894 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
895 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
896
897 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
898 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
899
900 with self.subTest("GE when equal"):
901 a = ComparatorNotImplemented(8)
902 b = ComparatorNotImplemented(8)
903 self.assertEqual(a, b)
904 with self.assertRaises(TypeError):
905 a >= b
906
907 with self.subTest("LE when equal"):
908 a = ComparatorNotImplemented(9)
909 b = ComparatorNotImplemented(9)
910 self.assertEqual(a, b)
911 with self.assertRaises(TypeError):
912 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +0200913
Serhiy Storchaka697a5262015-01-01 15:23:12 +0200914 def test_pickle(self):
915 for proto in range(4, pickle.HIGHEST_PROTOCOL + 1):
916 for name in '__lt__', '__gt__', '__le__', '__ge__':
917 with self.subTest(method=name, proto=proto):
918 method = getattr(Orderable_LT, name)
919 method_copy = pickle.loads(pickle.dumps(method, proto))
920 self.assertIs(method_copy, method)
921
922@functools.total_ordering
923class Orderable_LT:
924 def __init__(self, value):
925 self.value = value
926 def __lt__(self, other):
927 return self.value < other.value
928 def __eq__(self, other):
929 return self.value == other.value
930
931
Serhiy Storchaka46c56112015-05-24 21:53:49 +0300932class TestLRU:
Georg Brandl2e7346a2010-07-31 18:09:23 +0000933
934 def test_lru(self):
935 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100936 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +0300937 f = self.module.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000938 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000939 self.assertEqual(maxsize, 20)
940 self.assertEqual(currsize, 0)
941 self.assertEqual(hits, 0)
942 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000943
944 domain = range(5)
945 for i in range(1000):
946 x, y = choice(domain), choice(domain)
947 actual = f(x, y)
948 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +0000949 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000950 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000951 self.assertTrue(hits > misses)
952 self.assertEqual(hits + misses, 1000)
953 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000954
Raymond Hettinger02566ec2010-09-04 22:46:06 +0000955 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +0000956 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000957 self.assertEqual(hits, 0)
958 self.assertEqual(misses, 0)
959 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000960 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000961 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000962 self.assertEqual(hits, 0)
963 self.assertEqual(misses, 1)
964 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000965
Nick Coghlan98876832010-08-17 06:17:18 +0000966 # Test bypassing the cache
967 self.assertIs(f.__wrapped__, orig)
968 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000969 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000970 self.assertEqual(hits, 0)
971 self.assertEqual(misses, 1)
972 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +0000973
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000974 # test size zero (which means "never-cache")
Serhiy Storchaka46c56112015-05-24 21:53:49 +0300975 @self.module.lru_cache(0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000976 def f():
977 nonlocal f_cnt
978 f_cnt += 1
979 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000980 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000981 f_cnt = 0
982 for i in range(5):
983 self.assertEqual(f(), 20)
984 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000985 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000986 self.assertEqual(hits, 0)
987 self.assertEqual(misses, 5)
988 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000989
990 # test size one
Serhiy Storchaka46c56112015-05-24 21:53:49 +0300991 @self.module.lru_cache(1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000992 def f():
993 nonlocal f_cnt
994 f_cnt += 1
995 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000996 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000997 f_cnt = 0
998 for i in range(5):
999 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001000 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001001 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001002 self.assertEqual(hits, 4)
1003 self.assertEqual(misses, 1)
1004 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001005
Raymond Hettingerf3098282010-08-15 03:30:45 +00001006 # test size two
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001007 @self.module.lru_cache(2)
Raymond Hettingerf3098282010-08-15 03:30:45 +00001008 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001009 nonlocal f_cnt
1010 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +00001011 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +00001012 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001013 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +00001014 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
1015 # * * * *
1016 self.assertEqual(f(x), x*10)
1017 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001018 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001019 self.assertEqual(hits, 12)
1020 self.assertEqual(misses, 4)
1021 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001022
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001023 def test_lru_with_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001024 @self.module.lru_cache(maxsize=None)
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001025 def fib(n):
1026 if n < 2:
1027 return n
1028 return fib(n-1) + fib(n-2)
1029 self.assertEqual([fib(n) for n in range(16)],
1030 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1031 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001032 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001033 fib.cache_clear()
1034 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001035 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1036
1037 def test_lru_with_maxsize_negative(self):
1038 @self.module.lru_cache(maxsize=-10)
1039 def eq(n):
1040 return n
1041 for i in (0, 1):
1042 self.assertEqual([eq(n) for n in range(150)], list(range(150)))
1043 self.assertEqual(eq.cache_info(),
1044 self.module._CacheInfo(hits=0, misses=300, maxsize=-10, currsize=1))
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001045
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001046 def test_lru_with_exceptions(self):
1047 # Verify that user_function exceptions get passed through without
1048 # creating a hard-to-read chained exception.
1049 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +01001050 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001051 @self.module.lru_cache(maxsize)
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001052 def func(i):
1053 return 'abc'[i]
1054 self.assertEqual(func(0), 'a')
1055 with self.assertRaises(IndexError) as cm:
1056 func(15)
1057 self.assertIsNone(cm.exception.__context__)
1058 # Verify that the previous exception did not result in a cached entry
1059 with self.assertRaises(IndexError):
1060 func(15)
1061
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001062 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001063 for maxsize in (None, 128):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001064 @self.module.lru_cache(maxsize=maxsize, typed=True)
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001065 def square(x):
1066 return x * x
1067 self.assertEqual(square(3), 9)
1068 self.assertEqual(type(square(3)), type(9))
1069 self.assertEqual(square(3.0), 9.0)
1070 self.assertEqual(type(square(3.0)), type(9.0))
1071 self.assertEqual(square(x=3), 9)
1072 self.assertEqual(type(square(x=3)), type(9))
1073 self.assertEqual(square(x=3.0), 9.0)
1074 self.assertEqual(type(square(x=3.0)), type(9.0))
1075 self.assertEqual(square.cache_info().hits, 4)
1076 self.assertEqual(square.cache_info().misses, 4)
1077
Antoine Pitroub5b37142012-11-13 21:35:40 +01001078 def test_lru_with_keyword_args(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001079 @self.module.lru_cache()
Antoine Pitroub5b37142012-11-13 21:35:40 +01001080 def fib(n):
1081 if n < 2:
1082 return n
1083 return fib(n=n-1) + fib(n=n-2)
1084 self.assertEqual(
1085 [fib(n=number) for number in range(16)],
1086 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1087 )
1088 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001089 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001090 fib.cache_clear()
1091 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001092 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001093
1094 def test_lru_with_keyword_args_maxsize_none(self):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001095 @self.module.lru_cache(maxsize=None)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001096 def fib(n):
1097 if n < 2:
1098 return n
1099 return fib(n=n-1) + fib(n=n-2)
1100 self.assertEqual([fib(n=number) for number in range(16)],
1101 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1102 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001103 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
Antoine Pitroub5b37142012-11-13 21:35:40 +01001104 fib.cache_clear()
1105 self.assertEqual(fib.cache_info(),
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001106 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1107
1108 def test_lru_cache_decoration(self):
1109 def f(zomg: 'zomg_annotation'):
1110 """f doc string"""
1111 return 42
1112 g = self.module.lru_cache()(f)
1113 for attr in self.module.WRAPPER_ASSIGNMENTS:
1114 self.assertEqual(getattr(g, attr), getattr(f, attr))
1115
1116 @unittest.skipUnless(threading, 'This test requires threading.')
1117 def test_lru_cache_threaded(self):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001118 n, m = 5, 11
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001119 def orig(x, y):
1120 return 3 * x + y
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001121 f = self.module.lru_cache(maxsize=n*m)(orig)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001122 hits, misses, maxsize, currsize = f.cache_info()
1123 self.assertEqual(currsize, 0)
1124
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001125 start = threading.Event()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001126 def full(k):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001127 start.wait(10)
1128 for _ in range(m):
Serhiy Storchaka391af752015-06-08 12:44:18 +03001129 self.assertEqual(f(k, 0), orig(k, 0))
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001130
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001131 def clear():
1132 start.wait(10)
1133 for _ in range(2*m):
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001134 f.cache_clear()
1135
1136 orig_si = sys.getswitchinterval()
1137 sys.setswitchinterval(1e-6)
1138 try:
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001139 # create n threads in order to fill cache
Serhiy Storchaka391af752015-06-08 12:44:18 +03001140 threads = [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001141 for k in range(n)]
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001142 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001143 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001144
1145 hits, misses, maxsize, currsize = f.cache_info()
Serhiy Storchaka391af752015-06-08 12:44:18 +03001146 if self.module is py_functools:
1147 # XXX: Why can be not equal?
1148 self.assertLessEqual(misses, n)
1149 self.assertLessEqual(hits, m*n - misses)
1150 else:
1151 self.assertEqual(misses, n)
1152 self.assertEqual(hits, m*n - misses)
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001153 self.assertEqual(currsize, n)
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001154
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001155 # create n threads in order to fill cache and 1 to clear it
1156 threads = [threading.Thread(target=clear)]
Serhiy Storchaka391af752015-06-08 12:44:18 +03001157 threads += [threading.Thread(target=full, args=[k])
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001158 for k in range(n)]
1159 start.clear()
Serhiy Storchakabf2b3b72015-05-30 15:49:17 +03001160 with support.start_threads(threads):
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001161 start.set()
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001162 finally:
1163 sys.setswitchinterval(orig_si)
Antoine Pitroub5b37142012-11-13 21:35:40 +01001164
Serhiy Storchaka77cb1972015-06-08 11:14:31 +03001165 @unittest.skipUnless(threading, 'This test requires threading.')
1166 def test_lru_cache_threaded2(self):
1167 # Simultaneous call with the same arguments
1168 n, m = 5, 7
1169 start = threading.Barrier(n+1)
1170 pause = threading.Barrier(n+1)
1171 stop = threading.Barrier(n+1)
1172 @self.module.lru_cache(maxsize=m*n)
1173 def f(x):
1174 pause.wait(10)
1175 return 3 * x
1176 self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
1177 def test():
1178 for i in range(m):
1179 start.wait(10)
1180 self.assertEqual(f(i), 3 * i)
1181 stop.wait(10)
1182 threads = [threading.Thread(target=test) for k in range(n)]
1183 with support.start_threads(threads):
1184 for i in range(m):
1185 start.wait(10)
1186 stop.reset()
1187 pause.wait(10)
1188 start.reset()
1189 stop.wait(10)
1190 pause.reset()
1191 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
1192
Raymond Hettinger03923422013-03-04 02:52:50 -05001193 def test_need_for_rlock(self):
1194 # This will deadlock on an LRU cache that uses a regular lock
1195
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001196 @self.module.lru_cache(maxsize=10)
Raymond Hettinger03923422013-03-04 02:52:50 -05001197 def test_func(x):
1198 'Used to demonstrate a reentrant lru_cache call within a single thread'
1199 return x
1200
1201 class DoubleEq:
1202 'Demonstrate a reentrant lru_cache call within a single thread'
1203 def __init__(self, x):
1204 self.x = x
1205 def __hash__(self):
1206 return self.x
1207 def __eq__(self, other):
1208 if self.x == 2:
1209 test_func(DoubleEq(1))
1210 return self.x == other.x
1211
1212 test_func(DoubleEq(1)) # Load the cache
1213 test_func(DoubleEq(2)) # Load the cache
1214 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1215 DoubleEq(2)) # Verify the correct return value
1216
Raymond Hettinger4d588972014-08-12 12:44:52 -07001217 def test_early_detection_of_bad_call(self):
1218 # Issue #22184
1219 with self.assertRaises(TypeError):
1220 @functools.lru_cache
1221 def f():
1222 pass
1223
Serhiy Storchakae7070f02015-06-08 11:19:24 +03001224 def test_lru_method(self):
1225 class X(int):
1226 f_cnt = 0
1227 @self.module.lru_cache(2)
1228 def f(self, x):
1229 self.f_cnt += 1
1230 return x*10+self
1231 a = X(5)
1232 b = X(5)
1233 c = X(7)
1234 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
1235
1236 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
1237 self.assertEqual(a.f(x), x*10 + 5)
1238 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
1239 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
1240
1241 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
1242 self.assertEqual(b.f(x), x*10 + 5)
1243 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
1244 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
1245
1246 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
1247 self.assertEqual(c.f(x), x*10 + 7)
1248 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
1249 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
1250
1251 self.assertEqual(a.f.cache_info(), X.f.cache_info())
1252 self.assertEqual(b.f.cache_info(), X.f.cache_info())
1253 self.assertEqual(c.f.cache_info(), X.f.cache_info())
1254
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001255 def test_pickle(self):
1256 cls = self.__class__
1257 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
1258 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1259 with self.subTest(proto=proto, func=f):
1260 f_copy = pickle.loads(pickle.dumps(f, proto))
1261 self.assertIs(f_copy, f)
1262
1263 def test_copy(self):
1264 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001265 def orig(x, y):
1266 return 3 * x + y
1267 part = self.module.partial(orig, 2)
1268 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1269 self.module.lru_cache(2)(part))
1270 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001271 with self.subTest(func=f):
1272 f_copy = copy.copy(f)
1273 self.assertIs(f_copy, f)
1274
1275 def test_deepcopy(self):
1276 cls = self.__class__
Serhiy Storchakae4d65e32015-12-28 23:58:07 +02001277 def orig(x, y):
1278 return 3 * x + y
1279 part = self.module.partial(orig, 2)
1280 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
1281 self.module.lru_cache(2)(part))
1282 for f in funcs:
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001283 with self.subTest(func=f):
1284 f_copy = copy.deepcopy(f)
1285 self.assertIs(f_copy, f)
1286
1287
1288@py_functools.lru_cache()
1289def py_cached_func(x, y):
1290 return 3 * x + y
1291
1292@c_functools.lru_cache()
1293def c_cached_func(x, y):
1294 return 3 * x + y
1295
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001296
1297class TestLRUPy(TestLRU, unittest.TestCase):
1298 module = py_functools
Serhiy Storchaka45120f22015-10-24 09:49:56 +03001299 cached_func = py_cached_func,
1300
1301 @module.lru_cache()
1302 def cached_meth(self, x, y):
1303 return 3 * x + y
1304
1305 @staticmethod
1306 @module.lru_cache()
1307 def cached_staticmeth(x, y):
1308 return 3 * x + y
1309
1310
1311class TestLRUC(TestLRU, unittest.TestCase):
1312 module = c_functools
1313 cached_func = c_cached_func,
1314
1315 @module.lru_cache()
1316 def cached_meth(self, x, y):
1317 return 3 * x + y
1318
1319 @staticmethod
1320 @module.lru_cache()
1321 def cached_staticmeth(x, y):
1322 return 3 * x + y
Serhiy Storchaka46c56112015-05-24 21:53:49 +03001323
Raymond Hettinger03923422013-03-04 02:52:50 -05001324
Łukasz Langa6f692512013-06-05 12:20:24 +02001325class TestSingleDispatch(unittest.TestCase):
1326 def test_simple_overloads(self):
1327 @functools.singledispatch
1328 def g(obj):
1329 return "base"
1330 def g_int(i):
1331 return "integer"
1332 g.register(int, g_int)
1333 self.assertEqual(g("str"), "base")
1334 self.assertEqual(g(1), "integer")
1335 self.assertEqual(g([1,2,3]), "base")
1336
1337 def test_mro(self):
1338 @functools.singledispatch
1339 def g(obj):
1340 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001341 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001342 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001343 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001344 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001345 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001346 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001347 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001348 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001349 def g_A(a):
1350 return "A"
1351 def g_B(b):
1352 return "B"
1353 g.register(A, g_A)
1354 g.register(B, g_B)
1355 self.assertEqual(g(A()), "A")
1356 self.assertEqual(g(B()), "B")
1357 self.assertEqual(g(C()), "A")
1358 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001359
1360 def test_register_decorator(self):
1361 @functools.singledispatch
1362 def g(obj):
1363 return "base"
1364 @g.register(int)
1365 def g_int(i):
1366 return "int %s" % (i,)
1367 self.assertEqual(g(""), "base")
1368 self.assertEqual(g(12), "int 12")
1369 self.assertIs(g.dispatch(int), g_int)
1370 self.assertIs(g.dispatch(object), g.dispatch(str))
1371 # Note: in the assert above this is not g.
1372 # @singledispatch returns the wrapper.
1373
1374 def test_wrapping_attributes(self):
1375 @functools.singledispatch
1376 def g(obj):
1377 "Simple test"
1378 return "Test"
1379 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02001380 if sys.flags.optimize < 2:
1381 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02001382
1383 @unittest.skipUnless(decimal, 'requires _decimal')
1384 @support.cpython_only
1385 def test_c_classes(self):
1386 @functools.singledispatch
1387 def g(obj):
1388 return "base"
1389 @g.register(decimal.DecimalException)
1390 def _(obj):
1391 return obj.args
1392 subn = decimal.Subnormal("Exponent < Emin")
1393 rnd = decimal.Rounded("Number got rounded")
1394 self.assertEqual(g(subn), ("Exponent < Emin",))
1395 self.assertEqual(g(rnd), ("Number got rounded",))
1396 @g.register(decimal.Subnormal)
1397 def _(obj):
1398 return "Too small to care."
1399 self.assertEqual(g(subn), "Too small to care.")
1400 self.assertEqual(g(rnd), ("Number got rounded",))
1401
1402 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001403 # None of the examples in this test depend on haystack ordering.
Łukasz Langa6f692512013-06-05 12:20:24 +02001404 c = collections
1405 mro = functools._compose_mro
1406 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1407 for haystack in permutations(bases):
1408 m = mro(dict, haystack)
Łukasz Langa3720c772013-07-01 16:00:38 +02001409 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, c.Sized,
1410 c.Iterable, c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001411 bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict]
1412 for haystack in permutations(bases):
1413 m = mro(c.ChainMap, haystack)
1414 self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping,
1415 c.Sized, c.Iterable, c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001416
1417 # If there's a generic function with implementations registered for
1418 # both Sized and Container, passing a defaultdict to it results in an
1419 # ambiguous dispatch which will cause a RuntimeError (see
1420 # test_mro_conflicts).
1421 bases = [c.Container, c.Sized, str]
1422 for haystack in permutations(bases):
1423 m = mro(c.defaultdict, [c.Sized, c.Container, str])
1424 self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container,
1425 object])
1426
1427 # MutableSequence below is registered directly on D. In other words, it
1428 # preceeds MutableMapping which means single dispatch will always
1429 # choose MutableSequence here.
1430 class D(c.defaultdict):
1431 pass
1432 c.MutableSequence.register(D)
1433 bases = [c.MutableSequence, c.MutableMapping]
1434 for haystack in permutations(bases):
1435 m = mro(D, bases)
1436 self.assertEqual(m, [D, c.MutableSequence, c.Sequence,
1437 c.defaultdict, dict, c.MutableMapping,
1438 c.Mapping, c.Sized, c.Iterable, c.Container,
1439 object])
1440
1441 # Container and Callable are registered on different base classes and
1442 # a generic function supporting both should always pick the Callable
1443 # implementation if a C instance is passed.
1444 class C(c.defaultdict):
1445 def __call__(self):
1446 pass
1447 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1448 for haystack in permutations(bases):
1449 m = mro(C, haystack)
1450 self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping,
1451 c.Sized, c.Iterable, c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001452
1453 def test_register_abc(self):
1454 c = collections
1455 d = {"a": "b"}
1456 l = [1, 2, 3]
1457 s = {object(), None}
1458 f = frozenset(s)
1459 t = (1, 2, 3)
1460 @functools.singledispatch
1461 def g(obj):
1462 return "base"
1463 self.assertEqual(g(d), "base")
1464 self.assertEqual(g(l), "base")
1465 self.assertEqual(g(s), "base")
1466 self.assertEqual(g(f), "base")
1467 self.assertEqual(g(t), "base")
1468 g.register(c.Sized, lambda obj: "sized")
1469 self.assertEqual(g(d), "sized")
1470 self.assertEqual(g(l), "sized")
1471 self.assertEqual(g(s), "sized")
1472 self.assertEqual(g(f), "sized")
1473 self.assertEqual(g(t), "sized")
1474 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1475 self.assertEqual(g(d), "mutablemapping")
1476 self.assertEqual(g(l), "sized")
1477 self.assertEqual(g(s), "sized")
1478 self.assertEqual(g(f), "sized")
1479 self.assertEqual(g(t), "sized")
1480 g.register(c.ChainMap, lambda obj: "chainmap")
1481 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1482 self.assertEqual(g(l), "sized")
1483 self.assertEqual(g(s), "sized")
1484 self.assertEqual(g(f), "sized")
1485 self.assertEqual(g(t), "sized")
1486 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1487 self.assertEqual(g(d), "mutablemapping")
1488 self.assertEqual(g(l), "mutablesequence")
1489 self.assertEqual(g(s), "sized")
1490 self.assertEqual(g(f), "sized")
1491 self.assertEqual(g(t), "sized")
1492 g.register(c.MutableSet, lambda obj: "mutableset")
1493 self.assertEqual(g(d), "mutablemapping")
1494 self.assertEqual(g(l), "mutablesequence")
1495 self.assertEqual(g(s), "mutableset")
1496 self.assertEqual(g(f), "sized")
1497 self.assertEqual(g(t), "sized")
1498 g.register(c.Mapping, lambda obj: "mapping")
1499 self.assertEqual(g(d), "mutablemapping") # not specific enough
1500 self.assertEqual(g(l), "mutablesequence")
1501 self.assertEqual(g(s), "mutableset")
1502 self.assertEqual(g(f), "sized")
1503 self.assertEqual(g(t), "sized")
1504 g.register(c.Sequence, lambda obj: "sequence")
1505 self.assertEqual(g(d), "mutablemapping")
1506 self.assertEqual(g(l), "mutablesequence")
1507 self.assertEqual(g(s), "mutableset")
1508 self.assertEqual(g(f), "sized")
1509 self.assertEqual(g(t), "sequence")
1510 g.register(c.Set, lambda obj: "set")
1511 self.assertEqual(g(d), "mutablemapping")
1512 self.assertEqual(g(l), "mutablesequence")
1513 self.assertEqual(g(s), "mutableset")
1514 self.assertEqual(g(f), "set")
1515 self.assertEqual(g(t), "sequence")
1516 g.register(dict, lambda obj: "dict")
1517 self.assertEqual(g(d), "dict")
1518 self.assertEqual(g(l), "mutablesequence")
1519 self.assertEqual(g(s), "mutableset")
1520 self.assertEqual(g(f), "set")
1521 self.assertEqual(g(t), "sequence")
1522 g.register(list, lambda obj: "list")
1523 self.assertEqual(g(d), "dict")
1524 self.assertEqual(g(l), "list")
1525 self.assertEqual(g(s), "mutableset")
1526 self.assertEqual(g(f), "set")
1527 self.assertEqual(g(t), "sequence")
1528 g.register(set, lambda obj: "concrete-set")
1529 self.assertEqual(g(d), "dict")
1530 self.assertEqual(g(l), "list")
1531 self.assertEqual(g(s), "concrete-set")
1532 self.assertEqual(g(f), "set")
1533 self.assertEqual(g(t), "sequence")
1534 g.register(frozenset, lambda obj: "frozen-set")
1535 self.assertEqual(g(d), "dict")
1536 self.assertEqual(g(l), "list")
1537 self.assertEqual(g(s), "concrete-set")
1538 self.assertEqual(g(f), "frozen-set")
1539 self.assertEqual(g(t), "sequence")
1540 g.register(tuple, lambda obj: "tuple")
1541 self.assertEqual(g(d), "dict")
1542 self.assertEqual(g(l), "list")
1543 self.assertEqual(g(s), "concrete-set")
1544 self.assertEqual(g(f), "frozen-set")
1545 self.assertEqual(g(t), "tuple")
1546
Łukasz Langa3720c772013-07-01 16:00:38 +02001547 def test_c3_abc(self):
1548 c = collections
1549 mro = functools._c3_mro
1550 class A(object):
1551 pass
1552 class B(A):
1553 def __len__(self):
1554 return 0 # implies Sized
1555 @c.Container.register
1556 class C(object):
1557 pass
1558 class D(object):
1559 pass # unrelated
1560 class X(D, C, B):
1561 def __call__(self):
1562 pass # implies Callable
1563 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
1564 for abcs in permutations([c.Sized, c.Callable, c.Container]):
1565 self.assertEqual(mro(X, abcs=abcs), expected)
1566 # unrelated ABCs don't appear in the resulting MRO
1567 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
1568 self.assertEqual(mro(X, abcs=many_abcs), expected)
1569
Yury Selivanov77a8cd62015-08-18 14:20:00 -04001570 def test_false_meta(self):
1571 # see issue23572
1572 class MetaA(type):
1573 def __len__(self):
1574 return 0
1575 class A(metaclass=MetaA):
1576 pass
1577 class AA(A):
1578 pass
1579 @functools.singledispatch
1580 def fun(a):
1581 return 'base A'
1582 @fun.register(A)
1583 def _(a):
1584 return 'fun A'
1585 aa = AA()
1586 self.assertEqual(fun(aa), 'fun A')
1587
Łukasz Langa6f692512013-06-05 12:20:24 +02001588 def test_mro_conflicts(self):
1589 c = collections
Łukasz Langa6f692512013-06-05 12:20:24 +02001590 @functools.singledispatch
1591 def g(arg):
1592 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02001593 class O(c.Sized):
1594 def __len__(self):
1595 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001596 o = O()
1597 self.assertEqual(g(o), "base")
1598 g.register(c.Iterable, lambda arg: "iterable")
1599 g.register(c.Container, lambda arg: "container")
1600 g.register(c.Sized, lambda arg: "sized")
1601 g.register(c.Set, lambda arg: "set")
1602 self.assertEqual(g(o), "sized")
1603 c.Iterable.register(O)
1604 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
1605 c.Container.register(O)
1606 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02001607 c.Set.register(O)
1608 self.assertEqual(g(o), "set") # because c.Set is a subclass of
1609 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02001610 class P:
1611 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02001612 p = P()
1613 self.assertEqual(g(p), "base")
1614 c.Iterable.register(P)
1615 self.assertEqual(g(p), "iterable")
1616 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02001617 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02001618 g(p)
Łukasz Langa3720c772013-07-01 16:00:38 +02001619 self.assertIn(
1620 str(re_one.exception),
1621 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1622 "or <class 'collections.abc.Iterable'>"),
1623 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
1624 "or <class 'collections.abc.Container'>")),
1625 )
Łukasz Langa6f692512013-06-05 12:20:24 +02001626 class Q(c.Sized):
1627 def __len__(self):
1628 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001629 q = Q()
1630 self.assertEqual(g(q), "sized")
1631 c.Iterable.register(Q)
1632 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
1633 c.Set.register(Q)
1634 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02001635 # c.Sized and c.Iterable
1636 @functools.singledispatch
1637 def h(arg):
1638 return "base"
1639 @h.register(c.Sized)
1640 def _(arg):
1641 return "sized"
1642 @h.register(c.Container)
1643 def _(arg):
1644 return "container"
1645 # Even though Sized and Container are explicit bases of MutableMapping,
1646 # this ABC is implicitly registered on defaultdict which makes all of
1647 # MutableMapping's bases implicit as well from defaultdict's
1648 # perspective.
1649 with self.assertRaises(RuntimeError) as re_two:
1650 h(c.defaultdict(lambda: 0))
1651 self.assertIn(
1652 str(re_two.exception),
1653 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1654 "or <class 'collections.abc.Sized'>"),
1655 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1656 "or <class 'collections.abc.Container'>")),
1657 )
1658 class R(c.defaultdict):
1659 pass
1660 c.MutableSequence.register(R)
1661 @functools.singledispatch
1662 def i(arg):
1663 return "base"
1664 @i.register(c.MutableMapping)
1665 def _(arg):
1666 return "mapping"
1667 @i.register(c.MutableSequence)
1668 def _(arg):
1669 return "sequence"
1670 r = R()
1671 self.assertEqual(i(r), "sequence")
1672 class S:
1673 pass
1674 class T(S, c.Sized):
1675 def __len__(self):
1676 return 0
1677 t = T()
1678 self.assertEqual(h(t), "sized")
1679 c.Container.register(T)
1680 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
1681 class U:
1682 def __len__(self):
1683 return 0
1684 u = U()
1685 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
1686 # from the existence of __len__()
1687 c.Container.register(U)
1688 # There is no preference for registered versus inferred ABCs.
1689 with self.assertRaises(RuntimeError) as re_three:
1690 h(u)
1691 self.assertIn(
1692 str(re_three.exception),
1693 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1694 "or <class 'collections.abc.Sized'>"),
1695 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1696 "or <class 'collections.abc.Container'>")),
1697 )
1698 class V(c.Sized, S):
1699 def __len__(self):
1700 return 0
1701 @functools.singledispatch
1702 def j(arg):
1703 return "base"
1704 @j.register(S)
1705 def _(arg):
1706 return "s"
1707 @j.register(c.Container)
1708 def _(arg):
1709 return "container"
1710 v = V()
1711 self.assertEqual(j(v), "s")
1712 c.Container.register(V)
1713 self.assertEqual(j(v), "container") # because it ends up right after
1714 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02001715
1716 def test_cache_invalidation(self):
1717 from collections import UserDict
1718 class TracingDict(UserDict):
1719 def __init__(self, *args, **kwargs):
1720 super(TracingDict, self).__init__(*args, **kwargs)
1721 self.set_ops = []
1722 self.get_ops = []
1723 def __getitem__(self, key):
1724 result = self.data[key]
1725 self.get_ops.append(key)
1726 return result
1727 def __setitem__(self, key, value):
1728 self.set_ops.append(key)
1729 self.data[key] = value
1730 def clear(self):
1731 self.data.clear()
1732 _orig_wkd = functools.WeakKeyDictionary
1733 td = TracingDict()
1734 functools.WeakKeyDictionary = lambda: td
1735 c = collections
1736 @functools.singledispatch
1737 def g(arg):
1738 return "base"
1739 d = {}
1740 l = []
1741 self.assertEqual(len(td), 0)
1742 self.assertEqual(g(d), "base")
1743 self.assertEqual(len(td), 1)
1744 self.assertEqual(td.get_ops, [])
1745 self.assertEqual(td.set_ops, [dict])
1746 self.assertEqual(td.data[dict], g.registry[object])
1747 self.assertEqual(g(l), "base")
1748 self.assertEqual(len(td), 2)
1749 self.assertEqual(td.get_ops, [])
1750 self.assertEqual(td.set_ops, [dict, list])
1751 self.assertEqual(td.data[dict], g.registry[object])
1752 self.assertEqual(td.data[list], g.registry[object])
1753 self.assertEqual(td.data[dict], td.data[list])
1754 self.assertEqual(g(l), "base")
1755 self.assertEqual(g(d), "base")
1756 self.assertEqual(td.get_ops, [list, dict])
1757 self.assertEqual(td.set_ops, [dict, list])
1758 g.register(list, lambda arg: "list")
1759 self.assertEqual(td.get_ops, [list, dict])
1760 self.assertEqual(len(td), 0)
1761 self.assertEqual(g(d), "base")
1762 self.assertEqual(len(td), 1)
1763 self.assertEqual(td.get_ops, [list, dict])
1764 self.assertEqual(td.set_ops, [dict, list, dict])
1765 self.assertEqual(td.data[dict],
1766 functools._find_impl(dict, g.registry))
1767 self.assertEqual(g(l), "list")
1768 self.assertEqual(len(td), 2)
1769 self.assertEqual(td.get_ops, [list, dict])
1770 self.assertEqual(td.set_ops, [dict, list, dict, list])
1771 self.assertEqual(td.data[list],
1772 functools._find_impl(list, g.registry))
1773 class X:
1774 pass
1775 c.MutableMapping.register(X) # Will not invalidate the cache,
1776 # not using ABCs yet.
1777 self.assertEqual(g(d), "base")
1778 self.assertEqual(g(l), "list")
1779 self.assertEqual(td.get_ops, [list, dict, dict, list])
1780 self.assertEqual(td.set_ops, [dict, list, dict, list])
1781 g.register(c.Sized, lambda arg: "sized")
1782 self.assertEqual(len(td), 0)
1783 self.assertEqual(g(d), "sized")
1784 self.assertEqual(len(td), 1)
1785 self.assertEqual(td.get_ops, [list, dict, dict, list])
1786 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
1787 self.assertEqual(g(l), "list")
1788 self.assertEqual(len(td), 2)
1789 self.assertEqual(td.get_ops, [list, dict, dict, list])
1790 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1791 self.assertEqual(g(l), "list")
1792 self.assertEqual(g(d), "sized")
1793 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
1794 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1795 g.dispatch(list)
1796 g.dispatch(dict)
1797 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
1798 list, dict])
1799 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1800 c.MutableSet.register(X) # Will invalidate the cache.
1801 self.assertEqual(len(td), 2) # Stale cache.
1802 self.assertEqual(g(l), "list")
1803 self.assertEqual(len(td), 1)
1804 g.register(c.MutableMapping, lambda arg: "mutablemapping")
1805 self.assertEqual(len(td), 0)
1806 self.assertEqual(g(d), "mutablemapping")
1807 self.assertEqual(len(td), 1)
1808 self.assertEqual(g(l), "list")
1809 self.assertEqual(len(td), 2)
1810 g.register(dict, lambda arg: "dict")
1811 self.assertEqual(g(d), "dict")
1812 self.assertEqual(g(l), "list")
1813 g._clear_cache()
1814 self.assertEqual(len(td), 0)
1815 functools.WeakKeyDictionary = _orig_wkd
1816
1817
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001818if __name__ == '__main__':
Zachary Ware38c707e2015-04-13 15:00:43 -05001819 unittest.main()