blob: e79896a9b8e954cd49a1f2b3a6c3b5a4f7dcee9c [file] [log] [blame]
Yury Selivanov75445082015-05-11 22:57:16 -04001import contextlib
Yury Selivanov8085b802015-05-18 12:50:52 -04002import inspect
Yury Selivanov75445082015-05-11 22:57:16 -04003import sys
4import types
5import unittest
6import warnings
7from test import support
8
9
10class AsyncYieldFrom:
11 def __init__(self, obj):
12 self.obj = obj
13
14 def __await__(self):
15 yield from self.obj
16
17
18class AsyncYield:
19 def __init__(self, value):
20 self.value = value
21
22 def __await__(self):
23 yield self.value
24
25
26def run_async(coro):
27 assert coro.__class__ is types.GeneratorType
28
29 buffer = []
30 result = None
31 while True:
32 try:
33 buffer.append(coro.send(None))
34 except StopIteration as ex:
35 result = ex.args[0] if ex.args else None
36 break
37 return buffer, result
38
39
40@contextlib.contextmanager
41def silence_coro_gc():
42 with warnings.catch_warnings():
43 warnings.simplefilter("ignore")
44 yield
45 support.gc_collect()
46
47
48class AsyncBadSyntaxTest(unittest.TestCase):
49
50 def test_badsyntax_1(self):
51 with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
52 import test.badsyntax_async1
53
54 def test_badsyntax_2(self):
55 with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
56 import test.badsyntax_async2
57
58 def test_badsyntax_3(self):
59 with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
60 import test.badsyntax_async3
61
62 def test_badsyntax_4(self):
63 with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
64 import test.badsyntax_async4
65
66 def test_badsyntax_5(self):
67 with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
68 import test.badsyntax_async5
69
70 def test_badsyntax_6(self):
71 with self.assertRaisesRegex(
72 SyntaxError, "'yield' inside async function"):
73
74 import test.badsyntax_async6
75
76 def test_badsyntax_7(self):
77 with self.assertRaisesRegex(
78 SyntaxError, "'yield from' inside async function"):
79
80 import test.badsyntax_async7
81
82 def test_badsyntax_8(self):
83 with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
84 import test.badsyntax_async8
85
86 def test_badsyntax_9(self):
87 with self.assertRaisesRegex(SyntaxError, 'invalid syntax'):
88 import test.badsyntax_async9
89
90
Yury Selivanov8085b802015-05-18 12:50:52 -040091class TokenizerRegrTest(unittest.TestCase):
92
93 def test_oneline_defs(self):
94 buf = []
95 for i in range(500):
96 buf.append('def i{i}(): return {i}'.format(i=i))
97 buf = '\n'.join(buf)
98
99 # Test that 500 consequent, one-line defs is OK
100 ns = {}
101 exec(buf, ns, ns)
102 self.assertEqual(ns['i499'](), 499)
103
104 # Test that 500 consequent, one-line defs *and*
105 # one 'async def' following them is OK
106 buf += '\nasync def foo():\n return'
107 ns = {}
108 exec(buf, ns, ns)
109 self.assertEqual(ns['i499'](), 499)
110 self.assertTrue(inspect.iscoroutinefunction(ns['foo']))
111
112
Yury Selivanov75445082015-05-11 22:57:16 -0400113class CoroutineTest(unittest.TestCase):
114
115 def test_gen_1(self):
116 def gen(): yield
117 self.assertFalse(hasattr(gen, '__await__'))
118
119 def test_func_1(self):
120 async def foo():
121 return 10
122
123 f = foo()
124 self.assertIsInstance(f, types.GeneratorType)
125 self.assertTrue(bool(foo.__code__.co_flags & 0x80))
126 self.assertTrue(bool(foo.__code__.co_flags & 0x20))
127 self.assertTrue(bool(f.gi_code.co_flags & 0x80))
128 self.assertTrue(bool(f.gi_code.co_flags & 0x20))
129 self.assertEqual(run_async(f), ([], 10))
130
131 def bar(): pass
132 self.assertFalse(bool(bar.__code__.co_flags & 0x80))
133
134 def test_func_2(self):
135 async def foo():
136 raise StopIteration
137
138 with self.assertRaisesRegex(
139 RuntimeError, "generator raised StopIteration"):
140
141 run_async(foo())
142
143 def test_func_3(self):
144 async def foo():
145 raise StopIteration
146
147 with silence_coro_gc():
148 self.assertRegex(repr(foo()), '^<coroutine object.* at 0x.*>$')
149
150 def test_func_4(self):
151 async def foo():
152 raise StopIteration
153
154 check = lambda: self.assertRaisesRegex(
155 TypeError, "coroutine-objects do not support iteration")
156
157 with check():
158 list(foo())
159
160 with check():
161 tuple(foo())
162
163 with check():
164 sum(foo())
165
166 with check():
167 iter(foo())
168
169 with check():
170 next(foo())
171
172 with silence_coro_gc(), check():
173 for i in foo():
174 pass
175
176 with silence_coro_gc(), check():
177 [i for i in foo()]
178
179 def test_func_5(self):
180 @types.coroutine
181 def bar():
182 yield 1
183
184 async def foo():
185 await bar()
186
187 check = lambda: self.assertRaisesRegex(
188 TypeError, "coroutine-objects do not support iteration")
189
190 with check():
191 for el in foo(): pass
192
193 # the following should pass without an error
194 for el in bar():
195 self.assertEqual(el, 1)
196 self.assertEqual([el for el in bar()], [1])
197 self.assertEqual(tuple(bar()), (1,))
198 self.assertEqual(next(iter(bar())), 1)
199
200 def test_func_6(self):
201 @types.coroutine
202 def bar():
203 yield 1
204 yield 2
205
206 async def foo():
207 await bar()
208
209 f = foo()
Zachary Ware37ac5902015-05-13 01:03:06 -0500210 self.assertEqual(f.send(None), 1)
211 self.assertEqual(f.send(None), 2)
Yury Selivanov75445082015-05-11 22:57:16 -0400212 with self.assertRaises(StopIteration):
213 f.send(None)
214
215 def test_func_7(self):
216 async def bar():
217 return 10
218
219 def foo():
220 yield from bar()
221
222 with silence_coro_gc(), self.assertRaisesRegex(
223 TypeError,
224 "cannot 'yield from' a coroutine object from a generator"):
225
226 list(foo())
227
228 def test_func_8(self):
229 @types.coroutine
230 def bar():
231 return (yield from foo())
232
233 async def foo():
234 return 'spam'
235
236 self.assertEqual(run_async(bar()), ([], 'spam') )
237
238 def test_func_9(self):
239 async def foo(): pass
240
241 with self.assertWarnsRegex(
242 RuntimeWarning, "coroutine '.*test_func_9.*foo' was never awaited"):
243
244 foo()
245 support.gc_collect()
246
247 def test_await_1(self):
248
249 async def foo():
250 await 1
251 with self.assertRaisesRegex(TypeError, "object int can.t.*await"):
252 run_async(foo())
253
254 def test_await_2(self):
255 async def foo():
256 await []
257 with self.assertRaisesRegex(TypeError, "object list can.t.*await"):
258 run_async(foo())
259
260 def test_await_3(self):
261 async def foo():
262 await AsyncYieldFrom([1, 2, 3])
263
264 self.assertEqual(run_async(foo()), ([1, 2, 3], None))
265
266 def test_await_4(self):
267 async def bar():
268 return 42
269
270 async def foo():
271 return await bar()
272
273 self.assertEqual(run_async(foo()), ([], 42))
274
275 def test_await_5(self):
276 class Awaitable:
277 def __await__(self):
278 return
279
280 async def foo():
281 return (await Awaitable())
282
283 with self.assertRaisesRegex(
284 TypeError, "__await__.*returned non-iterator of type"):
285
286 run_async(foo())
287
288 def test_await_6(self):
289 class Awaitable:
290 def __await__(self):
291 return iter([52])
292
293 async def foo():
294 return (await Awaitable())
295
296 self.assertEqual(run_async(foo()), ([52], None))
297
298 def test_await_7(self):
299 class Awaitable:
300 def __await__(self):
301 yield 42
302 return 100
303
304 async def foo():
305 return (await Awaitable())
306
307 self.assertEqual(run_async(foo()), ([42], 100))
308
309 def test_await_8(self):
310 class Awaitable:
311 pass
312
313 async def foo():
314 return (await Awaitable())
315
316 with self.assertRaisesRegex(
317 TypeError, "object Awaitable can't be used in 'await' expression"):
318
319 run_async(foo())
320
321 def test_await_9(self):
322 def wrap():
323 return bar
324
325 async def bar():
326 return 42
327
328 async def foo():
329 b = bar()
330
331 db = {'b': lambda: wrap}
332
333 class DB:
334 b = wrap
335
336 return (await bar() + await wrap()() + await db['b']()()() +
337 await bar() * 1000 + await DB.b()())
338
339 async def foo2():
340 return -await bar()
341
342 self.assertEqual(run_async(foo()), ([], 42168))
343 self.assertEqual(run_async(foo2()), ([], -42))
344
345 def test_await_10(self):
346 async def baz():
347 return 42
348
349 async def bar():
350 return baz()
351
352 async def foo():
353 return await (await bar())
354
355 self.assertEqual(run_async(foo()), ([], 42))
356
357 def test_await_11(self):
358 def ident(val):
359 return val
360
361 async def bar():
362 return 'spam'
363
364 async def foo():
365 return ident(val=await bar())
366
367 async def foo2():
368 return await bar(), 'ham'
369
370 self.assertEqual(run_async(foo2()), ([], ('spam', 'ham')))
371
372 def test_await_12(self):
373 async def coro():
374 return 'spam'
375
376 class Awaitable:
377 def __await__(self):
378 return coro()
379
380 async def foo():
381 return await Awaitable()
382
383 with self.assertRaisesRegex(
384 TypeError, "__await__\(\) returned a coroutine"):
385
386 run_async(foo())
387
388 def test_await_13(self):
389 class Awaitable:
390 def __await__(self):
391 return self
392
393 async def foo():
394 return await Awaitable()
395
396 with self.assertRaisesRegex(
397 TypeError, "__await__.*returned non-iterator of type"):
398
399 run_async(foo())
400
401 def test_with_1(self):
402 class Manager:
403 def __init__(self, name):
404 self.name = name
405
406 async def __aenter__(self):
407 await AsyncYieldFrom(['enter-1-' + self.name,
408 'enter-2-' + self.name])
409 return self
410
411 async def __aexit__(self, *args):
412 await AsyncYieldFrom(['exit-1-' + self.name,
413 'exit-2-' + self.name])
414
415 if self.name == 'B':
416 return True
417
418
419 async def foo():
420 async with Manager("A") as a, Manager("B") as b:
421 await AsyncYieldFrom([('managers', a.name, b.name)])
422 1/0
423
424 f = foo()
425 result, _ = run_async(f)
426
427 self.assertEqual(
428 result, ['enter-1-A', 'enter-2-A', 'enter-1-B', 'enter-2-B',
429 ('managers', 'A', 'B'),
430 'exit-1-B', 'exit-2-B', 'exit-1-A', 'exit-2-A']
431 )
432
433 async def foo():
434 async with Manager("A") as a, Manager("C") as c:
435 await AsyncYieldFrom([('managers', a.name, c.name)])
436 1/0
437
438 with self.assertRaises(ZeroDivisionError):
439 run_async(foo())
440
441 def test_with_2(self):
442 class CM:
443 def __aenter__(self):
444 pass
445
446 async def foo():
447 async with CM():
448 pass
449
450 with self.assertRaisesRegex(AttributeError, '__aexit__'):
451 run_async(foo())
452
453 def test_with_3(self):
454 class CM:
455 def __aexit__(self):
456 pass
457
458 async def foo():
459 async with CM():
460 pass
461
462 with self.assertRaisesRegex(AttributeError, '__aenter__'):
463 run_async(foo())
464
465 def test_with_4(self):
466 class CM:
467 def __enter__(self):
468 pass
469
470 def __exit__(self):
471 pass
472
473 async def foo():
474 async with CM():
475 pass
476
477 with self.assertRaisesRegex(AttributeError, '__aexit__'):
478 run_async(foo())
479
480 def test_with_5(self):
481 # While this test doesn't make a lot of sense,
482 # it's a regression test for an early bug with opcodes
483 # generation
484
485 class CM:
486 async def __aenter__(self):
487 return self
488
489 async def __aexit__(self, *exc):
490 pass
491
492 async def func():
493 async with CM():
494 assert (1, ) == 1
495
496 with self.assertRaises(AssertionError):
497 run_async(func())
498
499 def test_with_6(self):
500 class CM:
501 def __aenter__(self):
502 return 123
503
504 def __aexit__(self, *e):
505 return 456
506
507 async def foo():
508 async with CM():
509 pass
510
511 with self.assertRaisesRegex(
512 TypeError, "object int can't be used in 'await' expression"):
513 # it's important that __aexit__ wasn't called
514 run_async(foo())
515
516 def test_with_7(self):
517 class CM:
518 async def __aenter__(self):
519 return self
520
521 def __aexit__(self, *e):
Nick Coghlanbaaadbf2015-05-13 15:54:02 +1000522 return 444
Yury Selivanov75445082015-05-11 22:57:16 -0400523
524 async def foo():
525 async with CM():
Nick Coghlanbaaadbf2015-05-13 15:54:02 +1000526 1/0
527
528 try:
529 run_async(foo())
530 except TypeError as exc:
531 self.assertRegex(
532 exc.args[0], "object int can't be used in 'await' expression")
533 self.assertTrue(exc.__context__ is not None)
534 self.assertTrue(isinstance(exc.__context__, ZeroDivisionError))
535 else:
536 self.fail('invalid asynchronous context manager did not fail')
537
538
539 def test_with_8(self):
540 CNT = 0
541
542 class CM:
543 async def __aenter__(self):
544 return self
545
546 def __aexit__(self, *e):
547 return 456
548
549 async def foo():
550 nonlocal CNT
551 async with CM():
552 CNT += 1
553
Yury Selivanov75445082015-05-11 22:57:16 -0400554
555 with self.assertRaisesRegex(
556 TypeError, "object int can't be used in 'await' expression"):
557
558 run_async(foo())
559
Nick Coghlanbaaadbf2015-05-13 15:54:02 +1000560 self.assertEqual(CNT, 1)
561
562
563 def test_with_9(self):
564 CNT = 0
565
566 class CM:
567 async def __aenter__(self):
568 return self
569
570 async def __aexit__(self, *e):
571 1/0
572
573 async def foo():
574 nonlocal CNT
575 async with CM():
576 CNT += 1
577
578 with self.assertRaises(ZeroDivisionError):
579 run_async(foo())
580
581 self.assertEqual(CNT, 1)
582
583 def test_with_10(self):
584 CNT = 0
585
586 class CM:
587 async def __aenter__(self):
588 return self
589
590 async def __aexit__(self, *e):
591 1/0
592
593 async def foo():
594 nonlocal CNT
595 async with CM():
596 async with CM():
597 raise RuntimeError
598
599 try:
600 run_async(foo())
601 except ZeroDivisionError as exc:
602 self.assertTrue(exc.__context__ is not None)
603 self.assertTrue(isinstance(exc.__context__, ZeroDivisionError))
604 self.assertTrue(isinstance(exc.__context__.__context__,
605 RuntimeError))
606 else:
607 self.fail('exception from __aexit__ did not propagate')
608
609 def test_with_11(self):
610 CNT = 0
611
612 class CM:
613 async def __aenter__(self):
614 raise NotImplementedError
615
616 async def __aexit__(self, *e):
617 1/0
618
619 async def foo():
620 nonlocal CNT
621 async with CM():
622 raise RuntimeError
623
624 try:
625 run_async(foo())
626 except NotImplementedError as exc:
627 self.assertTrue(exc.__context__ is None)
628 else:
629 self.fail('exception from __aenter__ did not propagate')
630
631 def test_with_12(self):
632 CNT = 0
633
634 class CM:
635 async def __aenter__(self):
636 return self
637
638 async def __aexit__(self, *e):
639 return True
640
641 async def foo():
642 nonlocal CNT
643 async with CM() as cm:
644 self.assertIs(cm.__class__, CM)
645 raise RuntimeError
646
647 run_async(foo())
648
Yury Selivanov9113dc72015-05-13 16:49:35 -0400649 def test_with_13(self):
650 CNT = 0
651
652 class CM:
653 async def __aenter__(self):
654 1/0
655
656 async def __aexit__(self, *e):
657 return True
658
659 async def foo():
660 nonlocal CNT
661 CNT += 1
662 async with CM():
663 CNT += 1000
664 CNT += 10000
665
666 with self.assertRaises(ZeroDivisionError):
667 run_async(foo())
668 self.assertEqual(CNT, 1)
669
Yury Selivanov75445082015-05-11 22:57:16 -0400670 def test_for_1(self):
671 aiter_calls = 0
672
673 class AsyncIter:
674 def __init__(self):
675 self.i = 0
676
677 async def __aiter__(self):
678 nonlocal aiter_calls
679 aiter_calls += 1
680 return self
681
682 async def __anext__(self):
683 self.i += 1
684
685 if not (self.i % 10):
686 await AsyncYield(self.i * 10)
687
688 if self.i > 100:
689 raise StopAsyncIteration
690
691 return self.i, self.i
692
693
694 buffer = []
695 async def test1():
696 async for i1, i2 in AsyncIter():
697 buffer.append(i1 + i2)
698
699 yielded, _ = run_async(test1())
700 # Make sure that __aiter__ was called only once
701 self.assertEqual(aiter_calls, 1)
702 self.assertEqual(yielded, [i * 100 for i in range(1, 11)])
703 self.assertEqual(buffer, [i*2 for i in range(1, 101)])
704
705
706 buffer = []
707 async def test2():
708 nonlocal buffer
709 async for i in AsyncIter():
710 buffer.append(i[0])
711 if i[0] == 20:
712 break
713 else:
714 buffer.append('what?')
715 buffer.append('end')
716
717 yielded, _ = run_async(test2())
718 # Make sure that __aiter__ was called only once
719 self.assertEqual(aiter_calls, 2)
720 self.assertEqual(yielded, [100, 200])
721 self.assertEqual(buffer, [i for i in range(1, 21)] + ['end'])
722
723
724 buffer = []
725 async def test3():
726 nonlocal buffer
727 async for i in AsyncIter():
728 if i[0] > 20:
729 continue
730 buffer.append(i[0])
731 else:
732 buffer.append('what?')
733 buffer.append('end')
734
735 yielded, _ = run_async(test3())
736 # Make sure that __aiter__ was called only once
737 self.assertEqual(aiter_calls, 3)
738 self.assertEqual(yielded, [i * 100 for i in range(1, 11)])
739 self.assertEqual(buffer, [i for i in range(1, 21)] +
740 ['what?', 'end'])
741
742 def test_for_2(self):
743 tup = (1, 2, 3)
744 refs_before = sys.getrefcount(tup)
745
746 async def foo():
747 async for i in tup:
748 print('never going to happen')
749
750 with self.assertRaisesRegex(
751 TypeError, "async for' requires an object.*__aiter__.*tuple"):
752
753 run_async(foo())
754
755 self.assertEqual(sys.getrefcount(tup), refs_before)
756
757 def test_for_3(self):
758 class I:
759 def __aiter__(self):
760 return self
761
762 aiter = I()
763 refs_before = sys.getrefcount(aiter)
764
765 async def foo():
766 async for i in aiter:
767 print('never going to happen')
768
769 with self.assertRaisesRegex(
770 TypeError,
771 "async for' received an invalid object.*__aiter.*\: I"):
772
773 run_async(foo())
774
775 self.assertEqual(sys.getrefcount(aiter), refs_before)
776
777 def test_for_4(self):
778 class I:
779 async def __aiter__(self):
780 return self
781
782 def __anext__(self):
783 return ()
784
785 aiter = I()
786 refs_before = sys.getrefcount(aiter)
787
788 async def foo():
789 async for i in aiter:
790 print('never going to happen')
791
792 with self.assertRaisesRegex(
793 TypeError,
794 "async for' received an invalid object.*__anext__.*tuple"):
795
796 run_async(foo())
797
798 self.assertEqual(sys.getrefcount(aiter), refs_before)
799
800 def test_for_5(self):
801 class I:
802 async def __aiter__(self):
803 return self
804
805 def __anext__(self):
806 return 123
807
808 async def foo():
809 async for i in I():
810 print('never going to happen')
811
812 with self.assertRaisesRegex(
813 TypeError,
814 "async for' received an invalid object.*__anext.*int"):
815
816 run_async(foo())
817
818 def test_for_6(self):
819 I = 0
820
821 class Manager:
822 async def __aenter__(self):
823 nonlocal I
824 I += 10000
825
826 async def __aexit__(self, *args):
827 nonlocal I
828 I += 100000
829
830 class Iterable:
831 def __init__(self):
832 self.i = 0
833
834 async def __aiter__(self):
835 return self
836
837 async def __anext__(self):
838 if self.i > 10:
839 raise StopAsyncIteration
840 self.i += 1
841 return self.i
842
843 ##############
844
845 manager = Manager()
846 iterable = Iterable()
847 mrefs_before = sys.getrefcount(manager)
848 irefs_before = sys.getrefcount(iterable)
849
850 async def main():
851 nonlocal I
852
853 async with manager:
854 async for i in iterable:
855 I += 1
856 I += 1000
857
858 run_async(main())
859 self.assertEqual(I, 111011)
860
861 self.assertEqual(sys.getrefcount(manager), mrefs_before)
862 self.assertEqual(sys.getrefcount(iterable), irefs_before)
863
864 ##############
865
866 async def main():
867 nonlocal I
868
869 async with Manager():
870 async for i in Iterable():
871 I += 1
872 I += 1000
873
874 async with Manager():
875 async for i in Iterable():
876 I += 1
877 I += 1000
878
879 run_async(main())
880 self.assertEqual(I, 333033)
881
882 ##############
883
884 async def main():
885 nonlocal I
886
887 async with Manager():
888 I += 100
889 async for i in Iterable():
890 I += 1
891 else:
892 I += 10000000
893 I += 1000
894
895 async with Manager():
896 I += 100
897 async for i in Iterable():
898 I += 1
899 else:
900 I += 10000000
901 I += 1000
902
903 run_async(main())
904 self.assertEqual(I, 20555255)
905
Yury Selivanov9113dc72015-05-13 16:49:35 -0400906 def test_for_7(self):
907 CNT = 0
908 class AI:
909 async def __aiter__(self):
910 1/0
911 async def foo():
912 nonlocal CNT
913 async for i in AI():
914 CNT += 1
915 CNT += 10
916 with self.assertRaises(ZeroDivisionError):
917 run_async(foo())
918 self.assertEqual(CNT, 0)
919
Yury Selivanov75445082015-05-11 22:57:16 -0400920
921class CoroAsyncIOCompatTest(unittest.TestCase):
922
923 def test_asyncio_1(self):
924 import asyncio
925
926 class MyException(Exception):
927 pass
928
929 buffer = []
930
931 class CM:
932 async def __aenter__(self):
933 buffer.append(1)
934 await asyncio.sleep(0.01)
935 buffer.append(2)
936 return self
937
938 async def __aexit__(self, exc_type, exc_val, exc_tb):
939 await asyncio.sleep(0.01)
940 buffer.append(exc_type.__name__)
941
942 async def f():
943 async with CM() as c:
944 await asyncio.sleep(0.01)
945 raise MyException
946 buffer.append('unreachable')
947
Yury Selivanovfdba8382015-05-12 14:28:08 -0400948 loop = asyncio.new_event_loop()
949 asyncio.set_event_loop(loop)
Yury Selivanov75445082015-05-11 22:57:16 -0400950 try:
951 loop.run_until_complete(f())
952 except MyException:
953 pass
954 finally:
955 loop.close()
Yury Selivanovfdba8382015-05-12 14:28:08 -0400956 asyncio.set_event_loop(None)
Yury Selivanov75445082015-05-11 22:57:16 -0400957
958 self.assertEqual(buffer, [1, 2, 'MyException'])
959
960
961class SysSetCoroWrapperTest(unittest.TestCase):
962
963 def test_set_wrapper_1(self):
964 async def foo():
965 return 'spam'
966
967 wrapped = None
968 def wrap(gen):
969 nonlocal wrapped
970 wrapped = gen
971 return gen
972
973 self.assertIsNone(sys.get_coroutine_wrapper())
974
975 sys.set_coroutine_wrapper(wrap)
976 self.assertIs(sys.get_coroutine_wrapper(), wrap)
977 try:
978 f = foo()
979 self.assertTrue(wrapped)
980
981 self.assertEqual(run_async(f), ([], 'spam'))
982 finally:
983 sys.set_coroutine_wrapper(None)
984
985 self.assertIsNone(sys.get_coroutine_wrapper())
986
987 wrapped = None
988 with silence_coro_gc():
989 foo()
990 self.assertFalse(wrapped)
991
992 def test_set_wrapper_2(self):
993 self.assertIsNone(sys.get_coroutine_wrapper())
994 with self.assertRaisesRegex(TypeError, "callable expected, got int"):
995 sys.set_coroutine_wrapper(1)
996 self.assertIsNone(sys.get_coroutine_wrapper())
997
998
999class CAPITest(unittest.TestCase):
1000
1001 def test_tp_await_1(self):
1002 from _testcapi import awaitType as at
1003
1004 async def foo():
1005 future = at(iter([1]))
1006 return (await future)
1007
1008 self.assertEqual(foo().send(None), 1)
1009
1010 def test_tp_await_2(self):
1011 # Test tp_await to __await__ mapping
1012 from _testcapi import awaitType as at
1013 future = at(iter([1]))
1014 self.assertEqual(next(future.__await__()), 1)
1015
1016 def test_tp_await_3(self):
1017 from _testcapi import awaitType as at
1018
1019 async def foo():
1020 future = at(1)
1021 return (await future)
1022
1023 with self.assertRaisesRegex(
1024 TypeError, "__await__.*returned non-iterator of type 'int'"):
1025 self.assertEqual(foo().send(None), 1)
1026
1027
Yury Selivanov75445082015-05-11 22:57:16 -04001028if __name__=="__main__":
Zachary Ware37ac5902015-05-13 01:03:06 -05001029 unittest.main()