Various changes to support further high-level automation efforts.
* added a RESTful interface for TKO. right now there's only a single, simple resource for accessing test attributes.
* extended the REST server library in a few ways, most notably to support
* querying on keyvals, with something like ?has_keyval=mykey=myvalue&...
* operators, delimited by a colon, like ?hostname:in=host1,host2,host3
* loading relationships over many items efficiently (see InstanceEntry.prepare_for_full_representation()). this is used to fill in keyvals when requesting a job listing, but it can (and should) be used in other places, such as listing labels for a host collection.
* loading a collection with inlined full representations, by passing full_representations=true
* added various features to the AFE RESTful interface as necessary.
* various fixes to the rest_client library, most notably
* changed HTTP client in rest_client.py to use DI rather than singleton, easing testability. the same should be done for _get_request_headers(), to be honest.
* better support for query params, including accepting a MultiValueDict and supporting URIs that already have query args
* basic support for redirects
* builtin support for requesting a full collection (get_full()), when clients explicitly expect the result not to be paged. i'm still considering alternative approaches to this -- it may make sense to have something like this be the default, and have clients set a default page size limit rather than passing it every time.
* minor change to mock.py to provide better debugging output.
Signed-off-by: Steve Howard <showard@google.com>
git-svn-id: http://test.kernel.org/svn/autotest/trunk@4438 592f7852-d20e-0410-864c-8624ca9c26a4
diff --git a/frontend/shared/query_lib.py b/frontend/shared/query_lib.py
index 63ce262..5515b60 100644
--- a/frontend/shared/query_lib.py
+++ b/frontend/shared/query_lib.py
@@ -22,7 +22,16 @@
def add_related_existence_selector(self, name, model, field, doc=None):
self.add_selector(
Selector(name, doc),
- _RelatedExistenceConstraint(model, field, self.make_alias))
+ _RelatedExistenceConstraint(model, field,
+ make_alias_fn=self.make_alias))
+
+
+ def add_keyval_selector(self, name, model, key_field, value_field,
+ doc=None):
+ self.add_selector(
+ Selector(name, doc),
+ _KeyvalConstraint(model, key_field, value_field,
+ make_alias_fn=self.make_alias))
def add_selector(self, selector, constraint):
@@ -46,7 +55,9 @@
def apply_selector(self, queryset, selector_name, value,
- comparison_type='equals', is_inverse=False):
+ comparison_type=None, is_inverse=False):
+ if comparison_type is None:
+ comparison_type = 'equals'
_, constraint = self._selectors[selector_name]
try:
return constraint.apply_constraint(queryset, value, comparison_type,
@@ -103,6 +114,8 @@
kwarg_name = str(self._field + '__' +
self._COMPARISON_MAP[comparison_type])
+ if comparison_type == 'in':
+ value = value.split(',')
if is_inverse:
return queryset.exclude(**{kwarg_name: value})
@@ -138,3 +151,38 @@
queryset = queryset.model.objects.add_where(queryset, condition)
return queryset
+
+
+class _KeyvalConstraint(Constraint):
+ def __init__(self, model, key_field, value_field, make_alias_fn):
+ self._model = model
+ self._key_field = key_field
+ self._value_field = value_field
+ self._make_alias_fn = make_alias_fn
+
+
+ def apply_constraint(self, queryset, value, comparison_type, is_inverse):
+ if comparison_type not in (None, 'equals'):
+ raise ConstraintError('Can only use equals or not equals with '
+ 'this selector')
+ if '=' not in value:
+ raise ConstraintError('You must specify a key=value pair for this '
+ 'selector')
+
+ key, actual_value = value.split('=', 1)
+ related_query = self._model.objects.filter(
+ **{self._key_field: key, self._value_field: actual_value})
+ alias = self._make_alias_fn()
+ queryset = queryset.model.objects.join_custom_field(queryset,
+ related_query,
+ alias)
+ if is_inverse:
+ condition = '%s.%s IS NULL'
+ else:
+ condition = '%s.%s IS NOT NULL'
+ condition %= (alias,
+ queryset.model.objects.key_on_joined_table(related_query))
+
+ queryset = queryset.model.objects.add_where(queryset, condition)
+
+ return queryset
diff --git a/frontend/shared/resource_lib.py b/frontend/shared/resource_lib.py
index 5d77410..fff8ff0 100644
--- a/frontend/shared/resource_lib.py
+++ b/frontend/shared/resource_lib.py
@@ -131,10 +131,12 @@
def read_query_parameters(self, parameters):
"""Read relevant query parameters from a Django MultiValueDict."""
- for param_name, _ in self._query_parameters_accepted():
- if param_name in parameters:
- self._query_params.setlist(param_name,
- parameters.getlist(param_name))
+ params_acccepted = set(param_name for param_name, _
+ in self._query_parameters_accepted())
+ for name, values in parameters.iterlists():
+ base_name = name.split(':', 1)[0]
+ if base_name in params_acccepted:
+ self._query_params.setlist(name, values)
def set_query_parameters(self, **parameters):
@@ -216,6 +218,9 @@
except ValueError, exc:
raise exceptions.BadRequest('Error decoding request body: '
'%s\n%r' % (exc, raw_data))
+ if not isinstance(raw_dict, dict):
+ raise exceptions.BadRequest('Expected dict input, got %s: %r' %
+ (type(raw_dict), raw_dict))
elif content_type == 'application/x-www-form-urlencoded':
cgi_dict = cgi.parse_qs(raw_data) # django won't do this for PUT
raw_dict = {}
@@ -319,6 +324,7 @@
assert self.model is not None
super(Entry, self).__init__(request)
self.instance = instance
+ self._is_prepared_for_full_representation = False
@classmethod
@@ -332,6 +338,41 @@
self.instance.delete()
+ def full_representation(self):
+ self.prepare_for_full_representation([self])
+ return super(InstanceEntry, self).full_representation()
+
+
+ @classmethod
+ def prepare_for_full_representation(cls, entries):
+ """
+ Prepare the given list of entries to generate full representations.
+
+ This method delegates to _do_prepare_for_full_representation(), which
+ subclasses may override as necessary to do the actual processing. This
+ method also marks the instance as prepared, so it's safe to call this
+ multiple times with the same instance(s) without wasting work.
+ """
+ not_prepared = [entry for entry in entries
+ if not entry._is_prepared_for_full_representation]
+ cls._do_prepare_for_full_representation([entry.instance
+ for entry in not_prepared])
+ for entry in not_prepared:
+ entry._is_prepared_for_full_representation = True
+
+
+ @classmethod
+ def _do_prepare_for_full_representation(cls, instances):
+ """
+ Subclasses may override this to gather data as needed for full
+ representations of the given model instances. Typically, this involves
+ querying over related objects, and this method offers a chance to query
+ for many instances at once, which can provide a great performance
+ benefit.
+ """
+ pass
+
+
class Collection(Resource):
_DEFAULT_ITEMS_PER_PAGE = 50
@@ -354,7 +395,9 @@
def _query_parameters_accepted(self):
params = [('start_index', 'Index of first member to include'),
- ('items_per_page', 'Number of members to include')]
+ ('items_per_page', 'Number of members to include'),
+ ('full_representations',
+ 'True to include full representations of members')]
for selector in self._query_processor.selectors():
params.append((selector.name, selector.doc))
return params
@@ -371,16 +414,33 @@
def _representation(self, entry_instances):
+ entries = [self._entry_from_instance(instance)
+ for instance in entry_instances]
+
+ want_full_representation = self._read_bool_parameter(
+ 'full_representations')
+ if want_full_representation:
+ self.entry_class.prepare_for_full_representation(entries)
+
members = []
- for instance in entry_instances:
- entry = self._entry_from_instance(instance)
- members.append(entry.short_representation())
+ for entry in entries:
+ if want_full_representation:
+ rep = entry.full_representation()
+ else:
+ rep = entry.short_representation()
+ members.append(rep)
rep = self.link()
rep.update({'members': members})
return rep
+ def _read_bool_parameter(self, name):
+ if name not in self._query_params:
+ return False
+ return (self._query_params[name].lower() == 'true')
+
+
def _read_int_parameter(self, name, default):
if name not in self._query_params:
return default
@@ -395,12 +455,17 @@
def _apply_form_query(self, queryset):
"""Apply any query selectors passed as form variables."""
for parameter, values in self._query_params.lists():
+ if ':' in parameter:
+ parameter, comparison_type = parameter.split(':', 1)
+ else:
+ comparison_type = None
+
if not self._query_processor.has_selector(parameter):
continue
for value in values: # forms keys can have multiple values
- queryset = self._query_processor.apply_selector(queryset,
- parameter,
- value)
+ queryset = self._query_processor.apply_selector(
+ queryset, parameter, value,
+ comparison_type=comparison_type)
return queryset
diff --git a/frontend/shared/resource_test_utils.py b/frontend/shared/resource_test_utils.py
new file mode 100644
index 0000000..8cb742f
--- /dev/null
+++ b/frontend/shared/resource_test_utils.py
@@ -0,0 +1,136 @@
+import operator, unittest
+import simplejson
+from django.test import client
+from autotest_lib.frontend.afe import frontend_test_utils, models as afe_models
+
+class ResourceTestCase(unittest.TestCase,
+ frontend_test_utils.FrontendTestMixin):
+ URI_PREFIX = None # subclasses may override this to use partial URIs
+
+ def setUp(self):
+ super(ResourceTestCase, self).setUp()
+ self._frontend_common_setup()
+ self._setup_debug_user()
+ self.client = client.Client()
+
+
+ def tearDown(self):
+ super(ResourceTestCase, self).tearDown()
+ self._frontend_common_teardown()
+
+
+ def _setup_debug_user(self):
+ user = afe_models.User.objects.create(login='debug_user')
+ acl = afe_models.AclGroup.objects.get(name='my_acl')
+ user.aclgroup_set.add(acl)
+
+
+ def _expected_status(self, method):
+ if method == 'post':
+ return 201
+ if method == 'delete':
+ return 204
+ return 200
+
+
+ def raw_request(self, method, uri, **kwargs):
+ method = method.lower()
+ if method == 'put':
+ # the put() implementation in Django's test client is poorly
+ # implemented and only supports url-encoded keyvals for the data.
+ # the post() implementation is correct, though, so use that, with a
+ # trick to override the method.
+ method = 'post'
+ kwargs['REQUEST_METHOD'] = 'PUT'
+
+ client_method = getattr(self.client, method)
+ return client_method(uri, **kwargs)
+
+
+ def request(self, method, uri, encode_body=True, **kwargs):
+ expected_status = self._expected_status(method)
+
+ if 'data' in kwargs:
+ kwargs.setdefault('content_type', 'application/json')
+ if kwargs['content_type'] == 'application/json':
+ kwargs['data'] = simplejson.dumps(kwargs['data'])
+
+ if uri.startswith('http://'):
+ full_uri = uri
+ else:
+ assert self.URI_PREFIX
+ full_uri = self.URI_PREFIX + '/' + uri
+
+ response = self.raw_request(method, full_uri, **kwargs)
+ self.assertEquals(
+ response.status_code, expected_status,
+ 'Requesting %s\nExpected %s, got %s: %s (headers: %s)'
+ % (full_uri, expected_status, response.status_code,
+ response.content, response._headers))
+
+ if response['content-type'] != 'application/json':
+ return response.content
+
+ try:
+ return simplejson.loads(response.content)
+ except ValueError:
+ self.fail('Invalid reponse body: %s' % response.content)
+
+
+ def sorted_by(self, collection, attribute):
+ return sorted(collection, key=operator.itemgetter(attribute))
+
+
+ def _read_attribute(self, item, attribute_or_list):
+ if isinstance(attribute_or_list, basestring):
+ attribute_or_list = [attribute_or_list]
+ for attribute in attribute_or_list:
+ item = item[attribute]
+ return item
+
+
+ def check_collection(self, collection, attribute_or_list, expected_list,
+ length=None, check_number=None):
+ """Check the members of a collection of dicts.
+
+ @param collection: an iterable of dicts
+ @param attribute_or_list: an attribute or list of attributes to read.
+ the results will be sorted and compared with expected_list. if
+ a list of attributes is given, the attributes will be read
+ hierarchically, i.e. item[attribute1][attribute2]...
+ @param expected_list: list of expected values
+ @param check_number: if given, only check this number of entries
+ @param length: expected length of list, only necessary if check_number
+ is given
+ """
+ actual_list = sorted(self._read_attribute(item, attribute_or_list)
+ for item in collection['members'])
+ if length is None and check_number is None:
+ length = len(expected_list)
+ if length is not None:
+ self.assertEquals(len(actual_list), length,
+ 'Expected %s, got %s: %s'
+ % (length, len(actual_list),
+ ', '.join(str(item) for item in actual_list)))
+ if check_number:
+ actual_list = actual_list[:check_number]
+ self.assertEquals(actual_list, expected_list)
+
+
+ def check_relationship(self, resource_uri, relationship_name,
+ other_entry_name, field, expected_values,
+ length=None, check_number=None):
+ """Check the members of a relationship collection.
+
+ @param resource_uri: URI of base resource
+ @param relationship_name: name of relationship attribute on base
+ resource
+ @param other_entry_name: name of other entry in relationship
+ @param field: name of field to grab on other entry
+ @param expected values: list of expected values for the given field
+ """
+ response = self.request('get', resource_uri)
+ relationship_uri = response[relationship_name]['href']
+ relationships = self.request('get', relationship_uri)
+ self.check_collection(relationships, [other_entry_name, field],
+ expected_values, length, check_number)
diff --git a/frontend/shared/rest_client.py b/frontend/shared/rest_client.py
index b5b83af..9c4f5d5 100644
--- a/frontend/shared/rest_client.py
+++ b/frontend/shared/rest_client.py
@@ -1,11 +1,10 @@
-import logging, pprint, re, urllib, getpass, urlparse
+import copy, getpass, logging, pprint, re, urllib, urlparse
import httplib2
-from django.utils import simplejson
+from django.utils import datastructures, simplejson
from autotest_lib.frontend.afe import rpc_client_lib
from autotest_lib.client.common_lib import utils
-_http = httplib2.Http()
_request_headers = {}
@@ -59,7 +58,8 @@
class Resource(object):
- def __init__(self, representation_dict):
+ def __init__(self, representation_dict, http):
+ self._http = http
assert 'href' in representation_dict
for key, value in representation_dict.iteritems():
setattr(self, str(key), value)
@@ -75,8 +75,10 @@
@classmethod
- def load(cls, uri):
- directory = cls({'href': uri})
+ def load(cls, uri, http=None):
+ if not http:
+ http = httplib2.Http()
+ directory = cls({'href': uri}, http)
return directory.get()
@@ -88,7 +90,7 @@
converted_dict = dict((key, self._read_representation(sub_value))
for key, sub_value in value.iteritems())
if 'href' in converted_dict:
- return type(self)(converted_dict)
+ return type(self)(converted_dict, http=self._http)
return converted_dict
return value
@@ -113,11 +115,14 @@
def _do_request(self, method, uri, query_parameters, encoded_body):
+ uri_parts = [uri]
if query_parameters:
- query_string = '?' + urllib.urlencode(query_parameters)
- else:
- query_string = ''
- full_uri = uri + query_string
+ if '?' in uri:
+ uri_parts += '&'
+ else:
+ uri_parts += '?'
+ uri_parts += urllib.urlencode(query_parameters, doseq=True)
+ full_uri = ''.join(uri_parts)
if encoded_body:
entity_body = simplejson.dumps(encoded_body)
@@ -131,7 +136,7 @@
site_verify = utils.import_site_function(
__file__, 'autotest_lib.frontend.shared.site_rest_client',
'site_verify_response', _site_verify_response_default)
- headers, response_body = _http.request(
+ headers, response_body = self._http.request(
full_uri, method, body=entity_body,
headers=_get_request_headers(uri))
if not site_verify(headers, response_body):
@@ -155,7 +160,8 @@
encoded_body)
if 300 <= response.status < 400: # redirection
- raise NotImplementedError(str(response)) # TODO
+ return self._do_request(method, response.headers['location'],
+ query_parameters, encoded_body)
if 400 <= response.status < 500:
raise ClientError(str(response))
if 500 <= response.status < 600:
@@ -165,19 +171,59 @@
def _stringify_query_parameter(self, value):
if isinstance(value, (list, tuple)):
- return ','.join(value)
+ return ','.join(self._stringify_query_parameter(item)
+ for item in value)
return str(value)
- def get(self, **query_parameters):
- string_parameters = dict((key, self._stringify_query_parameter(value))
- for key, value in query_parameters.iteritems()
- if value is not None)
- response = self._request('GET', query_parameters=string_parameters)
+ def _iterlists(self, mapping):
+ """This effectively lets us treat dicts as MultiValueDicts."""
+ if hasattr(mapping, 'iterlists'): # mapping is already a MultiValueDict
+ return mapping.iterlists()
+ return ((key, (value,)) for key, value in mapping.iteritems())
+
+
+ def get(self, query_parameters=None, **kwarg_query_parameters):
+ """
+ @param query_parameters: a dict or MultiValueDict
+ """
+ query_parameters = copy.copy(query_parameters) # avoid mutating original
+ if query_parameters is None:
+ query_parameters = {}
+ query_parameters.update(kwarg_query_parameters)
+
+ string_parameters = datastructures.MultiValueDict()
+ for key, values in self._iterlists(query_parameters):
+ string_parameters.setlist(
+ key, [self._stringify_query_parameter(value)
+ for value in values])
+
+ response = self._request('GET',
+ query_parameters=string_parameters.lists())
assert response.status == 200
return self._read_representation(response.decoded_body())
+ def get_full(self, results_limit, query_parameters=None,
+ **kwarg_query_parameters):
+ """
+ Like get() for collections, when the full collection is expected.
+
+ @param results_limit: maxmimum number of results to allow
+ @raises ClientError if there are more than results_limit results.
+ """
+ result = self.get(query_parameters=query_parameters,
+ items_per_page=results_limit,
+ **kwarg_query_parameters)
+ if result.total_results > results_limit:
+ raise ClientError(
+ 'Too many results (%s > %s) for request %s (%s %s)'
+ % (result.total_results, results_limit, self.href,
+ query_parameters, kwarg_query_parameters))
+ return result
+
+
+
def put(self):
response = self._request('PUT', encoded_body=self._representation())
assert response.status == 200