blob: c4d9fe6ac27026b1d06406716c86a3e74caa2eff [file] [log] [blame]
Thomas Wouters4d70c3d2006-06-08 14:42:34 +00001import functools
Raymond Hettinger003be522011-05-03 11:01:32 -07002import collections
R. David Murray378c0cf2010-02-24 01:46:21 +00003import sys
Raymond Hettinger9c323f82005-02-28 19:39:44 +00004import unittest
Benjamin Petersonee8712c2008-05-20 21:35:26 +00005from test import support
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +00006from weakref import proxy
Jack Diederiche0cbd692009-04-01 04:27:09 +00007import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00008from random import choice
Raymond Hettinger9c323f82005-02-28 19:39:44 +00009
10@staticmethod
11def PythonPartial(func, *args, **keywords):
12 'Pure Python approximation of partial()'
13 def newfunc(*fargs, **fkeywords):
14 newkeywords = keywords.copy()
15 newkeywords.update(fkeywords)
16 return func(*(args + fargs), **newkeywords)
17 newfunc.func = func
18 newfunc.args = args
19 newfunc.keywords = keywords
20 return newfunc
21
22def capture(*args, **kw):
23 """capture all positional and keyword arguments"""
24 return args, kw
25
Jack Diederiche0cbd692009-04-01 04:27:09 +000026def signature(part):
27 """ return the signature of a partial object """
28 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000029
Raymond Hettinger9c323f82005-02-28 19:39:44 +000030class TestPartial(unittest.TestCase):
31
Thomas Wouters4d70c3d2006-06-08 14:42:34 +000032 thetype = functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +000033
34 def test_basic_examples(self):
35 p = self.thetype(capture, 1, 2, a=10, b=20)
36 self.assertEqual(p(3, 4, b=30, c=40),
37 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
38 p = self.thetype(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000039 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000040
41 def test_attributes(self):
42 p = self.thetype(capture, 1, 2, a=10, b=20)
43 # attributes should be readable
44 self.assertEqual(p.func, capture)
45 self.assertEqual(p.args, (1, 2))
46 self.assertEqual(p.keywords, dict(a=10, b=20))
47 # attributes should not be writable
48 if not isinstance(self.thetype, type):
49 return
Georg Brandl89fad142010-03-14 10:23:39 +000050 self.assertRaises(AttributeError, setattr, p, 'func', map)
51 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
52 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
53
54 p = self.thetype(hex)
55 try:
56 del p.__dict__
57 except TypeError:
58 pass
59 else:
60 self.fail('partial object allowed __dict__ to be deleted')
Raymond Hettinger9c323f82005-02-28 19:39:44 +000061
62 def test_argument_checking(self):
63 self.assertRaises(TypeError, self.thetype) # need at least a func arg
64 try:
65 self.thetype(2)()
66 except TypeError:
67 pass
68 else:
69 self.fail('First arg not checked for callability')
70
71 def test_protection_of_callers_dict_argument(self):
72 # a caller's dictionary should not be altered by partial
73 def func(a=10, b=20):
74 return a
75 d = {'a':3}
76 p = self.thetype(func, a=5)
77 self.assertEqual(p(**d), 3)
78 self.assertEqual(d, {'a':3})
79 p(b=7)
80 self.assertEqual(d, {'a':3})
81
82 def test_arg_combinations(self):
83 # exercise special code paths for zero args in either partial
84 # object or the caller
85 p = self.thetype(capture)
86 self.assertEqual(p(), ((), {}))
87 self.assertEqual(p(1,2), ((1,2), {}))
88 p = self.thetype(capture, 1, 2)
89 self.assertEqual(p(), ((1,2), {}))
90 self.assertEqual(p(3,4), ((1,2,3,4), {}))
91
92 def test_kw_combinations(self):
93 # exercise special code paths for no keyword args in
94 # either the partial object or the caller
95 p = self.thetype(capture)
96 self.assertEqual(p(), ((), {}))
97 self.assertEqual(p(a=1), ((), {'a':1}))
98 p = self.thetype(capture, a=1)
99 self.assertEqual(p(), ((), {'a':1}))
100 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
101 # keyword args in the call override those in the partial object
102 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
103
104 def test_positional(self):
105 # make sure positional arguments are captured correctly
106 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
107 p = self.thetype(capture, *args)
108 expected = args + ('x',)
109 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000110 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000111
112 def test_keyword(self):
113 # make sure keyword arguments are captured correctly
114 for a in ['a', 0, None, 3.5]:
115 p = self.thetype(capture, a=a)
116 expected = {'a':a,'x':None}
117 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000118 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000119
120 def test_no_side_effects(self):
121 # make sure there are no side effects that affect subsequent calls
122 p = self.thetype(capture, 0, a=1)
123 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000124 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000125 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000126 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000127
128 def test_error_propagation(self):
129 def f(x, y):
130 x / y
131 self.assertRaises(ZeroDivisionError, self.thetype(f, 1, 0))
132 self.assertRaises(ZeroDivisionError, self.thetype(f, 1), 0)
133 self.assertRaises(ZeroDivisionError, self.thetype(f), 1, 0)
134 self.assertRaises(ZeroDivisionError, self.thetype(f, y=0), 1)
135
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000136 def test_weakref(self):
137 f = self.thetype(int, base=16)
138 p = proxy(f)
139 self.assertEqual(f.func, p.func)
140 f = None
141 self.assertRaises(ReferenceError, getattr, p, 'func')
142
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000143 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000144 data = list(map(str, range(10)))
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000145 join = self.thetype(str.join, '')
146 self.assertEqual(join(data), '0123456789')
147 join = self.thetype(''.join)
148 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000149
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000150 def test_repr(self):
151 args = (object(), object())
152 args_repr = ', '.join(repr(a) for a in args)
153 kwargs = {'a': object(), 'b': object()}
154 kwargs_repr = ', '.join("%s=%r" % (k, v) for k, v in kwargs.items())
155 if self.thetype is functools.partial:
156 name = 'functools.partial'
157 else:
158 name = self.thetype.__name__
159
160 f = self.thetype(capture)
161 self.assertEqual('{}({!r})'.format(name, capture),
162 repr(f))
163
164 f = self.thetype(capture, *args)
165 self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr),
166 repr(f))
167
168 f = self.thetype(capture, **kwargs)
169 self.assertEqual('{}({!r}, {})'.format(name, capture, kwargs_repr),
170 repr(f))
171
172 f = self.thetype(capture, *args, **kwargs)
173 self.assertEqual('{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr),
174 repr(f))
175
Jack Diederiche0cbd692009-04-01 04:27:09 +0000176 def test_pickle(self):
177 f = self.thetype(signature, 'asdf', bar=True)
178 f.add_something_to__dict__ = True
179 f_copy = pickle.loads(pickle.dumps(f))
180 self.assertEqual(signature(f), signature(f_copy))
181
Thomas Wouters4d70c3d2006-06-08 14:42:34 +0000182class PartialSubclass(functools.partial):
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000183 pass
184
185class TestPartialSubclass(TestPartial):
186
187 thetype = PartialSubclass
188
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000189class TestPythonPartial(TestPartial):
190
191 thetype = PythonPartial
192
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000193 # the python version hasn't a nice repr
194 def test_repr(self): pass
195
Jack Diederiche0cbd692009-04-01 04:27:09 +0000196 # the python version isn't picklable
197 def test_pickle(self): pass
198
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000199class TestUpdateWrapper(unittest.TestCase):
200
201 def check_wrapper(self, wrapper, wrapped,
202 assigned=functools.WRAPPER_ASSIGNMENTS,
203 updated=functools.WRAPPER_UPDATES):
204 # Check attributes were assigned
205 for name in assigned:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000206 self.assertTrue(getattr(wrapper, name) is getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000207 # Check attributes were updated
208 for name in updated:
209 wrapper_attr = getattr(wrapper, name)
210 wrapped_attr = getattr(wrapped, name)
211 for key in wrapped_attr:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000212 self.assertTrue(wrapped_attr[key] is wrapper_attr[key])
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000213
R. David Murray378c0cf2010-02-24 01:46:21 +0000214 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000215 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000216 """This is a test"""
217 pass
218 f.attr = 'This is also a test'
Antoine Pitrou560f7642010-08-04 18:28:02 +0000219 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000220 pass
221 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000222 return wrapper, f
223
224 def test_default_update(self):
225 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000226 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000227 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000228 self.assertEqual(wrapper.__name__, 'f')
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000229 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000230 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
231 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000232
R. David Murray378c0cf2010-02-24 01:46:21 +0000233 @unittest.skipIf(sys.flags.optimize >= 2,
234 "Docstrings are omitted with -O2 and above")
235 def test_default_update_doc(self):
236 wrapper, f = self._default_update()
237 self.assertEqual(wrapper.__doc__, 'This is a test')
238
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000239 def test_no_update(self):
240 def f():
241 """This is a test"""
242 pass
243 f.attr = 'This is also a test'
244 def wrapper():
245 pass
246 functools.update_wrapper(wrapper, f, (), ())
247 self.check_wrapper(wrapper, f, (), ())
248 self.assertEqual(wrapper.__name__, 'wrapper')
249 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000250 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000251 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000252
253 def test_selective_update(self):
254 def f():
255 pass
256 f.attr = 'This is a different test'
257 f.dict_attr = dict(a=1, b=2, c=3)
258 def wrapper():
259 pass
260 wrapper.dict_attr = {}
261 assign = ('attr',)
262 update = ('dict_attr',)
263 functools.update_wrapper(wrapper, f, assign, update)
264 self.check_wrapper(wrapper, f, assign, update)
265 self.assertEqual(wrapper.__name__, 'wrapper')
266 self.assertEqual(wrapper.__doc__, None)
267 self.assertEqual(wrapper.attr, 'This is a different test')
268 self.assertEqual(wrapper.dict_attr, f.dict_attr)
269
Nick Coghlan98876832010-08-17 06:17:18 +0000270 def test_missing_attributes(self):
271 def f():
272 pass
273 def wrapper():
274 pass
275 wrapper.dict_attr = {}
276 assign = ('attr',)
277 update = ('dict_attr',)
278 # Missing attributes on wrapped object are ignored
279 functools.update_wrapper(wrapper, f, assign, update)
280 self.assertNotIn('attr', wrapper.__dict__)
281 self.assertEqual(wrapper.dict_attr, {})
282 # Wrapper must have expected attributes for updating
283 del wrapper.dict_attr
284 with self.assertRaises(AttributeError):
285 functools.update_wrapper(wrapper, f, assign, update)
286 wrapper.dict_attr = 1
287 with self.assertRaises(AttributeError):
288 functools.update_wrapper(wrapper, f, assign, update)
289
290 @unittest.skipIf(sys.flags.optimize >= 2,
291 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000292 def test_builtin_update(self):
293 # Test for bug #1576241
294 def wrapper():
295 pass
296 functools.update_wrapper(wrapper, max)
297 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000298 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000299 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000300
301class TestWraps(TestUpdateWrapper):
302
R. David Murray378c0cf2010-02-24 01:46:21 +0000303 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000304 def f():
305 """This is a test"""
306 pass
307 f.attr = 'This is also a test'
308 @functools.wraps(f)
309 def wrapper():
310 pass
311 self.check_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000312 return wrapper
313
314 def test_default_update(self):
315 wrapper = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000316 self.assertEqual(wrapper.__name__, 'f')
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000317 self.assertEqual(wrapper.attr, 'This is also a test')
318
R. David Murray378c0cf2010-02-24 01:46:21 +0000319 @unittest.skipIf(not sys.flags.optimize <= 1,
320 "Docstrings are omitted with -O2 and above")
321 def test_default_update_doc(self):
322 wrapper = self._default_update()
323 self.assertEqual(wrapper.__doc__, 'This is a test')
324
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000325 def test_no_update(self):
326 def f():
327 """This is a test"""
328 pass
329 f.attr = 'This is also a test'
330 @functools.wraps(f, (), ())
331 def wrapper():
332 pass
333 self.check_wrapper(wrapper, f, (), ())
334 self.assertEqual(wrapper.__name__, 'wrapper')
335 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000336 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000337
338 def test_selective_update(self):
339 def f():
340 pass
341 f.attr = 'This is a different test'
342 f.dict_attr = dict(a=1, b=2, c=3)
343 def add_dict_attr(f):
344 f.dict_attr = {}
345 return f
346 assign = ('attr',)
347 update = ('dict_attr',)
348 @functools.wraps(f, assign, update)
349 @add_dict_attr
350 def wrapper():
351 pass
352 self.check_wrapper(wrapper, f, assign, update)
353 self.assertEqual(wrapper.__name__, 'wrapper')
354 self.assertEqual(wrapper.__doc__, None)
355 self.assertEqual(wrapper.attr, 'This is a different test')
356 self.assertEqual(wrapper.dict_attr, f.dict_attr)
357
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000358class TestReduce(unittest.TestCase):
359 func = functools.reduce
360
361 def test_reduce(self):
362 class Squares:
363 def __init__(self, max):
364 self.max = max
365 self.sofar = []
366
367 def __len__(self):
368 return len(self.sofar)
369
370 def __getitem__(self, i):
371 if not 0 <= i < self.max: raise IndexError
372 n = len(self.sofar)
373 while n <= i:
374 self.sofar.append(n*n)
375 n += 1
376 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000377 def add(x, y):
378 return x + y
379 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000380 self.assertEqual(
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000381 self.func(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000382 ['a','c','d','w']
383 )
384 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
385 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000386 self.func(lambda x, y: x*y, range(2,21), 1),
387 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000388 )
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000389 self.assertEqual(self.func(add, Squares(10)), 285)
390 self.assertEqual(self.func(add, Squares(10), 0), 285)
391 self.assertEqual(self.func(add, Squares(0), 0), 0)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000392 self.assertRaises(TypeError, self.func)
393 self.assertRaises(TypeError, self.func, 42, 42)
394 self.assertRaises(TypeError, self.func, 42, 42, 42)
395 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
396 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
397 self.assertRaises(TypeError, self.func, 42, (42, 42))
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000398 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
399 self.assertRaises(TypeError, self.func, add, "")
400 self.assertRaises(TypeError, self.func, add, ())
401 self.assertRaises(TypeError, self.func, add, object())
402
403 class TestFailingIter:
404 def __iter__(self):
405 raise RuntimeError
406 self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
407
408 self.assertEqual(self.func(add, [], None), None)
409 self.assertEqual(self.func(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000410
411 class BadSeq:
412 def __getitem__(self, index):
413 raise ValueError
414 self.assertRaises(ValueError, self.func, 42, BadSeq())
415
416 # Test reduce()'s use of iterators.
417 def test_iterator_usage(self):
418 class SequenceClass:
419 def __init__(self, n):
420 self.n = n
421 def __getitem__(self, i):
422 if 0 <= i < self.n:
423 return i
424 else:
425 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000426
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000427 from operator import add
428 self.assertEqual(self.func(add, SequenceClass(5)), 10)
429 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
430 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
431 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
432 self.assertEqual(self.func(add, SequenceClass(1)), 0)
433 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
434
435 d = {"one": 1, "two": 2, "three": 3}
436 self.assertEqual(self.func(add, d), "".join(d.keys()))
437
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000438class TestCmpToKey(unittest.TestCase):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700439
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000440 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700441 def cmp1(x, y):
442 return (x > y) - (x < y)
443 key = functools.cmp_to_key(cmp1)
444 self.assertEqual(key(3), key(3))
445 self.assertGreater(key(3), key(1))
446 def cmp2(x, y):
447 return int(x) - int(y)
448 key = functools.cmp_to_key(cmp2)
449 self.assertEqual(key(4.0), key('4'))
450 self.assertLess(key(2), key('35'))
451
452 def test_cmp_to_key_arguments(self):
453 def cmp1(x, y):
454 return (x > y) - (x < y)
455 key = functools.cmp_to_key(mycmp=cmp1)
456 self.assertEqual(key(obj=3), key(obj=3))
457 self.assertGreater(key(obj=3), key(obj=1))
458 with self.assertRaises((TypeError, AttributeError)):
459 key(3) > 1 # rhs is not a K object
460 with self.assertRaises((TypeError, AttributeError)):
461 1 < key(3) # lhs is not a K object
462 with self.assertRaises(TypeError):
463 key = functools.cmp_to_key() # too few args
464 with self.assertRaises(TypeError):
465 key = functools.cmp_to_key(cmp1, None) # too many args
466 key = functools.cmp_to_key(cmp1)
467 with self.assertRaises(TypeError):
468 key() # too few args
469 with self.assertRaises(TypeError):
470 key(None, None) # too many args
471
472 def test_bad_cmp(self):
473 def cmp1(x, y):
474 raise ZeroDivisionError
475 key = functools.cmp_to_key(cmp1)
476 with self.assertRaises(ZeroDivisionError):
477 key(3) > key(1)
478
479 class BadCmp:
480 def __lt__(self, other):
481 raise ZeroDivisionError
482 def cmp1(x, y):
483 return BadCmp()
484 with self.assertRaises(ZeroDivisionError):
485 key(3) > key(1)
486
487 def test_obj_field(self):
488 def cmp1(x, y):
489 return (x > y) - (x < y)
490 key = functools.cmp_to_key(mycmp=cmp1)
491 self.assertEqual(key(50).obj, 50)
492
493 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000494 def mycmp(x, y):
495 return y - x
496 self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)),
497 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000498
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700499 def test_sort_int_str(self):
500 def mycmp(x, y):
501 x, y = int(x), int(y)
502 return (x > y) - (x < y)
503 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
504 values = sorted(values, key=functools.cmp_to_key(mycmp))
505 self.assertEqual([int(value) for value in values],
506 [0, 1, 1, 2, 3, 4, 5, 7, 10])
507
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000508 def test_hash(self):
509 def mycmp(x, y):
510 return y - x
511 key = functools.cmp_to_key(mycmp)
512 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700513 self.assertRaises(TypeError, hash, k)
Raymond Hettingere7a24302011-05-03 11:16:36 -0700514 self.assertNotIsInstance(k, collections.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000515
516class TestTotalOrdering(unittest.TestCase):
517
518 def test_total_ordering_lt(self):
519 @functools.total_ordering
520 class A:
521 def __init__(self, value):
522 self.value = value
523 def __lt__(self, other):
524 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000525 def __eq__(self, other):
526 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000527 self.assertTrue(A(1) < A(2))
528 self.assertTrue(A(2) > A(1))
529 self.assertTrue(A(1) <= A(2))
530 self.assertTrue(A(2) >= A(1))
531 self.assertTrue(A(2) <= A(2))
532 self.assertTrue(A(2) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000533
534 def test_total_ordering_le(self):
535 @functools.total_ordering
536 class A:
537 def __init__(self, value):
538 self.value = value
539 def __le__(self, other):
540 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000541 def __eq__(self, other):
542 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000543 self.assertTrue(A(1) < A(2))
544 self.assertTrue(A(2) > A(1))
545 self.assertTrue(A(1) <= A(2))
546 self.assertTrue(A(2) >= A(1))
547 self.assertTrue(A(2) <= A(2))
548 self.assertTrue(A(2) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000549
550 def test_total_ordering_gt(self):
551 @functools.total_ordering
552 class A:
553 def __init__(self, value):
554 self.value = value
555 def __gt__(self, other):
556 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000557 def __eq__(self, other):
558 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000559 self.assertTrue(A(1) < A(2))
560 self.assertTrue(A(2) > A(1))
561 self.assertTrue(A(1) <= A(2))
562 self.assertTrue(A(2) >= A(1))
563 self.assertTrue(A(2) <= A(2))
564 self.assertTrue(A(2) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000565
566 def test_total_ordering_ge(self):
567 @functools.total_ordering
568 class A:
569 def __init__(self, value):
570 self.value = value
571 def __ge__(self, other):
572 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000573 def __eq__(self, other):
574 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000575 self.assertTrue(A(1) < A(2))
576 self.assertTrue(A(2) > A(1))
577 self.assertTrue(A(1) <= A(2))
578 self.assertTrue(A(2) >= A(1))
579 self.assertTrue(A(2) <= A(2))
580 self.assertTrue(A(2) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000581
582 def test_total_ordering_no_overwrite(self):
583 # new methods should not overwrite existing
584 @functools.total_ordering
585 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000586 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +0000587 self.assertTrue(A(1) < A(2))
588 self.assertTrue(A(2) > A(1))
589 self.assertTrue(A(1) <= A(2))
590 self.assertTrue(A(2) >= A(1))
591 self.assertTrue(A(2) <= A(2))
592 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000593
Benjamin Peterson42ebee32010-04-11 01:43:16 +0000594 def test_no_operations_defined(self):
595 with self.assertRaises(ValueError):
596 @functools.total_ordering
597 class A:
598 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000599
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000600 def test_bug_10042(self):
601 @functools.total_ordering
602 class TestTO:
603 def __init__(self, value):
604 self.value = value
605 def __eq__(self, other):
606 if isinstance(other, TestTO):
607 return self.value == other.value
608 return False
609 def __lt__(self, other):
610 if isinstance(other, TestTO):
611 return self.value < other.value
612 raise TypeError
613 with self.assertRaises(TypeError):
614 TestTO(8) <= ()
615
Georg Brandl2e7346a2010-07-31 18:09:23 +0000616class TestLRU(unittest.TestCase):
617
618 def test_lru(self):
619 def orig(x, y):
620 return 3*x+y
621 f = functools.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000622 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000623 self.assertEqual(maxsize, 20)
624 self.assertEqual(currsize, 0)
625 self.assertEqual(hits, 0)
626 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000627
628 domain = range(5)
629 for i in range(1000):
630 x, y = choice(domain), choice(domain)
631 actual = f(x, y)
632 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +0000633 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000634 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000635 self.assertTrue(hits > misses)
636 self.assertEqual(hits + misses, 1000)
637 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000638
Raymond Hettinger02566ec2010-09-04 22:46:06 +0000639 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +0000640 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000641 self.assertEqual(hits, 0)
642 self.assertEqual(misses, 0)
643 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000644 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000645 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000646 self.assertEqual(hits, 0)
647 self.assertEqual(misses, 1)
648 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000649
Nick Coghlan98876832010-08-17 06:17:18 +0000650 # Test bypassing the cache
651 self.assertIs(f.__wrapped__, orig)
652 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000653 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000654 self.assertEqual(hits, 0)
655 self.assertEqual(misses, 1)
656 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +0000657
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000658 # test size zero (which means "never-cache")
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000659 @functools.lru_cache(0)
660 def f():
661 nonlocal f_cnt
662 f_cnt += 1
663 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000664 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000665 f_cnt = 0
666 for i in range(5):
667 self.assertEqual(f(), 20)
668 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000669 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000670 self.assertEqual(hits, 0)
671 self.assertEqual(misses, 5)
672 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000673
674 # test size one
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000675 @functools.lru_cache(1)
676 def f():
677 nonlocal f_cnt
678 f_cnt += 1
679 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000680 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000681 f_cnt = 0
682 for i in range(5):
683 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000684 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000685 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000686 self.assertEqual(hits, 4)
687 self.assertEqual(misses, 1)
688 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000689
Raymond Hettingerf3098282010-08-15 03:30:45 +0000690 # test size two
691 @functools.lru_cache(2)
692 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000693 nonlocal f_cnt
694 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +0000695 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +0000696 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000697 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +0000698 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
699 # * * * *
700 self.assertEqual(f(x), x*10)
701 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000702 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000703 self.assertEqual(hits, 12)
704 self.assertEqual(misses, 4)
705 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000706
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +0000707 def test_lru_with_maxsize_none(self):
708 @functools.lru_cache(maxsize=None)
709 def fib(n):
710 if n < 2:
711 return n
712 return fib(n-1) + fib(n-2)
713 self.assertEqual([fib(n) for n in range(16)],
714 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
715 self.assertEqual(fib.cache_info(),
716 functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
717 fib.cache_clear()
718 self.assertEqual(fib.cache_info(),
719 functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
720
Raymond Hettinger4b779b32011-10-15 23:50:42 -0700721 def test_lru_with_exceptions(self):
722 # Verify that user_function exceptions get passed through without
723 # creating a hard-to-read chained exception.
724 # http://bugs.python.org/issue13177
725 for maxsize in (None, 100):
726 @functools.lru_cache(maxsize)
727 def func(i):
728 return 'abc'[i]
729 self.assertEqual(func(0), 'a')
730 with self.assertRaises(IndexError) as cm:
731 func(15)
732 self.assertIsNone(cm.exception.__context__)
733 # Verify that the previous exception did not result in a cached entry
734 with self.assertRaises(IndexError):
735 func(15)
736
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -0700737 def test_lru_with_types(self):
738 for maxsize in (None, 100):
739 @functools.lru_cache(maxsize=maxsize, typed=True)
740 def square(x):
741 return x * x
742 self.assertEqual(square(3), 9)
743 self.assertEqual(type(square(3)), type(9))
744 self.assertEqual(square(3.0), 9.0)
745 self.assertEqual(type(square(3.0)), type(9.0))
746 self.assertEqual(square(x=3), 9)
747 self.assertEqual(type(square(x=3)), type(9))
748 self.assertEqual(square(x=3.0), 9.0)
749 self.assertEqual(type(square(x=3.0)), type(9.0))
750 self.assertEqual(square.cache_info().hits, 4)
751 self.assertEqual(square.cache_info().misses, 4)
752
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000753def test_main(verbose=None):
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000754 test_classes = (
755 TestPartial,
756 TestPartialSubclass,
757 TestPythonPartial,
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000758 TestUpdateWrapper,
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000759 TestTotalOrdering,
Raymond Hettinger003be522011-05-03 11:01:32 -0700760 TestCmpToKey,
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000761 TestWraps,
Georg Brandl2e7346a2010-07-31 18:09:23 +0000762 TestReduce,
763 TestLRU,
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000764 )
Benjamin Petersonee8712c2008-05-20 21:35:26 +0000765 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000766
767 # verify reference counting
768 if verbose and hasattr(sys, "gettotalrefcount"):
769 import gc
770 counts = [None] * 5
Guido van Rossum805365e2007-05-07 22:24:25 +0000771 for i in range(len(counts)):
Benjamin Petersonee8712c2008-05-20 21:35:26 +0000772 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000773 gc.collect()
774 counts[i] = sys.gettotalrefcount()
Guido van Rossumbe19ed72007-02-09 05:37:30 +0000775 print(counts)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000776
777if __name__ == '__main__':
778 test_main(verbose=True)