blob: 2b170443a3b63895d0bdebc153b84088b837f45e [file] [log] [blame]
Eric Snow7f8bfc92018-01-29 18:23:44 -07001import contextlib
2import os
3import pickle
4from textwrap import dedent, indent
5import threading
6import unittest
7
8from test import support
9from test.support import script_helper
10
11interpreters = support.import_module('_xxsubinterpreters')
12
13
14def _captured_script(script):
15 r, w = os.pipe()
16 indented = script.replace('\n', '\n ')
17 wrapped = dedent(f"""
18 import contextlib
19 with open({w}, 'w') as chan:
20 with contextlib.redirect_stdout(chan):
21 {indented}
22 """)
23 return wrapped, open(r)
24
25
26def _run_output(interp, request, shared=None):
27 script, chan = _captured_script(request)
28 with chan:
29 interpreters.run_string(interp, script, shared)
30 return chan.read()
31
32
33@contextlib.contextmanager
34def _running(interp):
35 r, w = os.pipe()
36 def run():
37 interpreters.run_string(interp, dedent(f"""
38 # wait for "signal"
39 with open({r}) as chan:
40 chan.read()
41 """))
42
43 t = threading.Thread(target=run)
44 t.start()
45
46 yield
47
48 with open(w, 'w') as chan:
49 chan.write('done')
50 t.join()
51
52
53class IsShareableTests(unittest.TestCase):
54
55 def test_default_shareables(self):
56 shareables = [
57 # singletons
58 None,
59 # builtin objects
60 b'spam',
61 ]
62 for obj in shareables:
63 with self.subTest(obj):
64 self.assertTrue(
65 interpreters.is_shareable(obj))
66
67 def test_not_shareable(self):
68 class Cheese:
69 def __init__(self, name):
70 self.name = name
71 def __str__(self):
72 return self.name
73
74 class SubBytes(bytes):
75 """A subclass of a shareable type."""
76
77 not_shareables = [
78 # singletons
79 True,
80 False,
81 NotImplemented,
82 ...,
83 # builtin types and objects
84 type,
85 object,
86 object(),
87 Exception(),
88 42,
89 100.0,
90 'spam',
91 # user-defined types and objects
92 Cheese,
93 Cheese('Wensleydale'),
94 SubBytes(b'spam'),
95 ]
96 for obj in not_shareables:
97 with self.subTest(obj):
98 self.assertFalse(
99 interpreters.is_shareable(obj))
100
101
102class TestBase(unittest.TestCase):
103
104 def tearDown(self):
105 for id in interpreters.list_all():
106 if id == 0: # main
107 continue
108 try:
109 interpreters.destroy(id)
110 except RuntimeError:
111 pass # already destroyed
112
113 for cid in interpreters.channel_list_all():
114 try:
115 interpreters.channel_destroy(cid)
116 except interpreters.ChannelNotFoundError:
117 pass # already destroyed
118
119
120class ListAllTests(TestBase):
121
122 def test_initial(self):
123 main = interpreters.get_main()
124 ids = interpreters.list_all()
125 self.assertEqual(ids, [main])
126
127 def test_after_creating(self):
128 main = interpreters.get_main()
129 first = interpreters.create()
130 second = interpreters.create()
131 ids = interpreters.list_all()
132 self.assertEqual(ids, [main, first, second])
133
134 def test_after_destroying(self):
135 main = interpreters.get_main()
136 first = interpreters.create()
137 second = interpreters.create()
138 interpreters.destroy(first)
139 ids = interpreters.list_all()
140 self.assertEqual(ids, [main, second])
141
142
143class GetCurrentTests(TestBase):
144
145 def test_main(self):
146 main = interpreters.get_main()
147 cur = interpreters.get_current()
148 self.assertEqual(cur, main)
149
150 def test_subinterpreter(self):
151 main = interpreters.get_main()
152 interp = interpreters.create()
153 out = _run_output(interp, dedent("""
154 import _xxsubinterpreters as _interpreters
155 print(_interpreters.get_current())
156 """))
157 cur = int(out.strip())
158 _, expected = interpreters.list_all()
159 self.assertEqual(cur, expected)
160 self.assertNotEqual(cur, main)
161
162
163class GetMainTests(TestBase):
164
165 def test_from_main(self):
166 [expected] = interpreters.list_all()
167 main = interpreters.get_main()
168 self.assertEqual(main, expected)
169
170 def test_from_subinterpreter(self):
171 [expected] = interpreters.list_all()
172 interp = interpreters.create()
173 out = _run_output(interp, dedent("""
174 import _xxsubinterpreters as _interpreters
175 print(_interpreters.get_main())
176 """))
177 main = int(out.strip())
178 self.assertEqual(main, expected)
179
180
181class IsRunningTests(TestBase):
182
183 def test_main(self):
184 main = interpreters.get_main()
185 self.assertTrue(interpreters.is_running(main))
186
187 def test_subinterpreter(self):
188 interp = interpreters.create()
189 self.assertFalse(interpreters.is_running(interp))
190
191 with _running(interp):
192 self.assertTrue(interpreters.is_running(interp))
193 self.assertFalse(interpreters.is_running(interp))
194
195 def test_from_subinterpreter(self):
196 interp = interpreters.create()
197 out = _run_output(interp, dedent(f"""
198 import _xxsubinterpreters as _interpreters
199 if _interpreters.is_running({interp}):
200 print(True)
201 else:
202 print(False)
203 """))
204 self.assertEqual(out.strip(), 'True')
205
206 def test_already_destroyed(self):
207 interp = interpreters.create()
208 interpreters.destroy(interp)
209 with self.assertRaises(RuntimeError):
210 interpreters.is_running(interp)
211
212 def test_does_not_exist(self):
213 with self.assertRaises(RuntimeError):
214 interpreters.is_running(1_000_000)
215
216 def test_bad_id(self):
217 with self.assertRaises(RuntimeError):
218 interpreters.is_running(-1)
219
220
221class CreateTests(TestBase):
222
223 def test_in_main(self):
224 id = interpreters.create()
225
226 self.assertIn(id, interpreters.list_all())
227
228 @unittest.skip('enable this test when working on pystate.c')
229 def test_unique_id(self):
230 seen = set()
231 for _ in range(100):
232 id = interpreters.create()
233 interpreters.destroy(id)
234 seen.add(id)
235
236 self.assertEqual(len(seen), 100)
237
238 def test_in_thread(self):
239 lock = threading.Lock()
240 id = None
241 def f():
242 nonlocal id
243 id = interpreters.create()
244 lock.acquire()
245 lock.release()
246
247 t = threading.Thread(target=f)
248 with lock:
249 t.start()
250 t.join()
251 self.assertIn(id, interpreters.list_all())
252
253 def test_in_subinterpreter(self):
254 main, = interpreters.list_all()
255 id1 = interpreters.create()
256 out = _run_output(id1, dedent("""
257 import _xxsubinterpreters as _interpreters
258 id = _interpreters.create()
259 print(id)
260 """))
261 id2 = int(out.strip())
262
263 self.assertEqual(set(interpreters.list_all()), {main, id1, id2})
264
265 def test_in_threaded_subinterpreter(self):
266 main, = interpreters.list_all()
267 id1 = interpreters.create()
268 id2 = None
269 def f():
270 nonlocal id2
271 out = _run_output(id1, dedent("""
272 import _xxsubinterpreters as _interpreters
273 id = _interpreters.create()
274 print(id)
275 """))
276 id2 = int(out.strip())
277
278 t = threading.Thread(target=f)
279 t.start()
280 t.join()
281
282 self.assertEqual(set(interpreters.list_all()), {main, id1, id2})
283
284 def test_after_destroy_all(self):
285 before = set(interpreters.list_all())
286 # Create 3 subinterpreters.
287 ids = []
288 for _ in range(3):
289 id = interpreters.create()
290 ids.append(id)
291 # Now destroy them.
292 for id in ids:
293 interpreters.destroy(id)
294 # Finally, create another.
295 id = interpreters.create()
296 self.assertEqual(set(interpreters.list_all()), before | {id})
297
298 def test_after_destroy_some(self):
299 before = set(interpreters.list_all())
300 # Create 3 subinterpreters.
301 id1 = interpreters.create()
302 id2 = interpreters.create()
303 id3 = interpreters.create()
304 # Now destroy 2 of them.
305 interpreters.destroy(id1)
306 interpreters.destroy(id3)
307 # Finally, create another.
308 id = interpreters.create()
309 self.assertEqual(set(interpreters.list_all()), before | {id, id2})
310
311
312class DestroyTests(TestBase):
313
314 def test_one(self):
315 id1 = interpreters.create()
316 id2 = interpreters.create()
317 id3 = interpreters.create()
318 self.assertIn(id2, interpreters.list_all())
319 interpreters.destroy(id2)
320 self.assertNotIn(id2, interpreters.list_all())
321 self.assertIn(id1, interpreters.list_all())
322 self.assertIn(id3, interpreters.list_all())
323
324 def test_all(self):
325 before = set(interpreters.list_all())
326 ids = set()
327 for _ in range(3):
328 id = interpreters.create()
329 ids.add(id)
330 self.assertEqual(set(interpreters.list_all()), before | ids)
331 for id in ids:
332 interpreters.destroy(id)
333 self.assertEqual(set(interpreters.list_all()), before)
334
335 def test_main(self):
336 main, = interpreters.list_all()
337 with self.assertRaises(RuntimeError):
338 interpreters.destroy(main)
339
340 def f():
341 with self.assertRaises(RuntimeError):
342 interpreters.destroy(main)
343
344 t = threading.Thread(target=f)
345 t.start()
346 t.join()
347
348 def test_already_destroyed(self):
349 id = interpreters.create()
350 interpreters.destroy(id)
351 with self.assertRaises(RuntimeError):
352 interpreters.destroy(id)
353
354 def test_does_not_exist(self):
355 with self.assertRaises(RuntimeError):
356 interpreters.destroy(1_000_000)
357
358 def test_bad_id(self):
359 with self.assertRaises(RuntimeError):
360 interpreters.destroy(-1)
361
362 def test_from_current(self):
363 main, = interpreters.list_all()
364 id = interpreters.create()
365 script = dedent("""
366 import _xxsubinterpreters as _interpreters
367 _interpreters.destroy({})
368 """).format(id)
369
370 with self.assertRaises(RuntimeError):
371 interpreters.run_string(id, script)
372 self.assertEqual(set(interpreters.list_all()), {main, id})
373
374 def test_from_sibling(self):
375 main, = interpreters.list_all()
376 id1 = interpreters.create()
377 id2 = interpreters.create()
378 script = dedent("""
379 import _xxsubinterpreters as _interpreters
380 _interpreters.destroy({})
381 """).format(id2)
382 interpreters.run_string(id1, script)
383
384 self.assertEqual(set(interpreters.list_all()), {main, id1})
385
386 def test_from_other_thread(self):
387 id = interpreters.create()
388 def f():
389 interpreters.destroy(id)
390
391 t = threading.Thread(target=f)
392 t.start()
393 t.join()
394
395 def test_still_running(self):
396 main, = interpreters.list_all()
397 interp = interpreters.create()
398 with _running(interp):
399 with self.assertRaises(RuntimeError):
400 interpreters.destroy(interp)
401 self.assertTrue(interpreters.is_running(interp))
402
403
404class RunStringTests(TestBase):
405
406 SCRIPT = dedent("""
407 with open('{}', 'w') as out:
408 out.write('{}')
409 """)
410 FILENAME = 'spam'
411
412 def setUp(self):
413 super().setUp()
414 self.id = interpreters.create()
415 self._fs = None
416
417 def tearDown(self):
418 if self._fs is not None:
419 self._fs.close()
420 super().tearDown()
421
422 @property
423 def fs(self):
424 if self._fs is None:
425 self._fs = FSFixture(self)
426 return self._fs
427
428 def test_success(self):
429 script, file = _captured_script('print("it worked!", end="")')
430 with file:
431 interpreters.run_string(self.id, script)
432 out = file.read()
433
434 self.assertEqual(out, 'it worked!')
435
436 def test_in_thread(self):
437 script, file = _captured_script('print("it worked!", end="")')
438 with file:
439 def f():
440 interpreters.run_string(self.id, script)
441
442 t = threading.Thread(target=f)
443 t.start()
444 t.join()
445 out = file.read()
446
447 self.assertEqual(out, 'it worked!')
448
449 def test_create_thread(self):
450 script, file = _captured_script("""
451 import threading
452 def f():
453 print('it worked!', end='')
454
455 t = threading.Thread(target=f)
456 t.start()
457 t.join()
458 """)
459 with file:
460 interpreters.run_string(self.id, script)
461 out = file.read()
462
463 self.assertEqual(out, 'it worked!')
464
465 @unittest.skipUnless(hasattr(os, 'fork'), "test needs os.fork()")
466 def test_fork(self):
467 import tempfile
468 with tempfile.NamedTemporaryFile('w+') as file:
469 file.write('')
470 file.flush()
471
472 expected = 'spam spam spam spam spam'
473 script = dedent(f"""
474 # (inspired by Lib/test/test_fork.py)
475 import os
476 pid = os.fork()
477 if pid == 0: # child
478 with open('{file.name}', 'w') as out:
479 out.write('{expected}')
480 # Kill the unittest runner in the child process.
481 os._exit(1)
482 else:
483 SHORT_SLEEP = 0.1
484 import time
485 for _ in range(10):
486 spid, status = os.waitpid(pid, os.WNOHANG)
487 if spid == pid:
488 break
489 time.sleep(SHORT_SLEEP)
490 assert(spid == pid)
491 """)
492 interpreters.run_string(self.id, script)
493
494 file.seek(0)
495 content = file.read()
496 self.assertEqual(content, expected)
497
498 def test_already_running(self):
499 with _running(self.id):
500 with self.assertRaises(RuntimeError):
501 interpreters.run_string(self.id, 'print("spam")')
502
503 def test_does_not_exist(self):
504 id = 0
505 while id in interpreters.list_all():
506 id += 1
507 with self.assertRaises(RuntimeError):
508 interpreters.run_string(id, 'print("spam")')
509
510 def test_error_id(self):
511 with self.assertRaises(RuntimeError):
512 interpreters.run_string(-1, 'print("spam")')
513
514 def test_bad_id(self):
515 with self.assertRaises(TypeError):
516 interpreters.run_string('spam', 'print("spam")')
517
518 def test_bad_script(self):
519 with self.assertRaises(TypeError):
520 interpreters.run_string(self.id, 10)
521
522 def test_bytes_for_script(self):
523 with self.assertRaises(TypeError):
524 interpreters.run_string(self.id, b'print("spam")')
525
526 @contextlib.contextmanager
527 def assert_run_failed(self, exctype, msg=None):
528 with self.assertRaises(interpreters.RunFailedError) as caught:
529 yield
530 if msg is None:
531 self.assertEqual(str(caught.exception).split(':')[0],
532 str(exctype))
533 else:
534 self.assertEqual(str(caught.exception),
535 "{}: {}".format(exctype, msg))
536
537 def test_invalid_syntax(self):
538 with self.assert_run_failed(SyntaxError):
539 # missing close paren
540 interpreters.run_string(self.id, 'print("spam"')
541
542 def test_failure(self):
543 with self.assert_run_failed(Exception, 'spam'):
544 interpreters.run_string(self.id, 'raise Exception("spam")')
545
546 def test_SystemExit(self):
547 with self.assert_run_failed(SystemExit, '42'):
548 interpreters.run_string(self.id, 'raise SystemExit(42)')
549
550 def test_sys_exit(self):
551 with self.assert_run_failed(SystemExit):
552 interpreters.run_string(self.id, dedent("""
553 import sys
554 sys.exit()
555 """))
556
557 with self.assert_run_failed(SystemExit, '42'):
558 interpreters.run_string(self.id, dedent("""
559 import sys
560 sys.exit(42)
561 """))
562
563 def test_with_shared(self):
564 r, w = os.pipe()
565
566 shared = {
567 'spam': b'ham',
568 'eggs': b'-1',
569 'cheddar': None,
570 }
571 script = dedent(f"""
572 eggs = int(eggs)
573 spam = 42
574 result = spam + eggs
575
576 ns = dict(vars())
577 del ns['__builtins__']
578 import pickle
579 with open({w}, 'wb') as chan:
580 pickle.dump(ns, chan)
581 """)
582 interpreters.run_string(self.id, script, shared)
583 with open(r, 'rb') as chan:
584 ns = pickle.load(chan)
585
586 self.assertEqual(ns['spam'], 42)
587 self.assertEqual(ns['eggs'], -1)
588 self.assertEqual(ns['result'], 41)
589 self.assertIsNone(ns['cheddar'])
590
591 def test_shared_overwrites(self):
592 interpreters.run_string(self.id, dedent("""
593 spam = 'eggs'
594 ns1 = dict(vars())
595 del ns1['__builtins__']
596 """))
597
598 shared = {'spam': b'ham'}
599 script = dedent(f"""
600 ns2 = dict(vars())
601 del ns2['__builtins__']
602 """)
603 interpreters.run_string(self.id, script, shared)
604
605 r, w = os.pipe()
606 script = dedent(f"""
607 ns = dict(vars())
608 del ns['__builtins__']
609 import pickle
610 with open({w}, 'wb') as chan:
611 pickle.dump(ns, chan)
612 """)
613 interpreters.run_string(self.id, script)
614 with open(r, 'rb') as chan:
615 ns = pickle.load(chan)
616
617 self.assertEqual(ns['ns1']['spam'], 'eggs')
618 self.assertEqual(ns['ns2']['spam'], b'ham')
619 self.assertEqual(ns['spam'], b'ham')
620
621 def test_shared_overwrites_default_vars(self):
622 r, w = os.pipe()
623
624 shared = {'__name__': b'not __main__'}
625 script = dedent(f"""
626 spam = 42
627
628 ns = dict(vars())
629 del ns['__builtins__']
630 import pickle
631 with open({w}, 'wb') as chan:
632 pickle.dump(ns, chan)
633 """)
634 interpreters.run_string(self.id, script, shared)
635 with open(r, 'rb') as chan:
636 ns = pickle.load(chan)
637
638 self.assertEqual(ns['__name__'], b'not __main__')
639
640 def test_main_reused(self):
641 r, w = os.pipe()
642 interpreters.run_string(self.id, dedent(f"""
643 spam = True
644
645 ns = dict(vars())
646 del ns['__builtins__']
647 import pickle
648 with open({w}, 'wb') as chan:
649 pickle.dump(ns, chan)
650 del ns, pickle, chan
651 """))
652 with open(r, 'rb') as chan:
653 ns1 = pickle.load(chan)
654
655 r, w = os.pipe()
656 interpreters.run_string(self.id, dedent(f"""
657 eggs = False
658
659 ns = dict(vars())
660 del ns['__builtins__']
661 import pickle
662 with open({w}, 'wb') as chan:
663 pickle.dump(ns, chan)
664 """))
665 with open(r, 'rb') as chan:
666 ns2 = pickle.load(chan)
667
668 self.assertIn('spam', ns1)
669 self.assertNotIn('eggs', ns1)
670 self.assertIn('eggs', ns2)
671 self.assertIn('spam', ns2)
672
673 def test_execution_namespace_is_main(self):
674 r, w = os.pipe()
675
676 script = dedent(f"""
677 spam = 42
678
679 ns = dict(vars())
680 ns['__builtins__'] = str(ns['__builtins__'])
681 import pickle
682 with open({w}, 'wb') as chan:
683 pickle.dump(ns, chan)
684 """)
685 interpreters.run_string(self.id, script)
686 with open(r, 'rb') as chan:
687 ns = pickle.load(chan)
688
689 ns.pop('__builtins__')
690 ns.pop('__loader__')
691 self.assertEqual(ns, {
692 '__name__': '__main__',
693 '__annotations__': {},
694 '__doc__': None,
695 '__package__': None,
696 '__spec__': None,
697 'spam': 42,
698 })
699
700 def test_still_running_at_exit(self):
701 script = dedent(f"""
702 from textwrap import dedent
703 import threading
704 import _xxsubinterpreters as _interpreters
705 def f():
706 _interpreters.run_string(id, dedent('''
707 import time
708 # Give plenty of time for the main interpreter to finish.
709 time.sleep(1_000_000)
710 '''))
711
712 t = threading.Thread(target=f)
713 t.start()
714 """)
715 with support.temp_dir() as dirname:
716 filename = script_helper.make_script(dirname, 'interp', script)
717 with script_helper.spawn_python(filename) as proc:
718 retcode = proc.wait()
719
720 self.assertEqual(retcode, 0)
721
722
723class ChannelIDTests(TestBase):
724
725 def test_default_kwargs(self):
726 cid = interpreters._channel_id(10, force=True)
727
728 self.assertEqual(int(cid), 10)
729 self.assertEqual(cid.end, 'both')
730
731 def test_with_kwargs(self):
732 cid = interpreters._channel_id(10, send=True, force=True)
733 self.assertEqual(cid.end, 'send')
734
735 cid = interpreters._channel_id(10, send=True, recv=False, force=True)
736 self.assertEqual(cid.end, 'send')
737
738 cid = interpreters._channel_id(10, recv=True, force=True)
739 self.assertEqual(cid.end, 'recv')
740
741 cid = interpreters._channel_id(10, recv=True, send=False, force=True)
742 self.assertEqual(cid.end, 'recv')
743
744 cid = interpreters._channel_id(10, send=True, recv=True, force=True)
745 self.assertEqual(cid.end, 'both')
746
747 def test_coerce_id(self):
748 cid = interpreters._channel_id('10', force=True)
749 self.assertEqual(int(cid), 10)
750
751 cid = interpreters._channel_id(10.0, force=True)
752 self.assertEqual(int(cid), 10)
753
754 class Int(str):
755 def __init__(self, value):
756 self._value = value
757 def __int__(self):
758 return self._value
759
760 cid = interpreters._channel_id(Int(10), force=True)
761 self.assertEqual(int(cid), 10)
762
763 def test_bad_id(self):
764 ids = [-1, 2**64, "spam"]
765 for cid in ids:
766 with self.subTest(cid):
767 with self.assertRaises(ValueError):
768 interpreters._channel_id(cid)
769
770 with self.assertRaises(TypeError):
771 interpreters._channel_id(object())
772
773 def test_bad_kwargs(self):
774 with self.assertRaises(ValueError):
775 interpreters._channel_id(10, send=False, recv=False)
776
777 def test_does_not_exist(self):
778 cid = interpreters.channel_create()
779 with self.assertRaises(interpreters.ChannelNotFoundError):
780 interpreters._channel_id(int(cid) + 1) # unforced
781
782 def test_repr(self):
783 cid = interpreters._channel_id(10, force=True)
784 self.assertEqual(repr(cid), 'ChannelID(10)')
785
786 cid = interpreters._channel_id(10, send=True, force=True)
787 self.assertEqual(repr(cid), 'ChannelID(10, send=True)')
788
789 cid = interpreters._channel_id(10, recv=True, force=True)
790 self.assertEqual(repr(cid), 'ChannelID(10, recv=True)')
791
792 cid = interpreters._channel_id(10, send=True, recv=True, force=True)
793 self.assertEqual(repr(cid), 'ChannelID(10)')
794
795 def test_equality(self):
796 cid1 = interpreters.channel_create()
797 cid2 = interpreters._channel_id(int(cid1))
798 cid3 = interpreters.channel_create()
799
800 self.assertTrue(cid1 == cid1)
801 self.assertTrue(cid1 == cid2)
802 self.assertTrue(cid1 == int(cid1))
803 self.assertFalse(cid1 == cid3)
804
805 self.assertFalse(cid1 != cid1)
806 self.assertFalse(cid1 != cid2)
807 self.assertTrue(cid1 != cid3)
808
809
810class ChannelTests(TestBase):
811
812 def test_sequential_ids(self):
813 before = interpreters.channel_list_all()
814 id1 = interpreters.channel_create()
815 id2 = interpreters.channel_create()
816 id3 = interpreters.channel_create()
817 after = interpreters.channel_list_all()
818
819 self.assertEqual(id2, int(id1) + 1)
820 self.assertEqual(id3, int(id2) + 1)
821 self.assertEqual(set(after) - set(before), {id1, id2, id3})
822
823 def test_ids_global(self):
824 id1 = interpreters.create()
825 out = _run_output(id1, dedent("""
826 import _xxsubinterpreters as _interpreters
827 cid = _interpreters.channel_create()
828 print(int(cid))
829 """))
830 cid1 = int(out.strip())
831
832 id2 = interpreters.create()
833 out = _run_output(id2, dedent("""
834 import _xxsubinterpreters as _interpreters
835 cid = _interpreters.channel_create()
836 print(int(cid))
837 """))
838 cid2 = int(out.strip())
839
840 self.assertEqual(cid2, int(cid1) + 1)
841
842 ####################
843
844 def test_drop_single_user(self):
845 cid = interpreters.channel_create()
846 interpreters.channel_send(cid, b'spam')
847 interpreters.channel_recv(cid)
848 interpreters.channel_drop_interpreter(cid, send=True, recv=True)
849
850 with self.assertRaises(interpreters.ChannelClosedError):
851 interpreters.channel_send(cid, b'eggs')
852 with self.assertRaises(interpreters.ChannelClosedError):
853 interpreters.channel_recv(cid)
854
855 def test_drop_multiple_users(self):
856 cid = interpreters.channel_create()
857 id1 = interpreters.create()
858 id2 = interpreters.create()
859 interpreters.run_string(id1, dedent(f"""
860 import _xxsubinterpreters as _interpreters
861 _interpreters.channel_send({int(cid)}, b'spam')
862 """))
863 out = _run_output(id2, dedent(f"""
864 import _xxsubinterpreters as _interpreters
865 obj = _interpreters.channel_recv({int(cid)})
866 _interpreters.channel_drop_interpreter({int(cid)})
867 print(repr(obj))
868 """))
869 interpreters.run_string(id1, dedent(f"""
870 _interpreters.channel_drop_interpreter({int(cid)})
871 """))
872
873 self.assertEqual(out.strip(), "b'spam'")
874
875 def test_drop_no_kwargs(self):
876 cid = interpreters.channel_create()
877 interpreters.channel_send(cid, b'spam')
878 interpreters.channel_recv(cid)
879 interpreters.channel_drop_interpreter(cid)
880
881 with self.assertRaises(interpreters.ChannelClosedError):
882 interpreters.channel_send(cid, b'eggs')
883 with self.assertRaises(interpreters.ChannelClosedError):
884 interpreters.channel_recv(cid)
885
886 def test_drop_multiple_times(self):
887 cid = interpreters.channel_create()
888 interpreters.channel_send(cid, b'spam')
889 interpreters.channel_recv(cid)
890 interpreters.channel_drop_interpreter(cid, send=True, recv=True)
891
892 with self.assertRaises(interpreters.ChannelClosedError):
893 interpreters.channel_drop_interpreter(cid, send=True, recv=True)
894
895 def test_drop_with_unused_items(self):
896 cid = interpreters.channel_create()
897 interpreters.channel_send(cid, b'spam')
898 interpreters.channel_send(cid, b'ham')
899 interpreters.channel_drop_interpreter(cid, send=True, recv=True)
900
901 with self.assertRaises(interpreters.ChannelClosedError):
902 interpreters.channel_recv(cid)
903
904 def test_drop_never_used(self):
905 cid = interpreters.channel_create()
906 interpreters.channel_drop_interpreter(cid)
907
908 with self.assertRaises(interpreters.ChannelClosedError):
909 interpreters.channel_send(cid, b'spam')
910 with self.assertRaises(interpreters.ChannelClosedError):
911 interpreters.channel_recv(cid)
912
913 def test_drop_by_unassociated_interp(self):
914 cid = interpreters.channel_create()
915 interpreters.channel_send(cid, b'spam')
916 interp = interpreters.create()
917 interpreters.run_string(interp, dedent(f"""
918 import _xxsubinterpreters as _interpreters
919 _interpreters.channel_drop_interpreter({int(cid)})
920 """))
921 obj = interpreters.channel_recv(cid)
922 interpreters.channel_drop_interpreter(cid)
923
924 with self.assertRaises(interpreters.ChannelClosedError):
925 interpreters.channel_send(cid, b'eggs')
926 self.assertEqual(obj, b'spam')
927
928 def test_drop_close_if_unassociated(self):
929 cid = interpreters.channel_create()
930 interp = interpreters.create()
931 interpreters.run_string(interp, dedent(f"""
932 import _xxsubinterpreters as _interpreters
933 obj = _interpreters.channel_send({int(cid)}, b'spam')
934 _interpreters.channel_drop_interpreter({int(cid)})
935 """))
936
937 with self.assertRaises(interpreters.ChannelClosedError):
938 interpreters.channel_recv(cid)
939
940 def test_drop_partially(self):
941 # XXX Is partial close too wierd/confusing?
942 cid = interpreters.channel_create()
943 interpreters.channel_send(cid, None)
944 interpreters.channel_recv(cid)
945 interpreters.channel_send(cid, b'spam')
946 interpreters.channel_drop_interpreter(cid, send=True)
947 obj = interpreters.channel_recv(cid)
948
949 self.assertEqual(obj, b'spam')
950
951 def test_drop_used_multiple_times_by_single_user(self):
952 cid = interpreters.channel_create()
953 interpreters.channel_send(cid, b'spam')
954 interpreters.channel_send(cid, b'spam')
955 interpreters.channel_send(cid, b'spam')
956 interpreters.channel_recv(cid)
957 interpreters.channel_drop_interpreter(cid, send=True, recv=True)
958
959 with self.assertRaises(interpreters.ChannelClosedError):
960 interpreters.channel_send(cid, b'eggs')
961 with self.assertRaises(interpreters.ChannelClosedError):
962 interpreters.channel_recv(cid)
963
964 ####################
965
966 def test_close_single_user(self):
967 cid = interpreters.channel_create()
968 interpreters.channel_send(cid, b'spam')
969 interpreters.channel_recv(cid)
970 interpreters.channel_close(cid)
971
972 with self.assertRaises(interpreters.ChannelClosedError):
973 interpreters.channel_send(cid, b'eggs')
974 with self.assertRaises(interpreters.ChannelClosedError):
975 interpreters.channel_recv(cid)
976
977 def test_close_multiple_users(self):
978 cid = interpreters.channel_create()
979 id1 = interpreters.create()
980 id2 = interpreters.create()
981 interpreters.run_string(id1, dedent(f"""
982 import _xxsubinterpreters as _interpreters
983 _interpreters.channel_send({int(cid)}, b'spam')
984 """))
985 interpreters.run_string(id2, dedent(f"""
986 import _xxsubinterpreters as _interpreters
987 _interpreters.channel_recv({int(cid)})
988 """))
989 interpreters.channel_close(cid)
990 with self.assertRaises(interpreters.RunFailedError) as cm:
991 interpreters.run_string(id1, dedent(f"""
992 _interpreters.channel_send({int(cid)}, b'spam')
993 """))
994 self.assertIn('ChannelClosedError', str(cm.exception))
995 with self.assertRaises(interpreters.RunFailedError) as cm:
996 interpreters.run_string(id2, dedent(f"""
997 _interpreters.channel_send({int(cid)}, b'spam')
998 """))
999 self.assertIn('ChannelClosedError', str(cm.exception))
1000
1001 def test_close_multiple_times(self):
1002 cid = interpreters.channel_create()
1003 interpreters.channel_send(cid, b'spam')
1004 interpreters.channel_recv(cid)
1005 interpreters.channel_close(cid)
1006
1007 with self.assertRaises(interpreters.ChannelClosedError):
1008 interpreters.channel_close(cid)
1009
1010 def test_close_with_unused_items(self):
1011 cid = interpreters.channel_create()
1012 interpreters.channel_send(cid, b'spam')
1013 interpreters.channel_send(cid, b'ham')
1014 interpreters.channel_close(cid)
1015
1016 with self.assertRaises(interpreters.ChannelClosedError):
1017 interpreters.channel_recv(cid)
1018
1019 def test_close_never_used(self):
1020 cid = interpreters.channel_create()
1021 interpreters.channel_close(cid)
1022
1023 with self.assertRaises(interpreters.ChannelClosedError):
1024 interpreters.channel_send(cid, b'spam')
1025 with self.assertRaises(interpreters.ChannelClosedError):
1026 interpreters.channel_recv(cid)
1027
1028 def test_close_by_unassociated_interp(self):
1029 cid = interpreters.channel_create()
1030 interpreters.channel_send(cid, b'spam')
1031 interp = interpreters.create()
1032 interpreters.run_string(interp, dedent(f"""
1033 import _xxsubinterpreters as _interpreters
1034 _interpreters.channel_close({int(cid)})
1035 """))
1036 with self.assertRaises(interpreters.ChannelClosedError):
1037 interpreters.channel_recv(cid)
1038 with self.assertRaises(interpreters.ChannelClosedError):
1039 interpreters.channel_close(cid)
1040
1041 def test_close_used_multiple_times_by_single_user(self):
1042 cid = interpreters.channel_create()
1043 interpreters.channel_send(cid, b'spam')
1044 interpreters.channel_send(cid, b'spam')
1045 interpreters.channel_send(cid, b'spam')
1046 interpreters.channel_recv(cid)
1047 interpreters.channel_close(cid)
1048
1049 with self.assertRaises(interpreters.ChannelClosedError):
1050 interpreters.channel_send(cid, b'eggs')
1051 with self.assertRaises(interpreters.ChannelClosedError):
1052 interpreters.channel_recv(cid)
1053
1054 ####################
1055
1056 def test_send_recv_main(self):
1057 cid = interpreters.channel_create()
1058 orig = b'spam'
1059 interpreters.channel_send(cid, orig)
1060 obj = interpreters.channel_recv(cid)
1061
1062 self.assertEqual(obj, orig)
1063 self.assertIsNot(obj, orig)
1064
1065 def test_send_recv_same_interpreter(self):
1066 id1 = interpreters.create()
1067 out = _run_output(id1, dedent("""
1068 import _xxsubinterpreters as _interpreters
1069 cid = _interpreters.channel_create()
1070 orig = b'spam'
1071 _interpreters.channel_send(cid, orig)
1072 obj = _interpreters.channel_recv(cid)
1073 assert obj is not orig
1074 assert obj == orig
1075 """))
1076
1077 def test_send_recv_different_interpreters(self):
1078 cid = interpreters.channel_create()
1079 id1 = interpreters.create()
1080 out = _run_output(id1, dedent(f"""
1081 import _xxsubinterpreters as _interpreters
1082 _interpreters.channel_send({int(cid)}, b'spam')
1083 """))
1084 obj = interpreters.channel_recv(cid)
1085
1086 self.assertEqual(obj, b'spam')
1087
1088 def test_send_not_found(self):
1089 with self.assertRaises(interpreters.ChannelNotFoundError):
1090 interpreters.channel_send(10, b'spam')
1091
1092 def test_recv_not_found(self):
1093 with self.assertRaises(interpreters.ChannelNotFoundError):
1094 interpreters.channel_recv(10)
1095
1096 def test_recv_empty(self):
1097 cid = interpreters.channel_create()
1098 with self.assertRaises(interpreters.ChannelEmptyError):
1099 interpreters.channel_recv(cid)
1100
1101 def test_run_string_arg(self):
1102 cid = interpreters.channel_create()
1103 interp = interpreters.create()
1104
1105 out = _run_output(interp, dedent("""
1106 import _xxsubinterpreters as _interpreters
1107 print(cid.end)
1108 _interpreters.channel_send(cid, b'spam')
1109 """),
1110 dict(cid=cid.send))
1111 obj = interpreters.channel_recv(cid)
1112
1113 self.assertEqual(obj, b'spam')
1114 self.assertEqual(out.strip(), 'send')
1115
1116
1117if __name__ == '__main__':
1118 unittest.main()