Get the scheduler unittest to run against SQLite!

* get rid of monitor_db.DatabaseConn, and make monitor_db use the new DatabaseConnection
* modify some queries in monitor_db that weren't SQLite-compatible (SQLite doesn't support TRUE and FALSE literals)
* add frontend/django_test_utils.py, which contains utilities to
 * setup a django environment (something manage.py normally does for you)
 * replace the configured DB with a SQLite one, either in-memory or on disk
 * run syncdb on the test DB
 * backup and restore the test DB, handy because then we can syncdb once, save the fresh DB, and quickly restore it between unittests without having to run syncdb again (syncdb is terribly slow for whatever reason)
* modify monitor_db_unittest to use these methods to set up a temporary SQLite DB, run syncdb on it, and test against it
* replace much of the data modification code in monitor_db_unittest with use of the django models.  The INSERTs were very problematic with SQLite because syncdb doesn't set database defaults, but using the models solves that (django inserts the defaults itself). using the models is much cleaner anyway as you can see.  it was just difficult to do before, but now that we've got the infrastructure to setup the environment anyway, it's easy.  this is a good model for how we can make the scheduler use the django models eventually.
* reorder fields of Label model to match actual DB ordering; this is necessary since monitor_db depends on field ordering
* add defaults to some fields in AFE models that should've had them
* make DatabaseConnection.get_test_database support SQLite in files, which gives us persistence that is necessary and handy in the scheduler unittest
* add a fix to _SqliteBackend for pysqlite2 crappiness

The following are extras that weren't strictly necessary to get things working:
* add a debug feature to DatabaseConnection to print all queries
* add an execute_script method to DatabaseConnection (it was duplicated in migrate and monitor_db_unittest)
* rename "arguments" to "parameters" in _GenericBackend.execute, to match the DB-API names
* get rid of some debug code that was left in monitor_db, and one unnecessary statement

Signed-off-by: Steve Howard <showard@google.com>


git-svn-id: http://test.kernel.org/svn/autotest/trunk@2252 592f7852-d20e-0410-864c-8624ca9c26a4
diff --git a/database/database_connection.py b/database/database_connection.py
index 6d61728..fa870c9 100644
--- a/database/database_connection.py
+++ b/database/database_connection.py
@@ -38,8 +38,8 @@
         self._cursor = None
 
 
-    def execute(self, query, arguments=None):
-        self._cursor.execute(query, arguments)
+    def execute(self, query, parameters=None):
+        self._cursor.execute(query, parameters)
         self.rowcount = self._cursor.rowcount
         return self._cursor.fetchall()
 
@@ -88,11 +88,15 @@
         self._cursor = self._connection.cursor()
 
 
-    def execute(self, query, arguments=None):
+    def execute(self, query, parameters=None):
         # pysqlite2 uses paramstyle=qmark
         # TODO: make this more sophisticated if necessary
         query = query.replace('%s', '?')
-        return super(_SqliteBackend, self).execute(query, arguments)
+        # pysqlite2 can't handle parameters=None (it throws a nonsense
+        # exception)
+        if parameters is None:
+            parameters = ()
+        return super(_SqliteBackend, self).execute(query, parameters)
 
 
 _BACKEND_MAP = {
@@ -116,6 +120,7 @@
     * global_config_section - the section in which to find DB information. this
       should be passed to the constructor, not set later, and may be None, in
       which case information must be passed to connect().
+    * debug - if set True, all queries will be printed before being executed
     """
     _DATABASE_ATTRIBUTES = ('db_type', 'host', 'username', 'password',
                             'db_name')
@@ -124,6 +129,7 @@
         self.global_config_section = global_config_section
         self._backend = None
         self.rowcount = None
+        self.debug = False
 
         # reconnect defaults
         self.reconnect_enabled = True
@@ -217,6 +223,8 @@
         Execute a query and return cursor.fetchall(). try_reconnecting, if
         passed, will override self.reconnect_enabled.
         """
+        if self.debug:
+            print 'Executing %s, %s' % (query, parameters)
         # _connect_backend() contains a retry loop, so don't loop here
         try:
             results = self._backend.execute(query, parameters)
@@ -239,12 +247,12 @@
 
 
     @classmethod
-    def get_test_database(cls):
+    def get_test_database(cls, file_path=':memory:'):
         """
         Factory method returning a DatabaseConnection for a temporary in-memory
         database.
         """
         database = cls()
         database.reconnect_enabled = False
-        database.connect(db_type='sqlite', db_name=':memory:')
+        database.connect(db_type='sqlite', db_name=file_path)
         return database
diff --git a/database/migrate.py b/database/migrate.py
index 37c7857..0d57c8b 100644
--- a/database/migrate.py
+++ b/database/migrate.py
@@ -59,16 +59,15 @@
 
 
     def execute(self, query, *parameters):
-        #print 'SQL:', query % parameters
         return self._database.execute(query, parameters)
 
 
     def execute_script(self, script):
-        sql_statements = [statement.strip() for statement
-                          in script.split(';')]
+        sql_statements = [statement.strip()
+                          for statement in script.split(';')
+                          if statement.strip()]
         for statement in sql_statements:
-            if statement:
-                self.execute(statement)
+            self.execute(statement)
 
 
     def check_migrate_table_exists(self):
diff --git a/frontend/afe/models.py b/frontend/afe/models.py
index d578fd5..4f94a5c 100644
--- a/frontend/afe/models.py
+++ b/frontend/afe/models.py
@@ -2,7 +2,7 @@
 from django.db import models as dbmodels, connection
 from frontend.afe import model_logic
 from frontend import settings, thread_local
-from autotest_lib.client.common_lib import enum, host_protections
+from autotest_lib.client.common_lib import enum, host_protections, global_config
 
 
 class AclAccessViolation(Exception):
@@ -27,9 +27,9 @@
     name = dbmodels.CharField(maxlength=255, unique=True)
     kernel_config = dbmodels.CharField(maxlength=255, blank=True)
     platform = dbmodels.BooleanField(default=False)
-    only_if_needed = dbmodels.BooleanField(default=False)
     invalid = dbmodels.BooleanField(default=False,
                                     editable=settings.FULL_ADMIN)
+    only_if_needed = dbmodels.BooleanField(default=False)
 
     name_field = 'name'
     objects = model_logic.ExtendedManager()
@@ -520,6 +520,9 @@
     dependency_labels: many-to-many relationship with labels corresponding to
                        job dependencies
     """
+    DEFAULT_TIMEOUT = global_config.global_config.get_config_value(
+        'AUTOTEST_WEB', 'job_timeout_default', default=240)
+
     Priority = enum.Enum('Low', 'Medium', 'High', 'Urgent')
     ControlType = enum.Enum('Server', 'Client', start_value=1)
     Status = enum.Enum('Created', 'Queued', 'Pending', 'Running',
@@ -533,14 +536,15 @@
                                           default=Priority.MEDIUM)
     control_file = dbmodels.TextField()
     control_type = dbmodels.SmallIntegerField(choices=ControlType.choices(),
-                                              blank=True) # to allow 0
+                                              blank=True, # to allow 0
+                                              default=ControlType.CLIENT)
     created_on = dbmodels.DateTimeField(auto_now_add=True)
     synch_type = dbmodels.SmallIntegerField(
         blank=True, null=True, choices=Test.SynchType.choices())
     synch_count = dbmodels.IntegerField(blank=True, null=True)
     synchronizing = dbmodels.BooleanField(default=False)
     run_verify = dbmodels.BooleanField(default=True)
-    timeout = dbmodels.IntegerField()
+    timeout = dbmodels.IntegerField(default=DEFAULT_TIMEOUT)
     email_list = dbmodels.CharField(maxlength=250, blank=True)
     dependency_labels = dbmodels.ManyToManyField(
         Label, blank=True, filter_interface=dbmodels.HORIZONTAL)
diff --git a/frontend/afe/rpc_interface.py b/frontend/afe/rpc_interface.py
index 17fcf70..41048ab 100644
--- a/frontend/afe/rpc_interface.py
+++ b/frontend/afe/rpc_interface.py
@@ -546,8 +546,7 @@
     result['user_login'] = thread_local.get_user().login
     result['host_statuses'] = sorted(models.Host.Status.names)
     result['job_statuses'] = sorted(models.Job.Status.names)
-    result['job_timeout_default'] = global_config.global_config.get_config_value(
-        'AUTOTEST_WEB', 'job_timeout_default')
+    result['job_timeout_default'] = models.Job.DEFAULT_TIMEOUT
 
     result['status_dictionary'] = {"Abort": "Abort",
                                    "Aborted": "Aborted",
diff --git a/frontend/django_test_utils.py b/frontend/django_test_utils.py
new file mode 100644
index 0000000..89ec7d9
--- /dev/null
+++ b/frontend/django_test_utils.py
@@ -0,0 +1,42 @@
+import tempfile, shutil, os
+from django.core import management
+from django.conf import settings
+# can't import any django code that depends on the environment setup
+import common
+
+def setup_test_environ():
+    from autotest_lib.frontend import settings
+    management.setup_environ(settings)
+    from django.conf import settings
+    # django.conf.settings.LazySettings is buggy and requires us to get
+    # something from it before we set stuff on it
+    getattr(settings, 'DATABASE_ENGINE')
+    settings.DATABASE_ENGINE = 'sqlite3'
+    settings.DATABASE_NAME = ':memory:'
+
+
+def set_test_database(database):
+    from django.db import connection
+    settings.DATABASE_NAME = database
+    connection.close()
+
+
+def backup_test_database():
+    temp_fd, backup_path = tempfile.mkstemp(suffix='.test_db_backup')
+    os.close(temp_fd)
+    shutil.copyfile(settings.DATABASE_NAME, backup_path)
+    return backup_path
+
+
+def restore_test_database(backup_path):
+    from django.db import connection
+    connection.close()
+    shutil.copyfile(backup_path, settings.DATABASE_NAME)
+
+
+def cleanup_database_backup(backup_path):
+    os.remove(backup_path)
+
+
+def run_syncdb(verbosity=0):
+    management.syncdb(verbosity, interactive=False)
diff --git a/scheduler/monitor_db.py b/scheduler/monitor_db.py
index cb237ad..10cac39 100644
--- a/scheduler/monitor_db.py
+++ b/scheduler/monitor_db.py
@@ -10,10 +10,12 @@
 import common
 from autotest_lib.client.common_lib import global_config
 from autotest_lib.client.common_lib import host_protections, utils
+from autotest_lib.database import database_connection
 
 
 RESULTS_DIR = '.'
 AUTOSERV_NICE_LEVEL = 10
+CONFIG_SECTION = 'AUTOTEST_WEB'
 
 AUTOTEST_PATH = os.path.join(os.path.dirname(__file__), '..')
 
@@ -78,7 +80,7 @@
 
     # read in base url
     global _base_url
-    val = c.get_config_value("AUTOTEST_WEB", "base_url")
+    val = c.get_config_value(CONFIG_SECTION, "base_url")
     if val:
         _base_url = val
     else:
@@ -111,9 +113,13 @@
     print "%s> dispatcher starting" % time.strftime("%X %x")
     print "My PID is %d" % os.getpid()
 
+    if _testing_mode:
+        global_config.global_config.override_config_value(
+            CONFIG_SECTION, 'database', 'stresstest_autotest_web')
+
     os.environ['PATH'] = AUTOTEST_SERVER_DIR + ':' + os.environ['PATH']
     global _db
-    _db = DatabaseConn()
+    _db = database_connection.DatabaseConnection(CONFIG_SECTION)
     _db.connect()
 
     print "Setting signal handler"
@@ -152,73 +158,6 @@
         os.remove(path)
 
 
-class DatabaseConn:
-    def __init__(self):
-        self.reconnect_wait = 20
-        self.conn = None
-        self.cur = None
-
-        import MySQLdb.converters
-        self.convert_dict = MySQLdb.converters.conversions
-        self.convert_dict.setdefault(bool, self.convert_boolean)
-
-
-    @staticmethod
-    def convert_boolean(boolean, conversion_dict):
-        'Convert booleans to integer strings'
-        return str(int(boolean))
-
-
-    def connect(self, db_name=None):
-        self.disconnect()
-
-        # get global config and parse for info
-        c = global_config.global_config
-        dbase = "AUTOTEST_WEB"
-        db_host = c.get_config_value(dbase, "host")
-        if db_name is None:
-            db_name = c.get_config_value(dbase, "database")
-
-        if _testing_mode:
-            db_name = 'stresstest_autotest_web'
-
-        db_user = c.get_config_value(dbase, "user")
-        db_pass = c.get_config_value(dbase, "password")
-
-        while not self.conn:
-            try:
-                self.conn = MySQLdb.connect(
-                    host=db_host, user=db_user, passwd=db_pass,
-                    db=db_name, conv=self.convert_dict)
-
-                self.conn.autocommit(True)
-                self.cur = self.conn.cursor()
-            except MySQLdb.OperationalError:
-                traceback.print_exc()
-                print "Can't connect to MYSQL; reconnecting"
-                time.sleep(self.reconnect_wait)
-                self.disconnect()
-
-
-    def disconnect(self):
-        if self.conn:
-            self.conn.close()
-        self.conn = None
-        self.cur = None
-
-
-    def execute(self, *args, **dargs):
-        while (True):
-            try:
-                self.cur.execute(*args, **dargs)
-                return self.cur.fetchall()
-            except MySQLdb.OperationalError:
-                traceback.print_exc()
-                print "MYSQL connection died; reconnecting"
-                time.sleep(self.reconnect_wait)
-                self.connect()
-
-
 def generate_parse_command(results_dir, flags=""):
     parse = os.path.abspath(os.path.join(AUTOTEST_TKO_DIR, 'parse'))
     output = os.path.abspath(os.path.join(results_dir, '.parse.log'))
@@ -320,9 +259,9 @@
         hosts = Host.fetch(
             joins='LEFT JOIN host_queue_entries AS active_hqe '
                   'ON (hosts.id = active_hqe.host_id AND '
-                      'active_hqe.active = TRUE)',
+                      'active_hqe.active)',
             where="active_hqe.host_id IS NULL "
-                  "AND hosts.locked = FALSE "
+                  "AND NOT hosts.locked "
                   "AND (hosts.status IS NULL OR hosts.status = 'Ready')")
         return dict((host.id, host) for host in hosts)
 
@@ -763,14 +702,14 @@
 
         _db.execute(update + """
             SET host_queue_entries.status = 'Abort'
-            WHERE host_queue_entries.active IS TRUE""" + timed_out)
+            WHERE host_queue_entries.active""" + timed_out)
 
         _db.execute(update + """
             SET host_queue_entries.status = 'Aborted',
-                host_queue_entries.active = FALSE,
-                host_queue_entries.complete = TRUE
-            WHERE host_queue_entries.active IS FALSE
-                AND host_queue_entries.complete IS FALSE""" + timed_out)
+                host_queue_entries.active = 0,
+                host_queue_entries.complete = 1
+            WHERE NOT host_queue_entries.active
+                AND NOT host_queue_entries.complete""" + timed_out)
 
 
     def _clear_inactive_blocks(self):
@@ -1660,7 +1599,6 @@
 
     @classmethod
     def fetch(cls, where='', params=(), joins='', order_by=''):
-        table = cls._get_table()
         order_by = cls._prefix_with(order_by, 'ORDER BY ')
         where = cls._prefix_with(where, 'WHERE ')
         query = ('SELECT %(table)s.* FROM %(table)s %(joins)s '
diff --git a/scheduler/monitor_db_unittest.py b/scheduler/monitor_db_unittest.py
index 1216448..f4cfa83 100644
--- a/scheduler/monitor_db_unittest.py
+++ b/scheduler/monitor_db_unittest.py
@@ -6,32 +6,14 @@
 from autotest_lib.client.common_lib import global_config, host_protections
 from autotest_lib.client.common_lib.test_utils import mock
 from autotest_lib.database import database_connection, migrate
+from autotest_lib.scheduler import monitor_db
 
-import monitor_db
+from autotest_lib.frontend import django_test_utils
+django_test_utils.setup_test_environ()
+from autotest_lib.frontend.afe import models
 
 _DEBUG = False
 
-_TEST_DATA = """
--- create a user and an ACL group
-INSERT INTO users (login) VALUES ('my_user');
-INSERT INTO acl_groups (name) VALUES ('my_acl');
-INSERT INTO acl_groups_users (user_id, acl_group_id) VALUES (1, 1);
-
--- create some hosts
-INSERT INTO hosts (hostname) VALUES ('host1'), ('host2'), ('host3'), ('host4');
--- add hosts to the ACL group
-INSERT INTO acl_groups_hosts (host_id, acl_group_id) VALUES
-  (1, 1), (2, 1), (3, 1), (4, 1);
-
--- create a label for each of two hosts
-INSERT INTO labels (name) VALUES ('label1'), ('label2');
-INSERT INTO labels (name, only_if_needed) VALUES ('label3', true);
-
--- add hosts to labels
-INSERT INTO hosts_labels (host_id, label_id) VALUES
-  (1, 1), (2, 2);
-"""
-
 class Dummy(object):
     'Dummy object that can have attribute assigned to it'
 
@@ -51,88 +33,65 @@
 
 class BaseDispatcherTest(unittest.TestCase):
     _config_section = 'AUTOTEST_WEB'
-
-
-    def _setup_test_db_name(self):
-        global_config.global_config.reset_config_values()
-        real_db_name = global_config.global_config.get_config_value(
-            self._config_section, 'database')
-        test_db_name = 'test_' + real_db_name
-        global_config.global_config.override_config_value(self._config_section,
-                                                          'database',
-                                                          test_db_name)
-
-
-    def _read_db_info(self):
-        config = global_config.global_config
-        section = self._config_section
-        self._host = config.get_config_value(section, "host")
-        self._db_name = config.get_config_value(section, "database")
-        self._user = config.get_config_value(section, "user")
-        self._password = config.get_config_value(section, "password")
-
-
-    def _connect_to_db(self, db_name=''):
-        self._con = MySQLdb.connect(host=self._host, user=self._user,
-                                    passwd=self._password, db=db_name)
-        self._con.autocommit(True)
-        self._cur = self._con.cursor()
-
-
-    def _disconnect_from_db(self):
-        self._con.close()
-
+    _test_db_initialized = False
 
     def _do_query(self, sql):
-        if _DEBUG:
-            print 'SQL:', sql
-        self._cur.execute(sql)
+        self._database.execute(sql)
 
 
-    def _do_queries(self, sql_queries):
-        for query in sql_queries.split(';'):
-            query = query.strip()
-            if query:
-                self._do_query(query)
+    @classmethod
+    def _initialize_test_db(cls):
+        if cls._test_db_initialized:
+            return
+        temp_fd, cls._test_db_file = tempfile.mkstemp(suffix='.monitor_test')
+        os.close(temp_fd)
+        django_test_utils.set_test_database(cls._test_db_file)
+        django_test_utils.run_syncdb()
+        cls._test_db_backup = django_test_utils.backup_test_database()
+        cls._test_db_initialized = True
 
 
     def _open_test_db(self):
-        self._connect_to_db()
-        self._do_query('DROP DATABASE IF EXISTS ' + self._db_name)
-        self._do_query('CREATE DATABASE ' + self._db_name)
-        self._disconnect_from_db()
-
-        database = database_connection.DatabaseConnection('AUTOTEST_WEB')
-        database.connect(db_name=self._db_name)
-        manager = migrate.MigrationManager(database, force=True)
-        manager.do_sync_db()
-
-        self._connect_to_db(self._db_name)
+        self._initialize_test_db()
+        django_test_utils.restore_test_database(self._test_db_backup)
+        self._database = (
+            database_connection.DatabaseConnection.get_test_database(
+                self._test_db_file))
+        self._database.connect()
+        self._database.debug = _DEBUG
 
 
     def _close_test_db(self):
-        self._do_query('DROP DATABASE ' + self._db_name)
-        self._disconnect_from_db()
+        self._database.disconnect()
 
 
     def _set_monitor_stubs(self):
-        monitor_db._db = monitor_db.DatabaseConn()
-        monitor_db._db.connect(db_name=self._db_name)
+        monitor_db._db = self._database
 
 
     def _fill_in_test_data(self):
-        self._do_queries(_TEST_DATA)
+        user = models.User.objects.create(login='my_user')
+        acl_group = models.AclGroup.objects.create(name='my_acl')
+        acl_group.users.add(user)
+
+        hosts = [models.Host.objects.create(hostname=hostname) for hostname in
+                 ('host1', 'host2', 'host3', 'host4')]
+        acl_group.hosts = hosts
+
+        labels = [models.Label.objects.create(name=name) for name in
+                  ('label1', 'label2', 'label3')]
+        labels[2].only_if_needed = True
+        labels[2].save()
+        hosts[0].labels.add(labels[0])
+        hosts[1].labels.add(labels[1])
 
 
     def setUp(self):
         self.god = mock.mock_god()
-        self._setup_test_db_name()
-        self._read_db_info()
         self._open_test_db()
         self._fill_in_test_data()
         self._set_monitor_stubs()
         self._dispatcher = monitor_db.Dispatcher()
-        self._job_counter = 0
 
 
     def tearDown(self):
@@ -143,23 +102,18 @@
     def _create_job(self, hosts=[], metahosts=[], priority=0, active=0,
                     synchronous=False):
         synch_type = synchronous and 2 or 1
-        self._do_query('INSERT INTO jobs (name, owner, priority, synch_type) '
-                       'VALUES ("test", "my_user", %d, %d)' %
-                       (priority, synch_type))
-        self._job_counter += 1
-        job_id = self._job_counter
-        queue_entry_sql = (
-            'INSERT INTO host_queue_entries '
-            '(job_id, priority, host_id, meta_host, active) '
-            'VALUES (%d, %d, %%s, %%s, %d)' %
-            (job_id, priority, active))
+        job = models.Job.objects.create(name='test', owner='my_user',
+                                        priority=priority,
+                                        synch_type=synch_type)
         for host_id in hosts:
-            self._do_query(queue_entry_sql % (host_id, 'NULL'))
-            self._do_query('INSERT INTO ineligible_host_queues '
-                           '(job_id, host_id) VALUES (%d, %d)' %
-                           (job_id, host_id))
+            models.HostQueueEntry.objects.create(job=job, priority=priority,
+                                                 host_id=host_id, active=active)
+            models.IneligibleHostQueue.objects.create(job=job, host_id=host_id)
         for label_id in metahosts:
-            self._do_query(queue_entry_sql % ('NULL', label_id))
+            models.HostQueueEntry.objects.create(job=job, priority=priority,
+                                                 meta_host_id=label_id,
+                                                 active=active)
+        return job
 
 
     def _create_job_simple(self, hosts, use_metahost=False,
@@ -170,7 +124,7 @@
             args['metahosts'] = hosts
         else:
             args['hosts'] = hosts
-        self._create_job(priority=priority, active=active, **args)
+        return self._create_job(priority=priority, active=active, **args)
 
 
     def _update_hqe(self, set, where=''):
@@ -233,7 +187,6 @@
 
     def setUp(self):
         super(DispatcherSchedulingTest, self).setUp()
-        self._fill_in_test_data()
         self._jobs_scheduled = []
 
 
@@ -299,16 +252,16 @@
 
     def _test_only_if_needed_labels_helper(self, use_metahosts):
         # apply only_if_needed label3 to host1
-        self._do_query('INSERT INTO hosts_labels (host_id, label_id) '
-                       'VALUES (1, 3)')
-        self._create_job_simple([1], use_metahosts)
+        label3 = models.Label.smart_get('label3')
+        models.Host.smart_get('host1').labels.add(label3)
+
+        job = self._create_job_simple([1], use_metahosts)
         # if the job doesn't depend on label3, there should be no scheduling
         self._dispatcher._schedule_new_jobs()
         self._check_for_extra_schedulings()
 
         # now make the job depend on label3
-        self._do_query('INSERT INTO jobs_dependency_labels (job_id, label_id) '
-                       'VALUES (1, 3)')
+        job.dependency_labels.add(label3)
         self._dispatcher._schedule_new_jobs()
         self._assert_job_scheduled_on(1, 1)
         self._check_for_extra_schedulings()