blob: af19ef7fec7528ac11a2a00e6945aae1c6c95cb4 [file] [log] [blame]
Nick Coghlanc649ec52006-05-29 12:43:05 +00001import functools
R. David Murrayf28fd242010-02-23 00:24:49 +00002import sys
Raymond Hettinger9c323f82005-02-28 19:39:44 +00003import unittest
4from test import test_support
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +00005from weakref import proxy
Jack Diederichd60c29e2009-03-31 23:46:48 +00006import pickle
Raymond Hettinger9c323f82005-02-28 19:39:44 +00007
8@staticmethod
9def PythonPartial(func, *args, **keywords):
10 'Pure Python approximation of partial()'
11 def newfunc(*fargs, **fkeywords):
12 newkeywords = keywords.copy()
13 newkeywords.update(fkeywords)
14 return func(*(args + fargs), **newkeywords)
15 newfunc.func = func
16 newfunc.args = args
17 newfunc.keywords = keywords
18 return newfunc
19
20def capture(*args, **kw):
21 """capture all positional and keyword arguments"""
22 return args, kw
23
Jack Diederichd60c29e2009-03-31 23:46:48 +000024def signature(part):
25 """ return the signature of a partial object """
26 return (part.func, part.args, part.keywords, part.__dict__)
27
Raymond Hettinger9c323f82005-02-28 19:39:44 +000028class TestPartial(unittest.TestCase):
29
Nick Coghlanc649ec52006-05-29 12:43:05 +000030 thetype = functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +000031
32 def test_basic_examples(self):
33 p = self.thetype(capture, 1, 2, a=10, b=20)
34 self.assertEqual(p(3, 4, b=30, c=40),
35 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
36 p = self.thetype(map, lambda x: x*10)
37 self.assertEqual(p([1,2,3,4]), [10, 20, 30, 40])
38
39 def test_attributes(self):
40 p = self.thetype(capture, 1, 2, a=10, b=20)
41 # attributes should be readable
42 self.assertEqual(p.func, capture)
43 self.assertEqual(p.args, (1, 2))
44 self.assertEqual(p.keywords, dict(a=10, b=20))
45 # attributes should not be writable
46 if not isinstance(self.thetype, type):
47 return
48 self.assertRaises(TypeError, setattr, p, 'func', map)
49 self.assertRaises(TypeError, setattr, p, 'args', (1, 2))
50 self.assertRaises(TypeError, setattr, p, 'keywords', dict(a=1, b=2))
51
Georg Brandla34f87f2010-02-07 12:27:06 +000052 p = self.thetype(hex)
53 try:
54 del p.__dict__
55 except TypeError:
56 pass
57 else:
58 self.fail('partial object allowed __dict__ to be deleted')
59
Raymond Hettinger9c323f82005-02-28 19:39:44 +000060 def test_argument_checking(self):
61 self.assertRaises(TypeError, self.thetype) # need at least a func arg
62 try:
63 self.thetype(2)()
64 except TypeError:
65 pass
66 else:
67 self.fail('First arg not checked for callability')
68
69 def test_protection_of_callers_dict_argument(self):
70 # a caller's dictionary should not be altered by partial
71 def func(a=10, b=20):
72 return a
73 d = {'a':3}
74 p = self.thetype(func, a=5)
75 self.assertEqual(p(**d), 3)
76 self.assertEqual(d, {'a':3})
77 p(b=7)
78 self.assertEqual(d, {'a':3})
79
80 def test_arg_combinations(self):
81 # exercise special code paths for zero args in either partial
82 # object or the caller
83 p = self.thetype(capture)
84 self.assertEqual(p(), ((), {}))
85 self.assertEqual(p(1,2), ((1,2), {}))
86 p = self.thetype(capture, 1, 2)
87 self.assertEqual(p(), ((1,2), {}))
88 self.assertEqual(p(3,4), ((1,2,3,4), {}))
89
90 def test_kw_combinations(self):
91 # exercise special code paths for no keyword args in
92 # either the partial object or the caller
93 p = self.thetype(capture)
94 self.assertEqual(p(), ((), {}))
95 self.assertEqual(p(a=1), ((), {'a':1}))
96 p = self.thetype(capture, a=1)
97 self.assertEqual(p(), ((), {'a':1}))
98 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
99 # keyword args in the call override those in the partial object
100 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
101
102 def test_positional(self):
103 # make sure positional arguments are captured correctly
104 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
105 p = self.thetype(capture, *args)
106 expected = args + ('x',)
107 got, empty = p('x')
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000108 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000109
110 def test_keyword(self):
111 # make sure keyword arguments are captured correctly
112 for a in ['a', 0, None, 3.5]:
113 p = self.thetype(capture, a=a)
114 expected = {'a':a,'x':None}
115 empty, got = p(x=None)
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000116 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000117
118 def test_no_side_effects(self):
119 # make sure there are no side effects that affect subsequent calls
120 p = self.thetype(capture, 0, a=1)
121 args1, kw1 = p(1, b=2)
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000122 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000123 args2, kw2 = p()
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000124 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000125
126 def test_error_propagation(self):
127 def f(x, y):
Ezio Melottidde5b942010-02-03 05:37:26 +0000128 x // y
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000129 self.assertRaises(ZeroDivisionError, self.thetype(f, 1, 0))
130 self.assertRaises(ZeroDivisionError, self.thetype(f, 1), 0)
131 self.assertRaises(ZeroDivisionError, self.thetype(f), 1, 0)
132 self.assertRaises(ZeroDivisionError, self.thetype(f, y=0), 1)
133
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000134 def test_weakref(self):
135 f = self.thetype(int, base=16)
136 p = proxy(f)
137 self.assertEqual(f.func, p.func)
138 f = None
139 self.assertRaises(ReferenceError, getattr, p, 'func')
140
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000141 def test_with_bound_and_unbound_methods(self):
142 data = map(str, range(10))
143 join = self.thetype(str.join, '')
144 self.assertEqual(join(data), '0123456789')
145 join = self.thetype(''.join)
146 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000147
Jack Diederichd60c29e2009-03-31 23:46:48 +0000148 def test_pickle(self):
149 f = self.thetype(signature, 'asdf', bar=True)
150 f.add_something_to__dict__ = True
151 f_copy = pickle.loads(pickle.dumps(f))
152 self.assertEqual(signature(f), signature(f_copy))
153
Nick Coghlanc649ec52006-05-29 12:43:05 +0000154class PartialSubclass(functools.partial):
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000155 pass
156
157class TestPartialSubclass(TestPartial):
158
159 thetype = PartialSubclass
160
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000161class TestPythonPartial(TestPartial):
162
163 thetype = PythonPartial
164
Jack Diederichd60c29e2009-03-31 23:46:48 +0000165 # the python version isn't picklable
166 def test_pickle(self): pass
167
Nick Coghlan676725d2006-06-08 13:54:49 +0000168class TestUpdateWrapper(unittest.TestCase):
169
170 def check_wrapper(self, wrapper, wrapped,
171 assigned=functools.WRAPPER_ASSIGNMENTS,
172 updated=functools.WRAPPER_UPDATES):
173 # Check attributes were assigned
174 for name in assigned:
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000175 self.assertTrue(getattr(wrapper, name) is getattr(wrapped, name))
Nick Coghlan676725d2006-06-08 13:54:49 +0000176 # Check attributes were updated
177 for name in updated:
178 wrapper_attr = getattr(wrapper, name)
179 wrapped_attr = getattr(wrapped, name)
180 for key in wrapped_attr:
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000181 self.assertTrue(wrapped_attr[key] is wrapper_attr[key])
Nick Coghlan676725d2006-06-08 13:54:49 +0000182
R. David Murrayf28fd242010-02-23 00:24:49 +0000183 def _default_update(self):
Nick Coghlan676725d2006-06-08 13:54:49 +0000184 def f():
185 """This is a test"""
186 pass
187 f.attr = 'This is also a test'
188 def wrapper():
189 pass
190 functools.update_wrapper(wrapper, f)
R. David Murrayf28fd242010-02-23 00:24:49 +0000191 return wrapper, f
192
193 def test_default_update(self):
194 wrapper, f = self._default_update()
Nick Coghlan676725d2006-06-08 13:54:49 +0000195 self.check_wrapper(wrapper, f)
196 self.assertEqual(wrapper.__name__, 'f')
Nick Coghlan676725d2006-06-08 13:54:49 +0000197 self.assertEqual(wrapper.attr, 'This is also a test')
198
R. David Murrayf28fd242010-02-23 00:24:49 +0000199 @unittest.skipIf(sys.flags.optimize >= 2,
200 "Docstrings are omitted with -O2 and above")
201 def test_default_update_doc(self):
202 wrapper, f = self._default_update()
203 self.assertEqual(wrapper.__doc__, 'This is a test')
204
Nick Coghlan676725d2006-06-08 13:54:49 +0000205 def test_no_update(self):
206 def f():
207 """This is a test"""
208 pass
209 f.attr = 'This is also a test'
210 def wrapper():
211 pass
212 functools.update_wrapper(wrapper, f, (), ())
213 self.check_wrapper(wrapper, f, (), ())
214 self.assertEqual(wrapper.__name__, 'wrapper')
215 self.assertEqual(wrapper.__doc__, None)
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000216 self.assertFalse(hasattr(wrapper, 'attr'))
Nick Coghlan676725d2006-06-08 13:54:49 +0000217
218 def test_selective_update(self):
219 def f():
220 pass
221 f.attr = 'This is a different test'
222 f.dict_attr = dict(a=1, b=2, c=3)
223 def wrapper():
224 pass
225 wrapper.dict_attr = {}
226 assign = ('attr',)
227 update = ('dict_attr',)
228 functools.update_wrapper(wrapper, f, assign, update)
229 self.check_wrapper(wrapper, f, assign, update)
230 self.assertEqual(wrapper.__name__, 'wrapper')
231 self.assertEqual(wrapper.__doc__, None)
232 self.assertEqual(wrapper.attr, 'This is a different test')
233 self.assertEqual(wrapper.dict_attr, f.dict_attr)
234
Serhiy Storchaka72121c62013-01-27 19:45:49 +0200235 @test_support.requires_docstrings
Andrew M. Kuchling41eb7162006-10-27 16:39:10 +0000236 def test_builtin_update(self):
237 # Test for bug #1576241
238 def wrapper():
239 pass
240 functools.update_wrapper(wrapper, max)
241 self.assertEqual(wrapper.__name__, 'max')
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000242 self.assertTrue(wrapper.__doc__.startswith('max('))
Nick Coghlan676725d2006-06-08 13:54:49 +0000243
244class TestWraps(TestUpdateWrapper):
245
R. David Murrayf28fd242010-02-23 00:24:49 +0000246 def _default_update(self):
Nick Coghlan676725d2006-06-08 13:54:49 +0000247 def f():
248 """This is a test"""
249 pass
250 f.attr = 'This is also a test'
251 @functools.wraps(f)
252 def wrapper():
253 pass
254 self.check_wrapper(wrapper, f)
R. David Murrayf28fd242010-02-23 00:24:49 +0000255 return wrapper
256
257 def test_default_update(self):
258 wrapper = self._default_update()
Nick Coghlan676725d2006-06-08 13:54:49 +0000259 self.assertEqual(wrapper.__name__, 'f')
Nick Coghlan676725d2006-06-08 13:54:49 +0000260 self.assertEqual(wrapper.attr, 'This is also a test')
261
R. David Murrayf28fd242010-02-23 00:24:49 +0000262 @unittest.skipIf(not sys.flags.optimize <= 1,
263 "Docstrings are omitted with -O2 and above")
264 def test_default_update_doc(self):
265 wrapper = self._default_update()
266 self.assertEqual(wrapper.__doc__, 'This is a test')
267
Nick Coghlan676725d2006-06-08 13:54:49 +0000268 def test_no_update(self):
269 def f():
270 """This is a test"""
271 pass
272 f.attr = 'This is also a test'
273 @functools.wraps(f, (), ())
274 def wrapper():
275 pass
276 self.check_wrapper(wrapper, f, (), ())
277 self.assertEqual(wrapper.__name__, 'wrapper')
278 self.assertEqual(wrapper.__doc__, None)
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000279 self.assertFalse(hasattr(wrapper, 'attr'))
Nick Coghlan676725d2006-06-08 13:54:49 +0000280
281 def test_selective_update(self):
282 def f():
283 pass
284 f.attr = 'This is a different test'
285 f.dict_attr = dict(a=1, b=2, c=3)
286 def add_dict_attr(f):
287 f.dict_attr = {}
288 return f
289 assign = ('attr',)
290 update = ('dict_attr',)
291 @functools.wraps(f, assign, update)
292 @add_dict_attr
293 def wrapper():
294 pass
295 self.check_wrapper(wrapper, f, assign, update)
296 self.assertEqual(wrapper.__name__, 'wrapper')
297 self.assertEqual(wrapper.__doc__, None)
298 self.assertEqual(wrapper.attr, 'This is a different test')
299 self.assertEqual(wrapper.dict_attr, f.dict_attr)
300
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000301
Brett Cannon83e81842008-08-09 23:30:55 +0000302class TestReduce(unittest.TestCase):
303
304 def test_reduce(self):
305 class Squares:
306
307 def __init__(self, max):
308 self.max = max
309 self.sofar = []
310
311 def __len__(self): return len(self.sofar)
312
313 def __getitem__(self, i):
314 if not 0 <= i < self.max: raise IndexError
315 n = len(self.sofar)
316 while n <= i:
317 self.sofar.append(n*n)
318 n += 1
319 return self.sofar[i]
320
321 reduce = functools.reduce
322 self.assertEqual(reduce(lambda x, y: x+y, ['a', 'b', 'c'], ''), 'abc')
323 self.assertEqual(
324 reduce(lambda x, y: x+y, [['a', 'c'], [], ['d', 'w']], []),
325 ['a','c','d','w']
326 )
327 self.assertEqual(reduce(lambda x, y: x*y, range(2,8), 1), 5040)
328 self.assertEqual(
329 reduce(lambda x, y: x*y, range(2,21), 1L),
330 2432902008176640000L
331 )
332 self.assertEqual(reduce(lambda x, y: x+y, Squares(10)), 285)
333 self.assertEqual(reduce(lambda x, y: x+y, Squares(10), 0), 285)
334 self.assertEqual(reduce(lambda x, y: x+y, Squares(0), 0), 0)
335 self.assertRaises(TypeError, reduce)
336 self.assertRaises(TypeError, reduce, 42, 42)
337 self.assertRaises(TypeError, reduce, 42, 42, 42)
338 self.assertEqual(reduce(42, "1"), "1") # func is never called with one item
339 self.assertEqual(reduce(42, "", "1"), "1") # func is never called with one item
340 self.assertRaises(TypeError, reduce, 42, (42, 42))
341
Raymond Hettingerbb006cf2010-04-04 21:45:01 +0000342class TestCmpToKey(unittest.TestCase):
343 def test_cmp_to_key(self):
344 def mycmp(x, y):
345 return y - x
346 self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)),
347 [4, 3, 2, 1, 0])
Brett Cannon83e81842008-08-09 23:30:55 +0000348
Raymond Hettingere1d665a2010-04-05 18:53:43 +0000349 def test_hash(self):
350 def mycmp(x, y):
351 return y - x
352 key = functools.cmp_to_key(mycmp)
353 k = key(10)
354 self.assertRaises(TypeError, hash(k))
355
Raymond Hettinger06bc0b62010-04-04 22:24:03 +0000356class TestTotalOrdering(unittest.TestCase):
357
358 def test_total_ordering_lt(self):
359 @functools.total_ordering
360 class A:
361 def __init__(self, value):
362 self.value = value
363 def __lt__(self, other):
364 return self.value < other.value
Éric Araujo374274d2011-03-19 04:29:36 +0100365 def __eq__(self, other):
366 return self.value == other.value
Ezio Melotti2623a372010-11-21 13:34:58 +0000367 self.assertTrue(A(1) < A(2))
368 self.assertTrue(A(2) > A(1))
369 self.assertTrue(A(1) <= A(2))
370 self.assertTrue(A(2) >= A(1))
371 self.assertTrue(A(2) <= A(2))
372 self.assertTrue(A(2) >= A(2))
Raymond Hettinger06bc0b62010-04-04 22:24:03 +0000373
374 def test_total_ordering_le(self):
375 @functools.total_ordering
376 class A:
377 def __init__(self, value):
378 self.value = value
379 def __le__(self, other):
380 return self.value <= other.value
Éric Araujo374274d2011-03-19 04:29:36 +0100381 def __eq__(self, other):
382 return self.value == other.value
Ezio Melotti2623a372010-11-21 13:34:58 +0000383 self.assertTrue(A(1) < A(2))
384 self.assertTrue(A(2) > A(1))
385 self.assertTrue(A(1) <= A(2))
386 self.assertTrue(A(2) >= A(1))
387 self.assertTrue(A(2) <= A(2))
388 self.assertTrue(A(2) >= A(2))
Raymond Hettinger06bc0b62010-04-04 22:24:03 +0000389
390 def test_total_ordering_gt(self):
391 @functools.total_ordering
392 class A:
393 def __init__(self, value):
394 self.value = value
395 def __gt__(self, other):
396 return self.value > other.value
Éric Araujo374274d2011-03-19 04:29:36 +0100397 def __eq__(self, other):
398 return self.value == other.value
Ezio Melotti2623a372010-11-21 13:34:58 +0000399 self.assertTrue(A(1) < A(2))
400 self.assertTrue(A(2) > A(1))
401 self.assertTrue(A(1) <= A(2))
402 self.assertTrue(A(2) >= A(1))
403 self.assertTrue(A(2) <= A(2))
404 self.assertTrue(A(2) >= A(2))
Raymond Hettinger06bc0b62010-04-04 22:24:03 +0000405
406 def test_total_ordering_ge(self):
407 @functools.total_ordering
408 class A:
409 def __init__(self, value):
410 self.value = value
411 def __ge__(self, other):
412 return self.value >= other.value
Éric Araujo374274d2011-03-19 04:29:36 +0100413 def __eq__(self, other):
414 return self.value == other.value
Ezio Melotti2623a372010-11-21 13:34:58 +0000415 self.assertTrue(A(1) < A(2))
416 self.assertTrue(A(2) > A(1))
417 self.assertTrue(A(1) <= A(2))
418 self.assertTrue(A(2) >= A(1))
419 self.assertTrue(A(2) <= A(2))
420 self.assertTrue(A(2) >= A(2))
Raymond Hettinger06bc0b62010-04-04 22:24:03 +0000421
422 def test_total_ordering_no_overwrite(self):
423 # new methods should not overwrite existing
424 @functools.total_ordering
Benjamin Peterson9d0eaac2010-08-23 17:45:31 +0000425 class A(str):
Raymond Hettinger06bc0b62010-04-04 22:24:03 +0000426 pass
Ezio Melotti2623a372010-11-21 13:34:58 +0000427 self.assertTrue(A("a") < A("b"))
428 self.assertTrue(A("b") > A("a"))
429 self.assertTrue(A("a") <= A("b"))
430 self.assertTrue(A("b") >= A("a"))
431 self.assertTrue(A("b") <= A("b"))
432 self.assertTrue(A("b") >= A("b"))
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000433
Benjamin Petersona11da592010-04-11 01:40:32 +0000434 def test_no_operations_defined(self):
435 with self.assertRaises(ValueError):
436 @functools.total_ordering
437 class A:
438 pass
439
Éric Araujo374274d2011-03-19 04:29:36 +0100440 def test_bug_10042(self):
441 @functools.total_ordering
442 class TestTO:
443 def __init__(self, value):
444 self.value = value
445 def __eq__(self, other):
446 if isinstance(other, TestTO):
447 return self.value == other.value
448 return False
449 def __lt__(self, other):
450 if isinstance(other, TestTO):
451 return self.value < other.value
452 raise TypeError
453 with self.assertRaises(TypeError):
454 TestTO(8) <= ()
455
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000456def test_main(verbose=None):
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000457 test_classes = (
458 TestPartial,
459 TestPartialSubclass,
460 TestPythonPartial,
Nick Coghlan676725d2006-06-08 13:54:49 +0000461 TestUpdateWrapper,
Benjamin Peterson9d0eaac2010-08-23 17:45:31 +0000462 TestTotalOrdering,
Brett Cannon83e81842008-08-09 23:30:55 +0000463 TestWraps,
464 TestReduce,
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000465 )
466 test_support.run_unittest(*test_classes)
467
468 # verify reference counting
469 if verbose and hasattr(sys, "gettotalrefcount"):
470 import gc
471 counts = [None] * 5
472 for i in xrange(len(counts)):
473 test_support.run_unittest(*test_classes)
474 gc.collect()
475 counts[i] = sys.gettotalrefcount()
476 print counts
477
478if __name__ == '__main__':
479 test_main(verbose=True)