blob: 66e9770d2b59e2d36ecc9928ff7343351dd546b8 [file] [log] [blame]
Sergei Trofimov4e6afe92015-10-09 09:30:04 +01001# Copyright 2014-2015 ARM Limited
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14#
15
16
17import os
18import stat
19import logging
20import subprocess
21import re
22import threading
23import tempfile
24import shutil
25
26import pxssh
27from pexpect import EOF, TIMEOUT, spawn
28
29from devlib.exception import HostError, TargetError, TimeoutError
30from devlib.utils.misc import which, strip_bash_colors, escape_single_quotes, check_output
31
32
33ssh = None
34scp = None
35sshpass = None
36
37logger = logging.getLogger('ssh')
38
39
40def ssh_get_shell(host, username, password=None, keyfile=None, port=None, timeout=10, telnet=False):
41 _check_env()
42 if telnet:
43 if keyfile:
44 raise ValueError('keyfile may not be used with a telnet connection.')
45 conn = TelnetConnection()
46 else: # ssh
47 conn = pxssh.pxssh()
48 try:
49 if keyfile:
50 conn.login(host, username, ssh_key=keyfile, port=port, login_timeout=timeout)
51 else:
52 conn.login(host, username, password, port=port, login_timeout=timeout)
53 except EOF:
54 raise TargetError('Could not connect to {}; is the host name correct?'.format(host))
55 return conn
56
57
58class TelnetConnection(pxssh.pxssh):
59 # pylint: disable=arguments-differ
60
61 def login(self, server, username, password='', original_prompt=r'[#$]', login_timeout=10,
62 auto_prompt_reset=True, sync_multiplier=1):
63 cmd = 'telnet -l {} {}'.format(username, server)
64
65 spawn._spawn(self, cmd) # pylint: disable=protected-access
66 i = self.expect('(?i)(?:password)', timeout=login_timeout)
67 if i == 0:
68 self.sendline(password)
69 i = self.expect([original_prompt, 'Login incorrect'], timeout=login_timeout)
70 else:
71 raise pxssh.ExceptionPxssh('could not log in: did not see a password prompt')
72
73 if i:
74 raise pxssh.ExceptionPxssh('could not log in: password was incorrect')
75
76 if not self.sync_original_prompt(sync_multiplier):
77 self.close()
78 raise pxssh.ExceptionPxssh('could not synchronize with original prompt')
79
80 if auto_prompt_reset:
81 if not self.set_unique_prompt():
82 self.close()
83 message = 'could not set shell prompt (recieved: {}, expected: {}).'
84 raise pxssh.ExceptionPxssh(message.format(self.before, self.PROMPT))
85 return True
86
87
88def check_keyfile(keyfile):
89 """
90 keyfile must have the right access premissions in order to be useable. If the specified
91 file doesn't, create a temporary copy and set the right permissions for that.
92
93 Returns either the ``keyfile`` (if the permissions on it are correct) or the path to a
94 temporary copy with the right permissions.
95 """
96 desired_mask = stat.S_IWUSR | stat.S_IRUSR
97 actual_mask = os.stat(keyfile).st_mode & 0xFF
98 if actual_mask != desired_mask:
99 tmp_file = os.path.join(tempfile.gettempdir(), os.path.basename(keyfile))
100 shutil.copy(keyfile, tmp_file)
101 os.chmod(tmp_file, desired_mask)
102 return tmp_file
103 else: # permissions on keyfile are OK
104 return keyfile
105
106
107class SshConnection(object):
108
109 default_password_prompt = '[sudo] password'
110 max_cancel_attempts = 5
111
112 @property
113 def name(self):
114 return self.host
115
116 def __init__(self,
117 host,
118 username,
119 password=None,
120 keyfile=None,
121 port=None,
122 timeout=10,
123 telnet=False,
124 password_prompt=None,
125 ):
126 self.host = host
127 self.username = username
128 self.password = password
129 self.keyfile = check_keyfile(keyfile) if keyfile else keyfile
130 self.port = port
131 self.lock = threading.Lock()
132 self.password_prompt = password_prompt if password_prompt is not None else self.default_password_prompt
133 logger.debug('Logging in {}@{}'.format(username, host))
134 self.conn = ssh_get_shell(host, username, password, self.keyfile, port, timeout, telnet)
135
136 def push(self, source, dest, timeout=30):
137 dest = '{}@{}:{}'.format(self.username, self.host, dest)
138 return self._scp(source, dest, timeout)
139
140 def pull(self, source, dest, timeout=30):
141 source = '{}@{}:{}'.format(self.username, self.host, source)
142 return self._scp(source, dest, timeout)
143
144 def execute(self, command, timeout=None, check_exit_code=True, as_root=False, strip_colors=True):
145 with self.lock:
146 output = self._execute_and_wait_for_prompt(command, timeout, as_root, strip_colors)
147 if check_exit_code:
148 exit_code_text = self._execute_and_wait_for_prompt('echo $?', strip_colors=strip_colors, log=False)
149 try:
150 exit_code = int(exit_code_text.split()[0])
151 if exit_code:
152 message = 'Got exit code {}\nfrom: {}\nOUTPUT: {}'
153 raise TargetError(message.format(exit_code, command, output))
154 except (ValueError, IndexError):
155 logger.warning('Could not get exit code for "{}",\ngot: "{}"'.format(command, exit_code_text))
156 return output
157
158 def background(self, command, stdout=subprocess.PIPE, stderr=subprocess.PIPE):
159 port_string = '-p {}'.format(self.port) if self.port else ''
160 keyfile_string = '-i {}'.format(self.keyfile) if self.keyfile else ''
161 command = '{} {} {} {}@{} {}'.format(ssh, keyfile_string, port_string, self.username, self.host, command)
162 logger.debug(command)
163 if self.password:
164 command = _give_password(self.password, command)
165 return subprocess.Popen(command, stdout=stdout, stderr=stderr, shell=True)
166
167 def close(self):
168 logger.debug('Logging out {}@{}'.format(self.username, self.host))
169 self.conn.logout()
170
171 def cancel_running_command(self):
172 # simulate impatiently hitting ^C until command prompt appears
173 logger.debug('Sending ^C')
174 for _ in xrange(self.max_cancel_attempts):
175 self.conn.sendline(chr(3))
176 if self.conn.prompt(0.1):
177 return True
178 return False
179
180 def _execute_and_wait_for_prompt(self, command, timeout=None, as_root=False, strip_colors=True, log=True):
181 self.conn.prompt(0.1) # clear an existing prompt if there is one.
182 if as_root:
183 command = "sudo -- sh -c '{}'".format(escape_single_quotes(command))
184 if log:
185 logger.debug(command)
186 self.conn.sendline(command)
187 if self.password:
188 index = self.conn.expect_exact([self.password_prompt, TIMEOUT], timeout=0.5)
189 if index == 0:
190 self.conn.sendline(self.password)
191 else: # not as_root
192 if log:
193 logger.debug(command)
194 self.conn.sendline(command)
195 timed_out = self._wait_for_prompt(timeout)
196 # the regex removes line breaks potential introduced when writing
197 # command to shell.
198 output = process_backspaces(self.conn.before)
199 output = re.sub(r'\r([^\n])', r'\1', output)
200 if '\r\n' in output: # strip the echoed command
201 output = output.split('\r\n', 1)[1]
202 if timed_out:
203 self.cancel_running_command()
204 raise TimeoutError(command, output)
205 if strip_colors:
206 output = strip_bash_colors(output)
207 return output
208
209 def _wait_for_prompt(self, timeout=None):
210 if timeout:
211 return not self.conn.prompt(timeout)
212 else: # cannot timeout; wait forever
213 while not self.conn.prompt(1):
214 pass
215 return False
216
217 def _scp(self, source, dest, timeout=30):
218 # NOTE: the version of scp in Ubuntu 12.04 occasionally (and bizarrely)
219 # fails to connect to a device if port is explicitly specified using -P
220 # option, even if it is the default port, 22. To minimize this problem,
221 # only specify -P for scp if the port is *not* the default.
222 port_string = '-P {}'.format(self.port) if (self.port and self.port != 22) else ''
223 keyfile_string = '-i {}'.format(self.keyfile) if self.keyfile else ''
224 command = '{} -r {} {} {} {}'.format(scp, keyfile_string, port_string, source, dest)
225 pass_string = ''
226 logger.debug(command)
227 if self.password:
228 command = _give_password(self.password, command)
229 try:
230 check_output(command, timeout=timeout, shell=True)
231 except subprocess.CalledProcessError as e:
232 raise subprocess.CalledProcessError(e.returncode, e.cmd.replace(pass_string, ''), e.output)
233 except TimeoutError as e:
234 raise TimeoutError(e.command.replace(pass_string, ''), e.output)
235
236
237def _give_password(password, command):
238 if not sshpass:
239 raise HostError('Must have sshpass installed on the host in order to use password-based auth.')
240 pass_string = "sshpass -p '{}' ".format(password)
241 return pass_string + command
242
243
244def _check_env():
245 global ssh, scp, sshpass # pylint: disable=global-statement
246 if not ssh:
247 ssh = which('ssh')
248 scp = which('scp')
249 sshpass = which('sshpass')
250 if not (ssh and scp):
251 raise HostError('OpenSSH must be installed on the host.')
252
253
254def process_backspaces(text):
255 chars = []
256 for c in text:
257 if c == chr(8) and chars: # backspace
258 chars.pop()
259 else:
260 chars.append(c)
261 return ''.join(chars)