blob: 07c1cdf8483c08a928f9cf5ceb32d13a42fe89a3 [file] [log] [blame]
Yury Selivanov75445082015-05-11 22:57:16 -04001import contextlib
Serhiy Storchaka609a2e12015-11-12 11:31:51 +02002import copy
Yury Selivanov8085b802015-05-18 12:50:52 -04003import inspect
Serhiy Storchaka609a2e12015-11-12 11:31:51 +02004import pickle
Yury Selivanov75445082015-05-11 22:57:16 -04005import sys
6import types
7import unittest
8import warnings
9from test import support
10
11
12class AsyncYieldFrom:
13 def __init__(self, obj):
14 self.obj = obj
15
16 def __await__(self):
17 yield from self.obj
18
19
20class AsyncYield:
21 def __init__(self, value):
22 self.value = value
23
24 def __await__(self):
25 yield self.value
26
27
28def run_async(coro):
Yury Selivanov5376ba92015-06-22 12:19:30 -040029 assert coro.__class__ in {types.GeneratorType, types.CoroutineType}
Yury Selivanov75445082015-05-11 22:57:16 -040030
31 buffer = []
32 result = None
33 while True:
34 try:
35 buffer.append(coro.send(None))
36 except StopIteration as ex:
37 result = ex.args[0] if ex.args else None
38 break
39 return buffer, result
40
41
Yury Selivanov5376ba92015-06-22 12:19:30 -040042def run_async__await__(coro):
43 assert coro.__class__ is types.CoroutineType
44 aw = coro.__await__()
45 buffer = []
46 result = None
47 i = 0
48 while True:
49 try:
50 if i % 2:
51 buffer.append(next(aw))
52 else:
53 buffer.append(aw.send(None))
54 i += 1
55 except StopIteration as ex:
56 result = ex.args[0] if ex.args else None
57 break
58 return buffer, result
59
60
Yury Selivanov75445082015-05-11 22:57:16 -040061@contextlib.contextmanager
62def silence_coro_gc():
63 with warnings.catch_warnings():
64 warnings.simplefilter("ignore")
65 yield
66 support.gc_collect()
67
68
69class AsyncBadSyntaxTest(unittest.TestCase):
70
71 def test_badsyntax_1(self):
Yury Selivanov8fb307c2015-07-22 13:33:45 +030072 with self.assertRaisesRegex(SyntaxError, "'await' outside"):
Yury Selivanov75445082015-05-11 22:57:16 -040073 import test.badsyntax_async1
74
75 def test_badsyntax_2(self):
Yury Selivanov8fb307c2015-07-22 13:33:45 +030076 with self.assertRaisesRegex(SyntaxError, "'await' outside"):
Yury Selivanov75445082015-05-11 22:57:16 -040077 import test.badsyntax_async2
78
79 def test_badsyntax_3(self):
80 with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
81 import test.badsyntax_async3
82
83 def test_badsyntax_4(self):
84 with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
85 import test.badsyntax_async4
86
87 def test_badsyntax_5(self):
88 with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
89 import test.badsyntax_async5
90
91 def test_badsyntax_6(self):
92 with self.assertRaisesRegex(
93 SyntaxError, "'yield' inside async function"):
94
95 import test.badsyntax_async6
96
97 def test_badsyntax_7(self):
98 with self.assertRaisesRegex(
99 SyntaxError, "'yield from' inside async function"):
100
101 import test.badsyntax_async7
102
103 def test_badsyntax_8(self):
104 with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
105 import test.badsyntax_async8
106
107 def test_badsyntax_9(self):
Yury Selivanov9dec0352015-06-30 12:49:04 -0400108 ns = {}
109 for comp in {'(await a for a in b)',
110 '[await a for a in b]',
111 '{await a for a in b}',
112 '{await a: c for a in b}'}:
113
Yury Selivanov86cd7d62015-06-30 12:51:12 -0400114 with self.assertRaisesRegex(SyntaxError, 'await.*in comprehen'):
Yury Selivanov9dec0352015-06-30 12:49:04 -0400115 exec('async def f():\n\t{}'.format(comp), ns, ns)
116
Yury Selivanov8fb307c2015-07-22 13:33:45 +0300117 def test_badsyntax_10(self):
118 # Tests for issue 24619
119
120 samples = [
121 """async def foo():
122 def bar(): pass
123 await = 1
124 """,
125
126 """async def foo():
127
128 def bar(): pass
129 await = 1
130 """,
131
132 """async def foo():
133 def bar(): pass
134 if 1:
135 await = 1
136 """,
137
138 """def foo():
139 async def bar(): pass
140 if 1:
141 await a
142 """,
143
144 """def foo():
145 async def bar(): pass
146 await a
147 """,
148
149 """def foo():
150 def baz(): pass
151 async def bar(): pass
152 await a
153 """,
154
155 """def foo():
156 def baz(): pass
157 # 456
158 async def bar(): pass
159 # 123
160 await a
161 """,
162
163 """async def foo():
164 def baz(): pass
165 # 456
166 async def bar(): pass
167 # 123
168 await = 2
169 """,
170
171 """def foo():
172
173 def baz(): pass
174
175 async def bar(): pass
176
177 await a
178 """,
179
180 """async def foo():
181
182 def baz(): pass
183
184 async def bar(): pass
185
186 await = 2
187 """,
188
189 """async def foo():
190 def async(): pass
191 """,
192
193 """async def foo():
194 def await(): pass
195 """,
196
197 """async def foo():
198 def bar():
199 await
200 """,
201
202 """async def foo():
203 return lambda async: await
204 """,
205
206 """async def foo():
207 return lambda a: await
208 """,
209
Yury Selivanovb7666a32015-07-22 14:48:57 +0300210 """await a()""",
211
212 """async def foo(a=await b):
Yury Selivanov8fb307c2015-07-22 13:33:45 +0300213 pass
214 """,
215
Yury Selivanovf315c1c2015-07-23 09:10:44 +0300216 """async def foo(a:await b):
217 pass
218 """,
219
Yury Selivanov8fb307c2015-07-22 13:33:45 +0300220 """def baz():
Yury Selivanovb7666a32015-07-22 14:48:57 +0300221 async def foo(a=await b):
Yury Selivanov8fb307c2015-07-22 13:33:45 +0300222 pass
223 """,
224
225 """async def foo(async):
226 pass
227 """,
228
229 """async def foo():
230 def bar():
231 def baz():
232 async = 1
233 """,
234
235 """async def foo():
236 def bar():
237 def baz():
238 pass
239 async = 1
240 """,
241
242 """def foo():
243 async def bar():
244
245 async def baz():
246 pass
247
248 def baz():
249 42
250
251 async = 1
252 """,
253
254 """async def foo():
255 def bar():
256 def baz():
257 pass\nawait foo()
258 """,
259
260 """def foo():
261 def bar():
262 async def baz():
263 pass\nawait foo()
264 """,
265
266 """async def foo(await):
267 pass
268 """,
269
270 """def foo():
271
272 async def bar(): pass
273
274 await a
275 """,
276
277 """def foo():
278 async def bar():
279 pass\nawait a
280 """]
281
Yury Selivanov8fb307c2015-07-22 13:33:45 +0300282 for code in samples:
283 with self.subTest(code=code), self.assertRaises(SyntaxError):
Yury Selivanovb7666a32015-07-22 14:48:57 +0300284 compile(code, "<test>", "exec")
Yury Selivanov8fb307c2015-07-22 13:33:45 +0300285
286 def test_goodsyntax_1(self):
287 # Tests for issue 24619
288
289 def foo(await):
290 async def foo(): pass
291 async def foo():
292 pass
293 return await + 1
294 self.assertEqual(foo(10), 11)
295
296 def foo(await):
297 async def foo(): pass
298 async def foo(): pass
299 return await + 2
300 self.assertEqual(foo(20), 22)
301
302 def foo(await):
303
304 async def foo(): pass
305
306 async def foo(): pass
307
308 return await + 2
309 self.assertEqual(foo(20), 22)
310
311 def foo(await):
312 """spam"""
313 async def foo(): \
314 pass
315 # 123
316 async def foo(): pass
317 # 456
318 return await + 2
319 self.assertEqual(foo(20), 22)
320
321 def foo(await):
322 def foo(): pass
323 def foo(): pass
324 async def bar(): return await_
325 await_ = await
326 try:
327 bar().send(None)
328 except StopIteration as ex:
329 return ex.args[0]
330 self.assertEqual(foo(42), 42)
331
332 async def f():
333 async def g(): pass
334 await z
Yury Selivanov96ec9342015-07-23 15:01:58 +0300335 await = 1
Yury Selivanov8fb307c2015-07-22 13:33:45 +0300336 self.assertTrue(inspect.iscoroutinefunction(f))
337
Yury Selivanov75445082015-05-11 22:57:16 -0400338
Yury Selivanov8085b802015-05-18 12:50:52 -0400339class TokenizerRegrTest(unittest.TestCase):
340
341 def test_oneline_defs(self):
342 buf = []
343 for i in range(500):
344 buf.append('def i{i}(): return {i}'.format(i=i))
345 buf = '\n'.join(buf)
346
347 # Test that 500 consequent, one-line defs is OK
348 ns = {}
349 exec(buf, ns, ns)
350 self.assertEqual(ns['i499'](), 499)
351
352 # Test that 500 consequent, one-line defs *and*
353 # one 'async def' following them is OK
354 buf += '\nasync def foo():\n return'
355 ns = {}
356 exec(buf, ns, ns)
357 self.assertEqual(ns['i499'](), 499)
358 self.assertTrue(inspect.iscoroutinefunction(ns['foo']))
359
360
Yury Selivanov75445082015-05-11 22:57:16 -0400361class CoroutineTest(unittest.TestCase):
362
363 def test_gen_1(self):
364 def gen(): yield
365 self.assertFalse(hasattr(gen, '__await__'))
366
367 def test_func_1(self):
368 async def foo():
369 return 10
370
371 f = foo()
Yury Selivanov5376ba92015-06-22 12:19:30 -0400372 self.assertIsInstance(f, types.CoroutineType)
373 self.assertTrue(bool(foo.__code__.co_flags & inspect.CO_COROUTINE))
374 self.assertFalse(bool(foo.__code__.co_flags & inspect.CO_GENERATOR))
375 self.assertTrue(bool(f.cr_code.co_flags & inspect.CO_COROUTINE))
376 self.assertFalse(bool(f.cr_code.co_flags & inspect.CO_GENERATOR))
Yury Selivanov75445082015-05-11 22:57:16 -0400377 self.assertEqual(run_async(f), ([], 10))
378
Yury Selivanov5376ba92015-06-22 12:19:30 -0400379 self.assertEqual(run_async__await__(foo()), ([], 10))
380
Yury Selivanov75445082015-05-11 22:57:16 -0400381 def bar(): pass
Yury Selivanov5376ba92015-06-22 12:19:30 -0400382 self.assertFalse(bool(bar.__code__.co_flags & inspect.CO_COROUTINE))
Yury Selivanov75445082015-05-11 22:57:16 -0400383
384 def test_func_2(self):
385 async def foo():
386 raise StopIteration
387
388 with self.assertRaisesRegex(
Yury Selivanov5376ba92015-06-22 12:19:30 -0400389 RuntimeError, "coroutine raised StopIteration"):
Yury Selivanov75445082015-05-11 22:57:16 -0400390
391 run_async(foo())
392
393 def test_func_3(self):
394 async def foo():
395 raise StopIteration
396
397 with silence_coro_gc():
398 self.assertRegex(repr(foo()), '^<coroutine object.* at 0x.*>$')
399
400 def test_func_4(self):
401 async def foo():
402 raise StopIteration
403
404 check = lambda: self.assertRaisesRegex(
Yury Selivanov5376ba92015-06-22 12:19:30 -0400405 TypeError, "'coroutine' object is not iterable")
Yury Selivanov75445082015-05-11 22:57:16 -0400406
407 with check():
408 list(foo())
409
410 with check():
411 tuple(foo())
412
413 with check():
414 sum(foo())
415
416 with check():
417 iter(foo())
418
Yury Selivanov75445082015-05-11 22:57:16 -0400419 with silence_coro_gc(), check():
420 for i in foo():
421 pass
422
423 with silence_coro_gc(), check():
424 [i for i in foo()]
425
426 def test_func_5(self):
427 @types.coroutine
428 def bar():
429 yield 1
430
431 async def foo():
432 await bar()
433
434 check = lambda: self.assertRaisesRegex(
Yury Selivanov5376ba92015-06-22 12:19:30 -0400435 TypeError, "'coroutine' object is not iterable")
Yury Selivanov75445082015-05-11 22:57:16 -0400436
437 with check():
438 for el in foo(): pass
439
440 # the following should pass without an error
441 for el in bar():
442 self.assertEqual(el, 1)
443 self.assertEqual([el for el in bar()], [1])
444 self.assertEqual(tuple(bar()), (1,))
445 self.assertEqual(next(iter(bar())), 1)
446
447 def test_func_6(self):
448 @types.coroutine
449 def bar():
450 yield 1
451 yield 2
452
453 async def foo():
454 await bar()
455
456 f = foo()
Zachary Ware37ac5902015-05-13 01:03:06 -0500457 self.assertEqual(f.send(None), 1)
458 self.assertEqual(f.send(None), 2)
Yury Selivanov75445082015-05-11 22:57:16 -0400459 with self.assertRaises(StopIteration):
460 f.send(None)
461
462 def test_func_7(self):
463 async def bar():
464 return 10
465
466 def foo():
467 yield from bar()
468
469 with silence_coro_gc(), self.assertRaisesRegex(
470 TypeError,
Yury Selivanov5376ba92015-06-22 12:19:30 -0400471 "cannot 'yield from' a coroutine object in a non-coroutine generator"):
Yury Selivanov75445082015-05-11 22:57:16 -0400472
473 list(foo())
474
475 def test_func_8(self):
476 @types.coroutine
477 def bar():
478 return (yield from foo())
479
480 async def foo():
481 return 'spam'
482
483 self.assertEqual(run_async(bar()), ([], 'spam') )
484
485 def test_func_9(self):
486 async def foo(): pass
487
488 with self.assertWarnsRegex(
489 RuntimeWarning, "coroutine '.*test_func_9.*foo' was never awaited"):
490
491 foo()
492 support.gc_collect()
493
Yury Selivanov5376ba92015-06-22 12:19:30 -0400494 def test_func_10(self):
495 N = 0
496
497 @types.coroutine
498 def gen():
499 nonlocal N
500 try:
501 a = yield
502 yield (a ** 2)
503 except ZeroDivisionError:
504 N += 100
505 raise
506 finally:
507 N += 1
508
509 async def foo():
510 await gen()
511
512 coro = foo()
513 aw = coro.__await__()
514 self.assertIs(aw, iter(aw))
515 next(aw)
516 self.assertEqual(aw.send(10), 100)
517
518 self.assertEqual(N, 0)
519 aw.close()
520 self.assertEqual(N, 1)
521
522 coro = foo()
523 aw = coro.__await__()
524 next(aw)
525 with self.assertRaises(ZeroDivisionError):
526 aw.throw(ZeroDivisionError, None, None)
527 self.assertEqual(N, 102)
528
529 def test_func_11(self):
530 async def func(): pass
531 coro = func()
532 # Test that PyCoro_Type and _PyCoroWrapper_Type types were properly
533 # initialized
534 self.assertIn('__await__', dir(coro))
535 self.assertIn('__iter__', dir(coro.__await__()))
536 self.assertIn('coroutine_wrapper', repr(coro.__await__()))
537 coro.close() # avoid RuntimeWarning
538
539 def test_func_12(self):
540 async def g():
541 i = me.send(None)
542 await foo
543 me = g()
544 with self.assertRaisesRegex(ValueError,
545 "coroutine already executing"):
546 me.send(None)
547
548 def test_func_13(self):
549 async def g():
550 pass
551 with self.assertRaisesRegex(
552 TypeError,
553 "can't send non-None value to a just-started coroutine"):
554
555 g().send('spam')
556
557 def test_func_14(self):
558 @types.coroutine
559 def gen():
560 yield
561 async def coro():
562 try:
563 await gen()
564 except GeneratorExit:
565 await gen()
566 c = coro()
567 c.send(None)
568 with self.assertRaisesRegex(RuntimeError,
569 "coroutine ignored GeneratorExit"):
570 c.close()
571
Yury Selivanove13f8f32015-07-03 00:23:30 -0400572 def test_cr_await(self):
573 @types.coroutine
574 def a():
575 self.assertEqual(inspect.getcoroutinestate(coro_b), inspect.CORO_RUNNING)
576 self.assertIsNone(coro_b.cr_await)
577 yield
578 self.assertEqual(inspect.getcoroutinestate(coro_b), inspect.CORO_RUNNING)
579 self.assertIsNone(coro_b.cr_await)
580
581 async def c():
582 await a()
583
584 async def b():
585 self.assertIsNone(coro_b.cr_await)
586 await c()
587 self.assertIsNone(coro_b.cr_await)
588
589 coro_b = b()
590 self.assertEqual(inspect.getcoroutinestate(coro_b), inspect.CORO_CREATED)
591 self.assertIsNone(coro_b.cr_await)
592
593 coro_b.send(None)
594 self.assertEqual(inspect.getcoroutinestate(coro_b), inspect.CORO_SUSPENDED)
595 self.assertEqual(coro_b.cr_await.cr_await.gi_code.co_name, 'a')
596
597 with self.assertRaises(StopIteration):
598 coro_b.send(None) # complete coroutine
599 self.assertEqual(inspect.getcoroutinestate(coro_b), inspect.CORO_CLOSED)
600 self.assertIsNone(coro_b.cr_await)
601
Yury Selivanov5376ba92015-06-22 12:19:30 -0400602 def test_corotype_1(self):
603 ct = types.CoroutineType
604 self.assertIn('into coroutine', ct.send.__doc__)
605 self.assertIn('inside coroutine', ct.close.__doc__)
606 self.assertIn('in coroutine', ct.throw.__doc__)
607 self.assertIn('of the coroutine', ct.__dict__['__name__'].__doc__)
608 self.assertIn('of the coroutine', ct.__dict__['__qualname__'].__doc__)
609 self.assertEqual(ct.__name__, 'coroutine')
610
611 async def f(): pass
612 c = f()
613 self.assertIn('coroutine object', repr(c))
614 c.close()
615
Yury Selivanov75445082015-05-11 22:57:16 -0400616 def test_await_1(self):
617
618 async def foo():
619 await 1
620 with self.assertRaisesRegex(TypeError, "object int can.t.*await"):
621 run_async(foo())
622
623 def test_await_2(self):
624 async def foo():
625 await []
626 with self.assertRaisesRegex(TypeError, "object list can.t.*await"):
627 run_async(foo())
628
629 def test_await_3(self):
630 async def foo():
631 await AsyncYieldFrom([1, 2, 3])
632
633 self.assertEqual(run_async(foo()), ([1, 2, 3], None))
Yury Selivanov5376ba92015-06-22 12:19:30 -0400634 self.assertEqual(run_async__await__(foo()), ([1, 2, 3], None))
Yury Selivanov75445082015-05-11 22:57:16 -0400635
636 def test_await_4(self):
637 async def bar():
638 return 42
639
640 async def foo():
641 return await bar()
642
643 self.assertEqual(run_async(foo()), ([], 42))
644
645 def test_await_5(self):
646 class Awaitable:
647 def __await__(self):
648 return
649
650 async def foo():
651 return (await Awaitable())
652
653 with self.assertRaisesRegex(
654 TypeError, "__await__.*returned non-iterator of type"):
655
656 run_async(foo())
657
658 def test_await_6(self):
659 class Awaitable:
660 def __await__(self):
661 return iter([52])
662
663 async def foo():
664 return (await Awaitable())
665
666 self.assertEqual(run_async(foo()), ([52], None))
667
668 def test_await_7(self):
669 class Awaitable:
670 def __await__(self):
671 yield 42
672 return 100
673
674 async def foo():
675 return (await Awaitable())
676
677 self.assertEqual(run_async(foo()), ([42], 100))
678
679 def test_await_8(self):
680 class Awaitable:
681 pass
682
Yury Selivanov8fb307c2015-07-22 13:33:45 +0300683 async def foo(): return await Awaitable()
Yury Selivanov75445082015-05-11 22:57:16 -0400684
685 with self.assertRaisesRegex(
686 TypeError, "object Awaitable can't be used in 'await' expression"):
687
688 run_async(foo())
689
690 def test_await_9(self):
691 def wrap():
692 return bar
693
694 async def bar():
695 return 42
696
697 async def foo():
698 b = bar()
699
700 db = {'b': lambda: wrap}
701
702 class DB:
703 b = wrap
704
705 return (await bar() + await wrap()() + await db['b']()()() +
706 await bar() * 1000 + await DB.b()())
707
708 async def foo2():
709 return -await bar()
710
711 self.assertEqual(run_async(foo()), ([], 42168))
712 self.assertEqual(run_async(foo2()), ([], -42))
713
714 def test_await_10(self):
715 async def baz():
716 return 42
717
718 async def bar():
719 return baz()
720
721 async def foo():
722 return await (await bar())
723
724 self.assertEqual(run_async(foo()), ([], 42))
725
726 def test_await_11(self):
727 def ident(val):
728 return val
729
730 async def bar():
731 return 'spam'
732
733 async def foo():
734 return ident(val=await bar())
735
736 async def foo2():
737 return await bar(), 'ham'
738
739 self.assertEqual(run_async(foo2()), ([], ('spam', 'ham')))
740
741 def test_await_12(self):
742 async def coro():
743 return 'spam'
744
745 class Awaitable:
746 def __await__(self):
747 return coro()
748
749 async def foo():
750 return await Awaitable()
751
752 with self.assertRaisesRegex(
753 TypeError, "__await__\(\) returned a coroutine"):
754
755 run_async(foo())
756
757 def test_await_13(self):
758 class Awaitable:
759 def __await__(self):
760 return self
761
762 async def foo():
763 return await Awaitable()
764
765 with self.assertRaisesRegex(
766 TypeError, "__await__.*returned non-iterator of type"):
767
768 run_async(foo())
769
Yury Selivanovf2701522015-07-01 12:29:55 -0400770 def test_await_14(self):
771 class Wrapper:
772 # Forces the interpreter to use CoroutineType.__await__
773 def __init__(self, coro):
774 assert coro.__class__ is types.CoroutineType
775 self.coro = coro
776 def __await__(self):
777 return self.coro.__await__()
778
779 class FutureLike:
780 def __await__(self):
781 return (yield)
782
783 class Marker(Exception):
784 pass
785
786 async def coro1():
787 try:
788 return await FutureLike()
789 except ZeroDivisionError:
790 raise Marker
791 async def coro2():
792 return await Wrapper(coro1())
793
794 c = coro2()
795 c.send(None)
796 with self.assertRaisesRegex(StopIteration, 'spam'):
797 c.send('spam')
798
799 c = coro2()
800 c.send(None)
801 with self.assertRaises(Marker):
802 c.throw(ZeroDivisionError)
803
Yury Selivanov75445082015-05-11 22:57:16 -0400804 def test_with_1(self):
805 class Manager:
806 def __init__(self, name):
807 self.name = name
808
809 async def __aenter__(self):
810 await AsyncYieldFrom(['enter-1-' + self.name,
811 'enter-2-' + self.name])
812 return self
813
814 async def __aexit__(self, *args):
815 await AsyncYieldFrom(['exit-1-' + self.name,
816 'exit-2-' + self.name])
817
818 if self.name == 'B':
819 return True
820
821
822 async def foo():
823 async with Manager("A") as a, Manager("B") as b:
824 await AsyncYieldFrom([('managers', a.name, b.name)])
825 1/0
826
827 f = foo()
828 result, _ = run_async(f)
829
830 self.assertEqual(
831 result, ['enter-1-A', 'enter-2-A', 'enter-1-B', 'enter-2-B',
832 ('managers', 'A', 'B'),
833 'exit-1-B', 'exit-2-B', 'exit-1-A', 'exit-2-A']
834 )
835
836 async def foo():
837 async with Manager("A") as a, Manager("C") as c:
838 await AsyncYieldFrom([('managers', a.name, c.name)])
839 1/0
840
841 with self.assertRaises(ZeroDivisionError):
842 run_async(foo())
843
844 def test_with_2(self):
845 class CM:
846 def __aenter__(self):
847 pass
848
849 async def foo():
850 async with CM():
851 pass
852
853 with self.assertRaisesRegex(AttributeError, '__aexit__'):
854 run_async(foo())
855
856 def test_with_3(self):
857 class CM:
858 def __aexit__(self):
859 pass
860
861 async def foo():
862 async with CM():
863 pass
864
865 with self.assertRaisesRegex(AttributeError, '__aenter__'):
866 run_async(foo())
867
868 def test_with_4(self):
869 class CM:
870 def __enter__(self):
871 pass
872
873 def __exit__(self):
874 pass
875
876 async def foo():
877 async with CM():
878 pass
879
880 with self.assertRaisesRegex(AttributeError, '__aexit__'):
881 run_async(foo())
882
883 def test_with_5(self):
884 # While this test doesn't make a lot of sense,
885 # it's a regression test for an early bug with opcodes
886 # generation
887
888 class CM:
889 async def __aenter__(self):
890 return self
891
892 async def __aexit__(self, *exc):
893 pass
894
895 async def func():
896 async with CM():
897 assert (1, ) == 1
898
899 with self.assertRaises(AssertionError):
900 run_async(func())
901
902 def test_with_6(self):
903 class CM:
904 def __aenter__(self):
905 return 123
906
907 def __aexit__(self, *e):
908 return 456
909
910 async def foo():
911 async with CM():
912 pass
913
914 with self.assertRaisesRegex(
915 TypeError, "object int can't be used in 'await' expression"):
916 # it's important that __aexit__ wasn't called
917 run_async(foo())
918
919 def test_with_7(self):
920 class CM:
921 async def __aenter__(self):
922 return self
923
924 def __aexit__(self, *e):
Nick Coghlanbaaadbf2015-05-13 15:54:02 +1000925 return 444
Yury Selivanov75445082015-05-11 22:57:16 -0400926
927 async def foo():
928 async with CM():
Nick Coghlanbaaadbf2015-05-13 15:54:02 +1000929 1/0
930
931 try:
932 run_async(foo())
933 except TypeError as exc:
934 self.assertRegex(
935 exc.args[0], "object int can't be used in 'await' expression")
936 self.assertTrue(exc.__context__ is not None)
937 self.assertTrue(isinstance(exc.__context__, ZeroDivisionError))
938 else:
939 self.fail('invalid asynchronous context manager did not fail')
940
941
942 def test_with_8(self):
943 CNT = 0
944
945 class CM:
946 async def __aenter__(self):
947 return self
948
949 def __aexit__(self, *e):
950 return 456
951
952 async def foo():
953 nonlocal CNT
954 async with CM():
955 CNT += 1
956
Yury Selivanov75445082015-05-11 22:57:16 -0400957
958 with self.assertRaisesRegex(
959 TypeError, "object int can't be used in 'await' expression"):
960
961 run_async(foo())
962
Nick Coghlanbaaadbf2015-05-13 15:54:02 +1000963 self.assertEqual(CNT, 1)
964
965
966 def test_with_9(self):
967 CNT = 0
968
969 class CM:
970 async def __aenter__(self):
971 return self
972
973 async def __aexit__(self, *e):
974 1/0
975
976 async def foo():
977 nonlocal CNT
978 async with CM():
979 CNT += 1
980
981 with self.assertRaises(ZeroDivisionError):
982 run_async(foo())
983
984 self.assertEqual(CNT, 1)
985
986 def test_with_10(self):
987 CNT = 0
988
989 class CM:
990 async def __aenter__(self):
991 return self
992
993 async def __aexit__(self, *e):
994 1/0
995
996 async def foo():
997 nonlocal CNT
998 async with CM():
999 async with CM():
1000 raise RuntimeError
1001
1002 try:
1003 run_async(foo())
1004 except ZeroDivisionError as exc:
1005 self.assertTrue(exc.__context__ is not None)
1006 self.assertTrue(isinstance(exc.__context__, ZeroDivisionError))
1007 self.assertTrue(isinstance(exc.__context__.__context__,
1008 RuntimeError))
1009 else:
1010 self.fail('exception from __aexit__ did not propagate')
1011
1012 def test_with_11(self):
1013 CNT = 0
1014
1015 class CM:
1016 async def __aenter__(self):
1017 raise NotImplementedError
1018
1019 async def __aexit__(self, *e):
1020 1/0
1021
1022 async def foo():
1023 nonlocal CNT
1024 async with CM():
1025 raise RuntimeError
1026
1027 try:
1028 run_async(foo())
1029 except NotImplementedError as exc:
1030 self.assertTrue(exc.__context__ is None)
1031 else:
1032 self.fail('exception from __aenter__ did not propagate')
1033
1034 def test_with_12(self):
1035 CNT = 0
1036
1037 class CM:
1038 async def __aenter__(self):
1039 return self
1040
1041 async def __aexit__(self, *e):
1042 return True
1043
1044 async def foo():
1045 nonlocal CNT
1046 async with CM() as cm:
1047 self.assertIs(cm.__class__, CM)
1048 raise RuntimeError
1049
1050 run_async(foo())
1051
Yury Selivanov9113dc72015-05-13 16:49:35 -04001052 def test_with_13(self):
1053 CNT = 0
1054
1055 class CM:
1056 async def __aenter__(self):
1057 1/0
1058
1059 async def __aexit__(self, *e):
1060 return True
1061
1062 async def foo():
1063 nonlocal CNT
1064 CNT += 1
1065 async with CM():
1066 CNT += 1000
1067 CNT += 10000
1068
1069 with self.assertRaises(ZeroDivisionError):
1070 run_async(foo())
1071 self.assertEqual(CNT, 1)
1072
Yury Selivanov75445082015-05-11 22:57:16 -04001073 def test_for_1(self):
1074 aiter_calls = 0
1075
1076 class AsyncIter:
1077 def __init__(self):
1078 self.i = 0
1079
1080 async def __aiter__(self):
1081 nonlocal aiter_calls
1082 aiter_calls += 1
1083 return self
1084
1085 async def __anext__(self):
1086 self.i += 1
1087
1088 if not (self.i % 10):
1089 await AsyncYield(self.i * 10)
1090
1091 if self.i > 100:
1092 raise StopAsyncIteration
1093
1094 return self.i, self.i
1095
1096
1097 buffer = []
1098 async def test1():
1099 async for i1, i2 in AsyncIter():
1100 buffer.append(i1 + i2)
1101
1102 yielded, _ = run_async(test1())
1103 # Make sure that __aiter__ was called only once
1104 self.assertEqual(aiter_calls, 1)
1105 self.assertEqual(yielded, [i * 100 for i in range(1, 11)])
1106 self.assertEqual(buffer, [i*2 for i in range(1, 101)])
1107
1108
1109 buffer = []
1110 async def test2():
1111 nonlocal buffer
1112 async for i in AsyncIter():
1113 buffer.append(i[0])
1114 if i[0] == 20:
1115 break
1116 else:
1117 buffer.append('what?')
1118 buffer.append('end')
1119
1120 yielded, _ = run_async(test2())
1121 # Make sure that __aiter__ was called only once
1122 self.assertEqual(aiter_calls, 2)
1123 self.assertEqual(yielded, [100, 200])
1124 self.assertEqual(buffer, [i for i in range(1, 21)] + ['end'])
1125
1126
1127 buffer = []
1128 async def test3():
1129 nonlocal buffer
1130 async for i in AsyncIter():
1131 if i[0] > 20:
1132 continue
1133 buffer.append(i[0])
1134 else:
1135 buffer.append('what?')
1136 buffer.append('end')
1137
1138 yielded, _ = run_async(test3())
1139 # Make sure that __aiter__ was called only once
1140 self.assertEqual(aiter_calls, 3)
1141 self.assertEqual(yielded, [i * 100 for i in range(1, 11)])
1142 self.assertEqual(buffer, [i for i in range(1, 21)] +
1143 ['what?', 'end'])
1144
1145 def test_for_2(self):
1146 tup = (1, 2, 3)
1147 refs_before = sys.getrefcount(tup)
1148
1149 async def foo():
1150 async for i in tup:
1151 print('never going to happen')
1152
1153 with self.assertRaisesRegex(
1154 TypeError, "async for' requires an object.*__aiter__.*tuple"):
1155
1156 run_async(foo())
1157
1158 self.assertEqual(sys.getrefcount(tup), refs_before)
1159
1160 def test_for_3(self):
1161 class I:
1162 def __aiter__(self):
1163 return self
1164
1165 aiter = I()
1166 refs_before = sys.getrefcount(aiter)
1167
1168 async def foo():
1169 async for i in aiter:
1170 print('never going to happen')
1171
1172 with self.assertRaisesRegex(
1173 TypeError,
1174 "async for' received an invalid object.*__aiter.*\: I"):
1175
1176 run_async(foo())
1177
1178 self.assertEqual(sys.getrefcount(aiter), refs_before)
1179
1180 def test_for_4(self):
1181 class I:
1182 async def __aiter__(self):
1183 return self
1184
1185 def __anext__(self):
1186 return ()
1187
1188 aiter = I()
1189 refs_before = sys.getrefcount(aiter)
1190
1191 async def foo():
1192 async for i in aiter:
1193 print('never going to happen')
1194
1195 with self.assertRaisesRegex(
1196 TypeError,
1197 "async for' received an invalid object.*__anext__.*tuple"):
1198
1199 run_async(foo())
1200
1201 self.assertEqual(sys.getrefcount(aiter), refs_before)
1202
1203 def test_for_5(self):
1204 class I:
1205 async def __aiter__(self):
1206 return self
1207
1208 def __anext__(self):
1209 return 123
1210
1211 async def foo():
1212 async for i in I():
1213 print('never going to happen')
1214
1215 with self.assertRaisesRegex(
1216 TypeError,
1217 "async for' received an invalid object.*__anext.*int"):
1218
1219 run_async(foo())
1220
1221 def test_for_6(self):
1222 I = 0
1223
1224 class Manager:
1225 async def __aenter__(self):
1226 nonlocal I
1227 I += 10000
1228
1229 async def __aexit__(self, *args):
1230 nonlocal I
1231 I += 100000
1232
1233 class Iterable:
1234 def __init__(self):
1235 self.i = 0
1236
1237 async def __aiter__(self):
1238 return self
1239
1240 async def __anext__(self):
1241 if self.i > 10:
1242 raise StopAsyncIteration
1243 self.i += 1
1244 return self.i
1245
1246 ##############
1247
1248 manager = Manager()
1249 iterable = Iterable()
1250 mrefs_before = sys.getrefcount(manager)
1251 irefs_before = sys.getrefcount(iterable)
1252
1253 async def main():
1254 nonlocal I
1255
1256 async with manager:
1257 async for i in iterable:
1258 I += 1
1259 I += 1000
1260
1261 run_async(main())
1262 self.assertEqual(I, 111011)
1263
1264 self.assertEqual(sys.getrefcount(manager), mrefs_before)
1265 self.assertEqual(sys.getrefcount(iterable), irefs_before)
1266
1267 ##############
1268
1269 async def main():
1270 nonlocal I
1271
1272 async with Manager():
1273 async for i in Iterable():
1274 I += 1
1275 I += 1000
1276
1277 async with Manager():
1278 async for i in Iterable():
1279 I += 1
1280 I += 1000
1281
1282 run_async(main())
1283 self.assertEqual(I, 333033)
1284
1285 ##############
1286
1287 async def main():
1288 nonlocal I
1289
1290 async with Manager():
1291 I += 100
1292 async for i in Iterable():
1293 I += 1
1294 else:
1295 I += 10000000
1296 I += 1000
1297
1298 async with Manager():
1299 I += 100
1300 async for i in Iterable():
1301 I += 1
1302 else:
1303 I += 10000000
1304 I += 1000
1305
1306 run_async(main())
1307 self.assertEqual(I, 20555255)
1308
Yury Selivanov9113dc72015-05-13 16:49:35 -04001309 def test_for_7(self):
1310 CNT = 0
1311 class AI:
1312 async def __aiter__(self):
1313 1/0
1314 async def foo():
1315 nonlocal CNT
1316 async for i in AI():
1317 CNT += 1
1318 CNT += 10
1319 with self.assertRaises(ZeroDivisionError):
1320 run_async(foo())
1321 self.assertEqual(CNT, 0)
1322
Serhiy Storchaka609a2e12015-11-12 11:31:51 +02001323 def test_copy(self):
1324 async def func(): pass
1325 coro = func()
1326 with self.assertRaises(TypeError):
1327 copy.copy(coro)
1328
1329 aw = coro.__await__()
1330 try:
1331 with self.assertRaises(TypeError):
1332 copy.copy(aw)
1333 finally:
1334 aw.close()
1335
1336 def test_pickle(self):
1337 async def func(): pass
1338 coro = func()
1339 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1340 with self.assertRaises((TypeError, pickle.PicklingError)):
1341 pickle.dumps(coro, proto)
1342
1343 aw = coro.__await__()
1344 try:
1345 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1346 with self.assertRaises((TypeError, pickle.PicklingError)):
1347 pickle.dumps(aw, proto)
1348 finally:
1349 aw.close()
1350
Yury Selivanov75445082015-05-11 22:57:16 -04001351
1352class CoroAsyncIOCompatTest(unittest.TestCase):
1353
1354 def test_asyncio_1(self):
Victor Stinnerb45c0f72015-10-11 10:10:31 +02001355 # asyncio cannot be imported when Python is compiled without thread
1356 # support
Victor Stinner718c9842015-10-11 10:53:15 +02001357 asyncio = support.import_module('asyncio')
Yury Selivanov75445082015-05-11 22:57:16 -04001358
1359 class MyException(Exception):
1360 pass
1361
1362 buffer = []
1363
1364 class CM:
1365 async def __aenter__(self):
1366 buffer.append(1)
1367 await asyncio.sleep(0.01)
1368 buffer.append(2)
1369 return self
1370
1371 async def __aexit__(self, exc_type, exc_val, exc_tb):
1372 await asyncio.sleep(0.01)
1373 buffer.append(exc_type.__name__)
1374
1375 async def f():
1376 async with CM() as c:
1377 await asyncio.sleep(0.01)
1378 raise MyException
1379 buffer.append('unreachable')
1380
Yury Selivanovfdba8382015-05-12 14:28:08 -04001381 loop = asyncio.new_event_loop()
1382 asyncio.set_event_loop(loop)
Yury Selivanov75445082015-05-11 22:57:16 -04001383 try:
1384 loop.run_until_complete(f())
1385 except MyException:
1386 pass
1387 finally:
1388 loop.close()
Yury Selivanovfdba8382015-05-12 14:28:08 -04001389 asyncio.set_event_loop(None)
Yury Selivanov75445082015-05-11 22:57:16 -04001390
1391 self.assertEqual(buffer, [1, 2, 'MyException'])
1392
1393
1394class SysSetCoroWrapperTest(unittest.TestCase):
1395
1396 def test_set_wrapper_1(self):
1397 async def foo():
1398 return 'spam'
1399
1400 wrapped = None
1401 def wrap(gen):
1402 nonlocal wrapped
1403 wrapped = gen
1404 return gen
1405
1406 self.assertIsNone(sys.get_coroutine_wrapper())
1407
1408 sys.set_coroutine_wrapper(wrap)
1409 self.assertIs(sys.get_coroutine_wrapper(), wrap)
1410 try:
1411 f = foo()
1412 self.assertTrue(wrapped)
1413
1414 self.assertEqual(run_async(f), ([], 'spam'))
1415 finally:
1416 sys.set_coroutine_wrapper(None)
1417
1418 self.assertIsNone(sys.get_coroutine_wrapper())
1419
1420 wrapped = None
1421 with silence_coro_gc():
1422 foo()
1423 self.assertFalse(wrapped)
1424
1425 def test_set_wrapper_2(self):
1426 self.assertIsNone(sys.get_coroutine_wrapper())
1427 with self.assertRaisesRegex(TypeError, "callable expected, got int"):
1428 sys.set_coroutine_wrapper(1)
1429 self.assertIsNone(sys.get_coroutine_wrapper())
1430
Yury Selivanovaab3c4a2015-06-02 18:43:51 -04001431 def test_set_wrapper_3(self):
1432 async def foo():
1433 return 'spam'
1434
1435 def wrapper(coro):
1436 async def wrap(coro):
1437 return await coro
1438 return wrap(coro)
1439
1440 sys.set_coroutine_wrapper(wrapper)
1441 try:
Yury Selivanov94c22632015-06-04 10:16:51 -04001442 with silence_coro_gc(), self.assertRaisesRegex(
Yury Selivanovaab3c4a2015-06-02 18:43:51 -04001443 RuntimeError,
1444 "coroutine wrapper.*\.wrapper at 0x.*attempted to "
Yury Selivanov94c22632015-06-04 10:16:51 -04001445 "recursively wrap .* wrap .*"):
Yury Selivanovaab3c4a2015-06-02 18:43:51 -04001446
1447 foo()
1448 finally:
1449 sys.set_coroutine_wrapper(None)
1450
Yury Selivanov5376ba92015-06-22 12:19:30 -04001451 def test_set_wrapper_4(self):
1452 @types.coroutine
1453 def foo():
1454 return 'spam'
1455
1456 wrapped = None
1457 def wrap(gen):
1458 nonlocal wrapped
1459 wrapped = gen
1460 return gen
1461
1462 sys.set_coroutine_wrapper(wrap)
1463 try:
1464 foo()
1465 self.assertIs(
1466 wrapped, None,
1467 "generator-based coroutine was wrapped via "
1468 "sys.set_coroutine_wrapper")
1469 finally:
1470 sys.set_coroutine_wrapper(None)
1471
Yury Selivanov75445082015-05-11 22:57:16 -04001472
1473class CAPITest(unittest.TestCase):
1474
1475 def test_tp_await_1(self):
1476 from _testcapi import awaitType as at
1477
1478 async def foo():
1479 future = at(iter([1]))
1480 return (await future)
1481
1482 self.assertEqual(foo().send(None), 1)
1483
1484 def test_tp_await_2(self):
1485 # Test tp_await to __await__ mapping
1486 from _testcapi import awaitType as at
1487 future = at(iter([1]))
1488 self.assertEqual(next(future.__await__()), 1)
1489
1490 def test_tp_await_3(self):
1491 from _testcapi import awaitType as at
1492
1493 async def foo():
1494 future = at(1)
1495 return (await future)
1496
1497 with self.assertRaisesRegex(
1498 TypeError, "__await__.*returned non-iterator of type 'int'"):
1499 self.assertEqual(foo().send(None), 1)
1500
1501
Yury Selivanov75445082015-05-11 22:57:16 -04001502if __name__=="__main__":
Zachary Ware37ac5902015-05-13 01:03:06 -05001503 unittest.main()