blob: b88e9b7981a044c730996111fcaa0e2ad5dc27bd [file] [log] [blame]
Nick Coghlanc649ec52006-05-29 12:43:05 +00001import functools
Raymond Hettinger9c323f82005-02-28 19:39:44 +00002import unittest
3from test import test_support
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +00004from weakref import proxy
Jack Diederichd60c29e2009-03-31 23:46:48 +00005import pickle
Raymond Hettinger9c323f82005-02-28 19:39:44 +00006
7@staticmethod
8def PythonPartial(func, *args, **keywords):
9 'Pure Python approximation of partial()'
10 def newfunc(*fargs, **fkeywords):
11 newkeywords = keywords.copy()
12 newkeywords.update(fkeywords)
13 return func(*(args + fargs), **newkeywords)
14 newfunc.func = func
15 newfunc.args = args
16 newfunc.keywords = keywords
17 return newfunc
18
19def capture(*args, **kw):
20 """capture all positional and keyword arguments"""
21 return args, kw
22
Jack Diederichd60c29e2009-03-31 23:46:48 +000023def signature(part):
24 """ return the signature of a partial object """
25 return (part.func, part.args, part.keywords, part.__dict__)
26
Raymond Hettinger9c323f82005-02-28 19:39:44 +000027class TestPartial(unittest.TestCase):
28
Nick Coghlanc649ec52006-05-29 12:43:05 +000029 thetype = functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +000030
31 def test_basic_examples(self):
32 p = self.thetype(capture, 1, 2, a=10, b=20)
33 self.assertEqual(p(3, 4, b=30, c=40),
34 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
35 p = self.thetype(map, lambda x: x*10)
36 self.assertEqual(p([1,2,3,4]), [10, 20, 30, 40])
37
38 def test_attributes(self):
39 p = self.thetype(capture, 1, 2, a=10, b=20)
40 # attributes should be readable
41 self.assertEqual(p.func, capture)
42 self.assertEqual(p.args, (1, 2))
43 self.assertEqual(p.keywords, dict(a=10, b=20))
44 # attributes should not be writable
45 if not isinstance(self.thetype, type):
46 return
47 self.assertRaises(TypeError, setattr, p, 'func', map)
48 self.assertRaises(TypeError, setattr, p, 'args', (1, 2))
49 self.assertRaises(TypeError, setattr, p, 'keywords', dict(a=1, b=2))
50
Georg Brandla34f87f2010-02-07 12:27:06 +000051 p = self.thetype(hex)
52 try:
53 del p.__dict__
54 except TypeError:
55 pass
56 else:
57 self.fail('partial object allowed __dict__ to be deleted')
58
Raymond Hettinger9c323f82005-02-28 19:39:44 +000059 def test_argument_checking(self):
60 self.assertRaises(TypeError, self.thetype) # need at least a func arg
61 try:
62 self.thetype(2)()
63 except TypeError:
64 pass
65 else:
66 self.fail('First arg not checked for callability')
67
68 def test_protection_of_callers_dict_argument(self):
69 # a caller's dictionary should not be altered by partial
70 def func(a=10, b=20):
71 return a
72 d = {'a':3}
73 p = self.thetype(func, a=5)
74 self.assertEqual(p(**d), 3)
75 self.assertEqual(d, {'a':3})
76 p(b=7)
77 self.assertEqual(d, {'a':3})
78
79 def test_arg_combinations(self):
80 # exercise special code paths for zero args in either partial
81 # object or the caller
82 p = self.thetype(capture)
83 self.assertEqual(p(), ((), {}))
84 self.assertEqual(p(1,2), ((1,2), {}))
85 p = self.thetype(capture, 1, 2)
86 self.assertEqual(p(), ((1,2), {}))
87 self.assertEqual(p(3,4), ((1,2,3,4), {}))
88
89 def test_kw_combinations(self):
90 # exercise special code paths for no keyword args in
91 # either the partial object or the caller
92 p = self.thetype(capture)
93 self.assertEqual(p(), ((), {}))
94 self.assertEqual(p(a=1), ((), {'a':1}))
95 p = self.thetype(capture, a=1)
96 self.assertEqual(p(), ((), {'a':1}))
97 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
98 # keyword args in the call override those in the partial object
99 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
100
101 def test_positional(self):
102 # make sure positional arguments are captured correctly
103 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
104 p = self.thetype(capture, *args)
105 expected = args + ('x',)
106 got, empty = p('x')
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000107 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000108
109 def test_keyword(self):
110 # make sure keyword arguments are captured correctly
111 for a in ['a', 0, None, 3.5]:
112 p = self.thetype(capture, a=a)
113 expected = {'a':a,'x':None}
114 empty, got = p(x=None)
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000115 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000116
117 def test_no_side_effects(self):
118 # make sure there are no side effects that affect subsequent calls
119 p = self.thetype(capture, 0, a=1)
120 args1, kw1 = p(1, b=2)
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000121 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000122 args2, kw2 = p()
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000123 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000124
125 def test_error_propagation(self):
126 def f(x, y):
Ezio Melottidde5b942010-02-03 05:37:26 +0000127 x // y
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000128 self.assertRaises(ZeroDivisionError, self.thetype(f, 1, 0))
129 self.assertRaises(ZeroDivisionError, self.thetype(f, 1), 0)
130 self.assertRaises(ZeroDivisionError, self.thetype(f), 1, 0)
131 self.assertRaises(ZeroDivisionError, self.thetype(f, y=0), 1)
132
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000133 def test_weakref(self):
134 f = self.thetype(int, base=16)
135 p = proxy(f)
136 self.assertEqual(f.func, p.func)
137 f = None
138 self.assertRaises(ReferenceError, getattr, p, 'func')
139
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000140 def test_with_bound_and_unbound_methods(self):
141 data = map(str, range(10))
142 join = self.thetype(str.join, '')
143 self.assertEqual(join(data), '0123456789')
144 join = self.thetype(''.join)
145 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000146
Jack Diederichd60c29e2009-03-31 23:46:48 +0000147 def test_pickle(self):
148 f = self.thetype(signature, 'asdf', bar=True)
149 f.add_something_to__dict__ = True
150 f_copy = pickle.loads(pickle.dumps(f))
151 self.assertEqual(signature(f), signature(f_copy))
152
Nick Coghlanc649ec52006-05-29 12:43:05 +0000153class PartialSubclass(functools.partial):
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000154 pass
155
156class TestPartialSubclass(TestPartial):
157
158 thetype = PartialSubclass
159
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000160class TestPythonPartial(TestPartial):
161
162 thetype = PythonPartial
163
Jack Diederichd60c29e2009-03-31 23:46:48 +0000164 # the python version isn't picklable
165 def test_pickle(self): pass
166
Nick Coghlan676725d2006-06-08 13:54:49 +0000167class TestUpdateWrapper(unittest.TestCase):
168
169 def check_wrapper(self, wrapper, wrapped,
170 assigned=functools.WRAPPER_ASSIGNMENTS,
171 updated=functools.WRAPPER_UPDATES):
172 # Check attributes were assigned
173 for name in assigned:
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000174 self.assertTrue(getattr(wrapper, name) is getattr(wrapped, name))
Nick Coghlan676725d2006-06-08 13:54:49 +0000175 # Check attributes were updated
176 for name in updated:
177 wrapper_attr = getattr(wrapper, name)
178 wrapped_attr = getattr(wrapped, name)
179 for key in wrapped_attr:
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000180 self.assertTrue(wrapped_attr[key] is wrapper_attr[key])
Nick Coghlan676725d2006-06-08 13:54:49 +0000181
182 def test_default_update(self):
183 def f():
184 """This is a test"""
185 pass
186 f.attr = 'This is also a test'
187 def wrapper():
188 pass
189 functools.update_wrapper(wrapper, f)
190 self.check_wrapper(wrapper, f)
191 self.assertEqual(wrapper.__name__, 'f')
192 self.assertEqual(wrapper.__doc__, 'This is a test')
193 self.assertEqual(wrapper.attr, 'This is also a test')
194
195 def test_no_update(self):
196 def f():
197 """This is a test"""
198 pass
199 f.attr = 'This is also a test'
200 def wrapper():
201 pass
202 functools.update_wrapper(wrapper, f, (), ())
203 self.check_wrapper(wrapper, f, (), ())
204 self.assertEqual(wrapper.__name__, 'wrapper')
205 self.assertEqual(wrapper.__doc__, None)
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000206 self.assertFalse(hasattr(wrapper, 'attr'))
Nick Coghlan676725d2006-06-08 13:54:49 +0000207
208 def test_selective_update(self):
209 def f():
210 pass
211 f.attr = 'This is a different test'
212 f.dict_attr = dict(a=1, b=2, c=3)
213 def wrapper():
214 pass
215 wrapper.dict_attr = {}
216 assign = ('attr',)
217 update = ('dict_attr',)
218 functools.update_wrapper(wrapper, f, assign, update)
219 self.check_wrapper(wrapper, f, assign, update)
220 self.assertEqual(wrapper.__name__, 'wrapper')
221 self.assertEqual(wrapper.__doc__, None)
222 self.assertEqual(wrapper.attr, 'This is a different test')
223 self.assertEqual(wrapper.dict_attr, f.dict_attr)
224
Andrew M. Kuchling41eb7162006-10-27 16:39:10 +0000225 def test_builtin_update(self):
226 # Test for bug #1576241
227 def wrapper():
228 pass
229 functools.update_wrapper(wrapper, max)
230 self.assertEqual(wrapper.__name__, 'max')
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000231 self.assertTrue(wrapper.__doc__.startswith('max('))
Nick Coghlan676725d2006-06-08 13:54:49 +0000232
233class TestWraps(TestUpdateWrapper):
234
235 def test_default_update(self):
236 def f():
237 """This is a test"""
238 pass
239 f.attr = 'This is also a test'
240 @functools.wraps(f)
241 def wrapper():
242 pass
243 self.check_wrapper(wrapper, f)
244 self.assertEqual(wrapper.__name__, 'f')
245 self.assertEqual(wrapper.__doc__, 'This is a test')
246 self.assertEqual(wrapper.attr, 'This is also a test')
247
248 def test_no_update(self):
249 def f():
250 """This is a test"""
251 pass
252 f.attr = 'This is also a test'
253 @functools.wraps(f, (), ())
254 def wrapper():
255 pass
256 self.check_wrapper(wrapper, f, (), ())
257 self.assertEqual(wrapper.__name__, 'wrapper')
258 self.assertEqual(wrapper.__doc__, None)
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000259 self.assertFalse(hasattr(wrapper, 'attr'))
Nick Coghlan676725d2006-06-08 13:54:49 +0000260
261 def test_selective_update(self):
262 def f():
263 pass
264 f.attr = 'This is a different test'
265 f.dict_attr = dict(a=1, b=2, c=3)
266 def add_dict_attr(f):
267 f.dict_attr = {}
268 return f
269 assign = ('attr',)
270 update = ('dict_attr',)
271 @functools.wraps(f, assign, update)
272 @add_dict_attr
273 def wrapper():
274 pass
275 self.check_wrapper(wrapper, f, assign, update)
276 self.assertEqual(wrapper.__name__, 'wrapper')
277 self.assertEqual(wrapper.__doc__, None)
278 self.assertEqual(wrapper.attr, 'This is a different test')
279 self.assertEqual(wrapper.dict_attr, f.dict_attr)
280
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000281
Brett Cannon83e81842008-08-09 23:30:55 +0000282class TestReduce(unittest.TestCase):
283
284 def test_reduce(self):
285 class Squares:
286
287 def __init__(self, max):
288 self.max = max
289 self.sofar = []
290
291 def __len__(self): return len(self.sofar)
292
293 def __getitem__(self, i):
294 if not 0 <= i < self.max: raise IndexError
295 n = len(self.sofar)
296 while n <= i:
297 self.sofar.append(n*n)
298 n += 1
299 return self.sofar[i]
300
301 reduce = functools.reduce
302 self.assertEqual(reduce(lambda x, y: x+y, ['a', 'b', 'c'], ''), 'abc')
303 self.assertEqual(
304 reduce(lambda x, y: x+y, [['a', 'c'], [], ['d', 'w']], []),
305 ['a','c','d','w']
306 )
307 self.assertEqual(reduce(lambda x, y: x*y, range(2,8), 1), 5040)
308 self.assertEqual(
309 reduce(lambda x, y: x*y, range(2,21), 1L),
310 2432902008176640000L
311 )
312 self.assertEqual(reduce(lambda x, y: x+y, Squares(10)), 285)
313 self.assertEqual(reduce(lambda x, y: x+y, Squares(10), 0), 285)
314 self.assertEqual(reduce(lambda x, y: x+y, Squares(0), 0), 0)
315 self.assertRaises(TypeError, reduce)
316 self.assertRaises(TypeError, reduce, 42, 42)
317 self.assertRaises(TypeError, reduce, 42, 42, 42)
318 self.assertEqual(reduce(42, "1"), "1") # func is never called with one item
319 self.assertEqual(reduce(42, "", "1"), "1") # func is never called with one item
320 self.assertRaises(TypeError, reduce, 42, (42, 42))
321
322
323
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000324
325def test_main(verbose=None):
326 import sys
327 test_classes = (
328 TestPartial,
329 TestPartialSubclass,
330 TestPythonPartial,
Nick Coghlan676725d2006-06-08 13:54:49 +0000331 TestUpdateWrapper,
Brett Cannon83e81842008-08-09 23:30:55 +0000332 TestWraps,
333 TestReduce,
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000334 )
335 test_support.run_unittest(*test_classes)
336
337 # verify reference counting
338 if verbose and hasattr(sys, "gettotalrefcount"):
339 import gc
340 counts = [None] * 5
341 for i in xrange(len(counts)):
342 test_support.run_unittest(*test_classes)
343 gc.collect()
344 counts[i] = sys.gettotalrefcount()
345 print counts
346
347if __name__ == '__main__':
348 test_main(verbose=True)