Convert the parser into an embeddable library. This includes a major
re-organization of the existing code.

Signed-off-by: John Admanski <jadmanski@google.com>



git-svn-id: http://test.kernel.org/svn/autotest/trunk@1447 592f7852-d20e-0410-864c-8624ca9c26a4
diff --git a/tko/db.py b/tko/db.py
index 14e7c1b..7d2345b 100644
--- a/tko/db.py
+++ b/tko/db.py
@@ -1,5 +1,8 @@
-import re, os, sys, types
-from common import global_config
+import re, os, sys, types, time
+
+import common
+from autotest_lib.client.common_lib import global_config
+
 
 class MySQLTooManyRows(Exception):
 	pass
@@ -10,26 +13,28 @@
 				database = None, user = None, password = None):
 		self.debug = debug
 		self.autocommit = autocommit
-		
-		path = os.path.dirname(__file__)
-		
+
+		self.host = host
+		self.database = database
+		self.user = user
+		self.password = password
+
 		# grab the global config
 		c = global_config.global_config
-		
+
 		# grab the host, database
-		if not host:
-			host = c.get_config_value("TKO", "host")
-		if not database:
-			database = c.get_config_value("TKO", "database")
-		
+		if not self.host:
+			self.host = c.get_config_value("TKO", "host")
+		if not self.database:
+			self.database = c.get_config_value("TKO", "database")
+
 		# grab the user and password
-		if not user:
-			user = c.get_config_value("TKO", "user")
-		if not password:
-			password = c.get_config_value("TKO", "password")
-			
-		self.con = self.connect(host, database, user, password)
-		self.cur = self.con.cursor()
+		if not self.user:
+			self.user = c.get_config_value("TKO", "user")
+		if not self.password:
+			self.password = c.get_config_value("TKO", "password")
+
+		self._init_db()
 
 		# if not present, insert statuses
 		self.status_idx = {}
@@ -39,8 +44,8 @@
 			self.status_idx[s[1]] = s[0]
 			self.status_word[s[0]] = s[1]
 
-		dir = os.path.dirname(__file__)
-		machine_map = os.path.join(dir, 'machines')
+		machine_map = os.path.join(os.path.dirname(__file__),
+					   'machines')
 		if os.path.exists(machine_map):
 			self.machine_map = machine_map
 		else:
@@ -48,6 +53,41 @@
 		self.machine_group = {}
 
 
+	def _init_db(self):
+		# create the db connection and cursor
+		self.con = self.connect(self.host, self.database,
+					self.user, self.password)
+		self.cur = self.con.cursor()
+
+
+	def _run_with_retry(self, function, *args, **dargs):
+		"""Call function(*args, **dargs) until either it passes
+		without an operational error, or a timeout is reached. This
+		is intended for internal use with database functions, not
+		for generic use."""
+		OperationalError = _get_error_class("OperationalError")
+		# TODO: make this configurable
+		TIMEOUT = 3600 # one hour
+		success = False
+		start_time = time.time()
+		while not success:
+			try:
+				result = function(*args, **dargs)
+			except OperationalError:
+				stop_time = time.time()
+				elapsed_time = stop_time - start_time
+				if elapsed_time > TIMEOUT:
+					raise
+				else:
+					try:
+						self._init_db()
+					except OperationalError:
+						pass
+			else:
+				success = True
+		return result
+
+
 	def dprint(self, value):
 		if self.debug:
 			sys.stdout.write('SQL: ' + str(value) + '\n')
@@ -103,20 +143,26 @@
 
 		# TODO: this assumes there's a where clause...bad
 		if wherein and isinstance(wherein, types.DictionaryType):
-			keys_in = []
-			for field_in in wherein.keys():
-				keys_in += [field_in + ' in (' + ','.join(wherein[field_in])+') '] 
-			
+			keys_in = ["%s in (%s) " % (field, ','.join(where))
+				   for field, where in wherein.iteritems()]
 			cmd.append(' and '+' and '.join(keys_in))
+
 		if group_by:
 			cmd.append(' GROUP BY ' + group_by)
 
 		self.dprint('%s %s' % (' '.join(cmd), values))
-		numRec = self.cur.execute(' '.join(cmd), values)
-		if max_rows != None and numRec > max_rows:
-			msg = 'Exceeded allowed number of records'
-			raise MySQLTooManyRows(msg)
-		return self.cur.fetchall()
+
+		# create a re-runable function for executing the query
+		def exec_sql():
+			sql = ' '.join(cmd)
+			numRec = self.cur.execute(sql, values)
+			if max_rows != None and numRec > max_rows:
+				msg = 'Exceeded allowed number of records'
+				raise MySQLTooManyRows(msg)
+			return self.cur.fetchall()
+
+		# run the query, re-trying after operational errors
+		return self._run_with_retry(exec_sql)
 
 
 	def select_sql(self, fields, table, sql, values):
@@ -125,8 +171,28 @@
 		"""
 		cmd = 'select %s from %s %s' % (fields, table, sql)
 		self.dprint(cmd)
-		self.cur.execute(cmd, values)
-		return self.cur.fetchall()
+
+		# create a -re-runable function for executing the query
+		def exec_sql():
+			self.cur.execute(cmd, values)
+			return self.cur.fetchall()
+
+		# run the query, re-trying after operational errors
+		return self._run_with_retry(exec_sql)
+
+
+	def _exec_sql_with_commit(self, sql, values, commit):
+		if self.autocommit:
+			# re-run the query until it succeeds
+			def exec_sql():
+				self.cur.execute(sql, values)
+				self.con.commit()
+			self._run_with_retry(exec_sql)
+		else:
+			# take one shot at running the query
+			self.cur.execute(sql, values)
+			if commit:
+				self.con.commit()
 
 
 	def insert(self, table, data, commit = None):
@@ -136,18 +202,14 @@
 			data:
 				dictionary of fields and data
 		"""
-		if commit == None:
-			commit = self.autocommit
 		fields = data.keys()
 		refs = ['%s' for field in fields]
 		values = [data[field] for field in fields]
 		cmd = 'insert into %s (%s) values (%s)' % \
 				(table, ','.join(fields), ','.join(refs))
+		self.dprint('%s %s' % (cmd, values))
 
-		self.dprint('%s %s' % (cmd,values))
-		self.cur.execute(cmd, values)
-		if commit:
-			self.con.commit()
+		self._exec_sql_with_commit(cmd, values, commit)
 
 
 	def delete(self, table, where, commit = None):
@@ -158,11 +220,11 @@
 			keys = [field + '=%s' for field in where.keys()]
 			values = [where[field] for field in where.keys()]
 			cmd += ['where', ' and '.join(keys)]
-		self.dprint('%s %s' % (' '.join(cmd),values))
-		self.cur.execute(' '.join(cmd), values)
-		if commit:
-			self.con.commit()
-		
+		sql = ' '.join(cmd)
+		self.dprint('%s %s' % (sql, values))
+
+		self._exec_sql_with_commit(sql, values, commit)
+
 
 	def update(self, table, data, where, commit = None):
 		"""\
@@ -183,10 +245,10 @@
 		where_values = [where[field] for field in where.keys()]
 		cmd += ' where ' + ' and '.join(where_keys)
 
-		print '%s %s' % (cmd, data_values + where_values)
-		self.cur.execute(cmd, data_values + where_values)
-		if commit:
-			self.con.commit()
+		values = data_values + where_values
+		print '%s %s' % (cmd, values)
+
+		self._exec_sql_with_commit(cmd, values, commit)
 
 
 	def delete_job(self, tag, commit = None):
@@ -237,8 +299,10 @@
 				self.insert('iteration_result',
                                             data,
                                             commit=commit)
-		data = {'test_idx':test_idx, 'attribute':'version', 'value':test.version}
-		self.insert('test_attributes', data, commit=commit)
+
+		for key, value in test.attributes.iteritems():
+			data = {'test_idx': test_idx, 'attribute': key, 'value': value}
+			self.insert('test_attributes', data, commit=commit)
 
 
 	def read_machine_map(self):
@@ -354,16 +418,26 @@
 			return None
 
 
-# Use a class method as a class factory, generating a relevant database object.
-def db(*args, **dargs):
-	path = os.path.dirname(__file__)
-	db_type = None
-	
-	# read db_type from global config
-	c = global_config.global_config
-	db_type = c.get_config_value("TKO", "db_type", default="mysql")
-	db_type = 'db_' + db_type
-	exec ('import %s; db = %s.%s(*args, **dargs)'
-	      % (db_type, db_type, db_type))
+def _get_db_type():
+	"""Get the database type name to use from the global config."""
+	get_value = global_config.global_config.get_config_value
+	return "db_" + get_value("TKO", "db_type", default="mysql")
 
+
+def _get_error_class(class_name):
+	"""Retrieves the appropriate error class by name from the database
+	module."""
+	db_module = __import__("autotest_lib.tko." + _get_db_type(),
+			       globals(), locals(), ["driver"])
+	return getattr(db_module.driver, class_name)
+
+
+def db(*args, **dargs):
+	"""Creates an instance of the database class with the arguments
+	provided in args and dargs, using the database type specified by
+	the global configuration (defaulting to mysql)."""
+	db_type = _get_db_type()
+	db_module = __import__("autotest_lib.tko." + db_type, globals(),
+			       locals(), [db_type])
+	db = getattr(db_module, db_type)(*args, **dargs)
 	return db