blob: 3ba2f2383c81b46f92e179ef1acdf281c31e8a35 [file] [log] [blame]
Yury Selivanov75445082015-05-11 22:57:16 -04001import contextlib
Yury Selivanov8085b802015-05-18 12:50:52 -04002import inspect
Yury Selivanov75445082015-05-11 22:57:16 -04003import sys
4import types
5import unittest
6import warnings
7from test import support
8
9
10class AsyncYieldFrom:
11 def __init__(self, obj):
12 self.obj = obj
13
14 def __await__(self):
15 yield from self.obj
16
17
18class AsyncYield:
19 def __init__(self, value):
20 self.value = value
21
22 def __await__(self):
23 yield self.value
24
25
26def run_async(coro):
Yury Selivanov5376ba92015-06-22 12:19:30 -040027 assert coro.__class__ in {types.GeneratorType, types.CoroutineType}
Yury Selivanov75445082015-05-11 22:57:16 -040028
29 buffer = []
30 result = None
31 while True:
32 try:
33 buffer.append(coro.send(None))
34 except StopIteration as ex:
35 result = ex.args[0] if ex.args else None
36 break
37 return buffer, result
38
39
Yury Selivanov5376ba92015-06-22 12:19:30 -040040def run_async__await__(coro):
41 assert coro.__class__ is types.CoroutineType
42 aw = coro.__await__()
43 buffer = []
44 result = None
45 i = 0
46 while True:
47 try:
48 if i % 2:
49 buffer.append(next(aw))
50 else:
51 buffer.append(aw.send(None))
52 i += 1
53 except StopIteration as ex:
54 result = ex.args[0] if ex.args else None
55 break
56 return buffer, result
57
58
Yury Selivanov75445082015-05-11 22:57:16 -040059@contextlib.contextmanager
60def silence_coro_gc():
61 with warnings.catch_warnings():
62 warnings.simplefilter("ignore")
63 yield
64 support.gc_collect()
65
66
67class AsyncBadSyntaxTest(unittest.TestCase):
68
69 def test_badsyntax_1(self):
Yury Selivanov8fb307c2015-07-22 13:33:45 +030070 with self.assertRaisesRegex(SyntaxError, "'await' outside"):
Yury Selivanov75445082015-05-11 22:57:16 -040071 import test.badsyntax_async1
72
73 def test_badsyntax_2(self):
Yury Selivanov8fb307c2015-07-22 13:33:45 +030074 with self.assertRaisesRegex(SyntaxError, "'await' outside"):
Yury Selivanov75445082015-05-11 22:57:16 -040075 import test.badsyntax_async2
76
77 def test_badsyntax_3(self):
78 with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
79 import test.badsyntax_async3
80
81 def test_badsyntax_4(self):
82 with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
83 import test.badsyntax_async4
84
85 def test_badsyntax_5(self):
86 with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
87 import test.badsyntax_async5
88
89 def test_badsyntax_6(self):
90 with self.assertRaisesRegex(
91 SyntaxError, "'yield' inside async function"):
92
93 import test.badsyntax_async6
94
95 def test_badsyntax_7(self):
96 with self.assertRaisesRegex(
97 SyntaxError, "'yield from' inside async function"):
98
99 import test.badsyntax_async7
100
101 def test_badsyntax_8(self):
102 with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
103 import test.badsyntax_async8
104
105 def test_badsyntax_9(self):
Yury Selivanov9dec0352015-06-30 12:49:04 -0400106 ns = {}
107 for comp in {'(await a for a in b)',
108 '[await a for a in b]',
109 '{await a for a in b}',
110 '{await a: c for a in b}'}:
111
Yury Selivanov86cd7d62015-06-30 12:51:12 -0400112 with self.assertRaisesRegex(SyntaxError, 'await.*in comprehen'):
Yury Selivanov9dec0352015-06-30 12:49:04 -0400113 exec('async def f():\n\t{}'.format(comp), ns, ns)
114
Yury Selivanov8fb307c2015-07-22 13:33:45 +0300115 def test_badsyntax_10(self):
116 # Tests for issue 24619
117
118 samples = [
119 """async def foo():
120 def bar(): pass
121 await = 1
122 """,
123
124 """async def foo():
125
126 def bar(): pass
127 await = 1
128 """,
129
130 """async def foo():
131 def bar(): pass
132 if 1:
133 await = 1
134 """,
135
136 """def foo():
137 async def bar(): pass
138 if 1:
139 await a
140 """,
141
142 """def foo():
143 async def bar(): pass
144 await a
145 """,
146
147 """def foo():
148 def baz(): pass
149 async def bar(): pass
150 await a
151 """,
152
153 """def foo():
154 def baz(): pass
155 # 456
156 async def bar(): pass
157 # 123
158 await a
159 """,
160
161 """async def foo():
162 def baz(): pass
163 # 456
164 async def bar(): pass
165 # 123
166 await = 2
167 """,
168
169 """def foo():
170
171 def baz(): pass
172
173 async def bar(): pass
174
175 await a
176 """,
177
178 """async def foo():
179
180 def baz(): pass
181
182 async def bar(): pass
183
184 await = 2
185 """,
186
187 """async def foo():
188 def async(): pass
189 """,
190
191 """async def foo():
192 def await(): pass
193 """,
194
195 """async def foo():
196 def bar():
197 await
198 """,
199
200 """async def foo():
201 return lambda async: await
202 """,
203
204 """async def foo():
205 return lambda a: await
206 """,
207
208 """async def foo(a: await b):
209 pass
210 """,
211
212 """def baz():
213 async def foo(a: await b):
214 pass
215 """,
216
217 """async def foo(async):
218 pass
219 """,
220
221 """async def foo():
222 def bar():
223 def baz():
224 async = 1
225 """,
226
227 """async def foo():
228 def bar():
229 def baz():
230 pass
231 async = 1
232 """,
233
234 """def foo():
235 async def bar():
236
237 async def baz():
238 pass
239
240 def baz():
241 42
242
243 async = 1
244 """,
245
246 """async def foo():
247 def bar():
248 def baz():
249 pass\nawait foo()
250 """,
251
252 """def foo():
253 def bar():
254 async def baz():
255 pass\nawait foo()
256 """,
257
258 """async def foo(await):
259 pass
260 """,
261
262 """def foo():
263
264 async def bar(): pass
265
266 await a
267 """,
268
269 """def foo():
270 async def bar():
271 pass\nawait a
272 """]
273
274 ns = {}
275 for code in samples:
276 with self.subTest(code=code), self.assertRaises(SyntaxError):
277 exec(code, ns, ns)
278
279 def test_goodsyntax_1(self):
280 # Tests for issue 24619
281
282 def foo(await):
283 async def foo(): pass
284 async def foo():
285 pass
286 return await + 1
287 self.assertEqual(foo(10), 11)
288
289 def foo(await):
290 async def foo(): pass
291 async def foo(): pass
292 return await + 2
293 self.assertEqual(foo(20), 22)
294
295 def foo(await):
296
297 async def foo(): pass
298
299 async def foo(): pass
300
301 return await + 2
302 self.assertEqual(foo(20), 22)
303
304 def foo(await):
305 """spam"""
306 async def foo(): \
307 pass
308 # 123
309 async def foo(): pass
310 # 456
311 return await + 2
312 self.assertEqual(foo(20), 22)
313
314 def foo(await):
315 def foo(): pass
316 def foo(): pass
317 async def bar(): return await_
318 await_ = await
319 try:
320 bar().send(None)
321 except StopIteration as ex:
322 return ex.args[0]
323 self.assertEqual(foo(42), 42)
324
325 async def f():
326 async def g(): pass
327 await z
328 self.assertTrue(inspect.iscoroutinefunction(f))
329
Yury Selivanov75445082015-05-11 22:57:16 -0400330
Yury Selivanov8085b802015-05-18 12:50:52 -0400331class TokenizerRegrTest(unittest.TestCase):
332
333 def test_oneline_defs(self):
334 buf = []
335 for i in range(500):
336 buf.append('def i{i}(): return {i}'.format(i=i))
337 buf = '\n'.join(buf)
338
339 # Test that 500 consequent, one-line defs is OK
340 ns = {}
341 exec(buf, ns, ns)
342 self.assertEqual(ns['i499'](), 499)
343
344 # Test that 500 consequent, one-line defs *and*
345 # one 'async def' following them is OK
346 buf += '\nasync def foo():\n return'
347 ns = {}
348 exec(buf, ns, ns)
349 self.assertEqual(ns['i499'](), 499)
350 self.assertTrue(inspect.iscoroutinefunction(ns['foo']))
351
352
Yury Selivanov75445082015-05-11 22:57:16 -0400353class CoroutineTest(unittest.TestCase):
354
355 def test_gen_1(self):
356 def gen(): yield
357 self.assertFalse(hasattr(gen, '__await__'))
358
359 def test_func_1(self):
360 async def foo():
361 return 10
362
363 f = foo()
Yury Selivanov5376ba92015-06-22 12:19:30 -0400364 self.assertIsInstance(f, types.CoroutineType)
365 self.assertTrue(bool(foo.__code__.co_flags & inspect.CO_COROUTINE))
366 self.assertFalse(bool(foo.__code__.co_flags & inspect.CO_GENERATOR))
367 self.assertTrue(bool(f.cr_code.co_flags & inspect.CO_COROUTINE))
368 self.assertFalse(bool(f.cr_code.co_flags & inspect.CO_GENERATOR))
Yury Selivanov75445082015-05-11 22:57:16 -0400369 self.assertEqual(run_async(f), ([], 10))
370
Yury Selivanov5376ba92015-06-22 12:19:30 -0400371 self.assertEqual(run_async__await__(foo()), ([], 10))
372
Yury Selivanov75445082015-05-11 22:57:16 -0400373 def bar(): pass
Yury Selivanov5376ba92015-06-22 12:19:30 -0400374 self.assertFalse(bool(bar.__code__.co_flags & inspect.CO_COROUTINE))
Yury Selivanov75445082015-05-11 22:57:16 -0400375
376 def test_func_2(self):
377 async def foo():
378 raise StopIteration
379
380 with self.assertRaisesRegex(
Yury Selivanov5376ba92015-06-22 12:19:30 -0400381 RuntimeError, "coroutine raised StopIteration"):
Yury Selivanov75445082015-05-11 22:57:16 -0400382
383 run_async(foo())
384
385 def test_func_3(self):
386 async def foo():
387 raise StopIteration
388
389 with silence_coro_gc():
390 self.assertRegex(repr(foo()), '^<coroutine object.* at 0x.*>$')
391
392 def test_func_4(self):
393 async def foo():
394 raise StopIteration
395
396 check = lambda: self.assertRaisesRegex(
Yury Selivanov5376ba92015-06-22 12:19:30 -0400397 TypeError, "'coroutine' object is not iterable")
Yury Selivanov75445082015-05-11 22:57:16 -0400398
399 with check():
400 list(foo())
401
402 with check():
403 tuple(foo())
404
405 with check():
406 sum(foo())
407
408 with check():
409 iter(foo())
410
Yury Selivanov75445082015-05-11 22:57:16 -0400411 with silence_coro_gc(), check():
412 for i in foo():
413 pass
414
415 with silence_coro_gc(), check():
416 [i for i in foo()]
417
418 def test_func_5(self):
419 @types.coroutine
420 def bar():
421 yield 1
422
423 async def foo():
424 await bar()
425
426 check = lambda: self.assertRaisesRegex(
Yury Selivanov5376ba92015-06-22 12:19:30 -0400427 TypeError, "'coroutine' object is not iterable")
Yury Selivanov75445082015-05-11 22:57:16 -0400428
429 with check():
430 for el in foo(): pass
431
432 # the following should pass without an error
433 for el in bar():
434 self.assertEqual(el, 1)
435 self.assertEqual([el for el in bar()], [1])
436 self.assertEqual(tuple(bar()), (1,))
437 self.assertEqual(next(iter(bar())), 1)
438
439 def test_func_6(self):
440 @types.coroutine
441 def bar():
442 yield 1
443 yield 2
444
445 async def foo():
446 await bar()
447
448 f = foo()
Zachary Ware37ac5902015-05-13 01:03:06 -0500449 self.assertEqual(f.send(None), 1)
450 self.assertEqual(f.send(None), 2)
Yury Selivanov75445082015-05-11 22:57:16 -0400451 with self.assertRaises(StopIteration):
452 f.send(None)
453
454 def test_func_7(self):
455 async def bar():
456 return 10
457
458 def foo():
459 yield from bar()
460
461 with silence_coro_gc(), self.assertRaisesRegex(
462 TypeError,
Yury Selivanov5376ba92015-06-22 12:19:30 -0400463 "cannot 'yield from' a coroutine object in a non-coroutine generator"):
Yury Selivanov75445082015-05-11 22:57:16 -0400464
465 list(foo())
466
467 def test_func_8(self):
468 @types.coroutine
469 def bar():
470 return (yield from foo())
471
472 async def foo():
473 return 'spam'
474
475 self.assertEqual(run_async(bar()), ([], 'spam') )
476
477 def test_func_9(self):
478 async def foo(): pass
479
480 with self.assertWarnsRegex(
481 RuntimeWarning, "coroutine '.*test_func_9.*foo' was never awaited"):
482
483 foo()
484 support.gc_collect()
485
Yury Selivanov5376ba92015-06-22 12:19:30 -0400486 def test_func_10(self):
487 N = 0
488
489 @types.coroutine
490 def gen():
491 nonlocal N
492 try:
493 a = yield
494 yield (a ** 2)
495 except ZeroDivisionError:
496 N += 100
497 raise
498 finally:
499 N += 1
500
501 async def foo():
502 await gen()
503
504 coro = foo()
505 aw = coro.__await__()
506 self.assertIs(aw, iter(aw))
507 next(aw)
508 self.assertEqual(aw.send(10), 100)
509
510 self.assertEqual(N, 0)
511 aw.close()
512 self.assertEqual(N, 1)
513
514 coro = foo()
515 aw = coro.__await__()
516 next(aw)
517 with self.assertRaises(ZeroDivisionError):
518 aw.throw(ZeroDivisionError, None, None)
519 self.assertEqual(N, 102)
520
521 def test_func_11(self):
522 async def func(): pass
523 coro = func()
524 # Test that PyCoro_Type and _PyCoroWrapper_Type types were properly
525 # initialized
526 self.assertIn('__await__', dir(coro))
527 self.assertIn('__iter__', dir(coro.__await__()))
528 self.assertIn('coroutine_wrapper', repr(coro.__await__()))
529 coro.close() # avoid RuntimeWarning
530
531 def test_func_12(self):
532 async def g():
533 i = me.send(None)
534 await foo
535 me = g()
536 with self.assertRaisesRegex(ValueError,
537 "coroutine already executing"):
538 me.send(None)
539
540 def test_func_13(self):
541 async def g():
542 pass
543 with self.assertRaisesRegex(
544 TypeError,
545 "can't send non-None value to a just-started coroutine"):
546
547 g().send('spam')
548
549 def test_func_14(self):
550 @types.coroutine
551 def gen():
552 yield
553 async def coro():
554 try:
555 await gen()
556 except GeneratorExit:
557 await gen()
558 c = coro()
559 c.send(None)
560 with self.assertRaisesRegex(RuntimeError,
561 "coroutine ignored GeneratorExit"):
562 c.close()
563
Yury Selivanove13f8f32015-07-03 00:23:30 -0400564 def test_cr_await(self):
565 @types.coroutine
566 def a():
567 self.assertEqual(inspect.getcoroutinestate(coro_b), inspect.CORO_RUNNING)
568 self.assertIsNone(coro_b.cr_await)
569 yield
570 self.assertEqual(inspect.getcoroutinestate(coro_b), inspect.CORO_RUNNING)
571 self.assertIsNone(coro_b.cr_await)
572
573 async def c():
574 await a()
575
576 async def b():
577 self.assertIsNone(coro_b.cr_await)
578 await c()
579 self.assertIsNone(coro_b.cr_await)
580
581 coro_b = b()
582 self.assertEqual(inspect.getcoroutinestate(coro_b), inspect.CORO_CREATED)
583 self.assertIsNone(coro_b.cr_await)
584
585 coro_b.send(None)
586 self.assertEqual(inspect.getcoroutinestate(coro_b), inspect.CORO_SUSPENDED)
587 self.assertEqual(coro_b.cr_await.cr_await.gi_code.co_name, 'a')
588
589 with self.assertRaises(StopIteration):
590 coro_b.send(None) # complete coroutine
591 self.assertEqual(inspect.getcoroutinestate(coro_b), inspect.CORO_CLOSED)
592 self.assertIsNone(coro_b.cr_await)
593
Yury Selivanov5376ba92015-06-22 12:19:30 -0400594 def test_corotype_1(self):
595 ct = types.CoroutineType
596 self.assertIn('into coroutine', ct.send.__doc__)
597 self.assertIn('inside coroutine', ct.close.__doc__)
598 self.assertIn('in coroutine', ct.throw.__doc__)
599 self.assertIn('of the coroutine', ct.__dict__['__name__'].__doc__)
600 self.assertIn('of the coroutine', ct.__dict__['__qualname__'].__doc__)
601 self.assertEqual(ct.__name__, 'coroutine')
602
603 async def f(): pass
604 c = f()
605 self.assertIn('coroutine object', repr(c))
606 c.close()
607
Yury Selivanov75445082015-05-11 22:57:16 -0400608 def test_await_1(self):
609
610 async def foo():
611 await 1
612 with self.assertRaisesRegex(TypeError, "object int can.t.*await"):
613 run_async(foo())
614
615 def test_await_2(self):
616 async def foo():
617 await []
618 with self.assertRaisesRegex(TypeError, "object list can.t.*await"):
619 run_async(foo())
620
621 def test_await_3(self):
622 async def foo():
623 await AsyncYieldFrom([1, 2, 3])
624
625 self.assertEqual(run_async(foo()), ([1, 2, 3], None))
Yury Selivanov5376ba92015-06-22 12:19:30 -0400626 self.assertEqual(run_async__await__(foo()), ([1, 2, 3], None))
Yury Selivanov75445082015-05-11 22:57:16 -0400627
628 def test_await_4(self):
629 async def bar():
630 return 42
631
632 async def foo():
633 return await bar()
634
635 self.assertEqual(run_async(foo()), ([], 42))
636
637 def test_await_5(self):
638 class Awaitable:
639 def __await__(self):
640 return
641
642 async def foo():
643 return (await Awaitable())
644
645 with self.assertRaisesRegex(
646 TypeError, "__await__.*returned non-iterator of type"):
647
648 run_async(foo())
649
650 def test_await_6(self):
651 class Awaitable:
652 def __await__(self):
653 return iter([52])
654
655 async def foo():
656 return (await Awaitable())
657
658 self.assertEqual(run_async(foo()), ([52], None))
659
660 def test_await_7(self):
661 class Awaitable:
662 def __await__(self):
663 yield 42
664 return 100
665
666 async def foo():
667 return (await Awaitable())
668
669 self.assertEqual(run_async(foo()), ([42], 100))
670
671 def test_await_8(self):
672 class Awaitable:
673 pass
674
Yury Selivanov8fb307c2015-07-22 13:33:45 +0300675 async def foo(): return await Awaitable()
Yury Selivanov75445082015-05-11 22:57:16 -0400676
677 with self.assertRaisesRegex(
678 TypeError, "object Awaitable can't be used in 'await' expression"):
679
680 run_async(foo())
681
682 def test_await_9(self):
683 def wrap():
684 return bar
685
686 async def bar():
687 return 42
688
689 async def foo():
690 b = bar()
691
692 db = {'b': lambda: wrap}
693
694 class DB:
695 b = wrap
696
697 return (await bar() + await wrap()() + await db['b']()()() +
698 await bar() * 1000 + await DB.b()())
699
700 async def foo2():
701 return -await bar()
702
703 self.assertEqual(run_async(foo()), ([], 42168))
704 self.assertEqual(run_async(foo2()), ([], -42))
705
706 def test_await_10(self):
707 async def baz():
708 return 42
709
710 async def bar():
711 return baz()
712
713 async def foo():
714 return await (await bar())
715
716 self.assertEqual(run_async(foo()), ([], 42))
717
718 def test_await_11(self):
719 def ident(val):
720 return val
721
722 async def bar():
723 return 'spam'
724
725 async def foo():
726 return ident(val=await bar())
727
728 async def foo2():
729 return await bar(), 'ham'
730
731 self.assertEqual(run_async(foo2()), ([], ('spam', 'ham')))
732
733 def test_await_12(self):
734 async def coro():
735 return 'spam'
736
737 class Awaitable:
738 def __await__(self):
739 return coro()
740
741 async def foo():
742 return await Awaitable()
743
744 with self.assertRaisesRegex(
745 TypeError, "__await__\(\) returned a coroutine"):
746
747 run_async(foo())
748
749 def test_await_13(self):
750 class Awaitable:
751 def __await__(self):
752 return self
753
754 async def foo():
755 return await Awaitable()
756
757 with self.assertRaisesRegex(
758 TypeError, "__await__.*returned non-iterator of type"):
759
760 run_async(foo())
761
Yury Selivanovf2701522015-07-01 12:29:55 -0400762 def test_await_14(self):
763 class Wrapper:
764 # Forces the interpreter to use CoroutineType.__await__
765 def __init__(self, coro):
766 assert coro.__class__ is types.CoroutineType
767 self.coro = coro
768 def __await__(self):
769 return self.coro.__await__()
770
771 class FutureLike:
772 def __await__(self):
773 return (yield)
774
775 class Marker(Exception):
776 pass
777
778 async def coro1():
779 try:
780 return await FutureLike()
781 except ZeroDivisionError:
782 raise Marker
783 async def coro2():
784 return await Wrapper(coro1())
785
786 c = coro2()
787 c.send(None)
788 with self.assertRaisesRegex(StopIteration, 'spam'):
789 c.send('spam')
790
791 c = coro2()
792 c.send(None)
793 with self.assertRaises(Marker):
794 c.throw(ZeroDivisionError)
795
Yury Selivanov75445082015-05-11 22:57:16 -0400796 def test_with_1(self):
797 class Manager:
798 def __init__(self, name):
799 self.name = name
800
801 async def __aenter__(self):
802 await AsyncYieldFrom(['enter-1-' + self.name,
803 'enter-2-' + self.name])
804 return self
805
806 async def __aexit__(self, *args):
807 await AsyncYieldFrom(['exit-1-' + self.name,
808 'exit-2-' + self.name])
809
810 if self.name == 'B':
811 return True
812
813
814 async def foo():
815 async with Manager("A") as a, Manager("B") as b:
816 await AsyncYieldFrom([('managers', a.name, b.name)])
817 1/0
818
819 f = foo()
820 result, _ = run_async(f)
821
822 self.assertEqual(
823 result, ['enter-1-A', 'enter-2-A', 'enter-1-B', 'enter-2-B',
824 ('managers', 'A', 'B'),
825 'exit-1-B', 'exit-2-B', 'exit-1-A', 'exit-2-A']
826 )
827
828 async def foo():
829 async with Manager("A") as a, Manager("C") as c:
830 await AsyncYieldFrom([('managers', a.name, c.name)])
831 1/0
832
833 with self.assertRaises(ZeroDivisionError):
834 run_async(foo())
835
836 def test_with_2(self):
837 class CM:
838 def __aenter__(self):
839 pass
840
841 async def foo():
842 async with CM():
843 pass
844
845 with self.assertRaisesRegex(AttributeError, '__aexit__'):
846 run_async(foo())
847
848 def test_with_3(self):
849 class CM:
850 def __aexit__(self):
851 pass
852
853 async def foo():
854 async with CM():
855 pass
856
857 with self.assertRaisesRegex(AttributeError, '__aenter__'):
858 run_async(foo())
859
860 def test_with_4(self):
861 class CM:
862 def __enter__(self):
863 pass
864
865 def __exit__(self):
866 pass
867
868 async def foo():
869 async with CM():
870 pass
871
872 with self.assertRaisesRegex(AttributeError, '__aexit__'):
873 run_async(foo())
874
875 def test_with_5(self):
876 # While this test doesn't make a lot of sense,
877 # it's a regression test for an early bug with opcodes
878 # generation
879
880 class CM:
881 async def __aenter__(self):
882 return self
883
884 async def __aexit__(self, *exc):
885 pass
886
887 async def func():
888 async with CM():
889 assert (1, ) == 1
890
891 with self.assertRaises(AssertionError):
892 run_async(func())
893
894 def test_with_6(self):
895 class CM:
896 def __aenter__(self):
897 return 123
898
899 def __aexit__(self, *e):
900 return 456
901
902 async def foo():
903 async with CM():
904 pass
905
906 with self.assertRaisesRegex(
907 TypeError, "object int can't be used in 'await' expression"):
908 # it's important that __aexit__ wasn't called
909 run_async(foo())
910
911 def test_with_7(self):
912 class CM:
913 async def __aenter__(self):
914 return self
915
916 def __aexit__(self, *e):
Nick Coghlanbaaadbf2015-05-13 15:54:02 +1000917 return 444
Yury Selivanov75445082015-05-11 22:57:16 -0400918
919 async def foo():
920 async with CM():
Nick Coghlanbaaadbf2015-05-13 15:54:02 +1000921 1/0
922
923 try:
924 run_async(foo())
925 except TypeError as exc:
926 self.assertRegex(
927 exc.args[0], "object int can't be used in 'await' expression")
928 self.assertTrue(exc.__context__ is not None)
929 self.assertTrue(isinstance(exc.__context__, ZeroDivisionError))
930 else:
931 self.fail('invalid asynchronous context manager did not fail')
932
933
934 def test_with_8(self):
935 CNT = 0
936
937 class CM:
938 async def __aenter__(self):
939 return self
940
941 def __aexit__(self, *e):
942 return 456
943
944 async def foo():
945 nonlocal CNT
946 async with CM():
947 CNT += 1
948
Yury Selivanov75445082015-05-11 22:57:16 -0400949
950 with self.assertRaisesRegex(
951 TypeError, "object int can't be used in 'await' expression"):
952
953 run_async(foo())
954
Nick Coghlanbaaadbf2015-05-13 15:54:02 +1000955 self.assertEqual(CNT, 1)
956
957
958 def test_with_9(self):
959 CNT = 0
960
961 class CM:
962 async def __aenter__(self):
963 return self
964
965 async def __aexit__(self, *e):
966 1/0
967
968 async def foo():
969 nonlocal CNT
970 async with CM():
971 CNT += 1
972
973 with self.assertRaises(ZeroDivisionError):
974 run_async(foo())
975
976 self.assertEqual(CNT, 1)
977
978 def test_with_10(self):
979 CNT = 0
980
981 class CM:
982 async def __aenter__(self):
983 return self
984
985 async def __aexit__(self, *e):
986 1/0
987
988 async def foo():
989 nonlocal CNT
990 async with CM():
991 async with CM():
992 raise RuntimeError
993
994 try:
995 run_async(foo())
996 except ZeroDivisionError as exc:
997 self.assertTrue(exc.__context__ is not None)
998 self.assertTrue(isinstance(exc.__context__, ZeroDivisionError))
999 self.assertTrue(isinstance(exc.__context__.__context__,
1000 RuntimeError))
1001 else:
1002 self.fail('exception from __aexit__ did not propagate')
1003
1004 def test_with_11(self):
1005 CNT = 0
1006
1007 class CM:
1008 async def __aenter__(self):
1009 raise NotImplementedError
1010
1011 async def __aexit__(self, *e):
1012 1/0
1013
1014 async def foo():
1015 nonlocal CNT
1016 async with CM():
1017 raise RuntimeError
1018
1019 try:
1020 run_async(foo())
1021 except NotImplementedError as exc:
1022 self.assertTrue(exc.__context__ is None)
1023 else:
1024 self.fail('exception from __aenter__ did not propagate')
1025
1026 def test_with_12(self):
1027 CNT = 0
1028
1029 class CM:
1030 async def __aenter__(self):
1031 return self
1032
1033 async def __aexit__(self, *e):
1034 return True
1035
1036 async def foo():
1037 nonlocal CNT
1038 async with CM() as cm:
1039 self.assertIs(cm.__class__, CM)
1040 raise RuntimeError
1041
1042 run_async(foo())
1043
Yury Selivanov9113dc72015-05-13 16:49:35 -04001044 def test_with_13(self):
1045 CNT = 0
1046
1047 class CM:
1048 async def __aenter__(self):
1049 1/0
1050
1051 async def __aexit__(self, *e):
1052 return True
1053
1054 async def foo():
1055 nonlocal CNT
1056 CNT += 1
1057 async with CM():
1058 CNT += 1000
1059 CNT += 10000
1060
1061 with self.assertRaises(ZeroDivisionError):
1062 run_async(foo())
1063 self.assertEqual(CNT, 1)
1064
Yury Selivanov75445082015-05-11 22:57:16 -04001065 def test_for_1(self):
1066 aiter_calls = 0
1067
1068 class AsyncIter:
1069 def __init__(self):
1070 self.i = 0
1071
1072 async def __aiter__(self):
1073 nonlocal aiter_calls
1074 aiter_calls += 1
1075 return self
1076
1077 async def __anext__(self):
1078 self.i += 1
1079
1080 if not (self.i % 10):
1081 await AsyncYield(self.i * 10)
1082
1083 if self.i > 100:
1084 raise StopAsyncIteration
1085
1086 return self.i, self.i
1087
1088
1089 buffer = []
1090 async def test1():
1091 async for i1, i2 in AsyncIter():
1092 buffer.append(i1 + i2)
1093
1094 yielded, _ = run_async(test1())
1095 # Make sure that __aiter__ was called only once
1096 self.assertEqual(aiter_calls, 1)
1097 self.assertEqual(yielded, [i * 100 for i in range(1, 11)])
1098 self.assertEqual(buffer, [i*2 for i in range(1, 101)])
1099
1100
1101 buffer = []
1102 async def test2():
1103 nonlocal buffer
1104 async for i in AsyncIter():
1105 buffer.append(i[0])
1106 if i[0] == 20:
1107 break
1108 else:
1109 buffer.append('what?')
1110 buffer.append('end')
1111
1112 yielded, _ = run_async(test2())
1113 # Make sure that __aiter__ was called only once
1114 self.assertEqual(aiter_calls, 2)
1115 self.assertEqual(yielded, [100, 200])
1116 self.assertEqual(buffer, [i for i in range(1, 21)] + ['end'])
1117
1118
1119 buffer = []
1120 async def test3():
1121 nonlocal buffer
1122 async for i in AsyncIter():
1123 if i[0] > 20:
1124 continue
1125 buffer.append(i[0])
1126 else:
1127 buffer.append('what?')
1128 buffer.append('end')
1129
1130 yielded, _ = run_async(test3())
1131 # Make sure that __aiter__ was called only once
1132 self.assertEqual(aiter_calls, 3)
1133 self.assertEqual(yielded, [i * 100 for i in range(1, 11)])
1134 self.assertEqual(buffer, [i for i in range(1, 21)] +
1135 ['what?', 'end'])
1136
1137 def test_for_2(self):
1138 tup = (1, 2, 3)
1139 refs_before = sys.getrefcount(tup)
1140
1141 async def foo():
1142 async for i in tup:
1143 print('never going to happen')
1144
1145 with self.assertRaisesRegex(
1146 TypeError, "async for' requires an object.*__aiter__.*tuple"):
1147
1148 run_async(foo())
1149
1150 self.assertEqual(sys.getrefcount(tup), refs_before)
1151
1152 def test_for_3(self):
1153 class I:
1154 def __aiter__(self):
1155 return self
1156
1157 aiter = I()
1158 refs_before = sys.getrefcount(aiter)
1159
1160 async def foo():
1161 async for i in aiter:
1162 print('never going to happen')
1163
1164 with self.assertRaisesRegex(
1165 TypeError,
1166 "async for' received an invalid object.*__aiter.*\: I"):
1167
1168 run_async(foo())
1169
1170 self.assertEqual(sys.getrefcount(aiter), refs_before)
1171
1172 def test_for_4(self):
1173 class I:
1174 async def __aiter__(self):
1175 return self
1176
1177 def __anext__(self):
1178 return ()
1179
1180 aiter = I()
1181 refs_before = sys.getrefcount(aiter)
1182
1183 async def foo():
1184 async for i in aiter:
1185 print('never going to happen')
1186
1187 with self.assertRaisesRegex(
1188 TypeError,
1189 "async for' received an invalid object.*__anext__.*tuple"):
1190
1191 run_async(foo())
1192
1193 self.assertEqual(sys.getrefcount(aiter), refs_before)
1194
1195 def test_for_5(self):
1196 class I:
1197 async def __aiter__(self):
1198 return self
1199
1200 def __anext__(self):
1201 return 123
1202
1203 async def foo():
1204 async for i in I():
1205 print('never going to happen')
1206
1207 with self.assertRaisesRegex(
1208 TypeError,
1209 "async for' received an invalid object.*__anext.*int"):
1210
1211 run_async(foo())
1212
1213 def test_for_6(self):
1214 I = 0
1215
1216 class Manager:
1217 async def __aenter__(self):
1218 nonlocal I
1219 I += 10000
1220
1221 async def __aexit__(self, *args):
1222 nonlocal I
1223 I += 100000
1224
1225 class Iterable:
1226 def __init__(self):
1227 self.i = 0
1228
1229 async def __aiter__(self):
1230 return self
1231
1232 async def __anext__(self):
1233 if self.i > 10:
1234 raise StopAsyncIteration
1235 self.i += 1
1236 return self.i
1237
1238 ##############
1239
1240 manager = Manager()
1241 iterable = Iterable()
1242 mrefs_before = sys.getrefcount(manager)
1243 irefs_before = sys.getrefcount(iterable)
1244
1245 async def main():
1246 nonlocal I
1247
1248 async with manager:
1249 async for i in iterable:
1250 I += 1
1251 I += 1000
1252
1253 run_async(main())
1254 self.assertEqual(I, 111011)
1255
1256 self.assertEqual(sys.getrefcount(manager), mrefs_before)
1257 self.assertEqual(sys.getrefcount(iterable), irefs_before)
1258
1259 ##############
1260
1261 async def main():
1262 nonlocal I
1263
1264 async with Manager():
1265 async for i in Iterable():
1266 I += 1
1267 I += 1000
1268
1269 async with Manager():
1270 async for i in Iterable():
1271 I += 1
1272 I += 1000
1273
1274 run_async(main())
1275 self.assertEqual(I, 333033)
1276
1277 ##############
1278
1279 async def main():
1280 nonlocal I
1281
1282 async with Manager():
1283 I += 100
1284 async for i in Iterable():
1285 I += 1
1286 else:
1287 I += 10000000
1288 I += 1000
1289
1290 async with Manager():
1291 I += 100
1292 async for i in Iterable():
1293 I += 1
1294 else:
1295 I += 10000000
1296 I += 1000
1297
1298 run_async(main())
1299 self.assertEqual(I, 20555255)
1300
Yury Selivanov9113dc72015-05-13 16:49:35 -04001301 def test_for_7(self):
1302 CNT = 0
1303 class AI:
1304 async def __aiter__(self):
1305 1/0
1306 async def foo():
1307 nonlocal CNT
1308 async for i in AI():
1309 CNT += 1
1310 CNT += 10
1311 with self.assertRaises(ZeroDivisionError):
1312 run_async(foo())
1313 self.assertEqual(CNT, 0)
1314
Yury Selivanov75445082015-05-11 22:57:16 -04001315
1316class CoroAsyncIOCompatTest(unittest.TestCase):
1317
1318 def test_asyncio_1(self):
1319 import asyncio
1320
1321 class MyException(Exception):
1322 pass
1323
1324 buffer = []
1325
1326 class CM:
1327 async def __aenter__(self):
1328 buffer.append(1)
1329 await asyncio.sleep(0.01)
1330 buffer.append(2)
1331 return self
1332
1333 async def __aexit__(self, exc_type, exc_val, exc_tb):
1334 await asyncio.sleep(0.01)
1335 buffer.append(exc_type.__name__)
1336
1337 async def f():
1338 async with CM() as c:
1339 await asyncio.sleep(0.01)
1340 raise MyException
1341 buffer.append('unreachable')
1342
Yury Selivanovfdba8382015-05-12 14:28:08 -04001343 loop = asyncio.new_event_loop()
1344 asyncio.set_event_loop(loop)
Yury Selivanov75445082015-05-11 22:57:16 -04001345 try:
1346 loop.run_until_complete(f())
1347 except MyException:
1348 pass
1349 finally:
1350 loop.close()
Yury Selivanovfdba8382015-05-12 14:28:08 -04001351 asyncio.set_event_loop(None)
Yury Selivanov75445082015-05-11 22:57:16 -04001352
1353 self.assertEqual(buffer, [1, 2, 'MyException'])
1354
1355
1356class SysSetCoroWrapperTest(unittest.TestCase):
1357
1358 def test_set_wrapper_1(self):
1359 async def foo():
1360 return 'spam'
1361
1362 wrapped = None
1363 def wrap(gen):
1364 nonlocal wrapped
1365 wrapped = gen
1366 return gen
1367
1368 self.assertIsNone(sys.get_coroutine_wrapper())
1369
1370 sys.set_coroutine_wrapper(wrap)
1371 self.assertIs(sys.get_coroutine_wrapper(), wrap)
1372 try:
1373 f = foo()
1374 self.assertTrue(wrapped)
1375
1376 self.assertEqual(run_async(f), ([], 'spam'))
1377 finally:
1378 sys.set_coroutine_wrapper(None)
1379
1380 self.assertIsNone(sys.get_coroutine_wrapper())
1381
1382 wrapped = None
1383 with silence_coro_gc():
1384 foo()
1385 self.assertFalse(wrapped)
1386
1387 def test_set_wrapper_2(self):
1388 self.assertIsNone(sys.get_coroutine_wrapper())
1389 with self.assertRaisesRegex(TypeError, "callable expected, got int"):
1390 sys.set_coroutine_wrapper(1)
1391 self.assertIsNone(sys.get_coroutine_wrapper())
1392
Yury Selivanovaab3c4a2015-06-02 18:43:51 -04001393 def test_set_wrapper_3(self):
1394 async def foo():
1395 return 'spam'
1396
1397 def wrapper(coro):
1398 async def wrap(coro):
1399 return await coro
1400 return wrap(coro)
1401
1402 sys.set_coroutine_wrapper(wrapper)
1403 try:
Yury Selivanov94c22632015-06-04 10:16:51 -04001404 with silence_coro_gc(), self.assertRaisesRegex(
Yury Selivanovaab3c4a2015-06-02 18:43:51 -04001405 RuntimeError,
1406 "coroutine wrapper.*\.wrapper at 0x.*attempted to "
Yury Selivanov94c22632015-06-04 10:16:51 -04001407 "recursively wrap .* wrap .*"):
Yury Selivanovaab3c4a2015-06-02 18:43:51 -04001408
1409 foo()
1410 finally:
1411 sys.set_coroutine_wrapper(None)
1412
Yury Selivanov5376ba92015-06-22 12:19:30 -04001413 def test_set_wrapper_4(self):
1414 @types.coroutine
1415 def foo():
1416 return 'spam'
1417
1418 wrapped = None
1419 def wrap(gen):
1420 nonlocal wrapped
1421 wrapped = gen
1422 return gen
1423
1424 sys.set_coroutine_wrapper(wrap)
1425 try:
1426 foo()
1427 self.assertIs(
1428 wrapped, None,
1429 "generator-based coroutine was wrapped via "
1430 "sys.set_coroutine_wrapper")
1431 finally:
1432 sys.set_coroutine_wrapper(None)
1433
Yury Selivanov75445082015-05-11 22:57:16 -04001434
1435class CAPITest(unittest.TestCase):
1436
1437 def test_tp_await_1(self):
1438 from _testcapi import awaitType as at
1439
1440 async def foo():
1441 future = at(iter([1]))
1442 return (await future)
1443
1444 self.assertEqual(foo().send(None), 1)
1445
1446 def test_tp_await_2(self):
1447 # Test tp_await to __await__ mapping
1448 from _testcapi import awaitType as at
1449 future = at(iter([1]))
1450 self.assertEqual(next(future.__await__()), 1)
1451
1452 def test_tp_await_3(self):
1453 from _testcapi import awaitType as at
1454
1455 async def foo():
1456 future = at(1)
1457 return (await future)
1458
1459 with self.assertRaisesRegex(
1460 TypeError, "__await__.*returned non-iterator of type 'int'"):
1461 self.assertEqual(foo().send(None), 1)
1462
1463
Yury Selivanov75445082015-05-11 22:57:16 -04001464if __name__=="__main__":
Zachary Ware37ac5902015-05-13 01:03:06 -05001465 unittest.main()