blob: 254e92f7e8ef9b88b8af00499cf718fe89190e5d [file] [log] [blame]
showardf46ad4c2010-02-03 20:28:59 +00001import cgi, datetime, re, time, urllib
showardf828c772010-01-25 21:49:42 +00002from django import http
showardf46ad4c2010-02-03 20:28:59 +00003import django.core.exceptions
4from django.core import urlresolvers
jamesren3981f442010-02-16 19:27:59 +00005from django.utils import datastructures
Alex Miller4f341702013-03-25 12:39:12 -07006import json
showardf828c772010-01-25 21:49:42 +00007from autotest_lib.frontend.shared import exceptions, query_lib
8from autotest_lib.frontend.afe import model_logic
9
10
11_JSON_CONTENT_TYPE = 'application/json'
12
13
14def _resolve_class_path(class_path):
15 module_path, class_name = class_path.rsplit('.', 1)
16 module = __import__(module_path, {}, {}, [''])
17 return getattr(module, class_name)
18
19
20_NO_VALUE_SPECIFIED = object()
21
22class _InputDict(dict):
23 def get(self, key, default=_NO_VALUE_SPECIFIED):
24 return super(_InputDict, self).get(key, default)
25
26
27 @classmethod
28 def remove_unspecified_fields(cls, field_dict):
29 return dict((key, value) for key, value in field_dict.iteritems()
30 if value is not _NO_VALUE_SPECIFIED)
31
32
33class Resource(object):
34 _permitted_methods = None # subclasses must override this
35
36
showardf46ad4c2010-02-03 20:28:59 +000037 def __init__(self, request):
showardf828c772010-01-25 21:49:42 +000038 assert self._permitted_methods
jamesren3981f442010-02-16 19:27:59 +000039 # this request should be used for global environment info, like
40 # constructing absolute URIs. it should not be used for query
41 # parameters, because the request may not have been for this particular
42 # resource.
showardf46ad4c2010-02-03 20:28:59 +000043 self._request = request
jamesren3981f442010-02-16 19:27:59 +000044 # this dict will contain the applicable query parameters
45 self._query_params = datastructures.MultiValueDict()
showardf828c772010-01-25 21:49:42 +000046
47
48 @classmethod
49 def dispatch_request(cls, request, *args, **kwargs):
50 # handle a request directly
51 try:
jamesren3981f442010-02-16 19:27:59 +000052 try:
53 instance = cls.from_uri_args(request, **kwargs)
54 except django.core.exceptions.ObjectDoesNotExist, exc:
55 raise http.Http404(exc)
56
57 instance.read_query_parameters(request.GET)
58 return instance.handle_request()
59 except exceptions.RequestError, exc:
60 return exc.response
showardf828c772010-01-25 21:49:42 +000061
62
showardf46ad4c2010-02-03 20:28:59 +000063 def handle_request(self):
64 if self._request.method.upper() not in self._permitted_methods:
showardf828c772010-01-25 21:49:42 +000065 return http.HttpResponseNotAllowed(self._permitted_methods)
66
showardf46ad4c2010-02-03 20:28:59 +000067 handler = getattr(self, self._request.method.lower())
jamesren3981f442010-02-16 19:27:59 +000068 return handler()
showardf828c772010-01-25 21:49:42 +000069
70
71 # the handler methods below only need to be overridden if the resource
72 # supports the method
73
showardf46ad4c2010-02-03 20:28:59 +000074 def get(self):
showardf828c772010-01-25 21:49:42 +000075 """Handle a GET request.
76
77 @returns an HttpResponse
78 """
79 raise NotImplementedError
80
81
showardf46ad4c2010-02-03 20:28:59 +000082 def post(self):
showardf828c772010-01-25 21:49:42 +000083 """Handle a POST request.
84
85 @returns an HttpResponse
86 """
87 raise NotImplementedError
88
89
showardf46ad4c2010-02-03 20:28:59 +000090 def put(self):
showardf828c772010-01-25 21:49:42 +000091 """Handle a PUT request.
92
93 @returns an HttpResponse
94 """
95 raise NotImplementedError
96
97
showardf46ad4c2010-02-03 20:28:59 +000098 def delete(self):
showardf828c772010-01-25 21:49:42 +000099 """Handle a DELETE request.
100
101 @returns an HttpResponse
102 """
103 raise NotImplementedError
104
105
106 @classmethod
jamesren3981f442010-02-16 19:27:59 +0000107 def from_uri_args(cls, request, **kwargs):
showardf828c772010-01-25 21:49:42 +0000108 """Construct an instance from URI args.
109
110 Default implementation for resources with no URI args.
111 """
showardf46ad4c2010-02-03 20:28:59 +0000112 return cls(request)
showardf828c772010-01-25 21:49:42 +0000113
114
115 def _uri_args(self):
jamesren3981f442010-02-16 19:27:59 +0000116 """Return kwargs for a URI reference to this resource.
showardf828c772010-01-25 21:49:42 +0000117
118 Default implementation for resources with no URI args.
119 """
jamesren3981f442010-02-16 19:27:59 +0000120 return {}
showardf828c772010-01-25 21:49:42 +0000121
122
jamesren3981f442010-02-16 19:27:59 +0000123 def _query_parameters_accepted(self):
showardf828c772010-01-25 21:49:42 +0000124 """Return sequence of tuples (name, description) for query parameters.
125
126 Documents the available query parameters for GETting this resource.
127 Default implementation for resources with no parameters.
128 """
129 return ()
130
131
jamesren3981f442010-02-16 19:27:59 +0000132 def read_query_parameters(self, parameters):
133 """Read relevant query parameters from a Django MultiValueDict."""
jamesrencd7a81a2010-04-21 20:39:08 +0000134 params_acccepted = set(param_name for param_name, _
135 in self._query_parameters_accepted())
136 for name, values in parameters.iterlists():
137 base_name = name.split(':', 1)[0]
138 if base_name in params_acccepted:
139 self._query_params.setlist(name, values)
jamesren3981f442010-02-16 19:27:59 +0000140
141
142 def set_query_parameters(self, **parameters):
143 """Set query parameters programmatically."""
144 self._query_params.update(parameters)
145
146
147 def href(self, query_params=None):
showardf828c772010-01-25 21:49:42 +0000148 """Return URI to this resource."""
jamesren3981f442010-02-16 19:27:59 +0000149 kwargs = self._uri_args()
150 path = urlresolvers.reverse(self.dispatch_request, kwargs=kwargs)
151 full_query_params = datastructures.MultiValueDict(self._query_params)
152 if query_params:
153 full_query_params.update(query_params)
154 if full_query_params:
155 path += '?' + urllib.urlencode(full_query_params.lists(),
156 doseq=True)
showardf46ad4c2010-02-03 20:28:59 +0000157 return self._request.build_absolute_uri(path)
showardf828c772010-01-25 21:49:42 +0000158
159
showardf46ad4c2010-02-03 20:28:59 +0000160 def resolve_uri(self, uri):
161 # check for absolute URIs
162 match = re.match(r'(?P<root>https?://[^/]+)(?P<path>/.*)', uri)
163 if match:
164 # is this URI for a different host?
165 my_root = self._request.build_absolute_uri('/')
166 request_root = match.group('root') + '/'
167 if my_root != request_root:
168 # might support this in the future, but not now
169 raise exceptions.BadRequest('Unable to resolve remote URI %s'
170 % uri)
171 uri = match.group('path')
172
jamesren3981f442010-02-16 19:27:59 +0000173 try:
174 view_method, args, kwargs = urlresolvers.resolve(uri)
175 except http.Http404:
176 raise exceptions.BadRequest('Unable to resolve URI %s' % uri)
showardf828c772010-01-25 21:49:42 +0000177 resource_class = view_method.im_self # class owning this classmethod
jamesren3981f442010-02-16 19:27:59 +0000178 return resource_class.from_uri_args(self._request, **kwargs)
showardf828c772010-01-25 21:49:42 +0000179
180
showardf46ad4c2010-02-03 20:28:59 +0000181 def resolve_link(self, link):
showardf828c772010-01-25 21:49:42 +0000182 if isinstance(link, dict):
183 uri = link['href']
184 elif isinstance(link, basestring):
185 uri = link
186 else:
187 raise exceptions.BadRequest('Unable to understand link %s' % link)
showardf46ad4c2010-02-03 20:28:59 +0000188 return self.resolve_uri(uri)
showardf828c772010-01-25 21:49:42 +0000189
190
jamesren3981f442010-02-16 19:27:59 +0000191 def link(self, query_params=None):
192 return {'href': self.href(query_params=query_params)}
showardf828c772010-01-25 21:49:42 +0000193
194
195 def _query_parameters_response(self):
196 return dict((name, description)
jamesren3981f442010-02-16 19:27:59 +0000197 for name, description in self._query_parameters_accepted())
showardf828c772010-01-25 21:49:42 +0000198
199
200 def _basic_response(self, content):
201 """Construct and return a simple 200 response."""
202 assert isinstance(content, dict)
203 query_parameters = self._query_parameters_response()
204 if query_parameters:
205 content['query_parameters'] = query_parameters
Alex Miller4f341702013-03-25 12:39:12 -0700206 encoded_content = json.dumps(content)
showardf828c772010-01-25 21:49:42 +0000207 return http.HttpResponse(encoded_content,
208 content_type=_JSON_CONTENT_TYPE)
209
210
showardf46ad4c2010-02-03 20:28:59 +0000211 def _decoded_input(self):
212 content_type = self._request.META.get('CONTENT_TYPE',
213 _JSON_CONTENT_TYPE)
214 raw_data = self._request.raw_post_data
showardf828c772010-01-25 21:49:42 +0000215 if content_type == _JSON_CONTENT_TYPE:
216 try:
Alex Miller4f341702013-03-25 12:39:12 -0700217 raw_dict = json.loads(raw_data)
showardf828c772010-01-25 21:49:42 +0000218 except ValueError, exc:
219 raise exceptions.BadRequest('Error decoding request body: '
220 '%s\n%r' % (exc, raw_data))
jamesrencd7a81a2010-04-21 20:39:08 +0000221 if not isinstance(raw_dict, dict):
222 raise exceptions.BadRequest('Expected dict input, got %s: %r' %
223 (type(raw_dict), raw_dict))
showardf828c772010-01-25 21:49:42 +0000224 elif content_type == 'application/x-www-form-urlencoded':
225 cgi_dict = cgi.parse_qs(raw_data) # django won't do this for PUT
226 raw_dict = {}
227 for key, values in cgi_dict.items():
228 value = values[-1] # take last value if multiple were given
229 try:
230 # attempt to parse numbers, booleans and nulls
Alex Miller4f341702013-03-25 12:39:12 -0700231 raw_dict[key] = json.loads(value)
showardf828c772010-01-25 21:49:42 +0000232 except ValueError:
233 # otherwise, leave it as a string
234 raw_dict[key] = value
235 else:
236 raise exceptions.RequestError(415, 'Unsupported media type: %s'
237 % content_type)
238
239 return _InputDict(raw_dict)
240
241
242 def _format_datetime(self, date_time):
243 """Return ISO 8601 string for the given datetime"""
244 if date_time is None:
245 return None
246 timezone_hrs = time.timezone / 60 / 60 # convert seconds to hours
247 if timezone_hrs >= 0:
248 timezone_join = '+'
249 else:
250 timezone_join = '' # minus sign comes from number itself
251 timezone_spec = '%s%s:00' % (timezone_join, timezone_hrs)
252 return date_time.strftime('%Y-%m-%dT%H:%M:%S') + timezone_spec
253
254
255 @classmethod
256 def _check_for_required_fields(cls, input_dict, fields):
257 assert isinstance(fields, (list, tuple)), fields
258 missing_fields = ', '.join(field for field in fields
259 if field not in input_dict)
260 if missing_fields:
261 raise exceptions.BadRequest('Missing input: ' + missing_fields)
262
263
264class Entry(Resource):
showardf828c772010-01-25 21:49:42 +0000265 @classmethod
jamesren3981f442010-02-16 19:27:59 +0000266 def add_query_selectors(cls, query_processor):
267 """Sbuclasses may override this to support querying."""
268 pass
showardf828c772010-01-25 21:49:42 +0000269
270
271 def short_representation(self):
272 return self.link()
273
274
275 def full_representation(self):
276 return self.short_representation()
277
278
showardf46ad4c2010-02-03 20:28:59 +0000279 def get(self):
showardf828c772010-01-25 21:49:42 +0000280 return self._basic_response(self.full_representation())
281
282
showardf46ad4c2010-02-03 20:28:59 +0000283 def put(self):
showardf828c772010-01-25 21:49:42 +0000284 try:
showardf46ad4c2010-02-03 20:28:59 +0000285 self.update(self._decoded_input())
showardf828c772010-01-25 21:49:42 +0000286 except model_logic.ValidationError, exc:
287 raise exceptions.BadRequest('Invalid input: %s' % exc)
288 return self._basic_response(self.full_representation())
289
290
jamesren3981f442010-02-16 19:27:59 +0000291 def _delete_entry(self):
292 raise NotImplementedError
293
294
showardf46ad4c2010-02-03 20:28:59 +0000295 def delete(self):
jamesren3981f442010-02-16 19:27:59 +0000296 self._delete_entry()
showardf828c772010-01-25 21:49:42 +0000297 return http.HttpResponse(status=204) # No content
298
299
300 def create_instance(self, input_dict, containing_collection):
301 raise NotImplementedError
302
303
304 def update(self, input_dict):
305 raise NotImplementedError
306
307
jamesren3981f442010-02-16 19:27:59 +0000308class InstanceEntry(Entry):
309 class NullEntry(object):
310 def link(self):
311 return None
312
313
314 def short_representation(self):
315 return None
316
317
318 _null_entry = NullEntry()
319 _permitted_methods = ('GET', 'PUT', 'DELETE')
320 model = None # subclasses must override this with a Django model class
321
322
323 def __init__(self, request, instance):
324 assert self.model is not None
Dale Curtis8adf7892011-09-08 16:13:36 -0700325 super(InstanceEntry, self).__init__(request)
jamesren3981f442010-02-16 19:27:59 +0000326 self.instance = instance
jamesrencd7a81a2010-04-21 20:39:08 +0000327 self._is_prepared_for_full_representation = False
jamesren3981f442010-02-16 19:27:59 +0000328
329
330 @classmethod
331 def from_optional_instance(cls, request, instance):
332 if instance is None:
333 return cls._null_entry
334 return cls(request, instance)
335
336
337 def _delete_entry(self):
338 self.instance.delete()
339
340
jamesrencd7a81a2010-04-21 20:39:08 +0000341 def full_representation(self):
342 self.prepare_for_full_representation([self])
343 return super(InstanceEntry, self).full_representation()
344
345
346 @classmethod
347 def prepare_for_full_representation(cls, entries):
348 """
349 Prepare the given list of entries to generate full representations.
350
351 This method delegates to _do_prepare_for_full_representation(), which
352 subclasses may override as necessary to do the actual processing. This
353 method also marks the instance as prepared, so it's safe to call this
354 multiple times with the same instance(s) without wasting work.
355 """
356 not_prepared = [entry for entry in entries
357 if not entry._is_prepared_for_full_representation]
358 cls._do_prepare_for_full_representation([entry.instance
359 for entry in not_prepared])
360 for entry in not_prepared:
361 entry._is_prepared_for_full_representation = True
362
363
364 @classmethod
365 def _do_prepare_for_full_representation(cls, instances):
366 """
367 Subclasses may override this to gather data as needed for full
368 representations of the given model instances. Typically, this involves
369 querying over related objects, and this method offers a chance to query
370 for many instances at once, which can provide a great performance
371 benefit.
372 """
373 pass
374
375
showardf828c772010-01-25 21:49:42 +0000376class Collection(Resource):
377 _DEFAULT_ITEMS_PER_PAGE = 50
378
379 _permitted_methods=('GET', 'POST')
380
381 # subclasses must override these
382 queryset = None # or override _fresh_queryset() directly
383 entry_class = None
384
385
showardf46ad4c2010-02-03 20:28:59 +0000386 def __init__(self, request):
387 super(Collection, self).__init__(request)
showardf828c772010-01-25 21:49:42 +0000388 assert self.entry_class is not None
389 if isinstance(self.entry_class, basestring):
390 type(self).entry_class = _resolve_class_path(self.entry_class)
391
jamesren3981f442010-02-16 19:27:59 +0000392 self._query_processor = query_lib.QueryProcessor()
393 self.entry_class.add_query_selectors(self._query_processor)
394
395
396 def _query_parameters_accepted(self):
397 params = [('start_index', 'Index of first member to include'),
jamesrencd7a81a2010-04-21 20:39:08 +0000398 ('items_per_page', 'Number of members to include'),
399 ('full_representations',
400 'True to include full representations of members')]
jamesren3981f442010-02-16 19:27:59 +0000401 for selector in self._query_processor.selectors():
402 params.append((selector.name, selector.doc))
403 return params
showardf828c772010-01-25 21:49:42 +0000404
405
406 def _fresh_queryset(self):
407 assert self.queryset is not None
408 # always copy the queryset before using it to avoid caching
409 return self.queryset.all()
410
411
jamesren3981f442010-02-16 19:27:59 +0000412 def _entry_from_instance(self, instance):
413 return self.entry_class(self._request, instance)
414
415
showardf828c772010-01-25 21:49:42 +0000416 def _representation(self, entry_instances):
jamesrencd7a81a2010-04-21 20:39:08 +0000417 entries = [self._entry_from_instance(instance)
418 for instance in entry_instances]
419
420 want_full_representation = self._read_bool_parameter(
421 'full_representations')
422 if want_full_representation:
423 self.entry_class.prepare_for_full_representation(entries)
424
showardf828c772010-01-25 21:49:42 +0000425 members = []
jamesrencd7a81a2010-04-21 20:39:08 +0000426 for entry in entries:
427 if want_full_representation:
428 rep = entry.full_representation()
429 else:
430 rep = entry.short_representation()
431 members.append(rep)
showardf828c772010-01-25 21:49:42 +0000432
433 rep = self.link()
434 rep.update({'members': members})
435 return rep
436
437
jamesrencd7a81a2010-04-21 20:39:08 +0000438 def _read_bool_parameter(self, name):
439 if name not in self._query_params:
440 return False
441 return (self._query_params[name].lower() == 'true')
442
443
showardf46ad4c2010-02-03 20:28:59 +0000444 def _read_int_parameter(self, name, default):
jamesren3981f442010-02-16 19:27:59 +0000445 if name not in self._query_params:
showardf828c772010-01-25 21:49:42 +0000446 return default
jamesren3981f442010-02-16 19:27:59 +0000447 input_value = self._query_params[name]
showardf828c772010-01-25 21:49:42 +0000448 try:
449 return int(input_value)
450 except ValueError:
451 raise exceptions.BadRequest('Invalid non-numeric value for %s: %r'
452 % (name, input_value))
453
454
showardf46ad4c2010-02-03 20:28:59 +0000455 def _apply_form_query(self, queryset):
showardf828c772010-01-25 21:49:42 +0000456 """Apply any query selectors passed as form variables."""
jamesren3981f442010-02-16 19:27:59 +0000457 for parameter, values in self._query_params.lists():
jamesrencd7a81a2010-04-21 20:39:08 +0000458 if ':' in parameter:
459 parameter, comparison_type = parameter.split(':', 1)
460 else:
461 comparison_type = None
462
showardf828c772010-01-25 21:49:42 +0000463 if not self._query_processor.has_selector(parameter):
464 continue
465 for value in values: # forms keys can have multiple values
jamesrencd7a81a2010-04-21 20:39:08 +0000466 queryset = self._query_processor.apply_selector(
467 queryset, parameter, value,
468 comparison_type=comparison_type)
showardf828c772010-01-25 21:49:42 +0000469 return queryset
470
471
showardf46ad4c2010-02-03 20:28:59 +0000472 def _filtered_queryset(self):
473 return self._apply_form_query(self._fresh_queryset())
showardf828c772010-01-25 21:49:42 +0000474
475
showardf46ad4c2010-02-03 20:28:59 +0000476 def get(self):
477 queryset = self._filtered_queryset()
showardf828c772010-01-25 21:49:42 +0000478
showardf46ad4c2010-02-03 20:28:59 +0000479 items_per_page = self._read_int_parameter('items_per_page',
showardf828c772010-01-25 21:49:42 +0000480 self._DEFAULT_ITEMS_PER_PAGE)
showardf46ad4c2010-02-03 20:28:59 +0000481 start_index = self._read_int_parameter('start_index', 0)
showardf828c772010-01-25 21:49:42 +0000482 page = queryset[start_index:(start_index + items_per_page)]
483
484 rep = self._representation(page)
showardf828c772010-01-25 21:49:42 +0000485 rep.update({'total_results': len(queryset),
486 'start_index': start_index,
jamesren3981f442010-02-16 19:27:59 +0000487 'items_per_page': items_per_page})
showardf828c772010-01-25 21:49:42 +0000488 return self._basic_response(rep)
489
490
491 def full_representation(self):
492 # careful, this rep can be huge for large collections
493 return self._representation(self._fresh_queryset())
494
495
showardf46ad4c2010-02-03 20:28:59 +0000496 def post(self):
497 input_dict = self._decoded_input()
showardf828c772010-01-25 21:49:42 +0000498 try:
499 instance = self.entry_class.create_instance(input_dict, self)
jamesren3981f442010-02-16 19:27:59 +0000500 entry = self._entry_from_instance(instance)
showardf828c772010-01-25 21:49:42 +0000501 entry.update(input_dict)
502 except model_logic.ValidationError, exc:
503 raise exceptions.BadRequest('Invalid input: %s' % exc)
504 # RFC 2616 specifies that we provide the new URI in both the Location
505 # header and the body
506 response = http.HttpResponse(status=201, # Created
507 content=entry.href())
508 response['Location'] = entry.href()
509 return response
510
511
jamesren3981f442010-02-16 19:27:59 +0000512class Relationship(Entry):
513 _permitted_methods = ('GET', 'DELETE')
showardf828c772010-01-25 21:49:42 +0000514
jamesren3981f442010-02-16 19:27:59 +0000515 # subclasses must override this with a dict mapping name to entry class
516 related_classes = None
showardf828c772010-01-25 21:49:42 +0000517
518
jamesren3981f442010-02-16 19:27:59 +0000519 def __init__(self, **kwargs):
520 assert len(self.related_classes) == 2
521 self.entries = dict((name, kwargs[name])
522 for name in self.related_classes)
523 for name in self.related_classes: # sanity check
524 assert isinstance(self.entries[name], self.related_classes[name])
showardf828c772010-01-25 21:49:42 +0000525
jamesren3981f442010-02-16 19:27:59 +0000526 # just grab the request from one of the entries
527 some_entry = self.entries.itervalues().next()
528 super(Relationship, self).__init__(some_entry._request)
showardf828c772010-01-25 21:49:42 +0000529
530
531 @classmethod
jamesren3981f442010-02-16 19:27:59 +0000532 def from_uri_args(cls, request, **kwargs):
533 # kwargs contains URI args for each entry
534 entries = {}
535 for name, entry_class in cls.related_classes.iteritems():
536 entries[name] = entry_class.from_uri_args(request, **kwargs)
537 return cls(**entries)
showardf828c772010-01-25 21:49:42 +0000538
539
540 def _uri_args(self):
jamesren3981f442010-02-16 19:27:59 +0000541 kwargs = {}
542 for name, entry in self.entries.iteritems():
543 kwargs.update(entry._uri_args())
544 return kwargs
545
546
547 def short_representation(self):
548 rep = self.link()
549 for name, entry in self.entries.iteritems():
550 rep[name] = entry.short_representation()
551 return rep
showardf828c772010-01-25 21:49:42 +0000552
553
554 @classmethod
jamesren3981f442010-02-16 19:27:59 +0000555 def _get_related_manager(cls, instance):
556 """Get the related objects manager for the given instance.
showardf828c772010-01-25 21:49:42 +0000557
jamesren3981f442010-02-16 19:27:59 +0000558 The instance must be one of the related classes. This method will
559 return the related manager from that instance to instances of the other
560 related class.
showardf828c772010-01-25 21:49:42 +0000561 """
jamesren3981f442010-02-16 19:27:59 +0000562 this_model = type(instance)
563 models = [entry_class.model for entry_class
564 in cls.related_classes.values()]
565 if isinstance(instance, models[0]):
566 this_model, other_model = models
567 else:
568 other_model, this_model = models
showardf828c772010-01-25 21:49:42 +0000569
jamesren3981f442010-02-16 19:27:59 +0000570 _, field = this_model.objects.determine_relationship(other_model)
571 this_models_fields = (this_model._meta.fields
572 + this_model._meta.many_to_many)
573 if field in this_models_fields:
574 manager_name = field.attname
575 else:
576 # related manager is on other_model, get name of reverse related
577 # manager on this_model
578 manager_name = field.related.get_accessor_name()
579
580 return getattr(instance, manager_name)
showardf828c772010-01-25 21:49:42 +0000581
582
jamesren3981f442010-02-16 19:27:59 +0000583 def _delete_entry(self):
584 # choose order arbitrarily
585 entry, other_entry = self.entries.itervalues()
586 related_manager = self._get_related_manager(entry.instance)
587 related_manager.remove(other_entry.instance)
showardf828c772010-01-25 21:49:42 +0000588
589
jamesren3981f442010-02-16 19:27:59 +0000590 @classmethod
591 def create_instance(cls, input_dict, containing_collection):
592 other_name = containing_collection.unfixed_name
593 cls._check_for_required_fields(input_dict, (other_name,))
594 entry = containing_collection.fixed_entry
595 other_entry = containing_collection.resolve_link(input_dict[other_name])
596 related_manager = cls._get_related_manager(entry.instance)
597 related_manager.add(other_entry.instance)
598 return other_entry.instance
showardf828c772010-01-25 21:49:42 +0000599
600
jamesren3981f442010-02-16 19:27:59 +0000601 def update(self, input_dict):
602 pass
603
604
605class RelationshipCollection(Collection):
606 def __init__(self, request=None, fixed_entry=None):
607 if request is None:
608 request = fixed_entry._request
609 super(RelationshipCollection, self).__init__(request)
610
611 assert issubclass(self.entry_class, Relationship)
612 self.related_classes = self.entry_class.related_classes
613 self.fixed_name = None
614 self.fixed_entry = None
615 self.unfixed_name = None
616 self.related_manager = None
617
618 if fixed_entry is not None:
619 self._set_fixed_entry(fixed_entry)
620 entry_uri_arg = self.fixed_entry._uri_args().values()[0]
621 self._query_params[self.fixed_name] = entry_uri_arg
622
623
624 def _set_fixed_entry(self, entry):
625 """Set the fixed entry for this collection.
626
627 The entry must be an instance of one of the related entry classes. This
628 method must be called before a relationship is used. It gets called
629 either from the constructor (when collections are instantiated from
630 other resource handling code) or from read_query_parameters() (when a
631 request is made directly for the collection.
632 """
633 names = self.related_classes.keys()
634 if isinstance(entry, self.related_classes[names[0]]):
635 self.fixed_name, self.unfixed_name = names
636 else:
637 assert isinstance(entry, self.related_classes[names[1]])
638 self.unfixed_name, self.fixed_name = names
639 self.fixed_entry = entry
640 self.unfixed_class = self.related_classes[self.unfixed_name]
641 self.related_manager = self.entry_class._get_related_manager(
642 entry.instance)
643
644
645 def _query_parameters_accepted(self):
646 return [(name, 'Show relationships for this %s' % entry_class.__name__)
647 for name, entry_class
648 in self.related_classes.iteritems()]
649
650
651 def _resolve_query_param(self, name, uri_arg):
652 entry_class = self.related_classes[name]
653 return entry_class.from_uri_args(self._request, uri_arg)
654
655
656 def read_query_parameters(self, query_params):
657 super(RelationshipCollection, self).read_query_parameters(query_params)
658 if not self._query_params:
659 raise exceptions.BadRequest(
660 'You must specify one of the parameters %s and %s'
661 % tuple(self.related_classes.keys()))
662 query_items = self._query_params.items()
663 fixed_entry = self._resolve_query_param(*query_items[0])
664 self._set_fixed_entry(fixed_entry)
665
666 if len(query_items) > 1:
667 other_fixed_entry = self._resolve_query_param(*query_items[1])
668 self.related_manager = self.related_manager.filter(
669 pk=other_fixed_entry.instance.id)
670
671
672 def _entry_from_instance(self, instance):
673 unfixed_entry = self.unfixed_class(self._request, instance)
674 entries = {self.fixed_name: self.fixed_entry,
675 self.unfixed_name: unfixed_entry}
676 return self.entry_class(**entries)
677
678
679 def _fresh_queryset(self):
680 return self.related_manager.all()