blob: c683dcacf10b68e99ff4a0792c515cbb3e899703 [file] [log] [blame]
showard7c785282008-05-29 19:45:12 +00001"""
2Extensions to Django's model logic.
3"""
4
5from django.db import models as dbmodels, backend, connection
6from django.utils import datastructures
showard09096d82008-07-07 23:20:49 +00007from frontend.afe import readonly_connection
showard7c785282008-05-29 19:45:12 +00008
9class ValidationError(Exception):
jadmanski0afbb632008-06-06 21:10:57 +000010 """\
11 Data validation error in adding or updating an object. The associated
12 value is a dictionary mapping field names to error strings.
13 """
showard7c785282008-05-29 19:45:12 +000014
15
showard09096d82008-07-07 23:20:49 +000016def _wrap_with_readonly(method):
17 def wrapper_method(*args, **kwargs):
18 readonly_connection.connection.set_django_connection()
19 try:
20 return method(*args, **kwargs)
21 finally:
22 readonly_connection.connection.unset_django_connection()
23 wrapper_method.__name__ = method.__name__
24 return wrapper_method
25
26
27def _wrap_generator_with_readonly(generator):
28 """
29 We have to wrap generators specially. Assume it performs
30 the query on the first call to next().
31 """
32 def wrapper_generator(*args, **kwargs):
33 generator_obj = generator(*args, **kwargs)
34 readonly_connection.connection.set_django_connection()
35 try:
36 first_value = generator_obj.next()
37 finally:
38 readonly_connection.connection.unset_django_connection()
39 yield first_value
40
41 while True:
42 yield generator_obj.next()
43
44 wrapper_generator.__name__ = generator.__name__
45 return wrapper_generator
46
47
48def _make_queryset_readonly(queryset):
49 """
50 Wrap all methods that do database queries with a readonly connection.
51 """
52 db_query_methods = ['count', 'get', 'get_or_create', 'latest', 'in_bulk',
53 'delete']
54 for method_name in db_query_methods:
55 method = getattr(queryset, method_name)
56 wrapped_method = _wrap_with_readonly(method)
57 setattr(queryset, method_name, wrapped_method)
58
59 queryset.iterator = _wrap_generator_with_readonly(queryset.iterator)
60
61
62class ReadonlyQuerySet(dbmodels.query.QuerySet):
63 """
64 QuerySet object that performs all database queries with the read-only
65 connection.
66 """
67 def __init__(self, model=None):
68 super(ReadonlyQuerySet, self).__init__(model)
69 _make_queryset_readonly(self)
70
71
72 def values(self, *fields):
73 return self._clone(klass=ReadonlyValuesQuerySet, _fields=fields)
74
75
76class ReadonlyValuesQuerySet(dbmodels.query.ValuesQuerySet):
77 def __init__(self, model=None):
78 super(ReadonlyValuesQuerySet, self).__init__(model)
79 _make_queryset_readonly(self)
80
81
showard7c785282008-05-29 19:45:12 +000082class ExtendedManager(dbmodels.Manager):
jadmanski0afbb632008-06-06 21:10:57 +000083 """\
84 Extended manager supporting subquery filtering.
85 """
showard7c785282008-05-29 19:45:12 +000086
showard08f981b2008-06-24 21:59:03 +000087 class _CustomJoinQ(dbmodels.Q):
jadmanski0afbb632008-06-06 21:10:57 +000088 """
showard08f981b2008-06-24 21:59:03 +000089 Django "Q" object supporting a custom suffix for join aliases.See
90 filter_custom_join() for why this can be useful.
91 """
showard7c785282008-05-29 19:45:12 +000092
showard08f981b2008-06-24 21:59:03 +000093 def __init__(self, join_suffix, **kwargs):
94 super(ExtendedManager._CustomJoinQ, self).__init__(**kwargs)
95 self._join_suffix = join_suffix
showard7c785282008-05-29 19:45:12 +000096
showard08f981b2008-06-24 21:59:03 +000097
98 @staticmethod
99 def _substitute_aliases(renamed_aliases, condition):
100 for old_alias, new_alias in renamed_aliases:
101 condition = condition.replace(backend.quote_name(old_alias),
102 backend.quote_name(new_alias))
103 return condition
104
105
106 @staticmethod
107 def _unquote_name(name):
108 'This may be MySQL specific'
109 if backend.quote_name(name) == name:
110 return name[1:-1]
111 return name
showard7c785282008-05-29 19:45:12 +0000112
113
jadmanski0afbb632008-06-06 21:10:57 +0000114 def get_sql(self, opts):
showard08f981b2008-06-24 21:59:03 +0000115 joins, where, params = (
116 super(ExtendedManager._CustomJoinQ, self).get_sql(opts))
117
118 new_joins = datastructures.SortedDict()
119
120 # rename all join aliases and correct references in later joins
121 renamed_tables = []
122 # using iteritems seems to mess up the ordering here
123 for alias, (table, join_type, condition) in joins.items():
124 alias = self._unquote_name(alias)
125 new_alias = alias + self._join_suffix
126 renamed_tables.append((alias, new_alias))
127 condition = self._substitute_aliases(renamed_tables, condition)
128 new_alias = backend.quote_name(new_alias)
129 new_joins[new_alias] = (table, join_type, condition)
130
131 # correct references in where
132 new_where = []
133 for clause in where:
134 new_where.append(
135 self._substitute_aliases(renamed_tables, clause))
136
137 return new_joins, new_where, params
showard7c785282008-05-29 19:45:12 +0000138
139
showard08f981b2008-06-24 21:59:03 +0000140 def filter_custom_join(self, join_suffix, **kwargs):
jadmanski0afbb632008-06-06 21:10:57 +0000141 """
showard08f981b2008-06-24 21:59:03 +0000142 Just like Django filter(), but allows the user to specify a custom
143 suffix for the join aliases involves in the filter. This makes it
144 possible to join against a table multiple times (as long as a different
145 suffix is used each time), which is necessary for certain queries.
jadmanski0afbb632008-06-06 21:10:57 +0000146 """
showard08f981b2008-06-24 21:59:03 +0000147 filter_object = self._CustomJoinQ(join_suffix, **kwargs)
148 return self.complex_filter(filter_object)
showard7c785282008-05-29 19:45:12 +0000149
150
showard5ef36e92008-07-02 16:37:09 +0000151 @staticmethod
152 def _get_quoted_field(table, field):
153 return (backend.quote_name(table) + '.' + backend.quote_name(field))
154
155
showard7c199df2008-10-03 10:17:15 +0000156 def get_key_on_this_table(self, key_field=None):
showard5ef36e92008-07-02 16:37:09 +0000157 if key_field is None:
158 # default to primary key
159 key_field = self.model._meta.pk.column
160 return self._get_quoted_field(self.model._meta.db_table, key_field)
161
162
163
showard7c785282008-05-29 19:45:12 +0000164class ValidObjectsManager(ExtendedManager):
jadmanski0afbb632008-06-06 21:10:57 +0000165 """
166 Manager returning only objects with invalid=False.
167 """
168 def get_query_set(self):
169 queryset = super(ValidObjectsManager, self).get_query_set()
170 return queryset.filter(invalid=False)
showard7c785282008-05-29 19:45:12 +0000171
172
173class ModelExtensions(object):
jadmanski0afbb632008-06-06 21:10:57 +0000174 """\
175 Mixin with convenience functions for models, built on top of the
176 default Django model functions.
177 """
178 # TODO: at least some of these functions really belong in a custom
179 # Manager class
showard7c785282008-05-29 19:45:12 +0000180
jadmanski0afbb632008-06-06 21:10:57 +0000181 field_dict = None
182 # subclasses should override if they want to support smart_get() by name
183 name_field = None
showard7c785282008-05-29 19:45:12 +0000184
185
jadmanski0afbb632008-06-06 21:10:57 +0000186 @classmethod
187 def get_field_dict(cls):
188 if cls.field_dict is None:
189 cls.field_dict = {}
190 for field in cls._meta.fields:
191 cls.field_dict[field.name] = field
192 return cls.field_dict
showard7c785282008-05-29 19:45:12 +0000193
194
jadmanski0afbb632008-06-06 21:10:57 +0000195 @classmethod
196 def clean_foreign_keys(cls, data):
197 """\
198 -Convert foreign key fields in data from <field>_id to just
199 <field>.
200 -replace foreign key objects with their IDs
201 This method modifies data in-place.
202 """
203 for field in cls._meta.fields:
204 if not field.rel:
205 continue
206 if (field.attname != field.name and
207 field.attname in data):
208 data[field.name] = data[field.attname]
209 del data[field.attname]
showarde732ee72008-09-23 19:15:43 +0000210 if field.name not in data:
211 continue
jadmanski0afbb632008-06-06 21:10:57 +0000212 value = data[field.name]
213 if isinstance(value, dbmodels.Model):
214 data[field.name] = value.id
showard7c785282008-05-29 19:45:12 +0000215
216
jadmanski0afbb632008-06-06 21:10:57 +0000217 # TODO(showard) - is there a way to not have to do this?
218 @classmethod
219 def provide_default_values(cls, data):
220 """\
221 Provide default values for fields with default values which have
222 nothing passed in.
showard7c785282008-05-29 19:45:12 +0000223
jadmanski0afbb632008-06-06 21:10:57 +0000224 For CharField and TextField fields with "blank=True", if nothing
225 is passed, we fill in an empty string value, even if there's no
226 default set.
227 """
228 new_data = dict(data)
229 field_dict = cls.get_field_dict()
230 for name, obj in field_dict.iteritems():
231 if data.get(name) is not None:
232 continue
233 if obj.default is not dbmodels.fields.NOT_PROVIDED:
234 new_data[name] = obj.default
235 elif (isinstance(obj, dbmodels.CharField) or
236 isinstance(obj, dbmodels.TextField)):
237 new_data[name] = ''
238 return new_data
showard7c785282008-05-29 19:45:12 +0000239
240
jadmanski0afbb632008-06-06 21:10:57 +0000241 @classmethod
242 def convert_human_readable_values(cls, data, to_human_readable=False):
243 """\
244 Performs conversions on user-supplied field data, to make it
245 easier for users to pass human-readable data.
showard7c785282008-05-29 19:45:12 +0000246
jadmanski0afbb632008-06-06 21:10:57 +0000247 For all fields that have choice sets, convert their values
248 from human-readable strings to enum values, if necessary. This
249 allows users to pass strings instead of the corresponding
250 integer values.
showard7c785282008-05-29 19:45:12 +0000251
jadmanski0afbb632008-06-06 21:10:57 +0000252 For all foreign key fields, call smart_get with the supplied
253 data. This allows the user to pass either an ID value or
254 the name of the object as a string.
showard7c785282008-05-29 19:45:12 +0000255
jadmanski0afbb632008-06-06 21:10:57 +0000256 If to_human_readable=True, perform the inverse - i.e. convert
257 numeric values to human readable values.
showard7c785282008-05-29 19:45:12 +0000258
jadmanski0afbb632008-06-06 21:10:57 +0000259 This method modifies data in-place.
260 """
261 field_dict = cls.get_field_dict()
262 for field_name in data:
showarde732ee72008-09-23 19:15:43 +0000263 if field_name not in field_dict or data[field_name] is None:
jadmanski0afbb632008-06-06 21:10:57 +0000264 continue
265 field_obj = field_dict[field_name]
266 # convert enum values
267 if field_obj.choices:
268 for choice_data in field_obj.choices:
269 # choice_data is (value, name)
270 if to_human_readable:
271 from_val, to_val = choice_data
272 else:
273 to_val, from_val = choice_data
274 if from_val == data[field_name]:
275 data[field_name] = to_val
276 break
277 # convert foreign key values
278 elif field_obj.rel:
279 dest_obj = field_obj.rel.to.smart_get(
280 data[field_name])
281 if (to_human_readable and
282 dest_obj.name_field is not None):
283 data[field_name] = (
284 getattr(dest_obj,
285 dest_obj.name_field))
286 else:
showarde732ee72008-09-23 19:15:43 +0000287 data[field_name] = dest_obj._get_pk_val()
showard7c785282008-05-29 19:45:12 +0000288
289
jadmanski0afbb632008-06-06 21:10:57 +0000290 @classmethod
291 def validate_field_names(cls, data):
292 'Checks for extraneous fields in data.'
293 errors = {}
294 field_dict = cls.get_field_dict()
295 for field_name in data:
296 if field_name not in field_dict:
297 errors[field_name] = 'No field of this name'
298 return errors
showard7c785282008-05-29 19:45:12 +0000299
300
jadmanski0afbb632008-06-06 21:10:57 +0000301 @classmethod
302 def prepare_data_args(cls, data, kwargs):
303 'Common preparation for add_object and update_object'
304 data = dict(data) # don't modify the default keyword arg
305 data.update(kwargs)
306 # must check for extraneous field names here, while we have the
307 # data in a dict
308 errors = cls.validate_field_names(data)
309 if errors:
310 raise ValidationError(errors)
311 cls.convert_human_readable_values(data)
312 return data
showard7c785282008-05-29 19:45:12 +0000313
314
jadmanski0afbb632008-06-06 21:10:57 +0000315 def validate_unique(self):
316 """\
317 Validate that unique fields are unique. Django manipulators do
318 this too, but they're a huge pain to use manually. Trust me.
319 """
320 errors = {}
321 cls = type(self)
322 field_dict = self.get_field_dict()
323 manager = cls.get_valid_manager()
324 for field_name, field_obj in field_dict.iteritems():
325 if not field_obj.unique:
326 continue
showard7c785282008-05-29 19:45:12 +0000327
jadmanski0afbb632008-06-06 21:10:57 +0000328 value = getattr(self, field_name)
329 existing_objs = manager.filter(**{field_name : value})
330 num_existing = existing_objs.count()
showard7c785282008-05-29 19:45:12 +0000331
jadmanski0afbb632008-06-06 21:10:57 +0000332 if num_existing == 0:
333 continue
334 if num_existing == 1 and existing_objs[0].id == self.id:
335 continue
336 errors[field_name] = (
337 'This value must be unique (%s)' % (value))
338 return errors
showard7c785282008-05-29 19:45:12 +0000339
340
jadmanski0afbb632008-06-06 21:10:57 +0000341 def do_validate(self):
342 errors = self.validate()
343 unique_errors = self.validate_unique()
344 for field_name, error in unique_errors.iteritems():
345 errors.setdefault(field_name, error)
346 if errors:
347 raise ValidationError(errors)
showard7c785282008-05-29 19:45:12 +0000348
349
jadmanski0afbb632008-06-06 21:10:57 +0000350 # actually (externally) useful methods follow
showard7c785282008-05-29 19:45:12 +0000351
jadmanski0afbb632008-06-06 21:10:57 +0000352 @classmethod
353 def add_object(cls, data={}, **kwargs):
354 """\
355 Returns a new object created with the given data (a dictionary
356 mapping field names to values). Merges any extra keyword args
357 into data.
358 """
359 data = cls.prepare_data_args(data, kwargs)
360 data = cls.provide_default_values(data)
361 obj = cls(**data)
362 obj.do_validate()
363 obj.save()
364 return obj
showard7c785282008-05-29 19:45:12 +0000365
366
jadmanski0afbb632008-06-06 21:10:57 +0000367 def update_object(self, data={}, **kwargs):
368 """\
369 Updates the object with the given data (a dictionary mapping
370 field names to values). Merges any extra keyword args into
371 data.
372 """
373 data = self.prepare_data_args(data, kwargs)
374 for field_name, value in data.iteritems():
375 if value is not None:
376 setattr(self, field_name, value)
377 self.do_validate()
378 self.save()
showard7c785282008-05-29 19:45:12 +0000379
380
showard7c199df2008-10-03 10:17:15 +0000381 @staticmethod
382 def escape_user_sql(sql):
383 return sql.replace('%', '%%')
384
385
jadmanski0afbb632008-06-06 21:10:57 +0000386 @classmethod
showard7ac7b7a2008-07-21 20:24:29 +0000387 def query_objects(cls, filter_data, valid_only=True, initial_query=None):
jadmanski0afbb632008-06-06 21:10:57 +0000388 """\
389 Returns a QuerySet object for querying the given model_class
390 with the given filter_data. Optional special arguments in
391 filter_data include:
392 -query_start: index of first return to return
393 -query_limit: maximum number of results to return
394 -sort_by: list of fields to sort on. prefixing a '-' onto a
395 field name changes the sort to descending order.
396 -extra_args: keyword args to pass to query.extra() (see Django
397 DB layer documentation)
398 -extra_where: extra WHERE clause to append
399 """
400 query_start = filter_data.pop('query_start', None)
401 query_limit = filter_data.pop('query_limit', None)
402 if query_start and not query_limit:
403 raise ValueError('Cannot pass query_start without '
404 'query_limit')
405 sort_by = filter_data.pop('sort_by', [])
406 extra_args = filter_data.pop('extra_args', {})
407 extra_where = filter_data.pop('extra_where', None)
408 if extra_where:
showard1e935f12008-07-11 00:11:36 +0000409 # escape %'s
showard7c199df2008-10-03 10:17:15 +0000410 extra_where = cls.escape_user_sql(extra_where)
jadmanski0afbb632008-06-06 21:10:57 +0000411 extra_args.setdefault('where', []).append(extra_where)
showard7ac7b7a2008-07-21 20:24:29 +0000412 use_distinct = not filter_data.pop('no_distinct', False)
showard7c785282008-05-29 19:45:12 +0000413
showard7ac7b7a2008-07-21 20:24:29 +0000414 if initial_query is None:
415 if valid_only:
416 initial_query = cls.get_valid_manager()
417 else:
418 initial_query = cls.objects
419 query = initial_query.filter(**filter_data)
420 if use_distinct:
421 query = query.distinct()
showard7c785282008-05-29 19:45:12 +0000422
jadmanski0afbb632008-06-06 21:10:57 +0000423 # other arguments
424 if extra_args:
425 query = query.extra(**extra_args)
showard09096d82008-07-07 23:20:49 +0000426 query = query._clone(klass=ReadonlyQuerySet)
showard7c785282008-05-29 19:45:12 +0000427
jadmanski0afbb632008-06-06 21:10:57 +0000428 # sorting + paging
429 assert isinstance(sort_by, list) or isinstance(sort_by, tuple)
430 query = query.order_by(*sort_by)
431 if query_start is not None and query_limit is not None:
432 query_limit += query_start
433 return query[query_start:query_limit]
showard7c785282008-05-29 19:45:12 +0000434
435
jadmanski0afbb632008-06-06 21:10:57 +0000436 @classmethod
showard585c2ab2008-07-23 19:29:49 +0000437 def query_count(cls, filter_data, initial_query=None):
jadmanski0afbb632008-06-06 21:10:57 +0000438 """\
439 Like query_objects, but retreive only the count of results.
440 """
441 filter_data.pop('query_start', None)
442 filter_data.pop('query_limit', None)
showard585c2ab2008-07-23 19:29:49 +0000443 query = cls.query_objects(filter_data, initial_query=initial_query)
444 return query.count()
showard7c785282008-05-29 19:45:12 +0000445
446
jadmanski0afbb632008-06-06 21:10:57 +0000447 @classmethod
448 def clean_object_dicts(cls, field_dicts):
449 """\
450 Take a list of dicts corresponding to object (as returned by
451 query.values()) and clean the data to be more suitable for
452 returning to the user.
453 """
showarde732ee72008-09-23 19:15:43 +0000454 for field_dict in field_dicts:
455 cls.clean_foreign_keys(field_dict)
456 cls.convert_human_readable_values(field_dict,
457 to_human_readable=True)
showard7c785282008-05-29 19:45:12 +0000458
459
jadmanski0afbb632008-06-06 21:10:57 +0000460 @classmethod
showarde732ee72008-09-23 19:15:43 +0000461 def list_objects(cls, filter_data, initial_query=None, fields=None):
jadmanski0afbb632008-06-06 21:10:57 +0000462 """\
463 Like query_objects, but return a list of dictionaries.
464 """
showard7ac7b7a2008-07-21 20:24:29 +0000465 query = cls.query_objects(filter_data, initial_query=initial_query)
showarde732ee72008-09-23 19:15:43 +0000466 field_dicts = [model_object.get_object_dict(fields)
467 for model_object in query]
jadmanski0afbb632008-06-06 21:10:57 +0000468 return field_dicts
showard7c785282008-05-29 19:45:12 +0000469
470
jadmanski0afbb632008-06-06 21:10:57 +0000471 @classmethod
472 def smart_get(cls, *args, **kwargs):
473 """\
474 smart_get(integer) -> get object by ID
475 smart_get(string) -> get object by name_field
476 smart_get(keyword args) -> normal ModelClass.objects.get()
477 """
478 assert bool(args) ^ bool(kwargs)
479 if args:
480 assert len(args) == 1
481 arg = args[0]
482 if isinstance(arg, int) or isinstance(arg, long):
showarde732ee72008-09-23 19:15:43 +0000483 return cls.objects.get(pk=arg)
jadmanski0afbb632008-06-06 21:10:57 +0000484 if isinstance(arg, str) or isinstance(arg, unicode):
485 return cls.objects.get(
486 **{cls.name_field : arg})
487 raise ValueError(
488 'Invalid positional argument: %s (%s)' % (
489 str(arg), type(arg)))
490 return cls.objects.get(**kwargs)
showard7c785282008-05-29 19:45:12 +0000491
492
showarde732ee72008-09-23 19:15:43 +0000493 def get_object_dict(self, fields=None):
jadmanski0afbb632008-06-06 21:10:57 +0000494 """\
495 Return a dictionary mapping fields to this object's values.
496 """
showarde732ee72008-09-23 19:15:43 +0000497 if fields is None:
498 fields = self.get_field_dict().iterkeys()
jadmanski0afbb632008-06-06 21:10:57 +0000499 object_dict = dict((field_name, getattr(self, field_name))
showarde732ee72008-09-23 19:15:43 +0000500 for field_name in fields)
jadmanski0afbb632008-06-06 21:10:57 +0000501 self.clean_object_dicts([object_dict])
502 return object_dict
showard7c785282008-05-29 19:45:12 +0000503
504
jadmanski0afbb632008-06-06 21:10:57 +0000505 @classmethod
506 def get_valid_manager(cls):
507 return cls.objects
showard7c785282008-05-29 19:45:12 +0000508
509
510class ModelWithInvalid(ModelExtensions):
jadmanski0afbb632008-06-06 21:10:57 +0000511 """
512 Overrides model methods save() and delete() to support invalidation in
513 place of actual deletion. Subclasses must have a boolean "invalid"
514 field.
515 """
showard7c785282008-05-29 19:45:12 +0000516
jadmanski0afbb632008-06-06 21:10:57 +0000517 def save(self):
518 # see if this object was previously added and invalidated
519 my_name = getattr(self, self.name_field)
520 filters = {self.name_field : my_name, 'invalid' : True}
521 try:
522 old_object = self.__class__.objects.get(**filters)
523 except self.DoesNotExist:
524 # no existing object
525 super(ModelWithInvalid, self).save()
526 return
showard7c785282008-05-29 19:45:12 +0000527
jadmanski0afbb632008-06-06 21:10:57 +0000528 self.id = old_object.id
529 super(ModelWithInvalid, self).save()
showard7c785282008-05-29 19:45:12 +0000530
531
jadmanski0afbb632008-06-06 21:10:57 +0000532 def clean_object(self):
533 """
534 This method is called when an object is marked invalid.
535 Subclasses should override this to clean up relationships that
536 should no longer exist if the object were deleted."""
537 pass
showard7c785282008-05-29 19:45:12 +0000538
539
jadmanski0afbb632008-06-06 21:10:57 +0000540 def delete(self):
541 assert not self.invalid
542 self.invalid = True
543 self.save()
544 self.clean_object()
showard7c785282008-05-29 19:45:12 +0000545
546
jadmanski0afbb632008-06-06 21:10:57 +0000547 @classmethod
548 def get_valid_manager(cls):
549 return cls.valid_objects
showard7c785282008-05-29 19:45:12 +0000550
551
jadmanski0afbb632008-06-06 21:10:57 +0000552 class Manipulator(object):
553 """
554 Force default manipulators to look only at valid objects -
555 otherwise they will match against invalid objects when checking
556 uniqueness.
557 """
558 @classmethod
559 def _prepare(cls, model):
560 super(ModelWithInvalid.Manipulator, cls)._prepare(model)
561 cls.manager = model.valid_objects