blob: 35e39b528fd2a7da4e0849b76db52e6c4b3f296e [file] [log] [blame]
Thomas Wouters4d70c3d2006-06-08 14:42:34 +00001import functools
Raymond Hettinger003be522011-05-03 11:01:32 -07002import collections
R. David Murray378c0cf2010-02-24 01:46:21 +00003import sys
Raymond Hettinger9c323f82005-02-28 19:39:44 +00004import unittest
Benjamin Petersonee8712c2008-05-20 21:35:26 +00005from test import support
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +00006from weakref import proxy
Jack Diederiche0cbd692009-04-01 04:27:09 +00007import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00008from random import choice
Raymond Hettinger9c323f82005-02-28 19:39:44 +00009
10@staticmethod
11def PythonPartial(func, *args, **keywords):
12 'Pure Python approximation of partial()'
13 def newfunc(*fargs, **fkeywords):
14 newkeywords = keywords.copy()
15 newkeywords.update(fkeywords)
16 return func(*(args + fargs), **newkeywords)
17 newfunc.func = func
18 newfunc.args = args
19 newfunc.keywords = keywords
20 return newfunc
21
22def capture(*args, **kw):
23 """capture all positional and keyword arguments"""
24 return args, kw
25
Jack Diederiche0cbd692009-04-01 04:27:09 +000026def signature(part):
27 """ return the signature of a partial object """
28 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000029
Raymond Hettinger9c323f82005-02-28 19:39:44 +000030class TestPartial(unittest.TestCase):
31
Thomas Wouters4d70c3d2006-06-08 14:42:34 +000032 thetype = functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +000033
34 def test_basic_examples(self):
35 p = self.thetype(capture, 1, 2, a=10, b=20)
36 self.assertEqual(p(3, 4, b=30, c=40),
37 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
38 p = self.thetype(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000039 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000040
41 def test_attributes(self):
42 p = self.thetype(capture, 1, 2, a=10, b=20)
43 # attributes should be readable
44 self.assertEqual(p.func, capture)
45 self.assertEqual(p.args, (1, 2))
46 self.assertEqual(p.keywords, dict(a=10, b=20))
47 # attributes should not be writable
48 if not isinstance(self.thetype, type):
49 return
Georg Brandl89fad142010-03-14 10:23:39 +000050 self.assertRaises(AttributeError, setattr, p, 'func', map)
51 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
52 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
53
54 p = self.thetype(hex)
55 try:
56 del p.__dict__
57 except TypeError:
58 pass
59 else:
60 self.fail('partial object allowed __dict__ to be deleted')
Raymond Hettinger9c323f82005-02-28 19:39:44 +000061
62 def test_argument_checking(self):
63 self.assertRaises(TypeError, self.thetype) # need at least a func arg
64 try:
65 self.thetype(2)()
66 except TypeError:
67 pass
68 else:
69 self.fail('First arg not checked for callability')
70
71 def test_protection_of_callers_dict_argument(self):
72 # a caller's dictionary should not be altered by partial
73 def func(a=10, b=20):
74 return a
75 d = {'a':3}
76 p = self.thetype(func, a=5)
77 self.assertEqual(p(**d), 3)
78 self.assertEqual(d, {'a':3})
79 p(b=7)
80 self.assertEqual(d, {'a':3})
81
82 def test_arg_combinations(self):
83 # exercise special code paths for zero args in either partial
84 # object or the caller
85 p = self.thetype(capture)
86 self.assertEqual(p(), ((), {}))
87 self.assertEqual(p(1,2), ((1,2), {}))
88 p = self.thetype(capture, 1, 2)
89 self.assertEqual(p(), ((1,2), {}))
90 self.assertEqual(p(3,4), ((1,2,3,4), {}))
91
92 def test_kw_combinations(self):
93 # exercise special code paths for no keyword args in
94 # either the partial object or the caller
95 p = self.thetype(capture)
96 self.assertEqual(p(), ((), {}))
97 self.assertEqual(p(a=1), ((), {'a':1}))
98 p = self.thetype(capture, a=1)
99 self.assertEqual(p(), ((), {'a':1}))
100 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
101 # keyword args in the call override those in the partial object
102 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
103
104 def test_positional(self):
105 # make sure positional arguments are captured correctly
106 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
107 p = self.thetype(capture, *args)
108 expected = args + ('x',)
109 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000110 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000111
112 def test_keyword(self):
113 # make sure keyword arguments are captured correctly
114 for a in ['a', 0, None, 3.5]:
115 p = self.thetype(capture, a=a)
116 expected = {'a':a,'x':None}
117 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000118 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000119
120 def test_no_side_effects(self):
121 # make sure there are no side effects that affect subsequent calls
122 p = self.thetype(capture, 0, a=1)
123 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000124 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000125 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000126 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000127
128 def test_error_propagation(self):
129 def f(x, y):
130 x / y
131 self.assertRaises(ZeroDivisionError, self.thetype(f, 1, 0))
132 self.assertRaises(ZeroDivisionError, self.thetype(f, 1), 0)
133 self.assertRaises(ZeroDivisionError, self.thetype(f), 1, 0)
134 self.assertRaises(ZeroDivisionError, self.thetype(f, y=0), 1)
135
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000136 def test_weakref(self):
137 f = self.thetype(int, base=16)
138 p = proxy(f)
139 self.assertEqual(f.func, p.func)
140 f = None
141 self.assertRaises(ReferenceError, getattr, p, 'func')
142
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000143 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000144 data = list(map(str, range(10)))
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000145 join = self.thetype(str.join, '')
146 self.assertEqual(join(data), '0123456789')
147 join = self.thetype(''.join)
148 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000149
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000150 def test_repr(self):
151 args = (object(), object())
152 args_repr = ', '.join(repr(a) for a in args)
153 kwargs = {'a': object(), 'b': object()}
154 kwargs_repr = ', '.join("%s=%r" % (k, v) for k, v in kwargs.items())
155 if self.thetype is functools.partial:
156 name = 'functools.partial'
157 else:
158 name = self.thetype.__name__
159
160 f = self.thetype(capture)
161 self.assertEqual('{}({!r})'.format(name, capture),
162 repr(f))
163
164 f = self.thetype(capture, *args)
165 self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr),
166 repr(f))
167
168 f = self.thetype(capture, **kwargs)
169 self.assertEqual('{}({!r}, {})'.format(name, capture, kwargs_repr),
170 repr(f))
171
172 f = self.thetype(capture, *args, **kwargs)
173 self.assertEqual('{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr),
174 repr(f))
175
Jack Diederiche0cbd692009-04-01 04:27:09 +0000176 def test_pickle(self):
177 f = self.thetype(signature, 'asdf', bar=True)
178 f.add_something_to__dict__ = True
179 f_copy = pickle.loads(pickle.dumps(f))
180 self.assertEqual(signature(f), signature(f_copy))
181
Thomas Wouters4d70c3d2006-06-08 14:42:34 +0000182class PartialSubclass(functools.partial):
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000183 pass
184
185class TestPartialSubclass(TestPartial):
186
187 thetype = PartialSubclass
188
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000189class TestPythonPartial(TestPartial):
190
191 thetype = PythonPartial
192
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000193 # the python version hasn't a nice repr
194 def test_repr(self): pass
195
Jack Diederiche0cbd692009-04-01 04:27:09 +0000196 # the python version isn't picklable
197 def test_pickle(self): pass
198
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000199class TestUpdateWrapper(unittest.TestCase):
200
201 def check_wrapper(self, wrapper, wrapped,
202 assigned=functools.WRAPPER_ASSIGNMENTS,
203 updated=functools.WRAPPER_UPDATES):
204 # Check attributes were assigned
205 for name in assigned:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000206 self.assertTrue(getattr(wrapper, name) is getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000207 # Check attributes were updated
208 for name in updated:
209 wrapper_attr = getattr(wrapper, name)
210 wrapped_attr = getattr(wrapped, name)
211 for key in wrapped_attr:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000212 self.assertTrue(wrapped_attr[key] is wrapper_attr[key])
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000213
R. David Murray378c0cf2010-02-24 01:46:21 +0000214 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000215 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000216 """This is a test"""
217 pass
218 f.attr = 'This is also a test'
Antoine Pitrou560f7642010-08-04 18:28:02 +0000219 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000220 pass
221 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000222 return wrapper, f
223
224 def test_default_update(self):
225 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000226 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000227 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000228 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600229 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000230 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000231 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
232 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000233
R. David Murray378c0cf2010-02-24 01:46:21 +0000234 @unittest.skipIf(sys.flags.optimize >= 2,
235 "Docstrings are omitted with -O2 and above")
236 def test_default_update_doc(self):
237 wrapper, f = self._default_update()
238 self.assertEqual(wrapper.__doc__, 'This is a test')
239
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000240 def test_no_update(self):
241 def f():
242 """This is a test"""
243 pass
244 f.attr = 'This is also a test'
245 def wrapper():
246 pass
247 functools.update_wrapper(wrapper, f, (), ())
248 self.check_wrapper(wrapper, f, (), ())
249 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600250 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000251 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000252 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000253 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000254
255 def test_selective_update(self):
256 def f():
257 pass
258 f.attr = 'This is a different test'
259 f.dict_attr = dict(a=1, b=2, c=3)
260 def wrapper():
261 pass
262 wrapper.dict_attr = {}
263 assign = ('attr',)
264 update = ('dict_attr',)
265 functools.update_wrapper(wrapper, f, assign, update)
266 self.check_wrapper(wrapper, f, assign, update)
267 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600268 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000269 self.assertEqual(wrapper.__doc__, None)
270 self.assertEqual(wrapper.attr, 'This is a different test')
271 self.assertEqual(wrapper.dict_attr, f.dict_attr)
272
Nick Coghlan98876832010-08-17 06:17:18 +0000273 def test_missing_attributes(self):
274 def f():
275 pass
276 def wrapper():
277 pass
278 wrapper.dict_attr = {}
279 assign = ('attr',)
280 update = ('dict_attr',)
281 # Missing attributes on wrapped object are ignored
282 functools.update_wrapper(wrapper, f, assign, update)
283 self.assertNotIn('attr', wrapper.__dict__)
284 self.assertEqual(wrapper.dict_attr, {})
285 # Wrapper must have expected attributes for updating
286 del wrapper.dict_attr
287 with self.assertRaises(AttributeError):
288 functools.update_wrapper(wrapper, f, assign, update)
289 wrapper.dict_attr = 1
290 with self.assertRaises(AttributeError):
291 functools.update_wrapper(wrapper, f, assign, update)
292
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200293 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000294 @unittest.skipIf(sys.flags.optimize >= 2,
295 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000296 def test_builtin_update(self):
297 # Test for bug #1576241
298 def wrapper():
299 pass
300 functools.update_wrapper(wrapper, max)
301 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000302 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000303 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000304
305class TestWraps(TestUpdateWrapper):
306
R. David Murray378c0cf2010-02-24 01:46:21 +0000307 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000308 def f():
309 """This is a test"""
310 pass
311 f.attr = 'This is also a test'
312 @functools.wraps(f)
313 def wrapper():
314 pass
315 self.check_wrapper(wrapper, f)
Meador Ingeff7f64c2011-12-11 22:37:31 -0600316 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000317
318 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600319 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000320 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600321 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000322 self.assertEqual(wrapper.attr, 'This is also a test')
323
Serhiy Storchaka8e0ae2a2013-01-28 13:25:44 +0200324 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000325 "Docstrings are omitted with -O2 and above")
326 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600327 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000328 self.assertEqual(wrapper.__doc__, 'This is a test')
329
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000330 def test_no_update(self):
331 def f():
332 """This is a test"""
333 pass
334 f.attr = 'This is also a test'
335 @functools.wraps(f, (), ())
336 def wrapper():
337 pass
338 self.check_wrapper(wrapper, f, (), ())
339 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600340 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000341 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000342 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000343
344 def test_selective_update(self):
345 def f():
346 pass
347 f.attr = 'This is a different test'
348 f.dict_attr = dict(a=1, b=2, c=3)
349 def add_dict_attr(f):
350 f.dict_attr = {}
351 return f
352 assign = ('attr',)
353 update = ('dict_attr',)
354 @functools.wraps(f, assign, update)
355 @add_dict_attr
356 def wrapper():
357 pass
358 self.check_wrapper(wrapper, f, assign, update)
359 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600360 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000361 self.assertEqual(wrapper.__doc__, None)
362 self.assertEqual(wrapper.attr, 'This is a different test')
363 self.assertEqual(wrapper.dict_attr, f.dict_attr)
364
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000365class TestReduce(unittest.TestCase):
366 func = functools.reduce
367
368 def test_reduce(self):
369 class Squares:
370 def __init__(self, max):
371 self.max = max
372 self.sofar = []
373
374 def __len__(self):
375 return len(self.sofar)
376
377 def __getitem__(self, i):
378 if not 0 <= i < self.max: raise IndexError
379 n = len(self.sofar)
380 while n <= i:
381 self.sofar.append(n*n)
382 n += 1
383 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000384 def add(x, y):
385 return x + y
386 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000387 self.assertEqual(
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000388 self.func(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000389 ['a','c','d','w']
390 )
391 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
392 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000393 self.func(lambda x, y: x*y, range(2,21), 1),
394 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000395 )
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000396 self.assertEqual(self.func(add, Squares(10)), 285)
397 self.assertEqual(self.func(add, Squares(10), 0), 285)
398 self.assertEqual(self.func(add, Squares(0), 0), 0)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000399 self.assertRaises(TypeError, self.func)
400 self.assertRaises(TypeError, self.func, 42, 42)
401 self.assertRaises(TypeError, self.func, 42, 42, 42)
402 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
403 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
404 self.assertRaises(TypeError, self.func, 42, (42, 42))
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000405 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
406 self.assertRaises(TypeError, self.func, add, "")
407 self.assertRaises(TypeError, self.func, add, ())
408 self.assertRaises(TypeError, self.func, add, object())
409
410 class TestFailingIter:
411 def __iter__(self):
412 raise RuntimeError
413 self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
414
415 self.assertEqual(self.func(add, [], None), None)
416 self.assertEqual(self.func(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000417
418 class BadSeq:
419 def __getitem__(self, index):
420 raise ValueError
421 self.assertRaises(ValueError, self.func, 42, BadSeq())
422
423 # Test reduce()'s use of iterators.
424 def test_iterator_usage(self):
425 class SequenceClass:
426 def __init__(self, n):
427 self.n = n
428 def __getitem__(self, i):
429 if 0 <= i < self.n:
430 return i
431 else:
432 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000433
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000434 from operator import add
435 self.assertEqual(self.func(add, SequenceClass(5)), 10)
436 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
437 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
438 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
439 self.assertEqual(self.func(add, SequenceClass(1)), 0)
440 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
441
442 d = {"one": 1, "two": 2, "three": 3}
443 self.assertEqual(self.func(add, d), "".join(d.keys()))
444
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000445class TestCmpToKey(unittest.TestCase):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700446
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000447 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700448 def cmp1(x, y):
449 return (x > y) - (x < y)
450 key = functools.cmp_to_key(cmp1)
451 self.assertEqual(key(3), key(3))
452 self.assertGreater(key(3), key(1))
453 def cmp2(x, y):
454 return int(x) - int(y)
455 key = functools.cmp_to_key(cmp2)
456 self.assertEqual(key(4.0), key('4'))
457 self.assertLess(key(2), key('35'))
458
459 def test_cmp_to_key_arguments(self):
460 def cmp1(x, y):
461 return (x > y) - (x < y)
462 key = functools.cmp_to_key(mycmp=cmp1)
463 self.assertEqual(key(obj=3), key(obj=3))
464 self.assertGreater(key(obj=3), key(obj=1))
465 with self.assertRaises((TypeError, AttributeError)):
466 key(3) > 1 # rhs is not a K object
467 with self.assertRaises((TypeError, AttributeError)):
468 1 < key(3) # lhs is not a K object
469 with self.assertRaises(TypeError):
470 key = functools.cmp_to_key() # too few args
471 with self.assertRaises(TypeError):
472 key = functools.cmp_to_key(cmp1, None) # too many args
473 key = functools.cmp_to_key(cmp1)
474 with self.assertRaises(TypeError):
475 key() # too few args
476 with self.assertRaises(TypeError):
477 key(None, None) # too many args
478
479 def test_bad_cmp(self):
480 def cmp1(x, y):
481 raise ZeroDivisionError
482 key = functools.cmp_to_key(cmp1)
483 with self.assertRaises(ZeroDivisionError):
484 key(3) > key(1)
485
486 class BadCmp:
487 def __lt__(self, other):
488 raise ZeroDivisionError
489 def cmp1(x, y):
490 return BadCmp()
491 with self.assertRaises(ZeroDivisionError):
492 key(3) > key(1)
493
494 def test_obj_field(self):
495 def cmp1(x, y):
496 return (x > y) - (x < y)
497 key = functools.cmp_to_key(mycmp=cmp1)
498 self.assertEqual(key(50).obj, 50)
499
500 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000501 def mycmp(x, y):
502 return y - x
503 self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)),
504 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000505
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700506 def test_sort_int_str(self):
507 def mycmp(x, y):
508 x, y = int(x), int(y)
509 return (x > y) - (x < y)
510 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
511 values = sorted(values, key=functools.cmp_to_key(mycmp))
512 self.assertEqual([int(value) for value in values],
513 [0, 1, 1, 2, 3, 4, 5, 7, 10])
514
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000515 def test_hash(self):
516 def mycmp(x, y):
517 return y - x
518 key = functools.cmp_to_key(mycmp)
519 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700520 self.assertRaises(TypeError, hash, k)
Raymond Hettingere7a24302011-05-03 11:16:36 -0700521 self.assertNotIsInstance(k, collections.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000522
523class TestTotalOrdering(unittest.TestCase):
524
525 def test_total_ordering_lt(self):
526 @functools.total_ordering
527 class A:
528 def __init__(self, value):
529 self.value = value
530 def __lt__(self, other):
531 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000532 def __eq__(self, other):
533 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000534 self.assertTrue(A(1) < A(2))
535 self.assertTrue(A(2) > A(1))
536 self.assertTrue(A(1) <= A(2))
537 self.assertTrue(A(2) >= A(1))
538 self.assertTrue(A(2) <= A(2))
539 self.assertTrue(A(2) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000540
541 def test_total_ordering_le(self):
542 @functools.total_ordering
543 class A:
544 def __init__(self, value):
545 self.value = value
546 def __le__(self, other):
547 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000548 def __eq__(self, other):
549 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000550 self.assertTrue(A(1) < A(2))
551 self.assertTrue(A(2) > A(1))
552 self.assertTrue(A(1) <= A(2))
553 self.assertTrue(A(2) >= A(1))
554 self.assertTrue(A(2) <= A(2))
555 self.assertTrue(A(2) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000556
557 def test_total_ordering_gt(self):
558 @functools.total_ordering
559 class A:
560 def __init__(self, value):
561 self.value = value
562 def __gt__(self, other):
563 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000564 def __eq__(self, other):
565 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000566 self.assertTrue(A(1) < A(2))
567 self.assertTrue(A(2) > A(1))
568 self.assertTrue(A(1) <= A(2))
569 self.assertTrue(A(2) >= A(1))
570 self.assertTrue(A(2) <= A(2))
571 self.assertTrue(A(2) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000572
573 def test_total_ordering_ge(self):
574 @functools.total_ordering
575 class A:
576 def __init__(self, value):
577 self.value = value
578 def __ge__(self, other):
579 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000580 def __eq__(self, other):
581 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000582 self.assertTrue(A(1) < A(2))
583 self.assertTrue(A(2) > A(1))
584 self.assertTrue(A(1) <= A(2))
585 self.assertTrue(A(2) >= A(1))
586 self.assertTrue(A(2) <= A(2))
587 self.assertTrue(A(2) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000588
589 def test_total_ordering_no_overwrite(self):
590 # new methods should not overwrite existing
591 @functools.total_ordering
592 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000593 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +0000594 self.assertTrue(A(1) < A(2))
595 self.assertTrue(A(2) > A(1))
596 self.assertTrue(A(1) <= A(2))
597 self.assertTrue(A(2) >= A(1))
598 self.assertTrue(A(2) <= A(2))
599 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000600
Benjamin Peterson42ebee32010-04-11 01:43:16 +0000601 def test_no_operations_defined(self):
602 with self.assertRaises(ValueError):
603 @functools.total_ordering
604 class A:
605 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000606
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000607 def test_bug_10042(self):
608 @functools.total_ordering
609 class TestTO:
610 def __init__(self, value):
611 self.value = value
612 def __eq__(self, other):
613 if isinstance(other, TestTO):
614 return self.value == other.value
615 return False
616 def __lt__(self, other):
617 if isinstance(other, TestTO):
618 return self.value < other.value
619 raise TypeError
620 with self.assertRaises(TypeError):
621 TestTO(8) <= ()
622
Georg Brandl2e7346a2010-07-31 18:09:23 +0000623class TestLRU(unittest.TestCase):
624
625 def test_lru(self):
626 def orig(x, y):
627 return 3*x+y
628 f = functools.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000629 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000630 self.assertEqual(maxsize, 20)
631 self.assertEqual(currsize, 0)
632 self.assertEqual(hits, 0)
633 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000634
635 domain = range(5)
636 for i in range(1000):
637 x, y = choice(domain), choice(domain)
638 actual = f(x, y)
639 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +0000640 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000641 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000642 self.assertTrue(hits > misses)
643 self.assertEqual(hits + misses, 1000)
644 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000645
Raymond Hettinger02566ec2010-09-04 22:46:06 +0000646 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +0000647 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000648 self.assertEqual(hits, 0)
649 self.assertEqual(misses, 0)
650 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000651 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000652 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000653 self.assertEqual(hits, 0)
654 self.assertEqual(misses, 1)
655 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000656
Nick Coghlan98876832010-08-17 06:17:18 +0000657 # Test bypassing the cache
658 self.assertIs(f.__wrapped__, orig)
659 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000660 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000661 self.assertEqual(hits, 0)
662 self.assertEqual(misses, 1)
663 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +0000664
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000665 # test size zero (which means "never-cache")
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000666 @functools.lru_cache(0)
667 def f():
668 nonlocal f_cnt
669 f_cnt += 1
670 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000671 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000672 f_cnt = 0
673 for i in range(5):
674 self.assertEqual(f(), 20)
675 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000676 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000677 self.assertEqual(hits, 0)
678 self.assertEqual(misses, 5)
679 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000680
681 # test size one
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000682 @functools.lru_cache(1)
683 def f():
684 nonlocal f_cnt
685 f_cnt += 1
686 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000687 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000688 f_cnt = 0
689 for i in range(5):
690 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000691 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000692 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000693 self.assertEqual(hits, 4)
694 self.assertEqual(misses, 1)
695 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000696
Raymond Hettingerf3098282010-08-15 03:30:45 +0000697 # test size two
698 @functools.lru_cache(2)
699 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000700 nonlocal f_cnt
701 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +0000702 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +0000703 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000704 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +0000705 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
706 # * * * *
707 self.assertEqual(f(x), x*10)
708 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000709 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000710 self.assertEqual(hits, 12)
711 self.assertEqual(misses, 4)
712 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000713
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +0000714 def test_lru_with_maxsize_none(self):
715 @functools.lru_cache(maxsize=None)
716 def fib(n):
717 if n < 2:
718 return n
719 return fib(n-1) + fib(n-2)
720 self.assertEqual([fib(n) for n in range(16)],
721 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
722 self.assertEqual(fib.cache_info(),
723 functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
724 fib.cache_clear()
725 self.assertEqual(fib.cache_info(),
726 functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
727
Raymond Hettinger4b779b32011-10-15 23:50:42 -0700728 def test_lru_with_exceptions(self):
729 # Verify that user_function exceptions get passed through without
730 # creating a hard-to-read chained exception.
731 # http://bugs.python.org/issue13177
732 for maxsize in (None, 100):
733 @functools.lru_cache(maxsize)
734 def func(i):
735 return 'abc'[i]
736 self.assertEqual(func(0), 'a')
737 with self.assertRaises(IndexError) as cm:
738 func(15)
739 self.assertIsNone(cm.exception.__context__)
740 # Verify that the previous exception did not result in a cached entry
741 with self.assertRaises(IndexError):
742 func(15)
743
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -0700744 def test_lru_with_types(self):
745 for maxsize in (None, 100):
746 @functools.lru_cache(maxsize=maxsize, typed=True)
747 def square(x):
748 return x * x
749 self.assertEqual(square(3), 9)
750 self.assertEqual(type(square(3)), type(9))
751 self.assertEqual(square(3.0), 9.0)
752 self.assertEqual(type(square(3.0)), type(9.0))
753 self.assertEqual(square(x=3), 9)
754 self.assertEqual(type(square(x=3)), type(9))
755 self.assertEqual(square(x=3.0), 9.0)
756 self.assertEqual(type(square(x=3.0)), type(9.0))
757 self.assertEqual(square.cache_info().hits, 4)
758 self.assertEqual(square.cache_info().misses, 4)
759
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000760def test_main(verbose=None):
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000761 test_classes = (
762 TestPartial,
763 TestPartialSubclass,
764 TestPythonPartial,
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000765 TestUpdateWrapper,
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000766 TestTotalOrdering,
Raymond Hettinger003be522011-05-03 11:01:32 -0700767 TestCmpToKey,
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000768 TestWraps,
Georg Brandl2e7346a2010-07-31 18:09:23 +0000769 TestReduce,
770 TestLRU,
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000771 )
Benjamin Petersonee8712c2008-05-20 21:35:26 +0000772 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000773
774 # verify reference counting
775 if verbose and hasattr(sys, "gettotalrefcount"):
776 import gc
777 counts = [None] * 5
Guido van Rossum805365e2007-05-07 22:24:25 +0000778 for i in range(len(counts)):
Benjamin Petersonee8712c2008-05-20 21:35:26 +0000779 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000780 gc.collect()
781 counts[i] = sys.gettotalrefcount()
Guido van Rossumbe19ed72007-02-09 05:37:30 +0000782 print(counts)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000783
784if __name__ == '__main__':
785 test_main(verbose=True)