Catch any errors due to mysql losing its connection.  If it does lose
it, retry the connection.

From: Jeremy Orlow <jorlow@google.com>
Signed-off-by: Steve Howard <showard@google.com>



git-svn-id: http://test.kernel.org/svn/autotest/trunk@1280 592f7852-d20e-0410-864c-8624ca9c26a4
diff --git a/scheduler/monitor_db b/scheduler/monitor_db
index 475f849..f89c846 100755
--- a/scheduler/monitor_db
+++ b/scheduler/monitor_db
@@ -22,8 +22,7 @@
 if AUTOTEST_SERVER_DIR not in sys.path:
         sys.path.insert(0, AUTOTEST_SERVER_DIR)
 
-_connection = None
-_cursor = None
+_db = None
 _shutdown = False
 _notify_email = None
 _autoserv_path = 'autoserv'
@@ -73,7 +72,7 @@
 	except:
 		log_stacktrace("Uncaught exception; terminating monitor_db")
 
-	disconnect()
+	_db.disconnect()
 
 
 def handle_sigint(signum, frame):
@@ -89,7 +88,8 @@
 	print "My PID is %d" % os.getpid()
 
 	os.environ['PATH'] = AUTOTEST_SERVER_DIR + ':' + os.environ['PATH']
-	connect()
+	global _db
+	_db = DatabaseConn()
 
 	print "Setting signal handler"
 	signal.signal(signal.SIGINT, handle_sigint)
@@ -112,7 +112,7 @@
 
 
 def idle_hosts():
-	_cursor.execute("""
+	rows = _db.execute("""
 		SELECT * FROM hosts h WHERE 
 		id NOT IN (SELECT host_id FROM host_queue_entries WHERE active) AND (
 				(id IN (SELECT host_id FROM host_queue_entries WHERE not complete AND not active)) 
@@ -121,39 +121,68 @@
 				INNER JOIN hosts_labels hl ON hqe.meta_host=hl.label_id WHERE not hqe.complete AND not hqe.active))
 		)
 		AND locked=false AND (h.status IS null OR h.status='Ready')	""")
-	hosts = [Host(row=i) for i in _cursor.fetchall()]
+	hosts = [Host(row=i) for i in rows]
 	return hosts
 
 
-def connect():
-	path = os.path.dirname(os.path.abspath(sys.argv[0]))
-	# get global config and parse for info
-	c = global_config.global_config
-	dbase = "AUTOTEST_WEB"
-	DB_HOST = c.get_config_value(dbase, "host", "localhost")
-	DB_SCHEMA = c.get_config_value(dbase, "database", "autotest_web")
-	if _testing_mode:
-		DB_SCHEMA = 'stresstest_autotest_web'
+class DatabaseConn:
+	def __init__(self):
+		self.reconnect_wait = 20
+		self.conn = None
+		self.cur = None
 
-	DB_USER = c.get_config_value(dbase, "user", "autotest")
-	DB_PASS = c.get_config_value(dbase, "password")
-
-	global _connection, _cursor
-	_connection = MySQLdb.connect(
-		host=DB_HOST,
-		user=DB_USER,
-		passwd=DB_PASS,
-		db=DB_SCHEMA
-	)
-	_connection.autocommit(True)
-	_cursor = _connection.cursor()
+		self.connect()
 
 
-def disconnect():
-	global _connection, _cursor
-	_connection.close()
-	_connection = None
-	_cursor = None
+	def connect(self):
+		self.disconnect()
+
+		# get global config and parse for info
+		c = global_config.global_config
+		dbase = "AUTOTEST_WEB"
+		DB_HOST = c.get_config_value(dbase, "host", "localhost")
+		DB_SCHEMA = c.get_config_value(dbase, "database",
+					       "autotest_web")
+		
+		global _testing_mode
+		if _testing_mode:
+			DB_SCHEMA = 'stresstest_autotest_web'
+
+		DB_USER = c.get_config_value(dbase, "user", "autotest")
+		DB_PASS = c.get_config_value(dbase, "password", "google")
+
+		while not self.conn:
+			try:
+				self.conn = MySQLdb.connect(host=DB_HOST,
+				                            user=DB_USER,
+				                            passwd=DB_PASS,
+				                            db=DB_SCHEMA)
+
+				self.conn.autocommit(True)
+				self.cur = self.conn.cursor()
+			except MySQLdb.OperationalError:
+				#traceback.print_exc()
+				print "Can't connect to MYSQL; reconnecting"
+				time.sleep(self.reconnect_wait)
+				self.disconnect()
+
+
+	def disconnect(self):
+		if self.conn:
+			self.conn.close()
+		self.conn = None
+		self.cur = None
+
+
+	def execute(self, *args, **dargs):
+		while (True):
+			try:
+				self.cur.execute(*args, **dargs)
+				return self.cur.fetchall()
+			except MySQLdb.OperationalError:
+				print "MYSQL connection died; reconnecting"
+				time.sleep(self.reconnect_wait)
+				self.connect()
 
 
 def parse_results(results_dir, flags=""):
@@ -213,9 +242,9 @@
 
 
 	def _recover_lost(self):
-		_cursor.execute("""SELECT * FROM host_queue_entries WHERE active AND NOT complete""")
-		if _cursor.rowcount:
-			queue_entries = [HostQueueEntry(row=i) for i in _cursor.fetchall()]
+		rows = _db.execute("""SELECT * FROM host_queue_entries WHERE active AND NOT complete""")
+		if len(rows) > 0:
+			queue_entries = [HostQueueEntry(row=i) for i in rows]
 			for queue_entry in queue_entries:
 				job = queue_entry.job
 				if job.is_synchronous():
@@ -225,10 +254,10 @@
 					queue_entry.requeue()
 				queue_entry.clear_results_dir()
 			
-		_cursor.execute("""SELECT * FROM hosts
+		rows = _db.execute("""SELECT * FROM hosts
 		                   WHERE status != 'Ready' AND NOT locked""")
-		if _cursor.rowcount:
-			hosts = [Host(row=i) for i in _cursor.fetchall()]
+		if len(rows) > 0:
+			hosts = [Host(row=i) for i in rows]
 			for host in hosts:
 				verify_task = VerifyTask(host = host)
 				self.add_agent(Agent(tasks = [verify_task]))
@@ -627,11 +656,11 @@
 
 		if row is None:
 			sql = 'SELECT * FROM %s WHERE ID=%%s' % self.__table
-			_cursor.execute(sql, (id,))
-			if not _cursor.rowcount:
+			rows = _db.execute(sql, (id,))
+			if len(rows) == 0:
 				raise "row not found (table=%s, id=%s)" % \
 							(self.__table, id)
-			row = _cursor.fetchone()
+			row = rows[0]
 
 		assert len(row)==len(fields), "table = %s, row = %s/%d, fields = %s/%d" % (table, row, len(row), fields, len(fields))
 
@@ -646,13 +675,14 @@
 		if not table:
 			table = self.__table
 	
-		_cursor.execute("""
+		rows = _db.execute("""
 			SELECT count(*) FROM %s
 			WHERE %s
 		""" % (table, where))
-		count = _cursor.fetchall()
 
-		return int(count[0][0])
+		assert len(rows) == 1
+
+		return int(rows[0][0])
 
 
 	def num_cols(self):
@@ -667,7 +697,7 @@
 
 		query = "UPDATE %s SET %s = %%s WHERE id = %%s" % \
 							(self.__table, field)
-		_cursor.execute(query, (value, self.id))
+		_db.execute(query, (value, self.id))
 
 		self.__dict__[field] = value
 
@@ -680,13 +710,9 @@
 			values = ','.join(values)
 			query = """INSERT INTO %s (%s) VALUES (%s)""" % \
 						(self.__table, columns, values)
-			_cursor.execute(query)
+			_db.execute(query)
 
 
-	def delete(self):
-		_cursor.execute("""DELETE FROM %s WHERE id = %%s""" % \
-						self.__table, (self.id,))
-
 
 class IneligibleHostQueue(DBObject):
 	def __init__(self, id=None, row=None, new_record=None):
@@ -703,15 +729,15 @@
 
 
 	def current_task(self):
-		_cursor.execute("""
+		rows = _db.execute("""
 			SELECT * FROM host_queue_entries WHERE host_id=%s AND NOT complete AND active
 			""", (self.id,))
 		
-		if not _cursor.rowcount:
+		if len(rows) == 0:
 			return None
 		else:
-			assert _cursor.rowcount == 1
-			results = _cursor.fetchone();
+			assert len(rows) == 1
+			results = rows[0];
 #			print "current = %s" % results
 			return HostQueueEntry(row=results)
 
@@ -721,7 +747,7 @@
 			print "%s locked, not queuing" % self.hostname
 			return None
 #		print "%s/%s looking for work" % (self.hostname, self.platform_id)
-		_cursor.execute("""
+		rows = _db.execute("""
 			SELECT * FROM host_queue_entries
 			WHERE ((host_id=%s) OR (meta_host IS NOT null AND
 			(meta_host IN (
@@ -737,10 +763,10 @@
 			LIMIT 1
 		""", (self.id,self.id, self.id))
 
-		if not _cursor.rowcount:
+		if len(rows) == 0:
 			return None
 		else:
-			return [HostQueueEntry(row=i) for i in _cursor.fetchall()]
+			return [HostQueueEntry(row=i) for i in rows]
 	
 	def yield_work(self):
 		print "%s yielding work" % self.hostname
@@ -873,11 +899,11 @@
 
 
 	def get_host_queue_entries(self):
-		_cursor.execute("""
+		rows = _db.execute("""
 			SELECT * FROM host_queue_entries
 			WHERE job_id= %s
 		""", (self.id,))
-		entries = [HostQueueEntry(row=i) for i in _cursor.fetchall()]
+		entries = [HostQueueEntry(row=i) for i in rows]
 
 		assert len(entries)>0