blob: b5386af2e49f1db0ae1305b70f54ca00f83ab19c [file] [log] [blame]
mbligh7c8ea992009-06-22 19:03:08 +00001#!/usr/bin/python
showard0e73c852008-10-03 10:15:50 +00002
3import unittest, time
showard0e73c852008-10-03 10:15:50 +00004import common
5from autotest_lib.client.common_lib import global_config
6from autotest_lib.client.common_lib.test_utils import mock
7from autotest_lib.database import database_connection
8
9_CONFIG_SECTION = 'TKO'
10_HOST = 'myhost'
11_USER = 'myuser'
12_PASS = 'mypass'
13_DB_NAME = 'mydb'
14_DB_TYPE = 'mydbtype'
15
16_CONNECT_KWARGS = dict(host=_HOST, username=_USER, password=_PASS,
17 db_name=_DB_NAME)
18_RECONNECT_DELAY = 10
19
20class FakeDatabaseError(Exception):
21 pass
22
23
24class DatabaseConnectionTest(unittest.TestCase):
25 def setUp(self):
26 self.god = mock.mock_god()
27 self.god.stub_function(time, 'sleep')
28
29
30 def tearDown(self):
31 global_config.global_config.reset_config_values()
32 self.god.unstub_all()
33
34
35 def _get_database_connection(self, config_section=_CONFIG_SECTION):
36 if config_section == _CONFIG_SECTION:
37 self._override_config()
38 db = database_connection.DatabaseConnection(config_section)
39
40 self._fake_backend = self.god.create_mock_class(
41 database_connection._GenericBackend, 'fake_backend')
42 for exception in database_connection._DB_EXCEPTIONS:
43 setattr(self._fake_backend, exception, FakeDatabaseError)
44 self._fake_backend.rowcount = 0
45
46 def get_fake_backend(db_type):
47 self._db_type = db_type
48 return self._fake_backend
49 self.god.stub_with(db, '_get_backend', get_fake_backend)
50
51 db.reconnect_delay_sec = _RECONNECT_DELAY
52 return db
53
54
55 def _override_config(self):
56 c = global_config.global_config
57 c.override_config_value(_CONFIG_SECTION, 'host', _HOST)
58 c.override_config_value(_CONFIG_SECTION, 'user', _USER)
59 c.override_config_value(_CONFIG_SECTION, 'password', _PASS)
60 c.override_config_value(_CONFIG_SECTION, 'database', _DB_NAME)
61 c.override_config_value(_CONFIG_SECTION, 'db_type', _DB_TYPE)
62
63
64 def test_connect(self):
65 db = self._get_database_connection(config_section=None)
66 self._fake_backend.connect.expect_call(**_CONNECT_KWARGS)
67
68 db.connect(db_type=_DB_TYPE, host=_HOST, username=_USER,
69 password=_PASS, db_name=_DB_NAME)
70
71 self.assertEquals(self._db_type, _DB_TYPE)
72 self.god.check_playback()
73
74
75 def test_global_config(self):
76 db = self._get_database_connection()
77 self._fake_backend.connect.expect_call(**_CONNECT_KWARGS)
78
79 db.connect()
80
81 self.assertEquals(self._db_type, _DB_TYPE)
82 self.god.check_playback()
83
84
85 def _expect_reconnect(self, fail=False):
86 self._fake_backend.disconnect.expect_call()
87 call = self._fake_backend.connect.expect_call(**_CONNECT_KWARGS)
88 if fail:
89 call.and_raises(FakeDatabaseError())
90
91
92 def _expect_fail_and_reconnect(self, num_reconnects, fail_last=False):
93 self._fake_backend.connect.expect_call(**_CONNECT_KWARGS).and_raises(
94 FakeDatabaseError())
95 for i in xrange(num_reconnects):
96 time.sleep.expect_call(_RECONNECT_DELAY)
97 if i < num_reconnects - 1:
98 self._expect_reconnect(fail=True)
99 else:
100 self._expect_reconnect(fail=fail_last)
101
102
103 def test_connect_retry(self):
104 db = self._get_database_connection()
105 self._expect_fail_and_reconnect(1)
106
107 db.connect()
108 self.god.check_playback()
109
110 self._fake_backend.disconnect.expect_call()
111 self._expect_fail_and_reconnect(0)
112 self.assertRaises(FakeDatabaseError, db.connect,
113 try_reconnecting=False)
114 self.god.check_playback()
115
116 db.reconnect_enabled = False
117 self._fake_backend.disconnect.expect_call()
118 self._expect_fail_and_reconnect(0)
119 self.assertRaises(FakeDatabaseError, db.connect)
120 self.god.check_playback()
121
122
123 def test_max_reconnect(self):
124 db = self._get_database_connection()
125 db.max_reconnect_attempts = 5
126 self._expect_fail_and_reconnect(5, fail_last=True)
127
128 self.assertRaises(FakeDatabaseError, db.connect)
129 self.god.check_playback()
130
131
132 def test_reconnect_forever(self):
133 db = self._get_database_connection()
134 db.max_reconnect_attempts = database_connection.RECONNECT_FOREVER
135 self._expect_fail_and_reconnect(30)
136
137 db.connect()
138 self.god.check_playback()
139
140
141 def _simple_connect(self, db):
142 self._fake_backend.connect.expect_call(**_CONNECT_KWARGS)
143 db.connect()
144 self.god.check_playback()
145
146
147 def test_disconnect(self):
148 db = self._get_database_connection()
149 self._simple_connect(db)
150 self._fake_backend.disconnect.expect_call()
151
152 db.disconnect()
153 self.god.check_playback()
154
155
156 def test_execute(self):
157 db = self._get_database_connection()
158 self._simple_connect(db)
159 params = object()
160 self._fake_backend.execute.expect_call('query', params)
161
162 db.execute('query', params)
163 self.god.check_playback()
164
165
166 def test_execute_retry(self):
167 db = self._get_database_connection()
168 self._simple_connect(db)
169 self._fake_backend.execute.expect_call('query', None).and_raises(
170 FakeDatabaseError())
171 self._expect_reconnect()
172 self._fake_backend.execute.expect_call('query', None)
173
174 db.execute('query')
175 self.god.check_playback()
176
177 self._fake_backend.execute.expect_call('query', None).and_raises(
178 FakeDatabaseError())
179 self.assertRaises(FakeDatabaseError, db.execute, 'query',
180 try_reconnecting=False)
181
182
183if __name__ == '__main__':
184 unittest.main()