blob: 8012eaedde692f8c0673a9c0bf04888bdae1b014 [file] [log] [blame]
Aviv Keshet47539902013-06-21 15:29:31 -07001# pylint: disable-msg=C0111
2
showard89f84db2009-03-12 20:39:13 +00003import re, time, traceback
showard0e73c852008-10-03 10:15:50 +00004import common
5from autotest_lib.client.common_lib import global_config
6
7RECONNECT_FOREVER = object()
8
9_DB_EXCEPTIONS = ('DatabaseError', 'OperationalError', 'ProgrammingError')
10_GLOBAL_CONFIG_NAMES = {
11 'username' : 'user',
12 'db_name' : 'database',
13}
14
15def _copy_exceptions(source, destination):
16 for exception_name in _DB_EXCEPTIONS:
Dale Curtis74a314b2011-06-23 14:55:46 -070017 try:
18 setattr(destination, exception_name,
19 getattr(source, exception_name))
20 except AttributeError:
21 # Under the django backend:
22 # Django 1.3 does not have OperationalError and ProgrammingError.
23 # Let's just mock these classes with the base DatabaseError.
24 setattr(destination, exception_name,
25 getattr(source, 'DatabaseError'))
showard0e73c852008-10-03 10:15:50 +000026
27
28class _GenericBackend(object):
29 def __init__(self, database_module):
30 self._database_module = database_module
31 self._connection = None
32 self._cursor = None
33 self.rowcount = None
34 _copy_exceptions(database_module, self)
35
36
37 def connect(self, host=None, username=None, password=None, db_name=None):
38 """
39 This is assumed to enable autocommit.
40 """
41 raise NotImplementedError
42
43
44 def disconnect(self):
45 if self._connection:
46 self._connection.close()
47 self._connection = None
48 self._cursor = None
49
50
showardb1e51872008-10-07 11:08:18 +000051 def execute(self, query, parameters=None):
showard34ab0992009-10-05 22:47:57 +000052 if parameters is None:
53 parameters = ()
showardb1e51872008-10-07 11:08:18 +000054 self._cursor.execute(query, parameters)
showard0e73c852008-10-03 10:15:50 +000055 self.rowcount = self._cursor.rowcount
56 return self._cursor.fetchall()
57
58
showard0e73c852008-10-03 10:15:50 +000059class _MySqlBackend(_GenericBackend):
60 def __init__(self):
61 import MySQLdb
62 super(_MySqlBackend, self).__init__(MySQLdb)
63
64
65 @staticmethod
66 def convert_boolean(boolean, conversion_dict):
67 'Convert booleans to integer strings'
68 return str(int(boolean))
69
70
71 def connect(self, host=None, username=None, password=None, db_name=None):
72 import MySQLdb.converters
73 convert_dict = MySQLdb.converters.conversions
74 convert_dict.setdefault(bool, self.convert_boolean)
75
76 self._connection = self._database_module.connect(
77 host=host, user=username, passwd=password, db=db_name,
78 conv=convert_dict)
79 self._connection.autocommit(True)
80 self._cursor = self._connection.cursor()
81
82
showard0e73c852008-10-03 10:15:50 +000083class _SqliteBackend(_GenericBackend):
84 def __init__(self):
Aviv Keshet47539902013-06-21 15:29:31 -070085 try:
86 from pysqlite2 import dbapi2
87 except ImportError:
88 from sqlite3 import dbapi2
showard0e73c852008-10-03 10:15:50 +000089 super(_SqliteBackend, self).__init__(dbapi2)
showard89f84db2009-03-12 20:39:13 +000090 self._last_insert_id_re = re.compile(r'\sLAST_INSERT_ID\(\)',
91 re.IGNORECASE)
showard0e73c852008-10-03 10:15:50 +000092
93
94 def connect(self, host=None, username=None, password=None, db_name=None):
95 self._connection = self._database_module.connect(db_name)
96 self._connection.isolation_level = None # enable autocommit
97 self._cursor = self._connection.cursor()
98
99
showardb1e51872008-10-07 11:08:18 +0000100 def execute(self, query, parameters=None):
showard0e73c852008-10-03 10:15:50 +0000101 # pysqlite2 uses paramstyle=qmark
102 # TODO: make this more sophisticated if necessary
103 query = query.replace('%s', '?')
showardb1e51872008-10-07 11:08:18 +0000104 # pysqlite2 can't handle parameters=None (it throws a nonsense
105 # exception)
106 if parameters is None:
107 parameters = ()
showard89f84db2009-03-12 20:39:13 +0000108 # sqlite3 doesn't support MySQL's LAST_INSERT_ID(). Instead it has
109 # something similar called LAST_INSERT_ROWID() that will do enough of
110 # what we want (for our non-concurrent unittest use case).
111 query = self._last_insert_id_re.sub(' LAST_INSERT_ROWID()', query)
showardb1e51872008-10-07 11:08:18 +0000112 return super(_SqliteBackend, self).execute(query, parameters)
showard0e73c852008-10-03 10:15:50 +0000113
114
showard34ab0992009-10-05 22:47:57 +0000115class _DjangoBackend(_GenericBackend):
116 def __init__(self):
showardb21b8c82009-12-07 19:39:39 +0000117 from django.db import backend, connection, transaction
Dale Curtis74a314b2011-06-23 14:55:46 -0700118 import django.db as django_db
119 super(_DjangoBackend, self).__init__(django_db)
showardb21b8c82009-12-07 19:39:39 +0000120 self._django_connection = connection
121 self._django_transaction = transaction
showard34ab0992009-10-05 22:47:57 +0000122
123
124 def connect(self, host=None, username=None, password=None, db_name=None):
showardb21b8c82009-12-07 19:39:39 +0000125 self._connection = self._django_connection
showard34ab0992009-10-05 22:47:57 +0000126 self._cursor = self._connection.cursor()
127
128
showardb21b8c82009-12-07 19:39:39 +0000129 def execute(self, query, parameters=None):
130 try:
131 return super(_DjangoBackend, self).execute(query,
132 parameters=parameters)
133 finally:
134 self._django_transaction.commit_unless_managed()
135
136
showard0e73c852008-10-03 10:15:50 +0000137_BACKEND_MAP = {
showard34ab0992009-10-05 22:47:57 +0000138 'mysql': _MySqlBackend,
139 'sqlite': _SqliteBackend,
140 'django': _DjangoBackend,
showard0e73c852008-10-03 10:15:50 +0000141}
142
143
144class DatabaseConnection(object):
145 """
146 Generic wrapper for a database connection. Supports both mysql and sqlite
147 backends.
148
149 Public attributes:
150 * reconnect_enabled: if True, when an OperationalError occurs the class will
151 try to reconnect to the database automatically.
152 * reconnect_delay_sec: seconds to wait before reconnecting
153 * max_reconnect_attempts: maximum number of time to try reconnecting before
154 giving up. Setting to RECONNECT_FOREVER removes the limit.
155 * rowcount - will hold cursor.rowcount after each call to execute().
156 * global_config_section - the section in which to find DB information. this
157 should be passed to the constructor, not set later, and may be None, in
158 which case information must be passed to connect().
showardb1e51872008-10-07 11:08:18 +0000159 * debug - if set True, all queries will be printed before being executed
showard0e73c852008-10-03 10:15:50 +0000160 """
161 _DATABASE_ATTRIBUTES = ('db_type', 'host', 'username', 'password',
162 'db_name')
163
mblighfe6f1a42010-03-03 17:15:04 +0000164 def __init__(self, global_config_section=None, debug=False):
showard0e73c852008-10-03 10:15:50 +0000165 self.global_config_section = global_config_section
166 self._backend = None
167 self.rowcount = None
mblighfe6f1a42010-03-03 17:15:04 +0000168 self.debug = debug
showard0e73c852008-10-03 10:15:50 +0000169
170 # reconnect defaults
171 self.reconnect_enabled = True
172 self.reconnect_delay_sec = 20
173 self.max_reconnect_attempts = 10
174
175 self._read_options()
176
177
178 def _get_option(self, name, provided_value):
179 if provided_value is not None:
180 return provided_value
181 if self.global_config_section:
182 global_config_name = _GLOBAL_CONFIG_NAMES.get(name, name)
183 return global_config.global_config.get_config_value(
184 self.global_config_section, global_config_name)
185 return getattr(self, name, None)
186
187
188 def _read_options(self, db_type=None, host=None, username=None,
189 password=None, db_name=None):
190 self.db_type = self._get_option('db_type', db_type)
191 self.host = self._get_option('host', host)
192 self.username = self._get_option('username', username)
193 self.password = self._get_option('password', password)
194 self.db_name = self._get_option('db_name', db_name)
195
196
197 def _get_backend(self, db_type):
198 if db_type not in _BACKEND_MAP:
199 raise ValueError('Invalid database type: %s, should be one of %s' %
200 (db_type, ', '.join(_BACKEND_MAP.keys())))
201 backend_class = _BACKEND_MAP[db_type]
202 return backend_class()
203
204
205 def _reached_max_attempts(self, num_attempts):
206 return (self.max_reconnect_attempts is not RECONNECT_FOREVER and
207 num_attempts > self.max_reconnect_attempts)
208
209
210 def _is_reconnect_enabled(self, supplied_param):
211 if supplied_param is not None:
212 return supplied_param
213 return self.reconnect_enabled
214
215
216 def _connect_backend(self, try_reconnecting=None):
217 num_attempts = 0
218 while True:
219 try:
220 self._backend.connect(host=self.host, username=self.username,
221 password=self.password,
222 db_name=self.db_name)
223 return
224 except self._backend.OperationalError:
225 num_attempts += 1
226 if not self._is_reconnect_enabled(try_reconnecting):
227 raise
228 if self._reached_max_attempts(num_attempts):
229 raise
230 traceback.print_exc()
231 print ("Can't connect to database; reconnecting in %s sec" %
232 self.reconnect_delay_sec)
233 time.sleep(self.reconnect_delay_sec)
234 self.disconnect()
235
236
237 def connect(self, db_type=None, host=None, username=None, password=None,
238 db_name=None, try_reconnecting=None):
239 """
240 Parameters passed to this function will override defaults from global
241 config. try_reconnecting, if passed, will override
242 self.reconnect_enabled.
243 """
244 self.disconnect()
245 self._read_options(db_type, host, username, password, db_name)
246
247 self._backend = self._get_backend(self.db_type)
248 _copy_exceptions(self._backend, self)
249 self._connect_backend(try_reconnecting)
250
251
252 def disconnect(self):
253 if self._backend:
254 self._backend.disconnect()
255
256
257 def execute(self, query, parameters=None, try_reconnecting=None):
258 """
259 Execute a query and return cursor.fetchall(). try_reconnecting, if
260 passed, will override self.reconnect_enabled.
261 """
showardb1e51872008-10-07 11:08:18 +0000262 if self.debug:
263 print 'Executing %s, %s' % (query, parameters)
showard0e73c852008-10-03 10:15:50 +0000264 # _connect_backend() contains a retry loop, so don't loop here
265 try:
266 results = self._backend.execute(query, parameters)
267 except self._backend.OperationalError:
268 if not self._is_reconnect_enabled(try_reconnecting):
269 raise
270 traceback.print_exc()
271 print ("MYSQL connection died; reconnecting")
272 self.disconnect()
273 self._connect_backend(try_reconnecting)
274 results = self._backend.execute(query, parameters)
275
276 self.rowcount = self._backend.rowcount
277 return results
278
279
280 def get_database_info(self):
281 return dict((attribute, getattr(self, attribute))
282 for attribute in self._DATABASE_ATTRIBUTES)
283
284
285 @classmethod
showard34ab0992009-10-05 22:47:57 +0000286 def get_test_database(cls, file_path=':memory:', **constructor_kwargs):
showard0e73c852008-10-03 10:15:50 +0000287 """
288 Factory method returning a DatabaseConnection for a temporary in-memory
289 database.
290 """
showard34ab0992009-10-05 22:47:57 +0000291 database = cls(**constructor_kwargs)
showard0e73c852008-10-03 10:15:50 +0000292 database.reconnect_enabled = False
showardb1e51872008-10-07 11:08:18 +0000293 database.connect(db_type='sqlite', db_name=file_path)
showard0e73c852008-10-03 10:15:50 +0000294 return database
showard34ab0992009-10-05 22:47:57 +0000295
296
297class TranslatingDatabase(DatabaseConnection):
298 """
299 Database wrapper than applies arbitrary substitution regexps to each query
300 string. Useful for SQLite testing.
301 """
302 def __init__(self, translators):
303 """
304 @param translation_regexps: list of callables to apply to each query
305 string (in order). Each accepts a query string and returns a
306 (possibly) modified query string.
307 """
308 super(TranslatingDatabase, self).__init__()
309 self._translators = translators
310
311
312 def execute(self, query, parameters=None, try_reconnecting=None):
313 for translator in self._translators:
314 query = translator(query)
315 return super(TranslatingDatabase, self).execute(
316 query, parameters=parameters, try_reconnecting=try_reconnecting)
317
318
319 @classmethod
320 def make_regexp_translator(cls, search_re, replace_str):
321 """
322 Returns a translator that calls re.sub() on the query with the given
323 search and replace arguments.
324 """
325 def translator(query):
326 return re.sub(search_re, replace_str, query)
327 return translator