blob: 03dd5457368aa0cdf7967a65f0aea3168e46b0e7 [file] [log] [blame]
Nick Coghlanf4cb48a2013-11-03 16:41:46 +10001import abc
Raymond Hettinger003be522011-05-03 11:01:32 -07002import collections
Łukasz Langa6f692512013-06-05 12:20:24 +02003from itertools import permutations
Jack Diederiche0cbd692009-04-01 04:27:09 +00004import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00005from random import choice
Łukasz Langa6f692512013-06-05 12:20:24 +02006import sys
7from test import support
8import unittest
9from weakref import proxy
Raymond Hettinger9c323f82005-02-28 19:39:44 +000010
Antoine Pitroub5b37142012-11-13 21:35:40 +010011import functools
12
Antoine Pitroub5b37142012-11-13 21:35:40 +010013py_functools = support.import_fresh_module('functools', blocked=['_functools'])
14c_functools = support.import_fresh_module('functools', fresh=['_functools'])
15
Łukasz Langa6f692512013-06-05 12:20:24 +020016decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
17
18
Raymond Hettinger9c323f82005-02-28 19:39:44 +000019def capture(*args, **kw):
20 """capture all positional and keyword arguments"""
21 return args, kw
22
Łukasz Langa6f692512013-06-05 12:20:24 +020023
Jack Diederiche0cbd692009-04-01 04:27:09 +000024def signature(part):
25 """ return the signature of a partial object """
26 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000027
Łukasz Langa6f692512013-06-05 12:20:24 +020028
Serhiy Storchakaca4220b2013-02-05 22:12:59 +020029class TestPartial:
Raymond Hettinger9c323f82005-02-28 19:39:44 +000030
31 def test_basic_examples(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010032 p = self.partial(capture, 1, 2, a=10, b=20)
33 self.assertTrue(callable(p))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000034 self.assertEqual(p(3, 4, b=30, c=40),
35 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
Antoine Pitroub5b37142012-11-13 21:35:40 +010036 p = self.partial(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000037 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000038
39 def test_attributes(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010040 p = self.partial(capture, 1, 2, a=10, b=20)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000041 # attributes should be readable
42 self.assertEqual(p.func, capture)
43 self.assertEqual(p.args, (1, 2))
44 self.assertEqual(p.keywords, dict(a=10, b=20))
Raymond Hettinger9c323f82005-02-28 19:39:44 +000045
46 def test_argument_checking(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +010047 self.assertRaises(TypeError, self.partial) # need at least a func arg
Raymond Hettinger9c323f82005-02-28 19:39:44 +000048 try:
Antoine Pitroub5b37142012-11-13 21:35:40 +010049 self.partial(2)()
Raymond Hettinger9c323f82005-02-28 19:39:44 +000050 except TypeError:
51 pass
52 else:
53 self.fail('First arg not checked for callability')
54
55 def test_protection_of_callers_dict_argument(self):
56 # a caller's dictionary should not be altered by partial
57 def func(a=10, b=20):
58 return a
59 d = {'a':3}
Antoine Pitroub5b37142012-11-13 21:35:40 +010060 p = self.partial(func, a=5)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000061 self.assertEqual(p(**d), 3)
62 self.assertEqual(d, {'a':3})
63 p(b=7)
64 self.assertEqual(d, {'a':3})
65
66 def test_arg_combinations(self):
67 # exercise special code paths for zero args in either partial
68 # object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010069 p = self.partial(capture)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000070 self.assertEqual(p(), ((), {}))
71 self.assertEqual(p(1,2), ((1,2), {}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010072 p = self.partial(capture, 1, 2)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000073 self.assertEqual(p(), ((1,2), {}))
74 self.assertEqual(p(3,4), ((1,2,3,4), {}))
75
76 def test_kw_combinations(self):
77 # exercise special code paths for no keyword args in
78 # either the partial object or the caller
Antoine Pitroub5b37142012-11-13 21:35:40 +010079 p = self.partial(capture)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -040080 self.assertEqual(p.keywords, {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +000081 self.assertEqual(p(), ((), {}))
82 self.assertEqual(p(a=1), ((), {'a':1}))
Antoine Pitroub5b37142012-11-13 21:35:40 +010083 p = self.partial(capture, a=1)
Benjamin Peterson65bcdd72015-05-09 00:25:18 -040084 self.assertEqual(p.keywords, {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +000085 self.assertEqual(p(), ((), {'a':1}))
86 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
87 # keyword args in the call override those in the partial object
88 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
89
90 def test_positional(self):
91 # make sure positional arguments are captured correctly
92 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
Antoine Pitroub5b37142012-11-13 21:35:40 +010093 p = self.partial(capture, *args)
Raymond Hettinger9c323f82005-02-28 19:39:44 +000094 expected = args + ('x',)
95 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +000096 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +000097
98 def test_keyword(self):
99 # make sure keyword arguments are captured correctly
100 for a in ['a', 0, None, 3.5]:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100101 p = self.partial(capture, a=a)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000102 expected = {'a':a,'x':None}
103 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000104 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000105
106 def test_no_side_effects(self):
107 # make sure there are no side effects that affect subsequent calls
Antoine Pitroub5b37142012-11-13 21:35:40 +0100108 p = self.partial(capture, 0, a=1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000109 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000110 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000111 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000112 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000113
114 def test_error_propagation(self):
115 def f(x, y):
116 x / y
Antoine Pitroub5b37142012-11-13 21:35:40 +0100117 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
118 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
119 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
120 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000121
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000122 def test_weakref(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100123 f = self.partial(int, base=16)
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000124 p = proxy(f)
125 self.assertEqual(f.func, p.func)
126 f = None
127 self.assertRaises(ReferenceError, getattr, p, 'func')
128
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000129 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000130 data = list(map(str, range(10)))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100131 join = self.partial(str.join, '')
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000132 self.assertEqual(join(data), '0123456789')
Antoine Pitroub5b37142012-11-13 21:35:40 +0100133 join = self.partial(''.join)
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000134 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000135
Alexander Belopolskye49af342015-03-01 15:08:17 -0500136 def test_nested_optimization(self):
137 partial = self.partial
138 # Only "true" partial is optimized
139 if partial.__name__ != 'partial':
140 return
141 inner = partial(signature, 'asdf')
142 nested = partial(inner, bar=True)
143 flat = partial(signature, 'asdf', bar=True)
144 self.assertEqual(signature(nested), signature(flat))
145
Łukasz Langa6f692512013-06-05 12:20:24 +0200146
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200147@unittest.skipUnless(c_functools, 'requires the C _functools module')
148class TestPartialC(TestPartial, unittest.TestCase):
149 if c_functools:
150 partial = c_functools.partial
151
Zachary Ware101d9e72013-12-08 00:44:27 -0600152 def test_attributes_unwritable(self):
153 # attributes should not be writable
154 p = self.partial(capture, 1, 2, a=10, b=20)
155 self.assertRaises(AttributeError, setattr, p, 'func', map)
156 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
157 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
158
159 p = self.partial(hex)
160 try:
161 del p.__dict__
162 except TypeError:
163 pass
164 else:
165 self.fail('partial object allowed __dict__ to be deleted')
166
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000167 def test_repr(self):
168 args = (object(), object())
169 args_repr = ', '.join(repr(a) for a in args)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200170 kwargs = {'a': object(), 'b': object()}
171 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
172 'b={b!r}, a={a!r}'.format_map(kwargs)]
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200173 if self.partial is c_functools.partial:
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000174 name = 'functools.partial'
175 else:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100176 name = self.partial.__name__
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000177
Antoine Pitroub5b37142012-11-13 21:35:40 +0100178 f = self.partial(capture)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000179 self.assertEqual('{}({!r})'.format(name, capture),
180 repr(f))
181
Antoine Pitroub5b37142012-11-13 21:35:40 +0100182 f = self.partial(capture, *args)
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000183 self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr),
184 repr(f))
185
Antoine Pitroub5b37142012-11-13 21:35:40 +0100186 f = self.partial(capture, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200187 self.assertIn(repr(f),
188 ['{}({!r}, {})'.format(name, capture, kwargs_repr)
189 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000190
Antoine Pitroub5b37142012-11-13 21:35:40 +0100191 f = self.partial(capture, *args, **kwargs)
Serhiy Storchaka0aa74e12015-02-15 16:20:47 +0200192 self.assertIn(repr(f),
193 ['{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr)
194 for kwargs_repr in kwargs_reprs])
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000195
Jack Diederiche0cbd692009-04-01 04:27:09 +0000196 def test_pickle(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100197 f = self.partial(signature, 'asdf', bar=True)
Jack Diederiche0cbd692009-04-01 04:27:09 +0000198 f.add_something_to__dict__ = True
Serhiy Storchakabad12572014-12-15 14:03:42 +0200199 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
200 f_copy = pickle.loads(pickle.dumps(f, proto))
201 self.assertEqual(signature(f), signature(f_copy))
Jack Diederiche0cbd692009-04-01 04:27:09 +0000202
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200203 # Issue 6083: Reference counting bug
204 def test_setstate_refcount(self):
205 class BadSequence:
206 def __len__(self):
207 return 4
208 def __getitem__(self, key):
209 if key == 0:
210 return max
211 elif key == 1:
212 return tuple(range(1000000))
213 elif key in (2, 3):
214 return {}
215 raise IndexError
216
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200217 f = self.partial(object)
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200218 self.assertRaisesRegex(SystemError,
219 "new style getargs format but argument is not a tuple",
220 f.__setstate__, BadSequence())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000221
Łukasz Langa6f692512013-06-05 12:20:24 +0200222
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200223class TestPartialPy(TestPartial, unittest.TestCase):
224 partial = staticmethod(py_functools.partial)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000225
Łukasz Langa6f692512013-06-05 12:20:24 +0200226
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200227if c_functools:
228 class PartialSubclass(c_functools.partial):
229 pass
Antoine Pitrou33543272012-11-13 21:36:21 +0100230
Łukasz Langa6f692512013-06-05 12:20:24 +0200231
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200232@unittest.skipUnless(c_functools, 'requires the C _functools module')
Serhiy Storchakab6a53402013-02-04 12:57:16 +0200233class TestPartialCSubclass(TestPartialC):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200234 if c_functools:
235 partial = PartialSubclass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000236
Łukasz Langa6f692512013-06-05 12:20:24 +0200237
Nick Coghlanf4cb48a2013-11-03 16:41:46 +1000238class TestPartialMethod(unittest.TestCase):
239
240 class A(object):
241 nothing = functools.partialmethod(capture)
242 positional = functools.partialmethod(capture, 1)
243 keywords = functools.partialmethod(capture, a=2)
244 both = functools.partialmethod(capture, 3, b=4)
245
246 nested = functools.partialmethod(positional, 5)
247
248 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
249
250 static = functools.partialmethod(staticmethod(capture), 8)
251 cls = functools.partialmethod(classmethod(capture), d=9)
252
253 a = A()
254
255 def test_arg_combinations(self):
256 self.assertEqual(self.a.nothing(), ((self.a,), {}))
257 self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
258 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
259 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
260
261 self.assertEqual(self.a.positional(), ((self.a, 1), {}))
262 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
263 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
264 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
265
266 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
267 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
268 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
269 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
270
271 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
272 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
273 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
274 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
275
276 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
277
278 def test_nested(self):
279 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
280 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
281 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
282 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
283
284 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
285
286 def test_over_partial(self):
287 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
288 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
289 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
290 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
291
292 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
293
294 def test_bound_method_introspection(self):
295 obj = self.a
296 self.assertIs(obj.both.__self__, obj)
297 self.assertIs(obj.nested.__self__, obj)
298 self.assertIs(obj.over_partial.__self__, obj)
299 self.assertIs(obj.cls.__self__, self.A)
300 self.assertIs(self.A.cls.__self__, self.A)
301
302 def test_unbound_method_retrieval(self):
303 obj = self.A
304 self.assertFalse(hasattr(obj.both, "__self__"))
305 self.assertFalse(hasattr(obj.nested, "__self__"))
306 self.assertFalse(hasattr(obj.over_partial, "__self__"))
307 self.assertFalse(hasattr(obj.static, "__self__"))
308 self.assertFalse(hasattr(self.a.static, "__self__"))
309
310 def test_descriptors(self):
311 for obj in [self.A, self.a]:
312 with self.subTest(obj=obj):
313 self.assertEqual(obj.static(), ((8,), {}))
314 self.assertEqual(obj.static(5), ((8, 5), {}))
315 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
316 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
317
318 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
319 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
320 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
321 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
322
323 def test_overriding_keywords(self):
324 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
325 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
326
327 def test_invalid_args(self):
328 with self.assertRaises(TypeError):
329 class B(object):
330 method = functools.partialmethod(None, 1)
331
332 def test_repr(self):
333 self.assertEqual(repr(vars(self.A)['both']),
334 'functools.partialmethod({}, 3, b=4)'.format(capture))
335
336 def test_abstract(self):
337 class Abstract(abc.ABCMeta):
338
339 @abc.abstractmethod
340 def add(self, x, y):
341 pass
342
343 add5 = functools.partialmethod(add, 5)
344
345 self.assertTrue(Abstract.add.__isabstractmethod__)
346 self.assertTrue(Abstract.add5.__isabstractmethod__)
347
348 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
349 self.assertFalse(getattr(func, '__isabstractmethod__', False))
350
351
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000352class TestUpdateWrapper(unittest.TestCase):
353
354 def check_wrapper(self, wrapper, wrapped,
355 assigned=functools.WRAPPER_ASSIGNMENTS,
356 updated=functools.WRAPPER_UPDATES):
357 # Check attributes were assigned
358 for name in assigned:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000359 self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000360 # Check attributes were updated
361 for name in updated:
362 wrapper_attr = getattr(wrapper, name)
363 wrapped_attr = getattr(wrapped, name)
364 for key in wrapped_attr:
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000365 if name == "__dict__" and key == "__wrapped__":
366 # __wrapped__ is overwritten by the update code
367 continue
368 self.assertIs(wrapped_attr[key], wrapper_attr[key])
369 # Check __wrapped__
370 self.assertIs(wrapper.__wrapped__, wrapped)
371
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000372
R. David Murray378c0cf2010-02-24 01:46:21 +0000373 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000374 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000375 """This is a test"""
376 pass
377 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000378 f.__wrapped__ = "This is a bald faced lie"
Antoine Pitrou560f7642010-08-04 18:28:02 +0000379 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000380 pass
381 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000382 return wrapper, f
383
384 def test_default_update(self):
385 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000386 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000387 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000388 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600389 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000390 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000391 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
392 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000393
R. David Murray378c0cf2010-02-24 01:46:21 +0000394 @unittest.skipIf(sys.flags.optimize >= 2,
395 "Docstrings are omitted with -O2 and above")
396 def test_default_update_doc(self):
397 wrapper, f = self._default_update()
398 self.assertEqual(wrapper.__doc__, 'This is a test')
399
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000400 def test_no_update(self):
401 def f():
402 """This is a test"""
403 pass
404 f.attr = 'This is also a test'
405 def wrapper():
406 pass
407 functools.update_wrapper(wrapper, f, (), ())
408 self.check_wrapper(wrapper, f, (), ())
409 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600410 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000411 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000412 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000413 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000414
415 def test_selective_update(self):
416 def f():
417 pass
418 f.attr = 'This is a different test'
419 f.dict_attr = dict(a=1, b=2, c=3)
420 def wrapper():
421 pass
422 wrapper.dict_attr = {}
423 assign = ('attr',)
424 update = ('dict_attr',)
425 functools.update_wrapper(wrapper, f, assign, update)
426 self.check_wrapper(wrapper, f, assign, update)
427 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600428 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000429 self.assertEqual(wrapper.__doc__, None)
430 self.assertEqual(wrapper.attr, 'This is a different test')
431 self.assertEqual(wrapper.dict_attr, f.dict_attr)
432
Nick Coghlan98876832010-08-17 06:17:18 +0000433 def test_missing_attributes(self):
434 def f():
435 pass
436 def wrapper():
437 pass
438 wrapper.dict_attr = {}
439 assign = ('attr',)
440 update = ('dict_attr',)
441 # Missing attributes on wrapped object are ignored
442 functools.update_wrapper(wrapper, f, assign, update)
443 self.assertNotIn('attr', wrapper.__dict__)
444 self.assertEqual(wrapper.dict_attr, {})
445 # Wrapper must have expected attributes for updating
446 del wrapper.dict_attr
447 with self.assertRaises(AttributeError):
448 functools.update_wrapper(wrapper, f, assign, update)
449 wrapper.dict_attr = 1
450 with self.assertRaises(AttributeError):
451 functools.update_wrapper(wrapper, f, assign, update)
452
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200453 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000454 @unittest.skipIf(sys.flags.optimize >= 2,
455 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000456 def test_builtin_update(self):
457 # Test for bug #1576241
458 def wrapper():
459 pass
460 functools.update_wrapper(wrapper, max)
461 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000462 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000463 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000464
Łukasz Langa6f692512013-06-05 12:20:24 +0200465
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000466class TestWraps(TestUpdateWrapper):
467
R. David Murray378c0cf2010-02-24 01:46:21 +0000468 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000469 def f():
470 """This is a test"""
471 pass
472 f.attr = 'This is also a test'
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000473 f.__wrapped__ = "This is still a bald faced lie"
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000474 @functools.wraps(f)
475 def wrapper():
476 pass
Meador Ingeff7f64c2011-12-11 22:37:31 -0600477 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000478
479 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600480 wrapper, f = self._default_update()
Nick Coghlan24c05bc2013-07-15 21:13:08 +1000481 self.check_wrapper(wrapper, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000482 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600483 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000484 self.assertEqual(wrapper.attr, 'This is also a test')
485
Antoine Pitroub5b37142012-11-13 21:35:40 +0100486 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000487 "Docstrings are omitted with -O2 and above")
488 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600489 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000490 self.assertEqual(wrapper.__doc__, 'This is a test')
491
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000492 def test_no_update(self):
493 def f():
494 """This is a test"""
495 pass
496 f.attr = 'This is also a test'
497 @functools.wraps(f, (), ())
498 def wrapper():
499 pass
500 self.check_wrapper(wrapper, f, (), ())
501 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600502 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000503 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000504 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000505
506 def test_selective_update(self):
507 def f():
508 pass
509 f.attr = 'This is a different test'
510 f.dict_attr = dict(a=1, b=2, c=3)
511 def add_dict_attr(f):
512 f.dict_attr = {}
513 return f
514 assign = ('attr',)
515 update = ('dict_attr',)
516 @functools.wraps(f, assign, update)
517 @add_dict_attr
518 def wrapper():
519 pass
520 self.check_wrapper(wrapper, f, assign, update)
521 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600522 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000523 self.assertEqual(wrapper.__doc__, None)
524 self.assertEqual(wrapper.attr, 'This is a different test')
525 self.assertEqual(wrapper.dict_attr, f.dict_attr)
526
Łukasz Langa6f692512013-06-05 12:20:24 +0200527
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000528class TestReduce(unittest.TestCase):
529 func = functools.reduce
530
531 def test_reduce(self):
532 class Squares:
533 def __init__(self, max):
534 self.max = max
535 self.sofar = []
536
537 def __len__(self):
538 return len(self.sofar)
539
540 def __getitem__(self, i):
541 if not 0 <= i < self.max: raise IndexError
542 n = len(self.sofar)
543 while n <= i:
544 self.sofar.append(n*n)
545 n += 1
546 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000547 def add(x, y):
548 return x + y
549 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000550 self.assertEqual(
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000551 self.func(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000552 ['a','c','d','w']
553 )
554 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
555 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000556 self.func(lambda x, y: x*y, range(2,21), 1),
557 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000558 )
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000559 self.assertEqual(self.func(add, Squares(10)), 285)
560 self.assertEqual(self.func(add, Squares(10), 0), 285)
561 self.assertEqual(self.func(add, Squares(0), 0), 0)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000562 self.assertRaises(TypeError, self.func)
563 self.assertRaises(TypeError, self.func, 42, 42)
564 self.assertRaises(TypeError, self.func, 42, 42, 42)
565 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
566 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
567 self.assertRaises(TypeError, self.func, 42, (42, 42))
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000568 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
569 self.assertRaises(TypeError, self.func, add, "")
570 self.assertRaises(TypeError, self.func, add, ())
571 self.assertRaises(TypeError, self.func, add, object())
572
573 class TestFailingIter:
574 def __iter__(self):
575 raise RuntimeError
576 self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
577
578 self.assertEqual(self.func(add, [], None), None)
579 self.assertEqual(self.func(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000580
581 class BadSeq:
582 def __getitem__(self, index):
583 raise ValueError
584 self.assertRaises(ValueError, self.func, 42, BadSeq())
585
586 # Test reduce()'s use of iterators.
587 def test_iterator_usage(self):
588 class SequenceClass:
589 def __init__(self, n):
590 self.n = n
591 def __getitem__(self, i):
592 if 0 <= i < self.n:
593 return i
594 else:
595 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000596
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000597 from operator import add
598 self.assertEqual(self.func(add, SequenceClass(5)), 10)
599 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
600 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
601 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
602 self.assertEqual(self.func(add, SequenceClass(1)), 0)
603 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
604
605 d = {"one": 1, "two": 2, "three": 3}
606 self.assertEqual(self.func(add, d), "".join(d.keys()))
607
Łukasz Langa6f692512013-06-05 12:20:24 +0200608
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200609class TestCmpToKey:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700610
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000611 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700612 def cmp1(x, y):
613 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100614 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700615 self.assertEqual(key(3), key(3))
616 self.assertGreater(key(3), key(1))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100617 self.assertGreaterEqual(key(3), key(3))
618
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700619 def cmp2(x, y):
620 return int(x) - int(y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100621 key = self.cmp_to_key(cmp2)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700622 self.assertEqual(key(4.0), key('4'))
623 self.assertLess(key(2), key('35'))
Antoine Pitroub5b37142012-11-13 21:35:40 +0100624 self.assertLessEqual(key(2), key('35'))
625 self.assertNotEqual(key(2), key('35'))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700626
627 def test_cmp_to_key_arguments(self):
628 def cmp1(x, y):
629 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100630 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700631 self.assertEqual(key(obj=3), key(obj=3))
632 self.assertGreater(key(obj=3), key(obj=1))
633 with self.assertRaises((TypeError, AttributeError)):
634 key(3) > 1 # rhs is not a K object
635 with self.assertRaises((TypeError, AttributeError)):
636 1 < key(3) # lhs is not a K object
637 with self.assertRaises(TypeError):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100638 key = self.cmp_to_key() # too few args
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700639 with self.assertRaises(TypeError):
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200640 key = self.cmp_to_key(cmp1, None) # too many args
Antoine Pitroub5b37142012-11-13 21:35:40 +0100641 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700642 with self.assertRaises(TypeError):
643 key() # too few args
644 with self.assertRaises(TypeError):
645 key(None, None) # too many args
646
647 def test_bad_cmp(self):
648 def cmp1(x, y):
649 raise ZeroDivisionError
Antoine Pitroub5b37142012-11-13 21:35:40 +0100650 key = self.cmp_to_key(cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700651 with self.assertRaises(ZeroDivisionError):
652 key(3) > key(1)
653
654 class BadCmp:
655 def __lt__(self, other):
656 raise ZeroDivisionError
657 def cmp1(x, y):
658 return BadCmp()
659 with self.assertRaises(ZeroDivisionError):
660 key(3) > key(1)
661
662 def test_obj_field(self):
663 def cmp1(x, y):
664 return (x > y) - (x < y)
Antoine Pitroub5b37142012-11-13 21:35:40 +0100665 key = self.cmp_to_key(mycmp=cmp1)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700666 self.assertEqual(key(50).obj, 50)
667
668 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000669 def mycmp(x, y):
670 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100671 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000672 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000673
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700674 def test_sort_int_str(self):
675 def mycmp(x, y):
676 x, y = int(x), int(y)
677 return (x > y) - (x < y)
678 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
Antoine Pitroub5b37142012-11-13 21:35:40 +0100679 values = sorted(values, key=self.cmp_to_key(mycmp))
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700680 self.assertEqual([int(value) for value in values],
681 [0, 1, 1, 2, 3, 4, 5, 7, 10])
682
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000683 def test_hash(self):
684 def mycmp(x, y):
685 return y - x
Antoine Pitroub5b37142012-11-13 21:35:40 +0100686 key = self.cmp_to_key(mycmp)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000687 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700688 self.assertRaises(TypeError, hash, k)
Raymond Hettingere7a24302011-05-03 11:16:36 -0700689 self.assertNotIsInstance(k, collections.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000690
Łukasz Langa6f692512013-06-05 12:20:24 +0200691
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200692@unittest.skipUnless(c_functools, 'requires the C _functools module')
693class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
694 if c_functools:
695 cmp_to_key = c_functools.cmp_to_key
Antoine Pitroub5b37142012-11-13 21:35:40 +0100696
Łukasz Langa6f692512013-06-05 12:20:24 +0200697
Serhiy Storchakaca4220b2013-02-05 22:12:59 +0200698class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100699 cmp_to_key = staticmethod(py_functools.cmp_to_key)
700
Łukasz Langa6f692512013-06-05 12:20:24 +0200701
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000702class TestTotalOrdering(unittest.TestCase):
703
704 def test_total_ordering_lt(self):
705 @functools.total_ordering
706 class A:
707 def __init__(self, value):
708 self.value = value
709 def __lt__(self, other):
710 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000711 def __eq__(self, other):
712 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000713 self.assertTrue(A(1) < A(2))
714 self.assertTrue(A(2) > A(1))
715 self.assertTrue(A(1) <= A(2))
716 self.assertTrue(A(2) >= A(1))
717 self.assertTrue(A(2) <= A(2))
718 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000719 self.assertFalse(A(1) > A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000720
721 def test_total_ordering_le(self):
722 @functools.total_ordering
723 class A:
724 def __init__(self, value):
725 self.value = value
726 def __le__(self, other):
727 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000728 def __eq__(self, other):
729 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000730 self.assertTrue(A(1) < A(2))
731 self.assertTrue(A(2) > A(1))
732 self.assertTrue(A(1) <= A(2))
733 self.assertTrue(A(2) >= A(1))
734 self.assertTrue(A(2) <= A(2))
735 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000736 self.assertFalse(A(1) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000737
738 def test_total_ordering_gt(self):
739 @functools.total_ordering
740 class A:
741 def __init__(self, value):
742 self.value = value
743 def __gt__(self, other):
744 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000745 def __eq__(self, other):
746 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000747 self.assertTrue(A(1) < A(2))
748 self.assertTrue(A(2) > A(1))
749 self.assertTrue(A(1) <= A(2))
750 self.assertTrue(A(2) >= A(1))
751 self.assertTrue(A(2) <= A(2))
752 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000753 self.assertFalse(A(2) < A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000754
755 def test_total_ordering_ge(self):
756 @functools.total_ordering
757 class A:
758 def __init__(self, value):
759 self.value = value
760 def __ge__(self, other):
761 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000762 def __eq__(self, other):
763 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000764 self.assertTrue(A(1) < A(2))
765 self.assertTrue(A(2) > A(1))
766 self.assertTrue(A(1) <= A(2))
767 self.assertTrue(A(2) >= A(1))
768 self.assertTrue(A(2) <= A(2))
769 self.assertTrue(A(2) >= A(2))
Nick Coghlanf05d9812013-10-02 00:02:03 +1000770 self.assertFalse(A(2) <= A(1))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000771
772 def test_total_ordering_no_overwrite(self):
773 # new methods should not overwrite existing
774 @functools.total_ordering
775 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000776 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +0000777 self.assertTrue(A(1) < A(2))
778 self.assertTrue(A(2) > A(1))
779 self.assertTrue(A(1) <= A(2))
780 self.assertTrue(A(2) >= A(1))
781 self.assertTrue(A(2) <= A(2))
782 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000783
Benjamin Peterson42ebee32010-04-11 01:43:16 +0000784 def test_no_operations_defined(self):
785 with self.assertRaises(ValueError):
786 @functools.total_ordering
787 class A:
788 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000789
Nick Coghlanf05d9812013-10-02 00:02:03 +1000790 def test_type_error_when_not_implemented(self):
791 # bug 10042; ensure stack overflow does not occur
792 # when decorated types return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000793 @functools.total_ordering
Nick Coghlanf05d9812013-10-02 00:02:03 +1000794 class ImplementsLessThan:
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000795 def __init__(self, value):
796 self.value = value
797 def __eq__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000798 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000799 return self.value == other.value
800 return False
801 def __lt__(self, other):
Nick Coghlanf05d9812013-10-02 00:02:03 +1000802 if isinstance(other, ImplementsLessThan):
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000803 return self.value < other.value
Nick Coghlanf05d9812013-10-02 00:02:03 +1000804 return NotImplemented
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000805
Nick Coghlanf05d9812013-10-02 00:02:03 +1000806 @functools.total_ordering
807 class ImplementsGreaterThan:
808 def __init__(self, value):
809 self.value = value
810 def __eq__(self, other):
811 if isinstance(other, ImplementsGreaterThan):
812 return self.value == other.value
813 return False
814 def __gt__(self, other):
815 if isinstance(other, ImplementsGreaterThan):
816 return self.value > other.value
817 return NotImplemented
818
819 @functools.total_ordering
820 class ImplementsLessThanEqualTo:
821 def __init__(self, value):
822 self.value = value
823 def __eq__(self, other):
824 if isinstance(other, ImplementsLessThanEqualTo):
825 return self.value == other.value
826 return False
827 def __le__(self, other):
828 if isinstance(other, ImplementsLessThanEqualTo):
829 return self.value <= other.value
830 return NotImplemented
831
832 @functools.total_ordering
833 class ImplementsGreaterThanEqualTo:
834 def __init__(self, value):
835 self.value = value
836 def __eq__(self, other):
837 if isinstance(other, ImplementsGreaterThanEqualTo):
838 return self.value == other.value
839 return False
840 def __ge__(self, other):
841 if isinstance(other, ImplementsGreaterThanEqualTo):
842 return self.value >= other.value
843 return NotImplemented
844
845 @functools.total_ordering
846 class ComparatorNotImplemented:
847 def __init__(self, value):
848 self.value = value
849 def __eq__(self, other):
850 if isinstance(other, ComparatorNotImplemented):
851 return self.value == other.value
852 return False
853 def __lt__(self, other):
854 return NotImplemented
855
856 with self.subTest("LT < 1"), self.assertRaises(TypeError):
857 ImplementsLessThan(-1) < 1
858
859 with self.subTest("LT < LE"), self.assertRaises(TypeError):
860 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
861
862 with self.subTest("LT < GT"), self.assertRaises(TypeError):
863 ImplementsLessThan(1) < ImplementsGreaterThan(1)
864
865 with self.subTest("LE <= LT"), self.assertRaises(TypeError):
866 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
867
868 with self.subTest("LE <= GE"), self.assertRaises(TypeError):
869 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
870
871 with self.subTest("GT > GE"), self.assertRaises(TypeError):
872 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
873
874 with self.subTest("GT > LT"), self.assertRaises(TypeError):
875 ImplementsGreaterThan(5) > ImplementsLessThan(5)
876
877 with self.subTest("GE >= GT"), self.assertRaises(TypeError):
878 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
879
880 with self.subTest("GE >= LE"), self.assertRaises(TypeError):
881 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
882
883 with self.subTest("GE when equal"):
884 a = ComparatorNotImplemented(8)
885 b = ComparatorNotImplemented(8)
886 self.assertEqual(a, b)
887 with self.assertRaises(TypeError):
888 a >= b
889
890 with self.subTest("LE when equal"):
891 a = ComparatorNotImplemented(9)
892 b = ComparatorNotImplemented(9)
893 self.assertEqual(a, b)
894 with self.assertRaises(TypeError):
895 a <= b
Łukasz Langa6f692512013-06-05 12:20:24 +0200896
Serhiy Storchaka697a5262015-01-01 15:23:12 +0200897 def test_pickle(self):
898 for proto in range(4, pickle.HIGHEST_PROTOCOL + 1):
899 for name in '__lt__', '__gt__', '__le__', '__ge__':
900 with self.subTest(method=name, proto=proto):
901 method = getattr(Orderable_LT, name)
902 method_copy = pickle.loads(pickle.dumps(method, proto))
903 self.assertIs(method_copy, method)
904
905@functools.total_ordering
906class Orderable_LT:
907 def __init__(self, value):
908 self.value = value
909 def __lt__(self, other):
910 return self.value < other.value
911 def __eq__(self, other):
912 return self.value == other.value
913
914
Georg Brandl2e7346a2010-07-31 18:09:23 +0000915class TestLRU(unittest.TestCase):
916
917 def test_lru(self):
918 def orig(x, y):
Antoine Pitroub5b37142012-11-13 21:35:40 +0100919 return 3 * x + y
Georg Brandl2e7346a2010-07-31 18:09:23 +0000920 f = functools.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000921 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000922 self.assertEqual(maxsize, 20)
923 self.assertEqual(currsize, 0)
924 self.assertEqual(hits, 0)
925 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000926
927 domain = range(5)
928 for i in range(1000):
929 x, y = choice(domain), choice(domain)
930 actual = f(x, y)
931 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +0000932 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000933 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000934 self.assertTrue(hits > misses)
935 self.assertEqual(hits + misses, 1000)
936 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000937
Raymond Hettinger02566ec2010-09-04 22:46:06 +0000938 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +0000939 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000940 self.assertEqual(hits, 0)
941 self.assertEqual(misses, 0)
942 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000943 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000944 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000945 self.assertEqual(hits, 0)
946 self.assertEqual(misses, 1)
947 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000948
Nick Coghlan98876832010-08-17 06:17:18 +0000949 # Test bypassing the cache
950 self.assertIs(f.__wrapped__, orig)
951 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000952 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000953 self.assertEqual(hits, 0)
954 self.assertEqual(misses, 1)
955 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +0000956
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000957 # test size zero (which means "never-cache")
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000958 @functools.lru_cache(0)
959 def f():
960 nonlocal f_cnt
961 f_cnt += 1
962 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000963 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000964 f_cnt = 0
965 for i in range(5):
966 self.assertEqual(f(), 20)
967 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000968 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000969 self.assertEqual(hits, 0)
970 self.assertEqual(misses, 5)
971 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000972
973 # test size one
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000974 @functools.lru_cache(1)
975 def f():
976 nonlocal f_cnt
977 f_cnt += 1
978 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000979 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000980 f_cnt = 0
981 for i in range(5):
982 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000983 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000984 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000985 self.assertEqual(hits, 4)
986 self.assertEqual(misses, 1)
987 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000988
Raymond Hettingerf3098282010-08-15 03:30:45 +0000989 # test size two
990 @functools.lru_cache(2)
991 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000992 nonlocal f_cnt
993 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +0000994 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +0000995 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000996 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +0000997 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
998 # * * * *
999 self.assertEqual(f(x), x*10)
1000 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +00001001 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +00001002 self.assertEqual(hits, 12)
1003 self.assertEqual(misses, 4)
1004 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +00001005
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +00001006 def test_lru_with_maxsize_none(self):
1007 @functools.lru_cache(maxsize=None)
1008 def fib(n):
1009 if n < 2:
1010 return n
1011 return fib(n-1) + fib(n-2)
1012 self.assertEqual([fib(n) for n in range(16)],
1013 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1014 self.assertEqual(fib.cache_info(),
1015 functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1016 fib.cache_clear()
1017 self.assertEqual(fib.cache_info(),
1018 functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1019
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001020 def test_lru_with_exceptions(self):
1021 # Verify that user_function exceptions get passed through without
1022 # creating a hard-to-read chained exception.
1023 # http://bugs.python.org/issue13177
Antoine Pitroub5b37142012-11-13 21:35:40 +01001024 for maxsize in (None, 128):
Raymond Hettinger4b779b32011-10-15 23:50:42 -07001025 @functools.lru_cache(maxsize)
1026 def func(i):
1027 return 'abc'[i]
1028 self.assertEqual(func(0), 'a')
1029 with self.assertRaises(IndexError) as cm:
1030 func(15)
1031 self.assertIsNone(cm.exception.__context__)
1032 # Verify that the previous exception did not result in a cached entry
1033 with self.assertRaises(IndexError):
1034 func(15)
1035
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001036 def test_lru_with_types(self):
Antoine Pitroub5b37142012-11-13 21:35:40 +01001037 for maxsize in (None, 128):
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -07001038 @functools.lru_cache(maxsize=maxsize, typed=True)
1039 def square(x):
1040 return x * x
1041 self.assertEqual(square(3), 9)
1042 self.assertEqual(type(square(3)), type(9))
1043 self.assertEqual(square(3.0), 9.0)
1044 self.assertEqual(type(square(3.0)), type(9.0))
1045 self.assertEqual(square(x=3), 9)
1046 self.assertEqual(type(square(x=3)), type(9))
1047 self.assertEqual(square(x=3.0), 9.0)
1048 self.assertEqual(type(square(x=3.0)), type(9.0))
1049 self.assertEqual(square.cache_info().hits, 4)
1050 self.assertEqual(square.cache_info().misses, 4)
1051
Antoine Pitroub5b37142012-11-13 21:35:40 +01001052 def test_lru_with_keyword_args(self):
1053 @functools.lru_cache()
1054 def fib(n):
1055 if n < 2:
1056 return n
1057 return fib(n=n-1) + fib(n=n-2)
1058 self.assertEqual(
1059 [fib(n=number) for number in range(16)],
1060 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
1061 )
1062 self.assertEqual(fib.cache_info(),
1063 functools._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
1064 fib.cache_clear()
1065 self.assertEqual(fib.cache_info(),
1066 functools._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
1067
1068 def test_lru_with_keyword_args_maxsize_none(self):
1069 @functools.lru_cache(maxsize=None)
1070 def fib(n):
1071 if n < 2:
1072 return n
1073 return fib(n=n-1) + fib(n=n-2)
1074 self.assertEqual([fib(n=number) for number in range(16)],
1075 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
1076 self.assertEqual(fib.cache_info(),
1077 functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
1078 fib.cache_clear()
1079 self.assertEqual(fib.cache_info(),
1080 functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
1081
Raymond Hettinger03923422013-03-04 02:52:50 -05001082 def test_need_for_rlock(self):
1083 # This will deadlock on an LRU cache that uses a regular lock
1084
1085 @functools.lru_cache(maxsize=10)
1086 def test_func(x):
1087 'Used to demonstrate a reentrant lru_cache call within a single thread'
1088 return x
1089
1090 class DoubleEq:
1091 'Demonstrate a reentrant lru_cache call within a single thread'
1092 def __init__(self, x):
1093 self.x = x
1094 def __hash__(self):
1095 return self.x
1096 def __eq__(self, other):
1097 if self.x == 2:
1098 test_func(DoubleEq(1))
1099 return self.x == other.x
1100
1101 test_func(DoubleEq(1)) # Load the cache
1102 test_func(DoubleEq(2)) # Load the cache
1103 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
1104 DoubleEq(2)) # Verify the correct return value
1105
Raymond Hettinger4d588972014-08-12 12:44:52 -07001106 def test_early_detection_of_bad_call(self):
1107 # Issue #22184
1108 with self.assertRaises(TypeError):
1109 @functools.lru_cache
1110 def f():
1111 pass
1112
Raymond Hettinger03923422013-03-04 02:52:50 -05001113
Łukasz Langa6f692512013-06-05 12:20:24 +02001114class TestSingleDispatch(unittest.TestCase):
1115 def test_simple_overloads(self):
1116 @functools.singledispatch
1117 def g(obj):
1118 return "base"
1119 def g_int(i):
1120 return "integer"
1121 g.register(int, g_int)
1122 self.assertEqual(g("str"), "base")
1123 self.assertEqual(g(1), "integer")
1124 self.assertEqual(g([1,2,3]), "base")
1125
1126 def test_mro(self):
1127 @functools.singledispatch
1128 def g(obj):
1129 return "base"
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001130 class A:
Łukasz Langa6f692512013-06-05 12:20:24 +02001131 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001132 class C(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001133 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001134 class B(A):
Łukasz Langa6f692512013-06-05 12:20:24 +02001135 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001136 class D(C, B):
Łukasz Langa6f692512013-06-05 12:20:24 +02001137 pass
Łukasz Langa7f7a67a2013-06-07 22:25:27 +02001138 def g_A(a):
1139 return "A"
1140 def g_B(b):
1141 return "B"
1142 g.register(A, g_A)
1143 g.register(B, g_B)
1144 self.assertEqual(g(A()), "A")
1145 self.assertEqual(g(B()), "B")
1146 self.assertEqual(g(C()), "A")
1147 self.assertEqual(g(D()), "B")
Łukasz Langa6f692512013-06-05 12:20:24 +02001148
1149 def test_register_decorator(self):
1150 @functools.singledispatch
1151 def g(obj):
1152 return "base"
1153 @g.register(int)
1154 def g_int(i):
1155 return "int %s" % (i,)
1156 self.assertEqual(g(""), "base")
1157 self.assertEqual(g(12), "int 12")
1158 self.assertIs(g.dispatch(int), g_int)
1159 self.assertIs(g.dispatch(object), g.dispatch(str))
1160 # Note: in the assert above this is not g.
1161 # @singledispatch returns the wrapper.
1162
1163 def test_wrapping_attributes(self):
1164 @functools.singledispatch
1165 def g(obj):
1166 "Simple test"
1167 return "Test"
1168 self.assertEqual(g.__name__, "g")
Serhiy Storchakab12cb6a2013-12-08 18:16:18 +02001169 if sys.flags.optimize < 2:
1170 self.assertEqual(g.__doc__, "Simple test")
Łukasz Langa6f692512013-06-05 12:20:24 +02001171
1172 @unittest.skipUnless(decimal, 'requires _decimal')
1173 @support.cpython_only
1174 def test_c_classes(self):
1175 @functools.singledispatch
1176 def g(obj):
1177 return "base"
1178 @g.register(decimal.DecimalException)
1179 def _(obj):
1180 return obj.args
1181 subn = decimal.Subnormal("Exponent < Emin")
1182 rnd = decimal.Rounded("Number got rounded")
1183 self.assertEqual(g(subn), ("Exponent < Emin",))
1184 self.assertEqual(g(rnd), ("Number got rounded",))
1185 @g.register(decimal.Subnormal)
1186 def _(obj):
1187 return "Too small to care."
1188 self.assertEqual(g(subn), "Too small to care.")
1189 self.assertEqual(g(rnd), ("Number got rounded",))
1190
1191 def test_compose_mro(self):
Łukasz Langa3720c772013-07-01 16:00:38 +02001192 # None of the examples in this test depend on haystack ordering.
Łukasz Langa6f692512013-06-05 12:20:24 +02001193 c = collections
1194 mro = functools._compose_mro
1195 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
1196 for haystack in permutations(bases):
1197 m = mro(dict, haystack)
Łukasz Langa3720c772013-07-01 16:00:38 +02001198 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, c.Sized,
1199 c.Iterable, c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001200 bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict]
1201 for haystack in permutations(bases):
1202 m = mro(c.ChainMap, haystack)
1203 self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping,
1204 c.Sized, c.Iterable, c.Container, object])
Łukasz Langa3720c772013-07-01 16:00:38 +02001205
1206 # If there's a generic function with implementations registered for
1207 # both Sized and Container, passing a defaultdict to it results in an
1208 # ambiguous dispatch which will cause a RuntimeError (see
1209 # test_mro_conflicts).
1210 bases = [c.Container, c.Sized, str]
1211 for haystack in permutations(bases):
1212 m = mro(c.defaultdict, [c.Sized, c.Container, str])
1213 self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container,
1214 object])
1215
1216 # MutableSequence below is registered directly on D. In other words, it
1217 # preceeds MutableMapping which means single dispatch will always
1218 # choose MutableSequence here.
1219 class D(c.defaultdict):
1220 pass
1221 c.MutableSequence.register(D)
1222 bases = [c.MutableSequence, c.MutableMapping]
1223 for haystack in permutations(bases):
1224 m = mro(D, bases)
1225 self.assertEqual(m, [D, c.MutableSequence, c.Sequence,
1226 c.defaultdict, dict, c.MutableMapping,
1227 c.Mapping, c.Sized, c.Iterable, c.Container,
1228 object])
1229
1230 # Container and Callable are registered on different base classes and
1231 # a generic function supporting both should always pick the Callable
1232 # implementation if a C instance is passed.
1233 class C(c.defaultdict):
1234 def __call__(self):
1235 pass
1236 bases = [c.Sized, c.Callable, c.Container, c.Mapping]
1237 for haystack in permutations(bases):
1238 m = mro(C, haystack)
1239 self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping,
1240 c.Sized, c.Iterable, c.Container, object])
Łukasz Langa6f692512013-06-05 12:20:24 +02001241
1242 def test_register_abc(self):
1243 c = collections
1244 d = {"a": "b"}
1245 l = [1, 2, 3]
1246 s = {object(), None}
1247 f = frozenset(s)
1248 t = (1, 2, 3)
1249 @functools.singledispatch
1250 def g(obj):
1251 return "base"
1252 self.assertEqual(g(d), "base")
1253 self.assertEqual(g(l), "base")
1254 self.assertEqual(g(s), "base")
1255 self.assertEqual(g(f), "base")
1256 self.assertEqual(g(t), "base")
1257 g.register(c.Sized, lambda obj: "sized")
1258 self.assertEqual(g(d), "sized")
1259 self.assertEqual(g(l), "sized")
1260 self.assertEqual(g(s), "sized")
1261 self.assertEqual(g(f), "sized")
1262 self.assertEqual(g(t), "sized")
1263 g.register(c.MutableMapping, lambda obj: "mutablemapping")
1264 self.assertEqual(g(d), "mutablemapping")
1265 self.assertEqual(g(l), "sized")
1266 self.assertEqual(g(s), "sized")
1267 self.assertEqual(g(f), "sized")
1268 self.assertEqual(g(t), "sized")
1269 g.register(c.ChainMap, lambda obj: "chainmap")
1270 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered
1271 self.assertEqual(g(l), "sized")
1272 self.assertEqual(g(s), "sized")
1273 self.assertEqual(g(f), "sized")
1274 self.assertEqual(g(t), "sized")
1275 g.register(c.MutableSequence, lambda obj: "mutablesequence")
1276 self.assertEqual(g(d), "mutablemapping")
1277 self.assertEqual(g(l), "mutablesequence")
1278 self.assertEqual(g(s), "sized")
1279 self.assertEqual(g(f), "sized")
1280 self.assertEqual(g(t), "sized")
1281 g.register(c.MutableSet, lambda obj: "mutableset")
1282 self.assertEqual(g(d), "mutablemapping")
1283 self.assertEqual(g(l), "mutablesequence")
1284 self.assertEqual(g(s), "mutableset")
1285 self.assertEqual(g(f), "sized")
1286 self.assertEqual(g(t), "sized")
1287 g.register(c.Mapping, lambda obj: "mapping")
1288 self.assertEqual(g(d), "mutablemapping") # not specific enough
1289 self.assertEqual(g(l), "mutablesequence")
1290 self.assertEqual(g(s), "mutableset")
1291 self.assertEqual(g(f), "sized")
1292 self.assertEqual(g(t), "sized")
1293 g.register(c.Sequence, lambda obj: "sequence")
1294 self.assertEqual(g(d), "mutablemapping")
1295 self.assertEqual(g(l), "mutablesequence")
1296 self.assertEqual(g(s), "mutableset")
1297 self.assertEqual(g(f), "sized")
1298 self.assertEqual(g(t), "sequence")
1299 g.register(c.Set, lambda obj: "set")
1300 self.assertEqual(g(d), "mutablemapping")
1301 self.assertEqual(g(l), "mutablesequence")
1302 self.assertEqual(g(s), "mutableset")
1303 self.assertEqual(g(f), "set")
1304 self.assertEqual(g(t), "sequence")
1305 g.register(dict, lambda obj: "dict")
1306 self.assertEqual(g(d), "dict")
1307 self.assertEqual(g(l), "mutablesequence")
1308 self.assertEqual(g(s), "mutableset")
1309 self.assertEqual(g(f), "set")
1310 self.assertEqual(g(t), "sequence")
1311 g.register(list, lambda obj: "list")
1312 self.assertEqual(g(d), "dict")
1313 self.assertEqual(g(l), "list")
1314 self.assertEqual(g(s), "mutableset")
1315 self.assertEqual(g(f), "set")
1316 self.assertEqual(g(t), "sequence")
1317 g.register(set, lambda obj: "concrete-set")
1318 self.assertEqual(g(d), "dict")
1319 self.assertEqual(g(l), "list")
1320 self.assertEqual(g(s), "concrete-set")
1321 self.assertEqual(g(f), "set")
1322 self.assertEqual(g(t), "sequence")
1323 g.register(frozenset, lambda obj: "frozen-set")
1324 self.assertEqual(g(d), "dict")
1325 self.assertEqual(g(l), "list")
1326 self.assertEqual(g(s), "concrete-set")
1327 self.assertEqual(g(f), "frozen-set")
1328 self.assertEqual(g(t), "sequence")
1329 g.register(tuple, lambda obj: "tuple")
1330 self.assertEqual(g(d), "dict")
1331 self.assertEqual(g(l), "list")
1332 self.assertEqual(g(s), "concrete-set")
1333 self.assertEqual(g(f), "frozen-set")
1334 self.assertEqual(g(t), "tuple")
1335
Łukasz Langa3720c772013-07-01 16:00:38 +02001336 def test_c3_abc(self):
1337 c = collections
1338 mro = functools._c3_mro
1339 class A(object):
1340 pass
1341 class B(A):
1342 def __len__(self):
1343 return 0 # implies Sized
1344 @c.Container.register
1345 class C(object):
1346 pass
1347 class D(object):
1348 pass # unrelated
1349 class X(D, C, B):
1350 def __call__(self):
1351 pass # implies Callable
1352 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
1353 for abcs in permutations([c.Sized, c.Callable, c.Container]):
1354 self.assertEqual(mro(X, abcs=abcs), expected)
1355 # unrelated ABCs don't appear in the resulting MRO
1356 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
1357 self.assertEqual(mro(X, abcs=many_abcs), expected)
1358
Łukasz Langa6f692512013-06-05 12:20:24 +02001359 def test_mro_conflicts(self):
1360 c = collections
Łukasz Langa6f692512013-06-05 12:20:24 +02001361 @functools.singledispatch
1362 def g(arg):
1363 return "base"
Łukasz Langa6f692512013-06-05 12:20:24 +02001364 class O(c.Sized):
1365 def __len__(self):
1366 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001367 o = O()
1368 self.assertEqual(g(o), "base")
1369 g.register(c.Iterable, lambda arg: "iterable")
1370 g.register(c.Container, lambda arg: "container")
1371 g.register(c.Sized, lambda arg: "sized")
1372 g.register(c.Set, lambda arg: "set")
1373 self.assertEqual(g(o), "sized")
1374 c.Iterable.register(O)
1375 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__
1376 c.Container.register(O)
1377 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__
Łukasz Langa3720c772013-07-01 16:00:38 +02001378 c.Set.register(O)
1379 self.assertEqual(g(o), "set") # because c.Set is a subclass of
1380 # c.Sized and c.Container
Łukasz Langa6f692512013-06-05 12:20:24 +02001381 class P:
1382 pass
Łukasz Langa6f692512013-06-05 12:20:24 +02001383 p = P()
1384 self.assertEqual(g(p), "base")
1385 c.Iterable.register(P)
1386 self.assertEqual(g(p), "iterable")
1387 c.Container.register(P)
Łukasz Langa3720c772013-07-01 16:00:38 +02001388 with self.assertRaises(RuntimeError) as re_one:
Łukasz Langa6f692512013-06-05 12:20:24 +02001389 g(p)
Łukasz Langa3720c772013-07-01 16:00:38 +02001390 self.assertIn(
1391 str(re_one.exception),
1392 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1393 "or <class 'collections.abc.Iterable'>"),
1394 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
1395 "or <class 'collections.abc.Container'>")),
1396 )
Łukasz Langa6f692512013-06-05 12:20:24 +02001397 class Q(c.Sized):
1398 def __len__(self):
1399 return 0
Łukasz Langa6f692512013-06-05 12:20:24 +02001400 q = Q()
1401 self.assertEqual(g(q), "sized")
1402 c.Iterable.register(Q)
1403 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__
1404 c.Set.register(Q)
1405 self.assertEqual(g(q), "set") # because c.Set is a subclass of
Łukasz Langa3720c772013-07-01 16:00:38 +02001406 # c.Sized and c.Iterable
1407 @functools.singledispatch
1408 def h(arg):
1409 return "base"
1410 @h.register(c.Sized)
1411 def _(arg):
1412 return "sized"
1413 @h.register(c.Container)
1414 def _(arg):
1415 return "container"
1416 # Even though Sized and Container are explicit bases of MutableMapping,
1417 # this ABC is implicitly registered on defaultdict which makes all of
1418 # MutableMapping's bases implicit as well from defaultdict's
1419 # perspective.
1420 with self.assertRaises(RuntimeError) as re_two:
1421 h(c.defaultdict(lambda: 0))
1422 self.assertIn(
1423 str(re_two.exception),
1424 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1425 "or <class 'collections.abc.Sized'>"),
1426 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1427 "or <class 'collections.abc.Container'>")),
1428 )
1429 class R(c.defaultdict):
1430 pass
1431 c.MutableSequence.register(R)
1432 @functools.singledispatch
1433 def i(arg):
1434 return "base"
1435 @i.register(c.MutableMapping)
1436 def _(arg):
1437 return "mapping"
1438 @i.register(c.MutableSequence)
1439 def _(arg):
1440 return "sequence"
1441 r = R()
1442 self.assertEqual(i(r), "sequence")
1443 class S:
1444 pass
1445 class T(S, c.Sized):
1446 def __len__(self):
1447 return 0
1448 t = T()
1449 self.assertEqual(h(t), "sized")
1450 c.Container.register(T)
1451 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO
1452 class U:
1453 def __len__(self):
1454 return 0
1455 u = U()
1456 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred
1457 # from the existence of __len__()
1458 c.Container.register(U)
1459 # There is no preference for registered versus inferred ABCs.
1460 with self.assertRaises(RuntimeError) as re_three:
1461 h(u)
1462 self.assertIn(
1463 str(re_three.exception),
1464 (("Ambiguous dispatch: <class 'collections.abc.Container'> "
1465 "or <class 'collections.abc.Sized'>"),
1466 ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
1467 "or <class 'collections.abc.Container'>")),
1468 )
1469 class V(c.Sized, S):
1470 def __len__(self):
1471 return 0
1472 @functools.singledispatch
1473 def j(arg):
1474 return "base"
1475 @j.register(S)
1476 def _(arg):
1477 return "s"
1478 @j.register(c.Container)
1479 def _(arg):
1480 return "container"
1481 v = V()
1482 self.assertEqual(j(v), "s")
1483 c.Container.register(V)
1484 self.assertEqual(j(v), "container") # because it ends up right after
1485 # Sized in the MRO
Łukasz Langa6f692512013-06-05 12:20:24 +02001486
1487 def test_cache_invalidation(self):
1488 from collections import UserDict
1489 class TracingDict(UserDict):
1490 def __init__(self, *args, **kwargs):
1491 super(TracingDict, self).__init__(*args, **kwargs)
1492 self.set_ops = []
1493 self.get_ops = []
1494 def __getitem__(self, key):
1495 result = self.data[key]
1496 self.get_ops.append(key)
1497 return result
1498 def __setitem__(self, key, value):
1499 self.set_ops.append(key)
1500 self.data[key] = value
1501 def clear(self):
1502 self.data.clear()
1503 _orig_wkd = functools.WeakKeyDictionary
1504 td = TracingDict()
1505 functools.WeakKeyDictionary = lambda: td
1506 c = collections
1507 @functools.singledispatch
1508 def g(arg):
1509 return "base"
1510 d = {}
1511 l = []
1512 self.assertEqual(len(td), 0)
1513 self.assertEqual(g(d), "base")
1514 self.assertEqual(len(td), 1)
1515 self.assertEqual(td.get_ops, [])
1516 self.assertEqual(td.set_ops, [dict])
1517 self.assertEqual(td.data[dict], g.registry[object])
1518 self.assertEqual(g(l), "base")
1519 self.assertEqual(len(td), 2)
1520 self.assertEqual(td.get_ops, [])
1521 self.assertEqual(td.set_ops, [dict, list])
1522 self.assertEqual(td.data[dict], g.registry[object])
1523 self.assertEqual(td.data[list], g.registry[object])
1524 self.assertEqual(td.data[dict], td.data[list])
1525 self.assertEqual(g(l), "base")
1526 self.assertEqual(g(d), "base")
1527 self.assertEqual(td.get_ops, [list, dict])
1528 self.assertEqual(td.set_ops, [dict, list])
1529 g.register(list, lambda arg: "list")
1530 self.assertEqual(td.get_ops, [list, dict])
1531 self.assertEqual(len(td), 0)
1532 self.assertEqual(g(d), "base")
1533 self.assertEqual(len(td), 1)
1534 self.assertEqual(td.get_ops, [list, dict])
1535 self.assertEqual(td.set_ops, [dict, list, dict])
1536 self.assertEqual(td.data[dict],
1537 functools._find_impl(dict, g.registry))
1538 self.assertEqual(g(l), "list")
1539 self.assertEqual(len(td), 2)
1540 self.assertEqual(td.get_ops, [list, dict])
1541 self.assertEqual(td.set_ops, [dict, list, dict, list])
1542 self.assertEqual(td.data[list],
1543 functools._find_impl(list, g.registry))
1544 class X:
1545 pass
1546 c.MutableMapping.register(X) # Will not invalidate the cache,
1547 # not using ABCs yet.
1548 self.assertEqual(g(d), "base")
1549 self.assertEqual(g(l), "list")
1550 self.assertEqual(td.get_ops, [list, dict, dict, list])
1551 self.assertEqual(td.set_ops, [dict, list, dict, list])
1552 g.register(c.Sized, lambda arg: "sized")
1553 self.assertEqual(len(td), 0)
1554 self.assertEqual(g(d), "sized")
1555 self.assertEqual(len(td), 1)
1556 self.assertEqual(td.get_ops, [list, dict, dict, list])
1557 self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
1558 self.assertEqual(g(l), "list")
1559 self.assertEqual(len(td), 2)
1560 self.assertEqual(td.get_ops, [list, dict, dict, list])
1561 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1562 self.assertEqual(g(l), "list")
1563 self.assertEqual(g(d), "sized")
1564 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
1565 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1566 g.dispatch(list)
1567 g.dispatch(dict)
1568 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
1569 list, dict])
1570 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
1571 c.MutableSet.register(X) # Will invalidate the cache.
1572 self.assertEqual(len(td), 2) # Stale cache.
1573 self.assertEqual(g(l), "list")
1574 self.assertEqual(len(td), 1)
1575 g.register(c.MutableMapping, lambda arg: "mutablemapping")
1576 self.assertEqual(len(td), 0)
1577 self.assertEqual(g(d), "mutablemapping")
1578 self.assertEqual(len(td), 1)
1579 self.assertEqual(g(l), "list")
1580 self.assertEqual(len(td), 2)
1581 g.register(dict, lambda arg: "dict")
1582 self.assertEqual(g(d), "dict")
1583 self.assertEqual(g(l), "list")
1584 g._clear_cache()
1585 self.assertEqual(len(td), 0)
1586 functools.WeakKeyDictionary = _orig_wkd
1587
1588
Raymond Hettinger9c323f82005-02-28 19:39:44 +00001589if __name__ == '__main__':
Zachary Ware38c707e2015-04-13 15:00:43 -05001590 unittest.main()