blob: d7b553d2766d545dae068c989718ce1d5c766f3c [file] [log] [blame]
showard7c785282008-05-29 19:45:12 +00001"""
2Extensions to Django's model logic.
3"""
4
showarda5288b42009-07-28 20:06:08 +00005import re
6import django.core.exceptions
showard7c785282008-05-29 19:45:12 +00007from django.db import models as dbmodels, backend, connection
showarda5288b42009-07-28 20:06:08 +00008from django.db.models.sql import query
showard7e67b432010-01-20 01:13:04 +00009import django.db.models.sql.where
showard7c785282008-05-29 19:45:12 +000010from django.utils import datastructures
showard56e93772008-10-06 10:06:22 +000011from autotest_lib.frontend.afe import readonly_connection
showard7c785282008-05-29 19:45:12 +000012
13class ValidationError(Exception):
jadmanski0afbb632008-06-06 21:10:57 +000014 """\
showarda5288b42009-07-28 20:06:08 +000015 Data validation error in adding or updating an object. The associated
jadmanski0afbb632008-06-06 21:10:57 +000016 value is a dictionary mapping field names to error strings.
17 """
showard7c785282008-05-29 19:45:12 +000018
19
showard09096d82008-07-07 23:20:49 +000020def _wrap_with_readonly(method):
mbligh1ef218d2009-08-03 16:57:56 +000021 def wrapper_method(*args, **kwargs):
22 readonly_connection.connection().set_django_connection()
23 try:
24 return method(*args, **kwargs)
25 finally:
26 readonly_connection.connection().unset_django_connection()
27 wrapper_method.__name__ = method.__name__
28 return wrapper_method
showard09096d82008-07-07 23:20:49 +000029
30
showarda5288b42009-07-28 20:06:08 +000031def _quote_name(name):
32 """Shorthand for connection.ops.quote_name()."""
33 return connection.ops.quote_name(name)
34
35
showard09096d82008-07-07 23:20:49 +000036def _wrap_generator_with_readonly(generator):
37 """
38 We have to wrap generators specially. Assume it performs
39 the query on the first call to next().
40 """
41 def wrapper_generator(*args, **kwargs):
42 generator_obj = generator(*args, **kwargs)
showard56e93772008-10-06 10:06:22 +000043 readonly_connection.connection().set_django_connection()
showard09096d82008-07-07 23:20:49 +000044 try:
45 first_value = generator_obj.next()
46 finally:
showard56e93772008-10-06 10:06:22 +000047 readonly_connection.connection().unset_django_connection()
showard09096d82008-07-07 23:20:49 +000048 yield first_value
49
50 while True:
51 yield generator_obj.next()
52
53 wrapper_generator.__name__ = generator.__name__
54 return wrapper_generator
55
56
57def _make_queryset_readonly(queryset):
58 """
59 Wrap all methods that do database queries with a readonly connection.
60 """
61 db_query_methods = ['count', 'get', 'get_or_create', 'latest', 'in_bulk',
62 'delete']
63 for method_name in db_query_methods:
64 method = getattr(queryset, method_name)
65 wrapped_method = _wrap_with_readonly(method)
66 setattr(queryset, method_name, wrapped_method)
67
68 queryset.iterator = _wrap_generator_with_readonly(queryset.iterator)
69
70
71class ReadonlyQuerySet(dbmodels.query.QuerySet):
72 """
73 QuerySet object that performs all database queries with the read-only
74 connection.
75 """
showarda5288b42009-07-28 20:06:08 +000076 def __init__(self, model=None, *args, **kwargs):
77 super(ReadonlyQuerySet, self).__init__(model, *args, **kwargs)
showard09096d82008-07-07 23:20:49 +000078 _make_queryset_readonly(self)
79
80
81 def values(self, *fields):
showarda5288b42009-07-28 20:06:08 +000082 return self._clone(klass=ReadonlyValuesQuerySet,
83 setup=True, _fields=fields)
showard09096d82008-07-07 23:20:49 +000084
85
86class ReadonlyValuesQuerySet(dbmodels.query.ValuesQuerySet):
showarda5288b42009-07-28 20:06:08 +000087 def __init__(self, model=None, *args, **kwargs):
88 super(ReadonlyValuesQuerySet, self).__init__(model, *args, **kwargs)
showard09096d82008-07-07 23:20:49 +000089 _make_queryset_readonly(self)
90
91
showard7c785282008-05-29 19:45:12 +000092class ExtendedManager(dbmodels.Manager):
jadmanski0afbb632008-06-06 21:10:57 +000093 """\
94 Extended manager supporting subquery filtering.
95 """
showard7c785282008-05-29 19:45:12 +000096
showardf828c772010-01-25 21:49:42 +000097 class CustomQuery(query.Query):
showard7e67b432010-01-20 01:13:04 +000098 def __init__(self, *args, **kwargs):
showardf828c772010-01-25 21:49:42 +000099 super(ExtendedManager.CustomQuery, self).__init__(*args, **kwargs)
showard7e67b432010-01-20 01:13:04 +0000100 self._custom_joins = []
101
102
showarda5288b42009-07-28 20:06:08 +0000103 def clone(self, klass=None, **kwargs):
showardf828c772010-01-25 21:49:42 +0000104 obj = super(ExtendedManager.CustomQuery, self).clone(klass)
showard7e67b432010-01-20 01:13:04 +0000105 obj._custom_joins = list(self._custom_joins)
showarda5288b42009-07-28 20:06:08 +0000106 return obj
showard08f981b2008-06-24 21:59:03 +0000107
showard7e67b432010-01-20 01:13:04 +0000108
109 def combine(self, rhs, connector):
showardf828c772010-01-25 21:49:42 +0000110 super(ExtendedManager.CustomQuery, self).combine(rhs, connector)
showard7e67b432010-01-20 01:13:04 +0000111 if hasattr(rhs, '_custom_joins'):
112 self._custom_joins.extend(rhs._custom_joins)
113
114
115 def add_custom_join(self, table, condition, join_type,
116 condition_values=(), alias=None):
117 if alias is None:
118 alias = table
119 join_dict = dict(table=table,
120 condition=condition,
121 condition_values=condition_values,
122 join_type=join_type,
123 alias=alias)
124 self._custom_joins.append(join_dict)
125
126
showarda5288b42009-07-28 20:06:08 +0000127 def get_from_clause(self):
showardf828c772010-01-25 21:49:42 +0000128 from_, params = (super(ExtendedManager.CustomQuery, self)
showard7e67b432010-01-20 01:13:04 +0000129 .get_from_clause())
showard08f981b2008-06-24 21:59:03 +0000130
showard7e67b432010-01-20 01:13:04 +0000131 for join_dict in self._custom_joins:
132 from_.append('%s %s AS %s ON (%s)'
133 % (join_dict['join_type'],
134 _quote_name(join_dict['table']),
135 _quote_name(join_dict['alias']),
136 join_dict['condition']))
137 params.extend(join_dict['condition_values'])
showard7c785282008-05-29 19:45:12 +0000138
showarda5288b42009-07-28 20:06:08 +0000139 return from_, params
showard7c785282008-05-29 19:45:12 +0000140
141
showard7e67b432010-01-20 01:13:04 +0000142 @classmethod
143 def convert_query(self, query_set):
144 """
showardf828c772010-01-25 21:49:42 +0000145 Convert the query set's "query" attribute to a CustomQuery.
showard7e67b432010-01-20 01:13:04 +0000146 """
147 # Make a copy of the query set
148 query_set = query_set.all()
149 query_set.query = query_set.query.clone(
showardf828c772010-01-25 21:49:42 +0000150 klass=ExtendedManager.CustomQuery,
showard7e67b432010-01-20 01:13:04 +0000151 _custom_joins=[])
152 return query_set
showard43a3d262008-11-12 18:17:05 +0000153
154
showard7e67b432010-01-20 01:13:04 +0000155 class _WhereClause(object):
156 """Object allowing us to inject arbitrary SQL into Django queries.
showard43a3d262008-11-12 18:17:05 +0000157
showard7e67b432010-01-20 01:13:04 +0000158 By using this instead of extra(where=...), we can still freely combine
159 queries with & and |.
showarda5288b42009-07-28 20:06:08 +0000160 """
showard7e67b432010-01-20 01:13:04 +0000161 def __init__(self, clause, values=()):
162 self._clause = clause
163 self._values = values
showarda5288b42009-07-28 20:06:08 +0000164
showard7e67b432010-01-20 01:13:04 +0000165
166 def as_sql(self, qn=None):
167 return self._clause, self._values
168
169
170 def relabel_aliases(self, change_map):
171 return
showard43a3d262008-11-12 18:17:05 +0000172
173
showard8b0ea222009-12-23 19:23:03 +0000174 def add_join(self, query_set, join_table, join_key, join_condition='',
showard7e67b432010-01-20 01:13:04 +0000175 join_condition_values=(), join_from_key=None, alias=None,
176 suffix='', exclude=False, force_left_join=False):
177 """Add a join to query_set.
178
179 Join looks like this:
180 (INNER|LEFT) JOIN <join_table> AS <alias>
181 ON (<this table>.<join_from_key> = <join_table>.<join_key>
182 and <join_condition>)
183
showard0957a842009-05-11 19:25:08 +0000184 @param join_table table to join to
185 @param join_key field referencing back to this model to use for the join
186 @param join_condition extra condition for the ON clause of the join
showard7e67b432010-01-20 01:13:04 +0000187 @param join_condition_values values to substitute into join_condition
188 @param join_from_key column on this model to join from.
showard8b0ea222009-12-23 19:23:03 +0000189 @param alias alias to use for for join
190 @param suffix suffix to add to join_table for the join alias, if no
191 alias is provided
showard0957a842009-05-11 19:25:08 +0000192 @param exclude if true, exclude rows that match this join (will use a
showarda5288b42009-07-28 20:06:08 +0000193 LEFT OUTER JOIN and an appropriate WHERE condition)
showardc4780402009-08-31 18:31:34 +0000194 @param force_left_join - if true, a LEFT OUTER JOIN will be used
195 instead of an INNER JOIN regardless of other options
showard0957a842009-05-11 19:25:08 +0000196 """
showard7e67b432010-01-20 01:13:04 +0000197 join_from_table = query_set.model._meta.db_table
198 if join_from_key is None:
199 join_from_key = self.model._meta.pk.name
200 if alias is None:
201 alias = join_table + suffix
202 full_join_key = _quote_name(alias) + '.' + _quote_name(join_key)
203 full_join_condition = '%s = %s.%s' % (full_join_key,
204 _quote_name(join_from_table),
205 _quote_name(join_from_key))
showard43a3d262008-11-12 18:17:05 +0000206 if join_condition:
207 full_join_condition += ' AND (' + join_condition + ')'
208 if exclude or force_left_join:
showarda5288b42009-07-28 20:06:08 +0000209 join_type = query_set.query.LOUTER
showard43a3d262008-11-12 18:17:05 +0000210 else:
showarda5288b42009-07-28 20:06:08 +0000211 join_type = query_set.query.INNER
showard43a3d262008-11-12 18:17:05 +0000212
showardf828c772010-01-25 21:49:42 +0000213 query_set = self.CustomQuery.convert_query(query_set)
showard7e67b432010-01-20 01:13:04 +0000214 query_set.query.add_custom_join(join_table,
215 full_join_condition,
216 join_type,
217 condition_values=join_condition_values,
218 alias=alias)
showard43a3d262008-11-12 18:17:05 +0000219
showard7e67b432010-01-20 01:13:04 +0000220 if exclude:
221 query_set = query_set.extra(where=[full_join_key + ' IS NULL'])
222
223 return query_set
224
225
226 def _info_for_many_to_one_join(self, field, join_to_query, alias):
227 """
228 @param field: the ForeignKey field on the related model
229 @param join_to_query: the query over the related model that we're
230 joining to
231 @param alias: alias of joined table
232 """
233 info = {}
234 rhs_table = join_to_query.model._meta.db_table
235 info['rhs_table'] = rhs_table
236 info['rhs_column'] = field.column
237 info['lhs_column'] = field.rel.get_related_field().column
238 rhs_where = join_to_query.query.where
239 rhs_where.relabel_aliases({rhs_table: alias})
240 initial_clause, values = rhs_where.as_sql()
241 all_clauses = (initial_clause,) + join_to_query.query.extra_where
242 info['where_clause'] = ' AND '.join('(%s)' % clause
243 for clause in all_clauses)
244 values += join_to_query.query.extra_params
245 info['values'] = values
246 return info
247
248
249 def _info_for_many_to_many_join(self, m2m_field, join_to_query, alias,
250 m2m_is_on_this_model):
251 """
252 @param m2m_field: a Django field representing the M2M relationship.
253 It uses a pivot table with the following structure:
254 this model table <---> M2M pivot table <---> joined model table
255 @param join_to_query: the query over the related model that we're
256 joining to.
257 @param alias: alias of joined table
258 """
259 if m2m_is_on_this_model:
260 # referenced field on this model
261 lhs_id_field = self.model._meta.pk
262 # foreign key on the pivot table referencing lhs_id_field
263 m2m_lhs_column = m2m_field.m2m_column_name()
264 # foreign key on the pivot table referencing rhd_id_field
265 m2m_rhs_column = m2m_field.m2m_reverse_name()
266 # referenced field on related model
267 rhs_id_field = m2m_field.rel.get_related_field()
268 else:
269 lhs_id_field = m2m_field.rel.get_related_field()
270 m2m_lhs_column = m2m_field.m2m_reverse_name()
271 m2m_rhs_column = m2m_field.m2m_column_name()
272 rhs_id_field = join_to_query.model._meta.pk
273
274 info = {}
275 info['rhs_table'] = m2m_field.m2m_db_table()
276 info['rhs_column'] = m2m_lhs_column
277 info['lhs_column'] = lhs_id_field.column
278
279 # select the ID of related models relevant to this join. we can only do
280 # a single join, so we need to gather this information up front and
281 # include it in the join condition.
282 rhs_ids = join_to_query.values_list(rhs_id_field.attname, flat=True)
283 assert len(rhs_ids) == 1, ('Many-to-many custom field joins can only '
284 'match a single related object.')
285 rhs_id = rhs_ids[0]
286
287 info['where_clause'] = '%s.%s = %s' % (_quote_name(alias),
288 _quote_name(m2m_rhs_column),
289 rhs_id)
290 info['values'] = ()
291 return info
292
293
294 def join_custom_field(self, query_set, join_to_query, alias,
295 left_join=True):
296 """Join to a related model to create a custom field in the given query.
297
298 This method is used to construct a custom field on the given query based
299 on a many-valued relationsip. join_to_query should be a simple query
300 (no joins) on the related model which returns at most one related row
301 per instance of this model.
302
303 For many-to-one relationships, the joined table contains the matching
304 row from the related model it one is related, NULL otherwise.
305
306 For many-to-many relationships, the joined table contains the matching
307 row if it's related, NULL otherwise.
308 """
309 relationship_type, field = self.determine_relationship(
310 join_to_query.model)
311
312 if relationship_type == self.MANY_TO_ONE:
313 info = self._info_for_many_to_one_join(field, join_to_query, alias)
314 elif relationship_type == self.M2M_ON_RELATED_MODEL:
315 info = self._info_for_many_to_many_join(
316 m2m_field=field, join_to_query=join_to_query, alias=alias,
317 m2m_is_on_this_model=False)
318 elif relationship_type ==self.M2M_ON_THIS_MODEL:
319 info = self._info_for_many_to_many_join(
320 m2m_field=field, join_to_query=join_to_query, alias=alias,
321 m2m_is_on_this_model=True)
322
323 return self.add_join(query_set, info['rhs_table'], info['rhs_column'],
324 join_from_key=info['lhs_column'],
325 join_condition=info['where_clause'],
326 join_condition_values=info['values'],
327 alias=alias,
328 force_left_join=left_join)
329
330
showardf828c772010-01-25 21:49:42 +0000331 def key_on_joined_table(self, join_to_query):
332 """Get a non-null column on the table joined for the given query.
333
334 This analyzes the join that would be produced if join_to_query were
335 passed to join_custom_field.
336 """
337 relationship_type, field = self.determine_relationship(
338 join_to_query.model)
339 if relationship_type == self.MANY_TO_ONE:
340 return join_to_query.model._meta.pk.column
341 return field.m2m_column_name() # any column on the M2M table will do
342
343
showard7e67b432010-01-20 01:13:04 +0000344 def add_where(self, query_set, where, values=()):
345 query_set = query_set.all()
346 query_set.query.where.add(self._WhereClause(where, values),
347 django.db.models.sql.where.AND)
showardc4780402009-08-31 18:31:34 +0000348 return query_set
showard7c785282008-05-29 19:45:12 +0000349
350
showardeaccf8f2009-04-16 03:11:33 +0000351 def _get_quoted_field(self, table, field):
showarda5288b42009-07-28 20:06:08 +0000352 return _quote_name(table) + '.' + _quote_name(field)
showard5ef36e92008-07-02 16:37:09 +0000353
354
showard7c199df2008-10-03 10:17:15 +0000355 def get_key_on_this_table(self, key_field=None):
showard5ef36e92008-07-02 16:37:09 +0000356 if key_field is None:
357 # default to primary key
358 key_field = self.model._meta.pk.column
359 return self._get_quoted_field(self.model._meta.db_table, key_field)
360
361
showardeaccf8f2009-04-16 03:11:33 +0000362 def escape_user_sql(self, sql):
363 return sql.replace('%', '%%')
364
showard5ef36e92008-07-02 16:37:09 +0000365
showard0957a842009-05-11 19:25:08 +0000366 def _custom_select_query(self, query_set, selects):
showarda5288b42009-07-28 20:06:08 +0000367 sql, params = query_set.query.as_sql()
368 from_ = sql[sql.find(' FROM'):]
369
370 if query_set.query.distinct:
showard0957a842009-05-11 19:25:08 +0000371 distinct = 'DISTINCT '
372 else:
373 distinct = ''
showarda5288b42009-07-28 20:06:08 +0000374
375 sql_query = ('SELECT ' + distinct + ','.join(selects) + from_)
showard0957a842009-05-11 19:25:08 +0000376 cursor = readonly_connection.connection().cursor()
377 cursor.execute(sql_query, params)
378 return cursor.fetchall()
379
380
showard68693f72009-05-20 00:31:53 +0000381 def _is_relation_to(self, field, model_class):
382 return field.rel and field.rel.to is model_class
showard0957a842009-05-11 19:25:08 +0000383
384
showard7e67b432010-01-20 01:13:04 +0000385 MANY_TO_ONE = object()
386 M2M_ON_RELATED_MODEL = object()
387 M2M_ON_THIS_MODEL = object()
388
389 def determine_relationship(self, related_model):
390 """
391 Determine the relationship between this model and related_model.
392
393 related_model must have some sort of many-valued relationship to this
394 manager's model.
395 @returns (relationship_type, field), where relationship_type is one of
396 MANY_TO_ONE, M2M_ON_RELATED_MODEL, M2M_ON_THIS_MODEL, and field
397 is the Django field object for the relationship.
398 """
399 # look for a foreign key field on related_model relating to this model
400 for field in related_model._meta.fields:
401 if self._is_relation_to(field, self.model):
402 return self.MANY_TO_ONE, field
403
404 # look for an M2M field on related_model relating to this model
405 for field in related_model._meta.many_to_many:
406 if self._is_relation_to(field, self.model):
407 return self.M2M_ON_RELATED_MODEL, field
408
409 # maybe this model has the many-to-many field
410 for field in self.model._meta.many_to_many:
411 if self._is_relation_to(field, related_model):
412 return self.M2M_ON_THIS_MODEL, field
413
414 raise ValueError('%s has no relation to %s' %
415 (related_model, self.model))
416
417
showard68693f72009-05-20 00:31:53 +0000418 def _get_pivot_iterator(self, base_objects_by_id, related_model):
showard0957a842009-05-11 19:25:08 +0000419 """
showard68693f72009-05-20 00:31:53 +0000420 Determine the relationship between this model and related_model, and
421 return a pivot iterator.
422 @param base_objects_by_id: dict of instances of this model indexed by
423 their IDs
424 @returns a pivot iterator, which yields a tuple (base_object,
425 related_object) for each relationship between a base object and a
426 related object. all base_object instances come from base_objects_by_id.
showard7e67b432010-01-20 01:13:04 +0000427 Note -- this depends on Django model internals.
showard0957a842009-05-11 19:25:08 +0000428 """
showard7e67b432010-01-20 01:13:04 +0000429 relationship_type, field = self.determine_relationship(related_model)
430 if relationship_type == self.MANY_TO_ONE:
431 return self._many_to_one_pivot(base_objects_by_id,
432 related_model, field)
433 elif relationship_type == self.M2M_ON_RELATED_MODEL:
434 return self._many_to_many_pivot(
showard68693f72009-05-20 00:31:53 +0000435 base_objects_by_id, related_model, field.m2m_db_table(),
436 field.m2m_reverse_name(), field.m2m_column_name())
showard7e67b432010-01-20 01:13:04 +0000437 else:
438 assert relationship_type == self.M2M_ON_THIS_MODEL
439 return self._many_to_many_pivot(
showard68693f72009-05-20 00:31:53 +0000440 base_objects_by_id, related_model, field.m2m_db_table(),
441 field.m2m_column_name(), field.m2m_reverse_name())
showard0957a842009-05-11 19:25:08 +0000442
showard0957a842009-05-11 19:25:08 +0000443
showard68693f72009-05-20 00:31:53 +0000444 def _many_to_one_pivot(self, base_objects_by_id, related_model,
445 foreign_key_field):
446 """
447 @returns a pivot iterator - see _get_pivot_iterator()
448 """
449 filter_data = {foreign_key_field.name + '__pk__in':
450 base_objects_by_id.keys()}
451 for related_object in related_model.objects.filter(**filter_data):
showarda5a72c92009-08-20 23:35:21 +0000452 # lookup base object in the dict, rather than grabbing it from the
453 # related object. we need to return instances from the dict, not
454 # fresh instances of the same models (and grabbing model instances
455 # from the related models incurs a DB query each time).
456 base_object_id = getattr(related_object, foreign_key_field.attname)
457 base_object = base_objects_by_id[base_object_id]
showard68693f72009-05-20 00:31:53 +0000458 yield base_object, related_object
459
460
461 def _query_pivot_table(self, base_objects_by_id, pivot_table,
462 pivot_from_field, pivot_to_field):
showard0957a842009-05-11 19:25:08 +0000463 """
464 @param id_list list of IDs of self.model objects to include
465 @param pivot_table the name of the pivot table
466 @param pivot_from_field a field name on pivot_table referencing
467 self.model
468 @param pivot_to_field a field name on pivot_table referencing the
469 related model.
showard68693f72009-05-20 00:31:53 +0000470 @returns pivot list of IDs (base_id, related_id)
showard0957a842009-05-11 19:25:08 +0000471 """
472 query = """
473 SELECT %(from_field)s, %(to_field)s
474 FROM %(table)s
475 WHERE %(from_field)s IN (%(id_list)s)
476 """ % dict(from_field=pivot_from_field,
477 to_field=pivot_to_field,
478 table=pivot_table,
showard68693f72009-05-20 00:31:53 +0000479 id_list=','.join(str(id_) for id_
480 in base_objects_by_id.iterkeys()))
showard0957a842009-05-11 19:25:08 +0000481 cursor = readonly_connection.connection().cursor()
482 cursor.execute(query)
showard68693f72009-05-20 00:31:53 +0000483 return cursor.fetchall()
showard0957a842009-05-11 19:25:08 +0000484
485
showard68693f72009-05-20 00:31:53 +0000486 def _many_to_many_pivot(self, base_objects_by_id, related_model,
487 pivot_table, pivot_from_field, pivot_to_field):
488 """
489 @param pivot_table: see _query_pivot_table
490 @param pivot_from_field: see _query_pivot_table
491 @param pivot_to_field: see _query_pivot_table
492 @returns a pivot iterator - see _get_pivot_iterator()
493 """
494 id_pivot = self._query_pivot_table(base_objects_by_id, pivot_table,
495 pivot_from_field, pivot_to_field)
496
497 all_related_ids = list(set(related_id for base_id, related_id
498 in id_pivot))
499 related_objects_by_id = related_model.objects.in_bulk(all_related_ids)
500
501 for base_id, related_id in id_pivot:
502 yield base_objects_by_id[base_id], related_objects_by_id[related_id]
503
504
505 def populate_relationships(self, base_objects, related_model,
showard0957a842009-05-11 19:25:08 +0000506 related_list_name):
507 """
showard68693f72009-05-20 00:31:53 +0000508 For each instance of this model in base_objects, add a field named
509 related_list_name listing all the related objects of type related_model.
510 related_model must be in a many-to-one or many-to-many relationship with
511 this model.
512 @param base_objects - list of instances of this model
513 @param related_model - model class related to this model
514 @param related_list_name - attribute name in which to store the related
515 object list.
showard0957a842009-05-11 19:25:08 +0000516 """
showard68693f72009-05-20 00:31:53 +0000517 if not base_objects:
showard0957a842009-05-11 19:25:08 +0000518 # if we don't bail early, we'll get a SQL error later
519 return
showard0957a842009-05-11 19:25:08 +0000520
showard68693f72009-05-20 00:31:53 +0000521 base_objects_by_id = dict((base_object._get_pk_val(), base_object)
522 for base_object in base_objects)
523 pivot_iterator = self._get_pivot_iterator(base_objects_by_id,
524 related_model)
showard0957a842009-05-11 19:25:08 +0000525
showard68693f72009-05-20 00:31:53 +0000526 for base_object in base_objects:
527 setattr(base_object, related_list_name, [])
528
529 for base_object, related_object in pivot_iterator:
530 getattr(base_object, related_list_name).append(related_object)
showard0957a842009-05-11 19:25:08 +0000531
532
jamesrene3656232010-03-02 00:00:30 +0000533class ModelWithInvalidQuerySet(dbmodels.query.QuerySet):
534 """
535 QuerySet that handles delete() properly for models with an "invalid" bit
536 """
537 def delete(self):
538 for model in self:
539 model.delete()
540
541
542class ModelWithInvalidManager(ExtendedManager):
543 """
544 Manager for objects with an "invalid" bit
545 """
546 def get_query_set(self):
547 return ModelWithInvalidQuerySet(self.model)
548
549
550class ValidObjectsManager(ModelWithInvalidManager):
jadmanski0afbb632008-06-06 21:10:57 +0000551 """
552 Manager returning only objects with invalid=False.
553 """
554 def get_query_set(self):
555 queryset = super(ValidObjectsManager, self).get_query_set()
556 return queryset.filter(invalid=False)
showard7c785282008-05-29 19:45:12 +0000557
558
559class ModelExtensions(object):
jadmanski0afbb632008-06-06 21:10:57 +0000560 """\
561 Mixin with convenience functions for models, built on top of the
562 default Django model functions.
563 """
564 # TODO: at least some of these functions really belong in a custom
565 # Manager class
showard7c785282008-05-29 19:45:12 +0000566
jadmanski0afbb632008-06-06 21:10:57 +0000567 field_dict = None
568 # subclasses should override if they want to support smart_get() by name
569 name_field = None
showard7c785282008-05-29 19:45:12 +0000570
571
jadmanski0afbb632008-06-06 21:10:57 +0000572 @classmethod
573 def get_field_dict(cls):
574 if cls.field_dict is None:
575 cls.field_dict = {}
576 for field in cls._meta.fields:
577 cls.field_dict[field.name] = field
578 return cls.field_dict
showard7c785282008-05-29 19:45:12 +0000579
580
jadmanski0afbb632008-06-06 21:10:57 +0000581 @classmethod
582 def clean_foreign_keys(cls, data):
583 """\
584 -Convert foreign key fields in data from <field>_id to just
585 <field>.
586 -replace foreign key objects with their IDs
587 This method modifies data in-place.
588 """
589 for field in cls._meta.fields:
590 if not field.rel:
591 continue
592 if (field.attname != field.name and
593 field.attname in data):
594 data[field.name] = data[field.attname]
595 del data[field.attname]
showarde732ee72008-09-23 19:15:43 +0000596 if field.name not in data:
597 continue
jadmanski0afbb632008-06-06 21:10:57 +0000598 value = data[field.name]
599 if isinstance(value, dbmodels.Model):
showardf8b19042009-05-12 17:22:49 +0000600 data[field.name] = value._get_pk_val()
showard7c785282008-05-29 19:45:12 +0000601
602
showard21baa452008-10-21 00:08:39 +0000603 @classmethod
604 def _convert_booleans(cls, data):
605 """
606 Ensure BooleanFields actually get bool values. The Django MySQL
607 backend returns ints for BooleanFields, which is almost always not
608 a problem, but it can be annoying in certain situations.
609 """
610 for field in cls._meta.fields:
showardf8b19042009-05-12 17:22:49 +0000611 if type(field) == dbmodels.BooleanField and field.name in data:
showard21baa452008-10-21 00:08:39 +0000612 data[field.name] = bool(data[field.name])
613
614
jadmanski0afbb632008-06-06 21:10:57 +0000615 # TODO(showard) - is there a way to not have to do this?
616 @classmethod
617 def provide_default_values(cls, data):
618 """\
619 Provide default values for fields with default values which have
620 nothing passed in.
showard7c785282008-05-29 19:45:12 +0000621
jadmanski0afbb632008-06-06 21:10:57 +0000622 For CharField and TextField fields with "blank=True", if nothing
623 is passed, we fill in an empty string value, even if there's no
624 default set.
625 """
626 new_data = dict(data)
627 field_dict = cls.get_field_dict()
628 for name, obj in field_dict.iteritems():
629 if data.get(name) is not None:
630 continue
631 if obj.default is not dbmodels.fields.NOT_PROVIDED:
632 new_data[name] = obj.default
633 elif (isinstance(obj, dbmodels.CharField) or
634 isinstance(obj, dbmodels.TextField)):
635 new_data[name] = ''
636 return new_data
showard7c785282008-05-29 19:45:12 +0000637
638
jadmanski0afbb632008-06-06 21:10:57 +0000639 @classmethod
640 def convert_human_readable_values(cls, data, to_human_readable=False):
641 """\
642 Performs conversions on user-supplied field data, to make it
643 easier for users to pass human-readable data.
showard7c785282008-05-29 19:45:12 +0000644
jadmanski0afbb632008-06-06 21:10:57 +0000645 For all fields that have choice sets, convert their values
646 from human-readable strings to enum values, if necessary. This
647 allows users to pass strings instead of the corresponding
648 integer values.
showard7c785282008-05-29 19:45:12 +0000649
jadmanski0afbb632008-06-06 21:10:57 +0000650 For all foreign key fields, call smart_get with the supplied
651 data. This allows the user to pass either an ID value or
652 the name of the object as a string.
showard7c785282008-05-29 19:45:12 +0000653
jadmanski0afbb632008-06-06 21:10:57 +0000654 If to_human_readable=True, perform the inverse - i.e. convert
655 numeric values to human readable values.
showard7c785282008-05-29 19:45:12 +0000656
jadmanski0afbb632008-06-06 21:10:57 +0000657 This method modifies data in-place.
658 """
659 field_dict = cls.get_field_dict()
660 for field_name in data:
showarde732ee72008-09-23 19:15:43 +0000661 if field_name not in field_dict or data[field_name] is None:
jadmanski0afbb632008-06-06 21:10:57 +0000662 continue
663 field_obj = field_dict[field_name]
664 # convert enum values
665 if field_obj.choices:
666 for choice_data in field_obj.choices:
667 # choice_data is (value, name)
668 if to_human_readable:
669 from_val, to_val = choice_data
670 else:
671 to_val, from_val = choice_data
672 if from_val == data[field_name]:
673 data[field_name] = to_val
674 break
675 # convert foreign key values
676 elif field_obj.rel:
showarda4ea5742009-02-17 20:56:23 +0000677 dest_obj = field_obj.rel.to.smart_get(data[field_name],
678 valid_only=False)
showardf8b19042009-05-12 17:22:49 +0000679 if to_human_readable:
Paul Pendlebury5a8c6ad2011-02-01 07:20:17 -0800680 # parameterized_jobs do not have a name_field
681 if (field_name != 'parameterized_job' and
682 dest_obj.name_field is not None):
showardf8b19042009-05-12 17:22:49 +0000683 data[field_name] = getattr(dest_obj,
684 dest_obj.name_field)
jadmanski0afbb632008-06-06 21:10:57 +0000685 else:
showardb0a73032009-03-27 18:35:41 +0000686 data[field_name] = dest_obj
showard7c785282008-05-29 19:45:12 +0000687
688
jadmanski0afbb632008-06-06 21:10:57 +0000689 @classmethod
690 def validate_field_names(cls, data):
691 'Checks for extraneous fields in data.'
692 errors = {}
693 field_dict = cls.get_field_dict()
694 for field_name in data:
695 if field_name not in field_dict:
696 errors[field_name] = 'No field of this name'
697 return errors
showard7c785282008-05-29 19:45:12 +0000698
699
jadmanski0afbb632008-06-06 21:10:57 +0000700 @classmethod
701 def prepare_data_args(cls, data, kwargs):
702 'Common preparation for add_object and update_object'
703 data = dict(data) # don't modify the default keyword arg
704 data.update(kwargs)
705 # must check for extraneous field names here, while we have the
706 # data in a dict
707 errors = cls.validate_field_names(data)
708 if errors:
709 raise ValidationError(errors)
710 cls.convert_human_readable_values(data)
711 return data
showard7c785282008-05-29 19:45:12 +0000712
713
jadmanski0afbb632008-06-06 21:10:57 +0000714 def validate_unique(self):
715 """\
716 Validate that unique fields are unique. Django manipulators do
717 this too, but they're a huge pain to use manually. Trust me.
718 """
719 errors = {}
720 cls = type(self)
721 field_dict = self.get_field_dict()
722 manager = cls.get_valid_manager()
723 for field_name, field_obj in field_dict.iteritems():
724 if not field_obj.unique:
725 continue
showard7c785282008-05-29 19:45:12 +0000726
jadmanski0afbb632008-06-06 21:10:57 +0000727 value = getattr(self, field_name)
showardbd18ab72009-09-18 21:20:27 +0000728 if value is None and field_obj.auto_created:
729 # don't bother checking autoincrement fields about to be
730 # generated
731 continue
732
jadmanski0afbb632008-06-06 21:10:57 +0000733 existing_objs = manager.filter(**{field_name : value})
734 num_existing = existing_objs.count()
showard7c785282008-05-29 19:45:12 +0000735
jadmanski0afbb632008-06-06 21:10:57 +0000736 if num_existing == 0:
737 continue
738 if num_existing == 1 and existing_objs[0].id == self.id:
739 continue
740 errors[field_name] = (
741 'This value must be unique (%s)' % (value))
742 return errors
showard7c785282008-05-29 19:45:12 +0000743
744
showarda5288b42009-07-28 20:06:08 +0000745 def _validate(self):
746 """
747 First coerces all fields on this instance to their proper Python types.
748 Then runs validation on every field. Returns a dictionary of
749 field_name -> error_list.
750
751 Based on validate() from django.db.models.Model in Django 0.96, which
752 was removed in Django 1.0. It should reappear in a later version. See:
753 http://code.djangoproject.com/ticket/6845
754 """
755 error_dict = {}
756 for f in self._meta.fields:
757 try:
758 python_value = f.to_python(
759 getattr(self, f.attname, f.get_default()))
760 except django.core.exceptions.ValidationError, e:
jamesren1e0a4ce2010-04-21 17:45:11 +0000761 error_dict[f.name] = str(e)
showarda5288b42009-07-28 20:06:08 +0000762 continue
763
764 if not f.blank and not python_value:
765 error_dict[f.name] = 'This field is required.'
766 continue
767
768 setattr(self, f.attname, python_value)
769
770 return error_dict
771
772
jadmanski0afbb632008-06-06 21:10:57 +0000773 def do_validate(self):
showarda5288b42009-07-28 20:06:08 +0000774 errors = self._validate()
jadmanski0afbb632008-06-06 21:10:57 +0000775 unique_errors = self.validate_unique()
776 for field_name, error in unique_errors.iteritems():
777 errors.setdefault(field_name, error)
778 if errors:
779 raise ValidationError(errors)
showard7c785282008-05-29 19:45:12 +0000780
781
jadmanski0afbb632008-06-06 21:10:57 +0000782 # actually (externally) useful methods follow
showard7c785282008-05-29 19:45:12 +0000783
jadmanski0afbb632008-06-06 21:10:57 +0000784 @classmethod
785 def add_object(cls, data={}, **kwargs):
786 """\
787 Returns a new object created with the given data (a dictionary
788 mapping field names to values). Merges any extra keyword args
789 into data.
790 """
791 data = cls.prepare_data_args(data, kwargs)
792 data = cls.provide_default_values(data)
793 obj = cls(**data)
794 obj.do_validate()
795 obj.save()
796 return obj
showard7c785282008-05-29 19:45:12 +0000797
798
jadmanski0afbb632008-06-06 21:10:57 +0000799 def update_object(self, data={}, **kwargs):
800 """\
801 Updates the object with the given data (a dictionary mapping
802 field names to values). Merges any extra keyword args into
803 data.
804 """
805 data = self.prepare_data_args(data, kwargs)
806 for field_name, value in data.iteritems():
showardb0a73032009-03-27 18:35:41 +0000807 setattr(self, field_name, value)
jadmanski0afbb632008-06-06 21:10:57 +0000808 self.do_validate()
809 self.save()
showard7c785282008-05-29 19:45:12 +0000810
811
showard8bfb5cb2009-10-07 20:49:15 +0000812 # see query_objects()
813 _SPECIAL_FILTER_KEYS = ('query_start', 'query_limit', 'sort_by',
814 'extra_args', 'extra_where', 'no_distinct')
815
816
jadmanski0afbb632008-06-06 21:10:57 +0000817 @classmethod
showard8bfb5cb2009-10-07 20:49:15 +0000818 def _extract_special_params(cls, filter_data):
819 """
820 @returns a tuple of dicts (special_params, regular_filters), where
821 special_params contains the parameters we handle specially and
822 regular_filters is the remaining data to be handled by Django.
823 """
824 regular_filters = dict(filter_data)
825 special_params = {}
826 for key in cls._SPECIAL_FILTER_KEYS:
827 if key in regular_filters:
828 special_params[key] = regular_filters.pop(key)
829 return special_params, regular_filters
830
831
832 @classmethod
833 def apply_presentation(cls, query, filter_data):
834 """
835 Apply presentation parameters -- sorting and paging -- to the given
836 query.
837 @returns new query with presentation applied
838 """
839 special_params, _ = cls._extract_special_params(filter_data)
840 sort_by = special_params.get('sort_by', None)
841 if sort_by:
842 assert isinstance(sort_by, list) or isinstance(sort_by, tuple)
showard8b0ea222009-12-23 19:23:03 +0000843 query = query.extra(order_by=sort_by)
showard8bfb5cb2009-10-07 20:49:15 +0000844
845 query_start = special_params.get('query_start', None)
846 query_limit = special_params.get('query_limit', None)
847 if query_start is not None:
848 if query_limit is None:
849 raise ValueError('Cannot pass query_start without query_limit')
850 # query_limit is passed as a page size
showard7074b742009-10-12 20:30:04 +0000851 query_limit += query_start
852 return query[query_start:query_limit]
showard8bfb5cb2009-10-07 20:49:15 +0000853
854
855 @classmethod
856 def query_objects(cls, filter_data, valid_only=True, initial_query=None,
857 apply_presentation=True):
jadmanski0afbb632008-06-06 21:10:57 +0000858 """\
859 Returns a QuerySet object for querying the given model_class
860 with the given filter_data. Optional special arguments in
861 filter_data include:
862 -query_start: index of first return to return
863 -query_limit: maximum number of results to return
864 -sort_by: list of fields to sort on. prefixing a '-' onto a
865 field name changes the sort to descending order.
866 -extra_args: keyword args to pass to query.extra() (see Django
867 DB layer documentation)
showarda5288b42009-07-28 20:06:08 +0000868 -extra_where: extra WHERE clause to append
showard8bfb5cb2009-10-07 20:49:15 +0000869 -no_distinct: if True, a DISTINCT will not be added to the SELECT
jadmanski0afbb632008-06-06 21:10:57 +0000870 """
showard8bfb5cb2009-10-07 20:49:15 +0000871 special_params, regular_filters = cls._extract_special_params(
872 filter_data)
showard7c785282008-05-29 19:45:12 +0000873
showard7ac7b7a2008-07-21 20:24:29 +0000874 if initial_query is None:
875 if valid_only:
876 initial_query = cls.get_valid_manager()
877 else:
878 initial_query = cls.objects
showard8bfb5cb2009-10-07 20:49:15 +0000879
880 query = initial_query.filter(**regular_filters)
881
882 use_distinct = not special_params.get('no_distinct', False)
showard7ac7b7a2008-07-21 20:24:29 +0000883 if use_distinct:
884 query = query.distinct()
showard7c785282008-05-29 19:45:12 +0000885
showard8bfb5cb2009-10-07 20:49:15 +0000886 extra_args = special_params.get('extra_args', {})
887 extra_where = special_params.get('extra_where', None)
888 if extra_where:
889 # escape %'s
890 extra_where = cls.objects.escape_user_sql(extra_where)
891 extra_args.setdefault('where', []).append(extra_where)
jadmanski0afbb632008-06-06 21:10:57 +0000892 if extra_args:
893 query = query.extra(**extra_args)
showard09096d82008-07-07 23:20:49 +0000894 query = query._clone(klass=ReadonlyQuerySet)
showard7c785282008-05-29 19:45:12 +0000895
showard8bfb5cb2009-10-07 20:49:15 +0000896 if apply_presentation:
897 query = cls.apply_presentation(query, filter_data)
898
899 return query
showard7c785282008-05-29 19:45:12 +0000900
901
jadmanski0afbb632008-06-06 21:10:57 +0000902 @classmethod
showard585c2ab2008-07-23 19:29:49 +0000903 def query_count(cls, filter_data, initial_query=None):
jadmanski0afbb632008-06-06 21:10:57 +0000904 """\
905 Like query_objects, but retreive only the count of results.
906 """
907 filter_data.pop('query_start', None)
908 filter_data.pop('query_limit', None)
showard585c2ab2008-07-23 19:29:49 +0000909 query = cls.query_objects(filter_data, initial_query=initial_query)
910 return query.count()
showard7c785282008-05-29 19:45:12 +0000911
912
jadmanski0afbb632008-06-06 21:10:57 +0000913 @classmethod
914 def clean_object_dicts(cls, field_dicts):
915 """\
916 Take a list of dicts corresponding to object (as returned by
917 query.values()) and clean the data to be more suitable for
918 returning to the user.
919 """
showarde732ee72008-09-23 19:15:43 +0000920 for field_dict in field_dicts:
921 cls.clean_foreign_keys(field_dict)
showard21baa452008-10-21 00:08:39 +0000922 cls._convert_booleans(field_dict)
showarde732ee72008-09-23 19:15:43 +0000923 cls.convert_human_readable_values(field_dict,
924 to_human_readable=True)
showard7c785282008-05-29 19:45:12 +0000925
926
jadmanski0afbb632008-06-06 21:10:57 +0000927 @classmethod
showard8bfb5cb2009-10-07 20:49:15 +0000928 def list_objects(cls, filter_data, initial_query=None):
jadmanski0afbb632008-06-06 21:10:57 +0000929 """\
930 Like query_objects, but return a list of dictionaries.
931 """
showard7ac7b7a2008-07-21 20:24:29 +0000932 query = cls.query_objects(filter_data, initial_query=initial_query)
showard8bfb5cb2009-10-07 20:49:15 +0000933 extra_fields = query.query.extra_select.keys()
934 field_dicts = [model_object.get_object_dict(extra_fields=extra_fields)
showarde732ee72008-09-23 19:15:43 +0000935 for model_object in query]
jadmanski0afbb632008-06-06 21:10:57 +0000936 return field_dicts
showard7c785282008-05-29 19:45:12 +0000937
938
jadmanski0afbb632008-06-06 21:10:57 +0000939 @classmethod
showarda4ea5742009-02-17 20:56:23 +0000940 def smart_get(cls, id_or_name, valid_only=True):
jadmanski0afbb632008-06-06 21:10:57 +0000941 """\
942 smart_get(integer) -> get object by ID
943 smart_get(string) -> get object by name_field
jadmanski0afbb632008-06-06 21:10:57 +0000944 """
showarda4ea5742009-02-17 20:56:23 +0000945 if valid_only:
946 manager = cls.get_valid_manager()
947 else:
948 manager = cls.objects
949
950 if isinstance(id_or_name, (int, long)):
951 return manager.get(pk=id_or_name)
jamesren3e9f6092010-03-11 21:32:10 +0000952 if isinstance(id_or_name, basestring) and hasattr(cls, 'name_field'):
showarda4ea5742009-02-17 20:56:23 +0000953 return manager.get(**{cls.name_field : id_or_name})
954 raise ValueError(
955 'Invalid positional argument: %s (%s)' % (id_or_name,
956 type(id_or_name)))
showard7c785282008-05-29 19:45:12 +0000957
958
showardbe3ec042008-11-12 18:16:07 +0000959 @classmethod
960 def smart_get_bulk(cls, id_or_name_list):
961 invalid_inputs = []
962 result_objects = []
963 for id_or_name in id_or_name_list:
964 try:
965 result_objects.append(cls.smart_get(id_or_name))
966 except cls.DoesNotExist:
967 invalid_inputs.append(id_or_name)
968 if invalid_inputs:
mbligh7a3ebe32008-12-01 17:10:33 +0000969 raise cls.DoesNotExist('The following %ss do not exist: %s'
970 % (cls.__name__.lower(),
971 ', '.join(invalid_inputs)))
showardbe3ec042008-11-12 18:16:07 +0000972 return result_objects
973
974
showard8bfb5cb2009-10-07 20:49:15 +0000975 def get_object_dict(self, extra_fields=None):
jadmanski0afbb632008-06-06 21:10:57 +0000976 """\
showard8bfb5cb2009-10-07 20:49:15 +0000977 Return a dictionary mapping fields to this object's values. @param
978 extra_fields: list of extra attribute names to include, in addition to
979 the fields defined on this object.
jadmanski0afbb632008-06-06 21:10:57 +0000980 """
showard8bfb5cb2009-10-07 20:49:15 +0000981 fields = self.get_field_dict().keys()
982 if extra_fields:
983 fields += extra_fields
jadmanski0afbb632008-06-06 21:10:57 +0000984 object_dict = dict((field_name, getattr(self, field_name))
showarde732ee72008-09-23 19:15:43 +0000985 for field_name in fields)
jadmanski0afbb632008-06-06 21:10:57 +0000986 self.clean_object_dicts([object_dict])
showardd3dc1992009-04-22 21:01:40 +0000987 self._postprocess_object_dict(object_dict)
jadmanski0afbb632008-06-06 21:10:57 +0000988 return object_dict
showard7c785282008-05-29 19:45:12 +0000989
990
showardd3dc1992009-04-22 21:01:40 +0000991 def _postprocess_object_dict(self, object_dict):
992 """For subclasses to override."""
993 pass
994
995
jadmanski0afbb632008-06-06 21:10:57 +0000996 @classmethod
997 def get_valid_manager(cls):
998 return cls.objects
showard7c785282008-05-29 19:45:12 +0000999
1000
showard2bab8f42008-11-12 18:15:22 +00001001 def _record_attributes(self, attributes):
1002 """
1003 See on_attribute_changed.
1004 """
1005 assert not isinstance(attributes, basestring)
1006 self._recorded_attributes = dict((attribute, getattr(self, attribute))
1007 for attribute in attributes)
1008
1009
1010 def _check_for_updated_attributes(self):
1011 """
1012 See on_attribute_changed.
1013 """
1014 for attribute, original_value in self._recorded_attributes.iteritems():
1015 new_value = getattr(self, attribute)
1016 if original_value != new_value:
1017 self.on_attribute_changed(attribute, original_value)
1018 self._record_attributes(self._recorded_attributes.keys())
1019
1020
1021 def on_attribute_changed(self, attribute, old_value):
1022 """
1023 Called whenever an attribute is updated. To be overridden.
1024
1025 To use this method, you must:
1026 * call _record_attributes() from __init__() (after making the super
1027 call) with a list of attributes for which you want to be notified upon
1028 change.
1029 * call _check_for_updated_attributes() from save().
1030 """
1031 pass
1032
1033
showard7c785282008-05-29 19:45:12 +00001034class ModelWithInvalid(ModelExtensions):
jadmanski0afbb632008-06-06 21:10:57 +00001035 """
1036 Overrides model methods save() and delete() to support invalidation in
1037 place of actual deletion. Subclasses must have a boolean "invalid"
1038 field.
1039 """
showard7c785282008-05-29 19:45:12 +00001040
showarda5288b42009-07-28 20:06:08 +00001041 def save(self, *args, **kwargs):
showardddb90992009-02-11 23:39:32 +00001042 first_time = (self.id is None)
1043 if first_time:
1044 # see if this object was previously added and invalidated
1045 my_name = getattr(self, self.name_field)
1046 filters = {self.name_field : my_name, 'invalid' : True}
1047 try:
1048 old_object = self.__class__.objects.get(**filters)
showardafd97de2009-10-01 18:45:09 +00001049 self.resurrect_object(old_object)
showardddb90992009-02-11 23:39:32 +00001050 except self.DoesNotExist:
1051 # no existing object
1052 pass
showard7c785282008-05-29 19:45:12 +00001053
showarda5288b42009-07-28 20:06:08 +00001054 super(ModelWithInvalid, self).save(*args, **kwargs)
showard7c785282008-05-29 19:45:12 +00001055
1056
showardafd97de2009-10-01 18:45:09 +00001057 def resurrect_object(self, old_object):
1058 """
1059 Called when self is about to be saved for the first time and is actually
1060 "undeleting" a previously deleted object. Can be overridden by
1061 subclasses to copy data as desired from the deleted entry (but this
1062 superclass implementation must normally be called).
1063 """
1064 self.id = old_object.id
1065
1066
jadmanski0afbb632008-06-06 21:10:57 +00001067 def clean_object(self):
1068 """
1069 This method is called when an object is marked invalid.
1070 Subclasses should override this to clean up relationships that
showardafd97de2009-10-01 18:45:09 +00001071 should no longer exist if the object were deleted.
1072 """
jadmanski0afbb632008-06-06 21:10:57 +00001073 pass
showard7c785282008-05-29 19:45:12 +00001074
1075
jadmanski0afbb632008-06-06 21:10:57 +00001076 def delete(self):
1077 assert not self.invalid
1078 self.invalid = True
1079 self.save()
1080 self.clean_object()
showard7c785282008-05-29 19:45:12 +00001081
1082
jadmanski0afbb632008-06-06 21:10:57 +00001083 @classmethod
1084 def get_valid_manager(cls):
1085 return cls.valid_objects
showard7c785282008-05-29 19:45:12 +00001086
1087
jadmanski0afbb632008-06-06 21:10:57 +00001088 class Manipulator(object):
1089 """
1090 Force default manipulators to look only at valid objects -
1091 otherwise they will match against invalid objects when checking
1092 uniqueness.
1093 """
1094 @classmethod
1095 def _prepare(cls, model):
1096 super(ModelWithInvalid.Manipulator, cls)._prepare(model)
1097 cls.manager = model.valid_objects
showardf8b19042009-05-12 17:22:49 +00001098
1099
1100class ModelWithAttributes(object):
1101 """
1102 Mixin class for models that have an attribute model associated with them.
1103 The attribute model is assumed to have its value field named "value".
1104 """
1105
1106 def _get_attribute_model_and_args(self, attribute):
1107 """
1108 Subclasses should override this to return a tuple (attribute_model,
1109 keyword_args), where attribute_model is a model class and keyword_args
1110 is a dict of args to pass to attribute_model.objects.get() to get an
1111 instance of the given attribute on this object.
1112 """
1113 raise NotImplemented
1114
1115
1116 def set_attribute(self, attribute, value):
1117 attribute_model, get_args = self._get_attribute_model_and_args(
1118 attribute)
1119 attribute_object, _ = attribute_model.objects.get_or_create(**get_args)
1120 attribute_object.value = value
1121 attribute_object.save()
1122
1123
1124 def delete_attribute(self, attribute):
1125 attribute_model, get_args = self._get_attribute_model_and_args(
1126 attribute)
1127 try:
1128 attribute_model.objects.get(**get_args).delete()
showard16245422009-09-08 16:28:15 +00001129 except attribute_model.DoesNotExist:
showardf8b19042009-05-12 17:22:49 +00001130 pass
1131
1132
1133 def set_or_delete_attribute(self, attribute, value):
1134 if value is None:
1135 self.delete_attribute(attribute)
1136 else:
1137 self.set_attribute(attribute, value)
showard26b7ec72009-12-21 22:43:57 +00001138
1139
1140class ModelWithHashManager(dbmodels.Manager):
1141 """Manager for use with the ModelWithHash abstract model class"""
1142
1143 def create(self, **kwargs):
1144 raise Exception('ModelWithHash manager should use get_or_create() '
1145 'instead of create()')
1146
1147
1148 def get_or_create(self, **kwargs):
1149 kwargs['the_hash'] = self.model._compute_hash(**kwargs)
1150 return super(ModelWithHashManager, self).get_or_create(**kwargs)
1151
1152
1153class ModelWithHash(dbmodels.Model):
1154 """Superclass with methods for dealing with a hash column"""
1155
1156 the_hash = dbmodels.CharField(max_length=40, unique=True)
1157
1158 objects = ModelWithHashManager()
1159
1160 class Meta:
1161 abstract = True
1162
1163
1164 @classmethod
1165 def _compute_hash(cls, **kwargs):
1166 raise NotImplementedError('Subclasses must override _compute_hash()')
1167
1168
1169 def save(self, force_insert=False, **kwargs):
1170 """Prevents saving the model in most cases
1171
1172 We want these models to be immutable, so the generic save() operation
1173 will not work. These models should be instantiated through their the
1174 model.objects.get_or_create() method instead.
1175
1176 The exception is that save(force_insert=True) will be allowed, since
1177 that creates a new row. However, the preferred way to make instances of
1178 these models is through the get_or_create() method.
1179 """
1180 if not force_insert:
1181 # Allow a forced insert to happen; if it's a duplicate, the unique
1182 # constraint will catch it later anyways
1183 raise Exception('ModelWithHash is immutable')
1184 super(ModelWithHash, self).save(force_insert=force_insert, **kwargs)