| """Unit tests for contextlib.py, and other context managers.""" | 
 |  | 
 | from __future__ import with_statement | 
 |  | 
 | import os | 
 | import decimal | 
 | import tempfile | 
 | import unittest | 
 | import threading | 
 | from contextlib import *  # Tests __all__ | 
 |  | 
 | class ContextManagerTestCase(unittest.TestCase): | 
 |  | 
 |     def test_contextmanager_plain(self): | 
 |         state = [] | 
 |         @contextmanager | 
 |         def woohoo(): | 
 |             state.append(1) | 
 |             yield 42 | 
 |             state.append(999) | 
 |         with woohoo() as x: | 
 |             self.assertEqual(state, [1]) | 
 |             self.assertEqual(x, 42) | 
 |             state.append(x) | 
 |         self.assertEqual(state, [1, 42, 999]) | 
 |  | 
 |     def test_contextmanager_finally(self): | 
 |         state = [] | 
 |         @contextmanager | 
 |         def woohoo(): | 
 |             state.append(1) | 
 |             try: | 
 |                 yield 42 | 
 |             finally: | 
 |                 state.append(999) | 
 |         try: | 
 |             with woohoo() as x: | 
 |                 self.assertEqual(state, [1]) | 
 |                 self.assertEqual(x, 42) | 
 |                 state.append(x) | 
 |                 raise ZeroDivisionError() | 
 |         except ZeroDivisionError: | 
 |             pass | 
 |         else: | 
 |             self.fail("Expected ZeroDivisionError") | 
 |         self.assertEqual(state, [1, 42, 999]) | 
 |  | 
 |     def test_contextmanager_except(self): | 
 |         state = [] | 
 |         @contextmanager | 
 |         def woohoo(): | 
 |             state.append(1) | 
 |             try: | 
 |                 yield 42 | 
 |             except ZeroDivisionError, e: | 
 |                 state.append(e.args[0]) | 
 |                 self.assertEqual(state, [1, 42, 999]) | 
 |         with woohoo() as x: | 
 |             self.assertEqual(state, [1]) | 
 |             self.assertEqual(x, 42) | 
 |             state.append(x) | 
 |             raise ZeroDivisionError(999) | 
 |         self.assertEqual(state, [1, 42, 999]) | 
 |  | 
 | class NestedTestCase(unittest.TestCase): | 
 |  | 
 |     # XXX This needs more work | 
 |  | 
 |     def test_nested(self): | 
 |         @contextmanager | 
 |         def a(): | 
 |             yield 1 | 
 |         @contextmanager | 
 |         def b(): | 
 |             yield 2 | 
 |         @contextmanager | 
 |         def c(): | 
 |             yield 3 | 
 |         with nested(a(), b(), c()) as (x, y, z): | 
 |             self.assertEqual(x, 1) | 
 |             self.assertEqual(y, 2) | 
 |             self.assertEqual(z, 3) | 
 |  | 
 |     def test_nested_cleanup(self): | 
 |         state = [] | 
 |         @contextmanager | 
 |         def a(): | 
 |             state.append(1) | 
 |             try: | 
 |                 yield 2 | 
 |             finally: | 
 |                 state.append(3) | 
 |         @contextmanager | 
 |         def b(): | 
 |             state.append(4) | 
 |             try: | 
 |                 yield 5 | 
 |             finally: | 
 |                 state.append(6) | 
 |         try: | 
 |             with nested(a(), b()) as (x, y): | 
 |                 state.append(x) | 
 |                 state.append(y) | 
 |                 1/0 | 
 |         except ZeroDivisionError: | 
 |             self.assertEqual(state, [1, 4, 2, 5, 6, 3]) | 
 |         else: | 
 |             self.fail("Didn't raise ZeroDivisionError") | 
 |  | 
 |     def test_nested_b_swallows(self): | 
 |         @contextmanager | 
 |         def a(): | 
 |             yield | 
 |         @contextmanager | 
 |         def b(): | 
 |             try: | 
 |                 yield | 
 |             except: | 
 |                 # Swallow the exception | 
 |                 pass | 
 |         try: | 
 |             with nested(a(), b()): | 
 |                 1/0 | 
 |         except ZeroDivisionError: | 
 |             self.fail("Didn't swallow ZeroDivisionError") | 
 |  | 
 |     def test_nested_break(self): | 
 |         @contextmanager | 
 |         def a(): | 
 |             yield | 
 |         state = 0 | 
 |         while True: | 
 |             state += 1 | 
 |             with nested(a(), a()): | 
 |                 break | 
 |             state += 10 | 
 |         self.assertEqual(state, 1) | 
 |  | 
 |     def test_nested_continue(self): | 
 |         @contextmanager | 
 |         def a(): | 
 |             yield | 
 |         state = 0 | 
 |         while state < 3: | 
 |             state += 1 | 
 |             with nested(a(), a()): | 
 |                 continue | 
 |             state += 10 | 
 |         self.assertEqual(state, 3) | 
 |  | 
 |     def test_nested_return(self): | 
 |         @contextmanager | 
 |         def a(): | 
 |             try: | 
 |                 yield | 
 |             except: | 
 |                 pass | 
 |         def foo(): | 
 |             with nested(a(), a()): | 
 |                 return 1 | 
 |             return 10 | 
 |         self.assertEqual(foo(), 1) | 
 |  | 
 | class ClosingTestCase(unittest.TestCase): | 
 |  | 
 |     # XXX This needs more work | 
 |  | 
 |     def test_closing(self): | 
 |         state = [] | 
 |         class C: | 
 |             def close(self): | 
 |                 state.append(1) | 
 |         x = C() | 
 |         self.assertEqual(state, []) | 
 |         with closing(x) as y: | 
 |             self.assertEqual(x, y) | 
 |         self.assertEqual(state, [1]) | 
 |  | 
 |     def test_closing_error(self): | 
 |         state = [] | 
 |         class C: | 
 |             def close(self): | 
 |                 state.append(1) | 
 |         x = C() | 
 |         self.assertEqual(state, []) | 
 |         try: | 
 |             with closing(x) as y: | 
 |                 self.assertEqual(x, y) | 
 |                 1/0 | 
 |         except ZeroDivisionError: | 
 |             self.assertEqual(state, [1]) | 
 |         else: | 
 |             self.fail("Didn't raise ZeroDivisionError") | 
 |  | 
 | class FileContextTestCase(unittest.TestCase): | 
 |  | 
 |     def testWithOpen(self): | 
 |         tfn = tempfile.mktemp() | 
 |         try: | 
 |             f = None | 
 |             with open(tfn, "w") as f: | 
 |                 self.failIf(f.closed) | 
 |                 f.write("Booh\n") | 
 |             self.failUnless(f.closed) | 
 |             f = None | 
 |             try: | 
 |                 with open(tfn, "r") as f: | 
 |                     self.failIf(f.closed) | 
 |                     self.assertEqual(f.read(), "Booh\n") | 
 |                     1/0 | 
 |             except ZeroDivisionError: | 
 |                 self.failUnless(f.closed) | 
 |             else: | 
 |                 self.fail("Didn't raise ZeroDivisionError") | 
 |         finally: | 
 |             try: | 
 |                 os.remove(tfn) | 
 |             except os.error: | 
 |                 pass | 
 |  | 
 | class LockContextTestCase(unittest.TestCase): | 
 |  | 
 |     def boilerPlate(self, lock, locked): | 
 |         self.failIf(locked()) | 
 |         with lock: | 
 |             self.failUnless(locked()) | 
 |         self.failIf(locked()) | 
 |         try: | 
 |             with lock: | 
 |                 self.failUnless(locked()) | 
 |                 1/0 | 
 |         except ZeroDivisionError: | 
 |             self.failIf(locked()) | 
 |         else: | 
 |             self.fail("Didn't raise ZeroDivisionError") | 
 |  | 
 |     def testWithLock(self): | 
 |         lock = threading.Lock() | 
 |         self.boilerPlate(lock, lock.locked) | 
 |  | 
 |     def testWithRLock(self): | 
 |         lock = threading.RLock() | 
 |         self.boilerPlate(lock, lock._is_owned) | 
 |  | 
 |     def testWithCondition(self): | 
 |         lock = threading.Condition() | 
 |         def locked(): | 
 |             return lock._is_owned() | 
 |         self.boilerPlate(lock, locked) | 
 |  | 
 |     def testWithSemaphore(self): | 
 |         lock = threading.Semaphore() | 
 |         def locked(): | 
 |             if lock.acquire(False): | 
 |                 lock.release() | 
 |                 return False | 
 |             else: | 
 |                 return True | 
 |         self.boilerPlate(lock, locked) | 
 |  | 
 |     def testWithBoundedSemaphore(self): | 
 |         lock = threading.BoundedSemaphore() | 
 |         def locked(): | 
 |             if lock.acquire(False): | 
 |                 lock.release() | 
 |                 return False | 
 |             else: | 
 |                 return True | 
 |         self.boilerPlate(lock, locked) | 
 |  | 
 | class DecimalContextTestCase(unittest.TestCase): | 
 |  | 
 |     # XXX Somebody should write more thorough tests for this | 
 |  | 
 |     def testBasic(self): | 
 |         ctx = decimal.getcontext() | 
 |         ctx.prec = save_prec = decimal.ExtendedContext.prec + 5 | 
 |         with decimal.ExtendedContext: | 
 |             self.assertEqual(decimal.getcontext().prec, | 
 |                              decimal.ExtendedContext.prec) | 
 |         self.assertEqual(decimal.getcontext().prec, save_prec) | 
 |         try: | 
 |             with decimal.ExtendedContext: | 
 |                 self.assertEqual(decimal.getcontext().prec, | 
 |                                  decimal.ExtendedContext.prec) | 
 |                 1/0 | 
 |         except ZeroDivisionError: | 
 |             self.assertEqual(decimal.getcontext().prec, save_prec) | 
 |         else: | 
 |             self.fail("Didn't raise ZeroDivisionError") | 
 |  | 
 |  | 
 | if __name__ == "__main__": | 
 |     unittest.main() |