blob: d5ce6257456766123080f16379eb21e249a82ff6 [file] [log] [blame]
Richard Oudkerk84ed9a62013-08-14 15:35:41 +01001import errno
2import os
Charles-François Natalie241ac92013-09-05 20:46:49 +02003import selectors
Richard Oudkerk84ed9a62013-08-14 15:35:41 +01004import signal
5import socket
6import struct
7import sys
8import threading
9
10from . import connection
11from . import process
Davin Potts54586472016-09-09 18:03:10 -050012from .context import reduction
Richard Oudkerk7d2d43c2013-08-22 11:38:57 +010013from . import semaphore_tracker
Richard Oudkerk84ed9a62013-08-14 15:35:41 +010014from . import spawn
15from . import util
16
17__all__ = ['ensure_running', 'get_inherited_fds', 'connect_to_new_process',
18 'set_forkserver_preload']
19
20#
21#
22#
23
24MAXFDS_TO_SEND = 256
25UNSIGNED_STRUCT = struct.Struct('Q') # large enough for pid_t
26
Richard Oudkerk84ed9a62013-08-14 15:35:41 +010027#
Richard Oudkerkb1694cf2013-10-16 16:41:56 +010028# Forkserver class
Richard Oudkerk84ed9a62013-08-14 15:35:41 +010029#
30
Richard Oudkerkb1694cf2013-10-16 16:41:56 +010031class ForkServer(object):
Richard Oudkerk84ed9a62013-08-14 15:35:41 +010032
Richard Oudkerkb1694cf2013-10-16 16:41:56 +010033 def __init__(self):
34 self._forkserver_address = None
35 self._forkserver_alive_fd = None
36 self._inherited_fds = None
37 self._lock = threading.Lock()
38 self._preload_modules = ['__main__']
Richard Oudkerk84ed9a62013-08-14 15:35:41 +010039
Richard Oudkerkb1694cf2013-10-16 16:41:56 +010040 def set_forkserver_preload(self, modules_names):
41 '''Set list of module names to try to load in forkserver process.'''
42 if not all(type(mod) is str for mod in self._preload_modules):
43 raise TypeError('module_names must be a list of strings')
44 self._preload_modules = modules_names
Richard Oudkerk84ed9a62013-08-14 15:35:41 +010045
Richard Oudkerkb1694cf2013-10-16 16:41:56 +010046 def get_inherited_fds(self):
47 '''Return list of fds inherited from parent process.
Richard Oudkerk84ed9a62013-08-14 15:35:41 +010048
Richard Oudkerkb1694cf2013-10-16 16:41:56 +010049 This returns None if the current process was not started by fork
50 server.
51 '''
52 return self._inherited_fds
Richard Oudkerk84ed9a62013-08-14 15:35:41 +010053
Richard Oudkerkb1694cf2013-10-16 16:41:56 +010054 def connect_to_new_process(self, fds):
55 '''Request forkserver to create a child process.
Richard Oudkerk84ed9a62013-08-14 15:35:41 +010056
Richard Oudkerkb1694cf2013-10-16 16:41:56 +010057 Returns a pair of fds (status_r, data_w). The calling process can read
58 the child process's pid and (eventually) its returncode from status_r.
59 The calling process should write to data_w the pickled preparation and
60 process data.
61 '''
62 self.ensure_running()
63 if len(fds) + 4 >= MAXFDS_TO_SEND:
64 raise ValueError('too many fds')
65 with socket.socket(socket.AF_UNIX) as client:
66 client.connect(self._forkserver_address)
67 parent_r, child_w = os.pipe()
68 child_r, parent_w = os.pipe()
69 allfds = [child_r, child_w, self._forkserver_alive_fd,
70 semaphore_tracker.getfd()]
71 allfds += fds
Richard Oudkerk0718f702013-08-22 11:38:55 +010072 try:
Richard Oudkerkb1694cf2013-10-16 16:41:56 +010073 reduction.sendfds(client, allfds)
74 return parent_r, parent_w
Richard Oudkerk0718f702013-08-22 11:38:55 +010075 except:
Richard Oudkerkb1694cf2013-10-16 16:41:56 +010076 os.close(parent_r)
77 os.close(parent_w)
Richard Oudkerk0718f702013-08-22 11:38:55 +010078 raise
79 finally:
Richard Oudkerkb1694cf2013-10-16 16:41:56 +010080 os.close(child_r)
81 os.close(child_w)
Richard Oudkerk84ed9a62013-08-14 15:35:41 +010082
Richard Oudkerkb1694cf2013-10-16 16:41:56 +010083 def ensure_running(self):
84 '''Make sure that a fork server is running.
85
86 This can be called from any process. Note that usually a child
87 process will just reuse the forkserver started by its parent, so
88 ensure_running() will do nothing.
89 '''
90 with self._lock:
91 semaphore_tracker.ensure_running()
92 if self._forkserver_alive_fd is not None:
93 return
94
95 cmd = ('from multiprocessing.forkserver import main; ' +
96 'main(%d, %d, %r, **%r)')
97
98 if self._preload_modules:
99 desired_keys = {'main_path', 'sys_path'}
100 data = spawn.get_preparation_data('ignore')
101 data = dict((x,y) for (x,y) in data.items()
102 if x in desired_keys)
103 else:
104 data = {}
105
106 with socket.socket(socket.AF_UNIX) as listener:
107 address = connection.arbitrary_address('AF_UNIX')
108 listener.bind(address)
109 os.chmod(address, 0o600)
Charles-François Natali6e204602014-07-23 19:28:13 +0100110 listener.listen()
Richard Oudkerkb1694cf2013-10-16 16:41:56 +0100111
112 # all client processes own the write end of the "alive" pipe;
113 # when they all terminate the read end becomes ready.
114 alive_r, alive_w = os.pipe()
115 try:
116 fds_to_pass = [listener.fileno(), alive_r]
117 cmd %= (listener.fileno(), alive_r, self._preload_modules,
118 data)
119 exe = spawn.get_executable()
120 args = [exe] + util._args_from_interpreter_flags()
121 args += ['-c', cmd]
122 pid = util.spawnv_passfds(exe, args, fds_to_pass)
123 except:
124 os.close(alive_w)
125 raise
126 finally:
127 os.close(alive_r)
128 self._forkserver_address = address
129 self._forkserver_alive_fd = alive_w
130
131#
132#
133#
Richard Oudkerk84ed9a62013-08-14 15:35:41 +0100134
135def main(listener_fd, alive_r, preload, main_path=None, sys_path=None):
136 '''Run forkserver.'''
137 if preload:
138 if '__main__' in preload and main_path is not None:
139 process.current_process()._inheriting = True
140 try:
141 spawn.import_main_path(main_path)
142 finally:
143 del process.current_process()._inheriting
144 for modname in preload:
145 try:
146 __import__(modname)
147 except ImportError:
148 pass
149
Victor Stinnera6d865c2016-03-25 09:29:50 +0100150 util._close_stdin()
Richard Oudkerk84ed9a62013-08-14 15:35:41 +0100151
Antoine Pitrou6dd4d732017-05-04 16:44:53 +0200152 # ignoring SIGCHLD means no need to reap zombie processes;
153 # letting SIGINT through avoids KeyboardInterrupt tracebacks
154 handlers = {
155 signal.SIGCHLD: signal.SIG_IGN,
156 signal.SIGINT: signal.SIG_DFL,
157 }
158 old_handlers = {sig: signal.signal(sig, val)
159 for (sig, val) in handlers.items()}
160
Charles-François Natalie241ac92013-09-05 20:46:49 +0200161 with socket.socket(socket.AF_UNIX, fileno=listener_fd) as listener, \
162 selectors.DefaultSelector() as selector:
Richard Oudkerkb1694cf2013-10-16 16:41:56 +0100163 _forkserver._forkserver_address = listener.getsockname()
Charles-François Natalie241ac92013-09-05 20:46:49 +0200164
165 selector.register(listener, selectors.EVENT_READ)
166 selector.register(alive_r, selectors.EVENT_READ)
Richard Oudkerk84ed9a62013-08-14 15:35:41 +0100167
168 while True:
169 try:
Charles-François Natalie241ac92013-09-05 20:46:49 +0200170 while True:
171 rfds = [key.fileobj for (key, events) in selector.select()]
172 if rfds:
173 break
Richard Oudkerk84ed9a62013-08-14 15:35:41 +0100174
175 if alive_r in rfds:
176 # EOF because no more client processes left
177 assert os.read(alive_r, 1) == b''
178 raise SystemExit
179
180 assert listener in rfds
181 with listener.accept()[0] as s:
182 code = 1
183 if os.fork() == 0:
184 try:
Antoine Pitrou6dd4d732017-05-04 16:44:53 +0200185 _serve_one(s, listener, alive_r, old_handlers)
Richard Oudkerk84ed9a62013-08-14 15:35:41 +0100186 except Exception:
187 sys.excepthook(*sys.exc_info())
188 sys.stderr.flush()
189 finally:
190 os._exit(code)
191
Richard Oudkerk84ed9a62013-08-14 15:35:41 +0100192 except OSError as e:
193 if e.errno != errno.ECONNABORTED:
194 raise
195
Antoine Pitrou6dd4d732017-05-04 16:44:53 +0200196def _serve_one(s, listener, alive_r, handlers):
197 # close unnecessary stuff and reset signal handlers
Richard Oudkerk84ed9a62013-08-14 15:35:41 +0100198 listener.close()
199 os.close(alive_r)
Antoine Pitrou6dd4d732017-05-04 16:44:53 +0200200 for sig, val in handlers.items():
201 signal.signal(sig, val)
Richard Oudkerk84ed9a62013-08-14 15:35:41 +0100202
203 # receive fds from parent process
204 fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1)
205 s.close()
206 assert len(fds) <= MAXFDS_TO_SEND
Richard Oudkerkb1694cf2013-10-16 16:41:56 +0100207 (child_r, child_w, _forkserver._forkserver_alive_fd,
208 stfd, *_forkserver._inherited_fds) = fds
209 semaphore_tracker._semaphore_tracker._fd = stfd
Richard Oudkerk84ed9a62013-08-14 15:35:41 +0100210
211 # send pid to client processes
212 write_unsigned(child_w, os.getpid())
213
214 # reseed random number generator
215 if 'random' in sys.modules:
216 import random
217 random.seed()
218
219 # run process object received over pipe
220 code = spawn._main(child_r)
221
222 # write the exit code to the pipe
223 write_unsigned(child_w, code)
224
225#
226# Read and write unsigned numbers
227#
228
229def read_unsigned(fd):
230 data = b''
231 length = UNSIGNED_STRUCT.size
232 while len(data) < length:
Charles-François Natali6e6c59b2015-02-07 13:27:50 +0000233 s = os.read(fd, length - len(data))
Richard Oudkerk84ed9a62013-08-14 15:35:41 +0100234 if not s:
235 raise EOFError('unexpected EOF')
236 data += s
237 return UNSIGNED_STRUCT.unpack(data)[0]
238
239def write_unsigned(fd, n):
240 msg = UNSIGNED_STRUCT.pack(n)
241 while msg:
Charles-François Natali6e6c59b2015-02-07 13:27:50 +0000242 nbytes = os.write(fd, msg)
Richard Oudkerk84ed9a62013-08-14 15:35:41 +0100243 if nbytes == 0:
244 raise RuntimeError('should not get here')
245 msg = msg[nbytes:]
Richard Oudkerkb1694cf2013-10-16 16:41:56 +0100246
247#
248#
249#
250
251_forkserver = ForkServer()
252ensure_running = _forkserver.ensure_running
253get_inherited_fds = _forkserver.get_inherited_fds
254connect_to_new_process = _forkserver.connect_to_new_process
255set_forkserver_preload = _forkserver.set_forkserver_preload