blob: 601b9c3af95d824cabad26b7dacf96f12a50ebe9 [file] [log] [blame]
showard7c785282008-05-29 19:45:12 +00001"""
2Extensions to Django's model logic.
3"""
4
showarda5288b42009-07-28 20:06:08 +00005import re
6import django.core.exceptions
showard7c785282008-05-29 19:45:12 +00007from django.db import models as dbmodels, backend, connection
showarda5288b42009-07-28 20:06:08 +00008from django.db.models.sql import query
showard7c785282008-05-29 19:45:12 +00009from django.utils import datastructures
showard56e93772008-10-06 10:06:22 +000010from autotest_lib.frontend.afe import readonly_connection
showard7c785282008-05-29 19:45:12 +000011
showarda5288b42009-07-28 20:06:08 +000012
showard7c785282008-05-29 19:45:12 +000013class ValidationError(Exception):
jadmanski0afbb632008-06-06 21:10:57 +000014 """\
showarda5288b42009-07-28 20:06:08 +000015 Data validation error in adding or updating an object. The associated
jadmanski0afbb632008-06-06 21:10:57 +000016 value is a dictionary mapping field names to error strings.
17 """
showard7c785282008-05-29 19:45:12 +000018
19
showard09096d82008-07-07 23:20:49 +000020def _wrap_with_readonly(method):
mbligh1ef218d2009-08-03 16:57:56 +000021 def wrapper_method(*args, **kwargs):
22 readonly_connection.connection().set_django_connection()
23 try:
24 return method(*args, **kwargs)
25 finally:
26 readonly_connection.connection().unset_django_connection()
27 wrapper_method.__name__ = method.__name__
28 return wrapper_method
showard09096d82008-07-07 23:20:49 +000029
30
showarda5288b42009-07-28 20:06:08 +000031def _quote_name(name):
32 """Shorthand for connection.ops.quote_name()."""
33 return connection.ops.quote_name(name)
34
35
showard09096d82008-07-07 23:20:49 +000036def _wrap_generator_with_readonly(generator):
37 """
38 We have to wrap generators specially. Assume it performs
39 the query on the first call to next().
40 """
41 def wrapper_generator(*args, **kwargs):
42 generator_obj = generator(*args, **kwargs)
showard56e93772008-10-06 10:06:22 +000043 readonly_connection.connection().set_django_connection()
showard09096d82008-07-07 23:20:49 +000044 try:
45 first_value = generator_obj.next()
46 finally:
showard56e93772008-10-06 10:06:22 +000047 readonly_connection.connection().unset_django_connection()
showard09096d82008-07-07 23:20:49 +000048 yield first_value
49
50 while True:
51 yield generator_obj.next()
52
53 wrapper_generator.__name__ = generator.__name__
54 return wrapper_generator
55
56
57def _make_queryset_readonly(queryset):
58 """
59 Wrap all methods that do database queries with a readonly connection.
60 """
61 db_query_methods = ['count', 'get', 'get_or_create', 'latest', 'in_bulk',
62 'delete']
63 for method_name in db_query_methods:
64 method = getattr(queryset, method_name)
65 wrapped_method = _wrap_with_readonly(method)
66 setattr(queryset, method_name, wrapped_method)
67
68 queryset.iterator = _wrap_generator_with_readonly(queryset.iterator)
69
70
71class ReadonlyQuerySet(dbmodels.query.QuerySet):
72 """
73 QuerySet object that performs all database queries with the read-only
74 connection.
75 """
showarda5288b42009-07-28 20:06:08 +000076 def __init__(self, model=None, *args, **kwargs):
77 super(ReadonlyQuerySet, self).__init__(model, *args, **kwargs)
showard09096d82008-07-07 23:20:49 +000078 _make_queryset_readonly(self)
79
80
81 def values(self, *fields):
showarda5288b42009-07-28 20:06:08 +000082 return self._clone(klass=ReadonlyValuesQuerySet,
83 setup=True, _fields=fields)
showard09096d82008-07-07 23:20:49 +000084
85
86class ReadonlyValuesQuerySet(dbmodels.query.ValuesQuerySet):
showarda5288b42009-07-28 20:06:08 +000087 def __init__(self, model=None, *args, **kwargs):
88 super(ReadonlyValuesQuerySet, self).__init__(model, *args, **kwargs)
showard09096d82008-07-07 23:20:49 +000089 _make_queryset_readonly(self)
90
91
showard7c785282008-05-29 19:45:12 +000092class ExtendedManager(dbmodels.Manager):
jadmanski0afbb632008-06-06 21:10:57 +000093 """\
94 Extended manager supporting subquery filtering.
95 """
showard7c785282008-05-29 19:45:12 +000096
showarda5288b42009-07-28 20:06:08 +000097 class _CustomQuery(query.Query):
98 def clone(self, klass=None, **kwargs):
99 obj = super(ExtendedManager._CustomQuery, self).clone(
100 klass, _customSqlQ=self._customSqlQ)
showard7c785282008-05-29 19:45:12 +0000101
showarda5288b42009-07-28 20:06:08 +0000102 customQ = kwargs.get('_customSqlQ', None)
103 if customQ is not None:
104 obj._customSqlQ._joins.update(customQ._joins)
105 obj._customSqlQ._where.extend(customQ._where)
106 obj._customSqlQ._params.extend(customQ._params)
showard7c785282008-05-29 19:45:12 +0000107
showarda5288b42009-07-28 20:06:08 +0000108 return obj
showard08f981b2008-06-24 21:59:03 +0000109
showarda5288b42009-07-28 20:06:08 +0000110 def get_from_clause(self):
111 from_, params = super(
112 ExtendedManager._CustomQuery, self).get_from_clause()
showard08f981b2008-06-24 21:59:03 +0000113
showarda5288b42009-07-28 20:06:08 +0000114 join_clause = ''
115 for join_alias, join in self._customSqlQ._joins.iteritems():
116 join_table, join_type, condition = join
117 join_clause += ' %s %s AS %s ON (%s)' % (
118 join_type, join_table, join_alias, condition)
showard08f981b2008-06-24 21:59:03 +0000119
showarda5288b42009-07-28 20:06:08 +0000120 if join_clause:
121 from_.append(join_clause)
showard7c785282008-05-29 19:45:12 +0000122
showarda5288b42009-07-28 20:06:08 +0000123 return from_, params
showard7c785282008-05-29 19:45:12 +0000124
125
showard43a3d262008-11-12 18:17:05 +0000126 class _CustomSqlQ(dbmodels.Q):
127 def __init__(self):
128 self._joins = datastructures.SortedDict()
129 self._where, self._params = [], []
130
131
132 def add_join(self, table, condition, join_type, alias=None):
133 if alias is None:
134 alias = table
showard43a3d262008-11-12 18:17:05 +0000135 self._joins[alias] = (table, join_type, condition)
136
137
138 def add_where(self, where, params=[]):
139 self._where.append(where)
140 self._params.extend(params)
141
142
showarda5288b42009-07-28 20:06:08 +0000143 def add_to_query(self, query, aliases):
144 if self._where:
145 where = ' AND '.join(self._where)
146 query.add_extra(None, None, (where,), self._params, None, None)
147
148
149 def _add_customSqlQ(self, query_set, filter_object):
150 """\
151 Add a _CustomSqlQ to the query set.
152 """
153 # Make a copy of the query set
154 query_set = query_set.all()
155
156 query_set.query = query_set.query.clone(
157 ExtendedManager._CustomQuery, _customSqlQ=filter_object)
158 return query_set.filter(filter_object)
showard43a3d262008-11-12 18:17:05 +0000159
160
161 def add_join(self, query_set, join_table, join_key,
showard0957a842009-05-11 19:25:08 +0000162 join_condition='', suffix='', exclude=False,
163 force_left_join=False):
164 """
165 Add a join to query_set.
166 @param join_table table to join to
167 @param join_key field referencing back to this model to use for the join
168 @param join_condition extra condition for the ON clause of the join
169 @param suffix suffix to add to join_table for the join alias
170 @param exclude if true, exclude rows that match this join (will use a
showarda5288b42009-07-28 20:06:08 +0000171 LEFT OUTER JOIN and an appropriate WHERE condition)
showardc4780402009-08-31 18:31:34 +0000172 @param force_left_join - if true, a LEFT OUTER JOIN will be used
173 instead of an INNER JOIN regardless of other options
showard0957a842009-05-11 19:25:08 +0000174 """
175 join_from_table = self.model._meta.db_table
176 join_from_key = self.model._meta.pk.name
showard43a3d262008-11-12 18:17:05 +0000177 join_alias = join_table + suffix
178 full_join_key = join_alias + '.' + join_key
showard0957a842009-05-11 19:25:08 +0000179 full_join_condition = '%s = %s.%s' % (full_join_key, join_from_table,
180 join_from_key)
showard43a3d262008-11-12 18:17:05 +0000181 if join_condition:
182 full_join_condition += ' AND (' + join_condition + ')'
183 if exclude or force_left_join:
showarda5288b42009-07-28 20:06:08 +0000184 join_type = query_set.query.LOUTER
showard43a3d262008-11-12 18:17:05 +0000185 else:
showarda5288b42009-07-28 20:06:08 +0000186 join_type = query_set.query.INNER
showard43a3d262008-11-12 18:17:05 +0000187
188 filter_object = self._CustomSqlQ()
189 filter_object.add_join(join_table,
190 full_join_condition,
191 join_type,
192 alias=join_alias)
193 if exclude:
194 filter_object.add_where(full_join_key + ' IS NULL')
showard43a3d262008-11-12 18:17:05 +0000195
showarda5288b42009-07-28 20:06:08 +0000196 query_set = self._add_customSqlQ(query_set, filter_object)
showardc4780402009-08-31 18:31:34 +0000197 return query_set
showard7c785282008-05-29 19:45:12 +0000198
199
showardeaccf8f2009-04-16 03:11:33 +0000200 def _get_quoted_field(self, table, field):
showarda5288b42009-07-28 20:06:08 +0000201 return _quote_name(table) + '.' + _quote_name(field)
showard5ef36e92008-07-02 16:37:09 +0000202
203
showard7c199df2008-10-03 10:17:15 +0000204 def get_key_on_this_table(self, key_field=None):
showard5ef36e92008-07-02 16:37:09 +0000205 if key_field is None:
206 # default to primary key
207 key_field = self.model._meta.pk.column
208 return self._get_quoted_field(self.model._meta.db_table, key_field)
209
210
showardeaccf8f2009-04-16 03:11:33 +0000211 def escape_user_sql(self, sql):
212 return sql.replace('%', '%%')
213
showard5ef36e92008-07-02 16:37:09 +0000214
showard0957a842009-05-11 19:25:08 +0000215 def _custom_select_query(self, query_set, selects):
showarda5288b42009-07-28 20:06:08 +0000216 sql, params = query_set.query.as_sql()
217 from_ = sql[sql.find(' FROM'):]
218
219 if query_set.query.distinct:
showard0957a842009-05-11 19:25:08 +0000220 distinct = 'DISTINCT '
221 else:
222 distinct = ''
showarda5288b42009-07-28 20:06:08 +0000223
224 sql_query = ('SELECT ' + distinct + ','.join(selects) + from_)
showard0957a842009-05-11 19:25:08 +0000225 cursor = readonly_connection.connection().cursor()
226 cursor.execute(sql_query, params)
227 return cursor.fetchall()
228
229
showard68693f72009-05-20 00:31:53 +0000230 def _is_relation_to(self, field, model_class):
231 return field.rel and field.rel.to is model_class
showard0957a842009-05-11 19:25:08 +0000232
233
showard68693f72009-05-20 00:31:53 +0000234 def _get_pivot_iterator(self, base_objects_by_id, related_model):
showard0957a842009-05-11 19:25:08 +0000235 """
showard68693f72009-05-20 00:31:53 +0000236 Determine the relationship between this model and related_model, and
237 return a pivot iterator.
238 @param base_objects_by_id: dict of instances of this model indexed by
239 their IDs
240 @returns a pivot iterator, which yields a tuple (base_object,
241 related_object) for each relationship between a base object and a
242 related object. all base_object instances come from base_objects_by_id.
showard0957a842009-05-11 19:25:08 +0000243 Note -- this depends on Django model internals and will likely need to
244 be updated when we move to Django 1.x.
245 """
showard68693f72009-05-20 00:31:53 +0000246 # look for a field on related_model relating to this model
247 for field in related_model._meta.fields:
showard0957a842009-05-11 19:25:08 +0000248 if self._is_relation_to(field, self.model):
showard68693f72009-05-20 00:31:53 +0000249 # many-to-one
250 return self._many_to_one_pivot(base_objects_by_id,
251 related_model, field)
showard0957a842009-05-11 19:25:08 +0000252
showard68693f72009-05-20 00:31:53 +0000253 for field in related_model._meta.many_to_many:
showard0957a842009-05-11 19:25:08 +0000254 if self._is_relation_to(field, self.model):
255 # many-to-many
showard68693f72009-05-20 00:31:53 +0000256 return self._many_to_many_pivot(
257 base_objects_by_id, related_model, field.m2m_db_table(),
258 field.m2m_reverse_name(), field.m2m_column_name())
showard0957a842009-05-11 19:25:08 +0000259
260 # maybe this model has the many-to-many field
261 for field in self.model._meta.many_to_many:
showard68693f72009-05-20 00:31:53 +0000262 if self._is_relation_to(field, related_model):
263 return self._many_to_many_pivot(
264 base_objects_by_id, related_model, field.m2m_db_table(),
265 field.m2m_column_name(), field.m2m_reverse_name())
showard0957a842009-05-11 19:25:08 +0000266
267 raise ValueError('%s has no relation to %s' %
showard68693f72009-05-20 00:31:53 +0000268 (related_model, self.model))
showard0957a842009-05-11 19:25:08 +0000269
270
showard68693f72009-05-20 00:31:53 +0000271 def _many_to_one_pivot(self, base_objects_by_id, related_model,
272 foreign_key_field):
273 """
274 @returns a pivot iterator - see _get_pivot_iterator()
275 """
276 filter_data = {foreign_key_field.name + '__pk__in':
277 base_objects_by_id.keys()}
278 for related_object in related_model.objects.filter(**filter_data):
showarda5a72c92009-08-20 23:35:21 +0000279 # lookup base object in the dict, rather than grabbing it from the
280 # related object. we need to return instances from the dict, not
281 # fresh instances of the same models (and grabbing model instances
282 # from the related models incurs a DB query each time).
283 base_object_id = getattr(related_object, foreign_key_field.attname)
284 base_object = base_objects_by_id[base_object_id]
showard68693f72009-05-20 00:31:53 +0000285 yield base_object, related_object
286
287
288 def _query_pivot_table(self, base_objects_by_id, pivot_table,
289 pivot_from_field, pivot_to_field):
showard0957a842009-05-11 19:25:08 +0000290 """
291 @param id_list list of IDs of self.model objects to include
292 @param pivot_table the name of the pivot table
293 @param pivot_from_field a field name on pivot_table referencing
294 self.model
295 @param pivot_to_field a field name on pivot_table referencing the
296 related model.
showard68693f72009-05-20 00:31:53 +0000297 @returns pivot list of IDs (base_id, related_id)
showard0957a842009-05-11 19:25:08 +0000298 """
299 query = """
300 SELECT %(from_field)s, %(to_field)s
301 FROM %(table)s
302 WHERE %(from_field)s IN (%(id_list)s)
303 """ % dict(from_field=pivot_from_field,
304 to_field=pivot_to_field,
305 table=pivot_table,
showard68693f72009-05-20 00:31:53 +0000306 id_list=','.join(str(id_) for id_
307 in base_objects_by_id.iterkeys()))
showard0957a842009-05-11 19:25:08 +0000308 cursor = readonly_connection.connection().cursor()
309 cursor.execute(query)
showard68693f72009-05-20 00:31:53 +0000310 return cursor.fetchall()
showard0957a842009-05-11 19:25:08 +0000311
312
showard68693f72009-05-20 00:31:53 +0000313 def _many_to_many_pivot(self, base_objects_by_id, related_model,
314 pivot_table, pivot_from_field, pivot_to_field):
315 """
316 @param pivot_table: see _query_pivot_table
317 @param pivot_from_field: see _query_pivot_table
318 @param pivot_to_field: see _query_pivot_table
319 @returns a pivot iterator - see _get_pivot_iterator()
320 """
321 id_pivot = self._query_pivot_table(base_objects_by_id, pivot_table,
322 pivot_from_field, pivot_to_field)
323
324 all_related_ids = list(set(related_id for base_id, related_id
325 in id_pivot))
326 related_objects_by_id = related_model.objects.in_bulk(all_related_ids)
327
328 for base_id, related_id in id_pivot:
329 yield base_objects_by_id[base_id], related_objects_by_id[related_id]
330
331
332 def populate_relationships(self, base_objects, related_model,
showard0957a842009-05-11 19:25:08 +0000333 related_list_name):
334 """
showard68693f72009-05-20 00:31:53 +0000335 For each instance of this model in base_objects, add a field named
336 related_list_name listing all the related objects of type related_model.
337 related_model must be in a many-to-one or many-to-many relationship with
338 this model.
339 @param base_objects - list of instances of this model
340 @param related_model - model class related to this model
341 @param related_list_name - attribute name in which to store the related
342 object list.
showard0957a842009-05-11 19:25:08 +0000343 """
showard68693f72009-05-20 00:31:53 +0000344 if not base_objects:
showard0957a842009-05-11 19:25:08 +0000345 # if we don't bail early, we'll get a SQL error later
346 return
showard0957a842009-05-11 19:25:08 +0000347
showard68693f72009-05-20 00:31:53 +0000348 base_objects_by_id = dict((base_object._get_pk_val(), base_object)
349 for base_object in base_objects)
350 pivot_iterator = self._get_pivot_iterator(base_objects_by_id,
351 related_model)
showard0957a842009-05-11 19:25:08 +0000352
showard68693f72009-05-20 00:31:53 +0000353 for base_object in base_objects:
354 setattr(base_object, related_list_name, [])
355
356 for base_object, related_object in pivot_iterator:
357 getattr(base_object, related_list_name).append(related_object)
showard0957a842009-05-11 19:25:08 +0000358
359
showard7c785282008-05-29 19:45:12 +0000360class ValidObjectsManager(ExtendedManager):
jadmanski0afbb632008-06-06 21:10:57 +0000361 """
362 Manager returning only objects with invalid=False.
363 """
364 def get_query_set(self):
365 queryset = super(ValidObjectsManager, self).get_query_set()
366 return queryset.filter(invalid=False)
showard7c785282008-05-29 19:45:12 +0000367
368
369class ModelExtensions(object):
jadmanski0afbb632008-06-06 21:10:57 +0000370 """\
371 Mixin with convenience functions for models, built on top of the
372 default Django model functions.
373 """
374 # TODO: at least some of these functions really belong in a custom
375 # Manager class
showard7c785282008-05-29 19:45:12 +0000376
jadmanski0afbb632008-06-06 21:10:57 +0000377 field_dict = None
378 # subclasses should override if they want to support smart_get() by name
379 name_field = None
showard7c785282008-05-29 19:45:12 +0000380
381
jadmanski0afbb632008-06-06 21:10:57 +0000382 @classmethod
383 def get_field_dict(cls):
384 if cls.field_dict is None:
385 cls.field_dict = {}
386 for field in cls._meta.fields:
387 cls.field_dict[field.name] = field
388 return cls.field_dict
showard7c785282008-05-29 19:45:12 +0000389
390
jadmanski0afbb632008-06-06 21:10:57 +0000391 @classmethod
392 def clean_foreign_keys(cls, data):
393 """\
394 -Convert foreign key fields in data from <field>_id to just
395 <field>.
396 -replace foreign key objects with their IDs
397 This method modifies data in-place.
398 """
399 for field in cls._meta.fields:
400 if not field.rel:
401 continue
402 if (field.attname != field.name and
403 field.attname in data):
404 data[field.name] = data[field.attname]
405 del data[field.attname]
showarde732ee72008-09-23 19:15:43 +0000406 if field.name not in data:
407 continue
jadmanski0afbb632008-06-06 21:10:57 +0000408 value = data[field.name]
409 if isinstance(value, dbmodels.Model):
showardf8b19042009-05-12 17:22:49 +0000410 data[field.name] = value._get_pk_val()
showard7c785282008-05-29 19:45:12 +0000411
412
showard21baa452008-10-21 00:08:39 +0000413 @classmethod
414 def _convert_booleans(cls, data):
415 """
416 Ensure BooleanFields actually get bool values. The Django MySQL
417 backend returns ints for BooleanFields, which is almost always not
418 a problem, but it can be annoying in certain situations.
419 """
420 for field in cls._meta.fields:
showardf8b19042009-05-12 17:22:49 +0000421 if type(field) == dbmodels.BooleanField and field.name in data:
showard21baa452008-10-21 00:08:39 +0000422 data[field.name] = bool(data[field.name])
423
424
jadmanski0afbb632008-06-06 21:10:57 +0000425 # TODO(showard) - is there a way to not have to do this?
426 @classmethod
427 def provide_default_values(cls, data):
428 """\
429 Provide default values for fields with default values which have
430 nothing passed in.
showard7c785282008-05-29 19:45:12 +0000431
jadmanski0afbb632008-06-06 21:10:57 +0000432 For CharField and TextField fields with "blank=True", if nothing
433 is passed, we fill in an empty string value, even if there's no
434 default set.
435 """
436 new_data = dict(data)
437 field_dict = cls.get_field_dict()
438 for name, obj in field_dict.iteritems():
439 if data.get(name) is not None:
440 continue
441 if obj.default is not dbmodels.fields.NOT_PROVIDED:
442 new_data[name] = obj.default
443 elif (isinstance(obj, dbmodels.CharField) or
444 isinstance(obj, dbmodels.TextField)):
445 new_data[name] = ''
446 return new_data
showard7c785282008-05-29 19:45:12 +0000447
448
jadmanski0afbb632008-06-06 21:10:57 +0000449 @classmethod
450 def convert_human_readable_values(cls, data, to_human_readable=False):
451 """\
452 Performs conversions on user-supplied field data, to make it
453 easier for users to pass human-readable data.
showard7c785282008-05-29 19:45:12 +0000454
jadmanski0afbb632008-06-06 21:10:57 +0000455 For all fields that have choice sets, convert their values
456 from human-readable strings to enum values, if necessary. This
457 allows users to pass strings instead of the corresponding
458 integer values.
showard7c785282008-05-29 19:45:12 +0000459
jadmanski0afbb632008-06-06 21:10:57 +0000460 For all foreign key fields, call smart_get with the supplied
461 data. This allows the user to pass either an ID value or
462 the name of the object as a string.
showard7c785282008-05-29 19:45:12 +0000463
jadmanski0afbb632008-06-06 21:10:57 +0000464 If to_human_readable=True, perform the inverse - i.e. convert
465 numeric values to human readable values.
showard7c785282008-05-29 19:45:12 +0000466
jadmanski0afbb632008-06-06 21:10:57 +0000467 This method modifies data in-place.
468 """
469 field_dict = cls.get_field_dict()
470 for field_name in data:
showarde732ee72008-09-23 19:15:43 +0000471 if field_name not in field_dict or data[field_name] is None:
jadmanski0afbb632008-06-06 21:10:57 +0000472 continue
473 field_obj = field_dict[field_name]
474 # convert enum values
475 if field_obj.choices:
476 for choice_data in field_obj.choices:
477 # choice_data is (value, name)
478 if to_human_readable:
479 from_val, to_val = choice_data
480 else:
481 to_val, from_val = choice_data
482 if from_val == data[field_name]:
483 data[field_name] = to_val
484 break
485 # convert foreign key values
486 elif field_obj.rel:
showarda4ea5742009-02-17 20:56:23 +0000487 dest_obj = field_obj.rel.to.smart_get(data[field_name],
488 valid_only=False)
showardf8b19042009-05-12 17:22:49 +0000489 if to_human_readable:
490 if dest_obj.name_field is not None:
491 data[field_name] = getattr(dest_obj,
492 dest_obj.name_field)
jadmanski0afbb632008-06-06 21:10:57 +0000493 else:
showardb0a73032009-03-27 18:35:41 +0000494 data[field_name] = dest_obj
showard7c785282008-05-29 19:45:12 +0000495
496
jadmanski0afbb632008-06-06 21:10:57 +0000497 @classmethod
498 def validate_field_names(cls, data):
499 'Checks for extraneous fields in data.'
500 errors = {}
501 field_dict = cls.get_field_dict()
502 for field_name in data:
503 if field_name not in field_dict:
504 errors[field_name] = 'No field of this name'
505 return errors
showard7c785282008-05-29 19:45:12 +0000506
507
jadmanski0afbb632008-06-06 21:10:57 +0000508 @classmethod
509 def prepare_data_args(cls, data, kwargs):
510 'Common preparation for add_object and update_object'
511 data = dict(data) # don't modify the default keyword arg
512 data.update(kwargs)
513 # must check for extraneous field names here, while we have the
514 # data in a dict
515 errors = cls.validate_field_names(data)
516 if errors:
517 raise ValidationError(errors)
518 cls.convert_human_readable_values(data)
519 return data
showard7c785282008-05-29 19:45:12 +0000520
521
jadmanski0afbb632008-06-06 21:10:57 +0000522 def validate_unique(self):
523 """\
524 Validate that unique fields are unique. Django manipulators do
525 this too, but they're a huge pain to use manually. Trust me.
526 """
527 errors = {}
528 cls = type(self)
529 field_dict = self.get_field_dict()
530 manager = cls.get_valid_manager()
531 for field_name, field_obj in field_dict.iteritems():
532 if not field_obj.unique:
533 continue
showard7c785282008-05-29 19:45:12 +0000534
jadmanski0afbb632008-06-06 21:10:57 +0000535 value = getattr(self, field_name)
showardbd18ab72009-09-18 21:20:27 +0000536 if value is None and field_obj.auto_created:
537 # don't bother checking autoincrement fields about to be
538 # generated
539 continue
540
jadmanski0afbb632008-06-06 21:10:57 +0000541 existing_objs = manager.filter(**{field_name : value})
542 num_existing = existing_objs.count()
showard7c785282008-05-29 19:45:12 +0000543
jadmanski0afbb632008-06-06 21:10:57 +0000544 if num_existing == 0:
545 continue
546 if num_existing == 1 and existing_objs[0].id == self.id:
547 continue
548 errors[field_name] = (
549 'This value must be unique (%s)' % (value))
550 return errors
showard7c785282008-05-29 19:45:12 +0000551
552
showarda5288b42009-07-28 20:06:08 +0000553 def _validate(self):
554 """
555 First coerces all fields on this instance to their proper Python types.
556 Then runs validation on every field. Returns a dictionary of
557 field_name -> error_list.
558
559 Based on validate() from django.db.models.Model in Django 0.96, which
560 was removed in Django 1.0. It should reappear in a later version. See:
561 http://code.djangoproject.com/ticket/6845
562 """
563 error_dict = {}
564 for f in self._meta.fields:
565 try:
566 python_value = f.to_python(
567 getattr(self, f.attname, f.get_default()))
568 except django.core.exceptions.ValidationError, e:
569 error_dict[f.name] = str(e.message)
570 continue
571
572 if not f.blank and not python_value:
573 error_dict[f.name] = 'This field is required.'
574 continue
575
576 setattr(self, f.attname, python_value)
577
578 return error_dict
579
580
jadmanski0afbb632008-06-06 21:10:57 +0000581 def do_validate(self):
showarda5288b42009-07-28 20:06:08 +0000582 errors = self._validate()
jadmanski0afbb632008-06-06 21:10:57 +0000583 unique_errors = self.validate_unique()
584 for field_name, error in unique_errors.iteritems():
585 errors.setdefault(field_name, error)
586 if errors:
587 raise ValidationError(errors)
showard7c785282008-05-29 19:45:12 +0000588
589
jadmanski0afbb632008-06-06 21:10:57 +0000590 # actually (externally) useful methods follow
showard7c785282008-05-29 19:45:12 +0000591
jadmanski0afbb632008-06-06 21:10:57 +0000592 @classmethod
593 def add_object(cls, data={}, **kwargs):
594 """\
595 Returns a new object created with the given data (a dictionary
596 mapping field names to values). Merges any extra keyword args
597 into data.
598 """
599 data = cls.prepare_data_args(data, kwargs)
600 data = cls.provide_default_values(data)
601 obj = cls(**data)
602 obj.do_validate()
603 obj.save()
604 return obj
showard7c785282008-05-29 19:45:12 +0000605
606
jadmanski0afbb632008-06-06 21:10:57 +0000607 def update_object(self, data={}, **kwargs):
608 """\
609 Updates the object with the given data (a dictionary mapping
610 field names to values). Merges any extra keyword args into
611 data.
612 """
613 data = self.prepare_data_args(data, kwargs)
614 for field_name, value in data.iteritems():
showardb0a73032009-03-27 18:35:41 +0000615 setattr(self, field_name, value)
jadmanski0afbb632008-06-06 21:10:57 +0000616 self.do_validate()
617 self.save()
showard7c785282008-05-29 19:45:12 +0000618
619
jadmanski0afbb632008-06-06 21:10:57 +0000620 @classmethod
showard7ac7b7a2008-07-21 20:24:29 +0000621 def query_objects(cls, filter_data, valid_only=True, initial_query=None):
jadmanski0afbb632008-06-06 21:10:57 +0000622 """\
623 Returns a QuerySet object for querying the given model_class
624 with the given filter_data. Optional special arguments in
625 filter_data include:
626 -query_start: index of first return to return
627 -query_limit: maximum number of results to return
628 -sort_by: list of fields to sort on. prefixing a '-' onto a
629 field name changes the sort to descending order.
630 -extra_args: keyword args to pass to query.extra() (see Django
631 DB layer documentation)
showarda5288b42009-07-28 20:06:08 +0000632 -extra_where: extra WHERE clause to append
jadmanski0afbb632008-06-06 21:10:57 +0000633 """
showardc0ac3a72009-07-08 21:14:45 +0000634 filter_data = dict(filter_data) # copy so we don't mutate the original
jadmanski0afbb632008-06-06 21:10:57 +0000635 query_start = filter_data.pop('query_start', None)
636 query_limit = filter_data.pop('query_limit', None)
637 if query_start and not query_limit:
638 raise ValueError('Cannot pass query_start without '
639 'query_limit')
showardc4780402009-08-31 18:31:34 +0000640 sort_by = filter_data.pop('sort_by', None)
jadmanski0afbb632008-06-06 21:10:57 +0000641 extra_args = filter_data.pop('extra_args', {})
642 extra_where = filter_data.pop('extra_where', None)
643 if extra_where:
showard1e935f12008-07-11 00:11:36 +0000644 # escape %'s
showardeaccf8f2009-04-16 03:11:33 +0000645 extra_where = cls.objects.escape_user_sql(extra_where)
jadmanski0afbb632008-06-06 21:10:57 +0000646 extra_args.setdefault('where', []).append(extra_where)
showard7ac7b7a2008-07-21 20:24:29 +0000647 use_distinct = not filter_data.pop('no_distinct', False)
showard7c785282008-05-29 19:45:12 +0000648
showard7ac7b7a2008-07-21 20:24:29 +0000649 if initial_query is None:
650 if valid_only:
651 initial_query = cls.get_valid_manager()
652 else:
653 initial_query = cls.objects
654 query = initial_query.filter(**filter_data)
655 if use_distinct:
656 query = query.distinct()
showard7c785282008-05-29 19:45:12 +0000657
jadmanski0afbb632008-06-06 21:10:57 +0000658 # other arguments
659 if extra_args:
660 query = query.extra(**extra_args)
showard09096d82008-07-07 23:20:49 +0000661 query = query._clone(klass=ReadonlyQuerySet)
showard7c785282008-05-29 19:45:12 +0000662
jadmanski0afbb632008-06-06 21:10:57 +0000663 # sorting + paging
showardc4780402009-08-31 18:31:34 +0000664 if sort_by:
665 assert isinstance(sort_by, list) or isinstance(sort_by, tuple)
666 query = query.order_by(*sort_by)
jadmanski0afbb632008-06-06 21:10:57 +0000667 if query_start is not None and query_limit is not None:
668 query_limit += query_start
669 return query[query_start:query_limit]
showard7c785282008-05-29 19:45:12 +0000670
671
jadmanski0afbb632008-06-06 21:10:57 +0000672 @classmethod
showard585c2ab2008-07-23 19:29:49 +0000673 def query_count(cls, filter_data, initial_query=None):
jadmanski0afbb632008-06-06 21:10:57 +0000674 """\
675 Like query_objects, but retreive only the count of results.
676 """
677 filter_data.pop('query_start', None)
678 filter_data.pop('query_limit', None)
showard585c2ab2008-07-23 19:29:49 +0000679 query = cls.query_objects(filter_data, initial_query=initial_query)
680 return query.count()
showard7c785282008-05-29 19:45:12 +0000681
682
jadmanski0afbb632008-06-06 21:10:57 +0000683 @classmethod
684 def clean_object_dicts(cls, field_dicts):
685 """\
686 Take a list of dicts corresponding to object (as returned by
687 query.values()) and clean the data to be more suitable for
688 returning to the user.
689 """
showarde732ee72008-09-23 19:15:43 +0000690 for field_dict in field_dicts:
691 cls.clean_foreign_keys(field_dict)
showard21baa452008-10-21 00:08:39 +0000692 cls._convert_booleans(field_dict)
showarde732ee72008-09-23 19:15:43 +0000693 cls.convert_human_readable_values(field_dict,
694 to_human_readable=True)
showard7c785282008-05-29 19:45:12 +0000695
696
jadmanski0afbb632008-06-06 21:10:57 +0000697 @classmethod
showarde732ee72008-09-23 19:15:43 +0000698 def list_objects(cls, filter_data, initial_query=None, fields=None):
jadmanski0afbb632008-06-06 21:10:57 +0000699 """\
700 Like query_objects, but return a list of dictionaries.
701 """
showard7ac7b7a2008-07-21 20:24:29 +0000702 query = cls.query_objects(filter_data, initial_query=initial_query)
showarde732ee72008-09-23 19:15:43 +0000703 field_dicts = [model_object.get_object_dict(fields)
704 for model_object in query]
jadmanski0afbb632008-06-06 21:10:57 +0000705 return field_dicts
showard7c785282008-05-29 19:45:12 +0000706
707
jadmanski0afbb632008-06-06 21:10:57 +0000708 @classmethod
showarda4ea5742009-02-17 20:56:23 +0000709 def smart_get(cls, id_or_name, valid_only=True):
jadmanski0afbb632008-06-06 21:10:57 +0000710 """\
711 smart_get(integer) -> get object by ID
712 smart_get(string) -> get object by name_field
jadmanski0afbb632008-06-06 21:10:57 +0000713 """
showarda4ea5742009-02-17 20:56:23 +0000714 if valid_only:
715 manager = cls.get_valid_manager()
716 else:
717 manager = cls.objects
718
719 if isinstance(id_or_name, (int, long)):
720 return manager.get(pk=id_or_name)
721 if isinstance(id_or_name, basestring):
722 return manager.get(**{cls.name_field : id_or_name})
723 raise ValueError(
724 'Invalid positional argument: %s (%s)' % (id_or_name,
725 type(id_or_name)))
showard7c785282008-05-29 19:45:12 +0000726
727
showardbe3ec042008-11-12 18:16:07 +0000728 @classmethod
729 def smart_get_bulk(cls, id_or_name_list):
730 invalid_inputs = []
731 result_objects = []
732 for id_or_name in id_or_name_list:
733 try:
734 result_objects.append(cls.smart_get(id_or_name))
735 except cls.DoesNotExist:
736 invalid_inputs.append(id_or_name)
737 if invalid_inputs:
mbligh7a3ebe32008-12-01 17:10:33 +0000738 raise cls.DoesNotExist('The following %ss do not exist: %s'
739 % (cls.__name__.lower(),
740 ', '.join(invalid_inputs)))
showardbe3ec042008-11-12 18:16:07 +0000741 return result_objects
742
743
showarde732ee72008-09-23 19:15:43 +0000744 def get_object_dict(self, fields=None):
jadmanski0afbb632008-06-06 21:10:57 +0000745 """\
746 Return a dictionary mapping fields to this object's values.
747 """
showarde732ee72008-09-23 19:15:43 +0000748 if fields is None:
749 fields = self.get_field_dict().iterkeys()
jadmanski0afbb632008-06-06 21:10:57 +0000750 object_dict = dict((field_name, getattr(self, field_name))
showarde732ee72008-09-23 19:15:43 +0000751 for field_name in fields)
jadmanski0afbb632008-06-06 21:10:57 +0000752 self.clean_object_dicts([object_dict])
showardd3dc1992009-04-22 21:01:40 +0000753 self._postprocess_object_dict(object_dict)
jadmanski0afbb632008-06-06 21:10:57 +0000754 return object_dict
showard7c785282008-05-29 19:45:12 +0000755
756
showardd3dc1992009-04-22 21:01:40 +0000757 def _postprocess_object_dict(self, object_dict):
758 """For subclasses to override."""
759 pass
760
761
jadmanski0afbb632008-06-06 21:10:57 +0000762 @classmethod
763 def get_valid_manager(cls):
764 return cls.objects
showard7c785282008-05-29 19:45:12 +0000765
766
showard2bab8f42008-11-12 18:15:22 +0000767 def _record_attributes(self, attributes):
768 """
769 See on_attribute_changed.
770 """
771 assert not isinstance(attributes, basestring)
772 self._recorded_attributes = dict((attribute, getattr(self, attribute))
773 for attribute in attributes)
774
775
776 def _check_for_updated_attributes(self):
777 """
778 See on_attribute_changed.
779 """
780 for attribute, original_value in self._recorded_attributes.iteritems():
781 new_value = getattr(self, attribute)
782 if original_value != new_value:
783 self.on_attribute_changed(attribute, original_value)
784 self._record_attributes(self._recorded_attributes.keys())
785
786
787 def on_attribute_changed(self, attribute, old_value):
788 """
789 Called whenever an attribute is updated. To be overridden.
790
791 To use this method, you must:
792 * call _record_attributes() from __init__() (after making the super
793 call) with a list of attributes for which you want to be notified upon
794 change.
795 * call _check_for_updated_attributes() from save().
796 """
797 pass
798
799
showard7c785282008-05-29 19:45:12 +0000800class ModelWithInvalid(ModelExtensions):
jadmanski0afbb632008-06-06 21:10:57 +0000801 """
802 Overrides model methods save() and delete() to support invalidation in
803 place of actual deletion. Subclasses must have a boolean "invalid"
804 field.
805 """
showard7c785282008-05-29 19:45:12 +0000806
showarda5288b42009-07-28 20:06:08 +0000807 def save(self, *args, **kwargs):
showardddb90992009-02-11 23:39:32 +0000808 first_time = (self.id is None)
809 if first_time:
810 # see if this object was previously added and invalidated
811 my_name = getattr(self, self.name_field)
812 filters = {self.name_field : my_name, 'invalid' : True}
813 try:
814 old_object = self.__class__.objects.get(**filters)
showardafd97de2009-10-01 18:45:09 +0000815 self.resurrect_object(old_object)
showardddb90992009-02-11 23:39:32 +0000816 except self.DoesNotExist:
817 # no existing object
818 pass
showard7c785282008-05-29 19:45:12 +0000819
showarda5288b42009-07-28 20:06:08 +0000820 super(ModelWithInvalid, self).save(*args, **kwargs)
showard7c785282008-05-29 19:45:12 +0000821
822
showardafd97de2009-10-01 18:45:09 +0000823 def resurrect_object(self, old_object):
824 """
825 Called when self is about to be saved for the first time and is actually
826 "undeleting" a previously deleted object. Can be overridden by
827 subclasses to copy data as desired from the deleted entry (but this
828 superclass implementation must normally be called).
829 """
830 self.id = old_object.id
831
832
jadmanski0afbb632008-06-06 21:10:57 +0000833 def clean_object(self):
834 """
835 This method is called when an object is marked invalid.
836 Subclasses should override this to clean up relationships that
showardafd97de2009-10-01 18:45:09 +0000837 should no longer exist if the object were deleted.
838 """
jadmanski0afbb632008-06-06 21:10:57 +0000839 pass
showard7c785282008-05-29 19:45:12 +0000840
841
jadmanski0afbb632008-06-06 21:10:57 +0000842 def delete(self):
843 assert not self.invalid
844 self.invalid = True
845 self.save()
846 self.clean_object()
showard7c785282008-05-29 19:45:12 +0000847
848
jadmanski0afbb632008-06-06 21:10:57 +0000849 @classmethod
850 def get_valid_manager(cls):
851 return cls.valid_objects
showard7c785282008-05-29 19:45:12 +0000852
853
jadmanski0afbb632008-06-06 21:10:57 +0000854 class Manipulator(object):
855 """
856 Force default manipulators to look only at valid objects -
857 otherwise they will match against invalid objects when checking
858 uniqueness.
859 """
860 @classmethod
861 def _prepare(cls, model):
862 super(ModelWithInvalid.Manipulator, cls)._prepare(model)
863 cls.manager = model.valid_objects
showardf8b19042009-05-12 17:22:49 +0000864
865
866class ModelWithAttributes(object):
867 """
868 Mixin class for models that have an attribute model associated with them.
869 The attribute model is assumed to have its value field named "value".
870 """
871
872 def _get_attribute_model_and_args(self, attribute):
873 """
874 Subclasses should override this to return a tuple (attribute_model,
875 keyword_args), where attribute_model is a model class and keyword_args
876 is a dict of args to pass to attribute_model.objects.get() to get an
877 instance of the given attribute on this object.
878 """
879 raise NotImplemented
880
881
882 def set_attribute(self, attribute, value):
883 attribute_model, get_args = self._get_attribute_model_and_args(
884 attribute)
885 attribute_object, _ = attribute_model.objects.get_or_create(**get_args)
886 attribute_object.value = value
887 attribute_object.save()
888
889
890 def delete_attribute(self, attribute):
891 attribute_model, get_args = self._get_attribute_model_and_args(
892 attribute)
893 try:
894 attribute_model.objects.get(**get_args).delete()
showard16245422009-09-08 16:28:15 +0000895 except attribute_model.DoesNotExist:
showardf8b19042009-05-12 17:22:49 +0000896 pass
897
898
899 def set_or_delete_attribute(self, attribute, value):
900 if value is None:
901 self.delete_attribute(attribute)
902 else:
903 self.set_attribute(attribute, value)