blob: 3289bcdf22f004e9279b0a346a86fe147e513115 [file] [log] [blame]
Justin Giorgi5208eaa2016-07-02 20:12:12 -07001#!/usr/bin/python
2# Copyright 2016 The Chromium OS Authors. All rights reserved.
3# Use of this source code is governed by a BSD-style license that can be
4# found in the LICENSE file.
5
6import unittest
7
8import common
9from autotest_lib.client.common_lib import error
10from autotest_lib.server.hosts import base_label_unittest, factory
11from autotest_lib.server.hosts import paramiko_host
12
13
14class MockHost(object):
15 """Mock host object with no side effects."""
16 def __init__(self, hostname, **args):
17 self._init_args = args
18 self._init_args['hostname'] = hostname
19
20
21 def job_start(self):
22 """Only method called by factory."""
23 pass
24
25
26class MockConnectivity(object):
27 """Mock connectivity object with no side effects."""
28 def __init__(self, hostname, **args):
29 pass
30
31
32 def close(self):
33 """Only method called by factory."""
34 pass
35
36
37def _gen_mock_host(name, check_host=False):
38 """Create an identifiable mock host closs.
39 """
40 return type('mock_host_%s' % name, (MockHost,), {
41 '_host_cls_name': name,
42 'check_host': staticmethod(lambda host, timeout=None: check_host)
43 })
44
45
46def _gen_mock_conn(name):
47 """Create an identifiable mock connectivity class.
48 """
49 return type('mock_conn_%s' % name, (MockConnectivity,),
50 {'_conn_cls_name': name})
51
52
53def _gen_machine_dict(hostname='localhost', labels=[], attributes={}):
54 """Generate a machine dictionary with the specified parameters.
55
56 @param hostname: hostname of machine
57 @param labels: list of host labels
58 @param attributes: dict of host attributes
59
60 @return: machine dict with mocked AFE Host object.
61 """
62 afe_host = base_label_unittest.MockAFEHost(labels, attributes)
63 return {'hostname': hostname, 'afe_host': afe_host}
64
65
66class CreateHostUnittests(unittest.TestCase):
67 """Tests for create_host function."""
68
69 def setUp(self):
70 """Prevent use of real Host and connectivity objects due to potential
71 side effects.
72 """
73 self._orig_ssh_engine = factory.SSH_ENGINE
74 self._orig_types = factory.host_types
75 self._orig_dict = factory.OS_HOST_DICT
76 self._orig_cros_host = factory.cros_host.CrosHost
77 self._orig_local_host = factory.local_host.LocalHost
78 self._orig_ssh_host = factory.ssh_host.SSHHost
79 self._orig_paramiko_host = paramiko_host.ParamikoHost
80
81 self.host_types = factory.host_types = []
82 self.os_host_dict = factory.OS_HOST_DICT = {}
83 factory.cros_host.CrosHost = _gen_mock_host('cros_host')
84 factory.local_host.LocalHost = _gen_mock_conn('local')
85 factory.ssh_host.SSHHost = _gen_mock_conn('ssh')
86 paramiko_host.ParamikoHost = _gen_mock_conn('paramiko')
87
88
89 def tearDown(self):
90 """Clean up mocks."""
91 factory.SSH_ENGINE = self._orig_ssh_engine
92 factory.host_types = self._orig_types
93 factory.OS_HOST_DICT = self._orig_dict
94 factory.cros_host.CrosHost = self._orig_cros_host
95 factory.local_host.LocalHost = self._orig_local_host
96 factory.ssh_host.SSHHost = self._orig_ssh_host
97 paramiko_host.ParamikoHost = self._orig_paramiko_host
98
99
100 def _gen_machine_dict(self, hostname='localhost', labels=[], attributes={}):
101 """Generate a machine dictionary with the specified parameters.
102
103 @param hostname: hostname of machine
104 @param labels: list of host labels
105 @param attributes: dict of host attributes
106
107 @return: machine dict with mocked AFE Host object.
108 """
109 afe_host = base_label_unittest.MockAFEHost(labels, attributes)
110 return {'hostname': hostname, 'afe_host': afe_host}
111
112
113 def test_use_specified(self):
114 """Confirm that the specified host and connectivity classes are used."""
115 machine = _gen_machine_dict()
116 host_obj = factory.create_host(
117 machine,
118 _gen_mock_host('specified'),
119 _gen_mock_conn('specified')
120 )
121 self.assertEqual(host_obj._host_cls_name, 'specified')
122 self.assertEqual(host_obj._conn_cls_name, 'specified')
123
124
125 def test_detect_host_by_os_label(self):
126 """Confirm that the host object is selected by the os label.
127 """
128 machine = _gen_machine_dict(labels=['os:foo'])
129 self.os_host_dict['foo'] = _gen_mock_host('foo')
130 host_obj = factory.create_host(machine)
131 self.assertEqual(host_obj._host_cls_name, 'foo')
132
133
134 def test_detect_host_by_os_type_attribute(self):
135 """Confirm that the host object is selected by the os_type attribute
136 and that the os_type attribute is preferred over the os label.
137 """
138 machine = _gen_machine_dict(labels=['os:foo'],
139 attributes={'os_type': 'bar'})
140 self.os_host_dict['foo'] = _gen_mock_host('foo')
141 self.os_host_dict['bar'] = _gen_mock_host('bar')
142 host_obj = factory.create_host(machine)
143 self.assertEqual(host_obj._host_cls_name, 'bar')
144
145
146 def test_detect_host_by_check_host(self):
147 """Confirm check_host logic chooses a host object when label/attribute
148 detection fails.
149 """
150 machine = _gen_machine_dict()
151 self.host_types.append(_gen_mock_host('first', check_host=False))
152 self.host_types.append(_gen_mock_host('second', check_host=True))
153 self.host_types.append(_gen_mock_host('third', check_host=False))
154 host_obj = factory.create_host(machine)
155 self.assertEqual(host_obj._host_cls_name, 'second')
156
157
158 def test_detect_host_fallback_to_cros_host(self):
159 """Confirm fallback to CrosHost when all other detection fails.
160 """
161 machine = _gen_machine_dict()
162 host_obj = factory.create_host(machine)
163 self.assertEqual(host_obj._host_cls_name, 'cros_host')
164
165
166 def test_choose_connectivity_local(self):
167 """Confirm local connectivity class used when hostname is localhost.
168 """
169 machine = _gen_machine_dict(hostname='localhost')
170 host_obj = factory.create_host(machine)
171 self.assertEqual(host_obj._conn_cls_name, 'local')
172
173
174 def test_choose_connectivity_paramiko(self):
175 """Confirm paramiko connectivity class used when configured and
176 hostname is not localhost.
177 """
178 factory.SSH_ENGINE = 'paramiko'
179 machine = _gen_machine_dict(hostname='somehost')
180 host_obj = factory.create_host(machine)
181 self.assertEqual(host_obj._conn_cls_name, 'paramiko')
182
183
184 def test_choose_connectivity_ssh(self):
185 """Confirm ssh connectivity class used when configured and hostname
186 is not localhost.
187 """
188 factory.SSH_ENGINE = 'raw_ssh'
189 machine = _gen_machine_dict(hostname='somehost')
190 host_obj = factory.create_host(machine)
191 self.assertEqual(host_obj._conn_cls_name, 'ssh')
192
193
194 def test_choose_connectivity_unsupported(self):
195 """Confirm exception when configured for unsupported ssh engine.
196 """
197 factory.SSH_ENGINE = 'unsupported'
198 machine = _gen_machine_dict(hostname='somehost')
199 with self.assertRaises(error.AutoservError):
200 factory.create_host(machine)
201
202
203 def test_argument_passthrough(self):
204 """Confirm that detected and specified arguments are passed through to
205 the host object.
206 """
207 machine = _gen_machine_dict(hostname='localhost')
208 host_obj = factory.create_host(machine, foo='bar')
209 self.assertEqual(host_obj._init_args['hostname'], 'localhost')
210 self.assertTrue('afe_host' in host_obj._init_args)
211 self.assertEqual(host_obj._init_args['foo'], 'bar')
212
213
214 def test_global_ssh_params(self):
215 """Confirm passing of ssh parameters set as globals.
216 """
217 factory.ssh_user = 'foo'
218 factory.ssh_pass = 'bar'
219 factory.ssh_port = 1
220 factory.ssh_verbosity_flag = 'baz'
221 factory.ssh_options = 'zip'
222 machine = _gen_machine_dict()
223 try:
224 host_obj = factory.create_host(machine)
225 self.assertEqual(host_obj._init_args['user'], 'foo')
226 self.assertEqual(host_obj._init_args['password'], 'bar')
227 self.assertEqual(host_obj._init_args['port'], 1)
228 self.assertEqual(host_obj._init_args['ssh_verbosity_flag'], 'baz')
229 self.assertEqual(host_obj._init_args['ssh_options'], 'zip')
230 finally:
231 del factory.ssh_user
232 del factory.ssh_pass
233 del factory.ssh_port
234 del factory.ssh_verbosity_flag
235 del factory.ssh_options
236
237
238 def test_host_attribute_ssh_params(self):
239 """Confirm passing of ssh parameters from host attributes.
240 """
241 machine = _gen_machine_dict(attributes={'ssh_user': 'somebody',
242 'ssh_port': 100,
243 'ssh_verbosity_flag': 'verb',
244 'ssh_options': 'options'})
245 host_obj = factory.create_host(machine)
246 self.assertEqual(host_obj._init_args['user'], 'somebody')
247 self.assertEqual(host_obj._init_args['port'], 100)
248 self.assertEqual(host_obj._init_args['ssh_verbosity_flag'], 'verb')
249 self.assertEqual(host_obj._init_args['ssh_options'], 'options')
250
251
252class CreateTestbedUnittests(unittest.TestCase):
253 """Tests for create_testbed function."""
254
255 def setUp(self):
256 """Mock out TestBed class to eliminate side effects.
257 """
258 self._orig_testbed = factory.testbed.TestBed
259 factory.testbed.TestBed = _gen_mock_host('testbed')
260
261
262 def tearDown(self):
263 """Clean up mock.
264 """
265 factory.testbed.TestBed = self._orig_testbed
266
267
268 def test_argument_passthrough(self):
269 """Confirm that detected and specified arguments are passed through to
270 the testbed object.
271 """
272 machine = _gen_machine_dict(hostname='localhost')
273 testbed_obj = factory.create_testbed(machine, foo='bar')
274 self.assertEqual(testbed_obj._init_args['hostname'], 'localhost')
275 self.assertTrue('afe_host' in testbed_obj._init_args)
276 self.assertEqual(testbed_obj._init_args['foo'], 'bar')
277
278
279 def test_global_ssh_params(self):
280 """Confirm passing of ssh parameters set as globals.
281 """
282 factory.ssh_user = 'foo'
283 factory.ssh_pass = 'bar'
284 factory.ssh_port = 1
285 factory.ssh_verbosity_flag = 'baz'
286 factory.ssh_options = 'zip'
287 machine = _gen_machine_dict()
288 try:
289 testbed_obj = factory.create_testbed(machine)
290 self.assertEqual(testbed_obj._init_args['user'], 'foo')
291 self.assertEqual(testbed_obj._init_args['password'], 'bar')
292 self.assertEqual(testbed_obj._init_args['port'], 1)
293 self.assertEqual(testbed_obj._init_args['ssh_verbosity_flag'],
294 'baz')
295 self.assertEqual(testbed_obj._init_args['ssh_options'], 'zip')
296 finally:
297 del factory.ssh_user
298 del factory.ssh_pass
299 del factory.ssh_port
300 del factory.ssh_verbosity_flag
301 del factory.ssh_options
302
303
304 def test_host_attribute_ssh_params(self):
305 """Confirm passing of ssh parameters from host attributes.
306 """
307 machine = _gen_machine_dict(attributes={'ssh_user': 'somebody',
308 'ssh_port': 100,
309 'ssh_verbosity_flag': 'verb',
310 'ssh_options': 'options'})
311 testbed_obj = factory.create_testbed(machine)
312 self.assertEqual(testbed_obj._init_args['user'], 'somebody')
313 self.assertEqual(testbed_obj._init_args['port'], 100)
314 self.assertEqual(testbed_obj._init_args['ssh_verbosity_flag'], 'verb')
315 self.assertEqual(testbed_obj._init_args['ssh_options'], 'options')
316
317
318if __name__ == '__main__':
319 unittest.main()
320