Alex Deymo | e5e21f7 | 2013-07-15 16:57:48 -0700 | [diff] [blame] | 1 | # Copyright (c) 2013 The Chromium OS Authors. All rights reserved. |
| 2 | # Use of this source code is governed by a BSD-style license that can be |
| 3 | # found in the LICENSE file. |
| 4 | |
| 5 | from multiprocessing import Queue, queues |
| 6 | |
| 7 | |
| 8 | class QueueBarrierTimeout(Exception): |
| 9 | """QueueBarrier timeout exception.""" |
| 10 | |
| 11 | |
| 12 | class QueueBarrier(object): |
| 13 | """This class implements a simple barrier to synchronize processes. The |
| 14 | barrier relies on the fact that there a single process "master" and |n| |
| 15 | different "slaves" to make the implementation simpler. Also, given this |
| 16 | hierarchy, the slaves and the master can exchange a token while passing |
| 17 | through the barrier. |
| 18 | |
| 19 | The so called "master" shall call master_barrier() while the "slave" shall |
| 20 | call the slave_barrier() method. |
| 21 | |
| 22 | If the same group of |n| slaves and the same master are participating in the |
| 23 | barrier, it is totally safe to reuse the barrier several times with the same |
| 24 | group of processes. |
| 25 | """ |
| 26 | |
| 27 | |
| 28 | def __init__(self, n): |
| 29 | """Initializes the barrier with |n| slave processes and a master. |
| 30 | |
| 31 | @param n: The number of slave processes.""" |
| 32 | self.n_ = n |
| 33 | self.queue_master_ = Queue() |
| 34 | self.queue_slave_ = Queue() |
| 35 | |
| 36 | |
| 37 | def master_barrier(self, token=None, timeout=None): |
| 38 | """Makes the master wait until all the "n" slaves have reached this |
| 39 | point. |
| 40 | |
| 41 | @param token: A value passed to every slave. |
| 42 | @param timeout: The timeout, in seconds, to wait for the slaves. |
| 43 | A None value will block forever. |
| 44 | |
| 45 | Returns the list of received tokens from the slaves. |
| 46 | """ |
| 47 | # Wait for all the slaves. |
| 48 | result = [] |
| 49 | try: |
| 50 | for _ in range(self.n_): |
| 51 | result.append(self.queue_master_.get(timeout=timeout)) |
| 52 | except queues.Empty: |
| 53 | # Timeout expired |
| 54 | raise QueueBarrierTimeout() |
| 55 | # Release all the blocked slaves. |
| 56 | for _ in range(self.n_): |
| 57 | self.queue_slave_.put(token) |
| 58 | return result |
| 59 | |
| 60 | |
| 61 | def slave_barrier(self, token=None, timeout=None): |
| 62 | """Makes a slave wait until all the "n" slaves and the master have |
| 63 | reached this point. |
| 64 | |
| 65 | @param token: A value passed to the master. |
| 66 | @param timeout: The timeout, in seconds, to wait for the slaves. |
| 67 | A None value will block forever. |
| 68 | """ |
| 69 | self.queue_master_.put(token) |
| 70 | try: |
| 71 | return self.queue_slave_.get(timeout=timeout) |
| 72 | except queues.Empty: |
| 73 | # Timeout expired |
| 74 | raise QueueBarrierTimeout() |