blob: cefdae507e4f64bb1c84f289eb7e243fed3f4857 [file] [log] [blame]
showard7c785282008-05-29 19:45:12 +00001"""
2Extensions to Django's model logic.
3"""
4
showarda5288b42009-07-28 20:06:08 +00005import re
6import django.core.exceptions
showard7c785282008-05-29 19:45:12 +00007from django.db import models as dbmodels, backend, connection
showarda5288b42009-07-28 20:06:08 +00008from django.db.models.sql import query
showard7e67b432010-01-20 01:13:04 +00009import django.db.models.sql.where
showard7c785282008-05-29 19:45:12 +000010from django.utils import datastructures
showard56e93772008-10-06 10:06:22 +000011from autotest_lib.frontend.afe import readonly_connection
showard7c785282008-05-29 19:45:12 +000012
13class ValidationError(Exception):
jadmanski0afbb632008-06-06 21:10:57 +000014 """\
showarda5288b42009-07-28 20:06:08 +000015 Data validation error in adding or updating an object. The associated
jadmanski0afbb632008-06-06 21:10:57 +000016 value is a dictionary mapping field names to error strings.
17 """
showard7c785282008-05-29 19:45:12 +000018
19
showard09096d82008-07-07 23:20:49 +000020def _wrap_with_readonly(method):
mbligh1ef218d2009-08-03 16:57:56 +000021 def wrapper_method(*args, **kwargs):
22 readonly_connection.connection().set_django_connection()
23 try:
24 return method(*args, **kwargs)
25 finally:
26 readonly_connection.connection().unset_django_connection()
27 wrapper_method.__name__ = method.__name__
28 return wrapper_method
showard09096d82008-07-07 23:20:49 +000029
30
showarda5288b42009-07-28 20:06:08 +000031def _quote_name(name):
32 """Shorthand for connection.ops.quote_name()."""
33 return connection.ops.quote_name(name)
34
35
showard09096d82008-07-07 23:20:49 +000036def _wrap_generator_with_readonly(generator):
37 """
38 We have to wrap generators specially. Assume it performs
39 the query on the first call to next().
40 """
41 def wrapper_generator(*args, **kwargs):
42 generator_obj = generator(*args, **kwargs)
showard56e93772008-10-06 10:06:22 +000043 readonly_connection.connection().set_django_connection()
showard09096d82008-07-07 23:20:49 +000044 try:
45 first_value = generator_obj.next()
46 finally:
showard56e93772008-10-06 10:06:22 +000047 readonly_connection.connection().unset_django_connection()
showard09096d82008-07-07 23:20:49 +000048 yield first_value
49
50 while True:
51 yield generator_obj.next()
52
53 wrapper_generator.__name__ = generator.__name__
54 return wrapper_generator
55
56
57def _make_queryset_readonly(queryset):
58 """
59 Wrap all methods that do database queries with a readonly connection.
60 """
61 db_query_methods = ['count', 'get', 'get_or_create', 'latest', 'in_bulk',
62 'delete']
63 for method_name in db_query_methods:
64 method = getattr(queryset, method_name)
65 wrapped_method = _wrap_with_readonly(method)
66 setattr(queryset, method_name, wrapped_method)
67
68 queryset.iterator = _wrap_generator_with_readonly(queryset.iterator)
69
70
71class ReadonlyQuerySet(dbmodels.query.QuerySet):
72 """
73 QuerySet object that performs all database queries with the read-only
74 connection.
75 """
showarda5288b42009-07-28 20:06:08 +000076 def __init__(self, model=None, *args, **kwargs):
77 super(ReadonlyQuerySet, self).__init__(model, *args, **kwargs)
showard09096d82008-07-07 23:20:49 +000078 _make_queryset_readonly(self)
79
80
81 def values(self, *fields):
showarda5288b42009-07-28 20:06:08 +000082 return self._clone(klass=ReadonlyValuesQuerySet,
83 setup=True, _fields=fields)
showard09096d82008-07-07 23:20:49 +000084
85
86class ReadonlyValuesQuerySet(dbmodels.query.ValuesQuerySet):
showarda5288b42009-07-28 20:06:08 +000087 def __init__(self, model=None, *args, **kwargs):
88 super(ReadonlyValuesQuerySet, self).__init__(model, *args, **kwargs)
showard09096d82008-07-07 23:20:49 +000089 _make_queryset_readonly(self)
90
91
beepscc9fc702013-12-02 12:45:38 -080092class LeasedHostManager(dbmodels.Manager):
93 """Query manager for unleased, unlocked hosts.
94 """
95 def get_query_set(self):
96 return (super(LeasedHostManager, self).get_query_set().filter(
97 leased=0, locked=0))
98
99
showard7c785282008-05-29 19:45:12 +0000100class ExtendedManager(dbmodels.Manager):
jadmanski0afbb632008-06-06 21:10:57 +0000101 """\
102 Extended manager supporting subquery filtering.
103 """
showard7c785282008-05-29 19:45:12 +0000104
showardf828c772010-01-25 21:49:42 +0000105 class CustomQuery(query.Query):
showard7e67b432010-01-20 01:13:04 +0000106 def __init__(self, *args, **kwargs):
showardf828c772010-01-25 21:49:42 +0000107 super(ExtendedManager.CustomQuery, self).__init__(*args, **kwargs)
showard7e67b432010-01-20 01:13:04 +0000108 self._custom_joins = []
109
110
showarda5288b42009-07-28 20:06:08 +0000111 def clone(self, klass=None, **kwargs):
showardf828c772010-01-25 21:49:42 +0000112 obj = super(ExtendedManager.CustomQuery, self).clone(klass)
showard7e67b432010-01-20 01:13:04 +0000113 obj._custom_joins = list(self._custom_joins)
showarda5288b42009-07-28 20:06:08 +0000114 return obj
showard08f981b2008-06-24 21:59:03 +0000115
showard7e67b432010-01-20 01:13:04 +0000116
117 def combine(self, rhs, connector):
showardf828c772010-01-25 21:49:42 +0000118 super(ExtendedManager.CustomQuery, self).combine(rhs, connector)
showard7e67b432010-01-20 01:13:04 +0000119 if hasattr(rhs, '_custom_joins'):
120 self._custom_joins.extend(rhs._custom_joins)
121
122
123 def add_custom_join(self, table, condition, join_type,
124 condition_values=(), alias=None):
125 if alias is None:
126 alias = table
127 join_dict = dict(table=table,
128 condition=condition,
129 condition_values=condition_values,
130 join_type=join_type,
131 alias=alias)
132 self._custom_joins.append(join_dict)
133
134
showard7e67b432010-01-20 01:13:04 +0000135 @classmethod
136 def convert_query(self, query_set):
137 """
showardf828c772010-01-25 21:49:42 +0000138 Convert the query set's "query" attribute to a CustomQuery.
showard7e67b432010-01-20 01:13:04 +0000139 """
140 # Make a copy of the query set
141 query_set = query_set.all()
142 query_set.query = query_set.query.clone(
showardf828c772010-01-25 21:49:42 +0000143 klass=ExtendedManager.CustomQuery,
showard7e67b432010-01-20 01:13:04 +0000144 _custom_joins=[])
145 return query_set
showard43a3d262008-11-12 18:17:05 +0000146
147
showard7e67b432010-01-20 01:13:04 +0000148 class _WhereClause(object):
149 """Object allowing us to inject arbitrary SQL into Django queries.
showard43a3d262008-11-12 18:17:05 +0000150
showard7e67b432010-01-20 01:13:04 +0000151 By using this instead of extra(where=...), we can still freely combine
152 queries with & and |.
showarda5288b42009-07-28 20:06:08 +0000153 """
showard7e67b432010-01-20 01:13:04 +0000154 def __init__(self, clause, values=()):
155 self._clause = clause
156 self._values = values
showarda5288b42009-07-28 20:06:08 +0000157
showard7e67b432010-01-20 01:13:04 +0000158
Dale Curtis74a314b2011-06-23 14:55:46 -0700159 def as_sql(self, qn=None, connection=None):
showard7e67b432010-01-20 01:13:04 +0000160 return self._clause, self._values
161
162
163 def relabel_aliases(self, change_map):
164 return
showard43a3d262008-11-12 18:17:05 +0000165
166
showard8b0ea222009-12-23 19:23:03 +0000167 def add_join(self, query_set, join_table, join_key, join_condition='',
showard7e67b432010-01-20 01:13:04 +0000168 join_condition_values=(), join_from_key=None, alias=None,
169 suffix='', exclude=False, force_left_join=False):
170 """Add a join to query_set.
171
172 Join looks like this:
173 (INNER|LEFT) JOIN <join_table> AS <alias>
174 ON (<this table>.<join_from_key> = <join_table>.<join_key>
175 and <join_condition>)
176
showard0957a842009-05-11 19:25:08 +0000177 @param join_table table to join to
178 @param join_key field referencing back to this model to use for the join
179 @param join_condition extra condition for the ON clause of the join
showard7e67b432010-01-20 01:13:04 +0000180 @param join_condition_values values to substitute into join_condition
181 @param join_from_key column on this model to join from.
showard8b0ea222009-12-23 19:23:03 +0000182 @param alias alias to use for for join
183 @param suffix suffix to add to join_table for the join alias, if no
184 alias is provided
showard0957a842009-05-11 19:25:08 +0000185 @param exclude if true, exclude rows that match this join (will use a
showarda5288b42009-07-28 20:06:08 +0000186 LEFT OUTER JOIN and an appropriate WHERE condition)
showardc4780402009-08-31 18:31:34 +0000187 @param force_left_join - if true, a LEFT OUTER JOIN will be used
188 instead of an INNER JOIN regardless of other options
showard0957a842009-05-11 19:25:08 +0000189 """
showard7e67b432010-01-20 01:13:04 +0000190 join_from_table = query_set.model._meta.db_table
191 if join_from_key is None:
192 join_from_key = self.model._meta.pk.name
193 if alias is None:
194 alias = join_table + suffix
195 full_join_key = _quote_name(alias) + '.' + _quote_name(join_key)
196 full_join_condition = '%s = %s.%s' % (full_join_key,
197 _quote_name(join_from_table),
198 _quote_name(join_from_key))
showard43a3d262008-11-12 18:17:05 +0000199 if join_condition:
200 full_join_condition += ' AND (' + join_condition + ')'
201 if exclude or force_left_join:
showarda5288b42009-07-28 20:06:08 +0000202 join_type = query_set.query.LOUTER
showard43a3d262008-11-12 18:17:05 +0000203 else:
showarda5288b42009-07-28 20:06:08 +0000204 join_type = query_set.query.INNER
showard43a3d262008-11-12 18:17:05 +0000205
showardf828c772010-01-25 21:49:42 +0000206 query_set = self.CustomQuery.convert_query(query_set)
showard7e67b432010-01-20 01:13:04 +0000207 query_set.query.add_custom_join(join_table,
208 full_join_condition,
209 join_type,
210 condition_values=join_condition_values,
211 alias=alias)
showard43a3d262008-11-12 18:17:05 +0000212
showard7e67b432010-01-20 01:13:04 +0000213 if exclude:
214 query_set = query_set.extra(where=[full_join_key + ' IS NULL'])
215
216 return query_set
217
218
219 def _info_for_many_to_one_join(self, field, join_to_query, alias):
220 """
221 @param field: the ForeignKey field on the related model
222 @param join_to_query: the query over the related model that we're
223 joining to
224 @param alias: alias of joined table
225 """
226 info = {}
227 rhs_table = join_to_query.model._meta.db_table
228 info['rhs_table'] = rhs_table
229 info['rhs_column'] = field.column
230 info['lhs_column'] = field.rel.get_related_field().column
231 rhs_where = join_to_query.query.where
232 rhs_where.relabel_aliases({rhs_table: alias})
Dale Curtis74a314b2011-06-23 14:55:46 -0700233 compiler = join_to_query.query.get_compiler(using=join_to_query.db)
234 initial_clause, values = compiler.as_sql()
235 all_clauses = (initial_clause,)
236 if hasattr(join_to_query.query, 'extra_where'):
237 all_clauses += join_to_query.query.extra_where
238 info['where_clause'] = (
239 ' AND '.join('(%s)' % clause for clause in all_clauses))
showard7e67b432010-01-20 01:13:04 +0000240 info['values'] = values
241 return info
242
243
244 def _info_for_many_to_many_join(self, m2m_field, join_to_query, alias,
245 m2m_is_on_this_model):
246 """
247 @param m2m_field: a Django field representing the M2M relationship.
248 It uses a pivot table with the following structure:
249 this model table <---> M2M pivot table <---> joined model table
250 @param join_to_query: the query over the related model that we're
251 joining to.
252 @param alias: alias of joined table
253 """
254 if m2m_is_on_this_model:
255 # referenced field on this model
256 lhs_id_field = self.model._meta.pk
257 # foreign key on the pivot table referencing lhs_id_field
258 m2m_lhs_column = m2m_field.m2m_column_name()
259 # foreign key on the pivot table referencing rhd_id_field
260 m2m_rhs_column = m2m_field.m2m_reverse_name()
261 # referenced field on related model
262 rhs_id_field = m2m_field.rel.get_related_field()
263 else:
264 lhs_id_field = m2m_field.rel.get_related_field()
265 m2m_lhs_column = m2m_field.m2m_reverse_name()
266 m2m_rhs_column = m2m_field.m2m_column_name()
267 rhs_id_field = join_to_query.model._meta.pk
268
269 info = {}
270 info['rhs_table'] = m2m_field.m2m_db_table()
271 info['rhs_column'] = m2m_lhs_column
272 info['lhs_column'] = lhs_id_field.column
273
274 # select the ID of related models relevant to this join. we can only do
275 # a single join, so we need to gather this information up front and
276 # include it in the join condition.
277 rhs_ids = join_to_query.values_list(rhs_id_field.attname, flat=True)
278 assert len(rhs_ids) == 1, ('Many-to-many custom field joins can only '
279 'match a single related object.')
280 rhs_id = rhs_ids[0]
281
282 info['where_clause'] = '%s.%s = %s' % (_quote_name(alias),
283 _quote_name(m2m_rhs_column),
284 rhs_id)
285 info['values'] = ()
286 return info
287
288
289 def join_custom_field(self, query_set, join_to_query, alias,
290 left_join=True):
291 """Join to a related model to create a custom field in the given query.
292
293 This method is used to construct a custom field on the given query based
294 on a many-valued relationsip. join_to_query should be a simple query
295 (no joins) on the related model which returns at most one related row
296 per instance of this model.
297
298 For many-to-one relationships, the joined table contains the matching
299 row from the related model it one is related, NULL otherwise.
300
301 For many-to-many relationships, the joined table contains the matching
302 row if it's related, NULL otherwise.
303 """
304 relationship_type, field = self.determine_relationship(
305 join_to_query.model)
306
307 if relationship_type == self.MANY_TO_ONE:
308 info = self._info_for_many_to_one_join(field, join_to_query, alias)
309 elif relationship_type == self.M2M_ON_RELATED_MODEL:
310 info = self._info_for_many_to_many_join(
311 m2m_field=field, join_to_query=join_to_query, alias=alias,
312 m2m_is_on_this_model=False)
313 elif relationship_type ==self.M2M_ON_THIS_MODEL:
314 info = self._info_for_many_to_many_join(
315 m2m_field=field, join_to_query=join_to_query, alias=alias,
316 m2m_is_on_this_model=True)
317
318 return self.add_join(query_set, info['rhs_table'], info['rhs_column'],
319 join_from_key=info['lhs_column'],
320 join_condition=info['where_clause'],
321 join_condition_values=info['values'],
322 alias=alias,
323 force_left_join=left_join)
324
325
showardf828c772010-01-25 21:49:42 +0000326 def key_on_joined_table(self, join_to_query):
327 """Get a non-null column on the table joined for the given query.
328
329 This analyzes the join that would be produced if join_to_query were
330 passed to join_custom_field.
331 """
332 relationship_type, field = self.determine_relationship(
333 join_to_query.model)
334 if relationship_type == self.MANY_TO_ONE:
335 return join_to_query.model._meta.pk.column
336 return field.m2m_column_name() # any column on the M2M table will do
337
338
showard7e67b432010-01-20 01:13:04 +0000339 def add_where(self, query_set, where, values=()):
340 query_set = query_set.all()
341 query_set.query.where.add(self._WhereClause(where, values),
342 django.db.models.sql.where.AND)
showardc4780402009-08-31 18:31:34 +0000343 return query_set
showard7c785282008-05-29 19:45:12 +0000344
345
showardeaccf8f2009-04-16 03:11:33 +0000346 def _get_quoted_field(self, table, field):
showarda5288b42009-07-28 20:06:08 +0000347 return _quote_name(table) + '.' + _quote_name(field)
showard5ef36e92008-07-02 16:37:09 +0000348
349
showard7c199df2008-10-03 10:17:15 +0000350 def get_key_on_this_table(self, key_field=None):
showard5ef36e92008-07-02 16:37:09 +0000351 if key_field is None:
352 # default to primary key
353 key_field = self.model._meta.pk.column
354 return self._get_quoted_field(self.model._meta.db_table, key_field)
355
356
showardeaccf8f2009-04-16 03:11:33 +0000357 def escape_user_sql(self, sql):
358 return sql.replace('%', '%%')
359
showard5ef36e92008-07-02 16:37:09 +0000360
showard0957a842009-05-11 19:25:08 +0000361 def _custom_select_query(self, query_set, selects):
Dale Curtis74a314b2011-06-23 14:55:46 -0700362 compiler = query_set.query.get_compiler(using=query_set.db)
363 sql, params = compiler.as_sql()
showarda5288b42009-07-28 20:06:08 +0000364 from_ = sql[sql.find(' FROM'):]
365
366 if query_set.query.distinct:
showard0957a842009-05-11 19:25:08 +0000367 distinct = 'DISTINCT '
368 else:
369 distinct = ''
showarda5288b42009-07-28 20:06:08 +0000370
371 sql_query = ('SELECT ' + distinct + ','.join(selects) + from_)
showard0957a842009-05-11 19:25:08 +0000372 cursor = readonly_connection.connection().cursor()
373 cursor.execute(sql_query, params)
374 return cursor.fetchall()
375
376
showard68693f72009-05-20 00:31:53 +0000377 def _is_relation_to(self, field, model_class):
378 return field.rel and field.rel.to is model_class
showard0957a842009-05-11 19:25:08 +0000379
380
showard7e67b432010-01-20 01:13:04 +0000381 MANY_TO_ONE = object()
382 M2M_ON_RELATED_MODEL = object()
383 M2M_ON_THIS_MODEL = object()
384
385 def determine_relationship(self, related_model):
386 """
387 Determine the relationship between this model and related_model.
388
389 related_model must have some sort of many-valued relationship to this
390 manager's model.
391 @returns (relationship_type, field), where relationship_type is one of
392 MANY_TO_ONE, M2M_ON_RELATED_MODEL, M2M_ON_THIS_MODEL, and field
393 is the Django field object for the relationship.
394 """
395 # look for a foreign key field on related_model relating to this model
396 for field in related_model._meta.fields:
397 if self._is_relation_to(field, self.model):
398 return self.MANY_TO_ONE, field
399
400 # look for an M2M field on related_model relating to this model
401 for field in related_model._meta.many_to_many:
402 if self._is_relation_to(field, self.model):
403 return self.M2M_ON_RELATED_MODEL, field
404
405 # maybe this model has the many-to-many field
406 for field in self.model._meta.many_to_many:
407 if self._is_relation_to(field, related_model):
408 return self.M2M_ON_THIS_MODEL, field
409
410 raise ValueError('%s has no relation to %s' %
411 (related_model, self.model))
412
413
showard68693f72009-05-20 00:31:53 +0000414 def _get_pivot_iterator(self, base_objects_by_id, related_model):
showard0957a842009-05-11 19:25:08 +0000415 """
showard68693f72009-05-20 00:31:53 +0000416 Determine the relationship between this model and related_model, and
417 return a pivot iterator.
418 @param base_objects_by_id: dict of instances of this model indexed by
419 their IDs
420 @returns a pivot iterator, which yields a tuple (base_object,
421 related_object) for each relationship between a base object and a
422 related object. all base_object instances come from base_objects_by_id.
showard7e67b432010-01-20 01:13:04 +0000423 Note -- this depends on Django model internals.
showard0957a842009-05-11 19:25:08 +0000424 """
showard7e67b432010-01-20 01:13:04 +0000425 relationship_type, field = self.determine_relationship(related_model)
426 if relationship_type == self.MANY_TO_ONE:
427 return self._many_to_one_pivot(base_objects_by_id,
428 related_model, field)
429 elif relationship_type == self.M2M_ON_RELATED_MODEL:
430 return self._many_to_many_pivot(
showard68693f72009-05-20 00:31:53 +0000431 base_objects_by_id, related_model, field.m2m_db_table(),
432 field.m2m_reverse_name(), field.m2m_column_name())
showard7e67b432010-01-20 01:13:04 +0000433 else:
434 assert relationship_type == self.M2M_ON_THIS_MODEL
435 return self._many_to_many_pivot(
showard68693f72009-05-20 00:31:53 +0000436 base_objects_by_id, related_model, field.m2m_db_table(),
437 field.m2m_column_name(), field.m2m_reverse_name())
showard0957a842009-05-11 19:25:08 +0000438
showard0957a842009-05-11 19:25:08 +0000439
showard68693f72009-05-20 00:31:53 +0000440 def _many_to_one_pivot(self, base_objects_by_id, related_model,
441 foreign_key_field):
442 """
443 @returns a pivot iterator - see _get_pivot_iterator()
444 """
445 filter_data = {foreign_key_field.name + '__pk__in':
446 base_objects_by_id.keys()}
447 for related_object in related_model.objects.filter(**filter_data):
showarda5a72c92009-08-20 23:35:21 +0000448 # lookup base object in the dict, rather than grabbing it from the
449 # related object. we need to return instances from the dict, not
450 # fresh instances of the same models (and grabbing model instances
451 # from the related models incurs a DB query each time).
452 base_object_id = getattr(related_object, foreign_key_field.attname)
453 base_object = base_objects_by_id[base_object_id]
showard68693f72009-05-20 00:31:53 +0000454 yield base_object, related_object
455
456
457 def _query_pivot_table(self, base_objects_by_id, pivot_table,
458 pivot_from_field, pivot_to_field):
showard0957a842009-05-11 19:25:08 +0000459 """
460 @param id_list list of IDs of self.model objects to include
461 @param pivot_table the name of the pivot table
462 @param pivot_from_field a field name on pivot_table referencing
463 self.model
464 @param pivot_to_field a field name on pivot_table referencing the
465 related model.
showard68693f72009-05-20 00:31:53 +0000466 @returns pivot list of IDs (base_id, related_id)
showard0957a842009-05-11 19:25:08 +0000467 """
468 query = """
469 SELECT %(from_field)s, %(to_field)s
470 FROM %(table)s
471 WHERE %(from_field)s IN (%(id_list)s)
472 """ % dict(from_field=pivot_from_field,
473 to_field=pivot_to_field,
474 table=pivot_table,
showard68693f72009-05-20 00:31:53 +0000475 id_list=','.join(str(id_) for id_
476 in base_objects_by_id.iterkeys()))
showard0957a842009-05-11 19:25:08 +0000477 cursor = readonly_connection.connection().cursor()
478 cursor.execute(query)
showard68693f72009-05-20 00:31:53 +0000479 return cursor.fetchall()
showard0957a842009-05-11 19:25:08 +0000480
481
showard68693f72009-05-20 00:31:53 +0000482 def _many_to_many_pivot(self, base_objects_by_id, related_model,
483 pivot_table, pivot_from_field, pivot_to_field):
484 """
485 @param pivot_table: see _query_pivot_table
486 @param pivot_from_field: see _query_pivot_table
487 @param pivot_to_field: see _query_pivot_table
488 @returns a pivot iterator - see _get_pivot_iterator()
489 """
490 id_pivot = self._query_pivot_table(base_objects_by_id, pivot_table,
491 pivot_from_field, pivot_to_field)
492
493 all_related_ids = list(set(related_id for base_id, related_id
494 in id_pivot))
495 related_objects_by_id = related_model.objects.in_bulk(all_related_ids)
496
497 for base_id, related_id in id_pivot:
498 yield base_objects_by_id[base_id], related_objects_by_id[related_id]
499
500
501 def populate_relationships(self, base_objects, related_model,
showard0957a842009-05-11 19:25:08 +0000502 related_list_name):
503 """
showard68693f72009-05-20 00:31:53 +0000504 For each instance of this model in base_objects, add a field named
505 related_list_name listing all the related objects of type related_model.
506 related_model must be in a many-to-one or many-to-many relationship with
507 this model.
508 @param base_objects - list of instances of this model
509 @param related_model - model class related to this model
510 @param related_list_name - attribute name in which to store the related
511 object list.
showard0957a842009-05-11 19:25:08 +0000512 """
showard68693f72009-05-20 00:31:53 +0000513 if not base_objects:
showard0957a842009-05-11 19:25:08 +0000514 # if we don't bail early, we'll get a SQL error later
515 return
showard0957a842009-05-11 19:25:08 +0000516
showard68693f72009-05-20 00:31:53 +0000517 base_objects_by_id = dict((base_object._get_pk_val(), base_object)
518 for base_object in base_objects)
519 pivot_iterator = self._get_pivot_iterator(base_objects_by_id,
520 related_model)
showard0957a842009-05-11 19:25:08 +0000521
showard68693f72009-05-20 00:31:53 +0000522 for base_object in base_objects:
523 setattr(base_object, related_list_name, [])
524
525 for base_object, related_object in pivot_iterator:
526 getattr(base_object, related_list_name).append(related_object)
showard0957a842009-05-11 19:25:08 +0000527
528
jamesrene3656232010-03-02 00:00:30 +0000529class ModelWithInvalidQuerySet(dbmodels.query.QuerySet):
530 """
531 QuerySet that handles delete() properly for models with an "invalid" bit
532 """
533 def delete(self):
534 for model in self:
535 model.delete()
536
537
538class ModelWithInvalidManager(ExtendedManager):
539 """
540 Manager for objects with an "invalid" bit
541 """
542 def get_query_set(self):
543 return ModelWithInvalidQuerySet(self.model)
544
545
546class ValidObjectsManager(ModelWithInvalidManager):
jadmanski0afbb632008-06-06 21:10:57 +0000547 """
548 Manager returning only objects with invalid=False.
549 """
550 def get_query_set(self):
551 queryset = super(ValidObjectsManager, self).get_query_set()
552 return queryset.filter(invalid=False)
showard7c785282008-05-29 19:45:12 +0000553
554
555class ModelExtensions(object):
jadmanski0afbb632008-06-06 21:10:57 +0000556 """\
557 Mixin with convenience functions for models, built on top of the
558 default Django model functions.
559 """
560 # TODO: at least some of these functions really belong in a custom
561 # Manager class
showard7c785282008-05-29 19:45:12 +0000562
jadmanski0afbb632008-06-06 21:10:57 +0000563 field_dict = None
564 # subclasses should override if they want to support smart_get() by name
565 name_field = None
showard7c785282008-05-29 19:45:12 +0000566
567
jadmanski0afbb632008-06-06 21:10:57 +0000568 @classmethod
569 def get_field_dict(cls):
570 if cls.field_dict is None:
571 cls.field_dict = {}
572 for field in cls._meta.fields:
573 cls.field_dict[field.name] = field
574 return cls.field_dict
showard7c785282008-05-29 19:45:12 +0000575
576
jadmanski0afbb632008-06-06 21:10:57 +0000577 @classmethod
578 def clean_foreign_keys(cls, data):
579 """\
580 -Convert foreign key fields in data from <field>_id to just
581 <field>.
582 -replace foreign key objects with their IDs
583 This method modifies data in-place.
584 """
585 for field in cls._meta.fields:
586 if not field.rel:
587 continue
588 if (field.attname != field.name and
589 field.attname in data):
590 data[field.name] = data[field.attname]
591 del data[field.attname]
showarde732ee72008-09-23 19:15:43 +0000592 if field.name not in data:
593 continue
jadmanski0afbb632008-06-06 21:10:57 +0000594 value = data[field.name]
595 if isinstance(value, dbmodels.Model):
showardf8b19042009-05-12 17:22:49 +0000596 data[field.name] = value._get_pk_val()
showard7c785282008-05-29 19:45:12 +0000597
598
showard21baa452008-10-21 00:08:39 +0000599 @classmethod
600 def _convert_booleans(cls, data):
601 """
602 Ensure BooleanFields actually get bool values. The Django MySQL
603 backend returns ints for BooleanFields, which is almost always not
604 a problem, but it can be annoying in certain situations.
605 """
606 for field in cls._meta.fields:
showardf8b19042009-05-12 17:22:49 +0000607 if type(field) == dbmodels.BooleanField and field.name in data:
showard21baa452008-10-21 00:08:39 +0000608 data[field.name] = bool(data[field.name])
609
610
jadmanski0afbb632008-06-06 21:10:57 +0000611 # TODO(showard) - is there a way to not have to do this?
612 @classmethod
613 def provide_default_values(cls, data):
614 """\
615 Provide default values for fields with default values which have
616 nothing passed in.
showard7c785282008-05-29 19:45:12 +0000617
jadmanski0afbb632008-06-06 21:10:57 +0000618 For CharField and TextField fields with "blank=True", if nothing
619 is passed, we fill in an empty string value, even if there's no
620 default set.
621 """
622 new_data = dict(data)
623 field_dict = cls.get_field_dict()
624 for name, obj in field_dict.iteritems():
625 if data.get(name) is not None:
626 continue
627 if obj.default is not dbmodels.fields.NOT_PROVIDED:
628 new_data[name] = obj.default
629 elif (isinstance(obj, dbmodels.CharField) or
630 isinstance(obj, dbmodels.TextField)):
631 new_data[name] = ''
632 return new_data
showard7c785282008-05-29 19:45:12 +0000633
634
jadmanski0afbb632008-06-06 21:10:57 +0000635 @classmethod
636 def convert_human_readable_values(cls, data, to_human_readable=False):
637 """\
638 Performs conversions on user-supplied field data, to make it
639 easier for users to pass human-readable data.
showard7c785282008-05-29 19:45:12 +0000640
jadmanski0afbb632008-06-06 21:10:57 +0000641 For all fields that have choice sets, convert their values
642 from human-readable strings to enum values, if necessary. This
643 allows users to pass strings instead of the corresponding
644 integer values.
showard7c785282008-05-29 19:45:12 +0000645
jadmanski0afbb632008-06-06 21:10:57 +0000646 For all foreign key fields, call smart_get with the supplied
647 data. This allows the user to pass either an ID value or
648 the name of the object as a string.
showard7c785282008-05-29 19:45:12 +0000649
jadmanski0afbb632008-06-06 21:10:57 +0000650 If to_human_readable=True, perform the inverse - i.e. convert
651 numeric values to human readable values.
showard7c785282008-05-29 19:45:12 +0000652
jadmanski0afbb632008-06-06 21:10:57 +0000653 This method modifies data in-place.
654 """
655 field_dict = cls.get_field_dict()
656 for field_name in data:
showarde732ee72008-09-23 19:15:43 +0000657 if field_name not in field_dict or data[field_name] is None:
jadmanski0afbb632008-06-06 21:10:57 +0000658 continue
659 field_obj = field_dict[field_name]
660 # convert enum values
661 if field_obj.choices:
662 for choice_data in field_obj.choices:
663 # choice_data is (value, name)
664 if to_human_readable:
665 from_val, to_val = choice_data
666 else:
667 to_val, from_val = choice_data
668 if from_val == data[field_name]:
669 data[field_name] = to_val
670 break
671 # convert foreign key values
672 elif field_obj.rel:
showarda4ea5742009-02-17 20:56:23 +0000673 dest_obj = field_obj.rel.to.smart_get(data[field_name],
674 valid_only=False)
showardf8b19042009-05-12 17:22:49 +0000675 if to_human_readable:
Paul Pendlebury5a8c6ad2011-02-01 07:20:17 -0800676 # parameterized_jobs do not have a name_field
677 if (field_name != 'parameterized_job' and
678 dest_obj.name_field is not None):
showardf8b19042009-05-12 17:22:49 +0000679 data[field_name] = getattr(dest_obj,
680 dest_obj.name_field)
jadmanski0afbb632008-06-06 21:10:57 +0000681 else:
showardb0a73032009-03-27 18:35:41 +0000682 data[field_name] = dest_obj
showard7c785282008-05-29 19:45:12 +0000683
684
jadmanski0afbb632008-06-06 21:10:57 +0000685 @classmethod
686 def validate_field_names(cls, data):
687 'Checks for extraneous fields in data.'
688 errors = {}
689 field_dict = cls.get_field_dict()
690 for field_name in data:
691 if field_name not in field_dict:
692 errors[field_name] = 'No field of this name'
693 return errors
showard7c785282008-05-29 19:45:12 +0000694
695
jadmanski0afbb632008-06-06 21:10:57 +0000696 @classmethod
697 def prepare_data_args(cls, data, kwargs):
698 'Common preparation for add_object and update_object'
699 data = dict(data) # don't modify the default keyword arg
700 data.update(kwargs)
701 # must check for extraneous field names here, while we have the
702 # data in a dict
703 errors = cls.validate_field_names(data)
704 if errors:
705 raise ValidationError(errors)
706 cls.convert_human_readable_values(data)
707 return data
showard7c785282008-05-29 19:45:12 +0000708
709
Dale Curtis74a314b2011-06-23 14:55:46 -0700710 def _validate_unique(self):
jadmanski0afbb632008-06-06 21:10:57 +0000711 """\
712 Validate that unique fields are unique. Django manipulators do
713 this too, but they're a huge pain to use manually. Trust me.
714 """
715 errors = {}
716 cls = type(self)
717 field_dict = self.get_field_dict()
718 manager = cls.get_valid_manager()
719 for field_name, field_obj in field_dict.iteritems():
720 if not field_obj.unique:
721 continue
showard7c785282008-05-29 19:45:12 +0000722
jadmanski0afbb632008-06-06 21:10:57 +0000723 value = getattr(self, field_name)
showardbd18ab72009-09-18 21:20:27 +0000724 if value is None and field_obj.auto_created:
725 # don't bother checking autoincrement fields about to be
726 # generated
727 continue
728
jadmanski0afbb632008-06-06 21:10:57 +0000729 existing_objs = manager.filter(**{field_name : value})
730 num_existing = existing_objs.count()
showard7c785282008-05-29 19:45:12 +0000731
jadmanski0afbb632008-06-06 21:10:57 +0000732 if num_existing == 0:
733 continue
734 if num_existing == 1 and existing_objs[0].id == self.id:
735 continue
736 errors[field_name] = (
737 'This value must be unique (%s)' % (value))
738 return errors
showard7c785282008-05-29 19:45:12 +0000739
740
showarda5288b42009-07-28 20:06:08 +0000741 def _validate(self):
742 """
743 First coerces all fields on this instance to their proper Python types.
744 Then runs validation on every field. Returns a dictionary of
745 field_name -> error_list.
746
747 Based on validate() from django.db.models.Model in Django 0.96, which
748 was removed in Django 1.0. It should reappear in a later version. See:
749 http://code.djangoproject.com/ticket/6845
750 """
751 error_dict = {}
752 for f in self._meta.fields:
753 try:
754 python_value = f.to_python(
755 getattr(self, f.attname, f.get_default()))
756 except django.core.exceptions.ValidationError, e:
jamesren1e0a4ce2010-04-21 17:45:11 +0000757 error_dict[f.name] = str(e)
showarda5288b42009-07-28 20:06:08 +0000758 continue
759
760 if not f.blank and not python_value:
761 error_dict[f.name] = 'This field is required.'
762 continue
763
764 setattr(self, f.attname, python_value)
765
766 return error_dict
767
768
jadmanski0afbb632008-06-06 21:10:57 +0000769 def do_validate(self):
showarda5288b42009-07-28 20:06:08 +0000770 errors = self._validate()
Dale Curtis74a314b2011-06-23 14:55:46 -0700771 unique_errors = self._validate_unique()
jadmanski0afbb632008-06-06 21:10:57 +0000772 for field_name, error in unique_errors.iteritems():
773 errors.setdefault(field_name, error)
774 if errors:
775 raise ValidationError(errors)
showard7c785282008-05-29 19:45:12 +0000776
777
jadmanski0afbb632008-06-06 21:10:57 +0000778 # actually (externally) useful methods follow
showard7c785282008-05-29 19:45:12 +0000779
jadmanski0afbb632008-06-06 21:10:57 +0000780 @classmethod
781 def add_object(cls, data={}, **kwargs):
782 """\
783 Returns a new object created with the given data (a dictionary
784 mapping field names to values). Merges any extra keyword args
785 into data.
786 """
787 data = cls.prepare_data_args(data, kwargs)
788 data = cls.provide_default_values(data)
789 obj = cls(**data)
790 obj.do_validate()
791 obj.save()
792 return obj
showard7c785282008-05-29 19:45:12 +0000793
794
jadmanski0afbb632008-06-06 21:10:57 +0000795 def update_object(self, data={}, **kwargs):
796 """\
797 Updates the object with the given data (a dictionary mapping
798 field names to values). Merges any extra keyword args into
799 data.
800 """
801 data = self.prepare_data_args(data, kwargs)
802 for field_name, value in data.iteritems():
showardb0a73032009-03-27 18:35:41 +0000803 setattr(self, field_name, value)
jadmanski0afbb632008-06-06 21:10:57 +0000804 self.do_validate()
805 self.save()
showard7c785282008-05-29 19:45:12 +0000806
807
showard8bfb5cb2009-10-07 20:49:15 +0000808 # see query_objects()
809 _SPECIAL_FILTER_KEYS = ('query_start', 'query_limit', 'sort_by',
810 'extra_args', 'extra_where', 'no_distinct')
811
812
jadmanski0afbb632008-06-06 21:10:57 +0000813 @classmethod
showard8bfb5cb2009-10-07 20:49:15 +0000814 def _extract_special_params(cls, filter_data):
815 """
816 @returns a tuple of dicts (special_params, regular_filters), where
817 special_params contains the parameters we handle specially and
818 regular_filters is the remaining data to be handled by Django.
819 """
820 regular_filters = dict(filter_data)
821 special_params = {}
822 for key in cls._SPECIAL_FILTER_KEYS:
823 if key in regular_filters:
824 special_params[key] = regular_filters.pop(key)
825 return special_params, regular_filters
826
827
828 @classmethod
829 def apply_presentation(cls, query, filter_data):
830 """
831 Apply presentation parameters -- sorting and paging -- to the given
832 query.
833 @returns new query with presentation applied
834 """
835 special_params, _ = cls._extract_special_params(filter_data)
836 sort_by = special_params.get('sort_by', None)
837 if sort_by:
838 assert isinstance(sort_by, list) or isinstance(sort_by, tuple)
showard8b0ea222009-12-23 19:23:03 +0000839 query = query.extra(order_by=sort_by)
showard8bfb5cb2009-10-07 20:49:15 +0000840
841 query_start = special_params.get('query_start', None)
842 query_limit = special_params.get('query_limit', None)
843 if query_start is not None:
844 if query_limit is None:
845 raise ValueError('Cannot pass query_start without query_limit')
846 # query_limit is passed as a page size
showard7074b742009-10-12 20:30:04 +0000847 query_limit += query_start
848 return query[query_start:query_limit]
showard8bfb5cb2009-10-07 20:49:15 +0000849
850
851 @classmethod
852 def query_objects(cls, filter_data, valid_only=True, initial_query=None,
853 apply_presentation=True):
jadmanski0afbb632008-06-06 21:10:57 +0000854 """\
855 Returns a QuerySet object for querying the given model_class
856 with the given filter_data. Optional special arguments in
857 filter_data include:
858 -query_start: index of first return to return
859 -query_limit: maximum number of results to return
860 -sort_by: list of fields to sort on. prefixing a '-' onto a
861 field name changes the sort to descending order.
862 -extra_args: keyword args to pass to query.extra() (see Django
863 DB layer documentation)
showarda5288b42009-07-28 20:06:08 +0000864 -extra_where: extra WHERE clause to append
showard8bfb5cb2009-10-07 20:49:15 +0000865 -no_distinct: if True, a DISTINCT will not be added to the SELECT
jadmanski0afbb632008-06-06 21:10:57 +0000866 """
showard8bfb5cb2009-10-07 20:49:15 +0000867 special_params, regular_filters = cls._extract_special_params(
868 filter_data)
showard7c785282008-05-29 19:45:12 +0000869
showard7ac7b7a2008-07-21 20:24:29 +0000870 if initial_query is None:
871 if valid_only:
872 initial_query = cls.get_valid_manager()
873 else:
874 initial_query = cls.objects
showard8bfb5cb2009-10-07 20:49:15 +0000875
876 query = initial_query.filter(**regular_filters)
877
878 use_distinct = not special_params.get('no_distinct', False)
showard7ac7b7a2008-07-21 20:24:29 +0000879 if use_distinct:
880 query = query.distinct()
showard7c785282008-05-29 19:45:12 +0000881
showard8bfb5cb2009-10-07 20:49:15 +0000882 extra_args = special_params.get('extra_args', {})
883 extra_where = special_params.get('extra_where', None)
884 if extra_where:
885 # escape %'s
886 extra_where = cls.objects.escape_user_sql(extra_where)
887 extra_args.setdefault('where', []).append(extra_where)
jadmanski0afbb632008-06-06 21:10:57 +0000888 if extra_args:
889 query = query.extra(**extra_args)
showard09096d82008-07-07 23:20:49 +0000890 query = query._clone(klass=ReadonlyQuerySet)
showard7c785282008-05-29 19:45:12 +0000891
showard8bfb5cb2009-10-07 20:49:15 +0000892 if apply_presentation:
893 query = cls.apply_presentation(query, filter_data)
894
895 return query
showard7c785282008-05-29 19:45:12 +0000896
897
jadmanski0afbb632008-06-06 21:10:57 +0000898 @classmethod
showard585c2ab2008-07-23 19:29:49 +0000899 def query_count(cls, filter_data, initial_query=None):
jadmanski0afbb632008-06-06 21:10:57 +0000900 """\
901 Like query_objects, but retreive only the count of results.
902 """
903 filter_data.pop('query_start', None)
904 filter_data.pop('query_limit', None)
showard585c2ab2008-07-23 19:29:49 +0000905 query = cls.query_objects(filter_data, initial_query=initial_query)
906 return query.count()
showard7c785282008-05-29 19:45:12 +0000907
908
jadmanski0afbb632008-06-06 21:10:57 +0000909 @classmethod
910 def clean_object_dicts(cls, field_dicts):
911 """\
912 Take a list of dicts corresponding to object (as returned by
913 query.values()) and clean the data to be more suitable for
914 returning to the user.
915 """
showarde732ee72008-09-23 19:15:43 +0000916 for field_dict in field_dicts:
917 cls.clean_foreign_keys(field_dict)
showard21baa452008-10-21 00:08:39 +0000918 cls._convert_booleans(field_dict)
showarde732ee72008-09-23 19:15:43 +0000919 cls.convert_human_readable_values(field_dict,
920 to_human_readable=True)
showard7c785282008-05-29 19:45:12 +0000921
922
jadmanski0afbb632008-06-06 21:10:57 +0000923 @classmethod
showard8bfb5cb2009-10-07 20:49:15 +0000924 def list_objects(cls, filter_data, initial_query=None):
jadmanski0afbb632008-06-06 21:10:57 +0000925 """\
926 Like query_objects, but return a list of dictionaries.
927 """
showard7ac7b7a2008-07-21 20:24:29 +0000928 query = cls.query_objects(filter_data, initial_query=initial_query)
showard8bfb5cb2009-10-07 20:49:15 +0000929 extra_fields = query.query.extra_select.keys()
930 field_dicts = [model_object.get_object_dict(extra_fields=extra_fields)
showarde732ee72008-09-23 19:15:43 +0000931 for model_object in query]
jadmanski0afbb632008-06-06 21:10:57 +0000932 return field_dicts
showard7c785282008-05-29 19:45:12 +0000933
934
jadmanski0afbb632008-06-06 21:10:57 +0000935 @classmethod
showarda4ea5742009-02-17 20:56:23 +0000936 def smart_get(cls, id_or_name, valid_only=True):
jadmanski0afbb632008-06-06 21:10:57 +0000937 """\
938 smart_get(integer) -> get object by ID
939 smart_get(string) -> get object by name_field
jadmanski0afbb632008-06-06 21:10:57 +0000940 """
showarda4ea5742009-02-17 20:56:23 +0000941 if valid_only:
942 manager = cls.get_valid_manager()
943 else:
944 manager = cls.objects
945
946 if isinstance(id_or_name, (int, long)):
947 return manager.get(pk=id_or_name)
jamesren3e9f6092010-03-11 21:32:10 +0000948 if isinstance(id_or_name, basestring) and hasattr(cls, 'name_field'):
showarda4ea5742009-02-17 20:56:23 +0000949 return manager.get(**{cls.name_field : id_or_name})
950 raise ValueError(
951 'Invalid positional argument: %s (%s)' % (id_or_name,
952 type(id_or_name)))
showard7c785282008-05-29 19:45:12 +0000953
954
showardbe3ec042008-11-12 18:16:07 +0000955 @classmethod
956 def smart_get_bulk(cls, id_or_name_list):
957 invalid_inputs = []
958 result_objects = []
959 for id_or_name in id_or_name_list:
960 try:
961 result_objects.append(cls.smart_get(id_or_name))
962 except cls.DoesNotExist:
963 invalid_inputs.append(id_or_name)
964 if invalid_inputs:
mbligh7a3ebe32008-12-01 17:10:33 +0000965 raise cls.DoesNotExist('The following %ss do not exist: %s'
966 % (cls.__name__.lower(),
967 ', '.join(invalid_inputs)))
showardbe3ec042008-11-12 18:16:07 +0000968 return result_objects
969
970
showard8bfb5cb2009-10-07 20:49:15 +0000971 def get_object_dict(self, extra_fields=None):
jadmanski0afbb632008-06-06 21:10:57 +0000972 """\
showard8bfb5cb2009-10-07 20:49:15 +0000973 Return a dictionary mapping fields to this object's values. @param
974 extra_fields: list of extra attribute names to include, in addition to
975 the fields defined on this object.
jadmanski0afbb632008-06-06 21:10:57 +0000976 """
showard8bfb5cb2009-10-07 20:49:15 +0000977 fields = self.get_field_dict().keys()
978 if extra_fields:
979 fields += extra_fields
jadmanski0afbb632008-06-06 21:10:57 +0000980 object_dict = dict((field_name, getattr(self, field_name))
showarde732ee72008-09-23 19:15:43 +0000981 for field_name in fields)
jadmanski0afbb632008-06-06 21:10:57 +0000982 self.clean_object_dicts([object_dict])
showardd3dc1992009-04-22 21:01:40 +0000983 self._postprocess_object_dict(object_dict)
jadmanski0afbb632008-06-06 21:10:57 +0000984 return object_dict
showard7c785282008-05-29 19:45:12 +0000985
986
showardd3dc1992009-04-22 21:01:40 +0000987 def _postprocess_object_dict(self, object_dict):
988 """For subclasses to override."""
989 pass
990
991
jadmanski0afbb632008-06-06 21:10:57 +0000992 @classmethod
993 def get_valid_manager(cls):
994 return cls.objects
showard7c785282008-05-29 19:45:12 +0000995
996
showard2bab8f42008-11-12 18:15:22 +0000997 def _record_attributes(self, attributes):
998 """
999 See on_attribute_changed.
1000 """
1001 assert not isinstance(attributes, basestring)
1002 self._recorded_attributes = dict((attribute, getattr(self, attribute))
1003 for attribute in attributes)
1004
1005
1006 def _check_for_updated_attributes(self):
1007 """
1008 See on_attribute_changed.
1009 """
1010 for attribute, original_value in self._recorded_attributes.iteritems():
1011 new_value = getattr(self, attribute)
1012 if original_value != new_value:
1013 self.on_attribute_changed(attribute, original_value)
1014 self._record_attributes(self._recorded_attributes.keys())
1015
1016
1017 def on_attribute_changed(self, attribute, old_value):
1018 """
1019 Called whenever an attribute is updated. To be overridden.
1020
1021 To use this method, you must:
1022 * call _record_attributes() from __init__() (after making the super
1023 call) with a list of attributes for which you want to be notified upon
1024 change.
1025 * call _check_for_updated_attributes() from save().
1026 """
1027 pass
1028
1029
showard7c785282008-05-29 19:45:12 +00001030class ModelWithInvalid(ModelExtensions):
jadmanski0afbb632008-06-06 21:10:57 +00001031 """
1032 Overrides model methods save() and delete() to support invalidation in
1033 place of actual deletion. Subclasses must have a boolean "invalid"
1034 field.
1035 """
showard7c785282008-05-29 19:45:12 +00001036
showarda5288b42009-07-28 20:06:08 +00001037 def save(self, *args, **kwargs):
showardddb90992009-02-11 23:39:32 +00001038 first_time = (self.id is None)
1039 if first_time:
1040 # see if this object was previously added and invalidated
1041 my_name = getattr(self, self.name_field)
1042 filters = {self.name_field : my_name, 'invalid' : True}
1043 try:
1044 old_object = self.__class__.objects.get(**filters)
showardafd97de2009-10-01 18:45:09 +00001045 self.resurrect_object(old_object)
showardddb90992009-02-11 23:39:32 +00001046 except self.DoesNotExist:
1047 # no existing object
1048 pass
showard7c785282008-05-29 19:45:12 +00001049
showarda5288b42009-07-28 20:06:08 +00001050 super(ModelWithInvalid, self).save(*args, **kwargs)
showard7c785282008-05-29 19:45:12 +00001051
1052
showardafd97de2009-10-01 18:45:09 +00001053 def resurrect_object(self, old_object):
1054 """
1055 Called when self is about to be saved for the first time and is actually
1056 "undeleting" a previously deleted object. Can be overridden by
1057 subclasses to copy data as desired from the deleted entry (but this
1058 superclass implementation must normally be called).
1059 """
1060 self.id = old_object.id
1061
1062
jadmanski0afbb632008-06-06 21:10:57 +00001063 def clean_object(self):
1064 """
1065 This method is called when an object is marked invalid.
1066 Subclasses should override this to clean up relationships that
showardafd97de2009-10-01 18:45:09 +00001067 should no longer exist if the object were deleted.
1068 """
jadmanski0afbb632008-06-06 21:10:57 +00001069 pass
showard7c785282008-05-29 19:45:12 +00001070
1071
jadmanski0afbb632008-06-06 21:10:57 +00001072 def delete(self):
Dale Curtis74a314b2011-06-23 14:55:46 -07001073 self.invalid = self.invalid
jadmanski0afbb632008-06-06 21:10:57 +00001074 assert not self.invalid
1075 self.invalid = True
1076 self.save()
1077 self.clean_object()
showard7c785282008-05-29 19:45:12 +00001078
1079
jadmanski0afbb632008-06-06 21:10:57 +00001080 @classmethod
1081 def get_valid_manager(cls):
1082 return cls.valid_objects
showard7c785282008-05-29 19:45:12 +00001083
1084
jadmanski0afbb632008-06-06 21:10:57 +00001085 class Manipulator(object):
1086 """
1087 Force default manipulators to look only at valid objects -
1088 otherwise they will match against invalid objects when checking
1089 uniqueness.
1090 """
1091 @classmethod
1092 def _prepare(cls, model):
1093 super(ModelWithInvalid.Manipulator, cls)._prepare(model)
1094 cls.manager = model.valid_objects
showardf8b19042009-05-12 17:22:49 +00001095
1096
1097class ModelWithAttributes(object):
1098 """
1099 Mixin class for models that have an attribute model associated with them.
1100 The attribute model is assumed to have its value field named "value".
1101 """
1102
1103 def _get_attribute_model_and_args(self, attribute):
1104 """
1105 Subclasses should override this to return a tuple (attribute_model,
1106 keyword_args), where attribute_model is a model class and keyword_args
1107 is a dict of args to pass to attribute_model.objects.get() to get an
1108 instance of the given attribute on this object.
1109 """
Dale Curtis74a314b2011-06-23 14:55:46 -07001110 raise NotImplementedError
showardf8b19042009-05-12 17:22:49 +00001111
1112
1113 def set_attribute(self, attribute, value):
1114 attribute_model, get_args = self._get_attribute_model_and_args(
1115 attribute)
1116 attribute_object, _ = attribute_model.objects.get_or_create(**get_args)
1117 attribute_object.value = value
1118 attribute_object.save()
1119
1120
1121 def delete_attribute(self, attribute):
1122 attribute_model, get_args = self._get_attribute_model_and_args(
1123 attribute)
1124 try:
1125 attribute_model.objects.get(**get_args).delete()
showard16245422009-09-08 16:28:15 +00001126 except attribute_model.DoesNotExist:
showardf8b19042009-05-12 17:22:49 +00001127 pass
1128
1129
1130 def set_or_delete_attribute(self, attribute, value):
1131 if value is None:
1132 self.delete_attribute(attribute)
1133 else:
1134 self.set_attribute(attribute, value)
showard26b7ec72009-12-21 22:43:57 +00001135
1136
1137class ModelWithHashManager(dbmodels.Manager):
1138 """Manager for use with the ModelWithHash abstract model class"""
1139
1140 def create(self, **kwargs):
1141 raise Exception('ModelWithHash manager should use get_or_create() '
1142 'instead of create()')
1143
1144
1145 def get_or_create(self, **kwargs):
1146 kwargs['the_hash'] = self.model._compute_hash(**kwargs)
1147 return super(ModelWithHashManager, self).get_or_create(**kwargs)
1148
1149
1150class ModelWithHash(dbmodels.Model):
1151 """Superclass with methods for dealing with a hash column"""
1152
1153 the_hash = dbmodels.CharField(max_length=40, unique=True)
1154
1155 objects = ModelWithHashManager()
1156
1157 class Meta:
1158 abstract = True
1159
1160
1161 @classmethod
1162 def _compute_hash(cls, **kwargs):
1163 raise NotImplementedError('Subclasses must override _compute_hash()')
1164
1165
1166 def save(self, force_insert=False, **kwargs):
1167 """Prevents saving the model in most cases
1168
1169 We want these models to be immutable, so the generic save() operation
1170 will not work. These models should be instantiated through their the
1171 model.objects.get_or_create() method instead.
1172
1173 The exception is that save(force_insert=True) will be allowed, since
1174 that creates a new row. However, the preferred way to make instances of
1175 these models is through the get_or_create() method.
1176 """
1177 if not force_insert:
1178 # Allow a forced insert to happen; if it's a duplicate, the unique
1179 # constraint will catch it later anyways
1180 raise Exception('ModelWithHash is immutable')
1181 super(ModelWithHash, self).save(force_insert=force_insert, **kwargs)