Merged revisions 77794 via svnmerge from
svn+ssh://pythondev@svn.python.org/python/trunk

........
  r77794 | jesse.noller | 2010-01-26 22:05:57 -0500 (Tue, 26 Jan 2010) | 1 line

  Issue #6963: Added maxtasksperchild argument to multiprocessing.Pool
........
diff --git a/Lib/multiprocessing/__init__.py b/Lib/multiprocessing/__init__.py
index 5a13742..e4af68b 100644
--- a/Lib/multiprocessing/__init__.py
+++ b/Lib/multiprocessing/__init__.py
@@ -218,12 +218,12 @@
     from multiprocessing.queues import JoinableQueue
     return JoinableQueue(maxsize)
 
-def Pool(processes=None, initializer=None, initargs=()):
+def Pool(processes=None, initializer=None, initargs=(), maxtasksperchild=None):
     '''
     Returns a process pool object
     '''
     from multiprocessing.pool import Pool
-    return Pool(processes, initializer, initargs)
+    return Pool(processes, initializer, initargs, maxtasksperchild)
 
 def RawValue(typecode_or_type, *args):
     '''
diff --git a/Lib/multiprocessing/pool.py b/Lib/multiprocessing/pool.py
index d3ecc9b..6271b86 100644
--- a/Lib/multiprocessing/pool.py
+++ b/Lib/multiprocessing/pool.py
@@ -42,7 +42,8 @@
 # Code run by worker processes
 #
 
-def worker(inqueue, outqueue, initializer=None, initargs=()):
+def worker(inqueue, outqueue, initializer=None, initargs=(), maxtasks=None):
+    assert maxtasks is None or (type(maxtasks) == int and maxtasks > 0)
     put = outqueue.put
     get = inqueue.get
     if hasattr(inqueue, '_writer'):
@@ -52,7 +53,8 @@
     if initializer is not None:
         initializer(*initargs)
 
-    while 1:
+    completed = 0
+    while maxtasks is None or (maxtasks and completed < maxtasks):
         try:
             task = get()
         except (EOFError, IOError):
@@ -69,6 +71,8 @@
         except Exception as e:
             result = (False, e)
         put((job, i, result))
+        completed += 1
+    debug('worker exiting after %d tasks' % completed)
 
 #
 # Class representing a process pool
@@ -80,11 +84,15 @@
     '''
     Process = Process
 
-    def __init__(self, processes=None, initializer=None, initargs=()):
+    def __init__(self, processes=None, initializer=None, initargs=(),
+                 maxtasksperchild=None):
         self._setup_queues()
         self._taskqueue = queue.Queue()
         self._cache = {}
         self._state = RUN
+        self._maxtasksperchild = maxtasksperchild
+        self._initializer = initializer
+        self._initargs = initargs
 
         if processes is None:
             try:
@@ -95,16 +103,18 @@
         if initializer is not None and not hasattr(initializer, '__call__'):
             raise TypeError('initializer must be a callable')
 
+        self._processes = processes
         self._pool = []
-        for i in range(processes):
-            w = self.Process(
-                target=worker,
-                args=(self._inqueue, self._outqueue, initializer, initargs)
-                )
-            self._pool.append(w)
-            w.name = w.name.replace('Process', 'PoolWorker')
-            w.daemon = True
-            w.start()
+        self._repopulate_pool()
+
+        self._worker_handler = threading.Thread(
+            target=Pool._handle_workers,
+            args=(self, )
+            )
+        self._worker_handler.daemon = True
+        self._worker_handler._state = RUN
+        self._worker_handler.start()
+
 
         self._task_handler = threading.Thread(
             target=Pool._handle_tasks,
@@ -125,10 +135,48 @@
         self._terminate = Finalize(
             self, self._terminate_pool,
             args=(self._taskqueue, self._inqueue, self._outqueue, self._pool,
-                  self._task_handler, self._result_handler, self._cache),
+                  self._worker_handler, self._task_handler,
+                  self._result_handler, self._cache),
             exitpriority=15
             )
 
+    def _join_exited_workers(self):
+        """Cleanup after any worker processes which have exited due to reaching
+        their specified lifetime.  Returns True if any workers were cleaned up.
+        """
+        cleaned = False
+        for i in reversed(range(len(self._pool))):
+            worker = self._pool[i]
+            if worker.exitcode is not None:
+                # worker exited
+                debug('cleaning up worker %d' % i)
+                worker.join()
+                cleaned = True
+                del self._pool[i]
+        return cleaned
+
+    def _repopulate_pool(self):
+        """Bring the number of pool processes up to the specified number,
+        for use after reaping workers which have exited.
+        """
+        for i in range(self._processes - len(self._pool)):
+            w = self.Process(target=worker,
+                             args=(self._inqueue, self._outqueue,
+                                   self._initializer,
+                                   self._initargs, self._maxtasksperchild)
+                            )
+            self._pool.append(w)
+            w.name = w.name.replace('Process', 'PoolWorker')
+            w.daemon = True
+            w.start()
+            debug('added worker')
+
+    def _maintain_pool(self):
+        """Clean up any exited workers and start replacements for them.
+        """
+        if self._join_exited_workers():
+            self._repopulate_pool()
+
     def _setup_queues(self):
         from .queues import SimpleQueue
         self._inqueue = SimpleQueue()
@@ -218,6 +266,13 @@
         return result
 
     @staticmethod
+    def _handle_workers(pool):
+        while pool._worker_handler._state == RUN and pool._state == RUN:
+            pool._maintain_pool()
+            time.sleep(0.1)
+        debug('worker handler exiting')
+
+    @staticmethod
     def _handle_tasks(taskqueue, put, outqueue, pool):
         thread = threading.current_thread()
 
@@ -332,16 +387,19 @@
         debug('closing pool')
         if self._state == RUN:
             self._state = CLOSE
+            self._worker_handler._state = CLOSE
             self._taskqueue.put(None)
 
     def terminate(self):
         debug('terminating pool')
         self._state = TERMINATE
+        self._worker_handler._state = TERMINATE
         self._terminate()
 
     def join(self):
         debug('joining pool')
         assert self._state in (CLOSE, TERMINATE)
+        self._worker_handler.join()
         self._task_handler.join()
         self._result_handler.join()
         for p in self._pool:
@@ -358,10 +416,11 @@
 
     @classmethod
     def _terminate_pool(cls, taskqueue, inqueue, outqueue, pool,
-                        task_handler, result_handler, cache):
+                        worker_handler, task_handler, result_handler, cache):
         # this is guaranteed to only be called once
         debug('finalizing pool')
 
+        worker_handler._state = TERMINATE
         task_handler._state = TERMINATE
         taskqueue.put(None)                 # sentinel
 
@@ -373,10 +432,12 @@
         result_handler._state = TERMINATE
         outqueue.put(None)                  # sentinel
 
+        # Terminate workers which haven't already finished.
         if pool and hasattr(pool[0], 'terminate'):
             debug('terminating workers')
             for p in pool:
-                p.terminate()
+                if p.exitcode is None:
+                    p.terminate()
 
         debug('joining task handler')
         task_handler.join(1e100)
@@ -388,6 +449,11 @@
             debug('joining pool workers')
             for p in pool:
                 p.join()
+            for w in pool:
+                if w.exitcode is None:
+                    # worker has not yet exited
+                    debug('cleaning up worker %d' % w.pid)
+                    w.join()
 
 #
 # Class whose instances are returned by `Pool.apply_async()`