The paramiko.Transport.connect method doesn't seem to support any
equivalent to openssh's ConnectTimeout option. This can lead to some
situations where a connection potentially hangs (forever?) during the
initial SSH handshake, or during authentication.

To deal with this I've replace the Transport.connect call with some
more complex code that puts a 30-second limit on the initial
negotiation and the authentication, as well as wrapping both calls
in a couple of retries before giving up entirely and throwing a
timeout exception.

Risk: High
Visibility: Changs how we initiate paramiko connections so that we can
add a timeout.

Signed-off-by: John Admanski <jadmanski@google.com>



git-svn-id: http://test.kernel.org/svn/autotest/trunk@2986 592f7852-d20e-0410-864c-8624ca9c26a4
diff --git a/server/hosts/paramiko_host.py b/server/hosts/paramiko_host.py
index 619f172..31db932 100644
--- a/server/hosts/paramiko_host.py
+++ b/server/hosts/paramiko_host.py
@@ -1,12 +1,15 @@
-import os, sys, time, signal, socket, re, fnmatch, logging
+import os, sys, time, signal, socket, re, fnmatch, logging, threading
 import paramiko
 
-from autotest_lib.client.common_lib import utils, error, debug
+from autotest_lib.client.common_lib import utils, error
+from autotest_lib.server import subcommand
 from autotest_lib.server.hosts import abstract_ssh
 
 
 class ParamikoHost(abstract_ssh.AbstractSSHHost):
     KEEPALIVE_TIMEOUT_SECONDS = 30
+    CONNECT_TIMEOUT_SECONDS = 30
+    CONNECT_TIMEOUT_RETRIES = 3
 
     def _initialize(self, hostname, *args, **dargs):
         super(ParamikoHost, self)._initialize(hostname=hostname, *args, **dargs)
@@ -17,8 +20,6 @@
         self.keys = self.get_user_keys(hostname)
         self.pid = None
 
-        self.host_log = debug.get_logger()
-
 
     @staticmethod
     def _load_key(path):
@@ -86,18 +87,46 @@
         return user_keys
 
 
+    @staticmethod
+    def _check_transport_error(transport):
+        error = transport.get_exception()
+        if error:
+            transport.close()
+            raise error
+
+
+    def _connect_transport(self, pkey):
+        for _ in xrange(self.CONNECT_TIMEOUT_RETRIES):
+            transport = paramiko.Transport((self.hostname, self.port))
+            completed = threading.Event()
+            transport.start_client(completed)
+            completed.wait(self.CONNECT_TIMEOUT_SECONDS)
+            if completed.isSet():
+                self._check_transport_error(transport)
+                completed.clear()
+                transport.auth_publickey(self.user, pkey, completed)
+                completed.wait(self.CONNECT_TIMEOUT_SECONDS)
+                if completed.isSet():
+                    self._check_transport_error(transport)
+                    return transport
+            logging.warn("SSH negotiation timed out, retrying")
+            transport.close()
+        logging.error("SSH negotation has timed out %s times, giving up",
+                      self.CONNECT_TIMEOUT_RETRIES)
+        raise error.AutoservSSHTimeout("SSH negotiation timed out")
+
+
     def _init_transport(self):
         for path, key in self.keys.iteritems():
             try:
-                self.host_log.debug("Connecting with %s", path)
-                transport = paramiko.Transport((self.hostname, self.port))
-                transport.connect(username=self.user, pkey=key)
+                logging.debug("Connecting with %s", path)
+                transport = self._connect_transport(key)
                 transport.set_keepalive(self.KEEPALIVE_TIMEOUT_SECONDS)
                 self.transport = transport
                 self.pid = os.getpid()
                 return
             except paramiko.AuthenticationException:
-                self.host_log.debug("Authentication failure")
+                logging.debug("Authentication failure")
         else:
             raise error.AutoservSshPermissionDeniedError(
                 "Permission denied using all keys available to ParamikoHost",
@@ -111,6 +140,9 @@
                 # and this just hangs on linux after a fork()
                 self.transport.join = lambda: None
                 self.transport.atfork()
+                join_hook = lambda cmd: self._close_transport()
+                subcommand.subcommand.register_join_hook(join_hook)
+                logging.debug("Reopening SSH connection after a process fork")
             self._init_transport()
 
         channel = None
@@ -131,12 +163,16 @@
             return channel
 
 
-    def close(self):
-        super(ParamikoHost, self).close()
+    def _close_transport(self):
         if os.getpid() == self.pid:
             self.transport.close()
 
 
+    def close(self):
+        super(ParamikoHost, self).close()
+        self._close_transport()
+
+
     @staticmethod
     def _exhaust_stream(tee, output_list, recvfunc):
         while True:
@@ -172,7 +208,7 @@
         # tee to std* if no tees are provided
         stdout = stdout_tee or abstract_ssh.LoggerFile()
         stderr = stderr_tee or abstract_ssh.LoggerFile()
-        self.host_log.debug("ssh-paramiko: %s" % command)
+        logging.debug("ssh-paramiko: %s" % command)
 
         # start up the command
         echo_cmd = "echo `date '+%m/%d/%y %H:%M:%S'` Connected. >&2"