blob: d20bafe7dcffd6c202afa4796280e855cdbd9f02 [file] [log] [blame]
Thomas Wouters4d70c3d2006-06-08 14:42:34 +00001import functools
Raymond Hettinger9c323f82005-02-28 19:39:44 +00002import unittest
Benjamin Petersonee8712c2008-05-20 21:35:26 +00003from test import support
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +00004from weakref import proxy
Jack Diederiche0cbd692009-04-01 04:27:09 +00005import pickle
Raymond Hettinger9c323f82005-02-28 19:39:44 +00006
7@staticmethod
8def PythonPartial(func, *args, **keywords):
9 'Pure Python approximation of partial()'
10 def newfunc(*fargs, **fkeywords):
11 newkeywords = keywords.copy()
12 newkeywords.update(fkeywords)
13 return func(*(args + fargs), **newkeywords)
14 newfunc.func = func
15 newfunc.args = args
16 newfunc.keywords = keywords
17 return newfunc
18
19def capture(*args, **kw):
20 """capture all positional and keyword arguments"""
21 return args, kw
22
Jack Diederiche0cbd692009-04-01 04:27:09 +000023def signature(part):
24 """ return the signature of a partial object """
25 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000026
Raymond Hettinger9c323f82005-02-28 19:39:44 +000027class TestPartial(unittest.TestCase):
28
Thomas Wouters4d70c3d2006-06-08 14:42:34 +000029 thetype = functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +000030
31 def test_basic_examples(self):
32 p = self.thetype(capture, 1, 2, a=10, b=20)
33 self.assertEqual(p(3, 4, b=30, c=40),
34 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
35 p = self.thetype(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000036 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000037
38 def test_attributes(self):
39 p = self.thetype(capture, 1, 2, a=10, b=20)
40 # attributes should be readable
41 self.assertEqual(p.func, capture)
42 self.assertEqual(p.args, (1, 2))
43 self.assertEqual(p.keywords, dict(a=10, b=20))
44 # attributes should not be writable
45 if not isinstance(self.thetype, type):
46 return
47 self.assertRaises(TypeError, setattr, p, 'func', map)
48 self.assertRaises(TypeError, setattr, p, 'args', (1, 2))
49 self.assertRaises(TypeError, setattr, p, 'keywords', dict(a=1, b=2))
50
51 def test_argument_checking(self):
52 self.assertRaises(TypeError, self.thetype) # need at least a func arg
53 try:
54 self.thetype(2)()
55 except TypeError:
56 pass
57 else:
58 self.fail('First arg not checked for callability')
59
60 def test_protection_of_callers_dict_argument(self):
61 # a caller's dictionary should not be altered by partial
62 def func(a=10, b=20):
63 return a
64 d = {'a':3}
65 p = self.thetype(func, a=5)
66 self.assertEqual(p(**d), 3)
67 self.assertEqual(d, {'a':3})
68 p(b=7)
69 self.assertEqual(d, {'a':3})
70
71 def test_arg_combinations(self):
72 # exercise special code paths for zero args in either partial
73 # object or the caller
74 p = self.thetype(capture)
75 self.assertEqual(p(), ((), {}))
76 self.assertEqual(p(1,2), ((1,2), {}))
77 p = self.thetype(capture, 1, 2)
78 self.assertEqual(p(), ((1,2), {}))
79 self.assertEqual(p(3,4), ((1,2,3,4), {}))
80
81 def test_kw_combinations(self):
82 # exercise special code paths for no keyword args in
83 # either the partial object or the caller
84 p = self.thetype(capture)
85 self.assertEqual(p(), ((), {}))
86 self.assertEqual(p(a=1), ((), {'a':1}))
87 p = self.thetype(capture, a=1)
88 self.assertEqual(p(), ((), {'a':1}))
89 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
90 # keyword args in the call override those in the partial object
91 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
92
93 def test_positional(self):
94 # make sure positional arguments are captured correctly
95 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
96 p = self.thetype(capture, *args)
97 expected = args + ('x',)
98 got, empty = p('x')
Georg Brandlab91fde2009-08-13 08:51:18 +000099 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000100
101 def test_keyword(self):
102 # make sure keyword arguments are captured correctly
103 for a in ['a', 0, None, 3.5]:
104 p = self.thetype(capture, a=a)
105 expected = {'a':a,'x':None}
106 empty, got = p(x=None)
Georg Brandlab91fde2009-08-13 08:51:18 +0000107 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000108
109 def test_no_side_effects(self):
110 # make sure there are no side effects that affect subsequent calls
111 p = self.thetype(capture, 0, a=1)
112 args1, kw1 = p(1, b=2)
Georg Brandlab91fde2009-08-13 08:51:18 +0000113 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000114 args2, kw2 = p()
Georg Brandlab91fde2009-08-13 08:51:18 +0000115 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000116
117 def test_error_propagation(self):
118 def f(x, y):
119 x / y
120 self.assertRaises(ZeroDivisionError, self.thetype(f, 1, 0))
121 self.assertRaises(ZeroDivisionError, self.thetype(f, 1), 0)
122 self.assertRaises(ZeroDivisionError, self.thetype(f), 1, 0)
123 self.assertRaises(ZeroDivisionError, self.thetype(f, y=0), 1)
124
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000125 def test_attributes(self):
126 p = self.thetype(hex)
127 try:
128 del p.__dict__
129 except TypeError:
130 pass
131 else:
132 self.fail('partial object allowed __dict__ to be deleted')
133
134 def test_weakref(self):
135 f = self.thetype(int, base=16)
136 p = proxy(f)
137 self.assertEqual(f.func, p.func)
138 f = None
139 self.assertRaises(ReferenceError, getattr, p, 'func')
140
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000141 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000142 data = list(map(str, range(10)))
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000143 join = self.thetype(str.join, '')
144 self.assertEqual(join(data), '0123456789')
145 join = self.thetype(''.join)
146 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000147
Jack Diederiche0cbd692009-04-01 04:27:09 +0000148 def test_pickle(self):
149 f = self.thetype(signature, 'asdf', bar=True)
150 f.add_something_to__dict__ = True
151 f_copy = pickle.loads(pickle.dumps(f))
152 self.assertEqual(signature(f), signature(f_copy))
153
Thomas Wouters4d70c3d2006-06-08 14:42:34 +0000154class PartialSubclass(functools.partial):
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000155 pass
156
157class TestPartialSubclass(TestPartial):
158
159 thetype = PartialSubclass
160
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000161class TestPythonPartial(TestPartial):
162
163 thetype = PythonPartial
164
Jack Diederiche0cbd692009-04-01 04:27:09 +0000165 # the python version isn't picklable
166 def test_pickle(self): pass
167
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000168class TestUpdateWrapper(unittest.TestCase):
169
170 def check_wrapper(self, wrapper, wrapped,
171 assigned=functools.WRAPPER_ASSIGNMENTS,
172 updated=functools.WRAPPER_UPDATES):
173 # Check attributes were assigned
174 for name in assigned:
Georg Brandlab91fde2009-08-13 08:51:18 +0000175 self.assertTrue(getattr(wrapper, name) is getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000176 # Check attributes were updated
177 for name in updated:
178 wrapper_attr = getattr(wrapper, name)
179 wrapped_attr = getattr(wrapped, name)
180 for key in wrapped_attr:
Georg Brandlab91fde2009-08-13 08:51:18 +0000181 self.assertTrue(wrapped_attr[key] is wrapper_attr[key])
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000182
183 def test_default_update(self):
Raymond Hettingerc6d80c12010-08-08 00:56:52 +0000184 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000185 """This is a test"""
186 pass
187 f.attr = 'This is also a test'
Raymond Hettingerc6d80c12010-08-08 00:56:52 +0000188 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000189 pass
190 functools.update_wrapper(wrapper, f)
191 self.check_wrapper(wrapper, f)
192 self.assertEqual(wrapper.__name__, 'f')
193 self.assertEqual(wrapper.__doc__, 'This is a test')
194 self.assertEqual(wrapper.attr, 'This is also a test')
Raymond Hettingerc6d80c12010-08-08 00:56:52 +0000195 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
196 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000197
198 def test_no_update(self):
199 def f():
200 """This is a test"""
201 pass
202 f.attr = 'This is also a test'
203 def wrapper():
204 pass
205 functools.update_wrapper(wrapper, f, (), ())
206 self.check_wrapper(wrapper, f, (), ())
207 self.assertEqual(wrapper.__name__, 'wrapper')
208 self.assertEqual(wrapper.__doc__, None)
Raymond Hettingerc6d80c12010-08-08 00:56:52 +0000209 self.assertEqual(wrapper.__annotations__, {})
Georg Brandlab91fde2009-08-13 08:51:18 +0000210 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000211
212 def test_selective_update(self):
213 def f():
214 pass
215 f.attr = 'This is a different test'
216 f.dict_attr = dict(a=1, b=2, c=3)
217 def wrapper():
218 pass
219 wrapper.dict_attr = {}
220 assign = ('attr',)
221 update = ('dict_attr',)
222 functools.update_wrapper(wrapper, f, assign, update)
223 self.check_wrapper(wrapper, f, assign, update)
224 self.assertEqual(wrapper.__name__, 'wrapper')
225 self.assertEqual(wrapper.__doc__, None)
226 self.assertEqual(wrapper.attr, 'This is a different test')
227 self.assertEqual(wrapper.dict_attr, f.dict_attr)
228
Thomas Wouters89f507f2006-12-13 04:49:30 +0000229 def test_builtin_update(self):
230 # Test for bug #1576241
231 def wrapper():
232 pass
233 functools.update_wrapper(wrapper, max)
234 self.assertEqual(wrapper.__name__, 'max')
Georg Brandlab91fde2009-08-13 08:51:18 +0000235 self.assertTrue(wrapper.__doc__.startswith('max('))
Raymond Hettingerc6d80c12010-08-08 00:56:52 +0000236 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000237
238class TestWraps(TestUpdateWrapper):
239
240 def test_default_update(self):
241 def f():
242 """This is a test"""
243 pass
244 f.attr = 'This is also a test'
245 @functools.wraps(f)
246 def wrapper():
247 pass
248 self.check_wrapper(wrapper, f)
249 self.assertEqual(wrapper.__name__, 'f')
250 self.assertEqual(wrapper.__doc__, 'This is a test')
251 self.assertEqual(wrapper.attr, 'This is also a test')
252
253 def test_no_update(self):
254 def f():
255 """This is a test"""
256 pass
257 f.attr = 'This is also a test'
258 @functools.wraps(f, (), ())
259 def wrapper():
260 pass
261 self.check_wrapper(wrapper, f, (), ())
262 self.assertEqual(wrapper.__name__, 'wrapper')
263 self.assertEqual(wrapper.__doc__, None)
Georg Brandlab91fde2009-08-13 08:51:18 +0000264 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000265
266 def test_selective_update(self):
267 def f():
268 pass
269 f.attr = 'This is a different test'
270 f.dict_attr = dict(a=1, b=2, c=3)
271 def add_dict_attr(f):
272 f.dict_attr = {}
273 return f
274 assign = ('attr',)
275 update = ('dict_attr',)
276 @functools.wraps(f, assign, update)
277 @add_dict_attr
278 def wrapper():
279 pass
280 self.check_wrapper(wrapper, f, assign, update)
281 self.assertEqual(wrapper.__name__, 'wrapper')
282 self.assertEqual(wrapper.__doc__, None)
283 self.assertEqual(wrapper.attr, 'This is a different test')
284 self.assertEqual(wrapper.dict_attr, f.dict_attr)
285
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000286class TestReduce(unittest.TestCase):
287 func = functools.reduce
288
289 def test_reduce(self):
290 class Squares:
291 def __init__(self, max):
292 self.max = max
293 self.sofar = []
294
295 def __len__(self):
296 return len(self.sofar)
297
298 def __getitem__(self, i):
299 if not 0 <= i < self.max: raise IndexError
300 n = len(self.sofar)
301 while n <= i:
302 self.sofar.append(n*n)
303 n += 1
304 return self.sofar[i]
Alexander Belopolsky7a9bdbc2010-08-16 19:46:32 +0000305 def add(x, y):
306 return x + y
307 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000308 self.assertEqual(
Alexander Belopolsky7a9bdbc2010-08-16 19:46:32 +0000309 self.func(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000310 ['a','c','d','w']
311 )
312 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
313 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000314 self.func(lambda x, y: x*y, range(2,21), 1),
315 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000316 )
Alexander Belopolsky7a9bdbc2010-08-16 19:46:32 +0000317 self.assertEqual(self.func(add, Squares(10)), 285)
318 self.assertEqual(self.func(add, Squares(10), 0), 285)
319 self.assertEqual(self.func(add, Squares(0), 0), 0)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000320 self.assertRaises(TypeError, self.func)
321 self.assertRaises(TypeError, self.func, 42, 42)
322 self.assertRaises(TypeError, self.func, 42, 42, 42)
323 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
324 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
325 self.assertRaises(TypeError, self.func, 42, (42, 42))
Alexander Belopolsky7a9bdbc2010-08-16 19:46:32 +0000326 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
327 self.assertRaises(TypeError, self.func, add, "")
328 self.assertRaises(TypeError, self.func, add, ())
329 self.assertRaises(TypeError, self.func, add, object())
330
331 class TestFailingIter:
332 def __iter__(self):
333 raise RuntimeError
334 self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
335
336 self.assertEqual(self.func(add, [], None), None)
337 self.assertEqual(self.func(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000338
339 class BadSeq:
340 def __getitem__(self, index):
341 raise ValueError
342 self.assertRaises(ValueError, self.func, 42, BadSeq())
343
344 # Test reduce()'s use of iterators.
345 def test_iterator_usage(self):
346 class SequenceClass:
347 def __init__(self, n):
348 self.n = n
349 def __getitem__(self, i):
350 if 0 <= i < self.n:
351 return i
352 else:
353 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000354
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000355 from operator import add
356 self.assertEqual(self.func(add, SequenceClass(5)), 10)
357 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
358 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
359 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
360 self.assertEqual(self.func(add, SequenceClass(1)), 0)
361 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
362
363 d = {"one": 1, "two": 2, "three": 3}
364 self.assertEqual(self.func(add, d), "".join(d.keys()))
365
Guido van Rossumd8faa362007-04-27 19:54:29 +0000366
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000367
368
369def test_main(verbose=None):
370 import sys
371 test_classes = (
372 TestPartial,
373 TestPartialSubclass,
374 TestPythonPartial,
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000375 TestUpdateWrapper,
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000376 TestWraps,
377 TestReduce
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000378 )
Benjamin Petersonee8712c2008-05-20 21:35:26 +0000379 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000380
381 # verify reference counting
382 if verbose and hasattr(sys, "gettotalrefcount"):
383 import gc
384 counts = [None] * 5
Guido van Rossum805365e2007-05-07 22:24:25 +0000385 for i in range(len(counts)):
Benjamin Petersonee8712c2008-05-20 21:35:26 +0000386 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000387 gc.collect()
388 counts[i] = sys.gettotalrefcount()
Guido van Rossumbe19ed72007-02-09 05:37:30 +0000389 print(counts)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000390
391if __name__ == '__main__':
392 test_main(verbose=True)