-add easy invalidation functionality
 -add "show invalidated tests" option to common panel, disabled by default
 -made client submit "exclude_labels" option to exclude invalid tests.  this required somewhat widespread changes because it means the global condition is no longer just a SQL string but now a collection of parameters
 -add "invalidate tests" option to spreadsheet/table context menu, and button to test detail view.  it's really just a shortcut to add the "invalidated" label.
 -added logic to the server to handle "exclude_labels" option. it was done in this generic way because in the future i plan to add a UI to exclude any label or labels.
-force test label names to be unique
-fix a bug in logic to determine all labels assigned to a set of tests
-got rid of auto-refresh when changing between spreadsheet and table after the condition had changed

Signed-off-by: Steve Howard <showard@google.com>


git-svn-id: http://test.kernel.org/svn/autotest/trunk@2099 592f7852-d20e-0410-864c-8624ca9c26a4
diff --git a/new_tko/tko/models.py b/new_tko/tko/models.py
index 1e40240..c893f8b 100644
--- a/new_tko/tko/models.py
+++ b/new_tko/tko/models.py
@@ -26,23 +26,20 @@
                          [count_sql + ' AS ' + self._GROUP_COUNT_NAME] +
                          extra_select_fields)
 
-        # add the count field and all group fields to the query selects, so
-        # they'll be sortable and Django won't mess with any of them
-        #for field in group_fields + [self._GROUP_COUNT_NAME]:
-        #    query._select[field] = ''
+        # add the count field to the query selects, so they'll be sortable and
+        # Django won't mess with any of them
         query._select[self._GROUP_COUNT_NAME] = count_sql
 
-        # Inject the GROUP_BY clause into the query by adding it to the end of
-        # the queries WHERE clauses. We need it to come before the ORDER BY and
-        # LIMIT clauses.
-        num_real_where_clauses = len(query._where)
-        query._where.append('GROUP BY ' + ', '.join(group_fields))
         _, where, params = query._get_sql_clause()
-        if num_real_where_clauses == 0:
-            # handle the special case where there were no actual WHERE clauses
-            where = where.replace('WHERE GROUP BY', 'GROUP BY')
-        else:
-            where = where.replace('AND GROUP BY', 'GROUP BY')
+
+        # insert GROUP BY clause into query
+        group_by_clause = 'GROUP BY ' + ', '.join(group_fields)
+        group_by_position = where.rfind('ORDER BY')
+        if group_by_position == -1:
+            group_by_position = len(where)
+        where = (where[:group_by_position] +
+                 group_by_clause + ' ' +
+                 where[group_by_position:])
 
         return ('SELECT ' + ', '.join(select_fields) + where), params
 
@@ -179,7 +176,7 @@
 
 
 class TestLabel(dbmodels.Model, model_logic.ModelExtensions):
-    name = dbmodels.CharField(maxlength=80)
+    name = dbmodels.CharField(maxlength=80, unique=True)
     description = dbmodels.TextField(blank=True)
     tests = dbmodels.ManyToManyField(Test, blank=True,
                                      filter_interface=dbmodels.HORIZONTAL)
@@ -203,9 +200,10 @@
 # views
 
 class TestViewManager(TempManager):
-    class _JoinQ(dbmodels.Q):
+    class _CustomSqlQ(dbmodels.Q):
         def __init__(self):
             self._joins = datastructures.SortedDict()
+            self._where, self._params = [], []
 
 
         def add_join(self, table, condition, join_type, alias=None):
@@ -214,8 +212,13 @@
             self._joins[alias] = (table, join_type, condition)
 
 
+        def add_where(self, where, params=[]):
+            self._where.append(where)
+            self._params.extend(params)
+
+
         def get_sql(self, opts):
-            return self._joins, [], []
+            return self._joins, self._where, self._params
 
 
     def get_query_set(self):
@@ -227,27 +230,59 @@
         return query.extra(select=extra_select)
 
 
-    def _add_label_joins(self, query_set):
+    def _add_label_joins(self, query_set, suffix = '', join_condition='',
+                         exclude=False):
         table_name = self.model._meta.db_table
-        filter_object = self._JoinQ()
-        filter_object.add_join(
-            'test_labels_tests',
-            'test_labels_tests.test_id = %s.test_idx' % table_name,
-            'LEFT JOIN')
-        filter_object.add_join(
-            'test_labels',
-            'test_labels.id = test_labels_tests.testlabel_id',
-            'LEFT JOIN')
-        return query_set.complex_filter(filter_object).distinct()
+        first_join_alias = 'test_labels_tests' + suffix
+        first_join_condition = '%s.test_id = %s.test_idx' % (first_join_alias,
+                                                             table_name)
+        if join_condition:
+            first_join_condition += ' AND ' + join_condition
+        filter_object = self._CustomSqlQ()
+        filter_object.add_join('test_labels_tests',
+                               first_join_condition,
+                               'LEFT JOIN',
+                               alias=first_join_alias)
+
+        second_join_alias = 'test_labels' + suffix
+        second_join_condition = ('%s.id = %s.testlabel_id' %
+                                 (second_join_alias, first_join_alias))
+        filter_object.add_join('test_labels',
+                               second_join_condition,
+                               'LEFT JOIN',
+                               alias=second_join_alias)
+
+        if exclude:
+            filter_object.add_where(first_join_alias + '.testlabel_id IS NULL')
+        return query_set.filter(filter_object).distinct()
+
+
+    def _get_label_ids_from_names(self, label_names):
+        query = TestLabel.objects.filter(name__in=label_names).values('id')
+        return [label['id'] for label in query]
 
 
     def get_query_set_with_labels(self, filter_data):
+        exclude_labels = filter_data.pop('exclude_labels', [])
         query_set = self.get_query_set()
+        joined = False
         # TODO: make this check more thorough if necessary
         if 'test_labels' in filter_data.get('extra_where', ''):
             query_set = self._add_label_joins(query_set)
-        else:
+            joined = True
+
+        if exclude_labels:
+            label_ids = self._get_label_ids_from_names(exclude_labels)
+            condition = ('test_labels_tests_exclude.testlabel_id IN (%s)' %
+                         ','.join(str(label_id) for label_id in label_ids))
+            query_set = self._add_label_joins(query_set, suffix='_exclude',
+                                              join_condition=condition,
+                                              exclude=True)
+            joined = True
+
+        if not joined:
             filter_data['no_distinct'] = True
+
         return query_set
 
 
@@ -269,9 +304,10 @@
 
 
     def query_test_label_ids(self, filter_data):
-        query_set = self._add_label_joins(self.get_query_set()).distinct()
-        rows = self._custom_select_query(query_set, ['test_labels.id'])
-        return [row[0] for row in rows] # flatten rows to a list of ids
+        query_set = self.model.query_objects(filter_data)
+        query_set = self._add_label_joins(query_set, suffix='_list')
+        rows = self._custom_select_query(query_set, ['test_labels_list.id'])
+        return [row[0] for row in rows if row[0] is not None]
 
 
 class TestView(dbmodels.Model, model_logic.ModelExtensions):