Share SSH Master connection across control scripts.

BUG=chromium:726481
TEST=test_that locally. Ran trybots.

Change-Id: I39993f9179aa50690e987f5d2d34892cbe00ee44
Reviewed-on: https://chromium-review.googlesource.com/547077
Commit-Ready: Hidehiko Abe <hidehiko@chromium.org>
Tested-by: Hidehiko Abe <hidehiko@chromium.org>
Reviewed-by: Dan Shi <dshi@google.com>
diff --git a/server/autoserv b/server/autoserv
index 1bedcb1..90b899c 100755
--- a/server/autoserv
+++ b/server/autoserv
@@ -557,9 +557,7 @@
                         c['success'] = True
 
         finally:
-            while job.hosts:
-                host = job.hosts.pop()
-                host.close()
+            job.close()
     except:
         exit_code = 1
         traceback.print_exc()
diff --git a/server/hosts/abstract_ssh.py b/server/hosts/abstract_ssh.py
index c3a1a19..9d53514 100644
--- a/server/hosts/abstract_ssh.py
+++ b/server/hosts/abstract_ssh.py
@@ -28,7 +28,8 @@
 
     def _initialize(self, hostname, user="root", port=22, password="",
                     is_client_install_supported=True, afe_host=None,
-                    host_info_store=None, *args, **dargs):
+                    host_info_store=None, connection_pool=None,
+                    *args, **dargs):
         super(AbstractSSHHost, self)._initialize(hostname=hostname,
                                                  *args, **dargs)
         """
@@ -41,6 +42,8 @@
         @param afe_host: The host object attained from the AFE (get_hosts).
         @param host_info_store: Optional host_info.CachingHostInfoStore object
                 to obtain / update host information.
+        @param connection_pool: ssh_multiplex.ConnectionPool instance to share
+                the master ssh connection across control scripts.
         """
         # IP address is retrieved only on demand. Otherwise the host
         # initialization will fail for host is not online.
@@ -58,7 +61,11 @@
         control path option. If master-SSH is enabled, these fields will be
         initialized by start_master_ssh when a new SSH connection is initiated.
         """
-        self._master_ssh = ssh_multiplex.MasterSsh(hostname, user, port)
+        self._connection_pool = connection_pool
+        if connection_pool:
+            self._master_ssh = connection_pool.get(hostname, user, port)
+        else:
+            self._master_ssh = ssh_multiplex.MasterSsh(hostname, user, port)
 
         self._afe_host = afe_host or utils.EmptyAFEHost()
         self.host_info_store = (host_info_store or
@@ -726,7 +733,8 @@
     def close(self):
         super(AbstractSSHHost, self).close()
         self.rpc_server_tracker.disconnect_all()
-        self._master_ssh.close()
+        if not self._connection_pool:
+            self._master_ssh.close()
         if os.path.exists(self.known_hosts_file):
             os.remove(self.known_hosts_file)
 
diff --git a/server/hosts/factory.py b/server/hosts/factory.py
index bdde399..09a34a2 100644
--- a/server/hosts/factory.py
+++ b/server/hosts/factory.py
@@ -63,8 +63,8 @@
               afe_host, user, password, port, ssh_verbosity_flag and
               ssh_options.
     """
-    hostname, afe_host = server_utils.get_host_info_from_machine(
-            machine)
+    hostname, afe_host = server_utils.get_host_info_from_machine(machine)
+    connection_pool = server_utils.get_connection_pool_from_machine(machine)
     host_info_store = host_info.get_store_from_machine(machine)
     info = host_info_store.get()
 
@@ -92,6 +92,7 @@
             'port': int(port),
             'ssh_verbosity_flag': ssh_verbosity_flag,
             'ssh_options': ssh_options,
+            'connection_pool': connection_pool,
     }
     return host_args
 
diff --git a/server/hosts/ssh_multiplex.py b/server/hosts/ssh_multiplex.py
index 0852adc..4760f8c 100644
--- a/server/hosts/ssh_multiplex.py
+++ b/server/hosts/ssh_multiplex.py
@@ -123,3 +123,34 @@
             logging.debug('Cleaning ssh master_tempdir')
             self._master_tempdir.clean()
             self._master_tempdir = None
+
+
+class ConnectionPool(object):
+    """Holds SSH multiplex connection instance."""
+
+    def __init__(self):
+        self._pool = {}
+
+    def get(self, hostname, user, port):
+        """Returns MasterSsh instance for the given endpoint.
+
+        If the pool holds the instance already, returns it. If not, create the
+        instance, and returns it.
+
+        Caller has the responsibility to call maybe_start() before using it.
+
+        @param hostname: Host name of the endpoint.
+        @param user: User name to log in.
+        @param port: Port number sshd is listening.
+        """
+        key = (hostname, user, port)
+        master_ssh = self._pool.get(key)
+        if not master_ssh:
+            master_ssh = MasterSsh(hostname, user, port)
+            self._pool[key] = master_ssh
+        return master_ssh
+
+    def shutdown(self):
+        """Closes all ssh multiplex connections."""
+        for ssh in self._pool.itervalues():
+            ssh.close()
diff --git a/server/server_job.py b/server/server_job.py
index 23359b5..2e2bd55 100644
--- a/server/server_job.py
+++ b/server/server_job.py
@@ -48,6 +48,7 @@
 from autotest_lib.server.hosts import afe_store
 from autotest_lib.server.hosts import factory as host_factory
 from autotest_lib.server.hosts import host_info
+from autotest_lib.server.hosts import ssh_multiplex
 from autotest_lib.tko import db as tko_db
 from autotest_lib.tko import models as tko_models
 from autotest_lib.tko import status_lib
@@ -81,19 +82,24 @@
 GET_NETWORK_STATS_CONTROL_FILE = _control_segment_path('get_network_stats')
 
 
-def get_machine_dicts(machine_names, in_lab, host_attributes=None):
+def get_machine_dicts(machine_names, in_lab, host_attributes=None,
+                      connection_pool=None):
     """Converts a list of machine names to list of dicts.
 
     @param machine_names: A list of machine names.
     @param in_lab: A boolean indicating whether we're running in lab.
     @param host_attributes: Optional list of host attributes to add for each
             host.
+    @param connection_pool: ssh_multiplex.ConnectionPool instance to share
+            master connections across control scripts.
     @returns: A list of dicts. Each dict has the following keys:
             'hostname': Name of the machine originally in machine_names (str).
             'afe_host': A frontend.Host object for the machine, or a stub if
                     in_lab is false.
             'host_info_store': A host_info.CachingHostInfoStore object to obtain
                     host information. A stub if in_lab is False.
+            'connection_pool': ssh_multiplex.ConnectionPool instance to share
+                    master ssh connection across control scripts.
     """
     machine_dict_list = []
     for machine in machine_names:
@@ -119,6 +125,7 @@
                 'hostname' : machine,
                 'afe_host' : afe_host,
                 'host_info_store': host_info_store,
+                'connection_pool': connection_pool,
         })
 
     return machine_dict_list
@@ -339,10 +346,13 @@
         # unexpected reboot.
         self.failed_with_device_error = False
 
+        self._connection_pool = ssh_multiplex.ConnectionPool()
+
         self.parent_job_id = parent_job_id
         self.in_lab = in_lab
         self.machine_dict_list = get_machine_dicts(
-                self.machines, self.in_lab, host_attributes)
+                self.machines, self.in_lab, host_attributes,
+                self._connection_pool)
 
         # TODO(jrbarnette) The harness attribute is only relevant to
         # client jobs, but it's required to be present, or we will fail
@@ -1499,6 +1509,16 @@
                 host.clear_known_hosts()
 
 
+    def close(self):
+        """Closes this job's operation."""
+
+        # Use shallow copy, because host.close() internally discards itself.
+        for host in list(self.hosts):
+            host.close()
+        assert not self.hosts
+        self._connection_pool.shutdown()
+
+
     def _get_job_data(self):
         """Add custom data to the job keyval info.
 
diff --git a/server/site_utils.py b/server/site_utils.py
index a99bbb6..4e8626f 100644
--- a/server/site_utils.py
+++ b/server/site_utils.py
@@ -725,6 +725,13 @@
     return afe_host
 
 
+def get_connection_pool_from_machine(machine):
+    """Returns the ssh_multiplex.ConnectionPool from machine if possible."""
+    if not isinstance(machine, dict):
+        return None
+    return machine.get('connection_pool')
+
+
 def get_creds_abspath(creds_file):
     """Returns the abspath of the credentials file.
 
@@ -839,6 +846,8 @@
 
     @param duts: List of duts to lock.
     @param afe: afe instance.
+    @param lock_msg: message for afe on locking this host.
+    @param max_wait: Max wait time in seconds.
 
     @returns Boolean lock_success where True if all duts locked successfully or
              False if we timed out waiting too long for hosts to go idle.