bpo-35378: Fix multiprocessing.Pool references (GH-11627)
Changes in this commit:
1. Use a _strong_ reference between the Pool and associated iterators
2. Rework PR #8450 to eliminate a cycle in the Pool.
There is no test in this commit because any test that automatically tests this behaviour needs to eliminate the pool before joining the pool to check that the pool object is garbaged collected/does not hang. But doing this will potentially leak threads and processes (see https://bugs.python.org/issue35413).
diff --git a/Lib/multiprocessing/pool.py b/Lib/multiprocessing/pool.py
index bfb2769..18a56f8 100644
--- a/Lib/multiprocessing/pool.py
+++ b/Lib/multiprocessing/pool.py
@@ -151,8 +151,9 @@
'''
_wrap_exception = True
- def Process(self, *args, **kwds):
- return self._ctx.Process(*args, **kwds)
+ @staticmethod
+ def Process(ctx, *args, **kwds):
+ return ctx.Process(*args, **kwds)
def __init__(self, processes=None, initializer=None, initargs=(),
maxtasksperchild=None, context=None):
@@ -190,7 +191,10 @@
self._worker_handler = threading.Thread(
target=Pool._handle_workers,
- args=(self, )
+ args=(self._cache, self._taskqueue, self._ctx, self.Process,
+ self._processes, self._pool, self._inqueue, self._outqueue,
+ self._initializer, self._initargs, self._maxtasksperchild,
+ self._wrap_exception)
)
self._worker_handler.daemon = True
self._worker_handler._state = RUN
@@ -236,43 +240,61 @@
f'state={self._state} '
f'pool_size={len(self._pool)}>')
- def _join_exited_workers(self):
+ @staticmethod
+ def _join_exited_workers(pool):
"""Cleanup after any worker processes which have exited due to reaching
their specified lifetime. Returns True if any workers were cleaned up.
"""
cleaned = False
- for i in reversed(range(len(self._pool))):
- worker = self._pool[i]
+ for i in reversed(range(len(pool))):
+ worker = pool[i]
if worker.exitcode is not None:
# worker exited
util.debug('cleaning up worker %d' % i)
worker.join()
cleaned = True
- del self._pool[i]
+ del pool[i]
return cleaned
def _repopulate_pool(self):
+ return self._repopulate_pool_static(self._ctx, self.Process,
+ self._processes,
+ self._pool, self._inqueue,
+ self._outqueue, self._initializer,
+ self._initargs,
+ self._maxtasksperchild,
+ self._wrap_exception)
+
+ @staticmethod
+ def _repopulate_pool_static(ctx, Process, processes, pool, inqueue,
+ outqueue, initializer, initargs,
+ maxtasksperchild, wrap_exception):
"""Bring the number of pool processes up to the specified number,
for use after reaping workers which have exited.
"""
- for i in range(self._processes - len(self._pool)):
- w = self.Process(target=worker,
- args=(self._inqueue, self._outqueue,
- self._initializer,
- self._initargs, self._maxtasksperchild,
- self._wrap_exception)
- )
+ for i in range(processes - len(pool)):
+ w = Process(ctx, target=worker,
+ args=(inqueue, outqueue,
+ initializer,
+ initargs, maxtasksperchild,
+ wrap_exception))
w.name = w.name.replace('Process', 'PoolWorker')
w.daemon = True
w.start()
- self._pool.append(w)
+ pool.append(w)
util.debug('added worker')
- def _maintain_pool(self):
+ @staticmethod
+ def _maintain_pool(ctx, Process, processes, pool, inqueue, outqueue,
+ initializer, initargs, maxtasksperchild,
+ wrap_exception):
"""Clean up any exited workers and start replacements for them.
"""
- if self._join_exited_workers():
- self._repopulate_pool()
+ if Pool._join_exited_workers(pool):
+ Pool._repopulate_pool_static(ctx, Process, processes, pool,
+ inqueue, outqueue, initializer,
+ initargs, maxtasksperchild,
+ wrap_exception)
def _setup_queues(self):
self._inqueue = self._ctx.SimpleQueue()
@@ -331,7 +353,7 @@
'''
self._check_running()
if chunksize == 1:
- result = IMapIterator(self._cache)
+ result = IMapIterator(self)
self._taskqueue.put(
(
self._guarded_task_generation(result._job, func, iterable),
@@ -344,7 +366,7 @@
"Chunksize must be 1+, not {0:n}".format(
chunksize))
task_batches = Pool._get_tasks(func, iterable, chunksize)
- result = IMapIterator(self._cache)
+ result = IMapIterator(self)
self._taskqueue.put(
(
self._guarded_task_generation(result._job,
@@ -360,7 +382,7 @@
'''
self._check_running()
if chunksize == 1:
- result = IMapUnorderedIterator(self._cache)
+ result = IMapUnorderedIterator(self)
self._taskqueue.put(
(
self._guarded_task_generation(result._job, func, iterable),
@@ -372,7 +394,7 @@
raise ValueError(
"Chunksize must be 1+, not {0!r}".format(chunksize))
task_batches = Pool._get_tasks(func, iterable, chunksize)
- result = IMapUnorderedIterator(self._cache)
+ result = IMapUnorderedIterator(self)
self._taskqueue.put(
(
self._guarded_task_generation(result._job,
@@ -388,7 +410,7 @@
Asynchronous version of `apply()` method.
'''
self._check_running()
- result = ApplyResult(self._cache, callback, error_callback)
+ result = ApplyResult(self, callback, error_callback)
self._taskqueue.put(([(result._job, 0, func, args, kwds)], None))
return result
@@ -417,7 +439,7 @@
chunksize = 0
task_batches = Pool._get_tasks(func, iterable, chunksize)
- result = MapResult(self._cache, chunksize, len(iterable), callback,
+ result = MapResult(self, chunksize, len(iterable), callback,
error_callback=error_callback)
self._taskqueue.put(
(
@@ -430,16 +452,20 @@
return result
@staticmethod
- def _handle_workers(pool):
+ def _handle_workers(cache, taskqueue, ctx, Process, processes, pool,
+ inqueue, outqueue, initializer, initargs,
+ maxtasksperchild, wrap_exception):
thread = threading.current_thread()
# Keep maintaining workers until the cache gets drained, unless the pool
# is terminated.
- while thread._state == RUN or (pool._cache and thread._state != TERMINATE):
- pool._maintain_pool()
+ while thread._state == RUN or (cache and thread._state != TERMINATE):
+ Pool._maintain_pool(ctx, Process, processes, pool, inqueue,
+ outqueue, initializer, initargs,
+ maxtasksperchild, wrap_exception)
time.sleep(0.1)
# send sentinel to stop workers
- pool._taskqueue.put(None)
+ taskqueue.put(None)
util.debug('worker handler exiting')
@staticmethod
@@ -656,13 +682,14 @@
class ApplyResult(object):
- def __init__(self, cache, callback, error_callback):
+ def __init__(self, pool, callback, error_callback):
+ self._pool = pool
self._event = threading.Event()
self._job = next(job_counter)
- self._cache = cache
+ self._cache = pool._cache
self._callback = callback
self._error_callback = error_callback
- cache[self._job] = self
+ self._cache[self._job] = self
def ready(self):
return self._event.is_set()
@@ -692,6 +719,7 @@
self._error_callback(self._value)
self._event.set()
del self._cache[self._job]
+ self._pool = None
AsyncResult = ApplyResult # create alias -- see #17805
@@ -701,8 +729,8 @@
class MapResult(ApplyResult):
- def __init__(self, cache, chunksize, length, callback, error_callback):
- ApplyResult.__init__(self, cache, callback,
+ def __init__(self, pool, chunksize, length, callback, error_callback):
+ ApplyResult.__init__(self, pool, callback,
error_callback=error_callback)
self._success = True
self._value = [None] * length
@@ -710,7 +738,7 @@
if chunksize <= 0:
self._number_left = 0
self._event.set()
- del cache[self._job]
+ del self._cache[self._job]
else:
self._number_left = length//chunksize + bool(length % chunksize)
@@ -724,6 +752,7 @@
self._callback(self._value)
del self._cache[self._job]
self._event.set()
+ self._pool = None
else:
if not success and self._success:
# only store first exception
@@ -735,6 +764,7 @@
self._error_callback(self._value)
del self._cache[self._job]
self._event.set()
+ self._pool = None
#
# Class whose instances are returned by `Pool.imap()`
@@ -742,15 +772,16 @@
class IMapIterator(object):
- def __init__(self, cache):
+ def __init__(self, pool):
+ self._pool = pool
self._cond = threading.Condition(threading.Lock())
self._job = next(job_counter)
- self._cache = cache
+ self._cache = pool._cache
self._items = collections.deque()
self._index = 0
self._length = None
self._unsorted = {}
- cache[self._job] = self
+ self._cache[self._job] = self
def __iter__(self):
return self
@@ -761,12 +792,14 @@
item = self._items.popleft()
except IndexError:
if self._index == self._length:
+ self._pool = None
raise StopIteration from None
self._cond.wait(timeout)
try:
item = self._items.popleft()
except IndexError:
if self._index == self._length:
+ self._pool = None
raise StopIteration from None
raise TimeoutError from None
@@ -792,6 +825,7 @@
if self._index == self._length:
del self._cache[self._job]
+ self._pool = None
def _set_length(self, length):
with self._cond:
@@ -799,6 +833,7 @@
if self._index == self._length:
self._cond.notify()
del self._cache[self._job]
+ self._pool = None
#
# Class whose instances are returned by `Pool.imap_unordered()`
@@ -813,6 +848,7 @@
self._cond.notify()
if self._index == self._length:
del self._cache[self._job]
+ self._pool = None
#
#
@@ -822,7 +858,7 @@
_wrap_exception = False
@staticmethod
- def Process(*args, **kwds):
+ def Process(ctx, *args, **kwds):
from .dummy import Process
return Process(*args, **kwds)