blob: 31f8f7037df13e68403cc3480acc049a592dc903 [file] [log] [blame]
Nick Coghlanc649ec52006-05-29 12:43:05 +00001import functools
Raymond Hettinger9c323f82005-02-28 19:39:44 +00002import unittest
3from test import test_support
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +00004from weakref import proxy
Raymond Hettinger9c323f82005-02-28 19:39:44 +00005
6@staticmethod
7def PythonPartial(func, *args, **keywords):
8 'Pure Python approximation of partial()'
9 def newfunc(*fargs, **fkeywords):
10 newkeywords = keywords.copy()
11 newkeywords.update(fkeywords)
12 return func(*(args + fargs), **newkeywords)
13 newfunc.func = func
14 newfunc.args = args
15 newfunc.keywords = keywords
16 return newfunc
17
18def capture(*args, **kw):
19 """capture all positional and keyword arguments"""
20 return args, kw
21
22class TestPartial(unittest.TestCase):
23
Nick Coghlanc649ec52006-05-29 12:43:05 +000024 thetype = functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +000025
26 def test_basic_examples(self):
27 p = self.thetype(capture, 1, 2, a=10, b=20)
28 self.assertEqual(p(3, 4, b=30, c=40),
29 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
30 p = self.thetype(map, lambda x: x*10)
31 self.assertEqual(p([1,2,3,4]), [10, 20, 30, 40])
32
33 def test_attributes(self):
34 p = self.thetype(capture, 1, 2, a=10, b=20)
35 # attributes should be readable
36 self.assertEqual(p.func, capture)
37 self.assertEqual(p.args, (1, 2))
38 self.assertEqual(p.keywords, dict(a=10, b=20))
39 # attributes should not be writable
40 if not isinstance(self.thetype, type):
41 return
42 self.assertRaises(TypeError, setattr, p, 'func', map)
43 self.assertRaises(TypeError, setattr, p, 'args', (1, 2))
44 self.assertRaises(TypeError, setattr, p, 'keywords', dict(a=1, b=2))
45
Georg Brandl9bbf8362010-02-07 13:02:10 +000046 p = self.thetype(hex)
47 try:
48 del p.__dict__
49 except TypeError:
50 pass
51 else:
52 self.fail('partial object allowed __dict__ to be deleted')
53
Raymond Hettinger9c323f82005-02-28 19:39:44 +000054 def test_argument_checking(self):
55 self.assertRaises(TypeError, self.thetype) # need at least a func arg
56 try:
57 self.thetype(2)()
58 except TypeError:
59 pass
60 else:
61 self.fail('First arg not checked for callability')
62
63 def test_protection_of_callers_dict_argument(self):
64 # a caller's dictionary should not be altered by partial
65 def func(a=10, b=20):
66 return a
67 d = {'a':3}
68 p = self.thetype(func, a=5)
69 self.assertEqual(p(**d), 3)
70 self.assertEqual(d, {'a':3})
71 p(b=7)
72 self.assertEqual(d, {'a':3})
73
74 def test_arg_combinations(self):
75 # exercise special code paths for zero args in either partial
76 # object or the caller
77 p = self.thetype(capture)
78 self.assertEqual(p(), ((), {}))
79 self.assertEqual(p(1,2), ((1,2), {}))
80 p = self.thetype(capture, 1, 2)
81 self.assertEqual(p(), ((1,2), {}))
82 self.assertEqual(p(3,4), ((1,2,3,4), {}))
83
84 def test_kw_combinations(self):
85 # exercise special code paths for no keyword args in
86 # either the partial object or the caller
87 p = self.thetype(capture)
88 self.assertEqual(p(), ((), {}))
89 self.assertEqual(p(a=1), ((), {'a':1}))
90 p = self.thetype(capture, a=1)
91 self.assertEqual(p(), ((), {'a':1}))
92 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
93 # keyword args in the call override those in the partial object
94 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
95
96 def test_positional(self):
97 # make sure positional arguments are captured correctly
98 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
99 p = self.thetype(capture, *args)
100 expected = args + ('x',)
101 got, empty = p('x')
102 self.failUnless(expected == got and empty == {})
103
104 def test_keyword(self):
105 # make sure keyword arguments are captured correctly
106 for a in ['a', 0, None, 3.5]:
107 p = self.thetype(capture, a=a)
108 expected = {'a':a,'x':None}
109 empty, got = p(x=None)
110 self.failUnless(expected == got and empty == ())
111
112 def test_no_side_effects(self):
113 # make sure there are no side effects that affect subsequent calls
114 p = self.thetype(capture, 0, a=1)
115 args1, kw1 = p(1, b=2)
116 self.failUnless(args1 == (0,1) and kw1 == {'a':1,'b':2})
117 args2, kw2 = p()
118 self.failUnless(args2 == (0,) and kw2 == {'a':1})
119
120 def test_error_propagation(self):
121 def f(x, y):
Ezio Melotti3efafd72010-08-02 18:40:55 +0000122 x // y
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000123 self.assertRaises(ZeroDivisionError, self.thetype(f, 1, 0))
124 self.assertRaises(ZeroDivisionError, self.thetype(f, 1), 0)
125 self.assertRaises(ZeroDivisionError, self.thetype(f), 1, 0)
126 self.assertRaises(ZeroDivisionError, self.thetype(f, y=0), 1)
127
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000128 def test_weakref(self):
129 f = self.thetype(int, base=16)
130 p = proxy(f)
131 self.assertEqual(f.func, p.func)
132 f = None
133 self.assertRaises(ReferenceError, getattr, p, 'func')
134
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000135 def test_with_bound_and_unbound_methods(self):
136 data = map(str, range(10))
137 join = self.thetype(str.join, '')
138 self.assertEqual(join(data), '0123456789')
139 join = self.thetype(''.join)
140 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000141
Nick Coghlanc649ec52006-05-29 12:43:05 +0000142class PartialSubclass(functools.partial):
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000143 pass
144
145class TestPartialSubclass(TestPartial):
146
147 thetype = PartialSubclass
148
149
150class TestPythonPartial(TestPartial):
151
152 thetype = PythonPartial
153
Nick Coghlan676725d2006-06-08 13:54:49 +0000154class TestUpdateWrapper(unittest.TestCase):
155
156 def check_wrapper(self, wrapper, wrapped,
157 assigned=functools.WRAPPER_ASSIGNMENTS,
158 updated=functools.WRAPPER_UPDATES):
159 # Check attributes were assigned
160 for name in assigned:
161 self.failUnless(getattr(wrapper, name) is getattr(wrapped, name))
162 # Check attributes were updated
163 for name in updated:
164 wrapper_attr = getattr(wrapper, name)
165 wrapped_attr = getattr(wrapped, name)
166 for key in wrapped_attr:
167 self.failUnless(wrapped_attr[key] is wrapper_attr[key])
168
169 def test_default_update(self):
170 def f():
171 """This is a test"""
172 pass
173 f.attr = 'This is also a test'
174 def wrapper():
175 pass
176 functools.update_wrapper(wrapper, f)
177 self.check_wrapper(wrapper, f)
178 self.assertEqual(wrapper.__name__, 'f')
179 self.assertEqual(wrapper.__doc__, 'This is a test')
180 self.assertEqual(wrapper.attr, 'This is also a test')
181
182 def test_no_update(self):
183 def f():
184 """This is a test"""
185 pass
186 f.attr = 'This is also a test'
187 def wrapper():
188 pass
189 functools.update_wrapper(wrapper, f, (), ())
190 self.check_wrapper(wrapper, f, (), ())
191 self.assertEqual(wrapper.__name__, 'wrapper')
192 self.assertEqual(wrapper.__doc__, None)
193 self.failIf(hasattr(wrapper, 'attr'))
194
195 def test_selective_update(self):
196 def f():
197 pass
198 f.attr = 'This is a different test'
199 f.dict_attr = dict(a=1, b=2, c=3)
200 def wrapper():
201 pass
202 wrapper.dict_attr = {}
203 assign = ('attr',)
204 update = ('dict_attr',)
205 functools.update_wrapper(wrapper, f, assign, update)
206 self.check_wrapper(wrapper, f, assign, update)
207 self.assertEqual(wrapper.__name__, 'wrapper')
208 self.assertEqual(wrapper.__doc__, None)
209 self.assertEqual(wrapper.attr, 'This is a different test')
210 self.assertEqual(wrapper.dict_attr, f.dict_attr)
211
Andrew M. Kuchling41eb7162006-10-27 16:39:10 +0000212 def test_builtin_update(self):
213 # Test for bug #1576241
214 def wrapper():
215 pass
216 functools.update_wrapper(wrapper, max)
217 self.assertEqual(wrapper.__name__, 'max')
218 self.assert_(wrapper.__doc__.startswith('max('))
Nick Coghlan676725d2006-06-08 13:54:49 +0000219
220class TestWraps(TestUpdateWrapper):
221
222 def test_default_update(self):
223 def f():
224 """This is a test"""
225 pass
226 f.attr = 'This is also a test'
227 @functools.wraps(f)
228 def wrapper():
229 pass
230 self.check_wrapper(wrapper, f)
231 self.assertEqual(wrapper.__name__, 'f')
232 self.assertEqual(wrapper.__doc__, 'This is a test')
233 self.assertEqual(wrapper.attr, 'This is also a test')
234
235 def test_no_update(self):
236 def f():
237 """This is a test"""
238 pass
239 f.attr = 'This is also a test'
240 @functools.wraps(f, (), ())
241 def wrapper():
242 pass
243 self.check_wrapper(wrapper, f, (), ())
244 self.assertEqual(wrapper.__name__, 'wrapper')
245 self.assertEqual(wrapper.__doc__, None)
246 self.failIf(hasattr(wrapper, 'attr'))
247
248 def test_selective_update(self):
249 def f():
250 pass
251 f.attr = 'This is a different test'
252 f.dict_attr = dict(a=1, b=2, c=3)
253 def add_dict_attr(f):
254 f.dict_attr = {}
255 return f
256 assign = ('attr',)
257 update = ('dict_attr',)
258 @functools.wraps(f, assign, update)
259 @add_dict_attr
260 def wrapper():
261 pass
262 self.check_wrapper(wrapper, f, assign, update)
263 self.assertEqual(wrapper.__name__, 'wrapper')
264 self.assertEqual(wrapper.__doc__, None)
265 self.assertEqual(wrapper.attr, 'This is a different test')
266 self.assertEqual(wrapper.dict_attr, f.dict_attr)
267
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000268
Brett Cannon83e81842008-08-09 23:30:55 +0000269class TestReduce(unittest.TestCase):
270
271 def test_reduce(self):
272 class Squares:
273
274 def __init__(self, max):
275 self.max = max
276 self.sofar = []
277
278 def __len__(self): return len(self.sofar)
279
280 def __getitem__(self, i):
281 if not 0 <= i < self.max: raise IndexError
282 n = len(self.sofar)
283 while n <= i:
284 self.sofar.append(n*n)
285 n += 1
286 return self.sofar[i]
287
288 reduce = functools.reduce
289 self.assertEqual(reduce(lambda x, y: x+y, ['a', 'b', 'c'], ''), 'abc')
290 self.assertEqual(
291 reduce(lambda x, y: x+y, [['a', 'c'], [], ['d', 'w']], []),
292 ['a','c','d','w']
293 )
294 self.assertEqual(reduce(lambda x, y: x*y, range(2,8), 1), 5040)
295 self.assertEqual(
296 reduce(lambda x, y: x*y, range(2,21), 1L),
297 2432902008176640000L
298 )
299 self.assertEqual(reduce(lambda x, y: x+y, Squares(10)), 285)
300 self.assertEqual(reduce(lambda x, y: x+y, Squares(10), 0), 285)
301 self.assertEqual(reduce(lambda x, y: x+y, Squares(0), 0), 0)
302 self.assertRaises(TypeError, reduce)
303 self.assertRaises(TypeError, reduce, 42, 42)
304 self.assertRaises(TypeError, reduce, 42, 42, 42)
305 self.assertEqual(reduce(42, "1"), "1") # func is never called with one item
306 self.assertEqual(reduce(42, "", "1"), "1") # func is never called with one item
307 self.assertRaises(TypeError, reduce, 42, (42, 42))
308
309
310
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000311
312def test_main(verbose=None):
313 import sys
314 test_classes = (
315 TestPartial,
316 TestPartialSubclass,
317 TestPythonPartial,
Nick Coghlan676725d2006-06-08 13:54:49 +0000318 TestUpdateWrapper,
Brett Cannon83e81842008-08-09 23:30:55 +0000319 TestWraps,
320 TestReduce,
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000321 )
322 test_support.run_unittest(*test_classes)
323
324 # verify reference counting
325 if verbose and hasattr(sys, "gettotalrefcount"):
326 import gc
327 counts = [None] * 5
328 for i in xrange(len(counts)):
329 test_support.run_unittest(*test_classes)
330 gc.collect()
331 counts[i] = sys.gettotalrefcount()
332 print counts
333
334if __name__ == '__main__':
335 test_main(verbose=True)