blob: c683699f6b64b91c5dfde6caf574a369e465a1a6 [file] [log] [blame]
showard7c785282008-05-29 19:45:12 +00001"""
2Extensions to Django's model logic.
3"""
4
showarda5288b42009-07-28 20:06:08 +00005import re
6import django.core.exceptions
showard7c785282008-05-29 19:45:12 +00007from django.db import models as dbmodels, backend, connection
showarda5288b42009-07-28 20:06:08 +00008from django.db.models.sql import query
showard7c785282008-05-29 19:45:12 +00009from django.utils import datastructures
showard56e93772008-10-06 10:06:22 +000010from autotest_lib.frontend.afe import readonly_connection
showard7c785282008-05-29 19:45:12 +000011
showard8b0ea222009-12-23 19:23:03 +000012_quote_name = connection.ops.quote_name
showarda5288b42009-07-28 20:06:08 +000013
showard7c785282008-05-29 19:45:12 +000014class ValidationError(Exception):
jadmanski0afbb632008-06-06 21:10:57 +000015 """\
showarda5288b42009-07-28 20:06:08 +000016 Data validation error in adding or updating an object. The associated
jadmanski0afbb632008-06-06 21:10:57 +000017 value is a dictionary mapping field names to error strings.
18 """
showard7c785282008-05-29 19:45:12 +000019
20
showard09096d82008-07-07 23:20:49 +000021def _wrap_with_readonly(method):
mbligh1ef218d2009-08-03 16:57:56 +000022 def wrapper_method(*args, **kwargs):
23 readonly_connection.connection().set_django_connection()
24 try:
25 return method(*args, **kwargs)
26 finally:
27 readonly_connection.connection().unset_django_connection()
28 wrapper_method.__name__ = method.__name__
29 return wrapper_method
showard09096d82008-07-07 23:20:49 +000030
31
showarda5288b42009-07-28 20:06:08 +000032def _quote_name(name):
33 """Shorthand for connection.ops.quote_name()."""
34 return connection.ops.quote_name(name)
35
36
showard09096d82008-07-07 23:20:49 +000037def _wrap_generator_with_readonly(generator):
38 """
39 We have to wrap generators specially. Assume it performs
40 the query on the first call to next().
41 """
42 def wrapper_generator(*args, **kwargs):
43 generator_obj = generator(*args, **kwargs)
showard56e93772008-10-06 10:06:22 +000044 readonly_connection.connection().set_django_connection()
showard09096d82008-07-07 23:20:49 +000045 try:
46 first_value = generator_obj.next()
47 finally:
showard56e93772008-10-06 10:06:22 +000048 readonly_connection.connection().unset_django_connection()
showard09096d82008-07-07 23:20:49 +000049 yield first_value
50
51 while True:
52 yield generator_obj.next()
53
54 wrapper_generator.__name__ = generator.__name__
55 return wrapper_generator
56
57
58def _make_queryset_readonly(queryset):
59 """
60 Wrap all methods that do database queries with a readonly connection.
61 """
62 db_query_methods = ['count', 'get', 'get_or_create', 'latest', 'in_bulk',
63 'delete']
64 for method_name in db_query_methods:
65 method = getattr(queryset, method_name)
66 wrapped_method = _wrap_with_readonly(method)
67 setattr(queryset, method_name, wrapped_method)
68
69 queryset.iterator = _wrap_generator_with_readonly(queryset.iterator)
70
71
72class ReadonlyQuerySet(dbmodels.query.QuerySet):
73 """
74 QuerySet object that performs all database queries with the read-only
75 connection.
76 """
showarda5288b42009-07-28 20:06:08 +000077 def __init__(self, model=None, *args, **kwargs):
78 super(ReadonlyQuerySet, self).__init__(model, *args, **kwargs)
showard09096d82008-07-07 23:20:49 +000079 _make_queryset_readonly(self)
80
81
82 def values(self, *fields):
showarda5288b42009-07-28 20:06:08 +000083 return self._clone(klass=ReadonlyValuesQuerySet,
84 setup=True, _fields=fields)
showard09096d82008-07-07 23:20:49 +000085
86
87class ReadonlyValuesQuerySet(dbmodels.query.ValuesQuerySet):
showarda5288b42009-07-28 20:06:08 +000088 def __init__(self, model=None, *args, **kwargs):
89 super(ReadonlyValuesQuerySet, self).__init__(model, *args, **kwargs)
showard09096d82008-07-07 23:20:49 +000090 _make_queryset_readonly(self)
91
92
showard7c785282008-05-29 19:45:12 +000093class ExtendedManager(dbmodels.Manager):
jadmanski0afbb632008-06-06 21:10:57 +000094 """\
95 Extended manager supporting subquery filtering.
96 """
showard7c785282008-05-29 19:45:12 +000097
showarda5288b42009-07-28 20:06:08 +000098 class _CustomQuery(query.Query):
99 def clone(self, klass=None, **kwargs):
100 obj = super(ExtendedManager._CustomQuery, self).clone(
101 klass, _customSqlQ=self._customSqlQ)
showard7c785282008-05-29 19:45:12 +0000102
showarda5288b42009-07-28 20:06:08 +0000103 customQ = kwargs.get('_customSqlQ', None)
104 if customQ is not None:
105 obj._customSqlQ._joins.update(customQ._joins)
106 obj._customSqlQ._where.extend(customQ._where)
107 obj._customSqlQ._params.extend(customQ._params)
showard7c785282008-05-29 19:45:12 +0000108
showarda5288b42009-07-28 20:06:08 +0000109 return obj
showard08f981b2008-06-24 21:59:03 +0000110
showarda5288b42009-07-28 20:06:08 +0000111 def get_from_clause(self):
112 from_, params = super(
113 ExtendedManager._CustomQuery, self).get_from_clause()
showard08f981b2008-06-24 21:59:03 +0000114
showarda5288b42009-07-28 20:06:08 +0000115 join_clause = ''
116 for join_alias, join in self._customSqlQ._joins.iteritems():
117 join_table, join_type, condition = join
118 join_clause += ' %s %s AS %s ON (%s)' % (
showard8b0ea222009-12-23 19:23:03 +0000119 join_type, _quote_name(join_table),
120 _quote_name(join_alias), condition)
showard08f981b2008-06-24 21:59:03 +0000121
showarda5288b42009-07-28 20:06:08 +0000122 if join_clause:
123 from_.append(join_clause)
showard7c785282008-05-29 19:45:12 +0000124
showarda5288b42009-07-28 20:06:08 +0000125 return from_, params
showard7c785282008-05-29 19:45:12 +0000126
127
showard43a3d262008-11-12 18:17:05 +0000128 class _CustomSqlQ(dbmodels.Q):
129 def __init__(self):
130 self._joins = datastructures.SortedDict()
131 self._where, self._params = [], []
132
133
134 def add_join(self, table, condition, join_type, alias=None):
135 if alias is None:
136 alias = table
showard43a3d262008-11-12 18:17:05 +0000137 self._joins[alias] = (table, join_type, condition)
138
139
140 def add_where(self, where, params=[]):
141 self._where.append(where)
142 self._params.extend(params)
143
144
showarda5288b42009-07-28 20:06:08 +0000145 def add_to_query(self, query, aliases):
146 if self._where:
147 where = ' AND '.join(self._where)
148 query.add_extra(None, None, (where,), self._params, None, None)
149
150
151 def _add_customSqlQ(self, query_set, filter_object):
152 """\
153 Add a _CustomSqlQ to the query set.
154 """
155 # Make a copy of the query set
156 query_set = query_set.all()
157
158 query_set.query = query_set.query.clone(
159 ExtendedManager._CustomQuery, _customSqlQ=filter_object)
160 return query_set.filter(filter_object)
showard43a3d262008-11-12 18:17:05 +0000161
162
showard8b0ea222009-12-23 19:23:03 +0000163 def add_join(self, query_set, join_table, join_key, join_condition='',
164 alias=None, suffix='', exclude=False, force_left_join=False):
showard0957a842009-05-11 19:25:08 +0000165 """
166 Add a join to query_set.
167 @param join_table table to join to
168 @param join_key field referencing back to this model to use for the join
169 @param join_condition extra condition for the ON clause of the join
showard8b0ea222009-12-23 19:23:03 +0000170 @param alias alias to use for for join
171 @param suffix suffix to add to join_table for the join alias, if no
172 alias is provided
showard0957a842009-05-11 19:25:08 +0000173 @param exclude if true, exclude rows that match this join (will use a
showarda5288b42009-07-28 20:06:08 +0000174 LEFT OUTER JOIN and an appropriate WHERE condition)
showardc4780402009-08-31 18:31:34 +0000175 @param force_left_join - if true, a LEFT OUTER JOIN will be used
176 instead of an INNER JOIN regardless of other options
showard0957a842009-05-11 19:25:08 +0000177 """
showard8b0ea222009-12-23 19:23:03 +0000178 join_from_table = _quote_name(self.model._meta.db_table)
179 join_from_key = _quote_name(self.model._meta.pk.name)
180 if alias:
181 join_alias = alias
182 else:
183 join_alias = join_table + suffix
184 full_join_key = _quote_name(join_alias) + '.' + _quote_name(join_key)
showard0957a842009-05-11 19:25:08 +0000185 full_join_condition = '%s = %s.%s' % (full_join_key, join_from_table,
186 join_from_key)
showard43a3d262008-11-12 18:17:05 +0000187 if join_condition:
188 full_join_condition += ' AND (' + join_condition + ')'
189 if exclude or force_left_join:
showarda5288b42009-07-28 20:06:08 +0000190 join_type = query_set.query.LOUTER
showard43a3d262008-11-12 18:17:05 +0000191 else:
showarda5288b42009-07-28 20:06:08 +0000192 join_type = query_set.query.INNER
showard43a3d262008-11-12 18:17:05 +0000193
194 filter_object = self._CustomSqlQ()
195 filter_object.add_join(join_table,
196 full_join_condition,
197 join_type,
198 alias=join_alias)
199 if exclude:
200 filter_object.add_where(full_join_key + ' IS NULL')
showard43a3d262008-11-12 18:17:05 +0000201
showarda5288b42009-07-28 20:06:08 +0000202 query_set = self._add_customSqlQ(query_set, filter_object)
showardc4780402009-08-31 18:31:34 +0000203 return query_set
showard7c785282008-05-29 19:45:12 +0000204
205
showardeaccf8f2009-04-16 03:11:33 +0000206 def _get_quoted_field(self, table, field):
showarda5288b42009-07-28 20:06:08 +0000207 return _quote_name(table) + '.' + _quote_name(field)
showard5ef36e92008-07-02 16:37:09 +0000208
209
showard7c199df2008-10-03 10:17:15 +0000210 def get_key_on_this_table(self, key_field=None):
showard5ef36e92008-07-02 16:37:09 +0000211 if key_field is None:
212 # default to primary key
213 key_field = self.model._meta.pk.column
214 return self._get_quoted_field(self.model._meta.db_table, key_field)
215
216
showardeaccf8f2009-04-16 03:11:33 +0000217 def escape_user_sql(self, sql):
218 return sql.replace('%', '%%')
219
showard5ef36e92008-07-02 16:37:09 +0000220
showard0957a842009-05-11 19:25:08 +0000221 def _custom_select_query(self, query_set, selects):
showarda5288b42009-07-28 20:06:08 +0000222 sql, params = query_set.query.as_sql()
223 from_ = sql[sql.find(' FROM'):]
224
225 if query_set.query.distinct:
showard0957a842009-05-11 19:25:08 +0000226 distinct = 'DISTINCT '
227 else:
228 distinct = ''
showarda5288b42009-07-28 20:06:08 +0000229
230 sql_query = ('SELECT ' + distinct + ','.join(selects) + from_)
showard0957a842009-05-11 19:25:08 +0000231 cursor = readonly_connection.connection().cursor()
232 cursor.execute(sql_query, params)
233 return cursor.fetchall()
234
235
showard68693f72009-05-20 00:31:53 +0000236 def _is_relation_to(self, field, model_class):
237 return field.rel and field.rel.to is model_class
showard0957a842009-05-11 19:25:08 +0000238
239
showard68693f72009-05-20 00:31:53 +0000240 def _get_pivot_iterator(self, base_objects_by_id, related_model):
showard0957a842009-05-11 19:25:08 +0000241 """
showard68693f72009-05-20 00:31:53 +0000242 Determine the relationship between this model and related_model, and
243 return a pivot iterator.
244 @param base_objects_by_id: dict of instances of this model indexed by
245 their IDs
246 @returns a pivot iterator, which yields a tuple (base_object,
247 related_object) for each relationship between a base object and a
248 related object. all base_object instances come from base_objects_by_id.
showard0957a842009-05-11 19:25:08 +0000249 Note -- this depends on Django model internals and will likely need to
250 be updated when we move to Django 1.x.
251 """
showard68693f72009-05-20 00:31:53 +0000252 # look for a field on related_model relating to this model
253 for field in related_model._meta.fields:
showard0957a842009-05-11 19:25:08 +0000254 if self._is_relation_to(field, self.model):
showard68693f72009-05-20 00:31:53 +0000255 # many-to-one
256 return self._many_to_one_pivot(base_objects_by_id,
257 related_model, field)
showard0957a842009-05-11 19:25:08 +0000258
showard68693f72009-05-20 00:31:53 +0000259 for field in related_model._meta.many_to_many:
showard0957a842009-05-11 19:25:08 +0000260 if self._is_relation_to(field, self.model):
261 # many-to-many
showard68693f72009-05-20 00:31:53 +0000262 return self._many_to_many_pivot(
263 base_objects_by_id, related_model, field.m2m_db_table(),
264 field.m2m_reverse_name(), field.m2m_column_name())
showard0957a842009-05-11 19:25:08 +0000265
266 # maybe this model has the many-to-many field
267 for field in self.model._meta.many_to_many:
showard68693f72009-05-20 00:31:53 +0000268 if self._is_relation_to(field, related_model):
269 return self._many_to_many_pivot(
270 base_objects_by_id, related_model, field.m2m_db_table(),
271 field.m2m_column_name(), field.m2m_reverse_name())
showard0957a842009-05-11 19:25:08 +0000272
273 raise ValueError('%s has no relation to %s' %
showard68693f72009-05-20 00:31:53 +0000274 (related_model, self.model))
showard0957a842009-05-11 19:25:08 +0000275
276
showard68693f72009-05-20 00:31:53 +0000277 def _many_to_one_pivot(self, base_objects_by_id, related_model,
278 foreign_key_field):
279 """
280 @returns a pivot iterator - see _get_pivot_iterator()
281 """
282 filter_data = {foreign_key_field.name + '__pk__in':
283 base_objects_by_id.keys()}
284 for related_object in related_model.objects.filter(**filter_data):
showarda5a72c92009-08-20 23:35:21 +0000285 # lookup base object in the dict, rather than grabbing it from the
286 # related object. we need to return instances from the dict, not
287 # fresh instances of the same models (and grabbing model instances
288 # from the related models incurs a DB query each time).
289 base_object_id = getattr(related_object, foreign_key_field.attname)
290 base_object = base_objects_by_id[base_object_id]
showard68693f72009-05-20 00:31:53 +0000291 yield base_object, related_object
292
293
294 def _query_pivot_table(self, base_objects_by_id, pivot_table,
295 pivot_from_field, pivot_to_field):
showard0957a842009-05-11 19:25:08 +0000296 """
297 @param id_list list of IDs of self.model objects to include
298 @param pivot_table the name of the pivot table
299 @param pivot_from_field a field name on pivot_table referencing
300 self.model
301 @param pivot_to_field a field name on pivot_table referencing the
302 related model.
showard68693f72009-05-20 00:31:53 +0000303 @returns pivot list of IDs (base_id, related_id)
showard0957a842009-05-11 19:25:08 +0000304 """
305 query = """
306 SELECT %(from_field)s, %(to_field)s
307 FROM %(table)s
308 WHERE %(from_field)s IN (%(id_list)s)
309 """ % dict(from_field=pivot_from_field,
310 to_field=pivot_to_field,
311 table=pivot_table,
showard68693f72009-05-20 00:31:53 +0000312 id_list=','.join(str(id_) for id_
313 in base_objects_by_id.iterkeys()))
showard0957a842009-05-11 19:25:08 +0000314 cursor = readonly_connection.connection().cursor()
315 cursor.execute(query)
showard68693f72009-05-20 00:31:53 +0000316 return cursor.fetchall()
showard0957a842009-05-11 19:25:08 +0000317
318
showard68693f72009-05-20 00:31:53 +0000319 def _many_to_many_pivot(self, base_objects_by_id, related_model,
320 pivot_table, pivot_from_field, pivot_to_field):
321 """
322 @param pivot_table: see _query_pivot_table
323 @param pivot_from_field: see _query_pivot_table
324 @param pivot_to_field: see _query_pivot_table
325 @returns a pivot iterator - see _get_pivot_iterator()
326 """
327 id_pivot = self._query_pivot_table(base_objects_by_id, pivot_table,
328 pivot_from_field, pivot_to_field)
329
330 all_related_ids = list(set(related_id for base_id, related_id
331 in id_pivot))
332 related_objects_by_id = related_model.objects.in_bulk(all_related_ids)
333
334 for base_id, related_id in id_pivot:
335 yield base_objects_by_id[base_id], related_objects_by_id[related_id]
336
337
338 def populate_relationships(self, base_objects, related_model,
showard0957a842009-05-11 19:25:08 +0000339 related_list_name):
340 """
showard68693f72009-05-20 00:31:53 +0000341 For each instance of this model in base_objects, add a field named
342 related_list_name listing all the related objects of type related_model.
343 related_model must be in a many-to-one or many-to-many relationship with
344 this model.
345 @param base_objects - list of instances of this model
346 @param related_model - model class related to this model
347 @param related_list_name - attribute name in which to store the related
348 object list.
showard0957a842009-05-11 19:25:08 +0000349 """
showard68693f72009-05-20 00:31:53 +0000350 if not base_objects:
showard0957a842009-05-11 19:25:08 +0000351 # if we don't bail early, we'll get a SQL error later
352 return
showard0957a842009-05-11 19:25:08 +0000353
showard68693f72009-05-20 00:31:53 +0000354 base_objects_by_id = dict((base_object._get_pk_val(), base_object)
355 for base_object in base_objects)
356 pivot_iterator = self._get_pivot_iterator(base_objects_by_id,
357 related_model)
showard0957a842009-05-11 19:25:08 +0000358
showard68693f72009-05-20 00:31:53 +0000359 for base_object in base_objects:
360 setattr(base_object, related_list_name, [])
361
362 for base_object, related_object in pivot_iterator:
363 getattr(base_object, related_list_name).append(related_object)
showard0957a842009-05-11 19:25:08 +0000364
365
showard7c785282008-05-29 19:45:12 +0000366class ValidObjectsManager(ExtendedManager):
jadmanski0afbb632008-06-06 21:10:57 +0000367 """
368 Manager returning only objects with invalid=False.
369 """
370 def get_query_set(self):
371 queryset = super(ValidObjectsManager, self).get_query_set()
372 return queryset.filter(invalid=False)
showard7c785282008-05-29 19:45:12 +0000373
374
375class ModelExtensions(object):
jadmanski0afbb632008-06-06 21:10:57 +0000376 """\
377 Mixin with convenience functions for models, built on top of the
378 default Django model functions.
379 """
380 # TODO: at least some of these functions really belong in a custom
381 # Manager class
showard7c785282008-05-29 19:45:12 +0000382
jadmanski0afbb632008-06-06 21:10:57 +0000383 field_dict = None
384 # subclasses should override if they want to support smart_get() by name
385 name_field = None
showard7c785282008-05-29 19:45:12 +0000386
387
jadmanski0afbb632008-06-06 21:10:57 +0000388 @classmethod
389 def get_field_dict(cls):
390 if cls.field_dict is None:
391 cls.field_dict = {}
392 for field in cls._meta.fields:
393 cls.field_dict[field.name] = field
394 return cls.field_dict
showard7c785282008-05-29 19:45:12 +0000395
396
jadmanski0afbb632008-06-06 21:10:57 +0000397 @classmethod
398 def clean_foreign_keys(cls, data):
399 """\
400 -Convert foreign key fields in data from <field>_id to just
401 <field>.
402 -replace foreign key objects with their IDs
403 This method modifies data in-place.
404 """
405 for field in cls._meta.fields:
406 if not field.rel:
407 continue
408 if (field.attname != field.name and
409 field.attname in data):
410 data[field.name] = data[field.attname]
411 del data[field.attname]
showarde732ee72008-09-23 19:15:43 +0000412 if field.name not in data:
413 continue
jadmanski0afbb632008-06-06 21:10:57 +0000414 value = data[field.name]
415 if isinstance(value, dbmodels.Model):
showardf8b19042009-05-12 17:22:49 +0000416 data[field.name] = value._get_pk_val()
showard7c785282008-05-29 19:45:12 +0000417
418
showard21baa452008-10-21 00:08:39 +0000419 @classmethod
420 def _convert_booleans(cls, data):
421 """
422 Ensure BooleanFields actually get bool values. The Django MySQL
423 backend returns ints for BooleanFields, which is almost always not
424 a problem, but it can be annoying in certain situations.
425 """
426 for field in cls._meta.fields:
showardf8b19042009-05-12 17:22:49 +0000427 if type(field) == dbmodels.BooleanField and field.name in data:
showard21baa452008-10-21 00:08:39 +0000428 data[field.name] = bool(data[field.name])
429
430
jadmanski0afbb632008-06-06 21:10:57 +0000431 # TODO(showard) - is there a way to not have to do this?
432 @classmethod
433 def provide_default_values(cls, data):
434 """\
435 Provide default values for fields with default values which have
436 nothing passed in.
showard7c785282008-05-29 19:45:12 +0000437
jadmanski0afbb632008-06-06 21:10:57 +0000438 For CharField and TextField fields with "blank=True", if nothing
439 is passed, we fill in an empty string value, even if there's no
440 default set.
441 """
442 new_data = dict(data)
443 field_dict = cls.get_field_dict()
444 for name, obj in field_dict.iteritems():
445 if data.get(name) is not None:
446 continue
447 if obj.default is not dbmodels.fields.NOT_PROVIDED:
448 new_data[name] = obj.default
449 elif (isinstance(obj, dbmodels.CharField) or
450 isinstance(obj, dbmodels.TextField)):
451 new_data[name] = ''
452 return new_data
showard7c785282008-05-29 19:45:12 +0000453
454
jadmanski0afbb632008-06-06 21:10:57 +0000455 @classmethod
456 def convert_human_readable_values(cls, data, to_human_readable=False):
457 """\
458 Performs conversions on user-supplied field data, to make it
459 easier for users to pass human-readable data.
showard7c785282008-05-29 19:45:12 +0000460
jadmanski0afbb632008-06-06 21:10:57 +0000461 For all fields that have choice sets, convert their values
462 from human-readable strings to enum values, if necessary. This
463 allows users to pass strings instead of the corresponding
464 integer values.
showard7c785282008-05-29 19:45:12 +0000465
jadmanski0afbb632008-06-06 21:10:57 +0000466 For all foreign key fields, call smart_get with the supplied
467 data. This allows the user to pass either an ID value or
468 the name of the object as a string.
showard7c785282008-05-29 19:45:12 +0000469
jadmanski0afbb632008-06-06 21:10:57 +0000470 If to_human_readable=True, perform the inverse - i.e. convert
471 numeric values to human readable values.
showard7c785282008-05-29 19:45:12 +0000472
jadmanski0afbb632008-06-06 21:10:57 +0000473 This method modifies data in-place.
474 """
475 field_dict = cls.get_field_dict()
476 for field_name in data:
showarde732ee72008-09-23 19:15:43 +0000477 if field_name not in field_dict or data[field_name] is None:
jadmanski0afbb632008-06-06 21:10:57 +0000478 continue
479 field_obj = field_dict[field_name]
480 # convert enum values
481 if field_obj.choices:
482 for choice_data in field_obj.choices:
483 # choice_data is (value, name)
484 if to_human_readable:
485 from_val, to_val = choice_data
486 else:
487 to_val, from_val = choice_data
488 if from_val == data[field_name]:
489 data[field_name] = to_val
490 break
491 # convert foreign key values
492 elif field_obj.rel:
showarda4ea5742009-02-17 20:56:23 +0000493 dest_obj = field_obj.rel.to.smart_get(data[field_name],
494 valid_only=False)
showardf8b19042009-05-12 17:22:49 +0000495 if to_human_readable:
496 if dest_obj.name_field is not None:
497 data[field_name] = getattr(dest_obj,
498 dest_obj.name_field)
jadmanski0afbb632008-06-06 21:10:57 +0000499 else:
showardb0a73032009-03-27 18:35:41 +0000500 data[field_name] = dest_obj
showard7c785282008-05-29 19:45:12 +0000501
502
jadmanski0afbb632008-06-06 21:10:57 +0000503 @classmethod
504 def validate_field_names(cls, data):
505 'Checks for extraneous fields in data.'
506 errors = {}
507 field_dict = cls.get_field_dict()
508 for field_name in data:
509 if field_name not in field_dict:
510 errors[field_name] = 'No field of this name'
511 return errors
showard7c785282008-05-29 19:45:12 +0000512
513
jadmanski0afbb632008-06-06 21:10:57 +0000514 @classmethod
515 def prepare_data_args(cls, data, kwargs):
516 'Common preparation for add_object and update_object'
517 data = dict(data) # don't modify the default keyword arg
518 data.update(kwargs)
519 # must check for extraneous field names here, while we have the
520 # data in a dict
521 errors = cls.validate_field_names(data)
522 if errors:
523 raise ValidationError(errors)
524 cls.convert_human_readable_values(data)
525 return data
showard7c785282008-05-29 19:45:12 +0000526
527
jadmanski0afbb632008-06-06 21:10:57 +0000528 def validate_unique(self):
529 """\
530 Validate that unique fields are unique. Django manipulators do
531 this too, but they're a huge pain to use manually. Trust me.
532 """
533 errors = {}
534 cls = type(self)
535 field_dict = self.get_field_dict()
536 manager = cls.get_valid_manager()
537 for field_name, field_obj in field_dict.iteritems():
538 if not field_obj.unique:
539 continue
showard7c785282008-05-29 19:45:12 +0000540
jadmanski0afbb632008-06-06 21:10:57 +0000541 value = getattr(self, field_name)
showardbd18ab72009-09-18 21:20:27 +0000542 if value is None and field_obj.auto_created:
543 # don't bother checking autoincrement fields about to be
544 # generated
545 continue
546
jadmanski0afbb632008-06-06 21:10:57 +0000547 existing_objs = manager.filter(**{field_name : value})
548 num_existing = existing_objs.count()
showard7c785282008-05-29 19:45:12 +0000549
jadmanski0afbb632008-06-06 21:10:57 +0000550 if num_existing == 0:
551 continue
552 if num_existing == 1 and existing_objs[0].id == self.id:
553 continue
554 errors[field_name] = (
555 'This value must be unique (%s)' % (value))
556 return errors
showard7c785282008-05-29 19:45:12 +0000557
558
showarda5288b42009-07-28 20:06:08 +0000559 def _validate(self):
560 """
561 First coerces all fields on this instance to their proper Python types.
562 Then runs validation on every field. Returns a dictionary of
563 field_name -> error_list.
564
565 Based on validate() from django.db.models.Model in Django 0.96, which
566 was removed in Django 1.0. It should reappear in a later version. See:
567 http://code.djangoproject.com/ticket/6845
568 """
569 error_dict = {}
570 for f in self._meta.fields:
571 try:
572 python_value = f.to_python(
573 getattr(self, f.attname, f.get_default()))
574 except django.core.exceptions.ValidationError, e:
575 error_dict[f.name] = str(e.message)
576 continue
577
578 if not f.blank and not python_value:
579 error_dict[f.name] = 'This field is required.'
580 continue
581
582 setattr(self, f.attname, python_value)
583
584 return error_dict
585
586
jadmanski0afbb632008-06-06 21:10:57 +0000587 def do_validate(self):
showarda5288b42009-07-28 20:06:08 +0000588 errors = self._validate()
jadmanski0afbb632008-06-06 21:10:57 +0000589 unique_errors = self.validate_unique()
590 for field_name, error in unique_errors.iteritems():
591 errors.setdefault(field_name, error)
592 if errors:
593 raise ValidationError(errors)
showard7c785282008-05-29 19:45:12 +0000594
595
jadmanski0afbb632008-06-06 21:10:57 +0000596 # actually (externally) useful methods follow
showard7c785282008-05-29 19:45:12 +0000597
jadmanski0afbb632008-06-06 21:10:57 +0000598 @classmethod
599 def add_object(cls, data={}, **kwargs):
600 """\
601 Returns a new object created with the given data (a dictionary
602 mapping field names to values). Merges any extra keyword args
603 into data.
604 """
605 data = cls.prepare_data_args(data, kwargs)
606 data = cls.provide_default_values(data)
607 obj = cls(**data)
608 obj.do_validate()
609 obj.save()
610 return obj
showard7c785282008-05-29 19:45:12 +0000611
612
jadmanski0afbb632008-06-06 21:10:57 +0000613 def update_object(self, data={}, **kwargs):
614 """\
615 Updates the object with the given data (a dictionary mapping
616 field names to values). Merges any extra keyword args into
617 data.
618 """
619 data = self.prepare_data_args(data, kwargs)
620 for field_name, value in data.iteritems():
showardb0a73032009-03-27 18:35:41 +0000621 setattr(self, field_name, value)
jadmanski0afbb632008-06-06 21:10:57 +0000622 self.do_validate()
623 self.save()
showard7c785282008-05-29 19:45:12 +0000624
625
showard8bfb5cb2009-10-07 20:49:15 +0000626 # see query_objects()
627 _SPECIAL_FILTER_KEYS = ('query_start', 'query_limit', 'sort_by',
628 'extra_args', 'extra_where', 'no_distinct')
629
630
jadmanski0afbb632008-06-06 21:10:57 +0000631 @classmethod
showard8bfb5cb2009-10-07 20:49:15 +0000632 def _extract_special_params(cls, filter_data):
633 """
634 @returns a tuple of dicts (special_params, regular_filters), where
635 special_params contains the parameters we handle specially and
636 regular_filters is the remaining data to be handled by Django.
637 """
638 regular_filters = dict(filter_data)
639 special_params = {}
640 for key in cls._SPECIAL_FILTER_KEYS:
641 if key in regular_filters:
642 special_params[key] = regular_filters.pop(key)
643 return special_params, regular_filters
644
645
646 @classmethod
647 def apply_presentation(cls, query, filter_data):
648 """
649 Apply presentation parameters -- sorting and paging -- to the given
650 query.
651 @returns new query with presentation applied
652 """
653 special_params, _ = cls._extract_special_params(filter_data)
654 sort_by = special_params.get('sort_by', None)
655 if sort_by:
656 assert isinstance(sort_by, list) or isinstance(sort_by, tuple)
showard8b0ea222009-12-23 19:23:03 +0000657 query = query.extra(order_by=sort_by)
showard8bfb5cb2009-10-07 20:49:15 +0000658
659 query_start = special_params.get('query_start', None)
660 query_limit = special_params.get('query_limit', None)
661 if query_start is not None:
662 if query_limit is None:
663 raise ValueError('Cannot pass query_start without query_limit')
664 # query_limit is passed as a page size
showard7074b742009-10-12 20:30:04 +0000665 query_limit += query_start
666 return query[query_start:query_limit]
showard8bfb5cb2009-10-07 20:49:15 +0000667
668
669 @classmethod
670 def query_objects(cls, filter_data, valid_only=True, initial_query=None,
671 apply_presentation=True):
jadmanski0afbb632008-06-06 21:10:57 +0000672 """\
673 Returns a QuerySet object for querying the given model_class
674 with the given filter_data. Optional special arguments in
675 filter_data include:
676 -query_start: index of first return to return
677 -query_limit: maximum number of results to return
678 -sort_by: list of fields to sort on. prefixing a '-' onto a
679 field name changes the sort to descending order.
680 -extra_args: keyword args to pass to query.extra() (see Django
681 DB layer documentation)
showarda5288b42009-07-28 20:06:08 +0000682 -extra_where: extra WHERE clause to append
showard8bfb5cb2009-10-07 20:49:15 +0000683 -no_distinct: if True, a DISTINCT will not be added to the SELECT
jadmanski0afbb632008-06-06 21:10:57 +0000684 """
showard8bfb5cb2009-10-07 20:49:15 +0000685 special_params, regular_filters = cls._extract_special_params(
686 filter_data)
showard7c785282008-05-29 19:45:12 +0000687
showard7ac7b7a2008-07-21 20:24:29 +0000688 if initial_query is None:
689 if valid_only:
690 initial_query = cls.get_valid_manager()
691 else:
692 initial_query = cls.objects
showard8bfb5cb2009-10-07 20:49:15 +0000693
694 query = initial_query.filter(**regular_filters)
695
696 use_distinct = not special_params.get('no_distinct', False)
showard7ac7b7a2008-07-21 20:24:29 +0000697 if use_distinct:
698 query = query.distinct()
showard7c785282008-05-29 19:45:12 +0000699
showard8bfb5cb2009-10-07 20:49:15 +0000700 extra_args = special_params.get('extra_args', {})
701 extra_where = special_params.get('extra_where', None)
702 if extra_where:
703 # escape %'s
704 extra_where = cls.objects.escape_user_sql(extra_where)
705 extra_args.setdefault('where', []).append(extra_where)
jadmanski0afbb632008-06-06 21:10:57 +0000706 if extra_args:
707 query = query.extra(**extra_args)
showard09096d82008-07-07 23:20:49 +0000708 query = query._clone(klass=ReadonlyQuerySet)
showard7c785282008-05-29 19:45:12 +0000709
showard8bfb5cb2009-10-07 20:49:15 +0000710 if apply_presentation:
711 query = cls.apply_presentation(query, filter_data)
712
713 return query
showard7c785282008-05-29 19:45:12 +0000714
715
jadmanski0afbb632008-06-06 21:10:57 +0000716 @classmethod
showard585c2ab2008-07-23 19:29:49 +0000717 def query_count(cls, filter_data, initial_query=None):
jadmanski0afbb632008-06-06 21:10:57 +0000718 """\
719 Like query_objects, but retreive only the count of results.
720 """
721 filter_data.pop('query_start', None)
722 filter_data.pop('query_limit', None)
showard585c2ab2008-07-23 19:29:49 +0000723 query = cls.query_objects(filter_data, initial_query=initial_query)
724 return query.count()
showard7c785282008-05-29 19:45:12 +0000725
726
jadmanski0afbb632008-06-06 21:10:57 +0000727 @classmethod
728 def clean_object_dicts(cls, field_dicts):
729 """\
730 Take a list of dicts corresponding to object (as returned by
731 query.values()) and clean the data to be more suitable for
732 returning to the user.
733 """
showarde732ee72008-09-23 19:15:43 +0000734 for field_dict in field_dicts:
735 cls.clean_foreign_keys(field_dict)
showard21baa452008-10-21 00:08:39 +0000736 cls._convert_booleans(field_dict)
showarde732ee72008-09-23 19:15:43 +0000737 cls.convert_human_readable_values(field_dict,
738 to_human_readable=True)
showard7c785282008-05-29 19:45:12 +0000739
740
jadmanski0afbb632008-06-06 21:10:57 +0000741 @classmethod
showard8bfb5cb2009-10-07 20:49:15 +0000742 def list_objects(cls, filter_data, initial_query=None):
jadmanski0afbb632008-06-06 21:10:57 +0000743 """\
744 Like query_objects, but return a list of dictionaries.
745 """
showard7ac7b7a2008-07-21 20:24:29 +0000746 query = cls.query_objects(filter_data, initial_query=initial_query)
showard8bfb5cb2009-10-07 20:49:15 +0000747 extra_fields = query.query.extra_select.keys()
748 field_dicts = [model_object.get_object_dict(extra_fields=extra_fields)
showarde732ee72008-09-23 19:15:43 +0000749 for model_object in query]
jadmanski0afbb632008-06-06 21:10:57 +0000750 return field_dicts
showard7c785282008-05-29 19:45:12 +0000751
752
jadmanski0afbb632008-06-06 21:10:57 +0000753 @classmethod
showarda4ea5742009-02-17 20:56:23 +0000754 def smart_get(cls, id_or_name, valid_only=True):
jadmanski0afbb632008-06-06 21:10:57 +0000755 """\
756 smart_get(integer) -> get object by ID
757 smart_get(string) -> get object by name_field
jadmanski0afbb632008-06-06 21:10:57 +0000758 """
showarda4ea5742009-02-17 20:56:23 +0000759 if valid_only:
760 manager = cls.get_valid_manager()
761 else:
762 manager = cls.objects
763
764 if isinstance(id_or_name, (int, long)):
765 return manager.get(pk=id_or_name)
766 if isinstance(id_or_name, basestring):
767 return manager.get(**{cls.name_field : id_or_name})
768 raise ValueError(
769 'Invalid positional argument: %s (%s)' % (id_or_name,
770 type(id_or_name)))
showard7c785282008-05-29 19:45:12 +0000771
772
showardbe3ec042008-11-12 18:16:07 +0000773 @classmethod
774 def smart_get_bulk(cls, id_or_name_list):
775 invalid_inputs = []
776 result_objects = []
777 for id_or_name in id_or_name_list:
778 try:
779 result_objects.append(cls.smart_get(id_or_name))
780 except cls.DoesNotExist:
781 invalid_inputs.append(id_or_name)
782 if invalid_inputs:
mbligh7a3ebe32008-12-01 17:10:33 +0000783 raise cls.DoesNotExist('The following %ss do not exist: %s'
784 % (cls.__name__.lower(),
785 ', '.join(invalid_inputs)))
showardbe3ec042008-11-12 18:16:07 +0000786 return result_objects
787
788
showard8bfb5cb2009-10-07 20:49:15 +0000789 def get_object_dict(self, extra_fields=None):
jadmanski0afbb632008-06-06 21:10:57 +0000790 """\
showard8bfb5cb2009-10-07 20:49:15 +0000791 Return a dictionary mapping fields to this object's values. @param
792 extra_fields: list of extra attribute names to include, in addition to
793 the fields defined on this object.
jadmanski0afbb632008-06-06 21:10:57 +0000794 """
showard8bfb5cb2009-10-07 20:49:15 +0000795 fields = self.get_field_dict().keys()
796 if extra_fields:
797 fields += extra_fields
jadmanski0afbb632008-06-06 21:10:57 +0000798 object_dict = dict((field_name, getattr(self, field_name))
showarde732ee72008-09-23 19:15:43 +0000799 for field_name in fields)
jadmanski0afbb632008-06-06 21:10:57 +0000800 self.clean_object_dicts([object_dict])
showardd3dc1992009-04-22 21:01:40 +0000801 self._postprocess_object_dict(object_dict)
jadmanski0afbb632008-06-06 21:10:57 +0000802 return object_dict
showard7c785282008-05-29 19:45:12 +0000803
804
showardd3dc1992009-04-22 21:01:40 +0000805 def _postprocess_object_dict(self, object_dict):
806 """For subclasses to override."""
807 pass
808
809
jadmanski0afbb632008-06-06 21:10:57 +0000810 @classmethod
811 def get_valid_manager(cls):
812 return cls.objects
showard7c785282008-05-29 19:45:12 +0000813
814
showard2bab8f42008-11-12 18:15:22 +0000815 def _record_attributes(self, attributes):
816 """
817 See on_attribute_changed.
818 """
819 assert not isinstance(attributes, basestring)
820 self._recorded_attributes = dict((attribute, getattr(self, attribute))
821 for attribute in attributes)
822
823
824 def _check_for_updated_attributes(self):
825 """
826 See on_attribute_changed.
827 """
828 for attribute, original_value in self._recorded_attributes.iteritems():
829 new_value = getattr(self, attribute)
830 if original_value != new_value:
831 self.on_attribute_changed(attribute, original_value)
832 self._record_attributes(self._recorded_attributes.keys())
833
834
835 def on_attribute_changed(self, attribute, old_value):
836 """
837 Called whenever an attribute is updated. To be overridden.
838
839 To use this method, you must:
840 * call _record_attributes() from __init__() (after making the super
841 call) with a list of attributes for which you want to be notified upon
842 change.
843 * call _check_for_updated_attributes() from save().
844 """
845 pass
846
847
showard7c785282008-05-29 19:45:12 +0000848class ModelWithInvalid(ModelExtensions):
jadmanski0afbb632008-06-06 21:10:57 +0000849 """
850 Overrides model methods save() and delete() to support invalidation in
851 place of actual deletion. Subclasses must have a boolean "invalid"
852 field.
853 """
showard7c785282008-05-29 19:45:12 +0000854
showarda5288b42009-07-28 20:06:08 +0000855 def save(self, *args, **kwargs):
showardddb90992009-02-11 23:39:32 +0000856 first_time = (self.id is None)
857 if first_time:
858 # see if this object was previously added and invalidated
859 my_name = getattr(self, self.name_field)
860 filters = {self.name_field : my_name, 'invalid' : True}
861 try:
862 old_object = self.__class__.objects.get(**filters)
showardafd97de2009-10-01 18:45:09 +0000863 self.resurrect_object(old_object)
showardddb90992009-02-11 23:39:32 +0000864 except self.DoesNotExist:
865 # no existing object
866 pass
showard7c785282008-05-29 19:45:12 +0000867
showarda5288b42009-07-28 20:06:08 +0000868 super(ModelWithInvalid, self).save(*args, **kwargs)
showard7c785282008-05-29 19:45:12 +0000869
870
showardafd97de2009-10-01 18:45:09 +0000871 def resurrect_object(self, old_object):
872 """
873 Called when self is about to be saved for the first time and is actually
874 "undeleting" a previously deleted object. Can be overridden by
875 subclasses to copy data as desired from the deleted entry (but this
876 superclass implementation must normally be called).
877 """
878 self.id = old_object.id
879
880
jadmanski0afbb632008-06-06 21:10:57 +0000881 def clean_object(self):
882 """
883 This method is called when an object is marked invalid.
884 Subclasses should override this to clean up relationships that
showardafd97de2009-10-01 18:45:09 +0000885 should no longer exist if the object were deleted.
886 """
jadmanski0afbb632008-06-06 21:10:57 +0000887 pass
showard7c785282008-05-29 19:45:12 +0000888
889
jadmanski0afbb632008-06-06 21:10:57 +0000890 def delete(self):
891 assert not self.invalid
892 self.invalid = True
893 self.save()
894 self.clean_object()
showard7c785282008-05-29 19:45:12 +0000895
896
jadmanski0afbb632008-06-06 21:10:57 +0000897 @classmethod
898 def get_valid_manager(cls):
899 return cls.valid_objects
showard7c785282008-05-29 19:45:12 +0000900
901
jadmanski0afbb632008-06-06 21:10:57 +0000902 class Manipulator(object):
903 """
904 Force default manipulators to look only at valid objects -
905 otherwise they will match against invalid objects when checking
906 uniqueness.
907 """
908 @classmethod
909 def _prepare(cls, model):
910 super(ModelWithInvalid.Manipulator, cls)._prepare(model)
911 cls.manager = model.valid_objects
showardf8b19042009-05-12 17:22:49 +0000912
913
914class ModelWithAttributes(object):
915 """
916 Mixin class for models that have an attribute model associated with them.
917 The attribute model is assumed to have its value field named "value".
918 """
919
920 def _get_attribute_model_and_args(self, attribute):
921 """
922 Subclasses should override this to return a tuple (attribute_model,
923 keyword_args), where attribute_model is a model class and keyword_args
924 is a dict of args to pass to attribute_model.objects.get() to get an
925 instance of the given attribute on this object.
926 """
927 raise NotImplemented
928
929
930 def set_attribute(self, attribute, value):
931 attribute_model, get_args = self._get_attribute_model_and_args(
932 attribute)
933 attribute_object, _ = attribute_model.objects.get_or_create(**get_args)
934 attribute_object.value = value
935 attribute_object.save()
936
937
938 def delete_attribute(self, attribute):
939 attribute_model, get_args = self._get_attribute_model_and_args(
940 attribute)
941 try:
942 attribute_model.objects.get(**get_args).delete()
showard16245422009-09-08 16:28:15 +0000943 except attribute_model.DoesNotExist:
showardf8b19042009-05-12 17:22:49 +0000944 pass
945
946
947 def set_or_delete_attribute(self, attribute, value):
948 if value is None:
949 self.delete_attribute(attribute)
950 else:
951 self.set_attribute(attribute, value)
showard26b7ec72009-12-21 22:43:57 +0000952
953
954class ModelWithHashManager(dbmodels.Manager):
955 """Manager for use with the ModelWithHash abstract model class"""
956
957 def create(self, **kwargs):
958 raise Exception('ModelWithHash manager should use get_or_create() '
959 'instead of create()')
960
961
962 def get_or_create(self, **kwargs):
963 kwargs['the_hash'] = self.model._compute_hash(**kwargs)
964 return super(ModelWithHashManager, self).get_or_create(**kwargs)
965
966
967class ModelWithHash(dbmodels.Model):
968 """Superclass with methods for dealing with a hash column"""
969
970 the_hash = dbmodels.CharField(max_length=40, unique=True)
971
972 objects = ModelWithHashManager()
973
974 class Meta:
975 abstract = True
976
977
978 @classmethod
979 def _compute_hash(cls, **kwargs):
980 raise NotImplementedError('Subclasses must override _compute_hash()')
981
982
983 def save(self, force_insert=False, **kwargs):
984 """Prevents saving the model in most cases
985
986 We want these models to be immutable, so the generic save() operation
987 will not work. These models should be instantiated through their the
988 model.objects.get_or_create() method instead.
989
990 The exception is that save(force_insert=True) will be allowed, since
991 that creates a new row. However, the preferred way to make instances of
992 these models is through the get_or_create() method.
993 """
994 if not force_insert:
995 # Allow a forced insert to happen; if it's a duplicate, the unique
996 # constraint will catch it later anyways
997 raise Exception('ModelWithHash is immutable')
998 super(ModelWithHash, self).save(force_insert=force_insert, **kwargs)