blob: d89f808c63ab720eadc3d5d47526d0e573ea88b9 [file] [log] [blame]
#!/usr/bin/python2.4 -u
import os, sys, re, subprocess, tempfile
import MySQLdb, MySQLdb.constants.ER
from optparse import OptionParser
from common import global_config
MIGRATE_TABLE = 'migrate_info'
DEFAULT_MIGRATIONS_DIR = 'migrations'
class Migration(object):
def __init__(self, filename):
self.version = int(filename[:3])
self.name = filename[:-3]
self.module = __import__(self.name, globals(), locals(), [])
assert hasattr(self.module, 'migrate_up')
assert hasattr(self.module, 'migrate_down')
def migrate_up(self, manager):
self.module.migrate_up(manager)
def migrate_down(self, manager):
self.module.migrate_down(manager)
class MigrationManager(object):
connection = None
cursor = None
migrations_dir = None
def __init__(self, database, migrations_dir=None):
self.database = database
if migrations_dir is None:
migrations_dir = os.path.abspath(DEFAULT_MIGRATIONS_DIR)
self.migrations_dir = migrations_dir
sys.path.append(migrations_dir)
assert os.path.exists(migrations_dir)
self.db_host = None
self.db_name = None
self.username = None
self.password = None
def read_db_info(self):
# grab the config file and parse for info
c = global_config.global_config
self.db_host = c.get_config_value(self.database, "host")
self.db_name = c.get_config_value(self.database, "database")
self.username = c.get_config_value(self.database, "user")
self.password = c.get_config_value(self.database, "password")
def connect(self, host, db_name, username, password):
return MySQLdb.connect(host=host, db=db_name, user=username,
passwd=password)
def open_connection(self):
self.connection = self.connect(self.db_host, self.db_name,
self.username, self.password)
self.connection.autocommit(True)
self.cursor = self.connection.cursor()
def close_connection(self):
self.connection.close()
def execute(self, query, *parameters):
#print 'SQL:', query % parameters
return self.cursor.execute(query, parameters)
def execute_script(self, script):
sql_statements = [statement.strip() for statement
in script.split(';')]
for statement in sql_statements:
if statement:
self.execute(statement)
def check_migrate_table_exists(self):
try:
self.execute("SELECT * FROM %s" % MIGRATE_TABLE)
return True
except MySQLdb.ProgrammingError, exc:
error_code, _ = exc.args
if error_code == MySQLdb.constants.ER.NO_SUCH_TABLE:
return False
raise
def create_migrate_table(self):
if not self.check_migrate_table_exists():
self.execute("CREATE TABLE %s (`version` integer)" %
MIGRATE_TABLE)
else:
self.execute("DELETE FROM %s" % MIGRATE_TABLE)
self.execute("INSERT INTO %s VALUES (0)" % MIGRATE_TABLE)
assert self.cursor.rowcount == 1
def set_db_version(self, version):
assert isinstance(version, int)
self.execute("UPDATE %s SET version=%%s" % MIGRATE_TABLE,
version)
assert self.cursor.rowcount == 1
def get_db_version(self):
if not self.check_migrate_table_exists():
return 0
self.execute("SELECT * FROM %s" % MIGRATE_TABLE)
rows = self.cursor.fetchall()
if len(rows) == 0:
return 0
assert len(rows) == 1 and len(rows[0]) == 1
return rows[0][0]
def get_migrations(self, minimum_version=None, maximum_version=None):
migrate_files = [filename for filename
in os.listdir(self.migrations_dir)
if re.match(r'^\d\d\d_.*\.py$', filename)]
migrate_files.sort()
migrations = [Migration(filename) for filename in migrate_files]
if minimum_version is not None:
migrations = [migration for migration in migrations
if migration.version >= minimum_version]
if maximum_version is not None:
migrations = [migration for migration in migrations
if migration.version <= maximum_version]
return migrations
def do_migration(self, migration, migrate_up=True):
print 'Applying migration %s' % migration.name, # no newline
if migrate_up:
print 'up'
assert self.get_db_version() == migration.version - 1
migration.migrate_up(self)
new_version = migration.version
else:
print 'down'
assert self.get_db_version() == migration.version
migration.migrate_down(self)
new_version = migration.version - 1
self.set_db_version(new_version)
def migrate_to_version(self, version):
current_version = self.get_db_version()
if current_version < version:
lower, upper = current_version, version
migrate_up = True
else:
lower, upper = version, current_version
migrate_up = False
migrations = self.get_migrations(lower + 1, upper)
if not migrate_up:
migrations.reverse()
for migration in migrations:
self.do_migration(migration, migrate_up)
assert self.get_db_version() == version
print 'At version', version
def get_latest_version(self):
migrations = self.get_migrations()
return migrations[-1].version
def migrate_to_latest(self):
latest_version = self.get_latest_version()
self.migrate_to_version(latest_version)
def initialize_test_db(self):
self.read_db_info()
test_db_name = 'test_' + self.db_name
# first, connect to no DB so we can create a test DB
self.db_name = ''
self.open_connection()
print 'Creating test DB', test_db_name
self.execute('CREATE DATABASE ' + test_db_name)
self.close_connection()
# now connect to the test DB
self.db_name = test_db_name
self.open_connection()
def remove_test_db(self):
print 'Removing test DB'
self.execute('DROP DATABASE ' + self.db_name)
def get_mysql_args(self):
return ('-u %(user)s -p%(password)s -h %(host)s %(db)s' % {
'user' : self.username,
'password' : self.password,
'host' : self.db_host,
'db' : self.db_name})
def migrate_to_version_or_latest(self, version):
if version is None:
self.migrate_to_latest()
else:
self.migrate_to_version(version)
def do_sync_db(self, version=None):
self.read_db_info()
self.open_connection()
print 'Migration starting for database', self.db_name
self.migrate_to_version_or_latest(version)
print 'Migration complete'
def test_sync_db(self, version=None):
"""\
Create a fresh DB and run all migrations on it.
"""
self.initialize_test_db()
try:
print 'Starting migration test on DB', self.db_name
self.migrate_to_version_or_latest(version)
# show schema to the user
os.system('mysqldump %s --no-data=true '
'--add-drop-table=false' %
self.get_mysql_args())
finally:
self.remove_test_db()
print 'Test finished successfully'
def simulate_sync_db(self, version=None):
"""\
Create a fresh DB, copy the existing DB to it, and then
try to synchronize it.
"""
self.read_db_info()
self.open_connection()
db_version = self.get_db_version()
self.close_connection()
# don't do anything if we're already at the latest version
if db_version == self.get_latest_version():
print 'Skipping simulation, already at latest version'
return
# get existing data
self.read_db_info()
print 'Dumping existing data'
dump_fd, dump_file = tempfile.mkstemp('.migrate_dump')
os.close(dump_fd)
os.system('mysqldump %s >%s' %
(self.get_mysql_args(), dump_file))
# fill in test DB
self.initialize_test_db()
print 'Filling in test DB'
os.system('mysql %s <%s' % (self.get_mysql_args(), dump_file))
os.remove(dump_file)
try:
print 'Starting migration test on DB', self.db_name
self.migrate_to_version_or_latest(version)
finally:
self.remove_test_db()
print 'Test finished successfully'
USAGE = """\
%s [options] sync|test|simulate|safesync [version]
Options:
-d --database Which database to act on
-a --action Which action to perform"""\
% sys.argv[0]
def main():
parser = OptionParser()
parser.add_option("-d", "--database",
help="which database to act on",
dest="database")
parser.add_option("-a", "--action", help="what action to perform",
dest="action")
(options, args) = parser.parse_args()
manager = MigrationManager(options.database)
if len(args) > 0:
if len(args) > 1:
version = int(args[1])
else:
version = None
if args[0] == 'sync':
manager.do_sync_db(version)
elif args[0] == 'test':
manager.test_sync_db(version)
elif args[0] == 'simulate':
manager.simulate_sync_db(version)
elif args[0] == 'safesync':
print 'Simluating migration'
manager.simulate_sync_db(version)
print 'Performing real migration'
manager.do_sync_db(version)
else:
print USAGE
return
print USAGE
if __name__ == '__main__':
main()