blob: fa870c96a0206023e0ef3615cfb14b907b16f54e [file] [log] [blame]
showard0e73c852008-10-03 10:15:50 +00001import traceback, time
2import common
3from autotest_lib.client.common_lib import global_config
4
5RECONNECT_FOREVER = object()
6
7_DB_EXCEPTIONS = ('DatabaseError', 'OperationalError', 'ProgrammingError')
8_GLOBAL_CONFIG_NAMES = {
9 'username' : 'user',
10 'db_name' : 'database',
11}
12
13def _copy_exceptions(source, destination):
14 for exception_name in _DB_EXCEPTIONS:
15 setattr(destination, exception_name, getattr(source, exception_name))
16
17
18class _GenericBackend(object):
19 def __init__(self, database_module):
20 self._database_module = database_module
21 self._connection = None
22 self._cursor = None
23 self.rowcount = None
24 _copy_exceptions(database_module, self)
25
26
27 def connect(self, host=None, username=None, password=None, db_name=None):
28 """
29 This is assumed to enable autocommit.
30 """
31 raise NotImplementedError
32
33
34 def disconnect(self):
35 if self._connection:
36 self._connection.close()
37 self._connection = None
38 self._cursor = None
39
40
showardb1e51872008-10-07 11:08:18 +000041 def execute(self, query, parameters=None):
42 self._cursor.execute(query, parameters)
showard0e73c852008-10-03 10:15:50 +000043 self.rowcount = self._cursor.rowcount
44 return self._cursor.fetchall()
45
46
47 def get_exception_details(exception):
48 return ExceptionDetails.UNKNOWN
49
50
51class _MySqlBackend(_GenericBackend):
52 def __init__(self):
53 import MySQLdb
54 super(_MySqlBackend, self).__init__(MySQLdb)
55
56
57 @staticmethod
58 def convert_boolean(boolean, conversion_dict):
59 'Convert booleans to integer strings'
60 return str(int(boolean))
61
62
63 def connect(self, host=None, username=None, password=None, db_name=None):
64 import MySQLdb.converters
65 convert_dict = MySQLdb.converters.conversions
66 convert_dict.setdefault(bool, self.convert_boolean)
67
68 self._connection = self._database_module.connect(
69 host=host, user=username, passwd=password, db=db_name,
70 conv=convert_dict)
71 self._connection.autocommit(True)
72 self._cursor = self._connection.cursor()
73
74
75 def get_exception_details(exception):
76 pass
77
78
79class _SqliteBackend(_GenericBackend):
80 def __init__(self):
81 from pysqlite2 import dbapi2
82 super(_SqliteBackend, self).__init__(dbapi2)
83
84
85 def connect(self, host=None, username=None, password=None, db_name=None):
86 self._connection = self._database_module.connect(db_name)
87 self._connection.isolation_level = None # enable autocommit
88 self._cursor = self._connection.cursor()
89
90
showardb1e51872008-10-07 11:08:18 +000091 def execute(self, query, parameters=None):
showard0e73c852008-10-03 10:15:50 +000092 # pysqlite2 uses paramstyle=qmark
93 # TODO: make this more sophisticated if necessary
94 query = query.replace('%s', '?')
showardb1e51872008-10-07 11:08:18 +000095 # pysqlite2 can't handle parameters=None (it throws a nonsense
96 # exception)
97 if parameters is None:
98 parameters = ()
99 return super(_SqliteBackend, self).execute(query, parameters)
showard0e73c852008-10-03 10:15:50 +0000100
101
102_BACKEND_MAP = {
103 'mysql' : _MySqlBackend,
104 'sqlite' : _SqliteBackend,
105}
106
107
108class DatabaseConnection(object):
109 """
110 Generic wrapper for a database connection. Supports both mysql and sqlite
111 backends.
112
113 Public attributes:
114 * reconnect_enabled: if True, when an OperationalError occurs the class will
115 try to reconnect to the database automatically.
116 * reconnect_delay_sec: seconds to wait before reconnecting
117 * max_reconnect_attempts: maximum number of time to try reconnecting before
118 giving up. Setting to RECONNECT_FOREVER removes the limit.
119 * rowcount - will hold cursor.rowcount after each call to execute().
120 * global_config_section - the section in which to find DB information. this
121 should be passed to the constructor, not set later, and may be None, in
122 which case information must be passed to connect().
showardb1e51872008-10-07 11:08:18 +0000123 * debug - if set True, all queries will be printed before being executed
showard0e73c852008-10-03 10:15:50 +0000124 """
125 _DATABASE_ATTRIBUTES = ('db_type', 'host', 'username', 'password',
126 'db_name')
127
128 def __init__(self, global_config_section=None):
129 self.global_config_section = global_config_section
130 self._backend = None
131 self.rowcount = None
showardb1e51872008-10-07 11:08:18 +0000132 self.debug = False
showard0e73c852008-10-03 10:15:50 +0000133
134 # reconnect defaults
135 self.reconnect_enabled = True
136 self.reconnect_delay_sec = 20
137 self.max_reconnect_attempts = 10
138
139 self._read_options()
140
141
142 def _get_option(self, name, provided_value):
143 if provided_value is not None:
144 return provided_value
145 if self.global_config_section:
146 global_config_name = _GLOBAL_CONFIG_NAMES.get(name, name)
147 return global_config.global_config.get_config_value(
148 self.global_config_section, global_config_name)
149 return getattr(self, name, None)
150
151
152 def _read_options(self, db_type=None, host=None, username=None,
153 password=None, db_name=None):
154 self.db_type = self._get_option('db_type', db_type)
155 self.host = self._get_option('host', host)
156 self.username = self._get_option('username', username)
157 self.password = self._get_option('password', password)
158 self.db_name = self._get_option('db_name', db_name)
159
160
161 def _get_backend(self, db_type):
162 if db_type not in _BACKEND_MAP:
163 raise ValueError('Invalid database type: %s, should be one of %s' %
164 (db_type, ', '.join(_BACKEND_MAP.keys())))
165 backend_class = _BACKEND_MAP[db_type]
166 return backend_class()
167
168
169 def _reached_max_attempts(self, num_attempts):
170 return (self.max_reconnect_attempts is not RECONNECT_FOREVER and
171 num_attempts > self.max_reconnect_attempts)
172
173
174 def _is_reconnect_enabled(self, supplied_param):
175 if supplied_param is not None:
176 return supplied_param
177 return self.reconnect_enabled
178
179
180 def _connect_backend(self, try_reconnecting=None):
181 num_attempts = 0
182 while True:
183 try:
184 self._backend.connect(host=self.host, username=self.username,
185 password=self.password,
186 db_name=self.db_name)
187 return
188 except self._backend.OperationalError:
189 num_attempts += 1
190 if not self._is_reconnect_enabled(try_reconnecting):
191 raise
192 if self._reached_max_attempts(num_attempts):
193 raise
194 traceback.print_exc()
195 print ("Can't connect to database; reconnecting in %s sec" %
196 self.reconnect_delay_sec)
197 time.sleep(self.reconnect_delay_sec)
198 self.disconnect()
199
200
201 def connect(self, db_type=None, host=None, username=None, password=None,
202 db_name=None, try_reconnecting=None):
203 """
204 Parameters passed to this function will override defaults from global
205 config. try_reconnecting, if passed, will override
206 self.reconnect_enabled.
207 """
208 self.disconnect()
209 self._read_options(db_type, host, username, password, db_name)
210
211 self._backend = self._get_backend(self.db_type)
212 _copy_exceptions(self._backend, self)
213 self._connect_backend(try_reconnecting)
214
215
216 def disconnect(self):
217 if self._backend:
218 self._backend.disconnect()
219
220
221 def execute(self, query, parameters=None, try_reconnecting=None):
222 """
223 Execute a query and return cursor.fetchall(). try_reconnecting, if
224 passed, will override self.reconnect_enabled.
225 """
showardb1e51872008-10-07 11:08:18 +0000226 if self.debug:
227 print 'Executing %s, %s' % (query, parameters)
showard0e73c852008-10-03 10:15:50 +0000228 # _connect_backend() contains a retry loop, so don't loop here
229 try:
230 results = self._backend.execute(query, parameters)
231 except self._backend.OperationalError:
232 if not self._is_reconnect_enabled(try_reconnecting):
233 raise
234 traceback.print_exc()
235 print ("MYSQL connection died; reconnecting")
236 self.disconnect()
237 self._connect_backend(try_reconnecting)
238 results = self._backend.execute(query, parameters)
239
240 self.rowcount = self._backend.rowcount
241 return results
242
243
244 def get_database_info(self):
245 return dict((attribute, getattr(self, attribute))
246 for attribute in self._DATABASE_ATTRIBUTES)
247
248
249 @classmethod
showardb1e51872008-10-07 11:08:18 +0000250 def get_test_database(cls, file_path=':memory:'):
showard0e73c852008-10-03 10:15:50 +0000251 """
252 Factory method returning a DatabaseConnection for a temporary in-memory
253 database.
254 """
255 database = cls()
256 database.reconnect_enabled = False
showardb1e51872008-10-07 11:08:18 +0000257 database.connect(db_type='sqlite', db_name=file_path)
showard0e73c852008-10-03 10:15:50 +0000258 return database