blob: d6447540f4599c480f07dfa39081491cf08e311f [file] [log] [blame]
#!/usr/bin/python2.4
import unittest
import MySQLdb
import migrate
from common import global_config
# Which section of the global config to pull info from. We won't actually use
# that DB, we'll use the corresponding test DB (test_<db name>).
CONFIG_DB = 'AUTOTEST_WEB'
NUM_MIGRATIONS = 3
class DummyMigration(object):
"""\
Dummy migration class that records all migrations done in a class
varaible.
"""
migrations_done = []
def __init__(self, version):
self.version = version
self.name = '%03d_test' % version
@classmethod
def get_migrations_done(cls):
return cls.migrations_done
@classmethod
def clear_migrations_done(cls):
cls.migrations_done = []
@classmethod
def do_migration(cls, version, direction):
cls.migrations_done.append((version, direction))
def migrate_up(self, manager):
self.do_migration(self.version, 'up')
if self.version == 1:
manager.create_migrate_table()
def migrate_down(self, manager):
self.do_migration(self.version, 'down')
MIGRATIONS = [DummyMigration(n) for n in xrange(1, NUM_MIGRATIONS + 1)]
class TestableMigrationManager(migrate.MigrationManager):
def __init__(self, database, migrations_dir=None):
self.database = database
self.migrations_dir = migrations_dir
self.db_host = None
self.db_name = None
self.username = None
self.password = None
def read_db_info(self):
migrate.MigrationManager.read_db_info(self)
self.db_name = 'test_' + self.db_name
def get_migrations(self, minimum_version=None, maximum_version=None):
minimum_version = minimum_version or 1
maximum_version = maximum_version or len(MIGRATIONS)
return MIGRATIONS[minimum_version-1:maximum_version]
class MigrateManagerTest(unittest.TestCase):
config = global_config.global_config
host = config.get_config_value(CONFIG_DB, 'host')
db_name = 'test_' + config.get_config_value(CONFIG_DB, 'database')
user = config.get_config_value(CONFIG_DB, 'user')
password = config.get_config_value(CONFIG_DB, 'password')
def do_sql(self, sql):
self.con = MySQLdb.connect(host=self.host, user=self.user,
passwd=self.password)
self.con.autocommit(True)
self.cur = self.con.cursor()
try:
self.cur.execute(sql)
finally:
self.con.close()
def remove_db(self):
self.do_sql('DROP DATABASE ' + self.db_name)
def setUp(self):
self.do_sql('CREATE DATABASE ' + self.db_name)
try:
self.manager = TestableMigrationManager(CONFIG_DB)
except MySQLdb.OperationalError:
self.remove_db()
raise
DummyMigration.clear_migrations_done()
def tearDown(self):
self.remove_db()
def test_sync(self):
self.manager.do_sync_db()
self.assertEquals(self.manager.get_db_version(), NUM_MIGRATIONS)
self.assertEquals(DummyMigration.get_migrations_done(),
[(1, 'up'), (2, 'up'), (3, 'up')])
DummyMigration.clear_migrations_done()
self.manager.do_sync_db(0)
self.assertEquals(self.manager.get_db_version(), 0)
self.assertEquals(DummyMigration.get_migrations_done(),
[(3, 'down'), (2, 'down'), (1, 'down')])
def test_sync_one_by_one(self):
for version in xrange(1, NUM_MIGRATIONS + 1):
self.manager.do_sync_db(version)
self.assertEquals(self.manager.get_db_version(),
version)
self.assertEquals(
DummyMigration.get_migrations_done()[-1],
(version, 'up'))
for version in xrange(NUM_MIGRATIONS - 1, -1, -1):
self.manager.do_sync_db(version)
self.assertEquals(self.manager.get_db_version(),
version)
self.assertEquals(
DummyMigration.get_migrations_done()[-1],
(version + 1, 'down'))
def test_null_sync(self):
self.manager.do_sync_db()
DummyMigration.clear_migrations_done()
self.manager.do_sync_db()
self.assertEquals(DummyMigration.get_migrations_done(), [])
if __name__ == '__main__':
unittest.main()