blob: 0375601c518cebf8986b2189410840edea77efa5 [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettinger003be522011-05-03 11:01:32 -07002import collections
Łukasz Langa6f692512013-06-05 12:20:24 +02003from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00004import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00005from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02006import sys
7from test import support
8import unittest
9from weakref import proxy
Raymond Hettinger9c323f82005-02-28 19:39:44 +000010
Antoine Pitroub5b37142012-11-13 21:35:40 +010011import functools
12
Antoine Pitroub5b37142012-11-13 21:35:40 +010013py_functools = support.import_fresh_module('functools', blocked=['_functools'])
14c_functools = support.import_fresh_module('functools', fresh=['_functools'])
15
Łukasz Langa6f692512013-06-05 12:20:24 +020016decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
17
18
Raymond Hettinger9c323f82005-02-28 19:39:44 +000019def capture(*args, **kw):
20 """capture all positional and keyword arguments"""
21 return args, kw
22
Łukasz Langa6f692512013-06-05 12:20:24 +020023
Jack Diederiche0cbd692009-04-01 04:27:09 +000024def signature(part):
25 """ return the signature of a partial object """
26 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000027
Łukasz Langa6f692512013-06-05 12:20:24 +020028
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020029class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000030
31 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010032 p = self.partial(capture, 1, 2, a=10, b=20)
33 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000034 self.assertEqual(p(3, 4, b=30, c=40),
35 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010036 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000037 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000038
39 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010040 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000041 # attributes should be readable
42 self.assertEqual(p.func, capture)
43 self.assertEqual(p.args, (1, 2))
44 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000045
46 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010047 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000048 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010049 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000050 except TypeError:
51 pass
52 else:
53 self.fail('First arg not checked for callability')
54
55 def test_protection_of_callers_dict_argument(self):
56 # a caller's dictionary should not be altered by partial
57 def func(a=10, b=20):
58 return a
59 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010060 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000061 self.assertEqual(p(**d), 3)
62 self.assertEqual(d, {'a':3})
63 p(b=7)
64 self.assertEqual(d, {'a':3})
65
66 def test_arg_combinations(self):
67 # exercise special code paths for zero args in either partial
68 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010069 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000070 self.assertEqual(p(), ((), {}))
71 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010072 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000073 self.assertEqual(p(), ((1,2), {}))
74 self.assertEqual(p(3,4), ((1,2,3,4), {}))
75
76 def test_kw_combinations(self):
77 # exercise special code paths for no keyword args in
78 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010079 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000080 self.assertEqual(p(), ((), {}))
81 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010082 p = self.partial(capture, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000083 self.assertEqual(p(), ((), {'a':1}))
84 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
85 # keyword args in the call override those in the partial object
86 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
87
88 def test_positional(self):
89 # make sure positional arguments are captured correctly
90 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +010091 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000092 expected = args + ('x',)
93 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +000094 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +000095
96 def test_keyword(self):
97 # make sure keyword arguments are captured correctly
98 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +010099 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000100 expected = {'a':a,'x':None}
101 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000102 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000103
104 def test_no_side_effects(self):
105 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100106 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000107 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000108 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000109 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000110 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000111
112 def test_error_propagation(self):
113 def f(x, y):
114 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100115 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
116 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
117 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
118 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000119
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000120 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100121 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000122 p = proxy(f)
123 self.assertEqual(f.func, p.func)
124 f = None
125 self.assertRaises(ReferenceError, getattr, p, 'func')
126
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000127 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000128 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100129 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000130 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100131 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000132 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000133
Łukasz Langa6f692512013-06-05 12:20:24 +0200134
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200135@unittest.skipUnless(c_functools, 'requires the C _functools module')
136class TestPartialC(TestPartial, unittest.TestCase):
137 if c_functools:
138 partial = c_functools.partial
139
Zachary Ware101d9e72013-12-08 00:44:27 -0600140 def test_attributes_unwritable(self):
141 # attributes should not be writable
142 p = self.partial(capture, 1, 2, a=10, b=20)
143 self.assertRaises(AttributeError, setattr, p, 'func', map)
144 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
145 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
146
147 p = self.partial(hex)
148 try:
149 del p.__dict__
150 except TypeError:
151 pass
152 else:
153 self.fail('partial object allowed __dict__ to be deleted')
154
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000155 def test_repr(self):
156 args = (object(), object())
157 args_repr = ', '.join(repr(a) for a in args)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200158 kwargs = {'a': object(), 'b': object()}
159 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
160 'b={b!r}, a={a!r}'.format_map(kwargs)]
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200161 if self.partial is c_functools.partial:
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000162 name = 'functools.partial'
163 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100164 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000165
Antoine Pitroub5b37142012-11-13 21:35:40 +0100166 f = self.partial(capture)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000167 self.assertEqual('{}({!r})'.format(name, capture),
168 repr(f))
169
Antoine Pitroub5b37142012-11-13 21:35:40 +0100170 f = self.partial(capture, *args)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000171 self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr),
172 repr(f))
173
Antoine Pitroub5b37142012-11-13 21:35:40 +0100174 f = self.partial(capture, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200175 self.assertIn(repr(f),
176 ['{}({!r}, {})'.format(name, capture, kwargs_repr)
177 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000178
Antoine Pitroub5b37142012-11-13 21:35:40 +0100179 f = self.partial(capture, *args, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200180 self.assertIn(repr(f),
181 ['{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr)
182 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000183
Jack Diederiche0cbd692009-04-01 04:27:09 +0000184 def test_pickle(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100185 f = self.partial(signature, 'asdf', bar=True)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000186 f.add_something_to__dict__ = True
Serhiy Storchakabad12572014-12-15 14:03:42 +0200187 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
188 f_copy = pickle.loads(pickle.dumps(f, proto))
189 self.assertEqual(signature(f), signature(f_copy))
Jack Diederiche0cbd692009-04-01 04:27:09 +0000190
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200191 # Issue 6083: Reference counting bug
192 def test_setstate_refcount(self):
193 class BadSequence:
194 def __len__(self):
195 return 4
196 def __getitem__(self, key):
197 if key == 0:
198 return max
199 elif key == 1:
200 return tuple(range(1000000))
201 elif key in (2, 3):
202 return {}
203 raise IndexError
204
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200205 f = self.partial(object)
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200206 self.assertRaisesRegex(SystemError,
207 "new style getargs format but argument is not a tuple",
208 f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000209
Łukasz Langa6f692512013-06-05 12:20:24 +0200210
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200211class TestPartialPy(TestPartial, unittest.TestCase):
212 partial = staticmethod(py_functools.partial)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000213
Łukasz Langa6f692512013-06-05 12:20:24 +0200214
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200215if c_functools:
216 class PartialSubclass(c_functools.partial):
217 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100218
Łukasz Langa6f692512013-06-05 12:20:24 +0200219
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200220@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200221class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200222 if c_functools:
223 partial = PartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000224
Łukasz Langa6f692512013-06-05 12:20:24 +0200225
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000226class TestPartialMethod(unittest.TestCase):
227
228 class A(object):
229 nothing = functools.partialmethod(capture)
230 positional = functools.partialmethod(capture, 1)
231 keywords = functools.partialmethod(capture, a=2)
232 both = functools.partialmethod(capture, 3, b=4)
233
234 nested = functools.partialmethod(positional, 5)
235
236 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
237
238 static = functools.partialmethod(staticmethod(capture), 8)
239 cls = functools.partialmethod(classmethod(capture), d=9)
240
241 a = A()
242
243 def test_arg_combinations(self):
244 self.assertEqual(self.a.nothing(), ((self.a,), {}))
245 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
246 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
247 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
248
249 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
250 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
251 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
252 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
253
254 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
255 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
256 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
257 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
258
259 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
260 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
261 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
262 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
263
264 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
265
266 def test_nested(self):
267 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
268 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
269 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
270 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
271
272 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
273
274 def test_over_partial(self):
275 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
276 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
277 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
278 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
279
280 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
281
282 def test_bound_method_introspection(self):
283 obj = self.a
284 self.assertIs(obj.both.__self__, obj)
285 self.assertIs(obj.nested.__self__, obj)
286 self.assertIs(obj.over_partial.__self__, obj)
287 self.assertIs(obj.cls.__self__, self.A)
288 self.assertIs(self.A.cls.__self__, self.A)
289
290 def test_unbound_method_retrieval(self):
291 obj = self.A
292 self.assertFalse(hasattr(obj.both, "__self__"))
293 self.assertFalse(hasattr(obj.nested, "__self__"))
294 self.assertFalse(hasattr(obj.over_partial, "__self__"))
295 self.assertFalse(hasattr(obj.static, "__self__"))
296 self.assertFalse(hasattr(self.a.static, "__self__"))
297
298 def test_descriptors(self):
299 for obj in [self.A, self.a]:
300 with self.subTest(obj=obj):
301 self.assertEqual(obj.static(), ((8,), {}))
302 self.assertEqual(obj.static(5), ((8, 5), {}))
303 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
304 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
305
306 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
307 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
308 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
309 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
310
311 def test_overriding_keywords(self):
312 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
313 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
314
315 def test_invalid_args(self):
316 with self.assertRaises(TypeError):
317 class B(object):
318 method = functools.partialmethod(None, 1)
319
320 def test_repr(self):
321 self.assertEqual(repr(vars(self.A)['both']),
322 'functools.partialmethod({}, 3, b=4)'.format(capture))
323
324 def test_abstract(self):
325 class Abstract(abc.ABCMeta):
326
327 @abc.abstractmethod
328 def add(self, x, y):
329 pass
330
331 add5 = functools.partialmethod(add, 5)
332
333 self.assertTrue(Abstract.add.__isabstractmethod__)
334 self.assertTrue(Abstract.add5.__isabstractmethod__)
335
336 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
337 self.assertFalse(getattr(func, '__isabstractmethod__', False))
338
339
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000340class TestUpdateWrapper(unittest.TestCase):
341
342 def check_wrapper(self, wrapper, wrapped,
343 assigned=functools.WRAPPER_ASSIGNMENTS,
344 updated=functools.WRAPPER_UPDATES):
345 # Check attributes were assigned
346 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000347 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000348 # Check attributes were updated
349 for name in updated:
350 wrapper_attr = getattr(wrapper, name)
351 wrapped_attr = getattr(wrapped, name)
352 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000353 if name == "__dict__" and key == "__wrapped__":
354 # __wrapped__ is overwritten by the update code
355 continue
356 self.assertIs(wrapped_attr[key], wrapper_attr[key])
357 # Check __wrapped__
358 self.assertIs(wrapper.__wrapped__, wrapped)
359
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000360
R. David Murray378c0cf2010-02-24 01:46:21 +0000361 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000362 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000363 """This is a test"""
364 pass
365 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000366 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000367 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000368 pass
369 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000370 return wrapper, f
371
372 def test_default_update(self):
373 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000374 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000375 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000376 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600377 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000378 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000379 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
380 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000381
R. David Murray378c0cf2010-02-24 01:46:21 +0000382 @unittest.skipIf(sys.flags.optimize >= 2,
383 "Docstrings are omitted with -O2 and above")
384 def test_default_update_doc(self):
385 wrapper, f = self._default_update()
386 self.assertEqual(wrapper.__doc__, 'This is a test')
387
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000388 def test_no_update(self):
389 def f():
390 """This is a test"""
391 pass
392 f.attr = 'This is also a test'
393 def wrapper():
394 pass
395 functools.update_wrapper(wrapper, f, (), ())
396 self.check_wrapper(wrapper, f, (), ())
397 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600398 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000399 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000400 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000401 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000402
403 def test_selective_update(self):
404 def f():
405 pass
406 f.attr = 'This is a different test'
407 f.dict_attr = dict(a=1, b=2, c=3)
408 def wrapper():
409 pass
410 wrapper.dict_attr = {}
411 assign = ('attr',)
412 update = ('dict_attr',)
413 functools.update_wrapper(wrapper, f, assign, update)
414 self.check_wrapper(wrapper, f, assign, update)
415 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600416 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000417 self.assertEqual(wrapper.__doc__, None)
418 self.assertEqual(wrapper.attr, 'This is a different test')
419 self.assertEqual(wrapper.dict_attr, f.dict_attr)
420
Nick Coghlan98876832010-08-17 06:17:18 +0000421 def test_missing_attributes(self):
422 def f():
423 pass
424 def wrapper():
425 pass
426 wrapper.dict_attr = {}
427 assign = ('attr',)
428 update = ('dict_attr',)
429 # Missing attributes on wrapped object are ignored
430 functools.update_wrapper(wrapper, f, assign, update)
431 self.assertNotIn('attr', wrapper.__dict__)
432 self.assertEqual(wrapper.dict_attr, {})
433 # Wrapper must have expected attributes for updating
434 del wrapper.dict_attr
435 with self.assertRaises(AttributeError):
436 functools.update_wrapper(wrapper, f, assign, update)
437 wrapper.dict_attr = 1
438 with self.assertRaises(AttributeError):
439 functools.update_wrapper(wrapper, f, assign, update)
440
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200441 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000442 @unittest.skipIf(sys.flags.optimize >= 2,
443 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000444 def test_builtin_update(self):
445 # Test for bug #1576241
446 def wrapper():
447 pass
448 functools.update_wrapper(wrapper, max)
449 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000450 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000451 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000452
Łukasz Langa6f692512013-06-05 12:20:24 +0200453
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000454class TestWraps(TestUpdateWrapper):
455
R. David Murray378c0cf2010-02-24 01:46:21 +0000456 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000457 def f():
458 """This is a test"""
459 pass
460 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000461 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000462 @functools.wraps(f)
463 def wrapper():
464 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600465 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000466
467 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600468 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000469 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000470 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600471 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000472 self.assertEqual(wrapper.attr, 'This is also a test')
473
Antoine Pitroub5b37142012-11-13 21:35:40 +0100474 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000475 "Docstrings are omitted with -O2 and above")
476 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600477 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000478 self.assertEqual(wrapper.__doc__, 'This is a test')
479
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000480 def test_no_update(self):
481 def f():
482 """This is a test"""
483 pass
484 f.attr = 'This is also a test'
485 @functools.wraps(f, (), ())
486 def wrapper():
487 pass
488 self.check_wrapper(wrapper, f, (), ())
489 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600490 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000491 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000492 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000493
494 def test_selective_update(self):
495 def f():
496 pass
497 f.attr = 'This is a different test'
498 f.dict_attr = dict(a=1, b=2, c=3)
499 def add_dict_attr(f):
500 f.dict_attr = {}
501 return f
502 assign = ('attr',)
503 update = ('dict_attr',)
504 @functools.wraps(f, assign, update)
505 @add_dict_attr
506 def wrapper():
507 pass
508 self.check_wrapper(wrapper, f, assign, update)
509 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600510 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000511 self.assertEqual(wrapper.__doc__, None)
512 self.assertEqual(wrapper.attr, 'This is a different test')
513 self.assertEqual(wrapper.dict_attr, f.dict_attr)
514
Łukasz Langa6f692512013-06-05 12:20:24 +0200515
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000516class TestReduce(unittest.TestCase):
517 func = functools.reduce
518
519 def test_reduce(self):
520 class Squares:
521 def __init__(self, max):
522 self.max = max
523 self.sofar = []
524
525 def __len__(self):
526 return len(self.sofar)
527
528 def __getitem__(self, i):
529 if not 0 <= i < self.max: raise IndexError
530 n = len(self.sofar)
531 while n <= i:
532 self.sofar.append(n*n)
533 n += 1
534 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000535 def add(x, y):
536 return x + y
537 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000538 self.assertEqual(
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000539 self.func(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000540 ['a','c','d','w']
541 )
542 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
543 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000544 self.func(lambda x, y: x*y, range(2,21), 1),
545 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000546 )
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000547 self.assertEqual(self.func(add, Squares(10)), 285)
548 self.assertEqual(self.func(add, Squares(10), 0), 285)
549 self.assertEqual(self.func(add, Squares(0), 0), 0)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000550 self.assertRaises(TypeError, self.func)
551 self.assertRaises(TypeError, self.func, 42, 42)
552 self.assertRaises(TypeError, self.func, 42, 42, 42)
553 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
554 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
555 self.assertRaises(TypeError, self.func, 42, (42, 42))
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000556 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
557 self.assertRaises(TypeError, self.func, add, "")
558 self.assertRaises(TypeError, self.func, add, ())
559 self.assertRaises(TypeError, self.func, add, object())
560
561 class TestFailingIter:
562 def __iter__(self):
563 raise RuntimeError
564 self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
565
566 self.assertEqual(self.func(add, [], None), None)
567 self.assertEqual(self.func(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000568
569 class BadSeq:
570 def __getitem__(self, index):
571 raise ValueError
572 self.assertRaises(ValueError, self.func, 42, BadSeq())
573
574 # Test reduce()'s use of iterators.
575 def test_iterator_usage(self):
576 class SequenceClass:
577 def __init__(self, n):
578 self.n = n
579 def __getitem__(self, i):
580 if 0 <= i < self.n:
581 return i
582 else:
583 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000584
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000585 from operator import add
586 self.assertEqual(self.func(add, SequenceClass(5)), 10)
587 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
588 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
589 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
590 self.assertEqual(self.func(add, SequenceClass(1)), 0)
591 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
592
593 d = {"one": 1, "two": 2, "three": 3}
594 self.assertEqual(self.func(add, d), "".join(d.keys()))
595
Łukasz Langa6f692512013-06-05 12:20:24 +0200596
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200597class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700598
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000599 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700600 def cmp1(x, y):
601 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100602 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700603 self.assertEqual(key(3), key(3))
604 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100605 self.assertGreaterEqual(key(3), key(3))
606
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700607 def cmp2(x, y):
608 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100609 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700610 self.assertEqual(key(4.0), key('4'))
611 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100612 self.assertLessEqual(key(2), key('35'))
613 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700614
615 def test_cmp_to_key_arguments(self):
616 def cmp1(x, y):
617 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100618 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700619 self.assertEqual(key(obj=3), key(obj=3))
620 self.assertGreater(key(obj=3), key(obj=1))
621 with self.assertRaises((TypeError, AttributeError)):
622 key(3) > 1 # rhs is not a K object
623 with self.assertRaises((TypeError, AttributeError)):
624 1 < key(3) # lhs is not a K object
625 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100626 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700627 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200628 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100629 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700630 with self.assertRaises(TypeError):
631 key() # too few args
632 with self.assertRaises(TypeError):
633 key(None, None) # too many args
634
635 def test_bad_cmp(self):
636 def cmp1(x, y):
637 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100638 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700639 with self.assertRaises(ZeroDivisionError):
640 key(3) > key(1)
641
642 class BadCmp:
643 def __lt__(self, other):
644 raise ZeroDivisionError
645 def cmp1(x, y):
646 return BadCmp()
647 with self.assertRaises(ZeroDivisionError):
648 key(3) > key(1)
649
650 def test_obj_field(self):
651 def cmp1(x, y):
652 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100653 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700654 self.assertEqual(key(50).obj, 50)
655
656 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000657 def mycmp(x, y):
658 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100659 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000660 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000661
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700662 def test_sort_int_str(self):
663 def mycmp(x, y):
664 x, y = int(x), int(y)
665 return (x > y) - (x < y)
666 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100667 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700668 self.assertEqual([int(value) for value in values],
669 [0, 1, 1, 2, 3, 4, 5, 7, 10])
670
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000671 def test_hash(self):
672 def mycmp(x, y):
673 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100674 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000675 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700676 self.assertRaises(TypeError, hash, k)
Raymond Hettingere7a24302011-05-03 11:16:36 -0700677 self.assertNotIsInstance(k, collections.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000678
Łukasz Langa6f692512013-06-05 12:20:24 +0200679
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200680@unittest.skipUnless(c_functools, 'requires the C _functools module')
681class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
682 if c_functools:
683 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100684
Łukasz Langa6f692512013-06-05 12:20:24 +0200685
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200686class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100687 cmp_to_key = staticmethod(py_functools.cmp_to_key)
688
Łukasz Langa6f692512013-06-05 12:20:24 +0200689
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000690class TestTotalOrdering(unittest.TestCase):
691
692 def test_total_ordering_lt(self):
693 @functools.total_ordering
694 class A:
695 def __init__(self, value):
696 self.value = value
697 def __lt__(self, other):
698 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000699 def __eq__(self, other):
700 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000701 self.assertTrue(A(1) < A(2))
702 self.assertTrue(A(2) > A(1))
703 self.assertTrue(A(1) <= A(2))
704 self.assertTrue(A(2) >= A(1))
705 self.assertTrue(A(2) <= A(2))
706 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000707 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000708
709 def test_total_ordering_le(self):
710 @functools.total_ordering
711 class A:
712 def __init__(self, value):
713 self.value = value
714 def __le__(self, other):
715 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000716 def __eq__(self, other):
717 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000718 self.assertTrue(A(1) < A(2))
719 self.assertTrue(A(2) > A(1))
720 self.assertTrue(A(1) <= A(2))
721 self.assertTrue(A(2) >= A(1))
722 self.assertTrue(A(2) <= A(2))
723 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000724 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000725
726 def test_total_ordering_gt(self):
727 @functools.total_ordering
728 class A:
729 def __init__(self, value):
730 self.value = value
731 def __gt__(self, other):
732 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000733 def __eq__(self, other):
734 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000735 self.assertTrue(A(1) < A(2))
736 self.assertTrue(A(2) > A(1))
737 self.assertTrue(A(1) <= A(2))
738 self.assertTrue(A(2) >= A(1))
739 self.assertTrue(A(2) <= A(2))
740 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000741 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000742
743 def test_total_ordering_ge(self):
744 @functools.total_ordering
745 class A:
746 def __init__(self, value):
747 self.value = value
748 def __ge__(self, other):
749 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000750 def __eq__(self, other):
751 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000752 self.assertTrue(A(1) < A(2))
753 self.assertTrue(A(2) > A(1))
754 self.assertTrue(A(1) <= A(2))
755 self.assertTrue(A(2) >= A(1))
756 self.assertTrue(A(2) <= A(2))
757 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000758 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000759
760 def test_total_ordering_no_overwrite(self):
761 # new methods should not overwrite existing
762 @functools.total_ordering
763 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000764 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +0000765 self.assertTrue(A(1) < A(2))
766 self.assertTrue(A(2) > A(1))
767 self.assertTrue(A(1) <= A(2))
768 self.assertTrue(A(2) >= A(1))
769 self.assertTrue(A(2) <= A(2))
770 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000771
Benjamin Peterson42ebee32010-04-11 01:43:16 +0000772 def test_no_operations_defined(self):
773 with self.assertRaises(ValueError):
774 @functools.total_ordering
775 class A:
776 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000777
Nick Coghlanf05d9812013-10-02 00:02:03 +1000778 def test_type_error_when_not_implemented(self):
779 # bug 10042; ensure stack overflow does not occur
780 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000781 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +1000782 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000783 def __init__(self, value):
784 self.value = value
785 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000786 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000787 return self.value == other.value
788 return False
789 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000790 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000791 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +1000792 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000793
Nick Coghlanf05d9812013-10-02 00:02:03 +1000794 @functools.total_ordering
795 class ImplementsGreaterThan:
796 def __init__(self, value):
797 self.value = value
798 def __eq__(self, other):
799 if isinstance(other, ImplementsGreaterThan):
800 return self.value == other.value
801 return False
802 def __gt__(self, other):
803 if isinstance(other, ImplementsGreaterThan):
804 return self.value > other.value
805 return NotImplemented
806
807 @functools.total_ordering
808 class ImplementsLessThanEqualTo:
809 def __init__(self, value):
810 self.value = value
811 def __eq__(self, other):
812 if isinstance(other, ImplementsLessThanEqualTo):
813 return self.value == other.value
814 return False
815 def __le__(self, other):
816 if isinstance(other, ImplementsLessThanEqualTo):
817 return self.value <= other.value
818 return NotImplemented
819
820 @functools.total_ordering
821 class ImplementsGreaterThanEqualTo:
822 def __init__(self, value):
823 self.value = value
824 def __eq__(self, other):
825 if isinstance(other, ImplementsGreaterThanEqualTo):
826 return self.value == other.value
827 return False
828 def __ge__(self, other):
829 if isinstance(other, ImplementsGreaterThanEqualTo):
830 return self.value >= other.value
831 return NotImplemented
832
833 @functools.total_ordering
834 class ComparatorNotImplemented:
835 def __init__(self, value):
836 self.value = value
837 def __eq__(self, other):
838 if isinstance(other, ComparatorNotImplemented):
839 return self.value == other.value
840 return False
841 def __lt__(self, other):
842 return NotImplemented
843
844 with self.subTest("LT < 1"), self.assertRaises(TypeError):
845 ImplementsLessThan(-1) < 1
846
847 with self.subTest("LT < LE"), self.assertRaises(TypeError):
848 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
849
850 with self.subTest("LT < GT"), self.assertRaises(TypeError):
851 ImplementsLessThan(1) < ImplementsGreaterThan(1)
852
853 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
854 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
855
856 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
857 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
858
859 with self.subTest("GT > GE"), self.assertRaises(TypeError):
860 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
861
862 with self.subTest("GT > LT"), self.assertRaises(TypeError):
863 ImplementsGreaterThan(5) > ImplementsLessThan(5)
864
865 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
866 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
867
868 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
869 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
870
871 with self.subTest("GE when equal"):
872 a = ComparatorNotImplemented(8)
873 b = ComparatorNotImplemented(8)
874 self.assertEqual(a, b)
875 with self.assertRaises(TypeError):
876 a >= b
877
878 with self.subTest("LE when equal"):
879 a = ComparatorNotImplemented(9)
880 b = ComparatorNotImplemented(9)
881 self.assertEqual(a, b)
882 with self.assertRaises(TypeError):
883 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +0200884
Georg Brandl2e7346a2010-07-31 18:09:23 +0000885class TestLRU(unittest.TestCase):
886
887 def test_lru(self):
888 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100889 return 3 * x + y
Georg Brandl2e7346a2010-07-31 18:09:23 +0000890 f = functools.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000891 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000892 self.assertEqual(maxsize, 20)
893 self.assertEqual(currsize, 0)
894 self.assertEqual(hits, 0)
895 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000896
897 domain = range(5)
898 for i in range(1000):
899 x, y = choice(domain), choice(domain)
900 actual = f(x, y)
901 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +0000902 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000903 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000904 self.assertTrue(hits > misses)
905 self.assertEqual(hits + misses, 1000)
906 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000907
Raymond Hettinger02566ec2010-09-04 22:46:06 +0000908 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +0000909 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000910 self.assertEqual(hits, 0)
911 self.assertEqual(misses, 0)
912 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000913 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000914 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000915 self.assertEqual(hits, 0)
916 self.assertEqual(misses, 1)
917 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000918
Nick Coghlan98876832010-08-17 06:17:18 +0000919 # Test bypassing the cache
920 self.assertIs(f.__wrapped__, orig)
921 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000922 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000923 self.assertEqual(hits, 0)
924 self.assertEqual(misses, 1)
925 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +0000926
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000927 # test size zero (which means "never-cache")
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000928 @functools.lru_cache(0)
929 def f():
930 nonlocal f_cnt
931 f_cnt += 1
932 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000933 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000934 f_cnt = 0
935 for i in range(5):
936 self.assertEqual(f(), 20)
937 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000938 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000939 self.assertEqual(hits, 0)
940 self.assertEqual(misses, 5)
941 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000942
943 # test size one
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000944 @functools.lru_cache(1)
945 def f():
946 nonlocal f_cnt
947 f_cnt += 1
948 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000949 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000950 f_cnt = 0
951 for i in range(5):
952 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000953 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000954 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000955 self.assertEqual(hits, 4)
956 self.assertEqual(misses, 1)
957 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000958
Raymond Hettingerf3098282010-08-15 03:30:45 +0000959 # test size two
960 @functools.lru_cache(2)
961 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000962 nonlocal f_cnt
963 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +0000964 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +0000965 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000966 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +0000967 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
968 # * * * *
969 self.assertEqual(f(x), x*10)
970 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000971 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000972 self.assertEqual(hits, 12)
973 self.assertEqual(misses, 4)
974 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000975
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +0000976 def test_lru_with_maxsize_none(self):
977 @functools.lru_cache(maxsize=None)
978 def fib(n):
979 if n < 2:
980 return n
981 return fib(n-1) + fib(n-2)
982 self.assertEqual([fib(n) for n in range(16)],
983 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
984 self.assertEqual(fib.cache_info(),
985 functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
986 fib.cache_clear()
987 self.assertEqual(fib.cache_info(),
988 functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
989
Raymond Hettinger4b779b32011-10-15 23:50:42 -0700990 def test_lru_with_exceptions(self):
991 # Verify that user_function exceptions get passed through without
992 # creating a hard-to-read chained exception.
993 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +0100994 for maxsize in (None, 128):
Raymond Hettinger4b779b32011-10-15 23:50:42 -0700995 @functools.lru_cache(maxsize)
996 def func(i):
997 return 'abc'[i]
998 self.assertEqual(func(0), 'a')
999 with self.assertRaises(IndexError) as cm:
1000 func(15)
1001 self.assertIsNone(cm.exception.__context__)
1002 # Verify that the previous exception did not result in a cached entry
1003 with self.assertRaises(IndexError):
1004 func(15)
1005
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001006 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001007 for maxsize in (None, 128):
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001008 @functools.lru_cache(maxsize=maxsize, typed=True)
1009 def square(x):
1010 return x * x
1011 self.assertEqual(square(3), 9)
1012 self.assertEqual(type(square(3)), type(9))
1013 self.assertEqual(square(3.0), 9.0)
1014 self.assertEqual(type(square(3.0)), type(9.0))
1015 self.assertEqual(square(x=3), 9)
1016 self.assertEqual(type(square(x=3)), type(9))
1017 self.assertEqual(square(x=3.0), 9.0)
1018 self.assertEqual(type(square(x=3.0)), type(9.0))
1019 self.assertEqual(square.cache_info().hits, 4)
1020 self.assertEqual(square.cache_info().misses, 4)
1021
Antoine Pitroub5b37142012-11-13 21:35:40 +01001022 def test_lru_with_keyword_args(self):
1023 @functools.lru_cache()
1024 def fib(n):
1025 if n < 2:
1026 return n
1027 return fib(n=n-1) + fib(n=n-2)
1028 self.assertEqual(
1029 [fib(n=number) for number in range(16)],
1030 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1031 )
1032 self.assertEqual(fib.cache_info(),
1033 functools._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
1034 fib.cache_clear()
1035 self.assertEqual(fib.cache_info(),
1036 functools._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
1037
1038 def test_lru_with_keyword_args_maxsize_none(self):
1039 @functools.lru_cache(maxsize=None)
1040 def fib(n):
1041 if n < 2:
1042 return n
1043 return fib(n=n-1) + fib(n=n-2)
1044 self.assertEqual([fib(n=number) for number in range(16)],
1045 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1046 self.assertEqual(fib.cache_info(),
1047 functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1048 fib.cache_clear()
1049 self.assertEqual(fib.cache_info(),
1050 functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1051
Raymond Hettinger03923422013-03-04 02:52:50 -05001052 def test_need_for_rlock(self):
1053 # This will deadlock on an LRU cache that uses a regular lock
1054
1055 @functools.lru_cache(maxsize=10)
1056 def test_func(x):
1057 'Used to demonstrate a reentrant lru_cache call within a single thread'
1058 return x
1059
1060 class DoubleEq:
1061 'Demonstrate a reentrant lru_cache call within a single thread'
1062 def __init__(self, x):
1063 self.x = x
1064 def __hash__(self):
1065 return self.x
1066 def __eq__(self, other):
1067 if self.x == 2:
1068 test_func(DoubleEq(1))
1069 return self.x == other.x
1070
1071 test_func(DoubleEq(1)) # Load the cache
1072 test_func(DoubleEq(2)) # Load the cache
1073 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1074 DoubleEq(2)) # Verify the correct return value
1075
Raymond Hettinger4d588972014-08-12 12:44:52 -07001076 def test_early_detection_of_bad_call(self):
1077 # Issue #22184
1078 with self.assertRaises(TypeError):
1079 @functools.lru_cache
1080 def f():
1081 pass
1082
Raymond Hettinger03923422013-03-04 02:52:50 -05001083
Łukasz Langa6f692512013-06-05 12:20:24 +02001084class TestSingleDispatch(unittest.TestCase):
1085 def test_simple_overloads(self):
1086 @functools.singledispatch
1087 def g(obj):
1088 return "base"
1089 def g_int(i):
1090 return "integer"
1091 g.register(int, g_int)
1092 self.assertEqual(g("str"), "base")
1093 self.assertEqual(g(1), "integer")
1094 self.assertEqual(g([1,2,3]), "base")
1095
1096 def test_mro(self):
1097 @functools.singledispatch
1098 def g(obj):
1099 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001100 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001101 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001102 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001103 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001104 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001105 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001106 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001107 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001108 def g_A(a):
1109 return "A"
1110 def g_B(b):
1111 return "B"
1112 g.register(A, g_A)
1113 g.register(B, g_B)
1114 self.assertEqual(g(A()), "A")
1115 self.assertEqual(g(B()), "B")
1116 self.assertEqual(g(C()), "A")
1117 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001118
1119 def test_register_decorator(self):
1120 @functools.singledispatch
1121 def g(obj):
1122 return "base"
1123 @g.register(int)
1124 def g_int(i):
1125 return "int %s" % (i,)
1126 self.assertEqual(g(""), "base")
1127 self.assertEqual(g(12), "int 12")
1128 self.assertIs(g.dispatch(int), g_int)
1129 self.assertIs(g.dispatch(object), g.dispatch(str))
1130 # Note: in the assert above this is not g.
1131 # @singledispatch returns the wrapper.
1132
1133 def test_wrapping_attributes(self):
1134 @functools.singledispatch
1135 def g(obj):
1136 "Simple test"
1137 return "Test"
1138 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02001139 if sys.flags.optimize < 2:
1140 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02001141
1142 @unittest.skipUnless(decimal, 'requires _decimal')
1143 @support.cpython_only
1144 def test_c_classes(self):
1145 @functools.singledispatch
1146 def g(obj):
1147 return "base"
1148 @g.register(decimal.DecimalException)
1149 def _(obj):
1150 return obj.args
1151 subn = decimal.Subnormal("Exponent < Emin")
1152 rnd = decimal.Rounded("Number got rounded")
1153 self.assertEqual(g(subn), ("Exponent < Emin",))
1154 self.assertEqual(g(rnd), ("Number got rounded",))
1155 @g.register(decimal.Subnormal)
1156 def _(obj):
1157 return "Too small to care."
1158 self.assertEqual(g(subn), "Too small to care.")
1159 self.assertEqual(g(rnd), ("Number got rounded",))
1160
1161 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001162 # None of the examples in this test depend on haystack ordering.
Łukasz Langa6f692512013-06-05 12:20:24 +02001163 c = collections
1164 mro = functools._compose_mro
1165 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1166 for haystack in permutations(bases):
1167 m = mro(dict, haystack)
Łukasz Langa3720c772013-07-01 16:00:38 +02001168 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, c.Sized,
1169 c.Iterable, c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001170 bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict]
1171 for haystack in permutations(bases):
1172 m = mro(c.ChainMap, haystack)
1173 self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping,
1174 c.Sized, c.Iterable, c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001175
1176 # If there's a generic function with implementations registered for
1177 # both Sized and Container, passing a defaultdict to it results in an
1178 # ambiguous dispatch which will cause a RuntimeError (see
1179 # test_mro_conflicts).
1180 bases = [c.Container, c.Sized, str]
1181 for haystack in permutations(bases):
1182 m = mro(c.defaultdict, [c.Sized, c.Container, str])
1183 self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container,
1184 object])
1185
1186 # MutableSequence below is registered directly on D. In other words, it
1187 # preceeds MutableMapping which means single dispatch will always
1188 # choose MutableSequence here.
1189 class D(c.defaultdict):
1190 pass
1191 c.MutableSequence.register(D)
1192 bases = [c.MutableSequence, c.MutableMapping]
1193 for haystack in permutations(bases):
1194 m = mro(D, bases)
1195 self.assertEqual(m, [D, c.MutableSequence, c.Sequence,
1196 c.defaultdict, dict, c.MutableMapping,
1197 c.Mapping, c.Sized, c.Iterable, c.Container,
1198 object])
1199
1200 # Container and Callable are registered on different base classes and
1201 # a generic function supporting both should always pick the Callable
1202 # implementation if a C instance is passed.
1203 class C(c.defaultdict):
1204 def __call__(self):
1205 pass
1206 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1207 for haystack in permutations(bases):
1208 m = mro(C, haystack)
1209 self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping,
1210 c.Sized, c.Iterable, c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001211
1212 def test_register_abc(self):
1213 c = collections
1214 d = {"a": "b"}
1215 l = [1, 2, 3]
1216 s = {object(), None}
1217 f = frozenset(s)
1218 t = (1, 2, 3)
1219 @functools.singledispatch
1220 def g(obj):
1221 return "base"
1222 self.assertEqual(g(d), "base")
1223 self.assertEqual(g(l), "base")
1224 self.assertEqual(g(s), "base")
1225 self.assertEqual(g(f), "base")
1226 self.assertEqual(g(t), "base")
1227 g.register(c.Sized, lambda obj: "sized")
1228 self.assertEqual(g(d), "sized")
1229 self.assertEqual(g(l), "sized")
1230 self.assertEqual(g(s), "sized")
1231 self.assertEqual(g(f), "sized")
1232 self.assertEqual(g(t), "sized")
1233 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1234 self.assertEqual(g(d), "mutablemapping")
1235 self.assertEqual(g(l), "sized")
1236 self.assertEqual(g(s), "sized")
1237 self.assertEqual(g(f), "sized")
1238 self.assertEqual(g(t), "sized")
1239 g.register(c.ChainMap, lambda obj: "chainmap")
1240 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1241 self.assertEqual(g(l), "sized")
1242 self.assertEqual(g(s), "sized")
1243 self.assertEqual(g(f), "sized")
1244 self.assertEqual(g(t), "sized")
1245 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1246 self.assertEqual(g(d), "mutablemapping")
1247 self.assertEqual(g(l), "mutablesequence")
1248 self.assertEqual(g(s), "sized")
1249 self.assertEqual(g(f), "sized")
1250 self.assertEqual(g(t), "sized")
1251 g.register(c.MutableSet, lambda obj: "mutableset")
1252 self.assertEqual(g(d), "mutablemapping")
1253 self.assertEqual(g(l), "mutablesequence")
1254 self.assertEqual(g(s), "mutableset")
1255 self.assertEqual(g(f), "sized")
1256 self.assertEqual(g(t), "sized")
1257 g.register(c.Mapping, lambda obj: "mapping")
1258 self.assertEqual(g(d), "mutablemapping") # not specific enough
1259 self.assertEqual(g(l), "mutablesequence")
1260 self.assertEqual(g(s), "mutableset")
1261 self.assertEqual(g(f), "sized")
1262 self.assertEqual(g(t), "sized")
1263 g.register(c.Sequence, lambda obj: "sequence")
1264 self.assertEqual(g(d), "mutablemapping")
1265 self.assertEqual(g(l), "mutablesequence")
1266 self.assertEqual(g(s), "mutableset")
1267 self.assertEqual(g(f), "sized")
1268 self.assertEqual(g(t), "sequence")
1269 g.register(c.Set, lambda obj: "set")
1270 self.assertEqual(g(d), "mutablemapping")
1271 self.assertEqual(g(l), "mutablesequence")
1272 self.assertEqual(g(s), "mutableset")
1273 self.assertEqual(g(f), "set")
1274 self.assertEqual(g(t), "sequence")
1275 g.register(dict, lambda obj: "dict")
1276 self.assertEqual(g(d), "dict")
1277 self.assertEqual(g(l), "mutablesequence")
1278 self.assertEqual(g(s), "mutableset")
1279 self.assertEqual(g(f), "set")
1280 self.assertEqual(g(t), "sequence")
1281 g.register(list, lambda obj: "list")
1282 self.assertEqual(g(d), "dict")
1283 self.assertEqual(g(l), "list")
1284 self.assertEqual(g(s), "mutableset")
1285 self.assertEqual(g(f), "set")
1286 self.assertEqual(g(t), "sequence")
1287 g.register(set, lambda obj: "concrete-set")
1288 self.assertEqual(g(d), "dict")
1289 self.assertEqual(g(l), "list")
1290 self.assertEqual(g(s), "concrete-set")
1291 self.assertEqual(g(f), "set")
1292 self.assertEqual(g(t), "sequence")
1293 g.register(frozenset, lambda obj: "frozen-set")
1294 self.assertEqual(g(d), "dict")
1295 self.assertEqual(g(l), "list")
1296 self.assertEqual(g(s), "concrete-set")
1297 self.assertEqual(g(f), "frozen-set")
1298 self.assertEqual(g(t), "sequence")
1299 g.register(tuple, lambda obj: "tuple")
1300 self.assertEqual(g(d), "dict")
1301 self.assertEqual(g(l), "list")
1302 self.assertEqual(g(s), "concrete-set")
1303 self.assertEqual(g(f), "frozen-set")
1304 self.assertEqual(g(t), "tuple")
1305
Łukasz Langa3720c772013-07-01 16:00:38 +02001306 def test_c3_abc(self):
1307 c = collections
1308 mro = functools._c3_mro
1309 class A(object):
1310 pass
1311 class B(A):
1312 def __len__(self):
1313 return 0 # implies Sized
1314 @c.Container.register
1315 class C(object):
1316 pass
1317 class D(object):
1318 pass # unrelated
1319 class X(D, C, B):
1320 def __call__(self):
1321 pass # implies Callable
1322 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
1323 for abcs in permutations([c.Sized, c.Callable, c.Container]):
1324 self.assertEqual(mro(X, abcs=abcs), expected)
1325 # unrelated ABCs don't appear in the resulting MRO
1326 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
1327 self.assertEqual(mro(X, abcs=many_abcs), expected)
1328
Łukasz Langa6f692512013-06-05 12:20:24 +02001329 def test_mro_conflicts(self):
1330 c = collections
Łukasz Langa6f692512013-06-05 12:20:24 +02001331 @functools.singledispatch
1332 def g(arg):
1333 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02001334 class O(c.Sized):
1335 def __len__(self):
1336 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001337 o = O()
1338 self.assertEqual(g(o), "base")
1339 g.register(c.Iterable, lambda arg: "iterable")
1340 g.register(c.Container, lambda arg: "container")
1341 g.register(c.Sized, lambda arg: "sized")
1342 g.register(c.Set, lambda arg: "set")
1343 self.assertEqual(g(o), "sized")
1344 c.Iterable.register(O)
1345 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
1346 c.Container.register(O)
1347 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02001348 c.Set.register(O)
1349 self.assertEqual(g(o), "set") # because c.Set is a subclass of
1350 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02001351 class P:
1352 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02001353 p = P()
1354 self.assertEqual(g(p), "base")
1355 c.Iterable.register(P)
1356 self.assertEqual(g(p), "iterable")
1357 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02001358 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02001359 g(p)
Łukasz Langa3720c772013-07-01 16:00:38 +02001360 self.assertIn(
1361 str(re_one.exception),
1362 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1363 "or <class 'collections.abc.Iterable'>"),
1364 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
1365 "or <class 'collections.abc.Container'>")),
1366 )
Łukasz Langa6f692512013-06-05 12:20:24 +02001367 class Q(c.Sized):
1368 def __len__(self):
1369 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001370 q = Q()
1371 self.assertEqual(g(q), "sized")
1372 c.Iterable.register(Q)
1373 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
1374 c.Set.register(Q)
1375 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02001376 # c.Sized and c.Iterable
1377 @functools.singledispatch
1378 def h(arg):
1379 return "base"
1380 @h.register(c.Sized)
1381 def _(arg):
1382 return "sized"
1383 @h.register(c.Container)
1384 def _(arg):
1385 return "container"
1386 # Even though Sized and Container are explicit bases of MutableMapping,
1387 # this ABC is implicitly registered on defaultdict which makes all of
1388 # MutableMapping's bases implicit as well from defaultdict's
1389 # perspective.
1390 with self.assertRaises(RuntimeError) as re_two:
1391 h(c.defaultdict(lambda: 0))
1392 self.assertIn(
1393 str(re_two.exception),
1394 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1395 "or <class 'collections.abc.Sized'>"),
1396 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1397 "or <class 'collections.abc.Container'>")),
1398 )
1399 class R(c.defaultdict):
1400 pass
1401 c.MutableSequence.register(R)
1402 @functools.singledispatch
1403 def i(arg):
1404 return "base"
1405 @i.register(c.MutableMapping)
1406 def _(arg):
1407 return "mapping"
1408 @i.register(c.MutableSequence)
1409 def _(arg):
1410 return "sequence"
1411 r = R()
1412 self.assertEqual(i(r), "sequence")
1413 class S:
1414 pass
1415 class T(S, c.Sized):
1416 def __len__(self):
1417 return 0
1418 t = T()
1419 self.assertEqual(h(t), "sized")
1420 c.Container.register(T)
1421 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
1422 class U:
1423 def __len__(self):
1424 return 0
1425 u = U()
1426 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
1427 # from the existence of __len__()
1428 c.Container.register(U)
1429 # There is no preference for registered versus inferred ABCs.
1430 with self.assertRaises(RuntimeError) as re_three:
1431 h(u)
1432 self.assertIn(
1433 str(re_three.exception),
1434 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1435 "or <class 'collections.abc.Sized'>"),
1436 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1437 "or <class 'collections.abc.Container'>")),
1438 )
1439 class V(c.Sized, S):
1440 def __len__(self):
1441 return 0
1442 @functools.singledispatch
1443 def j(arg):
1444 return "base"
1445 @j.register(S)
1446 def _(arg):
1447 return "s"
1448 @j.register(c.Container)
1449 def _(arg):
1450 return "container"
1451 v = V()
1452 self.assertEqual(j(v), "s")
1453 c.Container.register(V)
1454 self.assertEqual(j(v), "container") # because it ends up right after
1455 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02001456
1457 def test_cache_invalidation(self):
1458 from collections import UserDict
1459 class TracingDict(UserDict):
1460 def __init__(self, *args, **kwargs):
1461 super(TracingDict, self).__init__(*args, **kwargs)
1462 self.set_ops = []
1463 self.get_ops = []
1464 def __getitem__(self, key):
1465 result = self.data[key]
1466 self.get_ops.append(key)
1467 return result
1468 def __setitem__(self, key, value):
1469 self.set_ops.append(key)
1470 self.data[key] = value
1471 def clear(self):
1472 self.data.clear()
1473 _orig_wkd = functools.WeakKeyDictionary
1474 td = TracingDict()
1475 functools.WeakKeyDictionary = lambda: td
1476 c = collections
1477 @functools.singledispatch
1478 def g(arg):
1479 return "base"
1480 d = {}
1481 l = []
1482 self.assertEqual(len(td), 0)
1483 self.assertEqual(g(d), "base")
1484 self.assertEqual(len(td), 1)
1485 self.assertEqual(td.get_ops, [])
1486 self.assertEqual(td.set_ops, [dict])
1487 self.assertEqual(td.data[dict], g.registry[object])
1488 self.assertEqual(g(l), "base")
1489 self.assertEqual(len(td), 2)
1490 self.assertEqual(td.get_ops, [])
1491 self.assertEqual(td.set_ops, [dict, list])
1492 self.assertEqual(td.data[dict], g.registry[object])
1493 self.assertEqual(td.data[list], g.registry[object])
1494 self.assertEqual(td.data[dict], td.data[list])
1495 self.assertEqual(g(l), "base")
1496 self.assertEqual(g(d), "base")
1497 self.assertEqual(td.get_ops, [list, dict])
1498 self.assertEqual(td.set_ops, [dict, list])
1499 g.register(list, lambda arg: "list")
1500 self.assertEqual(td.get_ops, [list, dict])
1501 self.assertEqual(len(td), 0)
1502 self.assertEqual(g(d), "base")
1503 self.assertEqual(len(td), 1)
1504 self.assertEqual(td.get_ops, [list, dict])
1505 self.assertEqual(td.set_ops, [dict, list, dict])
1506 self.assertEqual(td.data[dict],
1507 functools._find_impl(dict, g.registry))
1508 self.assertEqual(g(l), "list")
1509 self.assertEqual(len(td), 2)
1510 self.assertEqual(td.get_ops, [list, dict])
1511 self.assertEqual(td.set_ops, [dict, list, dict, list])
1512 self.assertEqual(td.data[list],
1513 functools._find_impl(list, g.registry))
1514 class X:
1515 pass
1516 c.MutableMapping.register(X) # Will not invalidate the cache,
1517 # not using ABCs yet.
1518 self.assertEqual(g(d), "base")
1519 self.assertEqual(g(l), "list")
1520 self.assertEqual(td.get_ops, [list, dict, dict, list])
1521 self.assertEqual(td.set_ops, [dict, list, dict, list])
1522 g.register(c.Sized, lambda arg: "sized")
1523 self.assertEqual(len(td), 0)
1524 self.assertEqual(g(d), "sized")
1525 self.assertEqual(len(td), 1)
1526 self.assertEqual(td.get_ops, [list, dict, dict, list])
1527 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
1528 self.assertEqual(g(l), "list")
1529 self.assertEqual(len(td), 2)
1530 self.assertEqual(td.get_ops, [list, dict, dict, list])
1531 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1532 self.assertEqual(g(l), "list")
1533 self.assertEqual(g(d), "sized")
1534 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
1535 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1536 g.dispatch(list)
1537 g.dispatch(dict)
1538 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
1539 list, dict])
1540 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1541 c.MutableSet.register(X) # Will invalidate the cache.
1542 self.assertEqual(len(td), 2) # Stale cache.
1543 self.assertEqual(g(l), "list")
1544 self.assertEqual(len(td), 1)
1545 g.register(c.MutableMapping, lambda arg: "mutablemapping")
1546 self.assertEqual(len(td), 0)
1547 self.assertEqual(g(d), "mutablemapping")
1548 self.assertEqual(len(td), 1)
1549 self.assertEqual(g(l), "list")
1550 self.assertEqual(len(td), 2)
1551 g.register(dict, lambda arg: "dict")
1552 self.assertEqual(g(d), "dict")
1553 self.assertEqual(g(l), "list")
1554 g._clear_cache()
1555 self.assertEqual(len(td), 0)
1556 functools.WeakKeyDictionary = _orig_wkd
1557
1558
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001559def test_main(verbose=None):
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001560 test_classes = (
Antoine Pitroub5b37142012-11-13 21:35:40 +01001561 TestPartialC,
1562 TestPartialPy,
1563 TestPartialCSubclass,
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001564 TestPartialMethod,
Thomas Wouters73e5a5b2006-06-08 15:35:45 +00001565 TestUpdateWrapper,
Benjamin Peterson9c2930e2010-08-23 17:40:33 +00001566 TestTotalOrdering,
Antoine Pitroub5b37142012-11-13 21:35:40 +01001567 TestCmpToKeyC,
1568 TestCmpToKeyPy,
Guido van Rossum0919a1a2006-08-26 20:49:04 +00001569 TestWraps,
Georg Brandl2e7346a2010-07-31 18:09:23 +00001570 TestReduce,
1571 TestLRU,
Łukasz Langa6f692512013-06-05 12:20:24 +02001572 TestSingleDispatch,
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001573 )
Benjamin Petersonee8712c2008-05-20 21:35:26 +00001574 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001575
1576 # verify reference counting
1577 if verbose and hasattr(sys, "gettotalrefcount"):
1578 import gc
1579 counts = [None] * 5
Guido van Rossum805365e2007-05-07 22:24:25 +00001580 for i in range(len(counts)):
Benjamin Petersonee8712c2008-05-20 21:35:26 +00001581 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001582 gc.collect()
1583 counts[i] = sys.gettotalrefcount()
Guido van Rossumbe19ed72007-02-09 05:37:30 +00001584 print(counts)
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001585
1586if __name__ == '__main__':
1587 test_main(verbose=True)