blob: c0a9a3b8f53e7534205a84a7ef7ed2c18eec8592 [file] [log] [blame]
Nick Coghlanc649ec52006-05-29 12:43:05 +00001import functools
R. David Murrayf28fd242010-02-23 00:24:49 +00002import sys
Raymond Hettinger9c323f82005-02-28 19:39:44 +00003import unittest
4from test import test_support
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +00005from weakref import proxy
Jack Diederichd60c29e2009-03-31 23:46:48 +00006import pickle
Raymond Hettinger9c323f82005-02-28 19:39:44 +00007
8@staticmethod
9def PythonPartial(func, *args, **keywords):
10 'Pure Python approximation of partial()'
11 def newfunc(*fargs, **fkeywords):
12 newkeywords = keywords.copy()
13 newkeywords.update(fkeywords)
14 return func(*(args + fargs), **newkeywords)
15 newfunc.func = func
16 newfunc.args = args
17 newfunc.keywords = keywords
18 return newfunc
19
20def capture(*args, **kw):
21 """capture all positional and keyword arguments"""
22 return args, kw
23
Jack Diederichd60c29e2009-03-31 23:46:48 +000024def signature(part):
25 """ return the signature of a partial object """
26 return (part.func, part.args, part.keywords, part.__dict__)
27
Raymond Hettinger9c323f82005-02-28 19:39:44 +000028class TestPartial(unittest.TestCase):
29
Nick Coghlanc649ec52006-05-29 12:43:05 +000030 thetype = functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +000031
32 def test_basic_examples(self):
33 p = self.thetype(capture, 1, 2, a=10, b=20)
34 self.assertEqual(p(3, 4, b=30, c=40),
35 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
36 p = self.thetype(map, lambda x: x*10)
37 self.assertEqual(p([1,2,3,4]), [10, 20, 30, 40])
38
39 def test_attributes(self):
40 p = self.thetype(capture, 1, 2, a=10, b=20)
41 # attributes should be readable
42 self.assertEqual(p.func, capture)
43 self.assertEqual(p.args, (1, 2))
44 self.assertEqual(p.keywords, dict(a=10, b=20))
45 # attributes should not be writable
Raymond Hettinger9c323f82005-02-28 19:39:44 +000046 self.assertRaises(TypeError, setattr, p, 'func', map)
47 self.assertRaises(TypeError, setattr, p, 'args', (1, 2))
48 self.assertRaises(TypeError, setattr, p, 'keywords', dict(a=1, b=2))
49
Georg Brandla34f87f2010-02-07 12:27:06 +000050 p = self.thetype(hex)
51 try:
52 del p.__dict__
53 except TypeError:
54 pass
55 else:
56 self.fail('partial object allowed __dict__ to be deleted')
57
Raymond Hettinger9c323f82005-02-28 19:39:44 +000058 def test_argument_checking(self):
59 self.assertRaises(TypeError, self.thetype) # need at least a func arg
60 try:
61 self.thetype(2)()
62 except TypeError:
63 pass
64 else:
65 self.fail('First arg not checked for callability')
66
67 def test_protection_of_callers_dict_argument(self):
68 # a caller's dictionary should not be altered by partial
69 def func(a=10, b=20):
70 return a
71 d = {'a':3}
72 p = self.thetype(func, a=5)
73 self.assertEqual(p(**d), 3)
74 self.assertEqual(d, {'a':3})
75 p(b=7)
76 self.assertEqual(d, {'a':3})
77
78 def test_arg_combinations(self):
79 # exercise special code paths for zero args in either partial
80 # object or the caller
81 p = self.thetype(capture)
82 self.assertEqual(p(), ((), {}))
83 self.assertEqual(p(1,2), ((1,2), {}))
84 p = self.thetype(capture, 1, 2)
85 self.assertEqual(p(), ((1,2), {}))
86 self.assertEqual(p(3,4), ((1,2,3,4), {}))
87
88 def test_kw_combinations(self):
89 # exercise special code paths for no keyword args in
90 # either the partial object or the caller
91 p = self.thetype(capture)
92 self.assertEqual(p(), ((), {}))
93 self.assertEqual(p(a=1), ((), {'a':1}))
94 p = self.thetype(capture, a=1)
95 self.assertEqual(p(), ((), {'a':1}))
96 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
97 # keyword args in the call override those in the partial object
98 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
99
100 def test_positional(self):
101 # make sure positional arguments are captured correctly
102 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
103 p = self.thetype(capture, *args)
104 expected = args + ('x',)
105 got, empty = p('x')
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000106 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000107
108 def test_keyword(self):
109 # make sure keyword arguments are captured correctly
110 for a in ['a', 0, None, 3.5]:
111 p = self.thetype(capture, a=a)
112 expected = {'a':a,'x':None}
113 empty, got = p(x=None)
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000114 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000115
116 def test_no_side_effects(self):
117 # make sure there are no side effects that affect subsequent calls
118 p = self.thetype(capture, 0, a=1)
119 args1, kw1 = p(1, b=2)
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000120 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000121 args2, kw2 = p()
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000122 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000123
124 def test_error_propagation(self):
125 def f(x, y):
Ezio Melottidde5b942010-02-03 05:37:26 +0000126 x // y
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000127 self.assertRaises(ZeroDivisionError, self.thetype(f, 1, 0))
128 self.assertRaises(ZeroDivisionError, self.thetype(f, 1), 0)
129 self.assertRaises(ZeroDivisionError, self.thetype(f), 1, 0)
130 self.assertRaises(ZeroDivisionError, self.thetype(f, y=0), 1)
131
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000132 def test_weakref(self):
133 f = self.thetype(int, base=16)
134 p = proxy(f)
135 self.assertEqual(f.func, p.func)
136 f = None
137 self.assertRaises(ReferenceError, getattr, p, 'func')
138
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000139 def test_with_bound_and_unbound_methods(self):
140 data = map(str, range(10))
141 join = self.thetype(str.join, '')
142 self.assertEqual(join(data), '0123456789')
143 join = self.thetype(''.join)
144 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000145
Jack Diederichd60c29e2009-03-31 23:46:48 +0000146 def test_pickle(self):
147 f = self.thetype(signature, 'asdf', bar=True)
148 f.add_something_to__dict__ = True
Serhiy Storchaka655720e2014-12-15 14:02:43 +0200149 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
150 f_copy = pickle.loads(pickle.dumps(f, proto))
151 self.assertEqual(signature(f), signature(f_copy))
Jack Diederichd60c29e2009-03-31 23:46:48 +0000152
Serhiy Storchakaa07a8b42013-02-04 12:45:46 +0200153 # Issue 6083: Reference counting bug
154 def test_setstate_refcount(self):
155 class BadSequence:
156 def __len__(self):
157 return 4
158 def __getitem__(self, key):
159 if key == 0:
160 return max
161 elif key == 1:
162 return tuple(range(1000000))
163 elif key in (2, 3):
164 return {}
165 raise IndexError
166
167 f = self.thetype(object)
168 self.assertRaises(SystemError, f.__setstate__, BadSequence())
169
Nick Coghlanc649ec52006-05-29 12:43:05 +0000170class PartialSubclass(functools.partial):
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000171 pass
172
173class TestPartialSubclass(TestPartial):
174
175 thetype = PartialSubclass
176
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000177class TestPythonPartial(TestPartial):
178
179 thetype = PythonPartial
180
Jack Diederichd60c29e2009-03-31 23:46:48 +0000181 # the python version isn't picklable
Zachary Ware1f702212013-12-10 14:09:20 -0600182 test_pickle = test_setstate_refcount = None
183
184 # the python version isn't a type
185 test_attributes = None
Jack Diederichd60c29e2009-03-31 23:46:48 +0000186
Nick Coghlan676725d2006-06-08 13:54:49 +0000187class TestUpdateWrapper(unittest.TestCase):
188
189 def check_wrapper(self, wrapper, wrapped,
190 assigned=functools.WRAPPER_ASSIGNMENTS,
191 updated=functools.WRAPPER_UPDATES):
192 # Check attributes were assigned
193 for name in assigned:
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000194 self.assertTrue(getattr(wrapper, name) is getattr(wrapped, name))
Nick Coghlan676725d2006-06-08 13:54:49 +0000195 # Check attributes were updated
196 for name in updated:
197 wrapper_attr = getattr(wrapper, name)
198 wrapped_attr = getattr(wrapped, name)
199 for key in wrapped_attr:
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000200 self.assertTrue(wrapped_attr[key] is wrapper_attr[key])
Nick Coghlan676725d2006-06-08 13:54:49 +0000201
R. David Murrayf28fd242010-02-23 00:24:49 +0000202 def _default_update(self):
Nick Coghlan676725d2006-06-08 13:54:49 +0000203 def f():
204 """This is a test"""
205 pass
206 f.attr = 'This is also a test'
207 def wrapper():
208 pass
209 functools.update_wrapper(wrapper, f)
R. David Murrayf28fd242010-02-23 00:24:49 +0000210 return wrapper, f
211
212 def test_default_update(self):
213 wrapper, f = self._default_update()
Nick Coghlan676725d2006-06-08 13:54:49 +0000214 self.check_wrapper(wrapper, f)
215 self.assertEqual(wrapper.__name__, 'f')
Nick Coghlan676725d2006-06-08 13:54:49 +0000216 self.assertEqual(wrapper.attr, 'This is also a test')
217
R. David Murrayf28fd242010-02-23 00:24:49 +0000218 @unittest.skipIf(sys.flags.optimize >= 2,
219 "Docstrings are omitted with -O2 and above")
220 def test_default_update_doc(self):
221 wrapper, f = self._default_update()
222 self.assertEqual(wrapper.__doc__, 'This is a test')
223
Nick Coghlan676725d2006-06-08 13:54:49 +0000224 def test_no_update(self):
225 def f():
226 """This is a test"""
227 pass
228 f.attr = 'This is also a test'
229 def wrapper():
230 pass
231 functools.update_wrapper(wrapper, f, (), ())
232 self.check_wrapper(wrapper, f, (), ())
233 self.assertEqual(wrapper.__name__, 'wrapper')
234 self.assertEqual(wrapper.__doc__, None)
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000235 self.assertFalse(hasattr(wrapper, 'attr'))
Nick Coghlan676725d2006-06-08 13:54:49 +0000236
237 def test_selective_update(self):
238 def f():
239 pass
240 f.attr = 'This is a different test'
241 f.dict_attr = dict(a=1, b=2, c=3)
242 def wrapper():
243 pass
244 wrapper.dict_attr = {}
245 assign = ('attr',)
246 update = ('dict_attr',)
247 functools.update_wrapper(wrapper, f, assign, update)
248 self.check_wrapper(wrapper, f, assign, update)
249 self.assertEqual(wrapper.__name__, 'wrapper')
250 self.assertEqual(wrapper.__doc__, None)
251 self.assertEqual(wrapper.attr, 'This is a different test')
252 self.assertEqual(wrapper.dict_attr, f.dict_attr)
253
Serhiy Storchaka72121c62013-01-27 19:45:49 +0200254 @test_support.requires_docstrings
Andrew M. Kuchling41eb7162006-10-27 16:39:10 +0000255 def test_builtin_update(self):
256 # Test for bug #1576241
257 def wrapper():
258 pass
259 functools.update_wrapper(wrapper, max)
260 self.assertEqual(wrapper.__name__, 'max')
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000261 self.assertTrue(wrapper.__doc__.startswith('max('))
Nick Coghlan676725d2006-06-08 13:54:49 +0000262
263class TestWraps(TestUpdateWrapper):
264
R. David Murrayf28fd242010-02-23 00:24:49 +0000265 def _default_update(self):
Nick Coghlan676725d2006-06-08 13:54:49 +0000266 def f():
267 """This is a test"""
268 pass
269 f.attr = 'This is also a test'
270 @functools.wraps(f)
271 def wrapper():
272 pass
273 self.check_wrapper(wrapper, f)
R. David Murrayf28fd242010-02-23 00:24:49 +0000274 return wrapper
275
276 def test_default_update(self):
277 wrapper = self._default_update()
Nick Coghlan676725d2006-06-08 13:54:49 +0000278 self.assertEqual(wrapper.__name__, 'f')
Nick Coghlan676725d2006-06-08 13:54:49 +0000279 self.assertEqual(wrapper.attr, 'This is also a test')
280
Serhiy Storchaka80a0a1e2013-01-28 13:24:01 +0200281 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murrayf28fd242010-02-23 00:24:49 +0000282 "Docstrings are omitted with -O2 and above")
283 def test_default_update_doc(self):
284 wrapper = self._default_update()
285 self.assertEqual(wrapper.__doc__, 'This is a test')
286
Nick Coghlan676725d2006-06-08 13:54:49 +0000287 def test_no_update(self):
288 def f():
289 """This is a test"""
290 pass
291 f.attr = 'This is also a test'
292 @functools.wraps(f, (), ())
293 def wrapper():
294 pass
295 self.check_wrapper(wrapper, f, (), ())
296 self.assertEqual(wrapper.__name__, 'wrapper')
297 self.assertEqual(wrapper.__doc__, None)
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000298 self.assertFalse(hasattr(wrapper, 'attr'))
Nick Coghlan676725d2006-06-08 13:54:49 +0000299
300 def test_selective_update(self):
301 def f():
302 pass
303 f.attr = 'This is a different test'
304 f.dict_attr = dict(a=1, b=2, c=3)
305 def add_dict_attr(f):
306 f.dict_attr = {}
307 return f
308 assign = ('attr',)
309 update = ('dict_attr',)
310 @functools.wraps(f, assign, update)
311 @add_dict_attr
312 def wrapper():
313 pass
314 self.check_wrapper(wrapper, f, assign, update)
315 self.assertEqual(wrapper.__name__, 'wrapper')
316 self.assertEqual(wrapper.__doc__, None)
317 self.assertEqual(wrapper.attr, 'This is a different test')
318 self.assertEqual(wrapper.dict_attr, f.dict_attr)
319
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000320
Brett Cannon83e81842008-08-09 23:30:55 +0000321class TestReduce(unittest.TestCase):
322
323 def test_reduce(self):
324 class Squares:
325
326 def __init__(self, max):
327 self.max = max
328 self.sofar = []
329
330 def __len__(self): return len(self.sofar)
331
332 def __getitem__(self, i):
333 if not 0 <= i < self.max: raise IndexError
334 n = len(self.sofar)
335 while n <= i:
336 self.sofar.append(n*n)
337 n += 1
338 return self.sofar[i]
339
340 reduce = functools.reduce
341 self.assertEqual(reduce(lambda x, y: x+y, ['a', 'b', 'c'], ''), 'abc')
342 self.assertEqual(
343 reduce(lambda x, y: x+y, [['a', 'c'], [], ['d', 'w']], []),
344 ['a','c','d','w']
345 )
346 self.assertEqual(reduce(lambda x, y: x*y, range(2,8), 1), 5040)
347 self.assertEqual(
348 reduce(lambda x, y: x*y, range(2,21), 1L),
349 2432902008176640000L
350 )
351 self.assertEqual(reduce(lambda x, y: x+y, Squares(10)), 285)
352 self.assertEqual(reduce(lambda x, y: x+y, Squares(10), 0), 285)
353 self.assertEqual(reduce(lambda x, y: x+y, Squares(0), 0), 0)
354 self.assertRaises(TypeError, reduce)
355 self.assertRaises(TypeError, reduce, 42, 42)
356 self.assertRaises(TypeError, reduce, 42, 42, 42)
357 self.assertEqual(reduce(42, "1"), "1") # func is never called with one item
358 self.assertEqual(reduce(42, "", "1"), "1") # func is never called with one item
359 self.assertRaises(TypeError, reduce, 42, (42, 42))
360
Raymond Hettingerbb006cf2010-04-04 21:45:01 +0000361class TestCmpToKey(unittest.TestCase):
362 def test_cmp_to_key(self):
363 def mycmp(x, y):
364 return y - x
365 self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)),
366 [4, 3, 2, 1, 0])
Brett Cannon83e81842008-08-09 23:30:55 +0000367
Raymond Hettingere1d665a2010-04-05 18:53:43 +0000368 def test_hash(self):
369 def mycmp(x, y):
370 return y - x
371 key = functools.cmp_to_key(mycmp)
372 k = key(10)
373 self.assertRaises(TypeError, hash(k))
374
Raymond Hettinger06bc0b62010-04-04 22:24:03 +0000375class TestTotalOrdering(unittest.TestCase):
376
377 def test_total_ordering_lt(self):
378 @functools.total_ordering
379 class A:
380 def __init__(self, value):
381 self.value = value
382 def __lt__(self, other):
383 return self.value < other.value
Éric Araujo374274d2011-03-19 04:29:36 +0100384 def __eq__(self, other):
385 return self.value == other.value
Ezio Melotti2623a372010-11-21 13:34:58 +0000386 self.assertTrue(A(1) < A(2))
387 self.assertTrue(A(2) > A(1))
388 self.assertTrue(A(1) <= A(2))
389 self.assertTrue(A(2) >= A(1))
390 self.assertTrue(A(2) <= A(2))
391 self.assertTrue(A(2) >= A(2))
Raymond Hettinger06bc0b62010-04-04 22:24:03 +0000392
393 def test_total_ordering_le(self):
394 @functools.total_ordering
395 class A:
396 def __init__(self, value):
397 self.value = value
398 def __le__(self, other):
399 return self.value <= other.value
Éric Araujo374274d2011-03-19 04:29:36 +0100400 def __eq__(self, other):
401 return self.value == other.value
Ezio Melotti2623a372010-11-21 13:34:58 +0000402 self.assertTrue(A(1) < A(2))
403 self.assertTrue(A(2) > A(1))
404 self.assertTrue(A(1) <= A(2))
405 self.assertTrue(A(2) >= A(1))
406 self.assertTrue(A(2) <= A(2))
407 self.assertTrue(A(2) >= A(2))
Raymond Hettinger06bc0b62010-04-04 22:24:03 +0000408
409 def test_total_ordering_gt(self):
410 @functools.total_ordering
411 class A:
412 def __init__(self, value):
413 self.value = value
414 def __gt__(self, other):
415 return self.value > other.value
Éric Araujo374274d2011-03-19 04:29:36 +0100416 def __eq__(self, other):
417 return self.value == other.value
Ezio Melotti2623a372010-11-21 13:34:58 +0000418 self.assertTrue(A(1) < A(2))
419 self.assertTrue(A(2) > A(1))
420 self.assertTrue(A(1) <= A(2))
421 self.assertTrue(A(2) >= A(1))
422 self.assertTrue(A(2) <= A(2))
423 self.assertTrue(A(2) >= A(2))
Raymond Hettinger06bc0b62010-04-04 22:24:03 +0000424
425 def test_total_ordering_ge(self):
426 @functools.total_ordering
427 class A:
428 def __init__(self, value):
429 self.value = value
430 def __ge__(self, other):
431 return self.value >= other.value
Éric Araujo374274d2011-03-19 04:29:36 +0100432 def __eq__(self, other):
433 return self.value == other.value
Ezio Melotti2623a372010-11-21 13:34:58 +0000434 self.assertTrue(A(1) < A(2))
435 self.assertTrue(A(2) > A(1))
436 self.assertTrue(A(1) <= A(2))
437 self.assertTrue(A(2) >= A(1))
438 self.assertTrue(A(2) <= A(2))
439 self.assertTrue(A(2) >= A(2))
Raymond Hettinger06bc0b62010-04-04 22:24:03 +0000440
441 def test_total_ordering_no_overwrite(self):
442 # new methods should not overwrite existing
443 @functools.total_ordering
Benjamin Peterson9d0eaac2010-08-23 17:45:31 +0000444 class A(str):
Raymond Hettinger06bc0b62010-04-04 22:24:03 +0000445 pass
Ezio Melotti2623a372010-11-21 13:34:58 +0000446 self.assertTrue(A("a") < A("b"))
447 self.assertTrue(A("b") > A("a"))
448 self.assertTrue(A("a") <= A("b"))
449 self.assertTrue(A("b") >= A("a"))
450 self.assertTrue(A("b") <= A("b"))
451 self.assertTrue(A("b") >= A("b"))
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000452
Benjamin Petersona11da592010-04-11 01:40:32 +0000453 def test_no_operations_defined(self):
454 with self.assertRaises(ValueError):
455 @functools.total_ordering
456 class A:
457 pass
458
Éric Araujo374274d2011-03-19 04:29:36 +0100459 def test_bug_10042(self):
460 @functools.total_ordering
461 class TestTO:
462 def __init__(self, value):
463 self.value = value
464 def __eq__(self, other):
465 if isinstance(other, TestTO):
466 return self.value == other.value
467 return False
468 def __lt__(self, other):
469 if isinstance(other, TestTO):
470 return self.value < other.value
471 raise TypeError
472 with self.assertRaises(TypeError):
473 TestTO(8) <= ()
474
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000475def test_main(verbose=None):
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000476 test_classes = (
477 TestPartial,
478 TestPartialSubclass,
479 TestPythonPartial,
Nick Coghlan676725d2006-06-08 13:54:49 +0000480 TestUpdateWrapper,
Benjamin Peterson9d0eaac2010-08-23 17:45:31 +0000481 TestTotalOrdering,
Brett Cannon83e81842008-08-09 23:30:55 +0000482 TestWraps,
483 TestReduce,
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000484 )
485 test_support.run_unittest(*test_classes)
486
487 # verify reference counting
488 if verbose and hasattr(sys, "gettotalrefcount"):
489 import gc
490 counts = [None] * 5
491 for i in xrange(len(counts)):
492 test_support.run_unittest(*test_classes)
493 gc.collect()
494 counts[i] = sys.gettotalrefcount()
495 print counts
496
497if __name__ == '__main__':
498 test_main(verbose=True)