blob: 3451a4c8759d8b1eb688731ab73f86df856ea9bc [file] [log] [blame]
Joannah Nanjekyebae872f2020-06-10 00:53:23 -03001import contextlib
2import os
3import threading
4from textwrap import dedent
5import unittest
6import time
7
8import _xxsubinterpreters as _interpreters
9from test.support import interpreters
10
11
12def _captured_script(script):
13 r, w = os.pipe()
14 indented = script.replace('\n', '\n ')
15 wrapped = dedent(f"""
16 import contextlib
17 with open({w}, 'w') as spipe:
18 with contextlib.redirect_stdout(spipe):
19 {indented}
20 """)
21 return wrapped, open(r)
22
23
24def clean_up_interpreters():
25 for interp in interpreters.list_all():
26 if interp.id == 0: # main
27 continue
28 try:
29 interp.close()
30 except RuntimeError:
31 pass # already destroyed
32
33
34def _run_output(interp, request, shared=None):
35 script, rpipe = _captured_script(request)
36 with rpipe:
37 interp.run(script)
38 return rpipe.read()
39
40
41@contextlib.contextmanager
42def _running(interp):
43 r, w = os.pipe()
44 def run():
45 interp.run(dedent(f"""
46 # wait for "signal"
47 with open({r}) as rpipe:
48 rpipe.read()
49 """))
50
51 t = threading.Thread(target=run)
52 t.start()
53
54 yield
55
56 with open(w, 'w') as spipe:
57 spipe.write('done')
58 t.join()
59
60
61class TestBase(unittest.TestCase):
62
63 def tearDown(self):
64 clean_up_interpreters()
65
66
67class CreateTests(TestBase):
68
69 def test_in_main(self):
70 interp = interpreters.create()
71 lst = interpreters.list_all()
72 self.assertEqual(interp.id, lst[1].id)
73
74 def test_in_thread(self):
75 lock = threading.Lock()
76 id = None
77 interp = interpreters.create()
78 lst = interpreters.list_all()
79 def f():
80 nonlocal id
81 id = interp.id
82 lock.acquire()
83 lock.release()
84
85 t = threading.Thread(target=f)
86 with lock:
87 t.start()
88 t.join()
89 self.assertEqual(interp.id, lst[1].id)
90
91 def test_in_subinterpreter(self):
92 main, = interpreters.list_all()
93 interp = interpreters.create()
94 out = _run_output(interp, dedent("""
95 from test.support import interpreters
96 interp = interpreters.create()
97 print(interp)
98 """))
99 interp2 = out.strip()
100
101 self.assertEqual(len(set(interpreters.list_all())), len({main, interp, interp2}))
102
103 def test_after_destroy_all(self):
104 before = set(interpreters.list_all())
105 # Create 3 subinterpreters.
106 interp_lst = []
107 for _ in range(3):
108 interps = interpreters.create()
109 interp_lst.append(interps)
110 # Now destroy them.
111 for interp in interp_lst:
112 interp.close()
113 # Finally, create another.
114 interp = interpreters.create()
115 self.assertEqual(len(set(interpreters.list_all())), len(before | {interp}))
116
117 def test_after_destroy_some(self):
118 before = set(interpreters.list_all())
119 # Create 3 subinterpreters.
120 interp1 = interpreters.create()
121 interp2 = interpreters.create()
122 interp3 = interpreters.create()
123 # Now destroy 2 of them.
124 interp1.close()
125 interp2.close()
126 # Finally, create another.
127 interp = interpreters.create()
128 self.assertEqual(len(set(interpreters.list_all())), len(before | {interp3, interp}))
129
130
131class GetCurrentTests(TestBase):
132
133 def test_main(self):
134 main_interp_id = _interpreters.get_main()
135 cur_interp_id = interpreters.get_current().id
136 self.assertEqual(cur_interp_id, main_interp_id)
137
138 def test_subinterpreter(self):
139 main = _interpreters.get_main()
140 interp = interpreters.create()
141 out = _run_output(interp, dedent("""
142 from test.support import interpreters
143 cur = interpreters.get_current()
144 print(cur)
145 """))
146 cur = out.strip()
147 self.assertNotEqual(cur, main)
148
149
150class ListAllTests(TestBase):
151
152 def test_initial(self):
153 interps = interpreters.list_all()
154 self.assertEqual(1, len(interps))
155
156 def test_after_creating(self):
157 main = interpreters.get_current()
158 first = interpreters.create()
159 second = interpreters.create()
160
161 ids = []
162 for interp in interpreters.list_all():
163 ids.append(interp.id)
164
165 self.assertEqual(ids, [main.id, first.id, second.id])
166
167 def test_after_destroying(self):
168 main = interpreters.get_current()
169 first = interpreters.create()
170 second = interpreters.create()
171 first.close()
172
173 ids = []
174 for interp in interpreters.list_all():
175 ids.append(interp.id)
176
177 self.assertEqual(ids, [main.id, second.id])
178
179
180class TestInterpreterId(TestBase):
181
182 def test_in_main(self):
183 main = interpreters.get_current()
184 self.assertEqual(0, main.id)
185
186 def test_with_custom_num(self):
187 interp = interpreters.Interpreter(1)
188 self.assertEqual(1, interp.id)
189
190 def test_for_readonly_property(self):
191 interp = interpreters.Interpreter(1)
192 with self.assertRaises(AttributeError):
193 interp.id = 2
194
195
196class TestInterpreterIsRunning(TestBase):
197
198 def test_main(self):
199 main = interpreters.get_current()
200 self.assertTrue(main.is_running())
201
202 def test_subinterpreter(self):
203 interp = interpreters.create()
204 self.assertFalse(interp.is_running())
205
206 with _running(interp):
207 self.assertTrue(interp.is_running())
208 self.assertFalse(interp.is_running())
209
210 def test_from_subinterpreter(self):
211 interp = interpreters.create()
212 out = _run_output(interp, dedent(f"""
213 import _xxsubinterpreters as _interpreters
214 if _interpreters.is_running({interp.id}):
215 print(True)
216 else:
217 print(False)
218 """))
219 self.assertEqual(out.strip(), 'True')
220
221 def test_already_destroyed(self):
222 interp = interpreters.create()
223 interp.close()
224 with self.assertRaises(RuntimeError):
225 interp.is_running()
226
227
228class TestInterpreterDestroy(TestBase):
229
230 def test_basic(self):
231 interp1 = interpreters.create()
232 interp2 = interpreters.create()
233 interp3 = interpreters.create()
234 self.assertEqual(4, len(interpreters.list_all()))
235 interp2.close()
236 self.assertEqual(3, len(interpreters.list_all()))
237
238 def test_all(self):
239 before = set(interpreters.list_all())
240 interps = set()
241 for _ in range(3):
242 interp = interpreters.create()
243 interps.add(interp)
244 self.assertEqual(len(set(interpreters.list_all())), len(before | interps))
245 for interp in interps:
246 interp.close()
247 self.assertEqual(len(set(interpreters.list_all())), len(before))
248
249 def test_main(self):
250 main, = interpreters.list_all()
251 with self.assertRaises(RuntimeError):
252 main.close()
253
254 def f():
255 with self.assertRaises(RuntimeError):
256 main.close()
257
258 t = threading.Thread(target=f)
259 t.start()
260 t.join()
261
262 def test_already_destroyed(self):
263 interp = interpreters.create()
264 interp.close()
265 with self.assertRaises(RuntimeError):
266 interp.close()
267
268 def test_from_current(self):
269 main, = interpreters.list_all()
270 interp = interpreters.create()
271 script = dedent(f"""
272 from test.support import interpreters
273 try:
274 main = interpreters.get_current()
275 main.close()
276 except RuntimeError:
277 pass
278 """)
279
280 interp.run(script)
281 self.assertEqual(len(set(interpreters.list_all())), len({main, interp}))
282
283 def test_from_sibling(self):
284 main, = interpreters.list_all()
285 interp1 = interpreters.create()
286 script = dedent(f"""
287 from test.support import interpreters
288 interp2 = interpreters.create()
289 interp2.close()
290 """)
291 interp1.run(script)
292
293 self.assertEqual(len(set(interpreters.list_all())), len({main, interp1}))
294
295 def test_from_other_thread(self):
296 interp = interpreters.create()
297 def f():
298 interp.close()
299
300 t = threading.Thread(target=f)
301 t.start()
302 t.join()
303
304 def test_still_running(self):
305 main, = interpreters.list_all()
306 interp = interpreters.create()
307 with _running(interp):
308 with self.assertRaises(RuntimeError):
309 interp.close()
310 self.assertTrue(interp.is_running())
311
312
313class TestInterpreterRun(TestBase):
314
315 SCRIPT = dedent("""
316 with open('{}', 'w') as out:
317 out.write('{}')
318 """)
319 FILENAME = 'spam'
320
321 def setUp(self):
322 super().setUp()
323 self.interp = interpreters.create()
324 self._fs = None
325
326 def tearDown(self):
327 if self._fs is not None:
328 self._fs.close()
329 super().tearDown()
330
331 @property
332 def fs(self):
333 if self._fs is None:
334 self._fs = FSFixture(self)
335 return self._fs
336
337 def test_success(self):
338 script, file = _captured_script('print("it worked!", end="")')
339 with file:
340 self.interp.run(script)
341 out = file.read()
342
343 self.assertEqual(out, 'it worked!')
344
345 def test_in_thread(self):
346 script, file = _captured_script('print("it worked!", end="")')
347 with file:
348 def f():
349 self.interp.run(script)
350
351 t = threading.Thread(target=f)
352 t.start()
353 t.join()
354 out = file.read()
355
356 self.assertEqual(out, 'it worked!')
357
358 @unittest.skipUnless(hasattr(os, 'fork'), "test needs os.fork()")
359 def test_fork(self):
360 import tempfile
361 with tempfile.NamedTemporaryFile('w+') as file:
362 file.write('')
363 file.flush()
364
365 expected = 'spam spam spam spam spam'
366 script = dedent(f"""
367 import os
368 try:
369 os.fork()
370 except RuntimeError:
371 with open('{file.name}', 'w') as out:
372 out.write('{expected}')
373 """)
374 self.interp.run(script)
375
376 file.seek(0)
377 content = file.read()
378 self.assertEqual(content, expected)
379
380 def test_already_running(self):
381 with _running(self.interp):
382 with self.assertRaises(RuntimeError):
383 self.interp.run('print("spam")')
384
385 def test_bad_script(self):
386 with self.assertRaises(TypeError):
387 self.interp.run(10)
388
389 def test_bytes_for_script(self):
390 with self.assertRaises(TypeError):
391 self.interp.run(b'print("spam")')
392
393
394class TestIsShareable(TestBase):
395
396 def test_default_shareables(self):
397 shareables = [
398 # singletons
399 None,
400 # builtin objects
401 b'spam',
402 'spam',
403 10,
404 -10,
405 ]
406 for obj in shareables:
407 with self.subTest(obj):
408 self.assertTrue(
409 interpreters.is_shareable(obj))
410
411 def test_not_shareable(self):
412 class Cheese:
413 def __init__(self, name):
414 self.name = name
415 def __str__(self):
416 return self.name
417
418 class SubBytes(bytes):
419 """A subclass of a shareable type."""
420
421 not_shareables = [
422 # singletons
423 True,
424 False,
425 NotImplemented,
426 ...,
427 # builtin types and objects
428 type,
429 object,
430 object(),
431 Exception(),
432 100.0,
433 # user-defined types and objects
434 Cheese,
435 Cheese('Wensleydale'),
436 SubBytes(b'spam'),
437 ]
438 for obj in not_shareables:
439 with self.subTest(repr(obj)):
440 self.assertFalse(
441 interpreters.is_shareable(obj))
442
443
444class TestChannel(TestBase):
445
446 def test_create_cid(self):
447 r, s = interpreters.create_channel()
448 self.assertIsInstance(r, interpreters.RecvChannel)
449 self.assertIsInstance(s, interpreters.SendChannel)
450
451 def test_sequential_ids(self):
452 before = interpreters.list_all_channels()
453 channels1 = interpreters.create_channel()
454 channels2 = interpreters.create_channel()
455 channels3 = interpreters.create_channel()
456 after = interpreters.list_all_channels()
457
458 self.assertEqual(len(set(after) - set(before)),
459 len({channels1, channels2, channels3}))
460
461
462class TestSendRecv(TestBase):
463
464 def test_send_recv_main(self):
465 r, s = interpreters.create_channel()
466 orig = b'spam'
467 s.send(orig)
468 obj = r.recv()
469
470 self.assertEqual(obj, orig)
471 self.assertIsNot(obj, orig)
472
473 def test_send_recv_same_interpreter(self):
474 interp = interpreters.create()
475 out = _run_output(interp, dedent("""
476 from test.support import interpreters
477 r, s = interpreters.create_channel()
478 orig = b'spam'
479 s.send(orig)
480 obj = r.recv()
481 assert obj is not orig
482 assert obj == orig
483 """))
484
485 def test_send_recv_different_threads(self):
486 r, s = interpreters.create_channel()
487
488 def f():
489 while True:
490 try:
491 obj = r.recv()
492 break
493 except interpreters.ChannelEmptyError:
494 time.sleep(0.1)
495 s.send(obj)
496 t = threading.Thread(target=f)
497 t.start()
498
499 s.send(b'spam')
500 t.join()
501 obj = r.recv()
502
503 self.assertEqual(obj, b'spam')
504
505 def test_send_recv_nowait_main(self):
506 r, s = interpreters.create_channel()
507 orig = b'spam'
508 s.send(orig)
509 obj = r.recv_nowait()
510
511 self.assertEqual(obj, orig)
512 self.assertIsNot(obj, orig)
513
514 def test_send_recv_nowait_same_interpreter(self):
515 interp = interpreters.create()
516 out = _run_output(interp, dedent("""
517 from test.support import interpreters
518 r, s = interpreters.create_channel()
519 orig = b'spam'
520 s.send(orig)
521 obj = r.recv_nowait()
522 assert obj is not orig
523 assert obj == orig
524 """))
525
526 r, s = interpreters.create_channel()
527
528 def f():
529 while True:
530 try:
531 obj = r.recv_nowait()
532 break
533 except _interpreters.ChannelEmptyError:
534 time.sleep(0.1)
535 s.send(obj)