blob: 55c549f6ab244a0fdaff075aaff977ffcf6da0c5 [file] [log] [blame]
Thomas Wouters4d70c3d2006-06-08 14:42:34 +00001import functools
Raymond Hettinger9c323f82005-02-28 19:39:44 +00002import unittest
3from test import test_support
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +00004from weakref import proxy
Raymond Hettinger9c323f82005-02-28 19:39:44 +00005
6@staticmethod
7def PythonPartial(func, *args, **keywords):
8 'Pure Python approximation of partial()'
9 def newfunc(*fargs, **fkeywords):
10 newkeywords = keywords.copy()
11 newkeywords.update(fkeywords)
12 return func(*(args + fargs), **newkeywords)
13 newfunc.func = func
14 newfunc.args = args
15 newfunc.keywords = keywords
16 return newfunc
17
18def capture(*args, **kw):
19 """capture all positional and keyword arguments"""
20 return args, kw
21
22class TestPartial(unittest.TestCase):
23
Thomas Wouters4d70c3d2006-06-08 14:42:34 +000024 thetype = functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +000025
26 def test_basic_examples(self):
27 p = self.thetype(capture, 1, 2, a=10, b=20)
28 self.assertEqual(p(3, 4, b=30, c=40),
29 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
30 p = self.thetype(map, lambda x: x*10)
31 self.assertEqual(p([1,2,3,4]), [10, 20, 30, 40])
32
33 def test_attributes(self):
34 p = self.thetype(capture, 1, 2, a=10, b=20)
35 # attributes should be readable
36 self.assertEqual(p.func, capture)
37 self.assertEqual(p.args, (1, 2))
38 self.assertEqual(p.keywords, dict(a=10, b=20))
39 # attributes should not be writable
40 if not isinstance(self.thetype, type):
41 return
42 self.assertRaises(TypeError, setattr, p, 'func', map)
43 self.assertRaises(TypeError, setattr, p, 'args', (1, 2))
44 self.assertRaises(TypeError, setattr, p, 'keywords', dict(a=1, b=2))
45
46 def test_argument_checking(self):
47 self.assertRaises(TypeError, self.thetype) # need at least a func arg
48 try:
49 self.thetype(2)()
50 except TypeError:
51 pass
52 else:
53 self.fail('First arg not checked for callability')
54
55 def test_protection_of_callers_dict_argument(self):
56 # a caller's dictionary should not be altered by partial
57 def func(a=10, b=20):
58 return a
59 d = {'a':3}
60 p = self.thetype(func, a=5)
61 self.assertEqual(p(**d), 3)
62 self.assertEqual(d, {'a':3})
63 p(b=7)
64 self.assertEqual(d, {'a':3})
65
66 def test_arg_combinations(self):
67 # exercise special code paths for zero args in either partial
68 # object or the caller
69 p = self.thetype(capture)
70 self.assertEqual(p(), ((), {}))
71 self.assertEqual(p(1,2), ((1,2), {}))
72 p = self.thetype(capture, 1, 2)
73 self.assertEqual(p(), ((1,2), {}))
74 self.assertEqual(p(3,4), ((1,2,3,4), {}))
75
76 def test_kw_combinations(self):
77 # exercise special code paths for no keyword args in
78 # either the partial object or the caller
79 p = self.thetype(capture)
80 self.assertEqual(p(), ((), {}))
81 self.assertEqual(p(a=1), ((), {'a':1}))
82 p = self.thetype(capture, a=1)
83 self.assertEqual(p(), ((), {'a':1}))
84 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
85 # keyword args in the call override those in the partial object
86 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
87
88 def test_positional(self):
89 # make sure positional arguments are captured correctly
90 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
91 p = self.thetype(capture, *args)
92 expected = args + ('x',)
93 got, empty = p('x')
94 self.failUnless(expected == got and empty == {})
95
96 def test_keyword(self):
97 # make sure keyword arguments are captured correctly
98 for a in ['a', 0, None, 3.5]:
99 p = self.thetype(capture, a=a)
100 expected = {'a':a,'x':None}
101 empty, got = p(x=None)
102 self.failUnless(expected == got and empty == ())
103
104 def test_no_side_effects(self):
105 # make sure there are no side effects that affect subsequent calls
106 p = self.thetype(capture, 0, a=1)
107 args1, kw1 = p(1, b=2)
108 self.failUnless(args1 == (0,1) and kw1 == {'a':1,'b':2})
109 args2, kw2 = p()
110 self.failUnless(args2 == (0,) and kw2 == {'a':1})
111
112 def test_error_propagation(self):
113 def f(x, y):
114 x / y
115 self.assertRaises(ZeroDivisionError, self.thetype(f, 1, 0))
116 self.assertRaises(ZeroDivisionError, self.thetype(f, 1), 0)
117 self.assertRaises(ZeroDivisionError, self.thetype(f), 1, 0)
118 self.assertRaises(ZeroDivisionError, self.thetype(f, y=0), 1)
119
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000120 def test_attributes(self):
121 p = self.thetype(hex)
122 try:
123 del p.__dict__
124 except TypeError:
125 pass
126 else:
127 self.fail('partial object allowed __dict__ to be deleted')
128
129 def test_weakref(self):
130 f = self.thetype(int, base=16)
131 p = proxy(f)
132 self.assertEqual(f.func, p.func)
133 f = None
134 self.assertRaises(ReferenceError, getattr, p, 'func')
135
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000136 def test_with_bound_and_unbound_methods(self):
137 data = map(str, range(10))
138 join = self.thetype(str.join, '')
139 self.assertEqual(join(data), '0123456789')
140 join = self.thetype(''.join)
141 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000142
Thomas Wouters4d70c3d2006-06-08 14:42:34 +0000143class PartialSubclass(functools.partial):
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000144 pass
145
146class TestPartialSubclass(TestPartial):
147
148 thetype = PartialSubclass
149
150
151class TestPythonPartial(TestPartial):
152
153 thetype = PythonPartial
154
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000155class TestUpdateWrapper(unittest.TestCase):
156
157 def check_wrapper(self, wrapper, wrapped,
158 assigned=functools.WRAPPER_ASSIGNMENTS,
159 updated=functools.WRAPPER_UPDATES):
160 # Check attributes were assigned
161 for name in assigned:
162 self.failUnless(getattr(wrapper, name) is getattr(wrapped, name))
163 # Check attributes were updated
164 for name in updated:
165 wrapper_attr = getattr(wrapper, name)
166 wrapped_attr = getattr(wrapped, name)
167 for key in wrapped_attr:
168 self.failUnless(wrapped_attr[key] is wrapper_attr[key])
169
170 def test_default_update(self):
171 def f():
172 """This is a test"""
173 pass
174 f.attr = 'This is also a test'
175 def wrapper():
176 pass
177 functools.update_wrapper(wrapper, f)
178 self.check_wrapper(wrapper, f)
179 self.assertEqual(wrapper.__name__, 'f')
180 self.assertEqual(wrapper.__doc__, 'This is a test')
181 self.assertEqual(wrapper.attr, 'This is also a test')
182
183 def test_no_update(self):
184 def f():
185 """This is a test"""
186 pass
187 f.attr = 'This is also a test'
188 def wrapper():
189 pass
190 functools.update_wrapper(wrapper, f, (), ())
191 self.check_wrapper(wrapper, f, (), ())
192 self.assertEqual(wrapper.__name__, 'wrapper')
193 self.assertEqual(wrapper.__doc__, None)
194 self.failIf(hasattr(wrapper, 'attr'))
195
196 def test_selective_update(self):
197 def f():
198 pass
199 f.attr = 'This is a different test'
200 f.dict_attr = dict(a=1, b=2, c=3)
201 def wrapper():
202 pass
203 wrapper.dict_attr = {}
204 assign = ('attr',)
205 update = ('dict_attr',)
206 functools.update_wrapper(wrapper, f, assign, update)
207 self.check_wrapper(wrapper, f, assign, update)
208 self.assertEqual(wrapper.__name__, 'wrapper')
209 self.assertEqual(wrapper.__doc__, None)
210 self.assertEqual(wrapper.attr, 'This is a different test')
211 self.assertEqual(wrapper.dict_attr, f.dict_attr)
212
Thomas Wouters89f507f2006-12-13 04:49:30 +0000213 def test_builtin_update(self):
214 # Test for bug #1576241
215 def wrapper():
216 pass
217 functools.update_wrapper(wrapper, max)
218 self.assertEqual(wrapper.__name__, 'max')
219 self.assert_(wrapper.__doc__.startswith('max('))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000220
221class TestWraps(TestUpdateWrapper):
222
223 def test_default_update(self):
224 def f():
225 """This is a test"""
226 pass
227 f.attr = 'This is also a test'
228 @functools.wraps(f)
229 def wrapper():
230 pass
231 self.check_wrapper(wrapper, f)
232 self.assertEqual(wrapper.__name__, 'f')
233 self.assertEqual(wrapper.__doc__, 'This is a test')
234 self.assertEqual(wrapper.attr, 'This is also a test')
235
236 def test_no_update(self):
237 def f():
238 """This is a test"""
239 pass
240 f.attr = 'This is also a test'
241 @functools.wraps(f, (), ())
242 def wrapper():
243 pass
244 self.check_wrapper(wrapper, f, (), ())
245 self.assertEqual(wrapper.__name__, 'wrapper')
246 self.assertEqual(wrapper.__doc__, None)
247 self.failIf(hasattr(wrapper, 'attr'))
248
249 def test_selective_update(self):
250 def f():
251 pass
252 f.attr = 'This is a different test'
253 f.dict_attr = dict(a=1, b=2, c=3)
254 def add_dict_attr(f):
255 f.dict_attr = {}
256 return f
257 assign = ('attr',)
258 update = ('dict_attr',)
259 @functools.wraps(f, assign, update)
260 @add_dict_attr
261 def wrapper():
262 pass
263 self.check_wrapper(wrapper, f, assign, update)
264 self.assertEqual(wrapper.__name__, 'wrapper')
265 self.assertEqual(wrapper.__doc__, None)
266 self.assertEqual(wrapper.attr, 'This is a different test')
267 self.assertEqual(wrapper.dict_attr, f.dict_attr)
268
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000269class TestReduce(unittest.TestCase):
270 func = functools.reduce
271
272 def test_reduce(self):
273 class Squares:
274 def __init__(self, max):
275 self.max = max
276 self.sofar = []
277
278 def __len__(self):
279 return len(self.sofar)
280
281 def __getitem__(self, i):
282 if not 0 <= i < self.max: raise IndexError
283 n = len(self.sofar)
284 while n <= i:
285 self.sofar.append(n*n)
286 n += 1
287 return self.sofar[i]
Guido van Rossumd8faa362007-04-27 19:54:29 +0000288
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000289 self.assertEqual(self.func(lambda x, y: x+y, ['a', 'b', 'c'], ''), 'abc')
290 self.assertEqual(
291 self.func(lambda x, y: x+y, [['a', 'c'], [], ['d', 'w']], []),
292 ['a','c','d','w']
293 )
294 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
295 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000296 self.func(lambda x, y: x*y, range(2,21), 1),
297 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000298 )
299 self.assertEqual(self.func(lambda x, y: x+y, Squares(10)), 285)
300 self.assertEqual(self.func(lambda x, y: x+y, Squares(10), 0), 285)
301 self.assertEqual(self.func(lambda x, y: x+y, Squares(0), 0), 0)
302 self.assertRaises(TypeError, self.func)
303 self.assertRaises(TypeError, self.func, 42, 42)
304 self.assertRaises(TypeError, self.func, 42, 42, 42)
305 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
306 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
307 self.assertRaises(TypeError, self.func, 42, (42, 42))
308
309 class BadSeq:
310 def __getitem__(self, index):
311 raise ValueError
312 self.assertRaises(ValueError, self.func, 42, BadSeq())
313
314 # Test reduce()'s use of iterators.
315 def test_iterator_usage(self):
316 class SequenceClass:
317 def __init__(self, n):
318 self.n = n
319 def __getitem__(self, i):
320 if 0 <= i < self.n:
321 return i
322 else:
323 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000324
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000325 from operator import add
326 self.assertEqual(self.func(add, SequenceClass(5)), 10)
327 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
328 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
329 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
330 self.assertEqual(self.func(add, SequenceClass(1)), 0)
331 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
332
333 d = {"one": 1, "two": 2, "three": 3}
334 self.assertEqual(self.func(add, d), "".join(d.keys()))
335
Guido van Rossumd8faa362007-04-27 19:54:29 +0000336
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000337
338
339def test_main(verbose=None):
340 import sys
341 test_classes = (
342 TestPartial,
343 TestPartialSubclass,
344 TestPythonPartial,
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000345 TestUpdateWrapper,
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000346 TestWraps,
347 TestReduce
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000348 )
349 test_support.run_unittest(*test_classes)
350
351 # verify reference counting
352 if verbose and hasattr(sys, "gettotalrefcount"):
353 import gc
354 counts = [None] * 5
Guido van Rossum805365e2007-05-07 22:24:25 +0000355 for i in range(len(counts)):
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000356 test_support.run_unittest(*test_classes)
357 gc.collect()
358 counts[i] = sys.gettotalrefcount()
Guido van Rossumbe19ed72007-02-09 05:37:30 +0000359 print(counts)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000360
361if __name__ == '__main__':
362 test_main(verbose=True)