blob: fbb43e43e65813b86c41524be26ee4f28e869f6b [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettinger003be522011-05-03 11:01:32 -07002import collections
Łukasz Langa6f692512013-06-05 12:20:24 +02003from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00004import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00005from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02006import sys
7from test import support
8import unittest
9from weakref import proxy
Raymond Hettinger9c323f82005-02-28 19:39:44 +000010
Antoine Pitroub5b37142012-11-13 21:35:40 +010011import functools
12
Antoine Pitroub5b37142012-11-13 21:35:40 +010013py_functools = support.import_fresh_module('functools', blocked=['_functools'])
14c_functools = support.import_fresh_module('functools', fresh=['_functools'])
15
Łukasz Langa6f692512013-06-05 12:20:24 +020016decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
17
18
Raymond Hettinger9c323f82005-02-28 19:39:44 +000019def capture(*args, **kw):
20 """capture all positional and keyword arguments"""
21 return args, kw
22
Łukasz Langa6f692512013-06-05 12:20:24 +020023
Jack Diederiche0cbd692009-04-01 04:27:09 +000024def signature(part):
25 """ return the signature of a partial object """
26 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000027
Łukasz Langa6f692512013-06-05 12:20:24 +020028
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020029class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000030
31 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010032 p = self.partial(capture, 1, 2, a=10, b=20)
33 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000034 self.assertEqual(p(3, 4, b=30, c=40),
35 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010036 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000037 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000038
39 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010040 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000041 # attributes should be readable
42 self.assertEqual(p.func, capture)
43 self.assertEqual(p.args, (1, 2))
44 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000045
46 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010047 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000048 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010049 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000050 except TypeError:
51 pass
52 else:
53 self.fail('First arg not checked for callability')
54
55 def test_protection_of_callers_dict_argument(self):
56 # a caller's dictionary should not be altered by partial
57 def func(a=10, b=20):
58 return a
59 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010060 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000061 self.assertEqual(p(**d), 3)
62 self.assertEqual(d, {'a':3})
63 p(b=7)
64 self.assertEqual(d, {'a':3})
65
66 def test_arg_combinations(self):
67 # exercise special code paths for zero args in either partial
68 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010069 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000070 self.assertEqual(p(), ((), {}))
71 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010072 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000073 self.assertEqual(p(), ((1,2), {}))
74 self.assertEqual(p(3,4), ((1,2,3,4), {}))
75
76 def test_kw_combinations(self):
77 # exercise special code paths for no keyword args in
78 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010079 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000080 self.assertEqual(p(), ((), {}))
81 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010082 p = self.partial(capture, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000083 self.assertEqual(p(), ((), {'a':1}))
84 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
85 # keyword args in the call override those in the partial object
86 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
87
88 def test_positional(self):
89 # make sure positional arguments are captured correctly
90 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +010091 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000092 expected = args + ('x',)
93 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +000094 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +000095
96 def test_keyword(self):
97 # make sure keyword arguments are captured correctly
98 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +010099 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000100 expected = {'a':a,'x':None}
101 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000102 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000103
104 def test_no_side_effects(self):
105 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100106 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000107 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000108 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000109 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000110 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000111
112 def test_error_propagation(self):
113 def f(x, y):
114 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100115 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
116 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
117 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
118 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000119
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000120 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100121 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000122 p = proxy(f)
123 self.assertEqual(f.func, p.func)
124 f = None
125 self.assertRaises(ReferenceError, getattr, p, 'func')
126
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000127 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000128 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100129 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000130 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100131 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000132 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000133
Łukasz Langa6f692512013-06-05 12:20:24 +0200134
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200135@unittest.skipUnless(c_functools, 'requires the C _functools module')
136class TestPartialC(TestPartial, unittest.TestCase):
137 if c_functools:
138 partial = c_functools.partial
139
Zachary Ware101d9e72013-12-08 00:44:27 -0600140 def test_attributes_unwritable(self):
141 # attributes should not be writable
142 p = self.partial(capture, 1, 2, a=10, b=20)
143 self.assertRaises(AttributeError, setattr, p, 'func', map)
144 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
145 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
146
147 p = self.partial(hex)
148 try:
149 del p.__dict__
150 except TypeError:
151 pass
152 else:
153 self.fail('partial object allowed __dict__ to be deleted')
154
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000155 def test_repr(self):
156 args = (object(), object())
157 args_repr = ', '.join(repr(a) for a in args)
Christian Heimesd0628922013-11-22 01:22:47 +0100158 #kwargs = {'a': object(), 'b': object()}
159 kwargs = {'a': object()}
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000160 kwargs_repr = ', '.join("%s=%r" % (k, v) for k, v in kwargs.items())
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200161 if self.partial is c_functools.partial:
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000162 name = 'functools.partial'
163 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100164 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000165
Antoine Pitroub5b37142012-11-13 21:35:40 +0100166 f = self.partial(capture)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000167 self.assertEqual('{}({!r})'.format(name, capture),
168 repr(f))
169
Antoine Pitroub5b37142012-11-13 21:35:40 +0100170 f = self.partial(capture, *args)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000171 self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr),
172 repr(f))
173
Antoine Pitroub5b37142012-11-13 21:35:40 +0100174 f = self.partial(capture, **kwargs)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000175 self.assertEqual('{}({!r}, {})'.format(name, capture, kwargs_repr),
176 repr(f))
177
Antoine Pitroub5b37142012-11-13 21:35:40 +0100178 f = self.partial(capture, *args, **kwargs)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000179 self.assertEqual('{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr),
180 repr(f))
181
Jack Diederiche0cbd692009-04-01 04:27:09 +0000182 def test_pickle(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100183 f = self.partial(signature, 'asdf', bar=True)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000184 f.add_something_to__dict__ = True
Serhiy Storchakabad12572014-12-15 14:03:42 +0200185 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
186 f_copy = pickle.loads(pickle.dumps(f, proto))
187 self.assertEqual(signature(f), signature(f_copy))
Jack Diederiche0cbd692009-04-01 04:27:09 +0000188
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200189 # Issue 6083: Reference counting bug
190 def test_setstate_refcount(self):
191 class BadSequence:
192 def __len__(self):
193 return 4
194 def __getitem__(self, key):
195 if key == 0:
196 return max
197 elif key == 1:
198 return tuple(range(1000000))
199 elif key in (2, 3):
200 return {}
201 raise IndexError
202
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200203 f = self.partial(object)
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200204 self.assertRaisesRegex(SystemError,
205 "new style getargs format but argument is not a tuple",
206 f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000207
Łukasz Langa6f692512013-06-05 12:20:24 +0200208
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200209class TestPartialPy(TestPartial, unittest.TestCase):
210 partial = staticmethod(py_functools.partial)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000211
Łukasz Langa6f692512013-06-05 12:20:24 +0200212
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200213if c_functools:
214 class PartialSubclass(c_functools.partial):
215 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100216
Łukasz Langa6f692512013-06-05 12:20:24 +0200217
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200218@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200219class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200220 if c_functools:
221 partial = PartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000222
Łukasz Langa6f692512013-06-05 12:20:24 +0200223
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000224class TestPartialMethod(unittest.TestCase):
225
226 class A(object):
227 nothing = functools.partialmethod(capture)
228 positional = functools.partialmethod(capture, 1)
229 keywords = functools.partialmethod(capture, a=2)
230 both = functools.partialmethod(capture, 3, b=4)
231
232 nested = functools.partialmethod(positional, 5)
233
234 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
235
236 static = functools.partialmethod(staticmethod(capture), 8)
237 cls = functools.partialmethod(classmethod(capture), d=9)
238
239 a = A()
240
241 def test_arg_combinations(self):
242 self.assertEqual(self.a.nothing(), ((self.a,), {}))
243 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
244 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
245 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
246
247 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
248 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
249 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
250 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
251
252 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
253 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
254 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
255 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
256
257 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
258 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
259 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
260 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
261
262 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
263
264 def test_nested(self):
265 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
266 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
267 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
268 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
269
270 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
271
272 def test_over_partial(self):
273 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
274 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
275 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
276 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
277
278 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
279
280 def test_bound_method_introspection(self):
281 obj = self.a
282 self.assertIs(obj.both.__self__, obj)
283 self.assertIs(obj.nested.__self__, obj)
284 self.assertIs(obj.over_partial.__self__, obj)
285 self.assertIs(obj.cls.__self__, self.A)
286 self.assertIs(self.A.cls.__self__, self.A)
287
288 def test_unbound_method_retrieval(self):
289 obj = self.A
290 self.assertFalse(hasattr(obj.both, "__self__"))
291 self.assertFalse(hasattr(obj.nested, "__self__"))
292 self.assertFalse(hasattr(obj.over_partial, "__self__"))
293 self.assertFalse(hasattr(obj.static, "__self__"))
294 self.assertFalse(hasattr(self.a.static, "__self__"))
295
296 def test_descriptors(self):
297 for obj in [self.A, self.a]:
298 with self.subTest(obj=obj):
299 self.assertEqual(obj.static(), ((8,), {}))
300 self.assertEqual(obj.static(5), ((8, 5), {}))
301 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
302 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
303
304 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
305 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
306 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
307 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
308
309 def test_overriding_keywords(self):
310 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
311 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
312
313 def test_invalid_args(self):
314 with self.assertRaises(TypeError):
315 class B(object):
316 method = functools.partialmethod(None, 1)
317
318 def test_repr(self):
319 self.assertEqual(repr(vars(self.A)['both']),
320 'functools.partialmethod({}, 3, b=4)'.format(capture))
321
322 def test_abstract(self):
323 class Abstract(abc.ABCMeta):
324
325 @abc.abstractmethod
326 def add(self, x, y):
327 pass
328
329 add5 = functools.partialmethod(add, 5)
330
331 self.assertTrue(Abstract.add.__isabstractmethod__)
332 self.assertTrue(Abstract.add5.__isabstractmethod__)
333
334 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
335 self.assertFalse(getattr(func, '__isabstractmethod__', False))
336
337
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000338class TestUpdateWrapper(unittest.TestCase):
339
340 def check_wrapper(self, wrapper, wrapped,
341 assigned=functools.WRAPPER_ASSIGNMENTS,
342 updated=functools.WRAPPER_UPDATES):
343 # Check attributes were assigned
344 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000345 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000346 # Check attributes were updated
347 for name in updated:
348 wrapper_attr = getattr(wrapper, name)
349 wrapped_attr = getattr(wrapped, name)
350 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000351 if name == "__dict__" and key == "__wrapped__":
352 # __wrapped__ is overwritten by the update code
353 continue
354 self.assertIs(wrapped_attr[key], wrapper_attr[key])
355 # Check __wrapped__
356 self.assertIs(wrapper.__wrapped__, wrapped)
357
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000358
R. David Murray378c0cf2010-02-24 01:46:21 +0000359 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000360 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000361 """This is a test"""
362 pass
363 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000364 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000365 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000366 pass
367 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000368 return wrapper, f
369
370 def test_default_update(self):
371 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000372 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000373 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000374 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600375 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000376 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000377 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
378 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000379
R. David Murray378c0cf2010-02-24 01:46:21 +0000380 @unittest.skipIf(sys.flags.optimize >= 2,
381 "Docstrings are omitted with -O2 and above")
382 def test_default_update_doc(self):
383 wrapper, f = self._default_update()
384 self.assertEqual(wrapper.__doc__, 'This is a test')
385
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000386 def test_no_update(self):
387 def f():
388 """This is a test"""
389 pass
390 f.attr = 'This is also a test'
391 def wrapper():
392 pass
393 functools.update_wrapper(wrapper, f, (), ())
394 self.check_wrapper(wrapper, f, (), ())
395 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600396 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000397 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000398 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000399 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000400
401 def test_selective_update(self):
402 def f():
403 pass
404 f.attr = 'This is a different test'
405 f.dict_attr = dict(a=1, b=2, c=3)
406 def wrapper():
407 pass
408 wrapper.dict_attr = {}
409 assign = ('attr',)
410 update = ('dict_attr',)
411 functools.update_wrapper(wrapper, f, assign, update)
412 self.check_wrapper(wrapper, f, assign, update)
413 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600414 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000415 self.assertEqual(wrapper.__doc__, None)
416 self.assertEqual(wrapper.attr, 'This is a different test')
417 self.assertEqual(wrapper.dict_attr, f.dict_attr)
418
Nick Coghlan98876832010-08-17 06:17:18 +0000419 def test_missing_attributes(self):
420 def f():
421 pass
422 def wrapper():
423 pass
424 wrapper.dict_attr = {}
425 assign = ('attr',)
426 update = ('dict_attr',)
427 # Missing attributes on wrapped object are ignored
428 functools.update_wrapper(wrapper, f, assign, update)
429 self.assertNotIn('attr', wrapper.__dict__)
430 self.assertEqual(wrapper.dict_attr, {})
431 # Wrapper must have expected attributes for updating
432 del wrapper.dict_attr
433 with self.assertRaises(AttributeError):
434 functools.update_wrapper(wrapper, f, assign, update)
435 wrapper.dict_attr = 1
436 with self.assertRaises(AttributeError):
437 functools.update_wrapper(wrapper, f, assign, update)
438
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200439 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000440 @unittest.skipIf(sys.flags.optimize >= 2,
441 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000442 def test_builtin_update(self):
443 # Test for bug #1576241
444 def wrapper():
445 pass
446 functools.update_wrapper(wrapper, max)
447 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000448 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000449 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000450
Łukasz Langa6f692512013-06-05 12:20:24 +0200451
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000452class TestWraps(TestUpdateWrapper):
453
R. David Murray378c0cf2010-02-24 01:46:21 +0000454 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000455 def f():
456 """This is a test"""
457 pass
458 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000459 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000460 @functools.wraps(f)
461 def wrapper():
462 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600463 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000464
465 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600466 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000467 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000468 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600469 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000470 self.assertEqual(wrapper.attr, 'This is also a test')
471
Antoine Pitroub5b37142012-11-13 21:35:40 +0100472 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000473 "Docstrings are omitted with -O2 and above")
474 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600475 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000476 self.assertEqual(wrapper.__doc__, 'This is a test')
477
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000478 def test_no_update(self):
479 def f():
480 """This is a test"""
481 pass
482 f.attr = 'This is also a test'
483 @functools.wraps(f, (), ())
484 def wrapper():
485 pass
486 self.check_wrapper(wrapper, f, (), ())
487 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600488 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000489 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000490 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000491
492 def test_selective_update(self):
493 def f():
494 pass
495 f.attr = 'This is a different test'
496 f.dict_attr = dict(a=1, b=2, c=3)
497 def add_dict_attr(f):
498 f.dict_attr = {}
499 return f
500 assign = ('attr',)
501 update = ('dict_attr',)
502 @functools.wraps(f, assign, update)
503 @add_dict_attr
504 def wrapper():
505 pass
506 self.check_wrapper(wrapper, f, assign, update)
507 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600508 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000509 self.assertEqual(wrapper.__doc__, None)
510 self.assertEqual(wrapper.attr, 'This is a different test')
511 self.assertEqual(wrapper.dict_attr, f.dict_attr)
512
Łukasz Langa6f692512013-06-05 12:20:24 +0200513
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000514class TestReduce(unittest.TestCase):
515 func = functools.reduce
516
517 def test_reduce(self):
518 class Squares:
519 def __init__(self, max):
520 self.max = max
521 self.sofar = []
522
523 def __len__(self):
524 return len(self.sofar)
525
526 def __getitem__(self, i):
527 if not 0 <= i < self.max: raise IndexError
528 n = len(self.sofar)
529 while n <= i:
530 self.sofar.append(n*n)
531 n += 1
532 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000533 def add(x, y):
534 return x + y
535 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000536 self.assertEqual(
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000537 self.func(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000538 ['a','c','d','w']
539 )
540 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
541 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000542 self.func(lambda x, y: x*y, range(2,21), 1),
543 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000544 )
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000545 self.assertEqual(self.func(add, Squares(10)), 285)
546 self.assertEqual(self.func(add, Squares(10), 0), 285)
547 self.assertEqual(self.func(add, Squares(0), 0), 0)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000548 self.assertRaises(TypeError, self.func)
549 self.assertRaises(TypeError, self.func, 42, 42)
550 self.assertRaises(TypeError, self.func, 42, 42, 42)
551 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
552 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
553 self.assertRaises(TypeError, self.func, 42, (42, 42))
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000554 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
555 self.assertRaises(TypeError, self.func, add, "")
556 self.assertRaises(TypeError, self.func, add, ())
557 self.assertRaises(TypeError, self.func, add, object())
558
559 class TestFailingIter:
560 def __iter__(self):
561 raise RuntimeError
562 self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
563
564 self.assertEqual(self.func(add, [], None), None)
565 self.assertEqual(self.func(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000566
567 class BadSeq:
568 def __getitem__(self, index):
569 raise ValueError
570 self.assertRaises(ValueError, self.func, 42, BadSeq())
571
572 # Test reduce()'s use of iterators.
573 def test_iterator_usage(self):
574 class SequenceClass:
575 def __init__(self, n):
576 self.n = n
577 def __getitem__(self, i):
578 if 0 <= i < self.n:
579 return i
580 else:
581 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000582
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000583 from operator import add
584 self.assertEqual(self.func(add, SequenceClass(5)), 10)
585 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
586 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
587 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
588 self.assertEqual(self.func(add, SequenceClass(1)), 0)
589 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
590
591 d = {"one": 1, "two": 2, "three": 3}
592 self.assertEqual(self.func(add, d), "".join(d.keys()))
593
Łukasz Langa6f692512013-06-05 12:20:24 +0200594
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200595class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700596
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000597 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700598 def cmp1(x, y):
599 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100600 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700601 self.assertEqual(key(3), key(3))
602 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100603 self.assertGreaterEqual(key(3), key(3))
604
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700605 def cmp2(x, y):
606 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100607 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700608 self.assertEqual(key(4.0), key('4'))
609 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100610 self.assertLessEqual(key(2), key('35'))
611 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700612
613 def test_cmp_to_key_arguments(self):
614 def cmp1(x, y):
615 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100616 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700617 self.assertEqual(key(obj=3), key(obj=3))
618 self.assertGreater(key(obj=3), key(obj=1))
619 with self.assertRaises((TypeError, AttributeError)):
620 key(3) > 1 # rhs is not a K object
621 with self.assertRaises((TypeError, AttributeError)):
622 1 < key(3) # lhs is not a K object
623 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100624 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700625 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200626 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100627 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700628 with self.assertRaises(TypeError):
629 key() # too few args
630 with self.assertRaises(TypeError):
631 key(None, None) # too many args
632
633 def test_bad_cmp(self):
634 def cmp1(x, y):
635 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100636 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700637 with self.assertRaises(ZeroDivisionError):
638 key(3) > key(1)
639
640 class BadCmp:
641 def __lt__(self, other):
642 raise ZeroDivisionError
643 def cmp1(x, y):
644 return BadCmp()
645 with self.assertRaises(ZeroDivisionError):
646 key(3) > key(1)
647
648 def test_obj_field(self):
649 def cmp1(x, y):
650 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100651 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700652 self.assertEqual(key(50).obj, 50)
653
654 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000655 def mycmp(x, y):
656 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100657 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000658 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000659
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700660 def test_sort_int_str(self):
661 def mycmp(x, y):
662 x, y = int(x), int(y)
663 return (x > y) - (x < y)
664 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100665 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700666 self.assertEqual([int(value) for value in values],
667 [0, 1, 1, 2, 3, 4, 5, 7, 10])
668
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000669 def test_hash(self):
670 def mycmp(x, y):
671 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100672 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000673 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700674 self.assertRaises(TypeError, hash, k)
Raymond Hettingere7a24302011-05-03 11:16:36 -0700675 self.assertNotIsInstance(k, collections.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000676
Łukasz Langa6f692512013-06-05 12:20:24 +0200677
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200678@unittest.skipUnless(c_functools, 'requires the C _functools module')
679class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
680 if c_functools:
681 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100682
Łukasz Langa6f692512013-06-05 12:20:24 +0200683
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200684class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100685 cmp_to_key = staticmethod(py_functools.cmp_to_key)
686
Łukasz Langa6f692512013-06-05 12:20:24 +0200687
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000688class TestTotalOrdering(unittest.TestCase):
689
690 def test_total_ordering_lt(self):
691 @functools.total_ordering
692 class A:
693 def __init__(self, value):
694 self.value = value
695 def __lt__(self, other):
696 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000697 def __eq__(self, other):
698 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000699 self.assertTrue(A(1) < A(2))
700 self.assertTrue(A(2) > A(1))
701 self.assertTrue(A(1) <= A(2))
702 self.assertTrue(A(2) >= A(1))
703 self.assertTrue(A(2) <= A(2))
704 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000705 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000706
707 def test_total_ordering_le(self):
708 @functools.total_ordering
709 class A:
710 def __init__(self, value):
711 self.value = value
712 def __le__(self, other):
713 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000714 def __eq__(self, other):
715 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000716 self.assertTrue(A(1) < A(2))
717 self.assertTrue(A(2) > A(1))
718 self.assertTrue(A(1) <= A(2))
719 self.assertTrue(A(2) >= A(1))
720 self.assertTrue(A(2) <= A(2))
721 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000722 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000723
724 def test_total_ordering_gt(self):
725 @functools.total_ordering
726 class A:
727 def __init__(self, value):
728 self.value = value
729 def __gt__(self, other):
730 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000731 def __eq__(self, other):
732 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000733 self.assertTrue(A(1) < A(2))
734 self.assertTrue(A(2) > A(1))
735 self.assertTrue(A(1) <= A(2))
736 self.assertTrue(A(2) >= A(1))
737 self.assertTrue(A(2) <= A(2))
738 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000739 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000740
741 def test_total_ordering_ge(self):
742 @functools.total_ordering
743 class A:
744 def __init__(self, value):
745 self.value = value
746 def __ge__(self, other):
747 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000748 def __eq__(self, other):
749 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000750 self.assertTrue(A(1) < A(2))
751 self.assertTrue(A(2) > A(1))
752 self.assertTrue(A(1) <= A(2))
753 self.assertTrue(A(2) >= A(1))
754 self.assertTrue(A(2) <= A(2))
755 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000756 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000757
758 def test_total_ordering_no_overwrite(self):
759 # new methods should not overwrite existing
760 @functools.total_ordering
761 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000762 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +0000763 self.assertTrue(A(1) < A(2))
764 self.assertTrue(A(2) > A(1))
765 self.assertTrue(A(1) <= A(2))
766 self.assertTrue(A(2) >= A(1))
767 self.assertTrue(A(2) <= A(2))
768 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000769
Benjamin Peterson42ebee32010-04-11 01:43:16 +0000770 def test_no_operations_defined(self):
771 with self.assertRaises(ValueError):
772 @functools.total_ordering
773 class A:
774 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000775
Nick Coghlanf05d9812013-10-02 00:02:03 +1000776 def test_type_error_when_not_implemented(self):
777 # bug 10042; ensure stack overflow does not occur
778 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000779 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +1000780 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000781 def __init__(self, value):
782 self.value = value
783 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000784 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000785 return self.value == other.value
786 return False
787 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000788 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000789 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +1000790 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000791
Nick Coghlanf05d9812013-10-02 00:02:03 +1000792 @functools.total_ordering
793 class ImplementsGreaterThan:
794 def __init__(self, value):
795 self.value = value
796 def __eq__(self, other):
797 if isinstance(other, ImplementsGreaterThan):
798 return self.value == other.value
799 return False
800 def __gt__(self, other):
801 if isinstance(other, ImplementsGreaterThan):
802 return self.value > other.value
803 return NotImplemented
804
805 @functools.total_ordering
806 class ImplementsLessThanEqualTo:
807 def __init__(self, value):
808 self.value = value
809 def __eq__(self, other):
810 if isinstance(other, ImplementsLessThanEqualTo):
811 return self.value == other.value
812 return False
813 def __le__(self, other):
814 if isinstance(other, ImplementsLessThanEqualTo):
815 return self.value <= other.value
816 return NotImplemented
817
818 @functools.total_ordering
819 class ImplementsGreaterThanEqualTo:
820 def __init__(self, value):
821 self.value = value
822 def __eq__(self, other):
823 if isinstance(other, ImplementsGreaterThanEqualTo):
824 return self.value == other.value
825 return False
826 def __ge__(self, other):
827 if isinstance(other, ImplementsGreaterThanEqualTo):
828 return self.value >= other.value
829 return NotImplemented
830
831 @functools.total_ordering
832 class ComparatorNotImplemented:
833 def __init__(self, value):
834 self.value = value
835 def __eq__(self, other):
836 if isinstance(other, ComparatorNotImplemented):
837 return self.value == other.value
838 return False
839 def __lt__(self, other):
840 return NotImplemented
841
842 with self.subTest("LT < 1"), self.assertRaises(TypeError):
843 ImplementsLessThan(-1) < 1
844
845 with self.subTest("LT < LE"), self.assertRaises(TypeError):
846 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
847
848 with self.subTest("LT < GT"), self.assertRaises(TypeError):
849 ImplementsLessThan(1) < ImplementsGreaterThan(1)
850
851 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
852 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
853
854 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
855 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
856
857 with self.subTest("GT > GE"), self.assertRaises(TypeError):
858 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
859
860 with self.subTest("GT > LT"), self.assertRaises(TypeError):
861 ImplementsGreaterThan(5) > ImplementsLessThan(5)
862
863 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
864 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
865
866 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
867 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
868
869 with self.subTest("GE when equal"):
870 a = ComparatorNotImplemented(8)
871 b = ComparatorNotImplemented(8)
872 self.assertEqual(a, b)
873 with self.assertRaises(TypeError):
874 a >= b
875
876 with self.subTest("LE when equal"):
877 a = ComparatorNotImplemented(9)
878 b = ComparatorNotImplemented(9)
879 self.assertEqual(a, b)
880 with self.assertRaises(TypeError):
881 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +0200882
Serhiy Storchaka697a5262015-01-01 15:23:12 +0200883 def test_pickle(self):
884 for proto in range(4, pickle.HIGHEST_PROTOCOL + 1):
885 for name in '__lt__', '__gt__', '__le__', '__ge__':
886 with self.subTest(method=name, proto=proto):
887 method = getattr(Orderable_LT, name)
888 method_copy = pickle.loads(pickle.dumps(method, proto))
889 self.assertIs(method_copy, method)
890
891@functools.total_ordering
892class Orderable_LT:
893 def __init__(self, value):
894 self.value = value
895 def __lt__(self, other):
896 return self.value < other.value
897 def __eq__(self, other):
898 return self.value == other.value
899
900
Georg Brandl2e7346a2010-07-31 18:09:23 +0000901class TestLRU(unittest.TestCase):
902
903 def test_lru(self):
904 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100905 return 3 * x + y
Georg Brandl2e7346a2010-07-31 18:09:23 +0000906 f = functools.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000907 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000908 self.assertEqual(maxsize, 20)
909 self.assertEqual(currsize, 0)
910 self.assertEqual(hits, 0)
911 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000912
913 domain = range(5)
914 for i in range(1000):
915 x, y = choice(domain), choice(domain)
916 actual = f(x, y)
917 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +0000918 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000919 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000920 self.assertTrue(hits > misses)
921 self.assertEqual(hits + misses, 1000)
922 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000923
Raymond Hettinger02566ec2010-09-04 22:46:06 +0000924 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +0000925 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000926 self.assertEqual(hits, 0)
927 self.assertEqual(misses, 0)
928 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000929 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000930 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000931 self.assertEqual(hits, 0)
932 self.assertEqual(misses, 1)
933 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000934
Nick Coghlan98876832010-08-17 06:17:18 +0000935 # Test bypassing the cache
936 self.assertIs(f.__wrapped__, orig)
937 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000938 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000939 self.assertEqual(hits, 0)
940 self.assertEqual(misses, 1)
941 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +0000942
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000943 # test size zero (which means "never-cache")
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000944 @functools.lru_cache(0)
945 def f():
946 nonlocal f_cnt
947 f_cnt += 1
948 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000949 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000950 f_cnt = 0
951 for i in range(5):
952 self.assertEqual(f(), 20)
953 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000954 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000955 self.assertEqual(hits, 0)
956 self.assertEqual(misses, 5)
957 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000958
959 # test size one
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000960 @functools.lru_cache(1)
961 def f():
962 nonlocal f_cnt
963 f_cnt += 1
964 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000965 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000966 f_cnt = 0
967 for i in range(5):
968 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000969 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000970 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000971 self.assertEqual(hits, 4)
972 self.assertEqual(misses, 1)
973 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000974
Raymond Hettingerf3098282010-08-15 03:30:45 +0000975 # test size two
976 @functools.lru_cache(2)
977 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000978 nonlocal f_cnt
979 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +0000980 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +0000981 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000982 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +0000983 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
984 # * * * *
985 self.assertEqual(f(x), x*10)
986 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000987 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000988 self.assertEqual(hits, 12)
989 self.assertEqual(misses, 4)
990 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000991
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +0000992 def test_lru_with_maxsize_none(self):
993 @functools.lru_cache(maxsize=None)
994 def fib(n):
995 if n < 2:
996 return n
997 return fib(n-1) + fib(n-2)
998 self.assertEqual([fib(n) for n in range(16)],
999 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1000 self.assertEqual(fib.cache_info(),
1001 functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1002 fib.cache_clear()
1003 self.assertEqual(fib.cache_info(),
1004 functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1005
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001006 def test_lru_with_exceptions(self):
1007 # Verify that user_function exceptions get passed through without
1008 # creating a hard-to-read chained exception.
1009 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +01001010 for maxsize in (None, 128):
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001011 @functools.lru_cache(maxsize)
1012 def func(i):
1013 return 'abc'[i]
1014 self.assertEqual(func(0), 'a')
1015 with self.assertRaises(IndexError) as cm:
1016 func(15)
1017 self.assertIsNone(cm.exception.__context__)
1018 # Verify that the previous exception did not result in a cached entry
1019 with self.assertRaises(IndexError):
1020 func(15)
1021
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001022 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001023 for maxsize in (None, 128):
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001024 @functools.lru_cache(maxsize=maxsize, typed=True)
1025 def square(x):
1026 return x * x
1027 self.assertEqual(square(3), 9)
1028 self.assertEqual(type(square(3)), type(9))
1029 self.assertEqual(square(3.0), 9.0)
1030 self.assertEqual(type(square(3.0)), type(9.0))
1031 self.assertEqual(square(x=3), 9)
1032 self.assertEqual(type(square(x=3)), type(9))
1033 self.assertEqual(square(x=3.0), 9.0)
1034 self.assertEqual(type(square(x=3.0)), type(9.0))
1035 self.assertEqual(square.cache_info().hits, 4)
1036 self.assertEqual(square.cache_info().misses, 4)
1037
Antoine Pitroub5b37142012-11-13 21:35:40 +01001038 def test_lru_with_keyword_args(self):
1039 @functools.lru_cache()
1040 def fib(n):
1041 if n < 2:
1042 return n
1043 return fib(n=n-1) + fib(n=n-2)
1044 self.assertEqual(
1045 [fib(n=number) for number in range(16)],
1046 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1047 )
1048 self.assertEqual(fib.cache_info(),
1049 functools._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
1050 fib.cache_clear()
1051 self.assertEqual(fib.cache_info(),
1052 functools._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
1053
1054 def test_lru_with_keyword_args_maxsize_none(self):
1055 @functools.lru_cache(maxsize=None)
1056 def fib(n):
1057 if n < 2:
1058 return n
1059 return fib(n=n-1) + fib(n=n-2)
1060 self.assertEqual([fib(n=number) for number in range(16)],
1061 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1062 self.assertEqual(fib.cache_info(),
1063 functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1064 fib.cache_clear()
1065 self.assertEqual(fib.cache_info(),
1066 functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1067
Raymond Hettinger03923422013-03-04 02:52:50 -05001068 def test_need_for_rlock(self):
1069 # This will deadlock on an LRU cache that uses a regular lock
1070
1071 @functools.lru_cache(maxsize=10)
1072 def test_func(x):
1073 'Used to demonstrate a reentrant lru_cache call within a single thread'
1074 return x
1075
1076 class DoubleEq:
1077 'Demonstrate a reentrant lru_cache call within a single thread'
1078 def __init__(self, x):
1079 self.x = x
1080 def __hash__(self):
1081 return self.x
1082 def __eq__(self, other):
1083 if self.x == 2:
1084 test_func(DoubleEq(1))
1085 return self.x == other.x
1086
1087 test_func(DoubleEq(1)) # Load the cache
1088 test_func(DoubleEq(2)) # Load the cache
1089 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1090 DoubleEq(2)) # Verify the correct return value
1091
Raymond Hettinger4d588972014-08-12 12:44:52 -07001092 def test_early_detection_of_bad_call(self):
1093 # Issue #22184
1094 with self.assertRaises(TypeError):
1095 @functools.lru_cache
1096 def f():
1097 pass
1098
Raymond Hettinger03923422013-03-04 02:52:50 -05001099
Łukasz Langa6f692512013-06-05 12:20:24 +02001100class TestSingleDispatch(unittest.TestCase):
1101 def test_simple_overloads(self):
1102 @functools.singledispatch
1103 def g(obj):
1104 return "base"
1105 def g_int(i):
1106 return "integer"
1107 g.register(int, g_int)
1108 self.assertEqual(g("str"), "base")
1109 self.assertEqual(g(1), "integer")
1110 self.assertEqual(g([1,2,3]), "base")
1111
1112 def test_mro(self):
1113 @functools.singledispatch
1114 def g(obj):
1115 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001116 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001117 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001118 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001119 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001120 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001121 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001122 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001123 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001124 def g_A(a):
1125 return "A"
1126 def g_B(b):
1127 return "B"
1128 g.register(A, g_A)
1129 g.register(B, g_B)
1130 self.assertEqual(g(A()), "A")
1131 self.assertEqual(g(B()), "B")
1132 self.assertEqual(g(C()), "A")
1133 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001134
1135 def test_register_decorator(self):
1136 @functools.singledispatch
1137 def g(obj):
1138 return "base"
1139 @g.register(int)
1140 def g_int(i):
1141 return "int %s" % (i,)
1142 self.assertEqual(g(""), "base")
1143 self.assertEqual(g(12), "int 12")
1144 self.assertIs(g.dispatch(int), g_int)
1145 self.assertIs(g.dispatch(object), g.dispatch(str))
1146 # Note: in the assert above this is not g.
1147 # @singledispatch returns the wrapper.
1148
1149 def test_wrapping_attributes(self):
1150 @functools.singledispatch
1151 def g(obj):
1152 "Simple test"
1153 return "Test"
1154 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02001155 if sys.flags.optimize < 2:
1156 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02001157
1158 @unittest.skipUnless(decimal, 'requires _decimal')
1159 @support.cpython_only
1160 def test_c_classes(self):
1161 @functools.singledispatch
1162 def g(obj):
1163 return "base"
1164 @g.register(decimal.DecimalException)
1165 def _(obj):
1166 return obj.args
1167 subn = decimal.Subnormal("Exponent < Emin")
1168 rnd = decimal.Rounded("Number got rounded")
1169 self.assertEqual(g(subn), ("Exponent < Emin",))
1170 self.assertEqual(g(rnd), ("Number got rounded",))
1171 @g.register(decimal.Subnormal)
1172 def _(obj):
1173 return "Too small to care."
1174 self.assertEqual(g(subn), "Too small to care.")
1175 self.assertEqual(g(rnd), ("Number got rounded",))
1176
1177 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001178 # None of the examples in this test depend on haystack ordering.
Łukasz Langa6f692512013-06-05 12:20:24 +02001179 c = collections
1180 mro = functools._compose_mro
1181 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1182 for haystack in permutations(bases):
1183 m = mro(dict, haystack)
Łukasz Langa3720c772013-07-01 16:00:38 +02001184 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, c.Sized,
1185 c.Iterable, c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001186 bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict]
1187 for haystack in permutations(bases):
1188 m = mro(c.ChainMap, haystack)
1189 self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping,
1190 c.Sized, c.Iterable, c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001191
1192 # If there's a generic function with implementations registered for
1193 # both Sized and Container, passing a defaultdict to it results in an
1194 # ambiguous dispatch which will cause a RuntimeError (see
1195 # test_mro_conflicts).
1196 bases = [c.Container, c.Sized, str]
1197 for haystack in permutations(bases):
1198 m = mro(c.defaultdict, [c.Sized, c.Container, str])
1199 self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container,
1200 object])
1201
1202 # MutableSequence below is registered directly on D. In other words, it
1203 # preceeds MutableMapping which means single dispatch will always
1204 # choose MutableSequence here.
1205 class D(c.defaultdict):
1206 pass
1207 c.MutableSequence.register(D)
1208 bases = [c.MutableSequence, c.MutableMapping]
1209 for haystack in permutations(bases):
1210 m = mro(D, bases)
1211 self.assertEqual(m, [D, c.MutableSequence, c.Sequence,
1212 c.defaultdict, dict, c.MutableMapping,
1213 c.Mapping, c.Sized, c.Iterable, c.Container,
1214 object])
1215
1216 # Container and Callable are registered on different base classes and
1217 # a generic function supporting both should always pick the Callable
1218 # implementation if a C instance is passed.
1219 class C(c.defaultdict):
1220 def __call__(self):
1221 pass
1222 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1223 for haystack in permutations(bases):
1224 m = mro(C, haystack)
1225 self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping,
1226 c.Sized, c.Iterable, c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001227
1228 def test_register_abc(self):
1229 c = collections
1230 d = {"a": "b"}
1231 l = [1, 2, 3]
1232 s = {object(), None}
1233 f = frozenset(s)
1234 t = (1, 2, 3)
1235 @functools.singledispatch
1236 def g(obj):
1237 return "base"
1238 self.assertEqual(g(d), "base")
1239 self.assertEqual(g(l), "base")
1240 self.assertEqual(g(s), "base")
1241 self.assertEqual(g(f), "base")
1242 self.assertEqual(g(t), "base")
1243 g.register(c.Sized, lambda obj: "sized")
1244 self.assertEqual(g(d), "sized")
1245 self.assertEqual(g(l), "sized")
1246 self.assertEqual(g(s), "sized")
1247 self.assertEqual(g(f), "sized")
1248 self.assertEqual(g(t), "sized")
1249 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1250 self.assertEqual(g(d), "mutablemapping")
1251 self.assertEqual(g(l), "sized")
1252 self.assertEqual(g(s), "sized")
1253 self.assertEqual(g(f), "sized")
1254 self.assertEqual(g(t), "sized")
1255 g.register(c.ChainMap, lambda obj: "chainmap")
1256 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1257 self.assertEqual(g(l), "sized")
1258 self.assertEqual(g(s), "sized")
1259 self.assertEqual(g(f), "sized")
1260 self.assertEqual(g(t), "sized")
1261 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1262 self.assertEqual(g(d), "mutablemapping")
1263 self.assertEqual(g(l), "mutablesequence")
1264 self.assertEqual(g(s), "sized")
1265 self.assertEqual(g(f), "sized")
1266 self.assertEqual(g(t), "sized")
1267 g.register(c.MutableSet, lambda obj: "mutableset")
1268 self.assertEqual(g(d), "mutablemapping")
1269 self.assertEqual(g(l), "mutablesequence")
1270 self.assertEqual(g(s), "mutableset")
1271 self.assertEqual(g(f), "sized")
1272 self.assertEqual(g(t), "sized")
1273 g.register(c.Mapping, lambda obj: "mapping")
1274 self.assertEqual(g(d), "mutablemapping") # not specific enough
1275 self.assertEqual(g(l), "mutablesequence")
1276 self.assertEqual(g(s), "mutableset")
1277 self.assertEqual(g(f), "sized")
1278 self.assertEqual(g(t), "sized")
1279 g.register(c.Sequence, lambda obj: "sequence")
1280 self.assertEqual(g(d), "mutablemapping")
1281 self.assertEqual(g(l), "mutablesequence")
1282 self.assertEqual(g(s), "mutableset")
1283 self.assertEqual(g(f), "sized")
1284 self.assertEqual(g(t), "sequence")
1285 g.register(c.Set, lambda obj: "set")
1286 self.assertEqual(g(d), "mutablemapping")
1287 self.assertEqual(g(l), "mutablesequence")
1288 self.assertEqual(g(s), "mutableset")
1289 self.assertEqual(g(f), "set")
1290 self.assertEqual(g(t), "sequence")
1291 g.register(dict, lambda obj: "dict")
1292 self.assertEqual(g(d), "dict")
1293 self.assertEqual(g(l), "mutablesequence")
1294 self.assertEqual(g(s), "mutableset")
1295 self.assertEqual(g(f), "set")
1296 self.assertEqual(g(t), "sequence")
1297 g.register(list, lambda obj: "list")
1298 self.assertEqual(g(d), "dict")
1299 self.assertEqual(g(l), "list")
1300 self.assertEqual(g(s), "mutableset")
1301 self.assertEqual(g(f), "set")
1302 self.assertEqual(g(t), "sequence")
1303 g.register(set, lambda obj: "concrete-set")
1304 self.assertEqual(g(d), "dict")
1305 self.assertEqual(g(l), "list")
1306 self.assertEqual(g(s), "concrete-set")
1307 self.assertEqual(g(f), "set")
1308 self.assertEqual(g(t), "sequence")
1309 g.register(frozenset, lambda obj: "frozen-set")
1310 self.assertEqual(g(d), "dict")
1311 self.assertEqual(g(l), "list")
1312 self.assertEqual(g(s), "concrete-set")
1313 self.assertEqual(g(f), "frozen-set")
1314 self.assertEqual(g(t), "sequence")
1315 g.register(tuple, lambda obj: "tuple")
1316 self.assertEqual(g(d), "dict")
1317 self.assertEqual(g(l), "list")
1318 self.assertEqual(g(s), "concrete-set")
1319 self.assertEqual(g(f), "frozen-set")
1320 self.assertEqual(g(t), "tuple")
1321
Łukasz Langa3720c772013-07-01 16:00:38 +02001322 def test_c3_abc(self):
1323 c = collections
1324 mro = functools._c3_mro
1325 class A(object):
1326 pass
1327 class B(A):
1328 def __len__(self):
1329 return 0 # implies Sized
1330 @c.Container.register
1331 class C(object):
1332 pass
1333 class D(object):
1334 pass # unrelated
1335 class X(D, C, B):
1336 def __call__(self):
1337 pass # implies Callable
1338 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
1339 for abcs in permutations([c.Sized, c.Callable, c.Container]):
1340 self.assertEqual(mro(X, abcs=abcs), expected)
1341 # unrelated ABCs don't appear in the resulting MRO
1342 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
1343 self.assertEqual(mro(X, abcs=many_abcs), expected)
1344
Łukasz Langa6f692512013-06-05 12:20:24 +02001345 def test_mro_conflicts(self):
1346 c = collections
Łukasz Langa6f692512013-06-05 12:20:24 +02001347 @functools.singledispatch
1348 def g(arg):
1349 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02001350 class O(c.Sized):
1351 def __len__(self):
1352 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001353 o = O()
1354 self.assertEqual(g(o), "base")
1355 g.register(c.Iterable, lambda arg: "iterable")
1356 g.register(c.Container, lambda arg: "container")
1357 g.register(c.Sized, lambda arg: "sized")
1358 g.register(c.Set, lambda arg: "set")
1359 self.assertEqual(g(o), "sized")
1360 c.Iterable.register(O)
1361 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
1362 c.Container.register(O)
1363 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02001364 c.Set.register(O)
1365 self.assertEqual(g(o), "set") # because c.Set is a subclass of
1366 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02001367 class P:
1368 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02001369 p = P()
1370 self.assertEqual(g(p), "base")
1371 c.Iterable.register(P)
1372 self.assertEqual(g(p), "iterable")
1373 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02001374 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02001375 g(p)
Łukasz Langa3720c772013-07-01 16:00:38 +02001376 self.assertIn(
1377 str(re_one.exception),
1378 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1379 "or <class 'collections.abc.Iterable'>"),
1380 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
1381 "or <class 'collections.abc.Container'>")),
1382 )
Łukasz Langa6f692512013-06-05 12:20:24 +02001383 class Q(c.Sized):
1384 def __len__(self):
1385 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001386 q = Q()
1387 self.assertEqual(g(q), "sized")
1388 c.Iterable.register(Q)
1389 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
1390 c.Set.register(Q)
1391 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02001392 # c.Sized and c.Iterable
1393 @functools.singledispatch
1394 def h(arg):
1395 return "base"
1396 @h.register(c.Sized)
1397 def _(arg):
1398 return "sized"
1399 @h.register(c.Container)
1400 def _(arg):
1401 return "container"
1402 # Even though Sized and Container are explicit bases of MutableMapping,
1403 # this ABC is implicitly registered on defaultdict which makes all of
1404 # MutableMapping's bases implicit as well from defaultdict's
1405 # perspective.
1406 with self.assertRaises(RuntimeError) as re_two:
1407 h(c.defaultdict(lambda: 0))
1408 self.assertIn(
1409 str(re_two.exception),
1410 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1411 "or <class 'collections.abc.Sized'>"),
1412 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1413 "or <class 'collections.abc.Container'>")),
1414 )
1415 class R(c.defaultdict):
1416 pass
1417 c.MutableSequence.register(R)
1418 @functools.singledispatch
1419 def i(arg):
1420 return "base"
1421 @i.register(c.MutableMapping)
1422 def _(arg):
1423 return "mapping"
1424 @i.register(c.MutableSequence)
1425 def _(arg):
1426 return "sequence"
1427 r = R()
1428 self.assertEqual(i(r), "sequence")
1429 class S:
1430 pass
1431 class T(S, c.Sized):
1432 def __len__(self):
1433 return 0
1434 t = T()
1435 self.assertEqual(h(t), "sized")
1436 c.Container.register(T)
1437 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
1438 class U:
1439 def __len__(self):
1440 return 0
1441 u = U()
1442 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
1443 # from the existence of __len__()
1444 c.Container.register(U)
1445 # There is no preference for registered versus inferred ABCs.
1446 with self.assertRaises(RuntimeError) as re_three:
1447 h(u)
1448 self.assertIn(
1449 str(re_three.exception),
1450 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1451 "or <class 'collections.abc.Sized'>"),
1452 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1453 "or <class 'collections.abc.Container'>")),
1454 )
1455 class V(c.Sized, S):
1456 def __len__(self):
1457 return 0
1458 @functools.singledispatch
1459 def j(arg):
1460 return "base"
1461 @j.register(S)
1462 def _(arg):
1463 return "s"
1464 @j.register(c.Container)
1465 def _(arg):
1466 return "container"
1467 v = V()
1468 self.assertEqual(j(v), "s")
1469 c.Container.register(V)
1470 self.assertEqual(j(v), "container") # because it ends up right after
1471 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02001472
1473 def test_cache_invalidation(self):
1474 from collections import UserDict
1475 class TracingDict(UserDict):
1476 def __init__(self, *args, **kwargs):
1477 super(TracingDict, self).__init__(*args, **kwargs)
1478 self.set_ops = []
1479 self.get_ops = []
1480 def __getitem__(self, key):
1481 result = self.data[key]
1482 self.get_ops.append(key)
1483 return result
1484 def __setitem__(self, key, value):
1485 self.set_ops.append(key)
1486 self.data[key] = value
1487 def clear(self):
1488 self.data.clear()
1489 _orig_wkd = functools.WeakKeyDictionary
1490 td = TracingDict()
1491 functools.WeakKeyDictionary = lambda: td
1492 c = collections
1493 @functools.singledispatch
1494 def g(arg):
1495 return "base"
1496 d = {}
1497 l = []
1498 self.assertEqual(len(td), 0)
1499 self.assertEqual(g(d), "base")
1500 self.assertEqual(len(td), 1)
1501 self.assertEqual(td.get_ops, [])
1502 self.assertEqual(td.set_ops, [dict])
1503 self.assertEqual(td.data[dict], g.registry[object])
1504 self.assertEqual(g(l), "base")
1505 self.assertEqual(len(td), 2)
1506 self.assertEqual(td.get_ops, [])
1507 self.assertEqual(td.set_ops, [dict, list])
1508 self.assertEqual(td.data[dict], g.registry[object])
1509 self.assertEqual(td.data[list], g.registry[object])
1510 self.assertEqual(td.data[dict], td.data[list])
1511 self.assertEqual(g(l), "base")
1512 self.assertEqual(g(d), "base")
1513 self.assertEqual(td.get_ops, [list, dict])
1514 self.assertEqual(td.set_ops, [dict, list])
1515 g.register(list, lambda arg: "list")
1516 self.assertEqual(td.get_ops, [list, dict])
1517 self.assertEqual(len(td), 0)
1518 self.assertEqual(g(d), "base")
1519 self.assertEqual(len(td), 1)
1520 self.assertEqual(td.get_ops, [list, dict])
1521 self.assertEqual(td.set_ops, [dict, list, dict])
1522 self.assertEqual(td.data[dict],
1523 functools._find_impl(dict, g.registry))
1524 self.assertEqual(g(l), "list")
1525 self.assertEqual(len(td), 2)
1526 self.assertEqual(td.get_ops, [list, dict])
1527 self.assertEqual(td.set_ops, [dict, list, dict, list])
1528 self.assertEqual(td.data[list],
1529 functools._find_impl(list, g.registry))
1530 class X:
1531 pass
1532 c.MutableMapping.register(X) # Will not invalidate the cache,
1533 # not using ABCs yet.
1534 self.assertEqual(g(d), "base")
1535 self.assertEqual(g(l), "list")
1536 self.assertEqual(td.get_ops, [list, dict, dict, list])
1537 self.assertEqual(td.set_ops, [dict, list, dict, list])
1538 g.register(c.Sized, lambda arg: "sized")
1539 self.assertEqual(len(td), 0)
1540 self.assertEqual(g(d), "sized")
1541 self.assertEqual(len(td), 1)
1542 self.assertEqual(td.get_ops, [list, dict, dict, list])
1543 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
1544 self.assertEqual(g(l), "list")
1545 self.assertEqual(len(td), 2)
1546 self.assertEqual(td.get_ops, [list, dict, dict, list])
1547 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1548 self.assertEqual(g(l), "list")
1549 self.assertEqual(g(d), "sized")
1550 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
1551 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1552 g.dispatch(list)
1553 g.dispatch(dict)
1554 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
1555 list, dict])
1556 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1557 c.MutableSet.register(X) # Will invalidate the cache.
1558 self.assertEqual(len(td), 2) # Stale cache.
1559 self.assertEqual(g(l), "list")
1560 self.assertEqual(len(td), 1)
1561 g.register(c.MutableMapping, lambda arg: "mutablemapping")
1562 self.assertEqual(len(td), 0)
1563 self.assertEqual(g(d), "mutablemapping")
1564 self.assertEqual(len(td), 1)
1565 self.assertEqual(g(l), "list")
1566 self.assertEqual(len(td), 2)
1567 g.register(dict, lambda arg: "dict")
1568 self.assertEqual(g(d), "dict")
1569 self.assertEqual(g(l), "list")
1570 g._clear_cache()
1571 self.assertEqual(len(td), 0)
1572 functools.WeakKeyDictionary = _orig_wkd
1573
1574
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001575def test_main(verbose=None):
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001576 test_classes = (
Antoine Pitroub5b37142012-11-13 21:35:40 +01001577 TestPartialC,
1578 TestPartialPy,
1579 TestPartialCSubclass,
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001580 TestPartialMethod,
Thomas Wouters73e5a5b2006-06-08 15:35:45 +00001581 TestUpdateWrapper,
Benjamin Peterson9c2930e2010-08-23 17:40:33 +00001582 TestTotalOrdering,
Antoine Pitroub5b37142012-11-13 21:35:40 +01001583 TestCmpToKeyC,
1584 TestCmpToKeyPy,
Guido van Rossum0919a1a2006-08-26 20:49:04 +00001585 TestWraps,
Georg Brandl2e7346a2010-07-31 18:09:23 +00001586 TestReduce,
1587 TestLRU,
Łukasz Langa6f692512013-06-05 12:20:24 +02001588 TestSingleDispatch,
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001589 )
Benjamin Petersonee8712c2008-05-20 21:35:26 +00001590 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001591
1592 # verify reference counting
1593 if verbose and hasattr(sys, "gettotalrefcount"):
1594 import gc
1595 counts = [None] * 5
Guido van Rossum805365e2007-05-07 22:24:25 +00001596 for i in range(len(counts)):
Benjamin Petersonee8712c2008-05-20 21:35:26 +00001597 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001598 gc.collect()
1599 counts[i] = sys.gettotalrefcount()
Guido van Rossumbe19ed72007-02-09 05:37:30 +00001600 print(counts)
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001601
1602if __name__ == '__main__':
1603 test_main(verbose=True)