| #!/usr/bin/python2 |
| |
| import unittest, tempfile, os |
| import common |
| import MySQLdb |
| from autotest_lib.client.common_lib import global_config |
| from autotest_lib.database import database_connection, migrate |
| |
| # Which section of the global config to pull info from. We won't actually use |
| # that DB, we'll use the corresponding test DB (test_<db name>). |
| CONFIG_DB = 'AUTOTEST_WEB' |
| |
| NUM_MIGRATIONS = 3 |
| |
| class DummyMigration(object): |
| """\ |
| Dummy migration class that records all migrations done in a class |
| varaible. |
| """ |
| |
| migrations_done = [] |
| |
| def __init__(self, version): |
| self.version = version |
| self.name = '%03d_test' % version |
| |
| |
| @classmethod |
| def get_migrations_done(cls): |
| return cls.migrations_done |
| |
| |
| @classmethod |
| def clear_migrations_done(cls): |
| cls.migrations_done = [] |
| |
| |
| @classmethod |
| def do_migration(cls, version, direction): |
| cls.migrations_done.append((version, direction)) |
| |
| |
| def migrate_up(self, manager): |
| self.do_migration(self.version, 'up') |
| if self.version == 1: |
| manager.create_migrate_table() |
| |
| |
| def migrate_down(self, manager): |
| self.do_migration(self.version, 'down') |
| |
| |
| MIGRATIONS = [DummyMigration(n) for n in xrange(1, NUM_MIGRATIONS + 1)] |
| |
| |
| class TestableMigrationManager(migrate.MigrationManager): |
| def _set_migrations_dir(self, migrations_dir=None): |
| pass |
| |
| |
| def get_migrations(self, minimum_version=None, maximum_version=None): |
| minimum_version = minimum_version or 1 |
| maximum_version = maximum_version or len(MIGRATIONS) |
| return MIGRATIONS[minimum_version-1:maximum_version] |
| |
| |
| class MigrateManagerTest(unittest.TestCase): |
| def setUp(self): |
| self._database = ( |
| database_connection.DatabaseConnection.get_test_database()) |
| self._database.connect() |
| self.manager = TestableMigrationManager(self._database) |
| DummyMigration.clear_migrations_done() |
| |
| |
| def tearDown(self): |
| self._database.disconnect() |
| |
| |
| def test_sync(self): |
| self.manager.do_sync_db() |
| self.assertEquals(self.manager.get_db_version(), NUM_MIGRATIONS) |
| self.assertEquals(DummyMigration.get_migrations_done(), |
| [(1, 'up'), (2, 'up'), (3, 'up')]) |
| |
| DummyMigration.clear_migrations_done() |
| self.manager.do_sync_db(0) |
| self.assertEquals(self.manager.get_db_version(), 0) |
| self.assertEquals(DummyMigration.get_migrations_done(), |
| [(3, 'down'), (2, 'down'), (1, 'down')]) |
| |
| |
| def test_sync_one_by_one(self): |
| for version in xrange(1, NUM_MIGRATIONS + 1): |
| self.manager.do_sync_db(version) |
| self.assertEquals(self.manager.get_db_version(), |
| version) |
| self.assertEquals( |
| DummyMigration.get_migrations_done()[-1], |
| (version, 'up')) |
| |
| for version in xrange(NUM_MIGRATIONS - 1, -1, -1): |
| self.manager.do_sync_db(version) |
| self.assertEquals(self.manager.get_db_version(), |
| version) |
| self.assertEquals( |
| DummyMigration.get_migrations_done()[-1], |
| (version + 1, 'down')) |
| |
| |
| def test_null_sync(self): |
| self.manager.do_sync_db() |
| DummyMigration.clear_migrations_done() |
| self.manager.do_sync_db() |
| self.assertEquals(DummyMigration.get_migrations_done(), []) |
| |
| |
| class DummyMigrationManager(object): |
| def __init__(self): |
| self.calls = [] |
| |
| |
| def execute_script(self, script): |
| self.calls.append(script) |
| |
| |
| class MigrationTest(unittest.TestCase): |
| def setUp(self): |
| self.manager = DummyMigrationManager() |
| |
| |
| def _do_migration(self, migration_module): |
| migration = migrate.Migration('name', 1, migration_module) |
| migration.migrate_up(self.manager) |
| migration.migrate_down(self.manager) |
| |
| self.assertEquals(self.manager.calls, ['foo', 'bar']) |
| |
| |
| def test_migration_with_methods(self): |
| class DummyMigration(object): |
| @staticmethod |
| def migrate_up(manager): |
| manager.execute_script('foo') |
| |
| |
| @staticmethod |
| def migrate_down(manager): |
| manager.execute_script('bar') |
| |
| self._do_migration(DummyMigration) |
| |
| |
| def test_migration_with_strings(self): |
| class DummyMigration(object): |
| UP_SQL = 'foo' |
| DOWN_SQL = 'bar' |
| |
| self._do_migration(DummyMigration) |
| |
| |
| if __name__ == '__main__': |
| unittest.main() |