blob: 5cc2a50e3debeb1302a63140fd49fb265d5f835c [file] [log] [blame]
Thomas Wouters4d70c3d2006-06-08 14:42:34 +00001import functools
R. David Murray378c0cf2010-02-24 01:46:21 +00002import sys
Raymond Hettinger9c323f82005-02-28 19:39:44 +00003import unittest
Benjamin Petersonee8712c2008-05-20 21:35:26 +00004from test import support
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +00005from weakref import proxy
Jack Diederiche0cbd692009-04-01 04:27:09 +00006import pickle
Raymond Hettinger9c323f82005-02-28 19:39:44 +00007
8@staticmethod
9def PythonPartial(func, *args, **keywords):
10 'Pure Python approximation of partial()'
11 def newfunc(*fargs, **fkeywords):
12 newkeywords = keywords.copy()
13 newkeywords.update(fkeywords)
14 return func(*(args + fargs), **newkeywords)
15 newfunc.func = func
16 newfunc.args = args
17 newfunc.keywords = keywords
18 return newfunc
19
20def capture(*args, **kw):
21 """capture all positional and keyword arguments"""
22 return args, kw
23
Jack Diederiche0cbd692009-04-01 04:27:09 +000024def signature(part):
25 """ return the signature of a partial object """
26 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000027
Raymond Hettinger9c323f82005-02-28 19:39:44 +000028class TestPartial(unittest.TestCase):
29
Thomas Wouters4d70c3d2006-06-08 14:42:34 +000030 thetype = functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +000031
32 def test_basic_examples(self):
33 p = self.thetype(capture, 1, 2, a=10, b=20)
34 self.assertEqual(p(3, 4, b=30, c=40),
35 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
36 p = self.thetype(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000037 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000038
39 def test_attributes(self):
40 p = self.thetype(capture, 1, 2, a=10, b=20)
41 # attributes should be readable
42 self.assertEqual(p.func, capture)
43 self.assertEqual(p.args, (1, 2))
44 self.assertEqual(p.keywords, dict(a=10, b=20))
45 # attributes should not be writable
46 if not isinstance(self.thetype, type):
47 return
Georg Brandl89fad142010-03-14 10:23:39 +000048 self.assertRaises(AttributeError, setattr, p, 'func', map)
49 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
50 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
51
52 p = self.thetype(hex)
53 try:
54 del p.__dict__
55 except TypeError:
56 pass
57 else:
58 self.fail('partial object allowed __dict__ to be deleted')
Raymond Hettinger9c323f82005-02-28 19:39:44 +000059
60 def test_argument_checking(self):
61 self.assertRaises(TypeError, self.thetype) # need at least a func arg
62 try:
63 self.thetype(2)()
64 except TypeError:
65 pass
66 else:
67 self.fail('First arg not checked for callability')
68
69 def test_protection_of_callers_dict_argument(self):
70 # a caller's dictionary should not be altered by partial
71 def func(a=10, b=20):
72 return a
73 d = {'a':3}
74 p = self.thetype(func, a=5)
75 self.assertEqual(p(**d), 3)
76 self.assertEqual(d, {'a':3})
77 p(b=7)
78 self.assertEqual(d, {'a':3})
79
80 def test_arg_combinations(self):
81 # exercise special code paths for zero args in either partial
82 # object or the caller
83 p = self.thetype(capture)
84 self.assertEqual(p(), ((), {}))
85 self.assertEqual(p(1,2), ((1,2), {}))
86 p = self.thetype(capture, 1, 2)
87 self.assertEqual(p(), ((1,2), {}))
88 self.assertEqual(p(3,4), ((1,2,3,4), {}))
89
90 def test_kw_combinations(self):
91 # exercise special code paths for no keyword args in
92 # either the partial object or the caller
93 p = self.thetype(capture)
94 self.assertEqual(p(), ((), {}))
95 self.assertEqual(p(a=1), ((), {'a':1}))
96 p = self.thetype(capture, a=1)
97 self.assertEqual(p(), ((), {'a':1}))
98 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
99 # keyword args in the call override those in the partial object
100 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
101
102 def test_positional(self):
103 # make sure positional arguments are captured correctly
104 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
105 p = self.thetype(capture, *args)
106 expected = args + ('x',)
107 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000108 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000109
110 def test_keyword(self):
111 # make sure keyword arguments are captured correctly
112 for a in ['a', 0, None, 3.5]:
113 p = self.thetype(capture, a=a)
114 expected = {'a':a,'x':None}
115 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000116 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000117
118 def test_no_side_effects(self):
119 # make sure there are no side effects that affect subsequent calls
120 p = self.thetype(capture, 0, a=1)
121 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000122 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000123 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000124 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000125
126 def test_error_propagation(self):
127 def f(x, y):
128 x / y
129 self.assertRaises(ZeroDivisionError, self.thetype(f, 1, 0))
130 self.assertRaises(ZeroDivisionError, self.thetype(f, 1), 0)
131 self.assertRaises(ZeroDivisionError, self.thetype(f), 1, 0)
132 self.assertRaises(ZeroDivisionError, self.thetype(f, y=0), 1)
133
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000134 def test_weakref(self):
135 f = self.thetype(int, base=16)
136 p = proxy(f)
137 self.assertEqual(f.func, p.func)
138 f = None
139 self.assertRaises(ReferenceError, getattr, p, 'func')
140
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000141 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000142 data = list(map(str, range(10)))
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000143 join = self.thetype(str.join, '')
144 self.assertEqual(join(data), '0123456789')
145 join = self.thetype(''.join)
146 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000147
Jack Diederiche0cbd692009-04-01 04:27:09 +0000148 def test_pickle(self):
149 f = self.thetype(signature, 'asdf', bar=True)
150 f.add_something_to__dict__ = True
151 f_copy = pickle.loads(pickle.dumps(f))
152 self.assertEqual(signature(f), signature(f_copy))
153
Thomas Wouters4d70c3d2006-06-08 14:42:34 +0000154class PartialSubclass(functools.partial):
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000155 pass
156
157class TestPartialSubclass(TestPartial):
158
159 thetype = PartialSubclass
160
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000161class TestPythonPartial(TestPartial):
162
163 thetype = PythonPartial
164
Jack Diederiche0cbd692009-04-01 04:27:09 +0000165 # the python version isn't picklable
166 def test_pickle(self): pass
167
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000168class TestUpdateWrapper(unittest.TestCase):
169
170 def check_wrapper(self, wrapper, wrapped,
171 assigned=functools.WRAPPER_ASSIGNMENTS,
172 updated=functools.WRAPPER_UPDATES):
173 # Check attributes were assigned
174 for name in assigned:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000175 self.assertTrue(getattr(wrapper, name) is getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000176 # Check attributes were updated
177 for name in updated:
178 wrapper_attr = getattr(wrapper, name)
179 wrapped_attr = getattr(wrapped, name)
180 for key in wrapped_attr:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000181 self.assertTrue(wrapped_attr[key] is wrapper_attr[key])
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000182
R. David Murray378c0cf2010-02-24 01:46:21 +0000183 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000184 def f():
185 """This is a test"""
186 pass
187 f.attr = 'This is also a test'
188 def wrapper():
189 pass
190 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000191 return wrapper, f
192
193 def test_default_update(self):
194 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000195 self.check_wrapper(wrapper, f)
196 self.assertEqual(wrapper.__name__, 'f')
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000197 self.assertEqual(wrapper.attr, 'This is also a test')
198
R. David Murray378c0cf2010-02-24 01:46:21 +0000199 @unittest.skipIf(sys.flags.optimize >= 2,
200 "Docstrings are omitted with -O2 and above")
201 def test_default_update_doc(self):
202 wrapper, f = self._default_update()
203 self.assertEqual(wrapper.__doc__, 'This is a test')
204
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000205 def test_no_update(self):
206 def f():
207 """This is a test"""
208 pass
209 f.attr = 'This is also a test'
210 def wrapper():
211 pass
212 functools.update_wrapper(wrapper, f, (), ())
213 self.check_wrapper(wrapper, f, (), ())
214 self.assertEqual(wrapper.__name__, 'wrapper')
215 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000216 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000217
218 def test_selective_update(self):
219 def f():
220 pass
221 f.attr = 'This is a different test'
222 f.dict_attr = dict(a=1, b=2, c=3)
223 def wrapper():
224 pass
225 wrapper.dict_attr = {}
226 assign = ('attr',)
227 update = ('dict_attr',)
228 functools.update_wrapper(wrapper, f, assign, update)
229 self.check_wrapper(wrapper, f, assign, update)
230 self.assertEqual(wrapper.__name__, 'wrapper')
231 self.assertEqual(wrapper.__doc__, None)
232 self.assertEqual(wrapper.attr, 'This is a different test')
233 self.assertEqual(wrapper.dict_attr, f.dict_attr)
234
Thomas Wouters89f507f2006-12-13 04:49:30 +0000235 def test_builtin_update(self):
236 # Test for bug #1576241
237 def wrapper():
238 pass
239 functools.update_wrapper(wrapper, max)
240 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000241 self.assertTrue(wrapper.__doc__.startswith('max('))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000242
243class TestWraps(TestUpdateWrapper):
244
R. David Murray378c0cf2010-02-24 01:46:21 +0000245 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000246 def f():
247 """This is a test"""
248 pass
249 f.attr = 'This is also a test'
250 @functools.wraps(f)
251 def wrapper():
252 pass
253 self.check_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000254 return wrapper
255
256 def test_default_update(self):
257 wrapper = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000258 self.assertEqual(wrapper.__name__, 'f')
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000259 self.assertEqual(wrapper.attr, 'This is also a test')
260
R. David Murray378c0cf2010-02-24 01:46:21 +0000261 @unittest.skipIf(not sys.flags.optimize <= 1,
262 "Docstrings are omitted with -O2 and above")
263 def test_default_update_doc(self):
264 wrapper = self._default_update()
265 self.assertEqual(wrapper.__doc__, 'This is a test')
266
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000267 def test_no_update(self):
268 def f():
269 """This is a test"""
270 pass
271 f.attr = 'This is also a test'
272 @functools.wraps(f, (), ())
273 def wrapper():
274 pass
275 self.check_wrapper(wrapper, f, (), ())
276 self.assertEqual(wrapper.__name__, 'wrapper')
277 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000278 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000279
280 def test_selective_update(self):
281 def f():
282 pass
283 f.attr = 'This is a different test'
284 f.dict_attr = dict(a=1, b=2, c=3)
285 def add_dict_attr(f):
286 f.dict_attr = {}
287 return f
288 assign = ('attr',)
289 update = ('dict_attr',)
290 @functools.wraps(f, assign, update)
291 @add_dict_attr
292 def wrapper():
293 pass
294 self.check_wrapper(wrapper, f, assign, update)
295 self.assertEqual(wrapper.__name__, 'wrapper')
296 self.assertEqual(wrapper.__doc__, None)
297 self.assertEqual(wrapper.attr, 'This is a different test')
298 self.assertEqual(wrapper.dict_attr, f.dict_attr)
299
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000300class TestReduce(unittest.TestCase):
301 func = functools.reduce
302
303 def test_reduce(self):
304 class Squares:
305 def __init__(self, max):
306 self.max = max
307 self.sofar = []
308
309 def __len__(self):
310 return len(self.sofar)
311
312 def __getitem__(self, i):
313 if not 0 <= i < self.max: raise IndexError
314 n = len(self.sofar)
315 while n <= i:
316 self.sofar.append(n*n)
317 n += 1
318 return self.sofar[i]
Guido van Rossumd8faa362007-04-27 19:54:29 +0000319
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000320 self.assertEqual(self.func(lambda x, y: x+y, ['a', 'b', 'c'], ''), 'abc')
321 self.assertEqual(
322 self.func(lambda x, y: x+y, [['a', 'c'], [], ['d', 'w']], []),
323 ['a','c','d','w']
324 )
325 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
326 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000327 self.func(lambda x, y: x*y, range(2,21), 1),
328 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000329 )
330 self.assertEqual(self.func(lambda x, y: x+y, Squares(10)), 285)
331 self.assertEqual(self.func(lambda x, y: x+y, Squares(10), 0), 285)
332 self.assertEqual(self.func(lambda x, y: x+y, Squares(0), 0), 0)
333 self.assertRaises(TypeError, self.func)
334 self.assertRaises(TypeError, self.func, 42, 42)
335 self.assertRaises(TypeError, self.func, 42, 42, 42)
336 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
337 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
338 self.assertRaises(TypeError, self.func, 42, (42, 42))
339
340 class BadSeq:
341 def __getitem__(self, index):
342 raise ValueError
343 self.assertRaises(ValueError, self.func, 42, BadSeq())
344
345 # Test reduce()'s use of iterators.
346 def test_iterator_usage(self):
347 class SequenceClass:
348 def __init__(self, n):
349 self.n = n
350 def __getitem__(self, i):
351 if 0 <= i < self.n:
352 return i
353 else:
354 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000355
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000356 from operator import add
357 self.assertEqual(self.func(add, SequenceClass(5)), 10)
358 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
359 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
360 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
361 self.assertEqual(self.func(add, SequenceClass(1)), 0)
362 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
363
364 d = {"one": 1, "two": 2, "three": 3}
365 self.assertEqual(self.func(add, d), "".join(d.keys()))
366
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000367class TestCmpToKey(unittest.TestCase):
368 def test_cmp_to_key(self):
369 def mycmp(x, y):
370 return y - x
371 self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)),
372 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000373
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000374 def test_hash(self):
375 def mycmp(x, y):
376 return y - x
377 key = functools.cmp_to_key(mycmp)
378 k = key(10)
379 self.assertRaises(TypeError, hash(k))
380
381class TestTotalOrdering(unittest.TestCase):
382
383 def test_total_ordering_lt(self):
384 @functools.total_ordering
385 class A:
386 def __init__(self, value):
387 self.value = value
388 def __lt__(self, other):
389 return self.value < other.value
390 self.assert_(A(1) < A(2))
391 self.assert_(A(2) > A(1))
392 self.assert_(A(1) <= A(2))
393 self.assert_(A(2) >= A(1))
394 self.assert_(A(2) <= A(2))
395 self.assert_(A(2) >= A(2))
396
397 def test_total_ordering_le(self):
398 @functools.total_ordering
399 class A:
400 def __init__(self, value):
401 self.value = value
402 def __le__(self, other):
403 return self.value <= other.value
404 self.assert_(A(1) < A(2))
405 self.assert_(A(2) > A(1))
406 self.assert_(A(1) <= A(2))
407 self.assert_(A(2) >= A(1))
408 self.assert_(A(2) <= A(2))
409 self.assert_(A(2) >= A(2))
410
411 def test_total_ordering_gt(self):
412 @functools.total_ordering
413 class A:
414 def __init__(self, value):
415 self.value = value
416 def __gt__(self, other):
417 return self.value > other.value
418 self.assert_(A(1) < A(2))
419 self.assert_(A(2) > A(1))
420 self.assert_(A(1) <= A(2))
421 self.assert_(A(2) >= A(1))
422 self.assert_(A(2) <= A(2))
423 self.assert_(A(2) >= A(2))
424
425 def test_total_ordering_ge(self):
426 @functools.total_ordering
427 class A:
428 def __init__(self, value):
429 self.value = value
430 def __ge__(self, other):
431 return self.value >= other.value
432 self.assert_(A(1) < A(2))
433 self.assert_(A(2) > A(1))
434 self.assert_(A(1) <= A(2))
435 self.assert_(A(2) >= A(1))
436 self.assert_(A(2) <= A(2))
437 self.assert_(A(2) >= A(2))
438
439 def test_total_ordering_no_overwrite(self):
440 # new methods should not overwrite existing
441 @functools.total_ordering
442 class A(int):
443 raise Exception()
444 self.assert_(A(1) < A(2))
445 self.assert_(A(2) > A(1))
446 self.assert_(A(1) <= A(2))
447 self.assert_(A(2) >= A(1))
448 self.assert_(A(2) <= A(2))
449 self.assert_(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000450
451
452def test_main(verbose=None):
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000453 test_classes = (
454 TestPartial,
455 TestPartialSubclass,
456 TestPythonPartial,
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000457 TestUpdateWrapper,
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000458 TestWraps,
459 TestReduce
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000460 )
Benjamin Petersonee8712c2008-05-20 21:35:26 +0000461 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000462
463 # verify reference counting
464 if verbose and hasattr(sys, "gettotalrefcount"):
465 import gc
466 counts = [None] * 5
Guido van Rossum805365e2007-05-07 22:24:25 +0000467 for i in range(len(counts)):
Benjamin Petersonee8712c2008-05-20 21:35:26 +0000468 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000469 gc.collect()
470 counts[i] = sys.gettotalrefcount()
Guido van Rossumbe19ed72007-02-09 05:37:30 +0000471 print(counts)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000472
473if __name__ == '__main__':
474 test_main(verbose=True)