blob: 84e974a8db0a02bc02e39cdb4cf439bd25c1d3b0 [file] [log] [blame]
Thomas Wouters4d70c3d2006-06-08 14:42:34 +00001import functools
Raymond Hettinger9c323f82005-02-28 19:39:44 +00002import unittest
Benjamin Petersonee8712c2008-05-20 21:35:26 +00003from test import support
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +00004from weakref import proxy
Raymond Hettinger9c323f82005-02-28 19:39:44 +00005
6@staticmethod
7def PythonPartial(func, *args, **keywords):
8 'Pure Python approximation of partial()'
9 def newfunc(*fargs, **fkeywords):
10 newkeywords = keywords.copy()
11 newkeywords.update(fkeywords)
12 return func(*(args + fargs), **newkeywords)
13 newfunc.func = func
14 newfunc.args = args
15 newfunc.keywords = keywords
16 return newfunc
17
18def capture(*args, **kw):
19 """capture all positional and keyword arguments"""
20 return args, kw
21
Guido van Rossumc1f779c2007-07-03 08:25:58 +000022
Raymond Hettinger9c323f82005-02-28 19:39:44 +000023class TestPartial(unittest.TestCase):
24
Thomas Wouters4d70c3d2006-06-08 14:42:34 +000025 thetype = functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +000026
27 def test_basic_examples(self):
28 p = self.thetype(capture, 1, 2, a=10, b=20)
29 self.assertEqual(p(3, 4, b=30, c=40),
30 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
31 p = self.thetype(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000032 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000033
34 def test_attributes(self):
35 p = self.thetype(capture, 1, 2, a=10, b=20)
36 # attributes should be readable
37 self.assertEqual(p.func, capture)
38 self.assertEqual(p.args, (1, 2))
39 self.assertEqual(p.keywords, dict(a=10, b=20))
40 # attributes should not be writable
41 if not isinstance(self.thetype, type):
42 return
43 self.assertRaises(TypeError, setattr, p, 'func', map)
44 self.assertRaises(TypeError, setattr, p, 'args', (1, 2))
45 self.assertRaises(TypeError, setattr, p, 'keywords', dict(a=1, b=2))
46
47 def test_argument_checking(self):
48 self.assertRaises(TypeError, self.thetype) # need at least a func arg
49 try:
50 self.thetype(2)()
51 except TypeError:
52 pass
53 else:
54 self.fail('First arg not checked for callability')
55
56 def test_protection_of_callers_dict_argument(self):
57 # a caller's dictionary should not be altered by partial
58 def func(a=10, b=20):
59 return a
60 d = {'a':3}
61 p = self.thetype(func, a=5)
62 self.assertEqual(p(**d), 3)
63 self.assertEqual(d, {'a':3})
64 p(b=7)
65 self.assertEqual(d, {'a':3})
66
67 def test_arg_combinations(self):
68 # exercise special code paths for zero args in either partial
69 # object or the caller
70 p = self.thetype(capture)
71 self.assertEqual(p(), ((), {}))
72 self.assertEqual(p(1,2), ((1,2), {}))
73 p = self.thetype(capture, 1, 2)
74 self.assertEqual(p(), ((1,2), {}))
75 self.assertEqual(p(3,4), ((1,2,3,4), {}))
76
77 def test_kw_combinations(self):
78 # exercise special code paths for no keyword args in
79 # either the partial object or the caller
80 p = self.thetype(capture)
81 self.assertEqual(p(), ((), {}))
82 self.assertEqual(p(a=1), ((), {'a':1}))
83 p = self.thetype(capture, a=1)
84 self.assertEqual(p(), ((), {'a':1}))
85 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
86 # keyword args in the call override those in the partial object
87 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
88
89 def test_positional(self):
90 # make sure positional arguments are captured correctly
91 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
92 p = self.thetype(capture, *args)
93 expected = args + ('x',)
94 got, empty = p('x')
95 self.failUnless(expected == got and empty == {})
96
97 def test_keyword(self):
98 # make sure keyword arguments are captured correctly
99 for a in ['a', 0, None, 3.5]:
100 p = self.thetype(capture, a=a)
101 expected = {'a':a,'x':None}
102 empty, got = p(x=None)
103 self.failUnless(expected == got and empty == ())
104
105 def test_no_side_effects(self):
106 # make sure there are no side effects that affect subsequent calls
107 p = self.thetype(capture, 0, a=1)
108 args1, kw1 = p(1, b=2)
109 self.failUnless(args1 == (0,1) and kw1 == {'a':1,'b':2})
110 args2, kw2 = p()
111 self.failUnless(args2 == (0,) and kw2 == {'a':1})
112
113 def test_error_propagation(self):
114 def f(x, y):
115 x / y
116 self.assertRaises(ZeroDivisionError, self.thetype(f, 1, 0))
117 self.assertRaises(ZeroDivisionError, self.thetype(f, 1), 0)
118 self.assertRaises(ZeroDivisionError, self.thetype(f), 1, 0)
119 self.assertRaises(ZeroDivisionError, self.thetype(f, y=0), 1)
120
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000121 def test_attributes(self):
122 p = self.thetype(hex)
123 try:
124 del p.__dict__
125 except TypeError:
126 pass
127 else:
128 self.fail('partial object allowed __dict__ to be deleted')
129
130 def test_weakref(self):
131 f = self.thetype(int, base=16)
132 p = proxy(f)
133 self.assertEqual(f.func, p.func)
134 f = None
135 self.assertRaises(ReferenceError, getattr, p, 'func')
136
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000137 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000138 data = list(map(str, range(10)))
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000139 join = self.thetype(str.join, '')
140 self.assertEqual(join(data), '0123456789')
141 join = self.thetype(''.join)
142 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000143
Thomas Wouters4d70c3d2006-06-08 14:42:34 +0000144class PartialSubclass(functools.partial):
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000145 pass
146
147class TestPartialSubclass(TestPartial):
148
149 thetype = PartialSubclass
150
151
152class TestPythonPartial(TestPartial):
153
154 thetype = PythonPartial
155
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000156class TestUpdateWrapper(unittest.TestCase):
157
158 def check_wrapper(self, wrapper, wrapped,
159 assigned=functools.WRAPPER_ASSIGNMENTS,
160 updated=functools.WRAPPER_UPDATES):
161 # Check attributes were assigned
162 for name in assigned:
163 self.failUnless(getattr(wrapper, name) is getattr(wrapped, name))
164 # Check attributes were updated
165 for name in updated:
166 wrapper_attr = getattr(wrapper, name)
167 wrapped_attr = getattr(wrapped, name)
168 for key in wrapped_attr:
169 self.failUnless(wrapped_attr[key] is wrapper_attr[key])
170
171 def test_default_update(self):
172 def f():
173 """This is a test"""
174 pass
175 f.attr = 'This is also a test'
176 def wrapper():
177 pass
178 functools.update_wrapper(wrapper, f)
179 self.check_wrapper(wrapper, f)
180 self.assertEqual(wrapper.__name__, 'f')
181 self.assertEqual(wrapper.__doc__, 'This is a test')
182 self.assertEqual(wrapper.attr, 'This is also a test')
183
184 def test_no_update(self):
185 def f():
186 """This is a test"""
187 pass
188 f.attr = 'This is also a test'
189 def wrapper():
190 pass
191 functools.update_wrapper(wrapper, f, (), ())
192 self.check_wrapper(wrapper, f, (), ())
193 self.assertEqual(wrapper.__name__, 'wrapper')
194 self.assertEqual(wrapper.__doc__, None)
195 self.failIf(hasattr(wrapper, 'attr'))
196
197 def test_selective_update(self):
198 def f():
199 pass
200 f.attr = 'This is a different test'
201 f.dict_attr = dict(a=1, b=2, c=3)
202 def wrapper():
203 pass
204 wrapper.dict_attr = {}
205 assign = ('attr',)
206 update = ('dict_attr',)
207 functools.update_wrapper(wrapper, f, assign, update)
208 self.check_wrapper(wrapper, f, assign, update)
209 self.assertEqual(wrapper.__name__, 'wrapper')
210 self.assertEqual(wrapper.__doc__, None)
211 self.assertEqual(wrapper.attr, 'This is a different test')
212 self.assertEqual(wrapper.dict_attr, f.dict_attr)
213
Thomas Wouters89f507f2006-12-13 04:49:30 +0000214 def test_builtin_update(self):
215 # Test for bug #1576241
216 def wrapper():
217 pass
218 functools.update_wrapper(wrapper, max)
219 self.assertEqual(wrapper.__name__, 'max')
220 self.assert_(wrapper.__doc__.startswith('max('))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000221
222class TestWraps(TestUpdateWrapper):
223
224 def test_default_update(self):
225 def f():
226 """This is a test"""
227 pass
228 f.attr = 'This is also a test'
229 @functools.wraps(f)
230 def wrapper():
231 pass
232 self.check_wrapper(wrapper, f)
233 self.assertEqual(wrapper.__name__, 'f')
234 self.assertEqual(wrapper.__doc__, 'This is a test')
235 self.assertEqual(wrapper.attr, 'This is also a test')
236
237 def test_no_update(self):
238 def f():
239 """This is a test"""
240 pass
241 f.attr = 'This is also a test'
242 @functools.wraps(f, (), ())
243 def wrapper():
244 pass
245 self.check_wrapper(wrapper, f, (), ())
246 self.assertEqual(wrapper.__name__, 'wrapper')
247 self.assertEqual(wrapper.__doc__, None)
248 self.failIf(hasattr(wrapper, 'attr'))
249
250 def test_selective_update(self):
251 def f():
252 pass
253 f.attr = 'This is a different test'
254 f.dict_attr = dict(a=1, b=2, c=3)
255 def add_dict_attr(f):
256 f.dict_attr = {}
257 return f
258 assign = ('attr',)
259 update = ('dict_attr',)
260 @functools.wraps(f, assign, update)
261 @add_dict_attr
262 def wrapper():
263 pass
264 self.check_wrapper(wrapper, f, assign, update)
265 self.assertEqual(wrapper.__name__, 'wrapper')
266 self.assertEqual(wrapper.__doc__, None)
267 self.assertEqual(wrapper.attr, 'This is a different test')
268 self.assertEqual(wrapper.dict_attr, f.dict_attr)
269
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000270class TestReduce(unittest.TestCase):
271 func = functools.reduce
272
273 def test_reduce(self):
274 class Squares:
275 def __init__(self, max):
276 self.max = max
277 self.sofar = []
278
279 def __len__(self):
280 return len(self.sofar)
281
282 def __getitem__(self, i):
283 if not 0 <= i < self.max: raise IndexError
284 n = len(self.sofar)
285 while n <= i:
286 self.sofar.append(n*n)
287 n += 1
288 return self.sofar[i]
Guido van Rossumd8faa362007-04-27 19:54:29 +0000289
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000290 self.assertEqual(self.func(lambda x, y: x+y, ['a', 'b', 'c'], ''), 'abc')
291 self.assertEqual(
292 self.func(lambda x, y: x+y, [['a', 'c'], [], ['d', 'w']], []),
293 ['a','c','d','w']
294 )
295 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
296 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000297 self.func(lambda x, y: x*y, range(2,21), 1),
298 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000299 )
300 self.assertEqual(self.func(lambda x, y: x+y, Squares(10)), 285)
301 self.assertEqual(self.func(lambda x, y: x+y, Squares(10), 0), 285)
302 self.assertEqual(self.func(lambda x, y: x+y, Squares(0), 0), 0)
303 self.assertRaises(TypeError, self.func)
304 self.assertRaises(TypeError, self.func, 42, 42)
305 self.assertRaises(TypeError, self.func, 42, 42, 42)
306 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
307 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
308 self.assertRaises(TypeError, self.func, 42, (42, 42))
309
310 class BadSeq:
311 def __getitem__(self, index):
312 raise ValueError
313 self.assertRaises(ValueError, self.func, 42, BadSeq())
314
315 # Test reduce()'s use of iterators.
316 def test_iterator_usage(self):
317 class SequenceClass:
318 def __init__(self, n):
319 self.n = n
320 def __getitem__(self, i):
321 if 0 <= i < self.n:
322 return i
323 else:
324 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000325
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000326 from operator import add
327 self.assertEqual(self.func(add, SequenceClass(5)), 10)
328 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
329 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
330 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
331 self.assertEqual(self.func(add, SequenceClass(1)), 0)
332 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
333
334 d = {"one": 1, "two": 2, "three": 3}
335 self.assertEqual(self.func(add, d), "".join(d.keys()))
336
Guido van Rossumd8faa362007-04-27 19:54:29 +0000337
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000338
339
340def test_main(verbose=None):
341 import sys
342 test_classes = (
343 TestPartial,
344 TestPartialSubclass,
345 TestPythonPartial,
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000346 TestUpdateWrapper,
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000347 TestWraps,
348 TestReduce
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000349 )
Benjamin Petersonee8712c2008-05-20 21:35:26 +0000350 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000351
352 # verify reference counting
353 if verbose and hasattr(sys, "gettotalrefcount"):
354 import gc
355 counts = [None] * 5
Guido van Rossum805365e2007-05-07 22:24:25 +0000356 for i in range(len(counts)):
Benjamin Petersonee8712c2008-05-20 21:35:26 +0000357 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000358 gc.collect()
359 counts[i] = sys.gettotalrefcount()
Guido van Rossumbe19ed72007-02-09 05:37:30 +0000360 print(counts)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000361
362if __name__ == '__main__':
363 test_main(verbose=True)