blob: d6447540f4599c480f07dfa39081491cf08e311f [file] [log] [blame]
showardf208abc2008-04-01 23:57:11 +00001#!/usr/bin/python2.4
2
3import unittest
4import MySQLdb
5import migrate
6from common import global_config
7
8# Which section of the global config to pull info from. We won't actually use
9# that DB, we'll use the corresponding test DB (test_<db name>).
10CONFIG_DB = 'AUTOTEST_WEB'
11
12NUM_MIGRATIONS = 3
13
14class DummyMigration(object):
15 """\
16 Dummy migration class that records all migrations done in a class
17 varaible.
18 """
19
20 migrations_done = []
21
22 def __init__(self, version):
23 self.version = version
24 self.name = '%03d_test' % version
25
26
27 @classmethod
28 def get_migrations_done(cls):
29 return cls.migrations_done
30
31
32 @classmethod
33 def clear_migrations_done(cls):
34 cls.migrations_done = []
35
36
37 @classmethod
38 def do_migration(cls, version, direction):
39 cls.migrations_done.append((version, direction))
40
41
42 def migrate_up(self, manager):
43 self.do_migration(self.version, 'up')
44 if self.version == 1:
45 manager.create_migrate_table()
46
47
48 def migrate_down(self, manager):
49 self.do_migration(self.version, 'down')
50
51
52MIGRATIONS = [DummyMigration(n) for n in xrange(1, NUM_MIGRATIONS + 1)]
53
54
55class TestableMigrationManager(migrate.MigrationManager):
56 def __init__(self, database, migrations_dir=None):
57 self.database = database
58 self.migrations_dir = migrations_dir
59 self.db_host = None
60 self.db_name = None
61 self.username = None
62 self.password = None
63
64
65 def read_db_info(self):
66 migrate.MigrationManager.read_db_info(self)
67 self.db_name = 'test_' + self.db_name
68
69
70 def get_migrations(self, minimum_version=None, maximum_version=None):
71 minimum_version = minimum_version or 1
72 maximum_version = maximum_version or len(MIGRATIONS)
73 return MIGRATIONS[minimum_version-1:maximum_version]
74
75
76class MigrateManagerTest(unittest.TestCase):
77 config = global_config.global_config
78 host = config.get_config_value(CONFIG_DB, 'host')
79 db_name = 'test_' + config.get_config_value(CONFIG_DB, 'database')
80 user = config.get_config_value(CONFIG_DB, 'user')
81 password = config.get_config_value(CONFIG_DB, 'password')
82
83 def do_sql(self, sql):
84 self.con = MySQLdb.connect(host=self.host, user=self.user,
85 passwd=self.password)
86 self.con.autocommit(True)
87 self.cur = self.con.cursor()
88 try:
89 self.cur.execute(sql)
90 finally:
91 self.con.close()
92
93
94 def remove_db(self):
95 self.do_sql('DROP DATABASE ' + self.db_name)
96
97
98 def setUp(self):
99 self.do_sql('CREATE DATABASE ' + self.db_name)
100 try:
101 self.manager = TestableMigrationManager(CONFIG_DB)
102 except MySQLdb.OperationalError:
103 self.remove_db()
104 raise
105 DummyMigration.clear_migrations_done()
106
107
108 def tearDown(self):
109 self.remove_db()
110
111
112 def test_sync(self):
113 self.manager.do_sync_db()
114 self.assertEquals(self.manager.get_db_version(), NUM_MIGRATIONS)
115 self.assertEquals(DummyMigration.get_migrations_done(),
116 [(1, 'up'), (2, 'up'), (3, 'up')])
117
118 DummyMigration.clear_migrations_done()
119 self.manager.do_sync_db(0)
120 self.assertEquals(self.manager.get_db_version(), 0)
121 self.assertEquals(DummyMigration.get_migrations_done(),
122 [(3, 'down'), (2, 'down'), (1, 'down')])
123
124
125 def test_sync_one_by_one(self):
126 for version in xrange(1, NUM_MIGRATIONS + 1):
127 self.manager.do_sync_db(version)
128 self.assertEquals(self.manager.get_db_version(),
129 version)
130 self.assertEquals(
131 DummyMigration.get_migrations_done()[-1],
132 (version, 'up'))
133
134 for version in xrange(NUM_MIGRATIONS - 1, -1, -1):
135 self.manager.do_sync_db(version)
136 self.assertEquals(self.manager.get_db_version(),
137 version)
138 self.assertEquals(
139 DummyMigration.get_migrations_done()[-1],
140 (version + 1, 'down'))
141
142
143 def test_null_sync(self):
144 self.manager.do_sync_db()
145 DummyMigration.clear_migrations_done()
146 self.manager.do_sync_db()
147 self.assertEquals(DummyMigration.get_migrations_done(), [])
148
149
150if __name__ == '__main__':
151 unittest.main()