blob: 5382252ee4506b1da922327f575148dfbf03e616 [file] [log] [blame]
mblighb5ec4872008-03-14 22:40:37 +00001#!/usr/bin/python2.4 -u
mblighe8819cd2008-02-15 16:48:40 +00002
3import os, sys, re, subprocess, tempfile
4import MySQLdb, MySQLdb.constants.ER
mblighb090f142008-02-27 21:33:46 +00005from optparse import OptionParser
mbligh9b907d62008-05-13 17:56:24 +00006import common
7from autotest_lib.client.common_lib import global_config
mblighe8819cd2008-02-15 16:48:40 +00008
9MIGRATE_TABLE = 'migrate_info'
10DEFAULT_MIGRATIONS_DIR = 'migrations'
11
mblighe8819cd2008-02-15 16:48:40 +000012class Migration(object):
mblighe8819cd2008-02-15 16:48:40 +000013 def __init__(self, filename):
14 self.version = int(filename[:3])
showarddecbe502008-03-28 16:31:10 +000015 self.name = filename[:-3]
16 self.module = __import__(self.name, globals(), locals(), [])
17 assert hasattr(self.module, 'migrate_up')
18 assert hasattr(self.module, 'migrate_down')
19
20
21 def migrate_up(self, manager):
22 self.module.migrate_up(manager)
23
24
25 def migrate_down(self, manager):
26 self.module.migrate_down(manager)
mblighe8819cd2008-02-15 16:48:40 +000027
28
29class MigrationManager(object):
30 connection = None
31 cursor = None
32 migrations_dir = None
33
showarddecbe502008-03-28 16:31:10 +000034 def __init__(self, database, migrations_dir=None):
mblighb090f142008-02-27 21:33:46 +000035 self.database = database
mblighe8819cd2008-02-15 16:48:40 +000036 if migrations_dir is None:
37 migrations_dir = os.path.abspath(DEFAULT_MIGRATIONS_DIR)
38 self.migrations_dir = migrations_dir
39 sys.path.append(migrations_dir)
40 assert os.path.exists(migrations_dir)
41
showarddecbe502008-03-28 16:31:10 +000042 self.db_host = None
43 self.db_name = None
44 self.username = None
45 self.password = None
mblighe8819cd2008-02-15 16:48:40 +000046
47
mblighe8819cd2008-02-15 16:48:40 +000048 def read_db_info(self):
mblighb090f142008-02-27 21:33:46 +000049 # grab the config file and parse for info
50 c = global_config.global_config
mbligh9b907d62008-05-13 17:56:24 +000051 sections = c.get_sections()
mblighb090f142008-02-27 21:33:46 +000052 self.db_host = c.get_config_value(self.database, "host")
53 self.db_name = c.get_config_value(self.database, "database")
54 self.username = c.get_config_value(self.database, "user")
55 self.password = c.get_config_value(self.database, "password")
mblighe8819cd2008-02-15 16:48:40 +000056
57
58 def connect(self, host, db_name, username, password):
59 return MySQLdb.connect(host=host, db=db_name, user=username,
60 passwd=password)
61
62
63 def open_connection(self):
64 self.connection = self.connect(self.db_host, self.db_name,
65 self.username, self.password)
mblighaa383b72008-03-12 20:11:56 +000066 self.connection.autocommit(True)
mblighe8819cd2008-02-15 16:48:40 +000067 self.cursor = self.connection.cursor()
68
69
70 def close_connection(self):
71 self.connection.close()
72
73
74 def execute(self, query, *parameters):
75 #print 'SQL:', query % parameters
76 return self.cursor.execute(query, parameters)
77
78
79 def execute_script(self, script):
80 sql_statements = [statement.strip() for statement
81 in script.split(';')]
82 for statement in sql_statements:
83 if statement:
84 self.execute(statement)
85
86
87 def check_migrate_table_exists(self):
88 try:
89 self.execute("SELECT * FROM %s" % MIGRATE_TABLE)
90 return True
91 except MySQLdb.ProgrammingError, exc:
92 error_code, _ = exc.args
93 if error_code == MySQLdb.constants.ER.NO_SUCH_TABLE:
94 return False
95 raise
96
97
98 def create_migrate_table(self):
mblighaa383b72008-03-12 20:11:56 +000099 if not self.check_migrate_table_exists():
100 self.execute("CREATE TABLE %s (`version` integer)" %
101 MIGRATE_TABLE)
102 else:
103 self.execute("DELETE FROM %s" % MIGRATE_TABLE)
mblighe8819cd2008-02-15 16:48:40 +0000104 self.execute("INSERT INTO %s VALUES (0)" % MIGRATE_TABLE)
105 assert self.cursor.rowcount == 1
106
107
108 def set_db_version(self, version):
109 assert isinstance(version, int)
110 self.execute("UPDATE %s SET version=%%s" % MIGRATE_TABLE,
111 version)
112 assert self.cursor.rowcount == 1
113
114
115 def get_db_version(self):
116 if not self.check_migrate_table_exists():
117 return 0
118 self.execute("SELECT * FROM %s" % MIGRATE_TABLE)
119 rows = self.cursor.fetchall()
mblighaa383b72008-03-12 20:11:56 +0000120 if len(rows) == 0:
121 return 0
mblighe8819cd2008-02-15 16:48:40 +0000122 assert len(rows) == 1 and len(rows[0]) == 1
123 return rows[0][0]
124
125
mblighaa383b72008-03-12 20:11:56 +0000126 def get_migrations(self, minimum_version=None, maximum_version=None):
mblighe8819cd2008-02-15 16:48:40 +0000127 migrate_files = [filename for filename
128 in os.listdir(self.migrations_dir)
129 if re.match(r'^\d\d\d_.*\.py$', filename)]
130 migrate_files.sort()
131 migrations = [Migration(filename) for filename in migrate_files]
132 if minimum_version is not None:
133 migrations = [migration for migration in migrations
134 if migration.version >= minimum_version]
mblighaa383b72008-03-12 20:11:56 +0000135 if maximum_version is not None:
136 migrations = [migration for migration in migrations
137 if migration.version <= maximum_version]
mblighe8819cd2008-02-15 16:48:40 +0000138 return migrations
139
140
mblighaa383b72008-03-12 20:11:56 +0000141 def do_migration(self, migration, migrate_up=True):
showarddecbe502008-03-28 16:31:10 +0000142 print 'Applying migration %s' % migration.name, # no newline
mblighaa383b72008-03-12 20:11:56 +0000143 if migrate_up:
showarddecbe502008-03-28 16:31:10 +0000144 print 'up'
mblighaa383b72008-03-12 20:11:56 +0000145 assert self.get_db_version() == migration.version - 1
showarddecbe502008-03-28 16:31:10 +0000146 migration.migrate_up(self)
mblighaa383b72008-03-12 20:11:56 +0000147 new_version = migration.version
148 else:
showarddecbe502008-03-28 16:31:10 +0000149 print 'down'
mblighaa383b72008-03-12 20:11:56 +0000150 assert self.get_db_version() == migration.version
showarddecbe502008-03-28 16:31:10 +0000151 migration.migrate_down(self)
mblighaa383b72008-03-12 20:11:56 +0000152 new_version = migration.version - 1
mblighaa383b72008-03-12 20:11:56 +0000153 self.set_db_version(new_version)
mblighe8819cd2008-02-15 16:48:40 +0000154
155
156 def migrate_to_version(self, version):
157 current_version = self.get_db_version()
mblighaa383b72008-03-12 20:11:56 +0000158 if current_version < version:
159 lower, upper = current_version, version
160 migrate_up = True
161 else:
162 lower, upper = version, current_version
163 migrate_up = False
mblighe8819cd2008-02-15 16:48:40 +0000164
mblighaa383b72008-03-12 20:11:56 +0000165 migrations = self.get_migrations(lower + 1, upper)
166 if not migrate_up:
167 migrations.reverse()
mblighe8819cd2008-02-15 16:48:40 +0000168 for migration in migrations:
mblighaa383b72008-03-12 20:11:56 +0000169 self.do_migration(migration, migrate_up)
mblighe8819cd2008-02-15 16:48:40 +0000170
171 assert self.get_db_version() == version
172 print 'At version', version
173
174
showardd2d4e2c2008-03-12 21:32:46 +0000175 def get_latest_version(self):
mblighe8819cd2008-02-15 16:48:40 +0000176 migrations = self.get_migrations()
showardd2d4e2c2008-03-12 21:32:46 +0000177 return migrations[-1].version
178
179
180 def migrate_to_latest(self):
181 latest_version = self.get_latest_version()
mblighe8819cd2008-02-15 16:48:40 +0000182 self.migrate_to_version(latest_version)
183
184
185 def initialize_test_db(self):
186 self.read_db_info()
187 test_db_name = 'test_' + self.db_name
188 # first, connect to no DB so we can create a test DB
189 self.db_name = ''
190 self.open_connection()
191 print 'Creating test DB', test_db_name
192 self.execute('CREATE DATABASE ' + test_db_name)
193 self.close_connection()
194 # now connect to the test DB
195 self.db_name = test_db_name
196 self.open_connection()
197
198
199 def remove_test_db(self):
200 print 'Removing test DB'
201 self.execute('DROP DATABASE ' + self.db_name)
202
203
204 def get_mysql_args(self):
205 return ('-u %(user)s -p%(password)s -h %(host)s %(db)s' % {
206 'user' : self.username,
207 'password' : self.password,
208 'host' : self.db_host,
209 'db' : self.db_name})
210
211
mblighaa383b72008-03-12 20:11:56 +0000212 def migrate_to_version_or_latest(self, version):
213 if version is None:
214 self.migrate_to_latest()
215 else:
216 self.migrate_to_version(version)
217
218
219 def do_sync_db(self, version=None):
mblighe8819cd2008-02-15 16:48:40 +0000220 self.read_db_info()
221 self.open_connection()
222 print 'Migration starting for database', self.db_name
mblighaa383b72008-03-12 20:11:56 +0000223 self.migrate_to_version_or_latest(version)
mblighe8819cd2008-02-15 16:48:40 +0000224 print 'Migration complete'
225
226
mblighaa383b72008-03-12 20:11:56 +0000227 def test_sync_db(self, version=None):
mblighe8819cd2008-02-15 16:48:40 +0000228 """\
229 Create a fresh DB and run all migrations on it.
230 """
231 self.initialize_test_db()
232 try:
233 print 'Starting migration test on DB', self.db_name
mblighaa383b72008-03-12 20:11:56 +0000234 self.migrate_to_version_or_latest(version)
mblighe8819cd2008-02-15 16:48:40 +0000235 # show schema to the user
236 os.system('mysqldump %s --no-data=true '
237 '--add-drop-table=false' %
238 self.get_mysql_args())
239 finally:
240 self.remove_test_db()
241 print 'Test finished successfully'
242
243
mblighaa383b72008-03-12 20:11:56 +0000244 def simulate_sync_db(self, version=None):
mblighe8819cd2008-02-15 16:48:40 +0000245 """\
246 Create a fresh DB, copy the existing DB to it, and then
247 try to synchronize it.
248 """
showardd2d4e2c2008-03-12 21:32:46 +0000249 self.read_db_info()
250 self.open_connection()
251 db_version = self.get_db_version()
252 self.close_connection()
253 # don't do anything if we're already at the latest version
254 if db_version == self.get_latest_version():
255 print 'Skipping simulation, already at latest version'
256 return
mblighe8819cd2008-02-15 16:48:40 +0000257 # get existing data
258 self.read_db_info()
259 print 'Dumping existing data'
260 dump_fd, dump_file = tempfile.mkstemp('.migrate_dump')
261 os.close(dump_fd)
262 os.system('mysqldump %s >%s' %
263 (self.get_mysql_args(), dump_file))
264 # fill in test DB
265 self.initialize_test_db()
266 print 'Filling in test DB'
267 os.system('mysql %s <%s' % (self.get_mysql_args(), dump_file))
268 os.remove(dump_file)
269 try:
270 print 'Starting migration test on DB', self.db_name
mblighaa383b72008-03-12 20:11:56 +0000271 self.migrate_to_version_or_latest(version)
mblighe8819cd2008-02-15 16:48:40 +0000272 finally:
273 self.remove_test_db()
274 print 'Test finished successfully'
275
276
mblighc2f24452008-03-31 16:46:13 +0000277USAGE = """\
278%s [options] sync|test|simulate|safesync [version]
279Options:
280 -d --database Which database to act on
281 -a --action Which action to perform"""\
282 % sys.argv[0]
mblighe8819cd2008-02-15 16:48:40 +0000283
284
285def main():
mblighb090f142008-02-27 21:33:46 +0000286 parser = OptionParser()
287 parser.add_option("-d", "--database",
288 help="which database to act on",
289 dest="database")
290 parser.add_option("-a", "--action", help="what action to perform",
291 dest="action")
292 (options, args) = parser.parse_args()
293 manager = MigrationManager(options.database)
294
295 if len(args) > 0:
mblighaa383b72008-03-12 20:11:56 +0000296 if len(args) > 1:
297 version = int(args[1])
298 else:
299 version = None
mblighb090f142008-02-27 21:33:46 +0000300 if args[0] == 'sync':
mblighaa383b72008-03-12 20:11:56 +0000301 manager.do_sync_db(version)
mblighb090f142008-02-27 21:33:46 +0000302 elif args[0] == 'test':
mblighaa383b72008-03-12 20:11:56 +0000303 manager.test_sync_db(version)
mblighb090f142008-02-27 21:33:46 +0000304 elif args[0] == 'simulate':
mblighaa383b72008-03-12 20:11:56 +0000305 manager.simulate_sync_db(version)
mblighb090f142008-02-27 21:33:46 +0000306 elif args[0] == 'safesync':
mblighe8819cd2008-02-15 16:48:40 +0000307 print 'Simluating migration'
mblighaa383b72008-03-12 20:11:56 +0000308 manager.simulate_sync_db(version)
mblighe8819cd2008-02-15 16:48:40 +0000309 print 'Performing real migration'
mblighaa383b72008-03-12 20:11:56 +0000310 manager.do_sync_db(version)
mblighe8819cd2008-02-15 16:48:40 +0000311 else:
312 print USAGE
313 return
mblighb090f142008-02-27 21:33:46 +0000314
mblighe8819cd2008-02-15 16:48:40 +0000315 print USAGE
316
317
318if __name__ == '__main__':
319 main()