blob: 7fbdb76590fc71bfc4d1274684c033264f4770e1 [file] [log] [blame]
showard7c785282008-05-29 19:45:12 +00001"""
2Extensions to Django's model logic.
3"""
4
showarda5288b42009-07-28 20:06:08 +00005import re
6import django.core.exceptions
showard7c785282008-05-29 19:45:12 +00007from django.db import models as dbmodels, backend, connection
showarda5288b42009-07-28 20:06:08 +00008from django.db.models.sql import query
showard7e67b432010-01-20 01:13:04 +00009import django.db.models.sql.where
showard7c785282008-05-29 19:45:12 +000010from django.utils import datastructures
showard56e93772008-10-06 10:06:22 +000011from autotest_lib.frontend.afe import readonly_connection
showard7c785282008-05-29 19:45:12 +000012
13class ValidationError(Exception):
jadmanski0afbb632008-06-06 21:10:57 +000014 """\
showarda5288b42009-07-28 20:06:08 +000015 Data validation error in adding or updating an object. The associated
jadmanski0afbb632008-06-06 21:10:57 +000016 value is a dictionary mapping field names to error strings.
17 """
showard7c785282008-05-29 19:45:12 +000018
19
showard09096d82008-07-07 23:20:49 +000020def _wrap_with_readonly(method):
mbligh1ef218d2009-08-03 16:57:56 +000021 def wrapper_method(*args, **kwargs):
22 readonly_connection.connection().set_django_connection()
23 try:
24 return method(*args, **kwargs)
25 finally:
26 readonly_connection.connection().unset_django_connection()
27 wrapper_method.__name__ = method.__name__
28 return wrapper_method
showard09096d82008-07-07 23:20:49 +000029
30
showarda5288b42009-07-28 20:06:08 +000031def _quote_name(name):
32 """Shorthand for connection.ops.quote_name()."""
33 return connection.ops.quote_name(name)
34
35
showard09096d82008-07-07 23:20:49 +000036def _wrap_generator_with_readonly(generator):
37 """
38 We have to wrap generators specially. Assume it performs
39 the query on the first call to next().
40 """
41 def wrapper_generator(*args, **kwargs):
42 generator_obj = generator(*args, **kwargs)
showard56e93772008-10-06 10:06:22 +000043 readonly_connection.connection().set_django_connection()
showard09096d82008-07-07 23:20:49 +000044 try:
45 first_value = generator_obj.next()
46 finally:
showard56e93772008-10-06 10:06:22 +000047 readonly_connection.connection().unset_django_connection()
showard09096d82008-07-07 23:20:49 +000048 yield first_value
49
50 while True:
51 yield generator_obj.next()
52
53 wrapper_generator.__name__ = generator.__name__
54 return wrapper_generator
55
56
57def _make_queryset_readonly(queryset):
58 """
59 Wrap all methods that do database queries with a readonly connection.
60 """
61 db_query_methods = ['count', 'get', 'get_or_create', 'latest', 'in_bulk',
62 'delete']
63 for method_name in db_query_methods:
64 method = getattr(queryset, method_name)
65 wrapped_method = _wrap_with_readonly(method)
66 setattr(queryset, method_name, wrapped_method)
67
68 queryset.iterator = _wrap_generator_with_readonly(queryset.iterator)
69
70
71class ReadonlyQuerySet(dbmodels.query.QuerySet):
72 """
73 QuerySet object that performs all database queries with the read-only
74 connection.
75 """
showarda5288b42009-07-28 20:06:08 +000076 def __init__(self, model=None, *args, **kwargs):
77 super(ReadonlyQuerySet, self).__init__(model, *args, **kwargs)
showard09096d82008-07-07 23:20:49 +000078 _make_queryset_readonly(self)
79
80
81 def values(self, *fields):
showarda5288b42009-07-28 20:06:08 +000082 return self._clone(klass=ReadonlyValuesQuerySet,
83 setup=True, _fields=fields)
showard09096d82008-07-07 23:20:49 +000084
85
86class ReadonlyValuesQuerySet(dbmodels.query.ValuesQuerySet):
showarda5288b42009-07-28 20:06:08 +000087 def __init__(self, model=None, *args, **kwargs):
88 super(ReadonlyValuesQuerySet, self).__init__(model, *args, **kwargs)
showard09096d82008-07-07 23:20:49 +000089 _make_queryset_readonly(self)
90
91
showard7c785282008-05-29 19:45:12 +000092class ExtendedManager(dbmodels.Manager):
jadmanski0afbb632008-06-06 21:10:57 +000093 """\
94 Extended manager supporting subquery filtering.
95 """
showard7c785282008-05-29 19:45:12 +000096
showarda5288b42009-07-28 20:06:08 +000097 class _CustomQuery(query.Query):
showard7e67b432010-01-20 01:13:04 +000098 def __init__(self, *args, **kwargs):
99 super(ExtendedManager._CustomQuery, self).__init__(*args, **kwargs)
100 self._custom_joins = []
101
102
showarda5288b42009-07-28 20:06:08 +0000103 def clone(self, klass=None, **kwargs):
showard7e67b432010-01-20 01:13:04 +0000104 obj = super(ExtendedManager._CustomQuery, self).clone(klass)
105 obj._custom_joins = list(self._custom_joins)
showarda5288b42009-07-28 20:06:08 +0000106 return obj
showard08f981b2008-06-24 21:59:03 +0000107
showard7e67b432010-01-20 01:13:04 +0000108
109 def combine(self, rhs, connector):
110 super(ExtendedManager._CustomQuery, self).combine(rhs, connector)
111 if hasattr(rhs, '_custom_joins'):
112 self._custom_joins.extend(rhs._custom_joins)
113
114
115 def add_custom_join(self, table, condition, join_type,
116 condition_values=(), alias=None):
117 if alias is None:
118 alias = table
119 join_dict = dict(table=table,
120 condition=condition,
121 condition_values=condition_values,
122 join_type=join_type,
123 alias=alias)
124 self._custom_joins.append(join_dict)
125
126
showarda5288b42009-07-28 20:06:08 +0000127 def get_from_clause(self):
showard7e67b432010-01-20 01:13:04 +0000128 from_, params = (super(ExtendedManager._CustomQuery, self)
129 .get_from_clause())
showard08f981b2008-06-24 21:59:03 +0000130
showard7e67b432010-01-20 01:13:04 +0000131 for join_dict in self._custom_joins:
132 from_.append('%s %s AS %s ON (%s)'
133 % (join_dict['join_type'],
134 _quote_name(join_dict['table']),
135 _quote_name(join_dict['alias']),
136 join_dict['condition']))
137 params.extend(join_dict['condition_values'])
showard7c785282008-05-29 19:45:12 +0000138
showarda5288b42009-07-28 20:06:08 +0000139 return from_, params
showard7c785282008-05-29 19:45:12 +0000140
141
showard7e67b432010-01-20 01:13:04 +0000142 @classmethod
143 def convert_query(self, query_set):
144 """
145 Convert the query set's "query" attribute to a _CustomQuery.
146 """
147 # Make a copy of the query set
148 query_set = query_set.all()
149 query_set.query = query_set.query.clone(
150 klass=ExtendedManager._CustomQuery,
151 _custom_joins=[])
152 return query_set
showard43a3d262008-11-12 18:17:05 +0000153
154
showard7e67b432010-01-20 01:13:04 +0000155 class _WhereClause(object):
156 """Object allowing us to inject arbitrary SQL into Django queries.
showard43a3d262008-11-12 18:17:05 +0000157
showard7e67b432010-01-20 01:13:04 +0000158 By using this instead of extra(where=...), we can still freely combine
159 queries with & and |.
showarda5288b42009-07-28 20:06:08 +0000160 """
showard7e67b432010-01-20 01:13:04 +0000161 def __init__(self, clause, values=()):
162 self._clause = clause
163 self._values = values
showarda5288b42009-07-28 20:06:08 +0000164
showard7e67b432010-01-20 01:13:04 +0000165
166 def as_sql(self, qn=None):
167 return self._clause, self._values
168
169
170 def relabel_aliases(self, change_map):
171 return
showard43a3d262008-11-12 18:17:05 +0000172
173
showard8b0ea222009-12-23 19:23:03 +0000174 def add_join(self, query_set, join_table, join_key, join_condition='',
showard7e67b432010-01-20 01:13:04 +0000175 join_condition_values=(), join_from_key=None, alias=None,
176 suffix='', exclude=False, force_left_join=False):
177 """Add a join to query_set.
178
179 Join looks like this:
180 (INNER|LEFT) JOIN <join_table> AS <alias>
181 ON (<this table>.<join_from_key> = <join_table>.<join_key>
182 and <join_condition>)
183
showard0957a842009-05-11 19:25:08 +0000184 @param join_table table to join to
185 @param join_key field referencing back to this model to use for the join
186 @param join_condition extra condition for the ON clause of the join
showard7e67b432010-01-20 01:13:04 +0000187 @param join_condition_values values to substitute into join_condition
188 @param join_from_key column on this model to join from.
showard8b0ea222009-12-23 19:23:03 +0000189 @param alias alias to use for for join
190 @param suffix suffix to add to join_table for the join alias, if no
191 alias is provided
showard0957a842009-05-11 19:25:08 +0000192 @param exclude if true, exclude rows that match this join (will use a
showarda5288b42009-07-28 20:06:08 +0000193 LEFT OUTER JOIN and an appropriate WHERE condition)
showardc4780402009-08-31 18:31:34 +0000194 @param force_left_join - if true, a LEFT OUTER JOIN will be used
195 instead of an INNER JOIN regardless of other options
showard0957a842009-05-11 19:25:08 +0000196 """
showard7e67b432010-01-20 01:13:04 +0000197 join_from_table = query_set.model._meta.db_table
198 if join_from_key is None:
199 join_from_key = self.model._meta.pk.name
200 if alias is None:
201 alias = join_table + suffix
202 full_join_key = _quote_name(alias) + '.' + _quote_name(join_key)
203 full_join_condition = '%s = %s.%s' % (full_join_key,
204 _quote_name(join_from_table),
205 _quote_name(join_from_key))
showard43a3d262008-11-12 18:17:05 +0000206 if join_condition:
207 full_join_condition += ' AND (' + join_condition + ')'
208 if exclude or force_left_join:
showarda5288b42009-07-28 20:06:08 +0000209 join_type = query_set.query.LOUTER
showard43a3d262008-11-12 18:17:05 +0000210 else:
showarda5288b42009-07-28 20:06:08 +0000211 join_type = query_set.query.INNER
showard43a3d262008-11-12 18:17:05 +0000212
showard7e67b432010-01-20 01:13:04 +0000213 query_set = self._CustomQuery.convert_query(query_set)
214 query_set.query.add_custom_join(join_table,
215 full_join_condition,
216 join_type,
217 condition_values=join_condition_values,
218 alias=alias)
showard43a3d262008-11-12 18:17:05 +0000219
showard7e67b432010-01-20 01:13:04 +0000220 if exclude:
221 query_set = query_set.extra(where=[full_join_key + ' IS NULL'])
222
223 return query_set
224
225
226 def _info_for_many_to_one_join(self, field, join_to_query, alias):
227 """
228 @param field: the ForeignKey field on the related model
229 @param join_to_query: the query over the related model that we're
230 joining to
231 @param alias: alias of joined table
232 """
233 info = {}
234 rhs_table = join_to_query.model._meta.db_table
235 info['rhs_table'] = rhs_table
236 info['rhs_column'] = field.column
237 info['lhs_column'] = field.rel.get_related_field().column
238 rhs_where = join_to_query.query.where
239 rhs_where.relabel_aliases({rhs_table: alias})
240 initial_clause, values = rhs_where.as_sql()
241 all_clauses = (initial_clause,) + join_to_query.query.extra_where
242 info['where_clause'] = ' AND '.join('(%s)' % clause
243 for clause in all_clauses)
244 values += join_to_query.query.extra_params
245 info['values'] = values
246 return info
247
248
249 def _info_for_many_to_many_join(self, m2m_field, join_to_query, alias,
250 m2m_is_on_this_model):
251 """
252 @param m2m_field: a Django field representing the M2M relationship.
253 It uses a pivot table with the following structure:
254 this model table <---> M2M pivot table <---> joined model table
255 @param join_to_query: the query over the related model that we're
256 joining to.
257 @param alias: alias of joined table
258 """
259 if m2m_is_on_this_model:
260 # referenced field on this model
261 lhs_id_field = self.model._meta.pk
262 # foreign key on the pivot table referencing lhs_id_field
263 m2m_lhs_column = m2m_field.m2m_column_name()
264 # foreign key on the pivot table referencing rhd_id_field
265 m2m_rhs_column = m2m_field.m2m_reverse_name()
266 # referenced field on related model
267 rhs_id_field = m2m_field.rel.get_related_field()
268 else:
269 lhs_id_field = m2m_field.rel.get_related_field()
270 m2m_lhs_column = m2m_field.m2m_reverse_name()
271 m2m_rhs_column = m2m_field.m2m_column_name()
272 rhs_id_field = join_to_query.model._meta.pk
273
274 info = {}
275 info['rhs_table'] = m2m_field.m2m_db_table()
276 info['rhs_column'] = m2m_lhs_column
277 info['lhs_column'] = lhs_id_field.column
278
279 # select the ID of related models relevant to this join. we can only do
280 # a single join, so we need to gather this information up front and
281 # include it in the join condition.
282 rhs_ids = join_to_query.values_list(rhs_id_field.attname, flat=True)
283 assert len(rhs_ids) == 1, ('Many-to-many custom field joins can only '
284 'match a single related object.')
285 rhs_id = rhs_ids[0]
286
287 info['where_clause'] = '%s.%s = %s' % (_quote_name(alias),
288 _quote_name(m2m_rhs_column),
289 rhs_id)
290 info['values'] = ()
291 return info
292
293
294 def join_custom_field(self, query_set, join_to_query, alias,
295 left_join=True):
296 """Join to a related model to create a custom field in the given query.
297
298 This method is used to construct a custom field on the given query based
299 on a many-valued relationsip. join_to_query should be a simple query
300 (no joins) on the related model which returns at most one related row
301 per instance of this model.
302
303 For many-to-one relationships, the joined table contains the matching
304 row from the related model it one is related, NULL otherwise.
305
306 For many-to-many relationships, the joined table contains the matching
307 row if it's related, NULL otherwise.
308 """
309 relationship_type, field = self.determine_relationship(
310 join_to_query.model)
311
312 if relationship_type == self.MANY_TO_ONE:
313 info = self._info_for_many_to_one_join(field, join_to_query, alias)
314 elif relationship_type == self.M2M_ON_RELATED_MODEL:
315 info = self._info_for_many_to_many_join(
316 m2m_field=field, join_to_query=join_to_query, alias=alias,
317 m2m_is_on_this_model=False)
318 elif relationship_type ==self.M2M_ON_THIS_MODEL:
319 info = self._info_for_many_to_many_join(
320 m2m_field=field, join_to_query=join_to_query, alias=alias,
321 m2m_is_on_this_model=True)
322
323 return self.add_join(query_set, info['rhs_table'], info['rhs_column'],
324 join_from_key=info['lhs_column'],
325 join_condition=info['where_clause'],
326 join_condition_values=info['values'],
327 alias=alias,
328 force_left_join=left_join)
329
330
331 def add_where(self, query_set, where, values=()):
332 query_set = query_set.all()
333 query_set.query.where.add(self._WhereClause(where, values),
334 django.db.models.sql.where.AND)
showardc4780402009-08-31 18:31:34 +0000335 return query_set
showard7c785282008-05-29 19:45:12 +0000336
337
showardeaccf8f2009-04-16 03:11:33 +0000338 def _get_quoted_field(self, table, field):
showarda5288b42009-07-28 20:06:08 +0000339 return _quote_name(table) + '.' + _quote_name(field)
showard5ef36e92008-07-02 16:37:09 +0000340
341
showard7c199df2008-10-03 10:17:15 +0000342 def get_key_on_this_table(self, key_field=None):
showard5ef36e92008-07-02 16:37:09 +0000343 if key_field is None:
344 # default to primary key
345 key_field = self.model._meta.pk.column
346 return self._get_quoted_field(self.model._meta.db_table, key_field)
347
348
showardeaccf8f2009-04-16 03:11:33 +0000349 def escape_user_sql(self, sql):
350 return sql.replace('%', '%%')
351
showard5ef36e92008-07-02 16:37:09 +0000352
showard0957a842009-05-11 19:25:08 +0000353 def _custom_select_query(self, query_set, selects):
showarda5288b42009-07-28 20:06:08 +0000354 sql, params = query_set.query.as_sql()
355 from_ = sql[sql.find(' FROM'):]
356
357 if query_set.query.distinct:
showard0957a842009-05-11 19:25:08 +0000358 distinct = 'DISTINCT '
359 else:
360 distinct = ''
showarda5288b42009-07-28 20:06:08 +0000361
362 sql_query = ('SELECT ' + distinct + ','.join(selects) + from_)
showard0957a842009-05-11 19:25:08 +0000363 cursor = readonly_connection.connection().cursor()
364 cursor.execute(sql_query, params)
365 return cursor.fetchall()
366
367
showard68693f72009-05-20 00:31:53 +0000368 def _is_relation_to(self, field, model_class):
369 return field.rel and field.rel.to is model_class
showard0957a842009-05-11 19:25:08 +0000370
371
showard7e67b432010-01-20 01:13:04 +0000372 MANY_TO_ONE = object()
373 M2M_ON_RELATED_MODEL = object()
374 M2M_ON_THIS_MODEL = object()
375
376 def determine_relationship(self, related_model):
377 """
378 Determine the relationship between this model and related_model.
379
380 related_model must have some sort of many-valued relationship to this
381 manager's model.
382 @returns (relationship_type, field), where relationship_type is one of
383 MANY_TO_ONE, M2M_ON_RELATED_MODEL, M2M_ON_THIS_MODEL, and field
384 is the Django field object for the relationship.
385 """
386 # look for a foreign key field on related_model relating to this model
387 for field in related_model._meta.fields:
388 if self._is_relation_to(field, self.model):
389 return self.MANY_TO_ONE, field
390
391 # look for an M2M field on related_model relating to this model
392 for field in related_model._meta.many_to_many:
393 if self._is_relation_to(field, self.model):
394 return self.M2M_ON_RELATED_MODEL, field
395
396 # maybe this model has the many-to-many field
397 for field in self.model._meta.many_to_many:
398 if self._is_relation_to(field, related_model):
399 return self.M2M_ON_THIS_MODEL, field
400
401 raise ValueError('%s has no relation to %s' %
402 (related_model, self.model))
403
404
showard68693f72009-05-20 00:31:53 +0000405 def _get_pivot_iterator(self, base_objects_by_id, related_model):
showard0957a842009-05-11 19:25:08 +0000406 """
showard68693f72009-05-20 00:31:53 +0000407 Determine the relationship between this model and related_model, and
408 return a pivot iterator.
409 @param base_objects_by_id: dict of instances of this model indexed by
410 their IDs
411 @returns a pivot iterator, which yields a tuple (base_object,
412 related_object) for each relationship between a base object and a
413 related object. all base_object instances come from base_objects_by_id.
showard7e67b432010-01-20 01:13:04 +0000414 Note -- this depends on Django model internals.
showard0957a842009-05-11 19:25:08 +0000415 """
showard7e67b432010-01-20 01:13:04 +0000416 relationship_type, field = self.determine_relationship(related_model)
417 if relationship_type == self.MANY_TO_ONE:
418 return self._many_to_one_pivot(base_objects_by_id,
419 related_model, field)
420 elif relationship_type == self.M2M_ON_RELATED_MODEL:
421 return self._many_to_many_pivot(
showard68693f72009-05-20 00:31:53 +0000422 base_objects_by_id, related_model, field.m2m_db_table(),
423 field.m2m_reverse_name(), field.m2m_column_name())
showard7e67b432010-01-20 01:13:04 +0000424 else:
425 assert relationship_type == self.M2M_ON_THIS_MODEL
426 return self._many_to_many_pivot(
showard68693f72009-05-20 00:31:53 +0000427 base_objects_by_id, related_model, field.m2m_db_table(),
428 field.m2m_column_name(), field.m2m_reverse_name())
showard0957a842009-05-11 19:25:08 +0000429
showard0957a842009-05-11 19:25:08 +0000430
showard68693f72009-05-20 00:31:53 +0000431 def _many_to_one_pivot(self, base_objects_by_id, related_model,
432 foreign_key_field):
433 """
434 @returns a pivot iterator - see _get_pivot_iterator()
435 """
436 filter_data = {foreign_key_field.name + '__pk__in':
437 base_objects_by_id.keys()}
438 for related_object in related_model.objects.filter(**filter_data):
showarda5a72c92009-08-20 23:35:21 +0000439 # lookup base object in the dict, rather than grabbing it from the
440 # related object. we need to return instances from the dict, not
441 # fresh instances of the same models (and grabbing model instances
442 # from the related models incurs a DB query each time).
443 base_object_id = getattr(related_object, foreign_key_field.attname)
444 base_object = base_objects_by_id[base_object_id]
showard68693f72009-05-20 00:31:53 +0000445 yield base_object, related_object
446
447
448 def _query_pivot_table(self, base_objects_by_id, pivot_table,
449 pivot_from_field, pivot_to_field):
showard0957a842009-05-11 19:25:08 +0000450 """
451 @param id_list list of IDs of self.model objects to include
452 @param pivot_table the name of the pivot table
453 @param pivot_from_field a field name on pivot_table referencing
454 self.model
455 @param pivot_to_field a field name on pivot_table referencing the
456 related model.
showard68693f72009-05-20 00:31:53 +0000457 @returns pivot list of IDs (base_id, related_id)
showard0957a842009-05-11 19:25:08 +0000458 """
459 query = """
460 SELECT %(from_field)s, %(to_field)s
461 FROM %(table)s
462 WHERE %(from_field)s IN (%(id_list)s)
463 """ % dict(from_field=pivot_from_field,
464 to_field=pivot_to_field,
465 table=pivot_table,
showard68693f72009-05-20 00:31:53 +0000466 id_list=','.join(str(id_) for id_
467 in base_objects_by_id.iterkeys()))
showard0957a842009-05-11 19:25:08 +0000468 cursor = readonly_connection.connection().cursor()
469 cursor.execute(query)
showard68693f72009-05-20 00:31:53 +0000470 return cursor.fetchall()
showard0957a842009-05-11 19:25:08 +0000471
472
showard68693f72009-05-20 00:31:53 +0000473 def _many_to_many_pivot(self, base_objects_by_id, related_model,
474 pivot_table, pivot_from_field, pivot_to_field):
475 """
476 @param pivot_table: see _query_pivot_table
477 @param pivot_from_field: see _query_pivot_table
478 @param pivot_to_field: see _query_pivot_table
479 @returns a pivot iterator - see _get_pivot_iterator()
480 """
481 id_pivot = self._query_pivot_table(base_objects_by_id, pivot_table,
482 pivot_from_field, pivot_to_field)
483
484 all_related_ids = list(set(related_id for base_id, related_id
485 in id_pivot))
486 related_objects_by_id = related_model.objects.in_bulk(all_related_ids)
487
488 for base_id, related_id in id_pivot:
489 yield base_objects_by_id[base_id], related_objects_by_id[related_id]
490
491
492 def populate_relationships(self, base_objects, related_model,
showard0957a842009-05-11 19:25:08 +0000493 related_list_name):
494 """
showard68693f72009-05-20 00:31:53 +0000495 For each instance of this model in base_objects, add a field named
496 related_list_name listing all the related objects of type related_model.
497 related_model must be in a many-to-one or many-to-many relationship with
498 this model.
499 @param base_objects - list of instances of this model
500 @param related_model - model class related to this model
501 @param related_list_name - attribute name in which to store the related
502 object list.
showard0957a842009-05-11 19:25:08 +0000503 """
showard68693f72009-05-20 00:31:53 +0000504 if not base_objects:
showard0957a842009-05-11 19:25:08 +0000505 # if we don't bail early, we'll get a SQL error later
506 return
showard0957a842009-05-11 19:25:08 +0000507
showard68693f72009-05-20 00:31:53 +0000508 base_objects_by_id = dict((base_object._get_pk_val(), base_object)
509 for base_object in base_objects)
510 pivot_iterator = self._get_pivot_iterator(base_objects_by_id,
511 related_model)
showard0957a842009-05-11 19:25:08 +0000512
showard68693f72009-05-20 00:31:53 +0000513 for base_object in base_objects:
514 setattr(base_object, related_list_name, [])
515
516 for base_object, related_object in pivot_iterator:
517 getattr(base_object, related_list_name).append(related_object)
showard0957a842009-05-11 19:25:08 +0000518
519
showard7c785282008-05-29 19:45:12 +0000520class ValidObjectsManager(ExtendedManager):
jadmanski0afbb632008-06-06 21:10:57 +0000521 """
522 Manager returning only objects with invalid=False.
523 """
524 def get_query_set(self):
525 queryset = super(ValidObjectsManager, self).get_query_set()
526 return queryset.filter(invalid=False)
showard7c785282008-05-29 19:45:12 +0000527
528
529class ModelExtensions(object):
jadmanski0afbb632008-06-06 21:10:57 +0000530 """\
531 Mixin with convenience functions for models, built on top of the
532 default Django model functions.
533 """
534 # TODO: at least some of these functions really belong in a custom
535 # Manager class
showard7c785282008-05-29 19:45:12 +0000536
jadmanski0afbb632008-06-06 21:10:57 +0000537 field_dict = None
538 # subclasses should override if they want to support smart_get() by name
539 name_field = None
showard7c785282008-05-29 19:45:12 +0000540
541
jadmanski0afbb632008-06-06 21:10:57 +0000542 @classmethod
543 def get_field_dict(cls):
544 if cls.field_dict is None:
545 cls.field_dict = {}
546 for field in cls._meta.fields:
547 cls.field_dict[field.name] = field
548 return cls.field_dict
showard7c785282008-05-29 19:45:12 +0000549
550
jadmanski0afbb632008-06-06 21:10:57 +0000551 @classmethod
552 def clean_foreign_keys(cls, data):
553 """\
554 -Convert foreign key fields in data from <field>_id to just
555 <field>.
556 -replace foreign key objects with their IDs
557 This method modifies data in-place.
558 """
559 for field in cls._meta.fields:
560 if not field.rel:
561 continue
562 if (field.attname != field.name and
563 field.attname in data):
564 data[field.name] = data[field.attname]
565 del data[field.attname]
showarde732ee72008-09-23 19:15:43 +0000566 if field.name not in data:
567 continue
jadmanski0afbb632008-06-06 21:10:57 +0000568 value = data[field.name]
569 if isinstance(value, dbmodels.Model):
showardf8b19042009-05-12 17:22:49 +0000570 data[field.name] = value._get_pk_val()
showard7c785282008-05-29 19:45:12 +0000571
572
showard21baa452008-10-21 00:08:39 +0000573 @classmethod
574 def _convert_booleans(cls, data):
575 """
576 Ensure BooleanFields actually get bool values. The Django MySQL
577 backend returns ints for BooleanFields, which is almost always not
578 a problem, but it can be annoying in certain situations.
579 """
580 for field in cls._meta.fields:
showardf8b19042009-05-12 17:22:49 +0000581 if type(field) == dbmodels.BooleanField and field.name in data:
showard21baa452008-10-21 00:08:39 +0000582 data[field.name] = bool(data[field.name])
583
584
jadmanski0afbb632008-06-06 21:10:57 +0000585 # TODO(showard) - is there a way to not have to do this?
586 @classmethod
587 def provide_default_values(cls, data):
588 """\
589 Provide default values for fields with default values which have
590 nothing passed in.
showard7c785282008-05-29 19:45:12 +0000591
jadmanski0afbb632008-06-06 21:10:57 +0000592 For CharField and TextField fields with "blank=True", if nothing
593 is passed, we fill in an empty string value, even if there's no
594 default set.
595 """
596 new_data = dict(data)
597 field_dict = cls.get_field_dict()
598 for name, obj in field_dict.iteritems():
599 if data.get(name) is not None:
600 continue
601 if obj.default is not dbmodels.fields.NOT_PROVIDED:
602 new_data[name] = obj.default
603 elif (isinstance(obj, dbmodels.CharField) or
604 isinstance(obj, dbmodels.TextField)):
605 new_data[name] = ''
606 return new_data
showard7c785282008-05-29 19:45:12 +0000607
608
jadmanski0afbb632008-06-06 21:10:57 +0000609 @classmethod
610 def convert_human_readable_values(cls, data, to_human_readable=False):
611 """\
612 Performs conversions on user-supplied field data, to make it
613 easier for users to pass human-readable data.
showard7c785282008-05-29 19:45:12 +0000614
jadmanski0afbb632008-06-06 21:10:57 +0000615 For all fields that have choice sets, convert their values
616 from human-readable strings to enum values, if necessary. This
617 allows users to pass strings instead of the corresponding
618 integer values.
showard7c785282008-05-29 19:45:12 +0000619
jadmanski0afbb632008-06-06 21:10:57 +0000620 For all foreign key fields, call smart_get with the supplied
621 data. This allows the user to pass either an ID value or
622 the name of the object as a string.
showard7c785282008-05-29 19:45:12 +0000623
jadmanski0afbb632008-06-06 21:10:57 +0000624 If to_human_readable=True, perform the inverse - i.e. convert
625 numeric values to human readable values.
showard7c785282008-05-29 19:45:12 +0000626
jadmanski0afbb632008-06-06 21:10:57 +0000627 This method modifies data in-place.
628 """
629 field_dict = cls.get_field_dict()
630 for field_name in data:
showarde732ee72008-09-23 19:15:43 +0000631 if field_name not in field_dict or data[field_name] is None:
jadmanski0afbb632008-06-06 21:10:57 +0000632 continue
633 field_obj = field_dict[field_name]
634 # convert enum values
635 if field_obj.choices:
636 for choice_data in field_obj.choices:
637 # choice_data is (value, name)
638 if to_human_readable:
639 from_val, to_val = choice_data
640 else:
641 to_val, from_val = choice_data
642 if from_val == data[field_name]:
643 data[field_name] = to_val
644 break
645 # convert foreign key values
646 elif field_obj.rel:
showarda4ea5742009-02-17 20:56:23 +0000647 dest_obj = field_obj.rel.to.smart_get(data[field_name],
648 valid_only=False)
showardf8b19042009-05-12 17:22:49 +0000649 if to_human_readable:
650 if dest_obj.name_field is not None:
651 data[field_name] = getattr(dest_obj,
652 dest_obj.name_field)
jadmanski0afbb632008-06-06 21:10:57 +0000653 else:
showardb0a73032009-03-27 18:35:41 +0000654 data[field_name] = dest_obj
showard7c785282008-05-29 19:45:12 +0000655
656
jadmanski0afbb632008-06-06 21:10:57 +0000657 @classmethod
658 def validate_field_names(cls, data):
659 'Checks for extraneous fields in data.'
660 errors = {}
661 field_dict = cls.get_field_dict()
662 for field_name in data:
663 if field_name not in field_dict:
664 errors[field_name] = 'No field of this name'
665 return errors
showard7c785282008-05-29 19:45:12 +0000666
667
jadmanski0afbb632008-06-06 21:10:57 +0000668 @classmethod
669 def prepare_data_args(cls, data, kwargs):
670 'Common preparation for add_object and update_object'
671 data = dict(data) # don't modify the default keyword arg
672 data.update(kwargs)
673 # must check for extraneous field names here, while we have the
674 # data in a dict
675 errors = cls.validate_field_names(data)
676 if errors:
677 raise ValidationError(errors)
678 cls.convert_human_readable_values(data)
679 return data
showard7c785282008-05-29 19:45:12 +0000680
681
jadmanski0afbb632008-06-06 21:10:57 +0000682 def validate_unique(self):
683 """\
684 Validate that unique fields are unique. Django manipulators do
685 this too, but they're a huge pain to use manually. Trust me.
686 """
687 errors = {}
688 cls = type(self)
689 field_dict = self.get_field_dict()
690 manager = cls.get_valid_manager()
691 for field_name, field_obj in field_dict.iteritems():
692 if not field_obj.unique:
693 continue
showard7c785282008-05-29 19:45:12 +0000694
jadmanski0afbb632008-06-06 21:10:57 +0000695 value = getattr(self, field_name)
showardbd18ab72009-09-18 21:20:27 +0000696 if value is None and field_obj.auto_created:
697 # don't bother checking autoincrement fields about to be
698 # generated
699 continue
700
jadmanski0afbb632008-06-06 21:10:57 +0000701 existing_objs = manager.filter(**{field_name : value})
702 num_existing = existing_objs.count()
showard7c785282008-05-29 19:45:12 +0000703
jadmanski0afbb632008-06-06 21:10:57 +0000704 if num_existing == 0:
705 continue
706 if num_existing == 1 and existing_objs[0].id == self.id:
707 continue
708 errors[field_name] = (
709 'This value must be unique (%s)' % (value))
710 return errors
showard7c785282008-05-29 19:45:12 +0000711
712
showarda5288b42009-07-28 20:06:08 +0000713 def _validate(self):
714 """
715 First coerces all fields on this instance to their proper Python types.
716 Then runs validation on every field. Returns a dictionary of
717 field_name -> error_list.
718
719 Based on validate() from django.db.models.Model in Django 0.96, which
720 was removed in Django 1.0. It should reappear in a later version. See:
721 http://code.djangoproject.com/ticket/6845
722 """
723 error_dict = {}
724 for f in self._meta.fields:
725 try:
726 python_value = f.to_python(
727 getattr(self, f.attname, f.get_default()))
728 except django.core.exceptions.ValidationError, e:
729 error_dict[f.name] = str(e.message)
730 continue
731
732 if not f.blank and not python_value:
733 error_dict[f.name] = 'This field is required.'
734 continue
735
736 setattr(self, f.attname, python_value)
737
738 return error_dict
739
740
jadmanski0afbb632008-06-06 21:10:57 +0000741 def do_validate(self):
showarda5288b42009-07-28 20:06:08 +0000742 errors = self._validate()
jadmanski0afbb632008-06-06 21:10:57 +0000743 unique_errors = self.validate_unique()
744 for field_name, error in unique_errors.iteritems():
745 errors.setdefault(field_name, error)
746 if errors:
747 raise ValidationError(errors)
showard7c785282008-05-29 19:45:12 +0000748
749
jadmanski0afbb632008-06-06 21:10:57 +0000750 # actually (externally) useful methods follow
showard7c785282008-05-29 19:45:12 +0000751
jadmanski0afbb632008-06-06 21:10:57 +0000752 @classmethod
753 def add_object(cls, data={}, **kwargs):
754 """\
755 Returns a new object created with the given data (a dictionary
756 mapping field names to values). Merges any extra keyword args
757 into data.
758 """
759 data = cls.prepare_data_args(data, kwargs)
760 data = cls.provide_default_values(data)
761 obj = cls(**data)
762 obj.do_validate()
763 obj.save()
764 return obj
showard7c785282008-05-29 19:45:12 +0000765
766
jadmanski0afbb632008-06-06 21:10:57 +0000767 def update_object(self, data={}, **kwargs):
768 """\
769 Updates the object with the given data (a dictionary mapping
770 field names to values). Merges any extra keyword args into
771 data.
772 """
773 data = self.prepare_data_args(data, kwargs)
774 for field_name, value in data.iteritems():
showardb0a73032009-03-27 18:35:41 +0000775 setattr(self, field_name, value)
jadmanski0afbb632008-06-06 21:10:57 +0000776 self.do_validate()
777 self.save()
showard7c785282008-05-29 19:45:12 +0000778
779
showard8bfb5cb2009-10-07 20:49:15 +0000780 # see query_objects()
781 _SPECIAL_FILTER_KEYS = ('query_start', 'query_limit', 'sort_by',
782 'extra_args', 'extra_where', 'no_distinct')
783
784
jadmanski0afbb632008-06-06 21:10:57 +0000785 @classmethod
showard8bfb5cb2009-10-07 20:49:15 +0000786 def _extract_special_params(cls, filter_data):
787 """
788 @returns a tuple of dicts (special_params, regular_filters), where
789 special_params contains the parameters we handle specially and
790 regular_filters is the remaining data to be handled by Django.
791 """
792 regular_filters = dict(filter_data)
793 special_params = {}
794 for key in cls._SPECIAL_FILTER_KEYS:
795 if key in regular_filters:
796 special_params[key] = regular_filters.pop(key)
797 return special_params, regular_filters
798
799
800 @classmethod
801 def apply_presentation(cls, query, filter_data):
802 """
803 Apply presentation parameters -- sorting and paging -- to the given
804 query.
805 @returns new query with presentation applied
806 """
807 special_params, _ = cls._extract_special_params(filter_data)
808 sort_by = special_params.get('sort_by', None)
809 if sort_by:
810 assert isinstance(sort_by, list) or isinstance(sort_by, tuple)
showard8b0ea222009-12-23 19:23:03 +0000811 query = query.extra(order_by=sort_by)
showard8bfb5cb2009-10-07 20:49:15 +0000812
813 query_start = special_params.get('query_start', None)
814 query_limit = special_params.get('query_limit', None)
815 if query_start is not None:
816 if query_limit is None:
817 raise ValueError('Cannot pass query_start without query_limit')
818 # query_limit is passed as a page size
showard7074b742009-10-12 20:30:04 +0000819 query_limit += query_start
820 return query[query_start:query_limit]
showard8bfb5cb2009-10-07 20:49:15 +0000821
822
823 @classmethod
824 def query_objects(cls, filter_data, valid_only=True, initial_query=None,
825 apply_presentation=True):
jadmanski0afbb632008-06-06 21:10:57 +0000826 """\
827 Returns a QuerySet object for querying the given model_class
828 with the given filter_data. Optional special arguments in
829 filter_data include:
830 -query_start: index of first return to return
831 -query_limit: maximum number of results to return
832 -sort_by: list of fields to sort on. prefixing a '-' onto a
833 field name changes the sort to descending order.
834 -extra_args: keyword args to pass to query.extra() (see Django
835 DB layer documentation)
showarda5288b42009-07-28 20:06:08 +0000836 -extra_where: extra WHERE clause to append
showard8bfb5cb2009-10-07 20:49:15 +0000837 -no_distinct: if True, a DISTINCT will not be added to the SELECT
jadmanski0afbb632008-06-06 21:10:57 +0000838 """
showard8bfb5cb2009-10-07 20:49:15 +0000839 special_params, regular_filters = cls._extract_special_params(
840 filter_data)
showard7c785282008-05-29 19:45:12 +0000841
showard7ac7b7a2008-07-21 20:24:29 +0000842 if initial_query is None:
843 if valid_only:
844 initial_query = cls.get_valid_manager()
845 else:
846 initial_query = cls.objects
showard8bfb5cb2009-10-07 20:49:15 +0000847
848 query = initial_query.filter(**regular_filters)
849
850 use_distinct = not special_params.get('no_distinct', False)
showard7ac7b7a2008-07-21 20:24:29 +0000851 if use_distinct:
852 query = query.distinct()
showard7c785282008-05-29 19:45:12 +0000853
showard8bfb5cb2009-10-07 20:49:15 +0000854 extra_args = special_params.get('extra_args', {})
855 extra_where = special_params.get('extra_where', None)
856 if extra_where:
857 # escape %'s
858 extra_where = cls.objects.escape_user_sql(extra_where)
859 extra_args.setdefault('where', []).append(extra_where)
jadmanski0afbb632008-06-06 21:10:57 +0000860 if extra_args:
861 query = query.extra(**extra_args)
showard09096d82008-07-07 23:20:49 +0000862 query = query._clone(klass=ReadonlyQuerySet)
showard7c785282008-05-29 19:45:12 +0000863
showard8bfb5cb2009-10-07 20:49:15 +0000864 if apply_presentation:
865 query = cls.apply_presentation(query, filter_data)
866
867 return query
showard7c785282008-05-29 19:45:12 +0000868
869
jadmanski0afbb632008-06-06 21:10:57 +0000870 @classmethod
showard585c2ab2008-07-23 19:29:49 +0000871 def query_count(cls, filter_data, initial_query=None):
jadmanski0afbb632008-06-06 21:10:57 +0000872 """\
873 Like query_objects, but retreive only the count of results.
874 """
875 filter_data.pop('query_start', None)
876 filter_data.pop('query_limit', None)
showard585c2ab2008-07-23 19:29:49 +0000877 query = cls.query_objects(filter_data, initial_query=initial_query)
878 return query.count()
showard7c785282008-05-29 19:45:12 +0000879
880
jadmanski0afbb632008-06-06 21:10:57 +0000881 @classmethod
882 def clean_object_dicts(cls, field_dicts):
883 """\
884 Take a list of dicts corresponding to object (as returned by
885 query.values()) and clean the data to be more suitable for
886 returning to the user.
887 """
showarde732ee72008-09-23 19:15:43 +0000888 for field_dict in field_dicts:
889 cls.clean_foreign_keys(field_dict)
showard21baa452008-10-21 00:08:39 +0000890 cls._convert_booleans(field_dict)
showarde732ee72008-09-23 19:15:43 +0000891 cls.convert_human_readable_values(field_dict,
892 to_human_readable=True)
showard7c785282008-05-29 19:45:12 +0000893
894
jadmanski0afbb632008-06-06 21:10:57 +0000895 @classmethod
showard8bfb5cb2009-10-07 20:49:15 +0000896 def list_objects(cls, filter_data, initial_query=None):
jadmanski0afbb632008-06-06 21:10:57 +0000897 """\
898 Like query_objects, but return a list of dictionaries.
899 """
showard7ac7b7a2008-07-21 20:24:29 +0000900 query = cls.query_objects(filter_data, initial_query=initial_query)
showard8bfb5cb2009-10-07 20:49:15 +0000901 extra_fields = query.query.extra_select.keys()
902 field_dicts = [model_object.get_object_dict(extra_fields=extra_fields)
showarde732ee72008-09-23 19:15:43 +0000903 for model_object in query]
jadmanski0afbb632008-06-06 21:10:57 +0000904 return field_dicts
showard7c785282008-05-29 19:45:12 +0000905
906
jadmanski0afbb632008-06-06 21:10:57 +0000907 @classmethod
showarda4ea5742009-02-17 20:56:23 +0000908 def smart_get(cls, id_or_name, valid_only=True):
jadmanski0afbb632008-06-06 21:10:57 +0000909 """\
910 smart_get(integer) -> get object by ID
911 smart_get(string) -> get object by name_field
jadmanski0afbb632008-06-06 21:10:57 +0000912 """
showarda4ea5742009-02-17 20:56:23 +0000913 if valid_only:
914 manager = cls.get_valid_manager()
915 else:
916 manager = cls.objects
917
918 if isinstance(id_or_name, (int, long)):
919 return manager.get(pk=id_or_name)
920 if isinstance(id_or_name, basestring):
921 return manager.get(**{cls.name_field : id_or_name})
922 raise ValueError(
923 'Invalid positional argument: %s (%s)' % (id_or_name,
924 type(id_or_name)))
showard7c785282008-05-29 19:45:12 +0000925
926
showardbe3ec042008-11-12 18:16:07 +0000927 @classmethod
928 def smart_get_bulk(cls, id_or_name_list):
929 invalid_inputs = []
930 result_objects = []
931 for id_or_name in id_or_name_list:
932 try:
933 result_objects.append(cls.smart_get(id_or_name))
934 except cls.DoesNotExist:
935 invalid_inputs.append(id_or_name)
936 if invalid_inputs:
mbligh7a3ebe32008-12-01 17:10:33 +0000937 raise cls.DoesNotExist('The following %ss do not exist: %s'
938 % (cls.__name__.lower(),
939 ', '.join(invalid_inputs)))
showardbe3ec042008-11-12 18:16:07 +0000940 return result_objects
941
942
showard8bfb5cb2009-10-07 20:49:15 +0000943 def get_object_dict(self, extra_fields=None):
jadmanski0afbb632008-06-06 21:10:57 +0000944 """\
showard8bfb5cb2009-10-07 20:49:15 +0000945 Return a dictionary mapping fields to this object's values. @param
946 extra_fields: list of extra attribute names to include, in addition to
947 the fields defined on this object.
jadmanski0afbb632008-06-06 21:10:57 +0000948 """
showard8bfb5cb2009-10-07 20:49:15 +0000949 fields = self.get_field_dict().keys()
950 if extra_fields:
951 fields += extra_fields
jadmanski0afbb632008-06-06 21:10:57 +0000952 object_dict = dict((field_name, getattr(self, field_name))
showarde732ee72008-09-23 19:15:43 +0000953 for field_name in fields)
jadmanski0afbb632008-06-06 21:10:57 +0000954 self.clean_object_dicts([object_dict])
showardd3dc1992009-04-22 21:01:40 +0000955 self._postprocess_object_dict(object_dict)
jadmanski0afbb632008-06-06 21:10:57 +0000956 return object_dict
showard7c785282008-05-29 19:45:12 +0000957
958
showardd3dc1992009-04-22 21:01:40 +0000959 def _postprocess_object_dict(self, object_dict):
960 """For subclasses to override."""
961 pass
962
963
jadmanski0afbb632008-06-06 21:10:57 +0000964 @classmethod
965 def get_valid_manager(cls):
966 return cls.objects
showard7c785282008-05-29 19:45:12 +0000967
968
showard2bab8f42008-11-12 18:15:22 +0000969 def _record_attributes(self, attributes):
970 """
971 See on_attribute_changed.
972 """
973 assert not isinstance(attributes, basestring)
974 self._recorded_attributes = dict((attribute, getattr(self, attribute))
975 for attribute in attributes)
976
977
978 def _check_for_updated_attributes(self):
979 """
980 See on_attribute_changed.
981 """
982 for attribute, original_value in self._recorded_attributes.iteritems():
983 new_value = getattr(self, attribute)
984 if original_value != new_value:
985 self.on_attribute_changed(attribute, original_value)
986 self._record_attributes(self._recorded_attributes.keys())
987
988
989 def on_attribute_changed(self, attribute, old_value):
990 """
991 Called whenever an attribute is updated. To be overridden.
992
993 To use this method, you must:
994 * call _record_attributes() from __init__() (after making the super
995 call) with a list of attributes for which you want to be notified upon
996 change.
997 * call _check_for_updated_attributes() from save().
998 """
999 pass
1000
1001
showard7c785282008-05-29 19:45:12 +00001002class ModelWithInvalid(ModelExtensions):
jadmanski0afbb632008-06-06 21:10:57 +00001003 """
1004 Overrides model methods save() and delete() to support invalidation in
1005 place of actual deletion. Subclasses must have a boolean "invalid"
1006 field.
1007 """
showard7c785282008-05-29 19:45:12 +00001008
showarda5288b42009-07-28 20:06:08 +00001009 def save(self, *args, **kwargs):
showardddb90992009-02-11 23:39:32 +00001010 first_time = (self.id is None)
1011 if first_time:
1012 # see if this object was previously added and invalidated
1013 my_name = getattr(self, self.name_field)
1014 filters = {self.name_field : my_name, 'invalid' : True}
1015 try:
1016 old_object = self.__class__.objects.get(**filters)
showardafd97de2009-10-01 18:45:09 +00001017 self.resurrect_object(old_object)
showardddb90992009-02-11 23:39:32 +00001018 except self.DoesNotExist:
1019 # no existing object
1020 pass
showard7c785282008-05-29 19:45:12 +00001021
showarda5288b42009-07-28 20:06:08 +00001022 super(ModelWithInvalid, self).save(*args, **kwargs)
showard7c785282008-05-29 19:45:12 +00001023
1024
showardafd97de2009-10-01 18:45:09 +00001025 def resurrect_object(self, old_object):
1026 """
1027 Called when self is about to be saved for the first time and is actually
1028 "undeleting" a previously deleted object. Can be overridden by
1029 subclasses to copy data as desired from the deleted entry (but this
1030 superclass implementation must normally be called).
1031 """
1032 self.id = old_object.id
1033
1034
jadmanski0afbb632008-06-06 21:10:57 +00001035 def clean_object(self):
1036 """
1037 This method is called when an object is marked invalid.
1038 Subclasses should override this to clean up relationships that
showardafd97de2009-10-01 18:45:09 +00001039 should no longer exist if the object were deleted.
1040 """
jadmanski0afbb632008-06-06 21:10:57 +00001041 pass
showard7c785282008-05-29 19:45:12 +00001042
1043
jadmanski0afbb632008-06-06 21:10:57 +00001044 def delete(self):
1045 assert not self.invalid
1046 self.invalid = True
1047 self.save()
1048 self.clean_object()
showard7c785282008-05-29 19:45:12 +00001049
1050
jadmanski0afbb632008-06-06 21:10:57 +00001051 @classmethod
1052 def get_valid_manager(cls):
1053 return cls.valid_objects
showard7c785282008-05-29 19:45:12 +00001054
1055
jadmanski0afbb632008-06-06 21:10:57 +00001056 class Manipulator(object):
1057 """
1058 Force default manipulators to look only at valid objects -
1059 otherwise they will match against invalid objects when checking
1060 uniqueness.
1061 """
1062 @classmethod
1063 def _prepare(cls, model):
1064 super(ModelWithInvalid.Manipulator, cls)._prepare(model)
1065 cls.manager = model.valid_objects
showardf8b19042009-05-12 17:22:49 +00001066
1067
1068class ModelWithAttributes(object):
1069 """
1070 Mixin class for models that have an attribute model associated with them.
1071 The attribute model is assumed to have its value field named "value".
1072 """
1073
1074 def _get_attribute_model_and_args(self, attribute):
1075 """
1076 Subclasses should override this to return a tuple (attribute_model,
1077 keyword_args), where attribute_model is a model class and keyword_args
1078 is a dict of args to pass to attribute_model.objects.get() to get an
1079 instance of the given attribute on this object.
1080 """
1081 raise NotImplemented
1082
1083
1084 def set_attribute(self, attribute, value):
1085 attribute_model, get_args = self._get_attribute_model_and_args(
1086 attribute)
1087 attribute_object, _ = attribute_model.objects.get_or_create(**get_args)
1088 attribute_object.value = value
1089 attribute_object.save()
1090
1091
1092 def delete_attribute(self, attribute):
1093 attribute_model, get_args = self._get_attribute_model_and_args(
1094 attribute)
1095 try:
1096 attribute_model.objects.get(**get_args).delete()
showard16245422009-09-08 16:28:15 +00001097 except attribute_model.DoesNotExist:
showardf8b19042009-05-12 17:22:49 +00001098 pass
1099
1100
1101 def set_or_delete_attribute(self, attribute, value):
1102 if value is None:
1103 self.delete_attribute(attribute)
1104 else:
1105 self.set_attribute(attribute, value)
showard26b7ec72009-12-21 22:43:57 +00001106
1107
1108class ModelWithHashManager(dbmodels.Manager):
1109 """Manager for use with the ModelWithHash abstract model class"""
1110
1111 def create(self, **kwargs):
1112 raise Exception('ModelWithHash manager should use get_or_create() '
1113 'instead of create()')
1114
1115
1116 def get_or_create(self, **kwargs):
1117 kwargs['the_hash'] = self.model._compute_hash(**kwargs)
1118 return super(ModelWithHashManager, self).get_or_create(**kwargs)
1119
1120
1121class ModelWithHash(dbmodels.Model):
1122 """Superclass with methods for dealing with a hash column"""
1123
1124 the_hash = dbmodels.CharField(max_length=40, unique=True)
1125
1126 objects = ModelWithHashManager()
1127
1128 class Meta:
1129 abstract = True
1130
1131
1132 @classmethod
1133 def _compute_hash(cls, **kwargs):
1134 raise NotImplementedError('Subclasses must override _compute_hash()')
1135
1136
1137 def save(self, force_insert=False, **kwargs):
1138 """Prevents saving the model in most cases
1139
1140 We want these models to be immutable, so the generic save() operation
1141 will not work. These models should be instantiated through their the
1142 model.objects.get_or_create() method instead.
1143
1144 The exception is that save(force_insert=True) will be allowed, since
1145 that creates a new row. However, the preferred way to make instances of
1146 these models is through the get_or_create() method.
1147 """
1148 if not force_insert:
1149 # Allow a forced insert to happen; if it's a duplicate, the unique
1150 # constraint will catch it later anyways
1151 raise Exception('ModelWithHash is immutable')
1152 super(ModelWithHash, self).save(force_insert=force_insert, **kwargs)