| """Unit tests for contextlib.py, and other context managers.""" | 
 |  | 
 | import sys | 
 | import tempfile | 
 | import unittest | 
 | from contextlib import *  # Tests __all__ | 
 | from test import support | 
 | try: | 
 |     import threading | 
 | except ImportError: | 
 |     threading = None | 
 |  | 
 |  | 
 | class ContextManagerTestCase(unittest.TestCase): | 
 |  | 
 |     def test_contextmanager_plain(self): | 
 |         state = [] | 
 |         @contextmanager | 
 |         def woohoo(): | 
 |             state.append(1) | 
 |             yield 42 | 
 |             state.append(999) | 
 |         with woohoo() as x: | 
 |             self.assertEqual(state, [1]) | 
 |             self.assertEqual(x, 42) | 
 |             state.append(x) | 
 |         self.assertEqual(state, [1, 42, 999]) | 
 |  | 
 |     def test_contextmanager_finally(self): | 
 |         state = [] | 
 |         @contextmanager | 
 |         def woohoo(): | 
 |             state.append(1) | 
 |             try: | 
 |                 yield 42 | 
 |             finally: | 
 |                 state.append(999) | 
 |         with self.assertRaises(ZeroDivisionError): | 
 |             with woohoo() as x: | 
 |                 self.assertEqual(state, [1]) | 
 |                 self.assertEqual(x, 42) | 
 |                 state.append(x) | 
 |                 raise ZeroDivisionError() | 
 |         self.assertEqual(state, [1, 42, 999]) | 
 |  | 
 |     def test_contextmanager_no_reraise(self): | 
 |         @contextmanager | 
 |         def whee(): | 
 |             yield | 
 |         ctx = whee() | 
 |         ctx.__enter__() | 
 |         # Calling __exit__ should not result in an exception | 
 |         self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None)) | 
 |  | 
 |     def test_contextmanager_trap_yield_after_throw(self): | 
 |         @contextmanager | 
 |         def whoo(): | 
 |             try: | 
 |                 yield | 
 |             except: | 
 |                 yield | 
 |         ctx = whoo() | 
 |         ctx.__enter__() | 
 |         self.assertRaises( | 
 |             RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None | 
 |         ) | 
 |  | 
 |     def test_contextmanager_except(self): | 
 |         state = [] | 
 |         @contextmanager | 
 |         def woohoo(): | 
 |             state.append(1) | 
 |             try: | 
 |                 yield 42 | 
 |             except ZeroDivisionError as e: | 
 |                 state.append(e.args[0]) | 
 |                 self.assertEqual(state, [1, 42, 999]) | 
 |         with woohoo() as x: | 
 |             self.assertEqual(state, [1]) | 
 |             self.assertEqual(x, 42) | 
 |             state.append(x) | 
 |             raise ZeroDivisionError(999) | 
 |         self.assertEqual(state, [1, 42, 999]) | 
 |  | 
 |     def _create_contextmanager_attribs(self): | 
 |         def attribs(**kw): | 
 |             def decorate(func): | 
 |                 for k,v in kw.items(): | 
 |                     setattr(func,k,v) | 
 |                 return func | 
 |             return decorate | 
 |         @contextmanager | 
 |         @attribs(foo='bar') | 
 |         def baz(spam): | 
 |             """Whee!""" | 
 |         return baz | 
 |  | 
 |     def test_contextmanager_attribs(self): | 
 |         baz = self._create_contextmanager_attribs() | 
 |         self.assertEqual(baz.__name__,'baz') | 
 |         self.assertEqual(baz.foo, 'bar') | 
 |  | 
 |     @unittest.skipIf(sys.flags.optimize >= 2, | 
 |                      "Docstrings are omitted with -O2 and above") | 
 |     def test_contextmanager_doc_attrib(self): | 
 |         baz = self._create_contextmanager_attribs() | 
 |         self.assertEqual(baz.__doc__, "Whee!") | 
 |  | 
 | class ClosingTestCase(unittest.TestCase): | 
 |  | 
 |     # XXX This needs more work | 
 |  | 
 |     def test_closing(self): | 
 |         state = [] | 
 |         class C: | 
 |             def close(self): | 
 |                 state.append(1) | 
 |         x = C() | 
 |         self.assertEqual(state, []) | 
 |         with closing(x) as y: | 
 |             self.assertEqual(x, y) | 
 |         self.assertEqual(state, [1]) | 
 |  | 
 |     def test_closing_error(self): | 
 |         state = [] | 
 |         class C: | 
 |             def close(self): | 
 |                 state.append(1) | 
 |         x = C() | 
 |         self.assertEqual(state, []) | 
 |         with self.assertRaises(ZeroDivisionError): | 
 |             with closing(x) as y: | 
 |                 self.assertEqual(x, y) | 
 |                 1 / 0 | 
 |         self.assertEqual(state, [1]) | 
 |  | 
 | class FileContextTestCase(unittest.TestCase): | 
 |  | 
 |     def testWithOpen(self): | 
 |         tfn = tempfile.mktemp() | 
 |         try: | 
 |             f = None | 
 |             with open(tfn, "w") as f: | 
 |                 self.assertFalse(f.closed) | 
 |                 f.write("Booh\n") | 
 |             self.assertTrue(f.closed) | 
 |             f = None | 
 |             with self.assertRaises(ZeroDivisionError): | 
 |                 with open(tfn, "r") as f: | 
 |                     self.assertFalse(f.closed) | 
 |                     self.assertEqual(f.read(), "Booh\n") | 
 |                     1 / 0 | 
 |             self.assertTrue(f.closed) | 
 |         finally: | 
 |             support.unlink(tfn) | 
 |  | 
 | @unittest.skipUnless(threading, 'Threading required for this test.') | 
 | class LockContextTestCase(unittest.TestCase): | 
 |  | 
 |     def boilerPlate(self, lock, locked): | 
 |         self.assertFalse(locked()) | 
 |         with lock: | 
 |             self.assertTrue(locked()) | 
 |         self.assertFalse(locked()) | 
 |         with self.assertRaises(ZeroDivisionError): | 
 |             with lock: | 
 |                 self.assertTrue(locked()) | 
 |                 1 / 0 | 
 |         self.assertFalse(locked()) | 
 |  | 
 |     def testWithLock(self): | 
 |         lock = threading.Lock() | 
 |         self.boilerPlate(lock, lock.locked) | 
 |  | 
 |     def testWithRLock(self): | 
 |         lock = threading.RLock() | 
 |         self.boilerPlate(lock, lock._is_owned) | 
 |  | 
 |     def testWithCondition(self): | 
 |         lock = threading.Condition() | 
 |         def locked(): | 
 |             return lock._is_owned() | 
 |         self.boilerPlate(lock, locked) | 
 |  | 
 |     def testWithSemaphore(self): | 
 |         lock = threading.Semaphore() | 
 |         def locked(): | 
 |             if lock.acquire(False): | 
 |                 lock.release() | 
 |                 return False | 
 |             else: | 
 |                 return True | 
 |         self.boilerPlate(lock, locked) | 
 |  | 
 |     def testWithBoundedSemaphore(self): | 
 |         lock = threading.BoundedSemaphore() | 
 |         def locked(): | 
 |             if lock.acquire(False): | 
 |                 lock.release() | 
 |                 return False | 
 |             else: | 
 |                 return True | 
 |         self.boilerPlate(lock, locked) | 
 |  | 
 |  | 
 | class mycontext(ContextDecorator): | 
 |     started = False | 
 |     exc = None | 
 |     catch = False | 
 |  | 
 |     def __enter__(self): | 
 |         self.started = True | 
 |         return self | 
 |  | 
 |     def __exit__(self, *exc): | 
 |         self.exc = exc | 
 |         return self.catch | 
 |  | 
 |  | 
 | class TestContextDecorator(unittest.TestCase): | 
 |  | 
 |     def test_contextdecorator(self): | 
 |         context = mycontext() | 
 |         with context as result: | 
 |             self.assertIs(result, context) | 
 |             self.assertTrue(context.started) | 
 |  | 
 |         self.assertEqual(context.exc, (None, None, None)) | 
 |  | 
 |  | 
 |     def test_contextdecorator_with_exception(self): | 
 |         context = mycontext() | 
 |  | 
 |         with self.assertRaisesRegex(NameError, 'foo'): | 
 |             with context: | 
 |                 raise NameError('foo') | 
 |         self.assertIsNotNone(context.exc) | 
 |         self.assertIs(context.exc[0], NameError) | 
 |  | 
 |         context = mycontext() | 
 |         context.catch = True | 
 |         with context: | 
 |             raise NameError('foo') | 
 |         self.assertIsNotNone(context.exc) | 
 |         self.assertIs(context.exc[0], NameError) | 
 |  | 
 |  | 
 |     def test_decorator(self): | 
 |         context = mycontext() | 
 |  | 
 |         @context | 
 |         def test(): | 
 |             self.assertIsNone(context.exc) | 
 |             self.assertTrue(context.started) | 
 |         test() | 
 |         self.assertEqual(context.exc, (None, None, None)) | 
 |  | 
 |  | 
 |     def test_decorator_with_exception(self): | 
 |         context = mycontext() | 
 |  | 
 |         @context | 
 |         def test(): | 
 |             self.assertIsNone(context.exc) | 
 |             self.assertTrue(context.started) | 
 |             raise NameError('foo') | 
 |  | 
 |         with self.assertRaisesRegex(NameError, 'foo'): | 
 |             test() | 
 |         self.assertIsNotNone(context.exc) | 
 |         self.assertIs(context.exc[0], NameError) | 
 |  | 
 |  | 
 |     def test_decorating_method(self): | 
 |         context = mycontext() | 
 |  | 
 |         class Test(object): | 
 |  | 
 |             @context | 
 |             def method(self, a, b, c=None): | 
 |                 self.a = a | 
 |                 self.b = b | 
 |                 self.c = c | 
 |  | 
 |         # these tests are for argument passing when used as a decorator | 
 |         test = Test() | 
 |         test.method(1, 2) | 
 |         self.assertEqual(test.a, 1) | 
 |         self.assertEqual(test.b, 2) | 
 |         self.assertEqual(test.c, None) | 
 |  | 
 |         test = Test() | 
 |         test.method('a', 'b', 'c') | 
 |         self.assertEqual(test.a, 'a') | 
 |         self.assertEqual(test.b, 'b') | 
 |         self.assertEqual(test.c, 'c') | 
 |  | 
 |         test = Test() | 
 |         test.method(a=1, b=2) | 
 |         self.assertEqual(test.a, 1) | 
 |         self.assertEqual(test.b, 2) | 
 |  | 
 |  | 
 |     def test_typo_enter(self): | 
 |         class mycontext(ContextDecorator): | 
 |             def __unter__(self): | 
 |                 pass | 
 |             def __exit__(self, *exc): | 
 |                 pass | 
 |  | 
 |         with self.assertRaises(AttributeError): | 
 |             with mycontext(): | 
 |                 pass | 
 |  | 
 |  | 
 |     def test_typo_exit(self): | 
 |         class mycontext(ContextDecorator): | 
 |             def __enter__(self): | 
 |                 pass | 
 |             def __uxit__(self, *exc): | 
 |                 pass | 
 |  | 
 |         with self.assertRaises(AttributeError): | 
 |             with mycontext(): | 
 |                 pass | 
 |  | 
 |  | 
 |     def test_contextdecorator_as_mixin(self): | 
 |         class somecontext(object): | 
 |             started = False | 
 |             exc = None | 
 |  | 
 |             def __enter__(self): | 
 |                 self.started = True | 
 |                 return self | 
 |  | 
 |             def __exit__(self, *exc): | 
 |                 self.exc = exc | 
 |  | 
 |         class mycontext(somecontext, ContextDecorator): | 
 |             pass | 
 |  | 
 |         context = mycontext() | 
 |         @context | 
 |         def test(): | 
 |             self.assertIsNone(context.exc) | 
 |             self.assertTrue(context.started) | 
 |         test() | 
 |         self.assertEqual(context.exc, (None, None, None)) | 
 |  | 
 |  | 
 |     def test_contextmanager_as_decorator(self): | 
 |         state = [] | 
 |         @contextmanager | 
 |         def woohoo(y): | 
 |             state.append(y) | 
 |             yield | 
 |             state.append(999) | 
 |  | 
 |         @woohoo(1) | 
 |         def test(x): | 
 |             self.assertEqual(state, [1]) | 
 |             state.append(x) | 
 |         test('something') | 
 |         self.assertEqual(state, [1, 'something', 999]) | 
 |  | 
 |  | 
 | # This is needed to make the test actually run under regrtest.py! | 
 | def test_main(): | 
 |     support.run_unittest(__name__) | 
 |  | 
 | if __name__ == "__main__": | 
 |     test_main() |