blob: 01d6cd29e8a62c8c6e5578001b5e75b59c06d66c [file] [log] [blame]
Thomas Wouters4d70c3d2006-06-08 14:42:34 +00001import functools
Raymond Hettinger9c323f82005-02-28 19:39:44 +00002import unittest
3from test import test_support
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +00004from weakref import proxy
Raymond Hettinger9c323f82005-02-28 19:39:44 +00005
6@staticmethod
7def PythonPartial(func, *args, **keywords):
8 'Pure Python approximation of partial()'
9 def newfunc(*fargs, **fkeywords):
10 newkeywords = keywords.copy()
11 newkeywords.update(fkeywords)
12 return func(*(args + fargs), **newkeywords)
13 newfunc.func = func
14 newfunc.args = args
15 newfunc.keywords = keywords
16 return newfunc
17
18def capture(*args, **kw):
19 """capture all positional and keyword arguments"""
20 return args, kw
21
22class TestPartial(unittest.TestCase):
23
Thomas Wouters4d70c3d2006-06-08 14:42:34 +000024 thetype = functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +000025
26 def test_basic_examples(self):
27 p = self.thetype(capture, 1, 2, a=10, b=20)
28 self.assertEqual(p(3, 4, b=30, c=40),
29 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
30 p = self.thetype(map, lambda x: x*10)
31 self.assertEqual(p([1,2,3,4]), [10, 20, 30, 40])
32
33 def test_attributes(self):
34 p = self.thetype(capture, 1, 2, a=10, b=20)
35 # attributes should be readable
36 self.assertEqual(p.func, capture)
37 self.assertEqual(p.args, (1, 2))
38 self.assertEqual(p.keywords, dict(a=10, b=20))
39 # attributes should not be writable
40 if not isinstance(self.thetype, type):
41 return
42 self.assertRaises(TypeError, setattr, p, 'func', map)
43 self.assertRaises(TypeError, setattr, p, 'args', (1, 2))
44 self.assertRaises(TypeError, setattr, p, 'keywords', dict(a=1, b=2))
45
46 def test_argument_checking(self):
47 self.assertRaises(TypeError, self.thetype) # need at least a func arg
48 try:
49 self.thetype(2)()
50 except TypeError:
51 pass
52 else:
53 self.fail('First arg not checked for callability')
54
55 def test_protection_of_callers_dict_argument(self):
56 # a caller's dictionary should not be altered by partial
57 def func(a=10, b=20):
58 return a
59 d = {'a':3}
60 p = self.thetype(func, a=5)
61 self.assertEqual(p(**d), 3)
62 self.assertEqual(d, {'a':3})
63 p(b=7)
64 self.assertEqual(d, {'a':3})
65
66 def test_arg_combinations(self):
67 # exercise special code paths for zero args in either partial
68 # object or the caller
69 p = self.thetype(capture)
70 self.assertEqual(p(), ((), {}))
71 self.assertEqual(p(1,2), ((1,2), {}))
72 p = self.thetype(capture, 1, 2)
73 self.assertEqual(p(), ((1,2), {}))
74 self.assertEqual(p(3,4), ((1,2,3,4), {}))
75
76 def test_kw_combinations(self):
77 # exercise special code paths for no keyword args in
78 # either the partial object or the caller
79 p = self.thetype(capture)
80 self.assertEqual(p(), ((), {}))
81 self.assertEqual(p(a=1), ((), {'a':1}))
82 p = self.thetype(capture, a=1)
83 self.assertEqual(p(), ((), {'a':1}))
84 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
85 # keyword args in the call override those in the partial object
86 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
87
88 def test_positional(self):
89 # make sure positional arguments are captured correctly
90 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
91 p = self.thetype(capture, *args)
92 expected = args + ('x',)
93 got, empty = p('x')
94 self.failUnless(expected == got and empty == {})
95
96 def test_keyword(self):
97 # make sure keyword arguments are captured correctly
98 for a in ['a', 0, None, 3.5]:
99 p = self.thetype(capture, a=a)
100 expected = {'a':a,'x':None}
101 empty, got = p(x=None)
102 self.failUnless(expected == got and empty == ())
103
104 def test_no_side_effects(self):
105 # make sure there are no side effects that affect subsequent calls
106 p = self.thetype(capture, 0, a=1)
107 args1, kw1 = p(1, b=2)
108 self.failUnless(args1 == (0,1) and kw1 == {'a':1,'b':2})
109 args2, kw2 = p()
110 self.failUnless(args2 == (0,) and kw2 == {'a':1})
111
112 def test_error_propagation(self):
113 def f(x, y):
114 x / y
115 self.assertRaises(ZeroDivisionError, self.thetype(f, 1, 0))
116 self.assertRaises(ZeroDivisionError, self.thetype(f, 1), 0)
117 self.assertRaises(ZeroDivisionError, self.thetype(f), 1, 0)
118 self.assertRaises(ZeroDivisionError, self.thetype(f, y=0), 1)
119
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000120 def test_attributes(self):
121 p = self.thetype(hex)
122 try:
123 del p.__dict__
124 except TypeError:
125 pass
126 else:
127 self.fail('partial object allowed __dict__ to be deleted')
128
129 def test_weakref(self):
130 f = self.thetype(int, base=16)
131 p = proxy(f)
132 self.assertEqual(f.func, p.func)
133 f = None
134 self.assertRaises(ReferenceError, getattr, p, 'func')
135
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000136 def test_with_bound_and_unbound_methods(self):
137 data = map(str, range(10))
138 join = self.thetype(str.join, '')
139 self.assertEqual(join(data), '0123456789')
140 join = self.thetype(''.join)
141 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000142
Thomas Wouters4d70c3d2006-06-08 14:42:34 +0000143class PartialSubclass(functools.partial):
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000144 pass
145
146class TestPartialSubclass(TestPartial):
147
148 thetype = PartialSubclass
149
150
151class TestPythonPartial(TestPartial):
152
153 thetype = PythonPartial
154
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000155class TestUpdateWrapper(unittest.TestCase):
156
157 def check_wrapper(self, wrapper, wrapped,
158 assigned=functools.WRAPPER_ASSIGNMENTS,
159 updated=functools.WRAPPER_UPDATES):
160 # Check attributes were assigned
161 for name in assigned:
162 self.failUnless(getattr(wrapper, name) is getattr(wrapped, name))
163 # Check attributes were updated
164 for name in updated:
165 wrapper_attr = getattr(wrapper, name)
166 wrapped_attr = getattr(wrapped, name)
167 for key in wrapped_attr:
168 self.failUnless(wrapped_attr[key] is wrapper_attr[key])
169
170 def test_default_update(self):
171 def f():
172 """This is a test"""
173 pass
174 f.attr = 'This is also a test'
175 def wrapper():
176 pass
177 functools.update_wrapper(wrapper, f)
178 self.check_wrapper(wrapper, f)
179 self.assertEqual(wrapper.__name__, 'f')
180 self.assertEqual(wrapper.__doc__, 'This is a test')
181 self.assertEqual(wrapper.attr, 'This is also a test')
182
183 def test_no_update(self):
184 def f():
185 """This is a test"""
186 pass
187 f.attr = 'This is also a test'
188 def wrapper():
189 pass
190 functools.update_wrapper(wrapper, f, (), ())
191 self.check_wrapper(wrapper, f, (), ())
192 self.assertEqual(wrapper.__name__, 'wrapper')
193 self.assertEqual(wrapper.__doc__, None)
194 self.failIf(hasattr(wrapper, 'attr'))
195
196 def test_selective_update(self):
197 def f():
198 pass
199 f.attr = 'This is a different test'
200 f.dict_attr = dict(a=1, b=2, c=3)
201 def wrapper():
202 pass
203 wrapper.dict_attr = {}
204 assign = ('attr',)
205 update = ('dict_attr',)
206 functools.update_wrapper(wrapper, f, assign, update)
207 self.check_wrapper(wrapper, f, assign, update)
208 self.assertEqual(wrapper.__name__, 'wrapper')
209 self.assertEqual(wrapper.__doc__, None)
210 self.assertEqual(wrapper.attr, 'This is a different test')
211 self.assertEqual(wrapper.dict_attr, f.dict_attr)
212
213
214class TestWraps(TestUpdateWrapper):
215
216 def test_default_update(self):
217 def f():
218 """This is a test"""
219 pass
220 f.attr = 'This is also a test'
221 @functools.wraps(f)
222 def wrapper():
223 pass
224 self.check_wrapper(wrapper, f)
225 self.assertEqual(wrapper.__name__, 'f')
226 self.assertEqual(wrapper.__doc__, 'This is a test')
227 self.assertEqual(wrapper.attr, 'This is also a test')
228
229 def test_no_update(self):
230 def f():
231 """This is a test"""
232 pass
233 f.attr = 'This is also a test'
234 @functools.wraps(f, (), ())
235 def wrapper():
236 pass
237 self.check_wrapper(wrapper, f, (), ())
238 self.assertEqual(wrapper.__name__, 'wrapper')
239 self.assertEqual(wrapper.__doc__, None)
240 self.failIf(hasattr(wrapper, 'attr'))
241
242 def test_selective_update(self):
243 def f():
244 pass
245 f.attr = 'This is a different test'
246 f.dict_attr = dict(a=1, b=2, c=3)
247 def add_dict_attr(f):
248 f.dict_attr = {}
249 return f
250 assign = ('attr',)
251 update = ('dict_attr',)
252 @functools.wraps(f, assign, update)
253 @add_dict_attr
254 def wrapper():
255 pass
256 self.check_wrapper(wrapper, f, assign, update)
257 self.assertEqual(wrapper.__name__, 'wrapper')
258 self.assertEqual(wrapper.__doc__, None)
259 self.assertEqual(wrapper.attr, 'This is a different test')
260 self.assertEqual(wrapper.dict_attr, f.dict_attr)
261
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000262class TestReduce(unittest.TestCase):
263 func = functools.reduce
264
265 def test_reduce(self):
266 class Squares:
267 def __init__(self, max):
268 self.max = max
269 self.sofar = []
270
271 def __len__(self):
272 return len(self.sofar)
273
274 def __getitem__(self, i):
275 if not 0 <= i < self.max: raise IndexError
276 n = len(self.sofar)
277 while n <= i:
278 self.sofar.append(n*n)
279 n += 1
280 return self.sofar[i]
281
282 self.assertEqual(self.func(lambda x, y: x+y, ['a', 'b', 'c'], ''), 'abc')
283 self.assertEqual(
284 self.func(lambda x, y: x+y, [['a', 'c'], [], ['d', 'w']], []),
285 ['a','c','d','w']
286 )
287 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
288 self.assertEqual(
289 self.func(lambda x, y: x*y, range(2,21), 1L),
290 2432902008176640000L
291 )
292 self.assertEqual(self.func(lambda x, y: x+y, Squares(10)), 285)
293 self.assertEqual(self.func(lambda x, y: x+y, Squares(10), 0), 285)
294 self.assertEqual(self.func(lambda x, y: x+y, Squares(0), 0), 0)
295 self.assertRaises(TypeError, self.func)
296 self.assertRaises(TypeError, self.func, 42, 42)
297 self.assertRaises(TypeError, self.func, 42, 42, 42)
298 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
299 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
300 self.assertRaises(TypeError, self.func, 42, (42, 42))
301
302 class BadSeq:
303 def __getitem__(self, index):
304 raise ValueError
305 self.assertRaises(ValueError, self.func, 42, BadSeq())
306
307 # Test reduce()'s use of iterators.
308 def test_iterator_usage(self):
309 class SequenceClass:
310 def __init__(self, n):
311 self.n = n
312 def __getitem__(self, i):
313 if 0 <= i < self.n:
314 return i
315 else:
316 raise IndexError
317
318 from operator import add
319 self.assertEqual(self.func(add, SequenceClass(5)), 10)
320 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
321 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
322 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
323 self.assertEqual(self.func(add, SequenceClass(1)), 0)
324 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
325
326 d = {"one": 1, "two": 2, "three": 3}
327 self.assertEqual(self.func(add, d), "".join(d.keys()))
328
329
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000330
331
332def test_main(verbose=None):
333 import sys
334 test_classes = (
335 TestPartial,
336 TestPartialSubclass,
337 TestPythonPartial,
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000338 TestUpdateWrapper,
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000339 TestWraps,
340 TestReduce
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000341 )
342 test_support.run_unittest(*test_classes)
343
344 # verify reference counting
345 if verbose and hasattr(sys, "gettotalrefcount"):
346 import gc
347 counts = [None] * 5
348 for i in xrange(len(counts)):
349 test_support.run_unittest(*test_classes)
350 gc.collect()
351 counts[i] = sys.gettotalrefcount()
352 print counts
353
354if __name__ == '__main__':
355 test_main(verbose=True)