blob: 204799f618366ac9cf5bbce73cadf1d7579426b1 [file] [log] [blame]
showard7c785282008-05-29 19:45:12 +00001"""
2Extensions to Django's model logic.
3"""
4
showarda5288b42009-07-28 20:06:08 +00005import re
6import django.core.exceptions
showard7c785282008-05-29 19:45:12 +00007from django.db import models as dbmodels, backend, connection
showarda5288b42009-07-28 20:06:08 +00008from django.db.models.sql import query
showard7e67b432010-01-20 01:13:04 +00009import django.db.models.sql.where
showard7c785282008-05-29 19:45:12 +000010from django.utils import datastructures
showard56e93772008-10-06 10:06:22 +000011from autotest_lib.frontend.afe import readonly_connection
showard7c785282008-05-29 19:45:12 +000012
13class ValidationError(Exception):
jadmanski0afbb632008-06-06 21:10:57 +000014 """\
showarda5288b42009-07-28 20:06:08 +000015 Data validation error in adding or updating an object. The associated
jadmanski0afbb632008-06-06 21:10:57 +000016 value is a dictionary mapping field names to error strings.
17 """
showard7c785282008-05-29 19:45:12 +000018
19
showard09096d82008-07-07 23:20:49 +000020def _wrap_with_readonly(method):
mbligh1ef218d2009-08-03 16:57:56 +000021 def wrapper_method(*args, **kwargs):
22 readonly_connection.connection().set_django_connection()
23 try:
24 return method(*args, **kwargs)
25 finally:
26 readonly_connection.connection().unset_django_connection()
27 wrapper_method.__name__ = method.__name__
28 return wrapper_method
showard09096d82008-07-07 23:20:49 +000029
30
showarda5288b42009-07-28 20:06:08 +000031def _quote_name(name):
32 """Shorthand for connection.ops.quote_name()."""
33 return connection.ops.quote_name(name)
34
35
showard09096d82008-07-07 23:20:49 +000036def _wrap_generator_with_readonly(generator):
37 """
38 We have to wrap generators specially. Assume it performs
39 the query on the first call to next().
40 """
41 def wrapper_generator(*args, **kwargs):
42 generator_obj = generator(*args, **kwargs)
showard56e93772008-10-06 10:06:22 +000043 readonly_connection.connection().set_django_connection()
showard09096d82008-07-07 23:20:49 +000044 try:
45 first_value = generator_obj.next()
46 finally:
showard56e93772008-10-06 10:06:22 +000047 readonly_connection.connection().unset_django_connection()
showard09096d82008-07-07 23:20:49 +000048 yield first_value
49
50 while True:
51 yield generator_obj.next()
52
53 wrapper_generator.__name__ = generator.__name__
54 return wrapper_generator
55
56
57def _make_queryset_readonly(queryset):
58 """
59 Wrap all methods that do database queries with a readonly connection.
60 """
61 db_query_methods = ['count', 'get', 'get_or_create', 'latest', 'in_bulk',
62 'delete']
63 for method_name in db_query_methods:
64 method = getattr(queryset, method_name)
65 wrapped_method = _wrap_with_readonly(method)
66 setattr(queryset, method_name, wrapped_method)
67
68 queryset.iterator = _wrap_generator_with_readonly(queryset.iterator)
69
70
71class ReadonlyQuerySet(dbmodels.query.QuerySet):
72 """
73 QuerySet object that performs all database queries with the read-only
74 connection.
75 """
showarda5288b42009-07-28 20:06:08 +000076 def __init__(self, model=None, *args, **kwargs):
77 super(ReadonlyQuerySet, self).__init__(model, *args, **kwargs)
showard09096d82008-07-07 23:20:49 +000078 _make_queryset_readonly(self)
79
80
81 def values(self, *fields):
showarda5288b42009-07-28 20:06:08 +000082 return self._clone(klass=ReadonlyValuesQuerySet,
83 setup=True, _fields=fields)
showard09096d82008-07-07 23:20:49 +000084
85
86class ReadonlyValuesQuerySet(dbmodels.query.ValuesQuerySet):
showarda5288b42009-07-28 20:06:08 +000087 def __init__(self, model=None, *args, **kwargs):
88 super(ReadonlyValuesQuerySet, self).__init__(model, *args, **kwargs)
showard09096d82008-07-07 23:20:49 +000089 _make_queryset_readonly(self)
90
91
showard7c785282008-05-29 19:45:12 +000092class ExtendedManager(dbmodels.Manager):
jadmanski0afbb632008-06-06 21:10:57 +000093 """\
94 Extended manager supporting subquery filtering.
95 """
showard7c785282008-05-29 19:45:12 +000096
showardf828c772010-01-25 21:49:42 +000097 class CustomQuery(query.Query):
showard7e67b432010-01-20 01:13:04 +000098 def __init__(self, *args, **kwargs):
showardf828c772010-01-25 21:49:42 +000099 super(ExtendedManager.CustomQuery, self).__init__(*args, **kwargs)
showard7e67b432010-01-20 01:13:04 +0000100 self._custom_joins = []
101
102
showarda5288b42009-07-28 20:06:08 +0000103 def clone(self, klass=None, **kwargs):
showardf828c772010-01-25 21:49:42 +0000104 obj = super(ExtendedManager.CustomQuery, self).clone(klass)
showard7e67b432010-01-20 01:13:04 +0000105 obj._custom_joins = list(self._custom_joins)
showarda5288b42009-07-28 20:06:08 +0000106 return obj
showard08f981b2008-06-24 21:59:03 +0000107
showard7e67b432010-01-20 01:13:04 +0000108
109 def combine(self, rhs, connector):
showardf828c772010-01-25 21:49:42 +0000110 super(ExtendedManager.CustomQuery, self).combine(rhs, connector)
showard7e67b432010-01-20 01:13:04 +0000111 if hasattr(rhs, '_custom_joins'):
112 self._custom_joins.extend(rhs._custom_joins)
113
114
115 def add_custom_join(self, table, condition, join_type,
116 condition_values=(), alias=None):
117 if alias is None:
118 alias = table
119 join_dict = dict(table=table,
120 condition=condition,
121 condition_values=condition_values,
122 join_type=join_type,
123 alias=alias)
124 self._custom_joins.append(join_dict)
125
126
showard7e67b432010-01-20 01:13:04 +0000127 @classmethod
128 def convert_query(self, query_set):
129 """
showardf828c772010-01-25 21:49:42 +0000130 Convert the query set's "query" attribute to a CustomQuery.
showard7e67b432010-01-20 01:13:04 +0000131 """
132 # Make a copy of the query set
133 query_set = query_set.all()
134 query_set.query = query_set.query.clone(
showardf828c772010-01-25 21:49:42 +0000135 klass=ExtendedManager.CustomQuery,
showard7e67b432010-01-20 01:13:04 +0000136 _custom_joins=[])
137 return query_set
showard43a3d262008-11-12 18:17:05 +0000138
139
showard7e67b432010-01-20 01:13:04 +0000140 class _WhereClause(object):
141 """Object allowing us to inject arbitrary SQL into Django queries.
showard43a3d262008-11-12 18:17:05 +0000142
showard7e67b432010-01-20 01:13:04 +0000143 By using this instead of extra(where=...), we can still freely combine
144 queries with & and |.
showarda5288b42009-07-28 20:06:08 +0000145 """
showard7e67b432010-01-20 01:13:04 +0000146 def __init__(self, clause, values=()):
147 self._clause = clause
148 self._values = values
showarda5288b42009-07-28 20:06:08 +0000149
showard7e67b432010-01-20 01:13:04 +0000150
Dale Curtis74a314b2011-06-23 14:55:46 -0700151 def as_sql(self, qn=None, connection=None):
showard7e67b432010-01-20 01:13:04 +0000152 return self._clause, self._values
153
154
155 def relabel_aliases(self, change_map):
156 return
showard43a3d262008-11-12 18:17:05 +0000157
158
showard8b0ea222009-12-23 19:23:03 +0000159 def add_join(self, query_set, join_table, join_key, join_condition='',
showard7e67b432010-01-20 01:13:04 +0000160 join_condition_values=(), join_from_key=None, alias=None,
161 suffix='', exclude=False, force_left_join=False):
162 """Add a join to query_set.
163
164 Join looks like this:
165 (INNER|LEFT) JOIN <join_table> AS <alias>
166 ON (<this table>.<join_from_key> = <join_table>.<join_key>
167 and <join_condition>)
168
showard0957a842009-05-11 19:25:08 +0000169 @param join_table table to join to
170 @param join_key field referencing back to this model to use for the join
171 @param join_condition extra condition for the ON clause of the join
showard7e67b432010-01-20 01:13:04 +0000172 @param join_condition_values values to substitute into join_condition
173 @param join_from_key column on this model to join from.
showard8b0ea222009-12-23 19:23:03 +0000174 @param alias alias to use for for join
175 @param suffix suffix to add to join_table for the join alias, if no
176 alias is provided
showard0957a842009-05-11 19:25:08 +0000177 @param exclude if true, exclude rows that match this join (will use a
showarda5288b42009-07-28 20:06:08 +0000178 LEFT OUTER JOIN and an appropriate WHERE condition)
showardc4780402009-08-31 18:31:34 +0000179 @param force_left_join - if true, a LEFT OUTER JOIN will be used
180 instead of an INNER JOIN regardless of other options
showard0957a842009-05-11 19:25:08 +0000181 """
showard7e67b432010-01-20 01:13:04 +0000182 join_from_table = query_set.model._meta.db_table
183 if join_from_key is None:
184 join_from_key = self.model._meta.pk.name
185 if alias is None:
186 alias = join_table + suffix
187 full_join_key = _quote_name(alias) + '.' + _quote_name(join_key)
188 full_join_condition = '%s = %s.%s' % (full_join_key,
189 _quote_name(join_from_table),
190 _quote_name(join_from_key))
showard43a3d262008-11-12 18:17:05 +0000191 if join_condition:
192 full_join_condition += ' AND (' + join_condition + ')'
193 if exclude or force_left_join:
showarda5288b42009-07-28 20:06:08 +0000194 join_type = query_set.query.LOUTER
showard43a3d262008-11-12 18:17:05 +0000195 else:
showarda5288b42009-07-28 20:06:08 +0000196 join_type = query_set.query.INNER
showard43a3d262008-11-12 18:17:05 +0000197
showardf828c772010-01-25 21:49:42 +0000198 query_set = self.CustomQuery.convert_query(query_set)
showard7e67b432010-01-20 01:13:04 +0000199 query_set.query.add_custom_join(join_table,
200 full_join_condition,
201 join_type,
202 condition_values=join_condition_values,
203 alias=alias)
showard43a3d262008-11-12 18:17:05 +0000204
showard7e67b432010-01-20 01:13:04 +0000205 if exclude:
206 query_set = query_set.extra(where=[full_join_key + ' IS NULL'])
207
208 return query_set
209
210
211 def _info_for_many_to_one_join(self, field, join_to_query, alias):
212 """
213 @param field: the ForeignKey field on the related model
214 @param join_to_query: the query over the related model that we're
215 joining to
216 @param alias: alias of joined table
217 """
218 info = {}
219 rhs_table = join_to_query.model._meta.db_table
220 info['rhs_table'] = rhs_table
221 info['rhs_column'] = field.column
222 info['lhs_column'] = field.rel.get_related_field().column
223 rhs_where = join_to_query.query.where
224 rhs_where.relabel_aliases({rhs_table: alias})
Dale Curtis74a314b2011-06-23 14:55:46 -0700225 compiler = join_to_query.query.get_compiler(using=join_to_query.db)
226 initial_clause, values = compiler.as_sql()
227 all_clauses = (initial_clause,)
228 if hasattr(join_to_query.query, 'extra_where'):
229 all_clauses += join_to_query.query.extra_where
230 info['where_clause'] = (
231 ' AND '.join('(%s)' % clause for clause in all_clauses))
showard7e67b432010-01-20 01:13:04 +0000232 info['values'] = values
233 return info
234
235
236 def _info_for_many_to_many_join(self, m2m_field, join_to_query, alias,
237 m2m_is_on_this_model):
238 """
239 @param m2m_field: a Django field representing the M2M relationship.
240 It uses a pivot table with the following structure:
241 this model table <---> M2M pivot table <---> joined model table
242 @param join_to_query: the query over the related model that we're
243 joining to.
244 @param alias: alias of joined table
245 """
246 if m2m_is_on_this_model:
247 # referenced field on this model
248 lhs_id_field = self.model._meta.pk
249 # foreign key on the pivot table referencing lhs_id_field
250 m2m_lhs_column = m2m_field.m2m_column_name()
251 # foreign key on the pivot table referencing rhd_id_field
252 m2m_rhs_column = m2m_field.m2m_reverse_name()
253 # referenced field on related model
254 rhs_id_field = m2m_field.rel.get_related_field()
255 else:
256 lhs_id_field = m2m_field.rel.get_related_field()
257 m2m_lhs_column = m2m_field.m2m_reverse_name()
258 m2m_rhs_column = m2m_field.m2m_column_name()
259 rhs_id_field = join_to_query.model._meta.pk
260
261 info = {}
262 info['rhs_table'] = m2m_field.m2m_db_table()
263 info['rhs_column'] = m2m_lhs_column
264 info['lhs_column'] = lhs_id_field.column
265
266 # select the ID of related models relevant to this join. we can only do
267 # a single join, so we need to gather this information up front and
268 # include it in the join condition.
269 rhs_ids = join_to_query.values_list(rhs_id_field.attname, flat=True)
270 assert len(rhs_ids) == 1, ('Many-to-many custom field joins can only '
271 'match a single related object.')
272 rhs_id = rhs_ids[0]
273
274 info['where_clause'] = '%s.%s = %s' % (_quote_name(alias),
275 _quote_name(m2m_rhs_column),
276 rhs_id)
277 info['values'] = ()
278 return info
279
280
281 def join_custom_field(self, query_set, join_to_query, alias,
282 left_join=True):
283 """Join to a related model to create a custom field in the given query.
284
285 This method is used to construct a custom field on the given query based
286 on a many-valued relationsip. join_to_query should be a simple query
287 (no joins) on the related model which returns at most one related row
288 per instance of this model.
289
290 For many-to-one relationships, the joined table contains the matching
291 row from the related model it one is related, NULL otherwise.
292
293 For many-to-many relationships, the joined table contains the matching
294 row if it's related, NULL otherwise.
295 """
296 relationship_type, field = self.determine_relationship(
297 join_to_query.model)
298
299 if relationship_type == self.MANY_TO_ONE:
300 info = self._info_for_many_to_one_join(field, join_to_query, alias)
301 elif relationship_type == self.M2M_ON_RELATED_MODEL:
302 info = self._info_for_many_to_many_join(
303 m2m_field=field, join_to_query=join_to_query, alias=alias,
304 m2m_is_on_this_model=False)
305 elif relationship_type ==self.M2M_ON_THIS_MODEL:
306 info = self._info_for_many_to_many_join(
307 m2m_field=field, join_to_query=join_to_query, alias=alias,
308 m2m_is_on_this_model=True)
309
310 return self.add_join(query_set, info['rhs_table'], info['rhs_column'],
311 join_from_key=info['lhs_column'],
312 join_condition=info['where_clause'],
313 join_condition_values=info['values'],
314 alias=alias,
315 force_left_join=left_join)
316
317
showardf828c772010-01-25 21:49:42 +0000318 def key_on_joined_table(self, join_to_query):
319 """Get a non-null column on the table joined for the given query.
320
321 This analyzes the join that would be produced if join_to_query were
322 passed to join_custom_field.
323 """
324 relationship_type, field = self.determine_relationship(
325 join_to_query.model)
326 if relationship_type == self.MANY_TO_ONE:
327 return join_to_query.model._meta.pk.column
328 return field.m2m_column_name() # any column on the M2M table will do
329
330
showard7e67b432010-01-20 01:13:04 +0000331 def add_where(self, query_set, where, values=()):
332 query_set = query_set.all()
333 query_set.query.where.add(self._WhereClause(where, values),
334 django.db.models.sql.where.AND)
showardc4780402009-08-31 18:31:34 +0000335 return query_set
showard7c785282008-05-29 19:45:12 +0000336
337
showardeaccf8f2009-04-16 03:11:33 +0000338 def _get_quoted_field(self, table, field):
showarda5288b42009-07-28 20:06:08 +0000339 return _quote_name(table) + '.' + _quote_name(field)
showard5ef36e92008-07-02 16:37:09 +0000340
341
showard7c199df2008-10-03 10:17:15 +0000342 def get_key_on_this_table(self, key_field=None):
showard5ef36e92008-07-02 16:37:09 +0000343 if key_field is None:
344 # default to primary key
345 key_field = self.model._meta.pk.column
346 return self._get_quoted_field(self.model._meta.db_table, key_field)
347
348
showardeaccf8f2009-04-16 03:11:33 +0000349 def escape_user_sql(self, sql):
350 return sql.replace('%', '%%')
351
showard5ef36e92008-07-02 16:37:09 +0000352
showard0957a842009-05-11 19:25:08 +0000353 def _custom_select_query(self, query_set, selects):
Dale Curtis74a314b2011-06-23 14:55:46 -0700354 compiler = query_set.query.get_compiler(using=query_set.db)
355 sql, params = compiler.as_sql()
showarda5288b42009-07-28 20:06:08 +0000356 from_ = sql[sql.find(' FROM'):]
357
358 if query_set.query.distinct:
showard0957a842009-05-11 19:25:08 +0000359 distinct = 'DISTINCT '
360 else:
361 distinct = ''
showarda5288b42009-07-28 20:06:08 +0000362
363 sql_query = ('SELECT ' + distinct + ','.join(selects) + from_)
showard0957a842009-05-11 19:25:08 +0000364 cursor = readonly_connection.connection().cursor()
365 cursor.execute(sql_query, params)
366 return cursor.fetchall()
367
368
showard68693f72009-05-20 00:31:53 +0000369 def _is_relation_to(self, field, model_class):
370 return field.rel and field.rel.to is model_class
showard0957a842009-05-11 19:25:08 +0000371
372
showard7e67b432010-01-20 01:13:04 +0000373 MANY_TO_ONE = object()
374 M2M_ON_RELATED_MODEL = object()
375 M2M_ON_THIS_MODEL = object()
376
377 def determine_relationship(self, related_model):
378 """
379 Determine the relationship between this model and related_model.
380
381 related_model must have some sort of many-valued relationship to this
382 manager's model.
383 @returns (relationship_type, field), where relationship_type is one of
384 MANY_TO_ONE, M2M_ON_RELATED_MODEL, M2M_ON_THIS_MODEL, and field
385 is the Django field object for the relationship.
386 """
387 # look for a foreign key field on related_model relating to this model
388 for field in related_model._meta.fields:
389 if self._is_relation_to(field, self.model):
390 return self.MANY_TO_ONE, field
391
392 # look for an M2M field on related_model relating to this model
393 for field in related_model._meta.many_to_many:
394 if self._is_relation_to(field, self.model):
395 return self.M2M_ON_RELATED_MODEL, field
396
397 # maybe this model has the many-to-many field
398 for field in self.model._meta.many_to_many:
399 if self._is_relation_to(field, related_model):
400 return self.M2M_ON_THIS_MODEL, field
401
402 raise ValueError('%s has no relation to %s' %
403 (related_model, self.model))
404
405
showard68693f72009-05-20 00:31:53 +0000406 def _get_pivot_iterator(self, base_objects_by_id, related_model):
showard0957a842009-05-11 19:25:08 +0000407 """
showard68693f72009-05-20 00:31:53 +0000408 Determine the relationship between this model and related_model, and
409 return a pivot iterator.
410 @param base_objects_by_id: dict of instances of this model indexed by
411 their IDs
412 @returns a pivot iterator, which yields a tuple (base_object,
413 related_object) for each relationship between a base object and a
414 related object. all base_object instances come from base_objects_by_id.
showard7e67b432010-01-20 01:13:04 +0000415 Note -- this depends on Django model internals.
showard0957a842009-05-11 19:25:08 +0000416 """
showard7e67b432010-01-20 01:13:04 +0000417 relationship_type, field = self.determine_relationship(related_model)
418 if relationship_type == self.MANY_TO_ONE:
419 return self._many_to_one_pivot(base_objects_by_id,
420 related_model, field)
421 elif relationship_type == self.M2M_ON_RELATED_MODEL:
422 return self._many_to_many_pivot(
showard68693f72009-05-20 00:31:53 +0000423 base_objects_by_id, related_model, field.m2m_db_table(),
424 field.m2m_reverse_name(), field.m2m_column_name())
showard7e67b432010-01-20 01:13:04 +0000425 else:
426 assert relationship_type == self.M2M_ON_THIS_MODEL
427 return self._many_to_many_pivot(
showard68693f72009-05-20 00:31:53 +0000428 base_objects_by_id, related_model, field.m2m_db_table(),
429 field.m2m_column_name(), field.m2m_reverse_name())
showard0957a842009-05-11 19:25:08 +0000430
showard0957a842009-05-11 19:25:08 +0000431
showard68693f72009-05-20 00:31:53 +0000432 def _many_to_one_pivot(self, base_objects_by_id, related_model,
433 foreign_key_field):
434 """
435 @returns a pivot iterator - see _get_pivot_iterator()
436 """
437 filter_data = {foreign_key_field.name + '__pk__in':
438 base_objects_by_id.keys()}
439 for related_object in related_model.objects.filter(**filter_data):
showarda5a72c92009-08-20 23:35:21 +0000440 # lookup base object in the dict, rather than grabbing it from the
441 # related object. we need to return instances from the dict, not
442 # fresh instances of the same models (and grabbing model instances
443 # from the related models incurs a DB query each time).
444 base_object_id = getattr(related_object, foreign_key_field.attname)
445 base_object = base_objects_by_id[base_object_id]
showard68693f72009-05-20 00:31:53 +0000446 yield base_object, related_object
447
448
449 def _query_pivot_table(self, base_objects_by_id, pivot_table,
450 pivot_from_field, pivot_to_field):
showard0957a842009-05-11 19:25:08 +0000451 """
452 @param id_list list of IDs of self.model objects to include
453 @param pivot_table the name of the pivot table
454 @param pivot_from_field a field name on pivot_table referencing
455 self.model
456 @param pivot_to_field a field name on pivot_table referencing the
457 related model.
showard68693f72009-05-20 00:31:53 +0000458 @returns pivot list of IDs (base_id, related_id)
showard0957a842009-05-11 19:25:08 +0000459 """
460 query = """
461 SELECT %(from_field)s, %(to_field)s
462 FROM %(table)s
463 WHERE %(from_field)s IN (%(id_list)s)
464 """ % dict(from_field=pivot_from_field,
465 to_field=pivot_to_field,
466 table=pivot_table,
showard68693f72009-05-20 00:31:53 +0000467 id_list=','.join(str(id_) for id_
468 in base_objects_by_id.iterkeys()))
showard0957a842009-05-11 19:25:08 +0000469 cursor = readonly_connection.connection().cursor()
470 cursor.execute(query)
showard68693f72009-05-20 00:31:53 +0000471 return cursor.fetchall()
showard0957a842009-05-11 19:25:08 +0000472
473
showard68693f72009-05-20 00:31:53 +0000474 def _many_to_many_pivot(self, base_objects_by_id, related_model,
475 pivot_table, pivot_from_field, pivot_to_field):
476 """
477 @param pivot_table: see _query_pivot_table
478 @param pivot_from_field: see _query_pivot_table
479 @param pivot_to_field: see _query_pivot_table
480 @returns a pivot iterator - see _get_pivot_iterator()
481 """
482 id_pivot = self._query_pivot_table(base_objects_by_id, pivot_table,
483 pivot_from_field, pivot_to_field)
484
485 all_related_ids = list(set(related_id for base_id, related_id
486 in id_pivot))
487 related_objects_by_id = related_model.objects.in_bulk(all_related_ids)
488
489 for base_id, related_id in id_pivot:
490 yield base_objects_by_id[base_id], related_objects_by_id[related_id]
491
492
493 def populate_relationships(self, base_objects, related_model,
showard0957a842009-05-11 19:25:08 +0000494 related_list_name):
495 """
showard68693f72009-05-20 00:31:53 +0000496 For each instance of this model in base_objects, add a field named
497 related_list_name listing all the related objects of type related_model.
498 related_model must be in a many-to-one or many-to-many relationship with
499 this model.
500 @param base_objects - list of instances of this model
501 @param related_model - model class related to this model
502 @param related_list_name - attribute name in which to store the related
503 object list.
showard0957a842009-05-11 19:25:08 +0000504 """
showard68693f72009-05-20 00:31:53 +0000505 if not base_objects:
showard0957a842009-05-11 19:25:08 +0000506 # if we don't bail early, we'll get a SQL error later
507 return
showard0957a842009-05-11 19:25:08 +0000508
showard68693f72009-05-20 00:31:53 +0000509 base_objects_by_id = dict((base_object._get_pk_val(), base_object)
510 for base_object in base_objects)
511 pivot_iterator = self._get_pivot_iterator(base_objects_by_id,
512 related_model)
showard0957a842009-05-11 19:25:08 +0000513
showard68693f72009-05-20 00:31:53 +0000514 for base_object in base_objects:
515 setattr(base_object, related_list_name, [])
516
517 for base_object, related_object in pivot_iterator:
518 getattr(base_object, related_list_name).append(related_object)
showard0957a842009-05-11 19:25:08 +0000519
520
jamesrene3656232010-03-02 00:00:30 +0000521class ModelWithInvalidQuerySet(dbmodels.query.QuerySet):
522 """
523 QuerySet that handles delete() properly for models with an "invalid" bit
524 """
525 def delete(self):
526 for model in self:
527 model.delete()
528
529
530class ModelWithInvalidManager(ExtendedManager):
531 """
532 Manager for objects with an "invalid" bit
533 """
534 def get_query_set(self):
535 return ModelWithInvalidQuerySet(self.model)
536
537
538class ValidObjectsManager(ModelWithInvalidManager):
jadmanski0afbb632008-06-06 21:10:57 +0000539 """
540 Manager returning only objects with invalid=False.
541 """
542 def get_query_set(self):
543 queryset = super(ValidObjectsManager, self).get_query_set()
544 return queryset.filter(invalid=False)
showard7c785282008-05-29 19:45:12 +0000545
546
547class ModelExtensions(object):
jadmanski0afbb632008-06-06 21:10:57 +0000548 """\
549 Mixin with convenience functions for models, built on top of the
550 default Django model functions.
551 """
552 # TODO: at least some of these functions really belong in a custom
553 # Manager class
showard7c785282008-05-29 19:45:12 +0000554
jadmanski0afbb632008-06-06 21:10:57 +0000555 field_dict = None
556 # subclasses should override if they want to support smart_get() by name
557 name_field = None
showard7c785282008-05-29 19:45:12 +0000558
559
jadmanski0afbb632008-06-06 21:10:57 +0000560 @classmethod
561 def get_field_dict(cls):
562 if cls.field_dict is None:
563 cls.field_dict = {}
564 for field in cls._meta.fields:
565 cls.field_dict[field.name] = field
566 return cls.field_dict
showard7c785282008-05-29 19:45:12 +0000567
568
jadmanski0afbb632008-06-06 21:10:57 +0000569 @classmethod
570 def clean_foreign_keys(cls, data):
571 """\
572 -Convert foreign key fields in data from <field>_id to just
573 <field>.
574 -replace foreign key objects with their IDs
575 This method modifies data in-place.
576 """
577 for field in cls._meta.fields:
578 if not field.rel:
579 continue
580 if (field.attname != field.name and
581 field.attname in data):
582 data[field.name] = data[field.attname]
583 del data[field.attname]
showarde732ee72008-09-23 19:15:43 +0000584 if field.name not in data:
585 continue
jadmanski0afbb632008-06-06 21:10:57 +0000586 value = data[field.name]
587 if isinstance(value, dbmodels.Model):
showardf8b19042009-05-12 17:22:49 +0000588 data[field.name] = value._get_pk_val()
showard7c785282008-05-29 19:45:12 +0000589
590
showard21baa452008-10-21 00:08:39 +0000591 @classmethod
592 def _convert_booleans(cls, data):
593 """
594 Ensure BooleanFields actually get bool values. The Django MySQL
595 backend returns ints for BooleanFields, which is almost always not
596 a problem, but it can be annoying in certain situations.
597 """
598 for field in cls._meta.fields:
showardf8b19042009-05-12 17:22:49 +0000599 if type(field) == dbmodels.BooleanField and field.name in data:
showard21baa452008-10-21 00:08:39 +0000600 data[field.name] = bool(data[field.name])
601
602
jadmanski0afbb632008-06-06 21:10:57 +0000603 # TODO(showard) - is there a way to not have to do this?
604 @classmethod
605 def provide_default_values(cls, data):
606 """\
607 Provide default values for fields with default values which have
608 nothing passed in.
showard7c785282008-05-29 19:45:12 +0000609
jadmanski0afbb632008-06-06 21:10:57 +0000610 For CharField and TextField fields with "blank=True", if nothing
611 is passed, we fill in an empty string value, even if there's no
612 default set.
613 """
614 new_data = dict(data)
615 field_dict = cls.get_field_dict()
616 for name, obj in field_dict.iteritems():
617 if data.get(name) is not None:
618 continue
619 if obj.default is not dbmodels.fields.NOT_PROVIDED:
620 new_data[name] = obj.default
621 elif (isinstance(obj, dbmodels.CharField) or
622 isinstance(obj, dbmodels.TextField)):
623 new_data[name] = ''
624 return new_data
showard7c785282008-05-29 19:45:12 +0000625
626
jadmanski0afbb632008-06-06 21:10:57 +0000627 @classmethod
628 def convert_human_readable_values(cls, data, to_human_readable=False):
629 """\
630 Performs conversions on user-supplied field data, to make it
631 easier for users to pass human-readable data.
showard7c785282008-05-29 19:45:12 +0000632
jadmanski0afbb632008-06-06 21:10:57 +0000633 For all fields that have choice sets, convert their values
634 from human-readable strings to enum values, if necessary. This
635 allows users to pass strings instead of the corresponding
636 integer values.
showard7c785282008-05-29 19:45:12 +0000637
jadmanski0afbb632008-06-06 21:10:57 +0000638 For all foreign key fields, call smart_get with the supplied
639 data. This allows the user to pass either an ID value or
640 the name of the object as a string.
showard7c785282008-05-29 19:45:12 +0000641
jadmanski0afbb632008-06-06 21:10:57 +0000642 If to_human_readable=True, perform the inverse - i.e. convert
643 numeric values to human readable values.
showard7c785282008-05-29 19:45:12 +0000644
jadmanski0afbb632008-06-06 21:10:57 +0000645 This method modifies data in-place.
646 """
647 field_dict = cls.get_field_dict()
648 for field_name in data:
showarde732ee72008-09-23 19:15:43 +0000649 if field_name not in field_dict or data[field_name] is None:
jadmanski0afbb632008-06-06 21:10:57 +0000650 continue
651 field_obj = field_dict[field_name]
652 # convert enum values
653 if field_obj.choices:
654 for choice_data in field_obj.choices:
655 # choice_data is (value, name)
656 if to_human_readable:
657 from_val, to_val = choice_data
658 else:
659 to_val, from_val = choice_data
660 if from_val == data[field_name]:
661 data[field_name] = to_val
662 break
663 # convert foreign key values
664 elif field_obj.rel:
showarda4ea5742009-02-17 20:56:23 +0000665 dest_obj = field_obj.rel.to.smart_get(data[field_name],
666 valid_only=False)
showardf8b19042009-05-12 17:22:49 +0000667 if to_human_readable:
Paul Pendlebury5a8c6ad2011-02-01 07:20:17 -0800668 # parameterized_jobs do not have a name_field
669 if (field_name != 'parameterized_job' and
670 dest_obj.name_field is not None):
showardf8b19042009-05-12 17:22:49 +0000671 data[field_name] = getattr(dest_obj,
672 dest_obj.name_field)
jadmanski0afbb632008-06-06 21:10:57 +0000673 else:
showardb0a73032009-03-27 18:35:41 +0000674 data[field_name] = dest_obj
showard7c785282008-05-29 19:45:12 +0000675
676
jadmanski0afbb632008-06-06 21:10:57 +0000677 @classmethod
678 def validate_field_names(cls, data):
679 'Checks for extraneous fields in data.'
680 errors = {}
681 field_dict = cls.get_field_dict()
682 for field_name in data:
683 if field_name not in field_dict:
684 errors[field_name] = 'No field of this name'
685 return errors
showard7c785282008-05-29 19:45:12 +0000686
687
jadmanski0afbb632008-06-06 21:10:57 +0000688 @classmethod
689 def prepare_data_args(cls, data, kwargs):
690 'Common preparation for add_object and update_object'
691 data = dict(data) # don't modify the default keyword arg
692 data.update(kwargs)
693 # must check for extraneous field names here, while we have the
694 # data in a dict
695 errors = cls.validate_field_names(data)
696 if errors:
697 raise ValidationError(errors)
698 cls.convert_human_readable_values(data)
699 return data
showard7c785282008-05-29 19:45:12 +0000700
701
Dale Curtis74a314b2011-06-23 14:55:46 -0700702 def _validate_unique(self):
jadmanski0afbb632008-06-06 21:10:57 +0000703 """\
704 Validate that unique fields are unique. Django manipulators do
705 this too, but they're a huge pain to use manually. Trust me.
706 """
707 errors = {}
708 cls = type(self)
709 field_dict = self.get_field_dict()
710 manager = cls.get_valid_manager()
711 for field_name, field_obj in field_dict.iteritems():
712 if not field_obj.unique:
713 continue
showard7c785282008-05-29 19:45:12 +0000714
jadmanski0afbb632008-06-06 21:10:57 +0000715 value = getattr(self, field_name)
showardbd18ab72009-09-18 21:20:27 +0000716 if value is None and field_obj.auto_created:
717 # don't bother checking autoincrement fields about to be
718 # generated
719 continue
720
jadmanski0afbb632008-06-06 21:10:57 +0000721 existing_objs = manager.filter(**{field_name : value})
722 num_existing = existing_objs.count()
showard7c785282008-05-29 19:45:12 +0000723
jadmanski0afbb632008-06-06 21:10:57 +0000724 if num_existing == 0:
725 continue
726 if num_existing == 1 and existing_objs[0].id == self.id:
727 continue
728 errors[field_name] = (
729 'This value must be unique (%s)' % (value))
730 return errors
showard7c785282008-05-29 19:45:12 +0000731
732
showarda5288b42009-07-28 20:06:08 +0000733 def _validate(self):
734 """
735 First coerces all fields on this instance to their proper Python types.
736 Then runs validation on every field. Returns a dictionary of
737 field_name -> error_list.
738
739 Based on validate() from django.db.models.Model in Django 0.96, which
740 was removed in Django 1.0. It should reappear in a later version. See:
741 http://code.djangoproject.com/ticket/6845
742 """
743 error_dict = {}
744 for f in self._meta.fields:
745 try:
746 python_value = f.to_python(
747 getattr(self, f.attname, f.get_default()))
748 except django.core.exceptions.ValidationError, e:
jamesren1e0a4ce2010-04-21 17:45:11 +0000749 error_dict[f.name] = str(e)
showarda5288b42009-07-28 20:06:08 +0000750 continue
751
752 if not f.blank and not python_value:
753 error_dict[f.name] = 'This field is required.'
754 continue
755
756 setattr(self, f.attname, python_value)
757
758 return error_dict
759
760
jadmanski0afbb632008-06-06 21:10:57 +0000761 def do_validate(self):
showarda5288b42009-07-28 20:06:08 +0000762 errors = self._validate()
Dale Curtis74a314b2011-06-23 14:55:46 -0700763 unique_errors = self._validate_unique()
jadmanski0afbb632008-06-06 21:10:57 +0000764 for field_name, error in unique_errors.iteritems():
765 errors.setdefault(field_name, error)
766 if errors:
767 raise ValidationError(errors)
showard7c785282008-05-29 19:45:12 +0000768
769
jadmanski0afbb632008-06-06 21:10:57 +0000770 # actually (externally) useful methods follow
showard7c785282008-05-29 19:45:12 +0000771
jadmanski0afbb632008-06-06 21:10:57 +0000772 @classmethod
773 def add_object(cls, data={}, **kwargs):
774 """\
775 Returns a new object created with the given data (a dictionary
776 mapping field names to values). Merges any extra keyword args
777 into data.
778 """
779 data = cls.prepare_data_args(data, kwargs)
780 data = cls.provide_default_values(data)
781 obj = cls(**data)
782 obj.do_validate()
783 obj.save()
784 return obj
showard7c785282008-05-29 19:45:12 +0000785
786
jadmanski0afbb632008-06-06 21:10:57 +0000787 def update_object(self, data={}, **kwargs):
788 """\
789 Updates the object with the given data (a dictionary mapping
790 field names to values). Merges any extra keyword args into
791 data.
792 """
793 data = self.prepare_data_args(data, kwargs)
794 for field_name, value in data.iteritems():
showardb0a73032009-03-27 18:35:41 +0000795 setattr(self, field_name, value)
jadmanski0afbb632008-06-06 21:10:57 +0000796 self.do_validate()
797 self.save()
showard7c785282008-05-29 19:45:12 +0000798
799
showard8bfb5cb2009-10-07 20:49:15 +0000800 # see query_objects()
801 _SPECIAL_FILTER_KEYS = ('query_start', 'query_limit', 'sort_by',
802 'extra_args', 'extra_where', 'no_distinct')
803
804
jadmanski0afbb632008-06-06 21:10:57 +0000805 @classmethod
showard8bfb5cb2009-10-07 20:49:15 +0000806 def _extract_special_params(cls, filter_data):
807 """
808 @returns a tuple of dicts (special_params, regular_filters), where
809 special_params contains the parameters we handle specially and
810 regular_filters is the remaining data to be handled by Django.
811 """
812 regular_filters = dict(filter_data)
813 special_params = {}
814 for key in cls._SPECIAL_FILTER_KEYS:
815 if key in regular_filters:
816 special_params[key] = regular_filters.pop(key)
817 return special_params, regular_filters
818
819
820 @classmethod
821 def apply_presentation(cls, query, filter_data):
822 """
823 Apply presentation parameters -- sorting and paging -- to the given
824 query.
825 @returns new query with presentation applied
826 """
827 special_params, _ = cls._extract_special_params(filter_data)
828 sort_by = special_params.get('sort_by', None)
829 if sort_by:
830 assert isinstance(sort_by, list) or isinstance(sort_by, tuple)
showard8b0ea222009-12-23 19:23:03 +0000831 query = query.extra(order_by=sort_by)
showard8bfb5cb2009-10-07 20:49:15 +0000832
833 query_start = special_params.get('query_start', None)
834 query_limit = special_params.get('query_limit', None)
835 if query_start is not None:
836 if query_limit is None:
837 raise ValueError('Cannot pass query_start without query_limit')
838 # query_limit is passed as a page size
showard7074b742009-10-12 20:30:04 +0000839 query_limit += query_start
840 return query[query_start:query_limit]
showard8bfb5cb2009-10-07 20:49:15 +0000841
842
843 @classmethod
844 def query_objects(cls, filter_data, valid_only=True, initial_query=None,
845 apply_presentation=True):
jadmanski0afbb632008-06-06 21:10:57 +0000846 """\
847 Returns a QuerySet object for querying the given model_class
848 with the given filter_data. Optional special arguments in
849 filter_data include:
850 -query_start: index of first return to return
851 -query_limit: maximum number of results to return
852 -sort_by: list of fields to sort on. prefixing a '-' onto a
853 field name changes the sort to descending order.
854 -extra_args: keyword args to pass to query.extra() (see Django
855 DB layer documentation)
showarda5288b42009-07-28 20:06:08 +0000856 -extra_where: extra WHERE clause to append
showard8bfb5cb2009-10-07 20:49:15 +0000857 -no_distinct: if True, a DISTINCT will not be added to the SELECT
jadmanski0afbb632008-06-06 21:10:57 +0000858 """
showard8bfb5cb2009-10-07 20:49:15 +0000859 special_params, regular_filters = cls._extract_special_params(
860 filter_data)
showard7c785282008-05-29 19:45:12 +0000861
showard7ac7b7a2008-07-21 20:24:29 +0000862 if initial_query is None:
863 if valid_only:
864 initial_query = cls.get_valid_manager()
865 else:
866 initial_query = cls.objects
showard8bfb5cb2009-10-07 20:49:15 +0000867
868 query = initial_query.filter(**regular_filters)
869
870 use_distinct = not special_params.get('no_distinct', False)
showard7ac7b7a2008-07-21 20:24:29 +0000871 if use_distinct:
872 query = query.distinct()
showard7c785282008-05-29 19:45:12 +0000873
showard8bfb5cb2009-10-07 20:49:15 +0000874 extra_args = special_params.get('extra_args', {})
875 extra_where = special_params.get('extra_where', None)
876 if extra_where:
877 # escape %'s
878 extra_where = cls.objects.escape_user_sql(extra_where)
879 extra_args.setdefault('where', []).append(extra_where)
jadmanski0afbb632008-06-06 21:10:57 +0000880 if extra_args:
881 query = query.extra(**extra_args)
showard09096d82008-07-07 23:20:49 +0000882 query = query._clone(klass=ReadonlyQuerySet)
showard7c785282008-05-29 19:45:12 +0000883
showard8bfb5cb2009-10-07 20:49:15 +0000884 if apply_presentation:
885 query = cls.apply_presentation(query, filter_data)
886
887 return query
showard7c785282008-05-29 19:45:12 +0000888
889
jadmanski0afbb632008-06-06 21:10:57 +0000890 @classmethod
showard585c2ab2008-07-23 19:29:49 +0000891 def query_count(cls, filter_data, initial_query=None):
jadmanski0afbb632008-06-06 21:10:57 +0000892 """\
893 Like query_objects, but retreive only the count of results.
894 """
895 filter_data.pop('query_start', None)
896 filter_data.pop('query_limit', None)
showard585c2ab2008-07-23 19:29:49 +0000897 query = cls.query_objects(filter_data, initial_query=initial_query)
898 return query.count()
showard7c785282008-05-29 19:45:12 +0000899
900
jadmanski0afbb632008-06-06 21:10:57 +0000901 @classmethod
902 def clean_object_dicts(cls, field_dicts):
903 """\
904 Take a list of dicts corresponding to object (as returned by
905 query.values()) and clean the data to be more suitable for
906 returning to the user.
907 """
showarde732ee72008-09-23 19:15:43 +0000908 for field_dict in field_dicts:
909 cls.clean_foreign_keys(field_dict)
showard21baa452008-10-21 00:08:39 +0000910 cls._convert_booleans(field_dict)
showarde732ee72008-09-23 19:15:43 +0000911 cls.convert_human_readable_values(field_dict,
912 to_human_readable=True)
showard7c785282008-05-29 19:45:12 +0000913
914
jadmanski0afbb632008-06-06 21:10:57 +0000915 @classmethod
showard8bfb5cb2009-10-07 20:49:15 +0000916 def list_objects(cls, filter_data, initial_query=None):
jadmanski0afbb632008-06-06 21:10:57 +0000917 """\
918 Like query_objects, but return a list of dictionaries.
919 """
showard7ac7b7a2008-07-21 20:24:29 +0000920 query = cls.query_objects(filter_data, initial_query=initial_query)
showard8bfb5cb2009-10-07 20:49:15 +0000921 extra_fields = query.query.extra_select.keys()
922 field_dicts = [model_object.get_object_dict(extra_fields=extra_fields)
showarde732ee72008-09-23 19:15:43 +0000923 for model_object in query]
jadmanski0afbb632008-06-06 21:10:57 +0000924 return field_dicts
showard7c785282008-05-29 19:45:12 +0000925
926
jadmanski0afbb632008-06-06 21:10:57 +0000927 @classmethod
showarda4ea5742009-02-17 20:56:23 +0000928 def smart_get(cls, id_or_name, valid_only=True):
jadmanski0afbb632008-06-06 21:10:57 +0000929 """\
930 smart_get(integer) -> get object by ID
931 smart_get(string) -> get object by name_field
jadmanski0afbb632008-06-06 21:10:57 +0000932 """
showarda4ea5742009-02-17 20:56:23 +0000933 if valid_only:
934 manager = cls.get_valid_manager()
935 else:
936 manager = cls.objects
937
938 if isinstance(id_or_name, (int, long)):
939 return manager.get(pk=id_or_name)
jamesren3e9f6092010-03-11 21:32:10 +0000940 if isinstance(id_or_name, basestring) and hasattr(cls, 'name_field'):
showarda4ea5742009-02-17 20:56:23 +0000941 return manager.get(**{cls.name_field : id_or_name})
942 raise ValueError(
943 'Invalid positional argument: %s (%s)' % (id_or_name,
944 type(id_or_name)))
showard7c785282008-05-29 19:45:12 +0000945
946
showardbe3ec042008-11-12 18:16:07 +0000947 @classmethod
948 def smart_get_bulk(cls, id_or_name_list):
949 invalid_inputs = []
950 result_objects = []
951 for id_or_name in id_or_name_list:
952 try:
953 result_objects.append(cls.smart_get(id_or_name))
954 except cls.DoesNotExist:
955 invalid_inputs.append(id_or_name)
956 if invalid_inputs:
mbligh7a3ebe32008-12-01 17:10:33 +0000957 raise cls.DoesNotExist('The following %ss do not exist: %s'
958 % (cls.__name__.lower(),
959 ', '.join(invalid_inputs)))
showardbe3ec042008-11-12 18:16:07 +0000960 return result_objects
961
962
showard8bfb5cb2009-10-07 20:49:15 +0000963 def get_object_dict(self, extra_fields=None):
jadmanski0afbb632008-06-06 21:10:57 +0000964 """\
showard8bfb5cb2009-10-07 20:49:15 +0000965 Return a dictionary mapping fields to this object's values. @param
966 extra_fields: list of extra attribute names to include, in addition to
967 the fields defined on this object.
jadmanski0afbb632008-06-06 21:10:57 +0000968 """
showard8bfb5cb2009-10-07 20:49:15 +0000969 fields = self.get_field_dict().keys()
970 if extra_fields:
971 fields += extra_fields
jadmanski0afbb632008-06-06 21:10:57 +0000972 object_dict = dict((field_name, getattr(self, field_name))
showarde732ee72008-09-23 19:15:43 +0000973 for field_name in fields)
jadmanski0afbb632008-06-06 21:10:57 +0000974 self.clean_object_dicts([object_dict])
showardd3dc1992009-04-22 21:01:40 +0000975 self._postprocess_object_dict(object_dict)
jadmanski0afbb632008-06-06 21:10:57 +0000976 return object_dict
showard7c785282008-05-29 19:45:12 +0000977
978
showardd3dc1992009-04-22 21:01:40 +0000979 def _postprocess_object_dict(self, object_dict):
980 """For subclasses to override."""
981 pass
982
983
jadmanski0afbb632008-06-06 21:10:57 +0000984 @classmethod
985 def get_valid_manager(cls):
986 return cls.objects
showard7c785282008-05-29 19:45:12 +0000987
988
showard2bab8f42008-11-12 18:15:22 +0000989 def _record_attributes(self, attributes):
990 """
991 See on_attribute_changed.
992 """
993 assert not isinstance(attributes, basestring)
994 self._recorded_attributes = dict((attribute, getattr(self, attribute))
995 for attribute in attributes)
996
997
998 def _check_for_updated_attributes(self):
999 """
1000 See on_attribute_changed.
1001 """
1002 for attribute, original_value in self._recorded_attributes.iteritems():
1003 new_value = getattr(self, attribute)
1004 if original_value != new_value:
1005 self.on_attribute_changed(attribute, original_value)
1006 self._record_attributes(self._recorded_attributes.keys())
1007
1008
1009 def on_attribute_changed(self, attribute, old_value):
1010 """
1011 Called whenever an attribute is updated. To be overridden.
1012
1013 To use this method, you must:
1014 * call _record_attributes() from __init__() (after making the super
1015 call) with a list of attributes for which you want to be notified upon
1016 change.
1017 * call _check_for_updated_attributes() from save().
1018 """
1019 pass
1020
1021
showard7c785282008-05-29 19:45:12 +00001022class ModelWithInvalid(ModelExtensions):
jadmanski0afbb632008-06-06 21:10:57 +00001023 """
1024 Overrides model methods save() and delete() to support invalidation in
1025 place of actual deletion. Subclasses must have a boolean "invalid"
1026 field.
1027 """
showard7c785282008-05-29 19:45:12 +00001028
showarda5288b42009-07-28 20:06:08 +00001029 def save(self, *args, **kwargs):
showardddb90992009-02-11 23:39:32 +00001030 first_time = (self.id is None)
1031 if first_time:
1032 # see if this object was previously added and invalidated
1033 my_name = getattr(self, self.name_field)
1034 filters = {self.name_field : my_name, 'invalid' : True}
1035 try:
1036 old_object = self.__class__.objects.get(**filters)
showardafd97de2009-10-01 18:45:09 +00001037 self.resurrect_object(old_object)
showardddb90992009-02-11 23:39:32 +00001038 except self.DoesNotExist:
1039 # no existing object
1040 pass
showard7c785282008-05-29 19:45:12 +00001041
showarda5288b42009-07-28 20:06:08 +00001042 super(ModelWithInvalid, self).save(*args, **kwargs)
showard7c785282008-05-29 19:45:12 +00001043
1044
showardafd97de2009-10-01 18:45:09 +00001045 def resurrect_object(self, old_object):
1046 """
1047 Called when self is about to be saved for the first time and is actually
1048 "undeleting" a previously deleted object. Can be overridden by
1049 subclasses to copy data as desired from the deleted entry (but this
1050 superclass implementation must normally be called).
1051 """
1052 self.id = old_object.id
1053
1054
jadmanski0afbb632008-06-06 21:10:57 +00001055 def clean_object(self):
1056 """
1057 This method is called when an object is marked invalid.
1058 Subclasses should override this to clean up relationships that
showardafd97de2009-10-01 18:45:09 +00001059 should no longer exist if the object were deleted.
1060 """
jadmanski0afbb632008-06-06 21:10:57 +00001061 pass
showard7c785282008-05-29 19:45:12 +00001062
1063
jadmanski0afbb632008-06-06 21:10:57 +00001064 def delete(self):
Dale Curtis74a314b2011-06-23 14:55:46 -07001065 self.invalid = self.invalid
jadmanski0afbb632008-06-06 21:10:57 +00001066 assert not self.invalid
1067 self.invalid = True
1068 self.save()
1069 self.clean_object()
showard7c785282008-05-29 19:45:12 +00001070
1071
jadmanski0afbb632008-06-06 21:10:57 +00001072 @classmethod
1073 def get_valid_manager(cls):
1074 return cls.valid_objects
showard7c785282008-05-29 19:45:12 +00001075
1076
jadmanski0afbb632008-06-06 21:10:57 +00001077 class Manipulator(object):
1078 """
1079 Force default manipulators to look only at valid objects -
1080 otherwise they will match against invalid objects when checking
1081 uniqueness.
1082 """
1083 @classmethod
1084 def _prepare(cls, model):
1085 super(ModelWithInvalid.Manipulator, cls)._prepare(model)
1086 cls.manager = model.valid_objects
showardf8b19042009-05-12 17:22:49 +00001087
1088
1089class ModelWithAttributes(object):
1090 """
1091 Mixin class for models that have an attribute model associated with them.
1092 The attribute model is assumed to have its value field named "value".
1093 """
1094
1095 def _get_attribute_model_and_args(self, attribute):
1096 """
1097 Subclasses should override this to return a tuple (attribute_model,
1098 keyword_args), where attribute_model is a model class and keyword_args
1099 is a dict of args to pass to attribute_model.objects.get() to get an
1100 instance of the given attribute on this object.
1101 """
Dale Curtis74a314b2011-06-23 14:55:46 -07001102 raise NotImplementedError
showardf8b19042009-05-12 17:22:49 +00001103
1104
1105 def set_attribute(self, attribute, value):
1106 attribute_model, get_args = self._get_attribute_model_and_args(
1107 attribute)
1108 attribute_object, _ = attribute_model.objects.get_or_create(**get_args)
1109 attribute_object.value = value
1110 attribute_object.save()
1111
1112
1113 def delete_attribute(self, attribute):
1114 attribute_model, get_args = self._get_attribute_model_and_args(
1115 attribute)
1116 try:
1117 attribute_model.objects.get(**get_args).delete()
showard16245422009-09-08 16:28:15 +00001118 except attribute_model.DoesNotExist:
showardf8b19042009-05-12 17:22:49 +00001119 pass
1120
1121
1122 def set_or_delete_attribute(self, attribute, value):
1123 if value is None:
1124 self.delete_attribute(attribute)
1125 else:
1126 self.set_attribute(attribute, value)
showard26b7ec72009-12-21 22:43:57 +00001127
1128
1129class ModelWithHashManager(dbmodels.Manager):
1130 """Manager for use with the ModelWithHash abstract model class"""
1131
1132 def create(self, **kwargs):
1133 raise Exception('ModelWithHash manager should use get_or_create() '
1134 'instead of create()')
1135
1136
1137 def get_or_create(self, **kwargs):
1138 kwargs['the_hash'] = self.model._compute_hash(**kwargs)
1139 return super(ModelWithHashManager, self).get_or_create(**kwargs)
1140
1141
1142class ModelWithHash(dbmodels.Model):
1143 """Superclass with methods for dealing with a hash column"""
1144
1145 the_hash = dbmodels.CharField(max_length=40, unique=True)
1146
1147 objects = ModelWithHashManager()
1148
1149 class Meta:
1150 abstract = True
1151
1152
1153 @classmethod
1154 def _compute_hash(cls, **kwargs):
1155 raise NotImplementedError('Subclasses must override _compute_hash()')
1156
1157
1158 def save(self, force_insert=False, **kwargs):
1159 """Prevents saving the model in most cases
1160
1161 We want these models to be immutable, so the generic save() operation
1162 will not work. These models should be instantiated through their the
1163 model.objects.get_or_create() method instead.
1164
1165 The exception is that save(force_insert=True) will be allowed, since
1166 that creates a new row. However, the preferred way to make instances of
1167 these models is through the get_or_create() method.
1168 """
1169 if not force_insert:
1170 # Allow a forced insert to happen; if it's a duplicate, the unique
1171 # constraint will catch it later anyways
1172 raise Exception('ModelWithHash is immutable')
1173 super(ModelWithHash, self).save(force_insert=force_insert, **kwargs)