blob: 397d3599312bec380e5100a529ab701574b9d157 [file] [log] [blame]
Eric Snow7f8bfc92018-01-29 18:23:44 -07001import contextlib
2import os
3import pickle
4from textwrap import dedent, indent
5import threading
Eric Snowf53d9f22018-02-20 16:30:17 -07006import time
Eric Snow7f8bfc92018-01-29 18:23:44 -07007import unittest
8
9from test import support
10from test.support import script_helper
11
12interpreters = support.import_module('_xxsubinterpreters')
13
14
15def _captured_script(script):
16 r, w = os.pipe()
17 indented = script.replace('\n', '\n ')
18 wrapped = dedent(f"""
19 import contextlib
20 with open({w}, 'w') as chan:
21 with contextlib.redirect_stdout(chan):
22 {indented}
23 """)
24 return wrapped, open(r)
25
26
27def _run_output(interp, request, shared=None):
28 script, chan = _captured_script(request)
29 with chan:
30 interpreters.run_string(interp, script, shared)
31 return chan.read()
32
33
34@contextlib.contextmanager
35def _running(interp):
36 r, w = os.pipe()
37 def run():
38 interpreters.run_string(interp, dedent(f"""
39 # wait for "signal"
40 with open({r}) as chan:
41 chan.read()
42 """))
43
44 t = threading.Thread(target=run)
45 t.start()
46
47 yield
48
49 with open(w, 'w') as chan:
50 chan.write('done')
51 t.join()
52
53
54class IsShareableTests(unittest.TestCase):
55
56 def test_default_shareables(self):
57 shareables = [
58 # singletons
59 None,
60 # builtin objects
61 b'spam',
62 ]
63 for obj in shareables:
64 with self.subTest(obj):
65 self.assertTrue(
66 interpreters.is_shareable(obj))
67
68 def test_not_shareable(self):
69 class Cheese:
70 def __init__(self, name):
71 self.name = name
72 def __str__(self):
73 return self.name
74
75 class SubBytes(bytes):
76 """A subclass of a shareable type."""
77
78 not_shareables = [
79 # singletons
80 True,
81 False,
82 NotImplemented,
83 ...,
84 # builtin types and objects
85 type,
86 object,
87 object(),
88 Exception(),
89 42,
90 100.0,
91 'spam',
92 # user-defined types and objects
93 Cheese,
94 Cheese('Wensleydale'),
95 SubBytes(b'spam'),
96 ]
97 for obj in not_shareables:
98 with self.subTest(obj):
99 self.assertFalse(
100 interpreters.is_shareable(obj))
101
102
103class TestBase(unittest.TestCase):
104
105 def tearDown(self):
106 for id in interpreters.list_all():
107 if id == 0: # main
108 continue
109 try:
110 interpreters.destroy(id)
111 except RuntimeError:
112 pass # already destroyed
113
114 for cid in interpreters.channel_list_all():
115 try:
116 interpreters.channel_destroy(cid)
117 except interpreters.ChannelNotFoundError:
118 pass # already destroyed
119
120
121class ListAllTests(TestBase):
122
123 def test_initial(self):
124 main = interpreters.get_main()
125 ids = interpreters.list_all()
126 self.assertEqual(ids, [main])
127
128 def test_after_creating(self):
129 main = interpreters.get_main()
130 first = interpreters.create()
131 second = interpreters.create()
132 ids = interpreters.list_all()
133 self.assertEqual(ids, [main, first, second])
134
135 def test_after_destroying(self):
136 main = interpreters.get_main()
137 first = interpreters.create()
138 second = interpreters.create()
139 interpreters.destroy(first)
140 ids = interpreters.list_all()
141 self.assertEqual(ids, [main, second])
142
143
144class GetCurrentTests(TestBase):
145
146 def test_main(self):
147 main = interpreters.get_main()
148 cur = interpreters.get_current()
149 self.assertEqual(cur, main)
150
151 def test_subinterpreter(self):
152 main = interpreters.get_main()
153 interp = interpreters.create()
154 out = _run_output(interp, dedent("""
155 import _xxsubinterpreters as _interpreters
Eric Snow4c6955e2018-02-16 18:53:40 -0700156 print(int(_interpreters.get_current()))
Eric Snow7f8bfc92018-01-29 18:23:44 -0700157 """))
158 cur = int(out.strip())
159 _, expected = interpreters.list_all()
160 self.assertEqual(cur, expected)
161 self.assertNotEqual(cur, main)
162
163
164class GetMainTests(TestBase):
165
166 def test_from_main(self):
167 [expected] = interpreters.list_all()
168 main = interpreters.get_main()
169 self.assertEqual(main, expected)
170
171 def test_from_subinterpreter(self):
172 [expected] = interpreters.list_all()
173 interp = interpreters.create()
174 out = _run_output(interp, dedent("""
175 import _xxsubinterpreters as _interpreters
Eric Snow4c6955e2018-02-16 18:53:40 -0700176 print(int(_interpreters.get_main()))
Eric Snow7f8bfc92018-01-29 18:23:44 -0700177 """))
178 main = int(out.strip())
179 self.assertEqual(main, expected)
180
181
182class IsRunningTests(TestBase):
183
184 def test_main(self):
185 main = interpreters.get_main()
186 self.assertTrue(interpreters.is_running(main))
187
188 def test_subinterpreter(self):
189 interp = interpreters.create()
190 self.assertFalse(interpreters.is_running(interp))
191
192 with _running(interp):
193 self.assertTrue(interpreters.is_running(interp))
194 self.assertFalse(interpreters.is_running(interp))
195
196 def test_from_subinterpreter(self):
197 interp = interpreters.create()
198 out = _run_output(interp, dedent(f"""
199 import _xxsubinterpreters as _interpreters
Eric Snow4c6955e2018-02-16 18:53:40 -0700200 if _interpreters.is_running({int(interp)}):
Eric Snow7f8bfc92018-01-29 18:23:44 -0700201 print(True)
202 else:
203 print(False)
204 """))
205 self.assertEqual(out.strip(), 'True')
206
207 def test_already_destroyed(self):
208 interp = interpreters.create()
209 interpreters.destroy(interp)
210 with self.assertRaises(RuntimeError):
211 interpreters.is_running(interp)
212
213 def test_does_not_exist(self):
214 with self.assertRaises(RuntimeError):
215 interpreters.is_running(1_000_000)
216
217 def test_bad_id(self):
218 with self.assertRaises(RuntimeError):
219 interpreters.is_running(-1)
220
221
Eric Snow4c6955e2018-02-16 18:53:40 -0700222class InterpreterIDTests(TestBase):
223
224 def test_with_int(self):
225 id = interpreters.InterpreterID(10, force=True)
226
227 self.assertEqual(int(id), 10)
228
229 def test_coerce_id(self):
230 id = interpreters.InterpreterID('10', force=True)
231 self.assertEqual(int(id), 10)
232
233 id = interpreters.InterpreterID(10.0, force=True)
234 self.assertEqual(int(id), 10)
235
236 class Int(str):
237 def __init__(self, value):
238 self._value = value
239 def __int__(self):
240 return self._value
241
242 id = interpreters.InterpreterID(Int(10), force=True)
243 self.assertEqual(int(id), 10)
244
245 def test_bad_id(self):
246 for id in [-1, 'spam']:
247 with self.subTest(id):
248 with self.assertRaises(ValueError):
249 interpreters.InterpreterID(id)
250 with self.assertRaises(OverflowError):
251 interpreters.InterpreterID(2**64)
252 with self.assertRaises(TypeError):
253 interpreters.InterpreterID(object())
254
255 def test_does_not_exist(self):
256 id = interpreters.channel_create()
257 with self.assertRaises(RuntimeError):
258 interpreters.InterpreterID(int(id) + 1) # unforced
259
260 def test_repr(self):
261 id = interpreters.InterpreterID(10, force=True)
262 self.assertEqual(repr(id), 'InterpreterID(10)')
263
264 def test_equality(self):
265 id1 = interpreters.create()
266 id2 = interpreters.InterpreterID(int(id1))
267 id3 = interpreters.create()
268
269 self.assertTrue(id1 == id1)
270 self.assertTrue(id1 == id2)
271 self.assertTrue(id1 == int(id1))
272 self.assertFalse(id1 == id3)
273
274 self.assertFalse(id1 != id1)
275 self.assertFalse(id1 != id2)
276 self.assertTrue(id1 != id3)
277
278
Eric Snow7f8bfc92018-01-29 18:23:44 -0700279class CreateTests(TestBase):
280
281 def test_in_main(self):
282 id = interpreters.create()
283
284 self.assertIn(id, interpreters.list_all())
285
286 @unittest.skip('enable this test when working on pystate.c')
287 def test_unique_id(self):
288 seen = set()
289 for _ in range(100):
290 id = interpreters.create()
291 interpreters.destroy(id)
292 seen.add(id)
293
294 self.assertEqual(len(seen), 100)
295
296 def test_in_thread(self):
297 lock = threading.Lock()
298 id = None
299 def f():
300 nonlocal id
301 id = interpreters.create()
302 lock.acquire()
303 lock.release()
304
305 t = threading.Thread(target=f)
306 with lock:
307 t.start()
308 t.join()
309 self.assertIn(id, interpreters.list_all())
310
311 def test_in_subinterpreter(self):
312 main, = interpreters.list_all()
313 id1 = interpreters.create()
314 out = _run_output(id1, dedent("""
315 import _xxsubinterpreters as _interpreters
316 id = _interpreters.create()
Eric Snow4c6955e2018-02-16 18:53:40 -0700317 print(int(id))
Eric Snow7f8bfc92018-01-29 18:23:44 -0700318 """))
319 id2 = int(out.strip())
320
321 self.assertEqual(set(interpreters.list_all()), {main, id1, id2})
322
323 def test_in_threaded_subinterpreter(self):
324 main, = interpreters.list_all()
325 id1 = interpreters.create()
326 id2 = None
327 def f():
328 nonlocal id2
329 out = _run_output(id1, dedent("""
330 import _xxsubinterpreters as _interpreters
331 id = _interpreters.create()
Eric Snow4c6955e2018-02-16 18:53:40 -0700332 print(int(id))
Eric Snow7f8bfc92018-01-29 18:23:44 -0700333 """))
334 id2 = int(out.strip())
335
336 t = threading.Thread(target=f)
337 t.start()
338 t.join()
339
340 self.assertEqual(set(interpreters.list_all()), {main, id1, id2})
341
342 def test_after_destroy_all(self):
343 before = set(interpreters.list_all())
344 # Create 3 subinterpreters.
345 ids = []
346 for _ in range(3):
347 id = interpreters.create()
348 ids.append(id)
349 # Now destroy them.
350 for id in ids:
351 interpreters.destroy(id)
352 # Finally, create another.
353 id = interpreters.create()
354 self.assertEqual(set(interpreters.list_all()), before | {id})
355
356 def test_after_destroy_some(self):
357 before = set(interpreters.list_all())
358 # Create 3 subinterpreters.
359 id1 = interpreters.create()
360 id2 = interpreters.create()
361 id3 = interpreters.create()
362 # Now destroy 2 of them.
363 interpreters.destroy(id1)
364 interpreters.destroy(id3)
365 # Finally, create another.
366 id = interpreters.create()
367 self.assertEqual(set(interpreters.list_all()), before | {id, id2})
368
369
370class DestroyTests(TestBase):
371
372 def test_one(self):
373 id1 = interpreters.create()
374 id2 = interpreters.create()
375 id3 = interpreters.create()
376 self.assertIn(id2, interpreters.list_all())
377 interpreters.destroy(id2)
378 self.assertNotIn(id2, interpreters.list_all())
379 self.assertIn(id1, interpreters.list_all())
380 self.assertIn(id3, interpreters.list_all())
381
382 def test_all(self):
383 before = set(interpreters.list_all())
384 ids = set()
385 for _ in range(3):
386 id = interpreters.create()
387 ids.add(id)
388 self.assertEqual(set(interpreters.list_all()), before | ids)
389 for id in ids:
390 interpreters.destroy(id)
391 self.assertEqual(set(interpreters.list_all()), before)
392
393 def test_main(self):
394 main, = interpreters.list_all()
395 with self.assertRaises(RuntimeError):
396 interpreters.destroy(main)
397
398 def f():
399 with self.assertRaises(RuntimeError):
400 interpreters.destroy(main)
401
402 t = threading.Thread(target=f)
403 t.start()
404 t.join()
405
406 def test_already_destroyed(self):
407 id = interpreters.create()
408 interpreters.destroy(id)
409 with self.assertRaises(RuntimeError):
410 interpreters.destroy(id)
411
412 def test_does_not_exist(self):
413 with self.assertRaises(RuntimeError):
414 interpreters.destroy(1_000_000)
415
416 def test_bad_id(self):
417 with self.assertRaises(RuntimeError):
418 interpreters.destroy(-1)
419
420 def test_from_current(self):
421 main, = interpreters.list_all()
422 id = interpreters.create()
Eric Snow4e9da0d2018-02-02 21:49:49 -0700423 script = dedent(f"""
Eric Snow7f8bfc92018-01-29 18:23:44 -0700424 import _xxsubinterpreters as _interpreters
Eric Snow4e9da0d2018-02-02 21:49:49 -0700425 try:
Eric Snow4c6955e2018-02-16 18:53:40 -0700426 _interpreters.destroy({int(id)})
Eric Snow4e9da0d2018-02-02 21:49:49 -0700427 except RuntimeError:
428 pass
429 """)
Eric Snow7f8bfc92018-01-29 18:23:44 -0700430
Eric Snow4e9da0d2018-02-02 21:49:49 -0700431 interpreters.run_string(id, script)
Eric Snow7f8bfc92018-01-29 18:23:44 -0700432 self.assertEqual(set(interpreters.list_all()), {main, id})
433
434 def test_from_sibling(self):
435 main, = interpreters.list_all()
436 id1 = interpreters.create()
437 id2 = interpreters.create()
Eric Snow4c6955e2018-02-16 18:53:40 -0700438 script = dedent(f"""
Eric Snow7f8bfc92018-01-29 18:23:44 -0700439 import _xxsubinterpreters as _interpreters
Eric Snow4c6955e2018-02-16 18:53:40 -0700440 _interpreters.destroy({int(id2)})
441 """)
Eric Snow7f8bfc92018-01-29 18:23:44 -0700442 interpreters.run_string(id1, script)
443
444 self.assertEqual(set(interpreters.list_all()), {main, id1})
445
446 def test_from_other_thread(self):
447 id = interpreters.create()
448 def f():
449 interpreters.destroy(id)
450
451 t = threading.Thread(target=f)
452 t.start()
453 t.join()
454
455 def test_still_running(self):
456 main, = interpreters.list_all()
457 interp = interpreters.create()
458 with _running(interp):
459 with self.assertRaises(RuntimeError):
460 interpreters.destroy(interp)
461 self.assertTrue(interpreters.is_running(interp))
462
463
464class RunStringTests(TestBase):
465
466 SCRIPT = dedent("""
467 with open('{}', 'w') as out:
468 out.write('{}')
469 """)
470 FILENAME = 'spam'
471
472 def setUp(self):
473 super().setUp()
474 self.id = interpreters.create()
475 self._fs = None
476
477 def tearDown(self):
478 if self._fs is not None:
479 self._fs.close()
480 super().tearDown()
481
482 @property
483 def fs(self):
484 if self._fs is None:
485 self._fs = FSFixture(self)
486 return self._fs
487
488 def test_success(self):
489 script, file = _captured_script('print("it worked!", end="")')
490 with file:
491 interpreters.run_string(self.id, script)
492 out = file.read()
493
494 self.assertEqual(out, 'it worked!')
495
496 def test_in_thread(self):
497 script, file = _captured_script('print("it worked!", end="")')
498 with file:
499 def f():
500 interpreters.run_string(self.id, script)
501
502 t = threading.Thread(target=f)
503 t.start()
504 t.join()
505 out = file.read()
506
507 self.assertEqual(out, 'it worked!')
508
509 def test_create_thread(self):
510 script, file = _captured_script("""
511 import threading
512 def f():
513 print('it worked!', end='')
514
515 t = threading.Thread(target=f)
516 t.start()
517 t.join()
518 """)
519 with file:
520 interpreters.run_string(self.id, script)
521 out = file.read()
522
523 self.assertEqual(out, 'it worked!')
524
525 @unittest.skipUnless(hasattr(os, 'fork'), "test needs os.fork()")
526 def test_fork(self):
527 import tempfile
528 with tempfile.NamedTemporaryFile('w+') as file:
529 file.write('')
530 file.flush()
531
532 expected = 'spam spam spam spam spam'
533 script = dedent(f"""
534 # (inspired by Lib/test/test_fork.py)
535 import os
536 pid = os.fork()
537 if pid == 0: # child
538 with open('{file.name}', 'w') as out:
539 out.write('{expected}')
540 # Kill the unittest runner in the child process.
541 os._exit(1)
542 else:
543 SHORT_SLEEP = 0.1
544 import time
545 for _ in range(10):
546 spid, status = os.waitpid(pid, os.WNOHANG)
547 if spid == pid:
548 break
549 time.sleep(SHORT_SLEEP)
550 assert(spid == pid)
551 """)
552 interpreters.run_string(self.id, script)
553
554 file.seek(0)
555 content = file.read()
556 self.assertEqual(content, expected)
557
558 def test_already_running(self):
559 with _running(self.id):
560 with self.assertRaises(RuntimeError):
561 interpreters.run_string(self.id, 'print("spam")')
562
563 def test_does_not_exist(self):
564 id = 0
565 while id in interpreters.list_all():
566 id += 1
567 with self.assertRaises(RuntimeError):
568 interpreters.run_string(id, 'print("spam")')
569
570 def test_error_id(self):
571 with self.assertRaises(RuntimeError):
572 interpreters.run_string(-1, 'print("spam")')
573
574 def test_bad_id(self):
575 with self.assertRaises(TypeError):
576 interpreters.run_string('spam', 'print("spam")')
577
578 def test_bad_script(self):
579 with self.assertRaises(TypeError):
580 interpreters.run_string(self.id, 10)
581
582 def test_bytes_for_script(self):
583 with self.assertRaises(TypeError):
584 interpreters.run_string(self.id, b'print("spam")')
585
586 @contextlib.contextmanager
587 def assert_run_failed(self, exctype, msg=None):
588 with self.assertRaises(interpreters.RunFailedError) as caught:
589 yield
590 if msg is None:
591 self.assertEqual(str(caught.exception).split(':')[0],
592 str(exctype))
593 else:
594 self.assertEqual(str(caught.exception),
595 "{}: {}".format(exctype, msg))
596
597 def test_invalid_syntax(self):
598 with self.assert_run_failed(SyntaxError):
599 # missing close paren
600 interpreters.run_string(self.id, 'print("spam"')
601
602 def test_failure(self):
603 with self.assert_run_failed(Exception, 'spam'):
604 interpreters.run_string(self.id, 'raise Exception("spam")')
605
606 def test_SystemExit(self):
607 with self.assert_run_failed(SystemExit, '42'):
608 interpreters.run_string(self.id, 'raise SystemExit(42)')
609
610 def test_sys_exit(self):
611 with self.assert_run_failed(SystemExit):
612 interpreters.run_string(self.id, dedent("""
613 import sys
614 sys.exit()
615 """))
616
617 with self.assert_run_failed(SystemExit, '42'):
618 interpreters.run_string(self.id, dedent("""
619 import sys
620 sys.exit(42)
621 """))
622
623 def test_with_shared(self):
624 r, w = os.pipe()
625
626 shared = {
627 'spam': b'ham',
628 'eggs': b'-1',
629 'cheddar': None,
630 }
631 script = dedent(f"""
632 eggs = int(eggs)
633 spam = 42
634 result = spam + eggs
635
636 ns = dict(vars())
637 del ns['__builtins__']
638 import pickle
639 with open({w}, 'wb') as chan:
640 pickle.dump(ns, chan)
641 """)
642 interpreters.run_string(self.id, script, shared)
643 with open(r, 'rb') as chan:
644 ns = pickle.load(chan)
645
646 self.assertEqual(ns['spam'], 42)
647 self.assertEqual(ns['eggs'], -1)
648 self.assertEqual(ns['result'], 41)
649 self.assertIsNone(ns['cheddar'])
650
651 def test_shared_overwrites(self):
652 interpreters.run_string(self.id, dedent("""
653 spam = 'eggs'
654 ns1 = dict(vars())
655 del ns1['__builtins__']
656 """))
657
658 shared = {'spam': b'ham'}
659 script = dedent(f"""
660 ns2 = dict(vars())
661 del ns2['__builtins__']
662 """)
663 interpreters.run_string(self.id, script, shared)
664
665 r, w = os.pipe()
666 script = dedent(f"""
667 ns = dict(vars())
668 del ns['__builtins__']
669 import pickle
670 with open({w}, 'wb') as chan:
671 pickle.dump(ns, chan)
672 """)
673 interpreters.run_string(self.id, script)
674 with open(r, 'rb') as chan:
675 ns = pickle.load(chan)
676
677 self.assertEqual(ns['ns1']['spam'], 'eggs')
678 self.assertEqual(ns['ns2']['spam'], b'ham')
679 self.assertEqual(ns['spam'], b'ham')
680
681 def test_shared_overwrites_default_vars(self):
682 r, w = os.pipe()
683
684 shared = {'__name__': b'not __main__'}
685 script = dedent(f"""
686 spam = 42
687
688 ns = dict(vars())
689 del ns['__builtins__']
690 import pickle
691 with open({w}, 'wb') as chan:
692 pickle.dump(ns, chan)
693 """)
694 interpreters.run_string(self.id, script, shared)
695 with open(r, 'rb') as chan:
696 ns = pickle.load(chan)
697
698 self.assertEqual(ns['__name__'], b'not __main__')
699
700 def test_main_reused(self):
701 r, w = os.pipe()
702 interpreters.run_string(self.id, dedent(f"""
703 spam = True
704
705 ns = dict(vars())
706 del ns['__builtins__']
707 import pickle
708 with open({w}, 'wb') as chan:
709 pickle.dump(ns, chan)
710 del ns, pickle, chan
711 """))
712 with open(r, 'rb') as chan:
713 ns1 = pickle.load(chan)
714
715 r, w = os.pipe()
716 interpreters.run_string(self.id, dedent(f"""
717 eggs = False
718
719 ns = dict(vars())
720 del ns['__builtins__']
721 import pickle
722 with open({w}, 'wb') as chan:
723 pickle.dump(ns, chan)
724 """))
725 with open(r, 'rb') as chan:
726 ns2 = pickle.load(chan)
727
728 self.assertIn('spam', ns1)
729 self.assertNotIn('eggs', ns1)
730 self.assertIn('eggs', ns2)
731 self.assertIn('spam', ns2)
732
733 def test_execution_namespace_is_main(self):
734 r, w = os.pipe()
735
736 script = dedent(f"""
737 spam = 42
738
739 ns = dict(vars())
740 ns['__builtins__'] = str(ns['__builtins__'])
741 import pickle
742 with open({w}, 'wb') as chan:
743 pickle.dump(ns, chan)
744 """)
745 interpreters.run_string(self.id, script)
746 with open(r, 'rb') as chan:
747 ns = pickle.load(chan)
748
749 ns.pop('__builtins__')
750 ns.pop('__loader__')
751 self.assertEqual(ns, {
752 '__name__': '__main__',
753 '__annotations__': {},
754 '__doc__': None,
755 '__package__': None,
756 '__spec__': None,
757 'spam': 42,
758 })
759
Eric Snow4c6955e2018-02-16 18:53:40 -0700760 # XXX Fix this test!
761 @unittest.skip('blocking forever')
Eric Snow7f8bfc92018-01-29 18:23:44 -0700762 def test_still_running_at_exit(self):
763 script = dedent(f"""
764 from textwrap import dedent
765 import threading
766 import _xxsubinterpreters as _interpreters
Eric Snow4c6955e2018-02-16 18:53:40 -0700767 id = _interpreters.create()
Eric Snow7f8bfc92018-01-29 18:23:44 -0700768 def f():
769 _interpreters.run_string(id, dedent('''
770 import time
771 # Give plenty of time for the main interpreter to finish.
772 time.sleep(1_000_000)
773 '''))
774
775 t = threading.Thread(target=f)
776 t.start()
777 """)
778 with support.temp_dir() as dirname:
779 filename = script_helper.make_script(dirname, 'interp', script)
780 with script_helper.spawn_python(filename) as proc:
781 retcode = proc.wait()
782
783 self.assertEqual(retcode, 0)
784
785
786class ChannelIDTests(TestBase):
787
788 def test_default_kwargs(self):
789 cid = interpreters._channel_id(10, force=True)
790
791 self.assertEqual(int(cid), 10)
792 self.assertEqual(cid.end, 'both')
793
794 def test_with_kwargs(self):
795 cid = interpreters._channel_id(10, send=True, force=True)
796 self.assertEqual(cid.end, 'send')
797
798 cid = interpreters._channel_id(10, send=True, recv=False, force=True)
799 self.assertEqual(cid.end, 'send')
800
801 cid = interpreters._channel_id(10, recv=True, force=True)
802 self.assertEqual(cid.end, 'recv')
803
804 cid = interpreters._channel_id(10, recv=True, send=False, force=True)
805 self.assertEqual(cid.end, 'recv')
806
807 cid = interpreters._channel_id(10, send=True, recv=True, force=True)
808 self.assertEqual(cid.end, 'both')
809
810 def test_coerce_id(self):
811 cid = interpreters._channel_id('10', force=True)
812 self.assertEqual(int(cid), 10)
813
814 cid = interpreters._channel_id(10.0, force=True)
815 self.assertEqual(int(cid), 10)
816
817 class Int(str):
818 def __init__(self, value):
819 self._value = value
820 def __int__(self):
821 return self._value
822
823 cid = interpreters._channel_id(Int(10), force=True)
824 self.assertEqual(int(cid), 10)
825
826 def test_bad_id(self):
Eric Snow4e9da0d2018-02-02 21:49:49 -0700827 for cid in [-1, 'spam']:
Eric Snow7f8bfc92018-01-29 18:23:44 -0700828 with self.subTest(cid):
829 with self.assertRaises(ValueError):
830 interpreters._channel_id(cid)
Eric Snow4e9da0d2018-02-02 21:49:49 -0700831 with self.assertRaises(OverflowError):
832 interpreters._channel_id(2**64)
Eric Snow7f8bfc92018-01-29 18:23:44 -0700833 with self.assertRaises(TypeError):
834 interpreters._channel_id(object())
835
836 def test_bad_kwargs(self):
837 with self.assertRaises(ValueError):
838 interpreters._channel_id(10, send=False, recv=False)
839
840 def test_does_not_exist(self):
841 cid = interpreters.channel_create()
842 with self.assertRaises(interpreters.ChannelNotFoundError):
843 interpreters._channel_id(int(cid) + 1) # unforced
844
845 def test_repr(self):
846 cid = interpreters._channel_id(10, force=True)
847 self.assertEqual(repr(cid), 'ChannelID(10)')
848
849 cid = interpreters._channel_id(10, send=True, force=True)
850 self.assertEqual(repr(cid), 'ChannelID(10, send=True)')
851
852 cid = interpreters._channel_id(10, recv=True, force=True)
853 self.assertEqual(repr(cid), 'ChannelID(10, recv=True)')
854
855 cid = interpreters._channel_id(10, send=True, recv=True, force=True)
856 self.assertEqual(repr(cid), 'ChannelID(10)')
857
858 def test_equality(self):
859 cid1 = interpreters.channel_create()
860 cid2 = interpreters._channel_id(int(cid1))
861 cid3 = interpreters.channel_create()
862
863 self.assertTrue(cid1 == cid1)
864 self.assertTrue(cid1 == cid2)
865 self.assertTrue(cid1 == int(cid1))
866 self.assertFalse(cid1 == cid3)
867
868 self.assertFalse(cid1 != cid1)
869 self.assertFalse(cid1 != cid2)
870 self.assertTrue(cid1 != cid3)
871
872
873class ChannelTests(TestBase):
874
875 def test_sequential_ids(self):
876 before = interpreters.channel_list_all()
877 id1 = interpreters.channel_create()
878 id2 = interpreters.channel_create()
879 id3 = interpreters.channel_create()
880 after = interpreters.channel_list_all()
881
882 self.assertEqual(id2, int(id1) + 1)
883 self.assertEqual(id3, int(id2) + 1)
884 self.assertEqual(set(after) - set(before), {id1, id2, id3})
885
886 def test_ids_global(self):
887 id1 = interpreters.create()
888 out = _run_output(id1, dedent("""
889 import _xxsubinterpreters as _interpreters
890 cid = _interpreters.channel_create()
891 print(int(cid))
892 """))
893 cid1 = int(out.strip())
894
895 id2 = interpreters.create()
896 out = _run_output(id2, dedent("""
897 import _xxsubinterpreters as _interpreters
898 cid = _interpreters.channel_create()
899 print(int(cid))
900 """))
901 cid2 = int(out.strip())
902
903 self.assertEqual(cid2, int(cid1) + 1)
904
905 ####################
906
907 def test_drop_single_user(self):
908 cid = interpreters.channel_create()
909 interpreters.channel_send(cid, b'spam')
910 interpreters.channel_recv(cid)
911 interpreters.channel_drop_interpreter(cid, send=True, recv=True)
912
913 with self.assertRaises(interpreters.ChannelClosedError):
914 interpreters.channel_send(cid, b'eggs')
915 with self.assertRaises(interpreters.ChannelClosedError):
916 interpreters.channel_recv(cid)
917
918 def test_drop_multiple_users(self):
919 cid = interpreters.channel_create()
920 id1 = interpreters.create()
921 id2 = interpreters.create()
922 interpreters.run_string(id1, dedent(f"""
923 import _xxsubinterpreters as _interpreters
924 _interpreters.channel_send({int(cid)}, b'spam')
925 """))
926 out = _run_output(id2, dedent(f"""
927 import _xxsubinterpreters as _interpreters
928 obj = _interpreters.channel_recv({int(cid)})
929 _interpreters.channel_drop_interpreter({int(cid)})
930 print(repr(obj))
931 """))
932 interpreters.run_string(id1, dedent(f"""
933 _interpreters.channel_drop_interpreter({int(cid)})
934 """))
935
936 self.assertEqual(out.strip(), "b'spam'")
937
938 def test_drop_no_kwargs(self):
939 cid = interpreters.channel_create()
940 interpreters.channel_send(cid, b'spam')
941 interpreters.channel_recv(cid)
942 interpreters.channel_drop_interpreter(cid)
943
944 with self.assertRaises(interpreters.ChannelClosedError):
945 interpreters.channel_send(cid, b'eggs')
946 with self.assertRaises(interpreters.ChannelClosedError):
947 interpreters.channel_recv(cid)
948
949 def test_drop_multiple_times(self):
950 cid = interpreters.channel_create()
951 interpreters.channel_send(cid, b'spam')
952 interpreters.channel_recv(cid)
953 interpreters.channel_drop_interpreter(cid, send=True, recv=True)
954
955 with self.assertRaises(interpreters.ChannelClosedError):
956 interpreters.channel_drop_interpreter(cid, send=True, recv=True)
957
958 def test_drop_with_unused_items(self):
959 cid = interpreters.channel_create()
960 interpreters.channel_send(cid, b'spam')
961 interpreters.channel_send(cid, b'ham')
962 interpreters.channel_drop_interpreter(cid, send=True, recv=True)
963
964 with self.assertRaises(interpreters.ChannelClosedError):
965 interpreters.channel_recv(cid)
966
967 def test_drop_never_used(self):
968 cid = interpreters.channel_create()
969 interpreters.channel_drop_interpreter(cid)
970
971 with self.assertRaises(interpreters.ChannelClosedError):
972 interpreters.channel_send(cid, b'spam')
973 with self.assertRaises(interpreters.ChannelClosedError):
974 interpreters.channel_recv(cid)
975
976 def test_drop_by_unassociated_interp(self):
977 cid = interpreters.channel_create()
978 interpreters.channel_send(cid, b'spam')
979 interp = interpreters.create()
980 interpreters.run_string(interp, dedent(f"""
981 import _xxsubinterpreters as _interpreters
982 _interpreters.channel_drop_interpreter({int(cid)})
983 """))
984 obj = interpreters.channel_recv(cid)
985 interpreters.channel_drop_interpreter(cid)
986
987 with self.assertRaises(interpreters.ChannelClosedError):
988 interpreters.channel_send(cid, b'eggs')
989 self.assertEqual(obj, b'spam')
990
991 def test_drop_close_if_unassociated(self):
992 cid = interpreters.channel_create()
993 interp = interpreters.create()
994 interpreters.run_string(interp, dedent(f"""
995 import _xxsubinterpreters as _interpreters
996 obj = _interpreters.channel_send({int(cid)}, b'spam')
997 _interpreters.channel_drop_interpreter({int(cid)})
998 """))
999
1000 with self.assertRaises(interpreters.ChannelClosedError):
1001 interpreters.channel_recv(cid)
1002
1003 def test_drop_partially(self):
1004 # XXX Is partial close too wierd/confusing?
1005 cid = interpreters.channel_create()
1006 interpreters.channel_send(cid, None)
1007 interpreters.channel_recv(cid)
1008 interpreters.channel_send(cid, b'spam')
1009 interpreters.channel_drop_interpreter(cid, send=True)
1010 obj = interpreters.channel_recv(cid)
1011
1012 self.assertEqual(obj, b'spam')
1013
1014 def test_drop_used_multiple_times_by_single_user(self):
1015 cid = interpreters.channel_create()
1016 interpreters.channel_send(cid, b'spam')
1017 interpreters.channel_send(cid, b'spam')
1018 interpreters.channel_send(cid, b'spam')
1019 interpreters.channel_recv(cid)
1020 interpreters.channel_drop_interpreter(cid, send=True, recv=True)
1021
1022 with self.assertRaises(interpreters.ChannelClosedError):
1023 interpreters.channel_send(cid, b'eggs')
1024 with self.assertRaises(interpreters.ChannelClosedError):
1025 interpreters.channel_recv(cid)
1026
1027 ####################
1028
1029 def test_close_single_user(self):
1030 cid = interpreters.channel_create()
1031 interpreters.channel_send(cid, b'spam')
1032 interpreters.channel_recv(cid)
1033 interpreters.channel_close(cid)
1034
1035 with self.assertRaises(interpreters.ChannelClosedError):
1036 interpreters.channel_send(cid, b'eggs')
1037 with self.assertRaises(interpreters.ChannelClosedError):
1038 interpreters.channel_recv(cid)
1039
1040 def test_close_multiple_users(self):
1041 cid = interpreters.channel_create()
1042 id1 = interpreters.create()
1043 id2 = interpreters.create()
1044 interpreters.run_string(id1, dedent(f"""
1045 import _xxsubinterpreters as _interpreters
1046 _interpreters.channel_send({int(cid)}, b'spam')
1047 """))
1048 interpreters.run_string(id2, dedent(f"""
1049 import _xxsubinterpreters as _interpreters
1050 _interpreters.channel_recv({int(cid)})
1051 """))
1052 interpreters.channel_close(cid)
1053 with self.assertRaises(interpreters.RunFailedError) as cm:
1054 interpreters.run_string(id1, dedent(f"""
1055 _interpreters.channel_send({int(cid)}, b'spam')
1056 """))
1057 self.assertIn('ChannelClosedError', str(cm.exception))
1058 with self.assertRaises(interpreters.RunFailedError) as cm:
1059 interpreters.run_string(id2, dedent(f"""
1060 _interpreters.channel_send({int(cid)}, b'spam')
1061 """))
1062 self.assertIn('ChannelClosedError', str(cm.exception))
1063
1064 def test_close_multiple_times(self):
1065 cid = interpreters.channel_create()
1066 interpreters.channel_send(cid, b'spam')
1067 interpreters.channel_recv(cid)
1068 interpreters.channel_close(cid)
1069
1070 with self.assertRaises(interpreters.ChannelClosedError):
1071 interpreters.channel_close(cid)
1072
1073 def test_close_with_unused_items(self):
1074 cid = interpreters.channel_create()
1075 interpreters.channel_send(cid, b'spam')
1076 interpreters.channel_send(cid, b'ham')
1077 interpreters.channel_close(cid)
1078
1079 with self.assertRaises(interpreters.ChannelClosedError):
1080 interpreters.channel_recv(cid)
1081
1082 def test_close_never_used(self):
1083 cid = interpreters.channel_create()
1084 interpreters.channel_close(cid)
1085
1086 with self.assertRaises(interpreters.ChannelClosedError):
1087 interpreters.channel_send(cid, b'spam')
1088 with self.assertRaises(interpreters.ChannelClosedError):
1089 interpreters.channel_recv(cid)
1090
1091 def test_close_by_unassociated_interp(self):
1092 cid = interpreters.channel_create()
1093 interpreters.channel_send(cid, b'spam')
1094 interp = interpreters.create()
1095 interpreters.run_string(interp, dedent(f"""
1096 import _xxsubinterpreters as _interpreters
1097 _interpreters.channel_close({int(cid)})
1098 """))
1099 with self.assertRaises(interpreters.ChannelClosedError):
1100 interpreters.channel_recv(cid)
1101 with self.assertRaises(interpreters.ChannelClosedError):
1102 interpreters.channel_close(cid)
1103
1104 def test_close_used_multiple_times_by_single_user(self):
1105 cid = interpreters.channel_create()
1106 interpreters.channel_send(cid, b'spam')
1107 interpreters.channel_send(cid, b'spam')
1108 interpreters.channel_send(cid, b'spam')
1109 interpreters.channel_recv(cid)
1110 interpreters.channel_close(cid)
1111
1112 with self.assertRaises(interpreters.ChannelClosedError):
1113 interpreters.channel_send(cid, b'eggs')
1114 with self.assertRaises(interpreters.ChannelClosedError):
1115 interpreters.channel_recv(cid)
1116
1117 ####################
1118
1119 def test_send_recv_main(self):
1120 cid = interpreters.channel_create()
1121 orig = b'spam'
1122 interpreters.channel_send(cid, orig)
1123 obj = interpreters.channel_recv(cid)
1124
1125 self.assertEqual(obj, orig)
1126 self.assertIsNot(obj, orig)
1127
1128 def test_send_recv_same_interpreter(self):
1129 id1 = interpreters.create()
1130 out = _run_output(id1, dedent("""
1131 import _xxsubinterpreters as _interpreters
1132 cid = _interpreters.channel_create()
1133 orig = b'spam'
1134 _interpreters.channel_send(cid, orig)
1135 obj = _interpreters.channel_recv(cid)
1136 assert obj is not orig
1137 assert obj == orig
1138 """))
1139
1140 def test_send_recv_different_interpreters(self):
1141 cid = interpreters.channel_create()
1142 id1 = interpreters.create()
1143 out = _run_output(id1, dedent(f"""
1144 import _xxsubinterpreters as _interpreters
1145 _interpreters.channel_send({int(cid)}, b'spam')
1146 """))
1147 obj = interpreters.channel_recv(cid)
1148
1149 self.assertEqual(obj, b'spam')
1150
Eric Snowf53d9f22018-02-20 16:30:17 -07001151 def test_send_recv_different_threads(self):
1152 cid = interpreters.channel_create()
1153
1154 def f():
1155 while True:
1156 try:
1157 obj = interpreters.channel_recv(cid)
1158 break
1159 except interpreters.ChannelEmptyError:
1160 time.sleep(0.1)
1161 interpreters.channel_send(cid, obj)
1162 t = threading.Thread(target=f)
1163 t.start()
1164
1165 interpreters.channel_send(cid, b'spam')
1166 t.join()
1167 obj = interpreters.channel_recv(cid)
1168
1169 self.assertEqual(obj, b'spam')
1170
1171 def test_send_recv_different_interpreters_and_threads(self):
1172 cid = interpreters.channel_create()
1173 id1 = interpreters.create()
1174 out = None
1175
1176 def f():
1177 nonlocal out
1178 out = _run_output(id1, dedent(f"""
1179 import time
1180 import _xxsubinterpreters as _interpreters
1181 while True:
1182 try:
1183 obj = _interpreters.channel_recv({int(cid)})
1184 break
1185 except _interpreters.ChannelEmptyError:
1186 time.sleep(0.1)
1187 assert(obj == b'spam')
1188 _interpreters.channel_send({int(cid)}, b'eggs')
1189 """))
1190 t = threading.Thread(target=f)
1191 t.start()
1192
1193 interpreters.channel_send(cid, b'spam')
1194 t.join()
1195 obj = interpreters.channel_recv(cid)
1196
1197 self.assertEqual(obj, b'eggs')
1198
Eric Snow7f8bfc92018-01-29 18:23:44 -07001199 def test_send_not_found(self):
1200 with self.assertRaises(interpreters.ChannelNotFoundError):
1201 interpreters.channel_send(10, b'spam')
1202
1203 def test_recv_not_found(self):
1204 with self.assertRaises(interpreters.ChannelNotFoundError):
1205 interpreters.channel_recv(10)
1206
1207 def test_recv_empty(self):
1208 cid = interpreters.channel_create()
1209 with self.assertRaises(interpreters.ChannelEmptyError):
1210 interpreters.channel_recv(cid)
1211
1212 def test_run_string_arg(self):
1213 cid = interpreters.channel_create()
1214 interp = interpreters.create()
1215
1216 out = _run_output(interp, dedent("""
1217 import _xxsubinterpreters as _interpreters
1218 print(cid.end)
1219 _interpreters.channel_send(cid, b'spam')
1220 """),
1221 dict(cid=cid.send))
1222 obj = interpreters.channel_recv(cid)
1223
1224 self.assertEqual(obj, b'spam')
1225 self.assertEqual(out.strip(), 'send')
1226
1227
1228if __name__ == '__main__':
1229 unittest.main()