blob: 53903c9e3b16ee2bec0d44cf6c5a50b0bf674d4d [file] [log] [blame]
showard89f84db2009-03-12 20:39:13 +00001import re, time, traceback
showard0e73c852008-10-03 10:15:50 +00002import common
3from autotest_lib.client.common_lib import global_config
4
5RECONNECT_FOREVER = object()
6
7_DB_EXCEPTIONS = ('DatabaseError', 'OperationalError', 'ProgrammingError')
8_GLOBAL_CONFIG_NAMES = {
9 'username' : 'user',
10 'db_name' : 'database',
11}
12
13def _copy_exceptions(source, destination):
14 for exception_name in _DB_EXCEPTIONS:
Dale Curtis74a314b2011-06-23 14:55:46 -070015 try:
16 setattr(destination, exception_name,
17 getattr(source, exception_name))
18 except AttributeError:
19 # Under the django backend:
20 # Django 1.3 does not have OperationalError and ProgrammingError.
21 # Let's just mock these classes with the base DatabaseError.
22 setattr(destination, exception_name,
23 getattr(source, 'DatabaseError'))
showard0e73c852008-10-03 10:15:50 +000024
25
26class _GenericBackend(object):
27 def __init__(self, database_module):
28 self._database_module = database_module
29 self._connection = None
30 self._cursor = None
31 self.rowcount = None
32 _copy_exceptions(database_module, self)
33
34
35 def connect(self, host=None, username=None, password=None, db_name=None):
36 """
37 This is assumed to enable autocommit.
38 """
39 raise NotImplementedError
40
41
42 def disconnect(self):
43 if self._connection:
44 self._connection.close()
45 self._connection = None
46 self._cursor = None
47
48
showardb1e51872008-10-07 11:08:18 +000049 def execute(self, query, parameters=None):
showard34ab0992009-10-05 22:47:57 +000050 if parameters is None:
51 parameters = ()
showardb1e51872008-10-07 11:08:18 +000052 self._cursor.execute(query, parameters)
showard0e73c852008-10-03 10:15:50 +000053 self.rowcount = self._cursor.rowcount
54 return self._cursor.fetchall()
55
56
showard0e73c852008-10-03 10:15:50 +000057class _MySqlBackend(_GenericBackend):
58 def __init__(self):
59 import MySQLdb
60 super(_MySqlBackend, self).__init__(MySQLdb)
61
62
63 @staticmethod
64 def convert_boolean(boolean, conversion_dict):
65 'Convert booleans to integer strings'
66 return str(int(boolean))
67
68
69 def connect(self, host=None, username=None, password=None, db_name=None):
70 import MySQLdb.converters
71 convert_dict = MySQLdb.converters.conversions
72 convert_dict.setdefault(bool, self.convert_boolean)
73
74 self._connection = self._database_module.connect(
75 host=host, user=username, passwd=password, db=db_name,
76 conv=convert_dict)
77 self._connection.autocommit(True)
78 self._cursor = self._connection.cursor()
79
80
showard0e73c852008-10-03 10:15:50 +000081class _SqliteBackend(_GenericBackend):
82 def __init__(self):
83 from pysqlite2 import dbapi2
84 super(_SqliteBackend, self).__init__(dbapi2)
showard89f84db2009-03-12 20:39:13 +000085 self._last_insert_id_re = re.compile(r'\sLAST_INSERT_ID\(\)',
86 re.IGNORECASE)
showard0e73c852008-10-03 10:15:50 +000087
88
89 def connect(self, host=None, username=None, password=None, db_name=None):
90 self._connection = self._database_module.connect(db_name)
91 self._connection.isolation_level = None # enable autocommit
92 self._cursor = self._connection.cursor()
93
94
showardb1e51872008-10-07 11:08:18 +000095 def execute(self, query, parameters=None):
showard0e73c852008-10-03 10:15:50 +000096 # pysqlite2 uses paramstyle=qmark
97 # TODO: make this more sophisticated if necessary
98 query = query.replace('%s', '?')
showardb1e51872008-10-07 11:08:18 +000099 # pysqlite2 can't handle parameters=None (it throws a nonsense
100 # exception)
101 if parameters is None:
102 parameters = ()
showard89f84db2009-03-12 20:39:13 +0000103 # sqlite3 doesn't support MySQL's LAST_INSERT_ID(). Instead it has
104 # something similar called LAST_INSERT_ROWID() that will do enough of
105 # what we want (for our non-concurrent unittest use case).
106 query = self._last_insert_id_re.sub(' LAST_INSERT_ROWID()', query)
showardb1e51872008-10-07 11:08:18 +0000107 return super(_SqliteBackend, self).execute(query, parameters)
showard0e73c852008-10-03 10:15:50 +0000108
109
showard34ab0992009-10-05 22:47:57 +0000110class _DjangoBackend(_GenericBackend):
111 def __init__(self):
showardb21b8c82009-12-07 19:39:39 +0000112 from django.db import backend, connection, transaction
Dale Curtis74a314b2011-06-23 14:55:46 -0700113 import django.db as django_db
114 super(_DjangoBackend, self).__init__(django_db)
showardb21b8c82009-12-07 19:39:39 +0000115 self._django_connection = connection
116 self._django_transaction = transaction
showard34ab0992009-10-05 22:47:57 +0000117
118
119 def connect(self, host=None, username=None, password=None, db_name=None):
showardb21b8c82009-12-07 19:39:39 +0000120 self._connection = self._django_connection
showard34ab0992009-10-05 22:47:57 +0000121 self._cursor = self._connection.cursor()
122
123
showardb21b8c82009-12-07 19:39:39 +0000124 def execute(self, query, parameters=None):
125 try:
126 return super(_DjangoBackend, self).execute(query,
127 parameters=parameters)
128 finally:
129 self._django_transaction.commit_unless_managed()
130
131
showard0e73c852008-10-03 10:15:50 +0000132_BACKEND_MAP = {
showard34ab0992009-10-05 22:47:57 +0000133 'mysql': _MySqlBackend,
134 'sqlite': _SqliteBackend,
135 'django': _DjangoBackend,
showard0e73c852008-10-03 10:15:50 +0000136}
137
138
139class DatabaseConnection(object):
140 """
141 Generic wrapper for a database connection. Supports both mysql and sqlite
142 backends.
143
144 Public attributes:
145 * reconnect_enabled: if True, when an OperationalError occurs the class will
146 try to reconnect to the database automatically.
147 * reconnect_delay_sec: seconds to wait before reconnecting
148 * max_reconnect_attempts: maximum number of time to try reconnecting before
149 giving up. Setting to RECONNECT_FOREVER removes the limit.
150 * rowcount - will hold cursor.rowcount after each call to execute().
151 * global_config_section - the section in which to find DB information. this
152 should be passed to the constructor, not set later, and may be None, in
153 which case information must be passed to connect().
showardb1e51872008-10-07 11:08:18 +0000154 * debug - if set True, all queries will be printed before being executed
showard0e73c852008-10-03 10:15:50 +0000155 """
156 _DATABASE_ATTRIBUTES = ('db_type', 'host', 'username', 'password',
157 'db_name')
158
mblighfe6f1a42010-03-03 17:15:04 +0000159 def __init__(self, global_config_section=None, debug=False):
showard0e73c852008-10-03 10:15:50 +0000160 self.global_config_section = global_config_section
161 self._backend = None
162 self.rowcount = None
mblighfe6f1a42010-03-03 17:15:04 +0000163 self.debug = debug
showard0e73c852008-10-03 10:15:50 +0000164
165 # reconnect defaults
166 self.reconnect_enabled = True
167 self.reconnect_delay_sec = 20
168 self.max_reconnect_attempts = 10
169
170 self._read_options()
171
172
173 def _get_option(self, name, provided_value):
174 if provided_value is not None:
175 return provided_value
176 if self.global_config_section:
177 global_config_name = _GLOBAL_CONFIG_NAMES.get(name, name)
178 return global_config.global_config.get_config_value(
179 self.global_config_section, global_config_name)
180 return getattr(self, name, None)
181
182
183 def _read_options(self, db_type=None, host=None, username=None,
184 password=None, db_name=None):
185 self.db_type = self._get_option('db_type', db_type)
186 self.host = self._get_option('host', host)
187 self.username = self._get_option('username', username)
188 self.password = self._get_option('password', password)
189 self.db_name = self._get_option('db_name', db_name)
190
191
192 def _get_backend(self, db_type):
193 if db_type not in _BACKEND_MAP:
194 raise ValueError('Invalid database type: %s, should be one of %s' %
195 (db_type, ', '.join(_BACKEND_MAP.keys())))
196 backend_class = _BACKEND_MAP[db_type]
197 return backend_class()
198
199
200 def _reached_max_attempts(self, num_attempts):
201 return (self.max_reconnect_attempts is not RECONNECT_FOREVER and
202 num_attempts > self.max_reconnect_attempts)
203
204
205 def _is_reconnect_enabled(self, supplied_param):
206 if supplied_param is not None:
207 return supplied_param
208 return self.reconnect_enabled
209
210
211 def _connect_backend(self, try_reconnecting=None):
212 num_attempts = 0
213 while True:
214 try:
215 self._backend.connect(host=self.host, username=self.username,
216 password=self.password,
217 db_name=self.db_name)
218 return
219 except self._backend.OperationalError:
220 num_attempts += 1
221 if not self._is_reconnect_enabled(try_reconnecting):
222 raise
223 if self._reached_max_attempts(num_attempts):
224 raise
225 traceback.print_exc()
226 print ("Can't connect to database; reconnecting in %s sec" %
227 self.reconnect_delay_sec)
228 time.sleep(self.reconnect_delay_sec)
229 self.disconnect()
230
231
232 def connect(self, db_type=None, host=None, username=None, password=None,
233 db_name=None, try_reconnecting=None):
234 """
235 Parameters passed to this function will override defaults from global
236 config. try_reconnecting, if passed, will override
237 self.reconnect_enabled.
238 """
239 self.disconnect()
240 self._read_options(db_type, host, username, password, db_name)
241
242 self._backend = self._get_backend(self.db_type)
243 _copy_exceptions(self._backend, self)
244 self._connect_backend(try_reconnecting)
245
246
247 def disconnect(self):
248 if self._backend:
249 self._backend.disconnect()
250
251
252 def execute(self, query, parameters=None, try_reconnecting=None):
253 """
254 Execute a query and return cursor.fetchall(). try_reconnecting, if
255 passed, will override self.reconnect_enabled.
256 """
showardb1e51872008-10-07 11:08:18 +0000257 if self.debug:
258 print 'Executing %s, %s' % (query, parameters)
showard0e73c852008-10-03 10:15:50 +0000259 # _connect_backend() contains a retry loop, so don't loop here
260 try:
261 results = self._backend.execute(query, parameters)
262 except self._backend.OperationalError:
263 if not self._is_reconnect_enabled(try_reconnecting):
264 raise
265 traceback.print_exc()
266 print ("MYSQL connection died; reconnecting")
267 self.disconnect()
268 self._connect_backend(try_reconnecting)
269 results = self._backend.execute(query, parameters)
270
271 self.rowcount = self._backend.rowcount
272 return results
273
274
275 def get_database_info(self):
276 return dict((attribute, getattr(self, attribute))
277 for attribute in self._DATABASE_ATTRIBUTES)
278
279
280 @classmethod
showard34ab0992009-10-05 22:47:57 +0000281 def get_test_database(cls, file_path=':memory:', **constructor_kwargs):
showard0e73c852008-10-03 10:15:50 +0000282 """
283 Factory method returning a DatabaseConnection for a temporary in-memory
284 database.
285 """
showard34ab0992009-10-05 22:47:57 +0000286 database = cls(**constructor_kwargs)
showard0e73c852008-10-03 10:15:50 +0000287 database.reconnect_enabled = False
showardb1e51872008-10-07 11:08:18 +0000288 database.connect(db_type='sqlite', db_name=file_path)
showard0e73c852008-10-03 10:15:50 +0000289 return database
showard34ab0992009-10-05 22:47:57 +0000290
291
292class TranslatingDatabase(DatabaseConnection):
293 """
294 Database wrapper than applies arbitrary substitution regexps to each query
295 string. Useful for SQLite testing.
296 """
297 def __init__(self, translators):
298 """
299 @param translation_regexps: list of callables to apply to each query
300 string (in order). Each accepts a query string and returns a
301 (possibly) modified query string.
302 """
303 super(TranslatingDatabase, self).__init__()
304 self._translators = translators
305
306
307 def execute(self, query, parameters=None, try_reconnecting=None):
308 for translator in self._translators:
309 query = translator(query)
310 return super(TranslatingDatabase, self).execute(
311 query, parameters=parameters, try_reconnecting=try_reconnecting)
312
313
314 @classmethod
315 def make_regexp_translator(cls, search_re, replace_str):
316 """
317 Returns a translator that calls re.sub() on the query with the given
318 search and replace arguments.
319 """
320 def translator(query):
321 return re.sub(search_re, replace_str, query)
322 return translator