blob: 11e6e844204e2b79f486722b1bc8925ad39cd75e [file] [log] [blame]
Thomas Wouters4d70c3d2006-06-08 14:42:34 +00001import functools
Raymond Hettinger003be522011-05-03 11:01:32 -07002import collections
R. David Murray378c0cf2010-02-24 01:46:21 +00003import sys
Raymond Hettinger9c323f82005-02-28 19:39:44 +00004import unittest
Benjamin Petersonee8712c2008-05-20 21:35:26 +00005from test import support
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +00006from weakref import proxy
Jack Diederiche0cbd692009-04-01 04:27:09 +00007import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00008from random import choice
Raymond Hettinger9c323f82005-02-28 19:39:44 +00009
10@staticmethod
11def PythonPartial(func, *args, **keywords):
12 'Pure Python approximation of partial()'
13 def newfunc(*fargs, **fkeywords):
14 newkeywords = keywords.copy()
15 newkeywords.update(fkeywords)
16 return func(*(args + fargs), **newkeywords)
17 newfunc.func = func
18 newfunc.args = args
19 newfunc.keywords = keywords
20 return newfunc
21
22def capture(*args, **kw):
23 """capture all positional and keyword arguments"""
24 return args, kw
25
Jack Diederiche0cbd692009-04-01 04:27:09 +000026def signature(part):
27 """ return the signature of a partial object """
28 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000029
Raymond Hettinger9c323f82005-02-28 19:39:44 +000030class TestPartial(unittest.TestCase):
31
Thomas Wouters4d70c3d2006-06-08 14:42:34 +000032 thetype = functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +000033
34 def test_basic_examples(self):
35 p = self.thetype(capture, 1, 2, a=10, b=20)
36 self.assertEqual(p(3, 4, b=30, c=40),
37 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
38 p = self.thetype(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000039 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000040
41 def test_attributes(self):
42 p = self.thetype(capture, 1, 2, a=10, b=20)
43 # attributes should be readable
44 self.assertEqual(p.func, capture)
45 self.assertEqual(p.args, (1, 2))
46 self.assertEqual(p.keywords, dict(a=10, b=20))
47 # attributes should not be writable
48 if not isinstance(self.thetype, type):
49 return
Georg Brandl89fad142010-03-14 10:23:39 +000050 self.assertRaises(AttributeError, setattr, p, 'func', map)
51 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
52 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
53
54 p = self.thetype(hex)
55 try:
56 del p.__dict__
57 except TypeError:
58 pass
59 else:
60 self.fail('partial object allowed __dict__ to be deleted')
Raymond Hettinger9c323f82005-02-28 19:39:44 +000061
62 def test_argument_checking(self):
63 self.assertRaises(TypeError, self.thetype) # need at least a func arg
64 try:
65 self.thetype(2)()
66 except TypeError:
67 pass
68 else:
69 self.fail('First arg not checked for callability')
70
71 def test_protection_of_callers_dict_argument(self):
72 # a caller's dictionary should not be altered by partial
73 def func(a=10, b=20):
74 return a
75 d = {'a':3}
76 p = self.thetype(func, a=5)
77 self.assertEqual(p(**d), 3)
78 self.assertEqual(d, {'a':3})
79 p(b=7)
80 self.assertEqual(d, {'a':3})
81
82 def test_arg_combinations(self):
83 # exercise special code paths for zero args in either partial
84 # object or the caller
85 p = self.thetype(capture)
86 self.assertEqual(p(), ((), {}))
87 self.assertEqual(p(1,2), ((1,2), {}))
88 p = self.thetype(capture, 1, 2)
89 self.assertEqual(p(), ((1,2), {}))
90 self.assertEqual(p(3,4), ((1,2,3,4), {}))
91
92 def test_kw_combinations(self):
93 # exercise special code paths for no keyword args in
94 # either the partial object or the caller
95 p = self.thetype(capture)
96 self.assertEqual(p(), ((), {}))
97 self.assertEqual(p(a=1), ((), {'a':1}))
98 p = self.thetype(capture, a=1)
99 self.assertEqual(p(), ((), {'a':1}))
100 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
101 # keyword args in the call override those in the partial object
102 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
103
104 def test_positional(self):
105 # make sure positional arguments are captured correctly
106 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
107 p = self.thetype(capture, *args)
108 expected = args + ('x',)
109 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000110 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000111
112 def test_keyword(self):
113 # make sure keyword arguments are captured correctly
114 for a in ['a', 0, None, 3.5]:
115 p = self.thetype(capture, a=a)
116 expected = {'a':a,'x':None}
117 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000118 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000119
120 def test_no_side_effects(self):
121 # make sure there are no side effects that affect subsequent calls
122 p = self.thetype(capture, 0, a=1)
123 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000124 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000125 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000126 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000127
128 def test_error_propagation(self):
129 def f(x, y):
130 x / y
131 self.assertRaises(ZeroDivisionError, self.thetype(f, 1, 0))
132 self.assertRaises(ZeroDivisionError, self.thetype(f, 1), 0)
133 self.assertRaises(ZeroDivisionError, self.thetype(f), 1, 0)
134 self.assertRaises(ZeroDivisionError, self.thetype(f, y=0), 1)
135
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000136 def test_weakref(self):
137 f = self.thetype(int, base=16)
138 p = proxy(f)
139 self.assertEqual(f.func, p.func)
140 f = None
141 self.assertRaises(ReferenceError, getattr, p, 'func')
142
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000143 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000144 data = list(map(str, range(10)))
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000145 join = self.thetype(str.join, '')
146 self.assertEqual(join(data), '0123456789')
147 join = self.thetype(''.join)
148 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000149
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000150 def test_repr(self):
151 args = (object(), object())
152 args_repr = ', '.join(repr(a) for a in args)
153 kwargs = {'a': object(), 'b': object()}
154 kwargs_repr = ', '.join("%s=%r" % (k, v) for k, v in kwargs.items())
155 if self.thetype is functools.partial:
156 name = 'functools.partial'
157 else:
158 name = self.thetype.__name__
159
160 f = self.thetype(capture)
161 self.assertEqual('{}({!r})'.format(name, capture),
162 repr(f))
163
164 f = self.thetype(capture, *args)
165 self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr),
166 repr(f))
167
168 f = self.thetype(capture, **kwargs)
169 self.assertEqual('{}({!r}, {})'.format(name, capture, kwargs_repr),
170 repr(f))
171
172 f = self.thetype(capture, *args, **kwargs)
173 self.assertEqual('{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr),
174 repr(f))
175
Jack Diederiche0cbd692009-04-01 04:27:09 +0000176 def test_pickle(self):
177 f = self.thetype(signature, 'asdf', bar=True)
178 f.add_something_to__dict__ = True
179 f_copy = pickle.loads(pickle.dumps(f))
180 self.assertEqual(signature(f), signature(f_copy))
181
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200182 # Issue 6083: Reference counting bug
183 def test_setstate_refcount(self):
184 class BadSequence:
185 def __len__(self):
186 return 4
187 def __getitem__(self, key):
188 if key == 0:
189 return max
190 elif key == 1:
191 return tuple(range(1000000))
192 elif key in (2, 3):
193 return {}
194 raise IndexError
195
196 f = self.thetype(object)
197 self.assertRaisesRegex(SystemError,
198 "new style getargs format but argument is not a tuple",
199 f.__setstate__, BadSequence())
200
Thomas Wouters4d70c3d2006-06-08 14:42:34 +0000201class PartialSubclass(functools.partial):
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000202 pass
203
204class TestPartialSubclass(TestPartial):
205
206 thetype = PartialSubclass
207
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000208class TestPythonPartial(TestPartial):
209
210 thetype = PythonPartial
211
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000212 # the python version hasn't a nice repr
213 def test_repr(self): pass
214
Jack Diederiche0cbd692009-04-01 04:27:09 +0000215 # the python version isn't picklable
216 def test_pickle(self): pass
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200217 def test_setstate_refcount(self): pass
Jack Diederiche0cbd692009-04-01 04:27:09 +0000218
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000219class TestUpdateWrapper(unittest.TestCase):
220
221 def check_wrapper(self, wrapper, wrapped,
222 assigned=functools.WRAPPER_ASSIGNMENTS,
223 updated=functools.WRAPPER_UPDATES):
224 # Check attributes were assigned
225 for name in assigned:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000226 self.assertTrue(getattr(wrapper, name) is getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000227 # Check attributes were updated
228 for name in updated:
229 wrapper_attr = getattr(wrapper, name)
230 wrapped_attr = getattr(wrapped, name)
231 for key in wrapped_attr:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000232 self.assertTrue(wrapped_attr[key] is wrapper_attr[key])
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000233
R. David Murray378c0cf2010-02-24 01:46:21 +0000234 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000235 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000236 """This is a test"""
237 pass
238 f.attr = 'This is also a test'
Antoine Pitrou560f7642010-08-04 18:28:02 +0000239 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000240 pass
241 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000242 return wrapper, f
243
244 def test_default_update(self):
245 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000246 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000247 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000248 self.assertEqual(wrapper.__name__, 'f')
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000249 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000250 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
251 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000252
R. David Murray378c0cf2010-02-24 01:46:21 +0000253 @unittest.skipIf(sys.flags.optimize >= 2,
254 "Docstrings are omitted with -O2 and above")
255 def test_default_update_doc(self):
256 wrapper, f = self._default_update()
257 self.assertEqual(wrapper.__doc__, 'This is a test')
258
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000259 def test_no_update(self):
260 def f():
261 """This is a test"""
262 pass
263 f.attr = 'This is also a test'
264 def wrapper():
265 pass
266 functools.update_wrapper(wrapper, f, (), ())
267 self.check_wrapper(wrapper, f, (), ())
268 self.assertEqual(wrapper.__name__, 'wrapper')
269 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000270 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000271 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000272
273 def test_selective_update(self):
274 def f():
275 pass
276 f.attr = 'This is a different test'
277 f.dict_attr = dict(a=1, b=2, c=3)
278 def wrapper():
279 pass
280 wrapper.dict_attr = {}
281 assign = ('attr',)
282 update = ('dict_attr',)
283 functools.update_wrapper(wrapper, f, assign, update)
284 self.check_wrapper(wrapper, f, assign, update)
285 self.assertEqual(wrapper.__name__, 'wrapper')
286 self.assertEqual(wrapper.__doc__, None)
287 self.assertEqual(wrapper.attr, 'This is a different test')
288 self.assertEqual(wrapper.dict_attr, f.dict_attr)
289
Nick Coghlan98876832010-08-17 06:17:18 +0000290 def test_missing_attributes(self):
291 def f():
292 pass
293 def wrapper():
294 pass
295 wrapper.dict_attr = {}
296 assign = ('attr',)
297 update = ('dict_attr',)
298 # Missing attributes on wrapped object are ignored
299 functools.update_wrapper(wrapper, f, assign, update)
300 self.assertNotIn('attr', wrapper.__dict__)
301 self.assertEqual(wrapper.dict_attr, {})
302 # Wrapper must have expected attributes for updating
303 del wrapper.dict_attr
304 with self.assertRaises(AttributeError):
305 functools.update_wrapper(wrapper, f, assign, update)
306 wrapper.dict_attr = 1
307 with self.assertRaises(AttributeError):
308 functools.update_wrapper(wrapper, f, assign, update)
309
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200310 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000311 @unittest.skipIf(sys.flags.optimize >= 2,
312 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000313 def test_builtin_update(self):
314 # Test for bug #1576241
315 def wrapper():
316 pass
317 functools.update_wrapper(wrapper, max)
318 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000319 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000320 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000321
322class TestWraps(TestUpdateWrapper):
323
R. David Murray378c0cf2010-02-24 01:46:21 +0000324 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000325 def f():
326 """This is a test"""
327 pass
328 f.attr = 'This is also a test'
329 @functools.wraps(f)
330 def wrapper():
331 pass
332 self.check_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000333 return wrapper
334
335 def test_default_update(self):
336 wrapper = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000337 self.assertEqual(wrapper.__name__, 'f')
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000338 self.assertEqual(wrapper.attr, 'This is also a test')
339
Serhiy Storchaka8e0ae2a2013-01-28 13:25:44 +0200340 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000341 "Docstrings are omitted with -O2 and above")
342 def test_default_update_doc(self):
343 wrapper = self._default_update()
344 self.assertEqual(wrapper.__doc__, 'This is a test')
345
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000346 def test_no_update(self):
347 def f():
348 """This is a test"""
349 pass
350 f.attr = 'This is also a test'
351 @functools.wraps(f, (), ())
352 def wrapper():
353 pass
354 self.check_wrapper(wrapper, f, (), ())
355 self.assertEqual(wrapper.__name__, 'wrapper')
356 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000357 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000358
359 def test_selective_update(self):
360 def f():
361 pass
362 f.attr = 'This is a different test'
363 f.dict_attr = dict(a=1, b=2, c=3)
364 def add_dict_attr(f):
365 f.dict_attr = {}
366 return f
367 assign = ('attr',)
368 update = ('dict_attr',)
369 @functools.wraps(f, assign, update)
370 @add_dict_attr
371 def wrapper():
372 pass
373 self.check_wrapper(wrapper, f, assign, update)
374 self.assertEqual(wrapper.__name__, 'wrapper')
375 self.assertEqual(wrapper.__doc__, None)
376 self.assertEqual(wrapper.attr, 'This is a different test')
377 self.assertEqual(wrapper.dict_attr, f.dict_attr)
378
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000379class TestReduce(unittest.TestCase):
380 func = functools.reduce
381
382 def test_reduce(self):
383 class Squares:
384 def __init__(self, max):
385 self.max = max
386 self.sofar = []
387
388 def __len__(self):
389 return len(self.sofar)
390
391 def __getitem__(self, i):
392 if not 0 <= i < self.max: raise IndexError
393 n = len(self.sofar)
394 while n <= i:
395 self.sofar.append(n*n)
396 n += 1
397 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000398 def add(x, y):
399 return x + y
400 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000401 self.assertEqual(
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000402 self.func(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000403 ['a','c','d','w']
404 )
405 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
406 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000407 self.func(lambda x, y: x*y, range(2,21), 1),
408 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000409 )
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000410 self.assertEqual(self.func(add, Squares(10)), 285)
411 self.assertEqual(self.func(add, Squares(10), 0), 285)
412 self.assertEqual(self.func(add, Squares(0), 0), 0)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000413 self.assertRaises(TypeError, self.func)
414 self.assertRaises(TypeError, self.func, 42, 42)
415 self.assertRaises(TypeError, self.func, 42, 42, 42)
416 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
417 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
418 self.assertRaises(TypeError, self.func, 42, (42, 42))
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000419 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
420 self.assertRaises(TypeError, self.func, add, "")
421 self.assertRaises(TypeError, self.func, add, ())
422 self.assertRaises(TypeError, self.func, add, object())
423
424 class TestFailingIter:
425 def __iter__(self):
426 raise RuntimeError
427 self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
428
429 self.assertEqual(self.func(add, [], None), None)
430 self.assertEqual(self.func(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000431
432 class BadSeq:
433 def __getitem__(self, index):
434 raise ValueError
435 self.assertRaises(ValueError, self.func, 42, BadSeq())
436
437 # Test reduce()'s use of iterators.
438 def test_iterator_usage(self):
439 class SequenceClass:
440 def __init__(self, n):
441 self.n = n
442 def __getitem__(self, i):
443 if 0 <= i < self.n:
444 return i
445 else:
446 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000447
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000448 from operator import add
449 self.assertEqual(self.func(add, SequenceClass(5)), 10)
450 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
451 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
452 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
453 self.assertEqual(self.func(add, SequenceClass(1)), 0)
454 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
455
456 d = {"one": 1, "two": 2, "three": 3}
457 self.assertEqual(self.func(add, d), "".join(d.keys()))
458
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000459class TestCmpToKey(unittest.TestCase):
460 def test_cmp_to_key(self):
461 def mycmp(x, y):
462 return y - x
463 self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)),
464 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000465
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000466 def test_hash(self):
467 def mycmp(x, y):
468 return y - x
469 key = functools.cmp_to_key(mycmp)
470 k = key(10)
Raymond Hettinger003be522011-05-03 11:01:32 -0700471 self.assertRaises(TypeError, hash, k)
472 self.assertFalse(isinstance(k, collections.Hashable))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000473
474class TestTotalOrdering(unittest.TestCase):
475
476 def test_total_ordering_lt(self):
477 @functools.total_ordering
478 class A:
479 def __init__(self, value):
480 self.value = value
481 def __lt__(self, other):
482 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000483 def __eq__(self, other):
484 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000485 self.assertTrue(A(1) < A(2))
486 self.assertTrue(A(2) > A(1))
487 self.assertTrue(A(1) <= A(2))
488 self.assertTrue(A(2) >= A(1))
489 self.assertTrue(A(2) <= A(2))
490 self.assertTrue(A(2) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000491
492 def test_total_ordering_le(self):
493 @functools.total_ordering
494 class A:
495 def __init__(self, value):
496 self.value = value
497 def __le__(self, other):
498 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000499 def __eq__(self, other):
500 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000501 self.assertTrue(A(1) < A(2))
502 self.assertTrue(A(2) > A(1))
503 self.assertTrue(A(1) <= A(2))
504 self.assertTrue(A(2) >= A(1))
505 self.assertTrue(A(2) <= A(2))
506 self.assertTrue(A(2) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000507
508 def test_total_ordering_gt(self):
509 @functools.total_ordering
510 class A:
511 def __init__(self, value):
512 self.value = value
513 def __gt__(self, other):
514 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000515 def __eq__(self, other):
516 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000517 self.assertTrue(A(1) < A(2))
518 self.assertTrue(A(2) > A(1))
519 self.assertTrue(A(1) <= A(2))
520 self.assertTrue(A(2) >= A(1))
521 self.assertTrue(A(2) <= A(2))
522 self.assertTrue(A(2) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000523
524 def test_total_ordering_ge(self):
525 @functools.total_ordering
526 class A:
527 def __init__(self, value):
528 self.value = value
529 def __ge__(self, other):
530 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000531 def __eq__(self, other):
532 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000533 self.assertTrue(A(1) < A(2))
534 self.assertTrue(A(2) > A(1))
535 self.assertTrue(A(1) <= A(2))
536 self.assertTrue(A(2) >= A(1))
537 self.assertTrue(A(2) <= A(2))
538 self.assertTrue(A(2) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000539
540 def test_total_ordering_no_overwrite(self):
541 # new methods should not overwrite existing
542 @functools.total_ordering
543 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000544 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +0000545 self.assertTrue(A(1) < A(2))
546 self.assertTrue(A(2) > A(1))
547 self.assertTrue(A(1) <= A(2))
548 self.assertTrue(A(2) >= A(1))
549 self.assertTrue(A(2) <= A(2))
550 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000551
Benjamin Peterson42ebee32010-04-11 01:43:16 +0000552 def test_no_operations_defined(self):
553 with self.assertRaises(ValueError):
554 @functools.total_ordering
555 class A:
556 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000557
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000558 def test_bug_10042(self):
559 @functools.total_ordering
560 class TestTO:
561 def __init__(self, value):
562 self.value = value
563 def __eq__(self, other):
564 if isinstance(other, TestTO):
565 return self.value == other.value
566 return False
567 def __lt__(self, other):
568 if isinstance(other, TestTO):
569 return self.value < other.value
570 raise TypeError
571 with self.assertRaises(TypeError):
572 TestTO(8) <= ()
573
Georg Brandl2e7346a2010-07-31 18:09:23 +0000574class TestLRU(unittest.TestCase):
575
576 def test_lru(self):
577 def orig(x, y):
578 return 3*x+y
579 f = functools.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000580 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000581 self.assertEqual(maxsize, 20)
582 self.assertEqual(currsize, 0)
583 self.assertEqual(hits, 0)
584 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000585
586 domain = range(5)
587 for i in range(1000):
588 x, y = choice(domain), choice(domain)
589 actual = f(x, y)
590 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +0000591 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000592 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000593 self.assertTrue(hits > misses)
594 self.assertEqual(hits + misses, 1000)
595 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000596
Raymond Hettinger02566ec2010-09-04 22:46:06 +0000597 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +0000598 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000599 self.assertEqual(hits, 0)
600 self.assertEqual(misses, 0)
601 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000602 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000603 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000604 self.assertEqual(hits, 0)
605 self.assertEqual(misses, 1)
606 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000607
Nick Coghlan98876832010-08-17 06:17:18 +0000608 # Test bypassing the cache
609 self.assertIs(f.__wrapped__, orig)
610 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000611 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000612 self.assertEqual(hits, 0)
613 self.assertEqual(misses, 1)
614 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +0000615
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000616 # test size zero (which means "never-cache")
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000617 @functools.lru_cache(0)
618 def f():
619 nonlocal f_cnt
620 f_cnt += 1
621 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000622 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000623 f_cnt = 0
624 for i in range(5):
625 self.assertEqual(f(), 20)
626 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000627 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000628 self.assertEqual(hits, 0)
629 self.assertEqual(misses, 5)
630 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000631
632 # test size one
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000633 @functools.lru_cache(1)
634 def f():
635 nonlocal f_cnt
636 f_cnt += 1
637 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000638 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000639 f_cnt = 0
640 for i in range(5):
641 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000642 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000643 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000644 self.assertEqual(hits, 4)
645 self.assertEqual(misses, 1)
646 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000647
Raymond Hettingerf3098282010-08-15 03:30:45 +0000648 # test size two
649 @functools.lru_cache(2)
650 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000651 nonlocal f_cnt
652 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +0000653 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +0000654 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000655 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +0000656 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
657 # * * * *
658 self.assertEqual(f(x), x*10)
659 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000660 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000661 self.assertEqual(hits, 12)
662 self.assertEqual(misses, 4)
663 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000664
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +0000665 def test_lru_with_maxsize_none(self):
666 @functools.lru_cache(maxsize=None)
667 def fib(n):
668 if n < 2:
669 return n
670 return fib(n-1) + fib(n-2)
671 self.assertEqual([fib(n) for n in range(16)],
672 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
673 self.assertEqual(fib.cache_info(),
674 functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
675 fib.cache_clear()
676 self.assertEqual(fib.cache_info(),
677 functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
678
Raymond Hettinger4b779b32011-10-15 23:50:42 -0700679 def test_lru_with_exceptions(self):
680 # Verify that user_function exceptions get passed through without
681 # creating a hard-to-read chained exception.
682 # http://bugs.python.org/issue13177
683 for maxsize in (None, 100):
684 @functools.lru_cache(maxsize)
685 def func(i):
686 return 'abc'[i]
687 self.assertEqual(func(0), 'a')
688 with self.assertRaises(IndexError) as cm:
689 func(15)
690 self.assertIsNone(cm.exception.__context__)
691 # Verify that the previous exception did not result in a cached entry
692 with self.assertRaises(IndexError):
693 func(15)
694
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000695def test_main(verbose=None):
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000696 test_classes = (
697 TestPartial,
698 TestPartialSubclass,
699 TestPythonPartial,
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000700 TestUpdateWrapper,
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000701 TestTotalOrdering,
Raymond Hettinger003be522011-05-03 11:01:32 -0700702 TestCmpToKey,
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000703 TestWraps,
Georg Brandl2e7346a2010-07-31 18:09:23 +0000704 TestReduce,
705 TestLRU,
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000706 )
Benjamin Petersonee8712c2008-05-20 21:35:26 +0000707 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000708
709 # verify reference counting
710 if verbose and hasattr(sys, "gettotalrefcount"):
711 import gc
712 counts = [None] * 5
Guido van Rossum805365e2007-05-07 22:24:25 +0000713 for i in range(len(counts)):
Benjamin Petersonee8712c2008-05-20 21:35:26 +0000714 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000715 gc.collect()
716 counts[i] = sys.gettotalrefcount()
Guido van Rossumbe19ed72007-02-09 05:37:30 +0000717 print(counts)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000718
719if __name__ == '__main__':
720 test_main(verbose=True)