blob: ec4ece1d36767b347144d2ad85e14288f74350b3 [file] [log] [blame]
showardf208abc2008-04-01 23:57:11 +00001#!/usr/bin/python2.4
2
3import unittest
4import MySQLdb
5import migrate
mbligh9b907d62008-05-13 17:56:24 +00006import common
7from autotest_lib.client.common_lib import global_config
showardf208abc2008-04-01 23:57:11 +00008
9# Which section of the global config to pull info from. We won't actually use
10# that DB, we'll use the corresponding test DB (test_<db name>).
11CONFIG_DB = 'AUTOTEST_WEB'
12
13NUM_MIGRATIONS = 3
14
15class DummyMigration(object):
16 """\
17 Dummy migration class that records all migrations done in a class
18 varaible.
19 """
20
21 migrations_done = []
22
23 def __init__(self, version):
24 self.version = version
25 self.name = '%03d_test' % version
26
27
28 @classmethod
29 def get_migrations_done(cls):
30 return cls.migrations_done
31
32
33 @classmethod
34 def clear_migrations_done(cls):
35 cls.migrations_done = []
36
37
38 @classmethod
39 def do_migration(cls, version, direction):
40 cls.migrations_done.append((version, direction))
41
42
43 def migrate_up(self, manager):
44 self.do_migration(self.version, 'up')
45 if self.version == 1:
46 manager.create_migrate_table()
47
48
49 def migrate_down(self, manager):
50 self.do_migration(self.version, 'down')
51
52
53MIGRATIONS = [DummyMigration(n) for n in xrange(1, NUM_MIGRATIONS + 1)]
54
55
56class TestableMigrationManager(migrate.MigrationManager):
57 def __init__(self, database, migrations_dir=None):
58 self.database = database
59 self.migrations_dir = migrations_dir
60 self.db_host = None
61 self.db_name = None
62 self.username = None
63 self.password = None
64
65
66 def read_db_info(self):
67 migrate.MigrationManager.read_db_info(self)
68 self.db_name = 'test_' + self.db_name
69
70
71 def get_migrations(self, minimum_version=None, maximum_version=None):
72 minimum_version = minimum_version or 1
73 maximum_version = maximum_version or len(MIGRATIONS)
74 return MIGRATIONS[minimum_version-1:maximum_version]
75
76
77class MigrateManagerTest(unittest.TestCase):
78 config = global_config.global_config
79 host = config.get_config_value(CONFIG_DB, 'host')
80 db_name = 'test_' + config.get_config_value(CONFIG_DB, 'database')
81 user = config.get_config_value(CONFIG_DB, 'user')
82 password = config.get_config_value(CONFIG_DB, 'password')
83
84 def do_sql(self, sql):
85 self.con = MySQLdb.connect(host=self.host, user=self.user,
86 passwd=self.password)
87 self.con.autocommit(True)
88 self.cur = self.con.cursor()
89 try:
90 self.cur.execute(sql)
91 finally:
92 self.con.close()
93
94
95 def remove_db(self):
96 self.do_sql('DROP DATABASE ' + self.db_name)
97
98
99 def setUp(self):
100 self.do_sql('CREATE DATABASE ' + self.db_name)
101 try:
102 self.manager = TestableMigrationManager(CONFIG_DB)
103 except MySQLdb.OperationalError:
104 self.remove_db()
105 raise
106 DummyMigration.clear_migrations_done()
107
108
109 def tearDown(self):
110 self.remove_db()
111
112
113 def test_sync(self):
114 self.manager.do_sync_db()
115 self.assertEquals(self.manager.get_db_version(), NUM_MIGRATIONS)
116 self.assertEquals(DummyMigration.get_migrations_done(),
117 [(1, 'up'), (2, 'up'), (3, 'up')])
118
119 DummyMigration.clear_migrations_done()
120 self.manager.do_sync_db(0)
121 self.assertEquals(self.manager.get_db_version(), 0)
122 self.assertEquals(DummyMigration.get_migrations_done(),
123 [(3, 'down'), (2, 'down'), (1, 'down')])
124
125
126 def test_sync_one_by_one(self):
127 for version in xrange(1, NUM_MIGRATIONS + 1):
128 self.manager.do_sync_db(version)
129 self.assertEquals(self.manager.get_db_version(),
130 version)
131 self.assertEquals(
132 DummyMigration.get_migrations_done()[-1],
133 (version, 'up'))
134
135 for version in xrange(NUM_MIGRATIONS - 1, -1, -1):
136 self.manager.do_sync_db(version)
137 self.assertEquals(self.manager.get_db_version(),
138 version)
139 self.assertEquals(
140 DummyMigration.get_migrations_done()[-1],
141 (version + 1, 'down'))
142
143
144 def test_null_sync(self):
145 self.manager.do_sync_db()
146 DummyMigration.clear_migrations_done()
147 self.manager.do_sync_db()
148 self.assertEquals(DummyMigration.get_migrations_done(), [])
149
150
151if __name__ == '__main__':
152 unittest.main()