blob: 094cc7a4597f9fe883c90549e83276cbb11ed778 [file] [log] [blame]
Antoine Pitrou557934f2009-11-06 22:41:14 +00001"""
2Various tests for synchronization primitives.
3"""
4
5import sys
6import time
Antoine Pitrou7c3e5772010-04-14 15:44:10 +00007from _thread import start_new_thread, get_ident, TIMEOUT_MAX
Antoine Pitrou557934f2009-11-06 22:41:14 +00008import threading
9import unittest
10
11from test import support
12
13
14def _wait():
15 # A crude wait/yield function not relying on synchronization primitives.
16 time.sleep(0.01)
17
18class Bunch(object):
19 """
20 A bunch of threads.
21 """
22 def __init__(self, f, n, wait_before_exit=False):
23 """
24 Construct a bunch of `n` threads running the same function `f`.
25 If `wait_before_exit` is True, the threads won't terminate until
26 do_finish() is called.
27 """
28 self.f = f
29 self.n = n
30 self.started = []
31 self.finished = []
32 self._can_exit = not wait_before_exit
33 def task():
34 tid = get_ident()
35 self.started.append(tid)
36 try:
37 f()
38 finally:
39 self.finished.append(tid)
40 while not self._can_exit:
41 _wait()
42 for i in range(n):
43 start_new_thread(task, ())
44
45 def wait_for_started(self):
46 while len(self.started) < self.n:
47 _wait()
48
49 def wait_for_finished(self):
50 while len(self.finished) < self.n:
51 _wait()
52
53 def do_finish(self):
54 self._can_exit = True
55
56
57class BaseTestCase(unittest.TestCase):
58 def setUp(self):
59 self._threads = support.threading_setup()
60
61 def tearDown(self):
62 support.threading_cleanup(*self._threads)
63 support.reap_children()
64
Antoine Pitrou7c3e5772010-04-14 15:44:10 +000065 def assertTimeout(self, actual, expected):
66 # The waiting and/or time.time() can be imprecise, which
67 # is why comparing to the expected value would sometimes fail
68 # (especially under Windows).
69 self.assertGreaterEqual(actual, expected * 0.6)
70 # Test nothing insane happened
71 self.assertLess(actual, expected * 10.0)
72
Antoine Pitrou557934f2009-11-06 22:41:14 +000073
74class BaseLockTests(BaseTestCase):
75 """
76 Tests for both recursive and non-recursive locks.
77 """
78
79 def test_constructor(self):
80 lock = self.locktype()
81 del lock
82
83 def test_acquire_destroy(self):
84 lock = self.locktype()
85 lock.acquire()
86 del lock
87
88 def test_acquire_release(self):
89 lock = self.locktype()
90 lock.acquire()
91 lock.release()
92 del lock
93
94 def test_try_acquire(self):
95 lock = self.locktype()
96 self.assertTrue(lock.acquire(False))
97 lock.release()
98
99 def test_try_acquire_contended(self):
100 lock = self.locktype()
101 lock.acquire()
102 result = []
103 def f():
104 result.append(lock.acquire(False))
105 Bunch(f, 1).wait_for_finished()
106 self.assertFalse(result[0])
107 lock.release()
108
109 def test_acquire_contended(self):
110 lock = self.locktype()
111 lock.acquire()
112 N = 5
113 def f():
114 lock.acquire()
115 lock.release()
116
117 b = Bunch(f, N)
118 b.wait_for_started()
119 _wait()
120 self.assertEqual(len(b.finished), 0)
121 lock.release()
122 b.wait_for_finished()
123 self.assertEqual(len(b.finished), N)
124
125 def test_with(self):
126 lock = self.locktype()
127 def f():
128 lock.acquire()
129 lock.release()
130 def _with(err=None):
131 with lock:
132 if err is not None:
133 raise err
134 _with()
135 # Check the lock is unacquired
136 Bunch(f, 1).wait_for_finished()
137 self.assertRaises(TypeError, _with, TypeError)
138 # Check the lock is unacquired
139 Bunch(f, 1).wait_for_finished()
140
Antoine Pitroub0872682009-11-09 16:08:16 +0000141 def test_thread_leak(self):
142 # The lock shouldn't leak a Thread instance when used from a foreign
143 # (non-threading) thread.
144 lock = self.locktype()
145 def f():
146 lock.acquire()
147 lock.release()
148 n = len(threading.enumerate())
149 # We run many threads in the hope that existing threads ids won't
150 # be recycled.
151 Bunch(f, 15).wait_for_finished()
Antoine Pitrou45fdb452011-04-04 21:59:09 +0200152 if len(threading.enumerate()) != n:
153 # There is a small window during which a Thread instance's
154 # target function has finished running, but the Thread is still
155 # alive and registered. Avoid spurious failures by waiting a
156 # bit more (seen on a buildbot).
157 time.sleep(0.4)
158 self.assertEqual(n, len(threading.enumerate()))
Antoine Pitroub0872682009-11-09 16:08:16 +0000159
Antoine Pitrou7c3e5772010-04-14 15:44:10 +0000160 def test_timeout(self):
161 lock = self.locktype()
162 # Can't set timeout if not blocking
163 self.assertRaises(ValueError, lock.acquire, 0, 1)
164 # Invalid timeout values
165 self.assertRaises(ValueError, lock.acquire, timeout=-100)
166 self.assertRaises(OverflowError, lock.acquire, timeout=1e100)
167 self.assertRaises(OverflowError, lock.acquire, timeout=TIMEOUT_MAX + 1)
168 # TIMEOUT_MAX is ok
169 lock.acquire(timeout=TIMEOUT_MAX)
170 lock.release()
171 t1 = time.time()
172 self.assertTrue(lock.acquire(timeout=5))
173 t2 = time.time()
174 # Just a sanity test that it didn't actually wait for the timeout.
175 self.assertLess(t2 - t1, 5)
176 results = []
177 def f():
178 t1 = time.time()
179 results.append(lock.acquire(timeout=0.5))
180 t2 = time.time()
181 results.append(t2 - t1)
182 Bunch(f, 1).wait_for_finished()
183 self.assertFalse(results[0])
184 self.assertTimeout(results[1], 0.5)
185
Antoine Pitrou557934f2009-11-06 22:41:14 +0000186
187class LockTests(BaseLockTests):
188 """
189 Tests for non-recursive, weak locks
190 (which can be acquired and released from different threads).
191 """
192 def test_reacquire(self):
193 # Lock needs to be released before re-acquiring.
194 lock = self.locktype()
195 phase = []
196 def f():
197 lock.acquire()
198 phase.append(None)
199 lock.acquire()
200 phase.append(None)
201 start_new_thread(f, ())
202 while len(phase) == 0:
203 _wait()
204 _wait()
205 self.assertEqual(len(phase), 1)
206 lock.release()
207 while len(phase) == 1:
208 _wait()
209 self.assertEqual(len(phase), 2)
210
211 def test_different_thread(self):
212 # Lock can be released from a different thread.
213 lock = self.locktype()
214 lock.acquire()
215 def f():
216 lock.release()
217 b = Bunch(f, 1)
218 b.wait_for_finished()
219 lock.acquire()
220 lock.release()
221
Antoine Pitrou7899acf2011-03-31 01:00:32 +0200222 def test_state_after_timeout(self):
223 # Issue #11618: check that lock is in a proper state after a
224 # (non-zero) timeout.
225 lock = self.locktype()
226 lock.acquire()
227 self.assertFalse(lock.acquire(timeout=0.01))
228 lock.release()
229 self.assertFalse(lock.locked())
230 self.assertTrue(lock.acquire(blocking=False))
231
Antoine Pitrou557934f2009-11-06 22:41:14 +0000232
233class RLockTests(BaseLockTests):
234 """
235 Tests for recursive locks.
236 """
237 def test_reacquire(self):
238 lock = self.locktype()
239 lock.acquire()
240 lock.acquire()
241 lock.release()
242 lock.acquire()
243 lock.release()
244 lock.release()
245
246 def test_release_unacquired(self):
247 # Cannot release an unacquired lock
248 lock = self.locktype()
249 self.assertRaises(RuntimeError, lock.release)
250 lock.acquire()
251 lock.acquire()
252 lock.release()
253 lock.acquire()
254 lock.release()
255 lock.release()
256 self.assertRaises(RuntimeError, lock.release)
257
258 def test_different_thread(self):
259 # Cannot release from a different thread
260 lock = self.locktype()
261 def f():
262 lock.acquire()
263 b = Bunch(f, 1, True)
264 try:
265 self.assertRaises(RuntimeError, lock.release)
266 finally:
267 b.do_finish()
268
269 def test__is_owned(self):
270 lock = self.locktype()
271 self.assertFalse(lock._is_owned())
272 lock.acquire()
273 self.assertTrue(lock._is_owned())
274 lock.acquire()
275 self.assertTrue(lock._is_owned())
276 result = []
277 def f():
278 result.append(lock._is_owned())
279 Bunch(f, 1).wait_for_finished()
280 self.assertFalse(result[0])
281 lock.release()
282 self.assertTrue(lock._is_owned())
283 lock.release()
284 self.assertFalse(lock._is_owned())
285
286
287class EventTests(BaseTestCase):
288 """
289 Tests for Event objects.
290 """
291
292 def test_is_set(self):
293 evt = self.eventtype()
294 self.assertFalse(evt.is_set())
295 evt.set()
296 self.assertTrue(evt.is_set())
297 evt.set()
298 self.assertTrue(evt.is_set())
299 evt.clear()
300 self.assertFalse(evt.is_set())
301 evt.clear()
302 self.assertFalse(evt.is_set())
303
304 def _check_notify(self, evt):
305 # All threads get notified
306 N = 5
307 results1 = []
308 results2 = []
309 def f():
310 results1.append(evt.wait())
311 results2.append(evt.wait())
312 b = Bunch(f, N)
313 b.wait_for_started()
314 _wait()
315 self.assertEqual(len(results1), 0)
316 evt.set()
317 b.wait_for_finished()
318 self.assertEqual(results1, [True] * N)
319 self.assertEqual(results2, [True] * N)
320
321 def test_notify(self):
322 evt = self.eventtype()
323 self._check_notify(evt)
324 # Another time, after an explicit clear()
325 evt.set()
326 evt.clear()
327 self._check_notify(evt)
328
329 def test_timeout(self):
330 evt = self.eventtype()
331 results1 = []
332 results2 = []
333 N = 5
334 def f():
335 results1.append(evt.wait(0.0))
336 t1 = time.time()
Antoine Pitrou7c3e5772010-04-14 15:44:10 +0000337 r = evt.wait(0.5)
Antoine Pitrou557934f2009-11-06 22:41:14 +0000338 t2 = time.time()
339 results2.append((r, t2 - t1))
340 Bunch(f, N).wait_for_finished()
341 self.assertEqual(results1, [False] * N)
342 for r, dt in results2:
343 self.assertFalse(r)
Antoine Pitrou7c3e5772010-04-14 15:44:10 +0000344 self.assertTimeout(dt, 0.5)
Antoine Pitrou557934f2009-11-06 22:41:14 +0000345 # The event is set
346 results1 = []
347 results2 = []
348 evt.set()
349 Bunch(f, N).wait_for_finished()
350 self.assertEqual(results1, [True] * N)
351 for r, dt in results2:
352 self.assertTrue(r)
353
Charles-François Natalided03482012-01-07 18:24:56 +0100354 def test_set_and_clear(self):
355 # Issue #13502: check that wait() returns true even when the event is
356 # cleared before the waiting thread is woken up.
357 evt = self.eventtype()
358 results = []
359 N = 5
360 def f():
361 results.append(evt.wait(1))
362 b = Bunch(f, N)
363 b.wait_for_started()
364 time.sleep(0.5)
365 evt.set()
366 evt.clear()
367 b.wait_for_finished()
368 self.assertEqual(results, [True] * N)
369
Antoine Pitrou557934f2009-11-06 22:41:14 +0000370
371class ConditionTests(BaseTestCase):
372 """
373 Tests for condition variables.
374 """
375
376 def test_acquire(self):
377 cond = self.condtype()
378 # Be default we have an RLock: the condition can be acquired multiple
379 # times.
380 cond.acquire()
381 cond.acquire()
382 cond.release()
383 cond.release()
384 lock = threading.Lock()
385 cond = self.condtype(lock)
386 cond.acquire()
387 self.assertFalse(lock.acquire(False))
388 cond.release()
389 self.assertTrue(lock.acquire(False))
390 self.assertFalse(cond.acquire(False))
391 lock.release()
392 with cond:
393 self.assertFalse(lock.acquire(False))
394
395 def test_unacquired_wait(self):
396 cond = self.condtype()
397 self.assertRaises(RuntimeError, cond.wait)
398
399 def test_unacquired_notify(self):
400 cond = self.condtype()
401 self.assertRaises(RuntimeError, cond.notify)
402
403 def _check_notify(self, cond):
404 N = 5
405 results1 = []
406 results2 = []
407 phase_num = 0
408 def f():
409 cond.acquire()
Georg Brandlb9a43912010-10-28 09:03:20 +0000410 result = cond.wait()
Antoine Pitrou557934f2009-11-06 22:41:14 +0000411 cond.release()
Georg Brandlb9a43912010-10-28 09:03:20 +0000412 results1.append((result, phase_num))
Antoine Pitrou557934f2009-11-06 22:41:14 +0000413 cond.acquire()
Georg Brandlb9a43912010-10-28 09:03:20 +0000414 result = cond.wait()
Antoine Pitrou557934f2009-11-06 22:41:14 +0000415 cond.release()
Georg Brandlb9a43912010-10-28 09:03:20 +0000416 results2.append((result, phase_num))
Antoine Pitrou557934f2009-11-06 22:41:14 +0000417 b = Bunch(f, N)
418 b.wait_for_started()
419 _wait()
420 self.assertEqual(results1, [])
421 # Notify 3 threads at first
422 cond.acquire()
423 cond.notify(3)
424 _wait()
425 phase_num = 1
426 cond.release()
427 while len(results1) < 3:
428 _wait()
Georg Brandlb9a43912010-10-28 09:03:20 +0000429 self.assertEqual(results1, [(True, 1)] * 3)
Antoine Pitrou557934f2009-11-06 22:41:14 +0000430 self.assertEqual(results2, [])
431 # Notify 5 threads: they might be in their first or second wait
432 cond.acquire()
433 cond.notify(5)
434 _wait()
435 phase_num = 2
436 cond.release()
437 while len(results1) + len(results2) < 8:
438 _wait()
Georg Brandlb9a43912010-10-28 09:03:20 +0000439 self.assertEqual(results1, [(True, 1)] * 3 + [(True, 2)] * 2)
440 self.assertEqual(results2, [(True, 2)] * 3)
Antoine Pitrou557934f2009-11-06 22:41:14 +0000441 # Notify all threads: they are all in their second wait
442 cond.acquire()
443 cond.notify_all()
444 _wait()
445 phase_num = 3
446 cond.release()
447 while len(results2) < 5:
448 _wait()
Georg Brandlb9a43912010-10-28 09:03:20 +0000449 self.assertEqual(results1, [(True, 1)] * 3 + [(True,2)] * 2)
450 self.assertEqual(results2, [(True, 2)] * 3 + [(True, 3)] * 2)
Antoine Pitrou557934f2009-11-06 22:41:14 +0000451 b.wait_for_finished()
452
453 def test_notify(self):
454 cond = self.condtype()
455 self._check_notify(cond)
456 # A second time, to check internal state is still ok.
457 self._check_notify(cond)
458
459 def test_timeout(self):
460 cond = self.condtype()
461 results = []
462 N = 5
463 def f():
464 cond.acquire()
465 t1 = time.time()
Georg Brandlb9a43912010-10-28 09:03:20 +0000466 result = cond.wait(0.5)
Antoine Pitrou557934f2009-11-06 22:41:14 +0000467 t2 = time.time()
468 cond.release()
Georg Brandlb9a43912010-10-28 09:03:20 +0000469 results.append((t2 - t1, result))
Antoine Pitrou557934f2009-11-06 22:41:14 +0000470 Bunch(f, N).wait_for_finished()
Georg Brandlb9a43912010-10-28 09:03:20 +0000471 self.assertEqual(len(results), N)
472 for dt, result in results:
Antoine Pitrou7c3e5772010-04-14 15:44:10 +0000473 self.assertTimeout(dt, 0.5)
Georg Brandlb9a43912010-10-28 09:03:20 +0000474 # Note that conceptually (that"s the condition variable protocol)
475 # a wait() may succeed even if no one notifies us and before any
476 # timeout occurs. Spurious wakeups can occur.
477 # This makes it hard to verify the result value.
478 # In practice, this implementation has no spurious wakeups.
479 self.assertFalse(result)
Antoine Pitrou557934f2009-11-06 22:41:14 +0000480
Kristján Valur Jónsson63315202010-11-18 12:46:39 +0000481 def test_waitfor(self):
482 cond = self.condtype()
483 state = 0
484 def f():
485 with cond:
486 result = cond.wait_for(lambda : state==4)
487 self.assertTrue(result)
488 self.assertEqual(state, 4)
489 b = Bunch(f, 1)
490 b.wait_for_started()
Victor Stinner3349bca2011-05-18 00:16:14 +0200491 for i in range(4):
Kristján Valur Jónsson63315202010-11-18 12:46:39 +0000492 time.sleep(0.01)
493 with cond:
494 state += 1
495 cond.notify()
496 b.wait_for_finished()
497
498 def test_waitfor_timeout(self):
499 cond = self.condtype()
500 state = 0
501 success = []
502 def f():
503 with cond:
504 dt = time.time()
505 result = cond.wait_for(lambda : state==4, timeout=0.1)
506 dt = time.time() - dt
507 self.assertFalse(result)
508 self.assertTimeout(dt, 0.1)
509 success.append(None)
510 b = Bunch(f, 1)
511 b.wait_for_started()
512 # Only increment 3 times, so state == 4 is never reached.
513 for i in range(3):
514 time.sleep(0.01)
515 with cond:
516 state += 1
517 cond.notify()
518 b.wait_for_finished()
519 self.assertEqual(len(success), 1)
520
Antoine Pitrou557934f2009-11-06 22:41:14 +0000521
522class BaseSemaphoreTests(BaseTestCase):
523 """
524 Common tests for {bounded, unbounded} semaphore objects.
525 """
526
527 def test_constructor(self):
528 self.assertRaises(ValueError, self.semtype, value = -1)
529 self.assertRaises(ValueError, self.semtype, value = -sys.maxsize)
530
531 def test_acquire(self):
532 sem = self.semtype(1)
533 sem.acquire()
534 sem.release()
535 sem = self.semtype(2)
536 sem.acquire()
537 sem.acquire()
538 sem.release()
539 sem.release()
540
541 def test_acquire_destroy(self):
542 sem = self.semtype()
543 sem.acquire()
544 del sem
545
546 def test_acquire_contended(self):
547 sem = self.semtype(7)
548 sem.acquire()
549 N = 10
550 results1 = []
551 results2 = []
552 phase_num = 0
553 def f():
554 sem.acquire()
555 results1.append(phase_num)
556 sem.acquire()
557 results2.append(phase_num)
558 b = Bunch(f, 10)
559 b.wait_for_started()
560 while len(results1) + len(results2) < 6:
561 _wait()
562 self.assertEqual(results1 + results2, [0] * 6)
563 phase_num = 1
564 for i in range(7):
565 sem.release()
566 while len(results1) + len(results2) < 13:
567 _wait()
568 self.assertEqual(sorted(results1 + results2), [0] * 6 + [1] * 7)
569 phase_num = 2
570 for i in range(6):
571 sem.release()
572 while len(results1) + len(results2) < 19:
573 _wait()
574 self.assertEqual(sorted(results1 + results2), [0] * 6 + [1] * 7 + [2] * 6)
575 # The semaphore is still locked
576 self.assertFalse(sem.acquire(False))
577 # Final release, to let the last thread finish
578 sem.release()
579 b.wait_for_finished()
580
581 def test_try_acquire(self):
582 sem = self.semtype(2)
583 self.assertTrue(sem.acquire(False))
584 self.assertTrue(sem.acquire(False))
585 self.assertFalse(sem.acquire(False))
586 sem.release()
587 self.assertTrue(sem.acquire(False))
588
589 def test_try_acquire_contended(self):
590 sem = self.semtype(4)
591 sem.acquire()
592 results = []
593 def f():
594 results.append(sem.acquire(False))
595 results.append(sem.acquire(False))
596 Bunch(f, 5).wait_for_finished()
597 # There can be a thread switch between acquiring the semaphore and
598 # appending the result, therefore results will not necessarily be
599 # ordered.
600 self.assertEqual(sorted(results), [False] * 7 + [True] * 3 )
601
Antoine Pitrou0454af92010-04-17 23:51:58 +0000602 def test_acquire_timeout(self):
603 sem = self.semtype(2)
604 self.assertRaises(ValueError, sem.acquire, False, timeout=1.0)
605 self.assertTrue(sem.acquire(timeout=0.005))
606 self.assertTrue(sem.acquire(timeout=0.005))
607 self.assertFalse(sem.acquire(timeout=0.005))
608 sem.release()
609 self.assertTrue(sem.acquire(timeout=0.005))
610 t = time.time()
611 self.assertFalse(sem.acquire(timeout=0.5))
612 dt = time.time() - t
613 self.assertTimeout(dt, 0.5)
614
Antoine Pitrou557934f2009-11-06 22:41:14 +0000615 def test_default_value(self):
616 # The default initial value is 1.
617 sem = self.semtype()
618 sem.acquire()
619 def f():
620 sem.acquire()
621 sem.release()
622 b = Bunch(f, 1)
623 b.wait_for_started()
624 _wait()
625 self.assertFalse(b.finished)
626 sem.release()
627 b.wait_for_finished()
628
629 def test_with(self):
630 sem = self.semtype(2)
631 def _with(err=None):
632 with sem:
633 self.assertTrue(sem.acquire(False))
634 sem.release()
635 with sem:
636 self.assertFalse(sem.acquire(False))
637 if err:
638 raise err
639 _with()
640 self.assertTrue(sem.acquire(False))
641 sem.release()
642 self.assertRaises(TypeError, _with, TypeError)
643 self.assertTrue(sem.acquire(False))
644 sem.release()
645
646class SemaphoreTests(BaseSemaphoreTests):
647 """
648 Tests for unbounded semaphores.
649 """
650
651 def test_release_unacquired(self):
652 # Unbounded releases are allowed and increment the semaphore's value
653 sem = self.semtype(1)
654 sem.release()
655 sem.acquire()
656 sem.acquire()
657 sem.release()
658
659
660class BoundedSemaphoreTests(BaseSemaphoreTests):
661 """
662 Tests for bounded semaphores.
663 """
664
665 def test_release_unacquired(self):
666 # Cannot go past the initial value
667 sem = self.semtype()
668 self.assertRaises(ValueError, sem.release)
669 sem.acquire()
670 sem.release()
671 self.assertRaises(ValueError, sem.release)
Kristján Valur Jónsson3be00032010-10-28 09:43:10 +0000672
673
674class BarrierTests(BaseTestCase):
675 """
676 Tests for Barrier objects.
677 """
678 N = 5
Antoine Pitrou12ae2902010-11-17 21:55:41 +0000679 defaultTimeout = 2.0
Kristján Valur Jónsson3be00032010-10-28 09:43:10 +0000680
681 def setUp(self):
Kristján Valur Jónssonf53a6262010-10-31 03:00:57 +0000682 self.barrier = self.barriertype(self.N, timeout=self.defaultTimeout)
Kristján Valur Jónsson3be00032010-10-28 09:43:10 +0000683 def tearDown(self):
684 self.barrier.abort()
685
686 def run_threads(self, f):
687 b = Bunch(f, self.N-1)
688 f()
689 b.wait_for_finished()
690
691 def multipass(self, results, n):
692 m = self.barrier.parties
693 self.assertEqual(m, self.N)
694 for i in range(n):
695 results[0].append(True)
696 self.assertEqual(len(results[1]), i * m)
697 self.barrier.wait()
698 results[1].append(True)
699 self.assertEqual(len(results[0]), (i + 1) * m)
700 self.barrier.wait()
701 self.assertEqual(self.barrier.n_waiting, 0)
702 self.assertFalse(self.barrier.broken)
703
704 def test_barrier(self, passes=1):
705 """
706 Test that a barrier is passed in lockstep
707 """
708 results = [[],[]]
709 def f():
710 self.multipass(results, passes)
711 self.run_threads(f)
712
713 def test_barrier_10(self):
714 """
715 Test that a barrier works for 10 consecutive runs
716 """
717 return self.test_barrier(10)
718
719 def test_wait_return(self):
720 """
721 test the return value from barrier.wait
722 """
723 results = []
724 def f():
725 r = self.barrier.wait()
726 results.append(r)
727
728 self.run_threads(f)
729 self.assertEqual(sum(results), sum(range(self.N)))
730
731 def test_action(self):
732 """
733 Test the 'action' callback
734 """
735 results = []
736 def action():
737 results.append(True)
738 barrier = self.barriertype(self.N, action)
739 def f():
740 barrier.wait()
741 self.assertEqual(len(results), 1)
742
743 self.run_threads(f)
744
745 def test_abort(self):
746 """
747 Test that an abort will put the barrier in a broken state
748 """
749 results1 = []
750 results2 = []
751 def f():
752 try:
753 i = self.barrier.wait()
754 if i == self.N//2:
755 raise RuntimeError
756 self.barrier.wait()
757 results1.append(True)
758 except threading.BrokenBarrierError:
759 results2.append(True)
760 except RuntimeError:
761 self.barrier.abort()
762 pass
763
764 self.run_threads(f)
765 self.assertEqual(len(results1), 0)
766 self.assertEqual(len(results2), self.N-1)
767 self.assertTrue(self.barrier.broken)
768
769 def test_reset(self):
770 """
771 Test that a 'reset' on a barrier frees the waiting threads
772 """
773 results1 = []
774 results2 = []
775 results3 = []
776 def f():
777 i = self.barrier.wait()
778 if i == self.N//2:
779 # Wait until the other threads are all in the barrier.
780 while self.barrier.n_waiting < self.N-1:
781 time.sleep(0.001)
782 self.barrier.reset()
783 else:
784 try:
785 self.barrier.wait()
786 results1.append(True)
787 except threading.BrokenBarrierError:
788 results2.append(True)
789 # Now, pass the barrier again
790 self.barrier.wait()
791 results3.append(True)
792
793 self.run_threads(f)
794 self.assertEqual(len(results1), 0)
795 self.assertEqual(len(results2), self.N-1)
796 self.assertEqual(len(results3), self.N)
797
798
799 def test_abort_and_reset(self):
800 """
801 Test that a barrier can be reset after being broken.
802 """
803 results1 = []
804 results2 = []
805 results3 = []
806 barrier2 = self.barriertype(self.N)
807 def f():
808 try:
809 i = self.barrier.wait()
810 if i == self.N//2:
811 raise RuntimeError
812 self.barrier.wait()
813 results1.append(True)
814 except threading.BrokenBarrierError:
815 results2.append(True)
816 except RuntimeError:
817 self.barrier.abort()
818 pass
819 # Synchronize and reset the barrier. Must synchronize first so
820 # that everyone has left it when we reset, and after so that no
821 # one enters it before the reset.
822 if barrier2.wait() == self.N//2:
823 self.barrier.reset()
824 barrier2.wait()
825 self.barrier.wait()
826 results3.append(True)
827
828 self.run_threads(f)
829 self.assertEqual(len(results1), 0)
830 self.assertEqual(len(results2), self.N-1)
831 self.assertEqual(len(results3), self.N)
832
833 def test_timeout(self):
834 """
835 Test wait(timeout)
836 """
837 def f():
838 i = self.barrier.wait()
839 if i == self.N // 2:
840 # One thread is late!
Antoine Pitrou12ae2902010-11-17 21:55:41 +0000841 time.sleep(1.0)
842 # Default timeout is 2.0, so this is shorter.
Kristján Valur Jónsson3be00032010-10-28 09:43:10 +0000843 self.assertRaises(threading.BrokenBarrierError,
Antoine Pitrou12ae2902010-11-17 21:55:41 +0000844 self.barrier.wait, 0.5)
Kristján Valur Jónsson3be00032010-10-28 09:43:10 +0000845 self.run_threads(f)
846
847 def test_default_timeout(self):
848 """
849 Test the barrier's default timeout
850 """
Charles-François Natalid4d1d062011-07-27 21:26:42 +0200851 # create a barrier with a low default timeout
852 barrier = self.barriertype(self.N, timeout=0.3)
Kristján Valur Jónsson3be00032010-10-28 09:43:10 +0000853 def f():
Kristján Valur Jónssonf53a6262010-10-31 03:00:57 +0000854 i = barrier.wait()
Kristján Valur Jónsson3be00032010-10-28 09:43:10 +0000855 if i == self.N // 2:
Charles-François Natalid4d1d062011-07-27 21:26:42 +0200856 # One thread is later than the default timeout of 0.3s.
Antoine Pitrou12ae2902010-11-17 21:55:41 +0000857 time.sleep(1.0)
Kristján Valur Jónssonf53a6262010-10-31 03:00:57 +0000858 self.assertRaises(threading.BrokenBarrierError, barrier.wait)
Kristján Valur Jónsson3be00032010-10-28 09:43:10 +0000859 self.run_threads(f)
860
861 def test_single_thread(self):
862 b = self.barriertype(1)
863 b.wait()
864 b.wait()