blob: 75ae7f3a15d5b2d0021be3b75d8ccb1be13f9450 [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettinger003be522011-05-03 11:01:32 -07002import collections
Łukasz Langa6f692512013-06-05 12:20:24 +02003from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00004import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00005from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02006import sys
7from test import support
8import unittest
9from weakref import proxy
Raymond Hettinger9c323f82005-02-28 19:39:44 +000010
Antoine Pitroub5b37142012-11-13 21:35:40 +010011import functools
12
Antoine Pitroub5b37142012-11-13 21:35:40 +010013py_functools = support.import_fresh_module('functools', blocked=['_functools'])
14c_functools = support.import_fresh_module('functools', fresh=['_functools'])
15
Łukasz Langa6f692512013-06-05 12:20:24 +020016decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
17
18
Raymond Hettinger9c323f82005-02-28 19:39:44 +000019def capture(*args, **kw):
20 """capture all positional and keyword arguments"""
21 return args, kw
22
Łukasz Langa6f692512013-06-05 12:20:24 +020023
Jack Diederiche0cbd692009-04-01 04:27:09 +000024def signature(part):
25 """ return the signature of a partial object """
26 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000027
Łukasz Langa6f692512013-06-05 12:20:24 +020028
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020029class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000030
31 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010032 p = self.partial(capture, 1, 2, a=10, b=20)
33 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000034 self.assertEqual(p(3, 4, b=30, c=40),
35 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010036 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000037 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000038
39 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010040 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000041 # attributes should be readable
42 self.assertEqual(p.func, capture)
43 self.assertEqual(p.args, (1, 2))
44 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000045
46 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010047 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000048 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010049 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000050 except TypeError:
51 pass
52 else:
53 self.fail('First arg not checked for callability')
54
55 def test_protection_of_callers_dict_argument(self):
56 # a caller's dictionary should not be altered by partial
57 def func(a=10, b=20):
58 return a
59 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010060 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000061 self.assertEqual(p(**d), 3)
62 self.assertEqual(d, {'a':3})
63 p(b=7)
64 self.assertEqual(d, {'a':3})
65
66 def test_arg_combinations(self):
67 # exercise special code paths for zero args in either partial
68 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010069 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000070 self.assertEqual(p(), ((), {}))
71 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010072 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000073 self.assertEqual(p(), ((1,2), {}))
74 self.assertEqual(p(3,4), ((1,2,3,4), {}))
75
76 def test_kw_combinations(self):
77 # exercise special code paths for no keyword args in
78 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010079 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000080 self.assertEqual(p(), ((), {}))
81 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010082 p = self.partial(capture, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000083 self.assertEqual(p(), ((), {'a':1}))
84 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
85 # keyword args in the call override those in the partial object
86 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
87
88 def test_positional(self):
89 # make sure positional arguments are captured correctly
90 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +010091 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000092 expected = args + ('x',)
93 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +000094 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +000095
96 def test_keyword(self):
97 # make sure keyword arguments are captured correctly
98 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +010099 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000100 expected = {'a':a,'x':None}
101 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000102 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000103
104 def test_no_side_effects(self):
105 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100106 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000107 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000108 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000109 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000110 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000111
112 def test_error_propagation(self):
113 def f(x, y):
114 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100115 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
116 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
117 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
118 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000119
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000120 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100121 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000122 p = proxy(f)
123 self.assertEqual(f.func, p.func)
124 f = None
125 self.assertRaises(ReferenceError, getattr, p, 'func')
126
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000127 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000128 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100129 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000130 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100131 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000132 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000133
Łukasz Langa6f692512013-06-05 12:20:24 +0200134
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200135@unittest.skipUnless(c_functools, 'requires the C _functools module')
136class TestPartialC(TestPartial, unittest.TestCase):
137 if c_functools:
138 partial = c_functools.partial
139
Zachary Ware101d9e72013-12-08 00:44:27 -0600140 def test_attributes_unwritable(self):
141 # attributes should not be writable
142 p = self.partial(capture, 1, 2, a=10, b=20)
143 self.assertRaises(AttributeError, setattr, p, 'func', map)
144 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
145 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
146
147 p = self.partial(hex)
148 try:
149 del p.__dict__
150 except TypeError:
151 pass
152 else:
153 self.fail('partial object allowed __dict__ to be deleted')
154
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000155 def test_repr(self):
156 args = (object(), object())
157 args_repr = ', '.join(repr(a) for a in args)
Christian Heimesd0628922013-11-22 01:22:47 +0100158 #kwargs = {'a': object(), 'b': object()}
159 kwargs = {'a': object()}
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000160 kwargs_repr = ', '.join("%s=%r" % (k, v) for k, v in kwargs.items())
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200161 if self.partial is c_functools.partial:
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000162 name = 'functools.partial'
163 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100164 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000165
Antoine Pitroub5b37142012-11-13 21:35:40 +0100166 f = self.partial(capture)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000167 self.assertEqual('{}({!r})'.format(name, capture),
168 repr(f))
169
Antoine Pitroub5b37142012-11-13 21:35:40 +0100170 f = self.partial(capture, *args)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000171 self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr),
172 repr(f))
173
Antoine Pitroub5b37142012-11-13 21:35:40 +0100174 f = self.partial(capture, **kwargs)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000175 self.assertEqual('{}({!r}, {})'.format(name, capture, kwargs_repr),
176 repr(f))
177
Antoine Pitroub5b37142012-11-13 21:35:40 +0100178 f = self.partial(capture, *args, **kwargs)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000179 self.assertEqual('{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr),
180 repr(f))
181
Jack Diederiche0cbd692009-04-01 04:27:09 +0000182 def test_pickle(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100183 f = self.partial(signature, 'asdf', bar=True)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000184 f.add_something_to__dict__ = True
185 f_copy = pickle.loads(pickle.dumps(f))
186 self.assertEqual(signature(f), signature(f_copy))
187
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200188 # Issue 6083: Reference counting bug
189 def test_setstate_refcount(self):
190 class BadSequence:
191 def __len__(self):
192 return 4
193 def __getitem__(self, key):
194 if key == 0:
195 return max
196 elif key == 1:
197 return tuple(range(1000000))
198 elif key in (2, 3):
199 return {}
200 raise IndexError
201
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200202 f = self.partial(object)
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200203 self.assertRaisesRegex(SystemError,
204 "new style getargs format but argument is not a tuple",
205 f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000206
Łukasz Langa6f692512013-06-05 12:20:24 +0200207
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200208class TestPartialPy(TestPartial, unittest.TestCase):
209 partial = staticmethod(py_functools.partial)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000210
Łukasz Langa6f692512013-06-05 12:20:24 +0200211
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200212if c_functools:
213 class PartialSubclass(c_functools.partial):
214 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100215
Łukasz Langa6f692512013-06-05 12:20:24 +0200216
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200217@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200218class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200219 if c_functools:
220 partial = PartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000221
Łukasz Langa6f692512013-06-05 12:20:24 +0200222
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000223class TestPartialMethod(unittest.TestCase):
224
225 class A(object):
226 nothing = functools.partialmethod(capture)
227 positional = functools.partialmethod(capture, 1)
228 keywords = functools.partialmethod(capture, a=2)
229 both = functools.partialmethod(capture, 3, b=4)
230
231 nested = functools.partialmethod(positional, 5)
232
233 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
234
235 static = functools.partialmethod(staticmethod(capture), 8)
236 cls = functools.partialmethod(classmethod(capture), d=9)
237
238 a = A()
239
240 def test_arg_combinations(self):
241 self.assertEqual(self.a.nothing(), ((self.a,), {}))
242 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
243 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
244 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
245
246 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
247 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
248 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
249 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
250
251 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
252 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
253 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
254 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
255
256 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
257 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
258 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
259 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
260
261 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
262
263 def test_nested(self):
264 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
265 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
266 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
267 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
268
269 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
270
271 def test_over_partial(self):
272 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
273 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
274 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
275 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
276
277 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
278
279 def test_bound_method_introspection(self):
280 obj = self.a
281 self.assertIs(obj.both.__self__, obj)
282 self.assertIs(obj.nested.__self__, obj)
283 self.assertIs(obj.over_partial.__self__, obj)
284 self.assertIs(obj.cls.__self__, self.A)
285 self.assertIs(self.A.cls.__self__, self.A)
286
287 def test_unbound_method_retrieval(self):
288 obj = self.A
289 self.assertFalse(hasattr(obj.both, "__self__"))
290 self.assertFalse(hasattr(obj.nested, "__self__"))
291 self.assertFalse(hasattr(obj.over_partial, "__self__"))
292 self.assertFalse(hasattr(obj.static, "__self__"))
293 self.assertFalse(hasattr(self.a.static, "__self__"))
294
295 def test_descriptors(self):
296 for obj in [self.A, self.a]:
297 with self.subTest(obj=obj):
298 self.assertEqual(obj.static(), ((8,), {}))
299 self.assertEqual(obj.static(5), ((8, 5), {}))
300 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
301 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
302
303 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
304 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
305 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
306 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
307
308 def test_overriding_keywords(self):
309 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
310 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
311
312 def test_invalid_args(self):
313 with self.assertRaises(TypeError):
314 class B(object):
315 method = functools.partialmethod(None, 1)
316
317 def test_repr(self):
318 self.assertEqual(repr(vars(self.A)['both']),
319 'functools.partialmethod({}, 3, b=4)'.format(capture))
320
321 def test_abstract(self):
322 class Abstract(abc.ABCMeta):
323
324 @abc.abstractmethod
325 def add(self, x, y):
326 pass
327
328 add5 = functools.partialmethod(add, 5)
329
330 self.assertTrue(Abstract.add.__isabstractmethod__)
331 self.assertTrue(Abstract.add5.__isabstractmethod__)
332
333 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
334 self.assertFalse(getattr(func, '__isabstractmethod__', False))
335
336
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000337class TestUpdateWrapper(unittest.TestCase):
338
339 def check_wrapper(self, wrapper, wrapped,
340 assigned=functools.WRAPPER_ASSIGNMENTS,
341 updated=functools.WRAPPER_UPDATES):
342 # Check attributes were assigned
343 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000344 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000345 # Check attributes were updated
346 for name in updated:
347 wrapper_attr = getattr(wrapper, name)
348 wrapped_attr = getattr(wrapped, name)
349 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000350 if name == "__dict__" and key == "__wrapped__":
351 # __wrapped__ is overwritten by the update code
352 continue
353 self.assertIs(wrapped_attr[key], wrapper_attr[key])
354 # Check __wrapped__
355 self.assertIs(wrapper.__wrapped__, wrapped)
356
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000357
R. David Murray378c0cf2010-02-24 01:46:21 +0000358 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000359 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000360 """This is a test"""
361 pass
362 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000363 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000364 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000365 pass
366 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000367 return wrapper, f
368
369 def test_default_update(self):
370 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000371 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000372 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000373 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600374 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000375 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000376 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
377 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000378
R. David Murray378c0cf2010-02-24 01:46:21 +0000379 @unittest.skipIf(sys.flags.optimize >= 2,
380 "Docstrings are omitted with -O2 and above")
381 def test_default_update_doc(self):
382 wrapper, f = self._default_update()
383 self.assertEqual(wrapper.__doc__, 'This is a test')
384
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000385 def test_no_update(self):
386 def f():
387 """This is a test"""
388 pass
389 f.attr = 'This is also a test'
390 def wrapper():
391 pass
392 functools.update_wrapper(wrapper, f, (), ())
393 self.check_wrapper(wrapper, f, (), ())
394 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600395 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000396 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000397 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000398 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000399
400 def test_selective_update(self):
401 def f():
402 pass
403 f.attr = 'This is a different test'
404 f.dict_attr = dict(a=1, b=2, c=3)
405 def wrapper():
406 pass
407 wrapper.dict_attr = {}
408 assign = ('attr',)
409 update = ('dict_attr',)
410 functools.update_wrapper(wrapper, f, assign, update)
411 self.check_wrapper(wrapper, f, assign, update)
412 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600413 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000414 self.assertEqual(wrapper.__doc__, None)
415 self.assertEqual(wrapper.attr, 'This is a different test')
416 self.assertEqual(wrapper.dict_attr, f.dict_attr)
417
Nick Coghlan98876832010-08-17 06:17:18 +0000418 def test_missing_attributes(self):
419 def f():
420 pass
421 def wrapper():
422 pass
423 wrapper.dict_attr = {}
424 assign = ('attr',)
425 update = ('dict_attr',)
426 # Missing attributes on wrapped object are ignored
427 functools.update_wrapper(wrapper, f, assign, update)
428 self.assertNotIn('attr', wrapper.__dict__)
429 self.assertEqual(wrapper.dict_attr, {})
430 # Wrapper must have expected attributes for updating
431 del wrapper.dict_attr
432 with self.assertRaises(AttributeError):
433 functools.update_wrapper(wrapper, f, assign, update)
434 wrapper.dict_attr = 1
435 with self.assertRaises(AttributeError):
436 functools.update_wrapper(wrapper, f, assign, update)
437
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200438 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000439 @unittest.skipIf(sys.flags.optimize >= 2,
440 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000441 def test_builtin_update(self):
442 # Test for bug #1576241
443 def wrapper():
444 pass
445 functools.update_wrapper(wrapper, max)
446 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000447 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000448 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000449
Łukasz Langa6f692512013-06-05 12:20:24 +0200450
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000451class TestWraps(TestUpdateWrapper):
452
R. David Murray378c0cf2010-02-24 01:46:21 +0000453 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000454 def f():
455 """This is a test"""
456 pass
457 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000458 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000459 @functools.wraps(f)
460 def wrapper():
461 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600462 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000463
464 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600465 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000466 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000467 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600468 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000469 self.assertEqual(wrapper.attr, 'This is also a test')
470
Antoine Pitroub5b37142012-11-13 21:35:40 +0100471 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000472 "Docstrings are omitted with -O2 and above")
473 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600474 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000475 self.assertEqual(wrapper.__doc__, 'This is a test')
476
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000477 def test_no_update(self):
478 def f():
479 """This is a test"""
480 pass
481 f.attr = 'This is also a test'
482 @functools.wraps(f, (), ())
483 def wrapper():
484 pass
485 self.check_wrapper(wrapper, f, (), ())
486 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600487 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000488 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000489 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000490
491 def test_selective_update(self):
492 def f():
493 pass
494 f.attr = 'This is a different test'
495 f.dict_attr = dict(a=1, b=2, c=3)
496 def add_dict_attr(f):
497 f.dict_attr = {}
498 return f
499 assign = ('attr',)
500 update = ('dict_attr',)
501 @functools.wraps(f, assign, update)
502 @add_dict_attr
503 def wrapper():
504 pass
505 self.check_wrapper(wrapper, f, assign, update)
506 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600507 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000508 self.assertEqual(wrapper.__doc__, None)
509 self.assertEqual(wrapper.attr, 'This is a different test')
510 self.assertEqual(wrapper.dict_attr, f.dict_attr)
511
Łukasz Langa6f692512013-06-05 12:20:24 +0200512
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000513class TestReduce(unittest.TestCase):
514 func = functools.reduce
515
516 def test_reduce(self):
517 class Squares:
518 def __init__(self, max):
519 self.max = max
520 self.sofar = []
521
522 def __len__(self):
523 return len(self.sofar)
524
525 def __getitem__(self, i):
526 if not 0 <= i < self.max: raise IndexError
527 n = len(self.sofar)
528 while n <= i:
529 self.sofar.append(n*n)
530 n += 1
531 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000532 def add(x, y):
533 return x + y
534 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000535 self.assertEqual(
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000536 self.func(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000537 ['a','c','d','w']
538 )
539 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
540 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000541 self.func(lambda x, y: x*y, range(2,21), 1),
542 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000543 )
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000544 self.assertEqual(self.func(add, Squares(10)), 285)
545 self.assertEqual(self.func(add, Squares(10), 0), 285)
546 self.assertEqual(self.func(add, Squares(0), 0), 0)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000547 self.assertRaises(TypeError, self.func)
548 self.assertRaises(TypeError, self.func, 42, 42)
549 self.assertRaises(TypeError, self.func, 42, 42, 42)
550 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
551 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
552 self.assertRaises(TypeError, self.func, 42, (42, 42))
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000553 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
554 self.assertRaises(TypeError, self.func, add, "")
555 self.assertRaises(TypeError, self.func, add, ())
556 self.assertRaises(TypeError, self.func, add, object())
557
558 class TestFailingIter:
559 def __iter__(self):
560 raise RuntimeError
561 self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
562
563 self.assertEqual(self.func(add, [], None), None)
564 self.assertEqual(self.func(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000565
566 class BadSeq:
567 def __getitem__(self, index):
568 raise ValueError
569 self.assertRaises(ValueError, self.func, 42, BadSeq())
570
571 # Test reduce()'s use of iterators.
572 def test_iterator_usage(self):
573 class SequenceClass:
574 def __init__(self, n):
575 self.n = n
576 def __getitem__(self, i):
577 if 0 <= i < self.n:
578 return i
579 else:
580 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000581
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000582 from operator import add
583 self.assertEqual(self.func(add, SequenceClass(5)), 10)
584 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
585 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
586 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
587 self.assertEqual(self.func(add, SequenceClass(1)), 0)
588 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
589
590 d = {"one": 1, "two": 2, "three": 3}
591 self.assertEqual(self.func(add, d), "".join(d.keys()))
592
Łukasz Langa6f692512013-06-05 12:20:24 +0200593
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200594class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700595
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000596 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700597 def cmp1(x, y):
598 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100599 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700600 self.assertEqual(key(3), key(3))
601 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100602 self.assertGreaterEqual(key(3), key(3))
603
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700604 def cmp2(x, y):
605 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100606 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700607 self.assertEqual(key(4.0), key('4'))
608 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100609 self.assertLessEqual(key(2), key('35'))
610 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700611
612 def test_cmp_to_key_arguments(self):
613 def cmp1(x, y):
614 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100615 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700616 self.assertEqual(key(obj=3), key(obj=3))
617 self.assertGreater(key(obj=3), key(obj=1))
618 with self.assertRaises((TypeError, AttributeError)):
619 key(3) > 1 # rhs is not a K object
620 with self.assertRaises((TypeError, AttributeError)):
621 1 < key(3) # lhs is not a K object
622 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100623 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700624 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200625 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100626 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700627 with self.assertRaises(TypeError):
628 key() # too few args
629 with self.assertRaises(TypeError):
630 key(None, None) # too many args
631
632 def test_bad_cmp(self):
633 def cmp1(x, y):
634 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100635 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700636 with self.assertRaises(ZeroDivisionError):
637 key(3) > key(1)
638
639 class BadCmp:
640 def __lt__(self, other):
641 raise ZeroDivisionError
642 def cmp1(x, y):
643 return BadCmp()
644 with self.assertRaises(ZeroDivisionError):
645 key(3) > key(1)
646
647 def test_obj_field(self):
648 def cmp1(x, y):
649 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100650 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700651 self.assertEqual(key(50).obj, 50)
652
653 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000654 def mycmp(x, y):
655 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100656 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000657 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000658
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700659 def test_sort_int_str(self):
660 def mycmp(x, y):
661 x, y = int(x), int(y)
662 return (x > y) - (x < y)
663 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100664 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700665 self.assertEqual([int(value) for value in values],
666 [0, 1, 1, 2, 3, 4, 5, 7, 10])
667
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000668 def test_hash(self):
669 def mycmp(x, y):
670 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100671 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000672 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700673 self.assertRaises(TypeError, hash, k)
Raymond Hettingere7a24302011-05-03 11:16:36 -0700674 self.assertNotIsInstance(k, collections.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000675
Łukasz Langa6f692512013-06-05 12:20:24 +0200676
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200677@unittest.skipUnless(c_functools, 'requires the C _functools module')
678class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
679 if c_functools:
680 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100681
Łukasz Langa6f692512013-06-05 12:20:24 +0200682
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200683class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100684 cmp_to_key = staticmethod(py_functools.cmp_to_key)
685
Łukasz Langa6f692512013-06-05 12:20:24 +0200686
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000687class TestTotalOrdering(unittest.TestCase):
688
689 def test_total_ordering_lt(self):
690 @functools.total_ordering
691 class A:
692 def __init__(self, value):
693 self.value = value
694 def __lt__(self, other):
695 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000696 def __eq__(self, other):
697 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000698 self.assertTrue(A(1) < A(2))
699 self.assertTrue(A(2) > A(1))
700 self.assertTrue(A(1) <= A(2))
701 self.assertTrue(A(2) >= A(1))
702 self.assertTrue(A(2) <= A(2))
703 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000704 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000705
706 def test_total_ordering_le(self):
707 @functools.total_ordering
708 class A:
709 def __init__(self, value):
710 self.value = value
711 def __le__(self, other):
712 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000713 def __eq__(self, other):
714 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000715 self.assertTrue(A(1) < A(2))
716 self.assertTrue(A(2) > A(1))
717 self.assertTrue(A(1) <= A(2))
718 self.assertTrue(A(2) >= A(1))
719 self.assertTrue(A(2) <= A(2))
720 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000721 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000722
723 def test_total_ordering_gt(self):
724 @functools.total_ordering
725 class A:
726 def __init__(self, value):
727 self.value = value
728 def __gt__(self, other):
729 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000730 def __eq__(self, other):
731 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000732 self.assertTrue(A(1) < A(2))
733 self.assertTrue(A(2) > A(1))
734 self.assertTrue(A(1) <= A(2))
735 self.assertTrue(A(2) >= A(1))
736 self.assertTrue(A(2) <= A(2))
737 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000738 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000739
740 def test_total_ordering_ge(self):
741 @functools.total_ordering
742 class A:
743 def __init__(self, value):
744 self.value = value
745 def __ge__(self, other):
746 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000747 def __eq__(self, other):
748 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000749 self.assertTrue(A(1) < A(2))
750 self.assertTrue(A(2) > A(1))
751 self.assertTrue(A(1) <= A(2))
752 self.assertTrue(A(2) >= A(1))
753 self.assertTrue(A(2) <= A(2))
754 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000755 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000756
757 def test_total_ordering_no_overwrite(self):
758 # new methods should not overwrite existing
759 @functools.total_ordering
760 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000761 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +0000762 self.assertTrue(A(1) < A(2))
763 self.assertTrue(A(2) > A(1))
764 self.assertTrue(A(1) <= A(2))
765 self.assertTrue(A(2) >= A(1))
766 self.assertTrue(A(2) <= A(2))
767 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000768
Benjamin Peterson42ebee32010-04-11 01:43:16 +0000769 def test_no_operations_defined(self):
770 with self.assertRaises(ValueError):
771 @functools.total_ordering
772 class A:
773 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000774
Nick Coghlanf05d9812013-10-02 00:02:03 +1000775 def test_type_error_when_not_implemented(self):
776 # bug 10042; ensure stack overflow does not occur
777 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000778 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +1000779 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000780 def __init__(self, value):
781 self.value = value
782 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000783 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000784 return self.value == other.value
785 return False
786 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000787 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000788 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +1000789 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000790
Nick Coghlanf05d9812013-10-02 00:02:03 +1000791 @functools.total_ordering
792 class ImplementsGreaterThan:
793 def __init__(self, value):
794 self.value = value
795 def __eq__(self, other):
796 if isinstance(other, ImplementsGreaterThan):
797 return self.value == other.value
798 return False
799 def __gt__(self, other):
800 if isinstance(other, ImplementsGreaterThan):
801 return self.value > other.value
802 return NotImplemented
803
804 @functools.total_ordering
805 class ImplementsLessThanEqualTo:
806 def __init__(self, value):
807 self.value = value
808 def __eq__(self, other):
809 if isinstance(other, ImplementsLessThanEqualTo):
810 return self.value == other.value
811 return False
812 def __le__(self, other):
813 if isinstance(other, ImplementsLessThanEqualTo):
814 return self.value <= other.value
815 return NotImplemented
816
817 @functools.total_ordering
818 class ImplementsGreaterThanEqualTo:
819 def __init__(self, value):
820 self.value = value
821 def __eq__(self, other):
822 if isinstance(other, ImplementsGreaterThanEqualTo):
823 return self.value == other.value
824 return False
825 def __ge__(self, other):
826 if isinstance(other, ImplementsGreaterThanEqualTo):
827 return self.value >= other.value
828 return NotImplemented
829
830 @functools.total_ordering
831 class ComparatorNotImplemented:
832 def __init__(self, value):
833 self.value = value
834 def __eq__(self, other):
835 if isinstance(other, ComparatorNotImplemented):
836 return self.value == other.value
837 return False
838 def __lt__(self, other):
839 return NotImplemented
840
841 with self.subTest("LT < 1"), self.assertRaises(TypeError):
842 ImplementsLessThan(-1) < 1
843
844 with self.subTest("LT < LE"), self.assertRaises(TypeError):
845 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
846
847 with self.subTest("LT < GT"), self.assertRaises(TypeError):
848 ImplementsLessThan(1) < ImplementsGreaterThan(1)
849
850 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
851 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
852
853 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
854 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
855
856 with self.subTest("GT > GE"), self.assertRaises(TypeError):
857 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
858
859 with self.subTest("GT > LT"), self.assertRaises(TypeError):
860 ImplementsGreaterThan(5) > ImplementsLessThan(5)
861
862 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
863 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
864
865 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
866 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
867
868 with self.subTest("GE when equal"):
869 a = ComparatorNotImplemented(8)
870 b = ComparatorNotImplemented(8)
871 self.assertEqual(a, b)
872 with self.assertRaises(TypeError):
873 a >= b
874
875 with self.subTest("LE when equal"):
876 a = ComparatorNotImplemented(9)
877 b = ComparatorNotImplemented(9)
878 self.assertEqual(a, b)
879 with self.assertRaises(TypeError):
880 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +0200881
Georg Brandl2e7346a2010-07-31 18:09:23 +0000882class TestLRU(unittest.TestCase):
883
884 def test_lru(self):
885 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100886 return 3 * x + y
Georg Brandl2e7346a2010-07-31 18:09:23 +0000887 f = functools.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000888 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000889 self.assertEqual(maxsize, 20)
890 self.assertEqual(currsize, 0)
891 self.assertEqual(hits, 0)
892 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000893
894 domain = range(5)
895 for i in range(1000):
896 x, y = choice(domain), choice(domain)
897 actual = f(x, y)
898 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +0000899 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000900 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000901 self.assertTrue(hits > misses)
902 self.assertEqual(hits + misses, 1000)
903 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000904
Raymond Hettinger02566ec2010-09-04 22:46:06 +0000905 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +0000906 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000907 self.assertEqual(hits, 0)
908 self.assertEqual(misses, 0)
909 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000910 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000911 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000912 self.assertEqual(hits, 0)
913 self.assertEqual(misses, 1)
914 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000915
Nick Coghlan98876832010-08-17 06:17:18 +0000916 # Test bypassing the cache
917 self.assertIs(f.__wrapped__, orig)
918 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000919 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000920 self.assertEqual(hits, 0)
921 self.assertEqual(misses, 1)
922 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +0000923
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000924 # test size zero (which means "never-cache")
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000925 @functools.lru_cache(0)
926 def f():
927 nonlocal f_cnt
928 f_cnt += 1
929 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000930 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000931 f_cnt = 0
932 for i in range(5):
933 self.assertEqual(f(), 20)
934 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000935 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000936 self.assertEqual(hits, 0)
937 self.assertEqual(misses, 5)
938 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000939
940 # test size one
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000941 @functools.lru_cache(1)
942 def f():
943 nonlocal f_cnt
944 f_cnt += 1
945 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000946 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000947 f_cnt = 0
948 for i in range(5):
949 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000950 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000951 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000952 self.assertEqual(hits, 4)
953 self.assertEqual(misses, 1)
954 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000955
Raymond Hettingerf3098282010-08-15 03:30:45 +0000956 # test size two
957 @functools.lru_cache(2)
958 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000959 nonlocal f_cnt
960 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +0000961 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +0000962 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000963 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +0000964 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
965 # * * * *
966 self.assertEqual(f(x), x*10)
967 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000968 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000969 self.assertEqual(hits, 12)
970 self.assertEqual(misses, 4)
971 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000972
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +0000973 def test_lru_with_maxsize_none(self):
974 @functools.lru_cache(maxsize=None)
975 def fib(n):
976 if n < 2:
977 return n
978 return fib(n-1) + fib(n-2)
979 self.assertEqual([fib(n) for n in range(16)],
980 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
981 self.assertEqual(fib.cache_info(),
982 functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
983 fib.cache_clear()
984 self.assertEqual(fib.cache_info(),
985 functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
986
Raymond Hettinger4b779b32011-10-15 23:50:42 -0700987 def test_lru_with_exceptions(self):
988 # Verify that user_function exceptions get passed through without
989 # creating a hard-to-read chained exception.
990 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +0100991 for maxsize in (None, 128):
Raymond Hettinger4b779b32011-10-15 23:50:42 -0700992 @functools.lru_cache(maxsize)
993 def func(i):
994 return 'abc'[i]
995 self.assertEqual(func(0), 'a')
996 with self.assertRaises(IndexError) as cm:
997 func(15)
998 self.assertIsNone(cm.exception.__context__)
999 # Verify that the previous exception did not result in a cached entry
1000 with self.assertRaises(IndexError):
1001 func(15)
1002
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001003 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001004 for maxsize in (None, 128):
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001005 @functools.lru_cache(maxsize=maxsize, typed=True)
1006 def square(x):
1007 return x * x
1008 self.assertEqual(square(3), 9)
1009 self.assertEqual(type(square(3)), type(9))
1010 self.assertEqual(square(3.0), 9.0)
1011 self.assertEqual(type(square(3.0)), type(9.0))
1012 self.assertEqual(square(x=3), 9)
1013 self.assertEqual(type(square(x=3)), type(9))
1014 self.assertEqual(square(x=3.0), 9.0)
1015 self.assertEqual(type(square(x=3.0)), type(9.0))
1016 self.assertEqual(square.cache_info().hits, 4)
1017 self.assertEqual(square.cache_info().misses, 4)
1018
Antoine Pitroub5b37142012-11-13 21:35:40 +01001019 def test_lru_with_keyword_args(self):
1020 @functools.lru_cache()
1021 def fib(n):
1022 if n < 2:
1023 return n
1024 return fib(n=n-1) + fib(n=n-2)
1025 self.assertEqual(
1026 [fib(n=number) for number in range(16)],
1027 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1028 )
1029 self.assertEqual(fib.cache_info(),
1030 functools._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
1031 fib.cache_clear()
1032 self.assertEqual(fib.cache_info(),
1033 functools._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
1034
1035 def test_lru_with_keyword_args_maxsize_none(self):
1036 @functools.lru_cache(maxsize=None)
1037 def fib(n):
1038 if n < 2:
1039 return n
1040 return fib(n=n-1) + fib(n=n-2)
1041 self.assertEqual([fib(n=number) for number in range(16)],
1042 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1043 self.assertEqual(fib.cache_info(),
1044 functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1045 fib.cache_clear()
1046 self.assertEqual(fib.cache_info(),
1047 functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1048
Raymond Hettinger03923422013-03-04 02:52:50 -05001049 def test_need_for_rlock(self):
1050 # This will deadlock on an LRU cache that uses a regular lock
1051
1052 @functools.lru_cache(maxsize=10)
1053 def test_func(x):
1054 'Used to demonstrate a reentrant lru_cache call within a single thread'
1055 return x
1056
1057 class DoubleEq:
1058 'Demonstrate a reentrant lru_cache call within a single thread'
1059 def __init__(self, x):
1060 self.x = x
1061 def __hash__(self):
1062 return self.x
1063 def __eq__(self, other):
1064 if self.x == 2:
1065 test_func(DoubleEq(1))
1066 return self.x == other.x
1067
1068 test_func(DoubleEq(1)) # Load the cache
1069 test_func(DoubleEq(2)) # Load the cache
1070 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1071 DoubleEq(2)) # Verify the correct return value
1072
1073
Łukasz Langa6f692512013-06-05 12:20:24 +02001074class TestSingleDispatch(unittest.TestCase):
1075 def test_simple_overloads(self):
1076 @functools.singledispatch
1077 def g(obj):
1078 return "base"
1079 def g_int(i):
1080 return "integer"
1081 g.register(int, g_int)
1082 self.assertEqual(g("str"), "base")
1083 self.assertEqual(g(1), "integer")
1084 self.assertEqual(g([1,2,3]), "base")
1085
1086 def test_mro(self):
1087 @functools.singledispatch
1088 def g(obj):
1089 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001090 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001091 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001092 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001093 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001094 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001095 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001096 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001097 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001098 def g_A(a):
1099 return "A"
1100 def g_B(b):
1101 return "B"
1102 g.register(A, g_A)
1103 g.register(B, g_B)
1104 self.assertEqual(g(A()), "A")
1105 self.assertEqual(g(B()), "B")
1106 self.assertEqual(g(C()), "A")
1107 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001108
1109 def test_register_decorator(self):
1110 @functools.singledispatch
1111 def g(obj):
1112 return "base"
1113 @g.register(int)
1114 def g_int(i):
1115 return "int %s" % (i,)
1116 self.assertEqual(g(""), "base")
1117 self.assertEqual(g(12), "int 12")
1118 self.assertIs(g.dispatch(int), g_int)
1119 self.assertIs(g.dispatch(object), g.dispatch(str))
1120 # Note: in the assert above this is not g.
1121 # @singledispatch returns the wrapper.
1122
1123 def test_wrapping_attributes(self):
1124 @functools.singledispatch
1125 def g(obj):
1126 "Simple test"
1127 return "Test"
1128 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02001129 if sys.flags.optimize < 2:
1130 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02001131
1132 @unittest.skipUnless(decimal, 'requires _decimal')
1133 @support.cpython_only
1134 def test_c_classes(self):
1135 @functools.singledispatch
1136 def g(obj):
1137 return "base"
1138 @g.register(decimal.DecimalException)
1139 def _(obj):
1140 return obj.args
1141 subn = decimal.Subnormal("Exponent < Emin")
1142 rnd = decimal.Rounded("Number got rounded")
1143 self.assertEqual(g(subn), ("Exponent < Emin",))
1144 self.assertEqual(g(rnd), ("Number got rounded",))
1145 @g.register(decimal.Subnormal)
1146 def _(obj):
1147 return "Too small to care."
1148 self.assertEqual(g(subn), "Too small to care.")
1149 self.assertEqual(g(rnd), ("Number got rounded",))
1150
1151 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001152 # None of the examples in this test depend on haystack ordering.
Łukasz Langa6f692512013-06-05 12:20:24 +02001153 c = collections
1154 mro = functools._compose_mro
1155 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1156 for haystack in permutations(bases):
1157 m = mro(dict, haystack)
Łukasz Langa3720c772013-07-01 16:00:38 +02001158 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, c.Sized,
1159 c.Iterable, c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001160 bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict]
1161 for haystack in permutations(bases):
1162 m = mro(c.ChainMap, haystack)
1163 self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping,
1164 c.Sized, c.Iterable, c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001165
1166 # If there's a generic function with implementations registered for
1167 # both Sized and Container, passing a defaultdict to it results in an
1168 # ambiguous dispatch which will cause a RuntimeError (see
1169 # test_mro_conflicts).
1170 bases = [c.Container, c.Sized, str]
1171 for haystack in permutations(bases):
1172 m = mro(c.defaultdict, [c.Sized, c.Container, str])
1173 self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container,
1174 object])
1175
1176 # MutableSequence below is registered directly on D. In other words, it
1177 # preceeds MutableMapping which means single dispatch will always
1178 # choose MutableSequence here.
1179 class D(c.defaultdict):
1180 pass
1181 c.MutableSequence.register(D)
1182 bases = [c.MutableSequence, c.MutableMapping]
1183 for haystack in permutations(bases):
1184 m = mro(D, bases)
1185 self.assertEqual(m, [D, c.MutableSequence, c.Sequence,
1186 c.defaultdict, dict, c.MutableMapping,
1187 c.Mapping, c.Sized, c.Iterable, c.Container,
1188 object])
1189
1190 # Container and Callable are registered on different base classes and
1191 # a generic function supporting both should always pick the Callable
1192 # implementation if a C instance is passed.
1193 class C(c.defaultdict):
1194 def __call__(self):
1195 pass
1196 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1197 for haystack in permutations(bases):
1198 m = mro(C, haystack)
1199 self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping,
1200 c.Sized, c.Iterable, c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001201
1202 def test_register_abc(self):
1203 c = collections
1204 d = {"a": "b"}
1205 l = [1, 2, 3]
1206 s = {object(), None}
1207 f = frozenset(s)
1208 t = (1, 2, 3)
1209 @functools.singledispatch
1210 def g(obj):
1211 return "base"
1212 self.assertEqual(g(d), "base")
1213 self.assertEqual(g(l), "base")
1214 self.assertEqual(g(s), "base")
1215 self.assertEqual(g(f), "base")
1216 self.assertEqual(g(t), "base")
1217 g.register(c.Sized, lambda obj: "sized")
1218 self.assertEqual(g(d), "sized")
1219 self.assertEqual(g(l), "sized")
1220 self.assertEqual(g(s), "sized")
1221 self.assertEqual(g(f), "sized")
1222 self.assertEqual(g(t), "sized")
1223 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1224 self.assertEqual(g(d), "mutablemapping")
1225 self.assertEqual(g(l), "sized")
1226 self.assertEqual(g(s), "sized")
1227 self.assertEqual(g(f), "sized")
1228 self.assertEqual(g(t), "sized")
1229 g.register(c.ChainMap, lambda obj: "chainmap")
1230 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1231 self.assertEqual(g(l), "sized")
1232 self.assertEqual(g(s), "sized")
1233 self.assertEqual(g(f), "sized")
1234 self.assertEqual(g(t), "sized")
1235 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1236 self.assertEqual(g(d), "mutablemapping")
1237 self.assertEqual(g(l), "mutablesequence")
1238 self.assertEqual(g(s), "sized")
1239 self.assertEqual(g(f), "sized")
1240 self.assertEqual(g(t), "sized")
1241 g.register(c.MutableSet, lambda obj: "mutableset")
1242 self.assertEqual(g(d), "mutablemapping")
1243 self.assertEqual(g(l), "mutablesequence")
1244 self.assertEqual(g(s), "mutableset")
1245 self.assertEqual(g(f), "sized")
1246 self.assertEqual(g(t), "sized")
1247 g.register(c.Mapping, lambda obj: "mapping")
1248 self.assertEqual(g(d), "mutablemapping") # not specific enough
1249 self.assertEqual(g(l), "mutablesequence")
1250 self.assertEqual(g(s), "mutableset")
1251 self.assertEqual(g(f), "sized")
1252 self.assertEqual(g(t), "sized")
1253 g.register(c.Sequence, lambda obj: "sequence")
1254 self.assertEqual(g(d), "mutablemapping")
1255 self.assertEqual(g(l), "mutablesequence")
1256 self.assertEqual(g(s), "mutableset")
1257 self.assertEqual(g(f), "sized")
1258 self.assertEqual(g(t), "sequence")
1259 g.register(c.Set, lambda obj: "set")
1260 self.assertEqual(g(d), "mutablemapping")
1261 self.assertEqual(g(l), "mutablesequence")
1262 self.assertEqual(g(s), "mutableset")
1263 self.assertEqual(g(f), "set")
1264 self.assertEqual(g(t), "sequence")
1265 g.register(dict, lambda obj: "dict")
1266 self.assertEqual(g(d), "dict")
1267 self.assertEqual(g(l), "mutablesequence")
1268 self.assertEqual(g(s), "mutableset")
1269 self.assertEqual(g(f), "set")
1270 self.assertEqual(g(t), "sequence")
1271 g.register(list, lambda obj: "list")
1272 self.assertEqual(g(d), "dict")
1273 self.assertEqual(g(l), "list")
1274 self.assertEqual(g(s), "mutableset")
1275 self.assertEqual(g(f), "set")
1276 self.assertEqual(g(t), "sequence")
1277 g.register(set, lambda obj: "concrete-set")
1278 self.assertEqual(g(d), "dict")
1279 self.assertEqual(g(l), "list")
1280 self.assertEqual(g(s), "concrete-set")
1281 self.assertEqual(g(f), "set")
1282 self.assertEqual(g(t), "sequence")
1283 g.register(frozenset, lambda obj: "frozen-set")
1284 self.assertEqual(g(d), "dict")
1285 self.assertEqual(g(l), "list")
1286 self.assertEqual(g(s), "concrete-set")
1287 self.assertEqual(g(f), "frozen-set")
1288 self.assertEqual(g(t), "sequence")
1289 g.register(tuple, lambda obj: "tuple")
1290 self.assertEqual(g(d), "dict")
1291 self.assertEqual(g(l), "list")
1292 self.assertEqual(g(s), "concrete-set")
1293 self.assertEqual(g(f), "frozen-set")
1294 self.assertEqual(g(t), "tuple")
1295
Łukasz Langa3720c772013-07-01 16:00:38 +02001296 def test_c3_abc(self):
1297 c = collections
1298 mro = functools._c3_mro
1299 class A(object):
1300 pass
1301 class B(A):
1302 def __len__(self):
1303 return 0 # implies Sized
1304 @c.Container.register
1305 class C(object):
1306 pass
1307 class D(object):
1308 pass # unrelated
1309 class X(D, C, B):
1310 def __call__(self):
1311 pass # implies Callable
1312 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
1313 for abcs in permutations([c.Sized, c.Callable, c.Container]):
1314 self.assertEqual(mro(X, abcs=abcs), expected)
1315 # unrelated ABCs don't appear in the resulting MRO
1316 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
1317 self.assertEqual(mro(X, abcs=many_abcs), expected)
1318
Łukasz Langa6f692512013-06-05 12:20:24 +02001319 def test_mro_conflicts(self):
1320 c = collections
Łukasz Langa6f692512013-06-05 12:20:24 +02001321 @functools.singledispatch
1322 def g(arg):
1323 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02001324 class O(c.Sized):
1325 def __len__(self):
1326 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001327 o = O()
1328 self.assertEqual(g(o), "base")
1329 g.register(c.Iterable, lambda arg: "iterable")
1330 g.register(c.Container, lambda arg: "container")
1331 g.register(c.Sized, lambda arg: "sized")
1332 g.register(c.Set, lambda arg: "set")
1333 self.assertEqual(g(o), "sized")
1334 c.Iterable.register(O)
1335 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
1336 c.Container.register(O)
1337 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02001338 c.Set.register(O)
1339 self.assertEqual(g(o), "set") # because c.Set is a subclass of
1340 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02001341 class P:
1342 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02001343 p = P()
1344 self.assertEqual(g(p), "base")
1345 c.Iterable.register(P)
1346 self.assertEqual(g(p), "iterable")
1347 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02001348 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02001349 g(p)
Łukasz Langa3720c772013-07-01 16:00:38 +02001350 self.assertIn(
1351 str(re_one.exception),
1352 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1353 "or <class 'collections.abc.Iterable'>"),
1354 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
1355 "or <class 'collections.abc.Container'>")),
1356 )
Łukasz Langa6f692512013-06-05 12:20:24 +02001357 class Q(c.Sized):
1358 def __len__(self):
1359 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001360 q = Q()
1361 self.assertEqual(g(q), "sized")
1362 c.Iterable.register(Q)
1363 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
1364 c.Set.register(Q)
1365 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02001366 # c.Sized and c.Iterable
1367 @functools.singledispatch
1368 def h(arg):
1369 return "base"
1370 @h.register(c.Sized)
1371 def _(arg):
1372 return "sized"
1373 @h.register(c.Container)
1374 def _(arg):
1375 return "container"
1376 # Even though Sized and Container are explicit bases of MutableMapping,
1377 # this ABC is implicitly registered on defaultdict which makes all of
1378 # MutableMapping's bases implicit as well from defaultdict's
1379 # perspective.
1380 with self.assertRaises(RuntimeError) as re_two:
1381 h(c.defaultdict(lambda: 0))
1382 self.assertIn(
1383 str(re_two.exception),
1384 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1385 "or <class 'collections.abc.Sized'>"),
1386 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1387 "or <class 'collections.abc.Container'>")),
1388 )
1389 class R(c.defaultdict):
1390 pass
1391 c.MutableSequence.register(R)
1392 @functools.singledispatch
1393 def i(arg):
1394 return "base"
1395 @i.register(c.MutableMapping)
1396 def _(arg):
1397 return "mapping"
1398 @i.register(c.MutableSequence)
1399 def _(arg):
1400 return "sequence"
1401 r = R()
1402 self.assertEqual(i(r), "sequence")
1403 class S:
1404 pass
1405 class T(S, c.Sized):
1406 def __len__(self):
1407 return 0
1408 t = T()
1409 self.assertEqual(h(t), "sized")
1410 c.Container.register(T)
1411 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
1412 class U:
1413 def __len__(self):
1414 return 0
1415 u = U()
1416 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
1417 # from the existence of __len__()
1418 c.Container.register(U)
1419 # There is no preference for registered versus inferred ABCs.
1420 with self.assertRaises(RuntimeError) as re_three:
1421 h(u)
1422 self.assertIn(
1423 str(re_three.exception),
1424 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1425 "or <class 'collections.abc.Sized'>"),
1426 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1427 "or <class 'collections.abc.Container'>")),
1428 )
1429 class V(c.Sized, S):
1430 def __len__(self):
1431 return 0
1432 @functools.singledispatch
1433 def j(arg):
1434 return "base"
1435 @j.register(S)
1436 def _(arg):
1437 return "s"
1438 @j.register(c.Container)
1439 def _(arg):
1440 return "container"
1441 v = V()
1442 self.assertEqual(j(v), "s")
1443 c.Container.register(V)
1444 self.assertEqual(j(v), "container") # because it ends up right after
1445 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02001446
1447 def test_cache_invalidation(self):
1448 from collections import UserDict
1449 class TracingDict(UserDict):
1450 def __init__(self, *args, **kwargs):
1451 super(TracingDict, self).__init__(*args, **kwargs)
1452 self.set_ops = []
1453 self.get_ops = []
1454 def __getitem__(self, key):
1455 result = self.data[key]
1456 self.get_ops.append(key)
1457 return result
1458 def __setitem__(self, key, value):
1459 self.set_ops.append(key)
1460 self.data[key] = value
1461 def clear(self):
1462 self.data.clear()
1463 _orig_wkd = functools.WeakKeyDictionary
1464 td = TracingDict()
1465 functools.WeakKeyDictionary = lambda: td
1466 c = collections
1467 @functools.singledispatch
1468 def g(arg):
1469 return "base"
1470 d = {}
1471 l = []
1472 self.assertEqual(len(td), 0)
1473 self.assertEqual(g(d), "base")
1474 self.assertEqual(len(td), 1)
1475 self.assertEqual(td.get_ops, [])
1476 self.assertEqual(td.set_ops, [dict])
1477 self.assertEqual(td.data[dict], g.registry[object])
1478 self.assertEqual(g(l), "base")
1479 self.assertEqual(len(td), 2)
1480 self.assertEqual(td.get_ops, [])
1481 self.assertEqual(td.set_ops, [dict, list])
1482 self.assertEqual(td.data[dict], g.registry[object])
1483 self.assertEqual(td.data[list], g.registry[object])
1484 self.assertEqual(td.data[dict], td.data[list])
1485 self.assertEqual(g(l), "base")
1486 self.assertEqual(g(d), "base")
1487 self.assertEqual(td.get_ops, [list, dict])
1488 self.assertEqual(td.set_ops, [dict, list])
1489 g.register(list, lambda arg: "list")
1490 self.assertEqual(td.get_ops, [list, dict])
1491 self.assertEqual(len(td), 0)
1492 self.assertEqual(g(d), "base")
1493 self.assertEqual(len(td), 1)
1494 self.assertEqual(td.get_ops, [list, dict])
1495 self.assertEqual(td.set_ops, [dict, list, dict])
1496 self.assertEqual(td.data[dict],
1497 functools._find_impl(dict, g.registry))
1498 self.assertEqual(g(l), "list")
1499 self.assertEqual(len(td), 2)
1500 self.assertEqual(td.get_ops, [list, dict])
1501 self.assertEqual(td.set_ops, [dict, list, dict, list])
1502 self.assertEqual(td.data[list],
1503 functools._find_impl(list, g.registry))
1504 class X:
1505 pass
1506 c.MutableMapping.register(X) # Will not invalidate the cache,
1507 # not using ABCs yet.
1508 self.assertEqual(g(d), "base")
1509 self.assertEqual(g(l), "list")
1510 self.assertEqual(td.get_ops, [list, dict, dict, list])
1511 self.assertEqual(td.set_ops, [dict, list, dict, list])
1512 g.register(c.Sized, lambda arg: "sized")
1513 self.assertEqual(len(td), 0)
1514 self.assertEqual(g(d), "sized")
1515 self.assertEqual(len(td), 1)
1516 self.assertEqual(td.get_ops, [list, dict, dict, list])
1517 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
1518 self.assertEqual(g(l), "list")
1519 self.assertEqual(len(td), 2)
1520 self.assertEqual(td.get_ops, [list, dict, dict, list])
1521 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1522 self.assertEqual(g(l), "list")
1523 self.assertEqual(g(d), "sized")
1524 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
1525 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1526 g.dispatch(list)
1527 g.dispatch(dict)
1528 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
1529 list, dict])
1530 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1531 c.MutableSet.register(X) # Will invalidate the cache.
1532 self.assertEqual(len(td), 2) # Stale cache.
1533 self.assertEqual(g(l), "list")
1534 self.assertEqual(len(td), 1)
1535 g.register(c.MutableMapping, lambda arg: "mutablemapping")
1536 self.assertEqual(len(td), 0)
1537 self.assertEqual(g(d), "mutablemapping")
1538 self.assertEqual(len(td), 1)
1539 self.assertEqual(g(l), "list")
1540 self.assertEqual(len(td), 2)
1541 g.register(dict, lambda arg: "dict")
1542 self.assertEqual(g(d), "dict")
1543 self.assertEqual(g(l), "list")
1544 g._clear_cache()
1545 self.assertEqual(len(td), 0)
1546 functools.WeakKeyDictionary = _orig_wkd
1547
1548
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001549def test_main(verbose=None):
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001550 test_classes = (
Antoine Pitroub5b37142012-11-13 21:35:40 +01001551 TestPartialC,
1552 TestPartialPy,
1553 TestPartialCSubclass,
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001554 TestPartialMethod,
Thomas Wouters73e5a5b2006-06-08 15:35:45 +00001555 TestUpdateWrapper,
Benjamin Peterson9c2930e2010-08-23 17:40:33 +00001556 TestTotalOrdering,
Antoine Pitroub5b37142012-11-13 21:35:40 +01001557 TestCmpToKeyC,
1558 TestCmpToKeyPy,
Guido van Rossum0919a1a2006-08-26 20:49:04 +00001559 TestWraps,
Georg Brandl2e7346a2010-07-31 18:09:23 +00001560 TestReduce,
1561 TestLRU,
Łukasz Langa6f692512013-06-05 12:20:24 +02001562 TestSingleDispatch,
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001563 )
Benjamin Petersonee8712c2008-05-20 21:35:26 +00001564 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001565
1566 # verify reference counting
1567 if verbose and hasattr(sys, "gettotalrefcount"):
1568 import gc
1569 counts = [None] * 5
Guido van Rossum805365e2007-05-07 22:24:25 +00001570 for i in range(len(counts)):
Benjamin Petersonee8712c2008-05-20 21:35:26 +00001571 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001572 gc.collect()
1573 counts[i] = sys.gettotalrefcount()
Guido van Rossumbe19ed72007-02-09 05:37:30 +00001574 print(counts)
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001575
1576if __name__ == '__main__':
1577 test_main(verbose=True)