blob: b6d818e4d4c3a9de68f2e0251866a54b8991c9f9 [file] [log] [blame]
Antoine Pitrou557934f2009-11-06 22:41:14 +00001"""
2Various tests for synchronization primitives.
3"""
4
5import sys
6import time
Antoine Pitrou7c3e5772010-04-14 15:44:10 +00007from _thread import start_new_thread, get_ident, TIMEOUT_MAX
Antoine Pitrou557934f2009-11-06 22:41:14 +00008import threading
9import unittest
10
11from test import support
12
13
14def _wait():
15 # A crude wait/yield function not relying on synchronization primitives.
16 time.sleep(0.01)
17
18class Bunch(object):
19 """
20 A bunch of threads.
21 """
22 def __init__(self, f, n, wait_before_exit=False):
23 """
24 Construct a bunch of `n` threads running the same function `f`.
25 If `wait_before_exit` is True, the threads won't terminate until
26 do_finish() is called.
27 """
28 self.f = f
29 self.n = n
30 self.started = []
31 self.finished = []
32 self._can_exit = not wait_before_exit
33 def task():
34 tid = get_ident()
35 self.started.append(tid)
36 try:
37 f()
38 finally:
39 self.finished.append(tid)
40 while not self._can_exit:
41 _wait()
42 for i in range(n):
43 start_new_thread(task, ())
44
45 def wait_for_started(self):
46 while len(self.started) < self.n:
47 _wait()
48
49 def wait_for_finished(self):
50 while len(self.finished) < self.n:
51 _wait()
52
53 def do_finish(self):
54 self._can_exit = True
55
56
57class BaseTestCase(unittest.TestCase):
58 def setUp(self):
59 self._threads = support.threading_setup()
60
61 def tearDown(self):
62 support.threading_cleanup(*self._threads)
63 support.reap_children()
64
Antoine Pitrou7c3e5772010-04-14 15:44:10 +000065 def assertTimeout(self, actual, expected):
66 # The waiting and/or time.time() can be imprecise, which
67 # is why comparing to the expected value would sometimes fail
68 # (especially under Windows).
69 self.assertGreaterEqual(actual, expected * 0.6)
70 # Test nothing insane happened
71 self.assertLess(actual, expected * 10.0)
72
Antoine Pitrou557934f2009-11-06 22:41:14 +000073
74class BaseLockTests(BaseTestCase):
75 """
76 Tests for both recursive and non-recursive locks.
77 """
78
79 def test_constructor(self):
80 lock = self.locktype()
81 del lock
82
83 def test_acquire_destroy(self):
84 lock = self.locktype()
85 lock.acquire()
86 del lock
87
88 def test_acquire_release(self):
89 lock = self.locktype()
90 lock.acquire()
91 lock.release()
92 del lock
93
94 def test_try_acquire(self):
95 lock = self.locktype()
96 self.assertTrue(lock.acquire(False))
97 lock.release()
98
99 def test_try_acquire_contended(self):
100 lock = self.locktype()
101 lock.acquire()
102 result = []
103 def f():
104 result.append(lock.acquire(False))
105 Bunch(f, 1).wait_for_finished()
106 self.assertFalse(result[0])
107 lock.release()
108
109 def test_acquire_contended(self):
110 lock = self.locktype()
111 lock.acquire()
112 N = 5
113 def f():
114 lock.acquire()
115 lock.release()
116
117 b = Bunch(f, N)
118 b.wait_for_started()
119 _wait()
120 self.assertEqual(len(b.finished), 0)
121 lock.release()
122 b.wait_for_finished()
123 self.assertEqual(len(b.finished), N)
124
125 def test_with(self):
126 lock = self.locktype()
127 def f():
128 lock.acquire()
129 lock.release()
130 def _with(err=None):
131 with lock:
132 if err is not None:
133 raise err
134 _with()
135 # Check the lock is unacquired
136 Bunch(f, 1).wait_for_finished()
137 self.assertRaises(TypeError, _with, TypeError)
138 # Check the lock is unacquired
139 Bunch(f, 1).wait_for_finished()
140
Antoine Pitroub0872682009-11-09 16:08:16 +0000141 def test_thread_leak(self):
142 # The lock shouldn't leak a Thread instance when used from a foreign
143 # (non-threading) thread.
144 lock = self.locktype()
145 def f():
146 lock.acquire()
147 lock.release()
148 n = len(threading.enumerate())
149 # We run many threads in the hope that existing threads ids won't
150 # be recycled.
151 Bunch(f, 15).wait_for_finished()
152 self.assertEqual(n, len(threading.enumerate()))
153
Antoine Pitrou7c3e5772010-04-14 15:44:10 +0000154 def test_timeout(self):
155 lock = self.locktype()
156 # Can't set timeout if not blocking
157 self.assertRaises(ValueError, lock.acquire, 0, 1)
158 # Invalid timeout values
159 self.assertRaises(ValueError, lock.acquire, timeout=-100)
160 self.assertRaises(OverflowError, lock.acquire, timeout=1e100)
161 self.assertRaises(OverflowError, lock.acquire, timeout=TIMEOUT_MAX + 1)
162 # TIMEOUT_MAX is ok
163 lock.acquire(timeout=TIMEOUT_MAX)
164 lock.release()
165 t1 = time.time()
166 self.assertTrue(lock.acquire(timeout=5))
167 t2 = time.time()
168 # Just a sanity test that it didn't actually wait for the timeout.
169 self.assertLess(t2 - t1, 5)
170 results = []
171 def f():
172 t1 = time.time()
173 results.append(lock.acquire(timeout=0.5))
174 t2 = time.time()
175 results.append(t2 - t1)
176 Bunch(f, 1).wait_for_finished()
177 self.assertFalse(results[0])
178 self.assertTimeout(results[1], 0.5)
179
Antoine Pitrou557934f2009-11-06 22:41:14 +0000180
181class LockTests(BaseLockTests):
182 """
183 Tests for non-recursive, weak locks
184 (which can be acquired and released from different threads).
185 """
186 def test_reacquire(self):
187 # Lock needs to be released before re-acquiring.
188 lock = self.locktype()
189 phase = []
190 def f():
191 lock.acquire()
192 phase.append(None)
193 lock.acquire()
194 phase.append(None)
195 start_new_thread(f, ())
196 while len(phase) == 0:
197 _wait()
198 _wait()
199 self.assertEqual(len(phase), 1)
200 lock.release()
201 while len(phase) == 1:
202 _wait()
203 self.assertEqual(len(phase), 2)
204
205 def test_different_thread(self):
206 # Lock can be released from a different thread.
207 lock = self.locktype()
208 lock.acquire()
209 def f():
210 lock.release()
211 b = Bunch(f, 1)
212 b.wait_for_finished()
213 lock.acquire()
214 lock.release()
215
216
217class RLockTests(BaseLockTests):
218 """
219 Tests for recursive locks.
220 """
221 def test_reacquire(self):
222 lock = self.locktype()
223 lock.acquire()
224 lock.acquire()
225 lock.release()
226 lock.acquire()
227 lock.release()
228 lock.release()
229
230 def test_release_unacquired(self):
231 # Cannot release an unacquired lock
232 lock = self.locktype()
233 self.assertRaises(RuntimeError, lock.release)
234 lock.acquire()
235 lock.acquire()
236 lock.release()
237 lock.acquire()
238 lock.release()
239 lock.release()
240 self.assertRaises(RuntimeError, lock.release)
241
242 def test_different_thread(self):
243 # Cannot release from a different thread
244 lock = self.locktype()
245 def f():
246 lock.acquire()
247 b = Bunch(f, 1, True)
248 try:
249 self.assertRaises(RuntimeError, lock.release)
250 finally:
251 b.do_finish()
252
253 def test__is_owned(self):
254 lock = self.locktype()
255 self.assertFalse(lock._is_owned())
256 lock.acquire()
257 self.assertTrue(lock._is_owned())
258 lock.acquire()
259 self.assertTrue(lock._is_owned())
260 result = []
261 def f():
262 result.append(lock._is_owned())
263 Bunch(f, 1).wait_for_finished()
264 self.assertFalse(result[0])
265 lock.release()
266 self.assertTrue(lock._is_owned())
267 lock.release()
268 self.assertFalse(lock._is_owned())
269
270
271class EventTests(BaseTestCase):
272 """
273 Tests for Event objects.
274 """
275
276 def test_is_set(self):
277 evt = self.eventtype()
278 self.assertFalse(evt.is_set())
279 evt.set()
280 self.assertTrue(evt.is_set())
281 evt.set()
282 self.assertTrue(evt.is_set())
283 evt.clear()
284 self.assertFalse(evt.is_set())
285 evt.clear()
286 self.assertFalse(evt.is_set())
287
288 def _check_notify(self, evt):
289 # All threads get notified
290 N = 5
291 results1 = []
292 results2 = []
293 def f():
294 results1.append(evt.wait())
295 results2.append(evt.wait())
296 b = Bunch(f, N)
297 b.wait_for_started()
298 _wait()
299 self.assertEqual(len(results1), 0)
300 evt.set()
301 b.wait_for_finished()
302 self.assertEqual(results1, [True] * N)
303 self.assertEqual(results2, [True] * N)
304
305 def test_notify(self):
306 evt = self.eventtype()
307 self._check_notify(evt)
308 # Another time, after an explicit clear()
309 evt.set()
310 evt.clear()
311 self._check_notify(evt)
312
313 def test_timeout(self):
314 evt = self.eventtype()
315 results1 = []
316 results2 = []
317 N = 5
318 def f():
319 results1.append(evt.wait(0.0))
320 t1 = time.time()
Antoine Pitrou7c3e5772010-04-14 15:44:10 +0000321 r = evt.wait(0.5)
Antoine Pitrou557934f2009-11-06 22:41:14 +0000322 t2 = time.time()
323 results2.append((r, t2 - t1))
324 Bunch(f, N).wait_for_finished()
325 self.assertEqual(results1, [False] * N)
326 for r, dt in results2:
327 self.assertFalse(r)
Antoine Pitrou7c3e5772010-04-14 15:44:10 +0000328 self.assertTimeout(dt, 0.5)
Antoine Pitrou557934f2009-11-06 22:41:14 +0000329 # The event is set
330 results1 = []
331 results2 = []
332 evt.set()
333 Bunch(f, N).wait_for_finished()
334 self.assertEqual(results1, [True] * N)
335 for r, dt in results2:
336 self.assertTrue(r)
337
338
339class ConditionTests(BaseTestCase):
340 """
341 Tests for condition variables.
342 """
343
344 def test_acquire(self):
345 cond = self.condtype()
346 # Be default we have an RLock: the condition can be acquired multiple
347 # times.
348 cond.acquire()
349 cond.acquire()
350 cond.release()
351 cond.release()
352 lock = threading.Lock()
353 cond = self.condtype(lock)
354 cond.acquire()
355 self.assertFalse(lock.acquire(False))
356 cond.release()
357 self.assertTrue(lock.acquire(False))
358 self.assertFalse(cond.acquire(False))
359 lock.release()
360 with cond:
361 self.assertFalse(lock.acquire(False))
362
363 def test_unacquired_wait(self):
364 cond = self.condtype()
365 self.assertRaises(RuntimeError, cond.wait)
366
367 def test_unacquired_notify(self):
368 cond = self.condtype()
369 self.assertRaises(RuntimeError, cond.notify)
370
371 def _check_notify(self, cond):
372 N = 5
373 results1 = []
374 results2 = []
375 phase_num = 0
376 def f():
377 cond.acquire()
Georg Brandlb9a43912010-10-28 09:03:20 +0000378 result = cond.wait()
Antoine Pitrou557934f2009-11-06 22:41:14 +0000379 cond.release()
Georg Brandlb9a43912010-10-28 09:03:20 +0000380 results1.append((result, phase_num))
Antoine Pitrou557934f2009-11-06 22:41:14 +0000381 cond.acquire()
Georg Brandlb9a43912010-10-28 09:03:20 +0000382 result = cond.wait()
Antoine Pitrou557934f2009-11-06 22:41:14 +0000383 cond.release()
Georg Brandlb9a43912010-10-28 09:03:20 +0000384 results2.append((result, phase_num))
Antoine Pitrou557934f2009-11-06 22:41:14 +0000385 b = Bunch(f, N)
386 b.wait_for_started()
387 _wait()
388 self.assertEqual(results1, [])
389 # Notify 3 threads at first
390 cond.acquire()
391 cond.notify(3)
392 _wait()
393 phase_num = 1
394 cond.release()
395 while len(results1) < 3:
396 _wait()
Georg Brandlb9a43912010-10-28 09:03:20 +0000397 self.assertEqual(results1, [(True, 1)] * 3)
Antoine Pitrou557934f2009-11-06 22:41:14 +0000398 self.assertEqual(results2, [])
399 # Notify 5 threads: they might be in their first or second wait
400 cond.acquire()
401 cond.notify(5)
402 _wait()
403 phase_num = 2
404 cond.release()
405 while len(results1) + len(results2) < 8:
406 _wait()
Georg Brandlb9a43912010-10-28 09:03:20 +0000407 self.assertEqual(results1, [(True, 1)] * 3 + [(True, 2)] * 2)
408 self.assertEqual(results2, [(True, 2)] * 3)
Antoine Pitrou557934f2009-11-06 22:41:14 +0000409 # Notify all threads: they are all in their second wait
410 cond.acquire()
411 cond.notify_all()
412 _wait()
413 phase_num = 3
414 cond.release()
415 while len(results2) < 5:
416 _wait()
Georg Brandlb9a43912010-10-28 09:03:20 +0000417 self.assertEqual(results1, [(True, 1)] * 3 + [(True,2)] * 2)
418 self.assertEqual(results2, [(True, 2)] * 3 + [(True, 3)] * 2)
Antoine Pitrou557934f2009-11-06 22:41:14 +0000419 b.wait_for_finished()
420
421 def test_notify(self):
422 cond = self.condtype()
423 self._check_notify(cond)
424 # A second time, to check internal state is still ok.
425 self._check_notify(cond)
426
427 def test_timeout(self):
428 cond = self.condtype()
429 results = []
430 N = 5
431 def f():
432 cond.acquire()
433 t1 = time.time()
Georg Brandlb9a43912010-10-28 09:03:20 +0000434 result = cond.wait(0.5)
Antoine Pitrou557934f2009-11-06 22:41:14 +0000435 t2 = time.time()
436 cond.release()
Georg Brandlb9a43912010-10-28 09:03:20 +0000437 results.append((t2 - t1, result))
Antoine Pitrou557934f2009-11-06 22:41:14 +0000438 Bunch(f, N).wait_for_finished()
Georg Brandlb9a43912010-10-28 09:03:20 +0000439 self.assertEqual(len(results), N)
440 for dt, result in results:
Antoine Pitrou7c3e5772010-04-14 15:44:10 +0000441 self.assertTimeout(dt, 0.5)
Georg Brandlb9a43912010-10-28 09:03:20 +0000442 # Note that conceptually (that"s the condition variable protocol)
443 # a wait() may succeed even if no one notifies us and before any
444 # timeout occurs. Spurious wakeups can occur.
445 # This makes it hard to verify the result value.
446 # In practice, this implementation has no spurious wakeups.
447 self.assertFalse(result)
Antoine Pitrou557934f2009-11-06 22:41:14 +0000448
Kristján Valur Jónsson63315202010-11-18 12:46:39 +0000449 def test_waitfor(self):
450 cond = self.condtype()
451 state = 0
452 def f():
453 with cond:
454 result = cond.wait_for(lambda : state==4)
455 self.assertTrue(result)
456 self.assertEqual(state, 4)
457 b = Bunch(f, 1)
458 b.wait_for_started()
459 for i in range(5):
460 time.sleep(0.01)
461 with cond:
462 state += 1
463 cond.notify()
464 b.wait_for_finished()
465
466 def test_waitfor_timeout(self):
467 cond = self.condtype()
468 state = 0
469 success = []
470 def f():
471 with cond:
472 dt = time.time()
473 result = cond.wait_for(lambda : state==4, timeout=0.1)
474 dt = time.time() - dt
475 self.assertFalse(result)
476 self.assertTimeout(dt, 0.1)
477 success.append(None)
478 b = Bunch(f, 1)
479 b.wait_for_started()
480 # Only increment 3 times, so state == 4 is never reached.
481 for i in range(3):
482 time.sleep(0.01)
483 with cond:
484 state += 1
485 cond.notify()
486 b.wait_for_finished()
487 self.assertEqual(len(success), 1)
488
Antoine Pitrou557934f2009-11-06 22:41:14 +0000489
490class BaseSemaphoreTests(BaseTestCase):
491 """
492 Common tests for {bounded, unbounded} semaphore objects.
493 """
494
495 def test_constructor(self):
496 self.assertRaises(ValueError, self.semtype, value = -1)
497 self.assertRaises(ValueError, self.semtype, value = -sys.maxsize)
498
499 def test_acquire(self):
500 sem = self.semtype(1)
501 sem.acquire()
502 sem.release()
503 sem = self.semtype(2)
504 sem.acquire()
505 sem.acquire()
506 sem.release()
507 sem.release()
508
509 def test_acquire_destroy(self):
510 sem = self.semtype()
511 sem.acquire()
512 del sem
513
514 def test_acquire_contended(self):
515 sem = self.semtype(7)
516 sem.acquire()
517 N = 10
518 results1 = []
519 results2 = []
520 phase_num = 0
521 def f():
522 sem.acquire()
523 results1.append(phase_num)
524 sem.acquire()
525 results2.append(phase_num)
526 b = Bunch(f, 10)
527 b.wait_for_started()
528 while len(results1) + len(results2) < 6:
529 _wait()
530 self.assertEqual(results1 + results2, [0] * 6)
531 phase_num = 1
532 for i in range(7):
533 sem.release()
534 while len(results1) + len(results2) < 13:
535 _wait()
536 self.assertEqual(sorted(results1 + results2), [0] * 6 + [1] * 7)
537 phase_num = 2
538 for i in range(6):
539 sem.release()
540 while len(results1) + len(results2) < 19:
541 _wait()
542 self.assertEqual(sorted(results1 + results2), [0] * 6 + [1] * 7 + [2] * 6)
543 # The semaphore is still locked
544 self.assertFalse(sem.acquire(False))
545 # Final release, to let the last thread finish
546 sem.release()
547 b.wait_for_finished()
548
549 def test_try_acquire(self):
550 sem = self.semtype(2)
551 self.assertTrue(sem.acquire(False))
552 self.assertTrue(sem.acquire(False))
553 self.assertFalse(sem.acquire(False))
554 sem.release()
555 self.assertTrue(sem.acquire(False))
556
557 def test_try_acquire_contended(self):
558 sem = self.semtype(4)
559 sem.acquire()
560 results = []
561 def f():
562 results.append(sem.acquire(False))
563 results.append(sem.acquire(False))
564 Bunch(f, 5).wait_for_finished()
565 # There can be a thread switch between acquiring the semaphore and
566 # appending the result, therefore results will not necessarily be
567 # ordered.
568 self.assertEqual(sorted(results), [False] * 7 + [True] * 3 )
569
Antoine Pitrou0454af92010-04-17 23:51:58 +0000570 def test_acquire_timeout(self):
571 sem = self.semtype(2)
572 self.assertRaises(ValueError, sem.acquire, False, timeout=1.0)
573 self.assertTrue(sem.acquire(timeout=0.005))
574 self.assertTrue(sem.acquire(timeout=0.005))
575 self.assertFalse(sem.acquire(timeout=0.005))
576 sem.release()
577 self.assertTrue(sem.acquire(timeout=0.005))
578 t = time.time()
579 self.assertFalse(sem.acquire(timeout=0.5))
580 dt = time.time() - t
581 self.assertTimeout(dt, 0.5)
582
Antoine Pitrou557934f2009-11-06 22:41:14 +0000583 def test_default_value(self):
584 # The default initial value is 1.
585 sem = self.semtype()
586 sem.acquire()
587 def f():
588 sem.acquire()
589 sem.release()
590 b = Bunch(f, 1)
591 b.wait_for_started()
592 _wait()
593 self.assertFalse(b.finished)
594 sem.release()
595 b.wait_for_finished()
596
597 def test_with(self):
598 sem = self.semtype(2)
599 def _with(err=None):
600 with sem:
601 self.assertTrue(sem.acquire(False))
602 sem.release()
603 with sem:
604 self.assertFalse(sem.acquire(False))
605 if err:
606 raise err
607 _with()
608 self.assertTrue(sem.acquire(False))
609 sem.release()
610 self.assertRaises(TypeError, _with, TypeError)
611 self.assertTrue(sem.acquire(False))
612 sem.release()
613
614class SemaphoreTests(BaseSemaphoreTests):
615 """
616 Tests for unbounded semaphores.
617 """
618
619 def test_release_unacquired(self):
620 # Unbounded releases are allowed and increment the semaphore's value
621 sem = self.semtype(1)
622 sem.release()
623 sem.acquire()
624 sem.acquire()
625 sem.release()
626
627
628class BoundedSemaphoreTests(BaseSemaphoreTests):
629 """
630 Tests for bounded semaphores.
631 """
632
633 def test_release_unacquired(self):
634 # Cannot go past the initial value
635 sem = self.semtype()
636 self.assertRaises(ValueError, sem.release)
637 sem.acquire()
638 sem.release()
639 self.assertRaises(ValueError, sem.release)
Kristján Valur Jónsson3be00032010-10-28 09:43:10 +0000640
641
642class BarrierTests(BaseTestCase):
643 """
644 Tests for Barrier objects.
645 """
646 N = 5
Antoine Pitrou12ae2902010-11-17 21:55:41 +0000647 defaultTimeout = 2.0
Kristján Valur Jónsson3be00032010-10-28 09:43:10 +0000648
649 def setUp(self):
Kristján Valur Jónssonf53a6262010-10-31 03:00:57 +0000650 self.barrier = self.barriertype(self.N, timeout=self.defaultTimeout)
Kristján Valur Jónsson3be00032010-10-28 09:43:10 +0000651 def tearDown(self):
652 self.barrier.abort()
653
654 def run_threads(self, f):
655 b = Bunch(f, self.N-1)
656 f()
657 b.wait_for_finished()
658
659 def multipass(self, results, n):
660 m = self.barrier.parties
661 self.assertEqual(m, self.N)
662 for i in range(n):
663 results[0].append(True)
664 self.assertEqual(len(results[1]), i * m)
665 self.barrier.wait()
666 results[1].append(True)
667 self.assertEqual(len(results[0]), (i + 1) * m)
668 self.barrier.wait()
669 self.assertEqual(self.barrier.n_waiting, 0)
670 self.assertFalse(self.barrier.broken)
671
672 def test_barrier(self, passes=1):
673 """
674 Test that a barrier is passed in lockstep
675 """
676 results = [[],[]]
677 def f():
678 self.multipass(results, passes)
679 self.run_threads(f)
680
681 def test_barrier_10(self):
682 """
683 Test that a barrier works for 10 consecutive runs
684 """
685 return self.test_barrier(10)
686
687 def test_wait_return(self):
688 """
689 test the return value from barrier.wait
690 """
691 results = []
692 def f():
693 r = self.barrier.wait()
694 results.append(r)
695
696 self.run_threads(f)
697 self.assertEqual(sum(results), sum(range(self.N)))
698
699 def test_action(self):
700 """
701 Test the 'action' callback
702 """
703 results = []
704 def action():
705 results.append(True)
706 barrier = self.barriertype(self.N, action)
707 def f():
708 barrier.wait()
709 self.assertEqual(len(results), 1)
710
711 self.run_threads(f)
712
713 def test_abort(self):
714 """
715 Test that an abort will put the barrier in a broken state
716 """
717 results1 = []
718 results2 = []
719 def f():
720 try:
721 i = self.barrier.wait()
722 if i == self.N//2:
723 raise RuntimeError
724 self.barrier.wait()
725 results1.append(True)
726 except threading.BrokenBarrierError:
727 results2.append(True)
728 except RuntimeError:
729 self.barrier.abort()
730 pass
731
732 self.run_threads(f)
733 self.assertEqual(len(results1), 0)
734 self.assertEqual(len(results2), self.N-1)
735 self.assertTrue(self.barrier.broken)
736
737 def test_reset(self):
738 """
739 Test that a 'reset' on a barrier frees the waiting threads
740 """
741 results1 = []
742 results2 = []
743 results3 = []
744 def f():
745 i = self.barrier.wait()
746 if i == self.N//2:
747 # Wait until the other threads are all in the barrier.
748 while self.barrier.n_waiting < self.N-1:
749 time.sleep(0.001)
750 self.barrier.reset()
751 else:
752 try:
753 self.barrier.wait()
754 results1.append(True)
755 except threading.BrokenBarrierError:
756 results2.append(True)
757 # Now, pass the barrier again
758 self.barrier.wait()
759 results3.append(True)
760
761 self.run_threads(f)
762 self.assertEqual(len(results1), 0)
763 self.assertEqual(len(results2), self.N-1)
764 self.assertEqual(len(results3), self.N)
765
766
767 def test_abort_and_reset(self):
768 """
769 Test that a barrier can be reset after being broken.
770 """
771 results1 = []
772 results2 = []
773 results3 = []
774 barrier2 = self.barriertype(self.N)
775 def f():
776 try:
777 i = self.barrier.wait()
778 if i == self.N//2:
779 raise RuntimeError
780 self.barrier.wait()
781 results1.append(True)
782 except threading.BrokenBarrierError:
783 results2.append(True)
784 except RuntimeError:
785 self.barrier.abort()
786 pass
787 # Synchronize and reset the barrier. Must synchronize first so
788 # that everyone has left it when we reset, and after so that no
789 # one enters it before the reset.
790 if barrier2.wait() == self.N//2:
791 self.barrier.reset()
792 barrier2.wait()
793 self.barrier.wait()
794 results3.append(True)
795
796 self.run_threads(f)
797 self.assertEqual(len(results1), 0)
798 self.assertEqual(len(results2), self.N-1)
799 self.assertEqual(len(results3), self.N)
800
801 def test_timeout(self):
802 """
803 Test wait(timeout)
804 """
805 def f():
806 i = self.barrier.wait()
807 if i == self.N // 2:
808 # One thread is late!
Antoine Pitrou12ae2902010-11-17 21:55:41 +0000809 time.sleep(1.0)
810 # Default timeout is 2.0, so this is shorter.
Kristján Valur Jónsson3be00032010-10-28 09:43:10 +0000811 self.assertRaises(threading.BrokenBarrierError,
Antoine Pitrou12ae2902010-11-17 21:55:41 +0000812 self.barrier.wait, 0.5)
Kristján Valur Jónsson3be00032010-10-28 09:43:10 +0000813 self.run_threads(f)
814
815 def test_default_timeout(self):
816 """
817 Test the barrier's default timeout
818 """
Kristján Valur Jónssonf53a6262010-10-31 03:00:57 +0000819 #create a barrier with a low default timeout
820 barrier = self.barriertype(self.N, timeout=0.1)
Kristján Valur Jónsson3be00032010-10-28 09:43:10 +0000821 def f():
Kristján Valur Jónssonf53a6262010-10-31 03:00:57 +0000822 i = barrier.wait()
Kristján Valur Jónsson3be00032010-10-28 09:43:10 +0000823 if i == self.N // 2:
824 # One thread is later than the default timeout of 0.1s.
Antoine Pitrou12ae2902010-11-17 21:55:41 +0000825 time.sleep(1.0)
Kristján Valur Jónssonf53a6262010-10-31 03:00:57 +0000826 self.assertRaises(threading.BrokenBarrierError, barrier.wait)
Kristján Valur Jónsson3be00032010-10-28 09:43:10 +0000827 self.run_threads(f)
828
829 def test_single_thread(self):
830 b = self.barriertype(1)
831 b.wait()
832 b.wait()