blob: 38c9713bf37aae9c59f542a15de19fc48f540e0c [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettinger003be522011-05-03 11:01:32 -07002import collections
Łukasz Langa6f692512013-06-05 12:20:24 +02003from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00004import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00005from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02006import sys
7from test import support
8import unittest
9from weakref import proxy
Raymond Hettinger9c323f82005-02-28 19:39:44 +000010
Antoine Pitroub5b37142012-11-13 21:35:40 +010011import functools
12
Antoine Pitroub5b37142012-11-13 21:35:40 +010013py_functools = support.import_fresh_module('functools', blocked=['_functools'])
14c_functools = support.import_fresh_module('functools', fresh=['_functools'])
15
Łukasz Langa6f692512013-06-05 12:20:24 +020016decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
17
18
Raymond Hettinger9c323f82005-02-28 19:39:44 +000019def capture(*args, **kw):
20 """capture all positional and keyword arguments"""
21 return args, kw
22
Łukasz Langa6f692512013-06-05 12:20:24 +020023
Jack Diederiche0cbd692009-04-01 04:27:09 +000024def signature(part):
25 """ return the signature of a partial object """
26 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000027
Łukasz Langa6f692512013-06-05 12:20:24 +020028
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020029class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000030
31 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010032 p = self.partial(capture, 1, 2, a=10, b=20)
33 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000034 self.assertEqual(p(3, 4, b=30, c=40),
35 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010036 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000037 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000038
39 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010040 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000041 # attributes should be readable
42 self.assertEqual(p.func, capture)
43 self.assertEqual(p.args, (1, 2))
44 self.assertEqual(p.keywords, dict(a=10, b=20))
45 # attributes should not be writable
Antoine Pitroub5b37142012-11-13 21:35:40 +010046 if not isinstance(self.partial, type):
Raymond Hettinger9c323f82005-02-28 19:39:44 +000047 return
Georg Brandl89fad142010-03-14 10:23:39 +000048 self.assertRaises(AttributeError, setattr, p, 'func', map)
49 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
50 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
51
Antoine Pitroub5b37142012-11-13 21:35:40 +010052 p = self.partial(hex)
Georg Brandl89fad142010-03-14 10:23:39 +000053 try:
54 del p.__dict__
55 except TypeError:
56 pass
57 else:
58 self.fail('partial object allowed __dict__ to be deleted')
Raymond Hettinger9c323f82005-02-28 19:39:44 +000059
60 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010061 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000062 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010063 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000064 except TypeError:
65 pass
66 else:
67 self.fail('First arg not checked for callability')
68
69 def test_protection_of_callers_dict_argument(self):
70 # a caller's dictionary should not be altered by partial
71 def func(a=10, b=20):
72 return a
73 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010074 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000075 self.assertEqual(p(**d), 3)
76 self.assertEqual(d, {'a':3})
77 p(b=7)
78 self.assertEqual(d, {'a':3})
79
80 def test_arg_combinations(self):
81 # exercise special code paths for zero args in either partial
82 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010083 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000084 self.assertEqual(p(), ((), {}))
85 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010086 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000087 self.assertEqual(p(), ((1,2), {}))
88 self.assertEqual(p(3,4), ((1,2,3,4), {}))
89
90 def test_kw_combinations(self):
91 # exercise special code paths for no keyword args in
92 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010093 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000094 self.assertEqual(p(), ((), {}))
95 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010096 p = self.partial(capture, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000097 self.assertEqual(p(), ((), {'a':1}))
98 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
99 # keyword args in the call override those in the partial object
100 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
101
102 def test_positional(self):
103 # make sure positional arguments are captured correctly
104 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100105 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000106 expected = args + ('x',)
107 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000108 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000109
110 def test_keyword(self):
111 # make sure keyword arguments are captured correctly
112 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100113 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000114 expected = {'a':a,'x':None}
115 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000116 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000117
118 def test_no_side_effects(self):
119 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100120 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000121 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000122 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000123 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000124 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000125
126 def test_error_propagation(self):
127 def f(x, y):
128 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100129 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
130 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
131 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
132 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000133
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000134 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100135 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000136 p = proxy(f)
137 self.assertEqual(f.func, p.func)
138 f = None
139 self.assertRaises(ReferenceError, getattr, p, 'func')
140
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000141 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000142 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100143 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000144 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100145 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000146 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000147
Łukasz Langa6f692512013-06-05 12:20:24 +0200148
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200149@unittest.skipUnless(c_functools, 'requires the C _functools module')
150class TestPartialC(TestPartial, unittest.TestCase):
151 if c_functools:
152 partial = c_functools.partial
153
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000154 def test_repr(self):
155 args = (object(), object())
156 args_repr = ', '.join(repr(a) for a in args)
Christian Heimesd0628922013-11-22 01:22:47 +0100157 #kwargs = {'a': object(), 'b': object()}
158 kwargs = {'a': object()}
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000159 kwargs_repr = ', '.join("%s=%r" % (k, v) for k, v in kwargs.items())
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200160 if self.partial is c_functools.partial:
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000161 name = 'functools.partial'
162 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100163 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000164
Antoine Pitroub5b37142012-11-13 21:35:40 +0100165 f = self.partial(capture)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000166 self.assertEqual('{}({!r})'.format(name, capture),
167 repr(f))
168
Antoine Pitroub5b37142012-11-13 21:35:40 +0100169 f = self.partial(capture, *args)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000170 self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr),
171 repr(f))
172
Antoine Pitroub5b37142012-11-13 21:35:40 +0100173 f = self.partial(capture, **kwargs)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000174 self.assertEqual('{}({!r}, {})'.format(name, capture, kwargs_repr),
175 repr(f))
176
Antoine Pitroub5b37142012-11-13 21:35:40 +0100177 f = self.partial(capture, *args, **kwargs)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000178 self.assertEqual('{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr),
179 repr(f))
180
Jack Diederiche0cbd692009-04-01 04:27:09 +0000181 def test_pickle(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100182 f = self.partial(signature, 'asdf', bar=True)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000183 f.add_something_to__dict__ = True
184 f_copy = pickle.loads(pickle.dumps(f))
185 self.assertEqual(signature(f), signature(f_copy))
186
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200187 # Issue 6083: Reference counting bug
188 def test_setstate_refcount(self):
189 class BadSequence:
190 def __len__(self):
191 return 4
192 def __getitem__(self, key):
193 if key == 0:
194 return max
195 elif key == 1:
196 return tuple(range(1000000))
197 elif key in (2, 3):
198 return {}
199 raise IndexError
200
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200201 f = self.partial(object)
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200202 self.assertRaisesRegex(SystemError,
203 "new style getargs format but argument is not a tuple",
204 f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000205
Łukasz Langa6f692512013-06-05 12:20:24 +0200206
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200207class TestPartialPy(TestPartial, unittest.TestCase):
208 partial = staticmethod(py_functools.partial)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000209
Łukasz Langa6f692512013-06-05 12:20:24 +0200210
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200211if c_functools:
212 class PartialSubclass(c_functools.partial):
213 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100214
Łukasz Langa6f692512013-06-05 12:20:24 +0200215
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200216@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200217class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200218 if c_functools:
219 partial = PartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000220
Łukasz Langa6f692512013-06-05 12:20:24 +0200221
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000222class TestPartialMethod(unittest.TestCase):
223
224 class A(object):
225 nothing = functools.partialmethod(capture)
226 positional = functools.partialmethod(capture, 1)
227 keywords = functools.partialmethod(capture, a=2)
228 both = functools.partialmethod(capture, 3, b=4)
229
230 nested = functools.partialmethod(positional, 5)
231
232 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
233
234 static = functools.partialmethod(staticmethod(capture), 8)
235 cls = functools.partialmethod(classmethod(capture), d=9)
236
237 a = A()
238
239 def test_arg_combinations(self):
240 self.assertEqual(self.a.nothing(), ((self.a,), {}))
241 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
242 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
243 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
244
245 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
246 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
247 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
248 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
249
250 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
251 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
252 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
253 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
254
255 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
256 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
257 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
258 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
259
260 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
261
262 def test_nested(self):
263 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
264 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
265 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
266 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
267
268 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
269
270 def test_over_partial(self):
271 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
272 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
273 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
274 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
275
276 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
277
278 def test_bound_method_introspection(self):
279 obj = self.a
280 self.assertIs(obj.both.__self__, obj)
281 self.assertIs(obj.nested.__self__, obj)
282 self.assertIs(obj.over_partial.__self__, obj)
283 self.assertIs(obj.cls.__self__, self.A)
284 self.assertIs(self.A.cls.__self__, self.A)
285
286 def test_unbound_method_retrieval(self):
287 obj = self.A
288 self.assertFalse(hasattr(obj.both, "__self__"))
289 self.assertFalse(hasattr(obj.nested, "__self__"))
290 self.assertFalse(hasattr(obj.over_partial, "__self__"))
291 self.assertFalse(hasattr(obj.static, "__self__"))
292 self.assertFalse(hasattr(self.a.static, "__self__"))
293
294 def test_descriptors(self):
295 for obj in [self.A, self.a]:
296 with self.subTest(obj=obj):
297 self.assertEqual(obj.static(), ((8,), {}))
298 self.assertEqual(obj.static(5), ((8, 5), {}))
299 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
300 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
301
302 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
303 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
304 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
305 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
306
307 def test_overriding_keywords(self):
308 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
309 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
310
311 def test_invalid_args(self):
312 with self.assertRaises(TypeError):
313 class B(object):
314 method = functools.partialmethod(None, 1)
315
316 def test_repr(self):
317 self.assertEqual(repr(vars(self.A)['both']),
318 'functools.partialmethod({}, 3, b=4)'.format(capture))
319
320 def test_abstract(self):
321 class Abstract(abc.ABCMeta):
322
323 @abc.abstractmethod
324 def add(self, x, y):
325 pass
326
327 add5 = functools.partialmethod(add, 5)
328
329 self.assertTrue(Abstract.add.__isabstractmethod__)
330 self.assertTrue(Abstract.add5.__isabstractmethod__)
331
332 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
333 self.assertFalse(getattr(func, '__isabstractmethod__', False))
334
335
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000336class TestUpdateWrapper(unittest.TestCase):
337
338 def check_wrapper(self, wrapper, wrapped,
339 assigned=functools.WRAPPER_ASSIGNMENTS,
340 updated=functools.WRAPPER_UPDATES):
341 # Check attributes were assigned
342 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000343 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000344 # Check attributes were updated
345 for name in updated:
346 wrapper_attr = getattr(wrapper, name)
347 wrapped_attr = getattr(wrapped, name)
348 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000349 if name == "__dict__" and key == "__wrapped__":
350 # __wrapped__ is overwritten by the update code
351 continue
352 self.assertIs(wrapped_attr[key], wrapper_attr[key])
353 # Check __wrapped__
354 self.assertIs(wrapper.__wrapped__, wrapped)
355
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000356
R. David Murray378c0cf2010-02-24 01:46:21 +0000357 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000358 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000359 """This is a test"""
360 pass
361 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000362 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000363 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000364 pass
365 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000366 return wrapper, f
367
368 def test_default_update(self):
369 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000370 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000371 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000372 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600373 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000374 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000375 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
376 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000377
R. David Murray378c0cf2010-02-24 01:46:21 +0000378 @unittest.skipIf(sys.flags.optimize >= 2,
379 "Docstrings are omitted with -O2 and above")
380 def test_default_update_doc(self):
381 wrapper, f = self._default_update()
382 self.assertEqual(wrapper.__doc__, 'This is a test')
383
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000384 def test_no_update(self):
385 def f():
386 """This is a test"""
387 pass
388 f.attr = 'This is also a test'
389 def wrapper():
390 pass
391 functools.update_wrapper(wrapper, f, (), ())
392 self.check_wrapper(wrapper, f, (), ())
393 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600394 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000395 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000396 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000397 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000398
399 def test_selective_update(self):
400 def f():
401 pass
402 f.attr = 'This is a different test'
403 f.dict_attr = dict(a=1, b=2, c=3)
404 def wrapper():
405 pass
406 wrapper.dict_attr = {}
407 assign = ('attr',)
408 update = ('dict_attr',)
409 functools.update_wrapper(wrapper, f, assign, update)
410 self.check_wrapper(wrapper, f, assign, update)
411 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600412 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000413 self.assertEqual(wrapper.__doc__, None)
414 self.assertEqual(wrapper.attr, 'This is a different test')
415 self.assertEqual(wrapper.dict_attr, f.dict_attr)
416
Nick Coghlan98876832010-08-17 06:17:18 +0000417 def test_missing_attributes(self):
418 def f():
419 pass
420 def wrapper():
421 pass
422 wrapper.dict_attr = {}
423 assign = ('attr',)
424 update = ('dict_attr',)
425 # Missing attributes on wrapped object are ignored
426 functools.update_wrapper(wrapper, f, assign, update)
427 self.assertNotIn('attr', wrapper.__dict__)
428 self.assertEqual(wrapper.dict_attr, {})
429 # Wrapper must have expected attributes for updating
430 del wrapper.dict_attr
431 with self.assertRaises(AttributeError):
432 functools.update_wrapper(wrapper, f, assign, update)
433 wrapper.dict_attr = 1
434 with self.assertRaises(AttributeError):
435 functools.update_wrapper(wrapper, f, assign, update)
436
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200437 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000438 @unittest.skipIf(sys.flags.optimize >= 2,
439 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000440 def test_builtin_update(self):
441 # Test for bug #1576241
442 def wrapper():
443 pass
444 functools.update_wrapper(wrapper, max)
445 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000446 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000447 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000448
Łukasz Langa6f692512013-06-05 12:20:24 +0200449
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000450class TestWraps(TestUpdateWrapper):
451
R. David Murray378c0cf2010-02-24 01:46:21 +0000452 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000453 def f():
454 """This is a test"""
455 pass
456 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000457 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000458 @functools.wraps(f)
459 def wrapper():
460 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600461 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000462
463 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600464 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000465 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000466 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600467 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000468 self.assertEqual(wrapper.attr, 'This is also a test')
469
Antoine Pitroub5b37142012-11-13 21:35:40 +0100470 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000471 "Docstrings are omitted with -O2 and above")
472 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600473 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000474 self.assertEqual(wrapper.__doc__, 'This is a test')
475
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000476 def test_no_update(self):
477 def f():
478 """This is a test"""
479 pass
480 f.attr = 'This is also a test'
481 @functools.wraps(f, (), ())
482 def wrapper():
483 pass
484 self.check_wrapper(wrapper, f, (), ())
485 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600486 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000487 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000488 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000489
490 def test_selective_update(self):
491 def f():
492 pass
493 f.attr = 'This is a different test'
494 f.dict_attr = dict(a=1, b=2, c=3)
495 def add_dict_attr(f):
496 f.dict_attr = {}
497 return f
498 assign = ('attr',)
499 update = ('dict_attr',)
500 @functools.wraps(f, assign, update)
501 @add_dict_attr
502 def wrapper():
503 pass
504 self.check_wrapper(wrapper, f, assign, update)
505 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600506 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000507 self.assertEqual(wrapper.__doc__, None)
508 self.assertEqual(wrapper.attr, 'This is a different test')
509 self.assertEqual(wrapper.dict_attr, f.dict_attr)
510
Łukasz Langa6f692512013-06-05 12:20:24 +0200511
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000512class TestReduce(unittest.TestCase):
513 func = functools.reduce
514
515 def test_reduce(self):
516 class Squares:
517 def __init__(self, max):
518 self.max = max
519 self.sofar = []
520
521 def __len__(self):
522 return len(self.sofar)
523
524 def __getitem__(self, i):
525 if not 0 <= i < self.max: raise IndexError
526 n = len(self.sofar)
527 while n <= i:
528 self.sofar.append(n*n)
529 n += 1
530 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000531 def add(x, y):
532 return x + y
533 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000534 self.assertEqual(
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000535 self.func(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000536 ['a','c','d','w']
537 )
538 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
539 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000540 self.func(lambda x, y: x*y, range(2,21), 1),
541 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000542 )
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000543 self.assertEqual(self.func(add, Squares(10)), 285)
544 self.assertEqual(self.func(add, Squares(10), 0), 285)
545 self.assertEqual(self.func(add, Squares(0), 0), 0)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000546 self.assertRaises(TypeError, self.func)
547 self.assertRaises(TypeError, self.func, 42, 42)
548 self.assertRaises(TypeError, self.func, 42, 42, 42)
549 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
550 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
551 self.assertRaises(TypeError, self.func, 42, (42, 42))
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000552 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
553 self.assertRaises(TypeError, self.func, add, "")
554 self.assertRaises(TypeError, self.func, add, ())
555 self.assertRaises(TypeError, self.func, add, object())
556
557 class TestFailingIter:
558 def __iter__(self):
559 raise RuntimeError
560 self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
561
562 self.assertEqual(self.func(add, [], None), None)
563 self.assertEqual(self.func(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000564
565 class BadSeq:
566 def __getitem__(self, index):
567 raise ValueError
568 self.assertRaises(ValueError, self.func, 42, BadSeq())
569
570 # Test reduce()'s use of iterators.
571 def test_iterator_usage(self):
572 class SequenceClass:
573 def __init__(self, n):
574 self.n = n
575 def __getitem__(self, i):
576 if 0 <= i < self.n:
577 return i
578 else:
579 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000580
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000581 from operator import add
582 self.assertEqual(self.func(add, SequenceClass(5)), 10)
583 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
584 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
585 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
586 self.assertEqual(self.func(add, SequenceClass(1)), 0)
587 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
588
589 d = {"one": 1, "two": 2, "three": 3}
590 self.assertEqual(self.func(add, d), "".join(d.keys()))
591
Łukasz Langa6f692512013-06-05 12:20:24 +0200592
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200593class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700594
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000595 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700596 def cmp1(x, y):
597 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100598 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700599 self.assertEqual(key(3), key(3))
600 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100601 self.assertGreaterEqual(key(3), key(3))
602
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700603 def cmp2(x, y):
604 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100605 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700606 self.assertEqual(key(4.0), key('4'))
607 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100608 self.assertLessEqual(key(2), key('35'))
609 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700610
611 def test_cmp_to_key_arguments(self):
612 def cmp1(x, y):
613 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100614 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700615 self.assertEqual(key(obj=3), key(obj=3))
616 self.assertGreater(key(obj=3), key(obj=1))
617 with self.assertRaises((TypeError, AttributeError)):
618 key(3) > 1 # rhs is not a K object
619 with self.assertRaises((TypeError, AttributeError)):
620 1 < key(3) # lhs is not a K object
621 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100622 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700623 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200624 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100625 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700626 with self.assertRaises(TypeError):
627 key() # too few args
628 with self.assertRaises(TypeError):
629 key(None, None) # too many args
630
631 def test_bad_cmp(self):
632 def cmp1(x, y):
633 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100634 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700635 with self.assertRaises(ZeroDivisionError):
636 key(3) > key(1)
637
638 class BadCmp:
639 def __lt__(self, other):
640 raise ZeroDivisionError
641 def cmp1(x, y):
642 return BadCmp()
643 with self.assertRaises(ZeroDivisionError):
644 key(3) > key(1)
645
646 def test_obj_field(self):
647 def cmp1(x, y):
648 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100649 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700650 self.assertEqual(key(50).obj, 50)
651
652 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000653 def mycmp(x, y):
654 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100655 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000656 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000657
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700658 def test_sort_int_str(self):
659 def mycmp(x, y):
660 x, y = int(x), int(y)
661 return (x > y) - (x < y)
662 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100663 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700664 self.assertEqual([int(value) for value in values],
665 [0, 1, 1, 2, 3, 4, 5, 7, 10])
666
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000667 def test_hash(self):
668 def mycmp(x, y):
669 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100670 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000671 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700672 self.assertRaises(TypeError, hash, k)
Raymond Hettingere7a24302011-05-03 11:16:36 -0700673 self.assertNotIsInstance(k, collections.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000674
Łukasz Langa6f692512013-06-05 12:20:24 +0200675
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200676@unittest.skipUnless(c_functools, 'requires the C _functools module')
677class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
678 if c_functools:
679 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100680
Łukasz Langa6f692512013-06-05 12:20:24 +0200681
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200682class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100683 cmp_to_key = staticmethod(py_functools.cmp_to_key)
684
Łukasz Langa6f692512013-06-05 12:20:24 +0200685
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000686class TestTotalOrdering(unittest.TestCase):
687
688 def test_total_ordering_lt(self):
689 @functools.total_ordering
690 class A:
691 def __init__(self, value):
692 self.value = value
693 def __lt__(self, other):
694 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000695 def __eq__(self, other):
696 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000697 self.assertTrue(A(1) < A(2))
698 self.assertTrue(A(2) > A(1))
699 self.assertTrue(A(1) <= A(2))
700 self.assertTrue(A(2) >= A(1))
701 self.assertTrue(A(2) <= A(2))
702 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000703 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000704
705 def test_total_ordering_le(self):
706 @functools.total_ordering
707 class A:
708 def __init__(self, value):
709 self.value = value
710 def __le__(self, other):
711 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000712 def __eq__(self, other):
713 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000714 self.assertTrue(A(1) < A(2))
715 self.assertTrue(A(2) > A(1))
716 self.assertTrue(A(1) <= A(2))
717 self.assertTrue(A(2) >= A(1))
718 self.assertTrue(A(2) <= A(2))
719 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000720 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000721
722 def test_total_ordering_gt(self):
723 @functools.total_ordering
724 class A:
725 def __init__(self, value):
726 self.value = value
727 def __gt__(self, other):
728 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000729 def __eq__(self, other):
730 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000731 self.assertTrue(A(1) < A(2))
732 self.assertTrue(A(2) > A(1))
733 self.assertTrue(A(1) <= A(2))
734 self.assertTrue(A(2) >= A(1))
735 self.assertTrue(A(2) <= A(2))
736 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000737 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000738
739 def test_total_ordering_ge(self):
740 @functools.total_ordering
741 class A:
742 def __init__(self, value):
743 self.value = value
744 def __ge__(self, other):
745 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000746 def __eq__(self, other):
747 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000748 self.assertTrue(A(1) < A(2))
749 self.assertTrue(A(2) > A(1))
750 self.assertTrue(A(1) <= A(2))
751 self.assertTrue(A(2) >= A(1))
752 self.assertTrue(A(2) <= A(2))
753 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000754 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000755
756 def test_total_ordering_no_overwrite(self):
757 # new methods should not overwrite existing
758 @functools.total_ordering
759 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000760 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +0000761 self.assertTrue(A(1) < A(2))
762 self.assertTrue(A(2) > A(1))
763 self.assertTrue(A(1) <= A(2))
764 self.assertTrue(A(2) >= A(1))
765 self.assertTrue(A(2) <= A(2))
766 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000767
Benjamin Peterson42ebee32010-04-11 01:43:16 +0000768 def test_no_operations_defined(self):
769 with self.assertRaises(ValueError):
770 @functools.total_ordering
771 class A:
772 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000773
Nick Coghlanf05d9812013-10-02 00:02:03 +1000774 def test_type_error_when_not_implemented(self):
775 # bug 10042; ensure stack overflow does not occur
776 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000777 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +1000778 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000779 def __init__(self, value):
780 self.value = value
781 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000782 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000783 return self.value == other.value
784 return False
785 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000786 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000787 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +1000788 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000789
Nick Coghlanf05d9812013-10-02 00:02:03 +1000790 @functools.total_ordering
791 class ImplementsGreaterThan:
792 def __init__(self, value):
793 self.value = value
794 def __eq__(self, other):
795 if isinstance(other, ImplementsGreaterThan):
796 return self.value == other.value
797 return False
798 def __gt__(self, other):
799 if isinstance(other, ImplementsGreaterThan):
800 return self.value > other.value
801 return NotImplemented
802
803 @functools.total_ordering
804 class ImplementsLessThanEqualTo:
805 def __init__(self, value):
806 self.value = value
807 def __eq__(self, other):
808 if isinstance(other, ImplementsLessThanEqualTo):
809 return self.value == other.value
810 return False
811 def __le__(self, other):
812 if isinstance(other, ImplementsLessThanEqualTo):
813 return self.value <= other.value
814 return NotImplemented
815
816 @functools.total_ordering
817 class ImplementsGreaterThanEqualTo:
818 def __init__(self, value):
819 self.value = value
820 def __eq__(self, other):
821 if isinstance(other, ImplementsGreaterThanEqualTo):
822 return self.value == other.value
823 return False
824 def __ge__(self, other):
825 if isinstance(other, ImplementsGreaterThanEqualTo):
826 return self.value >= other.value
827 return NotImplemented
828
829 @functools.total_ordering
830 class ComparatorNotImplemented:
831 def __init__(self, value):
832 self.value = value
833 def __eq__(self, other):
834 if isinstance(other, ComparatorNotImplemented):
835 return self.value == other.value
836 return False
837 def __lt__(self, other):
838 return NotImplemented
839
840 with self.subTest("LT < 1"), self.assertRaises(TypeError):
841 ImplementsLessThan(-1) < 1
842
843 with self.subTest("LT < LE"), self.assertRaises(TypeError):
844 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
845
846 with self.subTest("LT < GT"), self.assertRaises(TypeError):
847 ImplementsLessThan(1) < ImplementsGreaterThan(1)
848
849 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
850 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
851
852 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
853 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
854
855 with self.subTest("GT > GE"), self.assertRaises(TypeError):
856 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
857
858 with self.subTest("GT > LT"), self.assertRaises(TypeError):
859 ImplementsGreaterThan(5) > ImplementsLessThan(5)
860
861 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
862 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
863
864 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
865 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
866
867 with self.subTest("GE when equal"):
868 a = ComparatorNotImplemented(8)
869 b = ComparatorNotImplemented(8)
870 self.assertEqual(a, b)
871 with self.assertRaises(TypeError):
872 a >= b
873
874 with self.subTest("LE when equal"):
875 a = ComparatorNotImplemented(9)
876 b = ComparatorNotImplemented(9)
877 self.assertEqual(a, b)
878 with self.assertRaises(TypeError):
879 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +0200880
Georg Brandl2e7346a2010-07-31 18:09:23 +0000881class TestLRU(unittest.TestCase):
882
883 def test_lru(self):
884 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100885 return 3 * x + y
Georg Brandl2e7346a2010-07-31 18:09:23 +0000886 f = functools.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000887 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000888 self.assertEqual(maxsize, 20)
889 self.assertEqual(currsize, 0)
890 self.assertEqual(hits, 0)
891 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000892
893 domain = range(5)
894 for i in range(1000):
895 x, y = choice(domain), choice(domain)
896 actual = f(x, y)
897 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +0000898 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000899 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000900 self.assertTrue(hits > misses)
901 self.assertEqual(hits + misses, 1000)
902 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000903
Raymond Hettinger02566ec2010-09-04 22:46:06 +0000904 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +0000905 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000906 self.assertEqual(hits, 0)
907 self.assertEqual(misses, 0)
908 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000909 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000910 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000911 self.assertEqual(hits, 0)
912 self.assertEqual(misses, 1)
913 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000914
Nick Coghlan98876832010-08-17 06:17:18 +0000915 # Test bypassing the cache
916 self.assertIs(f.__wrapped__, orig)
917 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000918 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000919 self.assertEqual(hits, 0)
920 self.assertEqual(misses, 1)
921 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +0000922
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000923 # test size zero (which means "never-cache")
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000924 @functools.lru_cache(0)
925 def f():
926 nonlocal f_cnt
927 f_cnt += 1
928 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000929 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000930 f_cnt = 0
931 for i in range(5):
932 self.assertEqual(f(), 20)
933 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000934 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000935 self.assertEqual(hits, 0)
936 self.assertEqual(misses, 5)
937 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000938
939 # test size one
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000940 @functools.lru_cache(1)
941 def f():
942 nonlocal f_cnt
943 f_cnt += 1
944 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000945 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000946 f_cnt = 0
947 for i in range(5):
948 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000949 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000950 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000951 self.assertEqual(hits, 4)
952 self.assertEqual(misses, 1)
953 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000954
Raymond Hettingerf3098282010-08-15 03:30:45 +0000955 # test size two
956 @functools.lru_cache(2)
957 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000958 nonlocal f_cnt
959 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +0000960 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +0000961 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000962 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +0000963 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
964 # * * * *
965 self.assertEqual(f(x), x*10)
966 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000967 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000968 self.assertEqual(hits, 12)
969 self.assertEqual(misses, 4)
970 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000971
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +0000972 def test_lru_with_maxsize_none(self):
973 @functools.lru_cache(maxsize=None)
974 def fib(n):
975 if n < 2:
976 return n
977 return fib(n-1) + fib(n-2)
978 self.assertEqual([fib(n) for n in range(16)],
979 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
980 self.assertEqual(fib.cache_info(),
981 functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
982 fib.cache_clear()
983 self.assertEqual(fib.cache_info(),
984 functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
985
Raymond Hettinger4b779b32011-10-15 23:50:42 -0700986 def test_lru_with_exceptions(self):
987 # Verify that user_function exceptions get passed through without
988 # creating a hard-to-read chained exception.
989 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +0100990 for maxsize in (None, 128):
Raymond Hettinger4b779b32011-10-15 23:50:42 -0700991 @functools.lru_cache(maxsize)
992 def func(i):
993 return 'abc'[i]
994 self.assertEqual(func(0), 'a')
995 with self.assertRaises(IndexError) as cm:
996 func(15)
997 self.assertIsNone(cm.exception.__context__)
998 # Verify that the previous exception did not result in a cached entry
999 with self.assertRaises(IndexError):
1000 func(15)
1001
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001002 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001003 for maxsize in (None, 128):
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001004 @functools.lru_cache(maxsize=maxsize, typed=True)
1005 def square(x):
1006 return x * x
1007 self.assertEqual(square(3), 9)
1008 self.assertEqual(type(square(3)), type(9))
1009 self.assertEqual(square(3.0), 9.0)
1010 self.assertEqual(type(square(3.0)), type(9.0))
1011 self.assertEqual(square(x=3), 9)
1012 self.assertEqual(type(square(x=3)), type(9))
1013 self.assertEqual(square(x=3.0), 9.0)
1014 self.assertEqual(type(square(x=3.0)), type(9.0))
1015 self.assertEqual(square.cache_info().hits, 4)
1016 self.assertEqual(square.cache_info().misses, 4)
1017
Antoine Pitroub5b37142012-11-13 21:35:40 +01001018 def test_lru_with_keyword_args(self):
1019 @functools.lru_cache()
1020 def fib(n):
1021 if n < 2:
1022 return n
1023 return fib(n=n-1) + fib(n=n-2)
1024 self.assertEqual(
1025 [fib(n=number) for number in range(16)],
1026 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1027 )
1028 self.assertEqual(fib.cache_info(),
1029 functools._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
1030 fib.cache_clear()
1031 self.assertEqual(fib.cache_info(),
1032 functools._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
1033
1034 def test_lru_with_keyword_args_maxsize_none(self):
1035 @functools.lru_cache(maxsize=None)
1036 def fib(n):
1037 if n < 2:
1038 return n
1039 return fib(n=n-1) + fib(n=n-2)
1040 self.assertEqual([fib(n=number) for number in range(16)],
1041 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1042 self.assertEqual(fib.cache_info(),
1043 functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1044 fib.cache_clear()
1045 self.assertEqual(fib.cache_info(),
1046 functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1047
Raymond Hettinger03923422013-03-04 02:52:50 -05001048 def test_need_for_rlock(self):
1049 # This will deadlock on an LRU cache that uses a regular lock
1050
1051 @functools.lru_cache(maxsize=10)
1052 def test_func(x):
1053 'Used to demonstrate a reentrant lru_cache call within a single thread'
1054 return x
1055
1056 class DoubleEq:
1057 'Demonstrate a reentrant lru_cache call within a single thread'
1058 def __init__(self, x):
1059 self.x = x
1060 def __hash__(self):
1061 return self.x
1062 def __eq__(self, other):
1063 if self.x == 2:
1064 test_func(DoubleEq(1))
1065 return self.x == other.x
1066
1067 test_func(DoubleEq(1)) # Load the cache
1068 test_func(DoubleEq(2)) # Load the cache
1069 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1070 DoubleEq(2)) # Verify the correct return value
1071
1072
Łukasz Langa6f692512013-06-05 12:20:24 +02001073class TestSingleDispatch(unittest.TestCase):
1074 def test_simple_overloads(self):
1075 @functools.singledispatch
1076 def g(obj):
1077 return "base"
1078 def g_int(i):
1079 return "integer"
1080 g.register(int, g_int)
1081 self.assertEqual(g("str"), "base")
1082 self.assertEqual(g(1), "integer")
1083 self.assertEqual(g([1,2,3]), "base")
1084
1085 def test_mro(self):
1086 @functools.singledispatch
1087 def g(obj):
1088 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001089 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001090 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001091 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001092 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001093 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001094 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001095 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001096 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001097 def g_A(a):
1098 return "A"
1099 def g_B(b):
1100 return "B"
1101 g.register(A, g_A)
1102 g.register(B, g_B)
1103 self.assertEqual(g(A()), "A")
1104 self.assertEqual(g(B()), "B")
1105 self.assertEqual(g(C()), "A")
1106 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001107
1108 def test_register_decorator(self):
1109 @functools.singledispatch
1110 def g(obj):
1111 return "base"
1112 @g.register(int)
1113 def g_int(i):
1114 return "int %s" % (i,)
1115 self.assertEqual(g(""), "base")
1116 self.assertEqual(g(12), "int 12")
1117 self.assertIs(g.dispatch(int), g_int)
1118 self.assertIs(g.dispatch(object), g.dispatch(str))
1119 # Note: in the assert above this is not g.
1120 # @singledispatch returns the wrapper.
1121
1122 def test_wrapping_attributes(self):
1123 @functools.singledispatch
1124 def g(obj):
1125 "Simple test"
1126 return "Test"
1127 self.assertEqual(g.__name__, "g")
1128 self.assertEqual(g.__doc__, "Simple test")
1129
1130 @unittest.skipUnless(decimal, 'requires _decimal')
1131 @support.cpython_only
1132 def test_c_classes(self):
1133 @functools.singledispatch
1134 def g(obj):
1135 return "base"
1136 @g.register(decimal.DecimalException)
1137 def _(obj):
1138 return obj.args
1139 subn = decimal.Subnormal("Exponent < Emin")
1140 rnd = decimal.Rounded("Number got rounded")
1141 self.assertEqual(g(subn), ("Exponent < Emin",))
1142 self.assertEqual(g(rnd), ("Number got rounded",))
1143 @g.register(decimal.Subnormal)
1144 def _(obj):
1145 return "Too small to care."
1146 self.assertEqual(g(subn), "Too small to care.")
1147 self.assertEqual(g(rnd), ("Number got rounded",))
1148
1149 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001150 # None of the examples in this test depend on haystack ordering.
Łukasz Langa6f692512013-06-05 12:20:24 +02001151 c = collections
1152 mro = functools._compose_mro
1153 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1154 for haystack in permutations(bases):
1155 m = mro(dict, haystack)
Łukasz Langa3720c772013-07-01 16:00:38 +02001156 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, c.Sized,
1157 c.Iterable, c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001158 bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict]
1159 for haystack in permutations(bases):
1160 m = mro(c.ChainMap, haystack)
1161 self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping,
1162 c.Sized, c.Iterable, c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001163
1164 # If there's a generic function with implementations registered for
1165 # both Sized and Container, passing a defaultdict to it results in an
1166 # ambiguous dispatch which will cause a RuntimeError (see
1167 # test_mro_conflicts).
1168 bases = [c.Container, c.Sized, str]
1169 for haystack in permutations(bases):
1170 m = mro(c.defaultdict, [c.Sized, c.Container, str])
1171 self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container,
1172 object])
1173
1174 # MutableSequence below is registered directly on D. In other words, it
1175 # preceeds MutableMapping which means single dispatch will always
1176 # choose MutableSequence here.
1177 class D(c.defaultdict):
1178 pass
1179 c.MutableSequence.register(D)
1180 bases = [c.MutableSequence, c.MutableMapping]
1181 for haystack in permutations(bases):
1182 m = mro(D, bases)
1183 self.assertEqual(m, [D, c.MutableSequence, c.Sequence,
1184 c.defaultdict, dict, c.MutableMapping,
1185 c.Mapping, c.Sized, c.Iterable, c.Container,
1186 object])
1187
1188 # Container and Callable are registered on different base classes and
1189 # a generic function supporting both should always pick the Callable
1190 # implementation if a C instance is passed.
1191 class C(c.defaultdict):
1192 def __call__(self):
1193 pass
1194 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1195 for haystack in permutations(bases):
1196 m = mro(C, haystack)
1197 self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping,
1198 c.Sized, c.Iterable, c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001199
1200 def test_register_abc(self):
1201 c = collections
1202 d = {"a": "b"}
1203 l = [1, 2, 3]
1204 s = {object(), None}
1205 f = frozenset(s)
1206 t = (1, 2, 3)
1207 @functools.singledispatch
1208 def g(obj):
1209 return "base"
1210 self.assertEqual(g(d), "base")
1211 self.assertEqual(g(l), "base")
1212 self.assertEqual(g(s), "base")
1213 self.assertEqual(g(f), "base")
1214 self.assertEqual(g(t), "base")
1215 g.register(c.Sized, lambda obj: "sized")
1216 self.assertEqual(g(d), "sized")
1217 self.assertEqual(g(l), "sized")
1218 self.assertEqual(g(s), "sized")
1219 self.assertEqual(g(f), "sized")
1220 self.assertEqual(g(t), "sized")
1221 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1222 self.assertEqual(g(d), "mutablemapping")
1223 self.assertEqual(g(l), "sized")
1224 self.assertEqual(g(s), "sized")
1225 self.assertEqual(g(f), "sized")
1226 self.assertEqual(g(t), "sized")
1227 g.register(c.ChainMap, lambda obj: "chainmap")
1228 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1229 self.assertEqual(g(l), "sized")
1230 self.assertEqual(g(s), "sized")
1231 self.assertEqual(g(f), "sized")
1232 self.assertEqual(g(t), "sized")
1233 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1234 self.assertEqual(g(d), "mutablemapping")
1235 self.assertEqual(g(l), "mutablesequence")
1236 self.assertEqual(g(s), "sized")
1237 self.assertEqual(g(f), "sized")
1238 self.assertEqual(g(t), "sized")
1239 g.register(c.MutableSet, lambda obj: "mutableset")
1240 self.assertEqual(g(d), "mutablemapping")
1241 self.assertEqual(g(l), "mutablesequence")
1242 self.assertEqual(g(s), "mutableset")
1243 self.assertEqual(g(f), "sized")
1244 self.assertEqual(g(t), "sized")
1245 g.register(c.Mapping, lambda obj: "mapping")
1246 self.assertEqual(g(d), "mutablemapping") # not specific enough
1247 self.assertEqual(g(l), "mutablesequence")
1248 self.assertEqual(g(s), "mutableset")
1249 self.assertEqual(g(f), "sized")
1250 self.assertEqual(g(t), "sized")
1251 g.register(c.Sequence, lambda obj: "sequence")
1252 self.assertEqual(g(d), "mutablemapping")
1253 self.assertEqual(g(l), "mutablesequence")
1254 self.assertEqual(g(s), "mutableset")
1255 self.assertEqual(g(f), "sized")
1256 self.assertEqual(g(t), "sequence")
1257 g.register(c.Set, lambda obj: "set")
1258 self.assertEqual(g(d), "mutablemapping")
1259 self.assertEqual(g(l), "mutablesequence")
1260 self.assertEqual(g(s), "mutableset")
1261 self.assertEqual(g(f), "set")
1262 self.assertEqual(g(t), "sequence")
1263 g.register(dict, lambda obj: "dict")
1264 self.assertEqual(g(d), "dict")
1265 self.assertEqual(g(l), "mutablesequence")
1266 self.assertEqual(g(s), "mutableset")
1267 self.assertEqual(g(f), "set")
1268 self.assertEqual(g(t), "sequence")
1269 g.register(list, lambda obj: "list")
1270 self.assertEqual(g(d), "dict")
1271 self.assertEqual(g(l), "list")
1272 self.assertEqual(g(s), "mutableset")
1273 self.assertEqual(g(f), "set")
1274 self.assertEqual(g(t), "sequence")
1275 g.register(set, lambda obj: "concrete-set")
1276 self.assertEqual(g(d), "dict")
1277 self.assertEqual(g(l), "list")
1278 self.assertEqual(g(s), "concrete-set")
1279 self.assertEqual(g(f), "set")
1280 self.assertEqual(g(t), "sequence")
1281 g.register(frozenset, lambda obj: "frozen-set")
1282 self.assertEqual(g(d), "dict")
1283 self.assertEqual(g(l), "list")
1284 self.assertEqual(g(s), "concrete-set")
1285 self.assertEqual(g(f), "frozen-set")
1286 self.assertEqual(g(t), "sequence")
1287 g.register(tuple, lambda obj: "tuple")
1288 self.assertEqual(g(d), "dict")
1289 self.assertEqual(g(l), "list")
1290 self.assertEqual(g(s), "concrete-set")
1291 self.assertEqual(g(f), "frozen-set")
1292 self.assertEqual(g(t), "tuple")
1293
Łukasz Langa3720c772013-07-01 16:00:38 +02001294 def test_c3_abc(self):
1295 c = collections
1296 mro = functools._c3_mro
1297 class A(object):
1298 pass
1299 class B(A):
1300 def __len__(self):
1301 return 0 # implies Sized
1302 @c.Container.register
1303 class C(object):
1304 pass
1305 class D(object):
1306 pass # unrelated
1307 class X(D, C, B):
1308 def __call__(self):
1309 pass # implies Callable
1310 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
1311 for abcs in permutations([c.Sized, c.Callable, c.Container]):
1312 self.assertEqual(mro(X, abcs=abcs), expected)
1313 # unrelated ABCs don't appear in the resulting MRO
1314 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
1315 self.assertEqual(mro(X, abcs=many_abcs), expected)
1316
Łukasz Langa6f692512013-06-05 12:20:24 +02001317 def test_mro_conflicts(self):
1318 c = collections
Łukasz Langa6f692512013-06-05 12:20:24 +02001319 @functools.singledispatch
1320 def g(arg):
1321 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02001322 class O(c.Sized):
1323 def __len__(self):
1324 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001325 o = O()
1326 self.assertEqual(g(o), "base")
1327 g.register(c.Iterable, lambda arg: "iterable")
1328 g.register(c.Container, lambda arg: "container")
1329 g.register(c.Sized, lambda arg: "sized")
1330 g.register(c.Set, lambda arg: "set")
1331 self.assertEqual(g(o), "sized")
1332 c.Iterable.register(O)
1333 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
1334 c.Container.register(O)
1335 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02001336 c.Set.register(O)
1337 self.assertEqual(g(o), "set") # because c.Set is a subclass of
1338 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02001339 class P:
1340 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02001341 p = P()
1342 self.assertEqual(g(p), "base")
1343 c.Iterable.register(P)
1344 self.assertEqual(g(p), "iterable")
1345 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02001346 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02001347 g(p)
Łukasz Langa3720c772013-07-01 16:00:38 +02001348 self.assertIn(
1349 str(re_one.exception),
1350 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1351 "or <class 'collections.abc.Iterable'>"),
1352 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
1353 "or <class 'collections.abc.Container'>")),
1354 )
Łukasz Langa6f692512013-06-05 12:20:24 +02001355 class Q(c.Sized):
1356 def __len__(self):
1357 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001358 q = Q()
1359 self.assertEqual(g(q), "sized")
1360 c.Iterable.register(Q)
1361 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
1362 c.Set.register(Q)
1363 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02001364 # c.Sized and c.Iterable
1365 @functools.singledispatch
1366 def h(arg):
1367 return "base"
1368 @h.register(c.Sized)
1369 def _(arg):
1370 return "sized"
1371 @h.register(c.Container)
1372 def _(arg):
1373 return "container"
1374 # Even though Sized and Container are explicit bases of MutableMapping,
1375 # this ABC is implicitly registered on defaultdict which makes all of
1376 # MutableMapping's bases implicit as well from defaultdict's
1377 # perspective.
1378 with self.assertRaises(RuntimeError) as re_two:
1379 h(c.defaultdict(lambda: 0))
1380 self.assertIn(
1381 str(re_two.exception),
1382 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1383 "or <class 'collections.abc.Sized'>"),
1384 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1385 "or <class 'collections.abc.Container'>")),
1386 )
1387 class R(c.defaultdict):
1388 pass
1389 c.MutableSequence.register(R)
1390 @functools.singledispatch
1391 def i(arg):
1392 return "base"
1393 @i.register(c.MutableMapping)
1394 def _(arg):
1395 return "mapping"
1396 @i.register(c.MutableSequence)
1397 def _(arg):
1398 return "sequence"
1399 r = R()
1400 self.assertEqual(i(r), "sequence")
1401 class S:
1402 pass
1403 class T(S, c.Sized):
1404 def __len__(self):
1405 return 0
1406 t = T()
1407 self.assertEqual(h(t), "sized")
1408 c.Container.register(T)
1409 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
1410 class U:
1411 def __len__(self):
1412 return 0
1413 u = U()
1414 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
1415 # from the existence of __len__()
1416 c.Container.register(U)
1417 # There is no preference for registered versus inferred ABCs.
1418 with self.assertRaises(RuntimeError) as re_three:
1419 h(u)
1420 self.assertIn(
1421 str(re_three.exception),
1422 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1423 "or <class 'collections.abc.Sized'>"),
1424 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1425 "or <class 'collections.abc.Container'>")),
1426 )
1427 class V(c.Sized, S):
1428 def __len__(self):
1429 return 0
1430 @functools.singledispatch
1431 def j(arg):
1432 return "base"
1433 @j.register(S)
1434 def _(arg):
1435 return "s"
1436 @j.register(c.Container)
1437 def _(arg):
1438 return "container"
1439 v = V()
1440 self.assertEqual(j(v), "s")
1441 c.Container.register(V)
1442 self.assertEqual(j(v), "container") # because it ends up right after
1443 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02001444
1445 def test_cache_invalidation(self):
1446 from collections import UserDict
1447 class TracingDict(UserDict):
1448 def __init__(self, *args, **kwargs):
1449 super(TracingDict, self).__init__(*args, **kwargs)
1450 self.set_ops = []
1451 self.get_ops = []
1452 def __getitem__(self, key):
1453 result = self.data[key]
1454 self.get_ops.append(key)
1455 return result
1456 def __setitem__(self, key, value):
1457 self.set_ops.append(key)
1458 self.data[key] = value
1459 def clear(self):
1460 self.data.clear()
1461 _orig_wkd = functools.WeakKeyDictionary
1462 td = TracingDict()
1463 functools.WeakKeyDictionary = lambda: td
1464 c = collections
1465 @functools.singledispatch
1466 def g(arg):
1467 return "base"
1468 d = {}
1469 l = []
1470 self.assertEqual(len(td), 0)
1471 self.assertEqual(g(d), "base")
1472 self.assertEqual(len(td), 1)
1473 self.assertEqual(td.get_ops, [])
1474 self.assertEqual(td.set_ops, [dict])
1475 self.assertEqual(td.data[dict], g.registry[object])
1476 self.assertEqual(g(l), "base")
1477 self.assertEqual(len(td), 2)
1478 self.assertEqual(td.get_ops, [])
1479 self.assertEqual(td.set_ops, [dict, list])
1480 self.assertEqual(td.data[dict], g.registry[object])
1481 self.assertEqual(td.data[list], g.registry[object])
1482 self.assertEqual(td.data[dict], td.data[list])
1483 self.assertEqual(g(l), "base")
1484 self.assertEqual(g(d), "base")
1485 self.assertEqual(td.get_ops, [list, dict])
1486 self.assertEqual(td.set_ops, [dict, list])
1487 g.register(list, lambda arg: "list")
1488 self.assertEqual(td.get_ops, [list, dict])
1489 self.assertEqual(len(td), 0)
1490 self.assertEqual(g(d), "base")
1491 self.assertEqual(len(td), 1)
1492 self.assertEqual(td.get_ops, [list, dict])
1493 self.assertEqual(td.set_ops, [dict, list, dict])
1494 self.assertEqual(td.data[dict],
1495 functools._find_impl(dict, g.registry))
1496 self.assertEqual(g(l), "list")
1497 self.assertEqual(len(td), 2)
1498 self.assertEqual(td.get_ops, [list, dict])
1499 self.assertEqual(td.set_ops, [dict, list, dict, list])
1500 self.assertEqual(td.data[list],
1501 functools._find_impl(list, g.registry))
1502 class X:
1503 pass
1504 c.MutableMapping.register(X) # Will not invalidate the cache,
1505 # not using ABCs yet.
1506 self.assertEqual(g(d), "base")
1507 self.assertEqual(g(l), "list")
1508 self.assertEqual(td.get_ops, [list, dict, dict, list])
1509 self.assertEqual(td.set_ops, [dict, list, dict, list])
1510 g.register(c.Sized, lambda arg: "sized")
1511 self.assertEqual(len(td), 0)
1512 self.assertEqual(g(d), "sized")
1513 self.assertEqual(len(td), 1)
1514 self.assertEqual(td.get_ops, [list, dict, dict, list])
1515 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
1516 self.assertEqual(g(l), "list")
1517 self.assertEqual(len(td), 2)
1518 self.assertEqual(td.get_ops, [list, dict, dict, list])
1519 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1520 self.assertEqual(g(l), "list")
1521 self.assertEqual(g(d), "sized")
1522 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
1523 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1524 g.dispatch(list)
1525 g.dispatch(dict)
1526 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
1527 list, dict])
1528 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1529 c.MutableSet.register(X) # Will invalidate the cache.
1530 self.assertEqual(len(td), 2) # Stale cache.
1531 self.assertEqual(g(l), "list")
1532 self.assertEqual(len(td), 1)
1533 g.register(c.MutableMapping, lambda arg: "mutablemapping")
1534 self.assertEqual(len(td), 0)
1535 self.assertEqual(g(d), "mutablemapping")
1536 self.assertEqual(len(td), 1)
1537 self.assertEqual(g(l), "list")
1538 self.assertEqual(len(td), 2)
1539 g.register(dict, lambda arg: "dict")
1540 self.assertEqual(g(d), "dict")
1541 self.assertEqual(g(l), "list")
1542 g._clear_cache()
1543 self.assertEqual(len(td), 0)
1544 functools.WeakKeyDictionary = _orig_wkd
1545
1546
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001547def test_main(verbose=None):
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001548 test_classes = (
Antoine Pitroub5b37142012-11-13 21:35:40 +01001549 TestPartialC,
1550 TestPartialPy,
1551 TestPartialCSubclass,
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001552 TestPartialMethod,
Thomas Wouters73e5a5b2006-06-08 15:35:45 +00001553 TestUpdateWrapper,
Benjamin Peterson9c2930e2010-08-23 17:40:33 +00001554 TestTotalOrdering,
Antoine Pitroub5b37142012-11-13 21:35:40 +01001555 TestCmpToKeyC,
1556 TestCmpToKeyPy,
Guido van Rossum0919a1a2006-08-26 20:49:04 +00001557 TestWraps,
Georg Brandl2e7346a2010-07-31 18:09:23 +00001558 TestReduce,
1559 TestLRU,
Łukasz Langa6f692512013-06-05 12:20:24 +02001560 TestSingleDispatch,
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001561 )
Benjamin Petersonee8712c2008-05-20 21:35:26 +00001562 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001563
1564 # verify reference counting
1565 if verbose and hasattr(sys, "gettotalrefcount"):
1566 import gc
1567 counts = [None] * 5
Guido van Rossum805365e2007-05-07 22:24:25 +00001568 for i in range(len(counts)):
Benjamin Petersonee8712c2008-05-20 21:35:26 +00001569 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001570 gc.collect()
1571 counts[i] = sys.gettotalrefcount()
Guido van Rossumbe19ed72007-02-09 05:37:30 +00001572 print(counts)
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001573
1574if __name__ == '__main__':
1575 test_main(verbose=True)