Merge "Retry for checking ssh connection."
diff --git a/internal/lib/ssh.py b/internal/lib/ssh.py
index 92d94ab..590f1de 100755
--- a/internal/lib/ssh.py
+++ b/internal/lib/ssh.py
@@ -30,6 +30,7 @@
 _SSH_IDENTITY = "-l %(login_user)s %(ip_addr)s"
 _SSH_CMD_MAX_RETRY = 4
 _SSH_CMD_RETRY_SLEEP = 3
+_WAIT_FOR_SSH_MAX_TIMEOUT = 20
 
 
 def _SshCall(cmd, timeout=None):
@@ -217,9 +218,8 @@
 
         raise errors.UnknownType("Don't support the execute bin %s." % execute_bin)
 
-    @utils.TimeExecute(function_description="Waiting for SSH server")
-    def WaitForSsh(self, timeout=20, max_retry=_SSH_CMD_MAX_RETRY):
-        """Wait until the remote instance is ready to accept commands over SSH.
+    def CheckSshConnection(self, timeout):
+        """Run remote 'uptime' ssh command to check ssh connection.
 
         Args:
             timeout: Integer, the maximum time to wait for the command to respond.
@@ -229,18 +229,41 @@
         """
         remote_cmd = [self.GetBaseCmd(constants.SSH_BIN)]
         remote_cmd.append("uptime")
-        for _ in range(max_retry):
-            if _SshCall(" ".join(remote_cmd), timeout) == 0:
-                return
+
+        if _SshCall(" ".join(remote_cmd), timeout) == 0:
+            return
         raise errors.DeviceConnectionError(
             "Ssh isn't ready in the remote instance.")
 
+    @utils.TimeExecute(function_description="Waiting for SSH server")
+    def WaitForSsh(self, timeout=_WAIT_FOR_SSH_MAX_TIMEOUT,
+                   sleep_for_retry=_SSH_CMD_RETRY_SLEEP,
+                   max_retry=_SSH_CMD_MAX_RETRY):
+        """Wait until the remote instance is ready to accept commands over SSH.
+
+        Args:
+            timeout: Integer, the maximum time in seconds to wait for the
+                     command to respond.
+            sleep_for_retry: Integer, the sleep time in seconds for retry.
+            max_retry: Integer, the maximum number of retry.
+
+        Raises:
+            errors.DeviceConnectionError: Ssh isn't ready in the remote instance.
+        """
+        utils.RetryExceptionType(
+            exception_types=errors.DeviceConnectionError,
+            max_retries=max_retry,
+            functor=self.CheckSshConnection,
+            sleep_multiplier=sleep_for_retry,
+            retry_backoff_factor=utils.DEFAULT_RETRY_BACKOFF_FACTOR,
+            timeout=timeout)
+
     def ScpPushFile(self, src_file, dst_file):
         """Scp push file to remote.
 
         Args:
             src_file: The source file path to be pulled.
-            dst_file: The destiation file path the file is pulled to.
+            dst_file: The destination file path the file is pulled to.
         """
         scp_command = [self.GetBaseCmd(constants.SCP_BIN)]
         scp_command.append(src_file)
@@ -252,7 +275,7 @@
 
         Args:
             src_file: The source file path to be pulled.
-            dst_file: The destiation file path the file is pulled to.
+            dst_file: The destination file path the file is pulled to.
         """
         scp_command = [self.GetBaseCmd(constants.SCP_BIN)]
         scp_command.append("%s@%s:%s" %(self._gce_user, self._ip, src_file))
diff --git a/internal/lib/ssh_test.py b/internal/lib/ssh_test.py
index 5ec6a3a..bddefd5 100644
--- a/internal/lib/ssh_test.py
+++ b/internal/lib/ssh_test.py
@@ -193,6 +193,19 @@
         expected_ip = "1.1.1.1"
         self.assertEqual(ssh_object._ip, expected_ip)
 
+    def testWaitForSsh(self):
+        """Test WaitForSsh."""
+        ssh_object = ssh.Ssh(ip=self.FAKE_IP,
+                             gce_user=self.FAKE_SSH_USER,
+                             ssh_private_key_path=self.FAKE_SSH_PRIVATE_KEY_PATH,
+                             report_internal_ip=self.FAKE_REPORT_INTERNAL_IP)
+        self.Patch(ssh, "_SshCall", return_value=-1)
+        self.assertRaises(errors.DeviceConnectionError,
+                          ssh_object.WaitForSsh,
+                          timeout=1,
+                          sleep_for_retry=1,
+                          max_retry=1)
+
 
 if __name__ == "__main__":
     unittest.main()