blob: 36f154a7f568c7765ba3e1c19492335c99c62777 [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettinger003be522011-05-03 11:01:32 -07002import collections
Łukasz Langa6f692512013-06-05 12:20:24 +02003from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00004import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00005from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02006import sys
7from test import support
8import unittest
9from weakref import proxy
Raymond Hettinger9c323f82005-02-28 19:39:44 +000010
Antoine Pitroub5b37142012-11-13 21:35:40 +010011import functools
12
Antoine Pitroub5b37142012-11-13 21:35:40 +010013py_functools = support.import_fresh_module('functools', blocked=['_functools'])
14c_functools = support.import_fresh_module('functools', fresh=['_functools'])
15
Łukasz Langa6f692512013-06-05 12:20:24 +020016decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
17
18
Raymond Hettinger9c323f82005-02-28 19:39:44 +000019def capture(*args, **kw):
20 """capture all positional and keyword arguments"""
21 return args, kw
22
Łukasz Langa6f692512013-06-05 12:20:24 +020023
Jack Diederiche0cbd692009-04-01 04:27:09 +000024def signature(part):
25 """ return the signature of a partial object """
26 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000027
Łukasz Langa6f692512013-06-05 12:20:24 +020028
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020029class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000030
31 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010032 p = self.partial(capture, 1, 2, a=10, b=20)
33 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000034 self.assertEqual(p(3, 4, b=30, c=40),
35 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010036 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000037 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000038
39 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010040 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000041 # attributes should be readable
42 self.assertEqual(p.func, capture)
43 self.assertEqual(p.args, (1, 2))
44 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000045
46 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010047 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000048 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010049 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000050 except TypeError:
51 pass
52 else:
53 self.fail('First arg not checked for callability')
54
55 def test_protection_of_callers_dict_argument(self):
56 # a caller's dictionary should not be altered by partial
57 def func(a=10, b=20):
58 return a
59 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010060 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000061 self.assertEqual(p(**d), 3)
62 self.assertEqual(d, {'a':3})
63 p(b=7)
64 self.assertEqual(d, {'a':3})
65
66 def test_arg_combinations(self):
67 # exercise special code paths for zero args in either partial
68 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010069 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000070 self.assertEqual(p(), ((), {}))
71 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010072 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000073 self.assertEqual(p(), ((1,2), {}))
74 self.assertEqual(p(3,4), ((1,2,3,4), {}))
75
76 def test_kw_combinations(self):
77 # exercise special code paths for no keyword args in
78 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010079 p = self.partial(capture)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -040080 self.assertEqual(p.keywords, {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +000081 self.assertEqual(p(), ((), {}))
82 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010083 p = self.partial(capture, a=1)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -040084 self.assertEqual(p.keywords, {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +000085 self.assertEqual(p(), ((), {'a':1}))
86 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
87 # keyword args in the call override those in the partial object
88 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
89
90 def test_positional(self):
91 # make sure positional arguments are captured correctly
92 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +010093 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000094 expected = args + ('x',)
95 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +000096 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +000097
98 def test_keyword(self):
99 # make sure keyword arguments are captured correctly
100 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100101 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000102 expected = {'a':a,'x':None}
103 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000104 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000105
106 def test_no_side_effects(self):
107 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100108 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000109 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000110 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000111 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000112 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000113
114 def test_error_propagation(self):
115 def f(x, y):
116 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100117 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
118 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
119 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
120 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000121
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000122 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100123 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000124 p = proxy(f)
125 self.assertEqual(f.func, p.func)
126 f = None
127 self.assertRaises(ReferenceError, getattr, p, 'func')
128
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000129 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000130 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100131 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000132 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100133 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000134 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000135
Łukasz Langa6f692512013-06-05 12:20:24 +0200136
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200137@unittest.skipUnless(c_functools, 'requires the C _functools module')
138class TestPartialC(TestPartial, unittest.TestCase):
139 if c_functools:
140 partial = c_functools.partial
141
Zachary Ware101d9e72013-12-08 00:44:27 -0600142 def test_attributes_unwritable(self):
143 # attributes should not be writable
144 p = self.partial(capture, 1, 2, a=10, b=20)
145 self.assertRaises(AttributeError, setattr, p, 'func', map)
146 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
147 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
148
149 p = self.partial(hex)
150 try:
151 del p.__dict__
152 except TypeError:
153 pass
154 else:
155 self.fail('partial object allowed __dict__ to be deleted')
156
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000157 def test_repr(self):
158 args = (object(), object())
159 args_repr = ', '.join(repr(a) for a in args)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200160 kwargs = {'a': object(), 'b': object()}
161 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
162 'b={b!r}, a={a!r}'.format_map(kwargs)]
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200163 if self.partial is c_functools.partial:
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000164 name = 'functools.partial'
165 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100166 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000167
Antoine Pitroub5b37142012-11-13 21:35:40 +0100168 f = self.partial(capture)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000169 self.assertEqual('{}({!r})'.format(name, capture),
170 repr(f))
171
Antoine Pitroub5b37142012-11-13 21:35:40 +0100172 f = self.partial(capture, *args)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000173 self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr),
174 repr(f))
175
Antoine Pitroub5b37142012-11-13 21:35:40 +0100176 f = self.partial(capture, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200177 self.assertIn(repr(f),
178 ['{}({!r}, {})'.format(name, capture, kwargs_repr)
179 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000180
Antoine Pitroub5b37142012-11-13 21:35:40 +0100181 f = self.partial(capture, *args, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200182 self.assertIn(repr(f),
183 ['{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr)
184 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000185
Jack Diederiche0cbd692009-04-01 04:27:09 +0000186 def test_pickle(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100187 f = self.partial(signature, 'asdf', bar=True)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000188 f.add_something_to__dict__ = True
Serhiy Storchakabad12572014-12-15 14:03:42 +0200189 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
190 f_copy = pickle.loads(pickle.dumps(f, proto))
191 self.assertEqual(signature(f), signature(f_copy))
Jack Diederiche0cbd692009-04-01 04:27:09 +0000192
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200193 # Issue 6083: Reference counting bug
194 def test_setstate_refcount(self):
195 class BadSequence:
196 def __len__(self):
197 return 4
198 def __getitem__(self, key):
199 if key == 0:
200 return max
201 elif key == 1:
202 return tuple(range(1000000))
203 elif key in (2, 3):
204 return {}
205 raise IndexError
206
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200207 f = self.partial(object)
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200208 self.assertRaisesRegex(SystemError,
209 "new style getargs format but argument is not a tuple",
210 f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000211
Łukasz Langa6f692512013-06-05 12:20:24 +0200212
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200213class TestPartialPy(TestPartial, unittest.TestCase):
214 partial = staticmethod(py_functools.partial)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000215
Łukasz Langa6f692512013-06-05 12:20:24 +0200216
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200217if c_functools:
218 class PartialSubclass(c_functools.partial):
219 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100220
Łukasz Langa6f692512013-06-05 12:20:24 +0200221
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200222@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200223class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200224 if c_functools:
225 partial = PartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000226
Łukasz Langa6f692512013-06-05 12:20:24 +0200227
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000228class TestPartialMethod(unittest.TestCase):
229
230 class A(object):
231 nothing = functools.partialmethod(capture)
232 positional = functools.partialmethod(capture, 1)
233 keywords = functools.partialmethod(capture, a=2)
234 both = functools.partialmethod(capture, 3, b=4)
235
236 nested = functools.partialmethod(positional, 5)
237
238 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
239
240 static = functools.partialmethod(staticmethod(capture), 8)
241 cls = functools.partialmethod(classmethod(capture), d=9)
242
243 a = A()
244
245 def test_arg_combinations(self):
246 self.assertEqual(self.a.nothing(), ((self.a,), {}))
247 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
248 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
249 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
250
251 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
252 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
253 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
254 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
255
256 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
257 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
258 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
259 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
260
261 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
262 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
263 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
264 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
265
266 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
267
268 def test_nested(self):
269 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
270 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
271 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
272 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
273
274 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
275
276 def test_over_partial(self):
277 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
278 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
279 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
280 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
281
282 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
283
284 def test_bound_method_introspection(self):
285 obj = self.a
286 self.assertIs(obj.both.__self__, obj)
287 self.assertIs(obj.nested.__self__, obj)
288 self.assertIs(obj.over_partial.__self__, obj)
289 self.assertIs(obj.cls.__self__, self.A)
290 self.assertIs(self.A.cls.__self__, self.A)
291
292 def test_unbound_method_retrieval(self):
293 obj = self.A
294 self.assertFalse(hasattr(obj.both, "__self__"))
295 self.assertFalse(hasattr(obj.nested, "__self__"))
296 self.assertFalse(hasattr(obj.over_partial, "__self__"))
297 self.assertFalse(hasattr(obj.static, "__self__"))
298 self.assertFalse(hasattr(self.a.static, "__self__"))
299
300 def test_descriptors(self):
301 for obj in [self.A, self.a]:
302 with self.subTest(obj=obj):
303 self.assertEqual(obj.static(), ((8,), {}))
304 self.assertEqual(obj.static(5), ((8, 5), {}))
305 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
306 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
307
308 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
309 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
310 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
311 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
312
313 def test_overriding_keywords(self):
314 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
315 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
316
317 def test_invalid_args(self):
318 with self.assertRaises(TypeError):
319 class B(object):
320 method = functools.partialmethod(None, 1)
321
322 def test_repr(self):
323 self.assertEqual(repr(vars(self.A)['both']),
324 'functools.partialmethod({}, 3, b=4)'.format(capture))
325
326 def test_abstract(self):
327 class Abstract(abc.ABCMeta):
328
329 @abc.abstractmethod
330 def add(self, x, y):
331 pass
332
333 add5 = functools.partialmethod(add, 5)
334
335 self.assertTrue(Abstract.add.__isabstractmethod__)
336 self.assertTrue(Abstract.add5.__isabstractmethod__)
337
338 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
339 self.assertFalse(getattr(func, '__isabstractmethod__', False))
340
341
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000342class TestUpdateWrapper(unittest.TestCase):
343
344 def check_wrapper(self, wrapper, wrapped,
345 assigned=functools.WRAPPER_ASSIGNMENTS,
346 updated=functools.WRAPPER_UPDATES):
347 # Check attributes were assigned
348 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000349 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000350 # Check attributes were updated
351 for name in updated:
352 wrapper_attr = getattr(wrapper, name)
353 wrapped_attr = getattr(wrapped, name)
354 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000355 if name == "__dict__" and key == "__wrapped__":
356 # __wrapped__ is overwritten by the update code
357 continue
358 self.assertIs(wrapped_attr[key], wrapper_attr[key])
359 # Check __wrapped__
360 self.assertIs(wrapper.__wrapped__, wrapped)
361
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000362
R. David Murray378c0cf2010-02-24 01:46:21 +0000363 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000364 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000365 """This is a test"""
366 pass
367 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000368 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000369 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000370 pass
371 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000372 return wrapper, f
373
374 def test_default_update(self):
375 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000376 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000377 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000378 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600379 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000380 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000381 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
382 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000383
R. David Murray378c0cf2010-02-24 01:46:21 +0000384 @unittest.skipIf(sys.flags.optimize >= 2,
385 "Docstrings are omitted with -O2 and above")
386 def test_default_update_doc(self):
387 wrapper, f = self._default_update()
388 self.assertEqual(wrapper.__doc__, 'This is a test')
389
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000390 def test_no_update(self):
391 def f():
392 """This is a test"""
393 pass
394 f.attr = 'This is also a test'
395 def wrapper():
396 pass
397 functools.update_wrapper(wrapper, f, (), ())
398 self.check_wrapper(wrapper, f, (), ())
399 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600400 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000401 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000402 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000403 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000404
405 def test_selective_update(self):
406 def f():
407 pass
408 f.attr = 'This is a different test'
409 f.dict_attr = dict(a=1, b=2, c=3)
410 def wrapper():
411 pass
412 wrapper.dict_attr = {}
413 assign = ('attr',)
414 update = ('dict_attr',)
415 functools.update_wrapper(wrapper, f, assign, update)
416 self.check_wrapper(wrapper, f, assign, update)
417 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600418 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000419 self.assertEqual(wrapper.__doc__, None)
420 self.assertEqual(wrapper.attr, 'This is a different test')
421 self.assertEqual(wrapper.dict_attr, f.dict_attr)
422
Nick Coghlan98876832010-08-17 06:17:18 +0000423 def test_missing_attributes(self):
424 def f():
425 pass
426 def wrapper():
427 pass
428 wrapper.dict_attr = {}
429 assign = ('attr',)
430 update = ('dict_attr',)
431 # Missing attributes on wrapped object are ignored
432 functools.update_wrapper(wrapper, f, assign, update)
433 self.assertNotIn('attr', wrapper.__dict__)
434 self.assertEqual(wrapper.dict_attr, {})
435 # Wrapper must have expected attributes for updating
436 del wrapper.dict_attr
437 with self.assertRaises(AttributeError):
438 functools.update_wrapper(wrapper, f, assign, update)
439 wrapper.dict_attr = 1
440 with self.assertRaises(AttributeError):
441 functools.update_wrapper(wrapper, f, assign, update)
442
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200443 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000444 @unittest.skipIf(sys.flags.optimize >= 2,
445 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000446 def test_builtin_update(self):
447 # Test for bug #1576241
448 def wrapper():
449 pass
450 functools.update_wrapper(wrapper, max)
451 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000452 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000453 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000454
Łukasz Langa6f692512013-06-05 12:20:24 +0200455
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000456class TestWraps(TestUpdateWrapper):
457
R. David Murray378c0cf2010-02-24 01:46:21 +0000458 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000459 def f():
460 """This is a test"""
461 pass
462 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000463 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000464 @functools.wraps(f)
465 def wrapper():
466 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600467 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000468
469 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600470 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000471 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000472 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600473 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000474 self.assertEqual(wrapper.attr, 'This is also a test')
475
Antoine Pitroub5b37142012-11-13 21:35:40 +0100476 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000477 "Docstrings are omitted with -O2 and above")
478 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600479 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000480 self.assertEqual(wrapper.__doc__, 'This is a test')
481
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000482 def test_no_update(self):
483 def f():
484 """This is a test"""
485 pass
486 f.attr = 'This is also a test'
487 @functools.wraps(f, (), ())
488 def wrapper():
489 pass
490 self.check_wrapper(wrapper, f, (), ())
491 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600492 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000493 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000494 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000495
496 def test_selective_update(self):
497 def f():
498 pass
499 f.attr = 'This is a different test'
500 f.dict_attr = dict(a=1, b=2, c=3)
501 def add_dict_attr(f):
502 f.dict_attr = {}
503 return f
504 assign = ('attr',)
505 update = ('dict_attr',)
506 @functools.wraps(f, assign, update)
507 @add_dict_attr
508 def wrapper():
509 pass
510 self.check_wrapper(wrapper, f, assign, update)
511 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600512 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000513 self.assertEqual(wrapper.__doc__, None)
514 self.assertEqual(wrapper.attr, 'This is a different test')
515 self.assertEqual(wrapper.dict_attr, f.dict_attr)
516
Łukasz Langa6f692512013-06-05 12:20:24 +0200517
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000518class TestReduce(unittest.TestCase):
519 func = functools.reduce
520
521 def test_reduce(self):
522 class Squares:
523 def __init__(self, max):
524 self.max = max
525 self.sofar = []
526
527 def __len__(self):
528 return len(self.sofar)
529
530 def __getitem__(self, i):
531 if not 0 <= i < self.max: raise IndexError
532 n = len(self.sofar)
533 while n <= i:
534 self.sofar.append(n*n)
535 n += 1
536 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000537 def add(x, y):
538 return x + y
539 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000540 self.assertEqual(
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000541 self.func(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000542 ['a','c','d','w']
543 )
544 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
545 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000546 self.func(lambda x, y: x*y, range(2,21), 1),
547 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000548 )
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000549 self.assertEqual(self.func(add, Squares(10)), 285)
550 self.assertEqual(self.func(add, Squares(10), 0), 285)
551 self.assertEqual(self.func(add, Squares(0), 0), 0)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000552 self.assertRaises(TypeError, self.func)
553 self.assertRaises(TypeError, self.func, 42, 42)
554 self.assertRaises(TypeError, self.func, 42, 42, 42)
555 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
556 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
557 self.assertRaises(TypeError, self.func, 42, (42, 42))
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000558 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
559 self.assertRaises(TypeError, self.func, add, "")
560 self.assertRaises(TypeError, self.func, add, ())
561 self.assertRaises(TypeError, self.func, add, object())
562
563 class TestFailingIter:
564 def __iter__(self):
565 raise RuntimeError
566 self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
567
568 self.assertEqual(self.func(add, [], None), None)
569 self.assertEqual(self.func(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000570
571 class BadSeq:
572 def __getitem__(self, index):
573 raise ValueError
574 self.assertRaises(ValueError, self.func, 42, BadSeq())
575
576 # Test reduce()'s use of iterators.
577 def test_iterator_usage(self):
578 class SequenceClass:
579 def __init__(self, n):
580 self.n = n
581 def __getitem__(self, i):
582 if 0 <= i < self.n:
583 return i
584 else:
585 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000586
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000587 from operator import add
588 self.assertEqual(self.func(add, SequenceClass(5)), 10)
589 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
590 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
591 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
592 self.assertEqual(self.func(add, SequenceClass(1)), 0)
593 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
594
595 d = {"one": 1, "two": 2, "three": 3}
596 self.assertEqual(self.func(add, d), "".join(d.keys()))
597
Łukasz Langa6f692512013-06-05 12:20:24 +0200598
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200599class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700600
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000601 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700602 def cmp1(x, y):
603 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100604 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700605 self.assertEqual(key(3), key(3))
606 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100607 self.assertGreaterEqual(key(3), key(3))
608
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700609 def cmp2(x, y):
610 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100611 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700612 self.assertEqual(key(4.0), key('4'))
613 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100614 self.assertLessEqual(key(2), key('35'))
615 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700616
617 def test_cmp_to_key_arguments(self):
618 def cmp1(x, y):
619 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100620 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700621 self.assertEqual(key(obj=3), key(obj=3))
622 self.assertGreater(key(obj=3), key(obj=1))
623 with self.assertRaises((TypeError, AttributeError)):
624 key(3) > 1 # rhs is not a K object
625 with self.assertRaises((TypeError, AttributeError)):
626 1 < key(3) # lhs is not a K object
627 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100628 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700629 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200630 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100631 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700632 with self.assertRaises(TypeError):
633 key() # too few args
634 with self.assertRaises(TypeError):
635 key(None, None) # too many args
636
637 def test_bad_cmp(self):
638 def cmp1(x, y):
639 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100640 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700641 with self.assertRaises(ZeroDivisionError):
642 key(3) > key(1)
643
644 class BadCmp:
645 def __lt__(self, other):
646 raise ZeroDivisionError
647 def cmp1(x, y):
648 return BadCmp()
649 with self.assertRaises(ZeroDivisionError):
650 key(3) > key(1)
651
652 def test_obj_field(self):
653 def cmp1(x, y):
654 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100655 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700656 self.assertEqual(key(50).obj, 50)
657
658 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000659 def mycmp(x, y):
660 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100661 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000662 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000663
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700664 def test_sort_int_str(self):
665 def mycmp(x, y):
666 x, y = int(x), int(y)
667 return (x > y) - (x < y)
668 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100669 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700670 self.assertEqual([int(value) for value in values],
671 [0, 1, 1, 2, 3, 4, 5, 7, 10])
672
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000673 def test_hash(self):
674 def mycmp(x, y):
675 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100676 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000677 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700678 self.assertRaises(TypeError, hash, k)
Raymond Hettingere7a24302011-05-03 11:16:36 -0700679 self.assertNotIsInstance(k, collections.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000680
Łukasz Langa6f692512013-06-05 12:20:24 +0200681
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200682@unittest.skipUnless(c_functools, 'requires the C _functools module')
683class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
684 if c_functools:
685 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100686
Łukasz Langa6f692512013-06-05 12:20:24 +0200687
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200688class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100689 cmp_to_key = staticmethod(py_functools.cmp_to_key)
690
Łukasz Langa6f692512013-06-05 12:20:24 +0200691
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000692class TestTotalOrdering(unittest.TestCase):
693
694 def test_total_ordering_lt(self):
695 @functools.total_ordering
696 class A:
697 def __init__(self, value):
698 self.value = value
699 def __lt__(self, other):
700 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000701 def __eq__(self, other):
702 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000703 self.assertTrue(A(1) < A(2))
704 self.assertTrue(A(2) > A(1))
705 self.assertTrue(A(1) <= A(2))
706 self.assertTrue(A(2) >= A(1))
707 self.assertTrue(A(2) <= A(2))
708 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000709 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000710
711 def test_total_ordering_le(self):
712 @functools.total_ordering
713 class A:
714 def __init__(self, value):
715 self.value = value
716 def __le__(self, other):
717 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000718 def __eq__(self, other):
719 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000720 self.assertTrue(A(1) < A(2))
721 self.assertTrue(A(2) > A(1))
722 self.assertTrue(A(1) <= A(2))
723 self.assertTrue(A(2) >= A(1))
724 self.assertTrue(A(2) <= A(2))
725 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000726 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000727
728 def test_total_ordering_gt(self):
729 @functools.total_ordering
730 class A:
731 def __init__(self, value):
732 self.value = value
733 def __gt__(self, other):
734 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000735 def __eq__(self, other):
736 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000737 self.assertTrue(A(1) < A(2))
738 self.assertTrue(A(2) > A(1))
739 self.assertTrue(A(1) <= A(2))
740 self.assertTrue(A(2) >= A(1))
741 self.assertTrue(A(2) <= A(2))
742 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000743 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000744
745 def test_total_ordering_ge(self):
746 @functools.total_ordering
747 class A:
748 def __init__(self, value):
749 self.value = value
750 def __ge__(self, other):
751 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000752 def __eq__(self, other):
753 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000754 self.assertTrue(A(1) < A(2))
755 self.assertTrue(A(2) > A(1))
756 self.assertTrue(A(1) <= A(2))
757 self.assertTrue(A(2) >= A(1))
758 self.assertTrue(A(2) <= A(2))
759 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000760 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000761
762 def test_total_ordering_no_overwrite(self):
763 # new methods should not overwrite existing
764 @functools.total_ordering
765 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000766 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +0000767 self.assertTrue(A(1) < A(2))
768 self.assertTrue(A(2) > A(1))
769 self.assertTrue(A(1) <= A(2))
770 self.assertTrue(A(2) >= A(1))
771 self.assertTrue(A(2) <= A(2))
772 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000773
Benjamin Peterson42ebee32010-04-11 01:43:16 +0000774 def test_no_operations_defined(self):
775 with self.assertRaises(ValueError):
776 @functools.total_ordering
777 class A:
778 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000779
Nick Coghlanf05d9812013-10-02 00:02:03 +1000780 def test_type_error_when_not_implemented(self):
781 # bug 10042; ensure stack overflow does not occur
782 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000783 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +1000784 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000785 def __init__(self, value):
786 self.value = value
787 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000788 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000789 return self.value == other.value
790 return False
791 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000792 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000793 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +1000794 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000795
Nick Coghlanf05d9812013-10-02 00:02:03 +1000796 @functools.total_ordering
797 class ImplementsGreaterThan:
798 def __init__(self, value):
799 self.value = value
800 def __eq__(self, other):
801 if isinstance(other, ImplementsGreaterThan):
802 return self.value == other.value
803 return False
804 def __gt__(self, other):
805 if isinstance(other, ImplementsGreaterThan):
806 return self.value > other.value
807 return NotImplemented
808
809 @functools.total_ordering
810 class ImplementsLessThanEqualTo:
811 def __init__(self, value):
812 self.value = value
813 def __eq__(self, other):
814 if isinstance(other, ImplementsLessThanEqualTo):
815 return self.value == other.value
816 return False
817 def __le__(self, other):
818 if isinstance(other, ImplementsLessThanEqualTo):
819 return self.value <= other.value
820 return NotImplemented
821
822 @functools.total_ordering
823 class ImplementsGreaterThanEqualTo:
824 def __init__(self, value):
825 self.value = value
826 def __eq__(self, other):
827 if isinstance(other, ImplementsGreaterThanEqualTo):
828 return self.value == other.value
829 return False
830 def __ge__(self, other):
831 if isinstance(other, ImplementsGreaterThanEqualTo):
832 return self.value >= other.value
833 return NotImplemented
834
835 @functools.total_ordering
836 class ComparatorNotImplemented:
837 def __init__(self, value):
838 self.value = value
839 def __eq__(self, other):
840 if isinstance(other, ComparatorNotImplemented):
841 return self.value == other.value
842 return False
843 def __lt__(self, other):
844 return NotImplemented
845
846 with self.subTest("LT < 1"), self.assertRaises(TypeError):
847 ImplementsLessThan(-1) < 1
848
849 with self.subTest("LT < LE"), self.assertRaises(TypeError):
850 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
851
852 with self.subTest("LT < GT"), self.assertRaises(TypeError):
853 ImplementsLessThan(1) < ImplementsGreaterThan(1)
854
855 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
856 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
857
858 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
859 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
860
861 with self.subTest("GT > GE"), self.assertRaises(TypeError):
862 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
863
864 with self.subTest("GT > LT"), self.assertRaises(TypeError):
865 ImplementsGreaterThan(5) > ImplementsLessThan(5)
866
867 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
868 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
869
870 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
871 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
872
873 with self.subTest("GE when equal"):
874 a = ComparatorNotImplemented(8)
875 b = ComparatorNotImplemented(8)
876 self.assertEqual(a, b)
877 with self.assertRaises(TypeError):
878 a >= b
879
880 with self.subTest("LE when equal"):
881 a = ComparatorNotImplemented(9)
882 b = ComparatorNotImplemented(9)
883 self.assertEqual(a, b)
884 with self.assertRaises(TypeError):
885 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +0200886
Georg Brandl2e7346a2010-07-31 18:09:23 +0000887class TestLRU(unittest.TestCase):
888
889 def test_lru(self):
890 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100891 return 3 * x + y
Georg Brandl2e7346a2010-07-31 18:09:23 +0000892 f = functools.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000893 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000894 self.assertEqual(maxsize, 20)
895 self.assertEqual(currsize, 0)
896 self.assertEqual(hits, 0)
897 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000898
899 domain = range(5)
900 for i in range(1000):
901 x, y = choice(domain), choice(domain)
902 actual = f(x, y)
903 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +0000904 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000905 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000906 self.assertTrue(hits > misses)
907 self.assertEqual(hits + misses, 1000)
908 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000909
Raymond Hettinger02566ec2010-09-04 22:46:06 +0000910 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +0000911 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000912 self.assertEqual(hits, 0)
913 self.assertEqual(misses, 0)
914 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000915 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000916 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000917 self.assertEqual(hits, 0)
918 self.assertEqual(misses, 1)
919 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000920
Nick Coghlan98876832010-08-17 06:17:18 +0000921 # Test bypassing the cache
922 self.assertIs(f.__wrapped__, orig)
923 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000924 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000925 self.assertEqual(hits, 0)
926 self.assertEqual(misses, 1)
927 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +0000928
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000929 # test size zero (which means "never-cache")
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000930 @functools.lru_cache(0)
931 def f():
932 nonlocal f_cnt
933 f_cnt += 1
934 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000935 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000936 f_cnt = 0
937 for i in range(5):
938 self.assertEqual(f(), 20)
939 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000940 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000941 self.assertEqual(hits, 0)
942 self.assertEqual(misses, 5)
943 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000944
945 # test size one
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000946 @functools.lru_cache(1)
947 def f():
948 nonlocal f_cnt
949 f_cnt += 1
950 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000951 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000952 f_cnt = 0
953 for i in range(5):
954 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000955 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000956 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000957 self.assertEqual(hits, 4)
958 self.assertEqual(misses, 1)
959 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000960
Raymond Hettingerf3098282010-08-15 03:30:45 +0000961 # test size two
962 @functools.lru_cache(2)
963 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000964 nonlocal f_cnt
965 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +0000966 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +0000967 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000968 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +0000969 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
970 # * * * *
971 self.assertEqual(f(x), x*10)
972 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000973 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000974 self.assertEqual(hits, 12)
975 self.assertEqual(misses, 4)
976 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000977
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +0000978 def test_lru_with_maxsize_none(self):
979 @functools.lru_cache(maxsize=None)
980 def fib(n):
981 if n < 2:
982 return n
983 return fib(n-1) + fib(n-2)
984 self.assertEqual([fib(n) for n in range(16)],
985 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
986 self.assertEqual(fib.cache_info(),
987 functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
988 fib.cache_clear()
989 self.assertEqual(fib.cache_info(),
990 functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
991
Raymond Hettinger4b779b32011-10-15 23:50:42 -0700992 def test_lru_with_exceptions(self):
993 # Verify that user_function exceptions get passed through without
994 # creating a hard-to-read chained exception.
995 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +0100996 for maxsize in (None, 128):
Raymond Hettinger4b779b32011-10-15 23:50:42 -0700997 @functools.lru_cache(maxsize)
998 def func(i):
999 return 'abc'[i]
1000 self.assertEqual(func(0), 'a')
1001 with self.assertRaises(IndexError) as cm:
1002 func(15)
1003 self.assertIsNone(cm.exception.__context__)
1004 # Verify that the previous exception did not result in a cached entry
1005 with self.assertRaises(IndexError):
1006 func(15)
1007
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001008 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001009 for maxsize in (None, 128):
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001010 @functools.lru_cache(maxsize=maxsize, typed=True)
1011 def square(x):
1012 return x * x
1013 self.assertEqual(square(3), 9)
1014 self.assertEqual(type(square(3)), type(9))
1015 self.assertEqual(square(3.0), 9.0)
1016 self.assertEqual(type(square(3.0)), type(9.0))
1017 self.assertEqual(square(x=3), 9)
1018 self.assertEqual(type(square(x=3)), type(9))
1019 self.assertEqual(square(x=3.0), 9.0)
1020 self.assertEqual(type(square(x=3.0)), type(9.0))
1021 self.assertEqual(square.cache_info().hits, 4)
1022 self.assertEqual(square.cache_info().misses, 4)
1023
Antoine Pitroub5b37142012-11-13 21:35:40 +01001024 def test_lru_with_keyword_args(self):
1025 @functools.lru_cache()
1026 def fib(n):
1027 if n < 2:
1028 return n
1029 return fib(n=n-1) + fib(n=n-2)
1030 self.assertEqual(
1031 [fib(n=number) for number in range(16)],
1032 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1033 )
1034 self.assertEqual(fib.cache_info(),
1035 functools._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
1036 fib.cache_clear()
1037 self.assertEqual(fib.cache_info(),
1038 functools._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
1039
1040 def test_lru_with_keyword_args_maxsize_none(self):
1041 @functools.lru_cache(maxsize=None)
1042 def fib(n):
1043 if n < 2:
1044 return n
1045 return fib(n=n-1) + fib(n=n-2)
1046 self.assertEqual([fib(n=number) for number in range(16)],
1047 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1048 self.assertEqual(fib.cache_info(),
1049 functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1050 fib.cache_clear()
1051 self.assertEqual(fib.cache_info(),
1052 functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1053
Raymond Hettinger03923422013-03-04 02:52:50 -05001054 def test_need_for_rlock(self):
1055 # This will deadlock on an LRU cache that uses a regular lock
1056
1057 @functools.lru_cache(maxsize=10)
1058 def test_func(x):
1059 'Used to demonstrate a reentrant lru_cache call within a single thread'
1060 return x
1061
1062 class DoubleEq:
1063 'Demonstrate a reentrant lru_cache call within a single thread'
1064 def __init__(self, x):
1065 self.x = x
1066 def __hash__(self):
1067 return self.x
1068 def __eq__(self, other):
1069 if self.x == 2:
1070 test_func(DoubleEq(1))
1071 return self.x == other.x
1072
1073 test_func(DoubleEq(1)) # Load the cache
1074 test_func(DoubleEq(2)) # Load the cache
1075 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1076 DoubleEq(2)) # Verify the correct return value
1077
Raymond Hettinger4d588972014-08-12 12:44:52 -07001078 def test_early_detection_of_bad_call(self):
1079 # Issue #22184
1080 with self.assertRaises(TypeError):
1081 @functools.lru_cache
1082 def f():
1083 pass
1084
Raymond Hettinger03923422013-03-04 02:52:50 -05001085
Łukasz Langa6f692512013-06-05 12:20:24 +02001086class TestSingleDispatch(unittest.TestCase):
1087 def test_simple_overloads(self):
1088 @functools.singledispatch
1089 def g(obj):
1090 return "base"
1091 def g_int(i):
1092 return "integer"
1093 g.register(int, g_int)
1094 self.assertEqual(g("str"), "base")
1095 self.assertEqual(g(1), "integer")
1096 self.assertEqual(g([1,2,3]), "base")
1097
1098 def test_mro(self):
1099 @functools.singledispatch
1100 def g(obj):
1101 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001102 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001103 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001104 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001105 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001106 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001107 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001108 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001109 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001110 def g_A(a):
1111 return "A"
1112 def g_B(b):
1113 return "B"
1114 g.register(A, g_A)
1115 g.register(B, g_B)
1116 self.assertEqual(g(A()), "A")
1117 self.assertEqual(g(B()), "B")
1118 self.assertEqual(g(C()), "A")
1119 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001120
1121 def test_register_decorator(self):
1122 @functools.singledispatch
1123 def g(obj):
1124 return "base"
1125 @g.register(int)
1126 def g_int(i):
1127 return "int %s" % (i,)
1128 self.assertEqual(g(""), "base")
1129 self.assertEqual(g(12), "int 12")
1130 self.assertIs(g.dispatch(int), g_int)
1131 self.assertIs(g.dispatch(object), g.dispatch(str))
1132 # Note: in the assert above this is not g.
1133 # @singledispatch returns the wrapper.
1134
1135 def test_wrapping_attributes(self):
1136 @functools.singledispatch
1137 def g(obj):
1138 "Simple test"
1139 return "Test"
1140 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02001141 if sys.flags.optimize < 2:
1142 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02001143
1144 @unittest.skipUnless(decimal, 'requires _decimal')
1145 @support.cpython_only
1146 def test_c_classes(self):
1147 @functools.singledispatch
1148 def g(obj):
1149 return "base"
1150 @g.register(decimal.DecimalException)
1151 def _(obj):
1152 return obj.args
1153 subn = decimal.Subnormal("Exponent < Emin")
1154 rnd = decimal.Rounded("Number got rounded")
1155 self.assertEqual(g(subn), ("Exponent < Emin",))
1156 self.assertEqual(g(rnd), ("Number got rounded",))
1157 @g.register(decimal.Subnormal)
1158 def _(obj):
1159 return "Too small to care."
1160 self.assertEqual(g(subn), "Too small to care.")
1161 self.assertEqual(g(rnd), ("Number got rounded",))
1162
1163 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001164 # None of the examples in this test depend on haystack ordering.
Łukasz Langa6f692512013-06-05 12:20:24 +02001165 c = collections
1166 mro = functools._compose_mro
1167 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1168 for haystack in permutations(bases):
1169 m = mro(dict, haystack)
Łukasz Langa3720c772013-07-01 16:00:38 +02001170 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, c.Sized,
1171 c.Iterable, c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001172 bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict]
1173 for haystack in permutations(bases):
1174 m = mro(c.ChainMap, haystack)
1175 self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping,
1176 c.Sized, c.Iterable, c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001177
1178 # If there's a generic function with implementations registered for
1179 # both Sized and Container, passing a defaultdict to it results in an
1180 # ambiguous dispatch which will cause a RuntimeError (see
1181 # test_mro_conflicts).
1182 bases = [c.Container, c.Sized, str]
1183 for haystack in permutations(bases):
1184 m = mro(c.defaultdict, [c.Sized, c.Container, str])
1185 self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container,
1186 object])
1187
1188 # MutableSequence below is registered directly on D. In other words, it
1189 # preceeds MutableMapping which means single dispatch will always
1190 # choose MutableSequence here.
1191 class D(c.defaultdict):
1192 pass
1193 c.MutableSequence.register(D)
1194 bases = [c.MutableSequence, c.MutableMapping]
1195 for haystack in permutations(bases):
1196 m = mro(D, bases)
1197 self.assertEqual(m, [D, c.MutableSequence, c.Sequence,
1198 c.defaultdict, dict, c.MutableMapping,
1199 c.Mapping, c.Sized, c.Iterable, c.Container,
1200 object])
1201
1202 # Container and Callable are registered on different base classes and
1203 # a generic function supporting both should always pick the Callable
1204 # implementation if a C instance is passed.
1205 class C(c.defaultdict):
1206 def __call__(self):
1207 pass
1208 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1209 for haystack in permutations(bases):
1210 m = mro(C, haystack)
1211 self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping,
1212 c.Sized, c.Iterable, c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001213
1214 def test_register_abc(self):
1215 c = collections
1216 d = {"a": "b"}
1217 l = [1, 2, 3]
1218 s = {object(), None}
1219 f = frozenset(s)
1220 t = (1, 2, 3)
1221 @functools.singledispatch
1222 def g(obj):
1223 return "base"
1224 self.assertEqual(g(d), "base")
1225 self.assertEqual(g(l), "base")
1226 self.assertEqual(g(s), "base")
1227 self.assertEqual(g(f), "base")
1228 self.assertEqual(g(t), "base")
1229 g.register(c.Sized, lambda obj: "sized")
1230 self.assertEqual(g(d), "sized")
1231 self.assertEqual(g(l), "sized")
1232 self.assertEqual(g(s), "sized")
1233 self.assertEqual(g(f), "sized")
1234 self.assertEqual(g(t), "sized")
1235 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1236 self.assertEqual(g(d), "mutablemapping")
1237 self.assertEqual(g(l), "sized")
1238 self.assertEqual(g(s), "sized")
1239 self.assertEqual(g(f), "sized")
1240 self.assertEqual(g(t), "sized")
1241 g.register(c.ChainMap, lambda obj: "chainmap")
1242 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1243 self.assertEqual(g(l), "sized")
1244 self.assertEqual(g(s), "sized")
1245 self.assertEqual(g(f), "sized")
1246 self.assertEqual(g(t), "sized")
1247 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1248 self.assertEqual(g(d), "mutablemapping")
1249 self.assertEqual(g(l), "mutablesequence")
1250 self.assertEqual(g(s), "sized")
1251 self.assertEqual(g(f), "sized")
1252 self.assertEqual(g(t), "sized")
1253 g.register(c.MutableSet, lambda obj: "mutableset")
1254 self.assertEqual(g(d), "mutablemapping")
1255 self.assertEqual(g(l), "mutablesequence")
1256 self.assertEqual(g(s), "mutableset")
1257 self.assertEqual(g(f), "sized")
1258 self.assertEqual(g(t), "sized")
1259 g.register(c.Mapping, lambda obj: "mapping")
1260 self.assertEqual(g(d), "mutablemapping") # not specific enough
1261 self.assertEqual(g(l), "mutablesequence")
1262 self.assertEqual(g(s), "mutableset")
1263 self.assertEqual(g(f), "sized")
1264 self.assertEqual(g(t), "sized")
1265 g.register(c.Sequence, lambda obj: "sequence")
1266 self.assertEqual(g(d), "mutablemapping")
1267 self.assertEqual(g(l), "mutablesequence")
1268 self.assertEqual(g(s), "mutableset")
1269 self.assertEqual(g(f), "sized")
1270 self.assertEqual(g(t), "sequence")
1271 g.register(c.Set, lambda obj: "set")
1272 self.assertEqual(g(d), "mutablemapping")
1273 self.assertEqual(g(l), "mutablesequence")
1274 self.assertEqual(g(s), "mutableset")
1275 self.assertEqual(g(f), "set")
1276 self.assertEqual(g(t), "sequence")
1277 g.register(dict, lambda obj: "dict")
1278 self.assertEqual(g(d), "dict")
1279 self.assertEqual(g(l), "mutablesequence")
1280 self.assertEqual(g(s), "mutableset")
1281 self.assertEqual(g(f), "set")
1282 self.assertEqual(g(t), "sequence")
1283 g.register(list, lambda obj: "list")
1284 self.assertEqual(g(d), "dict")
1285 self.assertEqual(g(l), "list")
1286 self.assertEqual(g(s), "mutableset")
1287 self.assertEqual(g(f), "set")
1288 self.assertEqual(g(t), "sequence")
1289 g.register(set, lambda obj: "concrete-set")
1290 self.assertEqual(g(d), "dict")
1291 self.assertEqual(g(l), "list")
1292 self.assertEqual(g(s), "concrete-set")
1293 self.assertEqual(g(f), "set")
1294 self.assertEqual(g(t), "sequence")
1295 g.register(frozenset, lambda obj: "frozen-set")
1296 self.assertEqual(g(d), "dict")
1297 self.assertEqual(g(l), "list")
1298 self.assertEqual(g(s), "concrete-set")
1299 self.assertEqual(g(f), "frozen-set")
1300 self.assertEqual(g(t), "sequence")
1301 g.register(tuple, lambda obj: "tuple")
1302 self.assertEqual(g(d), "dict")
1303 self.assertEqual(g(l), "list")
1304 self.assertEqual(g(s), "concrete-set")
1305 self.assertEqual(g(f), "frozen-set")
1306 self.assertEqual(g(t), "tuple")
1307
Łukasz Langa3720c772013-07-01 16:00:38 +02001308 def test_c3_abc(self):
1309 c = collections
1310 mro = functools._c3_mro
1311 class A(object):
1312 pass
1313 class B(A):
1314 def __len__(self):
1315 return 0 # implies Sized
1316 @c.Container.register
1317 class C(object):
1318 pass
1319 class D(object):
1320 pass # unrelated
1321 class X(D, C, B):
1322 def __call__(self):
1323 pass # implies Callable
1324 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
1325 for abcs in permutations([c.Sized, c.Callable, c.Container]):
1326 self.assertEqual(mro(X, abcs=abcs), expected)
1327 # unrelated ABCs don't appear in the resulting MRO
1328 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
1329 self.assertEqual(mro(X, abcs=many_abcs), expected)
1330
Łukasz Langa6f692512013-06-05 12:20:24 +02001331 def test_mro_conflicts(self):
1332 c = collections
Łukasz Langa6f692512013-06-05 12:20:24 +02001333 @functools.singledispatch
1334 def g(arg):
1335 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02001336 class O(c.Sized):
1337 def __len__(self):
1338 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001339 o = O()
1340 self.assertEqual(g(o), "base")
1341 g.register(c.Iterable, lambda arg: "iterable")
1342 g.register(c.Container, lambda arg: "container")
1343 g.register(c.Sized, lambda arg: "sized")
1344 g.register(c.Set, lambda arg: "set")
1345 self.assertEqual(g(o), "sized")
1346 c.Iterable.register(O)
1347 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
1348 c.Container.register(O)
1349 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02001350 c.Set.register(O)
1351 self.assertEqual(g(o), "set") # because c.Set is a subclass of
1352 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02001353 class P:
1354 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02001355 p = P()
1356 self.assertEqual(g(p), "base")
1357 c.Iterable.register(P)
1358 self.assertEqual(g(p), "iterable")
1359 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02001360 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02001361 g(p)
Łukasz Langa3720c772013-07-01 16:00:38 +02001362 self.assertIn(
1363 str(re_one.exception),
1364 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1365 "or <class 'collections.abc.Iterable'>"),
1366 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
1367 "or <class 'collections.abc.Container'>")),
1368 )
Łukasz Langa6f692512013-06-05 12:20:24 +02001369 class Q(c.Sized):
1370 def __len__(self):
1371 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001372 q = Q()
1373 self.assertEqual(g(q), "sized")
1374 c.Iterable.register(Q)
1375 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
1376 c.Set.register(Q)
1377 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02001378 # c.Sized and c.Iterable
1379 @functools.singledispatch
1380 def h(arg):
1381 return "base"
1382 @h.register(c.Sized)
1383 def _(arg):
1384 return "sized"
1385 @h.register(c.Container)
1386 def _(arg):
1387 return "container"
1388 # Even though Sized and Container are explicit bases of MutableMapping,
1389 # this ABC is implicitly registered on defaultdict which makes all of
1390 # MutableMapping's bases implicit as well from defaultdict's
1391 # perspective.
1392 with self.assertRaises(RuntimeError) as re_two:
1393 h(c.defaultdict(lambda: 0))
1394 self.assertIn(
1395 str(re_two.exception),
1396 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1397 "or <class 'collections.abc.Sized'>"),
1398 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1399 "or <class 'collections.abc.Container'>")),
1400 )
1401 class R(c.defaultdict):
1402 pass
1403 c.MutableSequence.register(R)
1404 @functools.singledispatch
1405 def i(arg):
1406 return "base"
1407 @i.register(c.MutableMapping)
1408 def _(arg):
1409 return "mapping"
1410 @i.register(c.MutableSequence)
1411 def _(arg):
1412 return "sequence"
1413 r = R()
1414 self.assertEqual(i(r), "sequence")
1415 class S:
1416 pass
1417 class T(S, c.Sized):
1418 def __len__(self):
1419 return 0
1420 t = T()
1421 self.assertEqual(h(t), "sized")
1422 c.Container.register(T)
1423 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
1424 class U:
1425 def __len__(self):
1426 return 0
1427 u = U()
1428 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
1429 # from the existence of __len__()
1430 c.Container.register(U)
1431 # There is no preference for registered versus inferred ABCs.
1432 with self.assertRaises(RuntimeError) as re_three:
1433 h(u)
1434 self.assertIn(
1435 str(re_three.exception),
1436 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1437 "or <class 'collections.abc.Sized'>"),
1438 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1439 "or <class 'collections.abc.Container'>")),
1440 )
1441 class V(c.Sized, S):
1442 def __len__(self):
1443 return 0
1444 @functools.singledispatch
1445 def j(arg):
1446 return "base"
1447 @j.register(S)
1448 def _(arg):
1449 return "s"
1450 @j.register(c.Container)
1451 def _(arg):
1452 return "container"
1453 v = V()
1454 self.assertEqual(j(v), "s")
1455 c.Container.register(V)
1456 self.assertEqual(j(v), "container") # because it ends up right after
1457 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02001458
1459 def test_cache_invalidation(self):
1460 from collections import UserDict
1461 class TracingDict(UserDict):
1462 def __init__(self, *args, **kwargs):
1463 super(TracingDict, self).__init__(*args, **kwargs)
1464 self.set_ops = []
1465 self.get_ops = []
1466 def __getitem__(self, key):
1467 result = self.data[key]
1468 self.get_ops.append(key)
1469 return result
1470 def __setitem__(self, key, value):
1471 self.set_ops.append(key)
1472 self.data[key] = value
1473 def clear(self):
1474 self.data.clear()
1475 _orig_wkd = functools.WeakKeyDictionary
1476 td = TracingDict()
1477 functools.WeakKeyDictionary = lambda: td
1478 c = collections
1479 @functools.singledispatch
1480 def g(arg):
1481 return "base"
1482 d = {}
1483 l = []
1484 self.assertEqual(len(td), 0)
1485 self.assertEqual(g(d), "base")
1486 self.assertEqual(len(td), 1)
1487 self.assertEqual(td.get_ops, [])
1488 self.assertEqual(td.set_ops, [dict])
1489 self.assertEqual(td.data[dict], g.registry[object])
1490 self.assertEqual(g(l), "base")
1491 self.assertEqual(len(td), 2)
1492 self.assertEqual(td.get_ops, [])
1493 self.assertEqual(td.set_ops, [dict, list])
1494 self.assertEqual(td.data[dict], g.registry[object])
1495 self.assertEqual(td.data[list], g.registry[object])
1496 self.assertEqual(td.data[dict], td.data[list])
1497 self.assertEqual(g(l), "base")
1498 self.assertEqual(g(d), "base")
1499 self.assertEqual(td.get_ops, [list, dict])
1500 self.assertEqual(td.set_ops, [dict, list])
1501 g.register(list, lambda arg: "list")
1502 self.assertEqual(td.get_ops, [list, dict])
1503 self.assertEqual(len(td), 0)
1504 self.assertEqual(g(d), "base")
1505 self.assertEqual(len(td), 1)
1506 self.assertEqual(td.get_ops, [list, dict])
1507 self.assertEqual(td.set_ops, [dict, list, dict])
1508 self.assertEqual(td.data[dict],
1509 functools._find_impl(dict, g.registry))
1510 self.assertEqual(g(l), "list")
1511 self.assertEqual(len(td), 2)
1512 self.assertEqual(td.get_ops, [list, dict])
1513 self.assertEqual(td.set_ops, [dict, list, dict, list])
1514 self.assertEqual(td.data[list],
1515 functools._find_impl(list, g.registry))
1516 class X:
1517 pass
1518 c.MutableMapping.register(X) # Will not invalidate the cache,
1519 # not using ABCs yet.
1520 self.assertEqual(g(d), "base")
1521 self.assertEqual(g(l), "list")
1522 self.assertEqual(td.get_ops, [list, dict, dict, list])
1523 self.assertEqual(td.set_ops, [dict, list, dict, list])
1524 g.register(c.Sized, lambda arg: "sized")
1525 self.assertEqual(len(td), 0)
1526 self.assertEqual(g(d), "sized")
1527 self.assertEqual(len(td), 1)
1528 self.assertEqual(td.get_ops, [list, dict, dict, list])
1529 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
1530 self.assertEqual(g(l), "list")
1531 self.assertEqual(len(td), 2)
1532 self.assertEqual(td.get_ops, [list, dict, dict, list])
1533 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1534 self.assertEqual(g(l), "list")
1535 self.assertEqual(g(d), "sized")
1536 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
1537 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1538 g.dispatch(list)
1539 g.dispatch(dict)
1540 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
1541 list, dict])
1542 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1543 c.MutableSet.register(X) # Will invalidate the cache.
1544 self.assertEqual(len(td), 2) # Stale cache.
1545 self.assertEqual(g(l), "list")
1546 self.assertEqual(len(td), 1)
1547 g.register(c.MutableMapping, lambda arg: "mutablemapping")
1548 self.assertEqual(len(td), 0)
1549 self.assertEqual(g(d), "mutablemapping")
1550 self.assertEqual(len(td), 1)
1551 self.assertEqual(g(l), "list")
1552 self.assertEqual(len(td), 2)
1553 g.register(dict, lambda arg: "dict")
1554 self.assertEqual(g(d), "dict")
1555 self.assertEqual(g(l), "list")
1556 g._clear_cache()
1557 self.assertEqual(len(td), 0)
1558 functools.WeakKeyDictionary = _orig_wkd
1559
1560
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001561def test_main(verbose=None):
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001562 test_classes = (
Antoine Pitroub5b37142012-11-13 21:35:40 +01001563 TestPartialC,
1564 TestPartialPy,
1565 TestPartialCSubclass,
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001566 TestPartialMethod,
Thomas Wouters73e5a5b2006-06-08 15:35:45 +00001567 TestUpdateWrapper,
Benjamin Peterson9c2930e2010-08-23 17:40:33 +00001568 TestTotalOrdering,
Antoine Pitroub5b37142012-11-13 21:35:40 +01001569 TestCmpToKeyC,
1570 TestCmpToKeyPy,
Guido van Rossum0919a1a2006-08-26 20:49:04 +00001571 TestWraps,
Georg Brandl2e7346a2010-07-31 18:09:23 +00001572 TestReduce,
1573 TestLRU,
Łukasz Langa6f692512013-06-05 12:20:24 +02001574 TestSingleDispatch,
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001575 )
Benjamin Petersonee8712c2008-05-20 21:35:26 +00001576 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001577
1578 # verify reference counting
1579 if verbose and hasattr(sys, "gettotalrefcount"):
1580 import gc
1581 counts = [None] * 5
Guido van Rossum805365e2007-05-07 22:24:25 +00001582 for i in range(len(counts)):
Benjamin Petersonee8712c2008-05-20 21:35:26 +00001583 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001584 gc.collect()
1585 counts[i] = sys.gettotalrefcount()
Guido van Rossumbe19ed72007-02-09 05:37:30 +00001586 print(counts)
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001587
1588if __name__ == '__main__':
1589 test_main(verbose=True)