Convert the parser into an embeddable library. This includes a major
re-organization of the existing code.
Signed-off-by: John Admanski <jadmanski@google.com>
git-svn-id: http://test.kernel.org/svn/autotest/trunk@1447 592f7852-d20e-0410-864c-8624ca9c26a4
diff --git a/tko/db.py b/tko/db.py
index 14e7c1b..7d2345b 100644
--- a/tko/db.py
+++ b/tko/db.py
@@ -1,5 +1,8 @@
-import re, os, sys, types
-from common import global_config
+import re, os, sys, types, time
+
+import common
+from autotest_lib.client.common_lib import global_config
+
class MySQLTooManyRows(Exception):
pass
@@ -10,26 +13,28 @@
database = None, user = None, password = None):
self.debug = debug
self.autocommit = autocommit
-
- path = os.path.dirname(__file__)
-
+
+ self.host = host
+ self.database = database
+ self.user = user
+ self.password = password
+
# grab the global config
c = global_config.global_config
-
+
# grab the host, database
- if not host:
- host = c.get_config_value("TKO", "host")
- if not database:
- database = c.get_config_value("TKO", "database")
-
+ if not self.host:
+ self.host = c.get_config_value("TKO", "host")
+ if not self.database:
+ self.database = c.get_config_value("TKO", "database")
+
# grab the user and password
- if not user:
- user = c.get_config_value("TKO", "user")
- if not password:
- password = c.get_config_value("TKO", "password")
-
- self.con = self.connect(host, database, user, password)
- self.cur = self.con.cursor()
+ if not self.user:
+ self.user = c.get_config_value("TKO", "user")
+ if not self.password:
+ self.password = c.get_config_value("TKO", "password")
+
+ self._init_db()
# if not present, insert statuses
self.status_idx = {}
@@ -39,8 +44,8 @@
self.status_idx[s[1]] = s[0]
self.status_word[s[0]] = s[1]
- dir = os.path.dirname(__file__)
- machine_map = os.path.join(dir, 'machines')
+ machine_map = os.path.join(os.path.dirname(__file__),
+ 'machines')
if os.path.exists(machine_map):
self.machine_map = machine_map
else:
@@ -48,6 +53,41 @@
self.machine_group = {}
+ def _init_db(self):
+ # create the db connection and cursor
+ self.con = self.connect(self.host, self.database,
+ self.user, self.password)
+ self.cur = self.con.cursor()
+
+
+ def _run_with_retry(self, function, *args, **dargs):
+ """Call function(*args, **dargs) until either it passes
+ without an operational error, or a timeout is reached. This
+ is intended for internal use with database functions, not
+ for generic use."""
+ OperationalError = _get_error_class("OperationalError")
+ # TODO: make this configurable
+ TIMEOUT = 3600 # one hour
+ success = False
+ start_time = time.time()
+ while not success:
+ try:
+ result = function(*args, **dargs)
+ except OperationalError:
+ stop_time = time.time()
+ elapsed_time = stop_time - start_time
+ if elapsed_time > TIMEOUT:
+ raise
+ else:
+ try:
+ self._init_db()
+ except OperationalError:
+ pass
+ else:
+ success = True
+ return result
+
+
def dprint(self, value):
if self.debug:
sys.stdout.write('SQL: ' + str(value) + '\n')
@@ -103,20 +143,26 @@
# TODO: this assumes there's a where clause...bad
if wherein and isinstance(wherein, types.DictionaryType):
- keys_in = []
- for field_in in wherein.keys():
- keys_in += [field_in + ' in (' + ','.join(wherein[field_in])+') ']
-
+ keys_in = ["%s in (%s) " % (field, ','.join(where))
+ for field, where in wherein.iteritems()]
cmd.append(' and '+' and '.join(keys_in))
+
if group_by:
cmd.append(' GROUP BY ' + group_by)
self.dprint('%s %s' % (' '.join(cmd), values))
- numRec = self.cur.execute(' '.join(cmd), values)
- if max_rows != None and numRec > max_rows:
- msg = 'Exceeded allowed number of records'
- raise MySQLTooManyRows(msg)
- return self.cur.fetchall()
+
+ # create a re-runable function for executing the query
+ def exec_sql():
+ sql = ' '.join(cmd)
+ numRec = self.cur.execute(sql, values)
+ if max_rows != None and numRec > max_rows:
+ msg = 'Exceeded allowed number of records'
+ raise MySQLTooManyRows(msg)
+ return self.cur.fetchall()
+
+ # run the query, re-trying after operational errors
+ return self._run_with_retry(exec_sql)
def select_sql(self, fields, table, sql, values):
@@ -125,8 +171,28 @@
"""
cmd = 'select %s from %s %s' % (fields, table, sql)
self.dprint(cmd)
- self.cur.execute(cmd, values)
- return self.cur.fetchall()
+
+ # create a -re-runable function for executing the query
+ def exec_sql():
+ self.cur.execute(cmd, values)
+ return self.cur.fetchall()
+
+ # run the query, re-trying after operational errors
+ return self._run_with_retry(exec_sql)
+
+
+ def _exec_sql_with_commit(self, sql, values, commit):
+ if self.autocommit:
+ # re-run the query until it succeeds
+ def exec_sql():
+ self.cur.execute(sql, values)
+ self.con.commit()
+ self._run_with_retry(exec_sql)
+ else:
+ # take one shot at running the query
+ self.cur.execute(sql, values)
+ if commit:
+ self.con.commit()
def insert(self, table, data, commit = None):
@@ -136,18 +202,14 @@
data:
dictionary of fields and data
"""
- if commit == None:
- commit = self.autocommit
fields = data.keys()
refs = ['%s' for field in fields]
values = [data[field] for field in fields]
cmd = 'insert into %s (%s) values (%s)' % \
(table, ','.join(fields), ','.join(refs))
+ self.dprint('%s %s' % (cmd, values))
- self.dprint('%s %s' % (cmd,values))
- self.cur.execute(cmd, values)
- if commit:
- self.con.commit()
+ self._exec_sql_with_commit(cmd, values, commit)
def delete(self, table, where, commit = None):
@@ -158,11 +220,11 @@
keys = [field + '=%s' for field in where.keys()]
values = [where[field] for field in where.keys()]
cmd += ['where', ' and '.join(keys)]
- self.dprint('%s %s' % (' '.join(cmd),values))
- self.cur.execute(' '.join(cmd), values)
- if commit:
- self.con.commit()
-
+ sql = ' '.join(cmd)
+ self.dprint('%s %s' % (sql, values))
+
+ self._exec_sql_with_commit(sql, values, commit)
+
def update(self, table, data, where, commit = None):
"""\
@@ -183,10 +245,10 @@
where_values = [where[field] for field in where.keys()]
cmd += ' where ' + ' and '.join(where_keys)
- print '%s %s' % (cmd, data_values + where_values)
- self.cur.execute(cmd, data_values + where_values)
- if commit:
- self.con.commit()
+ values = data_values + where_values
+ print '%s %s' % (cmd, values)
+
+ self._exec_sql_with_commit(cmd, values, commit)
def delete_job(self, tag, commit = None):
@@ -237,8 +299,10 @@
self.insert('iteration_result',
data,
commit=commit)
- data = {'test_idx':test_idx, 'attribute':'version', 'value':test.version}
- self.insert('test_attributes', data, commit=commit)
+
+ for key, value in test.attributes.iteritems():
+ data = {'test_idx': test_idx, 'attribute': key, 'value': value}
+ self.insert('test_attributes', data, commit=commit)
def read_machine_map(self):
@@ -354,16 +418,26 @@
return None
-# Use a class method as a class factory, generating a relevant database object.
-def db(*args, **dargs):
- path = os.path.dirname(__file__)
- db_type = None
-
- # read db_type from global config
- c = global_config.global_config
- db_type = c.get_config_value("TKO", "db_type", default="mysql")
- db_type = 'db_' + db_type
- exec ('import %s; db = %s.%s(*args, **dargs)'
- % (db_type, db_type, db_type))
+def _get_db_type():
+ """Get the database type name to use from the global config."""
+ get_value = global_config.global_config.get_config_value
+ return "db_" + get_value("TKO", "db_type", default="mysql")
+
+def _get_error_class(class_name):
+ """Retrieves the appropriate error class by name from the database
+ module."""
+ db_module = __import__("autotest_lib.tko." + _get_db_type(),
+ globals(), locals(), ["driver"])
+ return getattr(db_module.driver, class_name)
+
+
+def db(*args, **dargs):
+ """Creates an instance of the database class with the arguments
+ provided in args and dargs, using the database type specified by
+ the global configuration (defaulting to mysql)."""
+ db_type = _get_db_type()
+ db_module = __import__("autotest_lib.tko." + db_type, globals(),
+ locals(), [db_type])
+ db = getattr(db_module, db_type)(*args, **dargs)
return db