blob: 3b5c70e47aa083a370e19a183deead69d8b1c7c2 [file] [log] [blame]
showard7c785282008-05-29 19:45:12 +00001"""
2Extensions to Django's model logic.
3"""
4
showarda5288b42009-07-28 20:06:08 +00005import re
6import django.core.exceptions
showard7c785282008-05-29 19:45:12 +00007from django.db import models as dbmodels, backend, connection
showarda5288b42009-07-28 20:06:08 +00008from django.db.models.sql import query
showard7c785282008-05-29 19:45:12 +00009from django.utils import datastructures
showard56e93772008-10-06 10:06:22 +000010from autotest_lib.frontend.afe import readonly_connection
showard7c785282008-05-29 19:45:12 +000011
12class ValidationError(Exception):
jadmanski0afbb632008-06-06 21:10:57 +000013 """\
showarda5288b42009-07-28 20:06:08 +000014 Data validation error in adding or updating an object. The associated
jadmanski0afbb632008-06-06 21:10:57 +000015 value is a dictionary mapping field names to error strings.
16 """
showard7c785282008-05-29 19:45:12 +000017
18
showard09096d82008-07-07 23:20:49 +000019def _wrap_with_readonly(method):
mbligh1ef218d2009-08-03 16:57:56 +000020 def wrapper_method(*args, **kwargs):
21 readonly_connection.connection().set_django_connection()
22 try:
23 return method(*args, **kwargs)
24 finally:
25 readonly_connection.connection().unset_django_connection()
26 wrapper_method.__name__ = method.__name__
27 return wrapper_method
showard09096d82008-07-07 23:20:49 +000028
29
showarda5288b42009-07-28 20:06:08 +000030def _quote_name(name):
31 """Shorthand for connection.ops.quote_name()."""
32 return connection.ops.quote_name(name)
33
34
showard09096d82008-07-07 23:20:49 +000035def _wrap_generator_with_readonly(generator):
36 """
37 We have to wrap generators specially. Assume it performs
38 the query on the first call to next().
39 """
40 def wrapper_generator(*args, **kwargs):
41 generator_obj = generator(*args, **kwargs)
showard56e93772008-10-06 10:06:22 +000042 readonly_connection.connection().set_django_connection()
showard09096d82008-07-07 23:20:49 +000043 try:
44 first_value = generator_obj.next()
45 finally:
showard56e93772008-10-06 10:06:22 +000046 readonly_connection.connection().unset_django_connection()
showard09096d82008-07-07 23:20:49 +000047 yield first_value
48
49 while True:
50 yield generator_obj.next()
51
52 wrapper_generator.__name__ = generator.__name__
53 return wrapper_generator
54
55
56def _make_queryset_readonly(queryset):
57 """
58 Wrap all methods that do database queries with a readonly connection.
59 """
60 db_query_methods = ['count', 'get', 'get_or_create', 'latest', 'in_bulk',
61 'delete']
62 for method_name in db_query_methods:
63 method = getattr(queryset, method_name)
64 wrapped_method = _wrap_with_readonly(method)
65 setattr(queryset, method_name, wrapped_method)
66
67 queryset.iterator = _wrap_generator_with_readonly(queryset.iterator)
68
69
70class ReadonlyQuerySet(dbmodels.query.QuerySet):
71 """
72 QuerySet object that performs all database queries with the read-only
73 connection.
74 """
showarda5288b42009-07-28 20:06:08 +000075 def __init__(self, model=None, *args, **kwargs):
76 super(ReadonlyQuerySet, self).__init__(model, *args, **kwargs)
showard09096d82008-07-07 23:20:49 +000077 _make_queryset_readonly(self)
78
79
80 def values(self, *fields):
showarda5288b42009-07-28 20:06:08 +000081 return self._clone(klass=ReadonlyValuesQuerySet,
82 setup=True, _fields=fields)
showard09096d82008-07-07 23:20:49 +000083
84
85class ReadonlyValuesQuerySet(dbmodels.query.ValuesQuerySet):
showarda5288b42009-07-28 20:06:08 +000086 def __init__(self, model=None, *args, **kwargs):
87 super(ReadonlyValuesQuerySet, self).__init__(model, *args, **kwargs)
showard09096d82008-07-07 23:20:49 +000088 _make_queryset_readonly(self)
89
90
showard7c785282008-05-29 19:45:12 +000091class ExtendedManager(dbmodels.Manager):
jadmanski0afbb632008-06-06 21:10:57 +000092 """\
93 Extended manager supporting subquery filtering.
94 """
showard7c785282008-05-29 19:45:12 +000095
showarda5288b42009-07-28 20:06:08 +000096 class _CustomQuery(query.Query):
97 def clone(self, klass=None, **kwargs):
98 obj = super(ExtendedManager._CustomQuery, self).clone(
99 klass, _customSqlQ=self._customSqlQ)
showard7c785282008-05-29 19:45:12 +0000100
showarda5288b42009-07-28 20:06:08 +0000101 customQ = kwargs.get('_customSqlQ', None)
102 if customQ is not None:
103 obj._customSqlQ._joins.update(customQ._joins)
104 obj._customSqlQ._where.extend(customQ._where)
105 obj._customSqlQ._params.extend(customQ._params)
showard7c785282008-05-29 19:45:12 +0000106
showarda5288b42009-07-28 20:06:08 +0000107 return obj
showard08f981b2008-06-24 21:59:03 +0000108
showarda5288b42009-07-28 20:06:08 +0000109 def get_from_clause(self):
110 from_, params = super(
111 ExtendedManager._CustomQuery, self).get_from_clause()
showard08f981b2008-06-24 21:59:03 +0000112
showarda5288b42009-07-28 20:06:08 +0000113 join_clause = ''
114 for join_alias, join in self._customSqlQ._joins.iteritems():
115 join_table, join_type, condition = join
116 join_clause += ' %s %s AS %s ON (%s)' % (
showard8b0ea222009-12-23 19:23:03 +0000117 join_type, _quote_name(join_table),
118 _quote_name(join_alias), condition)
showard08f981b2008-06-24 21:59:03 +0000119
showarda5288b42009-07-28 20:06:08 +0000120 if join_clause:
121 from_.append(join_clause)
showard7c785282008-05-29 19:45:12 +0000122
showarda5288b42009-07-28 20:06:08 +0000123 return from_, params
showard7c785282008-05-29 19:45:12 +0000124
125
showard43a3d262008-11-12 18:17:05 +0000126 class _CustomSqlQ(dbmodels.Q):
127 def __init__(self):
128 self._joins = datastructures.SortedDict()
129 self._where, self._params = [], []
130
131
132 def add_join(self, table, condition, join_type, alias=None):
133 if alias is None:
134 alias = table
showard43a3d262008-11-12 18:17:05 +0000135 self._joins[alias] = (table, join_type, condition)
136
137
138 def add_where(self, where, params=[]):
139 self._where.append(where)
140 self._params.extend(params)
141
142
showarda5288b42009-07-28 20:06:08 +0000143 def add_to_query(self, query, aliases):
144 if self._where:
145 where = ' AND '.join(self._where)
146 query.add_extra(None, None, (where,), self._params, None, None)
147
148
149 def _add_customSqlQ(self, query_set, filter_object):
150 """\
151 Add a _CustomSqlQ to the query set.
152 """
153 # Make a copy of the query set
154 query_set = query_set.all()
155
156 query_set.query = query_set.query.clone(
157 ExtendedManager._CustomQuery, _customSqlQ=filter_object)
158 return query_set.filter(filter_object)
showard43a3d262008-11-12 18:17:05 +0000159
160
showard8b0ea222009-12-23 19:23:03 +0000161 def add_join(self, query_set, join_table, join_key, join_condition='',
162 alias=None, suffix='', exclude=False, force_left_join=False):
showard0957a842009-05-11 19:25:08 +0000163 """
164 Add a join to query_set.
165 @param join_table table to join to
166 @param join_key field referencing back to this model to use for the join
167 @param join_condition extra condition for the ON clause of the join
showard8b0ea222009-12-23 19:23:03 +0000168 @param alias alias to use for for join
169 @param suffix suffix to add to join_table for the join alias, if no
170 alias is provided
showard0957a842009-05-11 19:25:08 +0000171 @param exclude if true, exclude rows that match this join (will use a
showarda5288b42009-07-28 20:06:08 +0000172 LEFT OUTER JOIN and an appropriate WHERE condition)
showardc4780402009-08-31 18:31:34 +0000173 @param force_left_join - if true, a LEFT OUTER JOIN will be used
174 instead of an INNER JOIN regardless of other options
showard0957a842009-05-11 19:25:08 +0000175 """
showard8b0ea222009-12-23 19:23:03 +0000176 join_from_table = _quote_name(self.model._meta.db_table)
177 join_from_key = _quote_name(self.model._meta.pk.name)
178 if alias:
179 join_alias = alias
180 else:
181 join_alias = join_table + suffix
182 full_join_key = _quote_name(join_alias) + '.' + _quote_name(join_key)
showard0957a842009-05-11 19:25:08 +0000183 full_join_condition = '%s = %s.%s' % (full_join_key, join_from_table,
184 join_from_key)
showard43a3d262008-11-12 18:17:05 +0000185 if join_condition:
186 full_join_condition += ' AND (' + join_condition + ')'
187 if exclude or force_left_join:
showarda5288b42009-07-28 20:06:08 +0000188 join_type = query_set.query.LOUTER
showard43a3d262008-11-12 18:17:05 +0000189 else:
showarda5288b42009-07-28 20:06:08 +0000190 join_type = query_set.query.INNER
showard43a3d262008-11-12 18:17:05 +0000191
192 filter_object = self._CustomSqlQ()
193 filter_object.add_join(join_table,
194 full_join_condition,
195 join_type,
196 alias=join_alias)
197 if exclude:
198 filter_object.add_where(full_join_key + ' IS NULL')
showard43a3d262008-11-12 18:17:05 +0000199
showarda5288b42009-07-28 20:06:08 +0000200 query_set = self._add_customSqlQ(query_set, filter_object)
showardc4780402009-08-31 18:31:34 +0000201 return query_set
showard7c785282008-05-29 19:45:12 +0000202
203
showardeaccf8f2009-04-16 03:11:33 +0000204 def _get_quoted_field(self, table, field):
showarda5288b42009-07-28 20:06:08 +0000205 return _quote_name(table) + '.' + _quote_name(field)
showard5ef36e92008-07-02 16:37:09 +0000206
207
showard7c199df2008-10-03 10:17:15 +0000208 def get_key_on_this_table(self, key_field=None):
showard5ef36e92008-07-02 16:37:09 +0000209 if key_field is None:
210 # default to primary key
211 key_field = self.model._meta.pk.column
212 return self._get_quoted_field(self.model._meta.db_table, key_field)
213
214
showardeaccf8f2009-04-16 03:11:33 +0000215 def escape_user_sql(self, sql):
216 return sql.replace('%', '%%')
217
showard5ef36e92008-07-02 16:37:09 +0000218
showard0957a842009-05-11 19:25:08 +0000219 def _custom_select_query(self, query_set, selects):
showarda5288b42009-07-28 20:06:08 +0000220 sql, params = query_set.query.as_sql()
221 from_ = sql[sql.find(' FROM'):]
222
223 if query_set.query.distinct:
showard0957a842009-05-11 19:25:08 +0000224 distinct = 'DISTINCT '
225 else:
226 distinct = ''
showarda5288b42009-07-28 20:06:08 +0000227
228 sql_query = ('SELECT ' + distinct + ','.join(selects) + from_)
showard0957a842009-05-11 19:25:08 +0000229 cursor = readonly_connection.connection().cursor()
230 cursor.execute(sql_query, params)
231 return cursor.fetchall()
232
233
showard68693f72009-05-20 00:31:53 +0000234 def _is_relation_to(self, field, model_class):
235 return field.rel and field.rel.to is model_class
showard0957a842009-05-11 19:25:08 +0000236
237
showard68693f72009-05-20 00:31:53 +0000238 def _get_pivot_iterator(self, base_objects_by_id, related_model):
showard0957a842009-05-11 19:25:08 +0000239 """
showard68693f72009-05-20 00:31:53 +0000240 Determine the relationship between this model and related_model, and
241 return a pivot iterator.
242 @param base_objects_by_id: dict of instances of this model indexed by
243 their IDs
244 @returns a pivot iterator, which yields a tuple (base_object,
245 related_object) for each relationship between a base object and a
246 related object. all base_object instances come from base_objects_by_id.
showard0957a842009-05-11 19:25:08 +0000247 Note -- this depends on Django model internals and will likely need to
248 be updated when we move to Django 1.x.
249 """
showard68693f72009-05-20 00:31:53 +0000250 # look for a field on related_model relating to this model
251 for field in related_model._meta.fields:
showard0957a842009-05-11 19:25:08 +0000252 if self._is_relation_to(field, self.model):
showard68693f72009-05-20 00:31:53 +0000253 # many-to-one
254 return self._many_to_one_pivot(base_objects_by_id,
255 related_model, field)
showard0957a842009-05-11 19:25:08 +0000256
showard68693f72009-05-20 00:31:53 +0000257 for field in related_model._meta.many_to_many:
showard0957a842009-05-11 19:25:08 +0000258 if self._is_relation_to(field, self.model):
259 # many-to-many
showard68693f72009-05-20 00:31:53 +0000260 return self._many_to_many_pivot(
261 base_objects_by_id, related_model, field.m2m_db_table(),
262 field.m2m_reverse_name(), field.m2m_column_name())
showard0957a842009-05-11 19:25:08 +0000263
264 # maybe this model has the many-to-many field
265 for field in self.model._meta.many_to_many:
showard68693f72009-05-20 00:31:53 +0000266 if self._is_relation_to(field, related_model):
267 return self._many_to_many_pivot(
268 base_objects_by_id, related_model, field.m2m_db_table(),
269 field.m2m_column_name(), field.m2m_reverse_name())
showard0957a842009-05-11 19:25:08 +0000270
271 raise ValueError('%s has no relation to %s' %
showard68693f72009-05-20 00:31:53 +0000272 (related_model, self.model))
showard0957a842009-05-11 19:25:08 +0000273
274
showard68693f72009-05-20 00:31:53 +0000275 def _many_to_one_pivot(self, base_objects_by_id, related_model,
276 foreign_key_field):
277 """
278 @returns a pivot iterator - see _get_pivot_iterator()
279 """
280 filter_data = {foreign_key_field.name + '__pk__in':
281 base_objects_by_id.keys()}
282 for related_object in related_model.objects.filter(**filter_data):
showarda5a72c92009-08-20 23:35:21 +0000283 # lookup base object in the dict, rather than grabbing it from the
284 # related object. we need to return instances from the dict, not
285 # fresh instances of the same models (and grabbing model instances
286 # from the related models incurs a DB query each time).
287 base_object_id = getattr(related_object, foreign_key_field.attname)
288 base_object = base_objects_by_id[base_object_id]
showard68693f72009-05-20 00:31:53 +0000289 yield base_object, related_object
290
291
292 def _query_pivot_table(self, base_objects_by_id, pivot_table,
293 pivot_from_field, pivot_to_field):
showard0957a842009-05-11 19:25:08 +0000294 """
295 @param id_list list of IDs of self.model objects to include
296 @param pivot_table the name of the pivot table
297 @param pivot_from_field a field name on pivot_table referencing
298 self.model
299 @param pivot_to_field a field name on pivot_table referencing the
300 related model.
showard68693f72009-05-20 00:31:53 +0000301 @returns pivot list of IDs (base_id, related_id)
showard0957a842009-05-11 19:25:08 +0000302 """
303 query = """
304 SELECT %(from_field)s, %(to_field)s
305 FROM %(table)s
306 WHERE %(from_field)s IN (%(id_list)s)
307 """ % dict(from_field=pivot_from_field,
308 to_field=pivot_to_field,
309 table=pivot_table,
showard68693f72009-05-20 00:31:53 +0000310 id_list=','.join(str(id_) for id_
311 in base_objects_by_id.iterkeys()))
showard0957a842009-05-11 19:25:08 +0000312 cursor = readonly_connection.connection().cursor()
313 cursor.execute(query)
showard68693f72009-05-20 00:31:53 +0000314 return cursor.fetchall()
showard0957a842009-05-11 19:25:08 +0000315
316
showard68693f72009-05-20 00:31:53 +0000317 def _many_to_many_pivot(self, base_objects_by_id, related_model,
318 pivot_table, pivot_from_field, pivot_to_field):
319 """
320 @param pivot_table: see _query_pivot_table
321 @param pivot_from_field: see _query_pivot_table
322 @param pivot_to_field: see _query_pivot_table
323 @returns a pivot iterator - see _get_pivot_iterator()
324 """
325 id_pivot = self._query_pivot_table(base_objects_by_id, pivot_table,
326 pivot_from_field, pivot_to_field)
327
328 all_related_ids = list(set(related_id for base_id, related_id
329 in id_pivot))
330 related_objects_by_id = related_model.objects.in_bulk(all_related_ids)
331
332 for base_id, related_id in id_pivot:
333 yield base_objects_by_id[base_id], related_objects_by_id[related_id]
334
335
336 def populate_relationships(self, base_objects, related_model,
showard0957a842009-05-11 19:25:08 +0000337 related_list_name):
338 """
showard68693f72009-05-20 00:31:53 +0000339 For each instance of this model in base_objects, add a field named
340 related_list_name listing all the related objects of type related_model.
341 related_model must be in a many-to-one or many-to-many relationship with
342 this model.
343 @param base_objects - list of instances of this model
344 @param related_model - model class related to this model
345 @param related_list_name - attribute name in which to store the related
346 object list.
showard0957a842009-05-11 19:25:08 +0000347 """
showard68693f72009-05-20 00:31:53 +0000348 if not base_objects:
showard0957a842009-05-11 19:25:08 +0000349 # if we don't bail early, we'll get a SQL error later
350 return
showard0957a842009-05-11 19:25:08 +0000351
showard68693f72009-05-20 00:31:53 +0000352 base_objects_by_id = dict((base_object._get_pk_val(), base_object)
353 for base_object in base_objects)
354 pivot_iterator = self._get_pivot_iterator(base_objects_by_id,
355 related_model)
showard0957a842009-05-11 19:25:08 +0000356
showard68693f72009-05-20 00:31:53 +0000357 for base_object in base_objects:
358 setattr(base_object, related_list_name, [])
359
360 for base_object, related_object in pivot_iterator:
361 getattr(base_object, related_list_name).append(related_object)
showard0957a842009-05-11 19:25:08 +0000362
363
showard7c785282008-05-29 19:45:12 +0000364class ValidObjectsManager(ExtendedManager):
jadmanski0afbb632008-06-06 21:10:57 +0000365 """
366 Manager returning only objects with invalid=False.
367 """
368 def get_query_set(self):
369 queryset = super(ValidObjectsManager, self).get_query_set()
370 return queryset.filter(invalid=False)
showard7c785282008-05-29 19:45:12 +0000371
372
373class ModelExtensions(object):
jadmanski0afbb632008-06-06 21:10:57 +0000374 """\
375 Mixin with convenience functions for models, built on top of the
376 default Django model functions.
377 """
378 # TODO: at least some of these functions really belong in a custom
379 # Manager class
showard7c785282008-05-29 19:45:12 +0000380
jadmanski0afbb632008-06-06 21:10:57 +0000381 field_dict = None
382 # subclasses should override if they want to support smart_get() by name
383 name_field = None
showard7c785282008-05-29 19:45:12 +0000384
385
jadmanski0afbb632008-06-06 21:10:57 +0000386 @classmethod
387 def get_field_dict(cls):
388 if cls.field_dict is None:
389 cls.field_dict = {}
390 for field in cls._meta.fields:
391 cls.field_dict[field.name] = field
392 return cls.field_dict
showard7c785282008-05-29 19:45:12 +0000393
394
jadmanski0afbb632008-06-06 21:10:57 +0000395 @classmethod
396 def clean_foreign_keys(cls, data):
397 """\
398 -Convert foreign key fields in data from <field>_id to just
399 <field>.
400 -replace foreign key objects with their IDs
401 This method modifies data in-place.
402 """
403 for field in cls._meta.fields:
404 if not field.rel:
405 continue
406 if (field.attname != field.name and
407 field.attname in data):
408 data[field.name] = data[field.attname]
409 del data[field.attname]
showarde732ee72008-09-23 19:15:43 +0000410 if field.name not in data:
411 continue
jadmanski0afbb632008-06-06 21:10:57 +0000412 value = data[field.name]
413 if isinstance(value, dbmodels.Model):
showardf8b19042009-05-12 17:22:49 +0000414 data[field.name] = value._get_pk_val()
showard7c785282008-05-29 19:45:12 +0000415
416
showard21baa452008-10-21 00:08:39 +0000417 @classmethod
418 def _convert_booleans(cls, data):
419 """
420 Ensure BooleanFields actually get bool values. The Django MySQL
421 backend returns ints for BooleanFields, which is almost always not
422 a problem, but it can be annoying in certain situations.
423 """
424 for field in cls._meta.fields:
showardf8b19042009-05-12 17:22:49 +0000425 if type(field) == dbmodels.BooleanField and field.name in data:
showard21baa452008-10-21 00:08:39 +0000426 data[field.name] = bool(data[field.name])
427
428
jadmanski0afbb632008-06-06 21:10:57 +0000429 # TODO(showard) - is there a way to not have to do this?
430 @classmethod
431 def provide_default_values(cls, data):
432 """\
433 Provide default values for fields with default values which have
434 nothing passed in.
showard7c785282008-05-29 19:45:12 +0000435
jadmanski0afbb632008-06-06 21:10:57 +0000436 For CharField and TextField fields with "blank=True", if nothing
437 is passed, we fill in an empty string value, even if there's no
438 default set.
439 """
440 new_data = dict(data)
441 field_dict = cls.get_field_dict()
442 for name, obj in field_dict.iteritems():
443 if data.get(name) is not None:
444 continue
445 if obj.default is not dbmodels.fields.NOT_PROVIDED:
446 new_data[name] = obj.default
447 elif (isinstance(obj, dbmodels.CharField) or
448 isinstance(obj, dbmodels.TextField)):
449 new_data[name] = ''
450 return new_data
showard7c785282008-05-29 19:45:12 +0000451
452
jadmanski0afbb632008-06-06 21:10:57 +0000453 @classmethod
454 def convert_human_readable_values(cls, data, to_human_readable=False):
455 """\
456 Performs conversions on user-supplied field data, to make it
457 easier for users to pass human-readable data.
showard7c785282008-05-29 19:45:12 +0000458
jadmanski0afbb632008-06-06 21:10:57 +0000459 For all fields that have choice sets, convert their values
460 from human-readable strings to enum values, if necessary. This
461 allows users to pass strings instead of the corresponding
462 integer values.
showard7c785282008-05-29 19:45:12 +0000463
jadmanski0afbb632008-06-06 21:10:57 +0000464 For all foreign key fields, call smart_get with the supplied
465 data. This allows the user to pass either an ID value or
466 the name of the object as a string.
showard7c785282008-05-29 19:45:12 +0000467
jadmanski0afbb632008-06-06 21:10:57 +0000468 If to_human_readable=True, perform the inverse - i.e. convert
469 numeric values to human readable values.
showard7c785282008-05-29 19:45:12 +0000470
jadmanski0afbb632008-06-06 21:10:57 +0000471 This method modifies data in-place.
472 """
473 field_dict = cls.get_field_dict()
474 for field_name in data:
showarde732ee72008-09-23 19:15:43 +0000475 if field_name not in field_dict or data[field_name] is None:
jadmanski0afbb632008-06-06 21:10:57 +0000476 continue
477 field_obj = field_dict[field_name]
478 # convert enum values
479 if field_obj.choices:
480 for choice_data in field_obj.choices:
481 # choice_data is (value, name)
482 if to_human_readable:
483 from_val, to_val = choice_data
484 else:
485 to_val, from_val = choice_data
486 if from_val == data[field_name]:
487 data[field_name] = to_val
488 break
489 # convert foreign key values
490 elif field_obj.rel:
showarda4ea5742009-02-17 20:56:23 +0000491 dest_obj = field_obj.rel.to.smart_get(data[field_name],
492 valid_only=False)
showardf8b19042009-05-12 17:22:49 +0000493 if to_human_readable:
494 if dest_obj.name_field is not None:
495 data[field_name] = getattr(dest_obj,
496 dest_obj.name_field)
jadmanski0afbb632008-06-06 21:10:57 +0000497 else:
showardb0a73032009-03-27 18:35:41 +0000498 data[field_name] = dest_obj
showard7c785282008-05-29 19:45:12 +0000499
500
jadmanski0afbb632008-06-06 21:10:57 +0000501 @classmethod
502 def validate_field_names(cls, data):
503 'Checks for extraneous fields in data.'
504 errors = {}
505 field_dict = cls.get_field_dict()
506 for field_name in data:
507 if field_name not in field_dict:
508 errors[field_name] = 'No field of this name'
509 return errors
showard7c785282008-05-29 19:45:12 +0000510
511
jadmanski0afbb632008-06-06 21:10:57 +0000512 @classmethod
513 def prepare_data_args(cls, data, kwargs):
514 'Common preparation for add_object and update_object'
515 data = dict(data) # don't modify the default keyword arg
516 data.update(kwargs)
517 # must check for extraneous field names here, while we have the
518 # data in a dict
519 errors = cls.validate_field_names(data)
520 if errors:
521 raise ValidationError(errors)
522 cls.convert_human_readable_values(data)
523 return data
showard7c785282008-05-29 19:45:12 +0000524
525
jadmanski0afbb632008-06-06 21:10:57 +0000526 def validate_unique(self):
527 """\
528 Validate that unique fields are unique. Django manipulators do
529 this too, but they're a huge pain to use manually. Trust me.
530 """
531 errors = {}
532 cls = type(self)
533 field_dict = self.get_field_dict()
534 manager = cls.get_valid_manager()
535 for field_name, field_obj in field_dict.iteritems():
536 if not field_obj.unique:
537 continue
showard7c785282008-05-29 19:45:12 +0000538
jadmanski0afbb632008-06-06 21:10:57 +0000539 value = getattr(self, field_name)
showardbd18ab72009-09-18 21:20:27 +0000540 if value is None and field_obj.auto_created:
541 # don't bother checking autoincrement fields about to be
542 # generated
543 continue
544
jadmanski0afbb632008-06-06 21:10:57 +0000545 existing_objs = manager.filter(**{field_name : value})
546 num_existing = existing_objs.count()
showard7c785282008-05-29 19:45:12 +0000547
jadmanski0afbb632008-06-06 21:10:57 +0000548 if num_existing == 0:
549 continue
550 if num_existing == 1 and existing_objs[0].id == self.id:
551 continue
552 errors[field_name] = (
553 'This value must be unique (%s)' % (value))
554 return errors
showard7c785282008-05-29 19:45:12 +0000555
556
showarda5288b42009-07-28 20:06:08 +0000557 def _validate(self):
558 """
559 First coerces all fields on this instance to their proper Python types.
560 Then runs validation on every field. Returns a dictionary of
561 field_name -> error_list.
562
563 Based on validate() from django.db.models.Model in Django 0.96, which
564 was removed in Django 1.0. It should reappear in a later version. See:
565 http://code.djangoproject.com/ticket/6845
566 """
567 error_dict = {}
568 for f in self._meta.fields:
569 try:
570 python_value = f.to_python(
571 getattr(self, f.attname, f.get_default()))
572 except django.core.exceptions.ValidationError, e:
573 error_dict[f.name] = str(e.message)
574 continue
575
576 if not f.blank and not python_value:
577 error_dict[f.name] = 'This field is required.'
578 continue
579
580 setattr(self, f.attname, python_value)
581
582 return error_dict
583
584
jadmanski0afbb632008-06-06 21:10:57 +0000585 def do_validate(self):
showarda5288b42009-07-28 20:06:08 +0000586 errors = self._validate()
jadmanski0afbb632008-06-06 21:10:57 +0000587 unique_errors = self.validate_unique()
588 for field_name, error in unique_errors.iteritems():
589 errors.setdefault(field_name, error)
590 if errors:
591 raise ValidationError(errors)
showard7c785282008-05-29 19:45:12 +0000592
593
jadmanski0afbb632008-06-06 21:10:57 +0000594 # actually (externally) useful methods follow
showard7c785282008-05-29 19:45:12 +0000595
jadmanski0afbb632008-06-06 21:10:57 +0000596 @classmethod
597 def add_object(cls, data={}, **kwargs):
598 """\
599 Returns a new object created with the given data (a dictionary
600 mapping field names to values). Merges any extra keyword args
601 into data.
602 """
603 data = cls.prepare_data_args(data, kwargs)
604 data = cls.provide_default_values(data)
605 obj = cls(**data)
606 obj.do_validate()
607 obj.save()
608 return obj
showard7c785282008-05-29 19:45:12 +0000609
610
jadmanski0afbb632008-06-06 21:10:57 +0000611 def update_object(self, data={}, **kwargs):
612 """\
613 Updates the object with the given data (a dictionary mapping
614 field names to values). Merges any extra keyword args into
615 data.
616 """
617 data = self.prepare_data_args(data, kwargs)
618 for field_name, value in data.iteritems():
showardb0a73032009-03-27 18:35:41 +0000619 setattr(self, field_name, value)
jadmanski0afbb632008-06-06 21:10:57 +0000620 self.do_validate()
621 self.save()
showard7c785282008-05-29 19:45:12 +0000622
623
showard8bfb5cb2009-10-07 20:49:15 +0000624 # see query_objects()
625 _SPECIAL_FILTER_KEYS = ('query_start', 'query_limit', 'sort_by',
626 'extra_args', 'extra_where', 'no_distinct')
627
628
jadmanski0afbb632008-06-06 21:10:57 +0000629 @classmethod
showard8bfb5cb2009-10-07 20:49:15 +0000630 def _extract_special_params(cls, filter_data):
631 """
632 @returns a tuple of dicts (special_params, regular_filters), where
633 special_params contains the parameters we handle specially and
634 regular_filters is the remaining data to be handled by Django.
635 """
636 regular_filters = dict(filter_data)
637 special_params = {}
638 for key in cls._SPECIAL_FILTER_KEYS:
639 if key in regular_filters:
640 special_params[key] = regular_filters.pop(key)
641 return special_params, regular_filters
642
643
644 @classmethod
645 def apply_presentation(cls, query, filter_data):
646 """
647 Apply presentation parameters -- sorting and paging -- to the given
648 query.
649 @returns new query with presentation applied
650 """
651 special_params, _ = cls._extract_special_params(filter_data)
652 sort_by = special_params.get('sort_by', None)
653 if sort_by:
654 assert isinstance(sort_by, list) or isinstance(sort_by, tuple)
showard8b0ea222009-12-23 19:23:03 +0000655 query = query.extra(order_by=sort_by)
showard8bfb5cb2009-10-07 20:49:15 +0000656
657 query_start = special_params.get('query_start', None)
658 query_limit = special_params.get('query_limit', None)
659 if query_start is not None:
660 if query_limit is None:
661 raise ValueError('Cannot pass query_start without query_limit')
662 # query_limit is passed as a page size
showard7074b742009-10-12 20:30:04 +0000663 query_limit += query_start
664 return query[query_start:query_limit]
showard8bfb5cb2009-10-07 20:49:15 +0000665
666
667 @classmethod
668 def query_objects(cls, filter_data, valid_only=True, initial_query=None,
669 apply_presentation=True):
jadmanski0afbb632008-06-06 21:10:57 +0000670 """\
671 Returns a QuerySet object for querying the given model_class
672 with the given filter_data. Optional special arguments in
673 filter_data include:
674 -query_start: index of first return to return
675 -query_limit: maximum number of results to return
676 -sort_by: list of fields to sort on. prefixing a '-' onto a
677 field name changes the sort to descending order.
678 -extra_args: keyword args to pass to query.extra() (see Django
679 DB layer documentation)
showarda5288b42009-07-28 20:06:08 +0000680 -extra_where: extra WHERE clause to append
showard8bfb5cb2009-10-07 20:49:15 +0000681 -no_distinct: if True, a DISTINCT will not be added to the SELECT
jadmanski0afbb632008-06-06 21:10:57 +0000682 """
showard8bfb5cb2009-10-07 20:49:15 +0000683 special_params, regular_filters = cls._extract_special_params(
684 filter_data)
showard7c785282008-05-29 19:45:12 +0000685
showard7ac7b7a2008-07-21 20:24:29 +0000686 if initial_query is None:
687 if valid_only:
688 initial_query = cls.get_valid_manager()
689 else:
690 initial_query = cls.objects
showard8bfb5cb2009-10-07 20:49:15 +0000691
692 query = initial_query.filter(**regular_filters)
693
694 use_distinct = not special_params.get('no_distinct', False)
showard7ac7b7a2008-07-21 20:24:29 +0000695 if use_distinct:
696 query = query.distinct()
showard7c785282008-05-29 19:45:12 +0000697
showard8bfb5cb2009-10-07 20:49:15 +0000698 extra_args = special_params.get('extra_args', {})
699 extra_where = special_params.get('extra_where', None)
700 if extra_where:
701 # escape %'s
702 extra_where = cls.objects.escape_user_sql(extra_where)
703 extra_args.setdefault('where', []).append(extra_where)
jadmanski0afbb632008-06-06 21:10:57 +0000704 if extra_args:
705 query = query.extra(**extra_args)
showard09096d82008-07-07 23:20:49 +0000706 query = query._clone(klass=ReadonlyQuerySet)
showard7c785282008-05-29 19:45:12 +0000707
showard8bfb5cb2009-10-07 20:49:15 +0000708 if apply_presentation:
709 query = cls.apply_presentation(query, filter_data)
710
711 return query
showard7c785282008-05-29 19:45:12 +0000712
713
jadmanski0afbb632008-06-06 21:10:57 +0000714 @classmethod
showard585c2ab2008-07-23 19:29:49 +0000715 def query_count(cls, filter_data, initial_query=None):
jadmanski0afbb632008-06-06 21:10:57 +0000716 """\
717 Like query_objects, but retreive only the count of results.
718 """
719 filter_data.pop('query_start', None)
720 filter_data.pop('query_limit', None)
showard585c2ab2008-07-23 19:29:49 +0000721 query = cls.query_objects(filter_data, initial_query=initial_query)
722 return query.count()
showard7c785282008-05-29 19:45:12 +0000723
724
jadmanski0afbb632008-06-06 21:10:57 +0000725 @classmethod
726 def clean_object_dicts(cls, field_dicts):
727 """\
728 Take a list of dicts corresponding to object (as returned by
729 query.values()) and clean the data to be more suitable for
730 returning to the user.
731 """
showarde732ee72008-09-23 19:15:43 +0000732 for field_dict in field_dicts:
733 cls.clean_foreign_keys(field_dict)
showard21baa452008-10-21 00:08:39 +0000734 cls._convert_booleans(field_dict)
showarde732ee72008-09-23 19:15:43 +0000735 cls.convert_human_readable_values(field_dict,
736 to_human_readable=True)
showard7c785282008-05-29 19:45:12 +0000737
738
jadmanski0afbb632008-06-06 21:10:57 +0000739 @classmethod
showard8bfb5cb2009-10-07 20:49:15 +0000740 def list_objects(cls, filter_data, initial_query=None):
jadmanski0afbb632008-06-06 21:10:57 +0000741 """\
742 Like query_objects, but return a list of dictionaries.
743 """
showard7ac7b7a2008-07-21 20:24:29 +0000744 query = cls.query_objects(filter_data, initial_query=initial_query)
showard8bfb5cb2009-10-07 20:49:15 +0000745 extra_fields = query.query.extra_select.keys()
746 field_dicts = [model_object.get_object_dict(extra_fields=extra_fields)
showarde732ee72008-09-23 19:15:43 +0000747 for model_object in query]
jadmanski0afbb632008-06-06 21:10:57 +0000748 return field_dicts
showard7c785282008-05-29 19:45:12 +0000749
750
jadmanski0afbb632008-06-06 21:10:57 +0000751 @classmethod
showarda4ea5742009-02-17 20:56:23 +0000752 def smart_get(cls, id_or_name, valid_only=True):
jadmanski0afbb632008-06-06 21:10:57 +0000753 """\
754 smart_get(integer) -> get object by ID
755 smart_get(string) -> get object by name_field
jadmanski0afbb632008-06-06 21:10:57 +0000756 """
showarda4ea5742009-02-17 20:56:23 +0000757 if valid_only:
758 manager = cls.get_valid_manager()
759 else:
760 manager = cls.objects
761
762 if isinstance(id_or_name, (int, long)):
763 return manager.get(pk=id_or_name)
764 if isinstance(id_or_name, basestring):
765 return manager.get(**{cls.name_field : id_or_name})
766 raise ValueError(
767 'Invalid positional argument: %s (%s)' % (id_or_name,
768 type(id_or_name)))
showard7c785282008-05-29 19:45:12 +0000769
770
showardbe3ec042008-11-12 18:16:07 +0000771 @classmethod
772 def smart_get_bulk(cls, id_or_name_list):
773 invalid_inputs = []
774 result_objects = []
775 for id_or_name in id_or_name_list:
776 try:
777 result_objects.append(cls.smart_get(id_or_name))
778 except cls.DoesNotExist:
779 invalid_inputs.append(id_or_name)
780 if invalid_inputs:
mbligh7a3ebe32008-12-01 17:10:33 +0000781 raise cls.DoesNotExist('The following %ss do not exist: %s'
782 % (cls.__name__.lower(),
783 ', '.join(invalid_inputs)))
showardbe3ec042008-11-12 18:16:07 +0000784 return result_objects
785
786
showard8bfb5cb2009-10-07 20:49:15 +0000787 def get_object_dict(self, extra_fields=None):
jadmanski0afbb632008-06-06 21:10:57 +0000788 """\
showard8bfb5cb2009-10-07 20:49:15 +0000789 Return a dictionary mapping fields to this object's values. @param
790 extra_fields: list of extra attribute names to include, in addition to
791 the fields defined on this object.
jadmanski0afbb632008-06-06 21:10:57 +0000792 """
showard8bfb5cb2009-10-07 20:49:15 +0000793 fields = self.get_field_dict().keys()
794 if extra_fields:
795 fields += extra_fields
jadmanski0afbb632008-06-06 21:10:57 +0000796 object_dict = dict((field_name, getattr(self, field_name))
showarde732ee72008-09-23 19:15:43 +0000797 for field_name in fields)
jadmanski0afbb632008-06-06 21:10:57 +0000798 self.clean_object_dicts([object_dict])
showardd3dc1992009-04-22 21:01:40 +0000799 self._postprocess_object_dict(object_dict)
jadmanski0afbb632008-06-06 21:10:57 +0000800 return object_dict
showard7c785282008-05-29 19:45:12 +0000801
802
showardd3dc1992009-04-22 21:01:40 +0000803 def _postprocess_object_dict(self, object_dict):
804 """For subclasses to override."""
805 pass
806
807
jadmanski0afbb632008-06-06 21:10:57 +0000808 @classmethod
809 def get_valid_manager(cls):
810 return cls.objects
showard7c785282008-05-29 19:45:12 +0000811
812
showard2bab8f42008-11-12 18:15:22 +0000813 def _record_attributes(self, attributes):
814 """
815 See on_attribute_changed.
816 """
817 assert not isinstance(attributes, basestring)
818 self._recorded_attributes = dict((attribute, getattr(self, attribute))
819 for attribute in attributes)
820
821
822 def _check_for_updated_attributes(self):
823 """
824 See on_attribute_changed.
825 """
826 for attribute, original_value in self._recorded_attributes.iteritems():
827 new_value = getattr(self, attribute)
828 if original_value != new_value:
829 self.on_attribute_changed(attribute, original_value)
830 self._record_attributes(self._recorded_attributes.keys())
831
832
833 def on_attribute_changed(self, attribute, old_value):
834 """
835 Called whenever an attribute is updated. To be overridden.
836
837 To use this method, you must:
838 * call _record_attributes() from __init__() (after making the super
839 call) with a list of attributes for which you want to be notified upon
840 change.
841 * call _check_for_updated_attributes() from save().
842 """
843 pass
844
845
showard7c785282008-05-29 19:45:12 +0000846class ModelWithInvalid(ModelExtensions):
jadmanski0afbb632008-06-06 21:10:57 +0000847 """
848 Overrides model methods save() and delete() to support invalidation in
849 place of actual deletion. Subclasses must have a boolean "invalid"
850 field.
851 """
showard7c785282008-05-29 19:45:12 +0000852
showarda5288b42009-07-28 20:06:08 +0000853 def save(self, *args, **kwargs):
showardddb90992009-02-11 23:39:32 +0000854 first_time = (self.id is None)
855 if first_time:
856 # see if this object was previously added and invalidated
857 my_name = getattr(self, self.name_field)
858 filters = {self.name_field : my_name, 'invalid' : True}
859 try:
860 old_object = self.__class__.objects.get(**filters)
showardafd97de2009-10-01 18:45:09 +0000861 self.resurrect_object(old_object)
showardddb90992009-02-11 23:39:32 +0000862 except self.DoesNotExist:
863 # no existing object
864 pass
showard7c785282008-05-29 19:45:12 +0000865
showarda5288b42009-07-28 20:06:08 +0000866 super(ModelWithInvalid, self).save(*args, **kwargs)
showard7c785282008-05-29 19:45:12 +0000867
868
showardafd97de2009-10-01 18:45:09 +0000869 def resurrect_object(self, old_object):
870 """
871 Called when self is about to be saved for the first time and is actually
872 "undeleting" a previously deleted object. Can be overridden by
873 subclasses to copy data as desired from the deleted entry (but this
874 superclass implementation must normally be called).
875 """
876 self.id = old_object.id
877
878
jadmanski0afbb632008-06-06 21:10:57 +0000879 def clean_object(self):
880 """
881 This method is called when an object is marked invalid.
882 Subclasses should override this to clean up relationships that
showardafd97de2009-10-01 18:45:09 +0000883 should no longer exist if the object were deleted.
884 """
jadmanski0afbb632008-06-06 21:10:57 +0000885 pass
showard7c785282008-05-29 19:45:12 +0000886
887
jadmanski0afbb632008-06-06 21:10:57 +0000888 def delete(self):
889 assert not self.invalid
890 self.invalid = True
891 self.save()
892 self.clean_object()
showard7c785282008-05-29 19:45:12 +0000893
894
jadmanski0afbb632008-06-06 21:10:57 +0000895 @classmethod
896 def get_valid_manager(cls):
897 return cls.valid_objects
showard7c785282008-05-29 19:45:12 +0000898
899
jadmanski0afbb632008-06-06 21:10:57 +0000900 class Manipulator(object):
901 """
902 Force default manipulators to look only at valid objects -
903 otherwise they will match against invalid objects when checking
904 uniqueness.
905 """
906 @classmethod
907 def _prepare(cls, model):
908 super(ModelWithInvalid.Manipulator, cls)._prepare(model)
909 cls.manager = model.valid_objects
showardf8b19042009-05-12 17:22:49 +0000910
911
912class ModelWithAttributes(object):
913 """
914 Mixin class for models that have an attribute model associated with them.
915 The attribute model is assumed to have its value field named "value".
916 """
917
918 def _get_attribute_model_and_args(self, attribute):
919 """
920 Subclasses should override this to return a tuple (attribute_model,
921 keyword_args), where attribute_model is a model class and keyword_args
922 is a dict of args to pass to attribute_model.objects.get() to get an
923 instance of the given attribute on this object.
924 """
925 raise NotImplemented
926
927
928 def set_attribute(self, attribute, value):
929 attribute_model, get_args = self._get_attribute_model_and_args(
930 attribute)
931 attribute_object, _ = attribute_model.objects.get_or_create(**get_args)
932 attribute_object.value = value
933 attribute_object.save()
934
935
936 def delete_attribute(self, attribute):
937 attribute_model, get_args = self._get_attribute_model_and_args(
938 attribute)
939 try:
940 attribute_model.objects.get(**get_args).delete()
showard16245422009-09-08 16:28:15 +0000941 except attribute_model.DoesNotExist:
showardf8b19042009-05-12 17:22:49 +0000942 pass
943
944
945 def set_or_delete_attribute(self, attribute, value):
946 if value is None:
947 self.delete_attribute(attribute)
948 else:
949 self.set_attribute(attribute, value)
showard26b7ec72009-12-21 22:43:57 +0000950
951
952class ModelWithHashManager(dbmodels.Manager):
953 """Manager for use with the ModelWithHash abstract model class"""
954
955 def create(self, **kwargs):
956 raise Exception('ModelWithHash manager should use get_or_create() '
957 'instead of create()')
958
959
960 def get_or_create(self, **kwargs):
961 kwargs['the_hash'] = self.model._compute_hash(**kwargs)
962 return super(ModelWithHashManager, self).get_or_create(**kwargs)
963
964
965class ModelWithHash(dbmodels.Model):
966 """Superclass with methods for dealing with a hash column"""
967
968 the_hash = dbmodels.CharField(max_length=40, unique=True)
969
970 objects = ModelWithHashManager()
971
972 class Meta:
973 abstract = True
974
975
976 @classmethod
977 def _compute_hash(cls, **kwargs):
978 raise NotImplementedError('Subclasses must override _compute_hash()')
979
980
981 def save(self, force_insert=False, **kwargs):
982 """Prevents saving the model in most cases
983
984 We want these models to be immutable, so the generic save() operation
985 will not work. These models should be instantiated through their the
986 model.objects.get_or_create() method instead.
987
988 The exception is that save(force_insert=True) will be allowed, since
989 that creates a new row. However, the preferred way to make instances of
990 these models is through the get_or_create() method.
991 """
992 if not force_insert:
993 # Allow a forced insert to happen; if it's a duplicate, the unique
994 # constraint will catch it later anyways
995 raise Exception('ModelWithHash is immutable')
996 super(ModelWithHash, self).save(force_insert=force_insert, **kwargs)