blob: 8d72ca20021486071f8b645b07295e613d514b0d [file] [log] [blame]
Eric Snow7f8bfc92018-01-29 18:23:44 -07001import contextlib
2import os
3import pickle
4from textwrap import dedent, indent
5import threading
6import unittest
7
8from test import support
9from test.support import script_helper
10
11interpreters = support.import_module('_xxsubinterpreters')
12
13
14def _captured_script(script):
15 r, w = os.pipe()
16 indented = script.replace('\n', '\n ')
17 wrapped = dedent(f"""
18 import contextlib
19 with open({w}, 'w') as chan:
20 with contextlib.redirect_stdout(chan):
21 {indented}
22 """)
23 return wrapped, open(r)
24
25
26def _run_output(interp, request, shared=None):
27 script, chan = _captured_script(request)
28 with chan:
29 interpreters.run_string(interp, script, shared)
30 return chan.read()
31
32
33@contextlib.contextmanager
34def _running(interp):
35 r, w = os.pipe()
36 def run():
37 interpreters.run_string(interp, dedent(f"""
38 # wait for "signal"
39 with open({r}) as chan:
40 chan.read()
41 """))
42
43 t = threading.Thread(target=run)
44 t.start()
45
46 yield
47
48 with open(w, 'w') as chan:
49 chan.write('done')
50 t.join()
51
52
53class IsShareableTests(unittest.TestCase):
54
55 def test_default_shareables(self):
56 shareables = [
57 # singletons
58 None,
59 # builtin objects
60 b'spam',
61 ]
62 for obj in shareables:
63 with self.subTest(obj):
64 self.assertTrue(
65 interpreters.is_shareable(obj))
66
67 def test_not_shareable(self):
68 class Cheese:
69 def __init__(self, name):
70 self.name = name
71 def __str__(self):
72 return self.name
73
74 class SubBytes(bytes):
75 """A subclass of a shareable type."""
76
77 not_shareables = [
78 # singletons
79 True,
80 False,
81 NotImplemented,
82 ...,
83 # builtin types and objects
84 type,
85 object,
86 object(),
87 Exception(),
88 42,
89 100.0,
90 'spam',
91 # user-defined types and objects
92 Cheese,
93 Cheese('Wensleydale'),
94 SubBytes(b'spam'),
95 ]
96 for obj in not_shareables:
97 with self.subTest(obj):
98 self.assertFalse(
99 interpreters.is_shareable(obj))
100
101
102class TestBase(unittest.TestCase):
103
104 def tearDown(self):
105 for id in interpreters.list_all():
106 if id == 0: # main
107 continue
108 try:
109 interpreters.destroy(id)
110 except RuntimeError:
111 pass # already destroyed
112
113 for cid in interpreters.channel_list_all():
114 try:
115 interpreters.channel_destroy(cid)
116 except interpreters.ChannelNotFoundError:
117 pass # already destroyed
118
119
120class ListAllTests(TestBase):
121
122 def test_initial(self):
123 main = interpreters.get_main()
124 ids = interpreters.list_all()
125 self.assertEqual(ids, [main])
126
127 def test_after_creating(self):
128 main = interpreters.get_main()
129 first = interpreters.create()
130 second = interpreters.create()
131 ids = interpreters.list_all()
132 self.assertEqual(ids, [main, first, second])
133
134 def test_after_destroying(self):
135 main = interpreters.get_main()
136 first = interpreters.create()
137 second = interpreters.create()
138 interpreters.destroy(first)
139 ids = interpreters.list_all()
140 self.assertEqual(ids, [main, second])
141
142
143class GetCurrentTests(TestBase):
144
145 def test_main(self):
146 main = interpreters.get_main()
147 cur = interpreters.get_current()
148 self.assertEqual(cur, main)
149
150 def test_subinterpreter(self):
151 main = interpreters.get_main()
152 interp = interpreters.create()
153 out = _run_output(interp, dedent("""
154 import _xxsubinterpreters as _interpreters
155 print(_interpreters.get_current())
156 """))
157 cur = int(out.strip())
158 _, expected = interpreters.list_all()
159 self.assertEqual(cur, expected)
160 self.assertNotEqual(cur, main)
161
162
163class GetMainTests(TestBase):
164
165 def test_from_main(self):
166 [expected] = interpreters.list_all()
167 main = interpreters.get_main()
168 self.assertEqual(main, expected)
169
170 def test_from_subinterpreter(self):
171 [expected] = interpreters.list_all()
172 interp = interpreters.create()
173 out = _run_output(interp, dedent("""
174 import _xxsubinterpreters as _interpreters
175 print(_interpreters.get_main())
176 """))
177 main = int(out.strip())
178 self.assertEqual(main, expected)
179
180
181class IsRunningTests(TestBase):
182
183 def test_main(self):
184 main = interpreters.get_main()
185 self.assertTrue(interpreters.is_running(main))
186
187 def test_subinterpreter(self):
188 interp = interpreters.create()
189 self.assertFalse(interpreters.is_running(interp))
190
191 with _running(interp):
192 self.assertTrue(interpreters.is_running(interp))
193 self.assertFalse(interpreters.is_running(interp))
194
195 def test_from_subinterpreter(self):
196 interp = interpreters.create()
197 out = _run_output(interp, dedent(f"""
198 import _xxsubinterpreters as _interpreters
199 if _interpreters.is_running({interp}):
200 print(True)
201 else:
202 print(False)
203 """))
204 self.assertEqual(out.strip(), 'True')
205
206 def test_already_destroyed(self):
207 interp = interpreters.create()
208 interpreters.destroy(interp)
209 with self.assertRaises(RuntimeError):
210 interpreters.is_running(interp)
211
212 def test_does_not_exist(self):
213 with self.assertRaises(RuntimeError):
214 interpreters.is_running(1_000_000)
215
216 def test_bad_id(self):
217 with self.assertRaises(RuntimeError):
218 interpreters.is_running(-1)
219
220
221class CreateTests(TestBase):
222
223 def test_in_main(self):
224 id = interpreters.create()
225
226 self.assertIn(id, interpreters.list_all())
227
228 @unittest.skip('enable this test when working on pystate.c')
229 def test_unique_id(self):
230 seen = set()
231 for _ in range(100):
232 id = interpreters.create()
233 interpreters.destroy(id)
234 seen.add(id)
235
236 self.assertEqual(len(seen), 100)
237
238 def test_in_thread(self):
239 lock = threading.Lock()
240 id = None
241 def f():
242 nonlocal id
243 id = interpreters.create()
244 lock.acquire()
245 lock.release()
246
247 t = threading.Thread(target=f)
248 with lock:
249 t.start()
250 t.join()
251 self.assertIn(id, interpreters.list_all())
252
253 def test_in_subinterpreter(self):
254 main, = interpreters.list_all()
255 id1 = interpreters.create()
256 out = _run_output(id1, dedent("""
257 import _xxsubinterpreters as _interpreters
258 id = _interpreters.create()
259 print(id)
260 """))
261 id2 = int(out.strip())
262
263 self.assertEqual(set(interpreters.list_all()), {main, id1, id2})
264
265 def test_in_threaded_subinterpreter(self):
266 main, = interpreters.list_all()
267 id1 = interpreters.create()
268 id2 = None
269 def f():
270 nonlocal id2
271 out = _run_output(id1, dedent("""
272 import _xxsubinterpreters as _interpreters
273 id = _interpreters.create()
274 print(id)
275 """))
276 id2 = int(out.strip())
277
278 t = threading.Thread(target=f)
279 t.start()
280 t.join()
281
282 self.assertEqual(set(interpreters.list_all()), {main, id1, id2})
283
284 def test_after_destroy_all(self):
285 before = set(interpreters.list_all())
286 # Create 3 subinterpreters.
287 ids = []
288 for _ in range(3):
289 id = interpreters.create()
290 ids.append(id)
291 # Now destroy them.
292 for id in ids:
293 interpreters.destroy(id)
294 # Finally, create another.
295 id = interpreters.create()
296 self.assertEqual(set(interpreters.list_all()), before | {id})
297
298 def test_after_destroy_some(self):
299 before = set(interpreters.list_all())
300 # Create 3 subinterpreters.
301 id1 = interpreters.create()
302 id2 = interpreters.create()
303 id3 = interpreters.create()
304 # Now destroy 2 of them.
305 interpreters.destroy(id1)
306 interpreters.destroy(id3)
307 # Finally, create another.
308 id = interpreters.create()
309 self.assertEqual(set(interpreters.list_all()), before | {id, id2})
310
311
312class DestroyTests(TestBase):
313
314 def test_one(self):
315 id1 = interpreters.create()
316 id2 = interpreters.create()
317 id3 = interpreters.create()
318 self.assertIn(id2, interpreters.list_all())
319 interpreters.destroy(id2)
320 self.assertNotIn(id2, interpreters.list_all())
321 self.assertIn(id1, interpreters.list_all())
322 self.assertIn(id3, interpreters.list_all())
323
324 def test_all(self):
325 before = set(interpreters.list_all())
326 ids = set()
327 for _ in range(3):
328 id = interpreters.create()
329 ids.add(id)
330 self.assertEqual(set(interpreters.list_all()), before | ids)
331 for id in ids:
332 interpreters.destroy(id)
333 self.assertEqual(set(interpreters.list_all()), before)
334
335 def test_main(self):
336 main, = interpreters.list_all()
337 with self.assertRaises(RuntimeError):
338 interpreters.destroy(main)
339
340 def f():
341 with self.assertRaises(RuntimeError):
342 interpreters.destroy(main)
343
344 t = threading.Thread(target=f)
345 t.start()
346 t.join()
347
348 def test_already_destroyed(self):
349 id = interpreters.create()
350 interpreters.destroy(id)
351 with self.assertRaises(RuntimeError):
352 interpreters.destroy(id)
353
354 def test_does_not_exist(self):
355 with self.assertRaises(RuntimeError):
356 interpreters.destroy(1_000_000)
357
358 def test_bad_id(self):
359 with self.assertRaises(RuntimeError):
360 interpreters.destroy(-1)
361
362 def test_from_current(self):
363 main, = interpreters.list_all()
364 id = interpreters.create()
Miss Islington (bot)f33eced2018-02-02 21:38:57 -0800365 script = dedent(f"""
Eric Snow7f8bfc92018-01-29 18:23:44 -0700366 import _xxsubinterpreters as _interpreters
Miss Islington (bot)f33eced2018-02-02 21:38:57 -0800367 try:
368 _interpreters.destroy({id})
369 except RuntimeError:
370 pass
371 """)
Eric Snow7f8bfc92018-01-29 18:23:44 -0700372
Miss Islington (bot)f33eced2018-02-02 21:38:57 -0800373 interpreters.run_string(id, script)
Eric Snow7f8bfc92018-01-29 18:23:44 -0700374 self.assertEqual(set(interpreters.list_all()), {main, id})
375
376 def test_from_sibling(self):
377 main, = interpreters.list_all()
378 id1 = interpreters.create()
379 id2 = interpreters.create()
380 script = dedent("""
381 import _xxsubinterpreters as _interpreters
382 _interpreters.destroy({})
383 """).format(id2)
384 interpreters.run_string(id1, script)
385
386 self.assertEqual(set(interpreters.list_all()), {main, id1})
387
388 def test_from_other_thread(self):
389 id = interpreters.create()
390 def f():
391 interpreters.destroy(id)
392
393 t = threading.Thread(target=f)
394 t.start()
395 t.join()
396
397 def test_still_running(self):
398 main, = interpreters.list_all()
399 interp = interpreters.create()
400 with _running(interp):
401 with self.assertRaises(RuntimeError):
402 interpreters.destroy(interp)
403 self.assertTrue(interpreters.is_running(interp))
404
405
406class RunStringTests(TestBase):
407
408 SCRIPT = dedent("""
409 with open('{}', 'w') as out:
410 out.write('{}')
411 """)
412 FILENAME = 'spam'
413
414 def setUp(self):
415 super().setUp()
416 self.id = interpreters.create()
417 self._fs = None
418
419 def tearDown(self):
420 if self._fs is not None:
421 self._fs.close()
422 super().tearDown()
423
424 @property
425 def fs(self):
426 if self._fs is None:
427 self._fs = FSFixture(self)
428 return self._fs
429
430 def test_success(self):
431 script, file = _captured_script('print("it worked!", end="")')
432 with file:
433 interpreters.run_string(self.id, script)
434 out = file.read()
435
436 self.assertEqual(out, 'it worked!')
437
438 def test_in_thread(self):
439 script, file = _captured_script('print("it worked!", end="")')
440 with file:
441 def f():
442 interpreters.run_string(self.id, script)
443
444 t = threading.Thread(target=f)
445 t.start()
446 t.join()
447 out = file.read()
448
449 self.assertEqual(out, 'it worked!')
450
451 def test_create_thread(self):
452 script, file = _captured_script("""
453 import threading
454 def f():
455 print('it worked!', end='')
456
457 t = threading.Thread(target=f)
458 t.start()
459 t.join()
460 """)
461 with file:
462 interpreters.run_string(self.id, script)
463 out = file.read()
464
465 self.assertEqual(out, 'it worked!')
466
467 @unittest.skipUnless(hasattr(os, 'fork'), "test needs os.fork()")
468 def test_fork(self):
469 import tempfile
470 with tempfile.NamedTemporaryFile('w+') as file:
471 file.write('')
472 file.flush()
473
474 expected = 'spam spam spam spam spam'
475 script = dedent(f"""
476 # (inspired by Lib/test/test_fork.py)
477 import os
478 pid = os.fork()
479 if pid == 0: # child
480 with open('{file.name}', 'w') as out:
481 out.write('{expected}')
482 # Kill the unittest runner in the child process.
483 os._exit(1)
484 else:
485 SHORT_SLEEP = 0.1
486 import time
487 for _ in range(10):
488 spid, status = os.waitpid(pid, os.WNOHANG)
489 if spid == pid:
490 break
491 time.sleep(SHORT_SLEEP)
492 assert(spid == pid)
493 """)
494 interpreters.run_string(self.id, script)
495
496 file.seek(0)
497 content = file.read()
498 self.assertEqual(content, expected)
499
500 def test_already_running(self):
501 with _running(self.id):
502 with self.assertRaises(RuntimeError):
503 interpreters.run_string(self.id, 'print("spam")')
504
505 def test_does_not_exist(self):
506 id = 0
507 while id in interpreters.list_all():
508 id += 1
509 with self.assertRaises(RuntimeError):
510 interpreters.run_string(id, 'print("spam")')
511
512 def test_error_id(self):
513 with self.assertRaises(RuntimeError):
514 interpreters.run_string(-1, 'print("spam")')
515
516 def test_bad_id(self):
517 with self.assertRaises(TypeError):
518 interpreters.run_string('spam', 'print("spam")')
519
520 def test_bad_script(self):
521 with self.assertRaises(TypeError):
522 interpreters.run_string(self.id, 10)
523
524 def test_bytes_for_script(self):
525 with self.assertRaises(TypeError):
526 interpreters.run_string(self.id, b'print("spam")')
527
528 @contextlib.contextmanager
529 def assert_run_failed(self, exctype, msg=None):
530 with self.assertRaises(interpreters.RunFailedError) as caught:
531 yield
532 if msg is None:
533 self.assertEqual(str(caught.exception).split(':')[0],
534 str(exctype))
535 else:
536 self.assertEqual(str(caught.exception),
537 "{}: {}".format(exctype, msg))
538
539 def test_invalid_syntax(self):
540 with self.assert_run_failed(SyntaxError):
541 # missing close paren
542 interpreters.run_string(self.id, 'print("spam"')
543
544 def test_failure(self):
545 with self.assert_run_failed(Exception, 'spam'):
546 interpreters.run_string(self.id, 'raise Exception("spam")')
547
548 def test_SystemExit(self):
549 with self.assert_run_failed(SystemExit, '42'):
550 interpreters.run_string(self.id, 'raise SystemExit(42)')
551
552 def test_sys_exit(self):
553 with self.assert_run_failed(SystemExit):
554 interpreters.run_string(self.id, dedent("""
555 import sys
556 sys.exit()
557 """))
558
559 with self.assert_run_failed(SystemExit, '42'):
560 interpreters.run_string(self.id, dedent("""
561 import sys
562 sys.exit(42)
563 """))
564
565 def test_with_shared(self):
566 r, w = os.pipe()
567
568 shared = {
569 'spam': b'ham',
570 'eggs': b'-1',
571 'cheddar': None,
572 }
573 script = dedent(f"""
574 eggs = int(eggs)
575 spam = 42
576 result = spam + eggs
577
578 ns = dict(vars())
579 del ns['__builtins__']
580 import pickle
581 with open({w}, 'wb') as chan:
582 pickle.dump(ns, chan)
583 """)
584 interpreters.run_string(self.id, script, shared)
585 with open(r, 'rb') as chan:
586 ns = pickle.load(chan)
587
588 self.assertEqual(ns['spam'], 42)
589 self.assertEqual(ns['eggs'], -1)
590 self.assertEqual(ns['result'], 41)
591 self.assertIsNone(ns['cheddar'])
592
593 def test_shared_overwrites(self):
594 interpreters.run_string(self.id, dedent("""
595 spam = 'eggs'
596 ns1 = dict(vars())
597 del ns1['__builtins__']
598 """))
599
600 shared = {'spam': b'ham'}
601 script = dedent(f"""
602 ns2 = dict(vars())
603 del ns2['__builtins__']
604 """)
605 interpreters.run_string(self.id, script, shared)
606
607 r, w = os.pipe()
608 script = dedent(f"""
609 ns = dict(vars())
610 del ns['__builtins__']
611 import pickle
612 with open({w}, 'wb') as chan:
613 pickle.dump(ns, chan)
614 """)
615 interpreters.run_string(self.id, script)
616 with open(r, 'rb') as chan:
617 ns = pickle.load(chan)
618
619 self.assertEqual(ns['ns1']['spam'], 'eggs')
620 self.assertEqual(ns['ns2']['spam'], b'ham')
621 self.assertEqual(ns['spam'], b'ham')
622
623 def test_shared_overwrites_default_vars(self):
624 r, w = os.pipe()
625
626 shared = {'__name__': b'not __main__'}
627 script = dedent(f"""
628 spam = 42
629
630 ns = dict(vars())
631 del ns['__builtins__']
632 import pickle
633 with open({w}, 'wb') as chan:
634 pickle.dump(ns, chan)
635 """)
636 interpreters.run_string(self.id, script, shared)
637 with open(r, 'rb') as chan:
638 ns = pickle.load(chan)
639
640 self.assertEqual(ns['__name__'], b'not __main__')
641
642 def test_main_reused(self):
643 r, w = os.pipe()
644 interpreters.run_string(self.id, dedent(f"""
645 spam = True
646
647 ns = dict(vars())
648 del ns['__builtins__']
649 import pickle
650 with open({w}, 'wb') as chan:
651 pickle.dump(ns, chan)
652 del ns, pickle, chan
653 """))
654 with open(r, 'rb') as chan:
655 ns1 = pickle.load(chan)
656
657 r, w = os.pipe()
658 interpreters.run_string(self.id, dedent(f"""
659 eggs = False
660
661 ns = dict(vars())
662 del ns['__builtins__']
663 import pickle
664 with open({w}, 'wb') as chan:
665 pickle.dump(ns, chan)
666 """))
667 with open(r, 'rb') as chan:
668 ns2 = pickle.load(chan)
669
670 self.assertIn('spam', ns1)
671 self.assertNotIn('eggs', ns1)
672 self.assertIn('eggs', ns2)
673 self.assertIn('spam', ns2)
674
675 def test_execution_namespace_is_main(self):
676 r, w = os.pipe()
677
678 script = dedent(f"""
679 spam = 42
680
681 ns = dict(vars())
682 ns['__builtins__'] = str(ns['__builtins__'])
683 import pickle
684 with open({w}, 'wb') as chan:
685 pickle.dump(ns, chan)
686 """)
687 interpreters.run_string(self.id, script)
688 with open(r, 'rb') as chan:
689 ns = pickle.load(chan)
690
691 ns.pop('__builtins__')
692 ns.pop('__loader__')
693 self.assertEqual(ns, {
694 '__name__': '__main__',
695 '__annotations__': {},
696 '__doc__': None,
697 '__package__': None,
698 '__spec__': None,
699 'spam': 42,
700 })
701
702 def test_still_running_at_exit(self):
703 script = dedent(f"""
704 from textwrap import dedent
705 import threading
706 import _xxsubinterpreters as _interpreters
707 def f():
708 _interpreters.run_string(id, dedent('''
709 import time
710 # Give plenty of time for the main interpreter to finish.
711 time.sleep(1_000_000)
712 '''))
713
714 t = threading.Thread(target=f)
715 t.start()
716 """)
717 with support.temp_dir() as dirname:
718 filename = script_helper.make_script(dirname, 'interp', script)
719 with script_helper.spawn_python(filename) as proc:
720 retcode = proc.wait()
721
722 self.assertEqual(retcode, 0)
723
724
725class ChannelIDTests(TestBase):
726
727 def test_default_kwargs(self):
728 cid = interpreters._channel_id(10, force=True)
729
730 self.assertEqual(int(cid), 10)
731 self.assertEqual(cid.end, 'both')
732
733 def test_with_kwargs(self):
734 cid = interpreters._channel_id(10, send=True, force=True)
735 self.assertEqual(cid.end, 'send')
736
737 cid = interpreters._channel_id(10, send=True, recv=False, force=True)
738 self.assertEqual(cid.end, 'send')
739
740 cid = interpreters._channel_id(10, recv=True, force=True)
741 self.assertEqual(cid.end, 'recv')
742
743 cid = interpreters._channel_id(10, recv=True, send=False, force=True)
744 self.assertEqual(cid.end, 'recv')
745
746 cid = interpreters._channel_id(10, send=True, recv=True, force=True)
747 self.assertEqual(cid.end, 'both')
748
749 def test_coerce_id(self):
750 cid = interpreters._channel_id('10', force=True)
751 self.assertEqual(int(cid), 10)
752
753 cid = interpreters._channel_id(10.0, force=True)
754 self.assertEqual(int(cid), 10)
755
756 class Int(str):
757 def __init__(self, value):
758 self._value = value
759 def __int__(self):
760 return self._value
761
762 cid = interpreters._channel_id(Int(10), force=True)
763 self.assertEqual(int(cid), 10)
764
765 def test_bad_id(self):
Miss Islington (bot)f33eced2018-02-02 21:38:57 -0800766 for cid in [-1, 'spam']:
Eric Snow7f8bfc92018-01-29 18:23:44 -0700767 with self.subTest(cid):
768 with self.assertRaises(ValueError):
769 interpreters._channel_id(cid)
Miss Islington (bot)f33eced2018-02-02 21:38:57 -0800770 with self.assertRaises(OverflowError):
771 interpreters._channel_id(2**64)
Eric Snow7f8bfc92018-01-29 18:23:44 -0700772 with self.assertRaises(TypeError):
773 interpreters._channel_id(object())
774
775 def test_bad_kwargs(self):
776 with self.assertRaises(ValueError):
777 interpreters._channel_id(10, send=False, recv=False)
778
779 def test_does_not_exist(self):
780 cid = interpreters.channel_create()
781 with self.assertRaises(interpreters.ChannelNotFoundError):
782 interpreters._channel_id(int(cid) + 1) # unforced
783
784 def test_repr(self):
785 cid = interpreters._channel_id(10, force=True)
786 self.assertEqual(repr(cid), 'ChannelID(10)')
787
788 cid = interpreters._channel_id(10, send=True, force=True)
789 self.assertEqual(repr(cid), 'ChannelID(10, send=True)')
790
791 cid = interpreters._channel_id(10, recv=True, force=True)
792 self.assertEqual(repr(cid), 'ChannelID(10, recv=True)')
793
794 cid = interpreters._channel_id(10, send=True, recv=True, force=True)
795 self.assertEqual(repr(cid), 'ChannelID(10)')
796
797 def test_equality(self):
798 cid1 = interpreters.channel_create()
799 cid2 = interpreters._channel_id(int(cid1))
800 cid3 = interpreters.channel_create()
801
802 self.assertTrue(cid1 == cid1)
803 self.assertTrue(cid1 == cid2)
804 self.assertTrue(cid1 == int(cid1))
805 self.assertFalse(cid1 == cid3)
806
807 self.assertFalse(cid1 != cid1)
808 self.assertFalse(cid1 != cid2)
809 self.assertTrue(cid1 != cid3)
810
811
812class ChannelTests(TestBase):
813
814 def test_sequential_ids(self):
815 before = interpreters.channel_list_all()
816 id1 = interpreters.channel_create()
817 id2 = interpreters.channel_create()
818 id3 = interpreters.channel_create()
819 after = interpreters.channel_list_all()
820
821 self.assertEqual(id2, int(id1) + 1)
822 self.assertEqual(id3, int(id2) + 1)
823 self.assertEqual(set(after) - set(before), {id1, id2, id3})
824
825 def test_ids_global(self):
826 id1 = interpreters.create()
827 out = _run_output(id1, dedent("""
828 import _xxsubinterpreters as _interpreters
829 cid = _interpreters.channel_create()
830 print(int(cid))
831 """))
832 cid1 = int(out.strip())
833
834 id2 = interpreters.create()
835 out = _run_output(id2, dedent("""
836 import _xxsubinterpreters as _interpreters
837 cid = _interpreters.channel_create()
838 print(int(cid))
839 """))
840 cid2 = int(out.strip())
841
842 self.assertEqual(cid2, int(cid1) + 1)
843
844 ####################
845
846 def test_drop_single_user(self):
847 cid = interpreters.channel_create()
848 interpreters.channel_send(cid, b'spam')
849 interpreters.channel_recv(cid)
850 interpreters.channel_drop_interpreter(cid, send=True, recv=True)
851
852 with self.assertRaises(interpreters.ChannelClosedError):
853 interpreters.channel_send(cid, b'eggs')
854 with self.assertRaises(interpreters.ChannelClosedError):
855 interpreters.channel_recv(cid)
856
857 def test_drop_multiple_users(self):
858 cid = interpreters.channel_create()
859 id1 = interpreters.create()
860 id2 = interpreters.create()
861 interpreters.run_string(id1, dedent(f"""
862 import _xxsubinterpreters as _interpreters
863 _interpreters.channel_send({int(cid)}, b'spam')
864 """))
865 out = _run_output(id2, dedent(f"""
866 import _xxsubinterpreters as _interpreters
867 obj = _interpreters.channel_recv({int(cid)})
868 _interpreters.channel_drop_interpreter({int(cid)})
869 print(repr(obj))
870 """))
871 interpreters.run_string(id1, dedent(f"""
872 _interpreters.channel_drop_interpreter({int(cid)})
873 """))
874
875 self.assertEqual(out.strip(), "b'spam'")
876
877 def test_drop_no_kwargs(self):
878 cid = interpreters.channel_create()
879 interpreters.channel_send(cid, b'spam')
880 interpreters.channel_recv(cid)
881 interpreters.channel_drop_interpreter(cid)
882
883 with self.assertRaises(interpreters.ChannelClosedError):
884 interpreters.channel_send(cid, b'eggs')
885 with self.assertRaises(interpreters.ChannelClosedError):
886 interpreters.channel_recv(cid)
887
888 def test_drop_multiple_times(self):
889 cid = interpreters.channel_create()
890 interpreters.channel_send(cid, b'spam')
891 interpreters.channel_recv(cid)
892 interpreters.channel_drop_interpreter(cid, send=True, recv=True)
893
894 with self.assertRaises(interpreters.ChannelClosedError):
895 interpreters.channel_drop_interpreter(cid, send=True, recv=True)
896
897 def test_drop_with_unused_items(self):
898 cid = interpreters.channel_create()
899 interpreters.channel_send(cid, b'spam')
900 interpreters.channel_send(cid, b'ham')
901 interpreters.channel_drop_interpreter(cid, send=True, recv=True)
902
903 with self.assertRaises(interpreters.ChannelClosedError):
904 interpreters.channel_recv(cid)
905
906 def test_drop_never_used(self):
907 cid = interpreters.channel_create()
908 interpreters.channel_drop_interpreter(cid)
909
910 with self.assertRaises(interpreters.ChannelClosedError):
911 interpreters.channel_send(cid, b'spam')
912 with self.assertRaises(interpreters.ChannelClosedError):
913 interpreters.channel_recv(cid)
914
915 def test_drop_by_unassociated_interp(self):
916 cid = interpreters.channel_create()
917 interpreters.channel_send(cid, b'spam')
918 interp = interpreters.create()
919 interpreters.run_string(interp, dedent(f"""
920 import _xxsubinterpreters as _interpreters
921 _interpreters.channel_drop_interpreter({int(cid)})
922 """))
923 obj = interpreters.channel_recv(cid)
924 interpreters.channel_drop_interpreter(cid)
925
926 with self.assertRaises(interpreters.ChannelClosedError):
927 interpreters.channel_send(cid, b'eggs')
928 self.assertEqual(obj, b'spam')
929
930 def test_drop_close_if_unassociated(self):
931 cid = interpreters.channel_create()
932 interp = interpreters.create()
933 interpreters.run_string(interp, dedent(f"""
934 import _xxsubinterpreters as _interpreters
935 obj = _interpreters.channel_send({int(cid)}, b'spam')
936 _interpreters.channel_drop_interpreter({int(cid)})
937 """))
938
939 with self.assertRaises(interpreters.ChannelClosedError):
940 interpreters.channel_recv(cid)
941
942 def test_drop_partially(self):
943 # XXX Is partial close too wierd/confusing?
944 cid = interpreters.channel_create()
945 interpreters.channel_send(cid, None)
946 interpreters.channel_recv(cid)
947 interpreters.channel_send(cid, b'spam')
948 interpreters.channel_drop_interpreter(cid, send=True)
949 obj = interpreters.channel_recv(cid)
950
951 self.assertEqual(obj, b'spam')
952
953 def test_drop_used_multiple_times_by_single_user(self):
954 cid = interpreters.channel_create()
955 interpreters.channel_send(cid, b'spam')
956 interpreters.channel_send(cid, b'spam')
957 interpreters.channel_send(cid, b'spam')
958 interpreters.channel_recv(cid)
959 interpreters.channel_drop_interpreter(cid, send=True, recv=True)
960
961 with self.assertRaises(interpreters.ChannelClosedError):
962 interpreters.channel_send(cid, b'eggs')
963 with self.assertRaises(interpreters.ChannelClosedError):
964 interpreters.channel_recv(cid)
965
966 ####################
967
968 def test_close_single_user(self):
969 cid = interpreters.channel_create()
970 interpreters.channel_send(cid, b'spam')
971 interpreters.channel_recv(cid)
972 interpreters.channel_close(cid)
973
974 with self.assertRaises(interpreters.ChannelClosedError):
975 interpreters.channel_send(cid, b'eggs')
976 with self.assertRaises(interpreters.ChannelClosedError):
977 interpreters.channel_recv(cid)
978
979 def test_close_multiple_users(self):
980 cid = interpreters.channel_create()
981 id1 = interpreters.create()
982 id2 = interpreters.create()
983 interpreters.run_string(id1, dedent(f"""
984 import _xxsubinterpreters as _interpreters
985 _interpreters.channel_send({int(cid)}, b'spam')
986 """))
987 interpreters.run_string(id2, dedent(f"""
988 import _xxsubinterpreters as _interpreters
989 _interpreters.channel_recv({int(cid)})
990 """))
991 interpreters.channel_close(cid)
992 with self.assertRaises(interpreters.RunFailedError) as cm:
993 interpreters.run_string(id1, dedent(f"""
994 _interpreters.channel_send({int(cid)}, b'spam')
995 """))
996 self.assertIn('ChannelClosedError', str(cm.exception))
997 with self.assertRaises(interpreters.RunFailedError) as cm:
998 interpreters.run_string(id2, dedent(f"""
999 _interpreters.channel_send({int(cid)}, b'spam')
1000 """))
1001 self.assertIn('ChannelClosedError', str(cm.exception))
1002
1003 def test_close_multiple_times(self):
1004 cid = interpreters.channel_create()
1005 interpreters.channel_send(cid, b'spam')
1006 interpreters.channel_recv(cid)
1007 interpreters.channel_close(cid)
1008
1009 with self.assertRaises(interpreters.ChannelClosedError):
1010 interpreters.channel_close(cid)
1011
1012 def test_close_with_unused_items(self):
1013 cid = interpreters.channel_create()
1014 interpreters.channel_send(cid, b'spam')
1015 interpreters.channel_send(cid, b'ham')
1016 interpreters.channel_close(cid)
1017
1018 with self.assertRaises(interpreters.ChannelClosedError):
1019 interpreters.channel_recv(cid)
1020
1021 def test_close_never_used(self):
1022 cid = interpreters.channel_create()
1023 interpreters.channel_close(cid)
1024
1025 with self.assertRaises(interpreters.ChannelClosedError):
1026 interpreters.channel_send(cid, b'spam')
1027 with self.assertRaises(interpreters.ChannelClosedError):
1028 interpreters.channel_recv(cid)
1029
1030 def test_close_by_unassociated_interp(self):
1031 cid = interpreters.channel_create()
1032 interpreters.channel_send(cid, b'spam')
1033 interp = interpreters.create()
1034 interpreters.run_string(interp, dedent(f"""
1035 import _xxsubinterpreters as _interpreters
1036 _interpreters.channel_close({int(cid)})
1037 """))
1038 with self.assertRaises(interpreters.ChannelClosedError):
1039 interpreters.channel_recv(cid)
1040 with self.assertRaises(interpreters.ChannelClosedError):
1041 interpreters.channel_close(cid)
1042
1043 def test_close_used_multiple_times_by_single_user(self):
1044 cid = interpreters.channel_create()
1045 interpreters.channel_send(cid, b'spam')
1046 interpreters.channel_send(cid, b'spam')
1047 interpreters.channel_send(cid, b'spam')
1048 interpreters.channel_recv(cid)
1049 interpreters.channel_close(cid)
1050
1051 with self.assertRaises(interpreters.ChannelClosedError):
1052 interpreters.channel_send(cid, b'eggs')
1053 with self.assertRaises(interpreters.ChannelClosedError):
1054 interpreters.channel_recv(cid)
1055
1056 ####################
1057
1058 def test_send_recv_main(self):
1059 cid = interpreters.channel_create()
1060 orig = b'spam'
1061 interpreters.channel_send(cid, orig)
1062 obj = interpreters.channel_recv(cid)
1063
1064 self.assertEqual(obj, orig)
1065 self.assertIsNot(obj, orig)
1066
1067 def test_send_recv_same_interpreter(self):
1068 id1 = interpreters.create()
1069 out = _run_output(id1, dedent("""
1070 import _xxsubinterpreters as _interpreters
1071 cid = _interpreters.channel_create()
1072 orig = b'spam'
1073 _interpreters.channel_send(cid, orig)
1074 obj = _interpreters.channel_recv(cid)
1075 assert obj is not orig
1076 assert obj == orig
1077 """))
1078
1079 def test_send_recv_different_interpreters(self):
1080 cid = interpreters.channel_create()
1081 id1 = interpreters.create()
1082 out = _run_output(id1, dedent(f"""
1083 import _xxsubinterpreters as _interpreters
1084 _interpreters.channel_send({int(cid)}, b'spam')
1085 """))
1086 obj = interpreters.channel_recv(cid)
1087
1088 self.assertEqual(obj, b'spam')
1089
1090 def test_send_not_found(self):
1091 with self.assertRaises(interpreters.ChannelNotFoundError):
1092 interpreters.channel_send(10, b'spam')
1093
1094 def test_recv_not_found(self):
1095 with self.assertRaises(interpreters.ChannelNotFoundError):
1096 interpreters.channel_recv(10)
1097
1098 def test_recv_empty(self):
1099 cid = interpreters.channel_create()
1100 with self.assertRaises(interpreters.ChannelEmptyError):
1101 interpreters.channel_recv(cid)
1102
1103 def test_run_string_arg(self):
1104 cid = interpreters.channel_create()
1105 interp = interpreters.create()
1106
1107 out = _run_output(interp, dedent("""
1108 import _xxsubinterpreters as _interpreters
1109 print(cid.end)
1110 _interpreters.channel_send(cid, b'spam')
1111 """),
1112 dict(cid=cid.send))
1113 obj = interpreters.channel_recv(cid)
1114
1115 self.assertEqual(obj, b'spam')
1116 self.assertEqual(out.strip(), 'send')
1117
1118
1119if __name__ == '__main__':
1120 unittest.main()