showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 1 | import traceback, time |
| 2 | import common |
| 3 | from autotest_lib.client.common_lib import global_config |
| 4 | |
| 5 | RECONNECT_FOREVER = object() |
| 6 | |
| 7 | _DB_EXCEPTIONS = ('DatabaseError', 'OperationalError', 'ProgrammingError') |
| 8 | _GLOBAL_CONFIG_NAMES = { |
| 9 | 'username' : 'user', |
| 10 | 'db_name' : 'database', |
| 11 | } |
| 12 | |
| 13 | def _copy_exceptions(source, destination): |
| 14 | for exception_name in _DB_EXCEPTIONS: |
| 15 | setattr(destination, exception_name, getattr(source, exception_name)) |
| 16 | |
| 17 | |
| 18 | class _GenericBackend(object): |
| 19 | def __init__(self, database_module): |
| 20 | self._database_module = database_module |
| 21 | self._connection = None |
| 22 | self._cursor = None |
| 23 | self.rowcount = None |
| 24 | _copy_exceptions(database_module, self) |
| 25 | |
| 26 | |
| 27 | def connect(self, host=None, username=None, password=None, db_name=None): |
| 28 | """ |
| 29 | This is assumed to enable autocommit. |
| 30 | """ |
| 31 | raise NotImplementedError |
| 32 | |
| 33 | |
| 34 | def disconnect(self): |
| 35 | if self._connection: |
| 36 | self._connection.close() |
| 37 | self._connection = None |
| 38 | self._cursor = None |
| 39 | |
| 40 | |
showard | b1e5187 | 2008-10-07 11:08:18 +0000 | [diff] [blame^] | 41 | def execute(self, query, parameters=None): |
| 42 | self._cursor.execute(query, parameters) |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 43 | self.rowcount = self._cursor.rowcount |
| 44 | return self._cursor.fetchall() |
| 45 | |
| 46 | |
| 47 | def get_exception_details(exception): |
| 48 | return ExceptionDetails.UNKNOWN |
| 49 | |
| 50 | |
| 51 | class _MySqlBackend(_GenericBackend): |
| 52 | def __init__(self): |
| 53 | import MySQLdb |
| 54 | super(_MySqlBackend, self).__init__(MySQLdb) |
| 55 | |
| 56 | |
| 57 | @staticmethod |
| 58 | def convert_boolean(boolean, conversion_dict): |
| 59 | 'Convert booleans to integer strings' |
| 60 | return str(int(boolean)) |
| 61 | |
| 62 | |
| 63 | def connect(self, host=None, username=None, password=None, db_name=None): |
| 64 | import MySQLdb.converters |
| 65 | convert_dict = MySQLdb.converters.conversions |
| 66 | convert_dict.setdefault(bool, self.convert_boolean) |
| 67 | |
| 68 | self._connection = self._database_module.connect( |
| 69 | host=host, user=username, passwd=password, db=db_name, |
| 70 | conv=convert_dict) |
| 71 | self._connection.autocommit(True) |
| 72 | self._cursor = self._connection.cursor() |
| 73 | |
| 74 | |
| 75 | def get_exception_details(exception): |
| 76 | pass |
| 77 | |
| 78 | |
| 79 | class _SqliteBackend(_GenericBackend): |
| 80 | def __init__(self): |
| 81 | from pysqlite2 import dbapi2 |
| 82 | super(_SqliteBackend, self).__init__(dbapi2) |
| 83 | |
| 84 | |
| 85 | def connect(self, host=None, username=None, password=None, db_name=None): |
| 86 | self._connection = self._database_module.connect(db_name) |
| 87 | self._connection.isolation_level = None # enable autocommit |
| 88 | self._cursor = self._connection.cursor() |
| 89 | |
| 90 | |
showard | b1e5187 | 2008-10-07 11:08:18 +0000 | [diff] [blame^] | 91 | def execute(self, query, parameters=None): |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 92 | # pysqlite2 uses paramstyle=qmark |
| 93 | # TODO: make this more sophisticated if necessary |
| 94 | query = query.replace('%s', '?') |
showard | b1e5187 | 2008-10-07 11:08:18 +0000 | [diff] [blame^] | 95 | # pysqlite2 can't handle parameters=None (it throws a nonsense |
| 96 | # exception) |
| 97 | if parameters is None: |
| 98 | parameters = () |
| 99 | return super(_SqliteBackend, self).execute(query, parameters) |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 100 | |
| 101 | |
| 102 | _BACKEND_MAP = { |
| 103 | 'mysql' : _MySqlBackend, |
| 104 | 'sqlite' : _SqliteBackend, |
| 105 | } |
| 106 | |
| 107 | |
| 108 | class DatabaseConnection(object): |
| 109 | """ |
| 110 | Generic wrapper for a database connection. Supports both mysql and sqlite |
| 111 | backends. |
| 112 | |
| 113 | Public attributes: |
| 114 | * reconnect_enabled: if True, when an OperationalError occurs the class will |
| 115 | try to reconnect to the database automatically. |
| 116 | * reconnect_delay_sec: seconds to wait before reconnecting |
| 117 | * max_reconnect_attempts: maximum number of time to try reconnecting before |
| 118 | giving up. Setting to RECONNECT_FOREVER removes the limit. |
| 119 | * rowcount - will hold cursor.rowcount after each call to execute(). |
| 120 | * global_config_section - the section in which to find DB information. this |
| 121 | should be passed to the constructor, not set later, and may be None, in |
| 122 | which case information must be passed to connect(). |
showard | b1e5187 | 2008-10-07 11:08:18 +0000 | [diff] [blame^] | 123 | * debug - if set True, all queries will be printed before being executed |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 124 | """ |
| 125 | _DATABASE_ATTRIBUTES = ('db_type', 'host', 'username', 'password', |
| 126 | 'db_name') |
| 127 | |
| 128 | def __init__(self, global_config_section=None): |
| 129 | self.global_config_section = global_config_section |
| 130 | self._backend = None |
| 131 | self.rowcount = None |
showard | b1e5187 | 2008-10-07 11:08:18 +0000 | [diff] [blame^] | 132 | self.debug = False |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 133 | |
| 134 | # reconnect defaults |
| 135 | self.reconnect_enabled = True |
| 136 | self.reconnect_delay_sec = 20 |
| 137 | self.max_reconnect_attempts = 10 |
| 138 | |
| 139 | self._read_options() |
| 140 | |
| 141 | |
| 142 | def _get_option(self, name, provided_value): |
| 143 | if provided_value is not None: |
| 144 | return provided_value |
| 145 | if self.global_config_section: |
| 146 | global_config_name = _GLOBAL_CONFIG_NAMES.get(name, name) |
| 147 | return global_config.global_config.get_config_value( |
| 148 | self.global_config_section, global_config_name) |
| 149 | return getattr(self, name, None) |
| 150 | |
| 151 | |
| 152 | def _read_options(self, db_type=None, host=None, username=None, |
| 153 | password=None, db_name=None): |
| 154 | self.db_type = self._get_option('db_type', db_type) |
| 155 | self.host = self._get_option('host', host) |
| 156 | self.username = self._get_option('username', username) |
| 157 | self.password = self._get_option('password', password) |
| 158 | self.db_name = self._get_option('db_name', db_name) |
| 159 | |
| 160 | |
| 161 | def _get_backend(self, db_type): |
| 162 | if db_type not in _BACKEND_MAP: |
| 163 | raise ValueError('Invalid database type: %s, should be one of %s' % |
| 164 | (db_type, ', '.join(_BACKEND_MAP.keys()))) |
| 165 | backend_class = _BACKEND_MAP[db_type] |
| 166 | return backend_class() |
| 167 | |
| 168 | |
| 169 | def _reached_max_attempts(self, num_attempts): |
| 170 | return (self.max_reconnect_attempts is not RECONNECT_FOREVER and |
| 171 | num_attempts > self.max_reconnect_attempts) |
| 172 | |
| 173 | |
| 174 | def _is_reconnect_enabled(self, supplied_param): |
| 175 | if supplied_param is not None: |
| 176 | return supplied_param |
| 177 | return self.reconnect_enabled |
| 178 | |
| 179 | |
| 180 | def _connect_backend(self, try_reconnecting=None): |
| 181 | num_attempts = 0 |
| 182 | while True: |
| 183 | try: |
| 184 | self._backend.connect(host=self.host, username=self.username, |
| 185 | password=self.password, |
| 186 | db_name=self.db_name) |
| 187 | return |
| 188 | except self._backend.OperationalError: |
| 189 | num_attempts += 1 |
| 190 | if not self._is_reconnect_enabled(try_reconnecting): |
| 191 | raise |
| 192 | if self._reached_max_attempts(num_attempts): |
| 193 | raise |
| 194 | traceback.print_exc() |
| 195 | print ("Can't connect to database; reconnecting in %s sec" % |
| 196 | self.reconnect_delay_sec) |
| 197 | time.sleep(self.reconnect_delay_sec) |
| 198 | self.disconnect() |
| 199 | |
| 200 | |
| 201 | def connect(self, db_type=None, host=None, username=None, password=None, |
| 202 | db_name=None, try_reconnecting=None): |
| 203 | """ |
| 204 | Parameters passed to this function will override defaults from global |
| 205 | config. try_reconnecting, if passed, will override |
| 206 | self.reconnect_enabled. |
| 207 | """ |
| 208 | self.disconnect() |
| 209 | self._read_options(db_type, host, username, password, db_name) |
| 210 | |
| 211 | self._backend = self._get_backend(self.db_type) |
| 212 | _copy_exceptions(self._backend, self) |
| 213 | self._connect_backend(try_reconnecting) |
| 214 | |
| 215 | |
| 216 | def disconnect(self): |
| 217 | if self._backend: |
| 218 | self._backend.disconnect() |
| 219 | |
| 220 | |
| 221 | def execute(self, query, parameters=None, try_reconnecting=None): |
| 222 | """ |
| 223 | Execute a query and return cursor.fetchall(). try_reconnecting, if |
| 224 | passed, will override self.reconnect_enabled. |
| 225 | """ |
showard | b1e5187 | 2008-10-07 11:08:18 +0000 | [diff] [blame^] | 226 | if self.debug: |
| 227 | print 'Executing %s, %s' % (query, parameters) |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 228 | # _connect_backend() contains a retry loop, so don't loop here |
| 229 | try: |
| 230 | results = self._backend.execute(query, parameters) |
| 231 | except self._backend.OperationalError: |
| 232 | if not self._is_reconnect_enabled(try_reconnecting): |
| 233 | raise |
| 234 | traceback.print_exc() |
| 235 | print ("MYSQL connection died; reconnecting") |
| 236 | self.disconnect() |
| 237 | self._connect_backend(try_reconnecting) |
| 238 | results = self._backend.execute(query, parameters) |
| 239 | |
| 240 | self.rowcount = self._backend.rowcount |
| 241 | return results |
| 242 | |
| 243 | |
| 244 | def get_database_info(self): |
| 245 | return dict((attribute, getattr(self, attribute)) |
| 246 | for attribute in self._DATABASE_ATTRIBUTES) |
| 247 | |
| 248 | |
| 249 | @classmethod |
showard | b1e5187 | 2008-10-07 11:08:18 +0000 | [diff] [blame^] | 250 | def get_test_database(cls, file_path=':memory:'): |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 251 | """ |
| 252 | Factory method returning a DatabaseConnection for a temporary in-memory |
| 253 | database. |
| 254 | """ |
| 255 | database = cls() |
| 256 | database.reconnect_enabled = False |
showard | b1e5187 | 2008-10-07 11:08:18 +0000 | [diff] [blame^] | 257 | database.connect(db_type='sqlite', db_name=file_path) |
showard | 0e73c85 | 2008-10-03 10:15:50 +0000 | [diff] [blame] | 258 | return database |