Catch any errors due to mysql losing its connection. If it does lose
it, retry the connection.
From: Jeremy Orlow <jorlow@google.com>
Signed-off-by: Steve Howard <showard@google.com>
git-svn-id: http://test.kernel.org/svn/autotest/trunk@1280 592f7852-d20e-0410-864c-8624ca9c26a4
diff --git a/scheduler/monitor_db b/scheduler/monitor_db
index 475f849..f89c846 100755
--- a/scheduler/monitor_db
+++ b/scheduler/monitor_db
@@ -22,8 +22,7 @@
if AUTOTEST_SERVER_DIR not in sys.path:
sys.path.insert(0, AUTOTEST_SERVER_DIR)
-_connection = None
-_cursor = None
+_db = None
_shutdown = False
_notify_email = None
_autoserv_path = 'autoserv'
@@ -73,7 +72,7 @@
except:
log_stacktrace("Uncaught exception; terminating monitor_db")
- disconnect()
+ _db.disconnect()
def handle_sigint(signum, frame):
@@ -89,7 +88,8 @@
print "My PID is %d" % os.getpid()
os.environ['PATH'] = AUTOTEST_SERVER_DIR + ':' + os.environ['PATH']
- connect()
+ global _db
+ _db = DatabaseConn()
print "Setting signal handler"
signal.signal(signal.SIGINT, handle_sigint)
@@ -112,7 +112,7 @@
def idle_hosts():
- _cursor.execute("""
+ rows = _db.execute("""
SELECT * FROM hosts h WHERE
id NOT IN (SELECT host_id FROM host_queue_entries WHERE active) AND (
(id IN (SELECT host_id FROM host_queue_entries WHERE not complete AND not active))
@@ -121,39 +121,68 @@
INNER JOIN hosts_labels hl ON hqe.meta_host=hl.label_id WHERE not hqe.complete AND not hqe.active))
)
AND locked=false AND (h.status IS null OR h.status='Ready') """)
- hosts = [Host(row=i) for i in _cursor.fetchall()]
+ hosts = [Host(row=i) for i in rows]
return hosts
-def connect():
- path = os.path.dirname(os.path.abspath(sys.argv[0]))
- # get global config and parse for info
- c = global_config.global_config
- dbase = "AUTOTEST_WEB"
- DB_HOST = c.get_config_value(dbase, "host", "localhost")
- DB_SCHEMA = c.get_config_value(dbase, "database", "autotest_web")
- if _testing_mode:
- DB_SCHEMA = 'stresstest_autotest_web'
+class DatabaseConn:
+ def __init__(self):
+ self.reconnect_wait = 20
+ self.conn = None
+ self.cur = None
- DB_USER = c.get_config_value(dbase, "user", "autotest")
- DB_PASS = c.get_config_value(dbase, "password")
-
- global _connection, _cursor
- _connection = MySQLdb.connect(
- host=DB_HOST,
- user=DB_USER,
- passwd=DB_PASS,
- db=DB_SCHEMA
- )
- _connection.autocommit(True)
- _cursor = _connection.cursor()
+ self.connect()
-def disconnect():
- global _connection, _cursor
- _connection.close()
- _connection = None
- _cursor = None
+ def connect(self):
+ self.disconnect()
+
+ # get global config and parse for info
+ c = global_config.global_config
+ dbase = "AUTOTEST_WEB"
+ DB_HOST = c.get_config_value(dbase, "host", "localhost")
+ DB_SCHEMA = c.get_config_value(dbase, "database",
+ "autotest_web")
+
+ global _testing_mode
+ if _testing_mode:
+ DB_SCHEMA = 'stresstest_autotest_web'
+
+ DB_USER = c.get_config_value(dbase, "user", "autotest")
+ DB_PASS = c.get_config_value(dbase, "password", "google")
+
+ while not self.conn:
+ try:
+ self.conn = MySQLdb.connect(host=DB_HOST,
+ user=DB_USER,
+ passwd=DB_PASS,
+ db=DB_SCHEMA)
+
+ self.conn.autocommit(True)
+ self.cur = self.conn.cursor()
+ except MySQLdb.OperationalError:
+ #traceback.print_exc()
+ print "Can't connect to MYSQL; reconnecting"
+ time.sleep(self.reconnect_wait)
+ self.disconnect()
+
+
+ def disconnect(self):
+ if self.conn:
+ self.conn.close()
+ self.conn = None
+ self.cur = None
+
+
+ def execute(self, *args, **dargs):
+ while (True):
+ try:
+ self.cur.execute(*args, **dargs)
+ return self.cur.fetchall()
+ except MySQLdb.OperationalError:
+ print "MYSQL connection died; reconnecting"
+ time.sleep(self.reconnect_wait)
+ self.connect()
def parse_results(results_dir, flags=""):
@@ -213,9 +242,9 @@
def _recover_lost(self):
- _cursor.execute("""SELECT * FROM host_queue_entries WHERE active AND NOT complete""")
- if _cursor.rowcount:
- queue_entries = [HostQueueEntry(row=i) for i in _cursor.fetchall()]
+ rows = _db.execute("""SELECT * FROM host_queue_entries WHERE active AND NOT complete""")
+ if len(rows) > 0:
+ queue_entries = [HostQueueEntry(row=i) for i in rows]
for queue_entry in queue_entries:
job = queue_entry.job
if job.is_synchronous():
@@ -225,10 +254,10 @@
queue_entry.requeue()
queue_entry.clear_results_dir()
- _cursor.execute("""SELECT * FROM hosts
+ rows = _db.execute("""SELECT * FROM hosts
WHERE status != 'Ready' AND NOT locked""")
- if _cursor.rowcount:
- hosts = [Host(row=i) for i in _cursor.fetchall()]
+ if len(rows) > 0:
+ hosts = [Host(row=i) for i in rows]
for host in hosts:
verify_task = VerifyTask(host = host)
self.add_agent(Agent(tasks = [verify_task]))
@@ -627,11 +656,11 @@
if row is None:
sql = 'SELECT * FROM %s WHERE ID=%%s' % self.__table
- _cursor.execute(sql, (id,))
- if not _cursor.rowcount:
+ rows = _db.execute(sql, (id,))
+ if len(rows) == 0:
raise "row not found (table=%s, id=%s)" % \
(self.__table, id)
- row = _cursor.fetchone()
+ row = rows[0]
assert len(row)==len(fields), "table = %s, row = %s/%d, fields = %s/%d" % (table, row, len(row), fields, len(fields))
@@ -646,13 +675,14 @@
if not table:
table = self.__table
- _cursor.execute("""
+ rows = _db.execute("""
SELECT count(*) FROM %s
WHERE %s
""" % (table, where))
- count = _cursor.fetchall()
- return int(count[0][0])
+ assert len(rows) == 1
+
+ return int(rows[0][0])
def num_cols(self):
@@ -667,7 +697,7 @@
query = "UPDATE %s SET %s = %%s WHERE id = %%s" % \
(self.__table, field)
- _cursor.execute(query, (value, self.id))
+ _db.execute(query, (value, self.id))
self.__dict__[field] = value
@@ -680,13 +710,9 @@
values = ','.join(values)
query = """INSERT INTO %s (%s) VALUES (%s)""" % \
(self.__table, columns, values)
- _cursor.execute(query)
+ _db.execute(query)
- def delete(self):
- _cursor.execute("""DELETE FROM %s WHERE id = %%s""" % \
- self.__table, (self.id,))
-
class IneligibleHostQueue(DBObject):
def __init__(self, id=None, row=None, new_record=None):
@@ -703,15 +729,15 @@
def current_task(self):
- _cursor.execute("""
+ rows = _db.execute("""
SELECT * FROM host_queue_entries WHERE host_id=%s AND NOT complete AND active
""", (self.id,))
- if not _cursor.rowcount:
+ if len(rows) == 0:
return None
else:
- assert _cursor.rowcount == 1
- results = _cursor.fetchone();
+ assert len(rows) == 1
+ results = rows[0];
# print "current = %s" % results
return HostQueueEntry(row=results)
@@ -721,7 +747,7 @@
print "%s locked, not queuing" % self.hostname
return None
# print "%s/%s looking for work" % (self.hostname, self.platform_id)
- _cursor.execute("""
+ rows = _db.execute("""
SELECT * FROM host_queue_entries
WHERE ((host_id=%s) OR (meta_host IS NOT null AND
(meta_host IN (
@@ -737,10 +763,10 @@
LIMIT 1
""", (self.id,self.id, self.id))
- if not _cursor.rowcount:
+ if len(rows) == 0:
return None
else:
- return [HostQueueEntry(row=i) for i in _cursor.fetchall()]
+ return [HostQueueEntry(row=i) for i in rows]
def yield_work(self):
print "%s yielding work" % self.hostname
@@ -873,11 +899,11 @@
def get_host_queue_entries(self):
- _cursor.execute("""
+ rows = _db.execute("""
SELECT * FROM host_queue_entries
WHERE job_id= %s
""", (self.id,))
- entries = [HostQueueEntry(row=i) for i in _cursor.fetchall()]
+ entries = [HostQueueEntry(row=i) for i in rows]
assert len(entries)>0