blob: f41a144a1415558e6df9624974c03b67461cbaf0 [file] [log] [blame]
Thomas Wouters4d70c3d2006-06-08 14:42:34 +00001import functools
R. David Murray378c0cf2010-02-24 01:46:21 +00002import sys
Raymond Hettinger9c323f82005-02-28 19:39:44 +00003import unittest
Benjamin Petersonee8712c2008-05-20 21:35:26 +00004from test import support
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +00005from weakref import proxy
Jack Diederiche0cbd692009-04-01 04:27:09 +00006import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00007from random import choice
Raymond Hettinger9c323f82005-02-28 19:39:44 +00008
9@staticmethod
10def PythonPartial(func, *args, **keywords):
11 'Pure Python approximation of partial()'
12 def newfunc(*fargs, **fkeywords):
13 newkeywords = keywords.copy()
14 newkeywords.update(fkeywords)
15 return func(*(args + fargs), **newkeywords)
16 newfunc.func = func
17 newfunc.args = args
18 newfunc.keywords = keywords
19 return newfunc
20
21def capture(*args, **kw):
22 """capture all positional and keyword arguments"""
23 return args, kw
24
Jack Diederiche0cbd692009-04-01 04:27:09 +000025def signature(part):
26 """ return the signature of a partial object """
27 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000028
Raymond Hettinger9c323f82005-02-28 19:39:44 +000029class TestPartial(unittest.TestCase):
30
Thomas Wouters4d70c3d2006-06-08 14:42:34 +000031 thetype = functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +000032
33 def test_basic_examples(self):
34 p = self.thetype(capture, 1, 2, a=10, b=20)
35 self.assertEqual(p(3, 4, b=30, c=40),
36 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
37 p = self.thetype(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000038 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000039
40 def test_attributes(self):
41 p = self.thetype(capture, 1, 2, a=10, b=20)
42 # attributes should be readable
43 self.assertEqual(p.func, capture)
44 self.assertEqual(p.args, (1, 2))
45 self.assertEqual(p.keywords, dict(a=10, b=20))
46 # attributes should not be writable
47 if not isinstance(self.thetype, type):
48 return
Georg Brandl89fad142010-03-14 10:23:39 +000049 self.assertRaises(AttributeError, setattr, p, 'func', map)
50 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
51 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
52
53 p = self.thetype(hex)
54 try:
55 del p.__dict__
56 except TypeError:
57 pass
58 else:
59 self.fail('partial object allowed __dict__ to be deleted')
Raymond Hettinger9c323f82005-02-28 19:39:44 +000060
61 def test_argument_checking(self):
62 self.assertRaises(TypeError, self.thetype) # need at least a func arg
63 try:
64 self.thetype(2)()
65 except TypeError:
66 pass
67 else:
68 self.fail('First arg not checked for callability')
69
70 def test_protection_of_callers_dict_argument(self):
71 # a caller's dictionary should not be altered by partial
72 def func(a=10, b=20):
73 return a
74 d = {'a':3}
75 p = self.thetype(func, a=5)
76 self.assertEqual(p(**d), 3)
77 self.assertEqual(d, {'a':3})
78 p(b=7)
79 self.assertEqual(d, {'a':3})
80
81 def test_arg_combinations(self):
82 # exercise special code paths for zero args in either partial
83 # object or the caller
84 p = self.thetype(capture)
85 self.assertEqual(p(), ((), {}))
86 self.assertEqual(p(1,2), ((1,2), {}))
87 p = self.thetype(capture, 1, 2)
88 self.assertEqual(p(), ((1,2), {}))
89 self.assertEqual(p(3,4), ((1,2,3,4), {}))
90
91 def test_kw_combinations(self):
92 # exercise special code paths for no keyword args in
93 # either the partial object or the caller
94 p = self.thetype(capture)
95 self.assertEqual(p(), ((), {}))
96 self.assertEqual(p(a=1), ((), {'a':1}))
97 p = self.thetype(capture, a=1)
98 self.assertEqual(p(), ((), {'a':1}))
99 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
100 # keyword args in the call override those in the partial object
101 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
102
103 def test_positional(self):
104 # make sure positional arguments are captured correctly
105 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
106 p = self.thetype(capture, *args)
107 expected = args + ('x',)
108 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000109 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000110
111 def test_keyword(self):
112 # make sure keyword arguments are captured correctly
113 for a in ['a', 0, None, 3.5]:
114 p = self.thetype(capture, a=a)
115 expected = {'a':a,'x':None}
116 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000117 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000118
119 def test_no_side_effects(self):
120 # make sure there are no side effects that affect subsequent calls
121 p = self.thetype(capture, 0, a=1)
122 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000123 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000124 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000125 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000126
127 def test_error_propagation(self):
128 def f(x, y):
129 x / y
130 self.assertRaises(ZeroDivisionError, self.thetype(f, 1, 0))
131 self.assertRaises(ZeroDivisionError, self.thetype(f, 1), 0)
132 self.assertRaises(ZeroDivisionError, self.thetype(f), 1, 0)
133 self.assertRaises(ZeroDivisionError, self.thetype(f, y=0), 1)
134
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000135 def test_weakref(self):
136 f = self.thetype(int, base=16)
137 p = proxy(f)
138 self.assertEqual(f.func, p.func)
139 f = None
140 self.assertRaises(ReferenceError, getattr, p, 'func')
141
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000142 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000143 data = list(map(str, range(10)))
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000144 join = self.thetype(str.join, '')
145 self.assertEqual(join(data), '0123456789')
146 join = self.thetype(''.join)
147 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000148
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000149 def test_repr(self):
150 args = (object(), object())
151 args_repr = ', '.join(repr(a) for a in args)
152 kwargs = {'a': object(), 'b': object()}
153 kwargs_repr = ', '.join("%s=%r" % (k, v) for k, v in kwargs.items())
154 if self.thetype is functools.partial:
155 name = 'functools.partial'
156 else:
157 name = self.thetype.__name__
158
159 f = self.thetype(capture)
160 self.assertEqual('{}({!r})'.format(name, capture),
161 repr(f))
162
163 f = self.thetype(capture, *args)
164 self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr),
165 repr(f))
166
167 f = self.thetype(capture, **kwargs)
168 self.assertEqual('{}({!r}, {})'.format(name, capture, kwargs_repr),
169 repr(f))
170
171 f = self.thetype(capture, *args, **kwargs)
172 self.assertEqual('{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr),
173 repr(f))
174
Jack Diederiche0cbd692009-04-01 04:27:09 +0000175 def test_pickle(self):
176 f = self.thetype(signature, 'asdf', bar=True)
177 f.add_something_to__dict__ = True
178 f_copy = pickle.loads(pickle.dumps(f))
179 self.assertEqual(signature(f), signature(f_copy))
180
Thomas Wouters4d70c3d2006-06-08 14:42:34 +0000181class PartialSubclass(functools.partial):
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000182 pass
183
184class TestPartialSubclass(TestPartial):
185
186 thetype = PartialSubclass
187
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000188class TestPythonPartial(TestPartial):
189
190 thetype = PythonPartial
191
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000192 # the python version hasn't a nice repr
193 def test_repr(self): pass
194
Jack Diederiche0cbd692009-04-01 04:27:09 +0000195 # the python version isn't picklable
196 def test_pickle(self): pass
197
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000198class TestUpdateWrapper(unittest.TestCase):
199
200 def check_wrapper(self, wrapper, wrapped,
201 assigned=functools.WRAPPER_ASSIGNMENTS,
202 updated=functools.WRAPPER_UPDATES):
203 # Check attributes were assigned
204 for name in assigned:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000205 self.assertTrue(getattr(wrapper, name) is getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000206 # Check attributes were updated
207 for name in updated:
208 wrapper_attr = getattr(wrapper, name)
209 wrapped_attr = getattr(wrapped, name)
210 for key in wrapped_attr:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000211 self.assertTrue(wrapped_attr[key] is wrapper_attr[key])
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000212
R. David Murray378c0cf2010-02-24 01:46:21 +0000213 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000214 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000215 """This is a test"""
216 pass
217 f.attr = 'This is also a test'
Antoine Pitrou560f7642010-08-04 18:28:02 +0000218 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000219 pass
220 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000221 return wrapper, f
222
223 def test_default_update(self):
224 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000225 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000226 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000227 self.assertEqual(wrapper.__name__, 'f')
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000228 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000229 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
230 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000231
R. David Murray378c0cf2010-02-24 01:46:21 +0000232 @unittest.skipIf(sys.flags.optimize >= 2,
233 "Docstrings are omitted with -O2 and above")
234 def test_default_update_doc(self):
235 wrapper, f = self._default_update()
236 self.assertEqual(wrapper.__doc__, 'This is a test')
237
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000238 def test_no_update(self):
239 def f():
240 """This is a test"""
241 pass
242 f.attr = 'This is also a test'
243 def wrapper():
244 pass
245 functools.update_wrapper(wrapper, f, (), ())
246 self.check_wrapper(wrapper, f, (), ())
247 self.assertEqual(wrapper.__name__, 'wrapper')
248 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000249 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000250 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000251
252 def test_selective_update(self):
253 def f():
254 pass
255 f.attr = 'This is a different test'
256 f.dict_attr = dict(a=1, b=2, c=3)
257 def wrapper():
258 pass
259 wrapper.dict_attr = {}
260 assign = ('attr',)
261 update = ('dict_attr',)
262 functools.update_wrapper(wrapper, f, assign, update)
263 self.check_wrapper(wrapper, f, assign, update)
264 self.assertEqual(wrapper.__name__, 'wrapper')
265 self.assertEqual(wrapper.__doc__, None)
266 self.assertEqual(wrapper.attr, 'This is a different test')
267 self.assertEqual(wrapper.dict_attr, f.dict_attr)
268
Nick Coghlan98876832010-08-17 06:17:18 +0000269 def test_missing_attributes(self):
270 def f():
271 pass
272 def wrapper():
273 pass
274 wrapper.dict_attr = {}
275 assign = ('attr',)
276 update = ('dict_attr',)
277 # Missing attributes on wrapped object are ignored
278 functools.update_wrapper(wrapper, f, assign, update)
279 self.assertNotIn('attr', wrapper.__dict__)
280 self.assertEqual(wrapper.dict_attr, {})
281 # Wrapper must have expected attributes for updating
282 del wrapper.dict_attr
283 with self.assertRaises(AttributeError):
284 functools.update_wrapper(wrapper, f, assign, update)
285 wrapper.dict_attr = 1
286 with self.assertRaises(AttributeError):
287 functools.update_wrapper(wrapper, f, assign, update)
288
289 @unittest.skipIf(sys.flags.optimize >= 2,
290 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000291 def test_builtin_update(self):
292 # Test for bug #1576241
293 def wrapper():
294 pass
295 functools.update_wrapper(wrapper, max)
296 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000297 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000298 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000299
300class TestWraps(TestUpdateWrapper):
301
R. David Murray378c0cf2010-02-24 01:46:21 +0000302 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000303 def f():
304 """This is a test"""
305 pass
306 f.attr = 'This is also a test'
307 @functools.wraps(f)
308 def wrapper():
309 pass
310 self.check_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000311 return wrapper
312
313 def test_default_update(self):
314 wrapper = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000315 self.assertEqual(wrapper.__name__, 'f')
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000316 self.assertEqual(wrapper.attr, 'This is also a test')
317
R. David Murray378c0cf2010-02-24 01:46:21 +0000318 @unittest.skipIf(not sys.flags.optimize <= 1,
319 "Docstrings are omitted with -O2 and above")
320 def test_default_update_doc(self):
321 wrapper = self._default_update()
322 self.assertEqual(wrapper.__doc__, 'This is a test')
323
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000324 def test_no_update(self):
325 def f():
326 """This is a test"""
327 pass
328 f.attr = 'This is also a test'
329 @functools.wraps(f, (), ())
330 def wrapper():
331 pass
332 self.check_wrapper(wrapper, f, (), ())
333 self.assertEqual(wrapper.__name__, 'wrapper')
334 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000335 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000336
337 def test_selective_update(self):
338 def f():
339 pass
340 f.attr = 'This is a different test'
341 f.dict_attr = dict(a=1, b=2, c=3)
342 def add_dict_attr(f):
343 f.dict_attr = {}
344 return f
345 assign = ('attr',)
346 update = ('dict_attr',)
347 @functools.wraps(f, assign, update)
348 @add_dict_attr
349 def wrapper():
350 pass
351 self.check_wrapper(wrapper, f, assign, update)
352 self.assertEqual(wrapper.__name__, 'wrapper')
353 self.assertEqual(wrapper.__doc__, None)
354 self.assertEqual(wrapper.attr, 'This is a different test')
355 self.assertEqual(wrapper.dict_attr, f.dict_attr)
356
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000357class TestReduce(unittest.TestCase):
358 func = functools.reduce
359
360 def test_reduce(self):
361 class Squares:
362 def __init__(self, max):
363 self.max = max
364 self.sofar = []
365
366 def __len__(self):
367 return len(self.sofar)
368
369 def __getitem__(self, i):
370 if not 0 <= i < self.max: raise IndexError
371 n = len(self.sofar)
372 while n <= i:
373 self.sofar.append(n*n)
374 n += 1
375 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000376 def add(x, y):
377 return x + y
378 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000379 self.assertEqual(
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000380 self.func(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000381 ['a','c','d','w']
382 )
383 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
384 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000385 self.func(lambda x, y: x*y, range(2,21), 1),
386 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000387 )
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000388 self.assertEqual(self.func(add, Squares(10)), 285)
389 self.assertEqual(self.func(add, Squares(10), 0), 285)
390 self.assertEqual(self.func(add, Squares(0), 0), 0)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000391 self.assertRaises(TypeError, self.func)
392 self.assertRaises(TypeError, self.func, 42, 42)
393 self.assertRaises(TypeError, self.func, 42, 42, 42)
394 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
395 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
396 self.assertRaises(TypeError, self.func, 42, (42, 42))
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000397 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
398 self.assertRaises(TypeError, self.func, add, "")
399 self.assertRaises(TypeError, self.func, add, ())
400 self.assertRaises(TypeError, self.func, add, object())
401
402 class TestFailingIter:
403 def __iter__(self):
404 raise RuntimeError
405 self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
406
407 self.assertEqual(self.func(add, [], None), None)
408 self.assertEqual(self.func(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000409
410 class BadSeq:
411 def __getitem__(self, index):
412 raise ValueError
413 self.assertRaises(ValueError, self.func, 42, BadSeq())
414
415 # Test reduce()'s use of iterators.
416 def test_iterator_usage(self):
417 class SequenceClass:
418 def __init__(self, n):
419 self.n = n
420 def __getitem__(self, i):
421 if 0 <= i < self.n:
422 return i
423 else:
424 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000425
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000426 from operator import add
427 self.assertEqual(self.func(add, SequenceClass(5)), 10)
428 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
429 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
430 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
431 self.assertEqual(self.func(add, SequenceClass(1)), 0)
432 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
433
434 d = {"one": 1, "two": 2, "three": 3}
435 self.assertEqual(self.func(add, d), "".join(d.keys()))
436
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000437class TestCmpToKey(unittest.TestCase):
438 def test_cmp_to_key(self):
439 def mycmp(x, y):
440 return y - x
441 self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)),
442 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000443
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000444 def test_hash(self):
445 def mycmp(x, y):
446 return y - x
447 key = functools.cmp_to_key(mycmp)
448 k = key(10)
449 self.assertRaises(TypeError, hash(k))
450
451class TestTotalOrdering(unittest.TestCase):
452
453 def test_total_ordering_lt(self):
454 @functools.total_ordering
455 class A:
456 def __init__(self, value):
457 self.value = value
458 def __lt__(self, other):
459 return self.value < other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000460 self.assertTrue(A(1) < A(2))
461 self.assertTrue(A(2) > A(1))
462 self.assertTrue(A(1) <= A(2))
463 self.assertTrue(A(2) >= A(1))
464 self.assertTrue(A(2) <= A(2))
465 self.assertTrue(A(2) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000466
467 def test_total_ordering_le(self):
468 @functools.total_ordering
469 class A:
470 def __init__(self, value):
471 self.value = value
472 def __le__(self, other):
473 return self.value <= other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000474 self.assertTrue(A(1) < A(2))
475 self.assertTrue(A(2) > A(1))
476 self.assertTrue(A(1) <= A(2))
477 self.assertTrue(A(2) >= A(1))
478 self.assertTrue(A(2) <= A(2))
479 self.assertTrue(A(2) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000480
481 def test_total_ordering_gt(self):
482 @functools.total_ordering
483 class A:
484 def __init__(self, value):
485 self.value = value
486 def __gt__(self, other):
487 return self.value > other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000488 self.assertTrue(A(1) < A(2))
489 self.assertTrue(A(2) > A(1))
490 self.assertTrue(A(1) <= A(2))
491 self.assertTrue(A(2) >= A(1))
492 self.assertTrue(A(2) <= A(2))
493 self.assertTrue(A(2) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000494
495 def test_total_ordering_ge(self):
496 @functools.total_ordering
497 class A:
498 def __init__(self, value):
499 self.value = value
500 def __ge__(self, other):
501 return self.value >= other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000502 self.assertTrue(A(1) < A(2))
503 self.assertTrue(A(2) > A(1))
504 self.assertTrue(A(1) <= A(2))
505 self.assertTrue(A(2) >= A(1))
506 self.assertTrue(A(2) <= A(2))
507 self.assertTrue(A(2) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000508
509 def test_total_ordering_no_overwrite(self):
510 # new methods should not overwrite existing
511 @functools.total_ordering
512 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000513 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +0000514 self.assertTrue(A(1) < A(2))
515 self.assertTrue(A(2) > A(1))
516 self.assertTrue(A(1) <= A(2))
517 self.assertTrue(A(2) >= A(1))
518 self.assertTrue(A(2) <= A(2))
519 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000520
Benjamin Peterson42ebee32010-04-11 01:43:16 +0000521 def test_no_operations_defined(self):
522 with self.assertRaises(ValueError):
523 @functools.total_ordering
524 class A:
525 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000526
Georg Brandl2e7346a2010-07-31 18:09:23 +0000527class TestLRU(unittest.TestCase):
528
529 def test_lru(self):
530 def orig(x, y):
531 return 3*x+y
532 f = functools.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000533 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000534 self.assertEqual(maxsize, 20)
535 self.assertEqual(currsize, 0)
536 self.assertEqual(hits, 0)
537 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000538
539 domain = range(5)
540 for i in range(1000):
541 x, y = choice(domain), choice(domain)
542 actual = f(x, y)
543 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +0000544 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000545 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000546 self.assertTrue(hits > misses)
547 self.assertEqual(hits + misses, 1000)
548 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000549
Raymond Hettinger02566ec2010-09-04 22:46:06 +0000550 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +0000551 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000552 self.assertEqual(hits, 0)
553 self.assertEqual(misses, 0)
554 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000555 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000556 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000557 self.assertEqual(hits, 0)
558 self.assertEqual(misses, 1)
559 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000560
Nick Coghlan98876832010-08-17 06:17:18 +0000561 # Test bypassing the cache
562 self.assertIs(f.__wrapped__, orig)
563 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000564 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000565 self.assertEqual(hits, 0)
566 self.assertEqual(misses, 1)
567 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +0000568
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000569 # test size zero (which means "never-cache")
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000570 @functools.lru_cache(0)
571 def f():
572 nonlocal f_cnt
573 f_cnt += 1
574 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000575 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000576 f_cnt = 0
577 for i in range(5):
578 self.assertEqual(f(), 20)
579 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000580 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000581 self.assertEqual(hits, 0)
582 self.assertEqual(misses, 5)
583 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000584
585 # test size one
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000586 @functools.lru_cache(1)
587 def f():
588 nonlocal f_cnt
589 f_cnt += 1
590 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000591 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000592 f_cnt = 0
593 for i in range(5):
594 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000595 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000596 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000597 self.assertEqual(hits, 4)
598 self.assertEqual(misses, 1)
599 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000600
Raymond Hettingerf3098282010-08-15 03:30:45 +0000601 # test size two
602 @functools.lru_cache(2)
603 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000604 nonlocal f_cnt
605 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +0000606 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +0000607 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000608 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +0000609 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
610 # * * * *
611 self.assertEqual(f(x), x*10)
612 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000613 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000614 self.assertEqual(hits, 12)
615 self.assertEqual(misses, 4)
616 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000617
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +0000618 def test_lru_with_maxsize_none(self):
619 @functools.lru_cache(maxsize=None)
620 def fib(n):
621 if n < 2:
622 return n
623 return fib(n-1) + fib(n-2)
624 self.assertEqual([fib(n) for n in range(16)],
625 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
626 self.assertEqual(fib.cache_info(),
627 functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
628 fib.cache_clear()
629 self.assertEqual(fib.cache_info(),
630 functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
631
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000632def test_main(verbose=None):
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000633 test_classes = (
634 TestPartial,
635 TestPartialSubclass,
636 TestPythonPartial,
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000637 TestUpdateWrapper,
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000638 TestTotalOrdering,
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000639 TestWraps,
Georg Brandl2e7346a2010-07-31 18:09:23 +0000640 TestReduce,
641 TestLRU,
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000642 )
Benjamin Petersonee8712c2008-05-20 21:35:26 +0000643 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000644
645 # verify reference counting
646 if verbose and hasattr(sys, "gettotalrefcount"):
647 import gc
648 counts = [None] * 5
Guido van Rossum805365e2007-05-07 22:24:25 +0000649 for i in range(len(counts)):
Benjamin Petersonee8712c2008-05-20 21:35:26 +0000650 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000651 gc.collect()
652 counts[i] = sys.gettotalrefcount()
Guido van Rossumbe19ed72007-02-09 05:37:30 +0000653 print(counts)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000654
655if __name__ == '__main__':
656 test_main(verbose=True)