Aviv Keshet | 4753990 | 2013-06-21 15:29:31 -0700 | [diff] [blame] | 1 | # pylint: disable-msg=C0111 |
| 2 | |
showard | 89f84db | 2009-03-12 20:39:13 +0000 | [diff] [blame] | 3 | import re, time, traceback |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 4 | import common |
| 5 | from autotest_lib.client.common_lib import global_config |
| 6 | |
| 7 | RECONNECT_FOREVER = object() |
| 8 | |
| 9 | _DB_EXCEPTIONS = ('DatabaseError', 'OperationalError', 'ProgrammingError') |
| 10 | _GLOBAL_CONFIG_NAMES = { |
| 11 | 'username' : 'user', |
| 12 | 'db_name' : 'database', |
| 13 | } |
| 14 | |
| 15 | def _copy_exceptions(source, destination): |
| 16 | for exception_name in _DB_EXCEPTIONS: |
Dale Curtis | 74a314b | 2011-06-23 14:55:46 -0700 | [diff] [blame] | 17 | try: |
| 18 | setattr(destination, exception_name, |
| 19 | getattr(source, exception_name)) |
| 20 | except AttributeError: |
| 21 | # Under the django backend: |
| 22 | # Django 1.3 does not have OperationalError and ProgrammingError. |
| 23 | # Let's just mock these classes with the base DatabaseError. |
| 24 | setattr(destination, exception_name, |
| 25 | getattr(source, 'DatabaseError')) |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 26 | |
| 27 | |
| 28 | class _GenericBackend(object): |
| 29 | def __init__(self, database_module): |
| 30 | self._database_module = database_module |
| 31 | self._connection = None |
| 32 | self._cursor = None |
| 33 | self.rowcount = None |
| 34 | _copy_exceptions(database_module, self) |
| 35 | |
| 36 | |
| 37 | def connect(self, host=None, username=None, password=None, db_name=None): |
| 38 | """ |
| 39 | This is assumed to enable autocommit. |
| 40 | """ |
| 41 | raise NotImplementedError |
| 42 | |
| 43 | |
| 44 | def disconnect(self): |
| 45 | if self._connection: |
| 46 | self._connection.close() |
| 47 | self._connection = None |
| 48 | self._cursor = None |
| 49 | |
| 50 | |
showard | b1e5187 | 2008-10-07 11:08:18 +0000 | [diff] [blame] | 51 | def execute(self, query, parameters=None): |
showard | 34ab099 | 2009-10-05 22:47:57 +0000 | [diff] [blame] | 52 | if parameters is None: |
| 53 | parameters = () |
showard | b1e5187 | 2008-10-07 11:08:18 +0000 | [diff] [blame] | 54 | self._cursor.execute(query, parameters) |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 55 | self.rowcount = self._cursor.rowcount |
| 56 | return self._cursor.fetchall() |
| 57 | |
| 58 | |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 59 | class _MySqlBackend(_GenericBackend): |
| 60 | def __init__(self): |
| 61 | import MySQLdb |
| 62 | super(_MySqlBackend, self).__init__(MySQLdb) |
| 63 | |
| 64 | |
| 65 | @staticmethod |
| 66 | def convert_boolean(boolean, conversion_dict): |
| 67 | 'Convert booleans to integer strings' |
| 68 | return str(int(boolean)) |
| 69 | |
| 70 | |
| 71 | def connect(self, host=None, username=None, password=None, db_name=None): |
| 72 | import MySQLdb.converters |
| 73 | convert_dict = MySQLdb.converters.conversions |
| 74 | convert_dict.setdefault(bool, self.convert_boolean) |
| 75 | |
| 76 | self._connection = self._database_module.connect( |
| 77 | host=host, user=username, passwd=password, db=db_name, |
| 78 | conv=convert_dict) |
| 79 | self._connection.autocommit(True) |
| 80 | self._cursor = self._connection.cursor() |
| 81 | |
| 82 | |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 83 | class _SqliteBackend(_GenericBackend): |
| 84 | def __init__(self): |
Aviv Keshet | 4753990 | 2013-06-21 15:29:31 -0700 | [diff] [blame] | 85 | try: |
| 86 | from pysqlite2 import dbapi2 |
| 87 | except ImportError: |
| 88 | from sqlite3 import dbapi2 |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 89 | super(_SqliteBackend, self).__init__(dbapi2) |
showard | 89f84db | 2009-03-12 20:39:13 +0000 | [diff] [blame] | 90 | self._last_insert_id_re = re.compile(r'\sLAST_INSERT_ID\(\)', |
| 91 | re.IGNORECASE) |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 92 | |
| 93 | |
| 94 | def connect(self, host=None, username=None, password=None, db_name=None): |
| 95 | self._connection = self._database_module.connect(db_name) |
| 96 | self._connection.isolation_level = None # enable autocommit |
| 97 | self._cursor = self._connection.cursor() |
| 98 | |
| 99 | |
showard | b1e5187 | 2008-10-07 11:08:18 +0000 | [diff] [blame] | 100 | def execute(self, query, parameters=None): |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 101 | # pysqlite2 uses paramstyle=qmark |
| 102 | # TODO: make this more sophisticated if necessary |
| 103 | query = query.replace('%s', '?') |
showard | b1e5187 | 2008-10-07 11:08:18 +0000 | [diff] [blame] | 104 | # pysqlite2 can't handle parameters=None (it throws a nonsense |
| 105 | # exception) |
| 106 | if parameters is None: |
| 107 | parameters = () |
showard | 89f84db | 2009-03-12 20:39:13 +0000 | [diff] [blame] | 108 | # sqlite3 doesn't support MySQL's LAST_INSERT_ID(). Instead it has |
| 109 | # something similar called LAST_INSERT_ROWID() that will do enough of |
| 110 | # what we want (for our non-concurrent unittest use case). |
| 111 | query = self._last_insert_id_re.sub(' LAST_INSERT_ROWID()', query) |
showard | b1e5187 | 2008-10-07 11:08:18 +0000 | [diff] [blame] | 112 | return super(_SqliteBackend, self).execute(query, parameters) |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 113 | |
| 114 | |
showard | 34ab099 | 2009-10-05 22:47:57 +0000 | [diff] [blame] | 115 | class _DjangoBackend(_GenericBackend): |
| 116 | def __init__(self): |
showard | b21b8c8 | 2009-12-07 19:39:39 +0000 | [diff] [blame] | 117 | from django.db import backend, connection, transaction |
Dale Curtis | 74a314b | 2011-06-23 14:55:46 -0700 | [diff] [blame] | 118 | import django.db as django_db |
| 119 | super(_DjangoBackend, self).__init__(django_db) |
showard | b21b8c8 | 2009-12-07 19:39:39 +0000 | [diff] [blame] | 120 | self._django_connection = connection |
| 121 | self._django_transaction = transaction |
showard | 34ab099 | 2009-10-05 22:47:57 +0000 | [diff] [blame] | 122 | |
| 123 | |
| 124 | def connect(self, host=None, username=None, password=None, db_name=None): |
showard | b21b8c8 | 2009-12-07 19:39:39 +0000 | [diff] [blame] | 125 | self._connection = self._django_connection |
showard | 34ab099 | 2009-10-05 22:47:57 +0000 | [diff] [blame] | 126 | self._cursor = self._connection.cursor() |
| 127 | |
| 128 | |
showard | b21b8c8 | 2009-12-07 19:39:39 +0000 | [diff] [blame] | 129 | def execute(self, query, parameters=None): |
| 130 | try: |
| 131 | return super(_DjangoBackend, self).execute(query, |
| 132 | parameters=parameters) |
| 133 | finally: |
| 134 | self._django_transaction.commit_unless_managed() |
| 135 | |
| 136 | |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 137 | _BACKEND_MAP = { |
showard | 34ab099 | 2009-10-05 22:47:57 +0000 | [diff] [blame] | 138 | 'mysql': _MySqlBackend, |
| 139 | 'sqlite': _SqliteBackend, |
| 140 | 'django': _DjangoBackend, |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 141 | } |
| 142 | |
| 143 | |
| 144 | class DatabaseConnection(object): |
| 145 | """ |
| 146 | Generic wrapper for a database connection. Supports both mysql and sqlite |
| 147 | backends. |
| 148 | |
| 149 | Public attributes: |
| 150 | * reconnect_enabled: if True, when an OperationalError occurs the class will |
| 151 | try to reconnect to the database automatically. |
| 152 | * reconnect_delay_sec: seconds to wait before reconnecting |
| 153 | * max_reconnect_attempts: maximum number of time to try reconnecting before |
| 154 | giving up. Setting to RECONNECT_FOREVER removes the limit. |
| 155 | * rowcount - will hold cursor.rowcount after each call to execute(). |
| 156 | * global_config_section - the section in which to find DB information. this |
| 157 | should be passed to the constructor, not set later, and may be None, in |
| 158 | which case information must be passed to connect(). |
showard | b1e5187 | 2008-10-07 11:08:18 +0000 | [diff] [blame] | 159 | * debug - if set True, all queries will be printed before being executed |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 160 | """ |
| 161 | _DATABASE_ATTRIBUTES = ('db_type', 'host', 'username', 'password', |
| 162 | 'db_name') |
| 163 | |
mbligh | fe6f1a4 | 2010-03-03 17:15:04 +0000 | [diff] [blame] | 164 | def __init__(self, global_config_section=None, debug=False): |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 165 | self.global_config_section = global_config_section |
| 166 | self._backend = None |
| 167 | self.rowcount = None |
mbligh | fe6f1a4 | 2010-03-03 17:15:04 +0000 | [diff] [blame] | 168 | self.debug = debug |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 169 | |
| 170 | # reconnect defaults |
| 171 | self.reconnect_enabled = True |
| 172 | self.reconnect_delay_sec = 20 |
| 173 | self.max_reconnect_attempts = 10 |
| 174 | |
| 175 | self._read_options() |
| 176 | |
| 177 | |
Dan Shi | 7f0c183 | 2014-10-27 16:05:57 -0700 | [diff] [blame] | 178 | def _get_option(self, name, provided_value, use_afe_setting=False): |
| 179 | """Get value of given option from global config. |
| 180 | |
| 181 | @param name: Name of the config. |
| 182 | @param provided_value: Value being provided to override the one from |
| 183 | global config. |
| 184 | @param use_afe_setting: Force to use the settings in AFE, default is |
| 185 | False. |
| 186 | """ |
| 187 | # TODO(dshi): This function returns the option value depends on multiple |
| 188 | # conditions. The value of `provided_value` has highest priority, then |
| 189 | # the code checks if use_afe_setting is True, if that's the case, force |
| 190 | # to use settings in AUTOTEST_WEB. At last the value is retrieved from |
| 191 | # specified global config section. |
| 192 | # The logic is too complicated for a generic function named like |
| 193 | # _get_option. Ideally we want to make it clear from caller that it |
| 194 | # wants to get database credential from one of the 3 ways: |
| 195 | # 1. Use the credential from given config section |
| 196 | # 2. Use the credential from AUTOTEST_WEB section |
| 197 | # 3. Use the credential provided by caller. |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 198 | if provided_value is not None: |
| 199 | return provided_value |
Dan Shi | 7f0c183 | 2014-10-27 16:05:57 -0700 | [diff] [blame] | 200 | section = ('AUTOTEST_WEB' if use_afe_setting else |
| 201 | self.global_config_section) |
| 202 | if section: |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 203 | global_config_name = _GLOBAL_CONFIG_NAMES.get(name, name) |
| 204 | return global_config.global_config.get_config_value( |
Dan Shi | 7f0c183 | 2014-10-27 16:05:57 -0700 | [diff] [blame] | 205 | section, global_config_name) |
| 206 | |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 207 | return getattr(self, name, None) |
| 208 | |
| 209 | |
| 210 | def _read_options(self, db_type=None, host=None, username=None, |
| 211 | password=None, db_name=None): |
Dan Shi | 7f0c183 | 2014-10-27 16:05:57 -0700 | [diff] [blame] | 212 | """Read database information from global config. |
| 213 | |
| 214 | Unless any parameter is specified a value, the connection will use |
| 215 | database name from given configure section (self.global_config_section), |
| 216 | and database credential from AFE database settings (AUTOTEST_WEB). |
| 217 | |
| 218 | @param db_type: database type, default to None. |
| 219 | @param host: database hostname, default to None. |
| 220 | @param username: user name for database connection, default to None. |
| 221 | @param password: database password, default to None. |
| 222 | @param db_name: database name, default to None. |
| 223 | """ |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 224 | self.db_name = self._get_option('db_name', db_name) |
Dan Shi | 7f0c183 | 2014-10-27 16:05:57 -0700 | [diff] [blame] | 225 | use_afe_setting = not bool(db_type or host or username or password) |
| 226 | |
| 227 | # Database credential can be provided by the caller, as passed in from |
| 228 | # function connect. |
| 229 | self.db_type = self._get_option('db_type', db_type, use_afe_setting) |
| 230 | self.host = self._get_option('host', host, use_afe_setting) |
| 231 | self.username = self._get_option('username', username, use_afe_setting) |
| 232 | self.password = self._get_option('password', password, use_afe_setting) |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 233 | |
| 234 | |
| 235 | def _get_backend(self, db_type): |
| 236 | if db_type not in _BACKEND_MAP: |
| 237 | raise ValueError('Invalid database type: %s, should be one of %s' % |
| 238 | (db_type, ', '.join(_BACKEND_MAP.keys()))) |
| 239 | backend_class = _BACKEND_MAP[db_type] |
| 240 | return backend_class() |
| 241 | |
| 242 | |
| 243 | def _reached_max_attempts(self, num_attempts): |
| 244 | return (self.max_reconnect_attempts is not RECONNECT_FOREVER and |
| 245 | num_attempts > self.max_reconnect_attempts) |
| 246 | |
| 247 | |
| 248 | def _is_reconnect_enabled(self, supplied_param): |
| 249 | if supplied_param is not None: |
| 250 | return supplied_param |
| 251 | return self.reconnect_enabled |
| 252 | |
| 253 | |
| 254 | def _connect_backend(self, try_reconnecting=None): |
| 255 | num_attempts = 0 |
| 256 | while True: |
| 257 | try: |
| 258 | self._backend.connect(host=self.host, username=self.username, |
| 259 | password=self.password, |
| 260 | db_name=self.db_name) |
| 261 | return |
| 262 | except self._backend.OperationalError: |
| 263 | num_attempts += 1 |
| 264 | if not self._is_reconnect_enabled(try_reconnecting): |
| 265 | raise |
| 266 | if self._reached_max_attempts(num_attempts): |
| 267 | raise |
| 268 | traceback.print_exc() |
| 269 | print ("Can't connect to database; reconnecting in %s sec" % |
| 270 | self.reconnect_delay_sec) |
| 271 | time.sleep(self.reconnect_delay_sec) |
| 272 | self.disconnect() |
| 273 | |
| 274 | |
| 275 | def connect(self, db_type=None, host=None, username=None, password=None, |
| 276 | db_name=None, try_reconnecting=None): |
| 277 | """ |
| 278 | Parameters passed to this function will override defaults from global |
| 279 | config. try_reconnecting, if passed, will override |
| 280 | self.reconnect_enabled. |
| 281 | """ |
| 282 | self.disconnect() |
| 283 | self._read_options(db_type, host, username, password, db_name) |
| 284 | |
| 285 | self._backend = self._get_backend(self.db_type) |
| 286 | _copy_exceptions(self._backend, self) |
| 287 | self._connect_backend(try_reconnecting) |
| 288 | |
| 289 | |
| 290 | def disconnect(self): |
| 291 | if self._backend: |
| 292 | self._backend.disconnect() |
| 293 | |
| 294 | |
| 295 | def execute(self, query, parameters=None, try_reconnecting=None): |
| 296 | """ |
| 297 | Execute a query and return cursor.fetchall(). try_reconnecting, if |
| 298 | passed, will override self.reconnect_enabled. |
| 299 | """ |
showard | b1e5187 | 2008-10-07 11:08:18 +0000 | [diff] [blame] | 300 | if self.debug: |
| 301 | print 'Executing %s, %s' % (query, parameters) |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 302 | # _connect_backend() contains a retry loop, so don't loop here |
| 303 | try: |
| 304 | results = self._backend.execute(query, parameters) |
| 305 | except self._backend.OperationalError: |
| 306 | if not self._is_reconnect_enabled(try_reconnecting): |
| 307 | raise |
| 308 | traceback.print_exc() |
| 309 | print ("MYSQL connection died; reconnecting") |
| 310 | self.disconnect() |
| 311 | self._connect_backend(try_reconnecting) |
| 312 | results = self._backend.execute(query, parameters) |
| 313 | |
| 314 | self.rowcount = self._backend.rowcount |
| 315 | return results |
| 316 | |
| 317 | |
| 318 | def get_database_info(self): |
| 319 | return dict((attribute, getattr(self, attribute)) |
| 320 | for attribute in self._DATABASE_ATTRIBUTES) |
| 321 | |
| 322 | |
| 323 | @classmethod |
showard | 34ab099 | 2009-10-05 22:47:57 +0000 | [diff] [blame] | 324 | def get_test_database(cls, file_path=':memory:', **constructor_kwargs): |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 325 | """ |
| 326 | Factory method returning a DatabaseConnection for a temporary in-memory |
| 327 | database. |
| 328 | """ |
showard | 34ab099 | 2009-10-05 22:47:57 +0000 | [diff] [blame] | 329 | database = cls(**constructor_kwargs) |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 330 | database.reconnect_enabled = False |
showard | b1e5187 | 2008-10-07 11:08:18 +0000 | [diff] [blame] | 331 | database.connect(db_type='sqlite', db_name=file_path) |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 332 | return database |
showard | 34ab099 | 2009-10-05 22:47:57 +0000 | [diff] [blame] | 333 | |
| 334 | |
| 335 | class TranslatingDatabase(DatabaseConnection): |
| 336 | """ |
| 337 | Database wrapper than applies arbitrary substitution regexps to each query |
| 338 | string. Useful for SQLite testing. |
| 339 | """ |
| 340 | def __init__(self, translators): |
| 341 | """ |
| 342 | @param translation_regexps: list of callables to apply to each query |
| 343 | string (in order). Each accepts a query string and returns a |
| 344 | (possibly) modified query string. |
| 345 | """ |
| 346 | super(TranslatingDatabase, self).__init__() |
| 347 | self._translators = translators |
| 348 | |
| 349 | |
| 350 | def execute(self, query, parameters=None, try_reconnecting=None): |
| 351 | for translator in self._translators: |
| 352 | query = translator(query) |
| 353 | return super(TranslatingDatabase, self).execute( |
| 354 | query, parameters=parameters, try_reconnecting=try_reconnecting) |
| 355 | |
| 356 | |
| 357 | @classmethod |
| 358 | def make_regexp_translator(cls, search_re, replace_str): |
| 359 | """ |
| 360 | Returns a translator that calls re.sub() on the query with the given |
| 361 | search and replace arguments. |
| 362 | """ |
| 363 | def translator(query): |
| 364 | return re.sub(search_re, replace_str, query) |
| 365 | return translator |