blob: a3c8bfb9eca758009651cc199ab6a6a88f752044 [file] [log] [blame]
Andrew Svetlov4dd3e3f2019-05-29 12:33:59 +03001import asyncio
2import inspect
3
4from .case import TestCase
5
6
7
8class IsolatedAsyncioTestCase(TestCase):
9 # Names intentionally have a long prefix
10 # to reduce a chance of clashing with user-defined attributes
11 # from inherited test case
12 #
13 # The class doesn't call loop.run_until_complete(self.setUp()) and family
14 # but uses a different approach:
15 # 1. create a long-running task that reads self.setUp()
16 # awaitable from queue along with a future
17 # 2. await the awaitable object passing in and set the result
18 # into the future object
19 # 3. Outer code puts the awaitable and the future object into a queue
20 # with waiting for the future
21 # The trick is necessary because every run_until_complete() call
22 # creates a new task with embedded ContextVar context.
23 # To share contextvars between setUp(), test and tearDown() we need to execute
24 # them inside the same task.
25
26 # Note: the test case modifies event loop policy if the policy was not instantiated
27 # yet.
28 # asyncio.get_event_loop_policy() creates a default policy on demand but never
29 # returns None
30 # I believe this is not an issue in user level tests but python itself for testing
31 # should reset a policy in every test module
32 # by calling asyncio.set_event_loop_policy(None) in tearDownModule()
33
34 def __init__(self, methodName='runTest'):
35 super().__init__(methodName)
36 self._asyncioTestLoop = None
37 self._asyncioCallsQueue = None
38
39 async def asyncSetUp(self):
40 pass
41
42 async def asyncTearDown(self):
43 pass
44
45 def addAsyncCleanup(self, func, /, *args, **kwargs):
46 # A trivial trampoline to addCleanup()
47 # the function exists because it has a different semantics
48 # and signature:
49 # addCleanup() accepts regular functions
50 # but addAsyncCleanup() accepts coroutines
51 #
52 # We intentionally don't add inspect.iscoroutinefunction() check
53 # for func argument because there is no way
54 # to check for async function reliably:
55 # 1. It can be "async def func()" iself
56 # 2. Class can implement "async def __call__()" method
57 # 3. Regular "def func()" that returns awaitable object
58 self.addCleanup(*(func, *args), **kwargs)
59
60 def _callSetUp(self):
61 self.setUp()
62 self._callAsync(self.asyncSetUp)
63
64 def _callTestMethod(self, method):
65 self._callMaybeAsync(method)
66
67 def _callTearDown(self):
68 self._callAsync(self.asyncTearDown)
69 self.tearDown()
70
71 def _callCleanup(self, function, *args, **kwargs):
72 self._callMaybeAsync(function, *args, **kwargs)
73
74 def _callAsync(self, func, /, *args, **kwargs):
75 assert self._asyncioTestLoop is not None
76 ret = func(*args, **kwargs)
77 assert inspect.isawaitable(ret)
78 fut = self._asyncioTestLoop.create_future()
79 self._asyncioCallsQueue.put_nowait((fut, ret))
80 return self._asyncioTestLoop.run_until_complete(fut)
81
82 def _callMaybeAsync(self, func, /, *args, **kwargs):
83 assert self._asyncioTestLoop is not None
84 ret = func(*args, **kwargs)
85 if inspect.isawaitable(ret):
86 fut = self._asyncioTestLoop.create_future()
87 self._asyncioCallsQueue.put_nowait((fut, ret))
88 return self._asyncioTestLoop.run_until_complete(fut)
89 else:
90 return ret
91
92 async def _asyncioLoopRunner(self):
93 queue = self._asyncioCallsQueue
94 while True:
95 query = await queue.get()
96 queue.task_done()
97 if query is None:
98 return
99 fut, awaitable = query
100 try:
101 ret = await awaitable
102 if not fut.cancelled():
103 fut.set_result(ret)
104 except asyncio.CancelledError:
105 raise
106 except Exception as ex:
107 if not fut.cancelled():
108 fut.set_exception(ex)
109
110 def _setupAsyncioLoop(self):
111 assert self._asyncioTestLoop is None
112 loop = asyncio.new_event_loop()
113 asyncio.set_event_loop(loop)
114 loop.set_debug(True)
115 self._asyncioTestLoop = loop
116 self._asyncioCallsQueue = asyncio.Queue(loop=loop)
117 self._asyncioCallsTask = loop.create_task(self._asyncioLoopRunner())
118
119 def _tearDownAsyncioLoop(self):
120 assert self._asyncioTestLoop is not None
121 loop = self._asyncioTestLoop
122 self._asyncioTestLoop = None
123 self._asyncioCallsQueue.put_nowait(None)
124 loop.run_until_complete(self._asyncioCallsQueue.join())
125
126 try:
127 # cancel all tasks
128 to_cancel = asyncio.all_tasks(loop)
129 if not to_cancel:
130 return
131
132 for task in to_cancel:
133 task.cancel()
134
135 loop.run_until_complete(
136 asyncio.gather(*to_cancel, loop=loop, return_exceptions=True))
137
138 for task in to_cancel:
139 if task.cancelled():
140 continue
141 if task.exception() is not None:
142 loop.call_exception_handler({
143 'message': 'unhandled exception during test shutdown',
144 'exception': task.exception(),
145 'task': task,
146 })
147 # shutdown asyncgens
148 loop.run_until_complete(loop.shutdown_asyncgens())
149 finally:
150 asyncio.set_event_loop(None)
151 loop.close()
152
153 def run(self, result=None):
154 self._setupAsyncioLoop()
155 try:
156 return super().run(result)
157 finally:
158 self._tearDownAsyncioLoop()