blob: 3e83abb21bc9420b242cac7bbe4b151ba73f467d [file] [log] [blame]
showard7c785282008-05-29 19:45:12 +00001"""
2Extensions to Django's model logic.
3"""
4
showarda5288b42009-07-28 20:06:08 +00005import re
6import django.core.exceptions
showard7c785282008-05-29 19:45:12 +00007from django.db import models as dbmodels, backend, connection
showarda5288b42009-07-28 20:06:08 +00008from django.db.models.sql import query
showard7c785282008-05-29 19:45:12 +00009from django.utils import datastructures
showard56e93772008-10-06 10:06:22 +000010from autotest_lib.frontend.afe import readonly_connection
showard7c785282008-05-29 19:45:12 +000011
showarda5288b42009-07-28 20:06:08 +000012
showard7c785282008-05-29 19:45:12 +000013class ValidationError(Exception):
jadmanski0afbb632008-06-06 21:10:57 +000014 """\
showarda5288b42009-07-28 20:06:08 +000015 Data validation error in adding or updating an object. The associated
jadmanski0afbb632008-06-06 21:10:57 +000016 value is a dictionary mapping field names to error strings.
17 """
showard7c785282008-05-29 19:45:12 +000018
19
showard09096d82008-07-07 23:20:49 +000020def _wrap_with_readonly(method):
mbligh1ef218d2009-08-03 16:57:56 +000021 def wrapper_method(*args, **kwargs):
22 readonly_connection.connection().set_django_connection()
23 try:
24 return method(*args, **kwargs)
25 finally:
26 readonly_connection.connection().unset_django_connection()
27 wrapper_method.__name__ = method.__name__
28 return wrapper_method
showard09096d82008-07-07 23:20:49 +000029
30
showarda5288b42009-07-28 20:06:08 +000031def _quote_name(name):
32 """Shorthand for connection.ops.quote_name()."""
33 return connection.ops.quote_name(name)
34
35
showard09096d82008-07-07 23:20:49 +000036def _wrap_generator_with_readonly(generator):
37 """
38 We have to wrap generators specially. Assume it performs
39 the query on the first call to next().
40 """
41 def wrapper_generator(*args, **kwargs):
42 generator_obj = generator(*args, **kwargs)
showard56e93772008-10-06 10:06:22 +000043 readonly_connection.connection().set_django_connection()
showard09096d82008-07-07 23:20:49 +000044 try:
45 first_value = generator_obj.next()
46 finally:
showard56e93772008-10-06 10:06:22 +000047 readonly_connection.connection().unset_django_connection()
showard09096d82008-07-07 23:20:49 +000048 yield first_value
49
50 while True:
51 yield generator_obj.next()
52
53 wrapper_generator.__name__ = generator.__name__
54 return wrapper_generator
55
56
57def _make_queryset_readonly(queryset):
58 """
59 Wrap all methods that do database queries with a readonly connection.
60 """
61 db_query_methods = ['count', 'get', 'get_or_create', 'latest', 'in_bulk',
62 'delete']
63 for method_name in db_query_methods:
64 method = getattr(queryset, method_name)
65 wrapped_method = _wrap_with_readonly(method)
66 setattr(queryset, method_name, wrapped_method)
67
68 queryset.iterator = _wrap_generator_with_readonly(queryset.iterator)
69
70
71class ReadonlyQuerySet(dbmodels.query.QuerySet):
72 """
73 QuerySet object that performs all database queries with the read-only
74 connection.
75 """
showarda5288b42009-07-28 20:06:08 +000076 def __init__(self, model=None, *args, **kwargs):
77 super(ReadonlyQuerySet, self).__init__(model, *args, **kwargs)
showard09096d82008-07-07 23:20:49 +000078 _make_queryset_readonly(self)
79
80
81 def values(self, *fields):
showarda5288b42009-07-28 20:06:08 +000082 return self._clone(klass=ReadonlyValuesQuerySet,
83 setup=True, _fields=fields)
showard09096d82008-07-07 23:20:49 +000084
85
86class ReadonlyValuesQuerySet(dbmodels.query.ValuesQuerySet):
showarda5288b42009-07-28 20:06:08 +000087 def __init__(self, model=None, *args, **kwargs):
88 super(ReadonlyValuesQuerySet, self).__init__(model, *args, **kwargs)
showard09096d82008-07-07 23:20:49 +000089 _make_queryset_readonly(self)
90
91
showard7c785282008-05-29 19:45:12 +000092class ExtendedManager(dbmodels.Manager):
jadmanski0afbb632008-06-06 21:10:57 +000093 """\
94 Extended manager supporting subquery filtering.
95 """
showard7c785282008-05-29 19:45:12 +000096
showarda5288b42009-07-28 20:06:08 +000097 class _CustomQuery(query.Query):
98 def clone(self, klass=None, **kwargs):
99 obj = super(ExtendedManager._CustomQuery, self).clone(
100 klass, _customSqlQ=self._customSqlQ)
showard7c785282008-05-29 19:45:12 +0000101
showarda5288b42009-07-28 20:06:08 +0000102 customQ = kwargs.get('_customSqlQ', None)
103 if customQ is not None:
104 obj._customSqlQ._joins.update(customQ._joins)
105 obj._customSqlQ._where.extend(customQ._where)
106 obj._customSqlQ._params.extend(customQ._params)
showard7c785282008-05-29 19:45:12 +0000107
showarda5288b42009-07-28 20:06:08 +0000108 return obj
showard08f981b2008-06-24 21:59:03 +0000109
showarda5288b42009-07-28 20:06:08 +0000110 def get_from_clause(self):
111 from_, params = super(
112 ExtendedManager._CustomQuery, self).get_from_clause()
showard08f981b2008-06-24 21:59:03 +0000113
showarda5288b42009-07-28 20:06:08 +0000114 join_clause = ''
115 for join_alias, join in self._customSqlQ._joins.iteritems():
116 join_table, join_type, condition = join
117 join_clause += ' %s %s AS %s ON (%s)' % (
118 join_type, join_table, join_alias, condition)
showard08f981b2008-06-24 21:59:03 +0000119
showarda5288b42009-07-28 20:06:08 +0000120 if join_clause:
121 from_.append(join_clause)
showard7c785282008-05-29 19:45:12 +0000122
showarda5288b42009-07-28 20:06:08 +0000123 return from_, params
showard7c785282008-05-29 19:45:12 +0000124
125
showard43a3d262008-11-12 18:17:05 +0000126 class _CustomSqlQ(dbmodels.Q):
127 def __init__(self):
128 self._joins = datastructures.SortedDict()
129 self._where, self._params = [], []
130
131
132 def add_join(self, table, condition, join_type, alias=None):
133 if alias is None:
134 alias = table
showard43a3d262008-11-12 18:17:05 +0000135 self._joins[alias] = (table, join_type, condition)
136
137
138 def add_where(self, where, params=[]):
139 self._where.append(where)
140 self._params.extend(params)
141
142
showarda5288b42009-07-28 20:06:08 +0000143 def add_to_query(self, query, aliases):
144 if self._where:
145 where = ' AND '.join(self._where)
146 query.add_extra(None, None, (where,), self._params, None, None)
147
148
149 def _add_customSqlQ(self, query_set, filter_object):
150 """\
151 Add a _CustomSqlQ to the query set.
152 """
153 # Make a copy of the query set
154 query_set = query_set.all()
155
156 query_set.query = query_set.query.clone(
157 ExtendedManager._CustomQuery, _customSqlQ=filter_object)
158 return query_set.filter(filter_object)
showard43a3d262008-11-12 18:17:05 +0000159
160
161 def add_join(self, query_set, join_table, join_key,
showard0957a842009-05-11 19:25:08 +0000162 join_condition='', suffix='', exclude=False,
163 force_left_join=False):
164 """
165 Add a join to query_set.
166 @param join_table table to join to
167 @param join_key field referencing back to this model to use for the join
168 @param join_condition extra condition for the ON clause of the join
169 @param suffix suffix to add to join_table for the join alias
170 @param exclude if true, exclude rows that match this join (will use a
showarda5288b42009-07-28 20:06:08 +0000171 LEFT OUTER JOIN and an appropriate WHERE condition)
172 @param force_left_join - if true, a LEFT OUTER JOIN will be used instead of an
showard0957a842009-05-11 19:25:08 +0000173 INNER JOIN regardless of other options
174 """
175 join_from_table = self.model._meta.db_table
176 join_from_key = self.model._meta.pk.name
showard43a3d262008-11-12 18:17:05 +0000177 join_alias = join_table + suffix
178 full_join_key = join_alias + '.' + join_key
showard0957a842009-05-11 19:25:08 +0000179 full_join_condition = '%s = %s.%s' % (full_join_key, join_from_table,
180 join_from_key)
showard43a3d262008-11-12 18:17:05 +0000181 if join_condition:
182 full_join_condition += ' AND (' + join_condition + ')'
183 if exclude or force_left_join:
showarda5288b42009-07-28 20:06:08 +0000184 join_type = query_set.query.LOUTER
showard43a3d262008-11-12 18:17:05 +0000185 else:
showarda5288b42009-07-28 20:06:08 +0000186 join_type = query_set.query.INNER
showard43a3d262008-11-12 18:17:05 +0000187
188 filter_object = self._CustomSqlQ()
189 filter_object.add_join(join_table,
190 full_join_condition,
191 join_type,
192 alias=join_alias)
193 if exclude:
194 filter_object.add_where(full_join_key + ' IS NULL')
showard43a3d262008-11-12 18:17:05 +0000195
showarda5288b42009-07-28 20:06:08 +0000196 query_set = self._add_customSqlQ(query_set, filter_object)
197 return query_set.distinct()
showard7c785282008-05-29 19:45:12 +0000198
199
showardeaccf8f2009-04-16 03:11:33 +0000200 def _get_quoted_field(self, table, field):
showarda5288b42009-07-28 20:06:08 +0000201 return _quote_name(table) + '.' + _quote_name(field)
showard5ef36e92008-07-02 16:37:09 +0000202
203
showard7c199df2008-10-03 10:17:15 +0000204 def get_key_on_this_table(self, key_field=None):
showard5ef36e92008-07-02 16:37:09 +0000205 if key_field is None:
206 # default to primary key
207 key_field = self.model._meta.pk.column
208 return self._get_quoted_field(self.model._meta.db_table, key_field)
209
210
showardeaccf8f2009-04-16 03:11:33 +0000211 def escape_user_sql(self, sql):
212 return sql.replace('%', '%%')
213
showard5ef36e92008-07-02 16:37:09 +0000214
showard0957a842009-05-11 19:25:08 +0000215 def _custom_select_query(self, query_set, selects):
showarda5288b42009-07-28 20:06:08 +0000216 sql, params = query_set.query.as_sql()
217 from_ = sql[sql.find(' FROM'):]
218
219 if query_set.query.distinct:
showard0957a842009-05-11 19:25:08 +0000220 distinct = 'DISTINCT '
221 else:
222 distinct = ''
showarda5288b42009-07-28 20:06:08 +0000223
224 sql_query = ('SELECT ' + distinct + ','.join(selects) + from_)
showard0957a842009-05-11 19:25:08 +0000225 cursor = readonly_connection.connection().cursor()
226 cursor.execute(sql_query, params)
227 return cursor.fetchall()
228
229
showard68693f72009-05-20 00:31:53 +0000230 def _is_relation_to(self, field, model_class):
231 return field.rel and field.rel.to is model_class
showard0957a842009-05-11 19:25:08 +0000232
233
showard68693f72009-05-20 00:31:53 +0000234 def _get_pivot_iterator(self, base_objects_by_id, related_model):
showard0957a842009-05-11 19:25:08 +0000235 """
showard68693f72009-05-20 00:31:53 +0000236 Determine the relationship between this model and related_model, and
237 return a pivot iterator.
238 @param base_objects_by_id: dict of instances of this model indexed by
239 their IDs
240 @returns a pivot iterator, which yields a tuple (base_object,
241 related_object) for each relationship between a base object and a
242 related object. all base_object instances come from base_objects_by_id.
showard0957a842009-05-11 19:25:08 +0000243 Note -- this depends on Django model internals and will likely need to
244 be updated when we move to Django 1.x.
245 """
showard68693f72009-05-20 00:31:53 +0000246 # look for a field on related_model relating to this model
247 for field in related_model._meta.fields:
showard0957a842009-05-11 19:25:08 +0000248 if self._is_relation_to(field, self.model):
showard68693f72009-05-20 00:31:53 +0000249 # many-to-one
250 return self._many_to_one_pivot(base_objects_by_id,
251 related_model, field)
showard0957a842009-05-11 19:25:08 +0000252
showard68693f72009-05-20 00:31:53 +0000253 for field in related_model._meta.many_to_many:
showard0957a842009-05-11 19:25:08 +0000254 if self._is_relation_to(field, self.model):
255 # many-to-many
showard68693f72009-05-20 00:31:53 +0000256 return self._many_to_many_pivot(
257 base_objects_by_id, related_model, field.m2m_db_table(),
258 field.m2m_reverse_name(), field.m2m_column_name())
showard0957a842009-05-11 19:25:08 +0000259
260 # maybe this model has the many-to-many field
261 for field in self.model._meta.many_to_many:
showard68693f72009-05-20 00:31:53 +0000262 if self._is_relation_to(field, related_model):
263 return self._many_to_many_pivot(
264 base_objects_by_id, related_model, field.m2m_db_table(),
265 field.m2m_column_name(), field.m2m_reverse_name())
showard0957a842009-05-11 19:25:08 +0000266
267 raise ValueError('%s has no relation to %s' %
showard68693f72009-05-20 00:31:53 +0000268 (related_model, self.model))
showard0957a842009-05-11 19:25:08 +0000269
270
showard68693f72009-05-20 00:31:53 +0000271 def _many_to_one_pivot(self, base_objects_by_id, related_model,
272 foreign_key_field):
273 """
274 @returns a pivot iterator - see _get_pivot_iterator()
275 """
276 filter_data = {foreign_key_field.name + '__pk__in':
277 base_objects_by_id.keys()}
278 for related_object in related_model.objects.filter(**filter_data):
279 fresh_base_object = getattr(related_object, foreign_key_field.name)
280 # lookup base object in the dict -- we need to return instances from
281 # the dict, not fresh instances of the same models
282 base_object = base_objects_by_id[fresh_base_object._get_pk_val()]
283 yield base_object, related_object
284
285
286 def _query_pivot_table(self, base_objects_by_id, pivot_table,
287 pivot_from_field, pivot_to_field):
showard0957a842009-05-11 19:25:08 +0000288 """
289 @param id_list list of IDs of self.model objects to include
290 @param pivot_table the name of the pivot table
291 @param pivot_from_field a field name on pivot_table referencing
292 self.model
293 @param pivot_to_field a field name on pivot_table referencing the
294 related model.
showard68693f72009-05-20 00:31:53 +0000295 @returns pivot list of IDs (base_id, related_id)
showard0957a842009-05-11 19:25:08 +0000296 """
297 query = """
298 SELECT %(from_field)s, %(to_field)s
299 FROM %(table)s
300 WHERE %(from_field)s IN (%(id_list)s)
301 """ % dict(from_field=pivot_from_field,
302 to_field=pivot_to_field,
303 table=pivot_table,
showard68693f72009-05-20 00:31:53 +0000304 id_list=','.join(str(id_) for id_
305 in base_objects_by_id.iterkeys()))
showard0957a842009-05-11 19:25:08 +0000306 cursor = readonly_connection.connection().cursor()
307 cursor.execute(query)
showard68693f72009-05-20 00:31:53 +0000308 return cursor.fetchall()
showard0957a842009-05-11 19:25:08 +0000309
310
showard68693f72009-05-20 00:31:53 +0000311 def _many_to_many_pivot(self, base_objects_by_id, related_model,
312 pivot_table, pivot_from_field, pivot_to_field):
313 """
314 @param pivot_table: see _query_pivot_table
315 @param pivot_from_field: see _query_pivot_table
316 @param pivot_to_field: see _query_pivot_table
317 @returns a pivot iterator - see _get_pivot_iterator()
318 """
319 id_pivot = self._query_pivot_table(base_objects_by_id, pivot_table,
320 pivot_from_field, pivot_to_field)
321
322 all_related_ids = list(set(related_id for base_id, related_id
323 in id_pivot))
324 related_objects_by_id = related_model.objects.in_bulk(all_related_ids)
325
326 for base_id, related_id in id_pivot:
327 yield base_objects_by_id[base_id], related_objects_by_id[related_id]
328
329
330 def populate_relationships(self, base_objects, related_model,
showard0957a842009-05-11 19:25:08 +0000331 related_list_name):
332 """
showard68693f72009-05-20 00:31:53 +0000333 For each instance of this model in base_objects, add a field named
334 related_list_name listing all the related objects of type related_model.
335 related_model must be in a many-to-one or many-to-many relationship with
336 this model.
337 @param base_objects - list of instances of this model
338 @param related_model - model class related to this model
339 @param related_list_name - attribute name in which to store the related
340 object list.
showard0957a842009-05-11 19:25:08 +0000341 """
showard68693f72009-05-20 00:31:53 +0000342 if not base_objects:
showard0957a842009-05-11 19:25:08 +0000343 # if we don't bail early, we'll get a SQL error later
344 return
showard0957a842009-05-11 19:25:08 +0000345
showard68693f72009-05-20 00:31:53 +0000346 base_objects_by_id = dict((base_object._get_pk_val(), base_object)
347 for base_object in base_objects)
348 pivot_iterator = self._get_pivot_iterator(base_objects_by_id,
349 related_model)
showard0957a842009-05-11 19:25:08 +0000350
showard68693f72009-05-20 00:31:53 +0000351 for base_object in base_objects:
352 setattr(base_object, related_list_name, [])
353
354 for base_object, related_object in pivot_iterator:
355 getattr(base_object, related_list_name).append(related_object)
showard0957a842009-05-11 19:25:08 +0000356
357
showard7c785282008-05-29 19:45:12 +0000358class ValidObjectsManager(ExtendedManager):
jadmanski0afbb632008-06-06 21:10:57 +0000359 """
360 Manager returning only objects with invalid=False.
361 """
362 def get_query_set(self):
363 queryset = super(ValidObjectsManager, self).get_query_set()
364 return queryset.filter(invalid=False)
showard7c785282008-05-29 19:45:12 +0000365
366
367class ModelExtensions(object):
jadmanski0afbb632008-06-06 21:10:57 +0000368 """\
369 Mixin with convenience functions for models, built on top of the
370 default Django model functions.
371 """
372 # TODO: at least some of these functions really belong in a custom
373 # Manager class
showard7c785282008-05-29 19:45:12 +0000374
jadmanski0afbb632008-06-06 21:10:57 +0000375 field_dict = None
376 # subclasses should override if they want to support smart_get() by name
377 name_field = None
showard7c785282008-05-29 19:45:12 +0000378
379
jadmanski0afbb632008-06-06 21:10:57 +0000380 @classmethod
381 def get_field_dict(cls):
382 if cls.field_dict is None:
383 cls.field_dict = {}
384 for field in cls._meta.fields:
385 cls.field_dict[field.name] = field
386 return cls.field_dict
showard7c785282008-05-29 19:45:12 +0000387
388
jadmanski0afbb632008-06-06 21:10:57 +0000389 @classmethod
390 def clean_foreign_keys(cls, data):
391 """\
392 -Convert foreign key fields in data from <field>_id to just
393 <field>.
394 -replace foreign key objects with their IDs
395 This method modifies data in-place.
396 """
397 for field in cls._meta.fields:
398 if not field.rel:
399 continue
400 if (field.attname != field.name and
401 field.attname in data):
402 data[field.name] = data[field.attname]
403 del data[field.attname]
showarde732ee72008-09-23 19:15:43 +0000404 if field.name not in data:
405 continue
jadmanski0afbb632008-06-06 21:10:57 +0000406 value = data[field.name]
407 if isinstance(value, dbmodels.Model):
showardf8b19042009-05-12 17:22:49 +0000408 data[field.name] = value._get_pk_val()
showard7c785282008-05-29 19:45:12 +0000409
410
showard21baa452008-10-21 00:08:39 +0000411 @classmethod
412 def _convert_booleans(cls, data):
413 """
414 Ensure BooleanFields actually get bool values. The Django MySQL
415 backend returns ints for BooleanFields, which is almost always not
416 a problem, but it can be annoying in certain situations.
417 """
418 for field in cls._meta.fields:
showardf8b19042009-05-12 17:22:49 +0000419 if type(field) == dbmodels.BooleanField and field.name in data:
showard21baa452008-10-21 00:08:39 +0000420 data[field.name] = bool(data[field.name])
421
422
jadmanski0afbb632008-06-06 21:10:57 +0000423 # TODO(showard) - is there a way to not have to do this?
424 @classmethod
425 def provide_default_values(cls, data):
426 """\
427 Provide default values for fields with default values which have
428 nothing passed in.
showard7c785282008-05-29 19:45:12 +0000429
jadmanski0afbb632008-06-06 21:10:57 +0000430 For CharField and TextField fields with "blank=True", if nothing
431 is passed, we fill in an empty string value, even if there's no
432 default set.
433 """
434 new_data = dict(data)
435 field_dict = cls.get_field_dict()
436 for name, obj in field_dict.iteritems():
437 if data.get(name) is not None:
438 continue
439 if obj.default is not dbmodels.fields.NOT_PROVIDED:
440 new_data[name] = obj.default
441 elif (isinstance(obj, dbmodels.CharField) or
442 isinstance(obj, dbmodels.TextField)):
443 new_data[name] = ''
444 return new_data
showard7c785282008-05-29 19:45:12 +0000445
446
jadmanski0afbb632008-06-06 21:10:57 +0000447 @classmethod
448 def convert_human_readable_values(cls, data, to_human_readable=False):
449 """\
450 Performs conversions on user-supplied field data, to make it
451 easier for users to pass human-readable data.
showard7c785282008-05-29 19:45:12 +0000452
jadmanski0afbb632008-06-06 21:10:57 +0000453 For all fields that have choice sets, convert their values
454 from human-readable strings to enum values, if necessary. This
455 allows users to pass strings instead of the corresponding
456 integer values.
showard7c785282008-05-29 19:45:12 +0000457
jadmanski0afbb632008-06-06 21:10:57 +0000458 For all foreign key fields, call smart_get with the supplied
459 data. This allows the user to pass either an ID value or
460 the name of the object as a string.
showard7c785282008-05-29 19:45:12 +0000461
jadmanski0afbb632008-06-06 21:10:57 +0000462 If to_human_readable=True, perform the inverse - i.e. convert
463 numeric values to human readable values.
showard7c785282008-05-29 19:45:12 +0000464
jadmanski0afbb632008-06-06 21:10:57 +0000465 This method modifies data in-place.
466 """
467 field_dict = cls.get_field_dict()
468 for field_name in data:
showarde732ee72008-09-23 19:15:43 +0000469 if field_name not in field_dict or data[field_name] is None:
jadmanski0afbb632008-06-06 21:10:57 +0000470 continue
471 field_obj = field_dict[field_name]
472 # convert enum values
473 if field_obj.choices:
474 for choice_data in field_obj.choices:
475 # choice_data is (value, name)
476 if to_human_readable:
477 from_val, to_val = choice_data
478 else:
479 to_val, from_val = choice_data
480 if from_val == data[field_name]:
481 data[field_name] = to_val
482 break
483 # convert foreign key values
484 elif field_obj.rel:
showarda4ea5742009-02-17 20:56:23 +0000485 dest_obj = field_obj.rel.to.smart_get(data[field_name],
486 valid_only=False)
showardf8b19042009-05-12 17:22:49 +0000487 if to_human_readable:
488 if dest_obj.name_field is not None:
489 data[field_name] = getattr(dest_obj,
490 dest_obj.name_field)
jadmanski0afbb632008-06-06 21:10:57 +0000491 else:
showardb0a73032009-03-27 18:35:41 +0000492 data[field_name] = dest_obj
showard7c785282008-05-29 19:45:12 +0000493
494
jadmanski0afbb632008-06-06 21:10:57 +0000495 @classmethod
496 def validate_field_names(cls, data):
497 'Checks for extraneous fields in data.'
498 errors = {}
499 field_dict = cls.get_field_dict()
500 for field_name in data:
501 if field_name not in field_dict:
502 errors[field_name] = 'No field of this name'
503 return errors
showard7c785282008-05-29 19:45:12 +0000504
505
jadmanski0afbb632008-06-06 21:10:57 +0000506 @classmethod
507 def prepare_data_args(cls, data, kwargs):
508 'Common preparation for add_object and update_object'
509 data = dict(data) # don't modify the default keyword arg
510 data.update(kwargs)
511 # must check for extraneous field names here, while we have the
512 # data in a dict
513 errors = cls.validate_field_names(data)
514 if errors:
515 raise ValidationError(errors)
516 cls.convert_human_readable_values(data)
517 return data
showard7c785282008-05-29 19:45:12 +0000518
519
jadmanski0afbb632008-06-06 21:10:57 +0000520 def validate_unique(self):
521 """\
522 Validate that unique fields are unique. Django manipulators do
523 this too, but they're a huge pain to use manually. Trust me.
524 """
525 errors = {}
526 cls = type(self)
527 field_dict = self.get_field_dict()
528 manager = cls.get_valid_manager()
529 for field_name, field_obj in field_dict.iteritems():
530 if not field_obj.unique:
531 continue
showard7c785282008-05-29 19:45:12 +0000532
jadmanski0afbb632008-06-06 21:10:57 +0000533 value = getattr(self, field_name)
534 existing_objs = manager.filter(**{field_name : value})
535 num_existing = existing_objs.count()
showard7c785282008-05-29 19:45:12 +0000536
jadmanski0afbb632008-06-06 21:10:57 +0000537 if num_existing == 0:
538 continue
539 if num_existing == 1 and existing_objs[0].id == self.id:
540 continue
541 errors[field_name] = (
542 'This value must be unique (%s)' % (value))
543 return errors
showard7c785282008-05-29 19:45:12 +0000544
545
showarda5288b42009-07-28 20:06:08 +0000546 def _validate(self):
547 """
548 First coerces all fields on this instance to their proper Python types.
549 Then runs validation on every field. Returns a dictionary of
550 field_name -> error_list.
551
552 Based on validate() from django.db.models.Model in Django 0.96, which
553 was removed in Django 1.0. It should reappear in a later version. See:
554 http://code.djangoproject.com/ticket/6845
555 """
556 error_dict = {}
557 for f in self._meta.fields:
558 try:
559 python_value = f.to_python(
560 getattr(self, f.attname, f.get_default()))
561 except django.core.exceptions.ValidationError, e:
562 error_dict[f.name] = str(e.message)
563 continue
564
565 if not f.blank and not python_value:
566 error_dict[f.name] = 'This field is required.'
567 continue
568
569 setattr(self, f.attname, python_value)
570
571 return error_dict
572
573
jadmanski0afbb632008-06-06 21:10:57 +0000574 def do_validate(self):
showarda5288b42009-07-28 20:06:08 +0000575 errors = self._validate()
jadmanski0afbb632008-06-06 21:10:57 +0000576 unique_errors = self.validate_unique()
577 for field_name, error in unique_errors.iteritems():
578 errors.setdefault(field_name, error)
579 if errors:
580 raise ValidationError(errors)
showard7c785282008-05-29 19:45:12 +0000581
582
jadmanski0afbb632008-06-06 21:10:57 +0000583 # actually (externally) useful methods follow
showard7c785282008-05-29 19:45:12 +0000584
jadmanski0afbb632008-06-06 21:10:57 +0000585 @classmethod
586 def add_object(cls, data={}, **kwargs):
587 """\
588 Returns a new object created with the given data (a dictionary
589 mapping field names to values). Merges any extra keyword args
590 into data.
591 """
592 data = cls.prepare_data_args(data, kwargs)
593 data = cls.provide_default_values(data)
594 obj = cls(**data)
595 obj.do_validate()
596 obj.save()
597 return obj
showard7c785282008-05-29 19:45:12 +0000598
599
jadmanski0afbb632008-06-06 21:10:57 +0000600 def update_object(self, data={}, **kwargs):
601 """\
602 Updates the object with the given data (a dictionary mapping
603 field names to values). Merges any extra keyword args into
604 data.
605 """
606 data = self.prepare_data_args(data, kwargs)
607 for field_name, value in data.iteritems():
showardb0a73032009-03-27 18:35:41 +0000608 setattr(self, field_name, value)
jadmanski0afbb632008-06-06 21:10:57 +0000609 self.do_validate()
610 self.save()
showard7c785282008-05-29 19:45:12 +0000611
612
jadmanski0afbb632008-06-06 21:10:57 +0000613 @classmethod
showard7ac7b7a2008-07-21 20:24:29 +0000614 def query_objects(cls, filter_data, valid_only=True, initial_query=None):
jadmanski0afbb632008-06-06 21:10:57 +0000615 """\
616 Returns a QuerySet object for querying the given model_class
617 with the given filter_data. Optional special arguments in
618 filter_data include:
619 -query_start: index of first return to return
620 -query_limit: maximum number of results to return
621 -sort_by: list of fields to sort on. prefixing a '-' onto a
622 field name changes the sort to descending order.
623 -extra_args: keyword args to pass to query.extra() (see Django
624 DB layer documentation)
showarda5288b42009-07-28 20:06:08 +0000625 -extra_where: extra WHERE clause to append
jadmanski0afbb632008-06-06 21:10:57 +0000626 """
showardc0ac3a72009-07-08 21:14:45 +0000627 filter_data = dict(filter_data) # copy so we don't mutate the original
jadmanski0afbb632008-06-06 21:10:57 +0000628 query_start = filter_data.pop('query_start', None)
629 query_limit = filter_data.pop('query_limit', None)
630 if query_start and not query_limit:
631 raise ValueError('Cannot pass query_start without '
632 'query_limit')
633 sort_by = filter_data.pop('sort_by', [])
634 extra_args = filter_data.pop('extra_args', {})
635 extra_where = filter_data.pop('extra_where', None)
636 if extra_where:
showard1e935f12008-07-11 00:11:36 +0000637 # escape %'s
showardeaccf8f2009-04-16 03:11:33 +0000638 extra_where = cls.objects.escape_user_sql(extra_where)
jadmanski0afbb632008-06-06 21:10:57 +0000639 extra_args.setdefault('where', []).append(extra_where)
showard7ac7b7a2008-07-21 20:24:29 +0000640 use_distinct = not filter_data.pop('no_distinct', False)
showard7c785282008-05-29 19:45:12 +0000641
showard7ac7b7a2008-07-21 20:24:29 +0000642 if initial_query is None:
643 if valid_only:
644 initial_query = cls.get_valid_manager()
645 else:
646 initial_query = cls.objects
647 query = initial_query.filter(**filter_data)
648 if use_distinct:
649 query = query.distinct()
showard7c785282008-05-29 19:45:12 +0000650
jadmanski0afbb632008-06-06 21:10:57 +0000651 # other arguments
652 if extra_args:
653 query = query.extra(**extra_args)
showard09096d82008-07-07 23:20:49 +0000654 query = query._clone(klass=ReadonlyQuerySet)
showard7c785282008-05-29 19:45:12 +0000655
jadmanski0afbb632008-06-06 21:10:57 +0000656 # sorting + paging
657 assert isinstance(sort_by, list) or isinstance(sort_by, tuple)
658 query = query.order_by(*sort_by)
659 if query_start is not None and query_limit is not None:
660 query_limit += query_start
661 return query[query_start:query_limit]
showard7c785282008-05-29 19:45:12 +0000662
663
jadmanski0afbb632008-06-06 21:10:57 +0000664 @classmethod
showard585c2ab2008-07-23 19:29:49 +0000665 def query_count(cls, filter_data, initial_query=None):
jadmanski0afbb632008-06-06 21:10:57 +0000666 """\
667 Like query_objects, but retreive only the count of results.
668 """
669 filter_data.pop('query_start', None)
670 filter_data.pop('query_limit', None)
showard585c2ab2008-07-23 19:29:49 +0000671 query = cls.query_objects(filter_data, initial_query=initial_query)
672 return query.count()
showard7c785282008-05-29 19:45:12 +0000673
674
jadmanski0afbb632008-06-06 21:10:57 +0000675 @classmethod
676 def clean_object_dicts(cls, field_dicts):
677 """\
678 Take a list of dicts corresponding to object (as returned by
679 query.values()) and clean the data to be more suitable for
680 returning to the user.
681 """
showarde732ee72008-09-23 19:15:43 +0000682 for field_dict in field_dicts:
683 cls.clean_foreign_keys(field_dict)
showard21baa452008-10-21 00:08:39 +0000684 cls._convert_booleans(field_dict)
showarde732ee72008-09-23 19:15:43 +0000685 cls.convert_human_readable_values(field_dict,
686 to_human_readable=True)
showard7c785282008-05-29 19:45:12 +0000687
688
jadmanski0afbb632008-06-06 21:10:57 +0000689 @classmethod
showarde732ee72008-09-23 19:15:43 +0000690 def list_objects(cls, filter_data, initial_query=None, fields=None):
jadmanski0afbb632008-06-06 21:10:57 +0000691 """\
692 Like query_objects, but return a list of dictionaries.
693 """
showard7ac7b7a2008-07-21 20:24:29 +0000694 query = cls.query_objects(filter_data, initial_query=initial_query)
showarde732ee72008-09-23 19:15:43 +0000695 field_dicts = [model_object.get_object_dict(fields)
696 for model_object in query]
jadmanski0afbb632008-06-06 21:10:57 +0000697 return field_dicts
showard7c785282008-05-29 19:45:12 +0000698
699
jadmanski0afbb632008-06-06 21:10:57 +0000700 @classmethod
showarda4ea5742009-02-17 20:56:23 +0000701 def smart_get(cls, id_or_name, valid_only=True):
jadmanski0afbb632008-06-06 21:10:57 +0000702 """\
703 smart_get(integer) -> get object by ID
704 smart_get(string) -> get object by name_field
jadmanski0afbb632008-06-06 21:10:57 +0000705 """
showarda4ea5742009-02-17 20:56:23 +0000706 if valid_only:
707 manager = cls.get_valid_manager()
708 else:
709 manager = cls.objects
710
711 if isinstance(id_or_name, (int, long)):
712 return manager.get(pk=id_or_name)
713 if isinstance(id_or_name, basestring):
714 return manager.get(**{cls.name_field : id_or_name})
715 raise ValueError(
716 'Invalid positional argument: %s (%s)' % (id_or_name,
717 type(id_or_name)))
showard7c785282008-05-29 19:45:12 +0000718
719
showardbe3ec042008-11-12 18:16:07 +0000720 @classmethod
721 def smart_get_bulk(cls, id_or_name_list):
722 invalid_inputs = []
723 result_objects = []
724 for id_or_name in id_or_name_list:
725 try:
726 result_objects.append(cls.smart_get(id_or_name))
727 except cls.DoesNotExist:
728 invalid_inputs.append(id_or_name)
729 if invalid_inputs:
mbligh7a3ebe32008-12-01 17:10:33 +0000730 raise cls.DoesNotExist('The following %ss do not exist: %s'
731 % (cls.__name__.lower(),
732 ', '.join(invalid_inputs)))
showardbe3ec042008-11-12 18:16:07 +0000733 return result_objects
734
735
showarde732ee72008-09-23 19:15:43 +0000736 def get_object_dict(self, fields=None):
jadmanski0afbb632008-06-06 21:10:57 +0000737 """\
738 Return a dictionary mapping fields to this object's values.
739 """
showarde732ee72008-09-23 19:15:43 +0000740 if fields is None:
741 fields = self.get_field_dict().iterkeys()
jadmanski0afbb632008-06-06 21:10:57 +0000742 object_dict = dict((field_name, getattr(self, field_name))
showarde732ee72008-09-23 19:15:43 +0000743 for field_name in fields)
jadmanski0afbb632008-06-06 21:10:57 +0000744 self.clean_object_dicts([object_dict])
showardd3dc1992009-04-22 21:01:40 +0000745 self._postprocess_object_dict(object_dict)
jadmanski0afbb632008-06-06 21:10:57 +0000746 return object_dict
showard7c785282008-05-29 19:45:12 +0000747
748
showardd3dc1992009-04-22 21:01:40 +0000749 def _postprocess_object_dict(self, object_dict):
750 """For subclasses to override."""
751 pass
752
753
jadmanski0afbb632008-06-06 21:10:57 +0000754 @classmethod
755 def get_valid_manager(cls):
756 return cls.objects
showard7c785282008-05-29 19:45:12 +0000757
758
showard2bab8f42008-11-12 18:15:22 +0000759 def _record_attributes(self, attributes):
760 """
761 See on_attribute_changed.
762 """
763 assert not isinstance(attributes, basestring)
764 self._recorded_attributes = dict((attribute, getattr(self, attribute))
765 for attribute in attributes)
766
767
768 def _check_for_updated_attributes(self):
769 """
770 See on_attribute_changed.
771 """
772 for attribute, original_value in self._recorded_attributes.iteritems():
773 new_value = getattr(self, attribute)
774 if original_value != new_value:
775 self.on_attribute_changed(attribute, original_value)
776 self._record_attributes(self._recorded_attributes.keys())
777
778
779 def on_attribute_changed(self, attribute, old_value):
780 """
781 Called whenever an attribute is updated. To be overridden.
782
783 To use this method, you must:
784 * call _record_attributes() from __init__() (after making the super
785 call) with a list of attributes for which you want to be notified upon
786 change.
787 * call _check_for_updated_attributes() from save().
788 """
789 pass
790
791
showard7c785282008-05-29 19:45:12 +0000792class ModelWithInvalid(ModelExtensions):
jadmanski0afbb632008-06-06 21:10:57 +0000793 """
794 Overrides model methods save() and delete() to support invalidation in
795 place of actual deletion. Subclasses must have a boolean "invalid"
796 field.
797 """
showard7c785282008-05-29 19:45:12 +0000798
showarda5288b42009-07-28 20:06:08 +0000799 def save(self, *args, **kwargs):
showardddb90992009-02-11 23:39:32 +0000800 first_time = (self.id is None)
801 if first_time:
802 # see if this object was previously added and invalidated
803 my_name = getattr(self, self.name_field)
804 filters = {self.name_field : my_name, 'invalid' : True}
805 try:
806 old_object = self.__class__.objects.get(**filters)
807 self.id = old_object.id
808 except self.DoesNotExist:
809 # no existing object
810 pass
showard7c785282008-05-29 19:45:12 +0000811
showarda5288b42009-07-28 20:06:08 +0000812 super(ModelWithInvalid, self).save(*args, **kwargs)
showard7c785282008-05-29 19:45:12 +0000813
814
jadmanski0afbb632008-06-06 21:10:57 +0000815 def clean_object(self):
816 """
817 This method is called when an object is marked invalid.
818 Subclasses should override this to clean up relationships that
819 should no longer exist if the object were deleted."""
820 pass
showard7c785282008-05-29 19:45:12 +0000821
822
jadmanski0afbb632008-06-06 21:10:57 +0000823 def delete(self):
824 assert not self.invalid
825 self.invalid = True
826 self.save()
827 self.clean_object()
showard7c785282008-05-29 19:45:12 +0000828
829
jadmanski0afbb632008-06-06 21:10:57 +0000830 @classmethod
831 def get_valid_manager(cls):
832 return cls.valid_objects
showard7c785282008-05-29 19:45:12 +0000833
834
jadmanski0afbb632008-06-06 21:10:57 +0000835 class Manipulator(object):
836 """
837 Force default manipulators to look only at valid objects -
838 otherwise they will match against invalid objects when checking
839 uniqueness.
840 """
841 @classmethod
842 def _prepare(cls, model):
843 super(ModelWithInvalid.Manipulator, cls)._prepare(model)
844 cls.manager = model.valid_objects
showardf8b19042009-05-12 17:22:49 +0000845
846
847class ModelWithAttributes(object):
848 """
849 Mixin class for models that have an attribute model associated with them.
850 The attribute model is assumed to have its value field named "value".
851 """
852
853 def _get_attribute_model_and_args(self, attribute):
854 """
855 Subclasses should override this to return a tuple (attribute_model,
856 keyword_args), where attribute_model is a model class and keyword_args
857 is a dict of args to pass to attribute_model.objects.get() to get an
858 instance of the given attribute on this object.
859 """
860 raise NotImplemented
861
862
863 def set_attribute(self, attribute, value):
864 attribute_model, get_args = self._get_attribute_model_and_args(
865 attribute)
866 attribute_object, _ = attribute_model.objects.get_or_create(**get_args)
867 attribute_object.value = value
868 attribute_object.save()
869
870
871 def delete_attribute(self, attribute):
872 attribute_model, get_args = self._get_attribute_model_and_args(
873 attribute)
874 try:
875 attribute_model.objects.get(**get_args).delete()
876 except HostAttribute.DoesNotExist:
877 pass
878
879
880 def set_or_delete_attribute(self, attribute, value):
881 if value is None:
882 self.delete_attribute(attribute)
883 else:
884 self.set_attribute(attribute, value)