| #!/usr/bin/python |
| #pylint: disable-msg=C0111 |
| |
| import datetime |
| import unittest |
| |
| import common |
| from autotest_lib.frontend import setup_django_environment |
| from autotest_lib.frontend.afe import frontend_test_utils |
| from autotest_lib.client.common_lib import host_queue_entry_states |
| from autotest_lib.database import database_connection |
| from autotest_lib.frontend.afe import models, model_attributes |
| from autotest_lib.scheduler import monitor_db |
| from autotest_lib.scheduler import scheduler_lib |
| from autotest_lib.scheduler import scheduler_models |
| |
| _DEBUG = False |
| |
| |
| class BaseSchedulerModelsTest(unittest.TestCase, |
| frontend_test_utils.FrontendTestMixin): |
| _config_section = 'AUTOTEST_WEB' |
| |
| def _do_query(self, sql): |
| self._database.execute(sql) |
| |
| |
| def _set_monitor_stubs(self): |
| # Clear the instance cache as this is a brand new database. |
| scheduler_models.DBObject._clear_instance_cache() |
| |
| self._database = ( |
| database_connection.TranslatingDatabase.get_test_database( |
| translators=scheduler_lib._DB_TRANSLATORS)) |
| self._database.connect(db_type='django') |
| self._database.debug = _DEBUG |
| |
| self.god.stub_with(scheduler_models, '_db', self._database) |
| |
| |
| def setUp(self): |
| self._frontend_common_setup() |
| self._set_monitor_stubs() |
| |
| |
| def tearDown(self): |
| self._database.disconnect() |
| self._frontend_common_teardown() |
| |
| |
| def _update_hqe(self, set, where=''): |
| query = 'UPDATE afe_host_queue_entries SET ' + set |
| if where: |
| query += ' WHERE ' + where |
| self._do_query(query) |
| |
| |
| class DBObjectTest(BaseSchedulerModelsTest): |
| |
| def test_compare_fields_in_row(self): |
| host = scheduler_models.Host(id=1) |
| fields = list(host._fields) |
| row_data = [getattr(host, fieldname) for fieldname in fields] |
| self.assertEqual({}, host._compare_fields_in_row(row_data)) |
| row_data[fields.index('hostname')] = 'spam' |
| self.assertEqual({'hostname': ('host1', 'spam')}, |
| host._compare_fields_in_row(row_data)) |
| row_data[fields.index('id')] = 23 |
| self.assertEqual({'hostname': ('host1', 'spam'), 'id': (1, 23)}, |
| host._compare_fields_in_row(row_data)) |
| |
| |
| def test_compare_fields_in_row_datetime_ignores_microseconds(self): |
| datetime_with_us = datetime.datetime(2009, 10, 07, 12, 34, 56, 7890) |
| datetime_without_us = datetime.datetime(2009, 10, 07, 12, 34, 56, 0) |
| class TestTable(scheduler_models.DBObject): |
| _table_name = 'test_table' |
| _fields = ('id', 'test_datetime') |
| tt = TestTable(row=[1, datetime_without_us]) |
| self.assertEqual({}, tt._compare_fields_in_row([1, datetime_with_us])) |
| |
| |
| def test_always_query(self): |
| host_a = scheduler_models.Host(id=2) |
| self.assertEqual(host_a.hostname, 'host2') |
| self._do_query('UPDATE afe_hosts SET hostname="host2-updated" ' |
| 'WHERE id=2') |
| host_b = scheduler_models.Host(id=2, always_query=True) |
| self.assert_(host_a is host_b, 'Cached instance not returned.') |
| self.assertEqual(host_a.hostname, 'host2-updated', |
| 'Database was not re-queried') |
| |
| # If either of these are called, a query was made when it shouldn't be. |
| host_a._compare_fields_in_row = lambda _: self.fail('eek! a query!') |
| host_a._update_fields_from_row = host_a._compare_fields_in_row |
| host_c = scheduler_models.Host(id=2, always_query=False) |
| self.assert_(host_a is host_c, 'Cached instance not returned') |
| |
| |
| def test_delete(self): |
| host = scheduler_models.Host(id=3) |
| host.delete() |
| host = self.assertRaises(scheduler_models.DBError, scheduler_models.Host, id=3, |
| always_query=False) |
| host = self.assertRaises(scheduler_models.DBError, scheduler_models.Host, id=3, |
| always_query=True) |
| |
| def test_save(self): |
| # Dummy Job to avoid creating a one in the HostQueueEntry __init__. |
| class MockJob(object): |
| def __init__(self, id, row): |
| pass |
| def tag(self): |
| return 'MockJob' |
| self.god.stub_with(scheduler_models, 'Job', MockJob) |
| hqe = scheduler_models.HostQueueEntry( |
| new_record=True, |
| row=[0, 1, 2, 'Queued', None, 0, 0, 0, '.', None, False, None, |
| None]) |
| hqe.save() |
| new_id = hqe.id |
| # Force a re-query and verify that the correct data was stored. |
| scheduler_models.DBObject._clear_instance_cache() |
| hqe = scheduler_models.HostQueueEntry(id=new_id) |
| self.assertEqual(hqe.id, new_id) |
| self.assertEqual(hqe.job_id, 1) |
| self.assertEqual(hqe.host_id, 2) |
| self.assertEqual(hqe.status, 'Queued') |
| self.assertEqual(hqe.meta_host, None) |
| self.assertEqual(hqe.active, False) |
| self.assertEqual(hqe.complete, False) |
| self.assertEqual(hqe.deleted, False) |
| self.assertEqual(hqe.execution_subdir, '.') |
| self.assertEqual(hqe.started_on, None) |
| self.assertEqual(hqe.finished_on, None) |
| |
| |
| class HostTest(BaseSchedulerModelsTest): |
| |
| def setUp(self): |
| super(HostTest, self).setUp() |
| self.old_config = scheduler_models.RESPECT_STATIC_LABELS |
| |
| |
| def tearDown(self): |
| super(HostTest, self).tearDown() |
| scheduler_models.RESPECT_STATIC_LABELS = self.old_config |
| |
| |
| def _setup_static_labels(self): |
| label1 = models.Label.objects.create(name='non_static_label') |
| non_static_platform = models.Label.objects.create( |
| name='static_platform', platform=False) |
| models.ReplacedLabel.objects.create(label_id=non_static_platform.id) |
| |
| static_label1 = models.StaticLabel.objects.create( |
| name='no_reference_label', platform=False) |
| static_platform = models.StaticLabel.objects.create( |
| name=non_static_platform.name, platform=True) |
| |
| host1 = models.Host.objects.create(hostname='test_host') |
| host1.labels.add(label1) |
| host1.labels.add(non_static_platform) |
| host1.static_labels.add(static_label1) |
| host1.static_labels.add(static_platform) |
| host1.save() |
| return host1 |
| |
| |
| def test_platform_and_labels_with_respect(self): |
| scheduler_models.RESPECT_STATIC_LABELS = True |
| test_host = self._setup_static_labels() |
| host = scheduler_models.Host(id=test_host.id) |
| platform, all_labels = host.platform_and_labels() |
| self.assertEqual(platform, 'static_platform') |
| self.assertNotIn('no_reference_label', all_labels) |
| self.assertEqual(all_labels, ['non_static_label', 'static_platform']) |
| |
| |
| def test_platform_and_labels_without_respect(self): |
| scheduler_models.RESPECT_STATIC_LABELS = False |
| test_host = self._setup_static_labels() |
| host = scheduler_models.Host(id=test_host.id) |
| platform, all_labels = host.platform_and_labels() |
| self.assertIsNone(platform) |
| self.assertEqual(all_labels, ['non_static_label', 'static_platform']) |
| |
| |
| def test_cmp_for_sort(self): |
| expected_order = [ |
| 'alice', 'Host1', 'host2', 'host3', 'host09', 'HOST010', |
| 'host10', 'host11', 'yolkfolk'] |
| hostname_idx = list(scheduler_models.Host._fields).index('hostname') |
| row = [None] * len(scheduler_models.Host._fields) |
| hosts = [] |
| for hostname in expected_order: |
| row[hostname_idx] = hostname |
| hosts.append(scheduler_models.Host(row=row, new_record=True)) |
| |
| host1 = hosts[expected_order.index('Host1')] |
| host010 = hosts[expected_order.index('HOST010')] |
| host10 = hosts[expected_order.index('host10')] |
| host3 = hosts[expected_order.index('host3')] |
| alice = hosts[expected_order.index('alice')] |
| self.assertEqual(0, scheduler_models.Host.cmp_for_sort(host10, host10)) |
| self.assertEqual(1, scheduler_models.Host.cmp_for_sort(host10, host010)) |
| self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host010, host10)) |
| self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host1, host10)) |
| self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host1, host010)) |
| self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host3, host10)) |
| self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host3, host010)) |
| self.assertEqual(1, scheduler_models.Host.cmp_for_sort(host3, host1)) |
| self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host1, host3)) |
| self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(alice, host3)) |
| self.assertEqual(1, scheduler_models.Host.cmp_for_sort(host3, alice)) |
| self.assertEqual(0, scheduler_models.Host.cmp_for_sort(alice, alice)) |
| |
| hosts.sort(cmp=scheduler_models.Host.cmp_for_sort) |
| self.assertEqual(expected_order, [h.hostname for h in hosts]) |
| |
| hosts.reverse() |
| hosts.sort(cmp=scheduler_models.Host.cmp_for_sort) |
| self.assertEqual(expected_order, [h.hostname for h in hosts]) |
| |
| |
| class HostQueueEntryTest(BaseSchedulerModelsTest): |
| def _create_hqe(self, dependency_labels=(), **create_job_kwargs): |
| job = self._create_job(**create_job_kwargs) |
| for label in dependency_labels: |
| job.dependency_labels.add(label) |
| hqes = list(scheduler_models.HostQueueEntry.fetch(where='job_id=%d' % job.id)) |
| self.assertEqual(1, len(hqes)) |
| return hqes[0] |
| |
| |
| def _check_hqe_labels(self, hqe, expected_labels): |
| expected_labels = set(expected_labels) |
| label_names = set(label.name for label in hqe.get_labels()) |
| self.assertEqual(expected_labels, label_names) |
| |
| |
| def test_get_labels_empty(self): |
| hqe = self._create_hqe(hosts=[1]) |
| labels = list(hqe.get_labels()) |
| self.assertEqual([], labels) |
| |
| |
| def test_get_labels_metahost(self): |
| hqe = self._create_hqe(metahosts=[2]) |
| self._check_hqe_labels(hqe, ['label2']) |
| |
| |
| def test_get_labels_dependencies(self): |
| hqe = self._create_hqe(dependency_labels=(self.label3,), |
| metahosts=[1]) |
| self._check_hqe_labels(hqe, ['label1', 'label3']) |
| |
| |
| def setup_abort_test(self, agent_finished=True): |
| """Setup the variables for testing abort method. |
| |
| @param agent_finished: True to mock agent is finished before aborting |
| the hqe. |
| @return hqe, dispatcher: Mock object of hqe and dispatcher to be used |
| to test abort method. |
| """ |
| hqe = self._create_hqe(hosts=[1]) |
| hqe.aborted = True |
| hqe.complete = False |
| hqe.status = models.HostQueueEntry.Status.STARTING |
| hqe.started_on = datetime.datetime.now() |
| |
| dispatcher = self.god.create_mock_class(monitor_db.Dispatcher, |
| 'Dispatcher') |
| agent = self.god.create_mock_class(monitor_db.Agent, 'Agent') |
| dispatcher.get_agents_for_entry.expect_call(hqe).and_return([agent]) |
| agent.is_done.expect_call().and_return(agent_finished) |
| return hqe, dispatcher |
| |
| |
| def test_abort_fail_with_unfinished_agent(self): |
| """abort should fail if the hqe still has agent not finished. |
| """ |
| hqe, dispatcher = self.setup_abort_test(agent_finished=False) |
| self.assertIsNone(hqe.finished_on) |
| with self.assertRaises(AssertionError): |
| hqe.abort(dispatcher) |
| self.god.check_playback() |
| # abort failed, finished_on should not be set |
| self.assertIsNone(hqe.finished_on) |
| |
| |
| def test_abort_success(self): |
| """abort should succeed if all agents for the hqe are finished. |
| """ |
| hqe, dispatcher = self.setup_abort_test(agent_finished=True) |
| self.assertIsNone(hqe.finished_on) |
| hqe.abort(dispatcher) |
| self.god.check_playback() |
| self.assertIsNotNone(hqe.finished_on) |
| |
| |
| def test_set_finished_on(self): |
| """Test that finished_on is set when hqe completes.""" |
| for status in host_queue_entry_states.Status.values: |
| hqe = self._create_hqe(hosts=[1]) |
| hqe.started_on = datetime.datetime.now() |
| hqe.job.update_field('shard_id', 3) |
| self.assertIsNone(hqe.finished_on) |
| hqe.set_status(status) |
| if status in host_queue_entry_states.COMPLETE_STATUSES: |
| self.assertIsNotNone(hqe.finished_on) |
| self.assertIsNone(hqe.job.shard_id) |
| else: |
| self.assertIsNone(hqe.finished_on) |
| self.assertEquals(hqe.job.shard_id, 3) |
| |
| |
| class JobTest(BaseSchedulerModelsTest): |
| def setUp(self): |
| super(JobTest, self).setUp() |
| |
| def _mock_create(**kwargs): |
| task = models.SpecialTask(**kwargs) |
| task.save() |
| self._tasks.append(task) |
| self.god.stub_with(models.SpecialTask.objects, 'create', _mock_create) |
| |
| |
| def _test_pre_job_tasks_helper(self, |
| reboot_before=model_attributes.RebootBefore.ALWAYS): |
| """ |
| Calls HQE._do_schedule_pre_job_tasks() and returns the created special |
| task |
| """ |
| self._tasks = [] |
| queue_entry = scheduler_models.HostQueueEntry.fetch('id = 1')[0] |
| queue_entry.job.reboot_before = reboot_before |
| queue_entry._do_schedule_pre_job_tasks() |
| return self._tasks |
| |
| |
| def test_job_request_abort(self): |
| django_job = self._create_job(hosts=[5, 6]) |
| job = scheduler_models.Job(django_job.id) |
| job.request_abort() |
| django_hqes = list(models.HostQueueEntry.objects.filter(job=job.id)) |
| for hqe in django_hqes: |
| self.assertTrue(hqe.aborted) |
| |
| |
| def _check_special_tasks(self, tasks, task_types): |
| self.assertEquals(len(tasks), len(task_types)) |
| for task, (task_type, queue_entry_id) in zip(tasks, task_types): |
| self.assertEquals(task.task, task_type) |
| self.assertEquals(task.host.id, 1) |
| if queue_entry_id: |
| self.assertEquals(task.queue_entry.id, queue_entry_id) |
| |
| |
| def test_run_asynchronous(self): |
| self._create_job(hosts=[1, 2]) |
| |
| tasks = self._test_pre_job_tasks_helper() |
| |
| self._check_special_tasks(tasks, [(models.SpecialTask.Task.RESET, 1)]) |
| |
| |
| def test_run_asynchronous_skip_verify(self): |
| job = self._create_job(hosts=[1, 2]) |
| job.run_verify = False |
| job.save() |
| |
| tasks = self._test_pre_job_tasks_helper() |
| |
| self._check_special_tasks(tasks, [(models.SpecialTask.Task.RESET, 1)]) |
| |
| |
| def test_run_synchronous_verify(self): |
| self._create_job(hosts=[1, 2], synchronous=True) |
| |
| tasks = self._test_pre_job_tasks_helper() |
| |
| self._check_special_tasks(tasks, [(models.SpecialTask.Task.RESET, 1)]) |
| |
| |
| def test_run_synchronous_skip_verify(self): |
| job = self._create_job(hosts=[1, 2], synchronous=True) |
| job.run_verify = False |
| job.save() |
| |
| tasks = self._test_pre_job_tasks_helper() |
| |
| self._check_special_tasks(tasks, [(models.SpecialTask.Task.RESET, 1)]) |
| |
| |
| def test_run_asynchronous_do_not_reset(self): |
| job = self._create_job(hosts=[1, 2]) |
| job.run_reset = False |
| job.run_verify = False |
| job.save() |
| |
| tasks = self._test_pre_job_tasks_helper() |
| |
| self.assertEquals(tasks, []) |
| |
| |
| def test_run_synchronous_do_not_reset_no_RebootBefore(self): |
| job = self._create_job(hosts=[1, 2], synchronous=True) |
| job.reboot_before = model_attributes.RebootBefore.NEVER |
| job.save() |
| |
| tasks = self._test_pre_job_tasks_helper( |
| reboot_before=model_attributes.RebootBefore.NEVER) |
| |
| self._check_special_tasks(tasks, [(models.SpecialTask.Task.VERIFY, 1)]) |
| |
| |
| def test_run_asynchronous_do_not_reset(self): |
| job = self._create_job(hosts=[1, 2], synchronous=False) |
| job.reboot_before = model_attributes.RebootBefore.NEVER |
| job.save() |
| |
| tasks = self._test_pre_job_tasks_helper( |
| reboot_before=model_attributes.RebootBefore.NEVER) |
| |
| self._check_special_tasks(tasks, [(models.SpecialTask.Task.VERIFY, 1)]) |
| |
| |
| def test_reboot_before_always(self): |
| job = self._create_job(hosts=[1]) |
| job.reboot_before = model_attributes.RebootBefore.ALWAYS |
| job.save() |
| |
| tasks = self._test_pre_job_tasks_helper() |
| |
| self._check_special_tasks(tasks, [ |
| (models.SpecialTask.Task.RESET, None) |
| ]) |
| |
| |
| def _test_reboot_before_if_dirty_helper(self): |
| job = self._create_job(hosts=[1]) |
| job.reboot_before = model_attributes.RebootBefore.IF_DIRTY |
| job.save() |
| |
| tasks = self._test_pre_job_tasks_helper() |
| task_types = [(models.SpecialTask.Task.RESET, None)] |
| |
| self._check_special_tasks(tasks, task_types) |
| |
| |
| def test_reboot_before_if_dirty(self): |
| models.Host.smart_get(1).update_object(dirty=True) |
| self._test_reboot_before_if_dirty_helper() |
| |
| |
| def test_reboot_before_not_dirty(self): |
| models.Host.smart_get(1).update_object(dirty=False) |
| self._test_reboot_before_if_dirty_helper() |
| |
| |
| if __name__ == '__main__': |
| unittest.main() |