Various changes to support further high-level automation efforts.

* added a RESTful interface for TKO.  right now there's only a single, simple resource for accessing test attributes.
* extended the REST server library in a few ways, most notably to support
* querying on keyvals, with something like ?has_keyval=mykey=myvalue&...
* operators, delimited by a colon, like ?hostname:in=host1,host2,host3
* loading relationships over many items efficiently (see InstanceEntry.prepare_for_full_representation()).  this is used to fill in keyvals when requesting a job listing, but it can (and should) be used in other places, such as listing labels for a host collection.
* loading a collection with inlined full representations, by passing full_representations=true
* added various features to the AFE RESTful interface as necessary.
* various fixes to the rest_client library, most notably
* changed HTTP client in rest_client.py to use DI rather than singleton, easing testability.  the same should be done for _get_request_headers(), to be honest.
* better support for query params, including accepting a MultiValueDict and supporting URIs that already have query args
* basic support for redirects
* builtin support for requesting a full collection (get_full()), when clients explicitly expect the result not to be paged.  i'm still considering alternative approaches to this -- it may make sense to have something like this be the default, and have clients set a default page size limit rather than passing it every time.
* minor change to mock.py to provide better debugging output.

Signed-off-by: Steve Howard <showard@google.com>



git-svn-id: http://test.kernel.org/svn/autotest/trunk@4438 592f7852-d20e-0410-864c-8624ca9c26a4
diff --git a/frontend/shared/query_lib.py b/frontend/shared/query_lib.py
index 63ce262..5515b60 100644
--- a/frontend/shared/query_lib.py
+++ b/frontend/shared/query_lib.py
@@ -22,7 +22,16 @@
     def add_related_existence_selector(self, name, model, field, doc=None):
         self.add_selector(
                 Selector(name, doc),
-                _RelatedExistenceConstraint(model, field, self.make_alias))
+                _RelatedExistenceConstraint(model, field,
+                                            make_alias_fn=self.make_alias))
+
+
+    def add_keyval_selector(self, name, model, key_field, value_field,
+                            doc=None):
+        self.add_selector(
+                Selector(name, doc),
+                _KeyvalConstraint(model, key_field, value_field,
+                                  make_alias_fn=self.make_alias))
 
 
     def add_selector(self, selector, constraint):
@@ -46,7 +55,9 @@
 
 
     def apply_selector(self, queryset, selector_name, value,
-                       comparison_type='equals', is_inverse=False):
+                       comparison_type=None, is_inverse=False):
+        if comparison_type is None:
+            comparison_type = 'equals'
         _, constraint = self._selectors[selector_name]
         try:
             return constraint.apply_constraint(queryset, value, comparison_type,
@@ -103,6 +114,8 @@
 
         kwarg_name = str(self._field + '__' +
                          self._COMPARISON_MAP[comparison_type])
+        if comparison_type == 'in':
+            value = value.split(',')
 
         if is_inverse:
             return queryset.exclude(**{kwarg_name: value})
@@ -138,3 +151,38 @@
         queryset = queryset.model.objects.add_where(queryset, condition)
 
         return queryset
+
+
+class _KeyvalConstraint(Constraint):
+    def __init__(self, model, key_field, value_field, make_alias_fn):
+        self._model = model
+        self._key_field = key_field
+        self._value_field = value_field
+        self._make_alias_fn = make_alias_fn
+
+
+    def apply_constraint(self, queryset, value, comparison_type, is_inverse):
+        if comparison_type not in (None, 'equals'):
+            raise ConstraintError('Can only use equals or not equals with '
+                                  'this selector')
+        if '=' not in value:
+            raise ConstraintError('You must specify a key=value pair for this '
+                                  'selector')
+
+        key, actual_value = value.split('=', 1)
+        related_query = self._model.objects.filter(
+                **{self._key_field: key, self._value_field: actual_value})
+        alias = self._make_alias_fn()
+        queryset = queryset.model.objects.join_custom_field(queryset,
+                                                            related_query,
+                                                            alias)
+        if is_inverse:
+            condition = '%s.%s IS NULL'
+        else:
+            condition = '%s.%s IS NOT NULL'
+        condition %= (alias,
+                      queryset.model.objects.key_on_joined_table(related_query))
+
+        queryset = queryset.model.objects.add_where(queryset, condition)
+
+        return queryset
diff --git a/frontend/shared/resource_lib.py b/frontend/shared/resource_lib.py
index 5d77410..fff8ff0 100644
--- a/frontend/shared/resource_lib.py
+++ b/frontend/shared/resource_lib.py
@@ -131,10 +131,12 @@
 
     def read_query_parameters(self, parameters):
         """Read relevant query parameters from a Django MultiValueDict."""
-        for param_name, _ in self._query_parameters_accepted():
-            if param_name in parameters:
-                self._query_params.setlist(param_name,
-                                           parameters.getlist(param_name))
+        params_acccepted = set(param_name for param_name, _
+                               in self._query_parameters_accepted())
+        for name, values in parameters.iterlists():
+            base_name = name.split(':', 1)[0]
+            if base_name in params_acccepted:
+                self._query_params.setlist(name, values)
 
 
     def set_query_parameters(self, **parameters):
@@ -216,6 +218,9 @@
             except ValueError, exc:
                 raise exceptions.BadRequest('Error decoding request body: '
                                             '%s\n%r' % (exc, raw_data))
+            if not isinstance(raw_dict, dict):
+                raise exceptions.BadRequest('Expected dict input, got %s: %r' %
+                                            (type(raw_dict), raw_dict))
         elif content_type == 'application/x-www-form-urlencoded':
             cgi_dict = cgi.parse_qs(raw_data) # django won't do this for PUT
             raw_dict = {}
@@ -319,6 +324,7 @@
         assert self.model is not None
         super(Entry, self).__init__(request)
         self.instance = instance
+        self._is_prepared_for_full_representation = False
 
 
     @classmethod
@@ -332,6 +338,41 @@
         self.instance.delete()
 
 
+    def full_representation(self):
+        self.prepare_for_full_representation([self])
+        return super(InstanceEntry, self).full_representation()
+
+
+    @classmethod
+    def prepare_for_full_representation(cls, entries):
+        """
+        Prepare the given list of entries to generate full representations.
+
+        This method delegates to _do_prepare_for_full_representation(), which
+        subclasses may override as necessary to do the actual processing.  This
+        method also marks the instance as prepared, so it's safe to call this
+        multiple times with the same instance(s) without wasting work.
+        """
+        not_prepared = [entry for entry in entries
+                        if not entry._is_prepared_for_full_representation]
+        cls._do_prepare_for_full_representation([entry.instance
+                                                 for entry in not_prepared])
+        for entry in not_prepared:
+            entry._is_prepared_for_full_representation = True
+
+
+    @classmethod
+    def _do_prepare_for_full_representation(cls, instances):
+        """
+        Subclasses may override this to gather data as needed for full
+        representations of the given model instances.  Typically, this involves
+        querying over related objects, and this method offers a chance to query
+        for many instances at once, which can provide a great performance
+        benefit.
+        """
+        pass
+
+
 class Collection(Resource):
     _DEFAULT_ITEMS_PER_PAGE = 50
 
@@ -354,7 +395,9 @@
 
     def _query_parameters_accepted(self):
         params = [('start_index', 'Index of first member to include'),
-                  ('items_per_page', 'Number of members to include')]
+                  ('items_per_page', 'Number of members to include'),
+                  ('full_representations',
+                   'True to include full representations of members')]
         for selector in self._query_processor.selectors():
             params.append((selector.name, selector.doc))
         return params
@@ -371,16 +414,33 @@
 
 
     def _representation(self, entry_instances):
+        entries = [self._entry_from_instance(instance)
+                   for instance in entry_instances]
+
+        want_full_representation = self._read_bool_parameter(
+                'full_representations')
+        if want_full_representation:
+            self.entry_class.prepare_for_full_representation(entries)
+
         members = []
-        for instance in entry_instances:
-            entry = self._entry_from_instance(instance)
-            members.append(entry.short_representation())
+        for entry in entries:
+            if want_full_representation:
+                rep = entry.full_representation()
+            else:
+                rep = entry.short_representation()
+            members.append(rep)
 
         rep = self.link()
         rep.update({'members': members})
         return rep
 
 
+    def _read_bool_parameter(self, name):
+        if name not in self._query_params:
+            return False
+        return (self._query_params[name].lower() == 'true')
+
+
     def _read_int_parameter(self, name, default):
         if name not in self._query_params:
             return default
@@ -395,12 +455,17 @@
     def _apply_form_query(self, queryset):
         """Apply any query selectors passed as form variables."""
         for parameter, values in self._query_params.lists():
+            if ':' in parameter:
+                parameter, comparison_type = parameter.split(':', 1)
+            else:
+                comparison_type = None
+
             if not self._query_processor.has_selector(parameter):
                 continue
             for value in values: # forms keys can have multiple values
-                queryset = self._query_processor.apply_selector(queryset,
-                                                                parameter,
-                                                                value)
+                queryset = self._query_processor.apply_selector(
+                        queryset, parameter, value,
+                        comparison_type=comparison_type)
         return queryset
 
 
diff --git a/frontend/shared/resource_test_utils.py b/frontend/shared/resource_test_utils.py
new file mode 100644
index 0000000..8cb742f
--- /dev/null
+++ b/frontend/shared/resource_test_utils.py
@@ -0,0 +1,136 @@
+import operator, unittest
+import simplejson
+from django.test import client
+from autotest_lib.frontend.afe import frontend_test_utils, models as afe_models
+
+class ResourceTestCase(unittest.TestCase,
+                       frontend_test_utils.FrontendTestMixin):
+    URI_PREFIX = None # subclasses may override this to use partial URIs
+
+    def setUp(self):
+        super(ResourceTestCase, self).setUp()
+        self._frontend_common_setup()
+        self._setup_debug_user()
+        self.client = client.Client()
+
+
+    def tearDown(self):
+        super(ResourceTestCase, self).tearDown()
+        self._frontend_common_teardown()
+
+
+    def _setup_debug_user(self):
+        user = afe_models.User.objects.create(login='debug_user')
+        acl = afe_models.AclGroup.objects.get(name='my_acl')
+        user.aclgroup_set.add(acl)
+
+
+    def _expected_status(self, method):
+        if method == 'post':
+            return 201
+        if method == 'delete':
+            return 204
+        return 200
+
+
+    def raw_request(self, method, uri, **kwargs):
+        method = method.lower()
+        if method == 'put':
+            # the put() implementation in Django's test client is poorly
+            # implemented and only supports url-encoded keyvals for the data.
+            # the post() implementation is correct, though, so use that, with a
+            # trick to override the method.
+            method = 'post'
+            kwargs['REQUEST_METHOD'] = 'PUT'
+
+        client_method = getattr(self.client, method)
+        return client_method(uri, **kwargs)
+
+
+    def request(self, method, uri, encode_body=True, **kwargs):
+        expected_status = self._expected_status(method)
+
+        if 'data' in kwargs:
+            kwargs.setdefault('content_type', 'application/json')
+            if kwargs['content_type'] == 'application/json':
+                kwargs['data'] = simplejson.dumps(kwargs['data'])
+
+        if uri.startswith('http://'):
+            full_uri = uri
+        else:
+            assert self.URI_PREFIX
+            full_uri = self.URI_PREFIX + '/' + uri
+
+        response = self.raw_request(method, full_uri, **kwargs)
+        self.assertEquals(
+                response.status_code, expected_status,
+                'Requesting %s\nExpected %s, got %s: %s (headers: %s)'
+                % (full_uri, expected_status, response.status_code,
+                   response.content, response._headers))
+
+        if response['content-type'] != 'application/json':
+            return response.content
+
+        try:
+            return simplejson.loads(response.content)
+        except ValueError:
+            self.fail('Invalid reponse body: %s' % response.content)
+
+
+    def sorted_by(self, collection, attribute):
+        return sorted(collection, key=operator.itemgetter(attribute))
+
+
+    def _read_attribute(self, item, attribute_or_list):
+        if isinstance(attribute_or_list, basestring):
+            attribute_or_list = [attribute_or_list]
+        for attribute in attribute_or_list:
+            item = item[attribute]
+        return item
+
+
+    def check_collection(self, collection, attribute_or_list, expected_list,
+                         length=None, check_number=None):
+        """Check the members of a collection of dicts.
+
+        @param collection: an iterable of dicts
+        @param attribute_or_list: an attribute or list of attributes to read.
+                the results will be sorted and compared with expected_list. if
+                a list of attributes is given, the attributes will be read
+                hierarchically, i.e. item[attribute1][attribute2]...
+        @param expected_list: list of expected values
+        @param check_number: if given, only check this number of entries
+        @param length: expected length of list, only necessary if check_number
+                is given
+        """
+        actual_list = sorted(self._read_attribute(item, attribute_or_list)
+                             for item in collection['members'])
+        if length is None and check_number is None:
+            length = len(expected_list)
+        if length is not None:
+            self.assertEquals(len(actual_list), length,
+                              'Expected %s, got %s: %s'
+                              % (length, len(actual_list),
+                                 ', '.join(str(item) for item in actual_list)))
+        if check_number:
+            actual_list = actual_list[:check_number]
+        self.assertEquals(actual_list, expected_list)
+
+
+    def check_relationship(self, resource_uri, relationship_name,
+                           other_entry_name, field, expected_values,
+                           length=None, check_number=None):
+        """Check the members of a relationship collection.
+
+        @param resource_uri: URI of base resource
+        @param relationship_name: name of relationship attribute on base
+                resource
+        @param other_entry_name: name of other entry in relationship
+        @param field: name of field to grab on other entry
+        @param expected values: list of expected values for the given field
+        """
+        response = self.request('get', resource_uri)
+        relationship_uri = response[relationship_name]['href']
+        relationships = self.request('get', relationship_uri)
+        self.check_collection(relationships, [other_entry_name, field],
+                              expected_values, length, check_number)
diff --git a/frontend/shared/rest_client.py b/frontend/shared/rest_client.py
index b5b83af..9c4f5d5 100644
--- a/frontend/shared/rest_client.py
+++ b/frontend/shared/rest_client.py
@@ -1,11 +1,10 @@
-import logging, pprint, re, urllib, getpass, urlparse
+import copy, getpass, logging, pprint, re, urllib, urlparse
 import httplib2
-from django.utils import simplejson
+from django.utils import datastructures, simplejson
 from autotest_lib.frontend.afe import rpc_client_lib
 from autotest_lib.client.common_lib import utils
 
 
-_http = httplib2.Http()
 _request_headers = {}
 
 
@@ -59,7 +58,8 @@
 
 
 class Resource(object):
-    def __init__(self, representation_dict):
+    def __init__(self, representation_dict, http):
+        self._http = http
         assert 'href' in representation_dict
         for key, value in representation_dict.iteritems():
             setattr(self, str(key), value)
@@ -75,8 +75,10 @@
 
 
     @classmethod
-    def load(cls, uri):
-        directory = cls({'href': uri})
+    def load(cls, uri, http=None):
+        if not http:
+            http = httplib2.Http()
+        directory = cls({'href': uri}, http)
         return directory.get()
 
 
@@ -88,7 +90,7 @@
             converted_dict = dict((key, self._read_representation(sub_value))
                                   for key, sub_value in value.iteritems())
             if 'href' in converted_dict:
-                return type(self)(converted_dict)
+                return type(self)(converted_dict, http=self._http)
             return converted_dict
         return value
 
@@ -113,11 +115,14 @@
 
 
     def _do_request(self, method, uri, query_parameters, encoded_body):
+        uri_parts = [uri]
         if query_parameters:
-            query_string = '?' + urllib.urlencode(query_parameters)
-        else:
-            query_string = ''
-        full_uri = uri + query_string
+            if '?' in uri:
+                uri_parts += '&'
+            else:
+                uri_parts += '?'
+            uri_parts += urllib.urlencode(query_parameters, doseq=True)
+        full_uri = ''.join(uri_parts)
 
         if encoded_body:
             entity_body = simplejson.dumps(encoded_body)
@@ -131,7 +136,7 @@
         site_verify = utils.import_site_function(
                 __file__, 'autotest_lib.frontend.shared.site_rest_client',
                 'site_verify_response', _site_verify_response_default)
-        headers, response_body = _http.request(
+        headers, response_body = self._http.request(
                 full_uri, method, body=entity_body,
                 headers=_get_request_headers(uri))
         if not site_verify(headers, response_body):
@@ -155,7 +160,8 @@
                                     encoded_body)
 
         if 300 <= response.status < 400: # redirection
-            raise NotImplementedError(str(response)) # TODO
+            return self._do_request(method, response.headers['location'],
+                                    query_parameters, encoded_body)
         if 400 <= response.status < 500:
             raise ClientError(str(response))
         if 500 <= response.status < 600:
@@ -165,19 +171,59 @@
 
     def _stringify_query_parameter(self, value):
         if isinstance(value, (list, tuple)):
-            return ','.join(value)
+            return ','.join(self._stringify_query_parameter(item)
+                            for item in value)
         return str(value)
 
 
-    def get(self, **query_parameters):
-        string_parameters = dict((key, self._stringify_query_parameter(value))
-                                 for key, value in query_parameters.iteritems()
-                                 if value is not None)
-        response = self._request('GET', query_parameters=string_parameters)
+    def _iterlists(self, mapping):
+        """This effectively lets us treat dicts as MultiValueDicts."""
+        if hasattr(mapping, 'iterlists'): # mapping is already a MultiValueDict
+            return mapping.iterlists()
+        return ((key, (value,)) for key, value in mapping.iteritems())
+
+
+    def get(self, query_parameters=None, **kwarg_query_parameters):
+        """
+        @param query_parameters: a dict or MultiValueDict
+        """
+        query_parameters = copy.copy(query_parameters) # avoid mutating original
+        if query_parameters is None:
+            query_parameters = {}
+        query_parameters.update(kwarg_query_parameters)
+
+        string_parameters = datastructures.MultiValueDict()
+        for key, values in self._iterlists(query_parameters):
+            string_parameters.setlist(
+                    key, [self._stringify_query_parameter(value)
+                          for value in values])
+
+        response = self._request('GET',
+                                 query_parameters=string_parameters.lists())
         assert response.status == 200
         return self._read_representation(response.decoded_body())
 
 
+    def get_full(self, results_limit, query_parameters=None,
+                 **kwarg_query_parameters):
+        """
+        Like get() for collections, when the full collection is expected.
+
+        @param results_limit: maxmimum number of results to allow
+        @raises ClientError if there are more than results_limit results.
+        """
+        result = self.get(query_parameters=query_parameters,
+                          items_per_page=results_limit,
+                          **kwarg_query_parameters)
+        if result.total_results > results_limit:
+            raise ClientError(
+                    'Too many results (%s > %s) for request %s (%s %s)'
+                    % (result.total_results, results_limit, self.href,
+                       query_parameters, kwarg_query_parameters))
+        return result
+
+
+
     def put(self):
         response = self._request('PUT', encoded_body=self._representation())
         assert response.status == 200