Add an extra parameter to the ssh tunnel

add the extra_args_ssh_tunnel args in the acloud.config
extra_args_ssh_tunnel: "-o ProxyCommand='ssh -W %h:%p firewall.example.org'"

The command will become to
ssh -o ProxyCommand='ssh -W %h:%p firewall.example.org' server2.example.org

Bug: 117625814
Test: atest acloud_test --host&
      acloud create (cloudtop)
      acloud reconnect (cloudtop)

Change-Id: I791568aa4829bb30567be38d6e08d75e37195f84
diff --git a/internal/lib/utils.py b/internal/lib/utils.py
index 5b085f4..ed9d343 100755
--- a/internal/lib/utils.py
+++ b/internal/lib/utils.py
@@ -25,6 +25,7 @@
 import logging
 import os
 import platform
+import shlex
 import shutil
 import signal
 import struct
@@ -830,7 +831,7 @@
 
 # pylint: disable=too-many-locals
 def AutoConnect(ip_addr, rsa_key_file, target_vnc_port, target_adb_port,
-                ssh_user, client_adb_port=None):
+                ssh_user, client_adb_port=None, extra_args_ssh_tunnel=None):
     """Autoconnect to an AVD instance.
 
     Args:
@@ -842,6 +843,7 @@
         target_adb_port: Integer of target adb port number.
         ssh_user: String of user login into the instance.
         client_adb_port: Integer, Specified adb port to establish connection.
+        extra_args_ssh_tunnel: String, extra args for ssh tunnel connection.
 
     Returns:
         NamedTuple of (vnc_port, adb_port) SSHTUNNEL of the connect, both are
@@ -858,10 +860,13 @@
             "target_adb_port": target_adb_port,
             "ssh_user": ssh_user,
             "ip_addr": ip_addr}
-        _ExecuteCommand(constants.SSH_BIN, ssh_tunnel_args.split())
-    except subprocess.CalledProcessError:
-        PrintColorString("Failed to create ssh tunnels, retry with '#acloud "
-                         "reconnect'.", TextColors.FAIL)
+        ssh_tunnel_args_list = shlex.split(ssh_tunnel_args)
+        if extra_args_ssh_tunnel:
+            ssh_tunnel_args_list.extend(shlex.split(extra_args_ssh_tunnel))
+        _ExecuteCommand(constants.SSH_BIN, ssh_tunnel_args_list)
+    except subprocess.CalledProcessError as e:
+        PrintColorString("\n%s\nFailed to create ssh tunnels, retry with '#acloud "
+                         "reconnect'." % e, TextColors.FAIL)
         return ForwardedPorts(vnc_port=None, adb_port=None)
 
     try:
diff --git a/internal/lib/utils_test.py b/internal/lib/utils_test.py
index 10d7a70..792ed8a 100644
--- a/internal/lib/utils_test.py
+++ b/internal/lib/utils_test.py
@@ -402,7 +402,7 @@
             self.fail("shouldn't timeout")
 
     def testAutoConnectCreateSSHTunnelFail(self):
-        """test auto connect."""
+        """Test auto connect."""
         fake_ip_addr = "1.1.1.1"
         fake_rsa_key_file = "/tmp/rsa_file"
         fake_target_vnc_port = 8888
@@ -418,6 +418,37 @@
                                                    target_adb_port,
                                                    ssh_user))
 
+    # pylint: disable=protected-access,no-member
+    def testExtraArgsSSHTunnel(self):
+        """Tesg extra args will be the same with expanded args."""
+        fake_ip_addr = "1.1.1.1"
+        fake_rsa_key_file = "/tmp/rsa_file"
+        fake_target_vnc_port = 8888
+        target_adb_port = 9999
+        ssh_user = "fake_user"
+        fake_port = 12345
+        self.Patch(utils, "PickFreePort", return_value=fake_port)
+        self.Patch(utils, "_ExecuteCommand")
+        self.Patch(subprocess, "check_call", return_value=True)
+        extra_args_ssh_tunnel = "-o command='shell %s %h' -o command1='ls -la'"
+        utils.AutoConnect(ip_addr=fake_ip_addr,
+                          rsa_key_file=fake_rsa_key_file,
+                          target_vnc_port=fake_target_vnc_port,
+                          target_adb_port=target_adb_port,
+                          ssh_user=ssh_user,
+                          client_adb_port=fake_port,
+                          extra_args_ssh_tunnel=extra_args_ssh_tunnel)
+        args_list = ["-i", "/tmp/rsa_file",
+                     "-o", "UserKnownHostsFile=/dev/null",
+                     "-o", "StrictHostKeyChecking=no",
+                     "-L", "12345:127.0.0.1:8888",
+                     "-L", "12345:127.0.0.1:9999",
+                     "-N", "-f", "-l", "fake_user", "1.1.1.1",
+                     "-o", "command=shell %s %h",
+                     "-o", "command1=ls -la"]
+        first_call_args = utils._ExecuteCommand.call_args_list[0][0]
+        self.assertEqual(first_call_args[1], args_list)
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/internal/proto/user_config.proto b/internal/proto/user_config.proto
index 131b42e..c3d618f 100755
--- a/internal/proto/user_config.proto
+++ b/internal/proto/user_config.proto
@@ -98,4 +98,7 @@
   // List of scopes that will be given to the instance
   // https://cloud.google.com/compute/docs/access/create-enable-service-accounts-for-instances#changeserviceaccountandscopes
   repeated string extra_scopes = 26;
+
+  // Provide some additional parameters to build the ssh tunnel.
+  optional string extra_args_ssh_tunnel = 27;
 }
diff --git a/public/actions/common_operations.py b/public/actions/common_operations.py
index 5a94c97..b2c5727 100644
--- a/public/actions/common_operations.py
+++ b/public/actions/common_operations.py
@@ -311,10 +311,13 @@
                 device_dict.update(device.build_info)
             if autoconnect:
                 forwarded_ports = utils.AutoConnect(
-                    ip, cfg.ssh_private_key_path,
-                    utils.AVD_PORT_DICT[avd_type].vnc_port,
-                    utils.AVD_PORT_DICT[avd_type].adb_port,
-                    getpass.getuser(), client_adb_port)
+                    ip_addr=ip,
+                    rsa_key_file=cfg.ssh_private_key_path,
+                    target_vnc_port=utils.AVD_PORT_DICT[avd_type].vnc_port,
+                    target_adb_port=utils.AVD_PORT_DICT[avd_type].adb_port,
+                    ssh_user=getpass.getuser(),
+                    client_adb_port=client_adb_port,
+                    extra_args_ssh_tunnel=cfg.extra_args_ssh_tunnel)
                 device_dict[constants.VNC_PORT] = forwarded_ports.vnc_port
                 device_dict[constants.ADB_PORT] = forwarded_ports.adb_port
             if device.instance_name in failures:
diff --git a/public/config.py b/public/config.py
index 78f4d1b..5a0566b 100755
--- a/public/config.py
+++ b/public/config.py
@@ -200,6 +200,8 @@
             usr_cfg.stable_cheeps_host_image_project or
             internal_cfg.default_usr_cfg.stable_cheeps_host_image_project)
 
+        self.extra_args_ssh_tunnel = usr_cfg.extra_args_ssh_tunnel
+
         self.common_hw_property_map = internal_cfg.common_hw_property_map
         self.hw_property = usr_cfg.hw_property
 
diff --git a/public/config_test.py b/public/config_test.py
index 272f6fc..3409242 100644
--- a/public/config_test.py
+++ b/public/config_test.py
@@ -44,6 +44,7 @@
 resolution: "1200x1200x1200x1200"
 client_id: "fake_client_id"
 client_secret: "fake_client_secret"
+extra_args_ssh_tunnel: "fake_extra_args_ssh_tunnel"
 metadata_variable {
     key: "metadata_1"
     value: "metadata_value_1"
@@ -132,6 +133,7 @@
         self.assertEqual(cfg.resolution, "1200x1200x1200x1200")
         self.assertEqual(cfg.client_id, "fake_client_id")
         self.assertEqual(cfg.client_secret, "fake_client_secret")
+        self.assertEqual(cfg.extra_args_ssh_tunnel, "fake_extra_args_ssh_tunnel")
         self.assertEqual(
             {key: val for key, val in cfg.metadata_variable.iteritems()},
             {"metadata_1": "metadata_value_1"})
diff --git a/public/device_driver.py b/public/device_driver.py
index d97e313..0de988d 100755
--- a/public/device_driver.py
+++ b/public/device_driver.py
@@ -405,8 +405,13 @@
             }
             if autoconnect:
                 forwarded_ports = utils.AutoConnect(
-                    ip, cfg.ssh_private_key_path, constants.GCE_VNC_PORT,
-                    constants.GCE_ADB_PORT, _SSH_USER, avd_spec.client_adb_port)
+                    ip_addr=ip,
+                    rsa_key_file=cfg.ssh_private_key_path,
+                    target_vnc_port=constants.GCE_VNC_PORT,
+                    target_adb_port=constants.GCE_ADB_PORT,
+                    ssh_user=_SSH_USER,
+                    client_adb_port=avd_spec.adb_port,
+                    extra_args_ssh_tunnel=cfg.extra_args_ssh_tunnel)
                 device_dict[constants.VNC_PORT] = forwarded_ports.vnc_port
                 device_dict[constants.ADB_PORT] = forwarded_ports.adb_port
             if device.instance_name in failures:
diff --git a/reconnect/reconnect.py b/reconnect/reconnect.py
index f8a434a..52f41e6 100644
--- a/reconnect/reconnect.py
+++ b/reconnect/reconnect.py
@@ -84,7 +84,10 @@
 
 
 @utils.TimeExecute(function_description="Reconnect instances")
-def ReconnectInstance(ssh_private_key_path, instance, reconnect_report):
+def ReconnectInstance(ssh_private_key_path,
+                      instance,
+                      reconnect_report,
+                      extra_args_ssh_tunnel=None):
     """Reconnect to the specified instance.
 
     It will:
@@ -98,6 +101,7 @@
                               e.g. ~/.ssh/acloud_rsa
         instance: list.Instance() object.
         reconnect_report: Report object.
+        extra_args_ssh_tunnel: String, extra args for ssh tunnel connection.
 
     Raises:
         errors.UnknownAvdType: Unable to reconnect to instance of unknown avd
@@ -119,11 +123,12 @@
     elif not instance.ssh_tunnel_is_connected and not instance.islocal:
         adb_cmd.DisconnectAdb()
         forwarded_ports = utils.AutoConnect(
-            instance.ip,
-            ssh_private_key_path,
-            utils.AVD_PORT_DICT[instance.avd_type].vnc_port,
-            utils.AVD_PORT_DICT[instance.avd_type].adb_port,
-            getpass.getuser())
+            ip_addr=instance.ip,
+            rsa_key_file=ssh_private_key_path,
+            target_vnc_port=utils.AVD_PORT_DICT[instance.avd_type].vnc_port,
+            target_adb_port=utils.AVD_PORT_DICT[instance.avd_type].adb_port,
+            ssh_user=getpass.getuser(),
+            extra_args_ssh_tunnel=extra_args_ssh_tunnel)
         vnc_port = forwarded_ports.vnc_port
         adb_port = forwarded_ports.adb_port
 
@@ -172,6 +177,9 @@
             continue
         if not instance.islocal:
             AddPublicSshRsaToInstance(cfg, getpass.getuser(), instance.name)
-        ReconnectInstance(cfg.ssh_private_key_path, instance, reconnect_report)
+        ReconnectInstance(cfg.ssh_private_key_path,
+                          instance,
+                          reconnect_report,
+                          cfg.extra_args_ssh_tunnel)
 
     utils.PrintDeviceSummary(reconnect_report)
diff --git a/reconnect/reconnect_test.py b/reconnect/reconnect_test.py
index 4673909..60522e6 100644
--- a/reconnect/reconnect_test.py
+++ b/reconnect/reconnect_test.py
@@ -72,24 +72,31 @@
         instance_object.display = ""
         utils.AutoConnect.call_count = 0
         instance_object.forwarding_vnc_port = 5555
+        extra_args_ssh_tunnel = None
         self.Patch(utils, "AutoConnect",
                    return_value=ForwardedPorts(vnc_port=11111, adb_port=22222))
         reconnect.ReconnectInstance(ssh_private_key_path, instance_object, fake_report)
-        utils.AutoConnect.assert_called_with(instance_object.ip,
-                                             ssh_private_key_path,
-                                             constants.CF_VNC_PORT,
-                                             constants.CF_ADB_PORT,
-                                             "fake_user")
+        utils.AutoConnect.assert_called_with(ip_addr=instance_object.ip,
+                                             rsa_key_file=ssh_private_key_path,
+                                             target_vnc_port=constants.CF_VNC_PORT,
+                                             target_adb_port=constants.CF_ADB_PORT,
+                                             ssh_user="fake_user",
+                                             extra_args_ssh_tunnel=extra_args_ssh_tunnel)
         utils.LaunchVncClient.assert_called_with(11111)
 
         instance_object.display = "999x777 (99)"
+        extra_args_ssh_tunnel = "fake_extra_args_ssh_tunnel"
         utils.AutoConnect.call_count = 0
-        reconnect.ReconnectInstance(ssh_private_key_path, instance_object, fake_report)
-        utils.AutoConnect.assert_called_with(instance_object.ip,
-                                             ssh_private_key_path,
-                                             constants.CF_VNC_PORT,
-                                             constants.CF_ADB_PORT,
-                                             "fake_user")
+        reconnect.ReconnectInstance(ssh_private_key_path,
+                                    instance_object,
+                                    fake_report,
+                                    extra_args_ssh_tunnel)
+        utils.AutoConnect.assert_called_with(ip_addr=instance_object.ip,
+                                             rsa_key_file=ssh_private_key_path,
+                                             target_vnc_port=constants.CF_VNC_PORT,
+                                             target_adb_port=constants.CF_ADB_PORT,
+                                             ssh_user="fake_user",
+                                             extra_args_ssh_tunnel=extra_args_ssh_tunnel)
         utils.LaunchVncClient.assert_called_with(11111, "999", "777")
 
         #test reconnect local instance.
@@ -98,7 +105,9 @@
         instance_object.forwarding_vnc_port = 5555
         instance_object.ssh_tunnel_is_connected = False
         utils.AutoConnect.call_count = 0
-        reconnect.ReconnectInstance(ssh_private_key_path, instance_object, fake_report)
+        reconnect.ReconnectInstance(ssh_private_key_path,
+                                    instance_object,
+                                    fake_report)
         utils.AutoConnect.assert_not_called()
         utils.LaunchVncClient.assert_called_with(5555)
 
@@ -115,24 +124,25 @@
         self.Patch(getpass, "getuser", return_value="fake_user")
         self.Patch(utils, "AutoConnect")
         self.Patch(reconnect, "StartVnc")
-
         #test reconnect remote instance when avd_type as gce.
         instance_object.avd_type = "gce"
         reconnect.ReconnectInstance(ssh_private_key_path, instance_object, fake_report)
-        utils.AutoConnect.assert_called_with(instance_object.ip,
-                                             ssh_private_key_path,
-                                             constants.GCE_VNC_PORT,
-                                             constants.GCE_ADB_PORT,
-                                             "fake_user")
+        utils.AutoConnect.assert_called_with(ip_addr=instance_object.ip,
+                                             rsa_key_file=ssh_private_key_path,
+                                             target_vnc_port=constants.GCE_VNC_PORT,
+                                             target_adb_port=constants.GCE_ADB_PORT,
+                                             ssh_user="fake_user",
+                                             extra_args_ssh_tunnel=None)
 
         #test reconnect remote instance when avd_type as cuttlefish.
         instance_object.avd_type = "cuttlefish"
         reconnect.ReconnectInstance(ssh_private_key_path, instance_object, fake_report)
-        utils.AutoConnect.assert_called_with(instance_object.ip,
-                                             ssh_private_key_path,
-                                             constants.CF_VNC_PORT,
-                                             constants.CF_ADB_PORT,
-                                             "fake_user")
+        utils.AutoConnect.assert_called_with(ip_addr=instance_object.ip,
+                                             rsa_key_file=ssh_private_key_path,
+                                             target_vnc_port=constants.CF_VNC_PORT,
+                                             target_adb_port=constants.CF_ADB_PORT,
+                                             ssh_user="fake_user",
+                                             extra_args_ssh_tunnel=None)
 
 
     def testReconnectInstanceUnknownAvdType(self):