blob: c50336e3e0deb037f95f2b2a2ccc7193b5eef5be [file] [log] [blame]
Thomas Wouters4d70c3d2006-06-08 14:42:34 +00001import functools
R. David Murray378c0cf2010-02-24 01:46:21 +00002import sys
Raymond Hettinger9c323f82005-02-28 19:39:44 +00003import unittest
Benjamin Petersonee8712c2008-05-20 21:35:26 +00004from test import support
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +00005from weakref import proxy
Jack Diederiche0cbd692009-04-01 04:27:09 +00006import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00007from random import choice
Raymond Hettinger9c323f82005-02-28 19:39:44 +00008
9@staticmethod
10def PythonPartial(func, *args, **keywords):
11 'Pure Python approximation of partial()'
12 def newfunc(*fargs, **fkeywords):
13 newkeywords = keywords.copy()
14 newkeywords.update(fkeywords)
15 return func(*(args + fargs), **newkeywords)
16 newfunc.func = func
17 newfunc.args = args
18 newfunc.keywords = keywords
19 return newfunc
20
21def capture(*args, **kw):
22 """capture all positional and keyword arguments"""
23 return args, kw
24
Jack Diederiche0cbd692009-04-01 04:27:09 +000025def signature(part):
26 """ return the signature of a partial object """
27 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000028
Raymond Hettinger9c323f82005-02-28 19:39:44 +000029class TestPartial(unittest.TestCase):
30
Thomas Wouters4d70c3d2006-06-08 14:42:34 +000031 thetype = functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +000032
33 def test_basic_examples(self):
34 p = self.thetype(capture, 1, 2, a=10, b=20)
35 self.assertEqual(p(3, 4, b=30, c=40),
36 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
37 p = self.thetype(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000038 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000039
40 def test_attributes(self):
41 p = self.thetype(capture, 1, 2, a=10, b=20)
42 # attributes should be readable
43 self.assertEqual(p.func, capture)
44 self.assertEqual(p.args, (1, 2))
45 self.assertEqual(p.keywords, dict(a=10, b=20))
46 # attributes should not be writable
47 if not isinstance(self.thetype, type):
48 return
Georg Brandl89fad142010-03-14 10:23:39 +000049 self.assertRaises(AttributeError, setattr, p, 'func', map)
50 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
51 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
52
53 p = self.thetype(hex)
54 try:
55 del p.__dict__
56 except TypeError:
57 pass
58 else:
59 self.fail('partial object allowed __dict__ to be deleted')
Raymond Hettinger9c323f82005-02-28 19:39:44 +000060
61 def test_argument_checking(self):
62 self.assertRaises(TypeError, self.thetype) # need at least a func arg
63 try:
64 self.thetype(2)()
65 except TypeError:
66 pass
67 else:
68 self.fail('First arg not checked for callability')
69
70 def test_protection_of_callers_dict_argument(self):
71 # a caller's dictionary should not be altered by partial
72 def func(a=10, b=20):
73 return a
74 d = {'a':3}
75 p = self.thetype(func, a=5)
76 self.assertEqual(p(**d), 3)
77 self.assertEqual(d, {'a':3})
78 p(b=7)
79 self.assertEqual(d, {'a':3})
80
81 def test_arg_combinations(self):
82 # exercise special code paths for zero args in either partial
83 # object or the caller
84 p = self.thetype(capture)
85 self.assertEqual(p(), ((), {}))
86 self.assertEqual(p(1,2), ((1,2), {}))
87 p = self.thetype(capture, 1, 2)
88 self.assertEqual(p(), ((1,2), {}))
89 self.assertEqual(p(3,4), ((1,2,3,4), {}))
90
91 def test_kw_combinations(self):
92 # exercise special code paths for no keyword args in
93 # either the partial object or the caller
94 p = self.thetype(capture)
95 self.assertEqual(p(), ((), {}))
96 self.assertEqual(p(a=1), ((), {'a':1}))
97 p = self.thetype(capture, a=1)
98 self.assertEqual(p(), ((), {'a':1}))
99 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
100 # keyword args in the call override those in the partial object
101 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
102
103 def test_positional(self):
104 # make sure positional arguments are captured correctly
105 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
106 p = self.thetype(capture, *args)
107 expected = args + ('x',)
108 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000109 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000110
111 def test_keyword(self):
112 # make sure keyword arguments are captured correctly
113 for a in ['a', 0, None, 3.5]:
114 p = self.thetype(capture, a=a)
115 expected = {'a':a,'x':None}
116 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000117 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000118
119 def test_no_side_effects(self):
120 # make sure there are no side effects that affect subsequent calls
121 p = self.thetype(capture, 0, a=1)
122 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000123 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000124 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000125 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000126
127 def test_error_propagation(self):
128 def f(x, y):
129 x / y
130 self.assertRaises(ZeroDivisionError, self.thetype(f, 1, 0))
131 self.assertRaises(ZeroDivisionError, self.thetype(f, 1), 0)
132 self.assertRaises(ZeroDivisionError, self.thetype(f), 1, 0)
133 self.assertRaises(ZeroDivisionError, self.thetype(f, y=0), 1)
134
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000135 def test_weakref(self):
136 f = self.thetype(int, base=16)
137 p = proxy(f)
138 self.assertEqual(f.func, p.func)
139 f = None
140 self.assertRaises(ReferenceError, getattr, p, 'func')
141
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000142 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000143 data = list(map(str, range(10)))
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000144 join = self.thetype(str.join, '')
145 self.assertEqual(join(data), '0123456789')
146 join = self.thetype(''.join)
147 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000148
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000149 def test_repr(self):
150 args = (object(), object())
151 args_repr = ', '.join(repr(a) for a in args)
152 kwargs = {'a': object(), 'b': object()}
153 kwargs_repr = ', '.join("%s=%r" % (k, v) for k, v in kwargs.items())
154 if self.thetype is functools.partial:
155 name = 'functools.partial'
156 else:
157 name = self.thetype.__name__
158
159 f = self.thetype(capture)
160 self.assertEqual('{}({!r})'.format(name, capture),
161 repr(f))
162
163 f = self.thetype(capture, *args)
164 self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr),
165 repr(f))
166
167 f = self.thetype(capture, **kwargs)
168 self.assertEqual('{}({!r}, {})'.format(name, capture, kwargs_repr),
169 repr(f))
170
171 f = self.thetype(capture, *args, **kwargs)
172 self.assertEqual('{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr),
173 repr(f))
174
Jack Diederiche0cbd692009-04-01 04:27:09 +0000175 def test_pickle(self):
176 f = self.thetype(signature, 'asdf', bar=True)
177 f.add_something_to__dict__ = True
178 f_copy = pickle.loads(pickle.dumps(f))
179 self.assertEqual(signature(f), signature(f_copy))
180
Thomas Wouters4d70c3d2006-06-08 14:42:34 +0000181class PartialSubclass(functools.partial):
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000182 pass
183
184class TestPartialSubclass(TestPartial):
185
186 thetype = PartialSubclass
187
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000188class TestPythonPartial(TestPartial):
189
190 thetype = PythonPartial
191
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000192 # the python version hasn't a nice repr
193 def test_repr(self): pass
194
Jack Diederiche0cbd692009-04-01 04:27:09 +0000195 # the python version isn't picklable
196 def test_pickle(self): pass
197
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000198class TestUpdateWrapper(unittest.TestCase):
199
200 def check_wrapper(self, wrapper, wrapped,
201 assigned=functools.WRAPPER_ASSIGNMENTS,
202 updated=functools.WRAPPER_UPDATES):
203 # Check attributes were assigned
204 for name in assigned:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000205 self.assertTrue(getattr(wrapper, name) is getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000206 # Check attributes were updated
207 for name in updated:
208 wrapper_attr = getattr(wrapper, name)
209 wrapped_attr = getattr(wrapped, name)
210 for key in wrapped_attr:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000211 self.assertTrue(wrapped_attr[key] is wrapper_attr[key])
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000212
R. David Murray378c0cf2010-02-24 01:46:21 +0000213 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000214 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000215 """This is a test"""
216 pass
217 f.attr = 'This is also a test'
Antoine Pitrou560f7642010-08-04 18:28:02 +0000218 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000219 pass
220 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000221 return wrapper, f
222
223 def test_default_update(self):
224 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000225 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000226 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000227 self.assertEqual(wrapper.__name__, 'f')
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000228 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000229 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
230 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000231
R. David Murray378c0cf2010-02-24 01:46:21 +0000232 @unittest.skipIf(sys.flags.optimize >= 2,
233 "Docstrings are omitted with -O2 and above")
234 def test_default_update_doc(self):
235 wrapper, f = self._default_update()
236 self.assertEqual(wrapper.__doc__, 'This is a test')
237
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000238 def test_no_update(self):
239 def f():
240 """This is a test"""
241 pass
242 f.attr = 'This is also a test'
243 def wrapper():
244 pass
245 functools.update_wrapper(wrapper, f, (), ())
246 self.check_wrapper(wrapper, f, (), ())
247 self.assertEqual(wrapper.__name__, 'wrapper')
248 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000249 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000250 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000251
252 def test_selective_update(self):
253 def f():
254 pass
255 f.attr = 'This is a different test'
256 f.dict_attr = dict(a=1, b=2, c=3)
257 def wrapper():
258 pass
259 wrapper.dict_attr = {}
260 assign = ('attr',)
261 update = ('dict_attr',)
262 functools.update_wrapper(wrapper, f, assign, update)
263 self.check_wrapper(wrapper, f, assign, update)
264 self.assertEqual(wrapper.__name__, 'wrapper')
265 self.assertEqual(wrapper.__doc__, None)
266 self.assertEqual(wrapper.attr, 'This is a different test')
267 self.assertEqual(wrapper.dict_attr, f.dict_attr)
268
Nick Coghlan98876832010-08-17 06:17:18 +0000269 def test_missing_attributes(self):
270 def f():
271 pass
272 def wrapper():
273 pass
274 wrapper.dict_attr = {}
275 assign = ('attr',)
276 update = ('dict_attr',)
277 # Missing attributes on wrapped object are ignored
278 functools.update_wrapper(wrapper, f, assign, update)
279 self.assertNotIn('attr', wrapper.__dict__)
280 self.assertEqual(wrapper.dict_attr, {})
281 # Wrapper must have expected attributes for updating
282 del wrapper.dict_attr
283 with self.assertRaises(AttributeError):
284 functools.update_wrapper(wrapper, f, assign, update)
285 wrapper.dict_attr = 1
286 with self.assertRaises(AttributeError):
287 functools.update_wrapper(wrapper, f, assign, update)
288
289 @unittest.skipIf(sys.flags.optimize >= 2,
290 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000291 def test_builtin_update(self):
292 # Test for bug #1576241
293 def wrapper():
294 pass
295 functools.update_wrapper(wrapper, max)
296 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000297 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000298 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000299
300class TestWraps(TestUpdateWrapper):
301
R. David Murray378c0cf2010-02-24 01:46:21 +0000302 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000303 def f():
304 """This is a test"""
305 pass
306 f.attr = 'This is also a test'
307 @functools.wraps(f)
308 def wrapper():
309 pass
310 self.check_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000311 return wrapper
312
313 def test_default_update(self):
314 wrapper = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000315 self.assertEqual(wrapper.__name__, 'f')
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000316 self.assertEqual(wrapper.attr, 'This is also a test')
317
R. David Murray378c0cf2010-02-24 01:46:21 +0000318 @unittest.skipIf(not sys.flags.optimize <= 1,
319 "Docstrings are omitted with -O2 and above")
320 def test_default_update_doc(self):
321 wrapper = self._default_update()
322 self.assertEqual(wrapper.__doc__, 'This is a test')
323
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000324 def test_no_update(self):
325 def f():
326 """This is a test"""
327 pass
328 f.attr = 'This is also a test'
329 @functools.wraps(f, (), ())
330 def wrapper():
331 pass
332 self.check_wrapper(wrapper, f, (), ())
333 self.assertEqual(wrapper.__name__, 'wrapper')
334 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000335 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000336
337 def test_selective_update(self):
338 def f():
339 pass
340 f.attr = 'This is a different test'
341 f.dict_attr = dict(a=1, b=2, c=3)
342 def add_dict_attr(f):
343 f.dict_attr = {}
344 return f
345 assign = ('attr',)
346 update = ('dict_attr',)
347 @functools.wraps(f, assign, update)
348 @add_dict_attr
349 def wrapper():
350 pass
351 self.check_wrapper(wrapper, f, assign, update)
352 self.assertEqual(wrapper.__name__, 'wrapper')
353 self.assertEqual(wrapper.__doc__, None)
354 self.assertEqual(wrapper.attr, 'This is a different test')
355 self.assertEqual(wrapper.dict_attr, f.dict_attr)
356
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000357class TestReduce(unittest.TestCase):
358 func = functools.reduce
359
360 def test_reduce(self):
361 class Squares:
362 def __init__(self, max):
363 self.max = max
364 self.sofar = []
365
366 def __len__(self):
367 return len(self.sofar)
368
369 def __getitem__(self, i):
370 if not 0 <= i < self.max: raise IndexError
371 n = len(self.sofar)
372 while n <= i:
373 self.sofar.append(n*n)
374 n += 1
375 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000376 def add(x, y):
377 return x + y
378 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000379 self.assertEqual(
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000380 self.func(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000381 ['a','c','d','w']
382 )
383 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
384 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000385 self.func(lambda x, y: x*y, range(2,21), 1),
386 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000387 )
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000388 self.assertEqual(self.func(add, Squares(10)), 285)
389 self.assertEqual(self.func(add, Squares(10), 0), 285)
390 self.assertEqual(self.func(add, Squares(0), 0), 0)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000391 self.assertRaises(TypeError, self.func)
392 self.assertRaises(TypeError, self.func, 42, 42)
393 self.assertRaises(TypeError, self.func, 42, 42, 42)
394 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
395 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
396 self.assertRaises(TypeError, self.func, 42, (42, 42))
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000397 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
398 self.assertRaises(TypeError, self.func, add, "")
399 self.assertRaises(TypeError, self.func, add, ())
400 self.assertRaises(TypeError, self.func, add, object())
401
402 class TestFailingIter:
403 def __iter__(self):
404 raise RuntimeError
405 self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
406
407 self.assertEqual(self.func(add, [], None), None)
408 self.assertEqual(self.func(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000409
410 class BadSeq:
411 def __getitem__(self, index):
412 raise ValueError
413 self.assertRaises(ValueError, self.func, 42, BadSeq())
414
415 # Test reduce()'s use of iterators.
416 def test_iterator_usage(self):
417 class SequenceClass:
418 def __init__(self, n):
419 self.n = n
420 def __getitem__(self, i):
421 if 0 <= i < self.n:
422 return i
423 else:
424 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000425
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000426 from operator import add
427 self.assertEqual(self.func(add, SequenceClass(5)), 10)
428 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
429 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
430 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
431 self.assertEqual(self.func(add, SequenceClass(1)), 0)
432 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
433
434 d = {"one": 1, "two": 2, "three": 3}
435 self.assertEqual(self.func(add, d), "".join(d.keys()))
436
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000437class TestCmpToKey(unittest.TestCase):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700438
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000439 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700440 def cmp1(x, y):
441 return (x > y) - (x < y)
442 key = functools.cmp_to_key(cmp1)
443 self.assertEqual(key(3), key(3))
444 self.assertGreater(key(3), key(1))
445 def cmp2(x, y):
446 return int(x) - int(y)
447 key = functools.cmp_to_key(cmp2)
448 self.assertEqual(key(4.0), key('4'))
449 self.assertLess(key(2), key('35'))
450
451 def test_cmp_to_key_arguments(self):
452 def cmp1(x, y):
453 return (x > y) - (x < y)
454 key = functools.cmp_to_key(mycmp=cmp1)
455 self.assertEqual(key(obj=3), key(obj=3))
456 self.assertGreater(key(obj=3), key(obj=1))
457 with self.assertRaises((TypeError, AttributeError)):
458 key(3) > 1 # rhs is not a K object
459 with self.assertRaises((TypeError, AttributeError)):
460 1 < key(3) # lhs is not a K object
461 with self.assertRaises(TypeError):
462 key = functools.cmp_to_key() # too few args
463 with self.assertRaises(TypeError):
464 key = functools.cmp_to_key(cmp1, None) # too many args
465 key = functools.cmp_to_key(cmp1)
466 with self.assertRaises(TypeError):
467 key() # too few args
468 with self.assertRaises(TypeError):
469 key(None, None) # too many args
470
471 def test_bad_cmp(self):
472 def cmp1(x, y):
473 raise ZeroDivisionError
474 key = functools.cmp_to_key(cmp1)
475 with self.assertRaises(ZeroDivisionError):
476 key(3) > key(1)
477
478 class BadCmp:
479 def __lt__(self, other):
480 raise ZeroDivisionError
481 def cmp1(x, y):
482 return BadCmp()
483 with self.assertRaises(ZeroDivisionError):
484 key(3) > key(1)
485
486 def test_obj_field(self):
487 def cmp1(x, y):
488 return (x > y) - (x < y)
489 key = functools.cmp_to_key(mycmp=cmp1)
490 self.assertEqual(key(50).obj, 50)
491
492 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000493 def mycmp(x, y):
494 return y - x
495 self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)),
496 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000497
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700498 def test_sort_int_str(self):
499 def mycmp(x, y):
500 x, y = int(x), int(y)
501 return (x > y) - (x < y)
502 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
503 values = sorted(values, key=functools.cmp_to_key(mycmp))
504 self.assertEqual([int(value) for value in values],
505 [0, 1, 1, 2, 3, 4, 5, 7, 10])
506
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000507 def test_hash(self):
508 def mycmp(x, y):
509 return y - x
510 key = functools.cmp_to_key(mycmp)
511 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700512 self.assertRaises(TypeError, hash, k)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000513
514class TestTotalOrdering(unittest.TestCase):
515
516 def test_total_ordering_lt(self):
517 @functools.total_ordering
518 class A:
519 def __init__(self, value):
520 self.value = value
521 def __lt__(self, other):
522 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000523 def __eq__(self, other):
524 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000525 self.assertTrue(A(1) < A(2))
526 self.assertTrue(A(2) > A(1))
527 self.assertTrue(A(1) <= A(2))
528 self.assertTrue(A(2) >= A(1))
529 self.assertTrue(A(2) <= A(2))
530 self.assertTrue(A(2) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000531
532 def test_total_ordering_le(self):
533 @functools.total_ordering
534 class A:
535 def __init__(self, value):
536 self.value = value
537 def __le__(self, other):
538 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000539 def __eq__(self, other):
540 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000541 self.assertTrue(A(1) < A(2))
542 self.assertTrue(A(2) > A(1))
543 self.assertTrue(A(1) <= A(2))
544 self.assertTrue(A(2) >= A(1))
545 self.assertTrue(A(2) <= A(2))
546 self.assertTrue(A(2) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000547
548 def test_total_ordering_gt(self):
549 @functools.total_ordering
550 class A:
551 def __init__(self, value):
552 self.value = value
553 def __gt__(self, other):
554 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000555 def __eq__(self, other):
556 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000557 self.assertTrue(A(1) < A(2))
558 self.assertTrue(A(2) > A(1))
559 self.assertTrue(A(1) <= A(2))
560 self.assertTrue(A(2) >= A(1))
561 self.assertTrue(A(2) <= A(2))
562 self.assertTrue(A(2) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000563
564 def test_total_ordering_ge(self):
565 @functools.total_ordering
566 class A:
567 def __init__(self, value):
568 self.value = value
569 def __ge__(self, other):
570 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000571 def __eq__(self, other):
572 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000573 self.assertTrue(A(1) < A(2))
574 self.assertTrue(A(2) > A(1))
575 self.assertTrue(A(1) <= A(2))
576 self.assertTrue(A(2) >= A(1))
577 self.assertTrue(A(2) <= A(2))
578 self.assertTrue(A(2) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000579
580 def test_total_ordering_no_overwrite(self):
581 # new methods should not overwrite existing
582 @functools.total_ordering
583 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000584 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +0000585 self.assertTrue(A(1) < A(2))
586 self.assertTrue(A(2) > A(1))
587 self.assertTrue(A(1) <= A(2))
588 self.assertTrue(A(2) >= A(1))
589 self.assertTrue(A(2) <= A(2))
590 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000591
Benjamin Peterson42ebee32010-04-11 01:43:16 +0000592 def test_no_operations_defined(self):
593 with self.assertRaises(ValueError):
594 @functools.total_ordering
595 class A:
596 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000597
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000598 def test_bug_10042(self):
599 @functools.total_ordering
600 class TestTO:
601 def __init__(self, value):
602 self.value = value
603 def __eq__(self, other):
604 if isinstance(other, TestTO):
605 return self.value == other.value
606 return False
607 def __lt__(self, other):
608 if isinstance(other, TestTO):
609 return self.value < other.value
610 raise TypeError
611 with self.assertRaises(TypeError):
612 TestTO(8) <= ()
613
Georg Brandl2e7346a2010-07-31 18:09:23 +0000614class TestLRU(unittest.TestCase):
615
616 def test_lru(self):
617 def orig(x, y):
618 return 3*x+y
619 f = functools.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000620 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000621 self.assertEqual(maxsize, 20)
622 self.assertEqual(currsize, 0)
623 self.assertEqual(hits, 0)
624 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000625
626 domain = range(5)
627 for i in range(1000):
628 x, y = choice(domain), choice(domain)
629 actual = f(x, y)
630 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +0000631 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000632 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000633 self.assertTrue(hits > misses)
634 self.assertEqual(hits + misses, 1000)
635 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000636
Raymond Hettinger02566ec2010-09-04 22:46:06 +0000637 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +0000638 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000639 self.assertEqual(hits, 0)
640 self.assertEqual(misses, 0)
641 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000642 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000643 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000644 self.assertEqual(hits, 0)
645 self.assertEqual(misses, 1)
646 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000647
Nick Coghlan98876832010-08-17 06:17:18 +0000648 # Test bypassing the cache
649 self.assertIs(f.__wrapped__, orig)
650 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000651 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000652 self.assertEqual(hits, 0)
653 self.assertEqual(misses, 1)
654 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +0000655
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000656 # test size zero (which means "never-cache")
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000657 @functools.lru_cache(0)
658 def f():
659 nonlocal f_cnt
660 f_cnt += 1
661 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000662 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000663 f_cnt = 0
664 for i in range(5):
665 self.assertEqual(f(), 20)
666 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000667 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000668 self.assertEqual(hits, 0)
669 self.assertEqual(misses, 5)
670 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000671
672 # test size one
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000673 @functools.lru_cache(1)
674 def f():
675 nonlocal f_cnt
676 f_cnt += 1
677 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000678 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000679 f_cnt = 0
680 for i in range(5):
681 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000682 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000683 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000684 self.assertEqual(hits, 4)
685 self.assertEqual(misses, 1)
686 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000687
Raymond Hettingerf3098282010-08-15 03:30:45 +0000688 # test size two
689 @functools.lru_cache(2)
690 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000691 nonlocal f_cnt
692 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +0000693 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +0000694 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000695 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +0000696 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
697 # * * * *
698 self.assertEqual(f(x), x*10)
699 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000700 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000701 self.assertEqual(hits, 12)
702 self.assertEqual(misses, 4)
703 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000704
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +0000705 def test_lru_with_maxsize_none(self):
706 @functools.lru_cache(maxsize=None)
707 def fib(n):
708 if n < 2:
709 return n
710 return fib(n-1) + fib(n-2)
711 self.assertEqual([fib(n) for n in range(16)],
712 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
713 self.assertEqual(fib.cache_info(),
714 functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
715 fib.cache_clear()
716 self.assertEqual(fib.cache_info(),
717 functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
718
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000719def test_main(verbose=None):
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000720 test_classes = (
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700721 TestCmpToKey,
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000722 TestPartial,
723 TestPartialSubclass,
724 TestPythonPartial,
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000725 TestUpdateWrapper,
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000726 TestTotalOrdering,
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000727 TestWraps,
Georg Brandl2e7346a2010-07-31 18:09:23 +0000728 TestReduce,
729 TestLRU,
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000730 )
Benjamin Petersonee8712c2008-05-20 21:35:26 +0000731 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000732
733 # verify reference counting
734 if verbose and hasattr(sys, "gettotalrefcount"):
735 import gc
736 counts = [None] * 5
Guido van Rossum805365e2007-05-07 22:24:25 +0000737 for i in range(len(counts)):
Benjamin Petersonee8712c2008-05-20 21:35:26 +0000738 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000739 gc.collect()
740 counts[i] = sys.gettotalrefcount()
Guido van Rossumbe19ed72007-02-09 05:37:30 +0000741 print(counts)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000742
743if __name__ == '__main__':
744 test_main(verbose=True)