| import contextlib |
| import os |
| import threading |
| from textwrap import dedent |
| import unittest |
| import time |
| |
| import _xxsubinterpreters as _interpreters |
| from test.support import interpreters |
| |
| |
| def _captured_script(script): |
| r, w = os.pipe() |
| indented = script.replace('\n', '\n ') |
| wrapped = dedent(f""" |
| import contextlib |
| with open({w}, 'w') as spipe: |
| with contextlib.redirect_stdout(spipe): |
| {indented} |
| """) |
| return wrapped, open(r) |
| |
| |
| def clean_up_interpreters(): |
| for interp in interpreters.list_all(): |
| if interp.id == 0: # main |
| continue |
| try: |
| interp.close() |
| except RuntimeError: |
| pass # already destroyed |
| |
| |
| def _run_output(interp, request, shared=None): |
| script, rpipe = _captured_script(request) |
| with rpipe: |
| interp.run(script) |
| return rpipe.read() |
| |
| |
| @contextlib.contextmanager |
| def _running(interp): |
| r, w = os.pipe() |
| def run(): |
| interp.run(dedent(f""" |
| # wait for "signal" |
| with open({r}) as rpipe: |
| rpipe.read() |
| """)) |
| |
| t = threading.Thread(target=run) |
| t.start() |
| |
| yield |
| |
| with open(w, 'w') as spipe: |
| spipe.write('done') |
| t.join() |
| |
| |
| class TestBase(unittest.TestCase): |
| |
| def tearDown(self): |
| clean_up_interpreters() |
| |
| |
| class CreateTests(TestBase): |
| |
| def test_in_main(self): |
| interp = interpreters.create() |
| lst = interpreters.list_all() |
| self.assertEqual(interp.id, lst[1].id) |
| |
| def test_in_thread(self): |
| lock = threading.Lock() |
| id = None |
| interp = interpreters.create() |
| lst = interpreters.list_all() |
| def f(): |
| nonlocal id |
| id = interp.id |
| lock.acquire() |
| lock.release() |
| |
| t = threading.Thread(target=f) |
| with lock: |
| t.start() |
| t.join() |
| self.assertEqual(interp.id, lst[1].id) |
| |
| def test_in_subinterpreter(self): |
| main, = interpreters.list_all() |
| interp = interpreters.create() |
| out = _run_output(interp, dedent(""" |
| from test.support import interpreters |
| interp = interpreters.create() |
| print(interp) |
| """)) |
| interp2 = out.strip() |
| |
| self.assertEqual(len(set(interpreters.list_all())), len({main, interp, interp2})) |
| |
| def test_after_destroy_all(self): |
| before = set(interpreters.list_all()) |
| # Create 3 subinterpreters. |
| interp_lst = [] |
| for _ in range(3): |
| interps = interpreters.create() |
| interp_lst.append(interps) |
| # Now destroy them. |
| for interp in interp_lst: |
| interp.close() |
| # Finally, create another. |
| interp = interpreters.create() |
| self.assertEqual(len(set(interpreters.list_all())), len(before | {interp})) |
| |
| def test_after_destroy_some(self): |
| before = set(interpreters.list_all()) |
| # Create 3 subinterpreters. |
| interp1 = interpreters.create() |
| interp2 = interpreters.create() |
| interp3 = interpreters.create() |
| # Now destroy 2 of them. |
| interp1.close() |
| interp2.close() |
| # Finally, create another. |
| interp = interpreters.create() |
| self.assertEqual(len(set(interpreters.list_all())), len(before | {interp3, interp})) |
| |
| |
| class GetCurrentTests(TestBase): |
| |
| def test_main(self): |
| main_interp_id = _interpreters.get_main() |
| cur_interp_id = interpreters.get_current().id |
| self.assertEqual(cur_interp_id, main_interp_id) |
| |
| def test_subinterpreter(self): |
| main = _interpreters.get_main() |
| interp = interpreters.create() |
| out = _run_output(interp, dedent(""" |
| from test.support import interpreters |
| cur = interpreters.get_current() |
| print(cur) |
| """)) |
| cur = out.strip() |
| self.assertNotEqual(cur, main) |
| |
| |
| class ListAllTests(TestBase): |
| |
| def test_initial(self): |
| interps = interpreters.list_all() |
| self.assertEqual(1, len(interps)) |
| |
| def test_after_creating(self): |
| main = interpreters.get_current() |
| first = interpreters.create() |
| second = interpreters.create() |
| |
| ids = [] |
| for interp in interpreters.list_all(): |
| ids.append(interp.id) |
| |
| self.assertEqual(ids, [main.id, first.id, second.id]) |
| |
| def test_after_destroying(self): |
| main = interpreters.get_current() |
| first = interpreters.create() |
| second = interpreters.create() |
| first.close() |
| |
| ids = [] |
| for interp in interpreters.list_all(): |
| ids.append(interp.id) |
| |
| self.assertEqual(ids, [main.id, second.id]) |
| |
| |
| class TestInterpreterId(TestBase): |
| |
| def test_in_main(self): |
| main = interpreters.get_current() |
| self.assertEqual(0, main.id) |
| |
| def test_with_custom_num(self): |
| interp = interpreters.Interpreter(1) |
| self.assertEqual(1, interp.id) |
| |
| def test_for_readonly_property(self): |
| interp = interpreters.Interpreter(1) |
| with self.assertRaises(AttributeError): |
| interp.id = 2 |
| |
| |
| class TestInterpreterIsRunning(TestBase): |
| |
| def test_main(self): |
| main = interpreters.get_current() |
| self.assertTrue(main.is_running()) |
| |
| def test_subinterpreter(self): |
| interp = interpreters.create() |
| self.assertFalse(interp.is_running()) |
| |
| with _running(interp): |
| self.assertTrue(interp.is_running()) |
| self.assertFalse(interp.is_running()) |
| |
| def test_from_subinterpreter(self): |
| interp = interpreters.create() |
| out = _run_output(interp, dedent(f""" |
| import _xxsubinterpreters as _interpreters |
| if _interpreters.is_running({interp.id}): |
| print(True) |
| else: |
| print(False) |
| """)) |
| self.assertEqual(out.strip(), 'True') |
| |
| def test_already_destroyed(self): |
| interp = interpreters.create() |
| interp.close() |
| with self.assertRaises(RuntimeError): |
| interp.is_running() |
| |
| |
| class TestInterpreterDestroy(TestBase): |
| |
| def test_basic(self): |
| interp1 = interpreters.create() |
| interp2 = interpreters.create() |
| interp3 = interpreters.create() |
| self.assertEqual(4, len(interpreters.list_all())) |
| interp2.close() |
| self.assertEqual(3, len(interpreters.list_all())) |
| |
| def test_all(self): |
| before = set(interpreters.list_all()) |
| interps = set() |
| for _ in range(3): |
| interp = interpreters.create() |
| interps.add(interp) |
| self.assertEqual(len(set(interpreters.list_all())), len(before | interps)) |
| for interp in interps: |
| interp.close() |
| self.assertEqual(len(set(interpreters.list_all())), len(before)) |
| |
| def test_main(self): |
| main, = interpreters.list_all() |
| with self.assertRaises(RuntimeError): |
| main.close() |
| |
| def f(): |
| with self.assertRaises(RuntimeError): |
| main.close() |
| |
| t = threading.Thread(target=f) |
| t.start() |
| t.join() |
| |
| def test_already_destroyed(self): |
| interp = interpreters.create() |
| interp.close() |
| with self.assertRaises(RuntimeError): |
| interp.close() |
| |
| def test_from_current(self): |
| main, = interpreters.list_all() |
| interp = interpreters.create() |
| script = dedent(f""" |
| from test.support import interpreters |
| try: |
| main = interpreters.get_current() |
| main.close() |
| except RuntimeError: |
| pass |
| """) |
| |
| interp.run(script) |
| self.assertEqual(len(set(interpreters.list_all())), len({main, interp})) |
| |
| def test_from_sibling(self): |
| main, = interpreters.list_all() |
| interp1 = interpreters.create() |
| script = dedent(f""" |
| from test.support import interpreters |
| interp2 = interpreters.create() |
| interp2.close() |
| """) |
| interp1.run(script) |
| |
| self.assertEqual(len(set(interpreters.list_all())), len({main, interp1})) |
| |
| def test_from_other_thread(self): |
| interp = interpreters.create() |
| def f(): |
| interp.close() |
| |
| t = threading.Thread(target=f) |
| t.start() |
| t.join() |
| |
| def test_still_running(self): |
| main, = interpreters.list_all() |
| interp = interpreters.create() |
| with _running(interp): |
| with self.assertRaises(RuntimeError): |
| interp.close() |
| self.assertTrue(interp.is_running()) |
| |
| |
| class TestInterpreterRun(TestBase): |
| |
| SCRIPT = dedent(""" |
| with open('{}', 'w') as out: |
| out.write('{}') |
| """) |
| FILENAME = 'spam' |
| |
| def setUp(self): |
| super().setUp() |
| self.interp = interpreters.create() |
| self._fs = None |
| |
| def tearDown(self): |
| if self._fs is not None: |
| self._fs.close() |
| super().tearDown() |
| |
| @property |
| def fs(self): |
| if self._fs is None: |
| self._fs = FSFixture(self) |
| return self._fs |
| |
| def test_success(self): |
| script, file = _captured_script('print("it worked!", end="")') |
| with file: |
| self.interp.run(script) |
| out = file.read() |
| |
| self.assertEqual(out, 'it worked!') |
| |
| def test_in_thread(self): |
| script, file = _captured_script('print("it worked!", end="")') |
| with file: |
| def f(): |
| self.interp.run(script) |
| |
| t = threading.Thread(target=f) |
| t.start() |
| t.join() |
| out = file.read() |
| |
| self.assertEqual(out, 'it worked!') |
| |
| @unittest.skipUnless(hasattr(os, 'fork'), "test needs os.fork()") |
| def test_fork(self): |
| import tempfile |
| with tempfile.NamedTemporaryFile('w+') as file: |
| file.write('') |
| file.flush() |
| |
| expected = 'spam spam spam spam spam' |
| script = dedent(f""" |
| import os |
| try: |
| os.fork() |
| except RuntimeError: |
| with open('{file.name}', 'w') as out: |
| out.write('{expected}') |
| """) |
| self.interp.run(script) |
| |
| file.seek(0) |
| content = file.read() |
| self.assertEqual(content, expected) |
| |
| def test_already_running(self): |
| with _running(self.interp): |
| with self.assertRaises(RuntimeError): |
| self.interp.run('print("spam")') |
| |
| def test_bad_script(self): |
| with self.assertRaises(TypeError): |
| self.interp.run(10) |
| |
| def test_bytes_for_script(self): |
| with self.assertRaises(TypeError): |
| self.interp.run(b'print("spam")') |
| |
| |
| class TestIsShareable(TestBase): |
| |
| def test_default_shareables(self): |
| shareables = [ |
| # singletons |
| None, |
| # builtin objects |
| b'spam', |
| 'spam', |
| 10, |
| -10, |
| ] |
| for obj in shareables: |
| with self.subTest(obj): |
| self.assertTrue( |
| interpreters.is_shareable(obj)) |
| |
| def test_not_shareable(self): |
| class Cheese: |
| def __init__(self, name): |
| self.name = name |
| def __str__(self): |
| return self.name |
| |
| class SubBytes(bytes): |
| """A subclass of a shareable type.""" |
| |
| not_shareables = [ |
| # singletons |
| True, |
| False, |
| NotImplemented, |
| ..., |
| # builtin types and objects |
| type, |
| object, |
| object(), |
| Exception(), |
| 100.0, |
| # user-defined types and objects |
| Cheese, |
| Cheese('Wensleydale'), |
| SubBytes(b'spam'), |
| ] |
| for obj in not_shareables: |
| with self.subTest(repr(obj)): |
| self.assertFalse( |
| interpreters.is_shareable(obj)) |
| |
| |
| class TestChannel(TestBase): |
| |
| def test_create_cid(self): |
| r, s = interpreters.create_channel() |
| self.assertIsInstance(r, interpreters.RecvChannel) |
| self.assertIsInstance(s, interpreters.SendChannel) |
| |
| def test_sequential_ids(self): |
| before = interpreters.list_all_channels() |
| channels1 = interpreters.create_channel() |
| channels2 = interpreters.create_channel() |
| channels3 = interpreters.create_channel() |
| after = interpreters.list_all_channels() |
| |
| self.assertEqual(len(set(after) - set(before)), |
| len({channels1, channels2, channels3})) |
| |
| |
| class TestSendRecv(TestBase): |
| |
| def test_send_recv_main(self): |
| r, s = interpreters.create_channel() |
| orig = b'spam' |
| s.send(orig) |
| obj = r.recv() |
| |
| self.assertEqual(obj, orig) |
| self.assertIsNot(obj, orig) |
| |
| def test_send_recv_same_interpreter(self): |
| interp = interpreters.create() |
| out = _run_output(interp, dedent(""" |
| from test.support import interpreters |
| r, s = interpreters.create_channel() |
| orig = b'spam' |
| s.send(orig) |
| obj = r.recv() |
| assert obj is not orig |
| assert obj == orig |
| """)) |
| |
| def test_send_recv_different_threads(self): |
| r, s = interpreters.create_channel() |
| |
| def f(): |
| while True: |
| try: |
| obj = r.recv() |
| break |
| except interpreters.ChannelEmptyError: |
| time.sleep(0.1) |
| s.send(obj) |
| t = threading.Thread(target=f) |
| t.start() |
| |
| s.send(b'spam') |
| t.join() |
| obj = r.recv() |
| |
| self.assertEqual(obj, b'spam') |
| |
| def test_send_recv_nowait_main(self): |
| r, s = interpreters.create_channel() |
| orig = b'spam' |
| s.send(orig) |
| obj = r.recv_nowait() |
| |
| self.assertEqual(obj, orig) |
| self.assertIsNot(obj, orig) |
| |
| def test_send_recv_nowait_same_interpreter(self): |
| interp = interpreters.create() |
| out = _run_output(interp, dedent(""" |
| from test.support import interpreters |
| r, s = interpreters.create_channel() |
| orig = b'spam' |
| s.send(orig) |
| obj = r.recv_nowait() |
| assert obj is not orig |
| assert obj == orig |
| """)) |
| |
| r, s = interpreters.create_channel() |
| |
| def f(): |
| while True: |
| try: |
| obj = r.recv_nowait() |
| break |
| except _interpreters.ChannelEmptyError: |
| time.sleep(0.1) |
| s.send(obj) |