blob: d89f808c63ab720eadc3d5d47526d0e573ea88b9 [file] [log] [blame]
mblighb5ec4872008-03-14 22:40:37 +00001#!/usr/bin/python2.4 -u
mblighe8819cd2008-02-15 16:48:40 +00002
3import os, sys, re, subprocess, tempfile
4import MySQLdb, MySQLdb.constants.ER
mblighb090f142008-02-27 21:33:46 +00005from optparse import OptionParser
6from common import global_config
mblighe8819cd2008-02-15 16:48:40 +00007
8MIGRATE_TABLE = 'migrate_info'
9DEFAULT_MIGRATIONS_DIR = 'migrations'
10
mblighe8819cd2008-02-15 16:48:40 +000011class Migration(object):
mblighe8819cd2008-02-15 16:48:40 +000012 def __init__(self, filename):
13 self.version = int(filename[:3])
showarddecbe502008-03-28 16:31:10 +000014 self.name = filename[:-3]
15 self.module = __import__(self.name, globals(), locals(), [])
16 assert hasattr(self.module, 'migrate_up')
17 assert hasattr(self.module, 'migrate_down')
18
19
20 def migrate_up(self, manager):
21 self.module.migrate_up(manager)
22
23
24 def migrate_down(self, manager):
25 self.module.migrate_down(manager)
mblighe8819cd2008-02-15 16:48:40 +000026
27
28class MigrationManager(object):
29 connection = None
30 cursor = None
31 migrations_dir = None
32
showarddecbe502008-03-28 16:31:10 +000033 def __init__(self, database, migrations_dir=None):
mblighb090f142008-02-27 21:33:46 +000034 self.database = database
mblighe8819cd2008-02-15 16:48:40 +000035 if migrations_dir is None:
36 migrations_dir = os.path.abspath(DEFAULT_MIGRATIONS_DIR)
37 self.migrations_dir = migrations_dir
38 sys.path.append(migrations_dir)
39 assert os.path.exists(migrations_dir)
40
showarddecbe502008-03-28 16:31:10 +000041 self.db_host = None
42 self.db_name = None
43 self.username = None
44 self.password = None
mblighe8819cd2008-02-15 16:48:40 +000045
46
mblighe8819cd2008-02-15 16:48:40 +000047 def read_db_info(self):
mblighb090f142008-02-27 21:33:46 +000048 # grab the config file and parse for info
49 c = global_config.global_config
mblighb090f142008-02-27 21:33:46 +000050 self.db_host = c.get_config_value(self.database, "host")
51 self.db_name = c.get_config_value(self.database, "database")
52 self.username = c.get_config_value(self.database, "user")
53 self.password = c.get_config_value(self.database, "password")
mblighe8819cd2008-02-15 16:48:40 +000054
55
56 def connect(self, host, db_name, username, password):
57 return MySQLdb.connect(host=host, db=db_name, user=username,
58 passwd=password)
59
60
61 def open_connection(self):
62 self.connection = self.connect(self.db_host, self.db_name,
63 self.username, self.password)
mblighaa383b72008-03-12 20:11:56 +000064 self.connection.autocommit(True)
mblighe8819cd2008-02-15 16:48:40 +000065 self.cursor = self.connection.cursor()
66
67
68 def close_connection(self):
69 self.connection.close()
70
71
72 def execute(self, query, *parameters):
73 #print 'SQL:', query % parameters
74 return self.cursor.execute(query, parameters)
75
76
77 def execute_script(self, script):
78 sql_statements = [statement.strip() for statement
79 in script.split(';')]
80 for statement in sql_statements:
81 if statement:
82 self.execute(statement)
83
84
85 def check_migrate_table_exists(self):
86 try:
87 self.execute("SELECT * FROM %s" % MIGRATE_TABLE)
88 return True
89 except MySQLdb.ProgrammingError, exc:
90 error_code, _ = exc.args
91 if error_code == MySQLdb.constants.ER.NO_SUCH_TABLE:
92 return False
93 raise
94
95
96 def create_migrate_table(self):
mblighaa383b72008-03-12 20:11:56 +000097 if not self.check_migrate_table_exists():
98 self.execute("CREATE TABLE %s (`version` integer)" %
99 MIGRATE_TABLE)
100 else:
101 self.execute("DELETE FROM %s" % MIGRATE_TABLE)
mblighe8819cd2008-02-15 16:48:40 +0000102 self.execute("INSERT INTO %s VALUES (0)" % MIGRATE_TABLE)
103 assert self.cursor.rowcount == 1
104
105
106 def set_db_version(self, version):
107 assert isinstance(version, int)
108 self.execute("UPDATE %s SET version=%%s" % MIGRATE_TABLE,
109 version)
110 assert self.cursor.rowcount == 1
111
112
113 def get_db_version(self):
114 if not self.check_migrate_table_exists():
115 return 0
116 self.execute("SELECT * FROM %s" % MIGRATE_TABLE)
117 rows = self.cursor.fetchall()
mblighaa383b72008-03-12 20:11:56 +0000118 if len(rows) == 0:
119 return 0
mblighe8819cd2008-02-15 16:48:40 +0000120 assert len(rows) == 1 and len(rows[0]) == 1
121 return rows[0][0]
122
123
mblighaa383b72008-03-12 20:11:56 +0000124 def get_migrations(self, minimum_version=None, maximum_version=None):
mblighe8819cd2008-02-15 16:48:40 +0000125 migrate_files = [filename for filename
126 in os.listdir(self.migrations_dir)
127 if re.match(r'^\d\d\d_.*\.py$', filename)]
128 migrate_files.sort()
129 migrations = [Migration(filename) for filename in migrate_files]
130 if minimum_version is not None:
131 migrations = [migration for migration in migrations
132 if migration.version >= minimum_version]
mblighaa383b72008-03-12 20:11:56 +0000133 if maximum_version is not None:
134 migrations = [migration for migration in migrations
135 if migration.version <= maximum_version]
mblighe8819cd2008-02-15 16:48:40 +0000136 return migrations
137
138
mblighaa383b72008-03-12 20:11:56 +0000139 def do_migration(self, migration, migrate_up=True):
showarddecbe502008-03-28 16:31:10 +0000140 print 'Applying migration %s' % migration.name, # no newline
mblighaa383b72008-03-12 20:11:56 +0000141 if migrate_up:
showarddecbe502008-03-28 16:31:10 +0000142 print 'up'
mblighaa383b72008-03-12 20:11:56 +0000143 assert self.get_db_version() == migration.version - 1
showarddecbe502008-03-28 16:31:10 +0000144 migration.migrate_up(self)
mblighaa383b72008-03-12 20:11:56 +0000145 new_version = migration.version
146 else:
showarddecbe502008-03-28 16:31:10 +0000147 print 'down'
mblighaa383b72008-03-12 20:11:56 +0000148 assert self.get_db_version() == migration.version
showarddecbe502008-03-28 16:31:10 +0000149 migration.migrate_down(self)
mblighaa383b72008-03-12 20:11:56 +0000150 new_version = migration.version - 1
mblighaa383b72008-03-12 20:11:56 +0000151 self.set_db_version(new_version)
mblighe8819cd2008-02-15 16:48:40 +0000152
153
154 def migrate_to_version(self, version):
155 current_version = self.get_db_version()
mblighaa383b72008-03-12 20:11:56 +0000156 if current_version < version:
157 lower, upper = current_version, version
158 migrate_up = True
159 else:
160 lower, upper = version, current_version
161 migrate_up = False
mblighe8819cd2008-02-15 16:48:40 +0000162
mblighaa383b72008-03-12 20:11:56 +0000163 migrations = self.get_migrations(lower + 1, upper)
164 if not migrate_up:
165 migrations.reverse()
mblighe8819cd2008-02-15 16:48:40 +0000166 for migration in migrations:
mblighaa383b72008-03-12 20:11:56 +0000167 self.do_migration(migration, migrate_up)
mblighe8819cd2008-02-15 16:48:40 +0000168
169 assert self.get_db_version() == version
170 print 'At version', version
171
172
showardd2d4e2c2008-03-12 21:32:46 +0000173 def get_latest_version(self):
mblighe8819cd2008-02-15 16:48:40 +0000174 migrations = self.get_migrations()
showardd2d4e2c2008-03-12 21:32:46 +0000175 return migrations[-1].version
176
177
178 def migrate_to_latest(self):
179 latest_version = self.get_latest_version()
mblighe8819cd2008-02-15 16:48:40 +0000180 self.migrate_to_version(latest_version)
181
182
183 def initialize_test_db(self):
184 self.read_db_info()
185 test_db_name = 'test_' + self.db_name
186 # first, connect to no DB so we can create a test DB
187 self.db_name = ''
188 self.open_connection()
189 print 'Creating test DB', test_db_name
190 self.execute('CREATE DATABASE ' + test_db_name)
191 self.close_connection()
192 # now connect to the test DB
193 self.db_name = test_db_name
194 self.open_connection()
195
196
197 def remove_test_db(self):
198 print 'Removing test DB'
199 self.execute('DROP DATABASE ' + self.db_name)
200
201
202 def get_mysql_args(self):
203 return ('-u %(user)s -p%(password)s -h %(host)s %(db)s' % {
204 'user' : self.username,
205 'password' : self.password,
206 'host' : self.db_host,
207 'db' : self.db_name})
208
209
mblighaa383b72008-03-12 20:11:56 +0000210 def migrate_to_version_or_latest(self, version):
211 if version is None:
212 self.migrate_to_latest()
213 else:
214 self.migrate_to_version(version)
215
216
217 def do_sync_db(self, version=None):
mblighe8819cd2008-02-15 16:48:40 +0000218 self.read_db_info()
219 self.open_connection()
220 print 'Migration starting for database', self.db_name
mblighaa383b72008-03-12 20:11:56 +0000221 self.migrate_to_version_or_latest(version)
mblighe8819cd2008-02-15 16:48:40 +0000222 print 'Migration complete'
223
224
mblighaa383b72008-03-12 20:11:56 +0000225 def test_sync_db(self, version=None):
mblighe8819cd2008-02-15 16:48:40 +0000226 """\
227 Create a fresh DB and run all migrations on it.
228 """
229 self.initialize_test_db()
230 try:
231 print 'Starting migration test on DB', self.db_name
mblighaa383b72008-03-12 20:11:56 +0000232 self.migrate_to_version_or_latest(version)
mblighe8819cd2008-02-15 16:48:40 +0000233 # show schema to the user
234 os.system('mysqldump %s --no-data=true '
235 '--add-drop-table=false' %
236 self.get_mysql_args())
237 finally:
238 self.remove_test_db()
239 print 'Test finished successfully'
240
241
mblighaa383b72008-03-12 20:11:56 +0000242 def simulate_sync_db(self, version=None):
mblighe8819cd2008-02-15 16:48:40 +0000243 """\
244 Create a fresh DB, copy the existing DB to it, and then
245 try to synchronize it.
246 """
showardd2d4e2c2008-03-12 21:32:46 +0000247 self.read_db_info()
248 self.open_connection()
249 db_version = self.get_db_version()
250 self.close_connection()
251 # don't do anything if we're already at the latest version
252 if db_version == self.get_latest_version():
253 print 'Skipping simulation, already at latest version'
254 return
mblighe8819cd2008-02-15 16:48:40 +0000255 # get existing data
256 self.read_db_info()
257 print 'Dumping existing data'
258 dump_fd, dump_file = tempfile.mkstemp('.migrate_dump')
259 os.close(dump_fd)
260 os.system('mysqldump %s >%s' %
261 (self.get_mysql_args(), dump_file))
262 # fill in test DB
263 self.initialize_test_db()
264 print 'Filling in test DB'
265 os.system('mysql %s <%s' % (self.get_mysql_args(), dump_file))
266 os.remove(dump_file)
267 try:
268 print 'Starting migration test on DB', self.db_name
mblighaa383b72008-03-12 20:11:56 +0000269 self.migrate_to_version_or_latest(version)
mblighe8819cd2008-02-15 16:48:40 +0000270 finally:
271 self.remove_test_db()
272 print 'Test finished successfully'
273
274
mblighc2f24452008-03-31 16:46:13 +0000275USAGE = """\
276%s [options] sync|test|simulate|safesync [version]
277Options:
278 -d --database Which database to act on
279 -a --action Which action to perform"""\
280 % sys.argv[0]
mblighe8819cd2008-02-15 16:48:40 +0000281
282
283def main():
mblighb090f142008-02-27 21:33:46 +0000284 parser = OptionParser()
285 parser.add_option("-d", "--database",
286 help="which database to act on",
287 dest="database")
288 parser.add_option("-a", "--action", help="what action to perform",
289 dest="action")
290 (options, args) = parser.parse_args()
291 manager = MigrationManager(options.database)
292
293 if len(args) > 0:
mblighaa383b72008-03-12 20:11:56 +0000294 if len(args) > 1:
295 version = int(args[1])
296 else:
297 version = None
mblighb090f142008-02-27 21:33:46 +0000298 if args[0] == 'sync':
mblighaa383b72008-03-12 20:11:56 +0000299 manager.do_sync_db(version)
mblighb090f142008-02-27 21:33:46 +0000300 elif args[0] == 'test':
mblighaa383b72008-03-12 20:11:56 +0000301 manager.test_sync_db(version)
mblighb090f142008-02-27 21:33:46 +0000302 elif args[0] == 'simulate':
mblighaa383b72008-03-12 20:11:56 +0000303 manager.simulate_sync_db(version)
mblighb090f142008-02-27 21:33:46 +0000304 elif args[0] == 'safesync':
mblighe8819cd2008-02-15 16:48:40 +0000305 print 'Simluating migration'
mblighaa383b72008-03-12 20:11:56 +0000306 manager.simulate_sync_db(version)
mblighe8819cd2008-02-15 16:48:40 +0000307 print 'Performing real migration'
mblighaa383b72008-03-12 20:11:56 +0000308 manager.do_sync_db(version)
mblighe8819cd2008-02-15 16:48:40 +0000309 else:
310 print USAGE
311 return
mblighb090f142008-02-27 21:33:46 +0000312
mblighe8819cd2008-02-15 16:48:40 +0000313 print USAGE
314
315
316if __name__ == '__main__':
317 main()