blob: 387517ebdca8e1a08f3363ec0498826ec5c19013 [file] [log] [blame]
Richard Oudkerk84ed9a62013-08-14 15:35:41 +01001import errno
2import os
Charles-François Natalie241ac92013-09-05 20:46:49 +02003import selectors
Richard Oudkerk84ed9a62013-08-14 15:35:41 +01004import signal
5import socket
6import struct
7import sys
8import threading
9
10from . import connection
11from . import process
12from . import reduction
Richard Oudkerk7d2d43c2013-08-22 11:38:57 +010013from . import semaphore_tracker
Richard Oudkerk84ed9a62013-08-14 15:35:41 +010014from . import spawn
15from . import util
16
17__all__ = ['ensure_running', 'get_inherited_fds', 'connect_to_new_process',
18 'set_forkserver_preload']
19
20#
21#
22#
23
24MAXFDS_TO_SEND = 256
25UNSIGNED_STRUCT = struct.Struct('Q') # large enough for pid_t
26
Richard Oudkerk84ed9a62013-08-14 15:35:41 +010027#
Richard Oudkerkb1694cf2013-10-16 16:41:56 +010028# Forkserver class
Richard Oudkerk84ed9a62013-08-14 15:35:41 +010029#
30
Richard Oudkerkb1694cf2013-10-16 16:41:56 +010031class ForkServer(object):
Richard Oudkerk84ed9a62013-08-14 15:35:41 +010032
Richard Oudkerkb1694cf2013-10-16 16:41:56 +010033 def __init__(self):
34 self._forkserver_address = None
35 self._forkserver_alive_fd = None
36 self._inherited_fds = None
37 self._lock = threading.Lock()
38 self._preload_modules = ['__main__']
Richard Oudkerk84ed9a62013-08-14 15:35:41 +010039
Richard Oudkerkb1694cf2013-10-16 16:41:56 +010040 def set_forkserver_preload(self, modules_names):
41 '''Set list of module names to try to load in forkserver process.'''
42 if not all(type(mod) is str for mod in self._preload_modules):
43 raise TypeError('module_names must be a list of strings')
44 self._preload_modules = modules_names
Richard Oudkerk84ed9a62013-08-14 15:35:41 +010045
Richard Oudkerkb1694cf2013-10-16 16:41:56 +010046 def get_inherited_fds(self):
47 '''Return list of fds inherited from parent process.
Richard Oudkerk84ed9a62013-08-14 15:35:41 +010048
Richard Oudkerkb1694cf2013-10-16 16:41:56 +010049 This returns None if the current process was not started by fork
50 server.
51 '''
52 return self._inherited_fds
Richard Oudkerk84ed9a62013-08-14 15:35:41 +010053
Richard Oudkerkb1694cf2013-10-16 16:41:56 +010054 def connect_to_new_process(self, fds):
55 '''Request forkserver to create a child process.
Richard Oudkerk84ed9a62013-08-14 15:35:41 +010056
Richard Oudkerkb1694cf2013-10-16 16:41:56 +010057 Returns a pair of fds (status_r, data_w). The calling process can read
58 the child process's pid and (eventually) its returncode from status_r.
59 The calling process should write to data_w the pickled preparation and
60 process data.
61 '''
62 self.ensure_running()
63 if len(fds) + 4 >= MAXFDS_TO_SEND:
64 raise ValueError('too many fds')
65 with socket.socket(socket.AF_UNIX) as client:
66 client.connect(self._forkserver_address)
67 parent_r, child_w = os.pipe()
68 child_r, parent_w = os.pipe()
69 allfds = [child_r, child_w, self._forkserver_alive_fd,
70 semaphore_tracker.getfd()]
71 allfds += fds
Richard Oudkerk0718f702013-08-22 11:38:55 +010072 try:
Richard Oudkerkb1694cf2013-10-16 16:41:56 +010073 reduction.sendfds(client, allfds)
74 return parent_r, parent_w
Richard Oudkerk0718f702013-08-22 11:38:55 +010075 except:
Richard Oudkerkb1694cf2013-10-16 16:41:56 +010076 os.close(parent_r)
77 os.close(parent_w)
Richard Oudkerk0718f702013-08-22 11:38:55 +010078 raise
79 finally:
Richard Oudkerkb1694cf2013-10-16 16:41:56 +010080 os.close(child_r)
81 os.close(child_w)
Richard Oudkerk84ed9a62013-08-14 15:35:41 +010082
Richard Oudkerkb1694cf2013-10-16 16:41:56 +010083 def ensure_running(self):
84 '''Make sure that a fork server is running.
85
86 This can be called from any process. Note that usually a child
87 process will just reuse the forkserver started by its parent, so
88 ensure_running() will do nothing.
89 '''
90 with self._lock:
91 semaphore_tracker.ensure_running()
92 if self._forkserver_alive_fd is not None:
93 return
94
95 cmd = ('from multiprocessing.forkserver import main; ' +
96 'main(%d, %d, %r, **%r)')
97
98 if self._preload_modules:
99 desired_keys = {'main_path', 'sys_path'}
100 data = spawn.get_preparation_data('ignore')
101 data = dict((x,y) for (x,y) in data.items()
102 if x in desired_keys)
103 else:
104 data = {}
105
106 with socket.socket(socket.AF_UNIX) as listener:
107 address = connection.arbitrary_address('AF_UNIX')
108 listener.bind(address)
109 os.chmod(address, 0o600)
110 listener.listen(100)
111
112 # all client processes own the write end of the "alive" pipe;
113 # when they all terminate the read end becomes ready.
114 alive_r, alive_w = os.pipe()
115 try:
116 fds_to_pass = [listener.fileno(), alive_r]
117 cmd %= (listener.fileno(), alive_r, self._preload_modules,
118 data)
119 exe = spawn.get_executable()
120 args = [exe] + util._args_from_interpreter_flags()
121 args += ['-c', cmd]
122 pid = util.spawnv_passfds(exe, args, fds_to_pass)
123 except:
124 os.close(alive_w)
125 raise
126 finally:
127 os.close(alive_r)
128 self._forkserver_address = address
129 self._forkserver_alive_fd = alive_w
130
131#
132#
133#
Richard Oudkerk84ed9a62013-08-14 15:35:41 +0100134
135def main(listener_fd, alive_r, preload, main_path=None, sys_path=None):
136 '''Run forkserver.'''
137 if preload:
138 if '__main__' in preload and main_path is not None:
139 process.current_process()._inheriting = True
140 try:
141 spawn.import_main_path(main_path)
142 finally:
143 del process.current_process()._inheriting
144 for modname in preload:
145 try:
146 __import__(modname)
147 except ImportError:
148 pass
149
150 # close sys.stdin
151 if sys.stdin is not None:
152 try:
153 sys.stdin.close()
154 sys.stdin = open(os.devnull)
155 except (OSError, ValueError):
156 pass
157
158 # ignoring SIGCHLD means no need to reap zombie processes
159 handler = signal.signal(signal.SIGCHLD, signal.SIG_IGN)
Charles-François Natalie241ac92013-09-05 20:46:49 +0200160 with socket.socket(socket.AF_UNIX, fileno=listener_fd) as listener, \
161 selectors.DefaultSelector() as selector:
Richard Oudkerkb1694cf2013-10-16 16:41:56 +0100162 _forkserver._forkserver_address = listener.getsockname()
Charles-François Natalie241ac92013-09-05 20:46:49 +0200163
164 selector.register(listener, selectors.EVENT_READ)
165 selector.register(alive_r, selectors.EVENT_READ)
Richard Oudkerk84ed9a62013-08-14 15:35:41 +0100166
167 while True:
168 try:
Charles-François Natalie241ac92013-09-05 20:46:49 +0200169 while True:
170 rfds = [key.fileobj for (key, events) in selector.select()]
171 if rfds:
172 break
Richard Oudkerk84ed9a62013-08-14 15:35:41 +0100173
174 if alive_r in rfds:
175 # EOF because no more client processes left
176 assert os.read(alive_r, 1) == b''
177 raise SystemExit
178
179 assert listener in rfds
180 with listener.accept()[0] as s:
181 code = 1
182 if os.fork() == 0:
183 try:
184 _serve_one(s, listener, alive_r, handler)
185 except Exception:
186 sys.excepthook(*sys.exc_info())
187 sys.stderr.flush()
188 finally:
189 os._exit(code)
190
191 except InterruptedError:
192 pass
193 except OSError as e:
194 if e.errno != errno.ECONNABORTED:
195 raise
196
Richard Oudkerk84ed9a62013-08-14 15:35:41 +0100197def _serve_one(s, listener, alive_r, handler):
Richard Oudkerk84ed9a62013-08-14 15:35:41 +0100198 # close unnecessary stuff and reset SIGCHLD handler
199 listener.close()
200 os.close(alive_r)
201 signal.signal(signal.SIGCHLD, handler)
202
203 # receive fds from parent process
204 fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1)
205 s.close()
206 assert len(fds) <= MAXFDS_TO_SEND
Richard Oudkerkb1694cf2013-10-16 16:41:56 +0100207 (child_r, child_w, _forkserver._forkserver_alive_fd,
208 stfd, *_forkserver._inherited_fds) = fds
209 semaphore_tracker._semaphore_tracker._fd = stfd
Richard Oudkerk84ed9a62013-08-14 15:35:41 +0100210
211 # send pid to client processes
212 write_unsigned(child_w, os.getpid())
213
214 # reseed random number generator
215 if 'random' in sys.modules:
216 import random
217 random.seed()
218
219 # run process object received over pipe
220 code = spawn._main(child_r)
221
222 # write the exit code to the pipe
223 write_unsigned(child_w, code)
224
225#
226# Read and write unsigned numbers
227#
228
229def read_unsigned(fd):
230 data = b''
231 length = UNSIGNED_STRUCT.size
232 while len(data) < length:
233 while True:
234 try:
235 s = os.read(fd, length - len(data))
236 except InterruptedError:
237 pass
238 else:
239 break
240 if not s:
241 raise EOFError('unexpected EOF')
242 data += s
243 return UNSIGNED_STRUCT.unpack(data)[0]
244
245def write_unsigned(fd, n):
246 msg = UNSIGNED_STRUCT.pack(n)
247 while msg:
248 while True:
249 try:
250 nbytes = os.write(fd, msg)
251 except InterruptedError:
252 pass
253 else:
254 break
255 if nbytes == 0:
256 raise RuntimeError('should not get here')
257 msg = msg[nbytes:]
Richard Oudkerkb1694cf2013-10-16 16:41:56 +0100258
259#
260#
261#
262
263_forkserver = ForkServer()
264ensure_running = _forkserver.ensure_running
265get_inherited_fds = _forkserver.get_inherited_fds
266connect_to_new_process = _forkserver.connect_to_new_process
267set_forkserver_preload = _forkserver.set_forkserver_preload