blob: a79841644e8de59b7d573f463ead7c07b61bcb6f [file] [log] [blame]
showard7c785282008-05-29 19:45:12 +00001"""
2Extensions to Django's model logic.
3"""
4
showarda5288b42009-07-28 20:06:08 +00005import re
6import django.core.exceptions
showard7c785282008-05-29 19:45:12 +00007from django.db import models as dbmodels, backend, connection
showarda5288b42009-07-28 20:06:08 +00008from django.db.models.sql import query
showard7c785282008-05-29 19:45:12 +00009from django.utils import datastructures
showard56e93772008-10-06 10:06:22 +000010from autotest_lib.frontend.afe import readonly_connection
showard7c785282008-05-29 19:45:12 +000011
showarda5288b42009-07-28 20:06:08 +000012
showard7c785282008-05-29 19:45:12 +000013class ValidationError(Exception):
jadmanski0afbb632008-06-06 21:10:57 +000014 """\
showarda5288b42009-07-28 20:06:08 +000015 Data validation error in adding or updating an object. The associated
jadmanski0afbb632008-06-06 21:10:57 +000016 value is a dictionary mapping field names to error strings.
17 """
showard7c785282008-05-29 19:45:12 +000018
19
showard09096d82008-07-07 23:20:49 +000020def _wrap_with_readonly(method):
mbligh1ef218d2009-08-03 16:57:56 +000021 def wrapper_method(*args, **kwargs):
22 readonly_connection.connection().set_django_connection()
23 try:
24 return method(*args, **kwargs)
25 finally:
26 readonly_connection.connection().unset_django_connection()
27 wrapper_method.__name__ = method.__name__
28 return wrapper_method
showard09096d82008-07-07 23:20:49 +000029
30
showarda5288b42009-07-28 20:06:08 +000031def _quote_name(name):
32 """Shorthand for connection.ops.quote_name()."""
33 return connection.ops.quote_name(name)
34
35
showard09096d82008-07-07 23:20:49 +000036def _wrap_generator_with_readonly(generator):
37 """
38 We have to wrap generators specially. Assume it performs
39 the query on the first call to next().
40 """
41 def wrapper_generator(*args, **kwargs):
42 generator_obj = generator(*args, **kwargs)
showard56e93772008-10-06 10:06:22 +000043 readonly_connection.connection().set_django_connection()
showard09096d82008-07-07 23:20:49 +000044 try:
45 first_value = generator_obj.next()
46 finally:
showard56e93772008-10-06 10:06:22 +000047 readonly_connection.connection().unset_django_connection()
showard09096d82008-07-07 23:20:49 +000048 yield first_value
49
50 while True:
51 yield generator_obj.next()
52
53 wrapper_generator.__name__ = generator.__name__
54 return wrapper_generator
55
56
57def _make_queryset_readonly(queryset):
58 """
59 Wrap all methods that do database queries with a readonly connection.
60 """
61 db_query_methods = ['count', 'get', 'get_or_create', 'latest', 'in_bulk',
62 'delete']
63 for method_name in db_query_methods:
64 method = getattr(queryset, method_name)
65 wrapped_method = _wrap_with_readonly(method)
66 setattr(queryset, method_name, wrapped_method)
67
68 queryset.iterator = _wrap_generator_with_readonly(queryset.iterator)
69
70
71class ReadonlyQuerySet(dbmodels.query.QuerySet):
72 """
73 QuerySet object that performs all database queries with the read-only
74 connection.
75 """
showarda5288b42009-07-28 20:06:08 +000076 def __init__(self, model=None, *args, **kwargs):
77 super(ReadonlyQuerySet, self).__init__(model, *args, **kwargs)
showard09096d82008-07-07 23:20:49 +000078 _make_queryset_readonly(self)
79
80
81 def values(self, *fields):
showarda5288b42009-07-28 20:06:08 +000082 return self._clone(klass=ReadonlyValuesQuerySet,
83 setup=True, _fields=fields)
showard09096d82008-07-07 23:20:49 +000084
85
86class ReadonlyValuesQuerySet(dbmodels.query.ValuesQuerySet):
showarda5288b42009-07-28 20:06:08 +000087 def __init__(self, model=None, *args, **kwargs):
88 super(ReadonlyValuesQuerySet, self).__init__(model, *args, **kwargs)
showard09096d82008-07-07 23:20:49 +000089 _make_queryset_readonly(self)
90
91
showard7c785282008-05-29 19:45:12 +000092class ExtendedManager(dbmodels.Manager):
jadmanski0afbb632008-06-06 21:10:57 +000093 """\
94 Extended manager supporting subquery filtering.
95 """
showard7c785282008-05-29 19:45:12 +000096
showarda5288b42009-07-28 20:06:08 +000097 class _CustomQuery(query.Query):
98 def clone(self, klass=None, **kwargs):
99 obj = super(ExtendedManager._CustomQuery, self).clone(
100 klass, _customSqlQ=self._customSqlQ)
showard7c785282008-05-29 19:45:12 +0000101
showarda5288b42009-07-28 20:06:08 +0000102 customQ = kwargs.get('_customSqlQ', None)
103 if customQ is not None:
104 obj._customSqlQ._joins.update(customQ._joins)
105 obj._customSqlQ._where.extend(customQ._where)
106 obj._customSqlQ._params.extend(customQ._params)
showard7c785282008-05-29 19:45:12 +0000107
showarda5288b42009-07-28 20:06:08 +0000108 return obj
showard08f981b2008-06-24 21:59:03 +0000109
showarda5288b42009-07-28 20:06:08 +0000110 def get_from_clause(self):
111 from_, params = super(
112 ExtendedManager._CustomQuery, self).get_from_clause()
showard08f981b2008-06-24 21:59:03 +0000113
showarda5288b42009-07-28 20:06:08 +0000114 join_clause = ''
115 for join_alias, join in self._customSqlQ._joins.iteritems():
116 join_table, join_type, condition = join
117 join_clause += ' %s %s AS %s ON (%s)' % (
118 join_type, join_table, join_alias, condition)
showard08f981b2008-06-24 21:59:03 +0000119
showarda5288b42009-07-28 20:06:08 +0000120 if join_clause:
121 from_.append(join_clause)
showard7c785282008-05-29 19:45:12 +0000122
showarda5288b42009-07-28 20:06:08 +0000123 return from_, params
showard7c785282008-05-29 19:45:12 +0000124
125
showard43a3d262008-11-12 18:17:05 +0000126 class _CustomSqlQ(dbmodels.Q):
127 def __init__(self):
128 self._joins = datastructures.SortedDict()
129 self._where, self._params = [], []
130
131
132 def add_join(self, table, condition, join_type, alias=None):
133 if alias is None:
134 alias = table
showard43a3d262008-11-12 18:17:05 +0000135 self._joins[alias] = (table, join_type, condition)
136
137
138 def add_where(self, where, params=[]):
139 self._where.append(where)
140 self._params.extend(params)
141
142
showarda5288b42009-07-28 20:06:08 +0000143 def add_to_query(self, query, aliases):
144 if self._where:
145 where = ' AND '.join(self._where)
146 query.add_extra(None, None, (where,), self._params, None, None)
147
148
149 def _add_customSqlQ(self, query_set, filter_object):
150 """\
151 Add a _CustomSqlQ to the query set.
152 """
153 # Make a copy of the query set
154 query_set = query_set.all()
155
156 query_set.query = query_set.query.clone(
157 ExtendedManager._CustomQuery, _customSqlQ=filter_object)
158 return query_set.filter(filter_object)
showard43a3d262008-11-12 18:17:05 +0000159
160
161 def add_join(self, query_set, join_table, join_key,
showard0957a842009-05-11 19:25:08 +0000162 join_condition='', suffix='', exclude=False,
163 force_left_join=False):
164 """
165 Add a join to query_set.
166 @param join_table table to join to
167 @param join_key field referencing back to this model to use for the join
168 @param join_condition extra condition for the ON clause of the join
169 @param suffix suffix to add to join_table for the join alias
170 @param exclude if true, exclude rows that match this join (will use a
showarda5288b42009-07-28 20:06:08 +0000171 LEFT OUTER JOIN and an appropriate WHERE condition)
showardc4780402009-08-31 18:31:34 +0000172 @param force_left_join - if true, a LEFT OUTER JOIN will be used
173 instead of an INNER JOIN regardless of other options
showard0957a842009-05-11 19:25:08 +0000174 """
175 join_from_table = self.model._meta.db_table
176 join_from_key = self.model._meta.pk.name
showard43a3d262008-11-12 18:17:05 +0000177 join_alias = join_table + suffix
178 full_join_key = join_alias + '.' + join_key
showard0957a842009-05-11 19:25:08 +0000179 full_join_condition = '%s = %s.%s' % (full_join_key, join_from_table,
180 join_from_key)
showard43a3d262008-11-12 18:17:05 +0000181 if join_condition:
182 full_join_condition += ' AND (' + join_condition + ')'
183 if exclude or force_left_join:
showarda5288b42009-07-28 20:06:08 +0000184 join_type = query_set.query.LOUTER
showard43a3d262008-11-12 18:17:05 +0000185 else:
showarda5288b42009-07-28 20:06:08 +0000186 join_type = query_set.query.INNER
showard43a3d262008-11-12 18:17:05 +0000187
188 filter_object = self._CustomSqlQ()
189 filter_object.add_join(join_table,
190 full_join_condition,
191 join_type,
192 alias=join_alias)
193 if exclude:
194 filter_object.add_where(full_join_key + ' IS NULL')
showard43a3d262008-11-12 18:17:05 +0000195
showarda5288b42009-07-28 20:06:08 +0000196 query_set = self._add_customSqlQ(query_set, filter_object)
showardc4780402009-08-31 18:31:34 +0000197 return query_set
showard7c785282008-05-29 19:45:12 +0000198
199
showardeaccf8f2009-04-16 03:11:33 +0000200 def _get_quoted_field(self, table, field):
showarda5288b42009-07-28 20:06:08 +0000201 return _quote_name(table) + '.' + _quote_name(field)
showard5ef36e92008-07-02 16:37:09 +0000202
203
showard7c199df2008-10-03 10:17:15 +0000204 def get_key_on_this_table(self, key_field=None):
showard5ef36e92008-07-02 16:37:09 +0000205 if key_field is None:
206 # default to primary key
207 key_field = self.model._meta.pk.column
208 return self._get_quoted_field(self.model._meta.db_table, key_field)
209
210
showardeaccf8f2009-04-16 03:11:33 +0000211 def escape_user_sql(self, sql):
212 return sql.replace('%', '%%')
213
showard5ef36e92008-07-02 16:37:09 +0000214
showard0957a842009-05-11 19:25:08 +0000215 def _custom_select_query(self, query_set, selects):
showarda5288b42009-07-28 20:06:08 +0000216 sql, params = query_set.query.as_sql()
217 from_ = sql[sql.find(' FROM'):]
218
219 if query_set.query.distinct:
showard0957a842009-05-11 19:25:08 +0000220 distinct = 'DISTINCT '
221 else:
222 distinct = ''
showarda5288b42009-07-28 20:06:08 +0000223
224 sql_query = ('SELECT ' + distinct + ','.join(selects) + from_)
showard0957a842009-05-11 19:25:08 +0000225 cursor = readonly_connection.connection().cursor()
226 cursor.execute(sql_query, params)
227 return cursor.fetchall()
228
229
showard68693f72009-05-20 00:31:53 +0000230 def _is_relation_to(self, field, model_class):
231 return field.rel and field.rel.to is model_class
showard0957a842009-05-11 19:25:08 +0000232
233
showard68693f72009-05-20 00:31:53 +0000234 def _get_pivot_iterator(self, base_objects_by_id, related_model):
showard0957a842009-05-11 19:25:08 +0000235 """
showard68693f72009-05-20 00:31:53 +0000236 Determine the relationship between this model and related_model, and
237 return a pivot iterator.
238 @param base_objects_by_id: dict of instances of this model indexed by
239 their IDs
240 @returns a pivot iterator, which yields a tuple (base_object,
241 related_object) for each relationship between a base object and a
242 related object. all base_object instances come from base_objects_by_id.
showard0957a842009-05-11 19:25:08 +0000243 Note -- this depends on Django model internals and will likely need to
244 be updated when we move to Django 1.x.
245 """
showard68693f72009-05-20 00:31:53 +0000246 # look for a field on related_model relating to this model
247 for field in related_model._meta.fields:
showard0957a842009-05-11 19:25:08 +0000248 if self._is_relation_to(field, self.model):
showard68693f72009-05-20 00:31:53 +0000249 # many-to-one
250 return self._many_to_one_pivot(base_objects_by_id,
251 related_model, field)
showard0957a842009-05-11 19:25:08 +0000252
showard68693f72009-05-20 00:31:53 +0000253 for field in related_model._meta.many_to_many:
showard0957a842009-05-11 19:25:08 +0000254 if self._is_relation_to(field, self.model):
255 # many-to-many
showard68693f72009-05-20 00:31:53 +0000256 return self._many_to_many_pivot(
257 base_objects_by_id, related_model, field.m2m_db_table(),
258 field.m2m_reverse_name(), field.m2m_column_name())
showard0957a842009-05-11 19:25:08 +0000259
260 # maybe this model has the many-to-many field
261 for field in self.model._meta.many_to_many:
showard68693f72009-05-20 00:31:53 +0000262 if self._is_relation_to(field, related_model):
263 return self._many_to_many_pivot(
264 base_objects_by_id, related_model, field.m2m_db_table(),
265 field.m2m_column_name(), field.m2m_reverse_name())
showard0957a842009-05-11 19:25:08 +0000266
267 raise ValueError('%s has no relation to %s' %
showard68693f72009-05-20 00:31:53 +0000268 (related_model, self.model))
showard0957a842009-05-11 19:25:08 +0000269
270
showard68693f72009-05-20 00:31:53 +0000271 def _many_to_one_pivot(self, base_objects_by_id, related_model,
272 foreign_key_field):
273 """
274 @returns a pivot iterator - see _get_pivot_iterator()
275 """
276 filter_data = {foreign_key_field.name + '__pk__in':
277 base_objects_by_id.keys()}
278 for related_object in related_model.objects.filter(**filter_data):
showarda5a72c92009-08-20 23:35:21 +0000279 # lookup base object in the dict, rather than grabbing it from the
280 # related object. we need to return instances from the dict, not
281 # fresh instances of the same models (and grabbing model instances
282 # from the related models incurs a DB query each time).
283 base_object_id = getattr(related_object, foreign_key_field.attname)
284 base_object = base_objects_by_id[base_object_id]
showard68693f72009-05-20 00:31:53 +0000285 yield base_object, related_object
286
287
288 def _query_pivot_table(self, base_objects_by_id, pivot_table,
289 pivot_from_field, pivot_to_field):
showard0957a842009-05-11 19:25:08 +0000290 """
291 @param id_list list of IDs of self.model objects to include
292 @param pivot_table the name of the pivot table
293 @param pivot_from_field a field name on pivot_table referencing
294 self.model
295 @param pivot_to_field a field name on pivot_table referencing the
296 related model.
showard68693f72009-05-20 00:31:53 +0000297 @returns pivot list of IDs (base_id, related_id)
showard0957a842009-05-11 19:25:08 +0000298 """
299 query = """
300 SELECT %(from_field)s, %(to_field)s
301 FROM %(table)s
302 WHERE %(from_field)s IN (%(id_list)s)
303 """ % dict(from_field=pivot_from_field,
304 to_field=pivot_to_field,
305 table=pivot_table,
showard68693f72009-05-20 00:31:53 +0000306 id_list=','.join(str(id_) for id_
307 in base_objects_by_id.iterkeys()))
showard0957a842009-05-11 19:25:08 +0000308 cursor = readonly_connection.connection().cursor()
309 cursor.execute(query)
showard68693f72009-05-20 00:31:53 +0000310 return cursor.fetchall()
showard0957a842009-05-11 19:25:08 +0000311
312
showard68693f72009-05-20 00:31:53 +0000313 def _many_to_many_pivot(self, base_objects_by_id, related_model,
314 pivot_table, pivot_from_field, pivot_to_field):
315 """
316 @param pivot_table: see _query_pivot_table
317 @param pivot_from_field: see _query_pivot_table
318 @param pivot_to_field: see _query_pivot_table
319 @returns a pivot iterator - see _get_pivot_iterator()
320 """
321 id_pivot = self._query_pivot_table(base_objects_by_id, pivot_table,
322 pivot_from_field, pivot_to_field)
323
324 all_related_ids = list(set(related_id for base_id, related_id
325 in id_pivot))
326 related_objects_by_id = related_model.objects.in_bulk(all_related_ids)
327
328 for base_id, related_id in id_pivot:
329 yield base_objects_by_id[base_id], related_objects_by_id[related_id]
330
331
332 def populate_relationships(self, base_objects, related_model,
showard0957a842009-05-11 19:25:08 +0000333 related_list_name):
334 """
showard68693f72009-05-20 00:31:53 +0000335 For each instance of this model in base_objects, add a field named
336 related_list_name listing all the related objects of type related_model.
337 related_model must be in a many-to-one or many-to-many relationship with
338 this model.
339 @param base_objects - list of instances of this model
340 @param related_model - model class related to this model
341 @param related_list_name - attribute name in which to store the related
342 object list.
showard0957a842009-05-11 19:25:08 +0000343 """
showard68693f72009-05-20 00:31:53 +0000344 if not base_objects:
showard0957a842009-05-11 19:25:08 +0000345 # if we don't bail early, we'll get a SQL error later
346 return
showard0957a842009-05-11 19:25:08 +0000347
showard68693f72009-05-20 00:31:53 +0000348 base_objects_by_id = dict((base_object._get_pk_val(), base_object)
349 for base_object in base_objects)
350 pivot_iterator = self._get_pivot_iterator(base_objects_by_id,
351 related_model)
showard0957a842009-05-11 19:25:08 +0000352
showard68693f72009-05-20 00:31:53 +0000353 for base_object in base_objects:
354 setattr(base_object, related_list_name, [])
355
356 for base_object, related_object in pivot_iterator:
357 getattr(base_object, related_list_name).append(related_object)
showard0957a842009-05-11 19:25:08 +0000358
359
showard7c785282008-05-29 19:45:12 +0000360class ValidObjectsManager(ExtendedManager):
jadmanski0afbb632008-06-06 21:10:57 +0000361 """
362 Manager returning only objects with invalid=False.
363 """
364 def get_query_set(self):
365 queryset = super(ValidObjectsManager, self).get_query_set()
366 return queryset.filter(invalid=False)
showard7c785282008-05-29 19:45:12 +0000367
368
369class ModelExtensions(object):
jadmanski0afbb632008-06-06 21:10:57 +0000370 """\
371 Mixin with convenience functions for models, built on top of the
372 default Django model functions.
373 """
374 # TODO: at least some of these functions really belong in a custom
375 # Manager class
showard7c785282008-05-29 19:45:12 +0000376
jadmanski0afbb632008-06-06 21:10:57 +0000377 field_dict = None
378 # subclasses should override if they want to support smart_get() by name
379 name_field = None
showard7c785282008-05-29 19:45:12 +0000380
381
jadmanski0afbb632008-06-06 21:10:57 +0000382 @classmethod
383 def get_field_dict(cls):
384 if cls.field_dict is None:
385 cls.field_dict = {}
386 for field in cls._meta.fields:
387 cls.field_dict[field.name] = field
388 return cls.field_dict
showard7c785282008-05-29 19:45:12 +0000389
390
jadmanski0afbb632008-06-06 21:10:57 +0000391 @classmethod
392 def clean_foreign_keys(cls, data):
393 """\
394 -Convert foreign key fields in data from <field>_id to just
395 <field>.
396 -replace foreign key objects with their IDs
397 This method modifies data in-place.
398 """
399 for field in cls._meta.fields:
400 if not field.rel:
401 continue
402 if (field.attname != field.name and
403 field.attname in data):
404 data[field.name] = data[field.attname]
405 del data[field.attname]
showarde732ee72008-09-23 19:15:43 +0000406 if field.name not in data:
407 continue
jadmanski0afbb632008-06-06 21:10:57 +0000408 value = data[field.name]
409 if isinstance(value, dbmodels.Model):
showardf8b19042009-05-12 17:22:49 +0000410 data[field.name] = value._get_pk_val()
showard7c785282008-05-29 19:45:12 +0000411
412
showard21baa452008-10-21 00:08:39 +0000413 @classmethod
414 def _convert_booleans(cls, data):
415 """
416 Ensure BooleanFields actually get bool values. The Django MySQL
417 backend returns ints for BooleanFields, which is almost always not
418 a problem, but it can be annoying in certain situations.
419 """
420 for field in cls._meta.fields:
showardf8b19042009-05-12 17:22:49 +0000421 if type(field) == dbmodels.BooleanField and field.name in data:
showard21baa452008-10-21 00:08:39 +0000422 data[field.name] = bool(data[field.name])
423
424
jadmanski0afbb632008-06-06 21:10:57 +0000425 # TODO(showard) - is there a way to not have to do this?
426 @classmethod
427 def provide_default_values(cls, data):
428 """\
429 Provide default values for fields with default values which have
430 nothing passed in.
showard7c785282008-05-29 19:45:12 +0000431
jadmanski0afbb632008-06-06 21:10:57 +0000432 For CharField and TextField fields with "blank=True", if nothing
433 is passed, we fill in an empty string value, even if there's no
434 default set.
435 """
436 new_data = dict(data)
437 field_dict = cls.get_field_dict()
438 for name, obj in field_dict.iteritems():
439 if data.get(name) is not None:
440 continue
441 if obj.default is not dbmodels.fields.NOT_PROVIDED:
442 new_data[name] = obj.default
443 elif (isinstance(obj, dbmodels.CharField) or
444 isinstance(obj, dbmodels.TextField)):
445 new_data[name] = ''
446 return new_data
showard7c785282008-05-29 19:45:12 +0000447
448
jadmanski0afbb632008-06-06 21:10:57 +0000449 @classmethod
450 def convert_human_readable_values(cls, data, to_human_readable=False):
451 """\
452 Performs conversions on user-supplied field data, to make it
453 easier for users to pass human-readable data.
showard7c785282008-05-29 19:45:12 +0000454
jadmanski0afbb632008-06-06 21:10:57 +0000455 For all fields that have choice sets, convert their values
456 from human-readable strings to enum values, if necessary. This
457 allows users to pass strings instead of the corresponding
458 integer values.
showard7c785282008-05-29 19:45:12 +0000459
jadmanski0afbb632008-06-06 21:10:57 +0000460 For all foreign key fields, call smart_get with the supplied
461 data. This allows the user to pass either an ID value or
462 the name of the object as a string.
showard7c785282008-05-29 19:45:12 +0000463
jadmanski0afbb632008-06-06 21:10:57 +0000464 If to_human_readable=True, perform the inverse - i.e. convert
465 numeric values to human readable values.
showard7c785282008-05-29 19:45:12 +0000466
jadmanski0afbb632008-06-06 21:10:57 +0000467 This method modifies data in-place.
468 """
469 field_dict = cls.get_field_dict()
470 for field_name in data:
showarde732ee72008-09-23 19:15:43 +0000471 if field_name not in field_dict or data[field_name] is None:
jadmanski0afbb632008-06-06 21:10:57 +0000472 continue
473 field_obj = field_dict[field_name]
474 # convert enum values
475 if field_obj.choices:
476 for choice_data in field_obj.choices:
477 # choice_data is (value, name)
478 if to_human_readable:
479 from_val, to_val = choice_data
480 else:
481 to_val, from_val = choice_data
482 if from_val == data[field_name]:
483 data[field_name] = to_val
484 break
485 # convert foreign key values
486 elif field_obj.rel:
showarda4ea5742009-02-17 20:56:23 +0000487 dest_obj = field_obj.rel.to.smart_get(data[field_name],
488 valid_only=False)
showardf8b19042009-05-12 17:22:49 +0000489 if to_human_readable:
490 if dest_obj.name_field is not None:
491 data[field_name] = getattr(dest_obj,
492 dest_obj.name_field)
jadmanski0afbb632008-06-06 21:10:57 +0000493 else:
showardb0a73032009-03-27 18:35:41 +0000494 data[field_name] = dest_obj
showard7c785282008-05-29 19:45:12 +0000495
496
jadmanski0afbb632008-06-06 21:10:57 +0000497 @classmethod
498 def validate_field_names(cls, data):
499 'Checks for extraneous fields in data.'
500 errors = {}
501 field_dict = cls.get_field_dict()
502 for field_name in data:
503 if field_name not in field_dict:
504 errors[field_name] = 'No field of this name'
505 return errors
showard7c785282008-05-29 19:45:12 +0000506
507
jadmanski0afbb632008-06-06 21:10:57 +0000508 @classmethod
509 def prepare_data_args(cls, data, kwargs):
510 'Common preparation for add_object and update_object'
511 data = dict(data) # don't modify the default keyword arg
512 data.update(kwargs)
513 # must check for extraneous field names here, while we have the
514 # data in a dict
515 errors = cls.validate_field_names(data)
516 if errors:
517 raise ValidationError(errors)
518 cls.convert_human_readable_values(data)
519 return data
showard7c785282008-05-29 19:45:12 +0000520
521
jadmanski0afbb632008-06-06 21:10:57 +0000522 def validate_unique(self):
523 """\
524 Validate that unique fields are unique. Django manipulators do
525 this too, but they're a huge pain to use manually. Trust me.
526 """
527 errors = {}
528 cls = type(self)
529 field_dict = self.get_field_dict()
530 manager = cls.get_valid_manager()
531 for field_name, field_obj in field_dict.iteritems():
532 if not field_obj.unique:
533 continue
showard7c785282008-05-29 19:45:12 +0000534
jadmanski0afbb632008-06-06 21:10:57 +0000535 value = getattr(self, field_name)
showardbd18ab72009-09-18 21:20:27 +0000536 if value is None and field_obj.auto_created:
537 # don't bother checking autoincrement fields about to be
538 # generated
539 continue
540
jadmanski0afbb632008-06-06 21:10:57 +0000541 existing_objs = manager.filter(**{field_name : value})
542 num_existing = existing_objs.count()
showard7c785282008-05-29 19:45:12 +0000543
jadmanski0afbb632008-06-06 21:10:57 +0000544 if num_existing == 0:
545 continue
546 if num_existing == 1 and existing_objs[0].id == self.id:
547 continue
548 errors[field_name] = (
549 'This value must be unique (%s)' % (value))
550 return errors
showard7c785282008-05-29 19:45:12 +0000551
552
showarda5288b42009-07-28 20:06:08 +0000553 def _validate(self):
554 """
555 First coerces all fields on this instance to their proper Python types.
556 Then runs validation on every field. Returns a dictionary of
557 field_name -> error_list.
558
559 Based on validate() from django.db.models.Model in Django 0.96, which
560 was removed in Django 1.0. It should reappear in a later version. See:
561 http://code.djangoproject.com/ticket/6845
562 """
563 error_dict = {}
564 for f in self._meta.fields:
565 try:
566 python_value = f.to_python(
567 getattr(self, f.attname, f.get_default()))
568 except django.core.exceptions.ValidationError, e:
569 error_dict[f.name] = str(e.message)
570 continue
571
572 if not f.blank and not python_value:
573 error_dict[f.name] = 'This field is required.'
574 continue
575
576 setattr(self, f.attname, python_value)
577
578 return error_dict
579
580
jadmanski0afbb632008-06-06 21:10:57 +0000581 def do_validate(self):
showarda5288b42009-07-28 20:06:08 +0000582 errors = self._validate()
jadmanski0afbb632008-06-06 21:10:57 +0000583 unique_errors = self.validate_unique()
584 for field_name, error in unique_errors.iteritems():
585 errors.setdefault(field_name, error)
586 if errors:
587 raise ValidationError(errors)
showard7c785282008-05-29 19:45:12 +0000588
589
jadmanski0afbb632008-06-06 21:10:57 +0000590 # actually (externally) useful methods follow
showard7c785282008-05-29 19:45:12 +0000591
jadmanski0afbb632008-06-06 21:10:57 +0000592 @classmethod
593 def add_object(cls, data={}, **kwargs):
594 """\
595 Returns a new object created with the given data (a dictionary
596 mapping field names to values). Merges any extra keyword args
597 into data.
598 """
599 data = cls.prepare_data_args(data, kwargs)
600 data = cls.provide_default_values(data)
601 obj = cls(**data)
602 obj.do_validate()
603 obj.save()
604 return obj
showard7c785282008-05-29 19:45:12 +0000605
606
jadmanski0afbb632008-06-06 21:10:57 +0000607 def update_object(self, data={}, **kwargs):
608 """\
609 Updates the object with the given data (a dictionary mapping
610 field names to values). Merges any extra keyword args into
611 data.
612 """
613 data = self.prepare_data_args(data, kwargs)
614 for field_name, value in data.iteritems():
showardb0a73032009-03-27 18:35:41 +0000615 setattr(self, field_name, value)
jadmanski0afbb632008-06-06 21:10:57 +0000616 self.do_validate()
617 self.save()
showard7c785282008-05-29 19:45:12 +0000618
619
showard8bfb5cb2009-10-07 20:49:15 +0000620 # see query_objects()
621 _SPECIAL_FILTER_KEYS = ('query_start', 'query_limit', 'sort_by',
622 'extra_args', 'extra_where', 'no_distinct')
623
624
jadmanski0afbb632008-06-06 21:10:57 +0000625 @classmethod
showard8bfb5cb2009-10-07 20:49:15 +0000626 def _extract_special_params(cls, filter_data):
627 """
628 @returns a tuple of dicts (special_params, regular_filters), where
629 special_params contains the parameters we handle specially and
630 regular_filters is the remaining data to be handled by Django.
631 """
632 regular_filters = dict(filter_data)
633 special_params = {}
634 for key in cls._SPECIAL_FILTER_KEYS:
635 if key in regular_filters:
636 special_params[key] = regular_filters.pop(key)
637 return special_params, regular_filters
638
639
640 @classmethod
641 def apply_presentation(cls, query, filter_data):
642 """
643 Apply presentation parameters -- sorting and paging -- to the given
644 query.
645 @returns new query with presentation applied
646 """
647 special_params, _ = cls._extract_special_params(filter_data)
648 sort_by = special_params.get('sort_by', None)
649 if sort_by:
650 assert isinstance(sort_by, list) or isinstance(sort_by, tuple)
651 query = query.order_by(*sort_by)
652
653 query_start = special_params.get('query_start', None)
654 query_limit = special_params.get('query_limit', None)
655 if query_start is not None:
656 if query_limit is None:
657 raise ValueError('Cannot pass query_start without query_limit')
658 # query_limit is passed as a page size
showard7074b742009-10-12 20:30:04 +0000659 query_limit += query_start
660 return query[query_start:query_limit]
showard8bfb5cb2009-10-07 20:49:15 +0000661
662
663 @classmethod
664 def query_objects(cls, filter_data, valid_only=True, initial_query=None,
665 apply_presentation=True):
jadmanski0afbb632008-06-06 21:10:57 +0000666 """\
667 Returns a QuerySet object for querying the given model_class
668 with the given filter_data. Optional special arguments in
669 filter_data include:
670 -query_start: index of first return to return
671 -query_limit: maximum number of results to return
672 -sort_by: list of fields to sort on. prefixing a '-' onto a
673 field name changes the sort to descending order.
674 -extra_args: keyword args to pass to query.extra() (see Django
675 DB layer documentation)
showarda5288b42009-07-28 20:06:08 +0000676 -extra_where: extra WHERE clause to append
showard8bfb5cb2009-10-07 20:49:15 +0000677 -no_distinct: if True, a DISTINCT will not be added to the SELECT
jadmanski0afbb632008-06-06 21:10:57 +0000678 """
showard8bfb5cb2009-10-07 20:49:15 +0000679 special_params, regular_filters = cls._extract_special_params(
680 filter_data)
showard7c785282008-05-29 19:45:12 +0000681
showard7ac7b7a2008-07-21 20:24:29 +0000682 if initial_query is None:
683 if valid_only:
684 initial_query = cls.get_valid_manager()
685 else:
686 initial_query = cls.objects
showard8bfb5cb2009-10-07 20:49:15 +0000687
688 query = initial_query.filter(**regular_filters)
689
690 use_distinct = not special_params.get('no_distinct', False)
showard7ac7b7a2008-07-21 20:24:29 +0000691 if use_distinct:
692 query = query.distinct()
showard7c785282008-05-29 19:45:12 +0000693
showard8bfb5cb2009-10-07 20:49:15 +0000694 extra_args = special_params.get('extra_args', {})
695 extra_where = special_params.get('extra_where', None)
696 if extra_where:
697 # escape %'s
698 extra_where = cls.objects.escape_user_sql(extra_where)
699 extra_args.setdefault('where', []).append(extra_where)
jadmanski0afbb632008-06-06 21:10:57 +0000700 if extra_args:
701 query = query.extra(**extra_args)
showard09096d82008-07-07 23:20:49 +0000702 query = query._clone(klass=ReadonlyQuerySet)
showard7c785282008-05-29 19:45:12 +0000703
showard8bfb5cb2009-10-07 20:49:15 +0000704 if apply_presentation:
705 query = cls.apply_presentation(query, filter_data)
706
707 return query
showard7c785282008-05-29 19:45:12 +0000708
709
jadmanski0afbb632008-06-06 21:10:57 +0000710 @classmethod
showard585c2ab2008-07-23 19:29:49 +0000711 def query_count(cls, filter_data, initial_query=None):
jadmanski0afbb632008-06-06 21:10:57 +0000712 """\
713 Like query_objects, but retreive only the count of results.
714 """
715 filter_data.pop('query_start', None)
716 filter_data.pop('query_limit', None)
showard585c2ab2008-07-23 19:29:49 +0000717 query = cls.query_objects(filter_data, initial_query=initial_query)
718 return query.count()
showard7c785282008-05-29 19:45:12 +0000719
720
jadmanski0afbb632008-06-06 21:10:57 +0000721 @classmethod
722 def clean_object_dicts(cls, field_dicts):
723 """\
724 Take a list of dicts corresponding to object (as returned by
725 query.values()) and clean the data to be more suitable for
726 returning to the user.
727 """
showarde732ee72008-09-23 19:15:43 +0000728 for field_dict in field_dicts:
729 cls.clean_foreign_keys(field_dict)
showard21baa452008-10-21 00:08:39 +0000730 cls._convert_booleans(field_dict)
showarde732ee72008-09-23 19:15:43 +0000731 cls.convert_human_readable_values(field_dict,
732 to_human_readable=True)
showard7c785282008-05-29 19:45:12 +0000733
734
jadmanski0afbb632008-06-06 21:10:57 +0000735 @classmethod
showard8bfb5cb2009-10-07 20:49:15 +0000736 def list_objects(cls, filter_data, initial_query=None):
jadmanski0afbb632008-06-06 21:10:57 +0000737 """\
738 Like query_objects, but return a list of dictionaries.
739 """
showard7ac7b7a2008-07-21 20:24:29 +0000740 query = cls.query_objects(filter_data, initial_query=initial_query)
showard8bfb5cb2009-10-07 20:49:15 +0000741 extra_fields = query.query.extra_select.keys()
742 field_dicts = [model_object.get_object_dict(extra_fields=extra_fields)
showarde732ee72008-09-23 19:15:43 +0000743 for model_object in query]
jadmanski0afbb632008-06-06 21:10:57 +0000744 return field_dicts
showard7c785282008-05-29 19:45:12 +0000745
746
jadmanski0afbb632008-06-06 21:10:57 +0000747 @classmethod
showarda4ea5742009-02-17 20:56:23 +0000748 def smart_get(cls, id_or_name, valid_only=True):
jadmanski0afbb632008-06-06 21:10:57 +0000749 """\
750 smart_get(integer) -> get object by ID
751 smart_get(string) -> get object by name_field
jadmanski0afbb632008-06-06 21:10:57 +0000752 """
showarda4ea5742009-02-17 20:56:23 +0000753 if valid_only:
754 manager = cls.get_valid_manager()
755 else:
756 manager = cls.objects
757
758 if isinstance(id_or_name, (int, long)):
759 return manager.get(pk=id_or_name)
760 if isinstance(id_or_name, basestring):
761 return manager.get(**{cls.name_field : id_or_name})
762 raise ValueError(
763 'Invalid positional argument: %s (%s)' % (id_or_name,
764 type(id_or_name)))
showard7c785282008-05-29 19:45:12 +0000765
766
showardbe3ec042008-11-12 18:16:07 +0000767 @classmethod
768 def smart_get_bulk(cls, id_or_name_list):
769 invalid_inputs = []
770 result_objects = []
771 for id_or_name in id_or_name_list:
772 try:
773 result_objects.append(cls.smart_get(id_or_name))
774 except cls.DoesNotExist:
775 invalid_inputs.append(id_or_name)
776 if invalid_inputs:
mbligh7a3ebe32008-12-01 17:10:33 +0000777 raise cls.DoesNotExist('The following %ss do not exist: %s'
778 % (cls.__name__.lower(),
779 ', '.join(invalid_inputs)))
showardbe3ec042008-11-12 18:16:07 +0000780 return result_objects
781
782
showard8bfb5cb2009-10-07 20:49:15 +0000783 def get_object_dict(self, extra_fields=None):
jadmanski0afbb632008-06-06 21:10:57 +0000784 """\
showard8bfb5cb2009-10-07 20:49:15 +0000785 Return a dictionary mapping fields to this object's values. @param
786 extra_fields: list of extra attribute names to include, in addition to
787 the fields defined on this object.
jadmanski0afbb632008-06-06 21:10:57 +0000788 """
showard8bfb5cb2009-10-07 20:49:15 +0000789 fields = self.get_field_dict().keys()
790 if extra_fields:
791 fields += extra_fields
jadmanski0afbb632008-06-06 21:10:57 +0000792 object_dict = dict((field_name, getattr(self, field_name))
showarde732ee72008-09-23 19:15:43 +0000793 for field_name in fields)
jadmanski0afbb632008-06-06 21:10:57 +0000794 self.clean_object_dicts([object_dict])
showardd3dc1992009-04-22 21:01:40 +0000795 self._postprocess_object_dict(object_dict)
jadmanski0afbb632008-06-06 21:10:57 +0000796 return object_dict
showard7c785282008-05-29 19:45:12 +0000797
798
showardd3dc1992009-04-22 21:01:40 +0000799 def _postprocess_object_dict(self, object_dict):
800 """For subclasses to override."""
801 pass
802
803
jadmanski0afbb632008-06-06 21:10:57 +0000804 @classmethod
805 def get_valid_manager(cls):
806 return cls.objects
showard7c785282008-05-29 19:45:12 +0000807
808
showard2bab8f42008-11-12 18:15:22 +0000809 def _record_attributes(self, attributes):
810 """
811 See on_attribute_changed.
812 """
813 assert not isinstance(attributes, basestring)
814 self._recorded_attributes = dict((attribute, getattr(self, attribute))
815 for attribute in attributes)
816
817
818 def _check_for_updated_attributes(self):
819 """
820 See on_attribute_changed.
821 """
822 for attribute, original_value in self._recorded_attributes.iteritems():
823 new_value = getattr(self, attribute)
824 if original_value != new_value:
825 self.on_attribute_changed(attribute, original_value)
826 self._record_attributes(self._recorded_attributes.keys())
827
828
829 def on_attribute_changed(self, attribute, old_value):
830 """
831 Called whenever an attribute is updated. To be overridden.
832
833 To use this method, you must:
834 * call _record_attributes() from __init__() (after making the super
835 call) with a list of attributes for which you want to be notified upon
836 change.
837 * call _check_for_updated_attributes() from save().
838 """
839 pass
840
841
showard7c785282008-05-29 19:45:12 +0000842class ModelWithInvalid(ModelExtensions):
jadmanski0afbb632008-06-06 21:10:57 +0000843 """
844 Overrides model methods save() and delete() to support invalidation in
845 place of actual deletion. Subclasses must have a boolean "invalid"
846 field.
847 """
showard7c785282008-05-29 19:45:12 +0000848
showarda5288b42009-07-28 20:06:08 +0000849 def save(self, *args, **kwargs):
showardddb90992009-02-11 23:39:32 +0000850 first_time = (self.id is None)
851 if first_time:
852 # see if this object was previously added and invalidated
853 my_name = getattr(self, self.name_field)
854 filters = {self.name_field : my_name, 'invalid' : True}
855 try:
856 old_object = self.__class__.objects.get(**filters)
showardafd97de2009-10-01 18:45:09 +0000857 self.resurrect_object(old_object)
showardddb90992009-02-11 23:39:32 +0000858 except self.DoesNotExist:
859 # no existing object
860 pass
showard7c785282008-05-29 19:45:12 +0000861
showarda5288b42009-07-28 20:06:08 +0000862 super(ModelWithInvalid, self).save(*args, **kwargs)
showard7c785282008-05-29 19:45:12 +0000863
864
showardafd97de2009-10-01 18:45:09 +0000865 def resurrect_object(self, old_object):
866 """
867 Called when self is about to be saved for the first time and is actually
868 "undeleting" a previously deleted object. Can be overridden by
869 subclasses to copy data as desired from the deleted entry (but this
870 superclass implementation must normally be called).
871 """
872 self.id = old_object.id
873
874
jadmanski0afbb632008-06-06 21:10:57 +0000875 def clean_object(self):
876 """
877 This method is called when an object is marked invalid.
878 Subclasses should override this to clean up relationships that
showardafd97de2009-10-01 18:45:09 +0000879 should no longer exist if the object were deleted.
880 """
jadmanski0afbb632008-06-06 21:10:57 +0000881 pass
showard7c785282008-05-29 19:45:12 +0000882
883
jadmanski0afbb632008-06-06 21:10:57 +0000884 def delete(self):
885 assert not self.invalid
886 self.invalid = True
887 self.save()
888 self.clean_object()
showard7c785282008-05-29 19:45:12 +0000889
890
jadmanski0afbb632008-06-06 21:10:57 +0000891 @classmethod
892 def get_valid_manager(cls):
893 return cls.valid_objects
showard7c785282008-05-29 19:45:12 +0000894
895
jadmanski0afbb632008-06-06 21:10:57 +0000896 class Manipulator(object):
897 """
898 Force default manipulators to look only at valid objects -
899 otherwise they will match against invalid objects when checking
900 uniqueness.
901 """
902 @classmethod
903 def _prepare(cls, model):
904 super(ModelWithInvalid.Manipulator, cls)._prepare(model)
905 cls.manager = model.valid_objects
showardf8b19042009-05-12 17:22:49 +0000906
907
908class ModelWithAttributes(object):
909 """
910 Mixin class for models that have an attribute model associated with them.
911 The attribute model is assumed to have its value field named "value".
912 """
913
914 def _get_attribute_model_and_args(self, attribute):
915 """
916 Subclasses should override this to return a tuple (attribute_model,
917 keyword_args), where attribute_model is a model class and keyword_args
918 is a dict of args to pass to attribute_model.objects.get() to get an
919 instance of the given attribute on this object.
920 """
921 raise NotImplemented
922
923
924 def set_attribute(self, attribute, value):
925 attribute_model, get_args = self._get_attribute_model_and_args(
926 attribute)
927 attribute_object, _ = attribute_model.objects.get_or_create(**get_args)
928 attribute_object.value = value
929 attribute_object.save()
930
931
932 def delete_attribute(self, attribute):
933 attribute_model, get_args = self._get_attribute_model_and_args(
934 attribute)
935 try:
936 attribute_model.objects.get(**get_args).delete()
showard16245422009-09-08 16:28:15 +0000937 except attribute_model.DoesNotExist:
showardf8b19042009-05-12 17:22:49 +0000938 pass
939
940
941 def set_or_delete_attribute(self, attribute, value):
942 if value is None:
943 self.delete_attribute(attribute)
944 else:
945 self.set_attribute(attribute, value)
showard26b7ec72009-12-21 22:43:57 +0000946
947
948class ModelWithHashManager(dbmodels.Manager):
949 """Manager for use with the ModelWithHash abstract model class"""
950
951 def create(self, **kwargs):
952 raise Exception('ModelWithHash manager should use get_or_create() '
953 'instead of create()')
954
955
956 def get_or_create(self, **kwargs):
957 kwargs['the_hash'] = self.model._compute_hash(**kwargs)
958 return super(ModelWithHashManager, self).get_or_create(**kwargs)
959
960
961class ModelWithHash(dbmodels.Model):
962 """Superclass with methods for dealing with a hash column"""
963
964 the_hash = dbmodels.CharField(max_length=40, unique=True)
965
966 objects = ModelWithHashManager()
967
968 class Meta:
969 abstract = True
970
971
972 @classmethod
973 def _compute_hash(cls, **kwargs):
974 raise NotImplementedError('Subclasses must override _compute_hash()')
975
976
977 def save(self, force_insert=False, **kwargs):
978 """Prevents saving the model in most cases
979
980 We want these models to be immutable, so the generic save() operation
981 will not work. These models should be instantiated through their the
982 model.objects.get_or_create() method instead.
983
984 The exception is that save(force_insert=True) will be allowed, since
985 that creates a new row. However, the preferred way to make instances of
986 these models is through the get_or_create() method.
987 """
988 if not force_insert:
989 # Allow a forced insert to happen; if it's a duplicate, the unique
990 # constraint will catch it later anyways
991 raise Exception('ModelWithHash is immutable')
992 super(ModelWithHash, self).save(force_insert=force_insert, **kwargs)