blob: df335e8614a8f6d5223aa03b55ad7772224c852d [file] [log] [blame]
Thomas Wouters4d70c3d2006-06-08 14:42:34 +00001import functools
R. David Murray378c0cf2010-02-24 01:46:21 +00002import sys
Raymond Hettinger9c323f82005-02-28 19:39:44 +00003import unittest
Benjamin Petersonee8712c2008-05-20 21:35:26 +00004from test import support
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +00005from weakref import proxy
Jack Diederiche0cbd692009-04-01 04:27:09 +00006import pickle
Raymond Hettinger9c323f82005-02-28 19:39:44 +00007
8@staticmethod
9def PythonPartial(func, *args, **keywords):
10 'Pure Python approximation of partial()'
11 def newfunc(*fargs, **fkeywords):
12 newkeywords = keywords.copy()
13 newkeywords.update(fkeywords)
14 return func(*(args + fargs), **newkeywords)
15 newfunc.func = func
16 newfunc.args = args
17 newfunc.keywords = keywords
18 return newfunc
19
20def capture(*args, **kw):
21 """capture all positional and keyword arguments"""
22 return args, kw
23
Jack Diederiche0cbd692009-04-01 04:27:09 +000024def signature(part):
25 """ return the signature of a partial object """
26 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000027
Raymond Hettinger9c323f82005-02-28 19:39:44 +000028class TestPartial(unittest.TestCase):
29
Thomas Wouters4d70c3d2006-06-08 14:42:34 +000030 thetype = functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +000031
32 def test_basic_examples(self):
33 p = self.thetype(capture, 1, 2, a=10, b=20)
34 self.assertEqual(p(3, 4, b=30, c=40),
35 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
36 p = self.thetype(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000037 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000038
39 def test_attributes(self):
40 p = self.thetype(capture, 1, 2, a=10, b=20)
41 # attributes should be readable
42 self.assertEqual(p.func, capture)
43 self.assertEqual(p.args, (1, 2))
44 self.assertEqual(p.keywords, dict(a=10, b=20))
45 # attributes should not be writable
46 if not isinstance(self.thetype, type):
47 return
48 self.assertRaises(TypeError, setattr, p, 'func', map)
49 self.assertRaises(TypeError, setattr, p, 'args', (1, 2))
50 self.assertRaises(TypeError, setattr, p, 'keywords', dict(a=1, b=2))
51
52 def test_argument_checking(self):
53 self.assertRaises(TypeError, self.thetype) # need at least a func arg
54 try:
55 self.thetype(2)()
56 except TypeError:
57 pass
58 else:
59 self.fail('First arg not checked for callability')
60
61 def test_protection_of_callers_dict_argument(self):
62 # a caller's dictionary should not be altered by partial
63 def func(a=10, b=20):
64 return a
65 d = {'a':3}
66 p = self.thetype(func, a=5)
67 self.assertEqual(p(**d), 3)
68 self.assertEqual(d, {'a':3})
69 p(b=7)
70 self.assertEqual(d, {'a':3})
71
72 def test_arg_combinations(self):
73 # exercise special code paths for zero args in either partial
74 # object or the caller
75 p = self.thetype(capture)
76 self.assertEqual(p(), ((), {}))
77 self.assertEqual(p(1,2), ((1,2), {}))
78 p = self.thetype(capture, 1, 2)
79 self.assertEqual(p(), ((1,2), {}))
80 self.assertEqual(p(3,4), ((1,2,3,4), {}))
81
82 def test_kw_combinations(self):
83 # exercise special code paths for no keyword args in
84 # either the partial object or the caller
85 p = self.thetype(capture)
86 self.assertEqual(p(), ((), {}))
87 self.assertEqual(p(a=1), ((), {'a':1}))
88 p = self.thetype(capture, a=1)
89 self.assertEqual(p(), ((), {'a':1}))
90 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
91 # keyword args in the call override those in the partial object
92 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
93
94 def test_positional(self):
95 # make sure positional arguments are captured correctly
96 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
97 p = self.thetype(capture, *args)
98 expected = args + ('x',)
99 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000100 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000101
102 def test_keyword(self):
103 # make sure keyword arguments are captured correctly
104 for a in ['a', 0, None, 3.5]:
105 p = self.thetype(capture, a=a)
106 expected = {'a':a,'x':None}
107 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000108 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000109
110 def test_no_side_effects(self):
111 # make sure there are no side effects that affect subsequent calls
112 p = self.thetype(capture, 0, a=1)
113 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000114 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000115 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000116 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000117
118 def test_error_propagation(self):
119 def f(x, y):
120 x / y
121 self.assertRaises(ZeroDivisionError, self.thetype(f, 1, 0))
122 self.assertRaises(ZeroDivisionError, self.thetype(f, 1), 0)
123 self.assertRaises(ZeroDivisionError, self.thetype(f), 1, 0)
124 self.assertRaises(ZeroDivisionError, self.thetype(f, y=0), 1)
125
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000126 def test_attributes(self):
127 p = self.thetype(hex)
128 try:
129 del p.__dict__
130 except TypeError:
131 pass
132 else:
133 self.fail('partial object allowed __dict__ to be deleted')
134
135 def test_weakref(self):
136 f = self.thetype(int, base=16)
137 p = proxy(f)
138 self.assertEqual(f.func, p.func)
139 f = None
140 self.assertRaises(ReferenceError, getattr, p, 'func')
141
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000142 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000143 data = list(map(str, range(10)))
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000144 join = self.thetype(str.join, '')
145 self.assertEqual(join(data), '0123456789')
146 join = self.thetype(''.join)
147 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000148
Jack Diederiche0cbd692009-04-01 04:27:09 +0000149 def test_pickle(self):
150 f = self.thetype(signature, 'asdf', bar=True)
151 f.add_something_to__dict__ = True
152 f_copy = pickle.loads(pickle.dumps(f))
153 self.assertEqual(signature(f), signature(f_copy))
154
Thomas Wouters4d70c3d2006-06-08 14:42:34 +0000155class PartialSubclass(functools.partial):
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000156 pass
157
158class TestPartialSubclass(TestPartial):
159
160 thetype = PartialSubclass
161
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000162class TestPythonPartial(TestPartial):
163
164 thetype = PythonPartial
165
Jack Diederiche0cbd692009-04-01 04:27:09 +0000166 # the python version isn't picklable
167 def test_pickle(self): pass
168
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000169class TestUpdateWrapper(unittest.TestCase):
170
171 def check_wrapper(self, wrapper, wrapped,
172 assigned=functools.WRAPPER_ASSIGNMENTS,
173 updated=functools.WRAPPER_UPDATES):
174 # Check attributes were assigned
175 for name in assigned:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000176 self.assertTrue(getattr(wrapper, name) is getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000177 # Check attributes were updated
178 for name in updated:
179 wrapper_attr = getattr(wrapper, name)
180 wrapped_attr = getattr(wrapped, name)
181 for key in wrapped_attr:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000182 self.assertTrue(wrapped_attr[key] is wrapper_attr[key])
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000183
R. David Murray378c0cf2010-02-24 01:46:21 +0000184 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000185 def f():
186 """This is a test"""
187 pass
188 f.attr = 'This is also a test'
189 def wrapper():
190 pass
191 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000192 return wrapper, f
193
194 def test_default_update(self):
195 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000196 self.check_wrapper(wrapper, f)
197 self.assertEqual(wrapper.__name__, 'f')
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000198 self.assertEqual(wrapper.attr, 'This is also a test')
199
R. David Murray378c0cf2010-02-24 01:46:21 +0000200 @unittest.skipIf(sys.flags.optimize >= 2,
201 "Docstrings are omitted with -O2 and above")
202 def test_default_update_doc(self):
203 wrapper, f = self._default_update()
204 self.assertEqual(wrapper.__doc__, 'This is a test')
205
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000206 def test_no_update(self):
207 def f():
208 """This is a test"""
209 pass
210 f.attr = 'This is also a test'
211 def wrapper():
212 pass
213 functools.update_wrapper(wrapper, f, (), ())
214 self.check_wrapper(wrapper, f, (), ())
215 self.assertEqual(wrapper.__name__, 'wrapper')
216 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000217 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000218
219 def test_selective_update(self):
220 def f():
221 pass
222 f.attr = 'This is a different test'
223 f.dict_attr = dict(a=1, b=2, c=3)
224 def wrapper():
225 pass
226 wrapper.dict_attr = {}
227 assign = ('attr',)
228 update = ('dict_attr',)
229 functools.update_wrapper(wrapper, f, assign, update)
230 self.check_wrapper(wrapper, f, assign, update)
231 self.assertEqual(wrapper.__name__, 'wrapper')
232 self.assertEqual(wrapper.__doc__, None)
233 self.assertEqual(wrapper.attr, 'This is a different test')
234 self.assertEqual(wrapper.dict_attr, f.dict_attr)
235
Thomas Wouters89f507f2006-12-13 04:49:30 +0000236 def test_builtin_update(self):
237 # Test for bug #1576241
238 def wrapper():
239 pass
240 functools.update_wrapper(wrapper, max)
241 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000242 self.assertTrue(wrapper.__doc__.startswith('max('))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000243
244class TestWraps(TestUpdateWrapper):
245
R. David Murray378c0cf2010-02-24 01:46:21 +0000246 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000247 def f():
248 """This is a test"""
249 pass
250 f.attr = 'This is also a test'
251 @functools.wraps(f)
252 def wrapper():
253 pass
254 self.check_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000255 return wrapper
256
257 def test_default_update(self):
258 wrapper = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000259 self.assertEqual(wrapper.__name__, 'f')
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000260 self.assertEqual(wrapper.attr, 'This is also a test')
261
R. David Murray378c0cf2010-02-24 01:46:21 +0000262 @unittest.skipIf(not sys.flags.optimize <= 1,
263 "Docstrings are omitted with -O2 and above")
264 def test_default_update_doc(self):
265 wrapper = self._default_update()
266 self.assertEqual(wrapper.__doc__, 'This is a test')
267
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000268 def test_no_update(self):
269 def f():
270 """This is a test"""
271 pass
272 f.attr = 'This is also a test'
273 @functools.wraps(f, (), ())
274 def wrapper():
275 pass
276 self.check_wrapper(wrapper, f, (), ())
277 self.assertEqual(wrapper.__name__, 'wrapper')
278 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000279 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000280
281 def test_selective_update(self):
282 def f():
283 pass
284 f.attr = 'This is a different test'
285 f.dict_attr = dict(a=1, b=2, c=3)
286 def add_dict_attr(f):
287 f.dict_attr = {}
288 return f
289 assign = ('attr',)
290 update = ('dict_attr',)
291 @functools.wraps(f, assign, update)
292 @add_dict_attr
293 def wrapper():
294 pass
295 self.check_wrapper(wrapper, f, assign, update)
296 self.assertEqual(wrapper.__name__, 'wrapper')
297 self.assertEqual(wrapper.__doc__, None)
298 self.assertEqual(wrapper.attr, 'This is a different test')
299 self.assertEqual(wrapper.dict_attr, f.dict_attr)
300
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000301class TestReduce(unittest.TestCase):
302 func = functools.reduce
303
304 def test_reduce(self):
305 class Squares:
306 def __init__(self, max):
307 self.max = max
308 self.sofar = []
309
310 def __len__(self):
311 return len(self.sofar)
312
313 def __getitem__(self, i):
314 if not 0 <= i < self.max: raise IndexError
315 n = len(self.sofar)
316 while n <= i:
317 self.sofar.append(n*n)
318 n += 1
319 return self.sofar[i]
Guido van Rossumd8faa362007-04-27 19:54:29 +0000320
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000321 self.assertEqual(self.func(lambda x, y: x+y, ['a', 'b', 'c'], ''), 'abc')
322 self.assertEqual(
323 self.func(lambda x, y: x+y, [['a', 'c'], [], ['d', 'w']], []),
324 ['a','c','d','w']
325 )
326 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
327 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000328 self.func(lambda x, y: x*y, range(2,21), 1),
329 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000330 )
331 self.assertEqual(self.func(lambda x, y: x+y, Squares(10)), 285)
332 self.assertEqual(self.func(lambda x, y: x+y, Squares(10), 0), 285)
333 self.assertEqual(self.func(lambda x, y: x+y, Squares(0), 0), 0)
334 self.assertRaises(TypeError, self.func)
335 self.assertRaises(TypeError, self.func, 42, 42)
336 self.assertRaises(TypeError, self.func, 42, 42, 42)
337 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
338 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
339 self.assertRaises(TypeError, self.func, 42, (42, 42))
340
341 class BadSeq:
342 def __getitem__(self, index):
343 raise ValueError
344 self.assertRaises(ValueError, self.func, 42, BadSeq())
345
346 # Test reduce()'s use of iterators.
347 def test_iterator_usage(self):
348 class SequenceClass:
349 def __init__(self, n):
350 self.n = n
351 def __getitem__(self, i):
352 if 0 <= i < self.n:
353 return i
354 else:
355 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000356
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000357 from operator import add
358 self.assertEqual(self.func(add, SequenceClass(5)), 10)
359 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
360 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
361 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
362 self.assertEqual(self.func(add, SequenceClass(1)), 0)
363 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
364
365 d = {"one": 1, "two": 2, "three": 3}
366 self.assertEqual(self.func(add, d), "".join(d.keys()))
367
Guido van Rossumd8faa362007-04-27 19:54:29 +0000368
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000369
370
371def test_main(verbose=None):
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000372 test_classes = (
373 TestPartial,
374 TestPartialSubclass,
375 TestPythonPartial,
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000376 TestUpdateWrapper,
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000377 TestWraps,
378 TestReduce
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000379 )
Benjamin Petersonee8712c2008-05-20 21:35:26 +0000380 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000381
382 # verify reference counting
383 if verbose and hasattr(sys, "gettotalrefcount"):
384 import gc
385 counts = [None] * 5
Guido van Rossum805365e2007-05-07 22:24:25 +0000386 for i in range(len(counts)):
Benjamin Petersonee8712c2008-05-20 21:35:26 +0000387 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000388 gc.collect()
389 counts[i] = sys.gettotalrefcount()
Guido van Rossumbe19ed72007-02-09 05:37:30 +0000390 print(counts)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000391
392if __name__ == '__main__':
393 test_main(verbose=True)