Merge "Retry for checking ssh connection."
diff --git a/internal/lib/ssh.py b/internal/lib/ssh.py
index 92d94ab..590f1de 100755
--- a/internal/lib/ssh.py
+++ b/internal/lib/ssh.py
@@ -30,6 +30,7 @@
_SSH_IDENTITY = "-l %(login_user)s %(ip_addr)s"
_SSH_CMD_MAX_RETRY = 4
_SSH_CMD_RETRY_SLEEP = 3
+_WAIT_FOR_SSH_MAX_TIMEOUT = 20
def _SshCall(cmd, timeout=None):
@@ -217,9 +218,8 @@
raise errors.UnknownType("Don't support the execute bin %s." % execute_bin)
- @utils.TimeExecute(function_description="Waiting for SSH server")
- def WaitForSsh(self, timeout=20, max_retry=_SSH_CMD_MAX_RETRY):
- """Wait until the remote instance is ready to accept commands over SSH.
+ def CheckSshConnection(self, timeout):
+ """Run remote 'uptime' ssh command to check ssh connection.
Args:
timeout: Integer, the maximum time to wait for the command to respond.
@@ -229,18 +229,41 @@
"""
remote_cmd = [self.GetBaseCmd(constants.SSH_BIN)]
remote_cmd.append("uptime")
- for _ in range(max_retry):
- if _SshCall(" ".join(remote_cmd), timeout) == 0:
- return
+
+ if _SshCall(" ".join(remote_cmd), timeout) == 0:
+ return
raise errors.DeviceConnectionError(
"Ssh isn't ready in the remote instance.")
+ @utils.TimeExecute(function_description="Waiting for SSH server")
+ def WaitForSsh(self, timeout=_WAIT_FOR_SSH_MAX_TIMEOUT,
+ sleep_for_retry=_SSH_CMD_RETRY_SLEEP,
+ max_retry=_SSH_CMD_MAX_RETRY):
+ """Wait until the remote instance is ready to accept commands over SSH.
+
+ Args:
+ timeout: Integer, the maximum time in seconds to wait for the
+ command to respond.
+ sleep_for_retry: Integer, the sleep time in seconds for retry.
+ max_retry: Integer, the maximum number of retry.
+
+ Raises:
+ errors.DeviceConnectionError: Ssh isn't ready in the remote instance.
+ """
+ utils.RetryExceptionType(
+ exception_types=errors.DeviceConnectionError,
+ max_retries=max_retry,
+ functor=self.CheckSshConnection,
+ sleep_multiplier=sleep_for_retry,
+ retry_backoff_factor=utils.DEFAULT_RETRY_BACKOFF_FACTOR,
+ timeout=timeout)
+
def ScpPushFile(self, src_file, dst_file):
"""Scp push file to remote.
Args:
src_file: The source file path to be pulled.
- dst_file: The destiation file path the file is pulled to.
+ dst_file: The destination file path the file is pulled to.
"""
scp_command = [self.GetBaseCmd(constants.SCP_BIN)]
scp_command.append(src_file)
@@ -252,7 +275,7 @@
Args:
src_file: The source file path to be pulled.
- dst_file: The destiation file path the file is pulled to.
+ dst_file: The destination file path the file is pulled to.
"""
scp_command = [self.GetBaseCmd(constants.SCP_BIN)]
scp_command.append("%s@%s:%s" %(self._gce_user, self._ip, src_file))
diff --git a/internal/lib/ssh_test.py b/internal/lib/ssh_test.py
index 5ec6a3a..bddefd5 100644
--- a/internal/lib/ssh_test.py
+++ b/internal/lib/ssh_test.py
@@ -193,6 +193,19 @@
expected_ip = "1.1.1.1"
self.assertEqual(ssh_object._ip, expected_ip)
+ def testWaitForSsh(self):
+ """Test WaitForSsh."""
+ ssh_object = ssh.Ssh(ip=self.FAKE_IP,
+ gce_user=self.FAKE_SSH_USER,
+ ssh_private_key_path=self.FAKE_SSH_PRIVATE_KEY_PATH,
+ report_internal_ip=self.FAKE_REPORT_INTERNAL_IP)
+ self.Patch(ssh, "_SshCall", return_value=-1)
+ self.assertRaises(errors.DeviceConnectionError,
+ ssh_object.WaitForSsh,
+ timeout=1,
+ sleep_for_retry=1,
+ max_retry=1)
+
if __name__ == "__main__":
unittest.main()