Add the ability to exclude hosts in an atomic group from the get_hosts RPC.

Signed-off-by: Gregory Smith <gps@google.com>


git-svn-id: http://test.kernel.org/svn/autotest/trunk@3581 592f7852-d20e-0410-864c-8624ca9c26a4
diff --git a/frontend/afe/rpc_interface.py b/frontend/afe/rpc_interface.py
index dc13ea0..bfd0810 100644
--- a/frontend/afe/rpc_interface.py
+++ b/frontend/afe/rpc_interface.py
@@ -165,16 +165,19 @@
     models.Host.smart_get(id).delete()
 
 
-def get_hosts(multiple_labels=[], exclude_only_if_needed_labels=False,
-              **filter_data):
-    """\
-    multiple_labels: match hosts in all of the labels given.  Should be a
-    list of label names.
-    exclude_only_if_needed_labels: exclude hosts with at least one
-    "only_if_needed" label applied.
+def get_hosts(multiple_labels=(), exclude_only_if_needed_labels=False,
+              exclude_atomic_group_hosts=False, **filter_data):
+    """
+    @param multiple_labels: match hosts in all of the labels given.  Should
+            be a list of label names.
+    @param exclude_only_if_needed_labels: Exclude hosts with at least one
+            "only_if_needed" label applied.
+    @param exclude_atomic_group_hosts: Exclude hosts that have one or more
+            atomic group labels associated with them.
     """
     hosts = rpc_utils.get_host_query(multiple_labels,
                                      exclude_only_if_needed_labels,
+                                     exclude_atomic_group_hosts,
                                      filter_data)
     hosts = list(hosts)
     models.Host.objects.populate_relationships(hosts, models.Label,
@@ -196,10 +199,16 @@
     return rpc_utils.prepare_for_serialization(host_dicts)
 
 
-def get_num_hosts(multiple_labels=[], exclude_only_if_needed_labels=False,
-                  **filter_data):
+def get_num_hosts(multiple_labels=(), exclude_only_if_needed_labels=False,
+                  exclude_atomic_group_hosts=False, **filter_data):
+    """
+    Same parameters as get_hosts().
+
+    @returns The number of matching hosts.
+    """
     hosts = rpc_utils.get_host_query(multiple_labels,
                                      exclude_only_if_needed_labels,
+                                     exclude_atomic_group_hosts,
                                      filter_data)
     return hosts.count()
 
diff --git a/frontend/afe/rpc_interface_unittest.py b/frontend/afe/rpc_interface_unittest.py
index b298029..c87bad6 100644
--- a/frontend/afe/rpc_interface_unittest.py
+++ b/frontend/afe/rpc_interface_unittest.py
@@ -76,6 +76,23 @@
         self._check_hostnames(hosts, ['host2'])
 
 
+    def test_get_hosts_exclude_atomic_group_hosts(self):
+        hosts = rpc_interface.get_hosts(
+                exclude_atomic_group_hosts=True,
+                hostname__in=['host4', 'host5', 'host6'])
+        self._check_hostnames(hosts, ['host4'])
+
+
+    def test_get_hosts_exclude_both(self):
+        self.hosts[0].labels.add(self.label3)
+
+        hosts = rpc_interface.get_hosts(
+                hostname__in=['host1', 'host2', 'host5'],
+                exclude_only_if_needed_labels=True,
+                exclude_atomic_group_hosts=True)
+        self._check_hostnames(hosts, ['host2'])
+
+
     def test_get_jobs_summary(self):
         job = self._create_job(hosts=xrange(1, 4))
         entries = list(job.hostqueueentry_set.all())
diff --git a/frontend/afe/rpc_utils.py b/frontend/afe/rpc_utils.py
index 7c58ec5..2b0fa83 100644
--- a/frontend/afe/rpc_utils.py
+++ b/frontend/afe/rpc_utils.py
@@ -112,7 +112,7 @@
     return {'where': where}
 
 
-def extra_host_filters(multiple_labels=[]):
+def extra_host_filters(multiple_labels=()):
     """\
     Generate SQL WHERE clauses for matching hosts in an intersection of
     labels.
@@ -126,19 +126,33 @@
     return extra_args
 
 
-def get_host_query(multiple_labels, exclude_only_if_needed_labels, filter_data):
+def get_host_query(multiple_labels, exclude_only_if_needed_labels,
+                   exclude_atomic_group_hosts, filter_data):
     query = models.Host.valid_objects.all()
     if exclude_only_if_needed_labels:
         only_if_needed_labels = models.Label.valid_objects.filter(
             only_if_needed=True)
         if only_if_needed_labels.count() > 0:
-            only_if_needed_ids = ','.join(str(label['id']) for label
-                                          in only_if_needed_labels.values('id'))
+            only_if_needed_ids = ','.join(
+                    str(label['id'])
+                    for label in only_if_needed_labels.values('id'))
             query = models.Host.objects.add_join(
                 query, 'hosts_labels', join_key='host_id',
-                join_condition='hosts_labels_exclude.label_id IN (%s)'
-                               % only_if_needed_ids,
-                suffix='_exclude', exclude=True)
+                join_condition=('hosts_labels_exclude_OIN.label_id IN (%s)'
+                                % only_if_needed_ids),
+                suffix='_exclude_OIN', exclude=True)
+    if exclude_atomic_group_hosts:
+        atomic_group_labels = models.Label.valid_objects.filter(
+                atomic_group__isnull=False)
+        if atomic_group_labels.count() > 0:
+            atomic_group_label_ids = ','.join(
+                    str(atomic_group['id'])
+                    for atomic_group in atomic_group_labels.values('id'))
+            query = models.Host.objects.add_join(
+                    query, 'hosts_labels', join_key='host_id',
+                    join_condition=('hosts_labels_exclude_AG.label_id IN (%s)'
+                                    % atomic_group_label_ids),
+                    suffix='_exclude_AG', exclude=True)
     filter_data['extra_args'] = (extra_host_filters(multiple_labels))
     return models.Host.query_objects(filter_data, initial_query=query)