blob: 492b226a0d549b5ca323ae97f7c40d43eb620781 [file] [log] [blame]
Jelle Zijlstra2e624692017-04-30 18:25:58 -07001import asyncio
Ilya Kulakov1aa094f2018-01-25 12:51:18 -08002from contextlib import asynccontextmanager, AbstractAsyncContextManager, AsyncExitStack
Jelle Zijlstra2e624692017-04-30 18:25:58 -07003import functools
4from test import support
5import unittest
6
Serhiy Storchakadb8e3a12018-07-23 23:38:31 +03007from test.test_contextlib import TestBaseExitStack
Ilya Kulakov1aa094f2018-01-25 12:51:18 -08008
Jelle Zijlstra2e624692017-04-30 18:25:58 -07009
10def _async_test(func):
11 """Decorator to turn an async function into a test case."""
12 @functools.wraps(func)
13 def wrapper(*args, **kwargs):
14 coro = func(*args, **kwargs)
15 loop = asyncio.new_event_loop()
16 asyncio.set_event_loop(loop)
17 try:
18 return loop.run_until_complete(coro)
19 finally:
20 loop.close()
Brett Cannon8425de42018-06-01 20:34:09 -070021 asyncio.set_event_loop_policy(None)
Jelle Zijlstra2e624692017-04-30 18:25:58 -070022 return wrapper
23
24
Jelle Zijlstra176baa32017-12-13 17:19:17 -080025class TestAbstractAsyncContextManager(unittest.TestCase):
26
27 @_async_test
28 async def test_enter(self):
29 class DefaultEnter(AbstractAsyncContextManager):
30 async def __aexit__(self, *args):
31 await super().__aexit__(*args)
32
33 manager = DefaultEnter()
34 self.assertIs(await manager.__aenter__(), manager)
35
36 async with manager as context:
37 self.assertIs(manager, context)
38
Yury Selivanov52698c72018-06-07 20:31:26 -040039 @_async_test
40 async def test_async_gen_propagates_generator_exit(self):
41 # A regression test for https://bugs.python.org/issue33786.
42
43 @asynccontextmanager
44 async def ctx():
45 yield
46
47 async def gen():
48 async with ctx():
49 yield 11
50
51 ret = []
52 exc = ValueError(22)
53 with self.assertRaises(ValueError):
54 async with ctx():
55 async for val in gen():
56 ret.append(val)
57 raise exc
58
59 self.assertEqual(ret, [11])
60
Jelle Zijlstra176baa32017-12-13 17:19:17 -080061 def test_exit_is_abstract(self):
62 class MissingAexit(AbstractAsyncContextManager):
63 pass
64
65 with self.assertRaises(TypeError):
66 MissingAexit()
67
68 def test_structural_subclassing(self):
69 class ManagerFromScratch:
70 async def __aenter__(self):
71 return self
72 async def __aexit__(self, exc_type, exc_value, traceback):
73 return None
74
75 self.assertTrue(issubclass(ManagerFromScratch, AbstractAsyncContextManager))
76
77 class DefaultEnter(AbstractAsyncContextManager):
78 async def __aexit__(self, *args):
79 await super().__aexit__(*args)
80
81 self.assertTrue(issubclass(DefaultEnter, AbstractAsyncContextManager))
82
83 class NoneAenter(ManagerFromScratch):
84 __aenter__ = None
85
86 self.assertFalse(issubclass(NoneAenter, AbstractAsyncContextManager))
87
88 class NoneAexit(ManagerFromScratch):
89 __aexit__ = None
90
91 self.assertFalse(issubclass(NoneAexit, AbstractAsyncContextManager))
92
93
Jelle Zijlstra2e624692017-04-30 18:25:58 -070094class AsyncContextManagerTestCase(unittest.TestCase):
95
96 @_async_test
97 async def test_contextmanager_plain(self):
98 state = []
99 @asynccontextmanager
100 async def woohoo():
101 state.append(1)
102 yield 42
103 state.append(999)
104 async with woohoo() as x:
105 self.assertEqual(state, [1])
106 self.assertEqual(x, 42)
107 state.append(x)
108 self.assertEqual(state, [1, 42, 999])
109
110 @_async_test
111 async def test_contextmanager_finally(self):
112 state = []
113 @asynccontextmanager
114 async def woohoo():
115 state.append(1)
116 try:
117 yield 42
118 finally:
119 state.append(999)
120 with self.assertRaises(ZeroDivisionError):
121 async with woohoo() as x:
122 self.assertEqual(state, [1])
123 self.assertEqual(x, 42)
124 state.append(x)
125 raise ZeroDivisionError()
126 self.assertEqual(state, [1, 42, 999])
127
128 @_async_test
129 async def test_contextmanager_no_reraise(self):
130 @asynccontextmanager
131 async def whee():
132 yield
133 ctx = whee()
134 await ctx.__aenter__()
135 # Calling __aexit__ should not result in an exception
136 self.assertFalse(await ctx.__aexit__(TypeError, TypeError("foo"), None))
137
138 @_async_test
139 async def test_contextmanager_trap_yield_after_throw(self):
140 @asynccontextmanager
141 async def whoo():
142 try:
143 yield
144 except:
145 yield
146 ctx = whoo()
147 await ctx.__aenter__()
148 with self.assertRaises(RuntimeError):
149 await ctx.__aexit__(TypeError, TypeError('foo'), None)
150
151 @_async_test
152 async def test_contextmanager_trap_no_yield(self):
153 @asynccontextmanager
154 async def whoo():
155 if False:
156 yield
157 ctx = whoo()
158 with self.assertRaises(RuntimeError):
159 await ctx.__aenter__()
160
161 @_async_test
162 async def test_contextmanager_trap_second_yield(self):
163 @asynccontextmanager
164 async def whoo():
165 yield
166 yield
167 ctx = whoo()
168 await ctx.__aenter__()
169 with self.assertRaises(RuntimeError):
170 await ctx.__aexit__(None, None, None)
171
172 @_async_test
173 async def test_contextmanager_non_normalised(self):
174 @asynccontextmanager
175 async def whoo():
176 try:
177 yield
178 except RuntimeError:
179 raise SyntaxError
180
181 ctx = whoo()
182 await ctx.__aenter__()
183 with self.assertRaises(SyntaxError):
184 await ctx.__aexit__(RuntimeError, None, None)
185
186 @_async_test
187 async def test_contextmanager_except(self):
188 state = []
189 @asynccontextmanager
190 async def woohoo():
191 state.append(1)
192 try:
193 yield 42
194 except ZeroDivisionError as e:
195 state.append(e.args[0])
196 self.assertEqual(state, [1, 42, 999])
197 async with woohoo() as x:
198 self.assertEqual(state, [1])
199 self.assertEqual(x, 42)
200 state.append(x)
201 raise ZeroDivisionError(999)
202 self.assertEqual(state, [1, 42, 999])
203
204 @_async_test
205 async def test_contextmanager_except_stopiter(self):
206 @asynccontextmanager
207 async def woohoo():
208 yield
209
210 for stop_exc in (StopIteration('spam'), StopAsyncIteration('ham')):
211 with self.subTest(type=type(stop_exc)):
212 try:
213 async with woohoo():
214 raise stop_exc
215 except Exception as ex:
216 self.assertIs(ex, stop_exc)
217 else:
218 self.fail(f'{stop_exc} was suppressed')
219
220 @_async_test
221 async def test_contextmanager_wrap_runtimeerror(self):
222 @asynccontextmanager
223 async def woohoo():
224 try:
225 yield
226 except Exception as exc:
227 raise RuntimeError(f'caught {exc}') from exc
228
229 with self.assertRaises(RuntimeError):
230 async with woohoo():
231 1 / 0
232
233 # If the context manager wrapped StopAsyncIteration in a RuntimeError,
234 # we also unwrap it, because we can't tell whether the wrapping was
235 # done by the generator machinery or by the generator itself.
236 with self.assertRaises(StopAsyncIteration):
237 async with woohoo():
238 raise StopAsyncIteration
239
240 def _create_contextmanager_attribs(self):
241 def attribs(**kw):
242 def decorate(func):
243 for k,v in kw.items():
244 setattr(func,k,v)
245 return func
246 return decorate
247 @asynccontextmanager
248 @attribs(foo='bar')
249 async def baz(spam):
250 """Whee!"""
251 yield
252 return baz
253
254 def test_contextmanager_attribs(self):
255 baz = self._create_contextmanager_attribs()
256 self.assertEqual(baz.__name__,'baz')
257 self.assertEqual(baz.foo, 'bar')
258
259 @support.requires_docstrings
260 def test_contextmanager_doc_attrib(self):
261 baz = self._create_contextmanager_attribs()
262 self.assertEqual(baz.__doc__, "Whee!")
263
264 @support.requires_docstrings
265 @_async_test
266 async def test_instance_docstring_given_cm_docstring(self):
267 baz = self._create_contextmanager_attribs()(None)
268 self.assertEqual(baz.__doc__, "Whee!")
269 async with baz:
270 pass # suppress warning
271
272 @_async_test
273 async def test_keywords(self):
274 # Ensure no keyword arguments are inhibited
275 @asynccontextmanager
276 async def woohoo(self, func, args, kwds):
277 yield (self, func, args, kwds)
278 async with woohoo(self=11, func=22, args=33, kwds=44) as target:
279 self.assertEqual(target, (11, 22, 33, 44))
280
281
Ilya Kulakov1aa094f2018-01-25 12:51:18 -0800282class TestAsyncExitStack(TestBaseExitStack, unittest.TestCase):
283 class SyncAsyncExitStack(AsyncExitStack):
284 @staticmethod
285 def run_coroutine(coro):
286 loop = asyncio.get_event_loop()
287
288 f = asyncio.ensure_future(coro)
289 f.add_done_callback(lambda f: loop.stop())
290 loop.run_forever()
291
292 exc = f.exception()
293
294 if not exc:
295 return f.result()
296 else:
297 context = exc.__context__
298
299 try:
300 raise exc
301 except:
302 exc.__context__ = context
303 raise exc
304
305 def close(self):
306 return self.run_coroutine(self.aclose())
307
308 def __enter__(self):
309 return self.run_coroutine(self.__aenter__())
310
311 def __exit__(self, *exc_details):
312 return self.run_coroutine(self.__aexit__(*exc_details))
313
314 exit_stack = SyncAsyncExitStack
315
316 def setUp(self):
317 self.loop = asyncio.new_event_loop()
318 asyncio.set_event_loop(self.loop)
319 self.addCleanup(self.loop.close)
Brett Cannon8425de42018-06-01 20:34:09 -0700320 self.addCleanup(asyncio.set_event_loop_policy, None)
Ilya Kulakov1aa094f2018-01-25 12:51:18 -0800321
322 @_async_test
323 async def test_async_callback(self):
324 expected = [
325 ((), {}),
326 ((1,), {}),
327 ((1,2), {}),
328 ((), dict(example=1)),
329 ((1,), dict(example=1)),
330 ((1,2), dict(example=1)),
331 ]
332 result = []
333 async def _exit(*args, **kwds):
334 """Test metadata propagation"""
335 result.append((args, kwds))
336
337 async with AsyncExitStack() as stack:
338 for args, kwds in reversed(expected):
339 if args and kwds:
340 f = stack.push_async_callback(_exit, *args, **kwds)
341 elif args:
342 f = stack.push_async_callback(_exit, *args)
343 elif kwds:
344 f = stack.push_async_callback(_exit, **kwds)
345 else:
346 f = stack.push_async_callback(_exit)
347 self.assertIs(f, _exit)
348 for wrapper in stack._exit_callbacks:
349 self.assertIs(wrapper[1].__wrapped__, _exit)
350 self.assertNotEqual(wrapper[1].__name__, _exit.__name__)
351 self.assertIsNone(wrapper[1].__doc__, _exit.__doc__)
352
353 self.assertEqual(result, expected)
354
Serhiy Storchaka42a139e2019-04-01 09:16:35 +0300355 result = []
356 async with AsyncExitStack() as stack:
357 with self.assertRaises(TypeError):
358 stack.push_async_callback(arg=1)
359 with self.assertRaises(TypeError):
360 self.exit_stack.push_async_callback(arg=2)
361 with self.assertWarns(DeprecationWarning):
362 stack.push_async_callback(callback=_exit, arg=3)
363 self.assertEqual(result, [((), {'arg': 3})])
364
Ilya Kulakov1aa094f2018-01-25 12:51:18 -0800365 @_async_test
366 async def test_async_push(self):
367 exc_raised = ZeroDivisionError
368 async def _expect_exc(exc_type, exc, exc_tb):
369 self.assertIs(exc_type, exc_raised)
370 async def _suppress_exc(*exc_details):
371 return True
372 async def _expect_ok(exc_type, exc, exc_tb):
373 self.assertIsNone(exc_type)
374 self.assertIsNone(exc)
375 self.assertIsNone(exc_tb)
376 class ExitCM(object):
377 def __init__(self, check_exc):
378 self.check_exc = check_exc
379 async def __aenter__(self):
380 self.fail("Should not be called!")
381 async def __aexit__(self, *exc_details):
382 await self.check_exc(*exc_details)
383
384 async with self.exit_stack() as stack:
385 stack.push_async_exit(_expect_ok)
386 self.assertIs(stack._exit_callbacks[-1][1], _expect_ok)
387 cm = ExitCM(_expect_ok)
388 stack.push_async_exit(cm)
389 self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
390 stack.push_async_exit(_suppress_exc)
391 self.assertIs(stack._exit_callbacks[-1][1], _suppress_exc)
392 cm = ExitCM(_expect_exc)
393 stack.push_async_exit(cm)
394 self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
395 stack.push_async_exit(_expect_exc)
396 self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
397 stack.push_async_exit(_expect_exc)
398 self.assertIs(stack._exit_callbacks[-1][1], _expect_exc)
399 1/0
400
401 @_async_test
402 async def test_async_enter_context(self):
403 class TestCM(object):
404 async def __aenter__(self):
405 result.append(1)
406 async def __aexit__(self, *exc_details):
407 result.append(3)
408
409 result = []
410 cm = TestCM()
411
412 async with AsyncExitStack() as stack:
413 @stack.push_async_callback # Registered first => cleaned up last
414 async def _exit():
415 result.append(4)
416 self.assertIsNotNone(_exit)
417 await stack.enter_async_context(cm)
418 self.assertIs(stack._exit_callbacks[-1][1].__self__, cm)
419 result.append(2)
420
421 self.assertEqual(result, [1, 2, 3, 4])
422
423 @_async_test
424 async def test_async_exit_exception_chaining(self):
425 # Ensure exception chaining matches the reference behaviour
426 async def raise_exc(exc):
427 raise exc
428
429 saved_details = None
430 async def suppress_exc(*exc_details):
431 nonlocal saved_details
432 saved_details = exc_details
433 return True
434
435 try:
436 async with self.exit_stack() as stack:
437 stack.push_async_callback(raise_exc, IndexError)
438 stack.push_async_callback(raise_exc, KeyError)
439 stack.push_async_callback(raise_exc, AttributeError)
440 stack.push_async_exit(suppress_exc)
441 stack.push_async_callback(raise_exc, ValueError)
442 1 / 0
443 except IndexError as exc:
444 self.assertIsInstance(exc.__context__, KeyError)
445 self.assertIsInstance(exc.__context__.__context__, AttributeError)
446 # Inner exceptions were suppressed
447 self.assertIsNone(exc.__context__.__context__.__context__)
448 else:
449 self.fail("Expected IndexError, but no exception was raised")
450 # Check the inner exceptions
451 inner_exc = saved_details[1]
452 self.assertIsInstance(inner_exc, ValueError)
453 self.assertIsInstance(inner_exc.__context__, ZeroDivisionError)
454
455
Jelle Zijlstra2e624692017-04-30 18:25:58 -0700456if __name__ == '__main__':
457 unittest.main()