| import contextlib |
| import inspect |
| import sys |
| import types |
| import unittest |
| import warnings |
| from test import support |
| |
| |
| class AsyncYieldFrom: |
| def __init__(self, obj): |
| self.obj = obj |
| |
| def __await__(self): |
| yield from self.obj |
| |
| |
| class AsyncYield: |
| def __init__(self, value): |
| self.value = value |
| |
| def __await__(self): |
| yield self.value |
| |
| |
| def run_async(coro): |
| assert coro.__class__ is types.GeneratorType |
| |
| buffer = [] |
| result = None |
| while True: |
| try: |
| buffer.append(coro.send(None)) |
| except StopIteration as ex: |
| result = ex.args[0] if ex.args else None |
| break |
| return buffer, result |
| |
| |
| @contextlib.contextmanager |
| def silence_coro_gc(): |
| with warnings.catch_warnings(): |
| warnings.simplefilter("ignore") |
| yield |
| support.gc_collect() |
| |
| |
| class AsyncBadSyntaxTest(unittest.TestCase): |
| |
| def test_badsyntax_1(self): |
| with self.assertRaisesRegex(SyntaxError, 'invalid syntax'): |
| import test.badsyntax_async1 |
| |
| def test_badsyntax_2(self): |
| with self.assertRaisesRegex(SyntaxError, 'invalid syntax'): |
| import test.badsyntax_async2 |
| |
| def test_badsyntax_3(self): |
| with self.assertRaisesRegex(SyntaxError, 'invalid syntax'): |
| import test.badsyntax_async3 |
| |
| def test_badsyntax_4(self): |
| with self.assertRaisesRegex(SyntaxError, 'invalid syntax'): |
| import test.badsyntax_async4 |
| |
| def test_badsyntax_5(self): |
| with self.assertRaisesRegex(SyntaxError, 'invalid syntax'): |
| import test.badsyntax_async5 |
| |
| def test_badsyntax_6(self): |
| with self.assertRaisesRegex( |
| SyntaxError, "'yield' inside async function"): |
| |
| import test.badsyntax_async6 |
| |
| def test_badsyntax_7(self): |
| with self.assertRaisesRegex( |
| SyntaxError, "'yield from' inside async function"): |
| |
| import test.badsyntax_async7 |
| |
| def test_badsyntax_8(self): |
| with self.assertRaisesRegex(SyntaxError, 'invalid syntax'): |
| import test.badsyntax_async8 |
| |
| def test_badsyntax_9(self): |
| with self.assertRaisesRegex(SyntaxError, 'invalid syntax'): |
| import test.badsyntax_async9 |
| |
| |
| class TokenizerRegrTest(unittest.TestCase): |
| |
| def test_oneline_defs(self): |
| buf = [] |
| for i in range(500): |
| buf.append('def i{i}(): return {i}'.format(i=i)) |
| buf = '\n'.join(buf) |
| |
| # Test that 500 consequent, one-line defs is OK |
| ns = {} |
| exec(buf, ns, ns) |
| self.assertEqual(ns['i499'](), 499) |
| |
| # Test that 500 consequent, one-line defs *and* |
| # one 'async def' following them is OK |
| buf += '\nasync def foo():\n return' |
| ns = {} |
| exec(buf, ns, ns) |
| self.assertEqual(ns['i499'](), 499) |
| self.assertTrue(inspect.iscoroutinefunction(ns['foo'])) |
| |
| |
| class CoroutineTest(unittest.TestCase): |
| |
| def test_gen_1(self): |
| def gen(): yield |
| self.assertFalse(hasattr(gen, '__await__')) |
| |
| def test_func_1(self): |
| async def foo(): |
| return 10 |
| |
| f = foo() |
| self.assertIsInstance(f, types.GeneratorType) |
| self.assertTrue(bool(foo.__code__.co_flags & 0x80)) |
| self.assertTrue(bool(foo.__code__.co_flags & 0x20)) |
| self.assertTrue(bool(f.gi_code.co_flags & 0x80)) |
| self.assertTrue(bool(f.gi_code.co_flags & 0x20)) |
| self.assertEqual(run_async(f), ([], 10)) |
| |
| def bar(): pass |
| self.assertFalse(bool(bar.__code__.co_flags & 0x80)) |
| |
| def test_func_2(self): |
| async def foo(): |
| raise StopIteration |
| |
| with self.assertRaisesRegex( |
| RuntimeError, "generator raised StopIteration"): |
| |
| run_async(foo()) |
| |
| def test_func_3(self): |
| async def foo(): |
| raise StopIteration |
| |
| with silence_coro_gc(): |
| self.assertRegex(repr(foo()), '^<coroutine object.* at 0x.*>$') |
| |
| def test_func_4(self): |
| async def foo(): |
| raise StopIteration |
| |
| check = lambda: self.assertRaisesRegex( |
| TypeError, "coroutine-objects do not support iteration") |
| |
| with check(): |
| list(foo()) |
| |
| with check(): |
| tuple(foo()) |
| |
| with check(): |
| sum(foo()) |
| |
| with check(): |
| iter(foo()) |
| |
| with check(): |
| next(foo()) |
| |
| with silence_coro_gc(), check(): |
| for i in foo(): |
| pass |
| |
| with silence_coro_gc(), check(): |
| [i for i in foo()] |
| |
| def test_func_5(self): |
| @types.coroutine |
| def bar(): |
| yield 1 |
| |
| async def foo(): |
| await bar() |
| |
| check = lambda: self.assertRaisesRegex( |
| TypeError, "coroutine-objects do not support iteration") |
| |
| with check(): |
| for el in foo(): pass |
| |
| # the following should pass without an error |
| for el in bar(): |
| self.assertEqual(el, 1) |
| self.assertEqual([el for el in bar()], [1]) |
| self.assertEqual(tuple(bar()), (1,)) |
| self.assertEqual(next(iter(bar())), 1) |
| |
| def test_func_6(self): |
| @types.coroutine |
| def bar(): |
| yield 1 |
| yield 2 |
| |
| async def foo(): |
| await bar() |
| |
| f = foo() |
| self.assertEqual(f.send(None), 1) |
| self.assertEqual(f.send(None), 2) |
| with self.assertRaises(StopIteration): |
| f.send(None) |
| |
| def test_func_7(self): |
| async def bar(): |
| return 10 |
| |
| def foo(): |
| yield from bar() |
| |
| with silence_coro_gc(), self.assertRaisesRegex( |
| TypeError, |
| "cannot 'yield from' a coroutine object from a generator"): |
| |
| list(foo()) |
| |
| def test_func_8(self): |
| @types.coroutine |
| def bar(): |
| return (yield from foo()) |
| |
| async def foo(): |
| return 'spam' |
| |
| self.assertEqual(run_async(bar()), ([], 'spam') ) |
| |
| def test_func_9(self): |
| async def foo(): pass |
| |
| with self.assertWarnsRegex( |
| RuntimeWarning, "coroutine '.*test_func_9.*foo' was never awaited"): |
| |
| foo() |
| support.gc_collect() |
| |
| def test_await_1(self): |
| |
| async def foo(): |
| await 1 |
| with self.assertRaisesRegex(TypeError, "object int can.t.*await"): |
| run_async(foo()) |
| |
| def test_await_2(self): |
| async def foo(): |
| await [] |
| with self.assertRaisesRegex(TypeError, "object list can.t.*await"): |
| run_async(foo()) |
| |
| def test_await_3(self): |
| async def foo(): |
| await AsyncYieldFrom([1, 2, 3]) |
| |
| self.assertEqual(run_async(foo()), ([1, 2, 3], None)) |
| |
| def test_await_4(self): |
| async def bar(): |
| return 42 |
| |
| async def foo(): |
| return await bar() |
| |
| self.assertEqual(run_async(foo()), ([], 42)) |
| |
| def test_await_5(self): |
| class Awaitable: |
| def __await__(self): |
| return |
| |
| async def foo(): |
| return (await Awaitable()) |
| |
| with self.assertRaisesRegex( |
| TypeError, "__await__.*returned non-iterator of type"): |
| |
| run_async(foo()) |
| |
| def test_await_6(self): |
| class Awaitable: |
| def __await__(self): |
| return iter([52]) |
| |
| async def foo(): |
| return (await Awaitable()) |
| |
| self.assertEqual(run_async(foo()), ([52], None)) |
| |
| def test_await_7(self): |
| class Awaitable: |
| def __await__(self): |
| yield 42 |
| return 100 |
| |
| async def foo(): |
| return (await Awaitable()) |
| |
| self.assertEqual(run_async(foo()), ([42], 100)) |
| |
| def test_await_8(self): |
| class Awaitable: |
| pass |
| |
| async def foo(): |
| return (await Awaitable()) |
| |
| with self.assertRaisesRegex( |
| TypeError, "object Awaitable can't be used in 'await' expression"): |
| |
| run_async(foo()) |
| |
| def test_await_9(self): |
| def wrap(): |
| return bar |
| |
| async def bar(): |
| return 42 |
| |
| async def foo(): |
| b = bar() |
| |
| db = {'b': lambda: wrap} |
| |
| class DB: |
| b = wrap |
| |
| return (await bar() + await wrap()() + await db['b']()()() + |
| await bar() * 1000 + await DB.b()()) |
| |
| async def foo2(): |
| return -await bar() |
| |
| self.assertEqual(run_async(foo()), ([], 42168)) |
| self.assertEqual(run_async(foo2()), ([], -42)) |
| |
| def test_await_10(self): |
| async def baz(): |
| return 42 |
| |
| async def bar(): |
| return baz() |
| |
| async def foo(): |
| return await (await bar()) |
| |
| self.assertEqual(run_async(foo()), ([], 42)) |
| |
| def test_await_11(self): |
| def ident(val): |
| return val |
| |
| async def bar(): |
| return 'spam' |
| |
| async def foo(): |
| return ident(val=await bar()) |
| |
| async def foo2(): |
| return await bar(), 'ham' |
| |
| self.assertEqual(run_async(foo2()), ([], ('spam', 'ham'))) |
| |
| def test_await_12(self): |
| async def coro(): |
| return 'spam' |
| |
| class Awaitable: |
| def __await__(self): |
| return coro() |
| |
| async def foo(): |
| return await Awaitable() |
| |
| with self.assertRaisesRegex( |
| TypeError, "__await__\(\) returned a coroutine"): |
| |
| run_async(foo()) |
| |
| def test_await_13(self): |
| class Awaitable: |
| def __await__(self): |
| return self |
| |
| async def foo(): |
| return await Awaitable() |
| |
| with self.assertRaisesRegex( |
| TypeError, "__await__.*returned non-iterator of type"): |
| |
| run_async(foo()) |
| |
| def test_with_1(self): |
| class Manager: |
| def __init__(self, name): |
| self.name = name |
| |
| async def __aenter__(self): |
| await AsyncYieldFrom(['enter-1-' + self.name, |
| 'enter-2-' + self.name]) |
| return self |
| |
| async def __aexit__(self, *args): |
| await AsyncYieldFrom(['exit-1-' + self.name, |
| 'exit-2-' + self.name]) |
| |
| if self.name == 'B': |
| return True |
| |
| |
| async def foo(): |
| async with Manager("A") as a, Manager("B") as b: |
| await AsyncYieldFrom([('managers', a.name, b.name)]) |
| 1/0 |
| |
| f = foo() |
| result, _ = run_async(f) |
| |
| self.assertEqual( |
| result, ['enter-1-A', 'enter-2-A', 'enter-1-B', 'enter-2-B', |
| ('managers', 'A', 'B'), |
| 'exit-1-B', 'exit-2-B', 'exit-1-A', 'exit-2-A'] |
| ) |
| |
| async def foo(): |
| async with Manager("A") as a, Manager("C") as c: |
| await AsyncYieldFrom([('managers', a.name, c.name)]) |
| 1/0 |
| |
| with self.assertRaises(ZeroDivisionError): |
| run_async(foo()) |
| |
| def test_with_2(self): |
| class CM: |
| def __aenter__(self): |
| pass |
| |
| async def foo(): |
| async with CM(): |
| pass |
| |
| with self.assertRaisesRegex(AttributeError, '__aexit__'): |
| run_async(foo()) |
| |
| def test_with_3(self): |
| class CM: |
| def __aexit__(self): |
| pass |
| |
| async def foo(): |
| async with CM(): |
| pass |
| |
| with self.assertRaisesRegex(AttributeError, '__aenter__'): |
| run_async(foo()) |
| |
| def test_with_4(self): |
| class CM: |
| def __enter__(self): |
| pass |
| |
| def __exit__(self): |
| pass |
| |
| async def foo(): |
| async with CM(): |
| pass |
| |
| with self.assertRaisesRegex(AttributeError, '__aexit__'): |
| run_async(foo()) |
| |
| def test_with_5(self): |
| # While this test doesn't make a lot of sense, |
| # it's a regression test for an early bug with opcodes |
| # generation |
| |
| class CM: |
| async def __aenter__(self): |
| return self |
| |
| async def __aexit__(self, *exc): |
| pass |
| |
| async def func(): |
| async with CM(): |
| assert (1, ) == 1 |
| |
| with self.assertRaises(AssertionError): |
| run_async(func()) |
| |
| def test_with_6(self): |
| class CM: |
| def __aenter__(self): |
| return 123 |
| |
| def __aexit__(self, *e): |
| return 456 |
| |
| async def foo(): |
| async with CM(): |
| pass |
| |
| with self.assertRaisesRegex( |
| TypeError, "object int can't be used in 'await' expression"): |
| # it's important that __aexit__ wasn't called |
| run_async(foo()) |
| |
| def test_with_7(self): |
| class CM: |
| async def __aenter__(self): |
| return self |
| |
| def __aexit__(self, *e): |
| return 444 |
| |
| async def foo(): |
| async with CM(): |
| 1/0 |
| |
| try: |
| run_async(foo()) |
| except TypeError as exc: |
| self.assertRegex( |
| exc.args[0], "object int can't be used in 'await' expression") |
| self.assertTrue(exc.__context__ is not None) |
| self.assertTrue(isinstance(exc.__context__, ZeroDivisionError)) |
| else: |
| self.fail('invalid asynchronous context manager did not fail') |
| |
| |
| def test_with_8(self): |
| CNT = 0 |
| |
| class CM: |
| async def __aenter__(self): |
| return self |
| |
| def __aexit__(self, *e): |
| return 456 |
| |
| async def foo(): |
| nonlocal CNT |
| async with CM(): |
| CNT += 1 |
| |
| |
| with self.assertRaisesRegex( |
| TypeError, "object int can't be used in 'await' expression"): |
| |
| run_async(foo()) |
| |
| self.assertEqual(CNT, 1) |
| |
| |
| def test_with_9(self): |
| CNT = 0 |
| |
| class CM: |
| async def __aenter__(self): |
| return self |
| |
| async def __aexit__(self, *e): |
| 1/0 |
| |
| async def foo(): |
| nonlocal CNT |
| async with CM(): |
| CNT += 1 |
| |
| with self.assertRaises(ZeroDivisionError): |
| run_async(foo()) |
| |
| self.assertEqual(CNT, 1) |
| |
| def test_with_10(self): |
| CNT = 0 |
| |
| class CM: |
| async def __aenter__(self): |
| return self |
| |
| async def __aexit__(self, *e): |
| 1/0 |
| |
| async def foo(): |
| nonlocal CNT |
| async with CM(): |
| async with CM(): |
| raise RuntimeError |
| |
| try: |
| run_async(foo()) |
| except ZeroDivisionError as exc: |
| self.assertTrue(exc.__context__ is not None) |
| self.assertTrue(isinstance(exc.__context__, ZeroDivisionError)) |
| self.assertTrue(isinstance(exc.__context__.__context__, |
| RuntimeError)) |
| else: |
| self.fail('exception from __aexit__ did not propagate') |
| |
| def test_with_11(self): |
| CNT = 0 |
| |
| class CM: |
| async def __aenter__(self): |
| raise NotImplementedError |
| |
| async def __aexit__(self, *e): |
| 1/0 |
| |
| async def foo(): |
| nonlocal CNT |
| async with CM(): |
| raise RuntimeError |
| |
| try: |
| run_async(foo()) |
| except NotImplementedError as exc: |
| self.assertTrue(exc.__context__ is None) |
| else: |
| self.fail('exception from __aenter__ did not propagate') |
| |
| def test_with_12(self): |
| CNT = 0 |
| |
| class CM: |
| async def __aenter__(self): |
| return self |
| |
| async def __aexit__(self, *e): |
| return True |
| |
| async def foo(): |
| nonlocal CNT |
| async with CM() as cm: |
| self.assertIs(cm.__class__, CM) |
| raise RuntimeError |
| |
| run_async(foo()) |
| |
| def test_with_13(self): |
| CNT = 0 |
| |
| class CM: |
| async def __aenter__(self): |
| 1/0 |
| |
| async def __aexit__(self, *e): |
| return True |
| |
| async def foo(): |
| nonlocal CNT |
| CNT += 1 |
| async with CM(): |
| CNT += 1000 |
| CNT += 10000 |
| |
| with self.assertRaises(ZeroDivisionError): |
| run_async(foo()) |
| self.assertEqual(CNT, 1) |
| |
| def test_for_1(self): |
| aiter_calls = 0 |
| |
| class AsyncIter: |
| def __init__(self): |
| self.i = 0 |
| |
| async def __aiter__(self): |
| nonlocal aiter_calls |
| aiter_calls += 1 |
| return self |
| |
| async def __anext__(self): |
| self.i += 1 |
| |
| if not (self.i % 10): |
| await AsyncYield(self.i * 10) |
| |
| if self.i > 100: |
| raise StopAsyncIteration |
| |
| return self.i, self.i |
| |
| |
| buffer = [] |
| async def test1(): |
| async for i1, i2 in AsyncIter(): |
| buffer.append(i1 + i2) |
| |
| yielded, _ = run_async(test1()) |
| # Make sure that __aiter__ was called only once |
| self.assertEqual(aiter_calls, 1) |
| self.assertEqual(yielded, [i * 100 for i in range(1, 11)]) |
| self.assertEqual(buffer, [i*2 for i in range(1, 101)]) |
| |
| |
| buffer = [] |
| async def test2(): |
| nonlocal buffer |
| async for i in AsyncIter(): |
| buffer.append(i[0]) |
| if i[0] == 20: |
| break |
| else: |
| buffer.append('what?') |
| buffer.append('end') |
| |
| yielded, _ = run_async(test2()) |
| # Make sure that __aiter__ was called only once |
| self.assertEqual(aiter_calls, 2) |
| self.assertEqual(yielded, [100, 200]) |
| self.assertEqual(buffer, [i for i in range(1, 21)] + ['end']) |
| |
| |
| buffer = [] |
| async def test3(): |
| nonlocal buffer |
| async for i in AsyncIter(): |
| if i[0] > 20: |
| continue |
| buffer.append(i[0]) |
| else: |
| buffer.append('what?') |
| buffer.append('end') |
| |
| yielded, _ = run_async(test3()) |
| # Make sure that __aiter__ was called only once |
| self.assertEqual(aiter_calls, 3) |
| self.assertEqual(yielded, [i * 100 for i in range(1, 11)]) |
| self.assertEqual(buffer, [i for i in range(1, 21)] + |
| ['what?', 'end']) |
| |
| def test_for_2(self): |
| tup = (1, 2, 3) |
| refs_before = sys.getrefcount(tup) |
| |
| async def foo(): |
| async for i in tup: |
| print('never going to happen') |
| |
| with self.assertRaisesRegex( |
| TypeError, "async for' requires an object.*__aiter__.*tuple"): |
| |
| run_async(foo()) |
| |
| self.assertEqual(sys.getrefcount(tup), refs_before) |
| |
| def test_for_3(self): |
| class I: |
| def __aiter__(self): |
| return self |
| |
| aiter = I() |
| refs_before = sys.getrefcount(aiter) |
| |
| async def foo(): |
| async for i in aiter: |
| print('never going to happen') |
| |
| with self.assertRaisesRegex( |
| TypeError, |
| "async for' received an invalid object.*__aiter.*\: I"): |
| |
| run_async(foo()) |
| |
| self.assertEqual(sys.getrefcount(aiter), refs_before) |
| |
| def test_for_4(self): |
| class I: |
| async def __aiter__(self): |
| return self |
| |
| def __anext__(self): |
| return () |
| |
| aiter = I() |
| refs_before = sys.getrefcount(aiter) |
| |
| async def foo(): |
| async for i in aiter: |
| print('never going to happen') |
| |
| with self.assertRaisesRegex( |
| TypeError, |
| "async for' received an invalid object.*__anext__.*tuple"): |
| |
| run_async(foo()) |
| |
| self.assertEqual(sys.getrefcount(aiter), refs_before) |
| |
| def test_for_5(self): |
| class I: |
| async def __aiter__(self): |
| return self |
| |
| def __anext__(self): |
| return 123 |
| |
| async def foo(): |
| async for i in I(): |
| print('never going to happen') |
| |
| with self.assertRaisesRegex( |
| TypeError, |
| "async for' received an invalid object.*__anext.*int"): |
| |
| run_async(foo()) |
| |
| def test_for_6(self): |
| I = 0 |
| |
| class Manager: |
| async def __aenter__(self): |
| nonlocal I |
| I += 10000 |
| |
| async def __aexit__(self, *args): |
| nonlocal I |
| I += 100000 |
| |
| class Iterable: |
| def __init__(self): |
| self.i = 0 |
| |
| async def __aiter__(self): |
| return self |
| |
| async def __anext__(self): |
| if self.i > 10: |
| raise StopAsyncIteration |
| self.i += 1 |
| return self.i |
| |
| ############## |
| |
| manager = Manager() |
| iterable = Iterable() |
| mrefs_before = sys.getrefcount(manager) |
| irefs_before = sys.getrefcount(iterable) |
| |
| async def main(): |
| nonlocal I |
| |
| async with manager: |
| async for i in iterable: |
| I += 1 |
| I += 1000 |
| |
| run_async(main()) |
| self.assertEqual(I, 111011) |
| |
| self.assertEqual(sys.getrefcount(manager), mrefs_before) |
| self.assertEqual(sys.getrefcount(iterable), irefs_before) |
| |
| ############## |
| |
| async def main(): |
| nonlocal I |
| |
| async with Manager(): |
| async for i in Iterable(): |
| I += 1 |
| I += 1000 |
| |
| async with Manager(): |
| async for i in Iterable(): |
| I += 1 |
| I += 1000 |
| |
| run_async(main()) |
| self.assertEqual(I, 333033) |
| |
| ############## |
| |
| async def main(): |
| nonlocal I |
| |
| async with Manager(): |
| I += 100 |
| async for i in Iterable(): |
| I += 1 |
| else: |
| I += 10000000 |
| I += 1000 |
| |
| async with Manager(): |
| I += 100 |
| async for i in Iterable(): |
| I += 1 |
| else: |
| I += 10000000 |
| I += 1000 |
| |
| run_async(main()) |
| self.assertEqual(I, 20555255) |
| |
| def test_for_7(self): |
| CNT = 0 |
| class AI: |
| async def __aiter__(self): |
| 1/0 |
| async def foo(): |
| nonlocal CNT |
| async for i in AI(): |
| CNT += 1 |
| CNT += 10 |
| with self.assertRaises(ZeroDivisionError): |
| run_async(foo()) |
| self.assertEqual(CNT, 0) |
| |
| |
| class CoroAsyncIOCompatTest(unittest.TestCase): |
| |
| def test_asyncio_1(self): |
| import asyncio |
| |
| class MyException(Exception): |
| pass |
| |
| buffer = [] |
| |
| class CM: |
| async def __aenter__(self): |
| buffer.append(1) |
| await asyncio.sleep(0.01) |
| buffer.append(2) |
| return self |
| |
| async def __aexit__(self, exc_type, exc_val, exc_tb): |
| await asyncio.sleep(0.01) |
| buffer.append(exc_type.__name__) |
| |
| async def f(): |
| async with CM() as c: |
| await asyncio.sleep(0.01) |
| raise MyException |
| buffer.append('unreachable') |
| |
| loop = asyncio.new_event_loop() |
| asyncio.set_event_loop(loop) |
| try: |
| loop.run_until_complete(f()) |
| except MyException: |
| pass |
| finally: |
| loop.close() |
| asyncio.set_event_loop(None) |
| |
| self.assertEqual(buffer, [1, 2, 'MyException']) |
| |
| |
| class SysSetCoroWrapperTest(unittest.TestCase): |
| |
| def test_set_wrapper_1(self): |
| async def foo(): |
| return 'spam' |
| |
| wrapped = None |
| def wrap(gen): |
| nonlocal wrapped |
| wrapped = gen |
| return gen |
| |
| self.assertIsNone(sys.get_coroutine_wrapper()) |
| |
| sys.set_coroutine_wrapper(wrap) |
| self.assertIs(sys.get_coroutine_wrapper(), wrap) |
| try: |
| f = foo() |
| self.assertTrue(wrapped) |
| |
| self.assertEqual(run_async(f), ([], 'spam')) |
| finally: |
| sys.set_coroutine_wrapper(None) |
| |
| self.assertIsNone(sys.get_coroutine_wrapper()) |
| |
| wrapped = None |
| with silence_coro_gc(): |
| foo() |
| self.assertFalse(wrapped) |
| |
| def test_set_wrapper_2(self): |
| self.assertIsNone(sys.get_coroutine_wrapper()) |
| with self.assertRaisesRegex(TypeError, "callable expected, got int"): |
| sys.set_coroutine_wrapper(1) |
| self.assertIsNone(sys.get_coroutine_wrapper()) |
| |
| |
| class CAPITest(unittest.TestCase): |
| |
| def test_tp_await_1(self): |
| from _testcapi import awaitType as at |
| |
| async def foo(): |
| future = at(iter([1])) |
| return (await future) |
| |
| self.assertEqual(foo().send(None), 1) |
| |
| def test_tp_await_2(self): |
| # Test tp_await to __await__ mapping |
| from _testcapi import awaitType as at |
| future = at(iter([1])) |
| self.assertEqual(next(future.__await__()), 1) |
| |
| def test_tp_await_3(self): |
| from _testcapi import awaitType as at |
| |
| async def foo(): |
| future = at(1) |
| return (await future) |
| |
| with self.assertRaisesRegex( |
| TypeError, "__await__.*returned non-iterator of type 'int'"): |
| self.assertEqual(foo().send(None), 1) |
| |
| |
| if __name__=="__main__": |
| unittest.main() |