blob: 31c093b1f08cdded0fe6ec2d56f9f80e0a6ee12b [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettinger003be522011-05-03 11:01:32 -07002import collections
Łukasz Langa6f692512013-06-05 12:20:24 +02003from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00004import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00005from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02006import sys
7from test import support
8import unittest
9from weakref import proxy
Raymond Hettinger9c323f82005-02-28 19:39:44 +000010
Antoine Pitroub5b37142012-11-13 21:35:40 +010011import functools
12
Antoine Pitroub5b37142012-11-13 21:35:40 +010013py_functools = support.import_fresh_module('functools', blocked=['_functools'])
14c_functools = support.import_fresh_module('functools', fresh=['_functools'])
15
Łukasz Langa6f692512013-06-05 12:20:24 +020016decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
17
18
Raymond Hettinger9c323f82005-02-28 19:39:44 +000019def capture(*args, **kw):
20 """capture all positional and keyword arguments"""
21 return args, kw
22
Łukasz Langa6f692512013-06-05 12:20:24 +020023
Jack Diederiche0cbd692009-04-01 04:27:09 +000024def signature(part):
25 """ return the signature of a partial object """
26 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000027
Łukasz Langa6f692512013-06-05 12:20:24 +020028
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020029class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000030
31 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010032 p = self.partial(capture, 1, 2, a=10, b=20)
33 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000034 self.assertEqual(p(3, 4, b=30, c=40),
35 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010036 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000037 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000038
39 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010040 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000041 # attributes should be readable
42 self.assertEqual(p.func, capture)
43 self.assertEqual(p.args, (1, 2))
44 self.assertEqual(p.keywords, dict(a=10, b=20))
45 # attributes should not be writable
Antoine Pitroub5b37142012-11-13 21:35:40 +010046 if not isinstance(self.partial, type):
Raymond Hettinger9c323f82005-02-28 19:39:44 +000047 return
Georg Brandl89fad142010-03-14 10:23:39 +000048 self.assertRaises(AttributeError, setattr, p, 'func', map)
49 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
50 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
51
Antoine Pitroub5b37142012-11-13 21:35:40 +010052 p = self.partial(hex)
Georg Brandl89fad142010-03-14 10:23:39 +000053 try:
54 del p.__dict__
55 except TypeError:
56 pass
57 else:
58 self.fail('partial object allowed __dict__ to be deleted')
Raymond Hettinger9c323f82005-02-28 19:39:44 +000059
60 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010061 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000062 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010063 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000064 except TypeError:
65 pass
66 else:
67 self.fail('First arg not checked for callability')
68
69 def test_protection_of_callers_dict_argument(self):
70 # a caller's dictionary should not be altered by partial
71 def func(a=10, b=20):
72 return a
73 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010074 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000075 self.assertEqual(p(**d), 3)
76 self.assertEqual(d, {'a':3})
77 p(b=7)
78 self.assertEqual(d, {'a':3})
79
80 def test_arg_combinations(self):
81 # exercise special code paths for zero args in either partial
82 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010083 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000084 self.assertEqual(p(), ((), {}))
85 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010086 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000087 self.assertEqual(p(), ((1,2), {}))
88 self.assertEqual(p(3,4), ((1,2,3,4), {}))
89
90 def test_kw_combinations(self):
91 # exercise special code paths for no keyword args in
92 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010093 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000094 self.assertEqual(p(), ((), {}))
95 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010096 p = self.partial(capture, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000097 self.assertEqual(p(), ((), {'a':1}))
98 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
99 # keyword args in the call override those in the partial object
100 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
101
102 def test_positional(self):
103 # make sure positional arguments are captured correctly
104 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100105 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000106 expected = args + ('x',)
107 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000108 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000109
110 def test_keyword(self):
111 # make sure keyword arguments are captured correctly
112 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100113 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000114 expected = {'a':a,'x':None}
115 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000116 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000117
118 def test_no_side_effects(self):
119 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100120 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000121 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000122 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000123 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000124 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000125
126 def test_error_propagation(self):
127 def f(x, y):
128 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100129 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
130 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
131 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
132 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000133
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000134 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100135 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000136 p = proxy(f)
137 self.assertEqual(f.func, p.func)
138 f = None
139 self.assertRaises(ReferenceError, getattr, p, 'func')
140
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000141 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000142 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100143 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000144 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100145 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000146 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000147
Łukasz Langa6f692512013-06-05 12:20:24 +0200148
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200149@unittest.skipUnless(c_functools, 'requires the C _functools module')
150class TestPartialC(TestPartial, unittest.TestCase):
151 if c_functools:
152 partial = c_functools.partial
153
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000154 def test_repr(self):
155 args = (object(), object())
156 args_repr = ', '.join(repr(a) for a in args)
157 kwargs = {'a': object(), 'b': object()}
158 kwargs_repr = ', '.join("%s=%r" % (k, v) for k, v in kwargs.items())
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200159 if self.partial is c_functools.partial:
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000160 name = 'functools.partial'
161 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100162 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000163
Antoine Pitroub5b37142012-11-13 21:35:40 +0100164 f = self.partial(capture)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000165 self.assertEqual('{}({!r})'.format(name, capture),
166 repr(f))
167
Antoine Pitroub5b37142012-11-13 21:35:40 +0100168 f = self.partial(capture, *args)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000169 self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr),
170 repr(f))
171
Antoine Pitroub5b37142012-11-13 21:35:40 +0100172 f = self.partial(capture, **kwargs)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000173 self.assertEqual('{}({!r}, {})'.format(name, capture, kwargs_repr),
174 repr(f))
175
Antoine Pitroub5b37142012-11-13 21:35:40 +0100176 f = self.partial(capture, *args, **kwargs)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000177 self.assertEqual('{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr),
178 repr(f))
179
Jack Diederiche0cbd692009-04-01 04:27:09 +0000180 def test_pickle(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100181 f = self.partial(signature, 'asdf', bar=True)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000182 f.add_something_to__dict__ = True
183 f_copy = pickle.loads(pickle.dumps(f))
184 self.assertEqual(signature(f), signature(f_copy))
185
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200186 # Issue 6083: Reference counting bug
187 def test_setstate_refcount(self):
188 class BadSequence:
189 def __len__(self):
190 return 4
191 def __getitem__(self, key):
192 if key == 0:
193 return max
194 elif key == 1:
195 return tuple(range(1000000))
196 elif key in (2, 3):
197 return {}
198 raise IndexError
199
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200200 f = self.partial(object)
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200201 self.assertRaisesRegex(SystemError,
202 "new style getargs format but argument is not a tuple",
203 f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000204
Łukasz Langa6f692512013-06-05 12:20:24 +0200205
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200206class TestPartialPy(TestPartial, unittest.TestCase):
207 partial = staticmethod(py_functools.partial)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000208
Łukasz Langa6f692512013-06-05 12:20:24 +0200209
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200210if c_functools:
211 class PartialSubclass(c_functools.partial):
212 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100213
Łukasz Langa6f692512013-06-05 12:20:24 +0200214
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200215@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200216class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200217 if c_functools:
218 partial = PartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000219
Łukasz Langa6f692512013-06-05 12:20:24 +0200220
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000221class TestPartialMethod(unittest.TestCase):
222
223 class A(object):
224 nothing = functools.partialmethod(capture)
225 positional = functools.partialmethod(capture, 1)
226 keywords = functools.partialmethod(capture, a=2)
227 both = functools.partialmethod(capture, 3, b=4)
228
229 nested = functools.partialmethod(positional, 5)
230
231 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
232
233 static = functools.partialmethod(staticmethod(capture), 8)
234 cls = functools.partialmethod(classmethod(capture), d=9)
235
236 a = A()
237
238 def test_arg_combinations(self):
239 self.assertEqual(self.a.nothing(), ((self.a,), {}))
240 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
241 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
242 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
243
244 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
245 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
246 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
247 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
248
249 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
250 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
251 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
252 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
253
254 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
255 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
256 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
257 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
258
259 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
260
261 def test_nested(self):
262 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
263 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
264 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
265 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
266
267 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
268
269 def test_over_partial(self):
270 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
271 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
272 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
273 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
274
275 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
276
277 def test_bound_method_introspection(self):
278 obj = self.a
279 self.assertIs(obj.both.__self__, obj)
280 self.assertIs(obj.nested.__self__, obj)
281 self.assertIs(obj.over_partial.__self__, obj)
282 self.assertIs(obj.cls.__self__, self.A)
283 self.assertIs(self.A.cls.__self__, self.A)
284
285 def test_unbound_method_retrieval(self):
286 obj = self.A
287 self.assertFalse(hasattr(obj.both, "__self__"))
288 self.assertFalse(hasattr(obj.nested, "__self__"))
289 self.assertFalse(hasattr(obj.over_partial, "__self__"))
290 self.assertFalse(hasattr(obj.static, "__self__"))
291 self.assertFalse(hasattr(self.a.static, "__self__"))
292
293 def test_descriptors(self):
294 for obj in [self.A, self.a]:
295 with self.subTest(obj=obj):
296 self.assertEqual(obj.static(), ((8,), {}))
297 self.assertEqual(obj.static(5), ((8, 5), {}))
298 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
299 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
300
301 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
302 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
303 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
304 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
305
306 def test_overriding_keywords(self):
307 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
308 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
309
310 def test_invalid_args(self):
311 with self.assertRaises(TypeError):
312 class B(object):
313 method = functools.partialmethod(None, 1)
314
315 def test_repr(self):
316 self.assertEqual(repr(vars(self.A)['both']),
317 'functools.partialmethod({}, 3, b=4)'.format(capture))
318
319 def test_abstract(self):
320 class Abstract(abc.ABCMeta):
321
322 @abc.abstractmethod
323 def add(self, x, y):
324 pass
325
326 add5 = functools.partialmethod(add, 5)
327
328 self.assertTrue(Abstract.add.__isabstractmethod__)
329 self.assertTrue(Abstract.add5.__isabstractmethod__)
330
331 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
332 self.assertFalse(getattr(func, '__isabstractmethod__', False))
333
334
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000335class TestUpdateWrapper(unittest.TestCase):
336
337 def check_wrapper(self, wrapper, wrapped,
338 assigned=functools.WRAPPER_ASSIGNMENTS,
339 updated=functools.WRAPPER_UPDATES):
340 # Check attributes were assigned
341 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000342 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000343 # Check attributes were updated
344 for name in updated:
345 wrapper_attr = getattr(wrapper, name)
346 wrapped_attr = getattr(wrapped, name)
347 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000348 if name == "__dict__" and key == "__wrapped__":
349 # __wrapped__ is overwritten by the update code
350 continue
351 self.assertIs(wrapped_attr[key], wrapper_attr[key])
352 # Check __wrapped__
353 self.assertIs(wrapper.__wrapped__, wrapped)
354
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000355
R. David Murray378c0cf2010-02-24 01:46:21 +0000356 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000357 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000358 """This is a test"""
359 pass
360 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000361 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000362 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000363 pass
364 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000365 return wrapper, f
366
367 def test_default_update(self):
368 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000369 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000370 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000371 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600372 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000373 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000374 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
375 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000376
R. David Murray378c0cf2010-02-24 01:46:21 +0000377 @unittest.skipIf(sys.flags.optimize >= 2,
378 "Docstrings are omitted with -O2 and above")
379 def test_default_update_doc(self):
380 wrapper, f = self._default_update()
381 self.assertEqual(wrapper.__doc__, 'This is a test')
382
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000383 def test_no_update(self):
384 def f():
385 """This is a test"""
386 pass
387 f.attr = 'This is also a test'
388 def wrapper():
389 pass
390 functools.update_wrapper(wrapper, f, (), ())
391 self.check_wrapper(wrapper, f, (), ())
392 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600393 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000394 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000395 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000396 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000397
398 def test_selective_update(self):
399 def f():
400 pass
401 f.attr = 'This is a different test'
402 f.dict_attr = dict(a=1, b=2, c=3)
403 def wrapper():
404 pass
405 wrapper.dict_attr = {}
406 assign = ('attr',)
407 update = ('dict_attr',)
408 functools.update_wrapper(wrapper, f, assign, update)
409 self.check_wrapper(wrapper, f, assign, update)
410 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600411 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000412 self.assertEqual(wrapper.__doc__, None)
413 self.assertEqual(wrapper.attr, 'This is a different test')
414 self.assertEqual(wrapper.dict_attr, f.dict_attr)
415
Nick Coghlan98876832010-08-17 06:17:18 +0000416 def test_missing_attributes(self):
417 def f():
418 pass
419 def wrapper():
420 pass
421 wrapper.dict_attr = {}
422 assign = ('attr',)
423 update = ('dict_attr',)
424 # Missing attributes on wrapped object are ignored
425 functools.update_wrapper(wrapper, f, assign, update)
426 self.assertNotIn('attr', wrapper.__dict__)
427 self.assertEqual(wrapper.dict_attr, {})
428 # Wrapper must have expected attributes for updating
429 del wrapper.dict_attr
430 with self.assertRaises(AttributeError):
431 functools.update_wrapper(wrapper, f, assign, update)
432 wrapper.dict_attr = 1
433 with self.assertRaises(AttributeError):
434 functools.update_wrapper(wrapper, f, assign, update)
435
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200436 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000437 @unittest.skipIf(sys.flags.optimize >= 2,
438 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000439 def test_builtin_update(self):
440 # Test for bug #1576241
441 def wrapper():
442 pass
443 functools.update_wrapper(wrapper, max)
444 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000445 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000446 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000447
Łukasz Langa6f692512013-06-05 12:20:24 +0200448
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000449class TestWraps(TestUpdateWrapper):
450
R. David Murray378c0cf2010-02-24 01:46:21 +0000451 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000452 def f():
453 """This is a test"""
454 pass
455 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000456 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000457 @functools.wraps(f)
458 def wrapper():
459 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600460 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000461
462 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600463 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000464 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000465 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600466 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000467 self.assertEqual(wrapper.attr, 'This is also a test')
468
Antoine Pitroub5b37142012-11-13 21:35:40 +0100469 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000470 "Docstrings are omitted with -O2 and above")
471 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600472 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000473 self.assertEqual(wrapper.__doc__, 'This is a test')
474
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000475 def test_no_update(self):
476 def f():
477 """This is a test"""
478 pass
479 f.attr = 'This is also a test'
480 @functools.wraps(f, (), ())
481 def wrapper():
482 pass
483 self.check_wrapper(wrapper, f, (), ())
484 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600485 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000486 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000487 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000488
489 def test_selective_update(self):
490 def f():
491 pass
492 f.attr = 'This is a different test'
493 f.dict_attr = dict(a=1, b=2, c=3)
494 def add_dict_attr(f):
495 f.dict_attr = {}
496 return f
497 assign = ('attr',)
498 update = ('dict_attr',)
499 @functools.wraps(f, assign, update)
500 @add_dict_attr
501 def wrapper():
502 pass
503 self.check_wrapper(wrapper, f, assign, update)
504 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600505 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000506 self.assertEqual(wrapper.__doc__, None)
507 self.assertEqual(wrapper.attr, 'This is a different test')
508 self.assertEqual(wrapper.dict_attr, f.dict_attr)
509
Łukasz Langa6f692512013-06-05 12:20:24 +0200510
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000511class TestReduce(unittest.TestCase):
512 func = functools.reduce
513
514 def test_reduce(self):
515 class Squares:
516 def __init__(self, max):
517 self.max = max
518 self.sofar = []
519
520 def __len__(self):
521 return len(self.sofar)
522
523 def __getitem__(self, i):
524 if not 0 <= i < self.max: raise IndexError
525 n = len(self.sofar)
526 while n <= i:
527 self.sofar.append(n*n)
528 n += 1
529 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000530 def add(x, y):
531 return x + y
532 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000533 self.assertEqual(
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000534 self.func(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000535 ['a','c','d','w']
536 )
537 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
538 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000539 self.func(lambda x, y: x*y, range(2,21), 1),
540 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000541 )
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000542 self.assertEqual(self.func(add, Squares(10)), 285)
543 self.assertEqual(self.func(add, Squares(10), 0), 285)
544 self.assertEqual(self.func(add, Squares(0), 0), 0)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000545 self.assertRaises(TypeError, self.func)
546 self.assertRaises(TypeError, self.func, 42, 42)
547 self.assertRaises(TypeError, self.func, 42, 42, 42)
548 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
549 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
550 self.assertRaises(TypeError, self.func, 42, (42, 42))
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000551 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
552 self.assertRaises(TypeError, self.func, add, "")
553 self.assertRaises(TypeError, self.func, add, ())
554 self.assertRaises(TypeError, self.func, add, object())
555
556 class TestFailingIter:
557 def __iter__(self):
558 raise RuntimeError
559 self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
560
561 self.assertEqual(self.func(add, [], None), None)
562 self.assertEqual(self.func(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000563
564 class BadSeq:
565 def __getitem__(self, index):
566 raise ValueError
567 self.assertRaises(ValueError, self.func, 42, BadSeq())
568
569 # Test reduce()'s use of iterators.
570 def test_iterator_usage(self):
571 class SequenceClass:
572 def __init__(self, n):
573 self.n = n
574 def __getitem__(self, i):
575 if 0 <= i < self.n:
576 return i
577 else:
578 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000579
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000580 from operator import add
581 self.assertEqual(self.func(add, SequenceClass(5)), 10)
582 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
583 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
584 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
585 self.assertEqual(self.func(add, SequenceClass(1)), 0)
586 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
587
588 d = {"one": 1, "two": 2, "three": 3}
589 self.assertEqual(self.func(add, d), "".join(d.keys()))
590
Łukasz Langa6f692512013-06-05 12:20:24 +0200591
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200592class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700593
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000594 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700595 def cmp1(x, y):
596 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100597 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700598 self.assertEqual(key(3), key(3))
599 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100600 self.assertGreaterEqual(key(3), key(3))
601
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700602 def cmp2(x, y):
603 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100604 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700605 self.assertEqual(key(4.0), key('4'))
606 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100607 self.assertLessEqual(key(2), key('35'))
608 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700609
610 def test_cmp_to_key_arguments(self):
611 def cmp1(x, y):
612 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100613 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700614 self.assertEqual(key(obj=3), key(obj=3))
615 self.assertGreater(key(obj=3), key(obj=1))
616 with self.assertRaises((TypeError, AttributeError)):
617 key(3) > 1 # rhs is not a K object
618 with self.assertRaises((TypeError, AttributeError)):
619 1 < key(3) # lhs is not a K object
620 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100621 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700622 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200623 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100624 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700625 with self.assertRaises(TypeError):
626 key() # too few args
627 with self.assertRaises(TypeError):
628 key(None, None) # too many args
629
630 def test_bad_cmp(self):
631 def cmp1(x, y):
632 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100633 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700634 with self.assertRaises(ZeroDivisionError):
635 key(3) > key(1)
636
637 class BadCmp:
638 def __lt__(self, other):
639 raise ZeroDivisionError
640 def cmp1(x, y):
641 return BadCmp()
642 with self.assertRaises(ZeroDivisionError):
643 key(3) > key(1)
644
645 def test_obj_field(self):
646 def cmp1(x, y):
647 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100648 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700649 self.assertEqual(key(50).obj, 50)
650
651 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000652 def mycmp(x, y):
653 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100654 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000655 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000656
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700657 def test_sort_int_str(self):
658 def mycmp(x, y):
659 x, y = int(x), int(y)
660 return (x > y) - (x < y)
661 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100662 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700663 self.assertEqual([int(value) for value in values],
664 [0, 1, 1, 2, 3, 4, 5, 7, 10])
665
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000666 def test_hash(self):
667 def mycmp(x, y):
668 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100669 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000670 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700671 self.assertRaises(TypeError, hash, k)
Raymond Hettingere7a24302011-05-03 11:16:36 -0700672 self.assertNotIsInstance(k, collections.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000673
Łukasz Langa6f692512013-06-05 12:20:24 +0200674
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200675@unittest.skipUnless(c_functools, 'requires the C _functools module')
676class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
677 if c_functools:
678 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100679
Łukasz Langa6f692512013-06-05 12:20:24 +0200680
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200681class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100682 cmp_to_key = staticmethod(py_functools.cmp_to_key)
683
Łukasz Langa6f692512013-06-05 12:20:24 +0200684
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000685class TestTotalOrdering(unittest.TestCase):
686
687 def test_total_ordering_lt(self):
688 @functools.total_ordering
689 class A:
690 def __init__(self, value):
691 self.value = value
692 def __lt__(self, other):
693 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000694 def __eq__(self, other):
695 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000696 self.assertTrue(A(1) < A(2))
697 self.assertTrue(A(2) > A(1))
698 self.assertTrue(A(1) <= A(2))
699 self.assertTrue(A(2) >= A(1))
700 self.assertTrue(A(2) <= A(2))
701 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000702 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000703
704 def test_total_ordering_le(self):
705 @functools.total_ordering
706 class A:
707 def __init__(self, value):
708 self.value = value
709 def __le__(self, other):
710 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000711 def __eq__(self, other):
712 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000713 self.assertTrue(A(1) < A(2))
714 self.assertTrue(A(2) > A(1))
715 self.assertTrue(A(1) <= A(2))
716 self.assertTrue(A(2) >= A(1))
717 self.assertTrue(A(2) <= A(2))
718 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000719 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000720
721 def test_total_ordering_gt(self):
722 @functools.total_ordering
723 class A:
724 def __init__(self, value):
725 self.value = value
726 def __gt__(self, other):
727 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000728 def __eq__(self, other):
729 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000730 self.assertTrue(A(1) < A(2))
731 self.assertTrue(A(2) > A(1))
732 self.assertTrue(A(1) <= A(2))
733 self.assertTrue(A(2) >= A(1))
734 self.assertTrue(A(2) <= A(2))
735 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000736 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000737
738 def test_total_ordering_ge(self):
739 @functools.total_ordering
740 class A:
741 def __init__(self, value):
742 self.value = value
743 def __ge__(self, other):
744 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000745 def __eq__(self, other):
746 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000747 self.assertTrue(A(1) < A(2))
748 self.assertTrue(A(2) > A(1))
749 self.assertTrue(A(1) <= A(2))
750 self.assertTrue(A(2) >= A(1))
751 self.assertTrue(A(2) <= A(2))
752 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000753 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000754
755 def test_total_ordering_no_overwrite(self):
756 # new methods should not overwrite existing
757 @functools.total_ordering
758 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000759 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +0000760 self.assertTrue(A(1) < A(2))
761 self.assertTrue(A(2) > A(1))
762 self.assertTrue(A(1) <= A(2))
763 self.assertTrue(A(2) >= A(1))
764 self.assertTrue(A(2) <= A(2))
765 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000766
Benjamin Peterson42ebee32010-04-11 01:43:16 +0000767 def test_no_operations_defined(self):
768 with self.assertRaises(ValueError):
769 @functools.total_ordering
770 class A:
771 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000772
Nick Coghlanf05d9812013-10-02 00:02:03 +1000773 def test_type_error_when_not_implemented(self):
774 # bug 10042; ensure stack overflow does not occur
775 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000776 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +1000777 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000778 def __init__(self, value):
779 self.value = value
780 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000781 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000782 return self.value == other.value
783 return False
784 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000785 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000786 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +1000787 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000788
Nick Coghlanf05d9812013-10-02 00:02:03 +1000789 @functools.total_ordering
790 class ImplementsGreaterThan:
791 def __init__(self, value):
792 self.value = value
793 def __eq__(self, other):
794 if isinstance(other, ImplementsGreaterThan):
795 return self.value == other.value
796 return False
797 def __gt__(self, other):
798 if isinstance(other, ImplementsGreaterThan):
799 return self.value > other.value
800 return NotImplemented
801
802 @functools.total_ordering
803 class ImplementsLessThanEqualTo:
804 def __init__(self, value):
805 self.value = value
806 def __eq__(self, other):
807 if isinstance(other, ImplementsLessThanEqualTo):
808 return self.value == other.value
809 return False
810 def __le__(self, other):
811 if isinstance(other, ImplementsLessThanEqualTo):
812 return self.value <= other.value
813 return NotImplemented
814
815 @functools.total_ordering
816 class ImplementsGreaterThanEqualTo:
817 def __init__(self, value):
818 self.value = value
819 def __eq__(self, other):
820 if isinstance(other, ImplementsGreaterThanEqualTo):
821 return self.value == other.value
822 return False
823 def __ge__(self, other):
824 if isinstance(other, ImplementsGreaterThanEqualTo):
825 return self.value >= other.value
826 return NotImplemented
827
828 @functools.total_ordering
829 class ComparatorNotImplemented:
830 def __init__(self, value):
831 self.value = value
832 def __eq__(self, other):
833 if isinstance(other, ComparatorNotImplemented):
834 return self.value == other.value
835 return False
836 def __lt__(self, other):
837 return NotImplemented
838
839 with self.subTest("LT < 1"), self.assertRaises(TypeError):
840 ImplementsLessThan(-1) < 1
841
842 with self.subTest("LT < LE"), self.assertRaises(TypeError):
843 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
844
845 with self.subTest("LT < GT"), self.assertRaises(TypeError):
846 ImplementsLessThan(1) < ImplementsGreaterThan(1)
847
848 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
849 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
850
851 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
852 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
853
854 with self.subTest("GT > GE"), self.assertRaises(TypeError):
855 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
856
857 with self.subTest("GT > LT"), self.assertRaises(TypeError):
858 ImplementsGreaterThan(5) > ImplementsLessThan(5)
859
860 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
861 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
862
863 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
864 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
865
866 with self.subTest("GE when equal"):
867 a = ComparatorNotImplemented(8)
868 b = ComparatorNotImplemented(8)
869 self.assertEqual(a, b)
870 with self.assertRaises(TypeError):
871 a >= b
872
873 with self.subTest("LE when equal"):
874 a = ComparatorNotImplemented(9)
875 b = ComparatorNotImplemented(9)
876 self.assertEqual(a, b)
877 with self.assertRaises(TypeError):
878 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +0200879
Georg Brandl2e7346a2010-07-31 18:09:23 +0000880class TestLRU(unittest.TestCase):
881
882 def test_lru(self):
883 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100884 return 3 * x + y
Georg Brandl2e7346a2010-07-31 18:09:23 +0000885 f = functools.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000886 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000887 self.assertEqual(maxsize, 20)
888 self.assertEqual(currsize, 0)
889 self.assertEqual(hits, 0)
890 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000891
892 domain = range(5)
893 for i in range(1000):
894 x, y = choice(domain), choice(domain)
895 actual = f(x, y)
896 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +0000897 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000898 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000899 self.assertTrue(hits > misses)
900 self.assertEqual(hits + misses, 1000)
901 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000902
Raymond Hettinger02566ec2010-09-04 22:46:06 +0000903 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +0000904 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000905 self.assertEqual(hits, 0)
906 self.assertEqual(misses, 0)
907 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000908 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000909 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000910 self.assertEqual(hits, 0)
911 self.assertEqual(misses, 1)
912 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000913
Nick Coghlan98876832010-08-17 06:17:18 +0000914 # Test bypassing the cache
915 self.assertIs(f.__wrapped__, orig)
916 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000917 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000918 self.assertEqual(hits, 0)
919 self.assertEqual(misses, 1)
920 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +0000921
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000922 # test size zero (which means "never-cache")
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000923 @functools.lru_cache(0)
924 def f():
925 nonlocal f_cnt
926 f_cnt += 1
927 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000928 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000929 f_cnt = 0
930 for i in range(5):
931 self.assertEqual(f(), 20)
932 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000933 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000934 self.assertEqual(hits, 0)
935 self.assertEqual(misses, 5)
936 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000937
938 # test size one
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000939 @functools.lru_cache(1)
940 def f():
941 nonlocal f_cnt
942 f_cnt += 1
943 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000944 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000945 f_cnt = 0
946 for i in range(5):
947 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000948 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000949 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000950 self.assertEqual(hits, 4)
951 self.assertEqual(misses, 1)
952 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000953
Raymond Hettingerf3098282010-08-15 03:30:45 +0000954 # test size two
955 @functools.lru_cache(2)
956 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000957 nonlocal f_cnt
958 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +0000959 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +0000960 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000961 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +0000962 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
963 # * * * *
964 self.assertEqual(f(x), x*10)
965 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000966 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000967 self.assertEqual(hits, 12)
968 self.assertEqual(misses, 4)
969 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000970
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +0000971 def test_lru_with_maxsize_none(self):
972 @functools.lru_cache(maxsize=None)
973 def fib(n):
974 if n < 2:
975 return n
976 return fib(n-1) + fib(n-2)
977 self.assertEqual([fib(n) for n in range(16)],
978 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
979 self.assertEqual(fib.cache_info(),
980 functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
981 fib.cache_clear()
982 self.assertEqual(fib.cache_info(),
983 functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
984
Raymond Hettinger4b779b32011-10-15 23:50:42 -0700985 def test_lru_with_exceptions(self):
986 # Verify that user_function exceptions get passed through without
987 # creating a hard-to-read chained exception.
988 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +0100989 for maxsize in (None, 128):
Raymond Hettinger4b779b32011-10-15 23:50:42 -0700990 @functools.lru_cache(maxsize)
991 def func(i):
992 return 'abc'[i]
993 self.assertEqual(func(0), 'a')
994 with self.assertRaises(IndexError) as cm:
995 func(15)
996 self.assertIsNone(cm.exception.__context__)
997 # Verify that the previous exception did not result in a cached entry
998 with self.assertRaises(IndexError):
999 func(15)
1000
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001001 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001002 for maxsize in (None, 128):
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001003 @functools.lru_cache(maxsize=maxsize, typed=True)
1004 def square(x):
1005 return x * x
1006 self.assertEqual(square(3), 9)
1007 self.assertEqual(type(square(3)), type(9))
1008 self.assertEqual(square(3.0), 9.0)
1009 self.assertEqual(type(square(3.0)), type(9.0))
1010 self.assertEqual(square(x=3), 9)
1011 self.assertEqual(type(square(x=3)), type(9))
1012 self.assertEqual(square(x=3.0), 9.0)
1013 self.assertEqual(type(square(x=3.0)), type(9.0))
1014 self.assertEqual(square.cache_info().hits, 4)
1015 self.assertEqual(square.cache_info().misses, 4)
1016
Antoine Pitroub5b37142012-11-13 21:35:40 +01001017 def test_lru_with_keyword_args(self):
1018 @functools.lru_cache()
1019 def fib(n):
1020 if n < 2:
1021 return n
1022 return fib(n=n-1) + fib(n=n-2)
1023 self.assertEqual(
1024 [fib(n=number) for number in range(16)],
1025 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1026 )
1027 self.assertEqual(fib.cache_info(),
1028 functools._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
1029 fib.cache_clear()
1030 self.assertEqual(fib.cache_info(),
1031 functools._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
1032
1033 def test_lru_with_keyword_args_maxsize_none(self):
1034 @functools.lru_cache(maxsize=None)
1035 def fib(n):
1036 if n < 2:
1037 return n
1038 return fib(n=n-1) + fib(n=n-2)
1039 self.assertEqual([fib(n=number) for number in range(16)],
1040 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1041 self.assertEqual(fib.cache_info(),
1042 functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1043 fib.cache_clear()
1044 self.assertEqual(fib.cache_info(),
1045 functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1046
Raymond Hettinger03923422013-03-04 02:52:50 -05001047 def test_need_for_rlock(self):
1048 # This will deadlock on an LRU cache that uses a regular lock
1049
1050 @functools.lru_cache(maxsize=10)
1051 def test_func(x):
1052 'Used to demonstrate a reentrant lru_cache call within a single thread'
1053 return x
1054
1055 class DoubleEq:
1056 'Demonstrate a reentrant lru_cache call within a single thread'
1057 def __init__(self, x):
1058 self.x = x
1059 def __hash__(self):
1060 return self.x
1061 def __eq__(self, other):
1062 if self.x == 2:
1063 test_func(DoubleEq(1))
1064 return self.x == other.x
1065
1066 test_func(DoubleEq(1)) # Load the cache
1067 test_func(DoubleEq(2)) # Load the cache
1068 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1069 DoubleEq(2)) # Verify the correct return value
1070
1071
Łukasz Langa6f692512013-06-05 12:20:24 +02001072class TestSingleDispatch(unittest.TestCase):
1073 def test_simple_overloads(self):
1074 @functools.singledispatch
1075 def g(obj):
1076 return "base"
1077 def g_int(i):
1078 return "integer"
1079 g.register(int, g_int)
1080 self.assertEqual(g("str"), "base")
1081 self.assertEqual(g(1), "integer")
1082 self.assertEqual(g([1,2,3]), "base")
1083
1084 def test_mro(self):
1085 @functools.singledispatch
1086 def g(obj):
1087 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001088 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001089 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001090 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001091 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001092 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001093 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001094 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001095 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001096 def g_A(a):
1097 return "A"
1098 def g_B(b):
1099 return "B"
1100 g.register(A, g_A)
1101 g.register(B, g_B)
1102 self.assertEqual(g(A()), "A")
1103 self.assertEqual(g(B()), "B")
1104 self.assertEqual(g(C()), "A")
1105 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001106
1107 def test_register_decorator(self):
1108 @functools.singledispatch
1109 def g(obj):
1110 return "base"
1111 @g.register(int)
1112 def g_int(i):
1113 return "int %s" % (i,)
1114 self.assertEqual(g(""), "base")
1115 self.assertEqual(g(12), "int 12")
1116 self.assertIs(g.dispatch(int), g_int)
1117 self.assertIs(g.dispatch(object), g.dispatch(str))
1118 # Note: in the assert above this is not g.
1119 # @singledispatch returns the wrapper.
1120
1121 def test_wrapping_attributes(self):
1122 @functools.singledispatch
1123 def g(obj):
1124 "Simple test"
1125 return "Test"
1126 self.assertEqual(g.__name__, "g")
1127 self.assertEqual(g.__doc__, "Simple test")
1128
1129 @unittest.skipUnless(decimal, 'requires _decimal')
1130 @support.cpython_only
1131 def test_c_classes(self):
1132 @functools.singledispatch
1133 def g(obj):
1134 return "base"
1135 @g.register(decimal.DecimalException)
1136 def _(obj):
1137 return obj.args
1138 subn = decimal.Subnormal("Exponent < Emin")
1139 rnd = decimal.Rounded("Number got rounded")
1140 self.assertEqual(g(subn), ("Exponent < Emin",))
1141 self.assertEqual(g(rnd), ("Number got rounded",))
1142 @g.register(decimal.Subnormal)
1143 def _(obj):
1144 return "Too small to care."
1145 self.assertEqual(g(subn), "Too small to care.")
1146 self.assertEqual(g(rnd), ("Number got rounded",))
1147
1148 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001149 # None of the examples in this test depend on haystack ordering.
Łukasz Langa6f692512013-06-05 12:20:24 +02001150 c = collections
1151 mro = functools._compose_mro
1152 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1153 for haystack in permutations(bases):
1154 m = mro(dict, haystack)
Łukasz Langa3720c772013-07-01 16:00:38 +02001155 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, c.Sized,
1156 c.Iterable, c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001157 bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict]
1158 for haystack in permutations(bases):
1159 m = mro(c.ChainMap, haystack)
1160 self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping,
1161 c.Sized, c.Iterable, c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001162
1163 # If there's a generic function with implementations registered for
1164 # both Sized and Container, passing a defaultdict to it results in an
1165 # ambiguous dispatch which will cause a RuntimeError (see
1166 # test_mro_conflicts).
1167 bases = [c.Container, c.Sized, str]
1168 for haystack in permutations(bases):
1169 m = mro(c.defaultdict, [c.Sized, c.Container, str])
1170 self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container,
1171 object])
1172
1173 # MutableSequence below is registered directly on D. In other words, it
1174 # preceeds MutableMapping which means single dispatch will always
1175 # choose MutableSequence here.
1176 class D(c.defaultdict):
1177 pass
1178 c.MutableSequence.register(D)
1179 bases = [c.MutableSequence, c.MutableMapping]
1180 for haystack in permutations(bases):
1181 m = mro(D, bases)
1182 self.assertEqual(m, [D, c.MutableSequence, c.Sequence,
1183 c.defaultdict, dict, c.MutableMapping,
1184 c.Mapping, c.Sized, c.Iterable, c.Container,
1185 object])
1186
1187 # Container and Callable are registered on different base classes and
1188 # a generic function supporting both should always pick the Callable
1189 # implementation if a C instance is passed.
1190 class C(c.defaultdict):
1191 def __call__(self):
1192 pass
1193 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1194 for haystack in permutations(bases):
1195 m = mro(C, haystack)
1196 self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping,
1197 c.Sized, c.Iterable, c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001198
1199 def test_register_abc(self):
1200 c = collections
1201 d = {"a": "b"}
1202 l = [1, 2, 3]
1203 s = {object(), None}
1204 f = frozenset(s)
1205 t = (1, 2, 3)
1206 @functools.singledispatch
1207 def g(obj):
1208 return "base"
1209 self.assertEqual(g(d), "base")
1210 self.assertEqual(g(l), "base")
1211 self.assertEqual(g(s), "base")
1212 self.assertEqual(g(f), "base")
1213 self.assertEqual(g(t), "base")
1214 g.register(c.Sized, lambda obj: "sized")
1215 self.assertEqual(g(d), "sized")
1216 self.assertEqual(g(l), "sized")
1217 self.assertEqual(g(s), "sized")
1218 self.assertEqual(g(f), "sized")
1219 self.assertEqual(g(t), "sized")
1220 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1221 self.assertEqual(g(d), "mutablemapping")
1222 self.assertEqual(g(l), "sized")
1223 self.assertEqual(g(s), "sized")
1224 self.assertEqual(g(f), "sized")
1225 self.assertEqual(g(t), "sized")
1226 g.register(c.ChainMap, lambda obj: "chainmap")
1227 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1228 self.assertEqual(g(l), "sized")
1229 self.assertEqual(g(s), "sized")
1230 self.assertEqual(g(f), "sized")
1231 self.assertEqual(g(t), "sized")
1232 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1233 self.assertEqual(g(d), "mutablemapping")
1234 self.assertEqual(g(l), "mutablesequence")
1235 self.assertEqual(g(s), "sized")
1236 self.assertEqual(g(f), "sized")
1237 self.assertEqual(g(t), "sized")
1238 g.register(c.MutableSet, lambda obj: "mutableset")
1239 self.assertEqual(g(d), "mutablemapping")
1240 self.assertEqual(g(l), "mutablesequence")
1241 self.assertEqual(g(s), "mutableset")
1242 self.assertEqual(g(f), "sized")
1243 self.assertEqual(g(t), "sized")
1244 g.register(c.Mapping, lambda obj: "mapping")
1245 self.assertEqual(g(d), "mutablemapping") # not specific enough
1246 self.assertEqual(g(l), "mutablesequence")
1247 self.assertEqual(g(s), "mutableset")
1248 self.assertEqual(g(f), "sized")
1249 self.assertEqual(g(t), "sized")
1250 g.register(c.Sequence, lambda obj: "sequence")
1251 self.assertEqual(g(d), "mutablemapping")
1252 self.assertEqual(g(l), "mutablesequence")
1253 self.assertEqual(g(s), "mutableset")
1254 self.assertEqual(g(f), "sized")
1255 self.assertEqual(g(t), "sequence")
1256 g.register(c.Set, lambda obj: "set")
1257 self.assertEqual(g(d), "mutablemapping")
1258 self.assertEqual(g(l), "mutablesequence")
1259 self.assertEqual(g(s), "mutableset")
1260 self.assertEqual(g(f), "set")
1261 self.assertEqual(g(t), "sequence")
1262 g.register(dict, lambda obj: "dict")
1263 self.assertEqual(g(d), "dict")
1264 self.assertEqual(g(l), "mutablesequence")
1265 self.assertEqual(g(s), "mutableset")
1266 self.assertEqual(g(f), "set")
1267 self.assertEqual(g(t), "sequence")
1268 g.register(list, lambda obj: "list")
1269 self.assertEqual(g(d), "dict")
1270 self.assertEqual(g(l), "list")
1271 self.assertEqual(g(s), "mutableset")
1272 self.assertEqual(g(f), "set")
1273 self.assertEqual(g(t), "sequence")
1274 g.register(set, lambda obj: "concrete-set")
1275 self.assertEqual(g(d), "dict")
1276 self.assertEqual(g(l), "list")
1277 self.assertEqual(g(s), "concrete-set")
1278 self.assertEqual(g(f), "set")
1279 self.assertEqual(g(t), "sequence")
1280 g.register(frozenset, lambda obj: "frozen-set")
1281 self.assertEqual(g(d), "dict")
1282 self.assertEqual(g(l), "list")
1283 self.assertEqual(g(s), "concrete-set")
1284 self.assertEqual(g(f), "frozen-set")
1285 self.assertEqual(g(t), "sequence")
1286 g.register(tuple, lambda obj: "tuple")
1287 self.assertEqual(g(d), "dict")
1288 self.assertEqual(g(l), "list")
1289 self.assertEqual(g(s), "concrete-set")
1290 self.assertEqual(g(f), "frozen-set")
1291 self.assertEqual(g(t), "tuple")
1292
Łukasz Langa3720c772013-07-01 16:00:38 +02001293 def test_c3_abc(self):
1294 c = collections
1295 mro = functools._c3_mro
1296 class A(object):
1297 pass
1298 class B(A):
1299 def __len__(self):
1300 return 0 # implies Sized
1301 @c.Container.register
1302 class C(object):
1303 pass
1304 class D(object):
1305 pass # unrelated
1306 class X(D, C, B):
1307 def __call__(self):
1308 pass # implies Callable
1309 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
1310 for abcs in permutations([c.Sized, c.Callable, c.Container]):
1311 self.assertEqual(mro(X, abcs=abcs), expected)
1312 # unrelated ABCs don't appear in the resulting MRO
1313 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
1314 self.assertEqual(mro(X, abcs=many_abcs), expected)
1315
Łukasz Langa6f692512013-06-05 12:20:24 +02001316 def test_mro_conflicts(self):
1317 c = collections
Łukasz Langa6f692512013-06-05 12:20:24 +02001318 @functools.singledispatch
1319 def g(arg):
1320 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02001321 class O(c.Sized):
1322 def __len__(self):
1323 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001324 o = O()
1325 self.assertEqual(g(o), "base")
1326 g.register(c.Iterable, lambda arg: "iterable")
1327 g.register(c.Container, lambda arg: "container")
1328 g.register(c.Sized, lambda arg: "sized")
1329 g.register(c.Set, lambda arg: "set")
1330 self.assertEqual(g(o), "sized")
1331 c.Iterable.register(O)
1332 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
1333 c.Container.register(O)
1334 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02001335 c.Set.register(O)
1336 self.assertEqual(g(o), "set") # because c.Set is a subclass of
1337 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02001338 class P:
1339 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02001340 p = P()
1341 self.assertEqual(g(p), "base")
1342 c.Iterable.register(P)
1343 self.assertEqual(g(p), "iterable")
1344 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02001345 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02001346 g(p)
Łukasz Langa3720c772013-07-01 16:00:38 +02001347 self.assertIn(
1348 str(re_one.exception),
1349 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1350 "or <class 'collections.abc.Iterable'>"),
1351 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
1352 "or <class 'collections.abc.Container'>")),
1353 )
Łukasz Langa6f692512013-06-05 12:20:24 +02001354 class Q(c.Sized):
1355 def __len__(self):
1356 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001357 q = Q()
1358 self.assertEqual(g(q), "sized")
1359 c.Iterable.register(Q)
1360 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
1361 c.Set.register(Q)
1362 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02001363 # c.Sized and c.Iterable
1364 @functools.singledispatch
1365 def h(arg):
1366 return "base"
1367 @h.register(c.Sized)
1368 def _(arg):
1369 return "sized"
1370 @h.register(c.Container)
1371 def _(arg):
1372 return "container"
1373 # Even though Sized and Container are explicit bases of MutableMapping,
1374 # this ABC is implicitly registered on defaultdict which makes all of
1375 # MutableMapping's bases implicit as well from defaultdict's
1376 # perspective.
1377 with self.assertRaises(RuntimeError) as re_two:
1378 h(c.defaultdict(lambda: 0))
1379 self.assertIn(
1380 str(re_two.exception),
1381 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1382 "or <class 'collections.abc.Sized'>"),
1383 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1384 "or <class 'collections.abc.Container'>")),
1385 )
1386 class R(c.defaultdict):
1387 pass
1388 c.MutableSequence.register(R)
1389 @functools.singledispatch
1390 def i(arg):
1391 return "base"
1392 @i.register(c.MutableMapping)
1393 def _(arg):
1394 return "mapping"
1395 @i.register(c.MutableSequence)
1396 def _(arg):
1397 return "sequence"
1398 r = R()
1399 self.assertEqual(i(r), "sequence")
1400 class S:
1401 pass
1402 class T(S, c.Sized):
1403 def __len__(self):
1404 return 0
1405 t = T()
1406 self.assertEqual(h(t), "sized")
1407 c.Container.register(T)
1408 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
1409 class U:
1410 def __len__(self):
1411 return 0
1412 u = U()
1413 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
1414 # from the existence of __len__()
1415 c.Container.register(U)
1416 # There is no preference for registered versus inferred ABCs.
1417 with self.assertRaises(RuntimeError) as re_three:
1418 h(u)
1419 self.assertIn(
1420 str(re_three.exception),
1421 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1422 "or <class 'collections.abc.Sized'>"),
1423 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1424 "or <class 'collections.abc.Container'>")),
1425 )
1426 class V(c.Sized, S):
1427 def __len__(self):
1428 return 0
1429 @functools.singledispatch
1430 def j(arg):
1431 return "base"
1432 @j.register(S)
1433 def _(arg):
1434 return "s"
1435 @j.register(c.Container)
1436 def _(arg):
1437 return "container"
1438 v = V()
1439 self.assertEqual(j(v), "s")
1440 c.Container.register(V)
1441 self.assertEqual(j(v), "container") # because it ends up right after
1442 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02001443
1444 def test_cache_invalidation(self):
1445 from collections import UserDict
1446 class TracingDict(UserDict):
1447 def __init__(self, *args, **kwargs):
1448 super(TracingDict, self).__init__(*args, **kwargs)
1449 self.set_ops = []
1450 self.get_ops = []
1451 def __getitem__(self, key):
1452 result = self.data[key]
1453 self.get_ops.append(key)
1454 return result
1455 def __setitem__(self, key, value):
1456 self.set_ops.append(key)
1457 self.data[key] = value
1458 def clear(self):
1459 self.data.clear()
1460 _orig_wkd = functools.WeakKeyDictionary
1461 td = TracingDict()
1462 functools.WeakKeyDictionary = lambda: td
1463 c = collections
1464 @functools.singledispatch
1465 def g(arg):
1466 return "base"
1467 d = {}
1468 l = []
1469 self.assertEqual(len(td), 0)
1470 self.assertEqual(g(d), "base")
1471 self.assertEqual(len(td), 1)
1472 self.assertEqual(td.get_ops, [])
1473 self.assertEqual(td.set_ops, [dict])
1474 self.assertEqual(td.data[dict], g.registry[object])
1475 self.assertEqual(g(l), "base")
1476 self.assertEqual(len(td), 2)
1477 self.assertEqual(td.get_ops, [])
1478 self.assertEqual(td.set_ops, [dict, list])
1479 self.assertEqual(td.data[dict], g.registry[object])
1480 self.assertEqual(td.data[list], g.registry[object])
1481 self.assertEqual(td.data[dict], td.data[list])
1482 self.assertEqual(g(l), "base")
1483 self.assertEqual(g(d), "base")
1484 self.assertEqual(td.get_ops, [list, dict])
1485 self.assertEqual(td.set_ops, [dict, list])
1486 g.register(list, lambda arg: "list")
1487 self.assertEqual(td.get_ops, [list, dict])
1488 self.assertEqual(len(td), 0)
1489 self.assertEqual(g(d), "base")
1490 self.assertEqual(len(td), 1)
1491 self.assertEqual(td.get_ops, [list, dict])
1492 self.assertEqual(td.set_ops, [dict, list, dict])
1493 self.assertEqual(td.data[dict],
1494 functools._find_impl(dict, g.registry))
1495 self.assertEqual(g(l), "list")
1496 self.assertEqual(len(td), 2)
1497 self.assertEqual(td.get_ops, [list, dict])
1498 self.assertEqual(td.set_ops, [dict, list, dict, list])
1499 self.assertEqual(td.data[list],
1500 functools._find_impl(list, g.registry))
1501 class X:
1502 pass
1503 c.MutableMapping.register(X) # Will not invalidate the cache,
1504 # not using ABCs yet.
1505 self.assertEqual(g(d), "base")
1506 self.assertEqual(g(l), "list")
1507 self.assertEqual(td.get_ops, [list, dict, dict, list])
1508 self.assertEqual(td.set_ops, [dict, list, dict, list])
1509 g.register(c.Sized, lambda arg: "sized")
1510 self.assertEqual(len(td), 0)
1511 self.assertEqual(g(d), "sized")
1512 self.assertEqual(len(td), 1)
1513 self.assertEqual(td.get_ops, [list, dict, dict, list])
1514 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
1515 self.assertEqual(g(l), "list")
1516 self.assertEqual(len(td), 2)
1517 self.assertEqual(td.get_ops, [list, dict, dict, list])
1518 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1519 self.assertEqual(g(l), "list")
1520 self.assertEqual(g(d), "sized")
1521 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
1522 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1523 g.dispatch(list)
1524 g.dispatch(dict)
1525 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
1526 list, dict])
1527 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1528 c.MutableSet.register(X) # Will invalidate the cache.
1529 self.assertEqual(len(td), 2) # Stale cache.
1530 self.assertEqual(g(l), "list")
1531 self.assertEqual(len(td), 1)
1532 g.register(c.MutableMapping, lambda arg: "mutablemapping")
1533 self.assertEqual(len(td), 0)
1534 self.assertEqual(g(d), "mutablemapping")
1535 self.assertEqual(len(td), 1)
1536 self.assertEqual(g(l), "list")
1537 self.assertEqual(len(td), 2)
1538 g.register(dict, lambda arg: "dict")
1539 self.assertEqual(g(d), "dict")
1540 self.assertEqual(g(l), "list")
1541 g._clear_cache()
1542 self.assertEqual(len(td), 0)
1543 functools.WeakKeyDictionary = _orig_wkd
1544
1545
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001546def test_main(verbose=None):
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001547 test_classes = (
Antoine Pitroub5b37142012-11-13 21:35:40 +01001548 TestPartialC,
1549 TestPartialPy,
1550 TestPartialCSubclass,
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001551 TestPartialMethod,
Thomas Wouters73e5a5b2006-06-08 15:35:45 +00001552 TestUpdateWrapper,
Benjamin Peterson9c2930e2010-08-23 17:40:33 +00001553 TestTotalOrdering,
Antoine Pitroub5b37142012-11-13 21:35:40 +01001554 TestCmpToKeyC,
1555 TestCmpToKeyPy,
Guido van Rossum0919a1a2006-08-26 20:49:04 +00001556 TestWraps,
Georg Brandl2e7346a2010-07-31 18:09:23 +00001557 TestReduce,
1558 TestLRU,
Łukasz Langa6f692512013-06-05 12:20:24 +02001559 TestSingleDispatch,
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001560 )
Benjamin Petersonee8712c2008-05-20 21:35:26 +00001561 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001562
1563 # verify reference counting
1564 if verbose and hasattr(sys, "gettotalrefcount"):
1565 import gc
1566 counts = [None] * 5
Guido van Rossum805365e2007-05-07 22:24:25 +00001567 for i in range(len(counts)):
Benjamin Petersonee8712c2008-05-20 21:35:26 +00001568 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001569 gc.collect()
1570 counts[i] = sys.gettotalrefcount()
Guido van Rossumbe19ed72007-02-09 05:37:30 +00001571 print(counts)
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001572
1573if __name__ == '__main__':
1574 test_main(verbose=True)