showard | 89f84db | 2009-03-12 20:39:13 +0000 | [diff] [blame] | 1 | import re, time, traceback |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 2 | import common |
| 3 | from autotest_lib.client.common_lib import global_config |
| 4 | |
| 5 | RECONNECT_FOREVER = object() |
| 6 | |
| 7 | _DB_EXCEPTIONS = ('DatabaseError', 'OperationalError', 'ProgrammingError') |
| 8 | _GLOBAL_CONFIG_NAMES = { |
| 9 | 'username' : 'user', |
| 10 | 'db_name' : 'database', |
| 11 | } |
| 12 | |
| 13 | def _copy_exceptions(source, destination): |
| 14 | for exception_name in _DB_EXCEPTIONS: |
Dale Curtis | 74a314b | 2011-06-23 14:55:46 -0700 | [diff] [blame] | 15 | try: |
| 16 | setattr(destination, exception_name, |
| 17 | getattr(source, exception_name)) |
| 18 | except AttributeError: |
| 19 | # Under the django backend: |
| 20 | # Django 1.3 does not have OperationalError and ProgrammingError. |
| 21 | # Let's just mock these classes with the base DatabaseError. |
| 22 | setattr(destination, exception_name, |
| 23 | getattr(source, 'DatabaseError')) |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 24 | |
| 25 | |
| 26 | class _GenericBackend(object): |
| 27 | def __init__(self, database_module): |
| 28 | self._database_module = database_module |
| 29 | self._connection = None |
| 30 | self._cursor = None |
| 31 | self.rowcount = None |
| 32 | _copy_exceptions(database_module, self) |
| 33 | |
| 34 | |
| 35 | def connect(self, host=None, username=None, password=None, db_name=None): |
| 36 | """ |
| 37 | This is assumed to enable autocommit. |
| 38 | """ |
| 39 | raise NotImplementedError |
| 40 | |
| 41 | |
| 42 | def disconnect(self): |
| 43 | if self._connection: |
| 44 | self._connection.close() |
| 45 | self._connection = None |
| 46 | self._cursor = None |
| 47 | |
| 48 | |
showard | b1e5187 | 2008-10-07 11:08:18 +0000 | [diff] [blame] | 49 | def execute(self, query, parameters=None): |
showard | 34ab099 | 2009-10-05 22:47:57 +0000 | [diff] [blame] | 50 | if parameters is None: |
| 51 | parameters = () |
showard | b1e5187 | 2008-10-07 11:08:18 +0000 | [diff] [blame] | 52 | self._cursor.execute(query, parameters) |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 53 | self.rowcount = self._cursor.rowcount |
| 54 | return self._cursor.fetchall() |
| 55 | |
| 56 | |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 57 | class _MySqlBackend(_GenericBackend): |
| 58 | def __init__(self): |
| 59 | import MySQLdb |
| 60 | super(_MySqlBackend, self).__init__(MySQLdb) |
| 61 | |
| 62 | |
| 63 | @staticmethod |
| 64 | def convert_boolean(boolean, conversion_dict): |
| 65 | 'Convert booleans to integer strings' |
| 66 | return str(int(boolean)) |
| 67 | |
| 68 | |
| 69 | def connect(self, host=None, username=None, password=None, db_name=None): |
| 70 | import MySQLdb.converters |
| 71 | convert_dict = MySQLdb.converters.conversions |
| 72 | convert_dict.setdefault(bool, self.convert_boolean) |
| 73 | |
| 74 | self._connection = self._database_module.connect( |
| 75 | host=host, user=username, passwd=password, db=db_name, |
| 76 | conv=convert_dict) |
| 77 | self._connection.autocommit(True) |
| 78 | self._cursor = self._connection.cursor() |
| 79 | |
| 80 | |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 81 | class _SqliteBackend(_GenericBackend): |
| 82 | def __init__(self): |
| 83 | from pysqlite2 import dbapi2 |
| 84 | super(_SqliteBackend, self).__init__(dbapi2) |
showard | 89f84db | 2009-03-12 20:39:13 +0000 | [diff] [blame] | 85 | self._last_insert_id_re = re.compile(r'\sLAST_INSERT_ID\(\)', |
| 86 | re.IGNORECASE) |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 87 | |
| 88 | |
| 89 | def connect(self, host=None, username=None, password=None, db_name=None): |
| 90 | self._connection = self._database_module.connect(db_name) |
| 91 | self._connection.isolation_level = None # enable autocommit |
| 92 | self._cursor = self._connection.cursor() |
| 93 | |
| 94 | |
showard | b1e5187 | 2008-10-07 11:08:18 +0000 | [diff] [blame] | 95 | def execute(self, query, parameters=None): |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 96 | # pysqlite2 uses paramstyle=qmark |
| 97 | # TODO: make this more sophisticated if necessary |
| 98 | query = query.replace('%s', '?') |
showard | b1e5187 | 2008-10-07 11:08:18 +0000 | [diff] [blame] | 99 | # pysqlite2 can't handle parameters=None (it throws a nonsense |
| 100 | # exception) |
| 101 | if parameters is None: |
| 102 | parameters = () |
showard | 89f84db | 2009-03-12 20:39:13 +0000 | [diff] [blame] | 103 | # sqlite3 doesn't support MySQL's LAST_INSERT_ID(). Instead it has |
| 104 | # something similar called LAST_INSERT_ROWID() that will do enough of |
| 105 | # what we want (for our non-concurrent unittest use case). |
| 106 | query = self._last_insert_id_re.sub(' LAST_INSERT_ROWID()', query) |
showard | b1e5187 | 2008-10-07 11:08:18 +0000 | [diff] [blame] | 107 | return super(_SqliteBackend, self).execute(query, parameters) |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 108 | |
| 109 | |
showard | 34ab099 | 2009-10-05 22:47:57 +0000 | [diff] [blame] | 110 | class _DjangoBackend(_GenericBackend): |
| 111 | def __init__(self): |
showard | b21b8c8 | 2009-12-07 19:39:39 +0000 | [diff] [blame] | 112 | from django.db import backend, connection, transaction |
Dale Curtis | 74a314b | 2011-06-23 14:55:46 -0700 | [diff] [blame] | 113 | import django.db as django_db |
| 114 | super(_DjangoBackend, self).__init__(django_db) |
showard | b21b8c8 | 2009-12-07 19:39:39 +0000 | [diff] [blame] | 115 | self._django_connection = connection |
| 116 | self._django_transaction = transaction |
showard | 34ab099 | 2009-10-05 22:47:57 +0000 | [diff] [blame] | 117 | |
| 118 | |
| 119 | def connect(self, host=None, username=None, password=None, db_name=None): |
showard | b21b8c8 | 2009-12-07 19:39:39 +0000 | [diff] [blame] | 120 | self._connection = self._django_connection |
showard | 34ab099 | 2009-10-05 22:47:57 +0000 | [diff] [blame] | 121 | self._cursor = self._connection.cursor() |
| 122 | |
| 123 | |
showard | b21b8c8 | 2009-12-07 19:39:39 +0000 | [diff] [blame] | 124 | def execute(self, query, parameters=None): |
| 125 | try: |
| 126 | return super(_DjangoBackend, self).execute(query, |
| 127 | parameters=parameters) |
| 128 | finally: |
| 129 | self._django_transaction.commit_unless_managed() |
| 130 | |
| 131 | |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 132 | _BACKEND_MAP = { |
showard | 34ab099 | 2009-10-05 22:47:57 +0000 | [diff] [blame] | 133 | 'mysql': _MySqlBackend, |
| 134 | 'sqlite': _SqliteBackend, |
| 135 | 'django': _DjangoBackend, |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 136 | } |
| 137 | |
| 138 | |
| 139 | class DatabaseConnection(object): |
| 140 | """ |
| 141 | Generic wrapper for a database connection. Supports both mysql and sqlite |
| 142 | backends. |
| 143 | |
| 144 | Public attributes: |
| 145 | * reconnect_enabled: if True, when an OperationalError occurs the class will |
| 146 | try to reconnect to the database automatically. |
| 147 | * reconnect_delay_sec: seconds to wait before reconnecting |
| 148 | * max_reconnect_attempts: maximum number of time to try reconnecting before |
| 149 | giving up. Setting to RECONNECT_FOREVER removes the limit. |
| 150 | * rowcount - will hold cursor.rowcount after each call to execute(). |
| 151 | * global_config_section - the section in which to find DB information. this |
| 152 | should be passed to the constructor, not set later, and may be None, in |
| 153 | which case information must be passed to connect(). |
showard | b1e5187 | 2008-10-07 11:08:18 +0000 | [diff] [blame] | 154 | * debug - if set True, all queries will be printed before being executed |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 155 | """ |
| 156 | _DATABASE_ATTRIBUTES = ('db_type', 'host', 'username', 'password', |
| 157 | 'db_name') |
| 158 | |
mbligh | fe6f1a4 | 2010-03-03 17:15:04 +0000 | [diff] [blame] | 159 | def __init__(self, global_config_section=None, debug=False): |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 160 | self.global_config_section = global_config_section |
| 161 | self._backend = None |
| 162 | self.rowcount = None |
mbligh | fe6f1a4 | 2010-03-03 17:15:04 +0000 | [diff] [blame] | 163 | self.debug = debug |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 164 | |
| 165 | # reconnect defaults |
| 166 | self.reconnect_enabled = True |
| 167 | self.reconnect_delay_sec = 20 |
| 168 | self.max_reconnect_attempts = 10 |
| 169 | |
| 170 | self._read_options() |
| 171 | |
| 172 | |
| 173 | def _get_option(self, name, provided_value): |
| 174 | if provided_value is not None: |
| 175 | return provided_value |
| 176 | if self.global_config_section: |
| 177 | global_config_name = _GLOBAL_CONFIG_NAMES.get(name, name) |
| 178 | return global_config.global_config.get_config_value( |
| 179 | self.global_config_section, global_config_name) |
| 180 | return getattr(self, name, None) |
| 181 | |
| 182 | |
| 183 | def _read_options(self, db_type=None, host=None, username=None, |
| 184 | password=None, db_name=None): |
| 185 | self.db_type = self._get_option('db_type', db_type) |
| 186 | self.host = self._get_option('host', host) |
| 187 | self.username = self._get_option('username', username) |
| 188 | self.password = self._get_option('password', password) |
| 189 | self.db_name = self._get_option('db_name', db_name) |
| 190 | |
| 191 | |
| 192 | def _get_backend(self, db_type): |
| 193 | if db_type not in _BACKEND_MAP: |
| 194 | raise ValueError('Invalid database type: %s, should be one of %s' % |
| 195 | (db_type, ', '.join(_BACKEND_MAP.keys()))) |
| 196 | backend_class = _BACKEND_MAP[db_type] |
| 197 | return backend_class() |
| 198 | |
| 199 | |
| 200 | def _reached_max_attempts(self, num_attempts): |
| 201 | return (self.max_reconnect_attempts is not RECONNECT_FOREVER and |
| 202 | num_attempts > self.max_reconnect_attempts) |
| 203 | |
| 204 | |
| 205 | def _is_reconnect_enabled(self, supplied_param): |
| 206 | if supplied_param is not None: |
| 207 | return supplied_param |
| 208 | return self.reconnect_enabled |
| 209 | |
| 210 | |
| 211 | def _connect_backend(self, try_reconnecting=None): |
| 212 | num_attempts = 0 |
| 213 | while True: |
| 214 | try: |
| 215 | self._backend.connect(host=self.host, username=self.username, |
| 216 | password=self.password, |
| 217 | db_name=self.db_name) |
| 218 | return |
| 219 | except self._backend.OperationalError: |
| 220 | num_attempts += 1 |
| 221 | if not self._is_reconnect_enabled(try_reconnecting): |
| 222 | raise |
| 223 | if self._reached_max_attempts(num_attempts): |
| 224 | raise |
| 225 | traceback.print_exc() |
| 226 | print ("Can't connect to database; reconnecting in %s sec" % |
| 227 | self.reconnect_delay_sec) |
| 228 | time.sleep(self.reconnect_delay_sec) |
| 229 | self.disconnect() |
| 230 | |
| 231 | |
| 232 | def connect(self, db_type=None, host=None, username=None, password=None, |
| 233 | db_name=None, try_reconnecting=None): |
| 234 | """ |
| 235 | Parameters passed to this function will override defaults from global |
| 236 | config. try_reconnecting, if passed, will override |
| 237 | self.reconnect_enabled. |
| 238 | """ |
| 239 | self.disconnect() |
| 240 | self._read_options(db_type, host, username, password, db_name) |
| 241 | |
| 242 | self._backend = self._get_backend(self.db_type) |
| 243 | _copy_exceptions(self._backend, self) |
| 244 | self._connect_backend(try_reconnecting) |
| 245 | |
| 246 | |
| 247 | def disconnect(self): |
| 248 | if self._backend: |
| 249 | self._backend.disconnect() |
| 250 | |
| 251 | |
| 252 | def execute(self, query, parameters=None, try_reconnecting=None): |
| 253 | """ |
| 254 | Execute a query and return cursor.fetchall(). try_reconnecting, if |
| 255 | passed, will override self.reconnect_enabled. |
| 256 | """ |
showard | b1e5187 | 2008-10-07 11:08:18 +0000 | [diff] [blame] | 257 | if self.debug: |
| 258 | print 'Executing %s, %s' % (query, parameters) |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 259 | # _connect_backend() contains a retry loop, so don't loop here |
| 260 | try: |
| 261 | results = self._backend.execute(query, parameters) |
| 262 | except self._backend.OperationalError: |
| 263 | if not self._is_reconnect_enabled(try_reconnecting): |
| 264 | raise |
| 265 | traceback.print_exc() |
| 266 | print ("MYSQL connection died; reconnecting") |
| 267 | self.disconnect() |
| 268 | self._connect_backend(try_reconnecting) |
| 269 | results = self._backend.execute(query, parameters) |
| 270 | |
| 271 | self.rowcount = self._backend.rowcount |
| 272 | return results |
| 273 | |
| 274 | |
| 275 | def get_database_info(self): |
| 276 | return dict((attribute, getattr(self, attribute)) |
| 277 | for attribute in self._DATABASE_ATTRIBUTES) |
| 278 | |
| 279 | |
| 280 | @classmethod |
showard | 34ab099 | 2009-10-05 22:47:57 +0000 | [diff] [blame] | 281 | def get_test_database(cls, file_path=':memory:', **constructor_kwargs): |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 282 | """ |
| 283 | Factory method returning a DatabaseConnection for a temporary in-memory |
| 284 | database. |
| 285 | """ |
showard | 34ab099 | 2009-10-05 22:47:57 +0000 | [diff] [blame] | 286 | database = cls(**constructor_kwargs) |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 287 | database.reconnect_enabled = False |
showard | b1e5187 | 2008-10-07 11:08:18 +0000 | [diff] [blame] | 288 | database.connect(db_type='sqlite', db_name=file_path) |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 289 | return database |
showard | 34ab099 | 2009-10-05 22:47:57 +0000 | [diff] [blame] | 290 | |
| 291 | |
| 292 | class TranslatingDatabase(DatabaseConnection): |
| 293 | """ |
| 294 | Database wrapper than applies arbitrary substitution regexps to each query |
| 295 | string. Useful for SQLite testing. |
| 296 | """ |
| 297 | def __init__(self, translators): |
| 298 | """ |
| 299 | @param translation_regexps: list of callables to apply to each query |
| 300 | string (in order). Each accepts a query string and returns a |
| 301 | (possibly) modified query string. |
| 302 | """ |
| 303 | super(TranslatingDatabase, self).__init__() |
| 304 | self._translators = translators |
| 305 | |
| 306 | |
| 307 | def execute(self, query, parameters=None, try_reconnecting=None): |
| 308 | for translator in self._translators: |
| 309 | query = translator(query) |
| 310 | return super(TranslatingDatabase, self).execute( |
| 311 | query, parameters=parameters, try_reconnecting=try_reconnecting) |
| 312 | |
| 313 | |
| 314 | @classmethod |
| 315 | def make_regexp_translator(cls, search_re, replace_str): |
| 316 | """ |
| 317 | Returns a translator that calls re.sub() on the query with the given |
| 318 | search and replace arguments. |
| 319 | """ |
| 320 | def translator(query): |
| 321 | return re.sub(search_re, replace_str, query) |
| 322 | return translator |