blob: 0a237079f15180d1cf93bad3a7fb7e60264f4da8 [file] [log] [blame]
Richard Oudkerk84ed9a62013-08-14 15:35:41 +01001import errno
2import os
Charles-François Natalie241ac92013-09-05 20:46:49 +02003import selectors
Richard Oudkerk84ed9a62013-08-14 15:35:41 +01004import signal
5import socket
6import struct
7import sys
8import threading
9
10from . import connection
11from . import process
12from . import reduction
Richard Oudkerk7d2d43c2013-08-22 11:38:57 +010013from . import semaphore_tracker
Richard Oudkerk84ed9a62013-08-14 15:35:41 +010014from . import spawn
15from . import util
16
17__all__ = ['ensure_running', 'get_inherited_fds', 'connect_to_new_process',
18 'set_forkserver_preload']
19
20#
21#
22#
23
24MAXFDS_TO_SEND = 256
25UNSIGNED_STRUCT = struct.Struct('Q') # large enough for pid_t
26
Richard Oudkerk0718f702013-08-22 11:38:55 +010027_forkserver_address = None
28_forkserver_alive_fd = None
Richard Oudkerk84ed9a62013-08-14 15:35:41 +010029_inherited_fds = None
30_lock = threading.Lock()
31_preload_modules = ['__main__']
32
Richard Oudkerk84ed9a62013-08-14 15:35:41 +010033#
34# Public function
35#
36
37def set_forkserver_preload(modules_names):
38 '''Set list of module names to try to load in forkserver process.'''
39 global _preload_modules
40 _preload_modules = modules_names
41
42
43def get_inherited_fds():
44 '''Return list of fds inherited from parent process.
45
46 This returns None if the current process was not started by fork server.
47 '''
48 return _inherited_fds
49
50
51def connect_to_new_process(fds):
52 '''Request forkserver to create a child process.
53
54 Returns a pair of fds (status_r, data_w). The calling process can read
55 the child process's pid and (eventually) its returncode from status_r.
56 The calling process should write to data_w the pickled preparation and
57 process data.
58 '''
Richard Oudkerk7d2d43c2013-08-22 11:38:57 +010059 if len(fds) + 4 >= MAXFDS_TO_SEND:
Richard Oudkerk84ed9a62013-08-14 15:35:41 +010060 raise ValueError('too many fds')
Richard Oudkerk84ed9a62013-08-14 15:35:41 +010061 with socket.socket(socket.AF_UNIX) as client:
Richard Oudkerk0718f702013-08-22 11:38:55 +010062 client.connect(_forkserver_address)
Victor Stinnerdaf45552013-08-28 00:53:59 +020063 parent_r, child_w = os.pipe()
64 child_r, parent_w = os.pipe()
Richard Oudkerk7d2d43c2013-08-22 11:38:57 +010065 allfds = [child_r, child_w, _forkserver_alive_fd,
66 semaphore_tracker._semaphore_tracker_fd]
Richard Oudkerk84ed9a62013-08-14 15:35:41 +010067 allfds += fds
68 try:
69 reduction.sendfds(client, allfds)
70 return parent_r, parent_w
Richard Oudkerk84ed9a62013-08-14 15:35:41 +010071 except:
72 os.close(parent_r)
73 os.close(parent_w)
74 raise
75 finally:
76 os.close(child_r)
77 os.close(child_w)
78
79
80def ensure_running():
81 '''Make sure that a fork server is running.
82
83 This can be called from any process. Note that usually a child
84 process will just reuse the forkserver started by its parent, so
85 ensure_running() will do nothing.
86 '''
Richard Oudkerk0718f702013-08-22 11:38:55 +010087 global _forkserver_address, _forkserver_alive_fd
Richard Oudkerk84ed9a62013-08-14 15:35:41 +010088 with _lock:
Richard Oudkerk0718f702013-08-22 11:38:55 +010089 if _forkserver_alive_fd is not None:
Richard Oudkerk84ed9a62013-08-14 15:35:41 +010090 return
91
92 assert all(type(mod) is str for mod in _preload_modules)
Richard Oudkerk84ed9a62013-08-14 15:35:41 +010093 cmd = ('from multiprocessing.forkserver import main; ' +
94 'main(%d, %d, %r, **%r)')
95
96 if _preload_modules:
97 desired_keys = {'main_path', 'sys_path'}
98 data = spawn.get_preparation_data('ignore')
99 data = dict((x,y) for (x,y) in data.items() if x in desired_keys)
100 else:
101 data = {}
102
103 with socket.socket(socket.AF_UNIX) as listener:
104 address = connection.arbitrary_address('AF_UNIX')
105 listener.bind(address)
106 os.chmod(address, 0o600)
107 listener.listen(100)
108
109 # all client processes own the write end of the "alive" pipe;
110 # when they all terminate the read end becomes ready.
Richard Oudkerk0d097b62013-08-28 11:25:34 +0100111 alive_r, alive_w = os.pipe()
Richard Oudkerk0718f702013-08-22 11:38:55 +0100112 try:
Richard Oudkerk7d2d43c2013-08-22 11:38:57 +0100113 fds_to_pass = [listener.fileno(), alive_r]
Richard Oudkerk0718f702013-08-22 11:38:55 +0100114 cmd %= (listener.fileno(), alive_r, _preload_modules, data)
115 exe = spawn.get_executable()
116 args = [exe] + util._args_from_interpreter_flags() + ['-c', cmd]
117 pid = util.spawnv_passfds(exe, args, fds_to_pass)
118 except:
119 os.close(alive_w)
120 raise
121 finally:
122 os.close(alive_r)
123 _forkserver_address = address
124 _forkserver_alive_fd = alive_w
Richard Oudkerk84ed9a62013-08-14 15:35:41 +0100125
126
127def main(listener_fd, alive_r, preload, main_path=None, sys_path=None):
128 '''Run forkserver.'''
129 if preload:
130 if '__main__' in preload and main_path is not None:
131 process.current_process()._inheriting = True
132 try:
133 spawn.import_main_path(main_path)
134 finally:
135 del process.current_process()._inheriting
136 for modname in preload:
137 try:
138 __import__(modname)
139 except ImportError:
140 pass
141
142 # close sys.stdin
143 if sys.stdin is not None:
144 try:
145 sys.stdin.close()
146 sys.stdin = open(os.devnull)
147 except (OSError, ValueError):
148 pass
149
150 # ignoring SIGCHLD means no need to reap zombie processes
151 handler = signal.signal(signal.SIGCHLD, signal.SIG_IGN)
Charles-François Natalie241ac92013-09-05 20:46:49 +0200152 with socket.socket(socket.AF_UNIX, fileno=listener_fd) as listener, \
153 selectors.DefaultSelector() as selector:
Richard Oudkerk0718f702013-08-22 11:38:55 +0100154 global _forkserver_address
155 _forkserver_address = listener.getsockname()
Charles-François Natalie241ac92013-09-05 20:46:49 +0200156
157 selector.register(listener, selectors.EVENT_READ)
158 selector.register(alive_r, selectors.EVENT_READ)
Richard Oudkerk84ed9a62013-08-14 15:35:41 +0100159
160 while True:
161 try:
Charles-François Natalie241ac92013-09-05 20:46:49 +0200162 while True:
163 rfds = [key.fileobj for (key, events) in selector.select()]
164 if rfds:
165 break
Richard Oudkerk84ed9a62013-08-14 15:35:41 +0100166
167 if alive_r in rfds:
168 # EOF because no more client processes left
169 assert os.read(alive_r, 1) == b''
170 raise SystemExit
171
172 assert listener in rfds
173 with listener.accept()[0] as s:
174 code = 1
175 if os.fork() == 0:
176 try:
177 _serve_one(s, listener, alive_r, handler)
178 except Exception:
179 sys.excepthook(*sys.exc_info())
180 sys.stderr.flush()
181 finally:
182 os._exit(code)
183
184 except InterruptedError:
185 pass
186 except OSError as e:
187 if e.errno != errno.ECONNABORTED:
188 raise
189
190#
191# Code to bootstrap new process
192#
193
194def _serve_one(s, listener, alive_r, handler):
Richard Oudkerk0718f702013-08-22 11:38:55 +0100195 global _inherited_fds, _forkserver_alive_fd
Richard Oudkerk84ed9a62013-08-14 15:35:41 +0100196
197 # close unnecessary stuff and reset SIGCHLD handler
198 listener.close()
199 os.close(alive_r)
200 signal.signal(signal.SIGCHLD, handler)
201
202 # receive fds from parent process
203 fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1)
204 s.close()
205 assert len(fds) <= MAXFDS_TO_SEND
Richard Oudkerk7d2d43c2013-08-22 11:38:57 +0100206 child_r, child_w, _forkserver_alive_fd, stfd, *_inherited_fds = fds
207 semaphore_tracker._semaphore_tracker_fd = stfd
Richard Oudkerk84ed9a62013-08-14 15:35:41 +0100208
209 # send pid to client processes
210 write_unsigned(child_w, os.getpid())
211
212 # reseed random number generator
213 if 'random' in sys.modules:
214 import random
215 random.seed()
216
217 # run process object received over pipe
218 code = spawn._main(child_r)
219
220 # write the exit code to the pipe
221 write_unsigned(child_w, code)
222
223#
224# Read and write unsigned numbers
225#
226
227def read_unsigned(fd):
228 data = b''
229 length = UNSIGNED_STRUCT.size
230 while len(data) < length:
231 while True:
232 try:
233 s = os.read(fd, length - len(data))
234 except InterruptedError:
235 pass
236 else:
237 break
238 if not s:
239 raise EOFError('unexpected EOF')
240 data += s
241 return UNSIGNED_STRUCT.unpack(data)[0]
242
243def write_unsigned(fd, n):
244 msg = UNSIGNED_STRUCT.pack(n)
245 while msg:
246 while True:
247 try:
248 nbytes = os.write(fd, msg)
249 except InterruptedError:
250 pass
251 else:
252 break
253 if nbytes == 0:
254 raise RuntimeError('should not get here')
255 msg = msg[nbytes:]