blob: c549ac4cc4255b9af91608108544c3d282e17b38 [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettinger003be522011-05-03 11:01:32 -07002import collections
Łukasz Langa6f692512013-06-05 12:20:24 +02003from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00004import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00005from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02006import sys
7from test import support
8import unittest
9from weakref import proxy
Raymond Hettinger9c323f82005-02-28 19:39:44 +000010
Antoine Pitroub5b37142012-11-13 21:35:40 +010011import functools
12
Antoine Pitroub5b37142012-11-13 21:35:40 +010013py_functools = support.import_fresh_module('functools', blocked=['_functools'])
14c_functools = support.import_fresh_module('functools', fresh=['_functools'])
15
Łukasz Langa6f692512013-06-05 12:20:24 +020016decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
17
18
Raymond Hettinger9c323f82005-02-28 19:39:44 +000019def capture(*args, **kw):
20 """capture all positional and keyword arguments"""
21 return args, kw
22
Łukasz Langa6f692512013-06-05 12:20:24 +020023
Jack Diederiche0cbd692009-04-01 04:27:09 +000024def signature(part):
25 """ return the signature of a partial object """
26 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000027
Łukasz Langa6f692512013-06-05 12:20:24 +020028
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020029class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000030
31 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010032 p = self.partial(capture, 1, 2, a=10, b=20)
33 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000034 self.assertEqual(p(3, 4, b=30, c=40),
35 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010036 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000037 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000038
39 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010040 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000041 # attributes should be readable
42 self.assertEqual(p.func, capture)
43 self.assertEqual(p.args, (1, 2))
44 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000045
46 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010047 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000048 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010049 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000050 except TypeError:
51 pass
52 else:
53 self.fail('First arg not checked for callability')
54
55 def test_protection_of_callers_dict_argument(self):
56 # a caller's dictionary should not be altered by partial
57 def func(a=10, b=20):
58 return a
59 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010060 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000061 self.assertEqual(p(**d), 3)
62 self.assertEqual(d, {'a':3})
63 p(b=7)
64 self.assertEqual(d, {'a':3})
65
66 def test_arg_combinations(self):
67 # exercise special code paths for zero args in either partial
68 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010069 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000070 self.assertEqual(p(), ((), {}))
71 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010072 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000073 self.assertEqual(p(), ((1,2), {}))
74 self.assertEqual(p(3,4), ((1,2,3,4), {}))
75
76 def test_kw_combinations(self):
77 # exercise special code paths for no keyword args in
78 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010079 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000080 self.assertEqual(p(), ((), {}))
81 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010082 p = self.partial(capture, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000083 self.assertEqual(p(), ((), {'a':1}))
84 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
85 # keyword args in the call override those in the partial object
86 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
87
88 def test_positional(self):
89 # make sure positional arguments are captured correctly
90 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +010091 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000092 expected = args + ('x',)
93 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +000094 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +000095
96 def test_keyword(self):
97 # make sure keyword arguments are captured correctly
98 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +010099 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000100 expected = {'a':a,'x':None}
101 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000102 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000103
104 def test_no_side_effects(self):
105 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100106 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000107 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000108 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000109 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000110 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000111
112 def test_error_propagation(self):
113 def f(x, y):
114 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100115 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
116 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
117 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
118 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000119
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000120 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100121 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000122 p = proxy(f)
123 self.assertEqual(f.func, p.func)
124 f = None
125 self.assertRaises(ReferenceError, getattr, p, 'func')
126
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000127 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000128 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100129 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000130 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100131 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000132 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000133
Alexander Belopolskye49af342015-03-01 15:08:17 -0500134 def test_nested_optimization(self):
135 partial = self.partial
136 # Only "true" partial is optimized
137 if partial.__name__ != 'partial':
138 return
139 inner = partial(signature, 'asdf')
140 nested = partial(inner, bar=True)
141 flat = partial(signature, 'asdf', bar=True)
142 self.assertEqual(signature(nested), signature(flat))
143
Łukasz Langa6f692512013-06-05 12:20:24 +0200144
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200145@unittest.skipUnless(c_functools, 'requires the C _functools module')
146class TestPartialC(TestPartial, unittest.TestCase):
147 if c_functools:
148 partial = c_functools.partial
149
Zachary Ware101d9e72013-12-08 00:44:27 -0600150 def test_attributes_unwritable(self):
151 # attributes should not be writable
152 p = self.partial(capture, 1, 2, a=10, b=20)
153 self.assertRaises(AttributeError, setattr, p, 'func', map)
154 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
155 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
156
157 p = self.partial(hex)
158 try:
159 del p.__dict__
160 except TypeError:
161 pass
162 else:
163 self.fail('partial object allowed __dict__ to be deleted')
164
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000165 def test_repr(self):
166 args = (object(), object())
167 args_repr = ', '.join(repr(a) for a in args)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200168 kwargs = {'a': object(), 'b': object()}
169 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
170 'b={b!r}, a={a!r}'.format_map(kwargs)]
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200171 if self.partial is c_functools.partial:
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000172 name = 'functools.partial'
173 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100174 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000175
Antoine Pitroub5b37142012-11-13 21:35:40 +0100176 f = self.partial(capture)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000177 self.assertEqual('{}({!r})'.format(name, capture),
178 repr(f))
179
Antoine Pitroub5b37142012-11-13 21:35:40 +0100180 f = self.partial(capture, *args)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000181 self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr),
182 repr(f))
183
Antoine Pitroub5b37142012-11-13 21:35:40 +0100184 f = self.partial(capture, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200185 self.assertIn(repr(f),
186 ['{}({!r}, {})'.format(name, capture, kwargs_repr)
187 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000188
Antoine Pitroub5b37142012-11-13 21:35:40 +0100189 f = self.partial(capture, *args, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200190 self.assertIn(repr(f),
191 ['{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr)
192 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000193
Jack Diederiche0cbd692009-04-01 04:27:09 +0000194 def test_pickle(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100195 f = self.partial(signature, 'asdf', bar=True)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000196 f.add_something_to__dict__ = True
Serhiy Storchakabad12572014-12-15 14:03:42 +0200197 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
198 f_copy = pickle.loads(pickle.dumps(f, proto))
199 self.assertEqual(signature(f), signature(f_copy))
Jack Diederiche0cbd692009-04-01 04:27:09 +0000200
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200201 # Issue 6083: Reference counting bug
202 def test_setstate_refcount(self):
203 class BadSequence:
204 def __len__(self):
205 return 4
206 def __getitem__(self, key):
207 if key == 0:
208 return max
209 elif key == 1:
210 return tuple(range(1000000))
211 elif key in (2, 3):
212 return {}
213 raise IndexError
214
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200215 f = self.partial(object)
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200216 self.assertRaisesRegex(SystemError,
217 "new style getargs format but argument is not a tuple",
218 f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000219
Łukasz Langa6f692512013-06-05 12:20:24 +0200220
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200221class TestPartialPy(TestPartial, unittest.TestCase):
222 partial = staticmethod(py_functools.partial)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000223
Łukasz Langa6f692512013-06-05 12:20:24 +0200224
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200225if c_functools:
226 class PartialSubclass(c_functools.partial):
227 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100228
Łukasz Langa6f692512013-06-05 12:20:24 +0200229
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200230@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200231class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200232 if c_functools:
233 partial = PartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000234
Łukasz Langa6f692512013-06-05 12:20:24 +0200235
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000236class TestPartialMethod(unittest.TestCase):
237
238 class A(object):
239 nothing = functools.partialmethod(capture)
240 positional = functools.partialmethod(capture, 1)
241 keywords = functools.partialmethod(capture, a=2)
242 both = functools.partialmethod(capture, 3, b=4)
243
244 nested = functools.partialmethod(positional, 5)
245
246 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
247
248 static = functools.partialmethod(staticmethod(capture), 8)
249 cls = functools.partialmethod(classmethod(capture), d=9)
250
251 a = A()
252
253 def test_arg_combinations(self):
254 self.assertEqual(self.a.nothing(), ((self.a,), {}))
255 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
256 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
257 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
258
259 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
260 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
261 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
262 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
263
264 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
265 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
266 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
267 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
268
269 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
270 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
271 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
272 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
273
274 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
275
276 def test_nested(self):
277 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
278 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
279 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
280 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
281
282 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
283
284 def test_over_partial(self):
285 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
286 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
287 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
288 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
289
290 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
291
292 def test_bound_method_introspection(self):
293 obj = self.a
294 self.assertIs(obj.both.__self__, obj)
295 self.assertIs(obj.nested.__self__, obj)
296 self.assertIs(obj.over_partial.__self__, obj)
297 self.assertIs(obj.cls.__self__, self.A)
298 self.assertIs(self.A.cls.__self__, self.A)
299
300 def test_unbound_method_retrieval(self):
301 obj = self.A
302 self.assertFalse(hasattr(obj.both, "__self__"))
303 self.assertFalse(hasattr(obj.nested, "__self__"))
304 self.assertFalse(hasattr(obj.over_partial, "__self__"))
305 self.assertFalse(hasattr(obj.static, "__self__"))
306 self.assertFalse(hasattr(self.a.static, "__self__"))
307
308 def test_descriptors(self):
309 for obj in [self.A, self.a]:
310 with self.subTest(obj=obj):
311 self.assertEqual(obj.static(), ((8,), {}))
312 self.assertEqual(obj.static(5), ((8, 5), {}))
313 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
314 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
315
316 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
317 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
318 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
319 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
320
321 def test_overriding_keywords(self):
322 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
323 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
324
325 def test_invalid_args(self):
326 with self.assertRaises(TypeError):
327 class B(object):
328 method = functools.partialmethod(None, 1)
329
330 def test_repr(self):
331 self.assertEqual(repr(vars(self.A)['both']),
332 'functools.partialmethod({}, 3, b=4)'.format(capture))
333
334 def test_abstract(self):
335 class Abstract(abc.ABCMeta):
336
337 @abc.abstractmethod
338 def add(self, x, y):
339 pass
340
341 add5 = functools.partialmethod(add, 5)
342
343 self.assertTrue(Abstract.add.__isabstractmethod__)
344 self.assertTrue(Abstract.add5.__isabstractmethod__)
345
346 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
347 self.assertFalse(getattr(func, '__isabstractmethod__', False))
348
349
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000350class TestUpdateWrapper(unittest.TestCase):
351
352 def check_wrapper(self, wrapper, wrapped,
353 assigned=functools.WRAPPER_ASSIGNMENTS,
354 updated=functools.WRAPPER_UPDATES):
355 # Check attributes were assigned
356 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000357 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000358 # Check attributes were updated
359 for name in updated:
360 wrapper_attr = getattr(wrapper, name)
361 wrapped_attr = getattr(wrapped, name)
362 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000363 if name == "__dict__" and key == "__wrapped__":
364 # __wrapped__ is overwritten by the update code
365 continue
366 self.assertIs(wrapped_attr[key], wrapper_attr[key])
367 # Check __wrapped__
368 self.assertIs(wrapper.__wrapped__, wrapped)
369
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000370
R. David Murray378c0cf2010-02-24 01:46:21 +0000371 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000372 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000373 """This is a test"""
374 pass
375 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000376 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000377 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000378 pass
379 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000380 return wrapper, f
381
382 def test_default_update(self):
383 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000384 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000385 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000386 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600387 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000388 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000389 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
390 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000391
R. David Murray378c0cf2010-02-24 01:46:21 +0000392 @unittest.skipIf(sys.flags.optimize >= 2,
393 "Docstrings are omitted with -O2 and above")
394 def test_default_update_doc(self):
395 wrapper, f = self._default_update()
396 self.assertEqual(wrapper.__doc__, 'This is a test')
397
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000398 def test_no_update(self):
399 def f():
400 """This is a test"""
401 pass
402 f.attr = 'This is also a test'
403 def wrapper():
404 pass
405 functools.update_wrapper(wrapper, f, (), ())
406 self.check_wrapper(wrapper, f, (), ())
407 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600408 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000409 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000410 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000411 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000412
413 def test_selective_update(self):
414 def f():
415 pass
416 f.attr = 'This is a different test'
417 f.dict_attr = dict(a=1, b=2, c=3)
418 def wrapper():
419 pass
420 wrapper.dict_attr = {}
421 assign = ('attr',)
422 update = ('dict_attr',)
423 functools.update_wrapper(wrapper, f, assign, update)
424 self.check_wrapper(wrapper, f, assign, update)
425 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600426 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000427 self.assertEqual(wrapper.__doc__, None)
428 self.assertEqual(wrapper.attr, 'This is a different test')
429 self.assertEqual(wrapper.dict_attr, f.dict_attr)
430
Nick Coghlan98876832010-08-17 06:17:18 +0000431 def test_missing_attributes(self):
432 def f():
433 pass
434 def wrapper():
435 pass
436 wrapper.dict_attr = {}
437 assign = ('attr',)
438 update = ('dict_attr',)
439 # Missing attributes on wrapped object are ignored
440 functools.update_wrapper(wrapper, f, assign, update)
441 self.assertNotIn('attr', wrapper.__dict__)
442 self.assertEqual(wrapper.dict_attr, {})
443 # Wrapper must have expected attributes for updating
444 del wrapper.dict_attr
445 with self.assertRaises(AttributeError):
446 functools.update_wrapper(wrapper, f, assign, update)
447 wrapper.dict_attr = 1
448 with self.assertRaises(AttributeError):
449 functools.update_wrapper(wrapper, f, assign, update)
450
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200451 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000452 @unittest.skipIf(sys.flags.optimize >= 2,
453 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000454 def test_builtin_update(self):
455 # Test for bug #1576241
456 def wrapper():
457 pass
458 functools.update_wrapper(wrapper, max)
459 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000460 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000461 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000462
Łukasz Langa6f692512013-06-05 12:20:24 +0200463
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000464class TestWraps(TestUpdateWrapper):
465
R. David Murray378c0cf2010-02-24 01:46:21 +0000466 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000467 def f():
468 """This is a test"""
469 pass
470 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000471 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000472 @functools.wraps(f)
473 def wrapper():
474 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600475 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000476
477 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600478 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000479 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000480 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600481 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000482 self.assertEqual(wrapper.attr, 'This is also a test')
483
Antoine Pitroub5b37142012-11-13 21:35:40 +0100484 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000485 "Docstrings are omitted with -O2 and above")
486 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600487 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000488 self.assertEqual(wrapper.__doc__, 'This is a test')
489
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000490 def test_no_update(self):
491 def f():
492 """This is a test"""
493 pass
494 f.attr = 'This is also a test'
495 @functools.wraps(f, (), ())
496 def wrapper():
497 pass
498 self.check_wrapper(wrapper, f, (), ())
499 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600500 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000501 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000502 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000503
504 def test_selective_update(self):
505 def f():
506 pass
507 f.attr = 'This is a different test'
508 f.dict_attr = dict(a=1, b=2, c=3)
509 def add_dict_attr(f):
510 f.dict_attr = {}
511 return f
512 assign = ('attr',)
513 update = ('dict_attr',)
514 @functools.wraps(f, assign, update)
515 @add_dict_attr
516 def wrapper():
517 pass
518 self.check_wrapper(wrapper, f, assign, update)
519 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600520 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000521 self.assertEqual(wrapper.__doc__, None)
522 self.assertEqual(wrapper.attr, 'This is a different test')
523 self.assertEqual(wrapper.dict_attr, f.dict_attr)
524
Łukasz Langa6f692512013-06-05 12:20:24 +0200525
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000526class TestReduce(unittest.TestCase):
527 func = functools.reduce
528
529 def test_reduce(self):
530 class Squares:
531 def __init__(self, max):
532 self.max = max
533 self.sofar = []
534
535 def __len__(self):
536 return len(self.sofar)
537
538 def __getitem__(self, i):
539 if not 0 <= i < self.max: raise IndexError
540 n = len(self.sofar)
541 while n <= i:
542 self.sofar.append(n*n)
543 n += 1
544 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000545 def add(x, y):
546 return x + y
547 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000548 self.assertEqual(
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000549 self.func(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000550 ['a','c','d','w']
551 )
552 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
553 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000554 self.func(lambda x, y: x*y, range(2,21), 1),
555 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000556 )
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000557 self.assertEqual(self.func(add, Squares(10)), 285)
558 self.assertEqual(self.func(add, Squares(10), 0), 285)
559 self.assertEqual(self.func(add, Squares(0), 0), 0)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000560 self.assertRaises(TypeError, self.func)
561 self.assertRaises(TypeError, self.func, 42, 42)
562 self.assertRaises(TypeError, self.func, 42, 42, 42)
563 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
564 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
565 self.assertRaises(TypeError, self.func, 42, (42, 42))
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000566 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
567 self.assertRaises(TypeError, self.func, add, "")
568 self.assertRaises(TypeError, self.func, add, ())
569 self.assertRaises(TypeError, self.func, add, object())
570
571 class TestFailingIter:
572 def __iter__(self):
573 raise RuntimeError
574 self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
575
576 self.assertEqual(self.func(add, [], None), None)
577 self.assertEqual(self.func(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000578
579 class BadSeq:
580 def __getitem__(self, index):
581 raise ValueError
582 self.assertRaises(ValueError, self.func, 42, BadSeq())
583
584 # Test reduce()'s use of iterators.
585 def test_iterator_usage(self):
586 class SequenceClass:
587 def __init__(self, n):
588 self.n = n
589 def __getitem__(self, i):
590 if 0 <= i < self.n:
591 return i
592 else:
593 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000594
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000595 from operator import add
596 self.assertEqual(self.func(add, SequenceClass(5)), 10)
597 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
598 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
599 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
600 self.assertEqual(self.func(add, SequenceClass(1)), 0)
601 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
602
603 d = {"one": 1, "two": 2, "three": 3}
604 self.assertEqual(self.func(add, d), "".join(d.keys()))
605
Łukasz Langa6f692512013-06-05 12:20:24 +0200606
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200607class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700608
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000609 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700610 def cmp1(x, y):
611 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100612 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700613 self.assertEqual(key(3), key(3))
614 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100615 self.assertGreaterEqual(key(3), key(3))
616
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700617 def cmp2(x, y):
618 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100619 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700620 self.assertEqual(key(4.0), key('4'))
621 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100622 self.assertLessEqual(key(2), key('35'))
623 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700624
625 def test_cmp_to_key_arguments(self):
626 def cmp1(x, y):
627 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100628 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700629 self.assertEqual(key(obj=3), key(obj=3))
630 self.assertGreater(key(obj=3), key(obj=1))
631 with self.assertRaises((TypeError, AttributeError)):
632 key(3) > 1 # rhs is not a K object
633 with self.assertRaises((TypeError, AttributeError)):
634 1 < key(3) # lhs is not a K object
635 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100636 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700637 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200638 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100639 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700640 with self.assertRaises(TypeError):
641 key() # too few args
642 with self.assertRaises(TypeError):
643 key(None, None) # too many args
644
645 def test_bad_cmp(self):
646 def cmp1(x, y):
647 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100648 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700649 with self.assertRaises(ZeroDivisionError):
650 key(3) > key(1)
651
652 class BadCmp:
653 def __lt__(self, other):
654 raise ZeroDivisionError
655 def cmp1(x, y):
656 return BadCmp()
657 with self.assertRaises(ZeroDivisionError):
658 key(3) > key(1)
659
660 def test_obj_field(self):
661 def cmp1(x, y):
662 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100663 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700664 self.assertEqual(key(50).obj, 50)
665
666 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000667 def mycmp(x, y):
668 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100669 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000670 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000671
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700672 def test_sort_int_str(self):
673 def mycmp(x, y):
674 x, y = int(x), int(y)
675 return (x > y) - (x < y)
676 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100677 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700678 self.assertEqual([int(value) for value in values],
679 [0, 1, 1, 2, 3, 4, 5, 7, 10])
680
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000681 def test_hash(self):
682 def mycmp(x, y):
683 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100684 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000685 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700686 self.assertRaises(TypeError, hash, k)
Raymond Hettingere7a24302011-05-03 11:16:36 -0700687 self.assertNotIsInstance(k, collections.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000688
Łukasz Langa6f692512013-06-05 12:20:24 +0200689
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200690@unittest.skipUnless(c_functools, 'requires the C _functools module')
691class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
692 if c_functools:
693 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100694
Łukasz Langa6f692512013-06-05 12:20:24 +0200695
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200696class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100697 cmp_to_key = staticmethod(py_functools.cmp_to_key)
698
Łukasz Langa6f692512013-06-05 12:20:24 +0200699
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000700class TestTotalOrdering(unittest.TestCase):
701
702 def test_total_ordering_lt(self):
703 @functools.total_ordering
704 class A:
705 def __init__(self, value):
706 self.value = value
707 def __lt__(self, other):
708 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000709 def __eq__(self, other):
710 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000711 self.assertTrue(A(1) < A(2))
712 self.assertTrue(A(2) > A(1))
713 self.assertTrue(A(1) <= A(2))
714 self.assertTrue(A(2) >= A(1))
715 self.assertTrue(A(2) <= A(2))
716 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000717 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000718
719 def test_total_ordering_le(self):
720 @functools.total_ordering
721 class A:
722 def __init__(self, value):
723 self.value = value
724 def __le__(self, other):
725 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000726 def __eq__(self, other):
727 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000728 self.assertTrue(A(1) < A(2))
729 self.assertTrue(A(2) > A(1))
730 self.assertTrue(A(1) <= A(2))
731 self.assertTrue(A(2) >= A(1))
732 self.assertTrue(A(2) <= A(2))
733 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000734 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000735
736 def test_total_ordering_gt(self):
737 @functools.total_ordering
738 class A:
739 def __init__(self, value):
740 self.value = value
741 def __gt__(self, other):
742 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000743 def __eq__(self, other):
744 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000745 self.assertTrue(A(1) < A(2))
746 self.assertTrue(A(2) > A(1))
747 self.assertTrue(A(1) <= A(2))
748 self.assertTrue(A(2) >= A(1))
749 self.assertTrue(A(2) <= A(2))
750 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000751 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000752
753 def test_total_ordering_ge(self):
754 @functools.total_ordering
755 class A:
756 def __init__(self, value):
757 self.value = value
758 def __ge__(self, other):
759 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000760 def __eq__(self, other):
761 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000762 self.assertTrue(A(1) < A(2))
763 self.assertTrue(A(2) > A(1))
764 self.assertTrue(A(1) <= A(2))
765 self.assertTrue(A(2) >= A(1))
766 self.assertTrue(A(2) <= A(2))
767 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000768 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000769
770 def test_total_ordering_no_overwrite(self):
771 # new methods should not overwrite existing
772 @functools.total_ordering
773 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000774 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +0000775 self.assertTrue(A(1) < A(2))
776 self.assertTrue(A(2) > A(1))
777 self.assertTrue(A(1) <= A(2))
778 self.assertTrue(A(2) >= A(1))
779 self.assertTrue(A(2) <= A(2))
780 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000781
Benjamin Peterson42ebee32010-04-11 01:43:16 +0000782 def test_no_operations_defined(self):
783 with self.assertRaises(ValueError):
784 @functools.total_ordering
785 class A:
786 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000787
Nick Coghlanf05d9812013-10-02 00:02:03 +1000788 def test_type_error_when_not_implemented(self):
789 # bug 10042; ensure stack overflow does not occur
790 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000791 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +1000792 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000793 def __init__(self, value):
794 self.value = value
795 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000796 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000797 return self.value == other.value
798 return False
799 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000800 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000801 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +1000802 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000803
Nick Coghlanf05d9812013-10-02 00:02:03 +1000804 @functools.total_ordering
805 class ImplementsGreaterThan:
806 def __init__(self, value):
807 self.value = value
808 def __eq__(self, other):
809 if isinstance(other, ImplementsGreaterThan):
810 return self.value == other.value
811 return False
812 def __gt__(self, other):
813 if isinstance(other, ImplementsGreaterThan):
814 return self.value > other.value
815 return NotImplemented
816
817 @functools.total_ordering
818 class ImplementsLessThanEqualTo:
819 def __init__(self, value):
820 self.value = value
821 def __eq__(self, other):
822 if isinstance(other, ImplementsLessThanEqualTo):
823 return self.value == other.value
824 return False
825 def __le__(self, other):
826 if isinstance(other, ImplementsLessThanEqualTo):
827 return self.value <= other.value
828 return NotImplemented
829
830 @functools.total_ordering
831 class ImplementsGreaterThanEqualTo:
832 def __init__(self, value):
833 self.value = value
834 def __eq__(self, other):
835 if isinstance(other, ImplementsGreaterThanEqualTo):
836 return self.value == other.value
837 return False
838 def __ge__(self, other):
839 if isinstance(other, ImplementsGreaterThanEqualTo):
840 return self.value >= other.value
841 return NotImplemented
842
843 @functools.total_ordering
844 class ComparatorNotImplemented:
845 def __init__(self, value):
846 self.value = value
847 def __eq__(self, other):
848 if isinstance(other, ComparatorNotImplemented):
849 return self.value == other.value
850 return False
851 def __lt__(self, other):
852 return NotImplemented
853
854 with self.subTest("LT < 1"), self.assertRaises(TypeError):
855 ImplementsLessThan(-1) < 1
856
857 with self.subTest("LT < LE"), self.assertRaises(TypeError):
858 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
859
860 with self.subTest("LT < GT"), self.assertRaises(TypeError):
861 ImplementsLessThan(1) < ImplementsGreaterThan(1)
862
863 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
864 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
865
866 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
867 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
868
869 with self.subTest("GT > GE"), self.assertRaises(TypeError):
870 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
871
872 with self.subTest("GT > LT"), self.assertRaises(TypeError):
873 ImplementsGreaterThan(5) > ImplementsLessThan(5)
874
875 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
876 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
877
878 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
879 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
880
881 with self.subTest("GE when equal"):
882 a = ComparatorNotImplemented(8)
883 b = ComparatorNotImplemented(8)
884 self.assertEqual(a, b)
885 with self.assertRaises(TypeError):
886 a >= b
887
888 with self.subTest("LE when equal"):
889 a = ComparatorNotImplemented(9)
890 b = ComparatorNotImplemented(9)
891 self.assertEqual(a, b)
892 with self.assertRaises(TypeError):
893 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +0200894
Serhiy Storchaka697a5262015-01-01 15:23:12 +0200895 def test_pickle(self):
896 for proto in range(4, pickle.HIGHEST_PROTOCOL + 1):
897 for name in '__lt__', '__gt__', '__le__', '__ge__':
898 with self.subTest(method=name, proto=proto):
899 method = getattr(Orderable_LT, name)
900 method_copy = pickle.loads(pickle.dumps(method, proto))
901 self.assertIs(method_copy, method)
902
903@functools.total_ordering
904class Orderable_LT:
905 def __init__(self, value):
906 self.value = value
907 def __lt__(self, other):
908 return self.value < other.value
909 def __eq__(self, other):
910 return self.value == other.value
911
912
Georg Brandl2e7346a2010-07-31 18:09:23 +0000913class TestLRU(unittest.TestCase):
914
915 def test_lru(self):
916 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100917 return 3 * x + y
Georg Brandl2e7346a2010-07-31 18:09:23 +0000918 f = functools.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000919 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000920 self.assertEqual(maxsize, 20)
921 self.assertEqual(currsize, 0)
922 self.assertEqual(hits, 0)
923 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000924
925 domain = range(5)
926 for i in range(1000):
927 x, y = choice(domain), choice(domain)
928 actual = f(x, y)
929 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +0000930 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000931 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000932 self.assertTrue(hits > misses)
933 self.assertEqual(hits + misses, 1000)
934 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000935
Raymond Hettinger02566ec2010-09-04 22:46:06 +0000936 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +0000937 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000938 self.assertEqual(hits, 0)
939 self.assertEqual(misses, 0)
940 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000941 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000942 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000943 self.assertEqual(hits, 0)
944 self.assertEqual(misses, 1)
945 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000946
Nick Coghlan98876832010-08-17 06:17:18 +0000947 # Test bypassing the cache
948 self.assertIs(f.__wrapped__, orig)
949 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000950 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000951 self.assertEqual(hits, 0)
952 self.assertEqual(misses, 1)
953 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +0000954
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000955 # test size zero (which means "never-cache")
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000956 @functools.lru_cache(0)
957 def f():
958 nonlocal f_cnt
959 f_cnt += 1
960 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000961 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000962 f_cnt = 0
963 for i in range(5):
964 self.assertEqual(f(), 20)
965 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000966 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000967 self.assertEqual(hits, 0)
968 self.assertEqual(misses, 5)
969 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000970
971 # test size one
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000972 @functools.lru_cache(1)
973 def f():
974 nonlocal f_cnt
975 f_cnt += 1
976 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000977 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000978 f_cnt = 0
979 for i in range(5):
980 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000981 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000982 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000983 self.assertEqual(hits, 4)
984 self.assertEqual(misses, 1)
985 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000986
Raymond Hettingerf3098282010-08-15 03:30:45 +0000987 # test size two
988 @functools.lru_cache(2)
989 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000990 nonlocal f_cnt
991 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +0000992 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +0000993 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000994 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +0000995 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
996 # * * * *
997 self.assertEqual(f(x), x*10)
998 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000999 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001000 self.assertEqual(hits, 12)
1001 self.assertEqual(misses, 4)
1002 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001003
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001004 def test_lru_with_maxsize_none(self):
1005 @functools.lru_cache(maxsize=None)
1006 def fib(n):
1007 if n < 2:
1008 return n
1009 return fib(n-1) + fib(n-2)
1010 self.assertEqual([fib(n) for n in range(16)],
1011 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1012 self.assertEqual(fib.cache_info(),
1013 functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1014 fib.cache_clear()
1015 self.assertEqual(fib.cache_info(),
1016 functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1017
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001018 def test_lru_with_exceptions(self):
1019 # Verify that user_function exceptions get passed through without
1020 # creating a hard-to-read chained exception.
1021 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +01001022 for maxsize in (None, 128):
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001023 @functools.lru_cache(maxsize)
1024 def func(i):
1025 return 'abc'[i]
1026 self.assertEqual(func(0), 'a')
1027 with self.assertRaises(IndexError) as cm:
1028 func(15)
1029 self.assertIsNone(cm.exception.__context__)
1030 # Verify that the previous exception did not result in a cached entry
1031 with self.assertRaises(IndexError):
1032 func(15)
1033
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001034 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001035 for maxsize in (None, 128):
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001036 @functools.lru_cache(maxsize=maxsize, typed=True)
1037 def square(x):
1038 return x * x
1039 self.assertEqual(square(3), 9)
1040 self.assertEqual(type(square(3)), type(9))
1041 self.assertEqual(square(3.0), 9.0)
1042 self.assertEqual(type(square(3.0)), type(9.0))
1043 self.assertEqual(square(x=3), 9)
1044 self.assertEqual(type(square(x=3)), type(9))
1045 self.assertEqual(square(x=3.0), 9.0)
1046 self.assertEqual(type(square(x=3.0)), type(9.0))
1047 self.assertEqual(square.cache_info().hits, 4)
1048 self.assertEqual(square.cache_info().misses, 4)
1049
Antoine Pitroub5b37142012-11-13 21:35:40 +01001050 def test_lru_with_keyword_args(self):
1051 @functools.lru_cache()
1052 def fib(n):
1053 if n < 2:
1054 return n
1055 return fib(n=n-1) + fib(n=n-2)
1056 self.assertEqual(
1057 [fib(n=number) for number in range(16)],
1058 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1059 )
1060 self.assertEqual(fib.cache_info(),
1061 functools._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
1062 fib.cache_clear()
1063 self.assertEqual(fib.cache_info(),
1064 functools._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
1065
1066 def test_lru_with_keyword_args_maxsize_none(self):
1067 @functools.lru_cache(maxsize=None)
1068 def fib(n):
1069 if n < 2:
1070 return n
1071 return fib(n=n-1) + fib(n=n-2)
1072 self.assertEqual([fib(n=number) for number in range(16)],
1073 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1074 self.assertEqual(fib.cache_info(),
1075 functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1076 fib.cache_clear()
1077 self.assertEqual(fib.cache_info(),
1078 functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1079
Raymond Hettinger03923422013-03-04 02:52:50 -05001080 def test_need_for_rlock(self):
1081 # This will deadlock on an LRU cache that uses a regular lock
1082
1083 @functools.lru_cache(maxsize=10)
1084 def test_func(x):
1085 'Used to demonstrate a reentrant lru_cache call within a single thread'
1086 return x
1087
1088 class DoubleEq:
1089 'Demonstrate a reentrant lru_cache call within a single thread'
1090 def __init__(self, x):
1091 self.x = x
1092 def __hash__(self):
1093 return self.x
1094 def __eq__(self, other):
1095 if self.x == 2:
1096 test_func(DoubleEq(1))
1097 return self.x == other.x
1098
1099 test_func(DoubleEq(1)) # Load the cache
1100 test_func(DoubleEq(2)) # Load the cache
1101 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1102 DoubleEq(2)) # Verify the correct return value
1103
Raymond Hettinger4d588972014-08-12 12:44:52 -07001104 def test_early_detection_of_bad_call(self):
1105 # Issue #22184
1106 with self.assertRaises(TypeError):
1107 @functools.lru_cache
1108 def f():
1109 pass
1110
Raymond Hettinger03923422013-03-04 02:52:50 -05001111
Łukasz Langa6f692512013-06-05 12:20:24 +02001112class TestSingleDispatch(unittest.TestCase):
1113 def test_simple_overloads(self):
1114 @functools.singledispatch
1115 def g(obj):
1116 return "base"
1117 def g_int(i):
1118 return "integer"
1119 g.register(int, g_int)
1120 self.assertEqual(g("str"), "base")
1121 self.assertEqual(g(1), "integer")
1122 self.assertEqual(g([1,2,3]), "base")
1123
1124 def test_mro(self):
1125 @functools.singledispatch
1126 def g(obj):
1127 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001128 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001129 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001130 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001131 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001132 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001133 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001134 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001135 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001136 def g_A(a):
1137 return "A"
1138 def g_B(b):
1139 return "B"
1140 g.register(A, g_A)
1141 g.register(B, g_B)
1142 self.assertEqual(g(A()), "A")
1143 self.assertEqual(g(B()), "B")
1144 self.assertEqual(g(C()), "A")
1145 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001146
1147 def test_register_decorator(self):
1148 @functools.singledispatch
1149 def g(obj):
1150 return "base"
1151 @g.register(int)
1152 def g_int(i):
1153 return "int %s" % (i,)
1154 self.assertEqual(g(""), "base")
1155 self.assertEqual(g(12), "int 12")
1156 self.assertIs(g.dispatch(int), g_int)
1157 self.assertIs(g.dispatch(object), g.dispatch(str))
1158 # Note: in the assert above this is not g.
1159 # @singledispatch returns the wrapper.
1160
1161 def test_wrapping_attributes(self):
1162 @functools.singledispatch
1163 def g(obj):
1164 "Simple test"
1165 return "Test"
1166 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02001167 if sys.flags.optimize < 2:
1168 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02001169
1170 @unittest.skipUnless(decimal, 'requires _decimal')
1171 @support.cpython_only
1172 def test_c_classes(self):
1173 @functools.singledispatch
1174 def g(obj):
1175 return "base"
1176 @g.register(decimal.DecimalException)
1177 def _(obj):
1178 return obj.args
1179 subn = decimal.Subnormal("Exponent < Emin")
1180 rnd = decimal.Rounded("Number got rounded")
1181 self.assertEqual(g(subn), ("Exponent < Emin",))
1182 self.assertEqual(g(rnd), ("Number got rounded",))
1183 @g.register(decimal.Subnormal)
1184 def _(obj):
1185 return "Too small to care."
1186 self.assertEqual(g(subn), "Too small to care.")
1187 self.assertEqual(g(rnd), ("Number got rounded",))
1188
1189 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001190 # None of the examples in this test depend on haystack ordering.
Łukasz Langa6f692512013-06-05 12:20:24 +02001191 c = collections
1192 mro = functools._compose_mro
1193 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1194 for haystack in permutations(bases):
1195 m = mro(dict, haystack)
Łukasz Langa3720c772013-07-01 16:00:38 +02001196 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, c.Sized,
1197 c.Iterable, c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001198 bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict]
1199 for haystack in permutations(bases):
1200 m = mro(c.ChainMap, haystack)
1201 self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping,
1202 c.Sized, c.Iterable, c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001203
1204 # If there's a generic function with implementations registered for
1205 # both Sized and Container, passing a defaultdict to it results in an
1206 # ambiguous dispatch which will cause a RuntimeError (see
1207 # test_mro_conflicts).
1208 bases = [c.Container, c.Sized, str]
1209 for haystack in permutations(bases):
1210 m = mro(c.defaultdict, [c.Sized, c.Container, str])
1211 self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container,
1212 object])
1213
1214 # MutableSequence below is registered directly on D. In other words, it
1215 # preceeds MutableMapping which means single dispatch will always
1216 # choose MutableSequence here.
1217 class D(c.defaultdict):
1218 pass
1219 c.MutableSequence.register(D)
1220 bases = [c.MutableSequence, c.MutableMapping]
1221 for haystack in permutations(bases):
1222 m = mro(D, bases)
1223 self.assertEqual(m, [D, c.MutableSequence, c.Sequence,
1224 c.defaultdict, dict, c.MutableMapping,
1225 c.Mapping, c.Sized, c.Iterable, c.Container,
1226 object])
1227
1228 # Container and Callable are registered on different base classes and
1229 # a generic function supporting both should always pick the Callable
1230 # implementation if a C instance is passed.
1231 class C(c.defaultdict):
1232 def __call__(self):
1233 pass
1234 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1235 for haystack in permutations(bases):
1236 m = mro(C, haystack)
1237 self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping,
1238 c.Sized, c.Iterable, c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001239
1240 def test_register_abc(self):
1241 c = collections
1242 d = {"a": "b"}
1243 l = [1, 2, 3]
1244 s = {object(), None}
1245 f = frozenset(s)
1246 t = (1, 2, 3)
1247 @functools.singledispatch
1248 def g(obj):
1249 return "base"
1250 self.assertEqual(g(d), "base")
1251 self.assertEqual(g(l), "base")
1252 self.assertEqual(g(s), "base")
1253 self.assertEqual(g(f), "base")
1254 self.assertEqual(g(t), "base")
1255 g.register(c.Sized, lambda obj: "sized")
1256 self.assertEqual(g(d), "sized")
1257 self.assertEqual(g(l), "sized")
1258 self.assertEqual(g(s), "sized")
1259 self.assertEqual(g(f), "sized")
1260 self.assertEqual(g(t), "sized")
1261 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1262 self.assertEqual(g(d), "mutablemapping")
1263 self.assertEqual(g(l), "sized")
1264 self.assertEqual(g(s), "sized")
1265 self.assertEqual(g(f), "sized")
1266 self.assertEqual(g(t), "sized")
1267 g.register(c.ChainMap, lambda obj: "chainmap")
1268 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1269 self.assertEqual(g(l), "sized")
1270 self.assertEqual(g(s), "sized")
1271 self.assertEqual(g(f), "sized")
1272 self.assertEqual(g(t), "sized")
1273 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1274 self.assertEqual(g(d), "mutablemapping")
1275 self.assertEqual(g(l), "mutablesequence")
1276 self.assertEqual(g(s), "sized")
1277 self.assertEqual(g(f), "sized")
1278 self.assertEqual(g(t), "sized")
1279 g.register(c.MutableSet, lambda obj: "mutableset")
1280 self.assertEqual(g(d), "mutablemapping")
1281 self.assertEqual(g(l), "mutablesequence")
1282 self.assertEqual(g(s), "mutableset")
1283 self.assertEqual(g(f), "sized")
1284 self.assertEqual(g(t), "sized")
1285 g.register(c.Mapping, lambda obj: "mapping")
1286 self.assertEqual(g(d), "mutablemapping") # not specific enough
1287 self.assertEqual(g(l), "mutablesequence")
1288 self.assertEqual(g(s), "mutableset")
1289 self.assertEqual(g(f), "sized")
1290 self.assertEqual(g(t), "sized")
1291 g.register(c.Sequence, lambda obj: "sequence")
1292 self.assertEqual(g(d), "mutablemapping")
1293 self.assertEqual(g(l), "mutablesequence")
1294 self.assertEqual(g(s), "mutableset")
1295 self.assertEqual(g(f), "sized")
1296 self.assertEqual(g(t), "sequence")
1297 g.register(c.Set, lambda obj: "set")
1298 self.assertEqual(g(d), "mutablemapping")
1299 self.assertEqual(g(l), "mutablesequence")
1300 self.assertEqual(g(s), "mutableset")
1301 self.assertEqual(g(f), "set")
1302 self.assertEqual(g(t), "sequence")
1303 g.register(dict, lambda obj: "dict")
1304 self.assertEqual(g(d), "dict")
1305 self.assertEqual(g(l), "mutablesequence")
1306 self.assertEqual(g(s), "mutableset")
1307 self.assertEqual(g(f), "set")
1308 self.assertEqual(g(t), "sequence")
1309 g.register(list, lambda obj: "list")
1310 self.assertEqual(g(d), "dict")
1311 self.assertEqual(g(l), "list")
1312 self.assertEqual(g(s), "mutableset")
1313 self.assertEqual(g(f), "set")
1314 self.assertEqual(g(t), "sequence")
1315 g.register(set, lambda obj: "concrete-set")
1316 self.assertEqual(g(d), "dict")
1317 self.assertEqual(g(l), "list")
1318 self.assertEqual(g(s), "concrete-set")
1319 self.assertEqual(g(f), "set")
1320 self.assertEqual(g(t), "sequence")
1321 g.register(frozenset, lambda obj: "frozen-set")
1322 self.assertEqual(g(d), "dict")
1323 self.assertEqual(g(l), "list")
1324 self.assertEqual(g(s), "concrete-set")
1325 self.assertEqual(g(f), "frozen-set")
1326 self.assertEqual(g(t), "sequence")
1327 g.register(tuple, lambda obj: "tuple")
1328 self.assertEqual(g(d), "dict")
1329 self.assertEqual(g(l), "list")
1330 self.assertEqual(g(s), "concrete-set")
1331 self.assertEqual(g(f), "frozen-set")
1332 self.assertEqual(g(t), "tuple")
1333
Łukasz Langa3720c772013-07-01 16:00:38 +02001334 def test_c3_abc(self):
1335 c = collections
1336 mro = functools._c3_mro
1337 class A(object):
1338 pass
1339 class B(A):
1340 def __len__(self):
1341 return 0 # implies Sized
1342 @c.Container.register
1343 class C(object):
1344 pass
1345 class D(object):
1346 pass # unrelated
1347 class X(D, C, B):
1348 def __call__(self):
1349 pass # implies Callable
1350 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
1351 for abcs in permutations([c.Sized, c.Callable, c.Container]):
1352 self.assertEqual(mro(X, abcs=abcs), expected)
1353 # unrelated ABCs don't appear in the resulting MRO
1354 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
1355 self.assertEqual(mro(X, abcs=many_abcs), expected)
1356
Łukasz Langa6f692512013-06-05 12:20:24 +02001357 def test_mro_conflicts(self):
1358 c = collections
Łukasz Langa6f692512013-06-05 12:20:24 +02001359 @functools.singledispatch
1360 def g(arg):
1361 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02001362 class O(c.Sized):
1363 def __len__(self):
1364 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001365 o = O()
1366 self.assertEqual(g(o), "base")
1367 g.register(c.Iterable, lambda arg: "iterable")
1368 g.register(c.Container, lambda arg: "container")
1369 g.register(c.Sized, lambda arg: "sized")
1370 g.register(c.Set, lambda arg: "set")
1371 self.assertEqual(g(o), "sized")
1372 c.Iterable.register(O)
1373 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
1374 c.Container.register(O)
1375 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02001376 c.Set.register(O)
1377 self.assertEqual(g(o), "set") # because c.Set is a subclass of
1378 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02001379 class P:
1380 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02001381 p = P()
1382 self.assertEqual(g(p), "base")
1383 c.Iterable.register(P)
1384 self.assertEqual(g(p), "iterable")
1385 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02001386 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02001387 g(p)
Łukasz Langa3720c772013-07-01 16:00:38 +02001388 self.assertIn(
1389 str(re_one.exception),
1390 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1391 "or <class 'collections.abc.Iterable'>"),
1392 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
1393 "or <class 'collections.abc.Container'>")),
1394 )
Łukasz Langa6f692512013-06-05 12:20:24 +02001395 class Q(c.Sized):
1396 def __len__(self):
1397 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001398 q = Q()
1399 self.assertEqual(g(q), "sized")
1400 c.Iterable.register(Q)
1401 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
1402 c.Set.register(Q)
1403 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02001404 # c.Sized and c.Iterable
1405 @functools.singledispatch
1406 def h(arg):
1407 return "base"
1408 @h.register(c.Sized)
1409 def _(arg):
1410 return "sized"
1411 @h.register(c.Container)
1412 def _(arg):
1413 return "container"
1414 # Even though Sized and Container are explicit bases of MutableMapping,
1415 # this ABC is implicitly registered on defaultdict which makes all of
1416 # MutableMapping's bases implicit as well from defaultdict's
1417 # perspective.
1418 with self.assertRaises(RuntimeError) as re_two:
1419 h(c.defaultdict(lambda: 0))
1420 self.assertIn(
1421 str(re_two.exception),
1422 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1423 "or <class 'collections.abc.Sized'>"),
1424 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1425 "or <class 'collections.abc.Container'>")),
1426 )
1427 class R(c.defaultdict):
1428 pass
1429 c.MutableSequence.register(R)
1430 @functools.singledispatch
1431 def i(arg):
1432 return "base"
1433 @i.register(c.MutableMapping)
1434 def _(arg):
1435 return "mapping"
1436 @i.register(c.MutableSequence)
1437 def _(arg):
1438 return "sequence"
1439 r = R()
1440 self.assertEqual(i(r), "sequence")
1441 class S:
1442 pass
1443 class T(S, c.Sized):
1444 def __len__(self):
1445 return 0
1446 t = T()
1447 self.assertEqual(h(t), "sized")
1448 c.Container.register(T)
1449 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
1450 class U:
1451 def __len__(self):
1452 return 0
1453 u = U()
1454 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
1455 # from the existence of __len__()
1456 c.Container.register(U)
1457 # There is no preference for registered versus inferred ABCs.
1458 with self.assertRaises(RuntimeError) as re_three:
1459 h(u)
1460 self.assertIn(
1461 str(re_three.exception),
1462 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1463 "or <class 'collections.abc.Sized'>"),
1464 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1465 "or <class 'collections.abc.Container'>")),
1466 )
1467 class V(c.Sized, S):
1468 def __len__(self):
1469 return 0
1470 @functools.singledispatch
1471 def j(arg):
1472 return "base"
1473 @j.register(S)
1474 def _(arg):
1475 return "s"
1476 @j.register(c.Container)
1477 def _(arg):
1478 return "container"
1479 v = V()
1480 self.assertEqual(j(v), "s")
1481 c.Container.register(V)
1482 self.assertEqual(j(v), "container") # because it ends up right after
1483 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02001484
1485 def test_cache_invalidation(self):
1486 from collections import UserDict
1487 class TracingDict(UserDict):
1488 def __init__(self, *args, **kwargs):
1489 super(TracingDict, self).__init__(*args, **kwargs)
1490 self.set_ops = []
1491 self.get_ops = []
1492 def __getitem__(self, key):
1493 result = self.data[key]
1494 self.get_ops.append(key)
1495 return result
1496 def __setitem__(self, key, value):
1497 self.set_ops.append(key)
1498 self.data[key] = value
1499 def clear(self):
1500 self.data.clear()
1501 _orig_wkd = functools.WeakKeyDictionary
1502 td = TracingDict()
1503 functools.WeakKeyDictionary = lambda: td
1504 c = collections
1505 @functools.singledispatch
1506 def g(arg):
1507 return "base"
1508 d = {}
1509 l = []
1510 self.assertEqual(len(td), 0)
1511 self.assertEqual(g(d), "base")
1512 self.assertEqual(len(td), 1)
1513 self.assertEqual(td.get_ops, [])
1514 self.assertEqual(td.set_ops, [dict])
1515 self.assertEqual(td.data[dict], g.registry[object])
1516 self.assertEqual(g(l), "base")
1517 self.assertEqual(len(td), 2)
1518 self.assertEqual(td.get_ops, [])
1519 self.assertEqual(td.set_ops, [dict, list])
1520 self.assertEqual(td.data[dict], g.registry[object])
1521 self.assertEqual(td.data[list], g.registry[object])
1522 self.assertEqual(td.data[dict], td.data[list])
1523 self.assertEqual(g(l), "base")
1524 self.assertEqual(g(d), "base")
1525 self.assertEqual(td.get_ops, [list, dict])
1526 self.assertEqual(td.set_ops, [dict, list])
1527 g.register(list, lambda arg: "list")
1528 self.assertEqual(td.get_ops, [list, dict])
1529 self.assertEqual(len(td), 0)
1530 self.assertEqual(g(d), "base")
1531 self.assertEqual(len(td), 1)
1532 self.assertEqual(td.get_ops, [list, dict])
1533 self.assertEqual(td.set_ops, [dict, list, dict])
1534 self.assertEqual(td.data[dict],
1535 functools._find_impl(dict, g.registry))
1536 self.assertEqual(g(l), "list")
1537 self.assertEqual(len(td), 2)
1538 self.assertEqual(td.get_ops, [list, dict])
1539 self.assertEqual(td.set_ops, [dict, list, dict, list])
1540 self.assertEqual(td.data[list],
1541 functools._find_impl(list, g.registry))
1542 class X:
1543 pass
1544 c.MutableMapping.register(X) # Will not invalidate the cache,
1545 # not using ABCs yet.
1546 self.assertEqual(g(d), "base")
1547 self.assertEqual(g(l), "list")
1548 self.assertEqual(td.get_ops, [list, dict, dict, list])
1549 self.assertEqual(td.set_ops, [dict, list, dict, list])
1550 g.register(c.Sized, lambda arg: "sized")
1551 self.assertEqual(len(td), 0)
1552 self.assertEqual(g(d), "sized")
1553 self.assertEqual(len(td), 1)
1554 self.assertEqual(td.get_ops, [list, dict, dict, list])
1555 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
1556 self.assertEqual(g(l), "list")
1557 self.assertEqual(len(td), 2)
1558 self.assertEqual(td.get_ops, [list, dict, dict, list])
1559 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1560 self.assertEqual(g(l), "list")
1561 self.assertEqual(g(d), "sized")
1562 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
1563 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1564 g.dispatch(list)
1565 g.dispatch(dict)
1566 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
1567 list, dict])
1568 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1569 c.MutableSet.register(X) # Will invalidate the cache.
1570 self.assertEqual(len(td), 2) # Stale cache.
1571 self.assertEqual(g(l), "list")
1572 self.assertEqual(len(td), 1)
1573 g.register(c.MutableMapping, lambda arg: "mutablemapping")
1574 self.assertEqual(len(td), 0)
1575 self.assertEqual(g(d), "mutablemapping")
1576 self.assertEqual(len(td), 1)
1577 self.assertEqual(g(l), "list")
1578 self.assertEqual(len(td), 2)
1579 g.register(dict, lambda arg: "dict")
1580 self.assertEqual(g(d), "dict")
1581 self.assertEqual(g(l), "list")
1582 g._clear_cache()
1583 self.assertEqual(len(td), 0)
1584 functools.WeakKeyDictionary = _orig_wkd
1585
1586
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001587def test_main(verbose=None):
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001588 test_classes = (
Antoine Pitroub5b37142012-11-13 21:35:40 +01001589 TestPartialC,
1590 TestPartialPy,
1591 TestPartialCSubclass,
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001592 TestPartialMethod,
Thomas Wouters73e5a5b2006-06-08 15:35:45 +00001593 TestUpdateWrapper,
Benjamin Peterson9c2930e2010-08-23 17:40:33 +00001594 TestTotalOrdering,
Antoine Pitroub5b37142012-11-13 21:35:40 +01001595 TestCmpToKeyC,
1596 TestCmpToKeyPy,
Guido van Rossum0919a1a2006-08-26 20:49:04 +00001597 TestWraps,
Georg Brandl2e7346a2010-07-31 18:09:23 +00001598 TestReduce,
1599 TestLRU,
Łukasz Langa6f692512013-06-05 12:20:24 +02001600 TestSingleDispatch,
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001601 )
Benjamin Petersonee8712c2008-05-20 21:35:26 +00001602 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001603
1604 # verify reference counting
1605 if verbose and hasattr(sys, "gettotalrefcount"):
1606 import gc
1607 counts = [None] * 5
Guido van Rossum805365e2007-05-07 22:24:25 +00001608 for i in range(len(counts)):
Benjamin Petersonee8712c2008-05-20 21:35:26 +00001609 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001610 gc.collect()
1611 counts[i] = sys.gettotalrefcount()
Guido van Rossumbe19ed72007-02-09 05:37:30 +00001612 print(counts)
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001613
1614if __name__ == '__main__':
1615 test_main(verbose=True)