blob: 32b4898d718efa8746d1326d02da186bd09ba326 [file] [log] [blame]
showard7c785282008-05-29 19:45:12 +00001"""
2Extensions to Django's model logic.
3"""
4
5from django.db import models as dbmodels, backend, connection
6from django.utils import datastructures
showard56e93772008-10-06 10:06:22 +00007from autotest_lib.frontend.afe import readonly_connection
showard7c785282008-05-29 19:45:12 +00008
9class ValidationError(Exception):
jadmanski0afbb632008-06-06 21:10:57 +000010 """\
11 Data validation error in adding or updating an object. The associated
12 value is a dictionary mapping field names to error strings.
13 """
showard7c785282008-05-29 19:45:12 +000014
15
showard09096d82008-07-07 23:20:49 +000016def _wrap_with_readonly(method):
17 def wrapper_method(*args, **kwargs):
showard56e93772008-10-06 10:06:22 +000018 readonly_connection.connection().set_django_connection()
showard09096d82008-07-07 23:20:49 +000019 try:
20 return method(*args, **kwargs)
21 finally:
showard56e93772008-10-06 10:06:22 +000022 readonly_connection.connection().unset_django_connection()
showard09096d82008-07-07 23:20:49 +000023 wrapper_method.__name__ = method.__name__
24 return wrapper_method
25
26
27def _wrap_generator_with_readonly(generator):
28 """
29 We have to wrap generators specially. Assume it performs
30 the query on the first call to next().
31 """
32 def wrapper_generator(*args, **kwargs):
33 generator_obj = generator(*args, **kwargs)
showard56e93772008-10-06 10:06:22 +000034 readonly_connection.connection().set_django_connection()
showard09096d82008-07-07 23:20:49 +000035 try:
36 first_value = generator_obj.next()
37 finally:
showard56e93772008-10-06 10:06:22 +000038 readonly_connection.connection().unset_django_connection()
showard09096d82008-07-07 23:20:49 +000039 yield first_value
40
41 while True:
42 yield generator_obj.next()
43
44 wrapper_generator.__name__ = generator.__name__
45 return wrapper_generator
46
47
48def _make_queryset_readonly(queryset):
49 """
50 Wrap all methods that do database queries with a readonly connection.
51 """
52 db_query_methods = ['count', 'get', 'get_or_create', 'latest', 'in_bulk',
53 'delete']
54 for method_name in db_query_methods:
55 method = getattr(queryset, method_name)
56 wrapped_method = _wrap_with_readonly(method)
57 setattr(queryset, method_name, wrapped_method)
58
59 queryset.iterator = _wrap_generator_with_readonly(queryset.iterator)
60
61
62class ReadonlyQuerySet(dbmodels.query.QuerySet):
63 """
64 QuerySet object that performs all database queries with the read-only
65 connection.
66 """
67 def __init__(self, model=None):
68 super(ReadonlyQuerySet, self).__init__(model)
69 _make_queryset_readonly(self)
70
71
72 def values(self, *fields):
73 return self._clone(klass=ReadonlyValuesQuerySet, _fields=fields)
74
75
76class ReadonlyValuesQuerySet(dbmodels.query.ValuesQuerySet):
77 def __init__(self, model=None):
78 super(ReadonlyValuesQuerySet, self).__init__(model)
79 _make_queryset_readonly(self)
80
81
showard7c785282008-05-29 19:45:12 +000082class ExtendedManager(dbmodels.Manager):
jadmanski0afbb632008-06-06 21:10:57 +000083 """\
84 Extended manager supporting subquery filtering.
85 """
showard7c785282008-05-29 19:45:12 +000086
showard08f981b2008-06-24 21:59:03 +000087 class _CustomJoinQ(dbmodels.Q):
jadmanski0afbb632008-06-06 21:10:57 +000088 """
showard08f981b2008-06-24 21:59:03 +000089 Django "Q" object supporting a custom suffix for join aliases.See
90 filter_custom_join() for why this can be useful.
91 """
showard7c785282008-05-29 19:45:12 +000092
showard08f981b2008-06-24 21:59:03 +000093 def __init__(self, join_suffix, **kwargs):
94 super(ExtendedManager._CustomJoinQ, self).__init__(**kwargs)
95 self._join_suffix = join_suffix
showard7c785282008-05-29 19:45:12 +000096
showard08f981b2008-06-24 21:59:03 +000097
98 @staticmethod
99 def _substitute_aliases(renamed_aliases, condition):
100 for old_alias, new_alias in renamed_aliases:
101 condition = condition.replace(backend.quote_name(old_alias),
102 backend.quote_name(new_alias))
103 return condition
104
105
106 @staticmethod
107 def _unquote_name(name):
108 'This may be MySQL specific'
109 if backend.quote_name(name) == name:
110 return name[1:-1]
111 return name
showard7c785282008-05-29 19:45:12 +0000112
113
jadmanski0afbb632008-06-06 21:10:57 +0000114 def get_sql(self, opts):
showard08f981b2008-06-24 21:59:03 +0000115 joins, where, params = (
116 super(ExtendedManager._CustomJoinQ, self).get_sql(opts))
117
118 new_joins = datastructures.SortedDict()
119
120 # rename all join aliases and correct references in later joins
121 renamed_tables = []
122 # using iteritems seems to mess up the ordering here
123 for alias, (table, join_type, condition) in joins.items():
124 alias = self._unquote_name(alias)
125 new_alias = alias + self._join_suffix
126 renamed_tables.append((alias, new_alias))
127 condition = self._substitute_aliases(renamed_tables, condition)
128 new_alias = backend.quote_name(new_alias)
129 new_joins[new_alias] = (table, join_type, condition)
130
131 # correct references in where
132 new_where = []
133 for clause in where:
134 new_where.append(
135 self._substitute_aliases(renamed_tables, clause))
136
137 return new_joins, new_where, params
showard7c785282008-05-29 19:45:12 +0000138
139
showard43a3d262008-11-12 18:17:05 +0000140 class _CustomSqlQ(dbmodels.Q):
141 def __init__(self):
142 self._joins = datastructures.SortedDict()
143 self._where, self._params = [], []
144
145
146 def add_join(self, table, condition, join_type, alias=None):
147 if alias is None:
148 alias = table
showard43a3d262008-11-12 18:17:05 +0000149 self._joins[alias] = (table, join_type, condition)
150
151
152 def add_where(self, where, params=[]):
153 self._where.append(where)
154 self._params.extend(params)
155
156
157 def get_sql(self, opts):
158 return self._joins, self._where, self._params
159
160
161 def add_join(self, query_set, join_table, join_key,
showard0957a842009-05-11 19:25:08 +0000162 join_condition='', suffix='', exclude=False,
163 force_left_join=False):
164 """
165 Add a join to query_set.
166 @param join_table table to join to
167 @param join_key field referencing back to this model to use for the join
168 @param join_condition extra condition for the ON clause of the join
169 @param suffix suffix to add to join_table for the join alias
170 @param exclude if true, exclude rows that match this join (will use a
171 LEFT JOIN and an appropriate WHERE condition)
172 @param force_left_join - if true, a LEFT JOIN will be used instead of an
173 INNER JOIN regardless of other options
174 """
175 join_from_table = self.model._meta.db_table
176 join_from_key = self.model._meta.pk.name
showard43a3d262008-11-12 18:17:05 +0000177 join_alias = join_table + suffix
178 full_join_key = join_alias + '.' + join_key
showard0957a842009-05-11 19:25:08 +0000179 full_join_condition = '%s = %s.%s' % (full_join_key, join_from_table,
180 join_from_key)
showard43a3d262008-11-12 18:17:05 +0000181 if join_condition:
182 full_join_condition += ' AND (' + join_condition + ')'
183 if exclude or force_left_join:
184 join_type = 'LEFT JOIN'
185 else:
186 join_type = 'INNER JOIN'
187
188 filter_object = self._CustomSqlQ()
189 filter_object.add_join(join_table,
190 full_join_condition,
191 join_type,
192 alias=join_alias)
193 if exclude:
194 filter_object.add_where(full_join_key + ' IS NULL')
195 return query_set.filter(filter_object).distinct()
196
197
showard08f981b2008-06-24 21:59:03 +0000198 def filter_custom_join(self, join_suffix, **kwargs):
jadmanski0afbb632008-06-06 21:10:57 +0000199 """
showard08f981b2008-06-24 21:59:03 +0000200 Just like Django filter(), but allows the user to specify a custom
201 suffix for the join aliases involves in the filter. This makes it
202 possible to join against a table multiple times (as long as a different
203 suffix is used each time), which is necessary for certain queries.
jadmanski0afbb632008-06-06 21:10:57 +0000204 """
showard08f981b2008-06-24 21:59:03 +0000205 filter_object = self._CustomJoinQ(join_suffix, **kwargs)
206 return self.complex_filter(filter_object)
showard7c785282008-05-29 19:45:12 +0000207
208
showardeaccf8f2009-04-16 03:11:33 +0000209 def _get_quoted_field(self, table, field):
showard5ef36e92008-07-02 16:37:09 +0000210 return (backend.quote_name(table) + '.' + backend.quote_name(field))
211
212
showard7c199df2008-10-03 10:17:15 +0000213 def get_key_on_this_table(self, key_field=None):
showard5ef36e92008-07-02 16:37:09 +0000214 if key_field is None:
215 # default to primary key
216 key_field = self.model._meta.pk.column
217 return self._get_quoted_field(self.model._meta.db_table, key_field)
218
219
showardeaccf8f2009-04-16 03:11:33 +0000220 def escape_user_sql(self, sql):
221 return sql.replace('%', '%%')
222
showard5ef36e92008-07-02 16:37:09 +0000223
showard0957a842009-05-11 19:25:08 +0000224 def _custom_select_query(self, query_set, selects):
225 query_selects, where, params = query_set._get_sql_clause()
226 if query_set._distinct:
227 distinct = 'DISTINCT '
228 else:
229 distinct = ''
230 sql_query = 'SELECT ' + distinct + ','.join(selects) + where
231 cursor = readonly_connection.connection().cursor()
232 cursor.execute(sql_query, params)
233 return cursor.fetchall()
234
235
showard68693f72009-05-20 00:31:53 +0000236 def _is_relation_to(self, field, model_class):
237 return field.rel and field.rel.to is model_class
showard0957a842009-05-11 19:25:08 +0000238
239
showard68693f72009-05-20 00:31:53 +0000240 def _get_pivot_iterator(self, base_objects_by_id, related_model):
showard0957a842009-05-11 19:25:08 +0000241 """
showard68693f72009-05-20 00:31:53 +0000242 Determine the relationship between this model and related_model, and
243 return a pivot iterator.
244 @param base_objects_by_id: dict of instances of this model indexed by
245 their IDs
246 @returns a pivot iterator, which yields a tuple (base_object,
247 related_object) for each relationship between a base object and a
248 related object. all base_object instances come from base_objects_by_id.
showard0957a842009-05-11 19:25:08 +0000249 Note -- this depends on Django model internals and will likely need to
250 be updated when we move to Django 1.x.
251 """
showard68693f72009-05-20 00:31:53 +0000252 # look for a field on related_model relating to this model
253 for field in related_model._meta.fields:
showard0957a842009-05-11 19:25:08 +0000254 if self._is_relation_to(field, self.model):
showard68693f72009-05-20 00:31:53 +0000255 # many-to-one
256 return self._many_to_one_pivot(base_objects_by_id,
257 related_model, field)
showard0957a842009-05-11 19:25:08 +0000258
showard68693f72009-05-20 00:31:53 +0000259 for field in related_model._meta.many_to_many:
showard0957a842009-05-11 19:25:08 +0000260 if self._is_relation_to(field, self.model):
261 # many-to-many
showard68693f72009-05-20 00:31:53 +0000262 return self._many_to_many_pivot(
263 base_objects_by_id, related_model, field.m2m_db_table(),
264 field.m2m_reverse_name(), field.m2m_column_name())
showard0957a842009-05-11 19:25:08 +0000265
266 # maybe this model has the many-to-many field
267 for field in self.model._meta.many_to_many:
showard68693f72009-05-20 00:31:53 +0000268 if self._is_relation_to(field, related_model):
269 return self._many_to_many_pivot(
270 base_objects_by_id, related_model, field.m2m_db_table(),
271 field.m2m_column_name(), field.m2m_reverse_name())
showard0957a842009-05-11 19:25:08 +0000272
273 raise ValueError('%s has no relation to %s' %
showard68693f72009-05-20 00:31:53 +0000274 (related_model, self.model))
showard0957a842009-05-11 19:25:08 +0000275
276
showard68693f72009-05-20 00:31:53 +0000277 def _many_to_one_pivot(self, base_objects_by_id, related_model,
278 foreign_key_field):
279 """
280 @returns a pivot iterator - see _get_pivot_iterator()
281 """
282 filter_data = {foreign_key_field.name + '__pk__in':
283 base_objects_by_id.keys()}
284 for related_object in related_model.objects.filter(**filter_data):
285 fresh_base_object = getattr(related_object, foreign_key_field.name)
286 # lookup base object in the dict -- we need to return instances from
287 # the dict, not fresh instances of the same models
288 base_object = base_objects_by_id[fresh_base_object._get_pk_val()]
289 yield base_object, related_object
290
291
292 def _query_pivot_table(self, base_objects_by_id, pivot_table,
293 pivot_from_field, pivot_to_field):
showard0957a842009-05-11 19:25:08 +0000294 """
295 @param id_list list of IDs of self.model objects to include
296 @param pivot_table the name of the pivot table
297 @param pivot_from_field a field name on pivot_table referencing
298 self.model
299 @param pivot_to_field a field name on pivot_table referencing the
300 related model.
showard68693f72009-05-20 00:31:53 +0000301 @returns pivot list of IDs (base_id, related_id)
showard0957a842009-05-11 19:25:08 +0000302 """
303 query = """
304 SELECT %(from_field)s, %(to_field)s
305 FROM %(table)s
306 WHERE %(from_field)s IN (%(id_list)s)
307 """ % dict(from_field=pivot_from_field,
308 to_field=pivot_to_field,
309 table=pivot_table,
showard68693f72009-05-20 00:31:53 +0000310 id_list=','.join(str(id_) for id_
311 in base_objects_by_id.iterkeys()))
showard0957a842009-05-11 19:25:08 +0000312 cursor = readonly_connection.connection().cursor()
313 cursor.execute(query)
showard68693f72009-05-20 00:31:53 +0000314 return cursor.fetchall()
showard0957a842009-05-11 19:25:08 +0000315
316
showard68693f72009-05-20 00:31:53 +0000317 def _many_to_many_pivot(self, base_objects_by_id, related_model,
318 pivot_table, pivot_from_field, pivot_to_field):
319 """
320 @param pivot_table: see _query_pivot_table
321 @param pivot_from_field: see _query_pivot_table
322 @param pivot_to_field: see _query_pivot_table
323 @returns a pivot iterator - see _get_pivot_iterator()
324 """
325 id_pivot = self._query_pivot_table(base_objects_by_id, pivot_table,
326 pivot_from_field, pivot_to_field)
327
328 all_related_ids = list(set(related_id for base_id, related_id
329 in id_pivot))
330 related_objects_by_id = related_model.objects.in_bulk(all_related_ids)
331
332 for base_id, related_id in id_pivot:
333 yield base_objects_by_id[base_id], related_objects_by_id[related_id]
334
335
336 def populate_relationships(self, base_objects, related_model,
showard0957a842009-05-11 19:25:08 +0000337 related_list_name):
338 """
showard68693f72009-05-20 00:31:53 +0000339 For each instance of this model in base_objects, add a field named
340 related_list_name listing all the related objects of type related_model.
341 related_model must be in a many-to-one or many-to-many relationship with
342 this model.
343 @param base_objects - list of instances of this model
344 @param related_model - model class related to this model
345 @param related_list_name - attribute name in which to store the related
346 object list.
showard0957a842009-05-11 19:25:08 +0000347 """
showard68693f72009-05-20 00:31:53 +0000348 if not base_objects:
showard0957a842009-05-11 19:25:08 +0000349 # if we don't bail early, we'll get a SQL error later
350 return
showard0957a842009-05-11 19:25:08 +0000351
showard68693f72009-05-20 00:31:53 +0000352 base_objects_by_id = dict((base_object._get_pk_val(), base_object)
353 for base_object in base_objects)
354 pivot_iterator = self._get_pivot_iterator(base_objects_by_id,
355 related_model)
showard0957a842009-05-11 19:25:08 +0000356
showard68693f72009-05-20 00:31:53 +0000357 for base_object in base_objects:
358 setattr(base_object, related_list_name, [])
359
360 for base_object, related_object in pivot_iterator:
361 getattr(base_object, related_list_name).append(related_object)
showard0957a842009-05-11 19:25:08 +0000362
363
showard7c785282008-05-29 19:45:12 +0000364class ValidObjectsManager(ExtendedManager):
jadmanski0afbb632008-06-06 21:10:57 +0000365 """
366 Manager returning only objects with invalid=False.
367 """
368 def get_query_set(self):
369 queryset = super(ValidObjectsManager, self).get_query_set()
370 return queryset.filter(invalid=False)
showard7c785282008-05-29 19:45:12 +0000371
372
373class ModelExtensions(object):
jadmanski0afbb632008-06-06 21:10:57 +0000374 """\
375 Mixin with convenience functions for models, built on top of the
376 default Django model functions.
377 """
378 # TODO: at least some of these functions really belong in a custom
379 # Manager class
showard7c785282008-05-29 19:45:12 +0000380
jadmanski0afbb632008-06-06 21:10:57 +0000381 field_dict = None
382 # subclasses should override if they want to support smart_get() by name
383 name_field = None
showard7c785282008-05-29 19:45:12 +0000384
385
jadmanski0afbb632008-06-06 21:10:57 +0000386 @classmethod
387 def get_field_dict(cls):
388 if cls.field_dict is None:
389 cls.field_dict = {}
390 for field in cls._meta.fields:
391 cls.field_dict[field.name] = field
392 return cls.field_dict
showard7c785282008-05-29 19:45:12 +0000393
394
jadmanski0afbb632008-06-06 21:10:57 +0000395 @classmethod
396 def clean_foreign_keys(cls, data):
397 """\
398 -Convert foreign key fields in data from <field>_id to just
399 <field>.
400 -replace foreign key objects with their IDs
401 This method modifies data in-place.
402 """
403 for field in cls._meta.fields:
404 if not field.rel:
405 continue
406 if (field.attname != field.name and
407 field.attname in data):
408 data[field.name] = data[field.attname]
409 del data[field.attname]
showarde732ee72008-09-23 19:15:43 +0000410 if field.name not in data:
411 continue
jadmanski0afbb632008-06-06 21:10:57 +0000412 value = data[field.name]
413 if isinstance(value, dbmodels.Model):
showardf8b19042009-05-12 17:22:49 +0000414 data[field.name] = value._get_pk_val()
showard7c785282008-05-29 19:45:12 +0000415
416
showard21baa452008-10-21 00:08:39 +0000417 @classmethod
418 def _convert_booleans(cls, data):
419 """
420 Ensure BooleanFields actually get bool values. The Django MySQL
421 backend returns ints for BooleanFields, which is almost always not
422 a problem, but it can be annoying in certain situations.
423 """
424 for field in cls._meta.fields:
showardf8b19042009-05-12 17:22:49 +0000425 if type(field) == dbmodels.BooleanField and field.name in data:
showard21baa452008-10-21 00:08:39 +0000426 data[field.name] = bool(data[field.name])
427
428
jadmanski0afbb632008-06-06 21:10:57 +0000429 # TODO(showard) - is there a way to not have to do this?
430 @classmethod
431 def provide_default_values(cls, data):
432 """\
433 Provide default values for fields with default values which have
434 nothing passed in.
showard7c785282008-05-29 19:45:12 +0000435
jadmanski0afbb632008-06-06 21:10:57 +0000436 For CharField and TextField fields with "blank=True", if nothing
437 is passed, we fill in an empty string value, even if there's no
438 default set.
439 """
440 new_data = dict(data)
441 field_dict = cls.get_field_dict()
442 for name, obj in field_dict.iteritems():
443 if data.get(name) is not None:
444 continue
445 if obj.default is not dbmodels.fields.NOT_PROVIDED:
446 new_data[name] = obj.default
447 elif (isinstance(obj, dbmodels.CharField) or
448 isinstance(obj, dbmodels.TextField)):
449 new_data[name] = ''
450 return new_data
showard7c785282008-05-29 19:45:12 +0000451
452
jadmanski0afbb632008-06-06 21:10:57 +0000453 @classmethod
454 def convert_human_readable_values(cls, data, to_human_readable=False):
455 """\
456 Performs conversions on user-supplied field data, to make it
457 easier for users to pass human-readable data.
showard7c785282008-05-29 19:45:12 +0000458
jadmanski0afbb632008-06-06 21:10:57 +0000459 For all fields that have choice sets, convert their values
460 from human-readable strings to enum values, if necessary. This
461 allows users to pass strings instead of the corresponding
462 integer values.
showard7c785282008-05-29 19:45:12 +0000463
jadmanski0afbb632008-06-06 21:10:57 +0000464 For all foreign key fields, call smart_get with the supplied
465 data. This allows the user to pass either an ID value or
466 the name of the object as a string.
showard7c785282008-05-29 19:45:12 +0000467
jadmanski0afbb632008-06-06 21:10:57 +0000468 If to_human_readable=True, perform the inverse - i.e. convert
469 numeric values to human readable values.
showard7c785282008-05-29 19:45:12 +0000470
jadmanski0afbb632008-06-06 21:10:57 +0000471 This method modifies data in-place.
472 """
473 field_dict = cls.get_field_dict()
474 for field_name in data:
showarde732ee72008-09-23 19:15:43 +0000475 if field_name not in field_dict or data[field_name] is None:
jadmanski0afbb632008-06-06 21:10:57 +0000476 continue
477 field_obj = field_dict[field_name]
478 # convert enum values
479 if field_obj.choices:
480 for choice_data in field_obj.choices:
481 # choice_data is (value, name)
482 if to_human_readable:
483 from_val, to_val = choice_data
484 else:
485 to_val, from_val = choice_data
486 if from_val == data[field_name]:
487 data[field_name] = to_val
488 break
489 # convert foreign key values
490 elif field_obj.rel:
showarda4ea5742009-02-17 20:56:23 +0000491 dest_obj = field_obj.rel.to.smart_get(data[field_name],
492 valid_only=False)
showardf8b19042009-05-12 17:22:49 +0000493 if to_human_readable:
494 if dest_obj.name_field is not None:
495 data[field_name] = getattr(dest_obj,
496 dest_obj.name_field)
jadmanski0afbb632008-06-06 21:10:57 +0000497 else:
showardb0a73032009-03-27 18:35:41 +0000498 data[field_name] = dest_obj
showard7c785282008-05-29 19:45:12 +0000499
500
jadmanski0afbb632008-06-06 21:10:57 +0000501 @classmethod
502 def validate_field_names(cls, data):
503 'Checks for extraneous fields in data.'
504 errors = {}
505 field_dict = cls.get_field_dict()
506 for field_name in data:
507 if field_name not in field_dict:
508 errors[field_name] = 'No field of this name'
509 return errors
showard7c785282008-05-29 19:45:12 +0000510
511
jadmanski0afbb632008-06-06 21:10:57 +0000512 @classmethod
513 def prepare_data_args(cls, data, kwargs):
514 'Common preparation for add_object and update_object'
515 data = dict(data) # don't modify the default keyword arg
516 data.update(kwargs)
517 # must check for extraneous field names here, while we have the
518 # data in a dict
519 errors = cls.validate_field_names(data)
520 if errors:
521 raise ValidationError(errors)
522 cls.convert_human_readable_values(data)
523 return data
showard7c785282008-05-29 19:45:12 +0000524
525
jadmanski0afbb632008-06-06 21:10:57 +0000526 def validate_unique(self):
527 """\
528 Validate that unique fields are unique. Django manipulators do
529 this too, but they're a huge pain to use manually. Trust me.
530 """
531 errors = {}
532 cls = type(self)
533 field_dict = self.get_field_dict()
534 manager = cls.get_valid_manager()
535 for field_name, field_obj in field_dict.iteritems():
536 if not field_obj.unique:
537 continue
showard7c785282008-05-29 19:45:12 +0000538
jadmanski0afbb632008-06-06 21:10:57 +0000539 value = getattr(self, field_name)
540 existing_objs = manager.filter(**{field_name : value})
541 num_existing = existing_objs.count()
showard7c785282008-05-29 19:45:12 +0000542
jadmanski0afbb632008-06-06 21:10:57 +0000543 if num_existing == 0:
544 continue
545 if num_existing == 1 and existing_objs[0].id == self.id:
546 continue
547 errors[field_name] = (
548 'This value must be unique (%s)' % (value))
549 return errors
showard7c785282008-05-29 19:45:12 +0000550
551
jadmanski0afbb632008-06-06 21:10:57 +0000552 def do_validate(self):
553 errors = self.validate()
554 unique_errors = self.validate_unique()
555 for field_name, error in unique_errors.iteritems():
556 errors.setdefault(field_name, error)
557 if errors:
558 raise ValidationError(errors)
showard7c785282008-05-29 19:45:12 +0000559
560
jadmanski0afbb632008-06-06 21:10:57 +0000561 # actually (externally) useful methods follow
showard7c785282008-05-29 19:45:12 +0000562
jadmanski0afbb632008-06-06 21:10:57 +0000563 @classmethod
564 def add_object(cls, data={}, **kwargs):
565 """\
566 Returns a new object created with the given data (a dictionary
567 mapping field names to values). Merges any extra keyword args
568 into data.
569 """
570 data = cls.prepare_data_args(data, kwargs)
571 data = cls.provide_default_values(data)
572 obj = cls(**data)
573 obj.do_validate()
574 obj.save()
575 return obj
showard7c785282008-05-29 19:45:12 +0000576
577
jadmanski0afbb632008-06-06 21:10:57 +0000578 def update_object(self, data={}, **kwargs):
579 """\
580 Updates the object with the given data (a dictionary mapping
581 field names to values). Merges any extra keyword args into
582 data.
583 """
584 data = self.prepare_data_args(data, kwargs)
585 for field_name, value in data.iteritems():
showardb0a73032009-03-27 18:35:41 +0000586 setattr(self, field_name, value)
jadmanski0afbb632008-06-06 21:10:57 +0000587 self.do_validate()
588 self.save()
showard7c785282008-05-29 19:45:12 +0000589
590
jadmanski0afbb632008-06-06 21:10:57 +0000591 @classmethod
showard7ac7b7a2008-07-21 20:24:29 +0000592 def query_objects(cls, filter_data, valid_only=True, initial_query=None):
jadmanski0afbb632008-06-06 21:10:57 +0000593 """\
594 Returns a QuerySet object for querying the given model_class
595 with the given filter_data. Optional special arguments in
596 filter_data include:
597 -query_start: index of first return to return
598 -query_limit: maximum number of results to return
599 -sort_by: list of fields to sort on. prefixing a '-' onto a
600 field name changes the sort to descending order.
601 -extra_args: keyword args to pass to query.extra() (see Django
602 DB layer documentation)
603 -extra_where: extra WHERE clause to append
604 """
605 query_start = filter_data.pop('query_start', None)
606 query_limit = filter_data.pop('query_limit', None)
607 if query_start and not query_limit:
608 raise ValueError('Cannot pass query_start without '
609 'query_limit')
610 sort_by = filter_data.pop('sort_by', [])
611 extra_args = filter_data.pop('extra_args', {})
612 extra_where = filter_data.pop('extra_where', None)
613 if extra_where:
showard1e935f12008-07-11 00:11:36 +0000614 # escape %'s
showardeaccf8f2009-04-16 03:11:33 +0000615 extra_where = cls.objects.escape_user_sql(extra_where)
jadmanski0afbb632008-06-06 21:10:57 +0000616 extra_args.setdefault('where', []).append(extra_where)
showard7ac7b7a2008-07-21 20:24:29 +0000617 use_distinct = not filter_data.pop('no_distinct', False)
showard7c785282008-05-29 19:45:12 +0000618
showard7ac7b7a2008-07-21 20:24:29 +0000619 if initial_query is None:
620 if valid_only:
621 initial_query = cls.get_valid_manager()
622 else:
623 initial_query = cls.objects
624 query = initial_query.filter(**filter_data)
625 if use_distinct:
626 query = query.distinct()
showard7c785282008-05-29 19:45:12 +0000627
jadmanski0afbb632008-06-06 21:10:57 +0000628 # other arguments
629 if extra_args:
630 query = query.extra(**extra_args)
showard09096d82008-07-07 23:20:49 +0000631 query = query._clone(klass=ReadonlyQuerySet)
showard7c785282008-05-29 19:45:12 +0000632
jadmanski0afbb632008-06-06 21:10:57 +0000633 # sorting + paging
634 assert isinstance(sort_by, list) or isinstance(sort_by, tuple)
635 query = query.order_by(*sort_by)
636 if query_start is not None and query_limit is not None:
637 query_limit += query_start
638 return query[query_start:query_limit]
showard7c785282008-05-29 19:45:12 +0000639
640
jadmanski0afbb632008-06-06 21:10:57 +0000641 @classmethod
showard585c2ab2008-07-23 19:29:49 +0000642 def query_count(cls, filter_data, initial_query=None):
jadmanski0afbb632008-06-06 21:10:57 +0000643 """\
644 Like query_objects, but retreive only the count of results.
645 """
646 filter_data.pop('query_start', None)
647 filter_data.pop('query_limit', None)
showard585c2ab2008-07-23 19:29:49 +0000648 query = cls.query_objects(filter_data, initial_query=initial_query)
649 return query.count()
showard7c785282008-05-29 19:45:12 +0000650
651
jadmanski0afbb632008-06-06 21:10:57 +0000652 @classmethod
653 def clean_object_dicts(cls, field_dicts):
654 """\
655 Take a list of dicts corresponding to object (as returned by
656 query.values()) and clean the data to be more suitable for
657 returning to the user.
658 """
showarde732ee72008-09-23 19:15:43 +0000659 for field_dict in field_dicts:
660 cls.clean_foreign_keys(field_dict)
showard21baa452008-10-21 00:08:39 +0000661 cls._convert_booleans(field_dict)
showarde732ee72008-09-23 19:15:43 +0000662 cls.convert_human_readable_values(field_dict,
663 to_human_readable=True)
showard7c785282008-05-29 19:45:12 +0000664
665
jadmanski0afbb632008-06-06 21:10:57 +0000666 @classmethod
showarde732ee72008-09-23 19:15:43 +0000667 def list_objects(cls, filter_data, initial_query=None, fields=None):
jadmanski0afbb632008-06-06 21:10:57 +0000668 """\
669 Like query_objects, but return a list of dictionaries.
670 """
showard7ac7b7a2008-07-21 20:24:29 +0000671 query = cls.query_objects(filter_data, initial_query=initial_query)
showarde732ee72008-09-23 19:15:43 +0000672 field_dicts = [model_object.get_object_dict(fields)
673 for model_object in query]
jadmanski0afbb632008-06-06 21:10:57 +0000674 return field_dicts
showard7c785282008-05-29 19:45:12 +0000675
676
jadmanski0afbb632008-06-06 21:10:57 +0000677 @classmethod
showarda4ea5742009-02-17 20:56:23 +0000678 def smart_get(cls, id_or_name, valid_only=True):
jadmanski0afbb632008-06-06 21:10:57 +0000679 """\
680 smart_get(integer) -> get object by ID
681 smart_get(string) -> get object by name_field
jadmanski0afbb632008-06-06 21:10:57 +0000682 """
showarda4ea5742009-02-17 20:56:23 +0000683 if valid_only:
684 manager = cls.get_valid_manager()
685 else:
686 manager = cls.objects
687
688 if isinstance(id_or_name, (int, long)):
689 return manager.get(pk=id_or_name)
690 if isinstance(id_or_name, basestring):
691 return manager.get(**{cls.name_field : id_or_name})
692 raise ValueError(
693 'Invalid positional argument: %s (%s)' % (id_or_name,
694 type(id_or_name)))
showard7c785282008-05-29 19:45:12 +0000695
696
showardbe3ec042008-11-12 18:16:07 +0000697 @classmethod
698 def smart_get_bulk(cls, id_or_name_list):
699 invalid_inputs = []
700 result_objects = []
701 for id_or_name in id_or_name_list:
702 try:
703 result_objects.append(cls.smart_get(id_or_name))
704 except cls.DoesNotExist:
705 invalid_inputs.append(id_or_name)
706 if invalid_inputs:
mbligh7a3ebe32008-12-01 17:10:33 +0000707 raise cls.DoesNotExist('The following %ss do not exist: %s'
708 % (cls.__name__.lower(),
709 ', '.join(invalid_inputs)))
showardbe3ec042008-11-12 18:16:07 +0000710 return result_objects
711
712
showarde732ee72008-09-23 19:15:43 +0000713 def get_object_dict(self, fields=None):
jadmanski0afbb632008-06-06 21:10:57 +0000714 """\
715 Return a dictionary mapping fields to this object's values.
716 """
showarde732ee72008-09-23 19:15:43 +0000717 if fields is None:
718 fields = self.get_field_dict().iterkeys()
jadmanski0afbb632008-06-06 21:10:57 +0000719 object_dict = dict((field_name, getattr(self, field_name))
showarde732ee72008-09-23 19:15:43 +0000720 for field_name in fields)
jadmanski0afbb632008-06-06 21:10:57 +0000721 self.clean_object_dicts([object_dict])
showardd3dc1992009-04-22 21:01:40 +0000722 self._postprocess_object_dict(object_dict)
jadmanski0afbb632008-06-06 21:10:57 +0000723 return object_dict
showard7c785282008-05-29 19:45:12 +0000724
725
showardd3dc1992009-04-22 21:01:40 +0000726 def _postprocess_object_dict(self, object_dict):
727 """For subclasses to override."""
728 pass
729
730
jadmanski0afbb632008-06-06 21:10:57 +0000731 @classmethod
732 def get_valid_manager(cls):
733 return cls.objects
showard7c785282008-05-29 19:45:12 +0000734
735
showard2bab8f42008-11-12 18:15:22 +0000736 def _record_attributes(self, attributes):
737 """
738 See on_attribute_changed.
739 """
740 assert not isinstance(attributes, basestring)
741 self._recorded_attributes = dict((attribute, getattr(self, attribute))
742 for attribute in attributes)
743
744
745 def _check_for_updated_attributes(self):
746 """
747 See on_attribute_changed.
748 """
749 for attribute, original_value in self._recorded_attributes.iteritems():
750 new_value = getattr(self, attribute)
751 if original_value != new_value:
752 self.on_attribute_changed(attribute, original_value)
753 self._record_attributes(self._recorded_attributes.keys())
754
755
756 def on_attribute_changed(self, attribute, old_value):
757 """
758 Called whenever an attribute is updated. To be overridden.
759
760 To use this method, you must:
761 * call _record_attributes() from __init__() (after making the super
762 call) with a list of attributes for which you want to be notified upon
763 change.
764 * call _check_for_updated_attributes() from save().
765 """
766 pass
767
768
showard7c785282008-05-29 19:45:12 +0000769class ModelWithInvalid(ModelExtensions):
jadmanski0afbb632008-06-06 21:10:57 +0000770 """
771 Overrides model methods save() and delete() to support invalidation in
772 place of actual deletion. Subclasses must have a boolean "invalid"
773 field.
774 """
showard7c785282008-05-29 19:45:12 +0000775
jadmanski0afbb632008-06-06 21:10:57 +0000776 def save(self):
showardddb90992009-02-11 23:39:32 +0000777 first_time = (self.id is None)
778 if first_time:
779 # see if this object was previously added and invalidated
780 my_name = getattr(self, self.name_field)
781 filters = {self.name_field : my_name, 'invalid' : True}
782 try:
783 old_object = self.__class__.objects.get(**filters)
784 self.id = old_object.id
785 except self.DoesNotExist:
786 # no existing object
787 pass
showard7c785282008-05-29 19:45:12 +0000788
jadmanski0afbb632008-06-06 21:10:57 +0000789 super(ModelWithInvalid, self).save()
showard7c785282008-05-29 19:45:12 +0000790
791
jadmanski0afbb632008-06-06 21:10:57 +0000792 def clean_object(self):
793 """
794 This method is called when an object is marked invalid.
795 Subclasses should override this to clean up relationships that
796 should no longer exist if the object were deleted."""
797 pass
showard7c785282008-05-29 19:45:12 +0000798
799
jadmanski0afbb632008-06-06 21:10:57 +0000800 def delete(self):
801 assert not self.invalid
802 self.invalid = True
803 self.save()
804 self.clean_object()
showard7c785282008-05-29 19:45:12 +0000805
806
jadmanski0afbb632008-06-06 21:10:57 +0000807 @classmethod
808 def get_valid_manager(cls):
809 return cls.valid_objects
showard7c785282008-05-29 19:45:12 +0000810
811
jadmanski0afbb632008-06-06 21:10:57 +0000812 class Manipulator(object):
813 """
814 Force default manipulators to look only at valid objects -
815 otherwise they will match against invalid objects when checking
816 uniqueness.
817 """
818 @classmethod
819 def _prepare(cls, model):
820 super(ModelWithInvalid.Manipulator, cls)._prepare(model)
821 cls.manager = model.valid_objects
showardf8b19042009-05-12 17:22:49 +0000822
823
824class ModelWithAttributes(object):
825 """
826 Mixin class for models that have an attribute model associated with them.
827 The attribute model is assumed to have its value field named "value".
828 """
829
830 def _get_attribute_model_and_args(self, attribute):
831 """
832 Subclasses should override this to return a tuple (attribute_model,
833 keyword_args), where attribute_model is a model class and keyword_args
834 is a dict of args to pass to attribute_model.objects.get() to get an
835 instance of the given attribute on this object.
836 """
837 raise NotImplemented
838
839
840 def set_attribute(self, attribute, value):
841 attribute_model, get_args = self._get_attribute_model_and_args(
842 attribute)
843 attribute_object, _ = attribute_model.objects.get_or_create(**get_args)
844 attribute_object.value = value
845 attribute_object.save()
846
847
848 def delete_attribute(self, attribute):
849 attribute_model, get_args = self._get_attribute_model_and_args(
850 attribute)
851 try:
852 attribute_model.objects.get(**get_args).delete()
853 except HostAttribute.DoesNotExist:
854 pass
855
856
857 def set_or_delete_attribute(self, attribute, value):
858 if value is None:
859 self.delete_attribute(attribute)
860 else:
861 self.set_attribute(attribute, value)