blob: 73a77d63f27b5f1484c22a2f617349ed795a9fcb [file] [log] [blame]
Thomas Wouters4d70c3d2006-06-08 14:42:34 +00001import functools
R. David Murray378c0cf2010-02-24 01:46:21 +00002import sys
Raymond Hettinger9c323f82005-02-28 19:39:44 +00003import unittest
Benjamin Petersonee8712c2008-05-20 21:35:26 +00004from test import support
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +00005from weakref import proxy
Jack Diederiche0cbd692009-04-01 04:27:09 +00006import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00007from random import choice
Raymond Hettinger9c323f82005-02-28 19:39:44 +00008
9@staticmethod
10def PythonPartial(func, *args, **keywords):
11 'Pure Python approximation of partial()'
12 def newfunc(*fargs, **fkeywords):
13 newkeywords = keywords.copy()
14 newkeywords.update(fkeywords)
15 return func(*(args + fargs), **newkeywords)
16 newfunc.func = func
17 newfunc.args = args
18 newfunc.keywords = keywords
19 return newfunc
20
21def capture(*args, **kw):
22 """capture all positional and keyword arguments"""
23 return args, kw
24
Jack Diederiche0cbd692009-04-01 04:27:09 +000025def signature(part):
26 """ return the signature of a partial object """
27 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000028
Raymond Hettinger9c323f82005-02-28 19:39:44 +000029class TestPartial(unittest.TestCase):
30
Thomas Wouters4d70c3d2006-06-08 14:42:34 +000031 thetype = functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +000032
33 def test_basic_examples(self):
34 p = self.thetype(capture, 1, 2, a=10, b=20)
35 self.assertEqual(p(3, 4, b=30, c=40),
36 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
37 p = self.thetype(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000038 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000039
40 def test_attributes(self):
41 p = self.thetype(capture, 1, 2, a=10, b=20)
42 # attributes should be readable
43 self.assertEqual(p.func, capture)
44 self.assertEqual(p.args, (1, 2))
45 self.assertEqual(p.keywords, dict(a=10, b=20))
46 # attributes should not be writable
47 if not isinstance(self.thetype, type):
48 return
Georg Brandl89fad142010-03-14 10:23:39 +000049 self.assertRaises(AttributeError, setattr, p, 'func', map)
50 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
51 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
52
53 p = self.thetype(hex)
54 try:
55 del p.__dict__
56 except TypeError:
57 pass
58 else:
59 self.fail('partial object allowed __dict__ to be deleted')
Raymond Hettinger9c323f82005-02-28 19:39:44 +000060
61 def test_argument_checking(self):
62 self.assertRaises(TypeError, self.thetype) # need at least a func arg
63 try:
64 self.thetype(2)()
65 except TypeError:
66 pass
67 else:
68 self.fail('First arg not checked for callability')
69
70 def test_protection_of_callers_dict_argument(self):
71 # a caller's dictionary should not be altered by partial
72 def func(a=10, b=20):
73 return a
74 d = {'a':3}
75 p = self.thetype(func, a=5)
76 self.assertEqual(p(**d), 3)
77 self.assertEqual(d, {'a':3})
78 p(b=7)
79 self.assertEqual(d, {'a':3})
80
81 def test_arg_combinations(self):
82 # exercise special code paths for zero args in either partial
83 # object or the caller
84 p = self.thetype(capture)
85 self.assertEqual(p(), ((), {}))
86 self.assertEqual(p(1,2), ((1,2), {}))
87 p = self.thetype(capture, 1, 2)
88 self.assertEqual(p(), ((1,2), {}))
89 self.assertEqual(p(3,4), ((1,2,3,4), {}))
90
91 def test_kw_combinations(self):
92 # exercise special code paths for no keyword args in
93 # either the partial object or the caller
94 p = self.thetype(capture)
95 self.assertEqual(p(), ((), {}))
96 self.assertEqual(p(a=1), ((), {'a':1}))
97 p = self.thetype(capture, a=1)
98 self.assertEqual(p(), ((), {'a':1}))
99 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
100 # keyword args in the call override those in the partial object
101 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
102
103 def test_positional(self):
104 # make sure positional arguments are captured correctly
105 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
106 p = self.thetype(capture, *args)
107 expected = args + ('x',)
108 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000109 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000110
111 def test_keyword(self):
112 # make sure keyword arguments are captured correctly
113 for a in ['a', 0, None, 3.5]:
114 p = self.thetype(capture, a=a)
115 expected = {'a':a,'x':None}
116 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000117 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000118
119 def test_no_side_effects(self):
120 # make sure there are no side effects that affect subsequent calls
121 p = self.thetype(capture, 0, a=1)
122 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000123 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000124 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000125 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000126
127 def test_error_propagation(self):
128 def f(x, y):
129 x / y
130 self.assertRaises(ZeroDivisionError, self.thetype(f, 1, 0))
131 self.assertRaises(ZeroDivisionError, self.thetype(f, 1), 0)
132 self.assertRaises(ZeroDivisionError, self.thetype(f), 1, 0)
133 self.assertRaises(ZeroDivisionError, self.thetype(f, y=0), 1)
134
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000135 def test_weakref(self):
136 f = self.thetype(int, base=16)
137 p = proxy(f)
138 self.assertEqual(f.func, p.func)
139 f = None
140 self.assertRaises(ReferenceError, getattr, p, 'func')
141
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000142 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000143 data = list(map(str, range(10)))
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000144 join = self.thetype(str.join, '')
145 self.assertEqual(join(data), '0123456789')
146 join = self.thetype(''.join)
147 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000148
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000149 def test_repr(self):
150 args = (object(), object())
151 args_repr = ', '.join(repr(a) for a in args)
152 kwargs = {'a': object(), 'b': object()}
153 kwargs_repr = ', '.join("%s=%r" % (k, v) for k, v in kwargs.items())
154 if self.thetype is functools.partial:
155 name = 'functools.partial'
156 else:
157 name = self.thetype.__name__
158
159 f = self.thetype(capture)
160 self.assertEqual('{}({!r})'.format(name, capture),
161 repr(f))
162
163 f = self.thetype(capture, *args)
164 self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr),
165 repr(f))
166
167 f = self.thetype(capture, **kwargs)
168 self.assertEqual('{}({!r}, {})'.format(name, capture, kwargs_repr),
169 repr(f))
170
171 f = self.thetype(capture, *args, **kwargs)
172 self.assertEqual('{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr),
173 repr(f))
174
Jack Diederiche0cbd692009-04-01 04:27:09 +0000175 def test_pickle(self):
176 f = self.thetype(signature, 'asdf', bar=True)
177 f.add_something_to__dict__ = True
178 f_copy = pickle.loads(pickle.dumps(f))
179 self.assertEqual(signature(f), signature(f_copy))
180
Thomas Wouters4d70c3d2006-06-08 14:42:34 +0000181class PartialSubclass(functools.partial):
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000182 pass
183
184class TestPartialSubclass(TestPartial):
185
186 thetype = PartialSubclass
187
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000188class TestPythonPartial(TestPartial):
189
190 thetype = PythonPartial
191
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000192 # the python version hasn't a nice repr
193 def test_repr(self): pass
194
Jack Diederiche0cbd692009-04-01 04:27:09 +0000195 # the python version isn't picklable
196 def test_pickle(self): pass
197
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000198class TestUpdateWrapper(unittest.TestCase):
199
200 def check_wrapper(self, wrapper, wrapped,
201 assigned=functools.WRAPPER_ASSIGNMENTS,
202 updated=functools.WRAPPER_UPDATES):
203 # Check attributes were assigned
204 for name in assigned:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000205 self.assertTrue(getattr(wrapper, name) is getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000206 # Check attributes were updated
207 for name in updated:
208 wrapper_attr = getattr(wrapper, name)
209 wrapped_attr = getattr(wrapped, name)
210 for key in wrapped_attr:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000211 self.assertTrue(wrapped_attr[key] is wrapper_attr[key])
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000212
R. David Murray378c0cf2010-02-24 01:46:21 +0000213 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000214 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000215 """This is a test"""
216 pass
217 f.attr = 'This is also a test'
Antoine Pitrou560f7642010-08-04 18:28:02 +0000218 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000219 pass
220 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000221 return wrapper, f
222
223 def test_default_update(self):
224 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000225 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000226 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000227 self.assertEqual(wrapper.__name__, 'f')
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000228 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000229 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
230 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000231
R. David Murray378c0cf2010-02-24 01:46:21 +0000232 @unittest.skipIf(sys.flags.optimize >= 2,
233 "Docstrings are omitted with -O2 and above")
234 def test_default_update_doc(self):
235 wrapper, f = self._default_update()
236 self.assertEqual(wrapper.__doc__, 'This is a test')
237
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000238 def test_no_update(self):
239 def f():
240 """This is a test"""
241 pass
242 f.attr = 'This is also a test'
243 def wrapper():
244 pass
245 functools.update_wrapper(wrapper, f, (), ())
246 self.check_wrapper(wrapper, f, (), ())
247 self.assertEqual(wrapper.__name__, 'wrapper')
248 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000249 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000250 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000251
252 def test_selective_update(self):
253 def f():
254 pass
255 f.attr = 'This is a different test'
256 f.dict_attr = dict(a=1, b=2, c=3)
257 def wrapper():
258 pass
259 wrapper.dict_attr = {}
260 assign = ('attr',)
261 update = ('dict_attr',)
262 functools.update_wrapper(wrapper, f, assign, update)
263 self.check_wrapper(wrapper, f, assign, update)
264 self.assertEqual(wrapper.__name__, 'wrapper')
265 self.assertEqual(wrapper.__doc__, None)
266 self.assertEqual(wrapper.attr, 'This is a different test')
267 self.assertEqual(wrapper.dict_attr, f.dict_attr)
268
Nick Coghlan98876832010-08-17 06:17:18 +0000269 def test_missing_attributes(self):
270 def f():
271 pass
272 def wrapper():
273 pass
274 wrapper.dict_attr = {}
275 assign = ('attr',)
276 update = ('dict_attr',)
277 # Missing attributes on wrapped object are ignored
278 functools.update_wrapper(wrapper, f, assign, update)
279 self.assertNotIn('attr', wrapper.__dict__)
280 self.assertEqual(wrapper.dict_attr, {})
281 # Wrapper must have expected attributes for updating
282 del wrapper.dict_attr
283 with self.assertRaises(AttributeError):
284 functools.update_wrapper(wrapper, f, assign, update)
285 wrapper.dict_attr = 1
286 with self.assertRaises(AttributeError):
287 functools.update_wrapper(wrapper, f, assign, update)
288
289 @unittest.skipIf(sys.flags.optimize >= 2,
290 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000291 def test_builtin_update(self):
292 # Test for bug #1576241
293 def wrapper():
294 pass
295 functools.update_wrapper(wrapper, max)
296 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000297 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000298 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000299
300class TestWraps(TestUpdateWrapper):
301
R. David Murray378c0cf2010-02-24 01:46:21 +0000302 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000303 def f():
304 """This is a test"""
305 pass
306 f.attr = 'This is also a test'
307 @functools.wraps(f)
308 def wrapper():
309 pass
310 self.check_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000311 return wrapper
312
313 def test_default_update(self):
314 wrapper = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000315 self.assertEqual(wrapper.__name__, 'f')
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000316 self.assertEqual(wrapper.attr, 'This is also a test')
317
R. David Murray378c0cf2010-02-24 01:46:21 +0000318 @unittest.skipIf(not sys.flags.optimize <= 1,
319 "Docstrings are omitted with -O2 and above")
320 def test_default_update_doc(self):
321 wrapper = self._default_update()
322 self.assertEqual(wrapper.__doc__, 'This is a test')
323
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000324 def test_no_update(self):
325 def f():
326 """This is a test"""
327 pass
328 f.attr = 'This is also a test'
329 @functools.wraps(f, (), ())
330 def wrapper():
331 pass
332 self.check_wrapper(wrapper, f, (), ())
333 self.assertEqual(wrapper.__name__, 'wrapper')
334 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000335 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000336
337 def test_selective_update(self):
338 def f():
339 pass
340 f.attr = 'This is a different test'
341 f.dict_attr = dict(a=1, b=2, c=3)
342 def add_dict_attr(f):
343 f.dict_attr = {}
344 return f
345 assign = ('attr',)
346 update = ('dict_attr',)
347 @functools.wraps(f, assign, update)
348 @add_dict_attr
349 def wrapper():
350 pass
351 self.check_wrapper(wrapper, f, assign, update)
352 self.assertEqual(wrapper.__name__, 'wrapper')
353 self.assertEqual(wrapper.__doc__, None)
354 self.assertEqual(wrapper.attr, 'This is a different test')
355 self.assertEqual(wrapper.dict_attr, f.dict_attr)
356
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000357class TestReduce(unittest.TestCase):
358 func = functools.reduce
359
360 def test_reduce(self):
361 class Squares:
362 def __init__(self, max):
363 self.max = max
364 self.sofar = []
365
366 def __len__(self):
367 return len(self.sofar)
368
369 def __getitem__(self, i):
370 if not 0 <= i < self.max: raise IndexError
371 n = len(self.sofar)
372 while n <= i:
373 self.sofar.append(n*n)
374 n += 1
375 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000376 def add(x, y):
377 return x + y
378 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000379 self.assertEqual(
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000380 self.func(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000381 ['a','c','d','w']
382 )
383 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
384 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000385 self.func(lambda x, y: x*y, range(2,21), 1),
386 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000387 )
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000388 self.assertEqual(self.func(add, Squares(10)), 285)
389 self.assertEqual(self.func(add, Squares(10), 0), 285)
390 self.assertEqual(self.func(add, Squares(0), 0), 0)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000391 self.assertRaises(TypeError, self.func)
392 self.assertRaises(TypeError, self.func, 42, 42)
393 self.assertRaises(TypeError, self.func, 42, 42, 42)
394 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
395 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
396 self.assertRaises(TypeError, self.func, 42, (42, 42))
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000397 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
398 self.assertRaises(TypeError, self.func, add, "")
399 self.assertRaises(TypeError, self.func, add, ())
400 self.assertRaises(TypeError, self.func, add, object())
401
402 class TestFailingIter:
403 def __iter__(self):
404 raise RuntimeError
405 self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
406
407 self.assertEqual(self.func(add, [], None), None)
408 self.assertEqual(self.func(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000409
410 class BadSeq:
411 def __getitem__(self, index):
412 raise ValueError
413 self.assertRaises(ValueError, self.func, 42, BadSeq())
414
415 # Test reduce()'s use of iterators.
416 def test_iterator_usage(self):
417 class SequenceClass:
418 def __init__(self, n):
419 self.n = n
420 def __getitem__(self, i):
421 if 0 <= i < self.n:
422 return i
423 else:
424 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000425
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000426 from operator import add
427 self.assertEqual(self.func(add, SequenceClass(5)), 10)
428 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
429 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
430 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
431 self.assertEqual(self.func(add, SequenceClass(1)), 0)
432 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
433
434 d = {"one": 1, "two": 2, "three": 3}
435 self.assertEqual(self.func(add, d), "".join(d.keys()))
436
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000437class TestCmpToKey(unittest.TestCase):
438 def test_cmp_to_key(self):
439 def mycmp(x, y):
440 return y - x
441 self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)),
442 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000443
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000444 def test_hash(self):
445 def mycmp(x, y):
446 return y - x
447 key = functools.cmp_to_key(mycmp)
448 k = key(10)
449 self.assertRaises(TypeError, hash(k))
450
451class TestTotalOrdering(unittest.TestCase):
452
453 def test_total_ordering_lt(self):
454 @functools.total_ordering
455 class A:
456 def __init__(self, value):
457 self.value = value
458 def __lt__(self, other):
459 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000460 def __eq__(self, other):
461 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000462 self.assertTrue(A(1) < A(2))
463 self.assertTrue(A(2) > A(1))
464 self.assertTrue(A(1) <= A(2))
465 self.assertTrue(A(2) >= A(1))
466 self.assertTrue(A(2) <= A(2))
467 self.assertTrue(A(2) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000468
469 def test_total_ordering_le(self):
470 @functools.total_ordering
471 class A:
472 def __init__(self, value):
473 self.value = value
474 def __le__(self, other):
475 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000476 def __eq__(self, other):
477 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000478 self.assertTrue(A(1) < A(2))
479 self.assertTrue(A(2) > A(1))
480 self.assertTrue(A(1) <= A(2))
481 self.assertTrue(A(2) >= A(1))
482 self.assertTrue(A(2) <= A(2))
483 self.assertTrue(A(2) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000484
485 def test_total_ordering_gt(self):
486 @functools.total_ordering
487 class A:
488 def __init__(self, value):
489 self.value = value
490 def __gt__(self, other):
491 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000492 def __eq__(self, other):
493 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000494 self.assertTrue(A(1) < A(2))
495 self.assertTrue(A(2) > A(1))
496 self.assertTrue(A(1) <= A(2))
497 self.assertTrue(A(2) >= A(1))
498 self.assertTrue(A(2) <= A(2))
499 self.assertTrue(A(2) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000500
501 def test_total_ordering_ge(self):
502 @functools.total_ordering
503 class A:
504 def __init__(self, value):
505 self.value = value
506 def __ge__(self, other):
507 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000508 def __eq__(self, other):
509 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000510 self.assertTrue(A(1) < A(2))
511 self.assertTrue(A(2) > A(1))
512 self.assertTrue(A(1) <= A(2))
513 self.assertTrue(A(2) >= A(1))
514 self.assertTrue(A(2) <= A(2))
515 self.assertTrue(A(2) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000516
517 def test_total_ordering_no_overwrite(self):
518 # new methods should not overwrite existing
519 @functools.total_ordering
520 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000521 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +0000522 self.assertTrue(A(1) < A(2))
523 self.assertTrue(A(2) > A(1))
524 self.assertTrue(A(1) <= A(2))
525 self.assertTrue(A(2) >= A(1))
526 self.assertTrue(A(2) <= A(2))
527 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000528
Benjamin Peterson42ebee32010-04-11 01:43:16 +0000529 def test_no_operations_defined(self):
530 with self.assertRaises(ValueError):
531 @functools.total_ordering
532 class A:
533 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000534
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000535 def test_bug_10042(self):
536 @functools.total_ordering
537 class TestTO:
538 def __init__(self, value):
539 self.value = value
540 def __eq__(self, other):
541 if isinstance(other, TestTO):
542 return self.value == other.value
543 return False
544 def __lt__(self, other):
545 if isinstance(other, TestTO):
546 return self.value < other.value
547 raise TypeError
548 with self.assertRaises(TypeError):
549 TestTO(8) <= ()
550
Georg Brandl2e7346a2010-07-31 18:09:23 +0000551class TestLRU(unittest.TestCase):
552
553 def test_lru(self):
554 def orig(x, y):
555 return 3*x+y
556 f = functools.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000557 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000558 self.assertEqual(maxsize, 20)
559 self.assertEqual(currsize, 0)
560 self.assertEqual(hits, 0)
561 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000562
563 domain = range(5)
564 for i in range(1000):
565 x, y = choice(domain), choice(domain)
566 actual = f(x, y)
567 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +0000568 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000569 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000570 self.assertTrue(hits > misses)
571 self.assertEqual(hits + misses, 1000)
572 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000573
Raymond Hettinger02566ec2010-09-04 22:46:06 +0000574 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +0000575 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000576 self.assertEqual(hits, 0)
577 self.assertEqual(misses, 0)
578 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000579 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000580 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000581 self.assertEqual(hits, 0)
582 self.assertEqual(misses, 1)
583 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000584
Nick Coghlan98876832010-08-17 06:17:18 +0000585 # Test bypassing the cache
586 self.assertIs(f.__wrapped__, orig)
587 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000588 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000589 self.assertEqual(hits, 0)
590 self.assertEqual(misses, 1)
591 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +0000592
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000593 # test size zero (which means "never-cache")
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000594 @functools.lru_cache(0)
595 def f():
596 nonlocal f_cnt
597 f_cnt += 1
598 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000599 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000600 f_cnt = 0
601 for i in range(5):
602 self.assertEqual(f(), 20)
603 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000604 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000605 self.assertEqual(hits, 0)
606 self.assertEqual(misses, 5)
607 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000608
609 # test size one
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000610 @functools.lru_cache(1)
611 def f():
612 nonlocal f_cnt
613 f_cnt += 1
614 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000615 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000616 f_cnt = 0
617 for i in range(5):
618 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000619 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000620 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000621 self.assertEqual(hits, 4)
622 self.assertEqual(misses, 1)
623 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000624
Raymond Hettingerf3098282010-08-15 03:30:45 +0000625 # test size two
626 @functools.lru_cache(2)
627 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000628 nonlocal f_cnt
629 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +0000630 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +0000631 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000632 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +0000633 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
634 # * * * *
635 self.assertEqual(f(x), x*10)
636 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000637 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000638 self.assertEqual(hits, 12)
639 self.assertEqual(misses, 4)
640 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000641
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +0000642 def test_lru_with_maxsize_none(self):
643 @functools.lru_cache(maxsize=None)
644 def fib(n):
645 if n < 2:
646 return n
647 return fib(n-1) + fib(n-2)
648 self.assertEqual([fib(n) for n in range(16)],
649 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
650 self.assertEqual(fib.cache_info(),
651 functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
652 fib.cache_clear()
653 self.assertEqual(fib.cache_info(),
654 functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
655
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000656def test_main(verbose=None):
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000657 test_classes = (
658 TestPartial,
659 TestPartialSubclass,
660 TestPythonPartial,
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000661 TestUpdateWrapper,
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000662 TestTotalOrdering,
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000663 TestWraps,
Georg Brandl2e7346a2010-07-31 18:09:23 +0000664 TestReduce,
665 TestLRU,
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000666 )
Benjamin Petersonee8712c2008-05-20 21:35:26 +0000667 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000668
669 # verify reference counting
670 if verbose and hasattr(sys, "gettotalrefcount"):
671 import gc
672 counts = [None] * 5
Guido van Rossum805365e2007-05-07 22:24:25 +0000673 for i in range(len(counts)):
Benjamin Petersonee8712c2008-05-20 21:35:26 +0000674 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000675 gc.collect()
676 counts[i] = sys.gettotalrefcount()
Guido van Rossumbe19ed72007-02-09 05:37:30 +0000677 print(counts)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000678
679if __name__ == '__main__':
680 test_main(verbose=True)