blob: d1ce2a9c8efaaf854532459b963e5ec283769392 [file] [log] [blame]
Thomas Wouters4d70c3d2006-06-08 14:42:34 +00001import functools
Raymond Hettinger003be522011-05-03 11:01:32 -07002import collections
R. David Murray378c0cf2010-02-24 01:46:21 +00003import sys
Raymond Hettinger9c323f82005-02-28 19:39:44 +00004import unittest
Benjamin Petersonee8712c2008-05-20 21:35:26 +00005from test import support
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +00006from weakref import proxy
Jack Diederiche0cbd692009-04-01 04:27:09 +00007import pickle
Georg Brandl2e7346a2010-07-31 18:09:23 +00008from random import choice
Raymond Hettinger9c323f82005-02-28 19:39:44 +00009
10@staticmethod
11def PythonPartial(func, *args, **keywords):
12 'Pure Python approximation of partial()'
13 def newfunc(*fargs, **fkeywords):
14 newkeywords = keywords.copy()
15 newkeywords.update(fkeywords)
16 return func(*(args + fargs), **newkeywords)
17 newfunc.func = func
18 newfunc.args = args
19 newfunc.keywords = keywords
20 return newfunc
21
22def capture(*args, **kw):
23 """capture all positional and keyword arguments"""
24 return args, kw
25
Jack Diederiche0cbd692009-04-01 04:27:09 +000026def signature(part):
27 """ return the signature of a partial object """
28 return (part.func, part.args, part.keywords, part.__dict__)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000029
Raymond Hettinger9c323f82005-02-28 19:39:44 +000030class TestPartial(unittest.TestCase):
31
Thomas Wouters4d70c3d2006-06-08 14:42:34 +000032 thetype = functools.partial
Raymond Hettinger9c323f82005-02-28 19:39:44 +000033
34 def test_basic_examples(self):
35 p = self.thetype(capture, 1, 2, a=10, b=20)
36 self.assertEqual(p(3, 4, b=30, c=40),
37 ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
38 p = self.thetype(map, lambda x: x*10)
Guido van Rossumc1f779c2007-07-03 08:25:58 +000039 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
Raymond Hettinger9c323f82005-02-28 19:39:44 +000040
41 def test_attributes(self):
42 p = self.thetype(capture, 1, 2, a=10, b=20)
43 # attributes should be readable
44 self.assertEqual(p.func, capture)
45 self.assertEqual(p.args, (1, 2))
46 self.assertEqual(p.keywords, dict(a=10, b=20))
47 # attributes should not be writable
Georg Brandl89fad142010-03-14 10:23:39 +000048 self.assertRaises(AttributeError, setattr, p, 'func', map)
49 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
50 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
51
52 p = self.thetype(hex)
53 try:
54 del p.__dict__
55 except TypeError:
56 pass
57 else:
58 self.fail('partial object allowed __dict__ to be deleted')
Raymond Hettinger9c323f82005-02-28 19:39:44 +000059
60 def test_argument_checking(self):
61 self.assertRaises(TypeError, self.thetype) # need at least a func arg
62 try:
63 self.thetype(2)()
64 except TypeError:
65 pass
66 else:
67 self.fail('First arg not checked for callability')
68
69 def test_protection_of_callers_dict_argument(self):
70 # a caller's dictionary should not be altered by partial
71 def func(a=10, b=20):
72 return a
73 d = {'a':3}
74 p = self.thetype(func, a=5)
75 self.assertEqual(p(**d), 3)
76 self.assertEqual(d, {'a':3})
77 p(b=7)
78 self.assertEqual(d, {'a':3})
79
80 def test_arg_combinations(self):
81 # exercise special code paths for zero args in either partial
82 # object or the caller
83 p = self.thetype(capture)
84 self.assertEqual(p(), ((), {}))
85 self.assertEqual(p(1,2), ((1,2), {}))
86 p = self.thetype(capture, 1, 2)
87 self.assertEqual(p(), ((1,2), {}))
88 self.assertEqual(p(3,4), ((1,2,3,4), {}))
89
90 def test_kw_combinations(self):
91 # exercise special code paths for no keyword args in
92 # either the partial object or the caller
93 p = self.thetype(capture)
94 self.assertEqual(p(), ((), {}))
95 self.assertEqual(p(a=1), ((), {'a':1}))
96 p = self.thetype(capture, a=1)
97 self.assertEqual(p(), ((), {'a':1}))
98 self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
99 # keyword args in the call override those in the partial object
100 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
101
102 def test_positional(self):
103 # make sure positional arguments are captured correctly
104 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
105 p = self.thetype(capture, *args)
106 expected = args + ('x',)
107 got, empty = p('x')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000108 self.assertTrue(expected == got and empty == {})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000109
110 def test_keyword(self):
111 # make sure keyword arguments are captured correctly
112 for a in ['a', 0, None, 3.5]:
113 p = self.thetype(capture, a=a)
114 expected = {'a':a,'x':None}
115 empty, got = p(x=None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000116 self.assertTrue(expected == got and empty == ())
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000117
118 def test_no_side_effects(self):
119 # make sure there are no side effects that affect subsequent calls
120 p = self.thetype(capture, 0, a=1)
121 args1, kw1 = p(1, b=2)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000122 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000123 args2, kw2 = p()
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000124 self.assertTrue(args2 == (0,) and kw2 == {'a':1})
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000125
126 def test_error_propagation(self):
127 def f(x, y):
128 x / y
129 self.assertRaises(ZeroDivisionError, self.thetype(f, 1, 0))
130 self.assertRaises(ZeroDivisionError, self.thetype(f, 1), 0)
131 self.assertRaises(ZeroDivisionError, self.thetype(f), 1, 0)
132 self.assertRaises(ZeroDivisionError, self.thetype(f, y=0), 1)
133
Raymond Hettingerc8b6d1b2005-03-08 06:14:50 +0000134 def test_weakref(self):
135 f = self.thetype(int, base=16)
136 p = proxy(f)
137 self.assertEqual(f.func, p.func)
138 f = None
139 self.assertRaises(ReferenceError, getattr, p, 'func')
140
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000141 def test_with_bound_and_unbound_methods(self):
Guido van Rossumc1f779c2007-07-03 08:25:58 +0000142 data = list(map(str, range(10)))
Raymond Hettinger26e512a2005-03-11 06:48:49 +0000143 join = self.thetype(str.join, '')
144 self.assertEqual(join(data), '0123456789')
145 join = self.thetype(''.join)
146 self.assertEqual(join(data), '0123456789')
Tim Peterseba28be2005-03-28 01:08:02 +0000147
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000148 def test_repr(self):
149 args = (object(), object())
150 args_repr = ', '.join(repr(a) for a in args)
151 kwargs = {'a': object(), 'b': object()}
152 kwargs_repr = ', '.join("%s=%r" % (k, v) for k, v in kwargs.items())
153 if self.thetype is functools.partial:
154 name = 'functools.partial'
155 else:
156 name = self.thetype.__name__
157
158 f = self.thetype(capture)
159 self.assertEqual('{}({!r})'.format(name, capture),
160 repr(f))
161
162 f = self.thetype(capture, *args)
163 self.assertEqual('{}({!r}, {})'.format(name, capture, args_repr),
164 repr(f))
165
166 f = self.thetype(capture, **kwargs)
167 self.assertEqual('{}({!r}, {})'.format(name, capture, kwargs_repr),
168 repr(f))
169
170 f = self.thetype(capture, *args, **kwargs)
171 self.assertEqual('{}({!r}, {}, {})'.format(name, capture, args_repr, kwargs_repr),
172 repr(f))
173
Jack Diederiche0cbd692009-04-01 04:27:09 +0000174 def test_pickle(self):
175 f = self.thetype(signature, 'asdf', bar=True)
176 f.add_something_to__dict__ = True
177 f_copy = pickle.loads(pickle.dumps(f))
178 self.assertEqual(signature(f), signature(f_copy))
179
Serhiy Storchaka19c4e0d2013-02-04 12:47:24 +0200180 # Issue 6083: Reference counting bug
181 def test_setstate_refcount(self):
182 class BadSequence:
183 def __len__(self):
184 return 4
185 def __getitem__(self, key):
186 if key == 0:
187 return max
188 elif key == 1:
189 return tuple(range(1000000))
190 elif key in (2, 3):
191 return {}
192 raise IndexError
193
194 f = self.thetype(object)
195 self.assertRaisesRegex(SystemError,
196 "new style getargs format but argument is not a tuple",
197 f.__setstate__, BadSequence())
198
Thomas Wouters4d70c3d2006-06-08 14:42:34 +0000199class PartialSubclass(functools.partial):
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000200 pass
201
202class TestPartialSubclass(TestPartial):
203
204 thetype = PartialSubclass
205
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000206class TestPythonPartial(TestPartial):
207
208 thetype = PythonPartial
209
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000210 # the python version hasn't a nice repr
Zachary Ware9fe6d862013-12-08 00:20:35 -0600211 test_repr = None
Alexander Belopolsky41e422a2010-12-01 20:05:49 +0000212
Jack Diederiche0cbd692009-04-01 04:27:09 +0000213 # the python version isn't picklable
Zachary Ware9fe6d862013-12-08 00:20:35 -0600214 test_pickle = test_setstate_refcount = None
215
216 # the python version isn't a type
217 test_attributes = None
Jack Diederiche0cbd692009-04-01 04:27:09 +0000218
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000219class TestUpdateWrapper(unittest.TestCase):
220
221 def check_wrapper(self, wrapper, wrapped,
222 assigned=functools.WRAPPER_ASSIGNMENTS,
223 updated=functools.WRAPPER_UPDATES):
224 # Check attributes were assigned
225 for name in assigned:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000226 self.assertTrue(getattr(wrapper, name) is getattr(wrapped, name))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000227 # Check attributes were updated
228 for name in updated:
229 wrapper_attr = getattr(wrapper, name)
230 wrapped_attr = getattr(wrapped, name)
231 for key in wrapped_attr:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000232 self.assertTrue(wrapped_attr[key] is wrapper_attr[key])
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000233
R. David Murray378c0cf2010-02-24 01:46:21 +0000234 def _default_update(self):
Antoine Pitrou560f7642010-08-04 18:28:02 +0000235 def f(a:'This is a new annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000236 """This is a test"""
237 pass
238 f.attr = 'This is also a test'
Antoine Pitrou560f7642010-08-04 18:28:02 +0000239 def wrapper(b:'This is the prior annotation'):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000240 pass
241 functools.update_wrapper(wrapper, f)
R. David Murray378c0cf2010-02-24 01:46:21 +0000242 return wrapper, f
243
244 def test_default_update(self):
245 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000246 self.check_wrapper(wrapper, f)
Nick Coghlan98876832010-08-17 06:17:18 +0000247 self.assertIs(wrapper.__wrapped__, f)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000248 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600249 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000250 self.assertEqual(wrapper.attr, 'This is also a test')
Antoine Pitrou560f7642010-08-04 18:28:02 +0000251 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
252 self.assertNotIn('b', wrapper.__annotations__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000253
R. David Murray378c0cf2010-02-24 01:46:21 +0000254 @unittest.skipIf(sys.flags.optimize >= 2,
255 "Docstrings are omitted with -O2 and above")
256 def test_default_update_doc(self):
257 wrapper, f = self._default_update()
258 self.assertEqual(wrapper.__doc__, 'This is a test')
259
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000260 def test_no_update(self):
261 def f():
262 """This is a test"""
263 pass
264 f.attr = 'This is also a test'
265 def wrapper():
266 pass
267 functools.update_wrapper(wrapper, f, (), ())
268 self.check_wrapper(wrapper, f, (), ())
269 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600270 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000271 self.assertEqual(wrapper.__doc__, None)
Antoine Pitrou560f7642010-08-04 18:28:02 +0000272 self.assertEqual(wrapper.__annotations__, {})
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000273 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000274
275 def test_selective_update(self):
276 def f():
277 pass
278 f.attr = 'This is a different test'
279 f.dict_attr = dict(a=1, b=2, c=3)
280 def wrapper():
281 pass
282 wrapper.dict_attr = {}
283 assign = ('attr',)
284 update = ('dict_attr',)
285 functools.update_wrapper(wrapper, f, assign, update)
286 self.check_wrapper(wrapper, f, assign, update)
287 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600288 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000289 self.assertEqual(wrapper.__doc__, None)
290 self.assertEqual(wrapper.attr, 'This is a different test')
291 self.assertEqual(wrapper.dict_attr, f.dict_attr)
292
Nick Coghlan98876832010-08-17 06:17:18 +0000293 def test_missing_attributes(self):
294 def f():
295 pass
296 def wrapper():
297 pass
298 wrapper.dict_attr = {}
299 assign = ('attr',)
300 update = ('dict_attr',)
301 # Missing attributes on wrapped object are ignored
302 functools.update_wrapper(wrapper, f, assign, update)
303 self.assertNotIn('attr', wrapper.__dict__)
304 self.assertEqual(wrapper.dict_attr, {})
305 # Wrapper must have expected attributes for updating
306 del wrapper.dict_attr
307 with self.assertRaises(AttributeError):
308 functools.update_wrapper(wrapper, f, assign, update)
309 wrapper.dict_attr = 1
310 with self.assertRaises(AttributeError):
311 functools.update_wrapper(wrapper, f, assign, update)
312
Serhiy Storchaka9d0add02013-01-27 19:47:45 +0200313 @support.requires_docstrings
Nick Coghlan98876832010-08-17 06:17:18 +0000314 @unittest.skipIf(sys.flags.optimize >= 2,
315 "Docstrings are omitted with -O2 and above")
Thomas Wouters89f507f2006-12-13 04:49:30 +0000316 def test_builtin_update(self):
317 # Test for bug #1576241
318 def wrapper():
319 pass
320 functools.update_wrapper(wrapper, max)
321 self.assertEqual(wrapper.__name__, 'max')
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000322 self.assertTrue(wrapper.__doc__.startswith('max('))
Antoine Pitrou560f7642010-08-04 18:28:02 +0000323 self.assertEqual(wrapper.__annotations__, {})
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000324
325class TestWraps(TestUpdateWrapper):
326
R. David Murray378c0cf2010-02-24 01:46:21 +0000327 def _default_update(self):
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000328 def f():
329 """This is a test"""
330 pass
331 f.attr = 'This is also a test'
332 @functools.wraps(f)
333 def wrapper():
334 pass
335 self.check_wrapper(wrapper, f)
Meador Ingeff7f64c2011-12-11 22:37:31 -0600336 return wrapper, f
R. David Murray378c0cf2010-02-24 01:46:21 +0000337
338 def test_default_update(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600339 wrapper, f = self._default_update()
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000340 self.assertEqual(wrapper.__name__, 'f')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600341 self.assertEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000342 self.assertEqual(wrapper.attr, 'This is also a test')
343
Serhiy Storchaka8e0ae2a2013-01-28 13:25:44 +0200344 @unittest.skipIf(sys.flags.optimize >= 2,
R. David Murray378c0cf2010-02-24 01:46:21 +0000345 "Docstrings are omitted with -O2 and above")
346 def test_default_update_doc(self):
Meador Ingeff7f64c2011-12-11 22:37:31 -0600347 wrapper, _ = self._default_update()
R. David Murray378c0cf2010-02-24 01:46:21 +0000348 self.assertEqual(wrapper.__doc__, 'This is a test')
349
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000350 def test_no_update(self):
351 def f():
352 """This is a test"""
353 pass
354 f.attr = 'This is also a test'
355 @functools.wraps(f, (), ())
356 def wrapper():
357 pass
358 self.check_wrapper(wrapper, f, (), ())
359 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600360 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000361 self.assertEqual(wrapper.__doc__, None)
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000362 self.assertFalse(hasattr(wrapper, 'attr'))
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000363
364 def test_selective_update(self):
365 def f():
366 pass
367 f.attr = 'This is a different test'
368 f.dict_attr = dict(a=1, b=2, c=3)
369 def add_dict_attr(f):
370 f.dict_attr = {}
371 return f
372 assign = ('attr',)
373 update = ('dict_attr',)
374 @functools.wraps(f, assign, update)
375 @add_dict_attr
376 def wrapper():
377 pass
378 self.check_wrapper(wrapper, f, assign, update)
379 self.assertEqual(wrapper.__name__, 'wrapper')
Meador Ingeff7f64c2011-12-11 22:37:31 -0600380 self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000381 self.assertEqual(wrapper.__doc__, None)
382 self.assertEqual(wrapper.attr, 'This is a different test')
383 self.assertEqual(wrapper.dict_attr, f.dict_attr)
384
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000385class TestReduce(unittest.TestCase):
386 func = functools.reduce
387
388 def test_reduce(self):
389 class Squares:
390 def __init__(self, max):
391 self.max = max
392 self.sofar = []
393
394 def __len__(self):
395 return len(self.sofar)
396
397 def __getitem__(self, i):
398 if not 0 <= i < self.max: raise IndexError
399 n = len(self.sofar)
400 while n <= i:
401 self.sofar.append(n*n)
402 n += 1
403 return self.sofar[i]
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000404 def add(x, y):
405 return x + y
406 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000407 self.assertEqual(
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000408 self.func(add, [['a', 'c'], [], ['d', 'w']], []),
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000409 ['a','c','d','w']
410 )
411 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
412 self.assertEqual(
Guido van Rossume2a383d2007-01-15 16:59:06 +0000413 self.func(lambda x, y: x*y, range(2,21), 1),
414 2432902008176640000
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000415 )
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000416 self.assertEqual(self.func(add, Squares(10)), 285)
417 self.assertEqual(self.func(add, Squares(10), 0), 285)
418 self.assertEqual(self.func(add, Squares(0), 0), 0)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000419 self.assertRaises(TypeError, self.func)
420 self.assertRaises(TypeError, self.func, 42, 42)
421 self.assertRaises(TypeError, self.func, 42, 42, 42)
422 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
423 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
424 self.assertRaises(TypeError, self.func, 42, (42, 42))
Alexander Belopolskye29e6bf2010-08-16 18:55:46 +0000425 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
426 self.assertRaises(TypeError, self.func, add, "")
427 self.assertRaises(TypeError, self.func, add, ())
428 self.assertRaises(TypeError, self.func, add, object())
429
430 class TestFailingIter:
431 def __iter__(self):
432 raise RuntimeError
433 self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
434
435 self.assertEqual(self.func(add, [], None), None)
436 self.assertEqual(self.func(add, [], 42), 42)
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000437
438 class BadSeq:
439 def __getitem__(self, index):
440 raise ValueError
441 self.assertRaises(ValueError, self.func, 42, BadSeq())
442
443 # Test reduce()'s use of iterators.
444 def test_iterator_usage(self):
445 class SequenceClass:
446 def __init__(self, n):
447 self.n = n
448 def __getitem__(self, i):
449 if 0 <= i < self.n:
450 return i
451 else:
452 raise IndexError
Guido van Rossumd8faa362007-04-27 19:54:29 +0000453
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000454 from operator import add
455 self.assertEqual(self.func(add, SequenceClass(5)), 10)
456 self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
457 self.assertRaises(TypeError, self.func, add, SequenceClass(0))
458 self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
459 self.assertEqual(self.func(add, SequenceClass(1)), 0)
460 self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
461
462 d = {"one": 1, "two": 2, "three": 3}
463 self.assertEqual(self.func(add, d), "".join(d.keys()))
464
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000465class TestCmpToKey(unittest.TestCase):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700466
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000467 def test_cmp_to_key(self):
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700468 def cmp1(x, y):
469 return (x > y) - (x < y)
470 key = functools.cmp_to_key(cmp1)
471 self.assertEqual(key(3), key(3))
472 self.assertGreater(key(3), key(1))
473 def cmp2(x, y):
474 return int(x) - int(y)
475 key = functools.cmp_to_key(cmp2)
476 self.assertEqual(key(4.0), key('4'))
477 self.assertLess(key(2), key('35'))
478
479 def test_cmp_to_key_arguments(self):
480 def cmp1(x, y):
481 return (x > y) - (x < y)
482 key = functools.cmp_to_key(mycmp=cmp1)
483 self.assertEqual(key(obj=3), key(obj=3))
484 self.assertGreater(key(obj=3), key(obj=1))
485 with self.assertRaises((TypeError, AttributeError)):
486 key(3) > 1 # rhs is not a K object
487 with self.assertRaises((TypeError, AttributeError)):
488 1 < key(3) # lhs is not a K object
489 with self.assertRaises(TypeError):
490 key = functools.cmp_to_key() # too few args
491 with self.assertRaises(TypeError):
492 key = functools.cmp_to_key(cmp1, None) # too many args
493 key = functools.cmp_to_key(cmp1)
494 with self.assertRaises(TypeError):
495 key() # too few args
496 with self.assertRaises(TypeError):
497 key(None, None) # too many args
498
499 def test_bad_cmp(self):
500 def cmp1(x, y):
501 raise ZeroDivisionError
502 key = functools.cmp_to_key(cmp1)
503 with self.assertRaises(ZeroDivisionError):
504 key(3) > key(1)
505
506 class BadCmp:
507 def __lt__(self, other):
508 raise ZeroDivisionError
509 def cmp1(x, y):
510 return BadCmp()
511 with self.assertRaises(ZeroDivisionError):
512 key(3) > key(1)
513
514 def test_obj_field(self):
515 def cmp1(x, y):
516 return (x > y) - (x < y)
517 key = functools.cmp_to_key(mycmp=cmp1)
518 self.assertEqual(key(50).obj, 50)
519
520 def test_sort_int(self):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000521 def mycmp(x, y):
522 return y - x
523 self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)),
524 [4, 3, 2, 1, 0])
Guido van Rossumd8faa362007-04-27 19:54:29 +0000525
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700526 def test_sort_int_str(self):
527 def mycmp(x, y):
528 x, y = int(x), int(y)
529 return (x > y) - (x < y)
530 values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
531 values = sorted(values, key=functools.cmp_to_key(mycmp))
532 self.assertEqual([int(value) for value in values],
533 [0, 1, 1, 2, 3, 4, 5, 7, 10])
534
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000535 def test_hash(self):
536 def mycmp(x, y):
537 return y - x
538 key = functools.cmp_to_key(mycmp)
539 k = key(10)
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700540 self.assertRaises(TypeError, hash, k)
Raymond Hettingere7a24302011-05-03 11:16:36 -0700541 self.assertNotIsInstance(k, collections.Hashable)
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000542
543class TestTotalOrdering(unittest.TestCase):
544
545 def test_total_ordering_lt(self):
546 @functools.total_ordering
547 class A:
548 def __init__(self, value):
549 self.value = value
550 def __lt__(self, other):
551 return self.value < other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000552 def __eq__(self, other):
553 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000554 self.assertTrue(A(1) < A(2))
555 self.assertTrue(A(2) > A(1))
556 self.assertTrue(A(1) <= A(2))
557 self.assertTrue(A(2) >= A(1))
558 self.assertTrue(A(2) <= A(2))
559 self.assertTrue(A(2) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000560
561 def test_total_ordering_le(self):
562 @functools.total_ordering
563 class A:
564 def __init__(self, value):
565 self.value = value
566 def __le__(self, other):
567 return self.value <= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000568 def __eq__(self, other):
569 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000570 self.assertTrue(A(1) < A(2))
571 self.assertTrue(A(2) > A(1))
572 self.assertTrue(A(1) <= A(2))
573 self.assertTrue(A(2) >= A(1))
574 self.assertTrue(A(2) <= A(2))
575 self.assertTrue(A(2) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000576
577 def test_total_ordering_gt(self):
578 @functools.total_ordering
579 class A:
580 def __init__(self, value):
581 self.value = value
582 def __gt__(self, other):
583 return self.value > other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000584 def __eq__(self, other):
585 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000586 self.assertTrue(A(1) < A(2))
587 self.assertTrue(A(2) > A(1))
588 self.assertTrue(A(1) <= A(2))
589 self.assertTrue(A(2) >= A(1))
590 self.assertTrue(A(2) <= A(2))
591 self.assertTrue(A(2) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000592
593 def test_total_ordering_ge(self):
594 @functools.total_ordering
595 class A:
596 def __init__(self, value):
597 self.value = value
598 def __ge__(self, other):
599 return self.value >= other.value
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000600 def __eq__(self, other):
601 return self.value == other.value
Ezio Melottib3aedd42010-11-20 19:04:17 +0000602 self.assertTrue(A(1) < A(2))
603 self.assertTrue(A(2) > A(1))
604 self.assertTrue(A(1) <= A(2))
605 self.assertTrue(A(2) >= A(1))
606 self.assertTrue(A(2) <= A(2))
607 self.assertTrue(A(2) >= A(2))
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000608
609 def test_total_ordering_no_overwrite(self):
610 # new methods should not overwrite existing
611 @functools.total_ordering
612 class A(int):
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000613 pass
Ezio Melottib3aedd42010-11-20 19:04:17 +0000614 self.assertTrue(A(1) < A(2))
615 self.assertTrue(A(2) > A(1))
616 self.assertTrue(A(1) <= A(2))
617 self.assertTrue(A(2) >= A(1))
618 self.assertTrue(A(2) <= A(2))
619 self.assertTrue(A(2) >= A(2))
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000620
Benjamin Peterson42ebee32010-04-11 01:43:16 +0000621 def test_no_operations_defined(self):
622 with self.assertRaises(ValueError):
623 @functools.total_ordering
624 class A:
625 pass
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000626
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000627 def test_bug_10042(self):
628 @functools.total_ordering
629 class TestTO:
630 def __init__(self, value):
631 self.value = value
632 def __eq__(self, other):
633 if isinstance(other, TestTO):
634 return self.value == other.value
635 return False
636 def __lt__(self, other):
637 if isinstance(other, TestTO):
638 return self.value < other.value
639 raise TypeError
640 with self.assertRaises(TypeError):
641 TestTO(8) <= ()
642
Georg Brandl2e7346a2010-07-31 18:09:23 +0000643class TestLRU(unittest.TestCase):
644
645 def test_lru(self):
646 def orig(x, y):
647 return 3*x+y
648 f = functools.lru_cache(maxsize=20)(orig)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000649 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000650 self.assertEqual(maxsize, 20)
651 self.assertEqual(currsize, 0)
652 self.assertEqual(hits, 0)
653 self.assertEqual(misses, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000654
655 domain = range(5)
656 for i in range(1000):
657 x, y = choice(domain), choice(domain)
658 actual = f(x, y)
659 expected = orig(x, y)
Ezio Melottib3aedd42010-11-20 19:04:17 +0000660 self.assertEqual(actual, expected)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000661 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000662 self.assertTrue(hits > misses)
663 self.assertEqual(hits + misses, 1000)
664 self.assertEqual(currsize, 20)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000665
Raymond Hettinger02566ec2010-09-04 22:46:06 +0000666 f.cache_clear() # test clearing
Raymond Hettinger7496b412010-11-30 19:15:45 +0000667 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000668 self.assertEqual(hits, 0)
669 self.assertEqual(misses, 0)
670 self.assertEqual(currsize, 0)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000671 f(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000672 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000673 self.assertEqual(hits, 0)
674 self.assertEqual(misses, 1)
675 self.assertEqual(currsize, 1)
Georg Brandl2e7346a2010-07-31 18:09:23 +0000676
Nick Coghlan98876832010-08-17 06:17:18 +0000677 # Test bypassing the cache
678 self.assertIs(f.__wrapped__, orig)
679 f.__wrapped__(x, y)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000680 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000681 self.assertEqual(hits, 0)
682 self.assertEqual(misses, 1)
683 self.assertEqual(currsize, 1)
Nick Coghlan98876832010-08-17 06:17:18 +0000684
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000685 # test size zero (which means "never-cache")
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000686 @functools.lru_cache(0)
687 def f():
688 nonlocal f_cnt
689 f_cnt += 1
690 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000691 self.assertEqual(f.cache_info().maxsize, 0)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000692 f_cnt = 0
693 for i in range(5):
694 self.assertEqual(f(), 20)
695 self.assertEqual(f_cnt, 5)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000696 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000697 self.assertEqual(hits, 0)
698 self.assertEqual(misses, 5)
699 self.assertEqual(currsize, 0)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000700
701 # test size one
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000702 @functools.lru_cache(1)
703 def f():
704 nonlocal f_cnt
705 f_cnt += 1
706 return 20
Nick Coghlan234515a2010-11-30 06:19:46 +0000707 self.assertEqual(f.cache_info().maxsize, 1)
Raymond Hettingerf3098282010-08-15 03:30:45 +0000708 f_cnt = 0
709 for i in range(5):
710 self.assertEqual(f(), 20)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000711 self.assertEqual(f_cnt, 1)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000712 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000713 self.assertEqual(hits, 4)
714 self.assertEqual(misses, 1)
715 self.assertEqual(currsize, 1)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000716
Raymond Hettingerf3098282010-08-15 03:30:45 +0000717 # test size two
718 @functools.lru_cache(2)
719 def f(x):
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000720 nonlocal f_cnt
721 f_cnt += 1
Raymond Hettingerf3098282010-08-15 03:30:45 +0000722 return x*10
Nick Coghlan234515a2010-11-30 06:19:46 +0000723 self.assertEqual(f.cache_info().maxsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000724 f_cnt = 0
Raymond Hettingerf3098282010-08-15 03:30:45 +0000725 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
726 # * * * *
727 self.assertEqual(f(x), x*10)
728 self.assertEqual(f_cnt, 4)
Raymond Hettinger7496b412010-11-30 19:15:45 +0000729 hits, misses, maxsize, currsize = f.cache_info()
Nick Coghlan234515a2010-11-30 06:19:46 +0000730 self.assertEqual(hits, 12)
731 self.assertEqual(misses, 4)
732 self.assertEqual(currsize, 2)
Raymond Hettinger0f56e902010-08-14 23:52:08 +0000733
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +0000734 def test_lru_with_maxsize_none(self):
735 @functools.lru_cache(maxsize=None)
736 def fib(n):
737 if n < 2:
738 return n
739 return fib(n-1) + fib(n-2)
740 self.assertEqual([fib(n) for n in range(16)],
741 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
742 self.assertEqual(fib.cache_info(),
743 functools._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
744 fib.cache_clear()
745 self.assertEqual(fib.cache_info(),
746 functools._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
747
Raymond Hettinger4b779b32011-10-15 23:50:42 -0700748 def test_lru_with_exceptions(self):
749 # Verify that user_function exceptions get passed through without
750 # creating a hard-to-read chained exception.
751 # http://bugs.python.org/issue13177
752 for maxsize in (None, 100):
753 @functools.lru_cache(maxsize)
754 def func(i):
755 return 'abc'[i]
756 self.assertEqual(func(0), 'a')
757 with self.assertRaises(IndexError) as cm:
758 func(15)
759 self.assertIsNone(cm.exception.__context__)
760 # Verify that the previous exception did not result in a cached entry
761 with self.assertRaises(IndexError):
762 func(15)
763
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -0700764 def test_lru_with_types(self):
765 for maxsize in (None, 100):
766 @functools.lru_cache(maxsize=maxsize, typed=True)
767 def square(x):
768 return x * x
769 self.assertEqual(square(3), 9)
770 self.assertEqual(type(square(3)), type(9))
771 self.assertEqual(square(3.0), 9.0)
772 self.assertEqual(type(square(3.0)), type(9.0))
773 self.assertEqual(square(x=3), 9)
774 self.assertEqual(type(square(x=3)), type(9))
775 self.assertEqual(square(x=3.0), 9.0)
776 self.assertEqual(type(square(x=3.0)), type(9.0))
777 self.assertEqual(square.cache_info().hits, 4)
778 self.assertEqual(square.cache_info().misses, 4)
779
Raymond Hettinger03923422013-03-04 02:52:50 -0500780 def test_need_for_rlock(self):
781 # This will deadlock on an LRU cache that uses a regular lock
782
783 @functools.lru_cache(maxsize=10)
784 def test_func(x):
785 'Used to demonstrate a reentrant lru_cache call within a single thread'
786 return x
787
788 class DoubleEq:
789 'Demonstrate a reentrant lru_cache call within a single thread'
790 def __init__(self, x):
791 self.x = x
792 def __hash__(self):
793 return self.x
794 def __eq__(self, other):
795 if self.x == 2:
796 test_func(DoubleEq(1))
797 return self.x == other.x
798
799 test_func(DoubleEq(1)) # Load the cache
800 test_func(DoubleEq(2)) # Load the cache
801 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call
802 DoubleEq(2)) # Verify the correct return value
803
804
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000805def test_main(verbose=None):
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000806 test_classes = (
807 TestPartial,
808 TestPartialSubclass,
809 TestPythonPartial,
Thomas Wouters73e5a5b2006-06-08 15:35:45 +0000810 TestUpdateWrapper,
Benjamin Peterson9c2930e2010-08-23 17:40:33 +0000811 TestTotalOrdering,
Raymond Hettinger003be522011-05-03 11:01:32 -0700812 TestCmpToKey,
Guido van Rossum0919a1a2006-08-26 20:49:04 +0000813 TestWraps,
Georg Brandl2e7346a2010-07-31 18:09:23 +0000814 TestReduce,
815 TestLRU,
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000816 )
Benjamin Petersonee8712c2008-05-20 21:35:26 +0000817 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000818
819 # verify reference counting
820 if verbose and hasattr(sys, "gettotalrefcount"):
821 import gc
822 counts = [None] * 5
Guido van Rossum805365e2007-05-07 22:24:25 +0000823 for i in range(len(counts)):
Benjamin Petersonee8712c2008-05-20 21:35:26 +0000824 support.run_unittest(*test_classes)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000825 gc.collect()
826 counts[i] = sys.gettotalrefcount()
Guido van Rossumbe19ed72007-02-09 05:37:30 +0000827 print(counts)
Raymond Hettinger9c323f82005-02-28 19:39:44 +0000828
829if __name__ == '__main__':
830 test_main(verbose=True)