blob: f9c2259914c2002757c37f93859ab166dd1886a5 [file] [log] [blame]
Antoine Pitrou557934f2009-11-06 22:41:14 +00001"""
2Various tests for synchronization primitives.
3"""
4
5import sys
6import time
7from _thread import start_new_thread, get_ident
8import threading
9import unittest
10
11from test import support
12
13
14def _wait():
15 # A crude wait/yield function not relying on synchronization primitives.
16 time.sleep(0.01)
17
18class Bunch(object):
19 """
20 A bunch of threads.
21 """
22 def __init__(self, f, n, wait_before_exit=False):
23 """
24 Construct a bunch of `n` threads running the same function `f`.
25 If `wait_before_exit` is True, the threads won't terminate until
26 do_finish() is called.
27 """
28 self.f = f
29 self.n = n
30 self.started = []
31 self.finished = []
32 self._can_exit = not wait_before_exit
33 def task():
34 tid = get_ident()
35 self.started.append(tid)
36 try:
37 f()
38 finally:
39 self.finished.append(tid)
40 while not self._can_exit:
41 _wait()
42 for i in range(n):
43 start_new_thread(task, ())
44
45 def wait_for_started(self):
46 while len(self.started) < self.n:
47 _wait()
48
49 def wait_for_finished(self):
50 while len(self.finished) < self.n:
51 _wait()
52
53 def do_finish(self):
54 self._can_exit = True
55
56
57class BaseTestCase(unittest.TestCase):
58 def setUp(self):
59 self._threads = support.threading_setup()
60
61 def tearDown(self):
62 support.threading_cleanup(*self._threads)
63 support.reap_children()
64
65
66class BaseLockTests(BaseTestCase):
67 """
68 Tests for both recursive and non-recursive locks.
69 """
70
71 def test_constructor(self):
72 lock = self.locktype()
73 del lock
74
75 def test_acquire_destroy(self):
76 lock = self.locktype()
77 lock.acquire()
78 del lock
79
80 def test_acquire_release(self):
81 lock = self.locktype()
82 lock.acquire()
83 lock.release()
84 del lock
85
86 def test_try_acquire(self):
87 lock = self.locktype()
88 self.assertTrue(lock.acquire(False))
89 lock.release()
90
91 def test_try_acquire_contended(self):
92 lock = self.locktype()
93 lock.acquire()
94 result = []
95 def f():
96 result.append(lock.acquire(False))
97 Bunch(f, 1).wait_for_finished()
98 self.assertFalse(result[0])
99 lock.release()
100
101 def test_acquire_contended(self):
102 lock = self.locktype()
103 lock.acquire()
104 N = 5
105 def f():
106 lock.acquire()
107 lock.release()
108
109 b = Bunch(f, N)
110 b.wait_for_started()
111 _wait()
112 self.assertEqual(len(b.finished), 0)
113 lock.release()
114 b.wait_for_finished()
115 self.assertEqual(len(b.finished), N)
116
117 def test_with(self):
118 lock = self.locktype()
119 def f():
120 lock.acquire()
121 lock.release()
122 def _with(err=None):
123 with lock:
124 if err is not None:
125 raise err
126 _with()
127 # Check the lock is unacquired
128 Bunch(f, 1).wait_for_finished()
129 self.assertRaises(TypeError, _with, TypeError)
130 # Check the lock is unacquired
131 Bunch(f, 1).wait_for_finished()
132
133
134class LockTests(BaseLockTests):
135 """
136 Tests for non-recursive, weak locks
137 (which can be acquired and released from different threads).
138 """
139 def test_reacquire(self):
140 # Lock needs to be released before re-acquiring.
141 lock = self.locktype()
142 phase = []
143 def f():
144 lock.acquire()
145 phase.append(None)
146 lock.acquire()
147 phase.append(None)
148 start_new_thread(f, ())
149 while len(phase) == 0:
150 _wait()
151 _wait()
152 self.assertEqual(len(phase), 1)
153 lock.release()
154 while len(phase) == 1:
155 _wait()
156 self.assertEqual(len(phase), 2)
157
158 def test_different_thread(self):
159 # Lock can be released from a different thread.
160 lock = self.locktype()
161 lock.acquire()
162 def f():
163 lock.release()
164 b = Bunch(f, 1)
165 b.wait_for_finished()
166 lock.acquire()
167 lock.release()
168
169
170class RLockTests(BaseLockTests):
171 """
172 Tests for recursive locks.
173 """
174 def test_reacquire(self):
175 lock = self.locktype()
176 lock.acquire()
177 lock.acquire()
178 lock.release()
179 lock.acquire()
180 lock.release()
181 lock.release()
182
183 def test_release_unacquired(self):
184 # Cannot release an unacquired lock
185 lock = self.locktype()
186 self.assertRaises(RuntimeError, lock.release)
187 lock.acquire()
188 lock.acquire()
189 lock.release()
190 lock.acquire()
191 lock.release()
192 lock.release()
193 self.assertRaises(RuntimeError, lock.release)
194
195 def test_different_thread(self):
196 # Cannot release from a different thread
197 lock = self.locktype()
198 def f():
199 lock.acquire()
200 b = Bunch(f, 1, True)
201 try:
202 self.assertRaises(RuntimeError, lock.release)
203 finally:
204 b.do_finish()
205
206 def test__is_owned(self):
207 lock = self.locktype()
208 self.assertFalse(lock._is_owned())
209 lock.acquire()
210 self.assertTrue(lock._is_owned())
211 lock.acquire()
212 self.assertTrue(lock._is_owned())
213 result = []
214 def f():
215 result.append(lock._is_owned())
216 Bunch(f, 1).wait_for_finished()
217 self.assertFalse(result[0])
218 lock.release()
219 self.assertTrue(lock._is_owned())
220 lock.release()
221 self.assertFalse(lock._is_owned())
222
223
224class EventTests(BaseTestCase):
225 """
226 Tests for Event objects.
227 """
228
229 def test_is_set(self):
230 evt = self.eventtype()
231 self.assertFalse(evt.is_set())
232 evt.set()
233 self.assertTrue(evt.is_set())
234 evt.set()
235 self.assertTrue(evt.is_set())
236 evt.clear()
237 self.assertFalse(evt.is_set())
238 evt.clear()
239 self.assertFalse(evt.is_set())
240
241 def _check_notify(self, evt):
242 # All threads get notified
243 N = 5
244 results1 = []
245 results2 = []
246 def f():
247 results1.append(evt.wait())
248 results2.append(evt.wait())
249 b = Bunch(f, N)
250 b.wait_for_started()
251 _wait()
252 self.assertEqual(len(results1), 0)
253 evt.set()
254 b.wait_for_finished()
255 self.assertEqual(results1, [True] * N)
256 self.assertEqual(results2, [True] * N)
257
258 def test_notify(self):
259 evt = self.eventtype()
260 self._check_notify(evt)
261 # Another time, after an explicit clear()
262 evt.set()
263 evt.clear()
264 self._check_notify(evt)
265
266 def test_timeout(self):
267 evt = self.eventtype()
268 results1 = []
269 results2 = []
270 N = 5
271 def f():
272 results1.append(evt.wait(0.0))
273 t1 = time.time()
274 r = evt.wait(0.2)
275 t2 = time.time()
276 results2.append((r, t2 - t1))
277 Bunch(f, N).wait_for_finished()
278 self.assertEqual(results1, [False] * N)
279 for r, dt in results2:
280 self.assertFalse(r)
281 self.assertTrue(dt >= 0.2, dt)
282 # The event is set
283 results1 = []
284 results2 = []
285 evt.set()
286 Bunch(f, N).wait_for_finished()
287 self.assertEqual(results1, [True] * N)
288 for r, dt in results2:
289 self.assertTrue(r)
290
291
292class ConditionTests(BaseTestCase):
293 """
294 Tests for condition variables.
295 """
296
297 def test_acquire(self):
298 cond = self.condtype()
299 # Be default we have an RLock: the condition can be acquired multiple
300 # times.
301 cond.acquire()
302 cond.acquire()
303 cond.release()
304 cond.release()
305 lock = threading.Lock()
306 cond = self.condtype(lock)
307 cond.acquire()
308 self.assertFalse(lock.acquire(False))
309 cond.release()
310 self.assertTrue(lock.acquire(False))
311 self.assertFalse(cond.acquire(False))
312 lock.release()
313 with cond:
314 self.assertFalse(lock.acquire(False))
315
316 def test_unacquired_wait(self):
317 cond = self.condtype()
318 self.assertRaises(RuntimeError, cond.wait)
319
320 def test_unacquired_notify(self):
321 cond = self.condtype()
322 self.assertRaises(RuntimeError, cond.notify)
323
324 def _check_notify(self, cond):
325 N = 5
326 results1 = []
327 results2 = []
328 phase_num = 0
329 def f():
330 cond.acquire()
331 cond.wait()
332 cond.release()
333 results1.append(phase_num)
334 cond.acquire()
335 cond.wait()
336 cond.release()
337 results2.append(phase_num)
338 b = Bunch(f, N)
339 b.wait_for_started()
340 _wait()
341 self.assertEqual(results1, [])
342 # Notify 3 threads at first
343 cond.acquire()
344 cond.notify(3)
345 _wait()
346 phase_num = 1
347 cond.release()
348 while len(results1) < 3:
349 _wait()
350 self.assertEqual(results1, [1] * 3)
351 self.assertEqual(results2, [])
352 # Notify 5 threads: they might be in their first or second wait
353 cond.acquire()
354 cond.notify(5)
355 _wait()
356 phase_num = 2
357 cond.release()
358 while len(results1) + len(results2) < 8:
359 _wait()
360 self.assertEqual(results1, [1] * 3 + [2] * 2)
361 self.assertEqual(results2, [2] * 3)
362 # Notify all threads: they are all in their second wait
363 cond.acquire()
364 cond.notify_all()
365 _wait()
366 phase_num = 3
367 cond.release()
368 while len(results2) < 5:
369 _wait()
370 self.assertEqual(results1, [1] * 3 + [2] * 2)
371 self.assertEqual(results2, [2] * 3 + [3] * 2)
372 b.wait_for_finished()
373
374 def test_notify(self):
375 cond = self.condtype()
376 self._check_notify(cond)
377 # A second time, to check internal state is still ok.
378 self._check_notify(cond)
379
380 def test_timeout(self):
381 cond = self.condtype()
382 results = []
383 N = 5
384 def f():
385 cond.acquire()
386 t1 = time.time()
387 cond.wait(0.2)
388 t2 = time.time()
389 cond.release()
390 results.append(t2 - t1)
391 Bunch(f, N).wait_for_finished()
392 self.assertEqual(len(results), 5)
393 for dt in results:
394 self.assertTrue(dt >= 0.2, dt)
395
396
397class BaseSemaphoreTests(BaseTestCase):
398 """
399 Common tests for {bounded, unbounded} semaphore objects.
400 """
401
402 def test_constructor(self):
403 self.assertRaises(ValueError, self.semtype, value = -1)
404 self.assertRaises(ValueError, self.semtype, value = -sys.maxsize)
405
406 def test_acquire(self):
407 sem = self.semtype(1)
408 sem.acquire()
409 sem.release()
410 sem = self.semtype(2)
411 sem.acquire()
412 sem.acquire()
413 sem.release()
414 sem.release()
415
416 def test_acquire_destroy(self):
417 sem = self.semtype()
418 sem.acquire()
419 del sem
420
421 def test_acquire_contended(self):
422 sem = self.semtype(7)
423 sem.acquire()
424 N = 10
425 results1 = []
426 results2 = []
427 phase_num = 0
428 def f():
429 sem.acquire()
430 results1.append(phase_num)
431 sem.acquire()
432 results2.append(phase_num)
433 b = Bunch(f, 10)
434 b.wait_for_started()
435 while len(results1) + len(results2) < 6:
436 _wait()
437 self.assertEqual(results1 + results2, [0] * 6)
438 phase_num = 1
439 for i in range(7):
440 sem.release()
441 while len(results1) + len(results2) < 13:
442 _wait()
443 self.assertEqual(sorted(results1 + results2), [0] * 6 + [1] * 7)
444 phase_num = 2
445 for i in range(6):
446 sem.release()
447 while len(results1) + len(results2) < 19:
448 _wait()
449 self.assertEqual(sorted(results1 + results2), [0] * 6 + [1] * 7 + [2] * 6)
450 # The semaphore is still locked
451 self.assertFalse(sem.acquire(False))
452 # Final release, to let the last thread finish
453 sem.release()
454 b.wait_for_finished()
455
456 def test_try_acquire(self):
457 sem = self.semtype(2)
458 self.assertTrue(sem.acquire(False))
459 self.assertTrue(sem.acquire(False))
460 self.assertFalse(sem.acquire(False))
461 sem.release()
462 self.assertTrue(sem.acquire(False))
463
464 def test_try_acquire_contended(self):
465 sem = self.semtype(4)
466 sem.acquire()
467 results = []
468 def f():
469 results.append(sem.acquire(False))
470 results.append(sem.acquire(False))
471 Bunch(f, 5).wait_for_finished()
472 # There can be a thread switch between acquiring the semaphore and
473 # appending the result, therefore results will not necessarily be
474 # ordered.
475 self.assertEqual(sorted(results), [False] * 7 + [True] * 3 )
476
477 def test_default_value(self):
478 # The default initial value is 1.
479 sem = self.semtype()
480 sem.acquire()
481 def f():
482 sem.acquire()
483 sem.release()
484 b = Bunch(f, 1)
485 b.wait_for_started()
486 _wait()
487 self.assertFalse(b.finished)
488 sem.release()
489 b.wait_for_finished()
490
491 def test_with(self):
492 sem = self.semtype(2)
493 def _with(err=None):
494 with sem:
495 self.assertTrue(sem.acquire(False))
496 sem.release()
497 with sem:
498 self.assertFalse(sem.acquire(False))
499 if err:
500 raise err
501 _with()
502 self.assertTrue(sem.acquire(False))
503 sem.release()
504 self.assertRaises(TypeError, _with, TypeError)
505 self.assertTrue(sem.acquire(False))
506 sem.release()
507
508class SemaphoreTests(BaseSemaphoreTests):
509 """
510 Tests for unbounded semaphores.
511 """
512
513 def test_release_unacquired(self):
514 # Unbounded releases are allowed and increment the semaphore's value
515 sem = self.semtype(1)
516 sem.release()
517 sem.acquire()
518 sem.acquire()
519 sem.release()
520
521
522class BoundedSemaphoreTests(BaseSemaphoreTests):
523 """
524 Tests for bounded semaphores.
525 """
526
527 def test_release_unacquired(self):
528 # Cannot go past the initial value
529 sem = self.semtype()
530 self.assertRaises(ValueError, sem.release)
531 sem.acquire()
532 sem.release()
533 self.assertRaises(ValueError, sem.release)