| #!/usr/bin/python2 |
| # Copyright 2016 The Chromium OS Authors. All rights reserved. |
| # Use of this source code is governed by a BSD-style license that can be |
| # found in the LICENSE file. |
| |
| import unittest |
| |
| import common |
| from autotest_lib.client.common_lib import error |
| from autotest_lib.server.hosts import base_label_unittest, factory |
| from autotest_lib.server.hosts import host_info |
| |
| |
| class MockHost(object): |
| """Mock host object with no side effects.""" |
| def __init__(self, hostname, **args): |
| self._init_args = args |
| self._init_args['hostname'] = hostname |
| |
| |
| def job_start(self): |
| """Only method called by factory.""" |
| pass |
| |
| |
| class MockConnectivity(object): |
| """Mock connectivity object with no side effects.""" |
| def __init__(self, hostname, **args): |
| pass |
| |
| def run(self, *args, **kwargs): |
| pass |
| |
| def close(self): |
| pass |
| |
| |
| def _gen_mock_host(name, check_host=False): |
| """Create an identifiable mock host closs. |
| """ |
| return type('mock_host_%s' % name, (MockHost,), { |
| '_host_cls_name': name, |
| 'check_host': staticmethod(lambda host, timeout=None: check_host) |
| }) |
| |
| |
| def _gen_mock_conn(name): |
| """Create an identifiable mock connectivity class. |
| """ |
| return type('mock_conn_%s' % name, (MockConnectivity,), |
| {'_conn_cls_name': name}) |
| |
| |
| def _gen_machine_dict(hostname='localhost', labels=[], attributes={}): |
| """Generate a machine dictionary with the specified parameters. |
| |
| @param hostname: hostname of machine |
| @param labels: list of host labels |
| @param attributes: dict of host attributes |
| |
| @return: machine dict with mocked AFE Host object and fake AfeStore. |
| """ |
| afe_host = base_label_unittest.MockAFEHost(labels, attributes) |
| store = host_info.InMemoryHostInfoStore() |
| store.commit(host_info.HostInfo(labels, attributes)) |
| return {'hostname': hostname, |
| 'afe_host': afe_host, |
| 'host_info_store': store} |
| |
| |
| class CreateHostUnittests(unittest.TestCase): |
| """Tests for create_host function.""" |
| |
| def setUp(self): |
| """Prevent use of real Host and connectivity objects due to potential |
| side effects. |
| """ |
| self._orig_types = factory.host_types |
| self._orig_dict = factory.OS_HOST_DICT |
| self._orig_cros_host = factory.cros_host.CrosHost |
| self._orig_local_host = factory.local_host.LocalHost |
| self._orig_ssh_host = factory.ssh_host.SSHHost |
| |
| self.host_types = factory.host_types = [] |
| self.os_host_dict = factory.OS_HOST_DICT = {} |
| factory.cros_host.CrosHost = _gen_mock_host('cros_host') |
| factory.local_host.LocalHost = _gen_mock_conn('local') |
| factory.ssh_host.SSHHost = _gen_mock_conn('ssh') |
| |
| |
| def tearDown(self): |
| """Clean up mocks.""" |
| factory.host_types = self._orig_types |
| factory.OS_HOST_DICT = self._orig_dict |
| factory.cros_host.CrosHost = self._orig_cros_host |
| factory.local_host.LocalHost = self._orig_local_host |
| factory.ssh_host.SSHHost = self._orig_ssh_host |
| |
| |
| def test_use_specified(self): |
| """Confirm that the specified host class is used.""" |
| machine = _gen_machine_dict() |
| host_obj = factory.create_host( |
| machine, |
| _gen_mock_host('specified'), |
| ) |
| self.assertEqual(host_obj._host_cls_name, 'specified') |
| |
| |
| def test_detect_host_by_os_label(self): |
| """Confirm that the host object is selected by the os label. |
| """ |
| machine = _gen_machine_dict(labels=['os:foo']) |
| self.os_host_dict['foo'] = _gen_mock_host('foo') |
| host_obj = factory.create_host(machine) |
| self.assertEqual(host_obj._host_cls_name, 'foo') |
| |
| |
| def test_detect_host_by_os_type_attribute(self): |
| """Confirm that the host object is selected by the os_type attribute |
| and that the os_type attribute is preferred over the os label. |
| """ |
| machine = _gen_machine_dict(labels=['os:foo'], |
| attributes={'os_type': 'bar'}) |
| self.os_host_dict['foo'] = _gen_mock_host('foo') |
| self.os_host_dict['bar'] = _gen_mock_host('bar') |
| host_obj = factory.create_host(machine) |
| self.assertEqual(host_obj._host_cls_name, 'bar') |
| |
| |
| def test_detect_host_by_check_host(self): |
| """Confirm check_host logic chooses a host object when label/attribute |
| detection fails. |
| """ |
| machine = _gen_machine_dict() |
| self.host_types.append(_gen_mock_host('first', check_host=False)) |
| self.host_types.append(_gen_mock_host('second', check_host=True)) |
| self.host_types.append(_gen_mock_host('third', check_host=False)) |
| host_obj = factory.create_host(machine) |
| self.assertEqual(host_obj._host_cls_name, 'second') |
| |
| |
| def test_detect_host_fallback_to_cros_host(self): |
| """Confirm fallback to CrosHost when all other detection fails. |
| """ |
| machine = _gen_machine_dict() |
| host_obj = factory.create_host(machine) |
| self.assertEqual(host_obj._host_cls_name, 'cros_host') |
| |
| |
| def test_choose_connectivity_local(self): |
| """Confirm local connectivity class used when hostname is localhost. |
| """ |
| machine = _gen_machine_dict(hostname='localhost') |
| host_obj = factory.create_host(machine) |
| self.assertEqual(host_obj._conn_cls_name, 'local') |
| |
| |
| def test_choose_connectivity_ssh(self): |
| """Confirm ssh connectivity class used when configured and hostname |
| is not localhost. |
| """ |
| machine = _gen_machine_dict(hostname='somehost') |
| host_obj = factory.create_host(machine) |
| self.assertEqual(host_obj._conn_cls_name, 'ssh') |
| |
| |
| def test_argument_passthrough(self): |
| """Confirm that detected and specified arguments are passed through to |
| the host object. |
| """ |
| machine = _gen_machine_dict(hostname='localhost') |
| host_obj = factory.create_host(machine, foo='bar') |
| self.assertEqual(host_obj._init_args['hostname'], 'localhost') |
| self.assertTrue('afe_host' in host_obj._init_args) |
| self.assertTrue('host_info_store' in host_obj._init_args) |
| self.assertEqual(host_obj._init_args['foo'], 'bar') |
| |
| |
| def test_global_ssh_params(self): |
| """Confirm passing of ssh parameters set as globals. |
| """ |
| factory.ssh_user = 'foo' |
| factory.ssh_pass = 'bar' |
| factory.ssh_port = 1 |
| factory.ssh_verbosity_flag = 'baz' |
| factory.ssh_options = 'zip' |
| machine = _gen_machine_dict() |
| try: |
| host_obj = factory.create_host(machine) |
| self.assertEqual(host_obj._init_args['user'], 'foo') |
| self.assertEqual(host_obj._init_args['password'], 'bar') |
| self.assertEqual(host_obj._init_args['port'], 1) |
| self.assertEqual(host_obj._init_args['ssh_verbosity_flag'], 'baz') |
| self.assertEqual(host_obj._init_args['ssh_options'], 'zip') |
| finally: |
| del factory.ssh_user |
| del factory.ssh_pass |
| del factory.ssh_port |
| del factory.ssh_verbosity_flag |
| del factory.ssh_options |
| |
| |
| def test_host_attribute_ssh_params(self): |
| """Confirm passing of ssh parameters from host attributes. |
| """ |
| machine = _gen_machine_dict(attributes={'ssh_user': 'somebody', |
| 'ssh_port': 100, |
| 'ssh_verbosity_flag': 'verb', |
| 'ssh_options': 'options'}) |
| host_obj = factory.create_host(machine) |
| self.assertEqual(host_obj._init_args['user'], 'somebody') |
| self.assertEqual(host_obj._init_args['port'], 100) |
| self.assertEqual(host_obj._init_args['ssh_verbosity_flag'], 'verb') |
| self.assertEqual(host_obj._init_args['ssh_options'], 'options') |
| |
| |
| if __name__ == '__main__': |
| unittest.main() |
| |