blob: c94233b76fbb168e665893e0c0c41f77f749292c [file] [log] [blame]
showard7c785282008-05-29 19:45:12 +00001"""
2Extensions to Django's model logic.
3"""
4
5from django.db import models as dbmodels, backend, connection
6from django.utils import datastructures
showard56e93772008-10-06 10:06:22 +00007from autotest_lib.frontend.afe import readonly_connection
showard7c785282008-05-29 19:45:12 +00008
9class ValidationError(Exception):
jadmanski0afbb632008-06-06 21:10:57 +000010 """\
11 Data validation error in adding or updating an object. The associated
12 value is a dictionary mapping field names to error strings.
13 """
showard7c785282008-05-29 19:45:12 +000014
15
showard09096d82008-07-07 23:20:49 +000016def _wrap_with_readonly(method):
17 def wrapper_method(*args, **kwargs):
showard56e93772008-10-06 10:06:22 +000018 readonly_connection.connection().set_django_connection()
showard09096d82008-07-07 23:20:49 +000019 try:
20 return method(*args, **kwargs)
21 finally:
showard56e93772008-10-06 10:06:22 +000022 readonly_connection.connection().unset_django_connection()
showard09096d82008-07-07 23:20:49 +000023 wrapper_method.__name__ = method.__name__
24 return wrapper_method
25
26
27def _wrap_generator_with_readonly(generator):
28 """
29 We have to wrap generators specially. Assume it performs
30 the query on the first call to next().
31 """
32 def wrapper_generator(*args, **kwargs):
33 generator_obj = generator(*args, **kwargs)
showard56e93772008-10-06 10:06:22 +000034 readonly_connection.connection().set_django_connection()
showard09096d82008-07-07 23:20:49 +000035 try:
36 first_value = generator_obj.next()
37 finally:
showard56e93772008-10-06 10:06:22 +000038 readonly_connection.connection().unset_django_connection()
showard09096d82008-07-07 23:20:49 +000039 yield first_value
40
41 while True:
42 yield generator_obj.next()
43
44 wrapper_generator.__name__ = generator.__name__
45 return wrapper_generator
46
47
48def _make_queryset_readonly(queryset):
49 """
50 Wrap all methods that do database queries with a readonly connection.
51 """
52 db_query_methods = ['count', 'get', 'get_or_create', 'latest', 'in_bulk',
53 'delete']
54 for method_name in db_query_methods:
55 method = getattr(queryset, method_name)
56 wrapped_method = _wrap_with_readonly(method)
57 setattr(queryset, method_name, wrapped_method)
58
59 queryset.iterator = _wrap_generator_with_readonly(queryset.iterator)
60
61
62class ReadonlyQuerySet(dbmodels.query.QuerySet):
63 """
64 QuerySet object that performs all database queries with the read-only
65 connection.
66 """
67 def __init__(self, model=None):
68 super(ReadonlyQuerySet, self).__init__(model)
69 _make_queryset_readonly(self)
70
71
72 def values(self, *fields):
73 return self._clone(klass=ReadonlyValuesQuerySet, _fields=fields)
74
75
76class ReadonlyValuesQuerySet(dbmodels.query.ValuesQuerySet):
77 def __init__(self, model=None):
78 super(ReadonlyValuesQuerySet, self).__init__(model)
79 _make_queryset_readonly(self)
80
81
showard7c785282008-05-29 19:45:12 +000082class ExtendedManager(dbmodels.Manager):
jadmanski0afbb632008-06-06 21:10:57 +000083 """\
84 Extended manager supporting subquery filtering.
85 """
showard7c785282008-05-29 19:45:12 +000086
showard08f981b2008-06-24 21:59:03 +000087 class _CustomJoinQ(dbmodels.Q):
jadmanski0afbb632008-06-06 21:10:57 +000088 """
showard08f981b2008-06-24 21:59:03 +000089 Django "Q" object supporting a custom suffix for join aliases.See
90 filter_custom_join() for why this can be useful.
91 """
showard7c785282008-05-29 19:45:12 +000092
showard08f981b2008-06-24 21:59:03 +000093 def __init__(self, join_suffix, **kwargs):
94 super(ExtendedManager._CustomJoinQ, self).__init__(**kwargs)
95 self._join_suffix = join_suffix
showard7c785282008-05-29 19:45:12 +000096
showard08f981b2008-06-24 21:59:03 +000097
98 @staticmethod
99 def _substitute_aliases(renamed_aliases, condition):
100 for old_alias, new_alias in renamed_aliases:
101 condition = condition.replace(backend.quote_name(old_alias),
102 backend.quote_name(new_alias))
103 return condition
104
105
106 @staticmethod
107 def _unquote_name(name):
108 'This may be MySQL specific'
109 if backend.quote_name(name) == name:
110 return name[1:-1]
111 return name
showard7c785282008-05-29 19:45:12 +0000112
113
jadmanski0afbb632008-06-06 21:10:57 +0000114 def get_sql(self, opts):
showard08f981b2008-06-24 21:59:03 +0000115 joins, where, params = (
116 super(ExtendedManager._CustomJoinQ, self).get_sql(opts))
117
118 new_joins = datastructures.SortedDict()
119
120 # rename all join aliases and correct references in later joins
121 renamed_tables = []
122 # using iteritems seems to mess up the ordering here
123 for alias, (table, join_type, condition) in joins.items():
124 alias = self._unquote_name(alias)
125 new_alias = alias + self._join_suffix
126 renamed_tables.append((alias, new_alias))
127 condition = self._substitute_aliases(renamed_tables, condition)
128 new_alias = backend.quote_name(new_alias)
129 new_joins[new_alias] = (table, join_type, condition)
130
131 # correct references in where
132 new_where = []
133 for clause in where:
134 new_where.append(
135 self._substitute_aliases(renamed_tables, clause))
136
137 return new_joins, new_where, params
showard7c785282008-05-29 19:45:12 +0000138
139
showard43a3d262008-11-12 18:17:05 +0000140 class _CustomSqlQ(dbmodels.Q):
141 def __init__(self):
142 self._joins = datastructures.SortedDict()
143 self._where, self._params = [], []
144
145
146 def add_join(self, table, condition, join_type, alias=None):
147 if alias is None:
148 alias = table
showard43a3d262008-11-12 18:17:05 +0000149 self._joins[alias] = (table, join_type, condition)
150
151
152 def add_where(self, where, params=[]):
153 self._where.append(where)
154 self._params.extend(params)
155
156
157 def get_sql(self, opts):
158 return self._joins, self._where, self._params
159
160
161 def add_join(self, query_set, join_table, join_key,
showard0957a842009-05-11 19:25:08 +0000162 join_condition='', suffix='', exclude=False,
163 force_left_join=False):
164 """
165 Add a join to query_set.
166 @param join_table table to join to
167 @param join_key field referencing back to this model to use for the join
168 @param join_condition extra condition for the ON clause of the join
169 @param suffix suffix to add to join_table for the join alias
170 @param exclude if true, exclude rows that match this join (will use a
171 LEFT JOIN and an appropriate WHERE condition)
172 @param force_left_join - if true, a LEFT JOIN will be used instead of an
173 INNER JOIN regardless of other options
174 """
175 join_from_table = self.model._meta.db_table
176 join_from_key = self.model._meta.pk.name
showard43a3d262008-11-12 18:17:05 +0000177 join_alias = join_table + suffix
178 full_join_key = join_alias + '.' + join_key
showard0957a842009-05-11 19:25:08 +0000179 full_join_condition = '%s = %s.%s' % (full_join_key, join_from_table,
180 join_from_key)
showard43a3d262008-11-12 18:17:05 +0000181 if join_condition:
182 full_join_condition += ' AND (' + join_condition + ')'
183 if exclude or force_left_join:
184 join_type = 'LEFT JOIN'
185 else:
186 join_type = 'INNER JOIN'
187
188 filter_object = self._CustomSqlQ()
189 filter_object.add_join(join_table,
190 full_join_condition,
191 join_type,
192 alias=join_alias)
193 if exclude:
194 filter_object.add_where(full_join_key + ' IS NULL')
195 return query_set.filter(filter_object).distinct()
196
197
showard08f981b2008-06-24 21:59:03 +0000198 def filter_custom_join(self, join_suffix, **kwargs):
jadmanski0afbb632008-06-06 21:10:57 +0000199 """
showard08f981b2008-06-24 21:59:03 +0000200 Just like Django filter(), but allows the user to specify a custom
201 suffix for the join aliases involves in the filter. This makes it
202 possible to join against a table multiple times (as long as a different
203 suffix is used each time), which is necessary for certain queries.
jadmanski0afbb632008-06-06 21:10:57 +0000204 """
showard08f981b2008-06-24 21:59:03 +0000205 filter_object = self._CustomJoinQ(join_suffix, **kwargs)
206 return self.complex_filter(filter_object)
showard7c785282008-05-29 19:45:12 +0000207
208
showardeaccf8f2009-04-16 03:11:33 +0000209 def _get_quoted_field(self, table, field):
showard5ef36e92008-07-02 16:37:09 +0000210 return (backend.quote_name(table) + '.' + backend.quote_name(field))
211
212
showard7c199df2008-10-03 10:17:15 +0000213 def get_key_on_this_table(self, key_field=None):
showard5ef36e92008-07-02 16:37:09 +0000214 if key_field is None:
215 # default to primary key
216 key_field = self.model._meta.pk.column
217 return self._get_quoted_field(self.model._meta.db_table, key_field)
218
219
showardeaccf8f2009-04-16 03:11:33 +0000220 def escape_user_sql(self, sql):
221 return sql.replace('%', '%%')
222
showard5ef36e92008-07-02 16:37:09 +0000223
showard0957a842009-05-11 19:25:08 +0000224 def _custom_select_query(self, query_set, selects):
225 query_selects, where, params = query_set._get_sql_clause()
226 if query_set._distinct:
227 distinct = 'DISTINCT '
228 else:
229 distinct = ''
230 sql_query = 'SELECT ' + distinct + ','.join(selects) + where
231 cursor = readonly_connection.connection().cursor()
232 cursor.execute(sql_query, params)
233 return cursor.fetchall()
234
235
showard68693f72009-05-20 00:31:53 +0000236 def _is_relation_to(self, field, model_class):
237 return field.rel and field.rel.to is model_class
showard0957a842009-05-11 19:25:08 +0000238
239
showard68693f72009-05-20 00:31:53 +0000240 def _get_pivot_iterator(self, base_objects_by_id, related_model):
showard0957a842009-05-11 19:25:08 +0000241 """
showard68693f72009-05-20 00:31:53 +0000242 Determine the relationship between this model and related_model, and
243 return a pivot iterator.
244 @param base_objects_by_id: dict of instances of this model indexed by
245 their IDs
246 @returns a pivot iterator, which yields a tuple (base_object,
247 related_object) for each relationship between a base object and a
248 related object. all base_object instances come from base_objects_by_id.
showard0957a842009-05-11 19:25:08 +0000249 Note -- this depends on Django model internals and will likely need to
250 be updated when we move to Django 1.x.
251 """
showard68693f72009-05-20 00:31:53 +0000252 # look for a field on related_model relating to this model
253 for field in related_model._meta.fields:
showard0957a842009-05-11 19:25:08 +0000254 if self._is_relation_to(field, self.model):
showard68693f72009-05-20 00:31:53 +0000255 # many-to-one
256 return self._many_to_one_pivot(base_objects_by_id,
257 related_model, field)
showard0957a842009-05-11 19:25:08 +0000258
showard68693f72009-05-20 00:31:53 +0000259 for field in related_model._meta.many_to_many:
showard0957a842009-05-11 19:25:08 +0000260 if self._is_relation_to(field, self.model):
261 # many-to-many
showard68693f72009-05-20 00:31:53 +0000262 return self._many_to_many_pivot(
263 base_objects_by_id, related_model, field.m2m_db_table(),
264 field.m2m_reverse_name(), field.m2m_column_name())
showard0957a842009-05-11 19:25:08 +0000265
266 # maybe this model has the many-to-many field
267 for field in self.model._meta.many_to_many:
showard68693f72009-05-20 00:31:53 +0000268 if self._is_relation_to(field, related_model):
269 return self._many_to_many_pivot(
270 base_objects_by_id, related_model, field.m2m_db_table(),
271 field.m2m_column_name(), field.m2m_reverse_name())
showard0957a842009-05-11 19:25:08 +0000272
273 raise ValueError('%s has no relation to %s' %
showard68693f72009-05-20 00:31:53 +0000274 (related_model, self.model))
showard0957a842009-05-11 19:25:08 +0000275
276
showard68693f72009-05-20 00:31:53 +0000277 def _many_to_one_pivot(self, base_objects_by_id, related_model,
278 foreign_key_field):
279 """
280 @returns a pivot iterator - see _get_pivot_iterator()
281 """
282 filter_data = {foreign_key_field.name + '__pk__in':
283 base_objects_by_id.keys()}
284 for related_object in related_model.objects.filter(**filter_data):
285 fresh_base_object = getattr(related_object, foreign_key_field.name)
286 # lookup base object in the dict -- we need to return instances from
287 # the dict, not fresh instances of the same models
288 base_object = base_objects_by_id[fresh_base_object._get_pk_val()]
289 yield base_object, related_object
290
291
292 def _query_pivot_table(self, base_objects_by_id, pivot_table,
293 pivot_from_field, pivot_to_field):
showard0957a842009-05-11 19:25:08 +0000294 """
295 @param id_list list of IDs of self.model objects to include
296 @param pivot_table the name of the pivot table
297 @param pivot_from_field a field name on pivot_table referencing
298 self.model
299 @param pivot_to_field a field name on pivot_table referencing the
300 related model.
showard68693f72009-05-20 00:31:53 +0000301 @returns pivot list of IDs (base_id, related_id)
showard0957a842009-05-11 19:25:08 +0000302 """
303 query = """
304 SELECT %(from_field)s, %(to_field)s
305 FROM %(table)s
306 WHERE %(from_field)s IN (%(id_list)s)
307 """ % dict(from_field=pivot_from_field,
308 to_field=pivot_to_field,
309 table=pivot_table,
showard68693f72009-05-20 00:31:53 +0000310 id_list=','.join(str(id_) for id_
311 in base_objects_by_id.iterkeys()))
showard0957a842009-05-11 19:25:08 +0000312 cursor = readonly_connection.connection().cursor()
313 cursor.execute(query)
showard68693f72009-05-20 00:31:53 +0000314 return cursor.fetchall()
showard0957a842009-05-11 19:25:08 +0000315
316
showard68693f72009-05-20 00:31:53 +0000317 def _many_to_many_pivot(self, base_objects_by_id, related_model,
318 pivot_table, pivot_from_field, pivot_to_field):
319 """
320 @param pivot_table: see _query_pivot_table
321 @param pivot_from_field: see _query_pivot_table
322 @param pivot_to_field: see _query_pivot_table
323 @returns a pivot iterator - see _get_pivot_iterator()
324 """
325 id_pivot = self._query_pivot_table(base_objects_by_id, pivot_table,
326 pivot_from_field, pivot_to_field)
327
328 all_related_ids = list(set(related_id for base_id, related_id
329 in id_pivot))
330 related_objects_by_id = related_model.objects.in_bulk(all_related_ids)
331
332 for base_id, related_id in id_pivot:
333 yield base_objects_by_id[base_id], related_objects_by_id[related_id]
334
335
336 def populate_relationships(self, base_objects, related_model,
showard0957a842009-05-11 19:25:08 +0000337 related_list_name):
338 """
showard68693f72009-05-20 00:31:53 +0000339 For each instance of this model in base_objects, add a field named
340 related_list_name listing all the related objects of type related_model.
341 related_model must be in a many-to-one or many-to-many relationship with
342 this model.
343 @param base_objects - list of instances of this model
344 @param related_model - model class related to this model
345 @param related_list_name - attribute name in which to store the related
346 object list.
showard0957a842009-05-11 19:25:08 +0000347 """
showard68693f72009-05-20 00:31:53 +0000348 if not base_objects:
showard0957a842009-05-11 19:25:08 +0000349 # if we don't bail early, we'll get a SQL error later
350 return
showard0957a842009-05-11 19:25:08 +0000351
showard68693f72009-05-20 00:31:53 +0000352 base_objects_by_id = dict((base_object._get_pk_val(), base_object)
353 for base_object in base_objects)
354 pivot_iterator = self._get_pivot_iterator(base_objects_by_id,
355 related_model)
showard0957a842009-05-11 19:25:08 +0000356
showard68693f72009-05-20 00:31:53 +0000357 for base_object in base_objects:
358 setattr(base_object, related_list_name, [])
359
360 for base_object, related_object in pivot_iterator:
361 getattr(base_object, related_list_name).append(related_object)
showard0957a842009-05-11 19:25:08 +0000362
363
showard7c785282008-05-29 19:45:12 +0000364class ValidObjectsManager(ExtendedManager):
jadmanski0afbb632008-06-06 21:10:57 +0000365 """
366 Manager returning only objects with invalid=False.
367 """
368 def get_query_set(self):
369 queryset = super(ValidObjectsManager, self).get_query_set()
370 return queryset.filter(invalid=False)
showard7c785282008-05-29 19:45:12 +0000371
372
373class ModelExtensions(object):
jadmanski0afbb632008-06-06 21:10:57 +0000374 """\
375 Mixin with convenience functions for models, built on top of the
376 default Django model functions.
377 """
378 # TODO: at least some of these functions really belong in a custom
379 # Manager class
showard7c785282008-05-29 19:45:12 +0000380
jadmanski0afbb632008-06-06 21:10:57 +0000381 field_dict = None
382 # subclasses should override if they want to support smart_get() by name
383 name_field = None
showard7c785282008-05-29 19:45:12 +0000384
385
jadmanski0afbb632008-06-06 21:10:57 +0000386 @classmethod
387 def get_field_dict(cls):
388 if cls.field_dict is None:
389 cls.field_dict = {}
390 for field in cls._meta.fields:
391 cls.field_dict[field.name] = field
392 return cls.field_dict
showard7c785282008-05-29 19:45:12 +0000393
394
jadmanski0afbb632008-06-06 21:10:57 +0000395 @classmethod
396 def clean_foreign_keys(cls, data):
397 """\
398 -Convert foreign key fields in data from <field>_id to just
399 <field>.
400 -replace foreign key objects with their IDs
401 This method modifies data in-place.
402 """
403 for field in cls._meta.fields:
404 if not field.rel:
405 continue
406 if (field.attname != field.name and
407 field.attname in data):
408 data[field.name] = data[field.attname]
409 del data[field.attname]
showarde732ee72008-09-23 19:15:43 +0000410 if field.name not in data:
411 continue
jadmanski0afbb632008-06-06 21:10:57 +0000412 value = data[field.name]
413 if isinstance(value, dbmodels.Model):
showardf8b19042009-05-12 17:22:49 +0000414 data[field.name] = value._get_pk_val()
showard7c785282008-05-29 19:45:12 +0000415
416
showard21baa452008-10-21 00:08:39 +0000417 @classmethod
418 def _convert_booleans(cls, data):
419 """
420 Ensure BooleanFields actually get bool values. The Django MySQL
421 backend returns ints for BooleanFields, which is almost always not
422 a problem, but it can be annoying in certain situations.
423 """
424 for field in cls._meta.fields:
showardf8b19042009-05-12 17:22:49 +0000425 if type(field) == dbmodels.BooleanField and field.name in data:
showard21baa452008-10-21 00:08:39 +0000426 data[field.name] = bool(data[field.name])
427
428
jadmanski0afbb632008-06-06 21:10:57 +0000429 # TODO(showard) - is there a way to not have to do this?
430 @classmethod
431 def provide_default_values(cls, data):
432 """\
433 Provide default values for fields with default values which have
434 nothing passed in.
showard7c785282008-05-29 19:45:12 +0000435
jadmanski0afbb632008-06-06 21:10:57 +0000436 For CharField and TextField fields with "blank=True", if nothing
437 is passed, we fill in an empty string value, even if there's no
438 default set.
439 """
440 new_data = dict(data)
441 field_dict = cls.get_field_dict()
442 for name, obj in field_dict.iteritems():
443 if data.get(name) is not None:
444 continue
445 if obj.default is not dbmodels.fields.NOT_PROVIDED:
446 new_data[name] = obj.default
447 elif (isinstance(obj, dbmodels.CharField) or
448 isinstance(obj, dbmodels.TextField)):
449 new_data[name] = ''
450 return new_data
showard7c785282008-05-29 19:45:12 +0000451
452
jadmanski0afbb632008-06-06 21:10:57 +0000453 @classmethod
454 def convert_human_readable_values(cls, data, to_human_readable=False):
455 """\
456 Performs conversions on user-supplied field data, to make it
457 easier for users to pass human-readable data.
showard7c785282008-05-29 19:45:12 +0000458
jadmanski0afbb632008-06-06 21:10:57 +0000459 For all fields that have choice sets, convert their values
460 from human-readable strings to enum values, if necessary. This
461 allows users to pass strings instead of the corresponding
462 integer values.
showard7c785282008-05-29 19:45:12 +0000463
jadmanski0afbb632008-06-06 21:10:57 +0000464 For all foreign key fields, call smart_get with the supplied
465 data. This allows the user to pass either an ID value or
466 the name of the object as a string.
showard7c785282008-05-29 19:45:12 +0000467
jadmanski0afbb632008-06-06 21:10:57 +0000468 If to_human_readable=True, perform the inverse - i.e. convert
469 numeric values to human readable values.
showard7c785282008-05-29 19:45:12 +0000470
jadmanski0afbb632008-06-06 21:10:57 +0000471 This method modifies data in-place.
472 """
473 field_dict = cls.get_field_dict()
474 for field_name in data:
showarde732ee72008-09-23 19:15:43 +0000475 if field_name not in field_dict or data[field_name] is None:
jadmanski0afbb632008-06-06 21:10:57 +0000476 continue
477 field_obj = field_dict[field_name]
478 # convert enum values
479 if field_obj.choices:
480 for choice_data in field_obj.choices:
481 # choice_data is (value, name)
482 if to_human_readable:
483 from_val, to_val = choice_data
484 else:
485 to_val, from_val = choice_data
486 if from_val == data[field_name]:
487 data[field_name] = to_val
488 break
489 # convert foreign key values
490 elif field_obj.rel:
showarda4ea5742009-02-17 20:56:23 +0000491 dest_obj = field_obj.rel.to.smart_get(data[field_name],
492 valid_only=False)
showardf8b19042009-05-12 17:22:49 +0000493 if to_human_readable:
494 if dest_obj.name_field is not None:
495 data[field_name] = getattr(dest_obj,
496 dest_obj.name_field)
jadmanski0afbb632008-06-06 21:10:57 +0000497 else:
showardb0a73032009-03-27 18:35:41 +0000498 data[field_name] = dest_obj
showard7c785282008-05-29 19:45:12 +0000499
500
jadmanski0afbb632008-06-06 21:10:57 +0000501 @classmethod
502 def validate_field_names(cls, data):
503 'Checks for extraneous fields in data.'
504 errors = {}
505 field_dict = cls.get_field_dict()
506 for field_name in data:
507 if field_name not in field_dict:
508 errors[field_name] = 'No field of this name'
509 return errors
showard7c785282008-05-29 19:45:12 +0000510
511
jadmanski0afbb632008-06-06 21:10:57 +0000512 @classmethod
513 def prepare_data_args(cls, data, kwargs):
514 'Common preparation for add_object and update_object'
515 data = dict(data) # don't modify the default keyword arg
516 data.update(kwargs)
517 # must check for extraneous field names here, while we have the
518 # data in a dict
519 errors = cls.validate_field_names(data)
520 if errors:
521 raise ValidationError(errors)
522 cls.convert_human_readable_values(data)
523 return data
showard7c785282008-05-29 19:45:12 +0000524
525
jadmanski0afbb632008-06-06 21:10:57 +0000526 def validate_unique(self):
527 """\
528 Validate that unique fields are unique. Django manipulators do
529 this too, but they're a huge pain to use manually. Trust me.
530 """
531 errors = {}
532 cls = type(self)
533 field_dict = self.get_field_dict()
534 manager = cls.get_valid_manager()
535 for field_name, field_obj in field_dict.iteritems():
536 if not field_obj.unique:
537 continue
showard7c785282008-05-29 19:45:12 +0000538
jadmanski0afbb632008-06-06 21:10:57 +0000539 value = getattr(self, field_name)
540 existing_objs = manager.filter(**{field_name : value})
541 num_existing = existing_objs.count()
showard7c785282008-05-29 19:45:12 +0000542
jadmanski0afbb632008-06-06 21:10:57 +0000543 if num_existing == 0:
544 continue
545 if num_existing == 1 and existing_objs[0].id == self.id:
546 continue
547 errors[field_name] = (
548 'This value must be unique (%s)' % (value))
549 return errors
showard7c785282008-05-29 19:45:12 +0000550
551
jadmanski0afbb632008-06-06 21:10:57 +0000552 def do_validate(self):
553 errors = self.validate()
554 unique_errors = self.validate_unique()
555 for field_name, error in unique_errors.iteritems():
556 errors.setdefault(field_name, error)
557 if errors:
558 raise ValidationError(errors)
showard7c785282008-05-29 19:45:12 +0000559
560
jadmanski0afbb632008-06-06 21:10:57 +0000561 # actually (externally) useful methods follow
showard7c785282008-05-29 19:45:12 +0000562
jadmanski0afbb632008-06-06 21:10:57 +0000563 @classmethod
564 def add_object(cls, data={}, **kwargs):
565 """\
566 Returns a new object created with the given data (a dictionary
567 mapping field names to values). Merges any extra keyword args
568 into data.
569 """
570 data = cls.prepare_data_args(data, kwargs)
571 data = cls.provide_default_values(data)
572 obj = cls(**data)
573 obj.do_validate()
574 obj.save()
575 return obj
showard7c785282008-05-29 19:45:12 +0000576
577
jadmanski0afbb632008-06-06 21:10:57 +0000578 def update_object(self, data={}, **kwargs):
579 """\
580 Updates the object with the given data (a dictionary mapping
581 field names to values). Merges any extra keyword args into
582 data.
583 """
584 data = self.prepare_data_args(data, kwargs)
585 for field_name, value in data.iteritems():
showardb0a73032009-03-27 18:35:41 +0000586 setattr(self, field_name, value)
jadmanski0afbb632008-06-06 21:10:57 +0000587 self.do_validate()
588 self.save()
showard7c785282008-05-29 19:45:12 +0000589
590
jadmanski0afbb632008-06-06 21:10:57 +0000591 @classmethod
showard7ac7b7a2008-07-21 20:24:29 +0000592 def query_objects(cls, filter_data, valid_only=True, initial_query=None):
jadmanski0afbb632008-06-06 21:10:57 +0000593 """\
594 Returns a QuerySet object for querying the given model_class
595 with the given filter_data. Optional special arguments in
596 filter_data include:
597 -query_start: index of first return to return
598 -query_limit: maximum number of results to return
599 -sort_by: list of fields to sort on. prefixing a '-' onto a
600 field name changes the sort to descending order.
601 -extra_args: keyword args to pass to query.extra() (see Django
602 DB layer documentation)
603 -extra_where: extra WHERE clause to append
604 """
showardc0ac3a72009-07-08 21:14:45 +0000605 filter_data = dict(filter_data) # copy so we don't mutate the original
jadmanski0afbb632008-06-06 21:10:57 +0000606 query_start = filter_data.pop('query_start', None)
607 query_limit = filter_data.pop('query_limit', None)
608 if query_start and not query_limit:
609 raise ValueError('Cannot pass query_start without '
610 'query_limit')
611 sort_by = filter_data.pop('sort_by', [])
612 extra_args = filter_data.pop('extra_args', {})
613 extra_where = filter_data.pop('extra_where', None)
614 if extra_where:
showard1e935f12008-07-11 00:11:36 +0000615 # escape %'s
showardeaccf8f2009-04-16 03:11:33 +0000616 extra_where = cls.objects.escape_user_sql(extra_where)
jadmanski0afbb632008-06-06 21:10:57 +0000617 extra_args.setdefault('where', []).append(extra_where)
showard7ac7b7a2008-07-21 20:24:29 +0000618 use_distinct = not filter_data.pop('no_distinct', False)
showard7c785282008-05-29 19:45:12 +0000619
showard7ac7b7a2008-07-21 20:24:29 +0000620 if initial_query is None:
621 if valid_only:
622 initial_query = cls.get_valid_manager()
623 else:
624 initial_query = cls.objects
625 query = initial_query.filter(**filter_data)
626 if use_distinct:
627 query = query.distinct()
showard7c785282008-05-29 19:45:12 +0000628
jadmanski0afbb632008-06-06 21:10:57 +0000629 # other arguments
630 if extra_args:
631 query = query.extra(**extra_args)
showard09096d82008-07-07 23:20:49 +0000632 query = query._clone(klass=ReadonlyQuerySet)
showard7c785282008-05-29 19:45:12 +0000633
jadmanski0afbb632008-06-06 21:10:57 +0000634 # sorting + paging
635 assert isinstance(sort_by, list) or isinstance(sort_by, tuple)
636 query = query.order_by(*sort_by)
637 if query_start is not None and query_limit is not None:
638 query_limit += query_start
639 return query[query_start:query_limit]
showard7c785282008-05-29 19:45:12 +0000640
641
jadmanski0afbb632008-06-06 21:10:57 +0000642 @classmethod
showard585c2ab2008-07-23 19:29:49 +0000643 def query_count(cls, filter_data, initial_query=None):
jadmanski0afbb632008-06-06 21:10:57 +0000644 """\
645 Like query_objects, but retreive only the count of results.
646 """
647 filter_data.pop('query_start', None)
648 filter_data.pop('query_limit', None)
showard585c2ab2008-07-23 19:29:49 +0000649 query = cls.query_objects(filter_data, initial_query=initial_query)
650 return query.count()
showard7c785282008-05-29 19:45:12 +0000651
652
jadmanski0afbb632008-06-06 21:10:57 +0000653 @classmethod
654 def clean_object_dicts(cls, field_dicts):
655 """\
656 Take a list of dicts corresponding to object (as returned by
657 query.values()) and clean the data to be more suitable for
658 returning to the user.
659 """
showarde732ee72008-09-23 19:15:43 +0000660 for field_dict in field_dicts:
661 cls.clean_foreign_keys(field_dict)
showard21baa452008-10-21 00:08:39 +0000662 cls._convert_booleans(field_dict)
showarde732ee72008-09-23 19:15:43 +0000663 cls.convert_human_readable_values(field_dict,
664 to_human_readable=True)
showard7c785282008-05-29 19:45:12 +0000665
666
jadmanski0afbb632008-06-06 21:10:57 +0000667 @classmethod
showarde732ee72008-09-23 19:15:43 +0000668 def list_objects(cls, filter_data, initial_query=None, fields=None):
jadmanski0afbb632008-06-06 21:10:57 +0000669 """\
670 Like query_objects, but return a list of dictionaries.
671 """
showard7ac7b7a2008-07-21 20:24:29 +0000672 query = cls.query_objects(filter_data, initial_query=initial_query)
showarde732ee72008-09-23 19:15:43 +0000673 field_dicts = [model_object.get_object_dict(fields)
674 for model_object in query]
jadmanski0afbb632008-06-06 21:10:57 +0000675 return field_dicts
showard7c785282008-05-29 19:45:12 +0000676
677
jadmanski0afbb632008-06-06 21:10:57 +0000678 @classmethod
showarda4ea5742009-02-17 20:56:23 +0000679 def smart_get(cls, id_or_name, valid_only=True):
jadmanski0afbb632008-06-06 21:10:57 +0000680 """\
681 smart_get(integer) -> get object by ID
682 smart_get(string) -> get object by name_field
jadmanski0afbb632008-06-06 21:10:57 +0000683 """
showarda4ea5742009-02-17 20:56:23 +0000684 if valid_only:
685 manager = cls.get_valid_manager()
686 else:
687 manager = cls.objects
688
689 if isinstance(id_or_name, (int, long)):
690 return manager.get(pk=id_or_name)
691 if isinstance(id_or_name, basestring):
692 return manager.get(**{cls.name_field : id_or_name})
693 raise ValueError(
694 'Invalid positional argument: %s (%s)' % (id_or_name,
695 type(id_or_name)))
showard7c785282008-05-29 19:45:12 +0000696
697
showardbe3ec042008-11-12 18:16:07 +0000698 @classmethod
699 def smart_get_bulk(cls, id_or_name_list):
700 invalid_inputs = []
701 result_objects = []
702 for id_or_name in id_or_name_list:
703 try:
704 result_objects.append(cls.smart_get(id_or_name))
705 except cls.DoesNotExist:
706 invalid_inputs.append(id_or_name)
707 if invalid_inputs:
mbligh7a3ebe32008-12-01 17:10:33 +0000708 raise cls.DoesNotExist('The following %ss do not exist: %s'
709 % (cls.__name__.lower(),
710 ', '.join(invalid_inputs)))
showardbe3ec042008-11-12 18:16:07 +0000711 return result_objects
712
713
showarde732ee72008-09-23 19:15:43 +0000714 def get_object_dict(self, fields=None):
jadmanski0afbb632008-06-06 21:10:57 +0000715 """\
716 Return a dictionary mapping fields to this object's values.
717 """
showarde732ee72008-09-23 19:15:43 +0000718 if fields is None:
719 fields = self.get_field_dict().iterkeys()
jadmanski0afbb632008-06-06 21:10:57 +0000720 object_dict = dict((field_name, getattr(self, field_name))
showarde732ee72008-09-23 19:15:43 +0000721 for field_name in fields)
jadmanski0afbb632008-06-06 21:10:57 +0000722 self.clean_object_dicts([object_dict])
showardd3dc1992009-04-22 21:01:40 +0000723 self._postprocess_object_dict(object_dict)
jadmanski0afbb632008-06-06 21:10:57 +0000724 return object_dict
showard7c785282008-05-29 19:45:12 +0000725
726
showardd3dc1992009-04-22 21:01:40 +0000727 def _postprocess_object_dict(self, object_dict):
728 """For subclasses to override."""
729 pass
730
731
jadmanski0afbb632008-06-06 21:10:57 +0000732 @classmethod
733 def get_valid_manager(cls):
734 return cls.objects
showard7c785282008-05-29 19:45:12 +0000735
736
showard2bab8f42008-11-12 18:15:22 +0000737 def _record_attributes(self, attributes):
738 """
739 See on_attribute_changed.
740 """
741 assert not isinstance(attributes, basestring)
742 self._recorded_attributes = dict((attribute, getattr(self, attribute))
743 for attribute in attributes)
744
745
746 def _check_for_updated_attributes(self):
747 """
748 See on_attribute_changed.
749 """
750 for attribute, original_value in self._recorded_attributes.iteritems():
751 new_value = getattr(self, attribute)
752 if original_value != new_value:
753 self.on_attribute_changed(attribute, original_value)
754 self._record_attributes(self._recorded_attributes.keys())
755
756
757 def on_attribute_changed(self, attribute, old_value):
758 """
759 Called whenever an attribute is updated. To be overridden.
760
761 To use this method, you must:
762 * call _record_attributes() from __init__() (after making the super
763 call) with a list of attributes for which you want to be notified upon
764 change.
765 * call _check_for_updated_attributes() from save().
766 """
767 pass
768
769
showard7c785282008-05-29 19:45:12 +0000770class ModelWithInvalid(ModelExtensions):
jadmanski0afbb632008-06-06 21:10:57 +0000771 """
772 Overrides model methods save() and delete() to support invalidation in
773 place of actual deletion. Subclasses must have a boolean "invalid"
774 field.
775 """
showard7c785282008-05-29 19:45:12 +0000776
jadmanski0afbb632008-06-06 21:10:57 +0000777 def save(self):
showardddb90992009-02-11 23:39:32 +0000778 first_time = (self.id is None)
779 if first_time:
780 # see if this object was previously added and invalidated
781 my_name = getattr(self, self.name_field)
782 filters = {self.name_field : my_name, 'invalid' : True}
783 try:
784 old_object = self.__class__.objects.get(**filters)
785 self.id = old_object.id
786 except self.DoesNotExist:
787 # no existing object
788 pass
showard7c785282008-05-29 19:45:12 +0000789
jadmanski0afbb632008-06-06 21:10:57 +0000790 super(ModelWithInvalid, self).save()
showard7c785282008-05-29 19:45:12 +0000791
792
jadmanski0afbb632008-06-06 21:10:57 +0000793 def clean_object(self):
794 """
795 This method is called when an object is marked invalid.
796 Subclasses should override this to clean up relationships that
797 should no longer exist if the object were deleted."""
798 pass
showard7c785282008-05-29 19:45:12 +0000799
800
jadmanski0afbb632008-06-06 21:10:57 +0000801 def delete(self):
802 assert not self.invalid
803 self.invalid = True
804 self.save()
805 self.clean_object()
showard7c785282008-05-29 19:45:12 +0000806
807
jadmanski0afbb632008-06-06 21:10:57 +0000808 @classmethod
809 def get_valid_manager(cls):
810 return cls.valid_objects
showard7c785282008-05-29 19:45:12 +0000811
812
jadmanski0afbb632008-06-06 21:10:57 +0000813 class Manipulator(object):
814 """
815 Force default manipulators to look only at valid objects -
816 otherwise they will match against invalid objects when checking
817 uniqueness.
818 """
819 @classmethod
820 def _prepare(cls, model):
821 super(ModelWithInvalid.Manipulator, cls)._prepare(model)
822 cls.manager = model.valid_objects
showardf8b19042009-05-12 17:22:49 +0000823
824
825class ModelWithAttributes(object):
826 """
827 Mixin class for models that have an attribute model associated with them.
828 The attribute model is assumed to have its value field named "value".
829 """
830
831 def _get_attribute_model_and_args(self, attribute):
832 """
833 Subclasses should override this to return a tuple (attribute_model,
834 keyword_args), where attribute_model is a model class and keyword_args
835 is a dict of args to pass to attribute_model.objects.get() to get an
836 instance of the given attribute on this object.
837 """
838 raise NotImplemented
839
840
841 def set_attribute(self, attribute, value):
842 attribute_model, get_args = self._get_attribute_model_and_args(
843 attribute)
844 attribute_object, _ = attribute_model.objects.get_or_create(**get_args)
845 attribute_object.value = value
846 attribute_object.save()
847
848
849 def delete_attribute(self, attribute):
850 attribute_model, get_args = self._get_attribute_model_and_args(
851 attribute)
852 try:
853 attribute_model.objects.get(**get_args).delete()
854 except HostAttribute.DoesNotExist:
855 pass
856
857
858 def set_or_delete_attribute(self, attribute, value):
859 if value is None:
860 self.delete_attribute(attribute)
861 else:
862 self.set_attribute(attribute, value)