Add unit tests for the TKO RPC interface

Risk: low
Visibility: low

Signed-off-by: James Ren <jamesren@google.com>


git-svn-id: http://test.kernel.org/svn/autotest/trunk@3332 592f7852-d20e-0410-864c-8624ca9c26a4
diff --git a/new_tko/tko/models.py b/new_tko/tko/models.py
index dd02fa8..2c1133d 100644
--- a/new_tko/tko/models.py
+++ b/new_tko/tko/models.py
@@ -47,6 +47,14 @@
         return ('SELECT ' + ', '.join(select_fields) + where), params
 
 
+    def _get_column_names(self, cursor):
+        """\
+        Gets the column names from the cursor description. This method exists
+        so that it can be mocked in the unit test for sqlite3 compatibility."
+        """
+        return [column_info[0] for column_info in cursor.description]
+
+
     def execute_group_query(self, query, group_by, extra_select_fields=[]):
         """
         Performs the given query grouped by the fields in group_by with the
@@ -60,7 +68,7 @@
                                                 extra_select_fields)
         cursor = readonly_connection.connection().cursor()
         cursor.execute(sql, params)
-        field_names = [column_info[0] for column_info in cursor.description]
+        field_names = self._get_column_names(cursor)
         row_dicts = [dict(zip(field_names, row)) for row in cursor.fetchall()]
         return row_dicts
 
diff --git a/new_tko/tko/rpc_interface.py b/new_tko/tko/rpc_interface.py
index ac7e0b7..75ed00b 100644
--- a/new_tko/tko/rpc_interface.py
+++ b/new_tko/tko/rpc_interface.py
@@ -133,19 +133,6 @@
     return rpc_utils.prepare_for_serialization(info)
 
 
-
-def get_test_logs_urls(**filter_data):
-    """
-    Return URLs to test logs for all tests matching the filter data.
-    """
-    query = models.TestView.query_objects(filter_data)
-    tests = set((test_view.job_tag, test_view.test) for test_view in query)
-    links = []
-    for job_tag, test in tests:
-        links.append('/results/' + job_tag + '/' + test)
-    return links
-
-
 def get_job_ids(**filter_data):
     """
     Returns AFE job IDs for all tests matching the filters.
diff --git a/new_tko/tko/rpc_interface_unittest.py b/new_tko/tko/rpc_interface_unittest.py
index f035363..aab54a1 100644
--- a/new_tko/tko/rpc_interface_unittest.py
+++ b/new_tko/tko/rpc_interface_unittest.py
@@ -1,9 +1,10 @@
-#!/usr/bin/python2.4
+#!/usr/bin/python
 
 import unittest
 import common
 from autotest_lib.new_tko import setup_django_environment
 from autotest_lib.frontend import setup_test_environment
+from autotest_lib.client.common_lib.test_utils import mock
 from django.db import connection
 from autotest_lib.new_tko.tko import models, rpc_interface
 
@@ -87,6 +88,7 @@
 
 class RpcInterfaceTest(unittest.TestCase):
     def setUp(self):
+        self._god = mock.mock_god()
         setup_test_environment.set_up()
         fix_iteration_tables()
         setup_test_view()
@@ -95,52 +97,68 @@
 
     def tearDown(self):
         setup_test_environment.tear_down()
+        self._god.unstub_all()
 
 
     def _create_initial_data(self):
-        machine = models.Machine(hostname='host1')
-        machine.save()
+        machine = models.Machine.objects.create(hostname='myhost')
 
-        kernel_name = 'mykernel'
-        kernel = models.Kernel(kernel_hash=kernel_name, base=kernel_name,
-                               printable=kernel_name)
-        kernel.save()
+        kernel_name = 'mykernel1'
+        kernel1 = models.Kernel.objects.create(kernel_hash=kernel_name,
+                                               base=kernel_name,
+                                               printable=kernel_name)
 
-        status = models.Status(word='GOOD')
-        status.save()
+        kernel_name = 'mykernel2'
+        kernel2 = models.Kernel.objects.create(kernel_hash=kernel_name,
+                                               base=kernel_name,
+                                               printable=kernel_name)
 
-        job = models.Job(tag='myjobtag', label='myjob', username='myuser',
-                         machine=machine)
-        job.save()
+        good_status = models.Status.objects.create(word='GOOD')
+        failed_status = models.Status.objects.create(word='FAILED')
 
-        test = models.Test(job=job, test='mytest', kernel=kernel,
-                                status=status, machine=machine)
-        test.save()
+        job1 = models.Job.objects.create(tag='1-myjobtag1', label='myjob1',
+                                         username='myuser', machine=machine)
+        job2 = models.Job.objects.create(tag='2-myjobtag2', label='myjob2',
+                                         username='myuser', machine=machine)
+
+        job1_test1 = models.Test.objects.create(job=job1, test='mytest1',
+                                                kernel=kernel1,
+                                                status=good_status,
+                                                machine=machine)
+        job1_test2 = models.Test.objects.create(job=job1, test='mytest2',
+                                                kernel=kernel1,
+                                                status=failed_status,
+                                                machine=machine)
+        job2_test1 = models.Test.objects.create(job=job2, test='kernbench',
+                                                kernel=kernel2,
+                                                status=good_status,
+                                                machine=machine)
 
         # like Noah's Ark, include two of each...just in case there's a bug with
         # multiple related items
+        models.TestAttribute.objects.create(test=job1_test1, attribute='myattr',
+                                            value='myval')
+        models.TestAttribute.objects.create(test=job1_test1,
+                                            attribute='myattr2', value='myval2')
 
-        (models.TestAttribute(test=test, attribute='myattr', value='myval')
-         .save())
-        (models.TestAttribute(test=test, attribute='myattr2', value='myval2')
-         .save())
-
-        # can't use models to add these, since they don't have real primary keys
-        self._add_iteration_keyval('iteration_attributes', test=test,
+        self._add_iteration_keyval('iteration_attributes', test=job1_test1,
                                    iteration=1, attribute='iattr',
                                    value='ival')
-        self._add_iteration_keyval('iteration_attributes', test=test,
+        self._add_iteration_keyval('iteration_attributes', test=job1_test1,
                                    iteration=1, attribute='iattr2',
                                    value='ival2')
-        self._add_iteration_keyval('iteration_result', test=test,
+        self._add_iteration_keyval('iteration_result', test=job1_test1,
                                    iteration=1, attribute='iresult',
                                    value=1)
-        self._add_iteration_keyval('iteration_result', test=test,
+        self._add_iteration_keyval('iteration_result', test=job1_test1,
                                    iteration=1, attribute='iresult2',
                                    value=2)
 
-        self._add_test_label(test, 'testlabel')
-        self._add_test_label(test, 'testlabel2')
+        label1 = models.TestLabel.objects.create(name='testlabel1')
+        label2 = models.TestLabel.objects.create(name='testlabel2')
+
+        label1.tests.add(job1_test1)
+        label2.tests.add(job1_test1)
 
 
     def _add_iteration_keyval(self, table, test, iteration, attribute, value):
@@ -149,22 +167,20 @@
                        (test.test_idx, iteration, attribute, value))
 
 
-    def _add_test_label(self, test, label_name):
-        test_label = models.TestLabel(name=label_name)
-        test_label.save()
-        test_label.tests.add(test)
+    def _check_for_get_test_views(self, test):
+        self.assertEquals(test['test_name'], 'mytest1')
+        self.assertEquals(test['job_tag'], '1-myjobtag1')
+        self.assertEquals(test['job_name'], 'myjob1')
+        self.assertEquals(test['job_owner'], 'myuser')
+        self.assertEquals(test['status'], 'GOOD')
+        self.assertEquals(test['hostname'], 'myhost')
+        self.assertEquals(test['kernel'], 'mykernel1')
 
 
     def test_get_detailed_test_views(self):
         test = rpc_interface.get_detailed_test_views()[0]
 
-        self.assertEquals(test['test_name'], 'mytest')
-        self.assertEquals(test['job_tag'], 'myjobtag')
-        self.assertEquals(test['job_name'], 'myjob')
-        self.assertEquals(test['job_owner'], 'myuser')
-        self.assertEquals(test['status'], 'GOOD')
-        self.assertEquals(test['hostname'], 'host1')
-        self.assertEquals(test['kernel'], 'mykernel')
+        self._check_for_get_test_views(test)
 
         self.assertEquals(test['attributes'], {'myattr': 'myval',
                                                'myattr2': 'myval2'})
@@ -172,23 +188,23 @@
                                                          'iattr2': 'ival2'},
                                                 'perf': {'iresult': 1,
                                                          'iresult2': 2}}])
-        self.assertEquals(test['labels'], ['testlabel', 'testlabel2'])
+        self.assertEquals(test['labels'], ['testlabel1', 'testlabel2'])
 
 
     def test_test_attributes(self):
-        rpc_interface.set_test_attribute('foo', 'bar', test_name='mytest')
+        rpc_interface.set_test_attribute('foo', 'bar', test_name='mytest1')
         test = rpc_interface.get_detailed_test_views()[0]
         self.assertEquals(test['attributes'], {'foo': 'bar',
                                                'myattr': 'myval',
                                                'myattr2': 'myval2'})
 
-        rpc_interface.set_test_attribute('foo', 'goo', test_name='mytest')
+        rpc_interface.set_test_attribute('foo', 'goo', test_name='mytest1')
         test = rpc_interface.get_detailed_test_views()[0]
         self.assertEquals(test['attributes'], {'foo': 'goo',
                                                'myattr': 'myval',
                                                'myattr2': 'myval2'})
 
-        rpc_interface.set_test_attribute('foo', None, test_name='mytest')
+        rpc_interface.set_test_attribute('foo', None, test_name='mytest1')
         test = rpc_interface.get_detailed_test_views()[0]
         self.assertEquals(test['attributes'], {'myattr': 'myval',
                                                'myattr2': 'myval2'})
@@ -196,7 +212,136 @@
 
     def test_immutable_attributes(self):
         self.assertRaises(ValueError, rpc_interface.set_test_attribute,
-                          'myattr', 'foo', test_name='mytest')
+                          'myattr', 'foo', test_name='mytest1')
+
+
+    def test_get_test_views(self):
+        tests = rpc_interface.get_test_views()
+
+        self.assertEquals(len(tests), 3)
+        test = rpc_interface.get_test_views(
+            job_name='myjob1', test_name='mytest1')[0]
+        self.assertEquals(tests[0], test)
+
+        self._check_for_get_test_views(test)
+
+        self.assertEquals(
+            [], rpc_interface.get_test_views(hostname='fakehost'))
+
+
+    def test_get_num_test_views(self):
+        self.assertEquals(rpc_interface.get_num_test_views(), 3)
+        self.assertEquals(rpc_interface.get_num_test_views(
+            job_name='myjob1', test_name='mytest1'), 1)
+
+
+    def _get_column_names_for_sqlite3(self, cursor):
+        names = [column_info[0] for column_info in cursor.description]
+
+        # replace all "table_name"."column_name" constructs with just
+        # column_name
+        for i, name in enumerate(names):
+            if '.' in name:
+                field_name = name.split('.', 1)[1]
+                names[i] = field_name.strip('"')
+
+        return names
+
+
+    def test_get_group_counts(self):
+        self._god.stub_with(models.TempManager, '_get_column_names',
+                            self._get_column_names_for_sqlite3)
+
+        self.assertEquals(rpc_interface.get_num_groups(['job_name']), 2)
+
+        counts = rpc_interface.get_group_counts(['job_name'])
+        groups = counts['groups']
+        self.assertEquals(len(groups), 2)
+        group1 = groups[0]
+        group2 = groups[1]
+
+        self.assertEquals(group1['group_count'], 2)
+        self.assertEquals(group1['job_name'], 'myjob1')
+        self.assertEquals(group2['group_count'], 1)
+        self.assertEquals(group2['job_name'], 'myjob2')
+
+        extra = {'extra' : 'kernel_hash'}
+        counts = rpc_interface.get_group_counts(['job_name'],
+                                                header_groups=[('job_name',)],
+                                                extra_select_fields=extra)
+        groups = counts['groups']
+        self.assertEquals(len(groups), 2)
+        group1 = groups[0]
+        group2 = groups[1]
+
+        self.assertEquals(group1['group_count'], 2)
+        self.assertEquals(group1['header_indices'], [0])
+        self.assertEquals(group1['extra'], 'mykernel1')
+        self.assertEquals(group2['group_count'], 1)
+        self.assertEquals(group2['header_indices'], [1])
+        self.assertEquals(group2['extra'], 'mykernel2')
+
+
+    def test_get_status_counts(self):
+        """\
+        This method cannot be tested with a sqlite3 test framework. The method
+        relies on the IF function, which is not present in sqlite3.
+        """
+
+
+    def test_get_latest_tests(self):
+        """\
+        This method cannot be tested with a sqlite3 test framework. The method
+        relies on the IF function, which is not present in sqlite3.
+        """
+
+
+    def test_get_job_ids(self):
+        self.assertEquals([1,2], rpc_interface.get_job_ids())
+        self.assertEquals([1], rpc_interface.get_job_ids(test_name='mytest2'))
+
+
+    def test_get_hosts_and_tests(self):
+        host_info = rpc_interface.get_hosts_and_tests()
+        self.assertEquals(len(host_info), 1)
+        info = host_info['myhost']
+
+        self.assertEquals(info['tests'], ['kernbench'])
+        self.assertEquals(info['id'], 1)
+
+
+    def _check_for_get_test_labels(self, label, label_num):
+        self.assertEquals(label['id'], label_num)
+        self.assertEquals(label['description'], '')
+        self.assertEquals(label['name'], 'testlabel%d' % label_num)
+
+
+    def test_test_labels(self):
+        labels = rpc_interface.get_test_labels_for_tests(test_name='mytest1')
+        self.assertEquals(len(labels), 2)
+        label1 = labels[0]
+        label2 = labels[1]
+
+        self._check_for_get_test_labels(label1, 1)
+        self._check_for_get_test_labels(label2, 2)
+
+        rpc_interface.test_label_remove_tests(label1['id'], test_name='mytest1')
+
+        labels = rpc_interface.get_test_labels_for_tests(test_name='mytest1')
+        self.assertEquals(len(labels), 1)
+        label = labels[0]
+
+        self._check_for_get_test_labels(label, 2)
+
+        rpc_interface.test_label_add_tests(label1['id'], test_name='mytest1')
+
+        labels = rpc_interface.get_test_labels_for_tests(test_name='mytest1')
+        self.assertEquals(len(labels), 2)
+        label1 = labels[0]
+        label2 = labels[1]
+
+        self._check_for_get_test_labels(label1, 1)
+        self._check_for_get_test_labels(label2, 2)
 
 
 if __name__ == '__main__':