blob: 308b1b3f8648525e64d205f5b55179ed5dbde529 [file] [log] [blame]
showard7c785282008-05-29 19:45:12 +00001"""
2Extensions to Django's model logic.
3"""
4
5from django.db import models as dbmodels, backend, connection
6from django.utils import datastructures
showard09096d82008-07-07 23:20:49 +00007from frontend.afe import readonly_connection
showard7c785282008-05-29 19:45:12 +00008
9class ValidationError(Exception):
jadmanski0afbb632008-06-06 21:10:57 +000010 """\
11 Data validation error in adding or updating an object. The associated
12 value is a dictionary mapping field names to error strings.
13 """
showard7c785282008-05-29 19:45:12 +000014
15
showard09096d82008-07-07 23:20:49 +000016def _wrap_with_readonly(method):
17 def wrapper_method(*args, **kwargs):
18 readonly_connection.connection.set_django_connection()
19 try:
20 return method(*args, **kwargs)
21 finally:
22 readonly_connection.connection.unset_django_connection()
23 wrapper_method.__name__ = method.__name__
24 return wrapper_method
25
26
27def _wrap_generator_with_readonly(generator):
28 """
29 We have to wrap generators specially. Assume it performs
30 the query on the first call to next().
31 """
32 def wrapper_generator(*args, **kwargs):
33 generator_obj = generator(*args, **kwargs)
34 readonly_connection.connection.set_django_connection()
35 try:
36 first_value = generator_obj.next()
37 finally:
38 readonly_connection.connection.unset_django_connection()
39 yield first_value
40
41 while True:
42 yield generator_obj.next()
43
44 wrapper_generator.__name__ = generator.__name__
45 return wrapper_generator
46
47
48def _make_queryset_readonly(queryset):
49 """
50 Wrap all methods that do database queries with a readonly connection.
51 """
52 db_query_methods = ['count', 'get', 'get_or_create', 'latest', 'in_bulk',
53 'delete']
54 for method_name in db_query_methods:
55 method = getattr(queryset, method_name)
56 wrapped_method = _wrap_with_readonly(method)
57 setattr(queryset, method_name, wrapped_method)
58
59 queryset.iterator = _wrap_generator_with_readonly(queryset.iterator)
60
61
62class ReadonlyQuerySet(dbmodels.query.QuerySet):
63 """
64 QuerySet object that performs all database queries with the read-only
65 connection.
66 """
67 def __init__(self, model=None):
68 super(ReadonlyQuerySet, self).__init__(model)
69 _make_queryset_readonly(self)
70
71
72 def values(self, *fields):
73 return self._clone(klass=ReadonlyValuesQuerySet, _fields=fields)
74
75
76class ReadonlyValuesQuerySet(dbmodels.query.ValuesQuerySet):
77 def __init__(self, model=None):
78 super(ReadonlyValuesQuerySet, self).__init__(model)
79 _make_queryset_readonly(self)
80
81
showard7c785282008-05-29 19:45:12 +000082class ExtendedManager(dbmodels.Manager):
jadmanski0afbb632008-06-06 21:10:57 +000083 """\
84 Extended manager supporting subquery filtering.
85 """
showard7c785282008-05-29 19:45:12 +000086
showard08f981b2008-06-24 21:59:03 +000087 class _CustomJoinQ(dbmodels.Q):
jadmanski0afbb632008-06-06 21:10:57 +000088 """
showard08f981b2008-06-24 21:59:03 +000089 Django "Q" object supporting a custom suffix for join aliases.See
90 filter_custom_join() for why this can be useful.
91 """
showard7c785282008-05-29 19:45:12 +000092
showard08f981b2008-06-24 21:59:03 +000093 def __init__(self, join_suffix, **kwargs):
94 super(ExtendedManager._CustomJoinQ, self).__init__(**kwargs)
95 self._join_suffix = join_suffix
showard7c785282008-05-29 19:45:12 +000096
showard08f981b2008-06-24 21:59:03 +000097
98 @staticmethod
99 def _substitute_aliases(renamed_aliases, condition):
100 for old_alias, new_alias in renamed_aliases:
101 condition = condition.replace(backend.quote_name(old_alias),
102 backend.quote_name(new_alias))
103 return condition
104
105
106 @staticmethod
107 def _unquote_name(name):
108 'This may be MySQL specific'
109 if backend.quote_name(name) == name:
110 return name[1:-1]
111 return name
showard7c785282008-05-29 19:45:12 +0000112
113
jadmanski0afbb632008-06-06 21:10:57 +0000114 def get_sql(self, opts):
showard08f981b2008-06-24 21:59:03 +0000115 joins, where, params = (
116 super(ExtendedManager._CustomJoinQ, self).get_sql(opts))
117
118 new_joins = datastructures.SortedDict()
119
120 # rename all join aliases and correct references in later joins
121 renamed_tables = []
122 # using iteritems seems to mess up the ordering here
123 for alias, (table, join_type, condition) in joins.items():
124 alias = self._unquote_name(alias)
125 new_alias = alias + self._join_suffix
126 renamed_tables.append((alias, new_alias))
127 condition = self._substitute_aliases(renamed_tables, condition)
128 new_alias = backend.quote_name(new_alias)
129 new_joins[new_alias] = (table, join_type, condition)
130
131 # correct references in where
132 new_where = []
133 for clause in where:
134 new_where.append(
135 self._substitute_aliases(renamed_tables, clause))
136
137 return new_joins, new_where, params
showard7c785282008-05-29 19:45:12 +0000138
139
showard08f981b2008-06-24 21:59:03 +0000140 def filter_custom_join(self, join_suffix, **kwargs):
jadmanski0afbb632008-06-06 21:10:57 +0000141 """
showard08f981b2008-06-24 21:59:03 +0000142 Just like Django filter(), but allows the user to specify a custom
143 suffix for the join aliases involves in the filter. This makes it
144 possible to join against a table multiple times (as long as a different
145 suffix is used each time), which is necessary for certain queries.
jadmanski0afbb632008-06-06 21:10:57 +0000146 """
showard08f981b2008-06-24 21:59:03 +0000147 filter_object = self._CustomJoinQ(join_suffix, **kwargs)
148 return self.complex_filter(filter_object)
showard7c785282008-05-29 19:45:12 +0000149
150
showard5ef36e92008-07-02 16:37:09 +0000151 @staticmethod
152 def _get_quoted_field(table, field):
153 return (backend.quote_name(table) + '.' + backend.quote_name(field))
154
155
156 def _get_key_on_this_table(self, key_field=None):
157 if key_field is None:
158 # default to primary key
159 key_field = self.model._meta.pk.column
160 return self._get_quoted_field(self.model._meta.db_table, key_field)
161
162
163
showard7c785282008-05-29 19:45:12 +0000164class ValidObjectsManager(ExtendedManager):
jadmanski0afbb632008-06-06 21:10:57 +0000165 """
166 Manager returning only objects with invalid=False.
167 """
168 def get_query_set(self):
169 queryset = super(ValidObjectsManager, self).get_query_set()
170 return queryset.filter(invalid=False)
showard7c785282008-05-29 19:45:12 +0000171
172
173class ModelExtensions(object):
jadmanski0afbb632008-06-06 21:10:57 +0000174 """\
175 Mixin with convenience functions for models, built on top of the
176 default Django model functions.
177 """
178 # TODO: at least some of these functions really belong in a custom
179 # Manager class
showard7c785282008-05-29 19:45:12 +0000180
jadmanski0afbb632008-06-06 21:10:57 +0000181 field_dict = None
182 # subclasses should override if they want to support smart_get() by name
183 name_field = None
showard7c785282008-05-29 19:45:12 +0000184
185
jadmanski0afbb632008-06-06 21:10:57 +0000186 @classmethod
187 def get_field_dict(cls):
188 if cls.field_dict is None:
189 cls.field_dict = {}
190 for field in cls._meta.fields:
191 cls.field_dict[field.name] = field
192 return cls.field_dict
showard7c785282008-05-29 19:45:12 +0000193
194
jadmanski0afbb632008-06-06 21:10:57 +0000195 @classmethod
196 def clean_foreign_keys(cls, data):
197 """\
198 -Convert foreign key fields in data from <field>_id to just
199 <field>.
200 -replace foreign key objects with their IDs
201 This method modifies data in-place.
202 """
203 for field in cls._meta.fields:
204 if not field.rel:
205 continue
206 if (field.attname != field.name and
207 field.attname in data):
208 data[field.name] = data[field.attname]
209 del data[field.attname]
210 value = data[field.name]
211 if isinstance(value, dbmodels.Model):
212 data[field.name] = value.id
showard7c785282008-05-29 19:45:12 +0000213
214
jadmanski0afbb632008-06-06 21:10:57 +0000215 # TODO(showard) - is there a way to not have to do this?
216 @classmethod
217 def provide_default_values(cls, data):
218 """\
219 Provide default values for fields with default values which have
220 nothing passed in.
showard7c785282008-05-29 19:45:12 +0000221
jadmanski0afbb632008-06-06 21:10:57 +0000222 For CharField and TextField fields with "blank=True", if nothing
223 is passed, we fill in an empty string value, even if there's no
224 default set.
225 """
226 new_data = dict(data)
227 field_dict = cls.get_field_dict()
228 for name, obj in field_dict.iteritems():
229 if data.get(name) is not None:
230 continue
231 if obj.default is not dbmodels.fields.NOT_PROVIDED:
232 new_data[name] = obj.default
233 elif (isinstance(obj, dbmodels.CharField) or
234 isinstance(obj, dbmodels.TextField)):
235 new_data[name] = ''
236 return new_data
showard7c785282008-05-29 19:45:12 +0000237
238
jadmanski0afbb632008-06-06 21:10:57 +0000239 @classmethod
240 def convert_human_readable_values(cls, data, to_human_readable=False):
241 """\
242 Performs conversions on user-supplied field data, to make it
243 easier for users to pass human-readable data.
showard7c785282008-05-29 19:45:12 +0000244
jadmanski0afbb632008-06-06 21:10:57 +0000245 For all fields that have choice sets, convert their values
246 from human-readable strings to enum values, if necessary. This
247 allows users to pass strings instead of the corresponding
248 integer values.
showard7c785282008-05-29 19:45:12 +0000249
jadmanski0afbb632008-06-06 21:10:57 +0000250 For all foreign key fields, call smart_get with the supplied
251 data. This allows the user to pass either an ID value or
252 the name of the object as a string.
showard7c785282008-05-29 19:45:12 +0000253
jadmanski0afbb632008-06-06 21:10:57 +0000254 If to_human_readable=True, perform the inverse - i.e. convert
255 numeric values to human readable values.
showard7c785282008-05-29 19:45:12 +0000256
jadmanski0afbb632008-06-06 21:10:57 +0000257 This method modifies data in-place.
258 """
259 field_dict = cls.get_field_dict()
260 for field_name in data:
261 if data[field_name] is None:
262 continue
263 field_obj = field_dict[field_name]
264 # convert enum values
265 if field_obj.choices:
266 for choice_data in field_obj.choices:
267 # choice_data is (value, name)
268 if to_human_readable:
269 from_val, to_val = choice_data
270 else:
271 to_val, from_val = choice_data
272 if from_val == data[field_name]:
273 data[field_name] = to_val
274 break
275 # convert foreign key values
276 elif field_obj.rel:
277 dest_obj = field_obj.rel.to.smart_get(
278 data[field_name])
279 if (to_human_readable and
280 dest_obj.name_field is not None):
281 data[field_name] = (
282 getattr(dest_obj,
283 dest_obj.name_field))
284 else:
285 data[field_name] = dest_obj.id
showard7c785282008-05-29 19:45:12 +0000286
287
jadmanski0afbb632008-06-06 21:10:57 +0000288 @classmethod
289 def validate_field_names(cls, data):
290 'Checks for extraneous fields in data.'
291 errors = {}
292 field_dict = cls.get_field_dict()
293 for field_name in data:
294 if field_name not in field_dict:
295 errors[field_name] = 'No field of this name'
296 return errors
showard7c785282008-05-29 19:45:12 +0000297
298
jadmanski0afbb632008-06-06 21:10:57 +0000299 @classmethod
300 def prepare_data_args(cls, data, kwargs):
301 'Common preparation for add_object and update_object'
302 data = dict(data) # don't modify the default keyword arg
303 data.update(kwargs)
304 # must check for extraneous field names here, while we have the
305 # data in a dict
306 errors = cls.validate_field_names(data)
307 if errors:
308 raise ValidationError(errors)
309 cls.convert_human_readable_values(data)
310 return data
showard7c785282008-05-29 19:45:12 +0000311
312
jadmanski0afbb632008-06-06 21:10:57 +0000313 def validate_unique(self):
314 """\
315 Validate that unique fields are unique. Django manipulators do
316 this too, but they're a huge pain to use manually. Trust me.
317 """
318 errors = {}
319 cls = type(self)
320 field_dict = self.get_field_dict()
321 manager = cls.get_valid_manager()
322 for field_name, field_obj in field_dict.iteritems():
323 if not field_obj.unique:
324 continue
showard7c785282008-05-29 19:45:12 +0000325
jadmanski0afbb632008-06-06 21:10:57 +0000326 value = getattr(self, field_name)
327 existing_objs = manager.filter(**{field_name : value})
328 num_existing = existing_objs.count()
showard7c785282008-05-29 19:45:12 +0000329
jadmanski0afbb632008-06-06 21:10:57 +0000330 if num_existing == 0:
331 continue
332 if num_existing == 1 and existing_objs[0].id == self.id:
333 continue
334 errors[field_name] = (
335 'This value must be unique (%s)' % (value))
336 return errors
showard7c785282008-05-29 19:45:12 +0000337
338
jadmanski0afbb632008-06-06 21:10:57 +0000339 def do_validate(self):
340 errors = self.validate()
341 unique_errors = self.validate_unique()
342 for field_name, error in unique_errors.iteritems():
343 errors.setdefault(field_name, error)
344 if errors:
345 raise ValidationError(errors)
showard7c785282008-05-29 19:45:12 +0000346
347
jadmanski0afbb632008-06-06 21:10:57 +0000348 # actually (externally) useful methods follow
showard7c785282008-05-29 19:45:12 +0000349
jadmanski0afbb632008-06-06 21:10:57 +0000350 @classmethod
351 def add_object(cls, data={}, **kwargs):
352 """\
353 Returns a new object created with the given data (a dictionary
354 mapping field names to values). Merges any extra keyword args
355 into data.
356 """
357 data = cls.prepare_data_args(data, kwargs)
358 data = cls.provide_default_values(data)
359 obj = cls(**data)
360 obj.do_validate()
361 obj.save()
362 return obj
showard7c785282008-05-29 19:45:12 +0000363
364
jadmanski0afbb632008-06-06 21:10:57 +0000365 def update_object(self, data={}, **kwargs):
366 """\
367 Updates the object with the given data (a dictionary mapping
368 field names to values). Merges any extra keyword args into
369 data.
370 """
371 data = self.prepare_data_args(data, kwargs)
372 for field_name, value in data.iteritems():
373 if value is not None:
374 setattr(self, field_name, value)
375 self.do_validate()
376 self.save()
showard7c785282008-05-29 19:45:12 +0000377
378
jadmanski0afbb632008-06-06 21:10:57 +0000379 @classmethod
380 def query_objects(cls, filter_data, valid_only=True):
381 """\
382 Returns a QuerySet object for querying the given model_class
383 with the given filter_data. Optional special arguments in
384 filter_data include:
385 -query_start: index of first return to return
386 -query_limit: maximum number of results to return
387 -sort_by: list of fields to sort on. prefixing a '-' onto a
388 field name changes the sort to descending order.
389 -extra_args: keyword args to pass to query.extra() (see Django
390 DB layer documentation)
391 -extra_where: extra WHERE clause to append
392 """
393 query_start = filter_data.pop('query_start', None)
394 query_limit = filter_data.pop('query_limit', None)
395 if query_start and not query_limit:
396 raise ValueError('Cannot pass query_start without '
397 'query_limit')
398 sort_by = filter_data.pop('sort_by', [])
399 extra_args = filter_data.pop('extra_args', {})
400 extra_where = filter_data.pop('extra_where', None)
401 if extra_where:
402 extra_args.setdefault('where', []).append(extra_where)
showard7c785282008-05-29 19:45:12 +0000403
jadmanski0afbb632008-06-06 21:10:57 +0000404 # filters
405 query_dict = {}
406 for field, value in filter_data.iteritems():
407 query_dict[field] = value
408 if valid_only:
409 manager = cls.get_valid_manager()
410 else:
411 manager = cls.objects
412 query = manager.filter(**query_dict).distinct()
showard7c785282008-05-29 19:45:12 +0000413
jadmanski0afbb632008-06-06 21:10:57 +0000414 # other arguments
415 if extra_args:
416 query = query.extra(**extra_args)
showard09096d82008-07-07 23:20:49 +0000417 query = query._clone(klass=ReadonlyQuerySet)
showard7c785282008-05-29 19:45:12 +0000418
jadmanski0afbb632008-06-06 21:10:57 +0000419 # sorting + paging
420 assert isinstance(sort_by, list) or isinstance(sort_by, tuple)
421 query = query.order_by(*sort_by)
422 if query_start is not None and query_limit is not None:
423 query_limit += query_start
424 return query[query_start:query_limit]
showard7c785282008-05-29 19:45:12 +0000425
426
jadmanski0afbb632008-06-06 21:10:57 +0000427 @classmethod
428 def query_count(cls, filter_data):
429 """\
430 Like query_objects, but retreive only the count of results.
431 """
432 filter_data.pop('query_start', None)
433 filter_data.pop('query_limit', None)
434 return cls.query_objects(filter_data).count()
showard7c785282008-05-29 19:45:12 +0000435
436
jadmanski0afbb632008-06-06 21:10:57 +0000437 @classmethod
438 def clean_object_dicts(cls, field_dicts):
439 """\
440 Take a list of dicts corresponding to object (as returned by
441 query.values()) and clean the data to be more suitable for
442 returning to the user.
443 """
444 for i in range(len(field_dicts)):
445 cls.clean_foreign_keys(field_dicts[i])
446 cls.convert_human_readable_values(
447 field_dicts[i], to_human_readable=True)
showard7c785282008-05-29 19:45:12 +0000448
449
jadmanski0afbb632008-06-06 21:10:57 +0000450 @classmethod
451 def list_objects(cls, filter_data):
452 """\
453 Like query_objects, but return a list of dictionaries.
454 """
455 query = cls.query_objects(filter_data)
456 field_dicts = list(query.values())
457 cls.clean_object_dicts(field_dicts)
458 return field_dicts
showard7c785282008-05-29 19:45:12 +0000459
460
jadmanski0afbb632008-06-06 21:10:57 +0000461 @classmethod
462 def smart_get(cls, *args, **kwargs):
463 """\
464 smart_get(integer) -> get object by ID
465 smart_get(string) -> get object by name_field
466 smart_get(keyword args) -> normal ModelClass.objects.get()
467 """
468 assert bool(args) ^ bool(kwargs)
469 if args:
470 assert len(args) == 1
471 arg = args[0]
472 if isinstance(arg, int) or isinstance(arg, long):
473 return cls.objects.get(id=arg)
474 if isinstance(arg, str) or isinstance(arg, unicode):
475 return cls.objects.get(
476 **{cls.name_field : arg})
477 raise ValueError(
478 'Invalid positional argument: %s (%s)' % (
479 str(arg), type(arg)))
480 return cls.objects.get(**kwargs)
showard7c785282008-05-29 19:45:12 +0000481
482
jadmanski0afbb632008-06-06 21:10:57 +0000483 def get_object_dict(self):
484 """\
485 Return a dictionary mapping fields to this object's values.
486 """
487 object_dict = dict((field_name, getattr(self, field_name))
488 for field_name
489 in self.get_field_dict().iterkeys())
490 self.clean_object_dicts([object_dict])
491 return object_dict
showard7c785282008-05-29 19:45:12 +0000492
493
jadmanski0afbb632008-06-06 21:10:57 +0000494 @classmethod
495 def get_valid_manager(cls):
496 return cls.objects
showard7c785282008-05-29 19:45:12 +0000497
498
499class ModelWithInvalid(ModelExtensions):
jadmanski0afbb632008-06-06 21:10:57 +0000500 """
501 Overrides model methods save() and delete() to support invalidation in
502 place of actual deletion. Subclasses must have a boolean "invalid"
503 field.
504 """
showard7c785282008-05-29 19:45:12 +0000505
jadmanski0afbb632008-06-06 21:10:57 +0000506 def save(self):
507 # see if this object was previously added and invalidated
508 my_name = getattr(self, self.name_field)
509 filters = {self.name_field : my_name, 'invalid' : True}
510 try:
511 old_object = self.__class__.objects.get(**filters)
512 except self.DoesNotExist:
513 # no existing object
514 super(ModelWithInvalid, self).save()
515 return
showard7c785282008-05-29 19:45:12 +0000516
jadmanski0afbb632008-06-06 21:10:57 +0000517 self.id = old_object.id
518 super(ModelWithInvalid, self).save()
showard7c785282008-05-29 19:45:12 +0000519
520
jadmanski0afbb632008-06-06 21:10:57 +0000521 def clean_object(self):
522 """
523 This method is called when an object is marked invalid.
524 Subclasses should override this to clean up relationships that
525 should no longer exist if the object were deleted."""
526 pass
showard7c785282008-05-29 19:45:12 +0000527
528
jadmanski0afbb632008-06-06 21:10:57 +0000529 def delete(self):
530 assert not self.invalid
531 self.invalid = True
532 self.save()
533 self.clean_object()
showard7c785282008-05-29 19:45:12 +0000534
535
jadmanski0afbb632008-06-06 21:10:57 +0000536 @classmethod
537 def get_valid_manager(cls):
538 return cls.valid_objects
showard7c785282008-05-29 19:45:12 +0000539
540
jadmanski0afbb632008-06-06 21:10:57 +0000541 class Manipulator(object):
542 """
543 Force default manipulators to look only at valid objects -
544 otherwise they will match against invalid objects when checking
545 uniqueness.
546 """
547 @classmethod
548 def _prepare(cls, model):
549 super(ModelWithInvalid.Manipulator, cls)._prepare(model)
550 cls.manager = model.valid_objects