blob: d06c4eb2665c76c5f4ee2ad969a857cf7ae71c73 [file] [log] [blame]
Mike Frysinger5291eaf2021-05-05 15:53:03 -04001# Copyright (C) 2008 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Common SSH management logic."""
16
17import functools
18import os
19import re
20import signal
21import subprocess
22import sys
23import tempfile
24try:
25 import threading as _threading
26except ImportError:
27 import dummy_threading as _threading
28import time
29
30import platform_utils
31from repo_trace import Trace
32
33
34_ssh_proxy_path = None
35_ssh_sock_path = None
36_ssh_clients = []
37
38
39def _run_ssh_version():
40 """run ssh -V to display the version number"""
41 return subprocess.check_output(['ssh', '-V'], stderr=subprocess.STDOUT).decode()
42
43
44def _parse_ssh_version(ver_str=None):
45 """parse a ssh version string into a tuple"""
46 if ver_str is None:
47 ver_str = _run_ssh_version()
48 m = re.match(r'^OpenSSH_([0-9.]+)(p[0-9]+)?\s', ver_str)
49 if m:
50 return tuple(int(x) for x in m.group(1).split('.'))
51 else:
52 return ()
53
54
55@functools.lru_cache(maxsize=None)
56def version():
57 """return ssh version as a tuple"""
58 try:
59 return _parse_ssh_version()
60 except subprocess.CalledProcessError:
61 print('fatal: unable to detect ssh version', file=sys.stderr)
62 sys.exit(1)
63
64
65def proxy():
66 global _ssh_proxy_path
67 if _ssh_proxy_path is None:
68 _ssh_proxy_path = os.path.join(
69 os.path.dirname(__file__),
70 'git_ssh')
71 return _ssh_proxy_path
72
73
74def add_client(p):
75 _ssh_clients.append(p)
76
77
78def remove_client(p):
79 try:
80 _ssh_clients.remove(p)
81 except ValueError:
82 pass
83
84
85def _terminate_clients():
86 global _ssh_clients
87 for p in _ssh_clients:
88 try:
89 os.kill(p.pid, signal.SIGTERM)
90 p.wait()
91 except OSError:
92 pass
93 _ssh_clients = []
94
95
96_master_processes = []
97_master_keys = set()
98_ssh_master = True
99_master_keys_lock = None
100
101
102def init():
103 """Should be called once at the start of repo to init ssh master handling.
104
105 At the moment, all we do is to create our lock.
106 """
107 global _master_keys_lock
108 assert _master_keys_lock is None, "Should only call init once"
109 _master_keys_lock = _threading.Lock()
110
111
112def _open_ssh(host, port=None):
113 global _ssh_master
114
115 # Bail before grabbing the lock if we already know that we aren't going to
116 # try creating new masters below.
117 if sys.platform in ('win32', 'cygwin'):
118 return False
119
120 # Acquire the lock. This is needed to prevent opening multiple masters for
121 # the same host when we're running "repo sync -jN" (for N > 1) _and_ the
122 # manifest <remote fetch="ssh://xyz"> specifies a different host from the
123 # one that was passed to repo init.
124 _master_keys_lock.acquire()
125 try:
126
127 # Check to see whether we already think that the master is running; if we
128 # think it's already running, return right away.
129 if port is not None:
130 key = '%s:%s' % (host, port)
131 else:
132 key = host
133
134 if key in _master_keys:
135 return True
136
137 if not _ssh_master or 'GIT_SSH' in os.environ:
138 # Failed earlier, so don't retry.
139 return False
140
141 # We will make two calls to ssh; this is the common part of both calls.
142 command_base = ['ssh',
143 '-o', 'ControlPath %s' % sock(),
144 host]
145 if port is not None:
146 command_base[1:1] = ['-p', str(port)]
147
148 # Since the key wasn't in _master_keys, we think that master isn't running.
149 # ...but before actually starting a master, we'll double-check. This can
150 # be important because we can't tell that that 'git@myhost.com' is the same
151 # as 'myhost.com' where "User git" is setup in the user's ~/.ssh/config file.
152 check_command = command_base + ['-O', 'check']
153 try:
154 Trace(': %s', ' '.join(check_command))
155 check_process = subprocess.Popen(check_command,
156 stdout=subprocess.PIPE,
157 stderr=subprocess.PIPE)
158 check_process.communicate() # read output, but ignore it...
159 isnt_running = check_process.wait()
160
161 if not isnt_running:
162 # Our double-check found that the master _was_ infact running. Add to
163 # the list of keys.
164 _master_keys.add(key)
165 return True
166 except Exception:
167 # Ignore excpetions. We we will fall back to the normal command and print
168 # to the log there.
169 pass
170
171 command = command_base[:1] + ['-M', '-N'] + command_base[1:]
172 try:
173 Trace(': %s', ' '.join(command))
174 p = subprocess.Popen(command)
175 except Exception as e:
176 _ssh_master = False
177 print('\nwarn: cannot enable ssh control master for %s:%s\n%s'
178 % (host, port, str(e)), file=sys.stderr)
179 return False
180
181 time.sleep(1)
182 ssh_died = (p.poll() is not None)
183 if ssh_died:
184 return False
185
186 _master_processes.append(p)
187 _master_keys.add(key)
188 return True
189 finally:
190 _master_keys_lock.release()
191
192
193def close():
194 global _master_keys_lock
195
196 _terminate_clients()
197
198 for p in _master_processes:
199 try:
200 os.kill(p.pid, signal.SIGTERM)
201 p.wait()
202 except OSError:
203 pass
204 del _master_processes[:]
205 _master_keys.clear()
206
207 d = sock(create=False)
208 if d:
209 try:
210 platform_utils.rmdir(os.path.dirname(d))
211 except OSError:
212 pass
213
214 # We're done with the lock, so we can delete it.
215 _master_keys_lock = None
216
217
218URI_SCP = re.compile(r'^([^@:]*@?[^:/]{1,}):')
219URI_ALL = re.compile(r'^([a-z][a-z+-]*)://([^@/]*@?[^/]*)/')
220
221
222def preconnect(url):
223 m = URI_ALL.match(url)
224 if m:
225 scheme = m.group(1)
226 host = m.group(2)
227 if ':' in host:
228 host, port = host.split(':')
229 else:
230 port = None
231 if scheme in ('ssh', 'git+ssh', 'ssh+git'):
232 return _open_ssh(host, port)
233 return False
234
235 m = URI_SCP.match(url)
236 if m:
237 host = m.group(1)
238 return _open_ssh(host)
239
240 return False
241
242def sock(create=True):
243 global _ssh_sock_path
244 if _ssh_sock_path is None:
245 if not create:
246 return None
247 tmp_dir = '/tmp'
248 if not os.path.exists(tmp_dir):
249 tmp_dir = tempfile.gettempdir()
250 if version() < (6, 7):
251 tokens = '%r@%h:%p'
252 else:
253 tokens = '%C' # hash of %l%h%p%r
254 _ssh_sock_path = os.path.join(
255 tempfile.mkdtemp('', 'ssh-', tmp_dir),
256 'master-' + tokens)
257 return _ssh_sock_path