blob: 5610e7a803d3af230341da00ddd6cc2e27afa7cf [file] [log] [blame]
showard7c785282008-05-29 19:45:12 +00001"""
2Extensions to Django's model logic.
3"""
4
showard0957a842009-05-11 19:25:08 +00005import itertools
showard7c785282008-05-29 19:45:12 +00006from django.db import models as dbmodels, backend, connection
7from django.utils import datastructures
showard56e93772008-10-06 10:06:22 +00008from autotest_lib.frontend.afe import readonly_connection
showard7c785282008-05-29 19:45:12 +00009
10class ValidationError(Exception):
jadmanski0afbb632008-06-06 21:10:57 +000011 """\
12 Data validation error in adding or updating an object. The associated
13 value is a dictionary mapping field names to error strings.
14 """
showard7c785282008-05-29 19:45:12 +000015
16
showard09096d82008-07-07 23:20:49 +000017def _wrap_with_readonly(method):
18 def wrapper_method(*args, **kwargs):
showard56e93772008-10-06 10:06:22 +000019 readonly_connection.connection().set_django_connection()
showard09096d82008-07-07 23:20:49 +000020 try:
21 return method(*args, **kwargs)
22 finally:
showard56e93772008-10-06 10:06:22 +000023 readonly_connection.connection().unset_django_connection()
showard09096d82008-07-07 23:20:49 +000024 wrapper_method.__name__ = method.__name__
25 return wrapper_method
26
27
28def _wrap_generator_with_readonly(generator):
29 """
30 We have to wrap generators specially. Assume it performs
31 the query on the first call to next().
32 """
33 def wrapper_generator(*args, **kwargs):
34 generator_obj = generator(*args, **kwargs)
showard56e93772008-10-06 10:06:22 +000035 readonly_connection.connection().set_django_connection()
showard09096d82008-07-07 23:20:49 +000036 try:
37 first_value = generator_obj.next()
38 finally:
showard56e93772008-10-06 10:06:22 +000039 readonly_connection.connection().unset_django_connection()
showard09096d82008-07-07 23:20:49 +000040 yield first_value
41
42 while True:
43 yield generator_obj.next()
44
45 wrapper_generator.__name__ = generator.__name__
46 return wrapper_generator
47
48
49def _make_queryset_readonly(queryset):
50 """
51 Wrap all methods that do database queries with a readonly connection.
52 """
53 db_query_methods = ['count', 'get', 'get_or_create', 'latest', 'in_bulk',
54 'delete']
55 for method_name in db_query_methods:
56 method = getattr(queryset, method_name)
57 wrapped_method = _wrap_with_readonly(method)
58 setattr(queryset, method_name, wrapped_method)
59
60 queryset.iterator = _wrap_generator_with_readonly(queryset.iterator)
61
62
63class ReadonlyQuerySet(dbmodels.query.QuerySet):
64 """
65 QuerySet object that performs all database queries with the read-only
66 connection.
67 """
68 def __init__(self, model=None):
69 super(ReadonlyQuerySet, self).__init__(model)
70 _make_queryset_readonly(self)
71
72
73 def values(self, *fields):
74 return self._clone(klass=ReadonlyValuesQuerySet, _fields=fields)
75
76
77class ReadonlyValuesQuerySet(dbmodels.query.ValuesQuerySet):
78 def __init__(self, model=None):
79 super(ReadonlyValuesQuerySet, self).__init__(model)
80 _make_queryset_readonly(self)
81
82
showard7c785282008-05-29 19:45:12 +000083class ExtendedManager(dbmodels.Manager):
jadmanski0afbb632008-06-06 21:10:57 +000084 """\
85 Extended manager supporting subquery filtering.
86 """
showard7c785282008-05-29 19:45:12 +000087
showard08f981b2008-06-24 21:59:03 +000088 class _CustomJoinQ(dbmodels.Q):
jadmanski0afbb632008-06-06 21:10:57 +000089 """
showard08f981b2008-06-24 21:59:03 +000090 Django "Q" object supporting a custom suffix for join aliases.See
91 filter_custom_join() for why this can be useful.
92 """
showard7c785282008-05-29 19:45:12 +000093
showard08f981b2008-06-24 21:59:03 +000094 def __init__(self, join_suffix, **kwargs):
95 super(ExtendedManager._CustomJoinQ, self).__init__(**kwargs)
96 self._join_suffix = join_suffix
showard7c785282008-05-29 19:45:12 +000097
showard08f981b2008-06-24 21:59:03 +000098
99 @staticmethod
100 def _substitute_aliases(renamed_aliases, condition):
101 for old_alias, new_alias in renamed_aliases:
102 condition = condition.replace(backend.quote_name(old_alias),
103 backend.quote_name(new_alias))
104 return condition
105
106
107 @staticmethod
108 def _unquote_name(name):
109 'This may be MySQL specific'
110 if backend.quote_name(name) == name:
111 return name[1:-1]
112 return name
showard7c785282008-05-29 19:45:12 +0000113
114
jadmanski0afbb632008-06-06 21:10:57 +0000115 def get_sql(self, opts):
showard08f981b2008-06-24 21:59:03 +0000116 joins, where, params = (
117 super(ExtendedManager._CustomJoinQ, self).get_sql(opts))
118
119 new_joins = datastructures.SortedDict()
120
121 # rename all join aliases and correct references in later joins
122 renamed_tables = []
123 # using iteritems seems to mess up the ordering here
124 for alias, (table, join_type, condition) in joins.items():
125 alias = self._unquote_name(alias)
126 new_alias = alias + self._join_suffix
127 renamed_tables.append((alias, new_alias))
128 condition = self._substitute_aliases(renamed_tables, condition)
129 new_alias = backend.quote_name(new_alias)
130 new_joins[new_alias] = (table, join_type, condition)
131
132 # correct references in where
133 new_where = []
134 for clause in where:
135 new_where.append(
136 self._substitute_aliases(renamed_tables, clause))
137
138 return new_joins, new_where, params
showard7c785282008-05-29 19:45:12 +0000139
140
showard43a3d262008-11-12 18:17:05 +0000141 class _CustomSqlQ(dbmodels.Q):
142 def __init__(self):
143 self._joins = datastructures.SortedDict()
144 self._where, self._params = [], []
145
146
147 def add_join(self, table, condition, join_type, alias=None):
148 if alias is None:
149 alias = table
showard43a3d262008-11-12 18:17:05 +0000150 self._joins[alias] = (table, join_type, condition)
151
152
153 def add_where(self, where, params=[]):
154 self._where.append(where)
155 self._params.extend(params)
156
157
158 def get_sql(self, opts):
159 return self._joins, self._where, self._params
160
161
162 def add_join(self, query_set, join_table, join_key,
showard0957a842009-05-11 19:25:08 +0000163 join_condition='', suffix='', exclude=False,
164 force_left_join=False):
165 """
166 Add a join to query_set.
167 @param join_table table to join to
168 @param join_key field referencing back to this model to use for the join
169 @param join_condition extra condition for the ON clause of the join
170 @param suffix suffix to add to join_table for the join alias
171 @param exclude if true, exclude rows that match this join (will use a
172 LEFT JOIN and an appropriate WHERE condition)
173 @param force_left_join - if true, a LEFT JOIN will be used instead of an
174 INNER JOIN regardless of other options
175 """
176 join_from_table = self.model._meta.db_table
177 join_from_key = self.model._meta.pk.name
showard43a3d262008-11-12 18:17:05 +0000178 join_alias = join_table + suffix
179 full_join_key = join_alias + '.' + join_key
showard0957a842009-05-11 19:25:08 +0000180 full_join_condition = '%s = %s.%s' % (full_join_key, join_from_table,
181 join_from_key)
showard43a3d262008-11-12 18:17:05 +0000182 if join_condition:
183 full_join_condition += ' AND (' + join_condition + ')'
184 if exclude or force_left_join:
185 join_type = 'LEFT JOIN'
186 else:
187 join_type = 'INNER JOIN'
188
189 filter_object = self._CustomSqlQ()
190 filter_object.add_join(join_table,
191 full_join_condition,
192 join_type,
193 alias=join_alias)
194 if exclude:
195 filter_object.add_where(full_join_key + ' IS NULL')
196 return query_set.filter(filter_object).distinct()
197
198
showard08f981b2008-06-24 21:59:03 +0000199 def filter_custom_join(self, join_suffix, **kwargs):
jadmanski0afbb632008-06-06 21:10:57 +0000200 """
showard08f981b2008-06-24 21:59:03 +0000201 Just like Django filter(), but allows the user to specify a custom
202 suffix for the join aliases involves in the filter. This makes it
203 possible to join against a table multiple times (as long as a different
204 suffix is used each time), which is necessary for certain queries.
jadmanski0afbb632008-06-06 21:10:57 +0000205 """
showard08f981b2008-06-24 21:59:03 +0000206 filter_object = self._CustomJoinQ(join_suffix, **kwargs)
207 return self.complex_filter(filter_object)
showard7c785282008-05-29 19:45:12 +0000208
209
showardeaccf8f2009-04-16 03:11:33 +0000210 def _get_quoted_field(self, table, field):
showard5ef36e92008-07-02 16:37:09 +0000211 return (backend.quote_name(table) + '.' + backend.quote_name(field))
212
213
showard7c199df2008-10-03 10:17:15 +0000214 def get_key_on_this_table(self, key_field=None):
showard5ef36e92008-07-02 16:37:09 +0000215 if key_field is None:
216 # default to primary key
217 key_field = self.model._meta.pk.column
218 return self._get_quoted_field(self.model._meta.db_table, key_field)
219
220
showardeaccf8f2009-04-16 03:11:33 +0000221 def escape_user_sql(self, sql):
222 return sql.replace('%', '%%')
223
showard5ef36e92008-07-02 16:37:09 +0000224
showard0957a842009-05-11 19:25:08 +0000225 def _custom_select_query(self, query_set, selects):
226 query_selects, where, params = query_set._get_sql_clause()
227 if query_set._distinct:
228 distinct = 'DISTINCT '
229 else:
230 distinct = ''
231 sql_query = 'SELECT ' + distinct + ','.join(selects) + where
232 cursor = readonly_connection.connection().cursor()
233 cursor.execute(sql_query, params)
234 return cursor.fetchall()
235
236
237 def _is_relation_to(self, field, model_class):
238 return field.rel and field.rel.to is model_class
239
240
241 def _determine_pivot_table(self, related_model):
242 """
243 Determine the pivot table for this relationship and return a tuple
244 (pivot_table, pivot_from_field, pivot_to_field). See
245 _query_pivot_table() for more info.
246 Note -- this depends on Django model internals and will likely need to
247 be updated when we move to Django 1.x.
248 """
249 # look for a field on related_model relating to this model
250 for field in related_model._meta.fields:
251 if self._is_relation_to(field, self.model):
252 # many-to-one -- the related table itself is the pivot table
253 return (related_model._meta.db_table, field.column,
254 related_model.objects.get_key_on_this_table())
255
256 for field in related_model._meta.many_to_many:
257 if self._is_relation_to(field, self.model):
258 # many-to-many
259 return (field.m2m_db_table(), field.m2m_reverse_name(),
260 field.m2m_column_name())
261
262 # maybe this model has the many-to-many field
263 for field in self.model._meta.many_to_many:
264 if self._is_relation_to(field, related_model):
265 return (field.m2m_db_table(), field.m2m_column_name(),
266 field.m2m_reverse_name())
267
268 raise ValueError('%s has no relation to %s' %
269 (related_model, self.model))
270
271
272 def _query_pivot_table(self, id_list, pivot_table, pivot_from_field,
273 pivot_to_field):
274 """
275 @param id_list list of IDs of self.model objects to include
276 @param pivot_table the name of the pivot table
277 @param pivot_from_field a field name on pivot_table referencing
278 self.model
279 @param pivot_to_field a field name on pivot_table referencing the
280 related model.
281 @returns a dict mapping each IDs from id_list to a list of IDs of
282 related objects.
283 """
284 query = """
285 SELECT %(from_field)s, %(to_field)s
286 FROM %(table)s
287 WHERE %(from_field)s IN (%(id_list)s)
288 """ % dict(from_field=pivot_from_field,
289 to_field=pivot_to_field,
290 table=pivot_table,
291 id_list=','.join(str(id_) for id_ in id_list))
292 cursor = readonly_connection.connection().cursor()
293 cursor.execute(query)
294
295 related_ids = {}
296 for model_id, related_id in cursor.fetchall():
297 related_ids.setdefault(model_id, []).append(related_id)
298 return related_ids
299
300
301 def populate_relationships(self, model_objects, related_model,
302 related_list_name):
303 """
304 For each instance in model_objects, add a field named related_list_name
305 listing all the related objects of type related_model. related_model
306 must be in a many-to-one or many-to-many relationship with this model.
307 """
308 if not model_objects:
309 # if we don't bail early, we'll get a SQL error later
310 return
showardf8b19042009-05-12 17:22:49 +0000311 id_list = (item._get_pk_val() for item in model_objects)
showard0957a842009-05-11 19:25:08 +0000312 pivot_table, pivot_from_field, pivot_to_field = (
313 self._determine_pivot_table(related_model))
314 related_ids = self._query_pivot_table(id_list, pivot_table,
315 pivot_from_field, pivot_to_field)
316
317 all_related_ids = list(set(itertools.chain(*related_ids.itervalues())))
318 related_objects_by_id = related_model.objects.in_bulk(all_related_ids)
319
320 for item in model_objects:
showardf8b19042009-05-12 17:22:49 +0000321 related_ids_for_item = related_ids.get(item._get_pk_val(), [])
showard0957a842009-05-11 19:25:08 +0000322 related_objects = [related_objects_by_id[related_id]
323 for related_id in related_ids_for_item]
324 setattr(item, related_list_name, related_objects)
325
326
showard7c785282008-05-29 19:45:12 +0000327class ValidObjectsManager(ExtendedManager):
jadmanski0afbb632008-06-06 21:10:57 +0000328 """
329 Manager returning only objects with invalid=False.
330 """
331 def get_query_set(self):
332 queryset = super(ValidObjectsManager, self).get_query_set()
333 return queryset.filter(invalid=False)
showard7c785282008-05-29 19:45:12 +0000334
335
336class ModelExtensions(object):
jadmanski0afbb632008-06-06 21:10:57 +0000337 """\
338 Mixin with convenience functions for models, built on top of the
339 default Django model functions.
340 """
341 # TODO: at least some of these functions really belong in a custom
342 # Manager class
showard7c785282008-05-29 19:45:12 +0000343
jadmanski0afbb632008-06-06 21:10:57 +0000344 field_dict = None
345 # subclasses should override if they want to support smart_get() by name
346 name_field = None
showard7c785282008-05-29 19:45:12 +0000347
348
jadmanski0afbb632008-06-06 21:10:57 +0000349 @classmethod
350 def get_field_dict(cls):
351 if cls.field_dict is None:
352 cls.field_dict = {}
353 for field in cls._meta.fields:
354 cls.field_dict[field.name] = field
355 return cls.field_dict
showard7c785282008-05-29 19:45:12 +0000356
357
jadmanski0afbb632008-06-06 21:10:57 +0000358 @classmethod
359 def clean_foreign_keys(cls, data):
360 """\
361 -Convert foreign key fields in data from <field>_id to just
362 <field>.
363 -replace foreign key objects with their IDs
364 This method modifies data in-place.
365 """
366 for field in cls._meta.fields:
367 if not field.rel:
368 continue
369 if (field.attname != field.name and
370 field.attname in data):
371 data[field.name] = data[field.attname]
372 del data[field.attname]
showarde732ee72008-09-23 19:15:43 +0000373 if field.name not in data:
374 continue
jadmanski0afbb632008-06-06 21:10:57 +0000375 value = data[field.name]
376 if isinstance(value, dbmodels.Model):
showardf8b19042009-05-12 17:22:49 +0000377 data[field.name] = value._get_pk_val()
showard7c785282008-05-29 19:45:12 +0000378
379
showard21baa452008-10-21 00:08:39 +0000380 @classmethod
381 def _convert_booleans(cls, data):
382 """
383 Ensure BooleanFields actually get bool values. The Django MySQL
384 backend returns ints for BooleanFields, which is almost always not
385 a problem, but it can be annoying in certain situations.
386 """
387 for field in cls._meta.fields:
showardf8b19042009-05-12 17:22:49 +0000388 if type(field) == dbmodels.BooleanField and field.name in data:
showard21baa452008-10-21 00:08:39 +0000389 data[field.name] = bool(data[field.name])
390
391
jadmanski0afbb632008-06-06 21:10:57 +0000392 # TODO(showard) - is there a way to not have to do this?
393 @classmethod
394 def provide_default_values(cls, data):
395 """\
396 Provide default values for fields with default values which have
397 nothing passed in.
showard7c785282008-05-29 19:45:12 +0000398
jadmanski0afbb632008-06-06 21:10:57 +0000399 For CharField and TextField fields with "blank=True", if nothing
400 is passed, we fill in an empty string value, even if there's no
401 default set.
402 """
403 new_data = dict(data)
404 field_dict = cls.get_field_dict()
405 for name, obj in field_dict.iteritems():
406 if data.get(name) is not None:
407 continue
408 if obj.default is not dbmodels.fields.NOT_PROVIDED:
409 new_data[name] = obj.default
410 elif (isinstance(obj, dbmodels.CharField) or
411 isinstance(obj, dbmodels.TextField)):
412 new_data[name] = ''
413 return new_data
showard7c785282008-05-29 19:45:12 +0000414
415
jadmanski0afbb632008-06-06 21:10:57 +0000416 @classmethod
417 def convert_human_readable_values(cls, data, to_human_readable=False):
418 """\
419 Performs conversions on user-supplied field data, to make it
420 easier for users to pass human-readable data.
showard7c785282008-05-29 19:45:12 +0000421
jadmanski0afbb632008-06-06 21:10:57 +0000422 For all fields that have choice sets, convert their values
423 from human-readable strings to enum values, if necessary. This
424 allows users to pass strings instead of the corresponding
425 integer values.
showard7c785282008-05-29 19:45:12 +0000426
jadmanski0afbb632008-06-06 21:10:57 +0000427 For all foreign key fields, call smart_get with the supplied
428 data. This allows the user to pass either an ID value or
429 the name of the object as a string.
showard7c785282008-05-29 19:45:12 +0000430
jadmanski0afbb632008-06-06 21:10:57 +0000431 If to_human_readable=True, perform the inverse - i.e. convert
432 numeric values to human readable values.
showard7c785282008-05-29 19:45:12 +0000433
jadmanski0afbb632008-06-06 21:10:57 +0000434 This method modifies data in-place.
435 """
436 field_dict = cls.get_field_dict()
437 for field_name in data:
showarde732ee72008-09-23 19:15:43 +0000438 if field_name not in field_dict or data[field_name] is None:
jadmanski0afbb632008-06-06 21:10:57 +0000439 continue
440 field_obj = field_dict[field_name]
441 # convert enum values
442 if field_obj.choices:
443 for choice_data in field_obj.choices:
444 # choice_data is (value, name)
445 if to_human_readable:
446 from_val, to_val = choice_data
447 else:
448 to_val, from_val = choice_data
449 if from_val == data[field_name]:
450 data[field_name] = to_val
451 break
452 # convert foreign key values
453 elif field_obj.rel:
showarda4ea5742009-02-17 20:56:23 +0000454 dest_obj = field_obj.rel.to.smart_get(data[field_name],
455 valid_only=False)
showardf8b19042009-05-12 17:22:49 +0000456 if to_human_readable:
457 if dest_obj.name_field is not None:
458 data[field_name] = getattr(dest_obj,
459 dest_obj.name_field)
jadmanski0afbb632008-06-06 21:10:57 +0000460 else:
showardb0a73032009-03-27 18:35:41 +0000461 data[field_name] = dest_obj
showard7c785282008-05-29 19:45:12 +0000462
463
jadmanski0afbb632008-06-06 21:10:57 +0000464 @classmethod
465 def validate_field_names(cls, data):
466 'Checks for extraneous fields in data.'
467 errors = {}
468 field_dict = cls.get_field_dict()
469 for field_name in data:
470 if field_name not in field_dict:
471 errors[field_name] = 'No field of this name'
472 return errors
showard7c785282008-05-29 19:45:12 +0000473
474
jadmanski0afbb632008-06-06 21:10:57 +0000475 @classmethod
476 def prepare_data_args(cls, data, kwargs):
477 'Common preparation for add_object and update_object'
478 data = dict(data) # don't modify the default keyword arg
479 data.update(kwargs)
480 # must check for extraneous field names here, while we have the
481 # data in a dict
482 errors = cls.validate_field_names(data)
483 if errors:
484 raise ValidationError(errors)
485 cls.convert_human_readable_values(data)
486 return data
showard7c785282008-05-29 19:45:12 +0000487
488
jadmanski0afbb632008-06-06 21:10:57 +0000489 def validate_unique(self):
490 """\
491 Validate that unique fields are unique. Django manipulators do
492 this too, but they're a huge pain to use manually. Trust me.
493 """
494 errors = {}
495 cls = type(self)
496 field_dict = self.get_field_dict()
497 manager = cls.get_valid_manager()
498 for field_name, field_obj in field_dict.iteritems():
499 if not field_obj.unique:
500 continue
showard7c785282008-05-29 19:45:12 +0000501
jadmanski0afbb632008-06-06 21:10:57 +0000502 value = getattr(self, field_name)
503 existing_objs = manager.filter(**{field_name : value})
504 num_existing = existing_objs.count()
showard7c785282008-05-29 19:45:12 +0000505
jadmanski0afbb632008-06-06 21:10:57 +0000506 if num_existing == 0:
507 continue
508 if num_existing == 1 and existing_objs[0].id == self.id:
509 continue
510 errors[field_name] = (
511 'This value must be unique (%s)' % (value))
512 return errors
showard7c785282008-05-29 19:45:12 +0000513
514
jadmanski0afbb632008-06-06 21:10:57 +0000515 def do_validate(self):
516 errors = self.validate()
517 unique_errors = self.validate_unique()
518 for field_name, error in unique_errors.iteritems():
519 errors.setdefault(field_name, error)
520 if errors:
521 raise ValidationError(errors)
showard7c785282008-05-29 19:45:12 +0000522
523
jadmanski0afbb632008-06-06 21:10:57 +0000524 # actually (externally) useful methods follow
showard7c785282008-05-29 19:45:12 +0000525
jadmanski0afbb632008-06-06 21:10:57 +0000526 @classmethod
527 def add_object(cls, data={}, **kwargs):
528 """\
529 Returns a new object created with the given data (a dictionary
530 mapping field names to values). Merges any extra keyword args
531 into data.
532 """
533 data = cls.prepare_data_args(data, kwargs)
534 data = cls.provide_default_values(data)
535 obj = cls(**data)
536 obj.do_validate()
537 obj.save()
538 return obj
showard7c785282008-05-29 19:45:12 +0000539
540
jadmanski0afbb632008-06-06 21:10:57 +0000541 def update_object(self, data={}, **kwargs):
542 """\
543 Updates the object with the given data (a dictionary mapping
544 field names to values). Merges any extra keyword args into
545 data.
546 """
547 data = self.prepare_data_args(data, kwargs)
548 for field_name, value in data.iteritems():
showardb0a73032009-03-27 18:35:41 +0000549 setattr(self, field_name, value)
jadmanski0afbb632008-06-06 21:10:57 +0000550 self.do_validate()
551 self.save()
showard7c785282008-05-29 19:45:12 +0000552
553
jadmanski0afbb632008-06-06 21:10:57 +0000554 @classmethod
showard7ac7b7a2008-07-21 20:24:29 +0000555 def query_objects(cls, filter_data, valid_only=True, initial_query=None):
jadmanski0afbb632008-06-06 21:10:57 +0000556 """\
557 Returns a QuerySet object for querying the given model_class
558 with the given filter_data. Optional special arguments in
559 filter_data include:
560 -query_start: index of first return to return
561 -query_limit: maximum number of results to return
562 -sort_by: list of fields to sort on. prefixing a '-' onto a
563 field name changes the sort to descending order.
564 -extra_args: keyword args to pass to query.extra() (see Django
565 DB layer documentation)
566 -extra_where: extra WHERE clause to append
567 """
568 query_start = filter_data.pop('query_start', None)
569 query_limit = filter_data.pop('query_limit', None)
570 if query_start and not query_limit:
571 raise ValueError('Cannot pass query_start without '
572 'query_limit')
573 sort_by = filter_data.pop('sort_by', [])
574 extra_args = filter_data.pop('extra_args', {})
575 extra_where = filter_data.pop('extra_where', None)
576 if extra_where:
showard1e935f12008-07-11 00:11:36 +0000577 # escape %'s
showardeaccf8f2009-04-16 03:11:33 +0000578 extra_where = cls.objects.escape_user_sql(extra_where)
jadmanski0afbb632008-06-06 21:10:57 +0000579 extra_args.setdefault('where', []).append(extra_where)
showard7ac7b7a2008-07-21 20:24:29 +0000580 use_distinct = not filter_data.pop('no_distinct', False)
showard7c785282008-05-29 19:45:12 +0000581
showard7ac7b7a2008-07-21 20:24:29 +0000582 if initial_query is None:
583 if valid_only:
584 initial_query = cls.get_valid_manager()
585 else:
586 initial_query = cls.objects
587 query = initial_query.filter(**filter_data)
588 if use_distinct:
589 query = query.distinct()
showard7c785282008-05-29 19:45:12 +0000590
jadmanski0afbb632008-06-06 21:10:57 +0000591 # other arguments
592 if extra_args:
593 query = query.extra(**extra_args)
showard09096d82008-07-07 23:20:49 +0000594 query = query._clone(klass=ReadonlyQuerySet)
showard7c785282008-05-29 19:45:12 +0000595
jadmanski0afbb632008-06-06 21:10:57 +0000596 # sorting + paging
597 assert isinstance(sort_by, list) or isinstance(sort_by, tuple)
598 query = query.order_by(*sort_by)
599 if query_start is not None and query_limit is not None:
600 query_limit += query_start
601 return query[query_start:query_limit]
showard7c785282008-05-29 19:45:12 +0000602
603
jadmanski0afbb632008-06-06 21:10:57 +0000604 @classmethod
showard585c2ab2008-07-23 19:29:49 +0000605 def query_count(cls, filter_data, initial_query=None):
jadmanski0afbb632008-06-06 21:10:57 +0000606 """\
607 Like query_objects, but retreive only the count of results.
608 """
609 filter_data.pop('query_start', None)
610 filter_data.pop('query_limit', None)
showard585c2ab2008-07-23 19:29:49 +0000611 query = cls.query_objects(filter_data, initial_query=initial_query)
612 return query.count()
showard7c785282008-05-29 19:45:12 +0000613
614
jadmanski0afbb632008-06-06 21:10:57 +0000615 @classmethod
616 def clean_object_dicts(cls, field_dicts):
617 """\
618 Take a list of dicts corresponding to object (as returned by
619 query.values()) and clean the data to be more suitable for
620 returning to the user.
621 """
showarde732ee72008-09-23 19:15:43 +0000622 for field_dict in field_dicts:
623 cls.clean_foreign_keys(field_dict)
showard21baa452008-10-21 00:08:39 +0000624 cls._convert_booleans(field_dict)
showarde732ee72008-09-23 19:15:43 +0000625 cls.convert_human_readable_values(field_dict,
626 to_human_readable=True)
showard7c785282008-05-29 19:45:12 +0000627
628
jadmanski0afbb632008-06-06 21:10:57 +0000629 @classmethod
showarde732ee72008-09-23 19:15:43 +0000630 def list_objects(cls, filter_data, initial_query=None, fields=None):
jadmanski0afbb632008-06-06 21:10:57 +0000631 """\
632 Like query_objects, but return a list of dictionaries.
633 """
showard7ac7b7a2008-07-21 20:24:29 +0000634 query = cls.query_objects(filter_data, initial_query=initial_query)
showarde732ee72008-09-23 19:15:43 +0000635 field_dicts = [model_object.get_object_dict(fields)
636 for model_object in query]
jadmanski0afbb632008-06-06 21:10:57 +0000637 return field_dicts
showard7c785282008-05-29 19:45:12 +0000638
639
jadmanski0afbb632008-06-06 21:10:57 +0000640 @classmethod
showarda4ea5742009-02-17 20:56:23 +0000641 def smart_get(cls, id_or_name, valid_only=True):
jadmanski0afbb632008-06-06 21:10:57 +0000642 """\
643 smart_get(integer) -> get object by ID
644 smart_get(string) -> get object by name_field
jadmanski0afbb632008-06-06 21:10:57 +0000645 """
showarda4ea5742009-02-17 20:56:23 +0000646 if valid_only:
647 manager = cls.get_valid_manager()
648 else:
649 manager = cls.objects
650
651 if isinstance(id_or_name, (int, long)):
652 return manager.get(pk=id_or_name)
653 if isinstance(id_or_name, basestring):
654 return manager.get(**{cls.name_field : id_or_name})
655 raise ValueError(
656 'Invalid positional argument: %s (%s)' % (id_or_name,
657 type(id_or_name)))
showard7c785282008-05-29 19:45:12 +0000658
659
showardbe3ec042008-11-12 18:16:07 +0000660 @classmethod
661 def smart_get_bulk(cls, id_or_name_list):
662 invalid_inputs = []
663 result_objects = []
664 for id_or_name in id_or_name_list:
665 try:
666 result_objects.append(cls.smart_get(id_or_name))
667 except cls.DoesNotExist:
668 invalid_inputs.append(id_or_name)
669 if invalid_inputs:
mbligh7a3ebe32008-12-01 17:10:33 +0000670 raise cls.DoesNotExist('The following %ss do not exist: %s'
671 % (cls.__name__.lower(),
672 ', '.join(invalid_inputs)))
showardbe3ec042008-11-12 18:16:07 +0000673 return result_objects
674
675
showarde732ee72008-09-23 19:15:43 +0000676 def get_object_dict(self, fields=None):
jadmanski0afbb632008-06-06 21:10:57 +0000677 """\
678 Return a dictionary mapping fields to this object's values.
679 """
showarde732ee72008-09-23 19:15:43 +0000680 if fields is None:
681 fields = self.get_field_dict().iterkeys()
jadmanski0afbb632008-06-06 21:10:57 +0000682 object_dict = dict((field_name, getattr(self, field_name))
showarde732ee72008-09-23 19:15:43 +0000683 for field_name in fields)
jadmanski0afbb632008-06-06 21:10:57 +0000684 self.clean_object_dicts([object_dict])
showardd3dc1992009-04-22 21:01:40 +0000685 self._postprocess_object_dict(object_dict)
jadmanski0afbb632008-06-06 21:10:57 +0000686 return object_dict
showard7c785282008-05-29 19:45:12 +0000687
688
showardd3dc1992009-04-22 21:01:40 +0000689 def _postprocess_object_dict(self, object_dict):
690 """For subclasses to override."""
691 pass
692
693
jadmanski0afbb632008-06-06 21:10:57 +0000694 @classmethod
695 def get_valid_manager(cls):
696 return cls.objects
showard7c785282008-05-29 19:45:12 +0000697
698
showard2bab8f42008-11-12 18:15:22 +0000699 def _record_attributes(self, attributes):
700 """
701 See on_attribute_changed.
702 """
703 assert not isinstance(attributes, basestring)
704 self._recorded_attributes = dict((attribute, getattr(self, attribute))
705 for attribute in attributes)
706
707
708 def _check_for_updated_attributes(self):
709 """
710 See on_attribute_changed.
711 """
712 for attribute, original_value in self._recorded_attributes.iteritems():
713 new_value = getattr(self, attribute)
714 if original_value != new_value:
715 self.on_attribute_changed(attribute, original_value)
716 self._record_attributes(self._recorded_attributes.keys())
717
718
719 def on_attribute_changed(self, attribute, old_value):
720 """
721 Called whenever an attribute is updated. To be overridden.
722
723 To use this method, you must:
724 * call _record_attributes() from __init__() (after making the super
725 call) with a list of attributes for which you want to be notified upon
726 change.
727 * call _check_for_updated_attributes() from save().
728 """
729 pass
730
731
showard7c785282008-05-29 19:45:12 +0000732class ModelWithInvalid(ModelExtensions):
jadmanski0afbb632008-06-06 21:10:57 +0000733 """
734 Overrides model methods save() and delete() to support invalidation in
735 place of actual deletion. Subclasses must have a boolean "invalid"
736 field.
737 """
showard7c785282008-05-29 19:45:12 +0000738
jadmanski0afbb632008-06-06 21:10:57 +0000739 def save(self):
showardddb90992009-02-11 23:39:32 +0000740 first_time = (self.id is None)
741 if first_time:
742 # see if this object was previously added and invalidated
743 my_name = getattr(self, self.name_field)
744 filters = {self.name_field : my_name, 'invalid' : True}
745 try:
746 old_object = self.__class__.objects.get(**filters)
747 self.id = old_object.id
748 except self.DoesNotExist:
749 # no existing object
750 pass
showard7c785282008-05-29 19:45:12 +0000751
jadmanski0afbb632008-06-06 21:10:57 +0000752 super(ModelWithInvalid, self).save()
showard7c785282008-05-29 19:45:12 +0000753
754
jadmanski0afbb632008-06-06 21:10:57 +0000755 def clean_object(self):
756 """
757 This method is called when an object is marked invalid.
758 Subclasses should override this to clean up relationships that
759 should no longer exist if the object were deleted."""
760 pass
showard7c785282008-05-29 19:45:12 +0000761
762
jadmanski0afbb632008-06-06 21:10:57 +0000763 def delete(self):
764 assert not self.invalid
765 self.invalid = True
766 self.save()
767 self.clean_object()
showard7c785282008-05-29 19:45:12 +0000768
769
jadmanski0afbb632008-06-06 21:10:57 +0000770 @classmethod
771 def get_valid_manager(cls):
772 return cls.valid_objects
showard7c785282008-05-29 19:45:12 +0000773
774
jadmanski0afbb632008-06-06 21:10:57 +0000775 class Manipulator(object):
776 """
777 Force default manipulators to look only at valid objects -
778 otherwise they will match against invalid objects when checking
779 uniqueness.
780 """
781 @classmethod
782 def _prepare(cls, model):
783 super(ModelWithInvalid.Manipulator, cls)._prepare(model)
784 cls.manager = model.valid_objects
showardf8b19042009-05-12 17:22:49 +0000785
786
787class ModelWithAttributes(object):
788 """
789 Mixin class for models that have an attribute model associated with them.
790 The attribute model is assumed to have its value field named "value".
791 """
792
793 def _get_attribute_model_and_args(self, attribute):
794 """
795 Subclasses should override this to return a tuple (attribute_model,
796 keyword_args), where attribute_model is a model class and keyword_args
797 is a dict of args to pass to attribute_model.objects.get() to get an
798 instance of the given attribute on this object.
799 """
800 raise NotImplemented
801
802
803 def set_attribute(self, attribute, value):
804 attribute_model, get_args = self._get_attribute_model_and_args(
805 attribute)
806 attribute_object, _ = attribute_model.objects.get_or_create(**get_args)
807 attribute_object.value = value
808 attribute_object.save()
809
810
811 def delete_attribute(self, attribute):
812 attribute_model, get_args = self._get_attribute_model_and_args(
813 attribute)
814 try:
815 attribute_model.objects.get(**get_args).delete()
816 except HostAttribute.DoesNotExist:
817 pass
818
819
820 def set_or_delete_attribute(self, attribute, value):
821 if value is None:
822 self.delete_attribute(attribute)
823 else:
824 self.set_attribute(attribute, value)