blob: 90b43e53a02d23e2d4f0942cc5b8aa7aceaebc56 [file] [log] [blame]
Aviv Keshet47539902013-06-21 15:29:31 -07001# pylint: disable-msg=C0111
2
showard89f84db2009-03-12 20:39:13 +00003import re, time, traceback
showard0e73c852008-10-03 10:15:50 +00004import common
5from autotest_lib.client.common_lib import global_config
6
7RECONNECT_FOREVER = object()
8
9_DB_EXCEPTIONS = ('DatabaseError', 'OperationalError', 'ProgrammingError')
10_GLOBAL_CONFIG_NAMES = {
11 'username' : 'user',
12 'db_name' : 'database',
13}
14
15def _copy_exceptions(source, destination):
16 for exception_name in _DB_EXCEPTIONS:
Dale Curtis74a314b2011-06-23 14:55:46 -070017 try:
18 setattr(destination, exception_name,
19 getattr(source, exception_name))
20 except AttributeError:
21 # Under the django backend:
22 # Django 1.3 does not have OperationalError and ProgrammingError.
23 # Let's just mock these classes with the base DatabaseError.
24 setattr(destination, exception_name,
25 getattr(source, 'DatabaseError'))
showard0e73c852008-10-03 10:15:50 +000026
27
28class _GenericBackend(object):
29 def __init__(self, database_module):
30 self._database_module = database_module
31 self._connection = None
32 self._cursor = None
33 self.rowcount = None
34 _copy_exceptions(database_module, self)
35
36
37 def connect(self, host=None, username=None, password=None, db_name=None):
38 """
39 This is assumed to enable autocommit.
40 """
41 raise NotImplementedError
42
43
44 def disconnect(self):
45 if self._connection:
46 self._connection.close()
47 self._connection = None
48 self._cursor = None
49
50
showardb1e51872008-10-07 11:08:18 +000051 def execute(self, query, parameters=None):
showard34ab0992009-10-05 22:47:57 +000052 if parameters is None:
53 parameters = ()
showardb1e51872008-10-07 11:08:18 +000054 self._cursor.execute(query, parameters)
showard0e73c852008-10-03 10:15:50 +000055 self.rowcount = self._cursor.rowcount
56 return self._cursor.fetchall()
57
58
showard0e73c852008-10-03 10:15:50 +000059class _MySqlBackend(_GenericBackend):
60 def __init__(self):
61 import MySQLdb
62 super(_MySqlBackend, self).__init__(MySQLdb)
63
64
65 @staticmethod
66 def convert_boolean(boolean, conversion_dict):
67 'Convert booleans to integer strings'
68 return str(int(boolean))
69
70
71 def connect(self, host=None, username=None, password=None, db_name=None):
72 import MySQLdb.converters
73 convert_dict = MySQLdb.converters.conversions
74 convert_dict.setdefault(bool, self.convert_boolean)
75
76 self._connection = self._database_module.connect(
77 host=host, user=username, passwd=password, db=db_name,
78 conv=convert_dict)
79 self._connection.autocommit(True)
80 self._cursor = self._connection.cursor()
81
82
showard0e73c852008-10-03 10:15:50 +000083class _SqliteBackend(_GenericBackend):
84 def __init__(self):
Aviv Keshet47539902013-06-21 15:29:31 -070085 try:
86 from pysqlite2 import dbapi2
87 except ImportError:
88 from sqlite3 import dbapi2
showard0e73c852008-10-03 10:15:50 +000089 super(_SqliteBackend, self).__init__(dbapi2)
showard89f84db2009-03-12 20:39:13 +000090 self._last_insert_id_re = re.compile(r'\sLAST_INSERT_ID\(\)',
91 re.IGNORECASE)
showard0e73c852008-10-03 10:15:50 +000092
93
94 def connect(self, host=None, username=None, password=None, db_name=None):
95 self._connection = self._database_module.connect(db_name)
96 self._connection.isolation_level = None # enable autocommit
97 self._cursor = self._connection.cursor()
98
99
showardb1e51872008-10-07 11:08:18 +0000100 def execute(self, query, parameters=None):
showard0e73c852008-10-03 10:15:50 +0000101 # pysqlite2 uses paramstyle=qmark
102 # TODO: make this more sophisticated if necessary
103 query = query.replace('%s', '?')
showardb1e51872008-10-07 11:08:18 +0000104 # pysqlite2 can't handle parameters=None (it throws a nonsense
105 # exception)
106 if parameters is None:
107 parameters = ()
showard89f84db2009-03-12 20:39:13 +0000108 # sqlite3 doesn't support MySQL's LAST_INSERT_ID(). Instead it has
109 # something similar called LAST_INSERT_ROWID() that will do enough of
110 # what we want (for our non-concurrent unittest use case).
111 query = self._last_insert_id_re.sub(' LAST_INSERT_ROWID()', query)
showardb1e51872008-10-07 11:08:18 +0000112 return super(_SqliteBackend, self).execute(query, parameters)
showard0e73c852008-10-03 10:15:50 +0000113
114
showard34ab0992009-10-05 22:47:57 +0000115class _DjangoBackend(_GenericBackend):
116 def __init__(self):
showardb21b8c82009-12-07 19:39:39 +0000117 from django.db import backend, connection, transaction
Dale Curtis74a314b2011-06-23 14:55:46 -0700118 import django.db as django_db
119 super(_DjangoBackend, self).__init__(django_db)
showardb21b8c82009-12-07 19:39:39 +0000120 self._django_connection = connection
121 self._django_transaction = transaction
showard34ab0992009-10-05 22:47:57 +0000122
123
124 def connect(self, host=None, username=None, password=None, db_name=None):
showardb21b8c82009-12-07 19:39:39 +0000125 self._connection = self._django_connection
showard34ab0992009-10-05 22:47:57 +0000126 self._cursor = self._connection.cursor()
127
128
showardb21b8c82009-12-07 19:39:39 +0000129 def execute(self, query, parameters=None):
130 try:
131 return super(_DjangoBackend, self).execute(query,
132 parameters=parameters)
133 finally:
134 self._django_transaction.commit_unless_managed()
135
136
showard0e73c852008-10-03 10:15:50 +0000137_BACKEND_MAP = {
showard34ab0992009-10-05 22:47:57 +0000138 'mysql': _MySqlBackend,
139 'sqlite': _SqliteBackend,
140 'django': _DjangoBackend,
showard0e73c852008-10-03 10:15:50 +0000141}
142
143
144class DatabaseConnection(object):
145 """
146 Generic wrapper for a database connection. Supports both mysql and sqlite
147 backends.
148
149 Public attributes:
150 * reconnect_enabled: if True, when an OperationalError occurs the class will
151 try to reconnect to the database automatically.
152 * reconnect_delay_sec: seconds to wait before reconnecting
153 * max_reconnect_attempts: maximum number of time to try reconnecting before
154 giving up. Setting to RECONNECT_FOREVER removes the limit.
155 * rowcount - will hold cursor.rowcount after each call to execute().
156 * global_config_section - the section in which to find DB information. this
157 should be passed to the constructor, not set later, and may be None, in
158 which case information must be passed to connect().
showardb1e51872008-10-07 11:08:18 +0000159 * debug - if set True, all queries will be printed before being executed
showard0e73c852008-10-03 10:15:50 +0000160 """
161 _DATABASE_ATTRIBUTES = ('db_type', 'host', 'username', 'password',
162 'db_name')
163
mblighfe6f1a42010-03-03 17:15:04 +0000164 def __init__(self, global_config_section=None, debug=False):
showard0e73c852008-10-03 10:15:50 +0000165 self.global_config_section = global_config_section
166 self._backend = None
167 self.rowcount = None
mblighfe6f1a42010-03-03 17:15:04 +0000168 self.debug = debug
showard0e73c852008-10-03 10:15:50 +0000169
170 # reconnect defaults
171 self.reconnect_enabled = True
172 self.reconnect_delay_sec = 20
173 self.max_reconnect_attempts = 10
174
175 self._read_options()
176
177
Dan Shi7f0c1832014-10-27 16:05:57 -0700178 def _get_option(self, name, provided_value, use_afe_setting=False):
179 """Get value of given option from global config.
180
181 @param name: Name of the config.
182 @param provided_value: Value being provided to override the one from
183 global config.
184 @param use_afe_setting: Force to use the settings in AFE, default is
185 False.
186 """
187 # TODO(dshi): This function returns the option value depends on multiple
188 # conditions. The value of `provided_value` has highest priority, then
189 # the code checks if use_afe_setting is True, if that's the case, force
190 # to use settings in AUTOTEST_WEB. At last the value is retrieved from
191 # specified global config section.
192 # The logic is too complicated for a generic function named like
193 # _get_option. Ideally we want to make it clear from caller that it
194 # wants to get database credential from one of the 3 ways:
195 # 1. Use the credential from given config section
196 # 2. Use the credential from AUTOTEST_WEB section
197 # 3. Use the credential provided by caller.
showard0e73c852008-10-03 10:15:50 +0000198 if provided_value is not None:
199 return provided_value
Dan Shi7f0c1832014-10-27 16:05:57 -0700200 section = ('AUTOTEST_WEB' if use_afe_setting else
201 self.global_config_section)
202 if section:
showard0e73c852008-10-03 10:15:50 +0000203 global_config_name = _GLOBAL_CONFIG_NAMES.get(name, name)
204 return global_config.global_config.get_config_value(
Dan Shi7f0c1832014-10-27 16:05:57 -0700205 section, global_config_name)
206
showard0e73c852008-10-03 10:15:50 +0000207 return getattr(self, name, None)
208
209
210 def _read_options(self, db_type=None, host=None, username=None,
211 password=None, db_name=None):
Dan Shi7f0c1832014-10-27 16:05:57 -0700212 """Read database information from global config.
213
214 Unless any parameter is specified a value, the connection will use
215 database name from given configure section (self.global_config_section),
216 and database credential from AFE database settings (AUTOTEST_WEB).
217
218 @param db_type: database type, default to None.
219 @param host: database hostname, default to None.
220 @param username: user name for database connection, default to None.
221 @param password: database password, default to None.
222 @param db_name: database name, default to None.
223 """
showard0e73c852008-10-03 10:15:50 +0000224 self.db_name = self._get_option('db_name', db_name)
Dan Shi7f0c1832014-10-27 16:05:57 -0700225 use_afe_setting = not bool(db_type or host or username or password)
226
227 # Database credential can be provided by the caller, as passed in from
228 # function connect.
229 self.db_type = self._get_option('db_type', db_type, use_afe_setting)
230 self.host = self._get_option('host', host, use_afe_setting)
231 self.username = self._get_option('username', username, use_afe_setting)
232 self.password = self._get_option('password', password, use_afe_setting)
showard0e73c852008-10-03 10:15:50 +0000233
234
235 def _get_backend(self, db_type):
236 if db_type not in _BACKEND_MAP:
237 raise ValueError('Invalid database type: %s, should be one of %s' %
238 (db_type, ', '.join(_BACKEND_MAP.keys())))
239 backend_class = _BACKEND_MAP[db_type]
240 return backend_class()
241
242
243 def _reached_max_attempts(self, num_attempts):
244 return (self.max_reconnect_attempts is not RECONNECT_FOREVER and
245 num_attempts > self.max_reconnect_attempts)
246
247
248 def _is_reconnect_enabled(self, supplied_param):
249 if supplied_param is not None:
250 return supplied_param
251 return self.reconnect_enabled
252
253
254 def _connect_backend(self, try_reconnecting=None):
255 num_attempts = 0
256 while True:
257 try:
258 self._backend.connect(host=self.host, username=self.username,
259 password=self.password,
260 db_name=self.db_name)
261 return
262 except self._backend.OperationalError:
263 num_attempts += 1
264 if not self._is_reconnect_enabled(try_reconnecting):
265 raise
266 if self._reached_max_attempts(num_attempts):
267 raise
268 traceback.print_exc()
269 print ("Can't connect to database; reconnecting in %s sec" %
270 self.reconnect_delay_sec)
271 time.sleep(self.reconnect_delay_sec)
272 self.disconnect()
273
274
275 def connect(self, db_type=None, host=None, username=None, password=None,
276 db_name=None, try_reconnecting=None):
277 """
278 Parameters passed to this function will override defaults from global
279 config. try_reconnecting, if passed, will override
280 self.reconnect_enabled.
281 """
282 self.disconnect()
283 self._read_options(db_type, host, username, password, db_name)
284
285 self._backend = self._get_backend(self.db_type)
286 _copy_exceptions(self._backend, self)
287 self._connect_backend(try_reconnecting)
288
289
290 def disconnect(self):
291 if self._backend:
292 self._backend.disconnect()
293
294
295 def execute(self, query, parameters=None, try_reconnecting=None):
296 """
297 Execute a query and return cursor.fetchall(). try_reconnecting, if
298 passed, will override self.reconnect_enabled.
299 """
showardb1e51872008-10-07 11:08:18 +0000300 if self.debug:
301 print 'Executing %s, %s' % (query, parameters)
showard0e73c852008-10-03 10:15:50 +0000302 # _connect_backend() contains a retry loop, so don't loop here
303 try:
304 results = self._backend.execute(query, parameters)
305 except self._backend.OperationalError:
306 if not self._is_reconnect_enabled(try_reconnecting):
307 raise
308 traceback.print_exc()
309 print ("MYSQL connection died; reconnecting")
310 self.disconnect()
311 self._connect_backend(try_reconnecting)
312 results = self._backend.execute(query, parameters)
313
314 self.rowcount = self._backend.rowcount
315 return results
316
317
318 def get_database_info(self):
319 return dict((attribute, getattr(self, attribute))
320 for attribute in self._DATABASE_ATTRIBUTES)
321
322
323 @classmethod
showard34ab0992009-10-05 22:47:57 +0000324 def get_test_database(cls, file_path=':memory:', **constructor_kwargs):
showard0e73c852008-10-03 10:15:50 +0000325 """
326 Factory method returning a DatabaseConnection for a temporary in-memory
327 database.
328 """
showard34ab0992009-10-05 22:47:57 +0000329 database = cls(**constructor_kwargs)
showard0e73c852008-10-03 10:15:50 +0000330 database.reconnect_enabled = False
showardb1e51872008-10-07 11:08:18 +0000331 database.connect(db_type='sqlite', db_name=file_path)
showard0e73c852008-10-03 10:15:50 +0000332 return database
showard34ab0992009-10-05 22:47:57 +0000333
334
335class TranslatingDatabase(DatabaseConnection):
336 """
337 Database wrapper than applies arbitrary substitution regexps to each query
338 string. Useful for SQLite testing.
339 """
340 def __init__(self, translators):
341 """
342 @param translation_regexps: list of callables to apply to each query
343 string (in order). Each accepts a query string and returns a
344 (possibly) modified query string.
345 """
346 super(TranslatingDatabase, self).__init__()
347 self._translators = translators
348
349
350 def execute(self, query, parameters=None, try_reconnecting=None):
351 for translator in self._translators:
352 query = translator(query)
353 return super(TranslatingDatabase, self).execute(
354 query, parameters=parameters, try_reconnecting=try_reconnecting)
355
356
357 @classmethod
358 def make_regexp_translator(cls, search_re, replace_str):
359 """
360 Returns a translator that calls re.sub() on the query with the given
361 search and replace arguments.
362 """
363 def translator(query):
364 return re.sub(search_re, replace_str, query)
365 return translator