blob: e52ed91a585c8a68bab1ce08ca265dce3f3a249c [file] [log] [blame]
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00001"""Unit tests for contextlib.py, and other context managers."""
2
R. David Murray378c0cf2010-02-24 01:46:21 +00003import sys
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00004import tempfile
5import unittest
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00006from contextlib import * # Tests __all__
Benjamin Petersonee8712c2008-05-20 21:35:26 +00007from test import support
Victor Stinner45df8202010-04-28 22:31:17 +00008try:
9 import threading
10except ImportError:
11 threading = None
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000012
Florent Xicluna41fe6152010-04-02 18:52:12 +000013
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000014class ContextManagerTestCase(unittest.TestCase):
15
16 def test_contextmanager_plain(self):
17 state = []
18 @contextmanager
19 def woohoo():
20 state.append(1)
21 yield 42
22 state.append(999)
23 with woohoo() as x:
24 self.assertEqual(state, [1])
25 self.assertEqual(x, 42)
26 state.append(x)
27 self.assertEqual(state, [1, 42, 999])
28
29 def test_contextmanager_finally(self):
30 state = []
31 @contextmanager
32 def woohoo():
33 state.append(1)
34 try:
35 yield 42
36 finally:
37 state.append(999)
Florent Xicluna41fe6152010-04-02 18:52:12 +000038 with self.assertRaises(ZeroDivisionError):
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000039 with woohoo() as x:
40 self.assertEqual(state, [1])
41 self.assertEqual(x, 42)
42 state.append(x)
43 raise ZeroDivisionError()
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000044 self.assertEqual(state, [1, 42, 999])
45
Thomas Wouters49fd7fa2006-04-21 10:40:58 +000046 def test_contextmanager_no_reraise(self):
47 @contextmanager
48 def whee():
49 yield
Thomas Wouters477c8d52006-05-27 19:21:47 +000050 ctx = whee()
Thomas Wouters49fd7fa2006-04-21 10:40:58 +000051 ctx.__enter__()
52 # Calling __exit__ should not result in an exception
Benjamin Petersonc9c0f202009-06-30 23:06:06 +000053 self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None))
Thomas Wouters49fd7fa2006-04-21 10:40:58 +000054
55 def test_contextmanager_trap_yield_after_throw(self):
56 @contextmanager
57 def whoo():
58 try:
59 yield
60 except:
61 yield
Thomas Wouters477c8d52006-05-27 19:21:47 +000062 ctx = whoo()
Thomas Wouters49fd7fa2006-04-21 10:40:58 +000063 ctx.__enter__()
64 self.assertRaises(
65 RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None
66 )
67
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000068 def test_contextmanager_except(self):
69 state = []
70 @contextmanager
71 def woohoo():
72 state.append(1)
73 try:
74 yield 42
Guido van Rossumb940e112007-01-10 16:19:56 +000075 except ZeroDivisionError as e:
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000076 state.append(e.args[0])
77 self.assertEqual(state, [1, 42, 999])
78 with woohoo() as x:
79 self.assertEqual(state, [1])
80 self.assertEqual(x, 42)
81 state.append(x)
82 raise ZeroDivisionError(999)
83 self.assertEqual(state, [1, 42, 999])
84
R. David Murray378c0cf2010-02-24 01:46:21 +000085 def _create_contextmanager_attribs(self):
Thomas Wouters49fd7fa2006-04-21 10:40:58 +000086 def attribs(**kw):
87 def decorate(func):
88 for k,v in kw.items():
89 setattr(func,k,v)
90 return func
91 return decorate
92 @contextmanager
93 @attribs(foo='bar')
94 def baz(spam):
95 """Whee!"""
R. David Murray378c0cf2010-02-24 01:46:21 +000096 return baz
97
98 def test_contextmanager_attribs(self):
99 baz = self._create_contextmanager_attribs()
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000100 self.assertEqual(baz.__name__,'baz')
101 self.assertEqual(baz.foo, 'bar')
R. David Murray378c0cf2010-02-24 01:46:21 +0000102
103 @unittest.skipIf(sys.flags.optimize >= 2,
104 "Docstrings are omitted with -O2 and above")
105 def test_contextmanager_doc_attrib(self):
106 baz = self._create_contextmanager_attribs()
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000107 self.assertEqual(baz.__doc__, "Whee!")
108
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000109class ClosingTestCase(unittest.TestCase):
110
111 # XXX This needs more work
112
113 def test_closing(self):
114 state = []
115 class C:
116 def close(self):
117 state.append(1)
118 x = C()
119 self.assertEqual(state, [])
120 with closing(x) as y:
121 self.assertEqual(x, y)
122 self.assertEqual(state, [1])
123
124 def test_closing_error(self):
125 state = []
126 class C:
127 def close(self):
128 state.append(1)
129 x = C()
130 self.assertEqual(state, [])
Florent Xicluna41fe6152010-04-02 18:52:12 +0000131 with self.assertRaises(ZeroDivisionError):
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000132 with closing(x) as y:
133 self.assertEqual(x, y)
Florent Xicluna41fe6152010-04-02 18:52:12 +0000134 1 / 0
135 self.assertEqual(state, [1])
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000136
137class FileContextTestCase(unittest.TestCase):
138
139 def testWithOpen(self):
140 tfn = tempfile.mktemp()
141 try:
142 f = None
143 with open(tfn, "w") as f:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000144 self.assertFalse(f.closed)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000145 f.write("Booh\n")
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000146 self.assertTrue(f.closed)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000147 f = None
Florent Xicluna41fe6152010-04-02 18:52:12 +0000148 with self.assertRaises(ZeroDivisionError):
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000149 with open(tfn, "r") as f:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000150 self.assertFalse(f.closed)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000151 self.assertEqual(f.read(), "Booh\n")
Florent Xicluna41fe6152010-04-02 18:52:12 +0000152 1 / 0
153 self.assertTrue(f.closed)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000154 finally:
Florent Xicluna41fe6152010-04-02 18:52:12 +0000155 support.unlink(tfn)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000156
Victor Stinner45df8202010-04-28 22:31:17 +0000157@unittest.skipUnless(threading, 'Threading required for this test.')
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000158class LockContextTestCase(unittest.TestCase):
159
160 def boilerPlate(self, lock, locked):
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000161 self.assertFalse(locked())
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000162 with lock:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000163 self.assertTrue(locked())
164 self.assertFalse(locked())
Florent Xicluna41fe6152010-04-02 18:52:12 +0000165 with self.assertRaises(ZeroDivisionError):
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000166 with lock:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000167 self.assertTrue(locked())
Florent Xicluna41fe6152010-04-02 18:52:12 +0000168 1 / 0
169 self.assertFalse(locked())
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000170
171 def testWithLock(self):
172 lock = threading.Lock()
173 self.boilerPlate(lock, lock.locked)
174
175 def testWithRLock(self):
176 lock = threading.RLock()
177 self.boilerPlate(lock, lock._is_owned)
178
179 def testWithCondition(self):
180 lock = threading.Condition()
181 def locked():
182 return lock._is_owned()
183 self.boilerPlate(lock, locked)
184
185 def testWithSemaphore(self):
186 lock = threading.Semaphore()
187 def locked():
188 if lock.acquire(False):
189 lock.release()
190 return False
191 else:
192 return True
193 self.boilerPlate(lock, locked)
194
195 def testWithBoundedSemaphore(self):
196 lock = threading.BoundedSemaphore()
197 def locked():
198 if lock.acquire(False):
199 lock.release()
200 return False
201 else:
202 return True
203 self.boilerPlate(lock, locked)
204
Michael Foordb3a89842010-06-30 12:17:50 +0000205
206class mycontext(ContextDecorator):
207 started = False
208 exc = None
209 catch = False
210
211 def __enter__(self):
212 self.started = True
213 return self
214
215 def __exit__(self, *exc):
216 self.exc = exc
217 return self.catch
218
219
220class TestContextDecorator(unittest.TestCase):
221
222 def test_contextdecorator(self):
223 context = mycontext()
224 with context as result:
225 self.assertIs(result, context)
226 self.assertTrue(context.started)
227
228 self.assertEqual(context.exc, (None, None, None))
229
230
231 def test_contextdecorator_with_exception(self):
232 context = mycontext()
233
Ezio Melottied3a7d22010-12-01 02:32:32 +0000234 with self.assertRaisesRegex(NameError, 'foo'):
Michael Foordb3a89842010-06-30 12:17:50 +0000235 with context:
236 raise NameError('foo')
237 self.assertIsNotNone(context.exc)
238 self.assertIs(context.exc[0], NameError)
239
240 context = mycontext()
241 context.catch = True
242 with context:
243 raise NameError('foo')
244 self.assertIsNotNone(context.exc)
245 self.assertIs(context.exc[0], NameError)
246
247
248 def test_decorator(self):
249 context = mycontext()
250
251 @context
252 def test():
253 self.assertIsNone(context.exc)
254 self.assertTrue(context.started)
255 test()
256 self.assertEqual(context.exc, (None, None, None))
257
258
259 def test_decorator_with_exception(self):
260 context = mycontext()
261
262 @context
263 def test():
264 self.assertIsNone(context.exc)
265 self.assertTrue(context.started)
266 raise NameError('foo')
267
Ezio Melottied3a7d22010-12-01 02:32:32 +0000268 with self.assertRaisesRegex(NameError, 'foo'):
Michael Foordb3a89842010-06-30 12:17:50 +0000269 test()
270 self.assertIsNotNone(context.exc)
271 self.assertIs(context.exc[0], NameError)
272
273
274 def test_decorating_method(self):
275 context = mycontext()
276
277 class Test(object):
278
279 @context
280 def method(self, a, b, c=None):
281 self.a = a
282 self.b = b
283 self.c = c
284
285 # these tests are for argument passing when used as a decorator
286 test = Test()
287 test.method(1, 2)
288 self.assertEqual(test.a, 1)
289 self.assertEqual(test.b, 2)
290 self.assertEqual(test.c, None)
291
292 test = Test()
293 test.method('a', 'b', 'c')
294 self.assertEqual(test.a, 'a')
295 self.assertEqual(test.b, 'b')
296 self.assertEqual(test.c, 'c')
297
298 test = Test()
299 test.method(a=1, b=2)
300 self.assertEqual(test.a, 1)
301 self.assertEqual(test.b, 2)
302
303
304 def test_typo_enter(self):
305 class mycontext(ContextDecorator):
306 def __unter__(self):
307 pass
308 def __exit__(self, *exc):
309 pass
310
311 with self.assertRaises(AttributeError):
312 with mycontext():
313 pass
314
315
316 def test_typo_exit(self):
317 class mycontext(ContextDecorator):
318 def __enter__(self):
319 pass
320 def __uxit__(self, *exc):
321 pass
322
323 with self.assertRaises(AttributeError):
324 with mycontext():
325 pass
326
327
328 def test_contextdecorator_as_mixin(self):
329 class somecontext(object):
330 started = False
331 exc = None
332
333 def __enter__(self):
334 self.started = True
335 return self
336
337 def __exit__(self, *exc):
338 self.exc = exc
339
340 class mycontext(somecontext, ContextDecorator):
341 pass
342
343 context = mycontext()
344 @context
345 def test():
346 self.assertIsNone(context.exc)
347 self.assertTrue(context.started)
348 test()
349 self.assertEqual(context.exc, (None, None, None))
350
351
352 def test_contextmanager_as_decorator(self):
Michael Foordb3a89842010-06-30 12:17:50 +0000353 @contextmanager
354 def woohoo(y):
355 state.append(y)
356 yield
357 state.append(999)
358
Nick Coghlan0ded3e32011-05-05 23:49:25 +1000359 state = []
Michael Foordb3a89842010-06-30 12:17:50 +0000360 @woohoo(1)
361 def test(x):
362 self.assertEqual(state, [1])
363 state.append(x)
364 test('something')
365 self.assertEqual(state, [1, 'something', 999])
366
Nick Coghlan0ded3e32011-05-05 23:49:25 +1000367 # Issue #11647: Ensure the decorated function is 'reusable'
368 state = []
369 test('something else')
370 self.assertEqual(state, [1, 'something else', 999])
371
Michael Foordb3a89842010-06-30 12:17:50 +0000372
Nick Coghlan3267a302012-05-21 22:54:43 +1000373class TestExitStack(unittest.TestCase):
374
375 def test_no_resources(self):
376 with ExitStack():
377 pass
378
379 def test_callback(self):
380 expected = [
381 ((), {}),
382 ((1,), {}),
383 ((1,2), {}),
384 ((), dict(example=1)),
385 ((1,), dict(example=1)),
386 ((1,2), dict(example=1)),
387 ]
388 result = []
389 def _exit(*args, **kwds):
390 """Test metadata propagation"""
391 result.append((args, kwds))
392 with ExitStack() as stack:
393 for args, kwds in reversed(expected):
394 if args and kwds:
395 f = stack.callback(_exit, *args, **kwds)
396 elif args:
397 f = stack.callback(_exit, *args)
398 elif kwds:
399 f = stack.callback(_exit, **kwds)
400 else:
401 f = stack.callback(_exit)
402 self.assertIs(f, _exit)
403 for wrapper in stack._exit_callbacks:
404 self.assertIs(wrapper.__wrapped__, _exit)
405 self.assertNotEqual(wrapper.__name__, _exit.__name__)
406 self.assertIsNone(wrapper.__doc__, _exit.__doc__)
407 self.assertEqual(result, expected)
408
409 def test_push(self):
410 exc_raised = ZeroDivisionError
411 def _expect_exc(exc_type, exc, exc_tb):
412 self.assertIs(exc_type, exc_raised)
413 def _suppress_exc(*exc_details):
414 return True
415 def _expect_ok(exc_type, exc, exc_tb):
416 self.assertIsNone(exc_type)
417 self.assertIsNone(exc)
418 self.assertIsNone(exc_tb)
419 class ExitCM(object):
420 def __init__(self, check_exc):
421 self.check_exc = check_exc
422 def __enter__(self):
423 self.fail("Should not be called!")
424 def __exit__(self, *exc_details):
425 self.check_exc(*exc_details)
426 with ExitStack() as stack:
427 stack.push(_expect_ok)
428 self.assertIs(stack._exit_callbacks[-1], _expect_ok)
429 cm = ExitCM(_expect_ok)
430 stack.push(cm)
431 self.assertIs(stack._exit_callbacks[-1].__self__, cm)
432 stack.push(_suppress_exc)
433 self.assertIs(stack._exit_callbacks[-1], _suppress_exc)
434 cm = ExitCM(_expect_exc)
435 stack.push(cm)
436 self.assertIs(stack._exit_callbacks[-1].__self__, cm)
437 stack.push(_expect_exc)
438 self.assertIs(stack._exit_callbacks[-1], _expect_exc)
439 stack.push(_expect_exc)
440 self.assertIs(stack._exit_callbacks[-1], _expect_exc)
441 1/0
442
443 def test_enter_context(self):
444 class TestCM(object):
445 def __enter__(self):
446 result.append(1)
447 def __exit__(self, *exc_details):
448 result.append(3)
449
450 result = []
451 cm = TestCM()
452 with ExitStack() as stack:
453 @stack.callback # Registered first => cleaned up last
454 def _exit():
455 result.append(4)
456 self.assertIsNotNone(_exit)
457 stack.enter_context(cm)
458 self.assertIs(stack._exit_callbacks[-1].__self__, cm)
459 result.append(2)
460 self.assertEqual(result, [1, 2, 3, 4])
461
462 def test_close(self):
463 result = []
464 with ExitStack() as stack:
465 @stack.callback
466 def _exit():
467 result.append(1)
468 self.assertIsNotNone(_exit)
469 stack.close()
470 result.append(2)
471 self.assertEqual(result, [1, 2])
472
473 def test_pop_all(self):
474 result = []
475 with ExitStack() as stack:
476 @stack.callback
477 def _exit():
478 result.append(3)
479 self.assertIsNotNone(_exit)
480 new_stack = stack.pop_all()
481 result.append(1)
482 result.append(2)
483 new_stack.close()
484 self.assertEqual(result, [1, 2, 3])
485
Nick Coghlanc73e8c22012-05-31 23:49:26 +1000486 def test_exit_raise(self):
487 with self.assertRaises(ZeroDivisionError):
488 with ExitStack() as stack:
489 stack.push(lambda *exc: False)
490 1/0
491
492 def test_exit_suppress(self):
493 with ExitStack() as stack:
494 stack.push(lambda *exc: True)
495 1/0
496
497 def test_exit_exception_chaining_reference(self):
498 # Sanity check to make sure that ExitStack chaining matches
499 # actual nested with statements
500 class RaiseExc:
501 def __init__(self, exc):
502 self.exc = exc
503 def __enter__(self):
504 return self
505 def __exit__(self, *exc_details):
506 raise self.exc
507
Nick Coghlan77452fc2012-06-01 22:48:32 +1000508 class RaiseExcWithContext:
509 def __init__(self, outer, inner):
510 self.outer = outer
511 self.inner = inner
512 def __enter__(self):
513 return self
514 def __exit__(self, *exc_details):
515 try:
516 raise self.inner
517 except:
518 raise self.outer
519
Nick Coghlanc73e8c22012-05-31 23:49:26 +1000520 class SuppressExc:
521 def __enter__(self):
522 return self
523 def __exit__(self, *exc_details):
524 type(self).saved_details = exc_details
525 return True
526
527 try:
528 with RaiseExc(IndexError):
Nick Coghlan77452fc2012-06-01 22:48:32 +1000529 with RaiseExcWithContext(KeyError, AttributeError):
530 with SuppressExc():
531 with RaiseExc(ValueError):
532 1 / 0
Nick Coghlanc73e8c22012-05-31 23:49:26 +1000533 except IndexError as exc:
534 self.assertIsInstance(exc.__context__, KeyError)
535 self.assertIsInstance(exc.__context__.__context__, AttributeError)
536 # Inner exceptions were suppressed
537 self.assertIsNone(exc.__context__.__context__.__context__)
538 else:
539 self.fail("Expected IndexError, but no exception was raised")
540 # Check the inner exceptions
541 inner_exc = SuppressExc.saved_details[1]
542 self.assertIsInstance(inner_exc, ValueError)
543 self.assertIsInstance(inner_exc.__context__, ZeroDivisionError)
544
545 def test_exit_exception_chaining(self):
546 # Ensure exception chaining matches the reference behaviour
547 def raise_exc(exc):
548 raise exc
549
550 saved_details = None
551 def suppress_exc(*exc_details):
552 nonlocal saved_details
553 saved_details = exc_details
554 return True
555
556 try:
557 with ExitStack() as stack:
558 stack.callback(raise_exc, IndexError)
559 stack.callback(raise_exc, KeyError)
560 stack.callback(raise_exc, AttributeError)
561 stack.push(suppress_exc)
562 stack.callback(raise_exc, ValueError)
563 1 / 0
564 except IndexError as exc:
565 self.assertIsInstance(exc.__context__, KeyError)
566 self.assertIsInstance(exc.__context__.__context__, AttributeError)
Nick Coghlan77452fc2012-06-01 22:48:32 +1000567 # Inner exceptions were suppressed
568 self.assertIsNone(exc.__context__.__context__.__context__)
Nick Coghlanc73e8c22012-05-31 23:49:26 +1000569 else:
570 self.fail("Expected IndexError, but no exception was raised")
571 # Check the inner exceptions
572 inner_exc = saved_details[1]
573 self.assertIsInstance(inner_exc, ValueError)
574 self.assertIsInstance(inner_exc.__context__, ZeroDivisionError)
575
576 def test_exit_exception_chaining_suppress(self):
577 with ExitStack() as stack:
578 stack.push(lambda *exc: True)
579 stack.push(lambda *exc: 1/0)
580 stack.push(lambda *exc: {}[1])
581
Nick Coghlana5bd2a12012-06-01 00:00:38 +1000582 def test_excessive_nesting(self):
583 # The original implementation would die with RecursionError here
584 with ExitStack() as stack:
585 for i in range(10000):
586 stack.callback(int)
587
Nick Coghlan3267a302012-05-21 22:54:43 +1000588 def test_instance_bypass(self):
589 class Example(object): pass
590 cm = Example()
591 cm.__exit__ = object()
592 stack = ExitStack()
593 self.assertRaises(AttributeError, stack.enter_context, cm)
594 stack.push(cm)
595 self.assertIs(stack._exit_callbacks[-1], cm)
596
597
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000598# This is needed to make the test actually run under regrtest.py!
599def test_main():
Benjamin Petersonc8c0d782009-07-01 01:39:51 +0000600 support.run_unittest(__name__)
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000601
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000602if __name__ == "__main__":
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000603 test_main()