blob: c0c623b4bd57d9e50b81adde65295b9846ccf150 [file] [log] [blame]
showard7c785282008-05-29 19:45:12 +00001"""
2Extensions to Django's model logic.
3"""
4
showarda5288b42009-07-28 20:06:08 +00005import re
6import django.core.exceptions
showard7c785282008-05-29 19:45:12 +00007from django.db import models as dbmodels, backend, connection
showarda5288b42009-07-28 20:06:08 +00008from django.db.models.sql import query
showard7c785282008-05-29 19:45:12 +00009from django.utils import datastructures
showard56e93772008-10-06 10:06:22 +000010from autotest_lib.frontend.afe import readonly_connection
showard7c785282008-05-29 19:45:12 +000011
showarda5288b42009-07-28 20:06:08 +000012
showard7c785282008-05-29 19:45:12 +000013class ValidationError(Exception):
jadmanski0afbb632008-06-06 21:10:57 +000014 """\
showarda5288b42009-07-28 20:06:08 +000015 Data validation error in adding or updating an object. The associated
jadmanski0afbb632008-06-06 21:10:57 +000016 value is a dictionary mapping field names to error strings.
17 """
showard7c785282008-05-29 19:45:12 +000018
19
showard09096d82008-07-07 23:20:49 +000020def _wrap_with_readonly(method):
mbligh1ef218d2009-08-03 16:57:56 +000021 def wrapper_method(*args, **kwargs):
22 readonly_connection.connection().set_django_connection()
23 try:
24 return method(*args, **kwargs)
25 finally:
26 readonly_connection.connection().unset_django_connection()
27 wrapper_method.__name__ = method.__name__
28 return wrapper_method
showard09096d82008-07-07 23:20:49 +000029
30
showarda5288b42009-07-28 20:06:08 +000031def _quote_name(name):
32 """Shorthand for connection.ops.quote_name()."""
33 return connection.ops.quote_name(name)
34
35
showard09096d82008-07-07 23:20:49 +000036def _wrap_generator_with_readonly(generator):
37 """
38 We have to wrap generators specially. Assume it performs
39 the query on the first call to next().
40 """
41 def wrapper_generator(*args, **kwargs):
42 generator_obj = generator(*args, **kwargs)
showard56e93772008-10-06 10:06:22 +000043 readonly_connection.connection().set_django_connection()
showard09096d82008-07-07 23:20:49 +000044 try:
45 first_value = generator_obj.next()
46 finally:
showard56e93772008-10-06 10:06:22 +000047 readonly_connection.connection().unset_django_connection()
showard09096d82008-07-07 23:20:49 +000048 yield first_value
49
50 while True:
51 yield generator_obj.next()
52
53 wrapper_generator.__name__ = generator.__name__
54 return wrapper_generator
55
56
57def _make_queryset_readonly(queryset):
58 """
59 Wrap all methods that do database queries with a readonly connection.
60 """
61 db_query_methods = ['count', 'get', 'get_or_create', 'latest', 'in_bulk',
62 'delete']
63 for method_name in db_query_methods:
64 method = getattr(queryset, method_name)
65 wrapped_method = _wrap_with_readonly(method)
66 setattr(queryset, method_name, wrapped_method)
67
68 queryset.iterator = _wrap_generator_with_readonly(queryset.iterator)
69
70
71class ReadonlyQuerySet(dbmodels.query.QuerySet):
72 """
73 QuerySet object that performs all database queries with the read-only
74 connection.
75 """
showarda5288b42009-07-28 20:06:08 +000076 def __init__(self, model=None, *args, **kwargs):
77 super(ReadonlyQuerySet, self).__init__(model, *args, **kwargs)
showard09096d82008-07-07 23:20:49 +000078 _make_queryset_readonly(self)
79
80
81 def values(self, *fields):
showarda5288b42009-07-28 20:06:08 +000082 return self._clone(klass=ReadonlyValuesQuerySet,
83 setup=True, _fields=fields)
showard09096d82008-07-07 23:20:49 +000084
85
86class ReadonlyValuesQuerySet(dbmodels.query.ValuesQuerySet):
showarda5288b42009-07-28 20:06:08 +000087 def __init__(self, model=None, *args, **kwargs):
88 super(ReadonlyValuesQuerySet, self).__init__(model, *args, **kwargs)
showard09096d82008-07-07 23:20:49 +000089 _make_queryset_readonly(self)
90
91
showard7c785282008-05-29 19:45:12 +000092class ExtendedManager(dbmodels.Manager):
jadmanski0afbb632008-06-06 21:10:57 +000093 """\
94 Extended manager supporting subquery filtering.
95 """
showard7c785282008-05-29 19:45:12 +000096
showarda5288b42009-07-28 20:06:08 +000097 class _CustomQuery(query.Query):
98 def clone(self, klass=None, **kwargs):
99 obj = super(ExtendedManager._CustomQuery, self).clone(
100 klass, _customSqlQ=self._customSqlQ)
showard7c785282008-05-29 19:45:12 +0000101
showarda5288b42009-07-28 20:06:08 +0000102 customQ = kwargs.get('_customSqlQ', None)
103 if customQ is not None:
104 obj._customSqlQ._joins.update(customQ._joins)
105 obj._customSqlQ._where.extend(customQ._where)
106 obj._customSqlQ._params.extend(customQ._params)
showard7c785282008-05-29 19:45:12 +0000107
showarda5288b42009-07-28 20:06:08 +0000108 return obj
showard08f981b2008-06-24 21:59:03 +0000109
showarda5288b42009-07-28 20:06:08 +0000110 def get_from_clause(self):
111 from_, params = super(
112 ExtendedManager._CustomQuery, self).get_from_clause()
showard08f981b2008-06-24 21:59:03 +0000113
showarda5288b42009-07-28 20:06:08 +0000114 join_clause = ''
115 for join_alias, join in self._customSqlQ._joins.iteritems():
116 join_table, join_type, condition = join
117 join_clause += ' %s %s AS %s ON (%s)' % (
118 join_type, join_table, join_alias, condition)
showard08f981b2008-06-24 21:59:03 +0000119
showarda5288b42009-07-28 20:06:08 +0000120 if join_clause:
121 from_.append(join_clause)
showard7c785282008-05-29 19:45:12 +0000122
showarda5288b42009-07-28 20:06:08 +0000123 return from_, params
showard7c785282008-05-29 19:45:12 +0000124
125
showard43a3d262008-11-12 18:17:05 +0000126 class _CustomSqlQ(dbmodels.Q):
127 def __init__(self):
128 self._joins = datastructures.SortedDict()
129 self._where, self._params = [], []
130
131
132 def add_join(self, table, condition, join_type, alias=None):
133 if alias is None:
134 alias = table
showard43a3d262008-11-12 18:17:05 +0000135 self._joins[alias] = (table, join_type, condition)
136
137
138 def add_where(self, where, params=[]):
139 self._where.append(where)
140 self._params.extend(params)
141
142
showarda5288b42009-07-28 20:06:08 +0000143 def add_to_query(self, query, aliases):
144 if self._where:
145 where = ' AND '.join(self._where)
146 query.add_extra(None, None, (where,), self._params, None, None)
147
148
149 def _add_customSqlQ(self, query_set, filter_object):
150 """\
151 Add a _CustomSqlQ to the query set.
152 """
153 # Make a copy of the query set
154 query_set = query_set.all()
155
156 query_set.query = query_set.query.clone(
157 ExtendedManager._CustomQuery, _customSqlQ=filter_object)
158 return query_set.filter(filter_object)
showard43a3d262008-11-12 18:17:05 +0000159
160
161 def add_join(self, query_set, join_table, join_key,
showard0957a842009-05-11 19:25:08 +0000162 join_condition='', suffix='', exclude=False,
163 force_left_join=False):
164 """
165 Add a join to query_set.
166 @param join_table table to join to
167 @param join_key field referencing back to this model to use for the join
168 @param join_condition extra condition for the ON clause of the join
169 @param suffix suffix to add to join_table for the join alias
170 @param exclude if true, exclude rows that match this join (will use a
showarda5288b42009-07-28 20:06:08 +0000171 LEFT OUTER JOIN and an appropriate WHERE condition)
showardc4780402009-08-31 18:31:34 +0000172 @param force_left_join - if true, a LEFT OUTER JOIN will be used
173 instead of an INNER JOIN regardless of other options
showard0957a842009-05-11 19:25:08 +0000174 """
175 join_from_table = self.model._meta.db_table
176 join_from_key = self.model._meta.pk.name
showard43a3d262008-11-12 18:17:05 +0000177 join_alias = join_table + suffix
178 full_join_key = join_alias + '.' + join_key
showard0957a842009-05-11 19:25:08 +0000179 full_join_condition = '%s = %s.%s' % (full_join_key, join_from_table,
180 join_from_key)
showard43a3d262008-11-12 18:17:05 +0000181 if join_condition:
182 full_join_condition += ' AND (' + join_condition + ')'
183 if exclude or force_left_join:
showarda5288b42009-07-28 20:06:08 +0000184 join_type = query_set.query.LOUTER
showard43a3d262008-11-12 18:17:05 +0000185 else:
showarda5288b42009-07-28 20:06:08 +0000186 join_type = query_set.query.INNER
showard43a3d262008-11-12 18:17:05 +0000187
188 filter_object = self._CustomSqlQ()
189 filter_object.add_join(join_table,
190 full_join_condition,
191 join_type,
192 alias=join_alias)
193 if exclude:
194 filter_object.add_where(full_join_key + ' IS NULL')
showard43a3d262008-11-12 18:17:05 +0000195
showarda5288b42009-07-28 20:06:08 +0000196 query_set = self._add_customSqlQ(query_set, filter_object)
showardc4780402009-08-31 18:31:34 +0000197 return query_set
showard7c785282008-05-29 19:45:12 +0000198
199
showardeaccf8f2009-04-16 03:11:33 +0000200 def _get_quoted_field(self, table, field):
showarda5288b42009-07-28 20:06:08 +0000201 return _quote_name(table) + '.' + _quote_name(field)
showard5ef36e92008-07-02 16:37:09 +0000202
203
showard7c199df2008-10-03 10:17:15 +0000204 def get_key_on_this_table(self, key_field=None):
showard5ef36e92008-07-02 16:37:09 +0000205 if key_field is None:
206 # default to primary key
207 key_field = self.model._meta.pk.column
208 return self._get_quoted_field(self.model._meta.db_table, key_field)
209
210
showardeaccf8f2009-04-16 03:11:33 +0000211 def escape_user_sql(self, sql):
212 return sql.replace('%', '%%')
213
showard5ef36e92008-07-02 16:37:09 +0000214
showard0957a842009-05-11 19:25:08 +0000215 def _custom_select_query(self, query_set, selects):
showarda5288b42009-07-28 20:06:08 +0000216 sql, params = query_set.query.as_sql()
217 from_ = sql[sql.find(' FROM'):]
218
219 if query_set.query.distinct:
showard0957a842009-05-11 19:25:08 +0000220 distinct = 'DISTINCT '
221 else:
222 distinct = ''
showarda5288b42009-07-28 20:06:08 +0000223
224 sql_query = ('SELECT ' + distinct + ','.join(selects) + from_)
showard0957a842009-05-11 19:25:08 +0000225 cursor = readonly_connection.connection().cursor()
226 cursor.execute(sql_query, params)
227 return cursor.fetchall()
228
229
showard68693f72009-05-20 00:31:53 +0000230 def _is_relation_to(self, field, model_class):
231 return field.rel and field.rel.to is model_class
showard0957a842009-05-11 19:25:08 +0000232
233
showard68693f72009-05-20 00:31:53 +0000234 def _get_pivot_iterator(self, base_objects_by_id, related_model):
showard0957a842009-05-11 19:25:08 +0000235 """
showard68693f72009-05-20 00:31:53 +0000236 Determine the relationship between this model and related_model, and
237 return a pivot iterator.
238 @param base_objects_by_id: dict of instances of this model indexed by
239 their IDs
240 @returns a pivot iterator, which yields a tuple (base_object,
241 related_object) for each relationship between a base object and a
242 related object. all base_object instances come from base_objects_by_id.
showard0957a842009-05-11 19:25:08 +0000243 Note -- this depends on Django model internals and will likely need to
244 be updated when we move to Django 1.x.
245 """
showard68693f72009-05-20 00:31:53 +0000246 # look for a field on related_model relating to this model
247 for field in related_model._meta.fields:
showard0957a842009-05-11 19:25:08 +0000248 if self._is_relation_to(field, self.model):
showard68693f72009-05-20 00:31:53 +0000249 # many-to-one
250 return self._many_to_one_pivot(base_objects_by_id,
251 related_model, field)
showard0957a842009-05-11 19:25:08 +0000252
showard68693f72009-05-20 00:31:53 +0000253 for field in related_model._meta.many_to_many:
showard0957a842009-05-11 19:25:08 +0000254 if self._is_relation_to(field, self.model):
255 # many-to-many
showard68693f72009-05-20 00:31:53 +0000256 return self._many_to_many_pivot(
257 base_objects_by_id, related_model, field.m2m_db_table(),
258 field.m2m_reverse_name(), field.m2m_column_name())
showard0957a842009-05-11 19:25:08 +0000259
260 # maybe this model has the many-to-many field
261 for field in self.model._meta.many_to_many:
showard68693f72009-05-20 00:31:53 +0000262 if self._is_relation_to(field, related_model):
263 return self._many_to_many_pivot(
264 base_objects_by_id, related_model, field.m2m_db_table(),
265 field.m2m_column_name(), field.m2m_reverse_name())
showard0957a842009-05-11 19:25:08 +0000266
267 raise ValueError('%s has no relation to %s' %
showard68693f72009-05-20 00:31:53 +0000268 (related_model, self.model))
showard0957a842009-05-11 19:25:08 +0000269
270
showard68693f72009-05-20 00:31:53 +0000271 def _many_to_one_pivot(self, base_objects_by_id, related_model,
272 foreign_key_field):
273 """
274 @returns a pivot iterator - see _get_pivot_iterator()
275 """
276 filter_data = {foreign_key_field.name + '__pk__in':
277 base_objects_by_id.keys()}
278 for related_object in related_model.objects.filter(**filter_data):
showarda5a72c92009-08-20 23:35:21 +0000279 # lookup base object in the dict, rather than grabbing it from the
280 # related object. we need to return instances from the dict, not
281 # fresh instances of the same models (and grabbing model instances
282 # from the related models incurs a DB query each time).
283 base_object_id = getattr(related_object, foreign_key_field.attname)
284 base_object = base_objects_by_id[base_object_id]
showard68693f72009-05-20 00:31:53 +0000285 yield base_object, related_object
286
287
288 def _query_pivot_table(self, base_objects_by_id, pivot_table,
289 pivot_from_field, pivot_to_field):
showard0957a842009-05-11 19:25:08 +0000290 """
291 @param id_list list of IDs of self.model objects to include
292 @param pivot_table the name of the pivot table
293 @param pivot_from_field a field name on pivot_table referencing
294 self.model
295 @param pivot_to_field a field name on pivot_table referencing the
296 related model.
showard68693f72009-05-20 00:31:53 +0000297 @returns pivot list of IDs (base_id, related_id)
showard0957a842009-05-11 19:25:08 +0000298 """
299 query = """
300 SELECT %(from_field)s, %(to_field)s
301 FROM %(table)s
302 WHERE %(from_field)s IN (%(id_list)s)
303 """ % dict(from_field=pivot_from_field,
304 to_field=pivot_to_field,
305 table=pivot_table,
showard68693f72009-05-20 00:31:53 +0000306 id_list=','.join(str(id_) for id_
307 in base_objects_by_id.iterkeys()))
showard0957a842009-05-11 19:25:08 +0000308 cursor = readonly_connection.connection().cursor()
309 cursor.execute(query)
showard68693f72009-05-20 00:31:53 +0000310 return cursor.fetchall()
showard0957a842009-05-11 19:25:08 +0000311
312
showard68693f72009-05-20 00:31:53 +0000313 def _many_to_many_pivot(self, base_objects_by_id, related_model,
314 pivot_table, pivot_from_field, pivot_to_field):
315 """
316 @param pivot_table: see _query_pivot_table
317 @param pivot_from_field: see _query_pivot_table
318 @param pivot_to_field: see _query_pivot_table
319 @returns a pivot iterator - see _get_pivot_iterator()
320 """
321 id_pivot = self._query_pivot_table(base_objects_by_id, pivot_table,
322 pivot_from_field, pivot_to_field)
323
324 all_related_ids = list(set(related_id for base_id, related_id
325 in id_pivot))
326 related_objects_by_id = related_model.objects.in_bulk(all_related_ids)
327
328 for base_id, related_id in id_pivot:
329 yield base_objects_by_id[base_id], related_objects_by_id[related_id]
330
331
332 def populate_relationships(self, base_objects, related_model,
showard0957a842009-05-11 19:25:08 +0000333 related_list_name):
334 """
showard68693f72009-05-20 00:31:53 +0000335 For each instance of this model in base_objects, add a field named
336 related_list_name listing all the related objects of type related_model.
337 related_model must be in a many-to-one or many-to-many relationship with
338 this model.
339 @param base_objects - list of instances of this model
340 @param related_model - model class related to this model
341 @param related_list_name - attribute name in which to store the related
342 object list.
showard0957a842009-05-11 19:25:08 +0000343 """
showard68693f72009-05-20 00:31:53 +0000344 if not base_objects:
showard0957a842009-05-11 19:25:08 +0000345 # if we don't bail early, we'll get a SQL error later
346 return
showard0957a842009-05-11 19:25:08 +0000347
showard68693f72009-05-20 00:31:53 +0000348 base_objects_by_id = dict((base_object._get_pk_val(), base_object)
349 for base_object in base_objects)
350 pivot_iterator = self._get_pivot_iterator(base_objects_by_id,
351 related_model)
showard0957a842009-05-11 19:25:08 +0000352
showard68693f72009-05-20 00:31:53 +0000353 for base_object in base_objects:
354 setattr(base_object, related_list_name, [])
355
356 for base_object, related_object in pivot_iterator:
357 getattr(base_object, related_list_name).append(related_object)
showard0957a842009-05-11 19:25:08 +0000358
359
showard7c785282008-05-29 19:45:12 +0000360class ValidObjectsManager(ExtendedManager):
jadmanski0afbb632008-06-06 21:10:57 +0000361 """
362 Manager returning only objects with invalid=False.
363 """
364 def get_query_set(self):
365 queryset = super(ValidObjectsManager, self).get_query_set()
366 return queryset.filter(invalid=False)
showard7c785282008-05-29 19:45:12 +0000367
368
369class ModelExtensions(object):
jadmanski0afbb632008-06-06 21:10:57 +0000370 """\
371 Mixin with convenience functions for models, built on top of the
372 default Django model functions.
373 """
374 # TODO: at least some of these functions really belong in a custom
375 # Manager class
showard7c785282008-05-29 19:45:12 +0000376
jadmanski0afbb632008-06-06 21:10:57 +0000377 field_dict = None
378 # subclasses should override if they want to support smart_get() by name
379 name_field = None
showard7c785282008-05-29 19:45:12 +0000380
381
jadmanski0afbb632008-06-06 21:10:57 +0000382 @classmethod
383 def get_field_dict(cls):
384 if cls.field_dict is None:
385 cls.field_dict = {}
386 for field in cls._meta.fields:
387 cls.field_dict[field.name] = field
388 return cls.field_dict
showard7c785282008-05-29 19:45:12 +0000389
390
jadmanski0afbb632008-06-06 21:10:57 +0000391 @classmethod
392 def clean_foreign_keys(cls, data):
393 """\
394 -Convert foreign key fields in data from <field>_id to just
395 <field>.
396 -replace foreign key objects with their IDs
397 This method modifies data in-place.
398 """
399 for field in cls._meta.fields:
400 if not field.rel:
401 continue
402 if (field.attname != field.name and
403 field.attname in data):
404 data[field.name] = data[field.attname]
405 del data[field.attname]
showarde732ee72008-09-23 19:15:43 +0000406 if field.name not in data:
407 continue
jadmanski0afbb632008-06-06 21:10:57 +0000408 value = data[field.name]
409 if isinstance(value, dbmodels.Model):
showardf8b19042009-05-12 17:22:49 +0000410 data[field.name] = value._get_pk_val()
showard7c785282008-05-29 19:45:12 +0000411
412
showard21baa452008-10-21 00:08:39 +0000413 @classmethod
414 def _convert_booleans(cls, data):
415 """
416 Ensure BooleanFields actually get bool values. The Django MySQL
417 backend returns ints for BooleanFields, which is almost always not
418 a problem, but it can be annoying in certain situations.
419 """
420 for field in cls._meta.fields:
showardf8b19042009-05-12 17:22:49 +0000421 if type(field) == dbmodels.BooleanField and field.name in data:
showard21baa452008-10-21 00:08:39 +0000422 data[field.name] = bool(data[field.name])
423
424
jadmanski0afbb632008-06-06 21:10:57 +0000425 # TODO(showard) - is there a way to not have to do this?
426 @classmethod
427 def provide_default_values(cls, data):
428 """\
429 Provide default values for fields with default values which have
430 nothing passed in.
showard7c785282008-05-29 19:45:12 +0000431
jadmanski0afbb632008-06-06 21:10:57 +0000432 For CharField and TextField fields with "blank=True", if nothing
433 is passed, we fill in an empty string value, even if there's no
434 default set.
435 """
436 new_data = dict(data)
437 field_dict = cls.get_field_dict()
438 for name, obj in field_dict.iteritems():
439 if data.get(name) is not None:
440 continue
441 if obj.default is not dbmodels.fields.NOT_PROVIDED:
442 new_data[name] = obj.default
443 elif (isinstance(obj, dbmodels.CharField) or
444 isinstance(obj, dbmodels.TextField)):
445 new_data[name] = ''
446 return new_data
showard7c785282008-05-29 19:45:12 +0000447
448
jadmanski0afbb632008-06-06 21:10:57 +0000449 @classmethod
450 def convert_human_readable_values(cls, data, to_human_readable=False):
451 """\
452 Performs conversions on user-supplied field data, to make it
453 easier for users to pass human-readable data.
showard7c785282008-05-29 19:45:12 +0000454
jadmanski0afbb632008-06-06 21:10:57 +0000455 For all fields that have choice sets, convert their values
456 from human-readable strings to enum values, if necessary. This
457 allows users to pass strings instead of the corresponding
458 integer values.
showard7c785282008-05-29 19:45:12 +0000459
jadmanski0afbb632008-06-06 21:10:57 +0000460 For all foreign key fields, call smart_get with the supplied
461 data. This allows the user to pass either an ID value or
462 the name of the object as a string.
showard7c785282008-05-29 19:45:12 +0000463
jadmanski0afbb632008-06-06 21:10:57 +0000464 If to_human_readable=True, perform the inverse - i.e. convert
465 numeric values to human readable values.
showard7c785282008-05-29 19:45:12 +0000466
jadmanski0afbb632008-06-06 21:10:57 +0000467 This method modifies data in-place.
468 """
469 field_dict = cls.get_field_dict()
470 for field_name in data:
showarde732ee72008-09-23 19:15:43 +0000471 if field_name not in field_dict or data[field_name] is None:
jadmanski0afbb632008-06-06 21:10:57 +0000472 continue
473 field_obj = field_dict[field_name]
474 # convert enum values
475 if field_obj.choices:
476 for choice_data in field_obj.choices:
477 # choice_data is (value, name)
478 if to_human_readable:
479 from_val, to_val = choice_data
480 else:
481 to_val, from_val = choice_data
482 if from_val == data[field_name]:
483 data[field_name] = to_val
484 break
485 # convert foreign key values
486 elif field_obj.rel:
showarda4ea5742009-02-17 20:56:23 +0000487 dest_obj = field_obj.rel.to.smart_get(data[field_name],
488 valid_only=False)
showardf8b19042009-05-12 17:22:49 +0000489 if to_human_readable:
490 if dest_obj.name_field is not None:
491 data[field_name] = getattr(dest_obj,
492 dest_obj.name_field)
jadmanski0afbb632008-06-06 21:10:57 +0000493 else:
showardb0a73032009-03-27 18:35:41 +0000494 data[field_name] = dest_obj
showard7c785282008-05-29 19:45:12 +0000495
496
jadmanski0afbb632008-06-06 21:10:57 +0000497 @classmethod
498 def validate_field_names(cls, data):
499 'Checks for extraneous fields in data.'
500 errors = {}
501 field_dict = cls.get_field_dict()
502 for field_name in data:
503 if field_name not in field_dict:
504 errors[field_name] = 'No field of this name'
505 return errors
showard7c785282008-05-29 19:45:12 +0000506
507
jadmanski0afbb632008-06-06 21:10:57 +0000508 @classmethod
509 def prepare_data_args(cls, data, kwargs):
510 'Common preparation for add_object and update_object'
511 data = dict(data) # don't modify the default keyword arg
512 data.update(kwargs)
513 # must check for extraneous field names here, while we have the
514 # data in a dict
515 errors = cls.validate_field_names(data)
516 if errors:
517 raise ValidationError(errors)
518 cls.convert_human_readable_values(data)
519 return data
showard7c785282008-05-29 19:45:12 +0000520
521
jadmanski0afbb632008-06-06 21:10:57 +0000522 def validate_unique(self):
523 """\
524 Validate that unique fields are unique. Django manipulators do
525 this too, but they're a huge pain to use manually. Trust me.
526 """
527 errors = {}
528 cls = type(self)
529 field_dict = self.get_field_dict()
530 manager = cls.get_valid_manager()
531 for field_name, field_obj in field_dict.iteritems():
532 if not field_obj.unique:
533 continue
showard7c785282008-05-29 19:45:12 +0000534
jadmanski0afbb632008-06-06 21:10:57 +0000535 value = getattr(self, field_name)
536 existing_objs = manager.filter(**{field_name : value})
537 num_existing = existing_objs.count()
showard7c785282008-05-29 19:45:12 +0000538
jadmanski0afbb632008-06-06 21:10:57 +0000539 if num_existing == 0:
540 continue
541 if num_existing == 1 and existing_objs[0].id == self.id:
542 continue
543 errors[field_name] = (
544 'This value must be unique (%s)' % (value))
545 return errors
showard7c785282008-05-29 19:45:12 +0000546
547
showarda5288b42009-07-28 20:06:08 +0000548 def _validate(self):
549 """
550 First coerces all fields on this instance to their proper Python types.
551 Then runs validation on every field. Returns a dictionary of
552 field_name -> error_list.
553
554 Based on validate() from django.db.models.Model in Django 0.96, which
555 was removed in Django 1.0. It should reappear in a later version. See:
556 http://code.djangoproject.com/ticket/6845
557 """
558 error_dict = {}
559 for f in self._meta.fields:
560 try:
561 python_value = f.to_python(
562 getattr(self, f.attname, f.get_default()))
563 except django.core.exceptions.ValidationError, e:
564 error_dict[f.name] = str(e.message)
565 continue
566
567 if not f.blank and not python_value:
568 error_dict[f.name] = 'This field is required.'
569 continue
570
571 setattr(self, f.attname, python_value)
572
573 return error_dict
574
575
jadmanski0afbb632008-06-06 21:10:57 +0000576 def do_validate(self):
showarda5288b42009-07-28 20:06:08 +0000577 errors = self._validate()
jadmanski0afbb632008-06-06 21:10:57 +0000578 unique_errors = self.validate_unique()
579 for field_name, error in unique_errors.iteritems():
580 errors.setdefault(field_name, error)
581 if errors:
582 raise ValidationError(errors)
showard7c785282008-05-29 19:45:12 +0000583
584
jadmanski0afbb632008-06-06 21:10:57 +0000585 # actually (externally) useful methods follow
showard7c785282008-05-29 19:45:12 +0000586
jadmanski0afbb632008-06-06 21:10:57 +0000587 @classmethod
588 def add_object(cls, data={}, **kwargs):
589 """\
590 Returns a new object created with the given data (a dictionary
591 mapping field names to values). Merges any extra keyword args
592 into data.
593 """
594 data = cls.prepare_data_args(data, kwargs)
595 data = cls.provide_default_values(data)
596 obj = cls(**data)
597 obj.do_validate()
598 obj.save()
599 return obj
showard7c785282008-05-29 19:45:12 +0000600
601
jadmanski0afbb632008-06-06 21:10:57 +0000602 def update_object(self, data={}, **kwargs):
603 """\
604 Updates the object with the given data (a dictionary mapping
605 field names to values). Merges any extra keyword args into
606 data.
607 """
608 data = self.prepare_data_args(data, kwargs)
609 for field_name, value in data.iteritems():
showardb0a73032009-03-27 18:35:41 +0000610 setattr(self, field_name, value)
jadmanski0afbb632008-06-06 21:10:57 +0000611 self.do_validate()
612 self.save()
showard7c785282008-05-29 19:45:12 +0000613
614
jadmanski0afbb632008-06-06 21:10:57 +0000615 @classmethod
showard7ac7b7a2008-07-21 20:24:29 +0000616 def query_objects(cls, filter_data, valid_only=True, initial_query=None):
jadmanski0afbb632008-06-06 21:10:57 +0000617 """\
618 Returns a QuerySet object for querying the given model_class
619 with the given filter_data. Optional special arguments in
620 filter_data include:
621 -query_start: index of first return to return
622 -query_limit: maximum number of results to return
623 -sort_by: list of fields to sort on. prefixing a '-' onto a
624 field name changes the sort to descending order.
625 -extra_args: keyword args to pass to query.extra() (see Django
626 DB layer documentation)
showarda5288b42009-07-28 20:06:08 +0000627 -extra_where: extra WHERE clause to append
jadmanski0afbb632008-06-06 21:10:57 +0000628 """
showardc0ac3a72009-07-08 21:14:45 +0000629 filter_data = dict(filter_data) # copy so we don't mutate the original
jadmanski0afbb632008-06-06 21:10:57 +0000630 query_start = filter_data.pop('query_start', None)
631 query_limit = filter_data.pop('query_limit', None)
632 if query_start and not query_limit:
633 raise ValueError('Cannot pass query_start without '
634 'query_limit')
showardc4780402009-08-31 18:31:34 +0000635 sort_by = filter_data.pop('sort_by', None)
jadmanski0afbb632008-06-06 21:10:57 +0000636 extra_args = filter_data.pop('extra_args', {})
637 extra_where = filter_data.pop('extra_where', None)
638 if extra_where:
showard1e935f12008-07-11 00:11:36 +0000639 # escape %'s
showardeaccf8f2009-04-16 03:11:33 +0000640 extra_where = cls.objects.escape_user_sql(extra_where)
jadmanski0afbb632008-06-06 21:10:57 +0000641 extra_args.setdefault('where', []).append(extra_where)
showard7ac7b7a2008-07-21 20:24:29 +0000642 use_distinct = not filter_data.pop('no_distinct', False)
showard7c785282008-05-29 19:45:12 +0000643
showard7ac7b7a2008-07-21 20:24:29 +0000644 if initial_query is None:
645 if valid_only:
646 initial_query = cls.get_valid_manager()
647 else:
648 initial_query = cls.objects
649 query = initial_query.filter(**filter_data)
650 if use_distinct:
651 query = query.distinct()
showard7c785282008-05-29 19:45:12 +0000652
jadmanski0afbb632008-06-06 21:10:57 +0000653 # other arguments
654 if extra_args:
655 query = query.extra(**extra_args)
showard09096d82008-07-07 23:20:49 +0000656 query = query._clone(klass=ReadonlyQuerySet)
showard7c785282008-05-29 19:45:12 +0000657
jadmanski0afbb632008-06-06 21:10:57 +0000658 # sorting + paging
showardc4780402009-08-31 18:31:34 +0000659 if sort_by:
660 assert isinstance(sort_by, list) or isinstance(sort_by, tuple)
661 query = query.order_by(*sort_by)
jadmanski0afbb632008-06-06 21:10:57 +0000662 if query_start is not None and query_limit is not None:
663 query_limit += query_start
664 return query[query_start:query_limit]
showard7c785282008-05-29 19:45:12 +0000665
666
jadmanski0afbb632008-06-06 21:10:57 +0000667 @classmethod
showard585c2ab2008-07-23 19:29:49 +0000668 def query_count(cls, filter_data, initial_query=None):
jadmanski0afbb632008-06-06 21:10:57 +0000669 """\
670 Like query_objects, but retreive only the count of results.
671 """
672 filter_data.pop('query_start', None)
673 filter_data.pop('query_limit', None)
showard585c2ab2008-07-23 19:29:49 +0000674 query = cls.query_objects(filter_data, initial_query=initial_query)
675 return query.count()
showard7c785282008-05-29 19:45:12 +0000676
677
jadmanski0afbb632008-06-06 21:10:57 +0000678 @classmethod
679 def clean_object_dicts(cls, field_dicts):
680 """\
681 Take a list of dicts corresponding to object (as returned by
682 query.values()) and clean the data to be more suitable for
683 returning to the user.
684 """
showarde732ee72008-09-23 19:15:43 +0000685 for field_dict in field_dicts:
686 cls.clean_foreign_keys(field_dict)
showard21baa452008-10-21 00:08:39 +0000687 cls._convert_booleans(field_dict)
showarde732ee72008-09-23 19:15:43 +0000688 cls.convert_human_readable_values(field_dict,
689 to_human_readable=True)
showard7c785282008-05-29 19:45:12 +0000690
691
jadmanski0afbb632008-06-06 21:10:57 +0000692 @classmethod
showarde732ee72008-09-23 19:15:43 +0000693 def list_objects(cls, filter_data, initial_query=None, fields=None):
jadmanski0afbb632008-06-06 21:10:57 +0000694 """\
695 Like query_objects, but return a list of dictionaries.
696 """
showard7ac7b7a2008-07-21 20:24:29 +0000697 query = cls.query_objects(filter_data, initial_query=initial_query)
showarde732ee72008-09-23 19:15:43 +0000698 field_dicts = [model_object.get_object_dict(fields)
699 for model_object in query]
jadmanski0afbb632008-06-06 21:10:57 +0000700 return field_dicts
showard7c785282008-05-29 19:45:12 +0000701
702
jadmanski0afbb632008-06-06 21:10:57 +0000703 @classmethod
showarda4ea5742009-02-17 20:56:23 +0000704 def smart_get(cls, id_or_name, valid_only=True):
jadmanski0afbb632008-06-06 21:10:57 +0000705 """\
706 smart_get(integer) -> get object by ID
707 smart_get(string) -> get object by name_field
jadmanski0afbb632008-06-06 21:10:57 +0000708 """
showarda4ea5742009-02-17 20:56:23 +0000709 if valid_only:
710 manager = cls.get_valid_manager()
711 else:
712 manager = cls.objects
713
714 if isinstance(id_or_name, (int, long)):
715 return manager.get(pk=id_or_name)
716 if isinstance(id_or_name, basestring):
717 return manager.get(**{cls.name_field : id_or_name})
718 raise ValueError(
719 'Invalid positional argument: %s (%s)' % (id_or_name,
720 type(id_or_name)))
showard7c785282008-05-29 19:45:12 +0000721
722
showardbe3ec042008-11-12 18:16:07 +0000723 @classmethod
724 def smart_get_bulk(cls, id_or_name_list):
725 invalid_inputs = []
726 result_objects = []
727 for id_or_name in id_or_name_list:
728 try:
729 result_objects.append(cls.smart_get(id_or_name))
730 except cls.DoesNotExist:
731 invalid_inputs.append(id_or_name)
732 if invalid_inputs:
mbligh7a3ebe32008-12-01 17:10:33 +0000733 raise cls.DoesNotExist('The following %ss do not exist: %s'
734 % (cls.__name__.lower(),
735 ', '.join(invalid_inputs)))
showardbe3ec042008-11-12 18:16:07 +0000736 return result_objects
737
738
showarde732ee72008-09-23 19:15:43 +0000739 def get_object_dict(self, fields=None):
jadmanski0afbb632008-06-06 21:10:57 +0000740 """\
741 Return a dictionary mapping fields to this object's values.
742 """
showarde732ee72008-09-23 19:15:43 +0000743 if fields is None:
744 fields = self.get_field_dict().iterkeys()
jadmanski0afbb632008-06-06 21:10:57 +0000745 object_dict = dict((field_name, getattr(self, field_name))
showarde732ee72008-09-23 19:15:43 +0000746 for field_name in fields)
jadmanski0afbb632008-06-06 21:10:57 +0000747 self.clean_object_dicts([object_dict])
showardd3dc1992009-04-22 21:01:40 +0000748 self._postprocess_object_dict(object_dict)
jadmanski0afbb632008-06-06 21:10:57 +0000749 return object_dict
showard7c785282008-05-29 19:45:12 +0000750
751
showardd3dc1992009-04-22 21:01:40 +0000752 def _postprocess_object_dict(self, object_dict):
753 """For subclasses to override."""
754 pass
755
756
jadmanski0afbb632008-06-06 21:10:57 +0000757 @classmethod
758 def get_valid_manager(cls):
759 return cls.objects
showard7c785282008-05-29 19:45:12 +0000760
761
showard2bab8f42008-11-12 18:15:22 +0000762 def _record_attributes(self, attributes):
763 """
764 See on_attribute_changed.
765 """
766 assert not isinstance(attributes, basestring)
767 self._recorded_attributes = dict((attribute, getattr(self, attribute))
768 for attribute in attributes)
769
770
771 def _check_for_updated_attributes(self):
772 """
773 See on_attribute_changed.
774 """
775 for attribute, original_value in self._recorded_attributes.iteritems():
776 new_value = getattr(self, attribute)
777 if original_value != new_value:
778 self.on_attribute_changed(attribute, original_value)
779 self._record_attributes(self._recorded_attributes.keys())
780
781
782 def on_attribute_changed(self, attribute, old_value):
783 """
784 Called whenever an attribute is updated. To be overridden.
785
786 To use this method, you must:
787 * call _record_attributes() from __init__() (after making the super
788 call) with a list of attributes for which you want to be notified upon
789 change.
790 * call _check_for_updated_attributes() from save().
791 """
792 pass
793
794
showard7c785282008-05-29 19:45:12 +0000795class ModelWithInvalid(ModelExtensions):
jadmanski0afbb632008-06-06 21:10:57 +0000796 """
797 Overrides model methods save() and delete() to support invalidation in
798 place of actual deletion. Subclasses must have a boolean "invalid"
799 field.
800 """
showard7c785282008-05-29 19:45:12 +0000801
showarda5288b42009-07-28 20:06:08 +0000802 def save(self, *args, **kwargs):
showardddb90992009-02-11 23:39:32 +0000803 first_time = (self.id is None)
804 if first_time:
805 # see if this object was previously added and invalidated
806 my_name = getattr(self, self.name_field)
807 filters = {self.name_field : my_name, 'invalid' : True}
808 try:
809 old_object = self.__class__.objects.get(**filters)
810 self.id = old_object.id
811 except self.DoesNotExist:
812 # no existing object
813 pass
showard7c785282008-05-29 19:45:12 +0000814
showarda5288b42009-07-28 20:06:08 +0000815 super(ModelWithInvalid, self).save(*args, **kwargs)
showard7c785282008-05-29 19:45:12 +0000816
817
jadmanski0afbb632008-06-06 21:10:57 +0000818 def clean_object(self):
819 """
820 This method is called when an object is marked invalid.
821 Subclasses should override this to clean up relationships that
822 should no longer exist if the object were deleted."""
823 pass
showard7c785282008-05-29 19:45:12 +0000824
825
jadmanski0afbb632008-06-06 21:10:57 +0000826 def delete(self):
827 assert not self.invalid
828 self.invalid = True
829 self.save()
830 self.clean_object()
showard7c785282008-05-29 19:45:12 +0000831
832
jadmanski0afbb632008-06-06 21:10:57 +0000833 @classmethod
834 def get_valid_manager(cls):
835 return cls.valid_objects
showard7c785282008-05-29 19:45:12 +0000836
837
jadmanski0afbb632008-06-06 21:10:57 +0000838 class Manipulator(object):
839 """
840 Force default manipulators to look only at valid objects -
841 otherwise they will match against invalid objects when checking
842 uniqueness.
843 """
844 @classmethod
845 def _prepare(cls, model):
846 super(ModelWithInvalid.Manipulator, cls)._prepare(model)
847 cls.manager = model.valid_objects
showardf8b19042009-05-12 17:22:49 +0000848
849
850class ModelWithAttributes(object):
851 """
852 Mixin class for models that have an attribute model associated with them.
853 The attribute model is assumed to have its value field named "value".
854 """
855
856 def _get_attribute_model_and_args(self, attribute):
857 """
858 Subclasses should override this to return a tuple (attribute_model,
859 keyword_args), where attribute_model is a model class and keyword_args
860 is a dict of args to pass to attribute_model.objects.get() to get an
861 instance of the given attribute on this object.
862 """
863 raise NotImplemented
864
865
866 def set_attribute(self, attribute, value):
867 attribute_model, get_args = self._get_attribute_model_and_args(
868 attribute)
869 attribute_object, _ = attribute_model.objects.get_or_create(**get_args)
870 attribute_object.value = value
871 attribute_object.save()
872
873
874 def delete_attribute(self, attribute):
875 attribute_model, get_args = self._get_attribute_model_and_args(
876 attribute)
877 try:
878 attribute_model.objects.get(**get_args).delete()
showard16245422009-09-08 16:28:15 +0000879 except attribute_model.DoesNotExist:
showardf8b19042009-05-12 17:22:49 +0000880 pass
881
882
883 def set_or_delete_attribute(self, attribute, value):
884 if value is None:
885 self.delete_attribute(attribute)
886 else:
887 self.set_attribute(attribute, value)