blob: 8bbcb5f970c5368a9d4d3f1da3e74cea525e4603 [file] [log] [blame]
mblighfa29a2a2008-05-16 22:48:09 +00001#!/usr/bin/python -u
mblighe8819cd2008-02-15 16:48:40 +00002
3import os, sys, re, subprocess, tempfile
4import MySQLdb, MySQLdb.constants.ER
mblighb090f142008-02-27 21:33:46 +00005from optparse import OptionParser
mbligh9b907d62008-05-13 17:56:24 +00006import common
7from autotest_lib.client.common_lib import global_config
mblighe8819cd2008-02-15 16:48:40 +00008
9MIGRATE_TABLE = 'migrate_info'
10DEFAULT_MIGRATIONS_DIR = 'migrations'
11
mblighe8819cd2008-02-15 16:48:40 +000012class Migration(object):
mblighe8819cd2008-02-15 16:48:40 +000013 def __init__(self, filename):
14 self.version = int(filename[:3])
showarddecbe502008-03-28 16:31:10 +000015 self.name = filename[:-3]
16 self.module = __import__(self.name, globals(), locals(), [])
17 assert hasattr(self.module, 'migrate_up')
18 assert hasattr(self.module, 'migrate_down')
19
20
21 def migrate_up(self, manager):
22 self.module.migrate_up(manager)
23
24
25 def migrate_down(self, manager):
26 self.module.migrate_down(manager)
mblighe8819cd2008-02-15 16:48:40 +000027
28
29class MigrationManager(object):
30 connection = None
31 cursor = None
32 migrations_dir = None
33
showarddecbe502008-03-28 16:31:10 +000034 def __init__(self, database, migrations_dir=None):
mblighb090f142008-02-27 21:33:46 +000035 self.database = database
mblighe8819cd2008-02-15 16:48:40 +000036 if migrations_dir is None:
37 migrations_dir = os.path.abspath(DEFAULT_MIGRATIONS_DIR)
38 self.migrations_dir = migrations_dir
39 sys.path.append(migrations_dir)
40 assert os.path.exists(migrations_dir)
41
showarddecbe502008-03-28 16:31:10 +000042 self.db_host = None
43 self.db_name = None
44 self.username = None
45 self.password = None
mblighe8819cd2008-02-15 16:48:40 +000046
47
mblighe8819cd2008-02-15 16:48:40 +000048 def read_db_info(self):
mblighb090f142008-02-27 21:33:46 +000049 # grab the config file and parse for info
50 c = global_config.global_config
mblighb090f142008-02-27 21:33:46 +000051 self.db_host = c.get_config_value(self.database, "host")
52 self.db_name = c.get_config_value(self.database, "database")
53 self.username = c.get_config_value(self.database, "user")
54 self.password = c.get_config_value(self.database, "password")
mblighe8819cd2008-02-15 16:48:40 +000055
56
57 def connect(self, host, db_name, username, password):
58 return MySQLdb.connect(host=host, db=db_name, user=username,
59 passwd=password)
60
61
62 def open_connection(self):
63 self.connection = self.connect(self.db_host, self.db_name,
64 self.username, self.password)
mblighaa383b72008-03-12 20:11:56 +000065 self.connection.autocommit(True)
mblighe8819cd2008-02-15 16:48:40 +000066 self.cursor = self.connection.cursor()
67
68
69 def close_connection(self):
70 self.connection.close()
71
72
73 def execute(self, query, *parameters):
74 #print 'SQL:', query % parameters
75 return self.cursor.execute(query, parameters)
76
77
78 def execute_script(self, script):
79 sql_statements = [statement.strip() for statement
80 in script.split(';')]
81 for statement in sql_statements:
82 if statement:
83 self.execute(statement)
84
85
86 def check_migrate_table_exists(self):
87 try:
88 self.execute("SELECT * FROM %s" % MIGRATE_TABLE)
89 return True
90 except MySQLdb.ProgrammingError, exc:
91 error_code, _ = exc.args
92 if error_code == MySQLdb.constants.ER.NO_SUCH_TABLE:
93 return False
94 raise
95
96
97 def create_migrate_table(self):
mblighaa383b72008-03-12 20:11:56 +000098 if not self.check_migrate_table_exists():
99 self.execute("CREATE TABLE %s (`version` integer)" %
100 MIGRATE_TABLE)
101 else:
102 self.execute("DELETE FROM %s" % MIGRATE_TABLE)
mblighe8819cd2008-02-15 16:48:40 +0000103 self.execute("INSERT INTO %s VALUES (0)" % MIGRATE_TABLE)
104 assert self.cursor.rowcount == 1
105
106
107 def set_db_version(self, version):
108 assert isinstance(version, int)
109 self.execute("UPDATE %s SET version=%%s" % MIGRATE_TABLE,
110 version)
111 assert self.cursor.rowcount == 1
112
113
114 def get_db_version(self):
115 if not self.check_migrate_table_exists():
116 return 0
117 self.execute("SELECT * FROM %s" % MIGRATE_TABLE)
118 rows = self.cursor.fetchall()
mblighaa383b72008-03-12 20:11:56 +0000119 if len(rows) == 0:
120 return 0
mblighe8819cd2008-02-15 16:48:40 +0000121 assert len(rows) == 1 and len(rows[0]) == 1
122 return rows[0][0]
123
124
mblighaa383b72008-03-12 20:11:56 +0000125 def get_migrations(self, minimum_version=None, maximum_version=None):
mblighe8819cd2008-02-15 16:48:40 +0000126 migrate_files = [filename for filename
127 in os.listdir(self.migrations_dir)
128 if re.match(r'^\d\d\d_.*\.py$', filename)]
129 migrate_files.sort()
130 migrations = [Migration(filename) for filename in migrate_files]
131 if minimum_version is not None:
132 migrations = [migration for migration in migrations
133 if migration.version >= minimum_version]
mblighaa383b72008-03-12 20:11:56 +0000134 if maximum_version is not None:
135 migrations = [migration for migration in migrations
136 if migration.version <= maximum_version]
mblighe8819cd2008-02-15 16:48:40 +0000137 return migrations
138
139
mblighaa383b72008-03-12 20:11:56 +0000140 def do_migration(self, migration, migrate_up=True):
showarddecbe502008-03-28 16:31:10 +0000141 print 'Applying migration %s' % migration.name, # no newline
mblighaa383b72008-03-12 20:11:56 +0000142 if migrate_up:
showarddecbe502008-03-28 16:31:10 +0000143 print 'up'
mblighaa383b72008-03-12 20:11:56 +0000144 assert self.get_db_version() == migration.version - 1
showarddecbe502008-03-28 16:31:10 +0000145 migration.migrate_up(self)
mblighaa383b72008-03-12 20:11:56 +0000146 new_version = migration.version
147 else:
showarddecbe502008-03-28 16:31:10 +0000148 print 'down'
mblighaa383b72008-03-12 20:11:56 +0000149 assert self.get_db_version() == migration.version
showarddecbe502008-03-28 16:31:10 +0000150 migration.migrate_down(self)
mblighaa383b72008-03-12 20:11:56 +0000151 new_version = migration.version - 1
mblighaa383b72008-03-12 20:11:56 +0000152 self.set_db_version(new_version)
mblighe8819cd2008-02-15 16:48:40 +0000153
154
155 def migrate_to_version(self, version):
156 current_version = self.get_db_version()
mblighaa383b72008-03-12 20:11:56 +0000157 if current_version < version:
158 lower, upper = current_version, version
159 migrate_up = True
160 else:
161 lower, upper = version, current_version
162 migrate_up = False
mblighe8819cd2008-02-15 16:48:40 +0000163
mblighaa383b72008-03-12 20:11:56 +0000164 migrations = self.get_migrations(lower + 1, upper)
165 if not migrate_up:
166 migrations.reverse()
mblighe8819cd2008-02-15 16:48:40 +0000167 for migration in migrations:
mblighaa383b72008-03-12 20:11:56 +0000168 self.do_migration(migration, migrate_up)
mblighe8819cd2008-02-15 16:48:40 +0000169
170 assert self.get_db_version() == version
171 print 'At version', version
172
173
showardd2d4e2c2008-03-12 21:32:46 +0000174 def get_latest_version(self):
mblighe8819cd2008-02-15 16:48:40 +0000175 migrations = self.get_migrations()
showardd2d4e2c2008-03-12 21:32:46 +0000176 return migrations[-1].version
177
178
179 def migrate_to_latest(self):
180 latest_version = self.get_latest_version()
mblighe8819cd2008-02-15 16:48:40 +0000181 self.migrate_to_version(latest_version)
182
183
184 def initialize_test_db(self):
185 self.read_db_info()
186 test_db_name = 'test_' + self.db_name
187 # first, connect to no DB so we can create a test DB
188 self.db_name = ''
189 self.open_connection()
190 print 'Creating test DB', test_db_name
191 self.execute('CREATE DATABASE ' + test_db_name)
192 self.close_connection()
193 # now connect to the test DB
194 self.db_name = test_db_name
195 self.open_connection()
196
197
198 def remove_test_db(self):
199 print 'Removing test DB'
200 self.execute('DROP DATABASE ' + self.db_name)
201
202
203 def get_mysql_args(self):
204 return ('-u %(user)s -p%(password)s -h %(host)s %(db)s' % {
205 'user' : self.username,
206 'password' : self.password,
207 'host' : self.db_host,
208 'db' : self.db_name})
209
210
mblighaa383b72008-03-12 20:11:56 +0000211 def migrate_to_version_or_latest(self, version):
212 if version is None:
213 self.migrate_to_latest()
214 else:
215 self.migrate_to_version(version)
216
217
218 def do_sync_db(self, version=None):
mblighe8819cd2008-02-15 16:48:40 +0000219 self.read_db_info()
220 self.open_connection()
221 print 'Migration starting for database', self.db_name
mblighaa383b72008-03-12 20:11:56 +0000222 self.migrate_to_version_or_latest(version)
mblighe8819cd2008-02-15 16:48:40 +0000223 print 'Migration complete'
224
225
mblighaa383b72008-03-12 20:11:56 +0000226 def test_sync_db(self, version=None):
mblighe8819cd2008-02-15 16:48:40 +0000227 """\
228 Create a fresh DB and run all migrations on it.
229 """
230 self.initialize_test_db()
231 try:
232 print 'Starting migration test on DB', self.db_name
mblighaa383b72008-03-12 20:11:56 +0000233 self.migrate_to_version_or_latest(version)
mblighe8819cd2008-02-15 16:48:40 +0000234 # show schema to the user
235 os.system('mysqldump %s --no-data=true '
236 '--add-drop-table=false' %
237 self.get_mysql_args())
238 finally:
239 self.remove_test_db()
240 print 'Test finished successfully'
241
242
mblighaa383b72008-03-12 20:11:56 +0000243 def simulate_sync_db(self, version=None):
mblighe8819cd2008-02-15 16:48:40 +0000244 """\
245 Create a fresh DB, copy the existing DB to it, and then
246 try to synchronize it.
247 """
showardd2d4e2c2008-03-12 21:32:46 +0000248 self.read_db_info()
249 self.open_connection()
250 db_version = self.get_db_version()
251 self.close_connection()
252 # don't do anything if we're already at the latest version
253 if db_version == self.get_latest_version():
254 print 'Skipping simulation, already at latest version'
255 return
mblighe8819cd2008-02-15 16:48:40 +0000256 # get existing data
257 self.read_db_info()
258 print 'Dumping existing data'
259 dump_fd, dump_file = tempfile.mkstemp('.migrate_dump')
260 os.close(dump_fd)
261 os.system('mysqldump %s >%s' %
262 (self.get_mysql_args(), dump_file))
263 # fill in test DB
264 self.initialize_test_db()
265 print 'Filling in test DB'
266 os.system('mysql %s <%s' % (self.get_mysql_args(), dump_file))
267 os.remove(dump_file)
268 try:
269 print 'Starting migration test on DB', self.db_name
mblighaa383b72008-03-12 20:11:56 +0000270 self.migrate_to_version_or_latest(version)
mblighe8819cd2008-02-15 16:48:40 +0000271 finally:
272 self.remove_test_db()
273 print 'Test finished successfully'
274
275
mblighc2f24452008-03-31 16:46:13 +0000276USAGE = """\
277%s [options] sync|test|simulate|safesync [version]
278Options:
279 -d --database Which database to act on
280 -a --action Which action to perform"""\
281 % sys.argv[0]
mblighe8819cd2008-02-15 16:48:40 +0000282
283
284def main():
mblighb090f142008-02-27 21:33:46 +0000285 parser = OptionParser()
286 parser.add_option("-d", "--database",
287 help="which database to act on",
288 dest="database")
289 parser.add_option("-a", "--action", help="what action to perform",
290 dest="action")
291 (options, args) = parser.parse_args()
292 manager = MigrationManager(options.database)
293
294 if len(args) > 0:
mblighaa383b72008-03-12 20:11:56 +0000295 if len(args) > 1:
296 version = int(args[1])
297 else:
298 version = None
mblighb090f142008-02-27 21:33:46 +0000299 if args[0] == 'sync':
mblighaa383b72008-03-12 20:11:56 +0000300 manager.do_sync_db(version)
mblighb090f142008-02-27 21:33:46 +0000301 elif args[0] == 'test':
mblighaa383b72008-03-12 20:11:56 +0000302 manager.test_sync_db(version)
mblighb090f142008-02-27 21:33:46 +0000303 elif args[0] == 'simulate':
mblighaa383b72008-03-12 20:11:56 +0000304 manager.simulate_sync_db(version)
mblighb090f142008-02-27 21:33:46 +0000305 elif args[0] == 'safesync':
mblighe8819cd2008-02-15 16:48:40 +0000306 print 'Simluating migration'
mblighaa383b72008-03-12 20:11:56 +0000307 manager.simulate_sync_db(version)
mblighe8819cd2008-02-15 16:48:40 +0000308 print 'Performing real migration'
mblighaa383b72008-03-12 20:11:56 +0000309 manager.do_sync_db(version)
mblighe8819cd2008-02-15 16:48:40 +0000310 else:
311 print USAGE
312 return
mblighb090f142008-02-27 21:33:46 +0000313
mblighe8819cd2008-02-15 16:48:40 +0000314 print USAGE
315
316
317if __name__ == '__main__':
318 main()