blob: dedfb1ebe28e82e84d2771a85001fca888590fe8 [file] [log] [blame]
Thomas Wouters4d70c3d2006-06-08 14:42:34 +00001import functools
Raymond Hettinger9c323f82005-02-28 19:39:44 +00002import unittest
Benjamin Petersonee8712c2008-05-20 21:35:26 +00003from test import support
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +00004from weakref import proxy
Jack Diederiche0cbd692009-04-01 04:27:09 +00005import pickle
Raymond Hettinger9c323f82005-02-28 19:39:44 +00006
7@staticmethod
8def PythonPartial(func, *args, **keywords):
9 'Pure Python approximation of partial()'
10 def newfunc(*fargs, **fkeywords):
11 newkeywords = keywords.copy()
12 newkeywords.update(fkeywords)
13 return func(*(args + fargs), **newkeywords)
14 newfunc.func = func
15 newfunc.args = args
16 newfunc.keywords = keywords
17 return newfunc
18
19def capture(*args, **kw):
20 """capture all positional and keyword arguments"""
21 return args, kw
22
Jack Diederiche0cbd692009-04-01 04:27:09 +000023def signature(part):
24 """ return the signature of a partial object """
25 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000026
Raymond Hettinger9c323f82005-02-28 19:39:44 +000027class TestPartial(unittest.TestCase):
28
Thomas Wouters4d70c3d2006-06-08 14:42:34 +000029 thetype = functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +000030
31 def test_basic_examples(self):
32 p = self.thetype(capture, 1, 2, a=10, b=20)
33 self.assertEqual(p(3, 4, b=30, c=40),
34 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
35 p = self.thetype(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000036 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000037
38 def test_attributes(self):
39 p = self.thetype(capture, 1, 2, a=10, b=20)
40 # attributes should be readable
41 self.assertEqual(p.func, capture)
42 self.assertEqual(p.args, (1, 2))
43 self.assertEqual(p.keywords, dict(a=10, b=20))
44 # attributes should not be writable
45 if not isinstance(self.thetype, type):
46 return
47 self.assertRaises(TypeError, setattr, p, 'func', map)
48 self.assertRaises(TypeError, setattr, p, 'args', (1, 2))
49 self.assertRaises(TypeError, setattr, p, 'keywords', dict(a=1, b=2))
50
51 def test_argument_checking(self):
52 self.assertRaises(TypeError, self.thetype) # need at least a func arg
53 try:
54 self.thetype(2)()
55 except TypeError:
56 pass
57 else:
58 self.fail('First arg not checked for callability')
59
60 def test_protection_of_callers_dict_argument(self):
61 # a caller's dictionary should not be altered by partial
62 def func(a=10, b=20):
63 return a
64 d = {'a':3}
65 p = self.thetype(func, a=5)
66 self.assertEqual(p(**d), 3)
67 self.assertEqual(d, {'a':3})
68 p(b=7)
69 self.assertEqual(d, {'a':3})
70
71 def test_arg_combinations(self):
72 # exercise special code paths for zero args in either partial
73 # object or the caller
74 p = self.thetype(capture)
75 self.assertEqual(p(), ((), {}))
76 self.assertEqual(p(1,2), ((1,2), {}))
77 p = self.thetype(capture, 1, 2)
78 self.assertEqual(p(), ((1,2), {}))
79 self.assertEqual(p(3,4), ((1,2,3,4), {}))
80
81 def test_kw_combinations(self):
82 # exercise special code paths for no keyword args in
83 # either the partial object or the caller
84 p = self.thetype(capture)
85 self.assertEqual(p(), ((), {}))
86 self.assertEqual(p(a=1), ((), {'a':1}))
87 p = self.thetype(capture, a=1)
88 self.assertEqual(p(), ((), {'a':1}))
89 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
90 # keyword args in the call override those in the partial object
91 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
92
93 def test_positional(self):
94 # make sure positional arguments are captured correctly
95 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
96 p = self.thetype(capture, *args)
97 expected = args + ('x',)
98 got, empty = p('x')
99 self.failUnless(expected == got and empty == {})
100
101 def test_keyword(self):
102 # make sure keyword arguments are captured correctly
103 for a in ['a', 0, None, 3.5]:
104 p = self.thetype(capture, a=a)
105 expected = {'a':a,'x':None}
106 empty, got = p(x=None)
107 self.failUnless(expected == got and empty == ())
108
109 def test_no_side_effects(self):
110 # make sure there are no side effects that affect subsequent calls
111 p = self.thetype(capture, 0, a=1)
112 args1, kw1 = p(1, b=2)
113 self.failUnless(args1 == (0,1) and kw1 == {'a':1,'b':2})
114 args2, kw2 = p()
115 self.failUnless(args2 == (0,) and kw2 == {'a':1})
116
117 def test_error_propagation(self):
118 def f(x, y):
119 x / y
120 self.assertRaises(ZeroDivisionError, self.thetype(f, 1, 0))
121 self.assertRaises(ZeroDivisionError, self.thetype(f, 1), 0)
122 self.assertRaises(ZeroDivisionError, self.thetype(f), 1, 0)
123 self.assertRaises(ZeroDivisionError, self.thetype(f, y=0), 1)
124
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000125 def test_attributes(self):
126 p = self.thetype(hex)
127 try:
128 del p.__dict__
129 except TypeError:
130 pass
131 else:
132 self.fail('partial object allowed __dict__ to be deleted')
133
134 def test_weakref(self):
135 f = self.thetype(int, base=16)
136 p = proxy(f)
137 self.assertEqual(f.func, p.func)
138 f = None
139 self.assertRaises(ReferenceError, getattr, p, 'func')
140
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000141 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000142 data = list(map(str, range(10)))
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000143 join = self.thetype(str.join, '')
144 self.assertEqual(join(data), '0123456789')
145 join = self.thetype(''.join)
146 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000147
Jack Diederiche0cbd692009-04-01 04:27:09 +0000148 def test_pickle(self):
149 f = self.thetype(signature, 'asdf', bar=True)
150 f.add_something_to__dict__ = True
151 f_copy = pickle.loads(pickle.dumps(f))
152 self.assertEqual(signature(f), signature(f_copy))
153
Thomas Wouters4d70c3d2006-06-08 14:42:34 +0000154class PartialSubclass(functools.partial):
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000155 pass
156
157class TestPartialSubclass(TestPartial):
158
159 thetype = PartialSubclass
160
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000161class TestPythonPartial(TestPartial):
162
163 thetype = PythonPartial
164
Jack Diederiche0cbd692009-04-01 04:27:09 +0000165 # the python version isn't picklable
166 def test_pickle(self): pass
167
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000168class TestUpdateWrapper(unittest.TestCase):
169
170 def check_wrapper(self, wrapper, wrapped,
171 assigned=functools.WRAPPER_ASSIGNMENTS,
172 updated=functools.WRAPPER_UPDATES):
173 # Check attributes were assigned
174 for name in assigned:
175 self.failUnless(getattr(wrapper, name) is getattr(wrapped, name))
176 # Check attributes were updated
177 for name in updated:
178 wrapper_attr = getattr(wrapper, name)
179 wrapped_attr = getattr(wrapped, name)
180 for key in wrapped_attr:
181 self.failUnless(wrapped_attr[key] is wrapper_attr[key])
182
183 def test_default_update(self):
184 def f():
185 """This is a test"""
186 pass
187 f.attr = 'This is also a test'
188 def wrapper():
189 pass
190 functools.update_wrapper(wrapper, f)
191 self.check_wrapper(wrapper, f)
192 self.assertEqual(wrapper.__name__, 'f')
193 self.assertEqual(wrapper.__doc__, 'This is a test')
194 self.assertEqual(wrapper.attr, 'This is also a test')
195
196 def test_no_update(self):
197 def f():
198 """This is a test"""
199 pass
200 f.attr = 'This is also a test'
201 def wrapper():
202 pass
203 functools.update_wrapper(wrapper, f, (), ())
204 self.check_wrapper(wrapper, f, (), ())
205 self.assertEqual(wrapper.__name__, 'wrapper')
206 self.assertEqual(wrapper.__doc__, None)
207 self.failIf(hasattr(wrapper, 'attr'))
208
209 def test_selective_update(self):
210 def f():
211 pass
212 f.attr = 'This is a different test'
213 f.dict_attr = dict(a=1, b=2, c=3)
214 def wrapper():
215 pass
216 wrapper.dict_attr = {}
217 assign = ('attr',)
218 update = ('dict_attr',)
219 functools.update_wrapper(wrapper, f, assign, update)
220 self.check_wrapper(wrapper, f, assign, update)
221 self.assertEqual(wrapper.__name__, 'wrapper')
222 self.assertEqual(wrapper.__doc__, None)
223 self.assertEqual(wrapper.attr, 'This is a different test')
224 self.assertEqual(wrapper.dict_attr, f.dict_attr)
225
Thomas Wouters89f507f2006-12-13 04:49:30 +0000226 def test_builtin_update(self):
227 # Test for bug #1576241
228 def wrapper():
229 pass
230 functools.update_wrapper(wrapper, max)
231 self.assertEqual(wrapper.__name__, 'max')
232 self.assert_(wrapper.__doc__.startswith('max('))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000233
234class TestWraps(TestUpdateWrapper):
235
236 def test_default_update(self):
237 def f():
238 """This is a test"""
239 pass
240 f.attr = 'This is also a test'
241 @functools.wraps(f)
242 def wrapper():
243 pass
244 self.check_wrapper(wrapper, f)
245 self.assertEqual(wrapper.__name__, 'f')
246 self.assertEqual(wrapper.__doc__, 'This is a test')
247 self.assertEqual(wrapper.attr, 'This is also a test')
248
249 def test_no_update(self):
250 def f():
251 """This is a test"""
252 pass
253 f.attr = 'This is also a test'
254 @functools.wraps(f, (), ())
255 def wrapper():
256 pass
257 self.check_wrapper(wrapper, f, (), ())
258 self.assertEqual(wrapper.__name__, 'wrapper')
259 self.assertEqual(wrapper.__doc__, None)
260 self.failIf(hasattr(wrapper, 'attr'))
261
262 def test_selective_update(self):
263 def f():
264 pass
265 f.attr = 'This is a different test'
266 f.dict_attr = dict(a=1, b=2, c=3)
267 def add_dict_attr(f):
268 f.dict_attr = {}
269 return f
270 assign = ('attr',)
271 update = ('dict_attr',)
272 @functools.wraps(f, assign, update)
273 @add_dict_attr
274 def wrapper():
275 pass
276 self.check_wrapper(wrapper, f, assign, update)
277 self.assertEqual(wrapper.__name__, 'wrapper')
278 self.assertEqual(wrapper.__doc__, None)
279 self.assertEqual(wrapper.attr, 'This is a different test')
280 self.assertEqual(wrapper.dict_attr, f.dict_attr)
281
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000282class TestReduce(unittest.TestCase):
283 func = functools.reduce
284
285 def test_reduce(self):
286 class Squares:
287 def __init__(self, max):
288 self.max = max
289 self.sofar = []
290
291 def __len__(self):
292 return len(self.sofar)
293
294 def __getitem__(self, i):
295 if not 0 <= i < self.max: raise IndexError
296 n = len(self.sofar)
297 while n <= i:
298 self.sofar.append(n*n)
299 n += 1
300 return self.sofar[i]
Guido van Rossumd8faa362007-04-27 19:54:29 +0000301
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000302 self.assertEqual(self.func(lambda x, y: x+y, ['a', 'b', 'c'], ''), 'abc')
303 self.assertEqual(
304 self.func(lambda x, y: x+y, [['a', 'c'], [], ['d', 'w']], []),
305 ['a','c','d','w']
306 )
307 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
308 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000309 self.func(lambda x, y: x*y, range(2,21), 1),
310 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000311 )
312 self.assertEqual(self.func(lambda x, y: x+y, Squares(10)), 285)
313 self.assertEqual(self.func(lambda x, y: x+y, Squares(10), 0), 285)
314 self.assertEqual(self.func(lambda x, y: x+y, Squares(0), 0), 0)
315 self.assertRaises(TypeError, self.func)
316 self.assertRaises(TypeError, self.func, 42, 42)
317 self.assertRaises(TypeError, self.func, 42, 42, 42)
318 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
319 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
320 self.assertRaises(TypeError, self.func, 42, (42, 42))
321
322 class BadSeq:
323 def __getitem__(self, index):
324 raise ValueError
325 self.assertRaises(ValueError, self.func, 42, BadSeq())
326
327 # Test reduce()'s use of iterators.
328 def test_iterator_usage(self):
329 class SequenceClass:
330 def __init__(self, n):
331 self.n = n
332 def __getitem__(self, i):
333 if 0 <= i < self.n:
334 return i
335 else:
336 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000337
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000338 from operator import add
339 self.assertEqual(self.func(add, SequenceClass(5)), 10)
340 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
341 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
342 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
343 self.assertEqual(self.func(add, SequenceClass(1)), 0)
344 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
345
346 d = {"one": 1, "two": 2, "three": 3}
347 self.assertEqual(self.func(add, d), "".join(d.keys()))
348
Guido van Rossumd8faa362007-04-27 19:54:29 +0000349
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000350
351
352def test_main(verbose=None):
353 import sys
354 test_classes = (
355 TestPartial,
356 TestPartialSubclass,
357 TestPythonPartial,
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000358 TestUpdateWrapper,
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000359 TestWraps,
360 TestReduce
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000361 )
Benjamin Petersonee8712c2008-05-20 21:35:26 +0000362 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000363
364 # verify reference counting
365 if verbose and hasattr(sys, "gettotalrefcount"):
366 import gc
367 counts = [None] * 5
Guido van Rossum805365e2007-05-07 22:24:25 +0000368 for i in range(len(counts)):
Benjamin Petersonee8712c2008-05-20 21:35:26 +0000369 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000370 gc.collect()
371 counts[i] = sys.gettotalrefcount()
Guido van Rossumbe19ed72007-02-09 05:37:30 +0000372 print(counts)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000373
374if __name__ == '__main__':
375 test_main(verbose=True)