showard | 89f84db | 2009-03-12 20:39:13 +0000 | [diff] [blame] | 1 | import re, time, traceback |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 2 | import common |
| 3 | from autotest_lib.client.common_lib import global_config |
| 4 | |
| 5 | RECONNECT_FOREVER = object() |
| 6 | |
| 7 | _DB_EXCEPTIONS = ('DatabaseError', 'OperationalError', 'ProgrammingError') |
| 8 | _GLOBAL_CONFIG_NAMES = { |
| 9 | 'username' : 'user', |
| 10 | 'db_name' : 'database', |
| 11 | } |
| 12 | |
| 13 | def _copy_exceptions(source, destination): |
| 14 | for exception_name in _DB_EXCEPTIONS: |
| 15 | setattr(destination, exception_name, getattr(source, exception_name)) |
| 16 | |
| 17 | |
| 18 | class _GenericBackend(object): |
| 19 | def __init__(self, database_module): |
| 20 | self._database_module = database_module |
| 21 | self._connection = None |
| 22 | self._cursor = None |
| 23 | self.rowcount = None |
| 24 | _copy_exceptions(database_module, self) |
| 25 | |
| 26 | |
| 27 | def connect(self, host=None, username=None, password=None, db_name=None): |
| 28 | """ |
| 29 | This is assumed to enable autocommit. |
| 30 | """ |
| 31 | raise NotImplementedError |
| 32 | |
| 33 | |
| 34 | def disconnect(self): |
| 35 | if self._connection: |
| 36 | self._connection.close() |
| 37 | self._connection = None |
| 38 | self._cursor = None |
| 39 | |
| 40 | |
showard | b1e5187 | 2008-10-07 11:08:18 +0000 | [diff] [blame] | 41 | def execute(self, query, parameters=None): |
showard | 34ab099 | 2009-10-05 22:47:57 +0000 | [diff] [blame] | 42 | if parameters is None: |
| 43 | parameters = () |
showard | b1e5187 | 2008-10-07 11:08:18 +0000 | [diff] [blame] | 44 | self._cursor.execute(query, parameters) |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 45 | self.rowcount = self._cursor.rowcount |
| 46 | return self._cursor.fetchall() |
| 47 | |
| 48 | |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 49 | class _MySqlBackend(_GenericBackend): |
| 50 | def __init__(self): |
| 51 | import MySQLdb |
| 52 | super(_MySqlBackend, self).__init__(MySQLdb) |
| 53 | |
| 54 | |
| 55 | @staticmethod |
| 56 | def convert_boolean(boolean, conversion_dict): |
| 57 | 'Convert booleans to integer strings' |
| 58 | return str(int(boolean)) |
| 59 | |
| 60 | |
| 61 | def connect(self, host=None, username=None, password=None, db_name=None): |
| 62 | import MySQLdb.converters |
| 63 | convert_dict = MySQLdb.converters.conversions |
| 64 | convert_dict.setdefault(bool, self.convert_boolean) |
| 65 | |
| 66 | self._connection = self._database_module.connect( |
| 67 | host=host, user=username, passwd=password, db=db_name, |
| 68 | conv=convert_dict) |
| 69 | self._connection.autocommit(True) |
| 70 | self._cursor = self._connection.cursor() |
| 71 | |
| 72 | |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 73 | class _SqliteBackend(_GenericBackend): |
| 74 | def __init__(self): |
| 75 | from pysqlite2 import dbapi2 |
| 76 | super(_SqliteBackend, self).__init__(dbapi2) |
showard | 89f84db | 2009-03-12 20:39:13 +0000 | [diff] [blame] | 77 | self._last_insert_id_re = re.compile(r'\sLAST_INSERT_ID\(\)', |
| 78 | re.IGNORECASE) |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 79 | |
| 80 | |
| 81 | def connect(self, host=None, username=None, password=None, db_name=None): |
| 82 | self._connection = self._database_module.connect(db_name) |
| 83 | self._connection.isolation_level = None # enable autocommit |
| 84 | self._cursor = self._connection.cursor() |
| 85 | |
| 86 | |
showard | b1e5187 | 2008-10-07 11:08:18 +0000 | [diff] [blame] | 87 | def execute(self, query, parameters=None): |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 88 | # pysqlite2 uses paramstyle=qmark |
| 89 | # TODO: make this more sophisticated if necessary |
| 90 | query = query.replace('%s', '?') |
showard | b1e5187 | 2008-10-07 11:08:18 +0000 | [diff] [blame] | 91 | # pysqlite2 can't handle parameters=None (it throws a nonsense |
| 92 | # exception) |
| 93 | if parameters is None: |
| 94 | parameters = () |
showard | 89f84db | 2009-03-12 20:39:13 +0000 | [diff] [blame] | 95 | # sqlite3 doesn't support MySQL's LAST_INSERT_ID(). Instead it has |
| 96 | # something similar called LAST_INSERT_ROWID() that will do enough of |
| 97 | # what we want (for our non-concurrent unittest use case). |
| 98 | query = self._last_insert_id_re.sub(' LAST_INSERT_ROWID()', query) |
showard | b1e5187 | 2008-10-07 11:08:18 +0000 | [diff] [blame] | 99 | return super(_SqliteBackend, self).execute(query, parameters) |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 100 | |
| 101 | |
showard | 34ab099 | 2009-10-05 22:47:57 +0000 | [diff] [blame] | 102 | class _DjangoBackend(_GenericBackend): |
| 103 | def __init__(self): |
showard | b21b8c8 | 2009-12-07 19:39:39 +0000 | [diff] [blame] | 104 | from django.db import backend, connection, transaction |
showard | 34ab099 | 2009-10-05 22:47:57 +0000 | [diff] [blame] | 105 | super(_DjangoBackend, self).__init__(backend.Database) |
showard | b21b8c8 | 2009-12-07 19:39:39 +0000 | [diff] [blame] | 106 | self._django_connection = connection |
| 107 | self._django_transaction = transaction |
showard | 34ab099 | 2009-10-05 22:47:57 +0000 | [diff] [blame] | 108 | |
| 109 | |
| 110 | def connect(self, host=None, username=None, password=None, db_name=None): |
showard | b21b8c8 | 2009-12-07 19:39:39 +0000 | [diff] [blame] | 111 | self._connection = self._django_connection |
showard | 34ab099 | 2009-10-05 22:47:57 +0000 | [diff] [blame] | 112 | self._cursor = self._connection.cursor() |
| 113 | |
| 114 | |
showard | b21b8c8 | 2009-12-07 19:39:39 +0000 | [diff] [blame] | 115 | def execute(self, query, parameters=None): |
| 116 | try: |
| 117 | return super(_DjangoBackend, self).execute(query, |
| 118 | parameters=parameters) |
| 119 | finally: |
| 120 | self._django_transaction.commit_unless_managed() |
| 121 | |
| 122 | |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 123 | _BACKEND_MAP = { |
showard | 34ab099 | 2009-10-05 22:47:57 +0000 | [diff] [blame] | 124 | 'mysql': _MySqlBackend, |
| 125 | 'sqlite': _SqliteBackend, |
| 126 | 'django': _DjangoBackend, |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 127 | } |
| 128 | |
| 129 | |
| 130 | class DatabaseConnection(object): |
| 131 | """ |
| 132 | Generic wrapper for a database connection. Supports both mysql and sqlite |
| 133 | backends. |
| 134 | |
| 135 | Public attributes: |
| 136 | * reconnect_enabled: if True, when an OperationalError occurs the class will |
| 137 | try to reconnect to the database automatically. |
| 138 | * reconnect_delay_sec: seconds to wait before reconnecting |
| 139 | * max_reconnect_attempts: maximum number of time to try reconnecting before |
| 140 | giving up. Setting to RECONNECT_FOREVER removes the limit. |
| 141 | * rowcount - will hold cursor.rowcount after each call to execute(). |
| 142 | * global_config_section - the section in which to find DB information. this |
| 143 | should be passed to the constructor, not set later, and may be None, in |
| 144 | which case information must be passed to connect(). |
showard | b1e5187 | 2008-10-07 11:08:18 +0000 | [diff] [blame] | 145 | * debug - if set True, all queries will be printed before being executed |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 146 | """ |
| 147 | _DATABASE_ATTRIBUTES = ('db_type', 'host', 'username', 'password', |
| 148 | 'db_name') |
| 149 | |
mbligh | fe6f1a4 | 2010-03-03 17:15:04 +0000 | [diff] [blame] | 150 | def __init__(self, global_config_section=None, debug=False): |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 151 | self.global_config_section = global_config_section |
| 152 | self._backend = None |
| 153 | self.rowcount = None |
mbligh | fe6f1a4 | 2010-03-03 17:15:04 +0000 | [diff] [blame] | 154 | self.debug = debug |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 155 | |
| 156 | # reconnect defaults |
| 157 | self.reconnect_enabled = True |
| 158 | self.reconnect_delay_sec = 20 |
| 159 | self.max_reconnect_attempts = 10 |
| 160 | |
| 161 | self._read_options() |
| 162 | |
| 163 | |
| 164 | def _get_option(self, name, provided_value): |
| 165 | if provided_value is not None: |
| 166 | return provided_value |
| 167 | if self.global_config_section: |
| 168 | global_config_name = _GLOBAL_CONFIG_NAMES.get(name, name) |
| 169 | return global_config.global_config.get_config_value( |
| 170 | self.global_config_section, global_config_name) |
| 171 | return getattr(self, name, None) |
| 172 | |
| 173 | |
| 174 | def _read_options(self, db_type=None, host=None, username=None, |
| 175 | password=None, db_name=None): |
| 176 | self.db_type = self._get_option('db_type', db_type) |
| 177 | self.host = self._get_option('host', host) |
| 178 | self.username = self._get_option('username', username) |
| 179 | self.password = self._get_option('password', password) |
| 180 | self.db_name = self._get_option('db_name', db_name) |
| 181 | |
| 182 | |
| 183 | def _get_backend(self, db_type): |
| 184 | if db_type not in _BACKEND_MAP: |
| 185 | raise ValueError('Invalid database type: %s, should be one of %s' % |
| 186 | (db_type, ', '.join(_BACKEND_MAP.keys()))) |
| 187 | backend_class = _BACKEND_MAP[db_type] |
| 188 | return backend_class() |
| 189 | |
| 190 | |
| 191 | def _reached_max_attempts(self, num_attempts): |
| 192 | return (self.max_reconnect_attempts is not RECONNECT_FOREVER and |
| 193 | num_attempts > self.max_reconnect_attempts) |
| 194 | |
| 195 | |
| 196 | def _is_reconnect_enabled(self, supplied_param): |
| 197 | if supplied_param is not None: |
| 198 | return supplied_param |
| 199 | return self.reconnect_enabled |
| 200 | |
| 201 | |
| 202 | def _connect_backend(self, try_reconnecting=None): |
| 203 | num_attempts = 0 |
| 204 | while True: |
| 205 | try: |
| 206 | self._backend.connect(host=self.host, username=self.username, |
| 207 | password=self.password, |
| 208 | db_name=self.db_name) |
| 209 | return |
| 210 | except self._backend.OperationalError: |
| 211 | num_attempts += 1 |
| 212 | if not self._is_reconnect_enabled(try_reconnecting): |
| 213 | raise |
| 214 | if self._reached_max_attempts(num_attempts): |
| 215 | raise |
| 216 | traceback.print_exc() |
| 217 | print ("Can't connect to database; reconnecting in %s sec" % |
| 218 | self.reconnect_delay_sec) |
| 219 | time.sleep(self.reconnect_delay_sec) |
| 220 | self.disconnect() |
| 221 | |
| 222 | |
| 223 | def connect(self, db_type=None, host=None, username=None, password=None, |
| 224 | db_name=None, try_reconnecting=None): |
| 225 | """ |
| 226 | Parameters passed to this function will override defaults from global |
| 227 | config. try_reconnecting, if passed, will override |
| 228 | self.reconnect_enabled. |
| 229 | """ |
| 230 | self.disconnect() |
| 231 | self._read_options(db_type, host, username, password, db_name) |
| 232 | |
| 233 | self._backend = self._get_backend(self.db_type) |
| 234 | _copy_exceptions(self._backend, self) |
| 235 | self._connect_backend(try_reconnecting) |
| 236 | |
| 237 | |
| 238 | def disconnect(self): |
| 239 | if self._backend: |
| 240 | self._backend.disconnect() |
| 241 | |
| 242 | |
| 243 | def execute(self, query, parameters=None, try_reconnecting=None): |
| 244 | """ |
| 245 | Execute a query and return cursor.fetchall(). try_reconnecting, if |
| 246 | passed, will override self.reconnect_enabled. |
| 247 | """ |
showard | b1e5187 | 2008-10-07 11:08:18 +0000 | [diff] [blame] | 248 | if self.debug: |
| 249 | print 'Executing %s, %s' % (query, parameters) |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 250 | # _connect_backend() contains a retry loop, so don't loop here |
| 251 | try: |
| 252 | results = self._backend.execute(query, parameters) |
| 253 | except self._backend.OperationalError: |
| 254 | if not self._is_reconnect_enabled(try_reconnecting): |
| 255 | raise |
| 256 | traceback.print_exc() |
| 257 | print ("MYSQL connection died; reconnecting") |
| 258 | self.disconnect() |
| 259 | self._connect_backend(try_reconnecting) |
| 260 | results = self._backend.execute(query, parameters) |
| 261 | |
| 262 | self.rowcount = self._backend.rowcount |
| 263 | return results |
| 264 | |
| 265 | |
| 266 | def get_database_info(self): |
| 267 | return dict((attribute, getattr(self, attribute)) |
| 268 | for attribute in self._DATABASE_ATTRIBUTES) |
| 269 | |
| 270 | |
| 271 | @classmethod |
showard | 34ab099 | 2009-10-05 22:47:57 +0000 | [diff] [blame] | 272 | def get_test_database(cls, file_path=':memory:', **constructor_kwargs): |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 273 | """ |
| 274 | Factory method returning a DatabaseConnection for a temporary in-memory |
| 275 | database. |
| 276 | """ |
showard | 34ab099 | 2009-10-05 22:47:57 +0000 | [diff] [blame] | 277 | database = cls(**constructor_kwargs) |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 278 | database.reconnect_enabled = False |
showard | b1e5187 | 2008-10-07 11:08:18 +0000 | [diff] [blame] | 279 | database.connect(db_type='sqlite', db_name=file_path) |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 280 | return database |
showard | 34ab099 | 2009-10-05 22:47:57 +0000 | [diff] [blame] | 281 | |
| 282 | |
| 283 | class TranslatingDatabase(DatabaseConnection): |
| 284 | """ |
| 285 | Database wrapper than applies arbitrary substitution regexps to each query |
| 286 | string. Useful for SQLite testing. |
| 287 | """ |
| 288 | def __init__(self, translators): |
| 289 | """ |
| 290 | @param translation_regexps: list of callables to apply to each query |
| 291 | string (in order). Each accepts a query string and returns a |
| 292 | (possibly) modified query string. |
| 293 | """ |
| 294 | super(TranslatingDatabase, self).__init__() |
| 295 | self._translators = translators |
| 296 | |
| 297 | |
| 298 | def execute(self, query, parameters=None, try_reconnecting=None): |
| 299 | for translator in self._translators: |
| 300 | query = translator(query) |
| 301 | return super(TranslatingDatabase, self).execute( |
| 302 | query, parameters=parameters, try_reconnecting=try_reconnecting) |
| 303 | |
| 304 | |
| 305 | @classmethod |
| 306 | def make_regexp_translator(cls, search_re, replace_str): |
| 307 | """ |
| 308 | Returns a translator that calls re.sub() on the query with the given |
| 309 | search and replace arguments. |
| 310 | """ |
| 311 | def translator(query): |
| 312 | return re.sub(search_re, replace_str, query) |
| 313 | return translator |