refactor and cleanup some database joining code for TKO queries

Signed-off-by: Steve Howard <showard@google.com>


git-svn-id: http://test.kernel.org/svn/autotest/trunk@3585 592f7852-d20e-0410-864c-8624ca9c26a4
diff --git a/new_tko/tko/models.py b/new_tko/tko/models.py
index 7ee1598..2749704 100644
--- a/new_tko/tko/models.py
+++ b/new_tko/tko/models.py
@@ -295,17 +295,33 @@
 
     def _get_include_exclude_suffix(self, exclude):
         if exclude:
-            suffix = '_exclude'
-        else:
-            suffix = '_include'
-        return suffix
+            return '_exclude'
+        return '_include'
+
+
+    def _add_attribute_join(self, query_set, join_condition,
+                            suffix=None, exclude=False):
+        if suffix is None:
+            suffix = self._get_include_exclude_suffix(exclude)
+        return self.add_join(query_set, 'test_attributes', join_key='test_idx',
+                             join_condition=join_condition,
+                             suffix=suffix, exclude=exclude)
+
+
+    def _add_label_pivot_table_join(self, query_set, suffix, join_condition='',
+                                    exclude=False, force_left_join=False):
+        return self.add_join(query_set, 'test_labels_tests', join_key='test_id',
+                             join_condition=join_condition,
+                             suffix=suffix, exclude=exclude,
+                             force_left_join=force_left_join)
 
 
     def _add_label_joins(self, query_set, suffix=''):
-        query_set = self.add_join(query_set, 'test_labels_tests',
-                                  join_key='test_id', suffix=suffix,
-                                  force_left_join=True)
+        query_set = self._add_label_pivot_table_join(
+                query_set, suffix=suffix, force_left_join=True)
 
+        # since we're not joining from the original table, we can't use
+        # self.add_join() again
         second_join_alias = 'test_labels' + suffix
         second_join_condition = ('%s.id = %s.testlabel_id' %
                                  (second_join_alias,
@@ -318,22 +334,25 @@
         return self._add_customSqlQ(query_set, filter_object)
 
 
-    def _add_attribute_join(self, query_set, join_condition='', suffix=None,
-                            exclude=False):
-        join_condition = self.escape_user_sql(join_condition)
-        if suffix is None:
-            suffix = self._get_include_exclude_suffix(exclude)
-        return self.add_join(query_set, 'test_attributes',
-                              join_key='test_idx',
-                              join_condition=join_condition,
-                              suffix=suffix, exclude=exclude)
-
-
     def _get_label_ids_from_names(self, label_names):
-        if not label_names:
-            return []
-        query = TestLabel.objects.filter(name__in=label_names).values('id')
-        return [str(label['id']) for label in query]
+        assert label_names
+        label_ids = list( # listifying avoids a double query below
+                TestLabel.objects.filter(name__in=label_names).values('id'))
+        if len(label_ids) < len(set(label_names)):
+                raise ValueError('Not all labels found: %s' %
+                                 ', '.join(label_names))
+        return [str(label['id']) for label in label_ids]
+
+
+    def _include_or_exclude_labels(self, query_set, label_names, exclude=False):
+        label_ids = self._get_label_ids_from_names(label_names)
+        suffix = self._get_include_exclude_suffix(exclude)
+        condition = ('test_labels_tests%s.testlabel_id IN (%s)' %
+                     (suffix, ','.join(label_ids)))
+        return self._add_label_pivot_table_join(query_set,
+                                                join_condition=condition,
+                                                suffix=suffix,
+                                                exclude=exclude)
 
 
     def get_query_set_with_joins(self, filter_data, include_host_labels=False):
@@ -341,32 +360,22 @@
         exclude_labels = filter_data.pop('exclude_labels', [])
         query_set = self.get_query_set()
         joined = False
-        # TODO: make this check more thorough if necessary
+
+        # TODO: make this feature obsolete in favor of include_labels and
+        # exclude_labels
         extra_where = filter_data.get('extra_where', '')
         if 'test_labels' in extra_where:
             query_set = self._add_label_joins(query_set)
             joined = True
 
-        include_label_ids = self._get_label_ids_from_names(include_labels)
-        if include_label_ids:
-            # TODO: Factor this out like what's done with attributes
-            condition = ('test_labels_tests_include.testlabel_id IN (%s)' %
-                         ','.join(include_label_ids))
-            query_set = self.add_join(query_set, 'test_labels_tests',
-                                       join_key='test_id',
-                                       suffix='_include',
-                                       join_condition=condition)
+        if include_labels:
+            query_set = self._include_or_exclude_labels(query_set,
+                                                        include_labels)
             joined = True
-
-        exclude_label_ids = self._get_label_ids_from_names(exclude_labels)
-        if exclude_label_ids:
-            condition = ('test_labels_tests_exclude.testlabel_id IN (%s)' %
-                         ','.join(exclude_label_ids))
-            query_set = self.add_join(query_set, 'test_labels_tests',
-                                       join_key='test_id',
-                                       suffix='_exclude',
-                                       join_condition=condition,
-                                       exclude=True)
+        if exclude_labels:
+            query_set = self._include_or_exclude_labels(query_set,
+                                                        exclude_labels,
+                                                        exclude=True)
             joined = True
 
         include_attributes_where = filter_data.pop('include_attributes_where',
@@ -375,17 +384,20 @@
                                                    '')
         if include_attributes_where:
             query_set = self._add_attribute_join(
-                query_set, join_condition=include_attributes_where)
+                query_set,
+                join_condition=self.escape_user_sql(include_attributes_where))
             joined = True
         if exclude_attributes_where:
             query_set = self._add_attribute_join(
-                query_set, join_condition=exclude_attributes_where,
+                query_set,
+                join_condition=self.escape_user_sql(exclude_attributes_where),
                 exclude=True)
             joined = True
 
         if not joined:
             filter_data['no_distinct'] = True
 
+        # TODO: make test_attributes_host_labels obsolete too
         if include_host_labels or 'test_attributes_host_labels' in extra_where:
             query_set = self._add_attribute_join(
                 query_set, suffix='_host_labels',