New, more consistent way of representing relationships in RESTful interface.  A relationship between two entrys is represented as a relationship itself, contained in a collection.  For example,

resources/labelings?host=myhost

retrieves all labelings for host myhost, each with a URI like

resources/labelings/myhost,mylabel

which represents that the label mylabel is applied to the host myhost.  The user can POST to the former URI to relate two entries and DELETE the latter URI to break the relationship.  This makes it much more consistent with the rest of the interface.

There are various other refactorings included in here that came up as I worked through this change to make things cleaner.

Signed-off-by: Steve Howard <showard@google.com>


git-svn-id: http://test.kernel.org/svn/autotest/trunk@4246 592f7852-d20e-0410-864c-8624ca9c26a4
diff --git a/frontend/shared/resource_lib.py b/frontend/shared/resource_lib.py
index a26acd4..5d77410 100644
--- a/frontend/shared/resource_lib.py
+++ b/frontend/shared/resource_lib.py
@@ -2,7 +2,8 @@
 from django import http
 import django.core.exceptions
 from django.core import urlresolvers
-from django.utils import simplejson
+from django.utils import datastructures
+import simplejson
 from autotest_lib.frontend.shared import exceptions, query_lib
 from autotest_lib.frontend.afe import model_logic
 
@@ -35,17 +36,28 @@
 
     def __init__(self, request):
         assert self._permitted_methods
+        # this request should be used for global environment info, like
+        # constructing absolute URIs.  it should not be used for query
+        # parameters, because the request may not have been for this particular
+        # resource.
         self._request = request
+        # this dict will contain the applicable query parameters
+        self._query_params = datastructures.MultiValueDict()
 
 
     @classmethod
     def dispatch_request(cls, request, *args, **kwargs):
         # handle a request directly
         try:
-            instance = cls.from_uri_args(request, *args, **kwargs)
-        except django.core.exceptions.ObjectDoesNotExist, exc:
-            raise http.Http404(exc)
-        return instance.handle_request()
+            try:
+                instance = cls.from_uri_args(request, **kwargs)
+            except django.core.exceptions.ObjectDoesNotExist, exc:
+                raise http.Http404(exc)
+
+            instance.read_query_parameters(request.GET)
+            return instance.handle_request()
+        except exceptions.RequestError, exc:
+            return exc.response
 
 
     def handle_request(self):
@@ -53,10 +65,7 @@
             return http.HttpResponseNotAllowed(self._permitted_methods)
 
         handler = getattr(self, self._request.method.lower())
-        try:
-            return handler()
-        except exceptions.RequestError, exc:
-            return exc.response
+        return handler()
 
 
     # the handler methods below only need to be overridden if the resource
@@ -95,7 +104,7 @@
 
 
     @classmethod
-    def from_uri_args(cls, request):
+    def from_uri_args(cls, request, **kwargs):
         """Construct an instance from URI args.
 
         Default implementation for resources with no URI args.
@@ -104,14 +113,14 @@
 
 
     def _uri_args(self):
-        """Return (args, kwargs) for a URI reference to this resource.
+        """Return kwargs for a URI reference to this resource.
 
         Default implementation for resources with no URI args.
         """
-        return (), {}
+        return {}
 
 
-    def _query_parameters(self):
+    def _query_parameters_accepted(self):
         """Return sequence of tuples (name, description) for query parameters.
 
         Documents the available query parameters for GETting this resource.
@@ -120,11 +129,29 @@
         return ()
 
 
-    def href(self):
+    def read_query_parameters(self, parameters):
+        """Read relevant query parameters from a Django MultiValueDict."""
+        for param_name, _ in self._query_parameters_accepted():
+            if param_name in parameters:
+                self._query_params.setlist(param_name,
+                                           parameters.getlist(param_name))
+
+
+    def set_query_parameters(self, **parameters):
+        """Set query parameters programmatically."""
+        self._query_params.update(parameters)
+
+
+    def href(self, query_params=None):
         """Return URI to this resource."""
-        args, kwargs = self._uri_args()
-        path = urlresolvers.reverse(self.dispatch_request, args=args,
-                                    kwargs=kwargs)
+        kwargs = self._uri_args()
+        path = urlresolvers.reverse(self.dispatch_request, kwargs=kwargs)
+        full_query_params = datastructures.MultiValueDict(self._query_params)
+        if query_params:
+            full_query_params.update(query_params)
+        if full_query_params:
+            path += '?' + urllib.urlencode(full_query_params.lists(),
+                                           doseq=True)
         return self._request.build_absolute_uri(path)
 
 
@@ -141,9 +168,12 @@
                                             % uri)
             uri = match.group('path')
 
-        view_method, args, kwargs = urlresolvers.resolve(uri)
+        try:
+            view_method, args, kwargs = urlresolvers.resolve(uri)
+        except http.Http404:
+            raise exceptions.BadRequest('Unable to resolve URI %s' % uri)
         resource_class = view_method.im_self # class owning this classmethod
-        return resource_class.from_uri_args(self._request, *args, **kwargs)
+        return resource_class.from_uri_args(self._request, **kwargs)
 
 
     def resolve_link(self, link):
@@ -156,13 +186,13 @@
         return self.resolve_uri(uri)
 
 
-    def link(self):
-        return {'href': self.href()}
+    def link(self, query_params=None):
+        return {'href': self.href(query_params=query_params)}
 
 
     def _query_parameters_response(self):
         return dict((name, description)
-                    for name, description in self._query_parameters())
+                    for name, description in self._query_parameters_accepted())
 
 
     def _basic_response(self, content):
@@ -227,34 +257,10 @@
 
 
 class Entry(Resource):
-    class NullEntry(object):
-        def link(self):
-            return None
-
-
-        def short_representation(self):
-            return None
-
-    _null_entry = NullEntry()
-
-
-    _permitted_methods = ('GET', 'PUT', 'DELETE')
-
-
-    # sublcasses must define this class to support querying
-    QueryProcessor = query_lib.BaseQueryProcessor
-
-
-    def __init__(self, request, instance):
-        super(Entry, self).__init__(request)
-        self.instance = instance
-
-
     @classmethod
-    def from_optional_instance(cls, request, instance):
-        if instance is None:
-            return cls._null_entry
-        return cls(request, instance)
+    def add_query_selectors(cls, query_processor):
+        """Sbuclasses may override this to support querying."""
+        pass
 
 
     def short_representation(self):
@@ -277,8 +283,12 @@
         return self._basic_response(self.full_representation())
 
 
+    def _delete_entry(self):
+        raise NotImplementedError
+
+
     def delete(self):
-        self.instance.delete()
+        self._delete_entry()
         return http.HttpResponse(status=204) # No content
 
 
@@ -290,6 +300,38 @@
         raise NotImplementedError
 
 
+class InstanceEntry(Entry):
+    class NullEntry(object):
+        def link(self):
+            return None
+
+
+        def short_representation(self):
+            return None
+
+
+    _null_entry = NullEntry()
+    _permitted_methods = ('GET', 'PUT', 'DELETE')
+    model = None # subclasses must override this with a Django model class
+
+
+    def __init__(self, request, instance):
+        assert self.model is not None
+        super(Entry, self).__init__(request)
+        self.instance = instance
+
+
+    @classmethod
+    def from_optional_instance(cls, request, instance):
+        if instance is None:
+            return cls._null_entry
+        return cls(request, instance)
+
+
+    def _delete_entry(self):
+        self.instance.delete()
+
+
 class Collection(Resource):
     _DEFAULT_ITEMS_PER_PAGE = 50
 
@@ -306,7 +348,16 @@
         if isinstance(self.entry_class, basestring):
             type(self).entry_class = _resolve_class_path(self.entry_class)
 
-        self._query_processor = self.entry_class.QueryProcessor()
+        self._query_processor = query_lib.QueryProcessor()
+        self.entry_class.add_query_selectors(self._query_processor)
+
+
+    def _query_parameters_accepted(self):
+        params = [('start_index', 'Index of first member to include'),
+                  ('items_per_page', 'Number of members to include')]
+        for selector in self._query_processor.selectors():
+            params.append((selector.name, selector.doc))
+        return params
 
 
     def _fresh_queryset(self):
@@ -315,10 +366,14 @@
         return self.queryset.all()
 
 
+    def _entry_from_instance(self, instance):
+        return self.entry_class(self._request, instance)
+
+
     def _representation(self, entry_instances):
         members = []
         for instance in entry_instances:
-            entry = self.entry_class(self._request, instance)
+            entry = self._entry_from_instance(instance)
             members.append(entry.short_representation())
 
         rep = self.link()
@@ -327,10 +382,9 @@
 
 
     def _read_int_parameter(self, name, default):
-        query_dict = self._request.GET
-        if name not in query_dict:
+        if name not in self._query_params:
             return default
-        input_value = query_dict[name]
+        input_value = self._query_params[name]
         try:
             return int(input_value)
         except ValueError:
@@ -340,7 +394,7 @@
 
     def _apply_form_query(self, queryset):
         """Apply any query selectors passed as form variables."""
-        for parameter, values in self._request.GET.lists():
+        for parameter, values in self._query_params.lists():
             if not self._query_processor.has_selector(parameter):
                 continue
             for value in values: # forms keys can have multiple values
@@ -363,13 +417,9 @@
         page = queryset[start_index:(start_index + items_per_page)]
 
         rep = self._representation(page)
-        selector_dict = dict((selector.name, selector.doc)
-                             for selector
-                             in self.entry_class.QueryProcessor.selectors())
         rep.update({'total_results': len(queryset),
                     'start_index': start_index,
-                    'items_per_page': items_per_page,
-                    'filtering_selectors': selector_dict})
+                    'items_per_page': items_per_page})
         return self._basic_response(rep)
 
 
@@ -382,7 +432,7 @@
         input_dict = self._decoded_input()
         try:
             instance = self.entry_class.create_instance(input_dict, self)
-            entry = self.entry_class(self._request, instance)
+            entry = self._entry_from_instance(instance)
             entry.update(input_dict)
         except model_logic.ValidationError, exc:
             raise exceptions.BadRequest('Invalid input: %s' % exc)
@@ -394,76 +444,172 @@
         return response
 
 
-class Relationship(Collection):
-    _permitted_methods=('GET', 'PUT')
+class Relationship(Entry):
+    _permitted_methods = ('GET', 'DELETE')
 
-    base_entry_class = None # subclasses must override this
+    # subclasses must override this with a dict mapping name to entry class
+    related_classes = None
 
 
-    def __init__(self, base_entry):
-        assert self.base_entry_class
-        if isinstance(self.base_entry_class, basestring):
-            type(self).base_entry_class = _resolve_class_path(
-                    self.base_entry_class)
-        assert isinstance(base_entry, self.base_entry_class)
-        self.base_entry = base_entry
-        super(Relationship, self).__init__(base_entry._request)
+    def __init__(self, **kwargs):
+        assert len(self.related_classes) == 2
+        self.entries = dict((name, kwargs[name])
+                            for name in self.related_classes)
+        for name in self.related_classes: # sanity check
+            assert isinstance(self.entries[name], self.related_classes[name])
 
-
-    def _fresh_queryset(self):
-        """Return a QuerySet for this relationship using self.base_entry."""
-        raise NotImplementedError
+        # just grab the request from one of the entries
+        some_entry = self.entries.itervalues().next()
+        super(Relationship, self).__init__(some_entry._request)
 
 
     @classmethod
-    def from_uri_args(cls, request, *args, **kwargs):
-        base_entry = cls.base_entry_class.from_uri_args(request, *args,
-                                                        **kwargs)
-        return cls(base_entry)
+    def from_uri_args(cls, request, **kwargs):
+        # kwargs contains URI args for each entry
+        entries = {}
+        for name, entry_class in cls.related_classes.iteritems():
+            entries[name] = entry_class.from_uri_args(request, **kwargs)
+        return cls(**entries)
 
 
     def _uri_args(self):
-        return self.base_entry._uri_args()
+        kwargs = {}
+        for name, entry in self.entries.iteritems():
+            kwargs.update(entry._uri_args())
+        return kwargs
+
+
+    def short_representation(self):
+        rep = self.link()
+        for name, entry in self.entries.iteritems():
+            rep[name] = entry.short_representation()
+        return rep
 
 
     @classmethod
-    def _input_collection_links(cls, input_data):
-        """Get the members of a user-provided collection.
+    def _get_related_manager(cls, instance):
+        """Get the related objects manager for the given instance.
 
-        Tries to be flexible about formats accepted from the user.
-        @returns a list of links, possibly only href strings (use
-                resolve_link())
+        The instance must be one of the related classes.  This method will
+        return the related manager from that instance to instances of the other
+        related class.
         """
-        if isinstance(input_data, dict) and 'members' in input_data:
-            # this mirrors the output representation for collections
-            # guard against accidental truncation of the relationship due to
-            # paging
-            is_partial_collection = ('total_results' in input_data
-                                     and 'items_per_page' in input_data
-                                     and input_data['total_results'] >
-                                         input_data['items_per_page'])
-            if is_partial_collection:
-                raise exceptions.BadRequest('You must retreive the full '
-                                            'collection to perform updates')
+        this_model = type(instance)
+        models = [entry_class.model for entry_class
+                  in cls.related_classes.values()]
+        if isinstance(instance, models[0]):
+            this_model, other_model = models
+        else:
+            other_model, this_model = models
 
-            return input_data['members']
-        if isinstance(input_data, list):
-            return input_data
-        raise exceptions.BadRequest('Cannot understand collection in input: %r'
-                                    % input_data)
+        _, field = this_model.objects.determine_relationship(other_model)
+        this_models_fields = (this_model._meta.fields
+                              + this_model._meta.many_to_many)
+        if field in this_models_fields:
+            manager_name = field.attname
+        else:
+            # related manager is on other_model, get name of reverse related
+            # manager on this_model
+            manager_name = field.related.get_accessor_name()
+
+        return getattr(instance, manager_name)
 
 
-    def put(self):
-        input_data = self._decoded_input()
-        self.update(input_data)
-        return self.get()
+    def _delete_entry(self):
+        # choose order arbitrarily
+        entry, other_entry = self.entries.itervalues()
+        related_manager = self._get_related_manager(entry.instance)
+        related_manager.remove(other_entry.instance)
 
 
-    def update(self, input_data):
-        links = self._input_collection_links(input_data)
-        instances = [self.resolve_link(link).instance for link in links]
-        self._update_relationship(instances)
+    @classmethod
+    def create_instance(cls, input_dict, containing_collection):
+        other_name = containing_collection.unfixed_name
+        cls._check_for_required_fields(input_dict, (other_name,))
+        entry = containing_collection.fixed_entry
+        other_entry = containing_collection.resolve_link(input_dict[other_name])
+        related_manager = cls._get_related_manager(entry.instance)
+        related_manager.add(other_entry.instance)
+        return other_entry.instance
 
 
-    def _update_relationship(self, related_instances):
-        raise NotImplementedError
+    def update(self, input_dict):
+        pass
+
+
+class RelationshipCollection(Collection):
+    def __init__(self, request=None, fixed_entry=None):
+        if request is None:
+            request = fixed_entry._request
+        super(RelationshipCollection, self).__init__(request)
+
+        assert issubclass(self.entry_class, Relationship)
+        self.related_classes = self.entry_class.related_classes
+        self.fixed_name = None
+        self.fixed_entry = None
+        self.unfixed_name = None
+        self.related_manager = None
+
+        if fixed_entry is not None:
+            self._set_fixed_entry(fixed_entry)
+            entry_uri_arg = self.fixed_entry._uri_args().values()[0]
+            self._query_params[self.fixed_name] = entry_uri_arg
+
+
+    def _set_fixed_entry(self, entry):
+        """Set the fixed entry for this collection.
+
+        The entry must be an instance of one of the related entry classes.  This
+        method must be called before a relationship is used.  It gets called
+        either from the constructor (when collections are instantiated from
+        other resource handling code) or from read_query_parameters() (when a
+        request is made directly for the collection.
+        """
+        names = self.related_classes.keys()
+        if isinstance(entry, self.related_classes[names[0]]):
+            self.fixed_name, self.unfixed_name = names
+        else:
+            assert isinstance(entry, self.related_classes[names[1]])
+            self.unfixed_name, self.fixed_name = names
+        self.fixed_entry = entry
+        self.unfixed_class = self.related_classes[self.unfixed_name]
+        self.related_manager = self.entry_class._get_related_manager(
+                entry.instance)
+
+
+    def _query_parameters_accepted(self):
+        return [(name, 'Show relationships for this %s' % entry_class.__name__)
+                for name, entry_class
+                in self.related_classes.iteritems()]
+
+
+    def _resolve_query_param(self, name, uri_arg):
+        entry_class = self.related_classes[name]
+        return entry_class.from_uri_args(self._request, uri_arg)
+
+
+    def read_query_parameters(self, query_params):
+        super(RelationshipCollection, self).read_query_parameters(query_params)
+        if not self._query_params:
+            raise exceptions.BadRequest(
+                    'You must specify one of the parameters %s and %s'
+                    % tuple(self.related_classes.keys()))
+        query_items = self._query_params.items()
+        fixed_entry = self._resolve_query_param(*query_items[0])
+        self._set_fixed_entry(fixed_entry)
+
+        if len(query_items) > 1:
+            other_fixed_entry = self._resolve_query_param(*query_items[1])
+            self.related_manager = self.related_manager.filter(
+                    pk=other_fixed_entry.instance.id)
+
+
+    def _entry_from_instance(self, instance):
+        unfixed_entry = self.unfixed_class(self._request, instance)
+        entries = {self.fixed_name: self.fixed_entry,
+                   self.unfixed_name: unfixed_entry}
+        return self.entry_class(**entries)
+
+
+    def _fresh_queryset(self):
+        return self.related_manager.all()