| import asyncio |
| from contextlib import asynccontextmanager |
| import functools |
| from test import support |
| import unittest |
| |
| |
| def _async_test(func): |
| """Decorator to turn an async function into a test case.""" |
| @functools.wraps(func) |
| def wrapper(*args, **kwargs): |
| coro = func(*args, **kwargs) |
| loop = asyncio.new_event_loop() |
| asyncio.set_event_loop(loop) |
| try: |
| return loop.run_until_complete(coro) |
| finally: |
| loop.close() |
| asyncio.set_event_loop(None) |
| return wrapper |
| |
| |
| class AsyncContextManagerTestCase(unittest.TestCase): |
| |
| @_async_test |
| async def test_contextmanager_plain(self): |
| state = [] |
| @asynccontextmanager |
| async def woohoo(): |
| state.append(1) |
| yield 42 |
| state.append(999) |
| async with woohoo() as x: |
| self.assertEqual(state, [1]) |
| self.assertEqual(x, 42) |
| state.append(x) |
| self.assertEqual(state, [1, 42, 999]) |
| |
| @_async_test |
| async def test_contextmanager_finally(self): |
| state = [] |
| @asynccontextmanager |
| async def woohoo(): |
| state.append(1) |
| try: |
| yield 42 |
| finally: |
| state.append(999) |
| with self.assertRaises(ZeroDivisionError): |
| async with woohoo() as x: |
| self.assertEqual(state, [1]) |
| self.assertEqual(x, 42) |
| state.append(x) |
| raise ZeroDivisionError() |
| self.assertEqual(state, [1, 42, 999]) |
| |
| @_async_test |
| async def test_contextmanager_no_reraise(self): |
| @asynccontextmanager |
| async def whee(): |
| yield |
| ctx = whee() |
| await ctx.__aenter__() |
| # Calling __aexit__ should not result in an exception |
| self.assertFalse(await ctx.__aexit__(TypeError, TypeError("foo"), None)) |
| |
| @_async_test |
| async def test_contextmanager_trap_yield_after_throw(self): |
| @asynccontextmanager |
| async def whoo(): |
| try: |
| yield |
| except: |
| yield |
| ctx = whoo() |
| await ctx.__aenter__() |
| with self.assertRaises(RuntimeError): |
| await ctx.__aexit__(TypeError, TypeError('foo'), None) |
| |
| @_async_test |
| async def test_contextmanager_trap_no_yield(self): |
| @asynccontextmanager |
| async def whoo(): |
| if False: |
| yield |
| ctx = whoo() |
| with self.assertRaises(RuntimeError): |
| await ctx.__aenter__() |
| |
| @_async_test |
| async def test_contextmanager_trap_second_yield(self): |
| @asynccontextmanager |
| async def whoo(): |
| yield |
| yield |
| ctx = whoo() |
| await ctx.__aenter__() |
| with self.assertRaises(RuntimeError): |
| await ctx.__aexit__(None, None, None) |
| |
| @_async_test |
| async def test_contextmanager_non_normalised(self): |
| @asynccontextmanager |
| async def whoo(): |
| try: |
| yield |
| except RuntimeError: |
| raise SyntaxError |
| |
| ctx = whoo() |
| await ctx.__aenter__() |
| with self.assertRaises(SyntaxError): |
| await ctx.__aexit__(RuntimeError, None, None) |
| |
| @_async_test |
| async def test_contextmanager_except(self): |
| state = [] |
| @asynccontextmanager |
| async def woohoo(): |
| state.append(1) |
| try: |
| yield 42 |
| except ZeroDivisionError as e: |
| state.append(e.args[0]) |
| self.assertEqual(state, [1, 42, 999]) |
| async with woohoo() as x: |
| self.assertEqual(state, [1]) |
| self.assertEqual(x, 42) |
| state.append(x) |
| raise ZeroDivisionError(999) |
| self.assertEqual(state, [1, 42, 999]) |
| |
| @_async_test |
| async def test_contextmanager_except_stopiter(self): |
| @asynccontextmanager |
| async def woohoo(): |
| yield |
| |
| for stop_exc in (StopIteration('spam'), StopAsyncIteration('ham')): |
| with self.subTest(type=type(stop_exc)): |
| try: |
| async with woohoo(): |
| raise stop_exc |
| except Exception as ex: |
| self.assertIs(ex, stop_exc) |
| else: |
| self.fail(f'{stop_exc} was suppressed') |
| |
| @_async_test |
| async def test_contextmanager_wrap_runtimeerror(self): |
| @asynccontextmanager |
| async def woohoo(): |
| try: |
| yield |
| except Exception as exc: |
| raise RuntimeError(f'caught {exc}') from exc |
| |
| with self.assertRaises(RuntimeError): |
| async with woohoo(): |
| 1 / 0 |
| |
| # If the context manager wrapped StopAsyncIteration in a RuntimeError, |
| # we also unwrap it, because we can't tell whether the wrapping was |
| # done by the generator machinery or by the generator itself. |
| with self.assertRaises(StopAsyncIteration): |
| async with woohoo(): |
| raise StopAsyncIteration |
| |
| def _create_contextmanager_attribs(self): |
| def attribs(**kw): |
| def decorate(func): |
| for k,v in kw.items(): |
| setattr(func,k,v) |
| return func |
| return decorate |
| @asynccontextmanager |
| @attribs(foo='bar') |
| async def baz(spam): |
| """Whee!""" |
| yield |
| return baz |
| |
| def test_contextmanager_attribs(self): |
| baz = self._create_contextmanager_attribs() |
| self.assertEqual(baz.__name__,'baz') |
| self.assertEqual(baz.foo, 'bar') |
| |
| @support.requires_docstrings |
| def test_contextmanager_doc_attrib(self): |
| baz = self._create_contextmanager_attribs() |
| self.assertEqual(baz.__doc__, "Whee!") |
| |
| @support.requires_docstrings |
| @_async_test |
| async def test_instance_docstring_given_cm_docstring(self): |
| baz = self._create_contextmanager_attribs()(None) |
| self.assertEqual(baz.__doc__, "Whee!") |
| async with baz: |
| pass # suppress warning |
| |
| @_async_test |
| async def test_keywords(self): |
| # Ensure no keyword arguments are inhibited |
| @asynccontextmanager |
| async def woohoo(self, func, args, kwds): |
| yield (self, func, args, kwds) |
| async with woohoo(self=11, func=22, args=33, kwds=44) as target: |
| self.assertEqual(target, (11, 22, 33, 44)) |
| |
| |
| if __name__ == '__main__': |
| unittest.main() |