A bit of refactoring to monitor_db.py to clean up some code and make it more testable.


git-svn-id: http://test.kernel.org/svn/autotest/trunk@1558 592f7852-d20e-0410-864c-8624ca9c26a4
diff --git a/scheduler/monitor_db.py b/scheduler/monitor_db.py
index e059e48..b88eff0 100644
--- a/scheduler/monitor_db.py
+++ b/scheduler/monitor_db.py
@@ -97,6 +97,7 @@
 	os.environ['PATH'] = AUTOTEST_SERVER_DIR + ':' + os.environ['PATH']
 	global _db
 	_db = DatabaseConn()
+	_db.connect()
 
 	print "Setting signal handler"
 	signal.signal(signal.SIGINT, handle_sigint)
@@ -158,8 +159,6 @@
 		self.convert_dict = MySQLdb.converters.conversions
 		self.convert_dict.setdefault(bool, self.convert_boolean)
 
-		self.connect()
-
 
 	@staticmethod
 	def convert_boolean(boolean, conversion_dict):
@@ -167,27 +166,27 @@
 		return str(int(boolean))
 
 
-	def connect(self):
+	def connect(self, db_name=None):
 		self.disconnect()
 
 		# get global config and parse for info
 		c = global_config.global_config
 		dbase = "AUTOTEST_WEB"
-		DB_HOST = c.get_config_value(dbase, "host")
-		DB_SCHEMA = c.get_config_value(dbase, "database")
-		
-		global _testing_mode
-		if _testing_mode:
-			DB_SCHEMA = 'stresstest_autotest_web'
+		db_host = c.get_config_value(dbase, "host")
+		if db_name is None:
+			db_name = c.get_config_value(dbase, "database")
 
-		DB_USER = c.get_config_value(dbase, "user")
-		DB_PASS = c.get_config_value(dbase, "password")
+		if _testing_mode:
+			db_name = 'stresstest_autotest_web'
+
+		db_user = c.get_config_value(dbase, "user")
+		db_pass = c.get_config_value(dbase, "password")
 
 		while not self.conn:
 			try:
 				self.conn = MySQLdb.connect(
-				    host=DB_HOST, user=DB_USER, passwd=DB_PASS,
-				    db=DB_SCHEMA, conv=self.convert_dict)
+				    host=db_host, user=db_user, passwd=db_pass,
+				    db=db_name, conv=self.convert_dict)
 
 				self.conn.autocommit(True)
 				self.cur = self.conn.cursor()
@@ -518,23 +517,19 @@
 
 		for host in idle_hosts():
 			tasks = host.next_queue_entries()
-			if tasks:
-				for next in tasks:
-					try:
-						agent = next.run(assigned_host=host)
-						if agent:							
-							self.add_agent(agent)
-							break
-					except:
-						next.set_status('Failed')
-						
-#						if next.host:
-#							next.host.set_status('Ready')
+			for next in tasks:
+				try:
+					agent = next.run(assigned_host=host)
+					if agent:
+						self.add_agent(agent)
+						break
+				except:
+					next.set_status('Failed')
 
-						log_stacktrace("task_id = %d" % next.id)
+					log_stacktrace("task_id = %d" % next.id)
 
 
-	def _find_aborting(self):	
+	def _find_aborting(self):
 		num_aborted = 0
 		# Find jobs that are aborting
 		for entry in queue_entries_to_abort():
@@ -1348,7 +1343,7 @@
 	def next_queue_entries(self):
 		if self.locked:
 			print "%s locked, not queuing" % self.hostname
-			return None
+			return []
 #		print "%s/%s looking for work" % (self.hostname, self.platform_id)
 		rows = _db.execute("""
 			SELECT * FROM host_queue_entries
@@ -1357,7 +1352,7 @@
 				SELECT label_id FROM hosts_labels WHERE host_id=%s
 				)
 			)
-			AND job_id NOT IN ( 
+			AND job_id NOT IN (
 				SELECT job_id FROM ineligible_host_queues
 				WHERE host_id=%s
 			)))
@@ -1366,11 +1361,8 @@
 			LIMIT 1
 		""", (self.id,self.id, self.id))
 
-		if len(rows) == 0:
-			return None
-		else:
-			return [HostQueueEntry(row=i) for i in rows]
-	
+		return [HostQueueEntry(row=i) for i in rows]
+
 	def yield_work(self):
 		print "%s yielding work" % self.hostname
 		if self.current_task():