blob: 10120530b81ec3e0cd2f0ee58c4e0e7f9d5ad4a0 [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettinger003be522011-05-03 11:01:32 -07002import collections
Łukasz Langa6f692512013-06-05 12:20:24 +02003from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00004import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00005from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02006import sys
7from test import support
8import unittest
9from weakref import proxy
Raymond Hettinger9c323f82005-02-28 19:39:44 +000010
Antoine Pitroub5b37142012-11-13 21:35:40 +010011import functools
12
Antoine Pitroub5b37142012-11-13 21:35:40 +010013py_functools = support.import_fresh_module('functools', blocked=['_functools'])
14c_functools = support.import_fresh_module('functools', fresh=['_functools'])
15
Łukasz Langa6f692512013-06-05 12:20:24 +020016decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
17
18
Raymond Hettinger9c323f82005-02-28 19:39:44 +000019def capture(*args, **kw):
20 """capture all positional and keyword arguments"""
21 return args, kw
22
Łukasz Langa6f692512013-06-05 12:20:24 +020023
Jack Diederiche0cbd692009-04-01 04:27:09 +000024def signature(part):
25 """ return the signature of a partial object """
26 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000027
Łukasz Langa6f692512013-06-05 12:20:24 +020028
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020029class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000030
31 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010032 p = self.partial(capture, 1, 2, a=10, b=20)
33 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000034 self.assertEqual(p(3, 4, b=30, c=40),
35 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010036 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000037 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000038
39 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010040 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000041 # attributes should be readable
42 self.assertEqual(p.func, capture)
43 self.assertEqual(p.args, (1, 2))
44 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000045
46 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010047 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000048 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010049 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000050 except TypeError:
51 pass
52 else:
53 self.fail('First arg not checked for callability')
54
55 def test_protection_of_callers_dict_argument(self):
56 # a caller's dictionary should not be altered by partial
57 def func(a=10, b=20):
58 return a
59 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010060 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000061 self.assertEqual(p(**d), 3)
62 self.assertEqual(d, {'a':3})
63 p(b=7)
64 self.assertEqual(d, {'a':3})
65
66 def test_arg_combinations(self):
67 # exercise special code paths for zero args in either partial
68 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010069 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000070 self.assertEqual(p(), ((), {}))
71 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010072 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000073 self.assertEqual(p(), ((1,2), {}))
74 self.assertEqual(p(3,4), ((1,2,3,4), {}))
75
76 def test_kw_combinations(self):
77 # exercise special code paths for no keyword args in
78 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010079 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000080 self.assertEqual(p(), ((), {}))
81 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010082 p = self.partial(capture, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000083 self.assertEqual(p(), ((), {'a':1}))
84 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
85 # keyword args in the call override those in the partial object
86 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
87
88 def test_positional(self):
89 # make sure positional arguments are captured correctly
90 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +010091 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000092 expected = args + ('x',)
93 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +000094 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +000095
96 def test_keyword(self):
97 # make sure keyword arguments are captured correctly
98 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +010099 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000100 expected = {'a':a,'x':None}
101 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000102 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000103
104 def test_no_side_effects(self):
105 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100106 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000107 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000108 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000109 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000110 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000111
112 def test_error_propagation(self):
113 def f(x, y):
114 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100115 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
116 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
117 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
118 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000119
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000120 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100121 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000122 p = proxy(f)
123 self.assertEqual(f.func, p.func)
124 f = None
125 self.assertRaises(ReferenceError, getattr, p, 'func')
126
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000127 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000128 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100129 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000130 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100131 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000132 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000133
Łukasz Langa6f692512013-06-05 12:20:24 +0200134
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200135@unittest.skipUnless(c_functools, 'requires the C _functools module')
136class TestPartialC(TestPartial, unittest.TestCase):
137 if c_functools:
138 partial = c_functools.partial
139
Zachary Ware101d9e72013-12-08 00:44:27 -0600140 def test_attributes_unwritable(self):
141 # attributes should not be writable
142 p = self.partial(capture, 1, 2, a=10, b=20)
143 self.assertRaises(AttributeError, setattr, p, 'func', map)
144 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
145 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
146
147 p = self.partial(hex)
148 try:
149 del p.__dict__
150 except TypeError:
151 pass
152 else:
153 self.fail('partial object allowed __dict__ to be deleted')
154
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000155 def test_repr(self):
156 args = (object(), object())
157 args_repr = ', '.join(repr(a) for a in args)
Christian Heimesd0628922013-11-22 01:22:47 +0100158 #kwargs = {'a': object(), 'b': object()}
159 kwargs = {'a': object()}
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000160 kwargs_repr = ', '.join("%s=%r" % (k, v) for k, v in kwargs.items())
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200161 if self.partial is c_functools.partial:
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000162 name = 'functools.partial'
163 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100164 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000165
Antoine Pitroub5b37142012-11-13 21:35:40 +0100166 f = self.partial(capture)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000167 self.assertEqual('{}({!r})'.format(name, capture),
168 repr(f))
169
Antoine Pitroub5b37142012-11-13 21:35:40 +0100170 f = self.partial(capture, *args)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000171 self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr),
172 repr(f))
173
Antoine Pitroub5b37142012-11-13 21:35:40 +0100174 f = self.partial(capture, **kwargs)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000175 self.assertEqual('{}({!r}, {})'.format(name, capture, kwargs_repr),
176 repr(f))
177
Antoine Pitroub5b37142012-11-13 21:35:40 +0100178 f = self.partial(capture, *args, **kwargs)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000179 self.assertEqual('{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr),
180 repr(f))
181
Jack Diederiche0cbd692009-04-01 04:27:09 +0000182 def test_pickle(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100183 f = self.partial(signature, 'asdf', bar=True)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000184 f.add_something_to__dict__ = True
Serhiy Storchakabad12572014-12-15 14:03:42 +0200185 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
186 f_copy = pickle.loads(pickle.dumps(f, proto))
187 self.assertEqual(signature(f), signature(f_copy))
Jack Diederiche0cbd692009-04-01 04:27:09 +0000188
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200189 # Issue 6083: Reference counting bug
190 def test_setstate_refcount(self):
191 class BadSequence:
192 def __len__(self):
193 return 4
194 def __getitem__(self, key):
195 if key == 0:
196 return max
197 elif key == 1:
198 return tuple(range(1000000))
199 elif key in (2, 3):
200 return {}
201 raise IndexError
202
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200203 f = self.partial(object)
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200204 self.assertRaisesRegex(SystemError,
205 "new style getargs format but argument is not a tuple",
206 f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000207
Łukasz Langa6f692512013-06-05 12:20:24 +0200208
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200209class TestPartialPy(TestPartial, unittest.TestCase):
210 partial = staticmethod(py_functools.partial)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000211
Łukasz Langa6f692512013-06-05 12:20:24 +0200212
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200213if c_functools:
214 class PartialSubclass(c_functools.partial):
215 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100216
Łukasz Langa6f692512013-06-05 12:20:24 +0200217
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200218@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200219class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200220 if c_functools:
221 partial = PartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000222
Łukasz Langa6f692512013-06-05 12:20:24 +0200223
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000224class TestPartialMethod(unittest.TestCase):
225
226 class A(object):
227 nothing = functools.partialmethod(capture)
228 positional = functools.partialmethod(capture, 1)
229 keywords = functools.partialmethod(capture, a=2)
230 both = functools.partialmethod(capture, 3, b=4)
231
232 nested = functools.partialmethod(positional, 5)
233
234 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
235
236 static = functools.partialmethod(staticmethod(capture), 8)
237 cls = functools.partialmethod(classmethod(capture), d=9)
238
239 a = A()
240
241 def test_arg_combinations(self):
242 self.assertEqual(self.a.nothing(), ((self.a,), {}))
243 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
244 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
245 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
246
247 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
248 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
249 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
250 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
251
252 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
253 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
254 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
255 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
256
257 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
258 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
259 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
260 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
261
262 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
263
264 def test_nested(self):
265 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
266 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
267 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
268 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
269
270 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
271
272 def test_over_partial(self):
273 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
274 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
275 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
276 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
277
278 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
279
280 def test_bound_method_introspection(self):
281 obj = self.a
282 self.assertIs(obj.both.__self__, obj)
283 self.assertIs(obj.nested.__self__, obj)
284 self.assertIs(obj.over_partial.__self__, obj)
285 self.assertIs(obj.cls.__self__, self.A)
286 self.assertIs(self.A.cls.__self__, self.A)
287
288 def test_unbound_method_retrieval(self):
289 obj = self.A
290 self.assertFalse(hasattr(obj.both, "__self__"))
291 self.assertFalse(hasattr(obj.nested, "__self__"))
292 self.assertFalse(hasattr(obj.over_partial, "__self__"))
293 self.assertFalse(hasattr(obj.static, "__self__"))
294 self.assertFalse(hasattr(self.a.static, "__self__"))
295
296 def test_descriptors(self):
297 for obj in [self.A, self.a]:
298 with self.subTest(obj=obj):
299 self.assertEqual(obj.static(), ((8,), {}))
300 self.assertEqual(obj.static(5), ((8, 5), {}))
301 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
302 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
303
304 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
305 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
306 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
307 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
308
309 def test_overriding_keywords(self):
310 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
311 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
312
313 def test_invalid_args(self):
314 with self.assertRaises(TypeError):
315 class B(object):
316 method = functools.partialmethod(None, 1)
317
318 def test_repr(self):
319 self.assertEqual(repr(vars(self.A)['both']),
320 'functools.partialmethod({}, 3, b=4)'.format(capture))
321
322 def test_abstract(self):
323 class Abstract(abc.ABCMeta):
324
325 @abc.abstractmethod
326 def add(self, x, y):
327 pass
328
329 add5 = functools.partialmethod(add, 5)
330
331 self.assertTrue(Abstract.add.__isabstractmethod__)
332 self.assertTrue(Abstract.add5.__isabstractmethod__)
333
334 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
335 self.assertFalse(getattr(func, '__isabstractmethod__', False))
336
337
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000338class TestUpdateWrapper(unittest.TestCase):
339
340 def check_wrapper(self, wrapper, wrapped,
341 assigned=functools.WRAPPER_ASSIGNMENTS,
342 updated=functools.WRAPPER_UPDATES):
343 # Check attributes were assigned
344 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000345 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000346 # Check attributes were updated
347 for name in updated:
348 wrapper_attr = getattr(wrapper, name)
349 wrapped_attr = getattr(wrapped, name)
350 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000351 if name == "__dict__" and key == "__wrapped__":
352 # __wrapped__ is overwritten by the update code
353 continue
354 self.assertIs(wrapped_attr[key], wrapper_attr[key])
355 # Check __wrapped__
356 self.assertIs(wrapper.__wrapped__, wrapped)
357
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000358
R. David Murray378c0cf2010-02-24 01:46:21 +0000359 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000360 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000361 """This is a test"""
362 pass
363 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000364 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000365 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000366 pass
367 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000368 return wrapper, f
369
370 def test_default_update(self):
371 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000372 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000373 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000374 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600375 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000376 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000377 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
378 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000379
R. David Murray378c0cf2010-02-24 01:46:21 +0000380 @unittest.skipIf(sys.flags.optimize >= 2,
381 "Docstrings are omitted with -O2 and above")
382 def test_default_update_doc(self):
383 wrapper, f = self._default_update()
384 self.assertEqual(wrapper.__doc__, 'This is a test')
385
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000386 def test_no_update(self):
387 def f():
388 """This is a test"""
389 pass
390 f.attr = 'This is also a test'
391 def wrapper():
392 pass
393 functools.update_wrapper(wrapper, f, (), ())
394 self.check_wrapper(wrapper, f, (), ())
395 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600396 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000397 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000398 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000399 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000400
401 def test_selective_update(self):
402 def f():
403 pass
404 f.attr = 'This is a different test'
405 f.dict_attr = dict(a=1, b=2, c=3)
406 def wrapper():
407 pass
408 wrapper.dict_attr = {}
409 assign = ('attr',)
410 update = ('dict_attr',)
411 functools.update_wrapper(wrapper, f, assign, update)
412 self.check_wrapper(wrapper, f, assign, update)
413 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600414 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000415 self.assertEqual(wrapper.__doc__, None)
416 self.assertEqual(wrapper.attr, 'This is a different test')
417 self.assertEqual(wrapper.dict_attr, f.dict_attr)
418
Nick Coghlan98876832010-08-17 06:17:18 +0000419 def test_missing_attributes(self):
420 def f():
421 pass
422 def wrapper():
423 pass
424 wrapper.dict_attr = {}
425 assign = ('attr',)
426 update = ('dict_attr',)
427 # Missing attributes on wrapped object are ignored
428 functools.update_wrapper(wrapper, f, assign, update)
429 self.assertNotIn('attr', wrapper.__dict__)
430 self.assertEqual(wrapper.dict_attr, {})
431 # Wrapper must have expected attributes for updating
432 del wrapper.dict_attr
433 with self.assertRaises(AttributeError):
434 functools.update_wrapper(wrapper, f, assign, update)
435 wrapper.dict_attr = 1
436 with self.assertRaises(AttributeError):
437 functools.update_wrapper(wrapper, f, assign, update)
438
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200439 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000440 @unittest.skipIf(sys.flags.optimize >= 2,
441 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000442 def test_builtin_update(self):
443 # Test for bug #1576241
444 def wrapper():
445 pass
446 functools.update_wrapper(wrapper, max)
447 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000448 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000449 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000450
Łukasz Langa6f692512013-06-05 12:20:24 +0200451
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000452class TestWraps(TestUpdateWrapper):
453
R. David Murray378c0cf2010-02-24 01:46:21 +0000454 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000455 def f():
456 """This is a test"""
457 pass
458 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000459 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000460 @functools.wraps(f)
461 def wrapper():
462 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600463 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000464
465 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600466 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000467 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000468 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600469 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000470 self.assertEqual(wrapper.attr, 'This is also a test')
471
Antoine Pitroub5b37142012-11-13 21:35:40 +0100472 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000473 "Docstrings are omitted with -O2 and above")
474 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600475 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000476 self.assertEqual(wrapper.__doc__, 'This is a test')
477
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000478 def test_no_update(self):
479 def f():
480 """This is a test"""
481 pass
482 f.attr = 'This is also a test'
483 @functools.wraps(f, (), ())
484 def wrapper():
485 pass
486 self.check_wrapper(wrapper, f, (), ())
487 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600488 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000489 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000490 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000491
492 def test_selective_update(self):
493 def f():
494 pass
495 f.attr = 'This is a different test'
496 f.dict_attr = dict(a=1, b=2, c=3)
497 def add_dict_attr(f):
498 f.dict_attr = {}
499 return f
500 assign = ('attr',)
501 update = ('dict_attr',)
502 @functools.wraps(f, assign, update)
503 @add_dict_attr
504 def wrapper():
505 pass
506 self.check_wrapper(wrapper, f, assign, update)
507 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600508 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000509 self.assertEqual(wrapper.__doc__, None)
510 self.assertEqual(wrapper.attr, 'This is a different test')
511 self.assertEqual(wrapper.dict_attr, f.dict_attr)
512
Łukasz Langa6f692512013-06-05 12:20:24 +0200513
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000514class TestReduce(unittest.TestCase):
515 func = functools.reduce
516
517 def test_reduce(self):
518 class Squares:
519 def __init__(self, max):
520 self.max = max
521 self.sofar = []
522
523 def __len__(self):
524 return len(self.sofar)
525
526 def __getitem__(self, i):
527 if not 0 <= i < self.max: raise IndexError
528 n = len(self.sofar)
529 while n <= i:
530 self.sofar.append(n*n)
531 n += 1
532 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000533 def add(x, y):
534 return x + y
535 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000536 self.assertEqual(
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000537 self.func(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000538 ['a','c','d','w']
539 )
540 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
541 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000542 self.func(lambda x, y: x*y, range(2,21), 1),
543 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000544 )
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000545 self.assertEqual(self.func(add, Squares(10)), 285)
546 self.assertEqual(self.func(add, Squares(10), 0), 285)
547 self.assertEqual(self.func(add, Squares(0), 0), 0)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000548 self.assertRaises(TypeError, self.func)
549 self.assertRaises(TypeError, self.func, 42, 42)
550 self.assertRaises(TypeError, self.func, 42, 42, 42)
551 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
552 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
553 self.assertRaises(TypeError, self.func, 42, (42, 42))
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000554 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
555 self.assertRaises(TypeError, self.func, add, "")
556 self.assertRaises(TypeError, self.func, add, ())
557 self.assertRaises(TypeError, self.func, add, object())
558
559 class TestFailingIter:
560 def __iter__(self):
561 raise RuntimeError
562 self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
563
564 self.assertEqual(self.func(add, [], None), None)
565 self.assertEqual(self.func(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000566
567 class BadSeq:
568 def __getitem__(self, index):
569 raise ValueError
570 self.assertRaises(ValueError, self.func, 42, BadSeq())
571
572 # Test reduce()'s use of iterators.
573 def test_iterator_usage(self):
574 class SequenceClass:
575 def __init__(self, n):
576 self.n = n
577 def __getitem__(self, i):
578 if 0 <= i < self.n:
579 return i
580 else:
581 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000582
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000583 from operator import add
584 self.assertEqual(self.func(add, SequenceClass(5)), 10)
585 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
586 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
587 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
588 self.assertEqual(self.func(add, SequenceClass(1)), 0)
589 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
590
591 d = {"one": 1, "two": 2, "three": 3}
592 self.assertEqual(self.func(add, d), "".join(d.keys()))
593
Łukasz Langa6f692512013-06-05 12:20:24 +0200594
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200595class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700596
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000597 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700598 def cmp1(x, y):
599 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100600 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700601 self.assertEqual(key(3), key(3))
602 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100603 self.assertGreaterEqual(key(3), key(3))
604
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700605 def cmp2(x, y):
606 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100607 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700608 self.assertEqual(key(4.0), key('4'))
609 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100610 self.assertLessEqual(key(2), key('35'))
611 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700612
613 def test_cmp_to_key_arguments(self):
614 def cmp1(x, y):
615 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100616 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700617 self.assertEqual(key(obj=3), key(obj=3))
618 self.assertGreater(key(obj=3), key(obj=1))
619 with self.assertRaises((TypeError, AttributeError)):
620 key(3) > 1 # rhs is not a K object
621 with self.assertRaises((TypeError, AttributeError)):
622 1 < key(3) # lhs is not a K object
623 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100624 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700625 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200626 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100627 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700628 with self.assertRaises(TypeError):
629 key() # too few args
630 with self.assertRaises(TypeError):
631 key(None, None) # too many args
632
633 def test_bad_cmp(self):
634 def cmp1(x, y):
635 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100636 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700637 with self.assertRaises(ZeroDivisionError):
638 key(3) > key(1)
639
640 class BadCmp:
641 def __lt__(self, other):
642 raise ZeroDivisionError
643 def cmp1(x, y):
644 return BadCmp()
645 with self.assertRaises(ZeroDivisionError):
646 key(3) > key(1)
647
648 def test_obj_field(self):
649 def cmp1(x, y):
650 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100651 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700652 self.assertEqual(key(50).obj, 50)
653
654 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000655 def mycmp(x, y):
656 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100657 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000658 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000659
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700660 def test_sort_int_str(self):
661 def mycmp(x, y):
662 x, y = int(x), int(y)
663 return (x > y) - (x < y)
664 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100665 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700666 self.assertEqual([int(value) for value in values],
667 [0, 1, 1, 2, 3, 4, 5, 7, 10])
668
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000669 def test_hash(self):
670 def mycmp(x, y):
671 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100672 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000673 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700674 self.assertRaises(TypeError, hash, k)
Raymond Hettingere7a24302011-05-03 11:16:36 -0700675 self.assertNotIsInstance(k, collections.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000676
Łukasz Langa6f692512013-06-05 12:20:24 +0200677
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200678@unittest.skipUnless(c_functools, 'requires the C _functools module')
679class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
680 if c_functools:
681 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100682
Łukasz Langa6f692512013-06-05 12:20:24 +0200683
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200684class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100685 cmp_to_key = staticmethod(py_functools.cmp_to_key)
686
Łukasz Langa6f692512013-06-05 12:20:24 +0200687
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000688class TestTotalOrdering(unittest.TestCase):
689
690 def test_total_ordering_lt(self):
691 @functools.total_ordering
692 class A:
693 def __init__(self, value):
694 self.value = value
695 def __lt__(self, other):
696 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000697 def __eq__(self, other):
698 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000699 self.assertTrue(A(1) < A(2))
700 self.assertTrue(A(2) > A(1))
701 self.assertTrue(A(1) <= A(2))
702 self.assertTrue(A(2) >= A(1))
703 self.assertTrue(A(2) <= A(2))
704 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000705 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000706
707 def test_total_ordering_le(self):
708 @functools.total_ordering
709 class A:
710 def __init__(self, value):
711 self.value = value
712 def __le__(self, other):
713 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000714 def __eq__(self, other):
715 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000716 self.assertTrue(A(1) < A(2))
717 self.assertTrue(A(2) > A(1))
718 self.assertTrue(A(1) <= A(2))
719 self.assertTrue(A(2) >= A(1))
720 self.assertTrue(A(2) <= A(2))
721 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000722 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000723
724 def test_total_ordering_gt(self):
725 @functools.total_ordering
726 class A:
727 def __init__(self, value):
728 self.value = value
729 def __gt__(self, other):
730 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000731 def __eq__(self, other):
732 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000733 self.assertTrue(A(1) < A(2))
734 self.assertTrue(A(2) > A(1))
735 self.assertTrue(A(1) <= A(2))
736 self.assertTrue(A(2) >= A(1))
737 self.assertTrue(A(2) <= A(2))
738 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000739 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000740
741 def test_total_ordering_ge(self):
742 @functools.total_ordering
743 class A:
744 def __init__(self, value):
745 self.value = value
746 def __ge__(self, other):
747 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000748 def __eq__(self, other):
749 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000750 self.assertTrue(A(1) < A(2))
751 self.assertTrue(A(2) > A(1))
752 self.assertTrue(A(1) <= A(2))
753 self.assertTrue(A(2) >= A(1))
754 self.assertTrue(A(2) <= A(2))
755 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000756 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000757
758 def test_total_ordering_no_overwrite(self):
759 # new methods should not overwrite existing
760 @functools.total_ordering
761 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000762 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +0000763 self.assertTrue(A(1) < A(2))
764 self.assertTrue(A(2) > A(1))
765 self.assertTrue(A(1) <= A(2))
766 self.assertTrue(A(2) >= A(1))
767 self.assertTrue(A(2) <= A(2))
768 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000769
Benjamin Peterson42ebee32010-04-11 01:43:16 +0000770 def test_no_operations_defined(self):
771 with self.assertRaises(ValueError):
772 @functools.total_ordering
773 class A:
774 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000775
Nick Coghlanf05d9812013-10-02 00:02:03 +1000776 def test_type_error_when_not_implemented(self):
777 # bug 10042; ensure stack overflow does not occur
778 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000779 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +1000780 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000781 def __init__(self, value):
782 self.value = value
783 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000784 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000785 return self.value == other.value
786 return False
787 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000788 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000789 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +1000790 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000791
Nick Coghlanf05d9812013-10-02 00:02:03 +1000792 @functools.total_ordering
793 class ImplementsGreaterThan:
794 def __init__(self, value):
795 self.value = value
796 def __eq__(self, other):
797 if isinstance(other, ImplementsGreaterThan):
798 return self.value == other.value
799 return False
800 def __gt__(self, other):
801 if isinstance(other, ImplementsGreaterThan):
802 return self.value > other.value
803 return NotImplemented
804
805 @functools.total_ordering
806 class ImplementsLessThanEqualTo:
807 def __init__(self, value):
808 self.value = value
809 def __eq__(self, other):
810 if isinstance(other, ImplementsLessThanEqualTo):
811 return self.value == other.value
812 return False
813 def __le__(self, other):
814 if isinstance(other, ImplementsLessThanEqualTo):
815 return self.value <= other.value
816 return NotImplemented
817
818 @functools.total_ordering
819 class ImplementsGreaterThanEqualTo:
820 def __init__(self, value):
821 self.value = value
822 def __eq__(self, other):
823 if isinstance(other, ImplementsGreaterThanEqualTo):
824 return self.value == other.value
825 return False
826 def __ge__(self, other):
827 if isinstance(other, ImplementsGreaterThanEqualTo):
828 return self.value >= other.value
829 return NotImplemented
830
831 @functools.total_ordering
832 class ComparatorNotImplemented:
833 def __init__(self, value):
834 self.value = value
835 def __eq__(self, other):
836 if isinstance(other, ComparatorNotImplemented):
837 return self.value == other.value
838 return False
839 def __lt__(self, other):
840 return NotImplemented
841
842 with self.subTest("LT < 1"), self.assertRaises(TypeError):
843 ImplementsLessThan(-1) < 1
844
845 with self.subTest("LT < LE"), self.assertRaises(TypeError):
846 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
847
848 with self.subTest("LT < GT"), self.assertRaises(TypeError):
849 ImplementsLessThan(1) < ImplementsGreaterThan(1)
850
851 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
852 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
853
854 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
855 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
856
857 with self.subTest("GT > GE"), self.assertRaises(TypeError):
858 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
859
860 with self.subTest("GT > LT"), self.assertRaises(TypeError):
861 ImplementsGreaterThan(5) > ImplementsLessThan(5)
862
863 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
864 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
865
866 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
867 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
868
869 with self.subTest("GE when equal"):
870 a = ComparatorNotImplemented(8)
871 b = ComparatorNotImplemented(8)
872 self.assertEqual(a, b)
873 with self.assertRaises(TypeError):
874 a >= b
875
876 with self.subTest("LE when equal"):
877 a = ComparatorNotImplemented(9)
878 b = ComparatorNotImplemented(9)
879 self.assertEqual(a, b)
880 with self.assertRaises(TypeError):
881 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +0200882
Georg Brandl2e7346a2010-07-31 18:09:23 +0000883class TestLRU(unittest.TestCase):
884
885 def test_lru(self):
886 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100887 return 3 * x + y
Georg Brandl2e7346a2010-07-31 18:09:23 +0000888 f = functools.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000889 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000890 self.assertEqual(maxsize, 20)
891 self.assertEqual(currsize, 0)
892 self.assertEqual(hits, 0)
893 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000894
895 domain = range(5)
896 for i in range(1000):
897 x, y = choice(domain), choice(domain)
898 actual = f(x, y)
899 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +0000900 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000901 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000902 self.assertTrue(hits > misses)
903 self.assertEqual(hits + misses, 1000)
904 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000905
Raymond Hettinger02566ec2010-09-04 22:46:06 +0000906 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +0000907 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000908 self.assertEqual(hits, 0)
909 self.assertEqual(misses, 0)
910 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000911 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000912 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000913 self.assertEqual(hits, 0)
914 self.assertEqual(misses, 1)
915 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000916
Nick Coghlan98876832010-08-17 06:17:18 +0000917 # Test bypassing the cache
918 self.assertIs(f.__wrapped__, orig)
919 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000920 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000921 self.assertEqual(hits, 0)
922 self.assertEqual(misses, 1)
923 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +0000924
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000925 # test size zero (which means "never-cache")
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000926 @functools.lru_cache(0)
927 def f():
928 nonlocal f_cnt
929 f_cnt += 1
930 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000931 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000932 f_cnt = 0
933 for i in range(5):
934 self.assertEqual(f(), 20)
935 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000936 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000937 self.assertEqual(hits, 0)
938 self.assertEqual(misses, 5)
939 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000940
941 # test size one
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000942 @functools.lru_cache(1)
943 def f():
944 nonlocal f_cnt
945 f_cnt += 1
946 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000947 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000948 f_cnt = 0
949 for i in range(5):
950 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000951 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000952 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000953 self.assertEqual(hits, 4)
954 self.assertEqual(misses, 1)
955 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000956
Raymond Hettingerf3098282010-08-15 03:30:45 +0000957 # test size two
958 @functools.lru_cache(2)
959 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000960 nonlocal f_cnt
961 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +0000962 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +0000963 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000964 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +0000965 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
966 # * * * *
967 self.assertEqual(f(x), x*10)
968 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000969 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000970 self.assertEqual(hits, 12)
971 self.assertEqual(misses, 4)
972 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000973
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +0000974 def test_lru_with_maxsize_none(self):
975 @functools.lru_cache(maxsize=None)
976 def fib(n):
977 if n < 2:
978 return n
979 return fib(n-1) + fib(n-2)
980 self.assertEqual([fib(n) for n in range(16)],
981 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
982 self.assertEqual(fib.cache_info(),
983 functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
984 fib.cache_clear()
985 self.assertEqual(fib.cache_info(),
986 functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
987
Raymond Hettinger4b779b32011-10-15 23:50:42 -0700988 def test_lru_with_exceptions(self):
989 # Verify that user_function exceptions get passed through without
990 # creating a hard-to-read chained exception.
991 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +0100992 for maxsize in (None, 128):
Raymond Hettinger4b779b32011-10-15 23:50:42 -0700993 @functools.lru_cache(maxsize)
994 def func(i):
995 return 'abc'[i]
996 self.assertEqual(func(0), 'a')
997 with self.assertRaises(IndexError) as cm:
998 func(15)
999 self.assertIsNone(cm.exception.__context__)
1000 # Verify that the previous exception did not result in a cached entry
1001 with self.assertRaises(IndexError):
1002 func(15)
1003
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001004 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001005 for maxsize in (None, 128):
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001006 @functools.lru_cache(maxsize=maxsize, typed=True)
1007 def square(x):
1008 return x * x
1009 self.assertEqual(square(3), 9)
1010 self.assertEqual(type(square(3)), type(9))
1011 self.assertEqual(square(3.0), 9.0)
1012 self.assertEqual(type(square(3.0)), type(9.0))
1013 self.assertEqual(square(x=3), 9)
1014 self.assertEqual(type(square(x=3)), type(9))
1015 self.assertEqual(square(x=3.0), 9.0)
1016 self.assertEqual(type(square(x=3.0)), type(9.0))
1017 self.assertEqual(square.cache_info().hits, 4)
1018 self.assertEqual(square.cache_info().misses, 4)
1019
Antoine Pitroub5b37142012-11-13 21:35:40 +01001020 def test_lru_with_keyword_args(self):
1021 @functools.lru_cache()
1022 def fib(n):
1023 if n < 2:
1024 return n
1025 return fib(n=n-1) + fib(n=n-2)
1026 self.assertEqual(
1027 [fib(n=number) for number in range(16)],
1028 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1029 )
1030 self.assertEqual(fib.cache_info(),
1031 functools._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
1032 fib.cache_clear()
1033 self.assertEqual(fib.cache_info(),
1034 functools._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
1035
1036 def test_lru_with_keyword_args_maxsize_none(self):
1037 @functools.lru_cache(maxsize=None)
1038 def fib(n):
1039 if n < 2:
1040 return n
1041 return fib(n=n-1) + fib(n=n-2)
1042 self.assertEqual([fib(n=number) for number in range(16)],
1043 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1044 self.assertEqual(fib.cache_info(),
1045 functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1046 fib.cache_clear()
1047 self.assertEqual(fib.cache_info(),
1048 functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1049
Raymond Hettinger03923422013-03-04 02:52:50 -05001050 def test_need_for_rlock(self):
1051 # This will deadlock on an LRU cache that uses a regular lock
1052
1053 @functools.lru_cache(maxsize=10)
1054 def test_func(x):
1055 'Used to demonstrate a reentrant lru_cache call within a single thread'
1056 return x
1057
1058 class DoubleEq:
1059 'Demonstrate a reentrant lru_cache call within a single thread'
1060 def __init__(self, x):
1061 self.x = x
1062 def __hash__(self):
1063 return self.x
1064 def __eq__(self, other):
1065 if self.x == 2:
1066 test_func(DoubleEq(1))
1067 return self.x == other.x
1068
1069 test_func(DoubleEq(1)) # Load the cache
1070 test_func(DoubleEq(2)) # Load the cache
1071 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1072 DoubleEq(2)) # Verify the correct return value
1073
Raymond Hettinger4d588972014-08-12 12:44:52 -07001074 def test_early_detection_of_bad_call(self):
1075 # Issue #22184
1076 with self.assertRaises(TypeError):
1077 @functools.lru_cache
1078 def f():
1079 pass
1080
Raymond Hettinger03923422013-03-04 02:52:50 -05001081
Łukasz Langa6f692512013-06-05 12:20:24 +02001082class TestSingleDispatch(unittest.TestCase):
1083 def test_simple_overloads(self):
1084 @functools.singledispatch
1085 def g(obj):
1086 return "base"
1087 def g_int(i):
1088 return "integer"
1089 g.register(int, g_int)
1090 self.assertEqual(g("str"), "base")
1091 self.assertEqual(g(1), "integer")
1092 self.assertEqual(g([1,2,3]), "base")
1093
1094 def test_mro(self):
1095 @functools.singledispatch
1096 def g(obj):
1097 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001098 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001099 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001100 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001101 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001102 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001103 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001104 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001105 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001106 def g_A(a):
1107 return "A"
1108 def g_B(b):
1109 return "B"
1110 g.register(A, g_A)
1111 g.register(B, g_B)
1112 self.assertEqual(g(A()), "A")
1113 self.assertEqual(g(B()), "B")
1114 self.assertEqual(g(C()), "A")
1115 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001116
1117 def test_register_decorator(self):
1118 @functools.singledispatch
1119 def g(obj):
1120 return "base"
1121 @g.register(int)
1122 def g_int(i):
1123 return "int %s" % (i,)
1124 self.assertEqual(g(""), "base")
1125 self.assertEqual(g(12), "int 12")
1126 self.assertIs(g.dispatch(int), g_int)
1127 self.assertIs(g.dispatch(object), g.dispatch(str))
1128 # Note: in the assert above this is not g.
1129 # @singledispatch returns the wrapper.
1130
1131 def test_wrapping_attributes(self):
1132 @functools.singledispatch
1133 def g(obj):
1134 "Simple test"
1135 return "Test"
1136 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02001137 if sys.flags.optimize < 2:
1138 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02001139
1140 @unittest.skipUnless(decimal, 'requires _decimal')
1141 @support.cpython_only
1142 def test_c_classes(self):
1143 @functools.singledispatch
1144 def g(obj):
1145 return "base"
1146 @g.register(decimal.DecimalException)
1147 def _(obj):
1148 return obj.args
1149 subn = decimal.Subnormal("Exponent < Emin")
1150 rnd = decimal.Rounded("Number got rounded")
1151 self.assertEqual(g(subn), ("Exponent < Emin",))
1152 self.assertEqual(g(rnd), ("Number got rounded",))
1153 @g.register(decimal.Subnormal)
1154 def _(obj):
1155 return "Too small to care."
1156 self.assertEqual(g(subn), "Too small to care.")
1157 self.assertEqual(g(rnd), ("Number got rounded",))
1158
1159 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001160 # None of the examples in this test depend on haystack ordering.
Łukasz Langa6f692512013-06-05 12:20:24 +02001161 c = collections
1162 mro = functools._compose_mro
1163 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1164 for haystack in permutations(bases):
1165 m = mro(dict, haystack)
Łukasz Langa3720c772013-07-01 16:00:38 +02001166 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, c.Sized,
1167 c.Iterable, c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001168 bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict]
1169 for haystack in permutations(bases):
1170 m = mro(c.ChainMap, haystack)
1171 self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping,
1172 c.Sized, c.Iterable, c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001173
1174 # If there's a generic function with implementations registered for
1175 # both Sized and Container, passing a defaultdict to it results in an
1176 # ambiguous dispatch which will cause a RuntimeError (see
1177 # test_mro_conflicts).
1178 bases = [c.Container, c.Sized, str]
1179 for haystack in permutations(bases):
1180 m = mro(c.defaultdict, [c.Sized, c.Container, str])
1181 self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container,
1182 object])
1183
1184 # MutableSequence below is registered directly on D. In other words, it
1185 # preceeds MutableMapping which means single dispatch will always
1186 # choose MutableSequence here.
1187 class D(c.defaultdict):
1188 pass
1189 c.MutableSequence.register(D)
1190 bases = [c.MutableSequence, c.MutableMapping]
1191 for haystack in permutations(bases):
1192 m = mro(D, bases)
1193 self.assertEqual(m, [D, c.MutableSequence, c.Sequence,
1194 c.defaultdict, dict, c.MutableMapping,
1195 c.Mapping, c.Sized, c.Iterable, c.Container,
1196 object])
1197
1198 # Container and Callable are registered on different base classes and
1199 # a generic function supporting both should always pick the Callable
1200 # implementation if a C instance is passed.
1201 class C(c.defaultdict):
1202 def __call__(self):
1203 pass
1204 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1205 for haystack in permutations(bases):
1206 m = mro(C, haystack)
1207 self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping,
1208 c.Sized, c.Iterable, c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001209
1210 def test_register_abc(self):
1211 c = collections
1212 d = {"a": "b"}
1213 l = [1, 2, 3]
1214 s = {object(), None}
1215 f = frozenset(s)
1216 t = (1, 2, 3)
1217 @functools.singledispatch
1218 def g(obj):
1219 return "base"
1220 self.assertEqual(g(d), "base")
1221 self.assertEqual(g(l), "base")
1222 self.assertEqual(g(s), "base")
1223 self.assertEqual(g(f), "base")
1224 self.assertEqual(g(t), "base")
1225 g.register(c.Sized, lambda obj: "sized")
1226 self.assertEqual(g(d), "sized")
1227 self.assertEqual(g(l), "sized")
1228 self.assertEqual(g(s), "sized")
1229 self.assertEqual(g(f), "sized")
1230 self.assertEqual(g(t), "sized")
1231 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1232 self.assertEqual(g(d), "mutablemapping")
1233 self.assertEqual(g(l), "sized")
1234 self.assertEqual(g(s), "sized")
1235 self.assertEqual(g(f), "sized")
1236 self.assertEqual(g(t), "sized")
1237 g.register(c.ChainMap, lambda obj: "chainmap")
1238 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1239 self.assertEqual(g(l), "sized")
1240 self.assertEqual(g(s), "sized")
1241 self.assertEqual(g(f), "sized")
1242 self.assertEqual(g(t), "sized")
1243 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1244 self.assertEqual(g(d), "mutablemapping")
1245 self.assertEqual(g(l), "mutablesequence")
1246 self.assertEqual(g(s), "sized")
1247 self.assertEqual(g(f), "sized")
1248 self.assertEqual(g(t), "sized")
1249 g.register(c.MutableSet, lambda obj: "mutableset")
1250 self.assertEqual(g(d), "mutablemapping")
1251 self.assertEqual(g(l), "mutablesequence")
1252 self.assertEqual(g(s), "mutableset")
1253 self.assertEqual(g(f), "sized")
1254 self.assertEqual(g(t), "sized")
1255 g.register(c.Mapping, lambda obj: "mapping")
1256 self.assertEqual(g(d), "mutablemapping") # not specific enough
1257 self.assertEqual(g(l), "mutablesequence")
1258 self.assertEqual(g(s), "mutableset")
1259 self.assertEqual(g(f), "sized")
1260 self.assertEqual(g(t), "sized")
1261 g.register(c.Sequence, lambda obj: "sequence")
1262 self.assertEqual(g(d), "mutablemapping")
1263 self.assertEqual(g(l), "mutablesequence")
1264 self.assertEqual(g(s), "mutableset")
1265 self.assertEqual(g(f), "sized")
1266 self.assertEqual(g(t), "sequence")
1267 g.register(c.Set, lambda obj: "set")
1268 self.assertEqual(g(d), "mutablemapping")
1269 self.assertEqual(g(l), "mutablesequence")
1270 self.assertEqual(g(s), "mutableset")
1271 self.assertEqual(g(f), "set")
1272 self.assertEqual(g(t), "sequence")
1273 g.register(dict, lambda obj: "dict")
1274 self.assertEqual(g(d), "dict")
1275 self.assertEqual(g(l), "mutablesequence")
1276 self.assertEqual(g(s), "mutableset")
1277 self.assertEqual(g(f), "set")
1278 self.assertEqual(g(t), "sequence")
1279 g.register(list, lambda obj: "list")
1280 self.assertEqual(g(d), "dict")
1281 self.assertEqual(g(l), "list")
1282 self.assertEqual(g(s), "mutableset")
1283 self.assertEqual(g(f), "set")
1284 self.assertEqual(g(t), "sequence")
1285 g.register(set, lambda obj: "concrete-set")
1286 self.assertEqual(g(d), "dict")
1287 self.assertEqual(g(l), "list")
1288 self.assertEqual(g(s), "concrete-set")
1289 self.assertEqual(g(f), "set")
1290 self.assertEqual(g(t), "sequence")
1291 g.register(frozenset, lambda obj: "frozen-set")
1292 self.assertEqual(g(d), "dict")
1293 self.assertEqual(g(l), "list")
1294 self.assertEqual(g(s), "concrete-set")
1295 self.assertEqual(g(f), "frozen-set")
1296 self.assertEqual(g(t), "sequence")
1297 g.register(tuple, lambda obj: "tuple")
1298 self.assertEqual(g(d), "dict")
1299 self.assertEqual(g(l), "list")
1300 self.assertEqual(g(s), "concrete-set")
1301 self.assertEqual(g(f), "frozen-set")
1302 self.assertEqual(g(t), "tuple")
1303
Łukasz Langa3720c772013-07-01 16:00:38 +02001304 def test_c3_abc(self):
1305 c = collections
1306 mro = functools._c3_mro
1307 class A(object):
1308 pass
1309 class B(A):
1310 def __len__(self):
1311 return 0 # implies Sized
1312 @c.Container.register
1313 class C(object):
1314 pass
1315 class D(object):
1316 pass # unrelated
1317 class X(D, C, B):
1318 def __call__(self):
1319 pass # implies Callable
1320 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
1321 for abcs in permutations([c.Sized, c.Callable, c.Container]):
1322 self.assertEqual(mro(X, abcs=abcs), expected)
1323 # unrelated ABCs don't appear in the resulting MRO
1324 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
1325 self.assertEqual(mro(X, abcs=many_abcs), expected)
1326
Łukasz Langa6f692512013-06-05 12:20:24 +02001327 def test_mro_conflicts(self):
1328 c = collections
Łukasz Langa6f692512013-06-05 12:20:24 +02001329 @functools.singledispatch
1330 def g(arg):
1331 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02001332 class O(c.Sized):
1333 def __len__(self):
1334 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001335 o = O()
1336 self.assertEqual(g(o), "base")
1337 g.register(c.Iterable, lambda arg: "iterable")
1338 g.register(c.Container, lambda arg: "container")
1339 g.register(c.Sized, lambda arg: "sized")
1340 g.register(c.Set, lambda arg: "set")
1341 self.assertEqual(g(o), "sized")
1342 c.Iterable.register(O)
1343 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
1344 c.Container.register(O)
1345 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02001346 c.Set.register(O)
1347 self.assertEqual(g(o), "set") # because c.Set is a subclass of
1348 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02001349 class P:
1350 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02001351 p = P()
1352 self.assertEqual(g(p), "base")
1353 c.Iterable.register(P)
1354 self.assertEqual(g(p), "iterable")
1355 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02001356 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02001357 g(p)
Łukasz Langa3720c772013-07-01 16:00:38 +02001358 self.assertIn(
1359 str(re_one.exception),
1360 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1361 "or <class 'collections.abc.Iterable'>"),
1362 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
1363 "or <class 'collections.abc.Container'>")),
1364 )
Łukasz Langa6f692512013-06-05 12:20:24 +02001365 class Q(c.Sized):
1366 def __len__(self):
1367 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001368 q = Q()
1369 self.assertEqual(g(q), "sized")
1370 c.Iterable.register(Q)
1371 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
1372 c.Set.register(Q)
1373 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02001374 # c.Sized and c.Iterable
1375 @functools.singledispatch
1376 def h(arg):
1377 return "base"
1378 @h.register(c.Sized)
1379 def _(arg):
1380 return "sized"
1381 @h.register(c.Container)
1382 def _(arg):
1383 return "container"
1384 # Even though Sized and Container are explicit bases of MutableMapping,
1385 # this ABC is implicitly registered on defaultdict which makes all of
1386 # MutableMapping's bases implicit as well from defaultdict's
1387 # perspective.
1388 with self.assertRaises(RuntimeError) as re_two:
1389 h(c.defaultdict(lambda: 0))
1390 self.assertIn(
1391 str(re_two.exception),
1392 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1393 "or <class 'collections.abc.Sized'>"),
1394 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1395 "or <class 'collections.abc.Container'>")),
1396 )
1397 class R(c.defaultdict):
1398 pass
1399 c.MutableSequence.register(R)
1400 @functools.singledispatch
1401 def i(arg):
1402 return "base"
1403 @i.register(c.MutableMapping)
1404 def _(arg):
1405 return "mapping"
1406 @i.register(c.MutableSequence)
1407 def _(arg):
1408 return "sequence"
1409 r = R()
1410 self.assertEqual(i(r), "sequence")
1411 class S:
1412 pass
1413 class T(S, c.Sized):
1414 def __len__(self):
1415 return 0
1416 t = T()
1417 self.assertEqual(h(t), "sized")
1418 c.Container.register(T)
1419 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
1420 class U:
1421 def __len__(self):
1422 return 0
1423 u = U()
1424 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
1425 # from the existence of __len__()
1426 c.Container.register(U)
1427 # There is no preference for registered versus inferred ABCs.
1428 with self.assertRaises(RuntimeError) as re_three:
1429 h(u)
1430 self.assertIn(
1431 str(re_three.exception),
1432 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1433 "or <class 'collections.abc.Sized'>"),
1434 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1435 "or <class 'collections.abc.Container'>")),
1436 )
1437 class V(c.Sized, S):
1438 def __len__(self):
1439 return 0
1440 @functools.singledispatch
1441 def j(arg):
1442 return "base"
1443 @j.register(S)
1444 def _(arg):
1445 return "s"
1446 @j.register(c.Container)
1447 def _(arg):
1448 return "container"
1449 v = V()
1450 self.assertEqual(j(v), "s")
1451 c.Container.register(V)
1452 self.assertEqual(j(v), "container") # because it ends up right after
1453 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02001454
1455 def test_cache_invalidation(self):
1456 from collections import UserDict
1457 class TracingDict(UserDict):
1458 def __init__(self, *args, **kwargs):
1459 super(TracingDict, self).__init__(*args, **kwargs)
1460 self.set_ops = []
1461 self.get_ops = []
1462 def __getitem__(self, key):
1463 result = self.data[key]
1464 self.get_ops.append(key)
1465 return result
1466 def __setitem__(self, key, value):
1467 self.set_ops.append(key)
1468 self.data[key] = value
1469 def clear(self):
1470 self.data.clear()
1471 _orig_wkd = functools.WeakKeyDictionary
1472 td = TracingDict()
1473 functools.WeakKeyDictionary = lambda: td
1474 c = collections
1475 @functools.singledispatch
1476 def g(arg):
1477 return "base"
1478 d = {}
1479 l = []
1480 self.assertEqual(len(td), 0)
1481 self.assertEqual(g(d), "base")
1482 self.assertEqual(len(td), 1)
1483 self.assertEqual(td.get_ops, [])
1484 self.assertEqual(td.set_ops, [dict])
1485 self.assertEqual(td.data[dict], g.registry[object])
1486 self.assertEqual(g(l), "base")
1487 self.assertEqual(len(td), 2)
1488 self.assertEqual(td.get_ops, [])
1489 self.assertEqual(td.set_ops, [dict, list])
1490 self.assertEqual(td.data[dict], g.registry[object])
1491 self.assertEqual(td.data[list], g.registry[object])
1492 self.assertEqual(td.data[dict], td.data[list])
1493 self.assertEqual(g(l), "base")
1494 self.assertEqual(g(d), "base")
1495 self.assertEqual(td.get_ops, [list, dict])
1496 self.assertEqual(td.set_ops, [dict, list])
1497 g.register(list, lambda arg: "list")
1498 self.assertEqual(td.get_ops, [list, dict])
1499 self.assertEqual(len(td), 0)
1500 self.assertEqual(g(d), "base")
1501 self.assertEqual(len(td), 1)
1502 self.assertEqual(td.get_ops, [list, dict])
1503 self.assertEqual(td.set_ops, [dict, list, dict])
1504 self.assertEqual(td.data[dict],
1505 functools._find_impl(dict, g.registry))
1506 self.assertEqual(g(l), "list")
1507 self.assertEqual(len(td), 2)
1508 self.assertEqual(td.get_ops, [list, dict])
1509 self.assertEqual(td.set_ops, [dict, list, dict, list])
1510 self.assertEqual(td.data[list],
1511 functools._find_impl(list, g.registry))
1512 class X:
1513 pass
1514 c.MutableMapping.register(X) # Will not invalidate the cache,
1515 # not using ABCs yet.
1516 self.assertEqual(g(d), "base")
1517 self.assertEqual(g(l), "list")
1518 self.assertEqual(td.get_ops, [list, dict, dict, list])
1519 self.assertEqual(td.set_ops, [dict, list, dict, list])
1520 g.register(c.Sized, lambda arg: "sized")
1521 self.assertEqual(len(td), 0)
1522 self.assertEqual(g(d), "sized")
1523 self.assertEqual(len(td), 1)
1524 self.assertEqual(td.get_ops, [list, dict, dict, list])
1525 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
1526 self.assertEqual(g(l), "list")
1527 self.assertEqual(len(td), 2)
1528 self.assertEqual(td.get_ops, [list, dict, dict, list])
1529 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1530 self.assertEqual(g(l), "list")
1531 self.assertEqual(g(d), "sized")
1532 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
1533 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1534 g.dispatch(list)
1535 g.dispatch(dict)
1536 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
1537 list, dict])
1538 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1539 c.MutableSet.register(X) # Will invalidate the cache.
1540 self.assertEqual(len(td), 2) # Stale cache.
1541 self.assertEqual(g(l), "list")
1542 self.assertEqual(len(td), 1)
1543 g.register(c.MutableMapping, lambda arg: "mutablemapping")
1544 self.assertEqual(len(td), 0)
1545 self.assertEqual(g(d), "mutablemapping")
1546 self.assertEqual(len(td), 1)
1547 self.assertEqual(g(l), "list")
1548 self.assertEqual(len(td), 2)
1549 g.register(dict, lambda arg: "dict")
1550 self.assertEqual(g(d), "dict")
1551 self.assertEqual(g(l), "list")
1552 g._clear_cache()
1553 self.assertEqual(len(td), 0)
1554 functools.WeakKeyDictionary = _orig_wkd
1555
1556
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001557def test_main(verbose=None):
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001558 test_classes = (
Antoine Pitroub5b37142012-11-13 21:35:40 +01001559 TestPartialC,
1560 TestPartialPy,
1561 TestPartialCSubclass,
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001562 TestPartialMethod,
Thomas Wouters73e5a5b2006-06-08 15:35:45 +00001563 TestUpdateWrapper,
Benjamin Peterson9c2930e2010-08-23 17:40:33 +00001564 TestTotalOrdering,
Antoine Pitroub5b37142012-11-13 21:35:40 +01001565 TestCmpToKeyC,
1566 TestCmpToKeyPy,
Guido van Rossum0919a1a2006-08-26 20:49:04 +00001567 TestWraps,
Georg Brandl2e7346a2010-07-31 18:09:23 +00001568 TestReduce,
1569 TestLRU,
Łukasz Langa6f692512013-06-05 12:20:24 +02001570 TestSingleDispatch,
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001571 )
Benjamin Petersonee8712c2008-05-20 21:35:26 +00001572 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001573
1574 # verify reference counting
1575 if verbose and hasattr(sys, "gettotalrefcount"):
1576 import gc
1577 counts = [None] * 5
Guido van Rossum805365e2007-05-07 22:24:25 +00001578 for i in range(len(counts)):
Benjamin Petersonee8712c2008-05-20 21:35:26 +00001579 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001580 gc.collect()
1581 counts[i] = sys.gettotalrefcount()
Guido van Rossumbe19ed72007-02-09 05:37:30 +00001582 print(counts)
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001583
1584if __name__ == '__main__':
1585 test_main(verbose=True)