blob: c4910a71c46955511f150d17344d13a269cb1f9b [file] [log] [blame]
Thomas Wouters4d70c3d2006-06-08 14:42:34 +00001import functools
Raymond Hettinger003be522011-05-03 11:01:32 -07002import collections
R. David Murray378c0cf2010-02-24 01:46:21 +00003import sys
Raymond Hettinger9c323f82005-02-28 19:39:44 +00004import unittest
Benjamin Petersonee8712c2008-05-20 21:35:26 +00005from test import support
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +00006from weakref import proxy
Jack Diederiche0cbd692009-04-01 04:27:09 +00007import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00008from random import choice
Raymond Hettinger9c323f82005-02-28 19:39:44 +00009
10@staticmethod
11def PythonPartial(func, *args, **keywords):
12 'Pure Python approximation of partial()'
13 def newfunc(*fargs, **fkeywords):
14 newkeywords = keywords.copy()
15 newkeywords.update(fkeywords)
16 return func(*(args + fargs), **newkeywords)
17 newfunc.func = func
18 newfunc.args = args
19 newfunc.keywords = keywords
20 return newfunc
21
22def capture(*args, **kw):
23 """capture all positional and keyword arguments"""
24 return args, kw
25
Jack Diederiche0cbd692009-04-01 04:27:09 +000026def signature(part):
27 """ return the signature of a partial object """
28 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000029
Raymond Hettinger9c323f82005-02-28 19:39:44 +000030class TestPartial(unittest.TestCase):
31
Thomas Wouters4d70c3d2006-06-08 14:42:34 +000032 thetype = functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +000033
34 def test_basic_examples(self):
35 p = self.thetype(capture, 1, 2, a=10, b=20)
36 self.assertEqual(p(3, 4, b=30, c=40),
37 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
38 p = self.thetype(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000039 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000040
41 def test_attributes(self):
42 p = self.thetype(capture, 1, 2, a=10, b=20)
43 # attributes should be readable
44 self.assertEqual(p.func, capture)
45 self.assertEqual(p.args, (1, 2))
46 self.assertEqual(p.keywords, dict(a=10, b=20))
47 # attributes should not be writable
48 if not isinstance(self.thetype, type):
49 return
Georg Brandl89fad142010-03-14 10:23:39 +000050 self.assertRaises(AttributeError, setattr, p, 'func', map)
51 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
52 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
53
54 p = self.thetype(hex)
55 try:
56 del p.__dict__
57 except TypeError:
58 pass
59 else:
60 self.fail('partial object allowed __dict__ to be deleted')
Raymond Hettinger9c323f82005-02-28 19:39:44 +000061
62 def test_argument_checking(self):
63 self.assertRaises(TypeError, self.thetype) # need at least a func arg
64 try:
65 self.thetype(2)()
66 except TypeError:
67 pass
68 else:
69 self.fail('First arg not checked for callability')
70
71 def test_protection_of_callers_dict_argument(self):
72 # a caller's dictionary should not be altered by partial
73 def func(a=10, b=20):
74 return a
75 d = {'a':3}
76 p = self.thetype(func, a=5)
77 self.assertEqual(p(**d), 3)
78 self.assertEqual(d, {'a':3})
79 p(b=7)
80 self.assertEqual(d, {'a':3})
81
82 def test_arg_combinations(self):
83 # exercise special code paths for zero args in either partial
84 # object or the caller
85 p = self.thetype(capture)
86 self.assertEqual(p(), ((), {}))
87 self.assertEqual(p(1,2), ((1,2), {}))
88 p = self.thetype(capture, 1, 2)
89 self.assertEqual(p(), ((1,2), {}))
90 self.assertEqual(p(3,4), ((1,2,3,4), {}))
91
92 def test_kw_combinations(self):
93 # exercise special code paths for no keyword args in
94 # either the partial object or the caller
95 p = self.thetype(capture)
96 self.assertEqual(p(), ((), {}))
97 self.assertEqual(p(a=1), ((), {'a':1}))
98 p = self.thetype(capture, a=1)
99 self.assertEqual(p(), ((), {'a':1}))
100 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
101 # keyword args in the call override those in the partial object
102 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
103
104 def test_positional(self):
105 # make sure positional arguments are captured correctly
106 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
107 p = self.thetype(capture, *args)
108 expected = args + ('x',)
109 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000110 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000111
112 def test_keyword(self):
113 # make sure keyword arguments are captured correctly
114 for a in ['a', 0, None, 3.5]:
115 p = self.thetype(capture, a=a)
116 expected = {'a':a,'x':None}
117 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000118 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000119
120 def test_no_side_effects(self):
121 # make sure there are no side effects that affect subsequent calls
122 p = self.thetype(capture, 0, a=1)
123 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000124 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000125 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000126 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000127
128 def test_error_propagation(self):
129 def f(x, y):
130 x / y
131 self.assertRaises(ZeroDivisionError, self.thetype(f, 1, 0))
132 self.assertRaises(ZeroDivisionError, self.thetype(f, 1), 0)
133 self.assertRaises(ZeroDivisionError, self.thetype(f), 1, 0)
134 self.assertRaises(ZeroDivisionError, self.thetype(f, y=0), 1)
135
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000136 def test_weakref(self):
137 f = self.thetype(int, base=16)
138 p = proxy(f)
139 self.assertEqual(f.func, p.func)
140 f = None
141 self.assertRaises(ReferenceError, getattr, p, 'func')
142
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000143 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000144 data = list(map(str, range(10)))
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000145 join = self.thetype(str.join, '')
146 self.assertEqual(join(data), '0123456789')
147 join = self.thetype(''.join)
148 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000149
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000150 def test_repr(self):
151 args = (object(), object())
152 args_repr = ', '.join(repr(a) for a in args)
153 kwargs = {'a': object(), 'b': object()}
154 kwargs_repr = ', '.join("%s=%r" % (k, v) for k, v in kwargs.items())
155 if self.thetype is functools.partial:
156 name = 'functools.partial'
157 else:
158 name = self.thetype.__name__
159
160 f = self.thetype(capture)
161 self.assertEqual('{}({!r})'.format(name, capture),
162 repr(f))
163
164 f = self.thetype(capture, *args)
165 self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr),
166 repr(f))
167
168 f = self.thetype(capture, **kwargs)
169 self.assertEqual('{}({!r}, {})'.format(name, capture, kwargs_repr),
170 repr(f))
171
172 f = self.thetype(capture, *args, **kwargs)
173 self.assertEqual('{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr),
174 repr(f))
175
Jack Diederiche0cbd692009-04-01 04:27:09 +0000176 def test_pickle(self):
177 f = self.thetype(signature, 'asdf', bar=True)
178 f.add_something_to__dict__ = True
179 f_copy = pickle.loads(pickle.dumps(f))
180 self.assertEqual(signature(f), signature(f_copy))
181
Thomas Wouters4d70c3d2006-06-08 14:42:34 +0000182class PartialSubclass(functools.partial):
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000183 pass
184
185class TestPartialSubclass(TestPartial):
186
187 thetype = PartialSubclass
188
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000189class TestPythonPartial(TestPartial):
190
191 thetype = PythonPartial
192
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000193 # the python version hasn't a nice repr
194 def test_repr(self): pass
195
Jack Diederiche0cbd692009-04-01 04:27:09 +0000196 # the python version isn't picklable
197 def test_pickle(self): pass
198
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000199class TestUpdateWrapper(unittest.TestCase):
200
201 def check_wrapper(self, wrapper, wrapped,
202 assigned=functools.WRAPPER_ASSIGNMENTS,
203 updated=functools.WRAPPER_UPDATES):
204 # Check attributes were assigned
205 for name in assigned:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000206 self.assertTrue(getattr(wrapper, name) is getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000207 # Check attributes were updated
208 for name in updated:
209 wrapper_attr = getattr(wrapper, name)
210 wrapped_attr = getattr(wrapped, name)
211 for key in wrapped_attr:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000212 self.assertTrue(wrapped_attr[key] is wrapper_attr[key])
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000213
R. David Murray378c0cf2010-02-24 01:46:21 +0000214 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000215 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000216 """This is a test"""
217 pass
218 f.attr = 'This is also a test'
Antoine Pitrou560f7642010-08-04 18:28:02 +0000219 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000220 pass
221 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000222 return wrapper, f
223
224 def test_default_update(self):
225 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000226 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000227 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000228 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600229 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000230 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000231 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
232 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000233
R. David Murray378c0cf2010-02-24 01:46:21 +0000234 @unittest.skipIf(sys.flags.optimize >= 2,
235 "Docstrings are omitted with -O2 and above")
236 def test_default_update_doc(self):
237 wrapper, f = self._default_update()
238 self.assertEqual(wrapper.__doc__, 'This is a test')
239
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000240 def test_no_update(self):
241 def f():
242 """This is a test"""
243 pass
244 f.attr = 'This is also a test'
245 def wrapper():
246 pass
247 functools.update_wrapper(wrapper, f, (), ())
248 self.check_wrapper(wrapper, f, (), ())
249 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600250 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000251 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000252 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000253 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000254
255 def test_selective_update(self):
256 def f():
257 pass
258 f.attr = 'This is a different test'
259 f.dict_attr = dict(a=1, b=2, c=3)
260 def wrapper():
261 pass
262 wrapper.dict_attr = {}
263 assign = ('attr',)
264 update = ('dict_attr',)
265 functools.update_wrapper(wrapper, f, assign, update)
266 self.check_wrapper(wrapper, f, assign, update)
267 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600268 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000269 self.assertEqual(wrapper.__doc__, None)
270 self.assertEqual(wrapper.attr, 'This is a different test')
271 self.assertEqual(wrapper.dict_attr, f.dict_attr)
272
Nick Coghlan98876832010-08-17 06:17:18 +0000273 def test_missing_attributes(self):
274 def f():
275 pass
276 def wrapper():
277 pass
278 wrapper.dict_attr = {}
279 assign = ('attr',)
280 update = ('dict_attr',)
281 # Missing attributes on wrapped object are ignored
282 functools.update_wrapper(wrapper, f, assign, update)
283 self.assertNotIn('attr', wrapper.__dict__)
284 self.assertEqual(wrapper.dict_attr, {})
285 # Wrapper must have expected attributes for updating
286 del wrapper.dict_attr
287 with self.assertRaises(AttributeError):
288 functools.update_wrapper(wrapper, f, assign, update)
289 wrapper.dict_attr = 1
290 with self.assertRaises(AttributeError):
291 functools.update_wrapper(wrapper, f, assign, update)
292
293 @unittest.skipIf(sys.flags.optimize >= 2,
294 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000295 def test_builtin_update(self):
296 # Test for bug #1576241
297 def wrapper():
298 pass
299 functools.update_wrapper(wrapper, max)
300 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000301 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000302 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000303
304class TestWraps(TestUpdateWrapper):
305
R. David Murray378c0cf2010-02-24 01:46:21 +0000306 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000307 def f():
308 """This is a test"""
309 pass
310 f.attr = 'This is also a test'
311 @functools.wraps(f)
312 def wrapper():
313 pass
314 self.check_wrapper(wrapper, f)
Meador Ingeff7f64c2011-12-11 22:37:31 -0600315 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000316
317 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600318 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000319 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600320 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000321 self.assertEqual(wrapper.attr, 'This is also a test')
322
R. David Murray378c0cf2010-02-24 01:46:21 +0000323 @unittest.skipIf(not sys.flags.optimize <= 1,
324 "Docstrings are omitted with -O2 and above")
325 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600326 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000327 self.assertEqual(wrapper.__doc__, 'This is a test')
328
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000329 def test_no_update(self):
330 def f():
331 """This is a test"""
332 pass
333 f.attr = 'This is also a test'
334 @functools.wraps(f, (), ())
335 def wrapper():
336 pass
337 self.check_wrapper(wrapper, f, (), ())
338 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600339 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000340 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000341 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000342
343 def test_selective_update(self):
344 def f():
345 pass
346 f.attr = 'This is a different test'
347 f.dict_attr = dict(a=1, b=2, c=3)
348 def add_dict_attr(f):
349 f.dict_attr = {}
350 return f
351 assign = ('attr',)
352 update = ('dict_attr',)
353 @functools.wraps(f, assign, update)
354 @add_dict_attr
355 def wrapper():
356 pass
357 self.check_wrapper(wrapper, f, assign, update)
358 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600359 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000360 self.assertEqual(wrapper.__doc__, None)
361 self.assertEqual(wrapper.attr, 'This is a different test')
362 self.assertEqual(wrapper.dict_attr, f.dict_attr)
363
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000364class TestReduce(unittest.TestCase):
365 func = functools.reduce
366
367 def test_reduce(self):
368 class Squares:
369 def __init__(self, max):
370 self.max = max
371 self.sofar = []
372
373 def __len__(self):
374 return len(self.sofar)
375
376 def __getitem__(self, i):
377 if not 0 <= i < self.max: raise IndexError
378 n = len(self.sofar)
379 while n <= i:
380 self.sofar.append(n*n)
381 n += 1
382 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000383 def add(x, y):
384 return x + y
385 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000386 self.assertEqual(
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000387 self.func(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000388 ['a','c','d','w']
389 )
390 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
391 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000392 self.func(lambda x, y: x*y, range(2,21), 1),
393 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000394 )
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000395 self.assertEqual(self.func(add, Squares(10)), 285)
396 self.assertEqual(self.func(add, Squares(10), 0), 285)
397 self.assertEqual(self.func(add, Squares(0), 0), 0)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000398 self.assertRaises(TypeError, self.func)
399 self.assertRaises(TypeError, self.func, 42, 42)
400 self.assertRaises(TypeError, self.func, 42, 42, 42)
401 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
402 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
403 self.assertRaises(TypeError, self.func, 42, (42, 42))
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000404 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
405 self.assertRaises(TypeError, self.func, add, "")
406 self.assertRaises(TypeError, self.func, add, ())
407 self.assertRaises(TypeError, self.func, add, object())
408
409 class TestFailingIter:
410 def __iter__(self):
411 raise RuntimeError
412 self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
413
414 self.assertEqual(self.func(add, [], None), None)
415 self.assertEqual(self.func(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000416
417 class BadSeq:
418 def __getitem__(self, index):
419 raise ValueError
420 self.assertRaises(ValueError, self.func, 42, BadSeq())
421
422 # Test reduce()'s use of iterators.
423 def test_iterator_usage(self):
424 class SequenceClass:
425 def __init__(self, n):
426 self.n = n
427 def __getitem__(self, i):
428 if 0 <= i < self.n:
429 return i
430 else:
431 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000432
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000433 from operator import add
434 self.assertEqual(self.func(add, SequenceClass(5)), 10)
435 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
436 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
437 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
438 self.assertEqual(self.func(add, SequenceClass(1)), 0)
439 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
440
441 d = {"one": 1, "two": 2, "three": 3}
442 self.assertEqual(self.func(add, d), "".join(d.keys()))
443
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000444class TestCmpToKey(unittest.TestCase):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700445
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000446 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700447 def cmp1(x, y):
448 return (x > y) - (x < y)
449 key = functools.cmp_to_key(cmp1)
450 self.assertEqual(key(3), key(3))
451 self.assertGreater(key(3), key(1))
452 def cmp2(x, y):
453 return int(x) - int(y)
454 key = functools.cmp_to_key(cmp2)
455 self.assertEqual(key(4.0), key('4'))
456 self.assertLess(key(2), key('35'))
457
458 def test_cmp_to_key_arguments(self):
459 def cmp1(x, y):
460 return (x > y) - (x < y)
461 key = functools.cmp_to_key(mycmp=cmp1)
462 self.assertEqual(key(obj=3), key(obj=3))
463 self.assertGreater(key(obj=3), key(obj=1))
464 with self.assertRaises((TypeError, AttributeError)):
465 key(3) > 1 # rhs is not a K object
466 with self.assertRaises((TypeError, AttributeError)):
467 1 < key(3) # lhs is not a K object
468 with self.assertRaises(TypeError):
469 key = functools.cmp_to_key() # too few args
470 with self.assertRaises(TypeError):
471 key = functools.cmp_to_key(cmp1, None) # too many args
472 key = functools.cmp_to_key(cmp1)
473 with self.assertRaises(TypeError):
474 key() # too few args
475 with self.assertRaises(TypeError):
476 key(None, None) # too many args
477
478 def test_bad_cmp(self):
479 def cmp1(x, y):
480 raise ZeroDivisionError
481 key = functools.cmp_to_key(cmp1)
482 with self.assertRaises(ZeroDivisionError):
483 key(3) > key(1)
484
485 class BadCmp:
486 def __lt__(self, other):
487 raise ZeroDivisionError
488 def cmp1(x, y):
489 return BadCmp()
490 with self.assertRaises(ZeroDivisionError):
491 key(3) > key(1)
492
493 def test_obj_field(self):
494 def cmp1(x, y):
495 return (x > y) - (x < y)
496 key = functools.cmp_to_key(mycmp=cmp1)
497 self.assertEqual(key(50).obj, 50)
498
499 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000500 def mycmp(x, y):
501 return y - x
502 self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)),
503 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000504
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700505 def test_sort_int_str(self):
506 def mycmp(x, y):
507 x, y = int(x), int(y)
508 return (x > y) - (x < y)
509 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
510 values = sorted(values, key=functools.cmp_to_key(mycmp))
511 self.assertEqual([int(value) for value in values],
512 [0, 1, 1, 2, 3, 4, 5, 7, 10])
513
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000514 def test_hash(self):
515 def mycmp(x, y):
516 return y - x
517 key = functools.cmp_to_key(mycmp)
518 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700519 self.assertRaises(TypeError, hash, k)
Raymond Hettingere7a24302011-05-03 11:16:36 -0700520 self.assertNotIsInstance(k, collections.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000521
522class TestTotalOrdering(unittest.TestCase):
523
524 def test_total_ordering_lt(self):
525 @functools.total_ordering
526 class A:
527 def __init__(self, value):
528 self.value = value
529 def __lt__(self, other):
530 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000531 def __eq__(self, other):
532 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000533 self.assertTrue(A(1) < A(2))
534 self.assertTrue(A(2) > A(1))
535 self.assertTrue(A(1) <= A(2))
536 self.assertTrue(A(2) >= A(1))
537 self.assertTrue(A(2) <= A(2))
538 self.assertTrue(A(2) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000539
540 def test_total_ordering_le(self):
541 @functools.total_ordering
542 class A:
543 def __init__(self, value):
544 self.value = value
545 def __le__(self, other):
546 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000547 def __eq__(self, other):
548 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000549 self.assertTrue(A(1) < A(2))
550 self.assertTrue(A(2) > A(1))
551 self.assertTrue(A(1) <= A(2))
552 self.assertTrue(A(2) >= A(1))
553 self.assertTrue(A(2) <= A(2))
554 self.assertTrue(A(2) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000555
556 def test_total_ordering_gt(self):
557 @functools.total_ordering
558 class A:
559 def __init__(self, value):
560 self.value = value
561 def __gt__(self, other):
562 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000563 def __eq__(self, other):
564 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000565 self.assertTrue(A(1) < A(2))
566 self.assertTrue(A(2) > A(1))
567 self.assertTrue(A(1) <= A(2))
568 self.assertTrue(A(2) >= A(1))
569 self.assertTrue(A(2) <= A(2))
570 self.assertTrue(A(2) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000571
572 def test_total_ordering_ge(self):
573 @functools.total_ordering
574 class A:
575 def __init__(self, value):
576 self.value = value
577 def __ge__(self, other):
578 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000579 def __eq__(self, other):
580 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000581 self.assertTrue(A(1) < A(2))
582 self.assertTrue(A(2) > A(1))
583 self.assertTrue(A(1) <= A(2))
584 self.assertTrue(A(2) >= A(1))
585 self.assertTrue(A(2) <= A(2))
586 self.assertTrue(A(2) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000587
588 def test_total_ordering_no_overwrite(self):
589 # new methods should not overwrite existing
590 @functools.total_ordering
591 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000592 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +0000593 self.assertTrue(A(1) < A(2))
594 self.assertTrue(A(2) > A(1))
595 self.assertTrue(A(1) <= A(2))
596 self.assertTrue(A(2) >= A(1))
597 self.assertTrue(A(2) <= A(2))
598 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000599
Benjamin Peterson42ebee32010-04-11 01:43:16 +0000600 def test_no_operations_defined(self):
601 with self.assertRaises(ValueError):
602 @functools.total_ordering
603 class A:
604 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000605
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000606 def test_bug_10042(self):
607 @functools.total_ordering
608 class TestTO:
609 def __init__(self, value):
610 self.value = value
611 def __eq__(self, other):
612 if isinstance(other, TestTO):
613 return self.value == other.value
614 return False
615 def __lt__(self, other):
616 if isinstance(other, TestTO):
617 return self.value < other.value
618 raise TypeError
619 with self.assertRaises(TypeError):
620 TestTO(8) <= ()
621
Georg Brandl2e7346a2010-07-31 18:09:23 +0000622class TestLRU(unittest.TestCase):
623
624 def test_lru(self):
625 def orig(x, y):
626 return 3*x+y
627 f = functools.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000628 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000629 self.assertEqual(maxsize, 20)
630 self.assertEqual(currsize, 0)
631 self.assertEqual(hits, 0)
632 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000633
634 domain = range(5)
635 for i in range(1000):
636 x, y = choice(domain), choice(domain)
637 actual = f(x, y)
638 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +0000639 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000640 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000641 self.assertTrue(hits > misses)
642 self.assertEqual(hits + misses, 1000)
643 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000644
Raymond Hettinger02566ec2010-09-04 22:46:06 +0000645 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +0000646 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000647 self.assertEqual(hits, 0)
648 self.assertEqual(misses, 0)
649 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000650 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000651 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000652 self.assertEqual(hits, 0)
653 self.assertEqual(misses, 1)
654 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000655
Nick Coghlan98876832010-08-17 06:17:18 +0000656 # Test bypassing the cache
657 self.assertIs(f.__wrapped__, orig)
658 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000659 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000660 self.assertEqual(hits, 0)
661 self.assertEqual(misses, 1)
662 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +0000663
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000664 # test size zero (which means "never-cache")
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000665 @functools.lru_cache(0)
666 def f():
667 nonlocal f_cnt
668 f_cnt += 1
669 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000670 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000671 f_cnt = 0
672 for i in range(5):
673 self.assertEqual(f(), 20)
674 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000675 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000676 self.assertEqual(hits, 0)
677 self.assertEqual(misses, 5)
678 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000679
680 # test size one
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000681 @functools.lru_cache(1)
682 def f():
683 nonlocal f_cnt
684 f_cnt += 1
685 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000686 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000687 f_cnt = 0
688 for i in range(5):
689 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000690 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000691 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000692 self.assertEqual(hits, 4)
693 self.assertEqual(misses, 1)
694 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000695
Raymond Hettingerf3098282010-08-15 03:30:45 +0000696 # test size two
697 @functools.lru_cache(2)
698 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000699 nonlocal f_cnt
700 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +0000701 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +0000702 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000703 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +0000704 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
705 # * * * *
706 self.assertEqual(f(x), x*10)
707 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000708 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000709 self.assertEqual(hits, 12)
710 self.assertEqual(misses, 4)
711 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000712
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +0000713 def test_lru_with_maxsize_none(self):
714 @functools.lru_cache(maxsize=None)
715 def fib(n):
716 if n < 2:
717 return n
718 return fib(n-1) + fib(n-2)
719 self.assertEqual([fib(n) for n in range(16)],
720 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
721 self.assertEqual(fib.cache_info(),
722 functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
723 fib.cache_clear()
724 self.assertEqual(fib.cache_info(),
725 functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
726
Raymond Hettinger4b779b32011-10-15 23:50:42 -0700727 def test_lru_with_exceptions(self):
728 # Verify that user_function exceptions get passed through without
729 # creating a hard-to-read chained exception.
730 # http://bugs.python.org/issue13177
731 for maxsize in (None, 100):
732 @functools.lru_cache(maxsize)
733 def func(i):
734 return 'abc'[i]
735 self.assertEqual(func(0), 'a')
736 with self.assertRaises(IndexError) as cm:
737 func(15)
738 self.assertIsNone(cm.exception.__context__)
739 # Verify that the previous exception did not result in a cached entry
740 with self.assertRaises(IndexError):
741 func(15)
742
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -0700743 def test_lru_with_types(self):
744 for maxsize in (None, 100):
745 @functools.lru_cache(maxsize=maxsize, typed=True)
746 def square(x):
747 return x * x
748 self.assertEqual(square(3), 9)
749 self.assertEqual(type(square(3)), type(9))
750 self.assertEqual(square(3.0), 9.0)
751 self.assertEqual(type(square(3.0)), type(9.0))
752 self.assertEqual(square(x=3), 9)
753 self.assertEqual(type(square(x=3)), type(9))
754 self.assertEqual(square(x=3.0), 9.0)
755 self.assertEqual(type(square(x=3.0)), type(9.0))
756 self.assertEqual(square.cache_info().hits, 4)
757 self.assertEqual(square.cache_info().misses, 4)
758
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000759def test_main(verbose=None):
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000760 test_classes = (
761 TestPartial,
762 TestPartialSubclass,
763 TestPythonPartial,
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000764 TestUpdateWrapper,
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000765 TestTotalOrdering,
Raymond Hettinger003be522011-05-03 11:01:32 -0700766 TestCmpToKey,
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000767 TestWraps,
Georg Brandl2e7346a2010-07-31 18:09:23 +0000768 TestReduce,
769 TestLRU,
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000770 )
Benjamin Petersonee8712c2008-05-20 21:35:26 +0000771 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000772
773 # verify reference counting
774 if verbose and hasattr(sys, "gettotalrefcount"):
775 import gc
776 counts = [None] * 5
Guido van Rossum805365e2007-05-07 22:24:25 +0000777 for i in range(len(counts)):
Benjamin Petersonee8712c2008-05-20 21:35:26 +0000778 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000779 gc.collect()
780 counts[i] = sys.gettotalrefcount()
Guido van Rossumbe19ed72007-02-09 05:37:30 +0000781 print(counts)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000782
783if __name__ == '__main__':
784 test_main(verbose=True)