blob: 0c2bafc0298e9d63cc9798892e487426b1328259 [file] [log] [blame]
showard7c785282008-05-29 19:45:12 +00001"""
2Extensions to Django's model logic.
3"""
4
showarda5288b42009-07-28 20:06:08 +00005import re
Michael Liang8864e862014-07-22 08:36:05 -07006import time
showarda5288b42009-07-28 20:06:08 +00007import django.core.exceptions
showard7c785282008-05-29 19:45:12 +00008from django.db import models as dbmodels, backend, connection
showarda5288b42009-07-28 20:06:08 +00009from django.db.models.sql import query
showard7e67b432010-01-20 01:13:04 +000010import django.db.models.sql.where
showard7c785282008-05-29 19:45:12 +000011from django.utils import datastructures
Prashanth B489b91d2014-03-15 12:17:16 -070012from autotest_lib.frontend.afe import rdb_model_extensions
showard56e93772008-10-06 10:06:22 +000013from autotest_lib.frontend.afe import readonly_connection
showard7c785282008-05-29 19:45:12 +000014
Prashanth B489b91d2014-03-15 12:17:16 -070015
16class ValidationError(django.core.exceptions.ValidationError):
jadmanski0afbb632008-06-06 21:10:57 +000017 """\
showarda5288b42009-07-28 20:06:08 +000018 Data validation error in adding or updating an object. The associated
jadmanski0afbb632008-06-06 21:10:57 +000019 value is a dictionary mapping field names to error strings.
20 """
showard7c785282008-05-29 19:45:12 +000021
22
showard09096d82008-07-07 23:20:49 +000023def _wrap_with_readonly(method):
mbligh1ef218d2009-08-03 16:57:56 +000024 def wrapper_method(*args, **kwargs):
25 readonly_connection.connection().set_django_connection()
26 try:
27 return method(*args, **kwargs)
28 finally:
29 readonly_connection.connection().unset_django_connection()
30 wrapper_method.__name__ = method.__name__
31 return wrapper_method
showard09096d82008-07-07 23:20:49 +000032
33
showarda5288b42009-07-28 20:06:08 +000034def _quote_name(name):
35 """Shorthand for connection.ops.quote_name()."""
36 return connection.ops.quote_name(name)
37
38
showard09096d82008-07-07 23:20:49 +000039def _wrap_generator_with_readonly(generator):
40 """
41 We have to wrap generators specially. Assume it performs
42 the query on the first call to next().
43 """
44 def wrapper_generator(*args, **kwargs):
45 generator_obj = generator(*args, **kwargs)
showard56e93772008-10-06 10:06:22 +000046 readonly_connection.connection().set_django_connection()
showard09096d82008-07-07 23:20:49 +000047 try:
48 first_value = generator_obj.next()
49 finally:
showard56e93772008-10-06 10:06:22 +000050 readonly_connection.connection().unset_django_connection()
showard09096d82008-07-07 23:20:49 +000051 yield first_value
52
53 while True:
54 yield generator_obj.next()
55
56 wrapper_generator.__name__ = generator.__name__
57 return wrapper_generator
58
59
60def _make_queryset_readonly(queryset):
61 """
62 Wrap all methods that do database queries with a readonly connection.
63 """
64 db_query_methods = ['count', 'get', 'get_or_create', 'latest', 'in_bulk',
65 'delete']
66 for method_name in db_query_methods:
67 method = getattr(queryset, method_name)
68 wrapped_method = _wrap_with_readonly(method)
69 setattr(queryset, method_name, wrapped_method)
70
71 queryset.iterator = _wrap_generator_with_readonly(queryset.iterator)
72
73
74class ReadonlyQuerySet(dbmodels.query.QuerySet):
75 """
76 QuerySet object that performs all database queries with the read-only
77 connection.
78 """
showarda5288b42009-07-28 20:06:08 +000079 def __init__(self, model=None, *args, **kwargs):
80 super(ReadonlyQuerySet, self).__init__(model, *args, **kwargs)
showard09096d82008-07-07 23:20:49 +000081 _make_queryset_readonly(self)
82
83
84 def values(self, *fields):
showarda5288b42009-07-28 20:06:08 +000085 return self._clone(klass=ReadonlyValuesQuerySet,
86 setup=True, _fields=fields)
showard09096d82008-07-07 23:20:49 +000087
88
89class ReadonlyValuesQuerySet(dbmodels.query.ValuesQuerySet):
showarda5288b42009-07-28 20:06:08 +000090 def __init__(self, model=None, *args, **kwargs):
91 super(ReadonlyValuesQuerySet, self).__init__(model, *args, **kwargs)
showard09096d82008-07-07 23:20:49 +000092 _make_queryset_readonly(self)
93
94
beepscc9fc702013-12-02 12:45:38 -080095class LeasedHostManager(dbmodels.Manager):
96 """Query manager for unleased, unlocked hosts.
97 """
98 def get_query_set(self):
99 return (super(LeasedHostManager, self).get_query_set().filter(
100 leased=0, locked=0))
101
102
showard7c785282008-05-29 19:45:12 +0000103class ExtendedManager(dbmodels.Manager):
jadmanski0afbb632008-06-06 21:10:57 +0000104 """\
105 Extended manager supporting subquery filtering.
106 """
showard7c785282008-05-29 19:45:12 +0000107
showardf828c772010-01-25 21:49:42 +0000108 class CustomQuery(query.Query):
showard7e67b432010-01-20 01:13:04 +0000109 def __init__(self, *args, **kwargs):
showardf828c772010-01-25 21:49:42 +0000110 super(ExtendedManager.CustomQuery, self).__init__(*args, **kwargs)
showard7e67b432010-01-20 01:13:04 +0000111 self._custom_joins = []
112
113
showarda5288b42009-07-28 20:06:08 +0000114 def clone(self, klass=None, **kwargs):
showardf828c772010-01-25 21:49:42 +0000115 obj = super(ExtendedManager.CustomQuery, self).clone(klass)
showard7e67b432010-01-20 01:13:04 +0000116 obj._custom_joins = list(self._custom_joins)
showarda5288b42009-07-28 20:06:08 +0000117 return obj
showard08f981b2008-06-24 21:59:03 +0000118
showard7e67b432010-01-20 01:13:04 +0000119
120 def combine(self, rhs, connector):
showardf828c772010-01-25 21:49:42 +0000121 super(ExtendedManager.CustomQuery, self).combine(rhs, connector)
showard7e67b432010-01-20 01:13:04 +0000122 if hasattr(rhs, '_custom_joins'):
123 self._custom_joins.extend(rhs._custom_joins)
124
125
126 def add_custom_join(self, table, condition, join_type,
127 condition_values=(), alias=None):
128 if alias is None:
129 alias = table
130 join_dict = dict(table=table,
131 condition=condition,
132 condition_values=condition_values,
133 join_type=join_type,
134 alias=alias)
135 self._custom_joins.append(join_dict)
136
137
showard7e67b432010-01-20 01:13:04 +0000138 @classmethod
139 def convert_query(self, query_set):
140 """
showardf828c772010-01-25 21:49:42 +0000141 Convert the query set's "query" attribute to a CustomQuery.
showard7e67b432010-01-20 01:13:04 +0000142 """
143 # Make a copy of the query set
144 query_set = query_set.all()
145 query_set.query = query_set.query.clone(
showardf828c772010-01-25 21:49:42 +0000146 klass=ExtendedManager.CustomQuery,
showard7e67b432010-01-20 01:13:04 +0000147 _custom_joins=[])
148 return query_set
showard43a3d262008-11-12 18:17:05 +0000149
150
showard7e67b432010-01-20 01:13:04 +0000151 class _WhereClause(object):
152 """Object allowing us to inject arbitrary SQL into Django queries.
showard43a3d262008-11-12 18:17:05 +0000153
showard7e67b432010-01-20 01:13:04 +0000154 By using this instead of extra(where=...), we can still freely combine
155 queries with & and |.
showarda5288b42009-07-28 20:06:08 +0000156 """
showard7e67b432010-01-20 01:13:04 +0000157 def __init__(self, clause, values=()):
158 self._clause = clause
159 self._values = values
showarda5288b42009-07-28 20:06:08 +0000160
showard7e67b432010-01-20 01:13:04 +0000161
Dale Curtis74a314b2011-06-23 14:55:46 -0700162 def as_sql(self, qn=None, connection=None):
showard7e67b432010-01-20 01:13:04 +0000163 return self._clause, self._values
164
165
166 def relabel_aliases(self, change_map):
167 return
showard43a3d262008-11-12 18:17:05 +0000168
169
showard8b0ea222009-12-23 19:23:03 +0000170 def add_join(self, query_set, join_table, join_key, join_condition='',
showard7e67b432010-01-20 01:13:04 +0000171 join_condition_values=(), join_from_key=None, alias=None,
172 suffix='', exclude=False, force_left_join=False):
173 """Add a join to query_set.
174
175 Join looks like this:
176 (INNER|LEFT) JOIN <join_table> AS <alias>
177 ON (<this table>.<join_from_key> = <join_table>.<join_key>
178 and <join_condition>)
179
showard0957a842009-05-11 19:25:08 +0000180 @param join_table table to join to
181 @param join_key field referencing back to this model to use for the join
182 @param join_condition extra condition for the ON clause of the join
showard7e67b432010-01-20 01:13:04 +0000183 @param join_condition_values values to substitute into join_condition
184 @param join_from_key column on this model to join from.
showard8b0ea222009-12-23 19:23:03 +0000185 @param alias alias to use for for join
186 @param suffix suffix to add to join_table for the join alias, if no
187 alias is provided
showard0957a842009-05-11 19:25:08 +0000188 @param exclude if true, exclude rows that match this join (will use a
showarda5288b42009-07-28 20:06:08 +0000189 LEFT OUTER JOIN and an appropriate WHERE condition)
showardc4780402009-08-31 18:31:34 +0000190 @param force_left_join - if true, a LEFT OUTER JOIN will be used
191 instead of an INNER JOIN regardless of other options
showard0957a842009-05-11 19:25:08 +0000192 """
showard7e67b432010-01-20 01:13:04 +0000193 join_from_table = query_set.model._meta.db_table
194 if join_from_key is None:
195 join_from_key = self.model._meta.pk.name
196 if alias is None:
197 alias = join_table + suffix
198 full_join_key = _quote_name(alias) + '.' + _quote_name(join_key)
199 full_join_condition = '%s = %s.%s' % (full_join_key,
200 _quote_name(join_from_table),
201 _quote_name(join_from_key))
showard43a3d262008-11-12 18:17:05 +0000202 if join_condition:
203 full_join_condition += ' AND (' + join_condition + ')'
204 if exclude or force_left_join:
showarda5288b42009-07-28 20:06:08 +0000205 join_type = query_set.query.LOUTER
showard43a3d262008-11-12 18:17:05 +0000206 else:
showarda5288b42009-07-28 20:06:08 +0000207 join_type = query_set.query.INNER
showard43a3d262008-11-12 18:17:05 +0000208
showardf828c772010-01-25 21:49:42 +0000209 query_set = self.CustomQuery.convert_query(query_set)
showard7e67b432010-01-20 01:13:04 +0000210 query_set.query.add_custom_join(join_table,
211 full_join_condition,
212 join_type,
213 condition_values=join_condition_values,
214 alias=alias)
showard43a3d262008-11-12 18:17:05 +0000215
showard7e67b432010-01-20 01:13:04 +0000216 if exclude:
217 query_set = query_set.extra(where=[full_join_key + ' IS NULL'])
218
219 return query_set
220
221
222 def _info_for_many_to_one_join(self, field, join_to_query, alias):
223 """
224 @param field: the ForeignKey field on the related model
225 @param join_to_query: the query over the related model that we're
226 joining to
227 @param alias: alias of joined table
228 """
229 info = {}
230 rhs_table = join_to_query.model._meta.db_table
231 info['rhs_table'] = rhs_table
232 info['rhs_column'] = field.column
233 info['lhs_column'] = field.rel.get_related_field().column
234 rhs_where = join_to_query.query.where
235 rhs_where.relabel_aliases({rhs_table: alias})
Dale Curtis74a314b2011-06-23 14:55:46 -0700236 compiler = join_to_query.query.get_compiler(using=join_to_query.db)
237 initial_clause, values = compiler.as_sql()
238 all_clauses = (initial_clause,)
239 if hasattr(join_to_query.query, 'extra_where'):
240 all_clauses += join_to_query.query.extra_where
241 info['where_clause'] = (
242 ' AND '.join('(%s)' % clause for clause in all_clauses))
showard7e67b432010-01-20 01:13:04 +0000243 info['values'] = values
244 return info
245
246
247 def _info_for_many_to_many_join(self, m2m_field, join_to_query, alias,
248 m2m_is_on_this_model):
249 """
250 @param m2m_field: a Django field representing the M2M relationship.
251 It uses a pivot table with the following structure:
252 this model table <---> M2M pivot table <---> joined model table
253 @param join_to_query: the query over the related model that we're
254 joining to.
255 @param alias: alias of joined table
256 """
257 if m2m_is_on_this_model:
258 # referenced field on this model
259 lhs_id_field = self.model._meta.pk
260 # foreign key on the pivot table referencing lhs_id_field
261 m2m_lhs_column = m2m_field.m2m_column_name()
262 # foreign key on the pivot table referencing rhd_id_field
263 m2m_rhs_column = m2m_field.m2m_reverse_name()
264 # referenced field on related model
265 rhs_id_field = m2m_field.rel.get_related_field()
266 else:
267 lhs_id_field = m2m_field.rel.get_related_field()
268 m2m_lhs_column = m2m_field.m2m_reverse_name()
269 m2m_rhs_column = m2m_field.m2m_column_name()
270 rhs_id_field = join_to_query.model._meta.pk
271
272 info = {}
273 info['rhs_table'] = m2m_field.m2m_db_table()
274 info['rhs_column'] = m2m_lhs_column
275 info['lhs_column'] = lhs_id_field.column
276
277 # select the ID of related models relevant to this join. we can only do
278 # a single join, so we need to gather this information up front and
279 # include it in the join condition.
280 rhs_ids = join_to_query.values_list(rhs_id_field.attname, flat=True)
281 assert len(rhs_ids) == 1, ('Many-to-many custom field joins can only '
282 'match a single related object.')
283 rhs_id = rhs_ids[0]
284
285 info['where_clause'] = '%s.%s = %s' % (_quote_name(alias),
286 _quote_name(m2m_rhs_column),
287 rhs_id)
288 info['values'] = ()
289 return info
290
291
292 def join_custom_field(self, query_set, join_to_query, alias,
293 left_join=True):
294 """Join to a related model to create a custom field in the given query.
295
296 This method is used to construct a custom field on the given query based
297 on a many-valued relationsip. join_to_query should be a simple query
298 (no joins) on the related model which returns at most one related row
299 per instance of this model.
300
301 For many-to-one relationships, the joined table contains the matching
302 row from the related model it one is related, NULL otherwise.
303
304 For many-to-many relationships, the joined table contains the matching
305 row if it's related, NULL otherwise.
306 """
307 relationship_type, field = self.determine_relationship(
308 join_to_query.model)
309
310 if relationship_type == self.MANY_TO_ONE:
311 info = self._info_for_many_to_one_join(field, join_to_query, alias)
312 elif relationship_type == self.M2M_ON_RELATED_MODEL:
313 info = self._info_for_many_to_many_join(
314 m2m_field=field, join_to_query=join_to_query, alias=alias,
315 m2m_is_on_this_model=False)
316 elif relationship_type ==self.M2M_ON_THIS_MODEL:
317 info = self._info_for_many_to_many_join(
318 m2m_field=field, join_to_query=join_to_query, alias=alias,
319 m2m_is_on_this_model=True)
320
321 return self.add_join(query_set, info['rhs_table'], info['rhs_column'],
322 join_from_key=info['lhs_column'],
323 join_condition=info['where_clause'],
324 join_condition_values=info['values'],
325 alias=alias,
326 force_left_join=left_join)
327
328
showardf828c772010-01-25 21:49:42 +0000329 def key_on_joined_table(self, join_to_query):
330 """Get a non-null column on the table joined for the given query.
331
332 This analyzes the join that would be produced if join_to_query were
333 passed to join_custom_field.
334 """
335 relationship_type, field = self.determine_relationship(
336 join_to_query.model)
337 if relationship_type == self.MANY_TO_ONE:
338 return join_to_query.model._meta.pk.column
339 return field.m2m_column_name() # any column on the M2M table will do
340
341
showard7e67b432010-01-20 01:13:04 +0000342 def add_where(self, query_set, where, values=()):
343 query_set = query_set.all()
344 query_set.query.where.add(self._WhereClause(where, values),
345 django.db.models.sql.where.AND)
showardc4780402009-08-31 18:31:34 +0000346 return query_set
showard7c785282008-05-29 19:45:12 +0000347
348
showardeaccf8f2009-04-16 03:11:33 +0000349 def _get_quoted_field(self, table, field):
showarda5288b42009-07-28 20:06:08 +0000350 return _quote_name(table) + '.' + _quote_name(field)
showard5ef36e92008-07-02 16:37:09 +0000351
352
showard7c199df2008-10-03 10:17:15 +0000353 def get_key_on_this_table(self, key_field=None):
showard5ef36e92008-07-02 16:37:09 +0000354 if key_field is None:
355 # default to primary key
356 key_field = self.model._meta.pk.column
357 return self._get_quoted_field(self.model._meta.db_table, key_field)
358
359
showardeaccf8f2009-04-16 03:11:33 +0000360 def escape_user_sql(self, sql):
361 return sql.replace('%', '%%')
362
showard5ef36e92008-07-02 16:37:09 +0000363
showard0957a842009-05-11 19:25:08 +0000364 def _custom_select_query(self, query_set, selects):
Dale Curtis74a314b2011-06-23 14:55:46 -0700365 compiler = query_set.query.get_compiler(using=query_set.db)
366 sql, params = compiler.as_sql()
showarda5288b42009-07-28 20:06:08 +0000367 from_ = sql[sql.find(' FROM'):]
368
369 if query_set.query.distinct:
showard0957a842009-05-11 19:25:08 +0000370 distinct = 'DISTINCT '
371 else:
372 distinct = ''
showarda5288b42009-07-28 20:06:08 +0000373
374 sql_query = ('SELECT ' + distinct + ','.join(selects) + from_)
showard0957a842009-05-11 19:25:08 +0000375 cursor = readonly_connection.connection().cursor()
376 cursor.execute(sql_query, params)
377 return cursor.fetchall()
378
379
showard68693f72009-05-20 00:31:53 +0000380 def _is_relation_to(self, field, model_class):
381 return field.rel and field.rel.to is model_class
showard0957a842009-05-11 19:25:08 +0000382
383
showard7e67b432010-01-20 01:13:04 +0000384 MANY_TO_ONE = object()
385 M2M_ON_RELATED_MODEL = object()
386 M2M_ON_THIS_MODEL = object()
387
388 def determine_relationship(self, related_model):
389 """
390 Determine the relationship between this model and related_model.
391
392 related_model must have some sort of many-valued relationship to this
393 manager's model.
394 @returns (relationship_type, field), where relationship_type is one of
395 MANY_TO_ONE, M2M_ON_RELATED_MODEL, M2M_ON_THIS_MODEL, and field
396 is the Django field object for the relationship.
397 """
398 # look for a foreign key field on related_model relating to this model
399 for field in related_model._meta.fields:
400 if self._is_relation_to(field, self.model):
401 return self.MANY_TO_ONE, field
402
403 # look for an M2M field on related_model relating to this model
404 for field in related_model._meta.many_to_many:
405 if self._is_relation_to(field, self.model):
406 return self.M2M_ON_RELATED_MODEL, field
407
408 # maybe this model has the many-to-many field
409 for field in self.model._meta.many_to_many:
410 if self._is_relation_to(field, related_model):
411 return self.M2M_ON_THIS_MODEL, field
412
413 raise ValueError('%s has no relation to %s' %
414 (related_model, self.model))
415
416
showard68693f72009-05-20 00:31:53 +0000417 def _get_pivot_iterator(self, base_objects_by_id, related_model):
showard0957a842009-05-11 19:25:08 +0000418 """
showard68693f72009-05-20 00:31:53 +0000419 Determine the relationship between this model and related_model, and
420 return a pivot iterator.
421 @param base_objects_by_id: dict of instances of this model indexed by
422 their IDs
423 @returns a pivot iterator, which yields a tuple (base_object,
424 related_object) for each relationship between a base object and a
425 related object. all base_object instances come from base_objects_by_id.
showard7e67b432010-01-20 01:13:04 +0000426 Note -- this depends on Django model internals.
showard0957a842009-05-11 19:25:08 +0000427 """
showard7e67b432010-01-20 01:13:04 +0000428 relationship_type, field = self.determine_relationship(related_model)
429 if relationship_type == self.MANY_TO_ONE:
430 return self._many_to_one_pivot(base_objects_by_id,
431 related_model, field)
432 elif relationship_type == self.M2M_ON_RELATED_MODEL:
433 return self._many_to_many_pivot(
showard68693f72009-05-20 00:31:53 +0000434 base_objects_by_id, related_model, field.m2m_db_table(),
435 field.m2m_reverse_name(), field.m2m_column_name())
showard7e67b432010-01-20 01:13:04 +0000436 else:
437 assert relationship_type == self.M2M_ON_THIS_MODEL
438 return self._many_to_many_pivot(
showard68693f72009-05-20 00:31:53 +0000439 base_objects_by_id, related_model, field.m2m_db_table(),
440 field.m2m_column_name(), field.m2m_reverse_name())
showard0957a842009-05-11 19:25:08 +0000441
showard0957a842009-05-11 19:25:08 +0000442
showard68693f72009-05-20 00:31:53 +0000443 def _many_to_one_pivot(self, base_objects_by_id, related_model,
444 foreign_key_field):
445 """
446 @returns a pivot iterator - see _get_pivot_iterator()
447 """
448 filter_data = {foreign_key_field.name + '__pk__in':
449 base_objects_by_id.keys()}
450 for related_object in related_model.objects.filter(**filter_data):
showarda5a72c92009-08-20 23:35:21 +0000451 # lookup base object in the dict, rather than grabbing it from the
452 # related object. we need to return instances from the dict, not
453 # fresh instances of the same models (and grabbing model instances
454 # from the related models incurs a DB query each time).
455 base_object_id = getattr(related_object, foreign_key_field.attname)
456 base_object = base_objects_by_id[base_object_id]
showard68693f72009-05-20 00:31:53 +0000457 yield base_object, related_object
458
459
460 def _query_pivot_table(self, base_objects_by_id, pivot_table,
461 pivot_from_field, pivot_to_field):
showard0957a842009-05-11 19:25:08 +0000462 """
463 @param id_list list of IDs of self.model objects to include
464 @param pivot_table the name of the pivot table
465 @param pivot_from_field a field name on pivot_table referencing
466 self.model
467 @param pivot_to_field a field name on pivot_table referencing the
468 related model.
showard68693f72009-05-20 00:31:53 +0000469 @returns pivot list of IDs (base_id, related_id)
showard0957a842009-05-11 19:25:08 +0000470 """
471 query = """
472 SELECT %(from_field)s, %(to_field)s
473 FROM %(table)s
474 WHERE %(from_field)s IN (%(id_list)s)
475 """ % dict(from_field=pivot_from_field,
476 to_field=pivot_to_field,
477 table=pivot_table,
showard68693f72009-05-20 00:31:53 +0000478 id_list=','.join(str(id_) for id_
479 in base_objects_by_id.iterkeys()))
showard0957a842009-05-11 19:25:08 +0000480 cursor = readonly_connection.connection().cursor()
481 cursor.execute(query)
showard68693f72009-05-20 00:31:53 +0000482 return cursor.fetchall()
showard0957a842009-05-11 19:25:08 +0000483
484
showard68693f72009-05-20 00:31:53 +0000485 def _many_to_many_pivot(self, base_objects_by_id, related_model,
486 pivot_table, pivot_from_field, pivot_to_field):
487 """
488 @param pivot_table: see _query_pivot_table
489 @param pivot_from_field: see _query_pivot_table
490 @param pivot_to_field: see _query_pivot_table
491 @returns a pivot iterator - see _get_pivot_iterator()
492 """
493 id_pivot = self._query_pivot_table(base_objects_by_id, pivot_table,
494 pivot_from_field, pivot_to_field)
495
496 all_related_ids = list(set(related_id for base_id, related_id
497 in id_pivot))
498 related_objects_by_id = related_model.objects.in_bulk(all_related_ids)
499
500 for base_id, related_id in id_pivot:
501 yield base_objects_by_id[base_id], related_objects_by_id[related_id]
502
503
504 def populate_relationships(self, base_objects, related_model,
showard0957a842009-05-11 19:25:08 +0000505 related_list_name):
506 """
showard68693f72009-05-20 00:31:53 +0000507 For each instance of this model in base_objects, add a field named
508 related_list_name listing all the related objects of type related_model.
509 related_model must be in a many-to-one or many-to-many relationship with
510 this model.
511 @param base_objects - list of instances of this model
512 @param related_model - model class related to this model
513 @param related_list_name - attribute name in which to store the related
514 object list.
showard0957a842009-05-11 19:25:08 +0000515 """
showard68693f72009-05-20 00:31:53 +0000516 if not base_objects:
showard0957a842009-05-11 19:25:08 +0000517 # if we don't bail early, we'll get a SQL error later
518 return
showard0957a842009-05-11 19:25:08 +0000519
showard68693f72009-05-20 00:31:53 +0000520 base_objects_by_id = dict((base_object._get_pk_val(), base_object)
521 for base_object in base_objects)
522 pivot_iterator = self._get_pivot_iterator(base_objects_by_id,
523 related_model)
showard0957a842009-05-11 19:25:08 +0000524
showard68693f72009-05-20 00:31:53 +0000525 for base_object in base_objects:
526 setattr(base_object, related_list_name, [])
527
528 for base_object, related_object in pivot_iterator:
529 getattr(base_object, related_list_name).append(related_object)
showard0957a842009-05-11 19:25:08 +0000530
531
jamesrene3656232010-03-02 00:00:30 +0000532class ModelWithInvalidQuerySet(dbmodels.query.QuerySet):
533 """
534 QuerySet that handles delete() properly for models with an "invalid" bit
535 """
536 def delete(self):
537 for model in self:
538 model.delete()
539
540
541class ModelWithInvalidManager(ExtendedManager):
542 """
543 Manager for objects with an "invalid" bit
544 """
545 def get_query_set(self):
546 return ModelWithInvalidQuerySet(self.model)
547
548
549class ValidObjectsManager(ModelWithInvalidManager):
jadmanski0afbb632008-06-06 21:10:57 +0000550 """
551 Manager returning only objects with invalid=False.
552 """
553 def get_query_set(self):
554 queryset = super(ValidObjectsManager, self).get_query_set()
555 return queryset.filter(invalid=False)
showard7c785282008-05-29 19:45:12 +0000556
557
Prashanth B489b91d2014-03-15 12:17:16 -0700558class ModelExtensions(rdb_model_extensions.ModelValidators):
jadmanski0afbb632008-06-06 21:10:57 +0000559 """\
Prashanth B489b91d2014-03-15 12:17:16 -0700560 Mixin with convenience functions for models, built on top of
561 the model validators in rdb_model_extensions.
jadmanski0afbb632008-06-06 21:10:57 +0000562 """
563 # TODO: at least some of these functions really belong in a custom
564 # Manager class
showard7c785282008-05-29 19:45:12 +0000565
Jakob Juelich3bb7c802014-09-02 16:31:11 -0700566
567 SERIALIZATION_LINKS_TO_FOLLOW = set()
568 """
569 To be able to send jobs and hosts to shards, it's necessary to find their
570 dependencies.
571 The most generic approach for this would be to traverse all relationships
572 to other objects recursively. This would list all objects that are related
573 in any way.
574 But this approach finds too many objects: If a host should be transferred,
575 all it's relationships would be traversed. This would find an acl group.
576 If then the acl group's relationships are traversed, the relationship
577 would be followed backwards and many other hosts would be found.
578
579 This mapping tells that algorithm which relations to follow explicitly.
580 """
581
Jakob Juelichf865d332014-09-29 10:47:49 -0700582
583 SERIALIZATION_LOCAL_LINKS_TO_UPDATE = set()
584 """
585 On deserializion, if the object to persist already exists, local fields
586 will only be updated, if their name is in this set.
587 """
588
589
jadmanski0afbb632008-06-06 21:10:57 +0000590 @classmethod
591 def convert_human_readable_values(cls, data, to_human_readable=False):
592 """\
593 Performs conversions on user-supplied field data, to make it
594 easier for users to pass human-readable data.
showard7c785282008-05-29 19:45:12 +0000595
jadmanski0afbb632008-06-06 21:10:57 +0000596 For all fields that have choice sets, convert their values
597 from human-readable strings to enum values, if necessary. This
598 allows users to pass strings instead of the corresponding
599 integer values.
showard7c785282008-05-29 19:45:12 +0000600
jadmanski0afbb632008-06-06 21:10:57 +0000601 For all foreign key fields, call smart_get with the supplied
602 data. This allows the user to pass either an ID value or
603 the name of the object as a string.
showard7c785282008-05-29 19:45:12 +0000604
jadmanski0afbb632008-06-06 21:10:57 +0000605 If to_human_readable=True, perform the inverse - i.e. convert
606 numeric values to human readable values.
showard7c785282008-05-29 19:45:12 +0000607
jadmanski0afbb632008-06-06 21:10:57 +0000608 This method modifies data in-place.
609 """
610 field_dict = cls.get_field_dict()
611 for field_name in data:
showarde732ee72008-09-23 19:15:43 +0000612 if field_name not in field_dict or data[field_name] is None:
jadmanski0afbb632008-06-06 21:10:57 +0000613 continue
614 field_obj = field_dict[field_name]
615 # convert enum values
616 if field_obj.choices:
617 for choice_data in field_obj.choices:
618 # choice_data is (value, name)
619 if to_human_readable:
620 from_val, to_val = choice_data
621 else:
622 to_val, from_val = choice_data
623 if from_val == data[field_name]:
624 data[field_name] = to_val
625 break
626 # convert foreign key values
627 elif field_obj.rel:
showarda4ea5742009-02-17 20:56:23 +0000628 dest_obj = field_obj.rel.to.smart_get(data[field_name],
629 valid_only=False)
showardf8b19042009-05-12 17:22:49 +0000630 if to_human_readable:
Paul Pendlebury5a8c6ad2011-02-01 07:20:17 -0800631 # parameterized_jobs do not have a name_field
632 if (field_name != 'parameterized_job' and
633 dest_obj.name_field is not None):
showardf8b19042009-05-12 17:22:49 +0000634 data[field_name] = getattr(dest_obj,
635 dest_obj.name_field)
jadmanski0afbb632008-06-06 21:10:57 +0000636 else:
showardb0a73032009-03-27 18:35:41 +0000637 data[field_name] = dest_obj
showard7c785282008-05-29 19:45:12 +0000638
639
showard7c785282008-05-29 19:45:12 +0000640
641
Dale Curtis74a314b2011-06-23 14:55:46 -0700642 def _validate_unique(self):
jadmanski0afbb632008-06-06 21:10:57 +0000643 """\
644 Validate that unique fields are unique. Django manipulators do
645 this too, but they're a huge pain to use manually. Trust me.
646 """
647 errors = {}
648 cls = type(self)
649 field_dict = self.get_field_dict()
650 manager = cls.get_valid_manager()
651 for field_name, field_obj in field_dict.iteritems():
652 if not field_obj.unique:
653 continue
showard7c785282008-05-29 19:45:12 +0000654
jadmanski0afbb632008-06-06 21:10:57 +0000655 value = getattr(self, field_name)
showardbd18ab72009-09-18 21:20:27 +0000656 if value is None and field_obj.auto_created:
657 # don't bother checking autoincrement fields about to be
658 # generated
659 continue
660
jadmanski0afbb632008-06-06 21:10:57 +0000661 existing_objs = manager.filter(**{field_name : value})
662 num_existing = existing_objs.count()
showard7c785282008-05-29 19:45:12 +0000663
jadmanski0afbb632008-06-06 21:10:57 +0000664 if num_existing == 0:
665 continue
666 if num_existing == 1 and existing_objs[0].id == self.id:
667 continue
668 errors[field_name] = (
669 'This value must be unique (%s)' % (value))
670 return errors
showard7c785282008-05-29 19:45:12 +0000671
672
showarda5288b42009-07-28 20:06:08 +0000673 def _validate(self):
674 """
675 First coerces all fields on this instance to their proper Python types.
676 Then runs validation on every field. Returns a dictionary of
677 field_name -> error_list.
678
679 Based on validate() from django.db.models.Model in Django 0.96, which
680 was removed in Django 1.0. It should reappear in a later version. See:
681 http://code.djangoproject.com/ticket/6845
682 """
683 error_dict = {}
684 for f in self._meta.fields:
685 try:
686 python_value = f.to_python(
687 getattr(self, f.attname, f.get_default()))
688 except django.core.exceptions.ValidationError, e:
jamesren1e0a4ce2010-04-21 17:45:11 +0000689 error_dict[f.name] = str(e)
showarda5288b42009-07-28 20:06:08 +0000690 continue
691
692 if not f.blank and not python_value:
693 error_dict[f.name] = 'This field is required.'
694 continue
695
696 setattr(self, f.attname, python_value)
697
698 return error_dict
699
700
jadmanski0afbb632008-06-06 21:10:57 +0000701 def do_validate(self):
showarda5288b42009-07-28 20:06:08 +0000702 errors = self._validate()
Dale Curtis74a314b2011-06-23 14:55:46 -0700703 unique_errors = self._validate_unique()
jadmanski0afbb632008-06-06 21:10:57 +0000704 for field_name, error in unique_errors.iteritems():
705 errors.setdefault(field_name, error)
706 if errors:
707 raise ValidationError(errors)
showard7c785282008-05-29 19:45:12 +0000708
709
jadmanski0afbb632008-06-06 21:10:57 +0000710 # actually (externally) useful methods follow
showard7c785282008-05-29 19:45:12 +0000711
jadmanski0afbb632008-06-06 21:10:57 +0000712 @classmethod
713 def add_object(cls, data={}, **kwargs):
714 """\
715 Returns a new object created with the given data (a dictionary
716 mapping field names to values). Merges any extra keyword args
717 into data.
718 """
Prashanth B489b91d2014-03-15 12:17:16 -0700719 data = dict(data)
720 data.update(kwargs)
721 data = cls.prepare_data_args(data)
722 cls.convert_human_readable_values(data)
jadmanski0afbb632008-06-06 21:10:57 +0000723 data = cls.provide_default_values(data)
Prashanth B489b91d2014-03-15 12:17:16 -0700724
jadmanski0afbb632008-06-06 21:10:57 +0000725 obj = cls(**data)
726 obj.do_validate()
727 obj.save()
728 return obj
showard7c785282008-05-29 19:45:12 +0000729
730
jadmanski0afbb632008-06-06 21:10:57 +0000731 def update_object(self, data={}, **kwargs):
732 """\
733 Updates the object with the given data (a dictionary mapping
734 field names to values). Merges any extra keyword args into
735 data.
736 """
Prashanth B489b91d2014-03-15 12:17:16 -0700737 data = dict(data)
738 data.update(kwargs)
739 data = self.prepare_data_args(data)
740 self.convert_human_readable_values(data)
jadmanski0afbb632008-06-06 21:10:57 +0000741 for field_name, value in data.iteritems():
showardb0a73032009-03-27 18:35:41 +0000742 setattr(self, field_name, value)
jadmanski0afbb632008-06-06 21:10:57 +0000743 self.do_validate()
744 self.save()
showard7c785282008-05-29 19:45:12 +0000745
746
showard8bfb5cb2009-10-07 20:49:15 +0000747 # see query_objects()
748 _SPECIAL_FILTER_KEYS = ('query_start', 'query_limit', 'sort_by',
749 'extra_args', 'extra_where', 'no_distinct')
750
751
jadmanski0afbb632008-06-06 21:10:57 +0000752 @classmethod
showard8bfb5cb2009-10-07 20:49:15 +0000753 def _extract_special_params(cls, filter_data):
754 """
755 @returns a tuple of dicts (special_params, regular_filters), where
756 special_params contains the parameters we handle specially and
757 regular_filters is the remaining data to be handled by Django.
758 """
759 regular_filters = dict(filter_data)
760 special_params = {}
761 for key in cls._SPECIAL_FILTER_KEYS:
762 if key in regular_filters:
763 special_params[key] = regular_filters.pop(key)
764 return special_params, regular_filters
765
766
767 @classmethod
768 def apply_presentation(cls, query, filter_data):
769 """
770 Apply presentation parameters -- sorting and paging -- to the given
771 query.
772 @returns new query with presentation applied
773 """
774 special_params, _ = cls._extract_special_params(filter_data)
775 sort_by = special_params.get('sort_by', None)
776 if sort_by:
777 assert isinstance(sort_by, list) or isinstance(sort_by, tuple)
showard8b0ea222009-12-23 19:23:03 +0000778 query = query.extra(order_by=sort_by)
showard8bfb5cb2009-10-07 20:49:15 +0000779
780 query_start = special_params.get('query_start', None)
781 query_limit = special_params.get('query_limit', None)
782 if query_start is not None:
783 if query_limit is None:
784 raise ValueError('Cannot pass query_start without query_limit')
785 # query_limit is passed as a page size
showard7074b742009-10-12 20:30:04 +0000786 query_limit += query_start
787 return query[query_start:query_limit]
showard8bfb5cb2009-10-07 20:49:15 +0000788
789
790 @classmethod
791 def query_objects(cls, filter_data, valid_only=True, initial_query=None,
792 apply_presentation=True):
jadmanski0afbb632008-06-06 21:10:57 +0000793 """\
794 Returns a QuerySet object for querying the given model_class
795 with the given filter_data. Optional special arguments in
796 filter_data include:
797 -query_start: index of first return to return
798 -query_limit: maximum number of results to return
799 -sort_by: list of fields to sort on. prefixing a '-' onto a
800 field name changes the sort to descending order.
801 -extra_args: keyword args to pass to query.extra() (see Django
802 DB layer documentation)
showarda5288b42009-07-28 20:06:08 +0000803 -extra_where: extra WHERE clause to append
showard8bfb5cb2009-10-07 20:49:15 +0000804 -no_distinct: if True, a DISTINCT will not be added to the SELECT
jadmanski0afbb632008-06-06 21:10:57 +0000805 """
showard8bfb5cb2009-10-07 20:49:15 +0000806 special_params, regular_filters = cls._extract_special_params(
807 filter_data)
showard7c785282008-05-29 19:45:12 +0000808
showard7ac7b7a2008-07-21 20:24:29 +0000809 if initial_query is None:
810 if valid_only:
811 initial_query = cls.get_valid_manager()
812 else:
813 initial_query = cls.objects
showard8bfb5cb2009-10-07 20:49:15 +0000814
815 query = initial_query.filter(**regular_filters)
816
817 use_distinct = not special_params.get('no_distinct', False)
showard7ac7b7a2008-07-21 20:24:29 +0000818 if use_distinct:
819 query = query.distinct()
showard7c785282008-05-29 19:45:12 +0000820
showard8bfb5cb2009-10-07 20:49:15 +0000821 extra_args = special_params.get('extra_args', {})
822 extra_where = special_params.get('extra_where', None)
823 if extra_where:
824 # escape %'s
825 extra_where = cls.objects.escape_user_sql(extra_where)
826 extra_args.setdefault('where', []).append(extra_where)
jadmanski0afbb632008-06-06 21:10:57 +0000827 if extra_args:
828 query = query.extra(**extra_args)
showard09096d82008-07-07 23:20:49 +0000829 query = query._clone(klass=ReadonlyQuerySet)
showard7c785282008-05-29 19:45:12 +0000830
showard8bfb5cb2009-10-07 20:49:15 +0000831 if apply_presentation:
832 query = cls.apply_presentation(query, filter_data)
833
834 return query
showard7c785282008-05-29 19:45:12 +0000835
836
jadmanski0afbb632008-06-06 21:10:57 +0000837 @classmethod
showard585c2ab2008-07-23 19:29:49 +0000838 def query_count(cls, filter_data, initial_query=None):
jadmanski0afbb632008-06-06 21:10:57 +0000839 """\
840 Like query_objects, but retreive only the count of results.
841 """
842 filter_data.pop('query_start', None)
843 filter_data.pop('query_limit', None)
showard585c2ab2008-07-23 19:29:49 +0000844 query = cls.query_objects(filter_data, initial_query=initial_query)
845 return query.count()
showard7c785282008-05-29 19:45:12 +0000846
847
jadmanski0afbb632008-06-06 21:10:57 +0000848 @classmethod
849 def clean_object_dicts(cls, field_dicts):
850 """\
851 Take a list of dicts corresponding to object (as returned by
852 query.values()) and clean the data to be more suitable for
853 returning to the user.
854 """
showarde732ee72008-09-23 19:15:43 +0000855 for field_dict in field_dicts:
856 cls.clean_foreign_keys(field_dict)
showard21baa452008-10-21 00:08:39 +0000857 cls._convert_booleans(field_dict)
showarde732ee72008-09-23 19:15:43 +0000858 cls.convert_human_readable_values(field_dict,
859 to_human_readable=True)
showard7c785282008-05-29 19:45:12 +0000860
861
jadmanski0afbb632008-06-06 21:10:57 +0000862 @classmethod
showard8bfb5cb2009-10-07 20:49:15 +0000863 def list_objects(cls, filter_data, initial_query=None):
jadmanski0afbb632008-06-06 21:10:57 +0000864 """\
865 Like query_objects, but return a list of dictionaries.
866 """
showard7ac7b7a2008-07-21 20:24:29 +0000867 query = cls.query_objects(filter_data, initial_query=initial_query)
showard8bfb5cb2009-10-07 20:49:15 +0000868 extra_fields = query.query.extra_select.keys()
869 field_dicts = [model_object.get_object_dict(extra_fields=extra_fields)
showarde732ee72008-09-23 19:15:43 +0000870 for model_object in query]
jadmanski0afbb632008-06-06 21:10:57 +0000871 return field_dicts
showard7c785282008-05-29 19:45:12 +0000872
873
jadmanski0afbb632008-06-06 21:10:57 +0000874 @classmethod
showarda4ea5742009-02-17 20:56:23 +0000875 def smart_get(cls, id_or_name, valid_only=True):
jadmanski0afbb632008-06-06 21:10:57 +0000876 """\
877 smart_get(integer) -> get object by ID
878 smart_get(string) -> get object by name_field
jadmanski0afbb632008-06-06 21:10:57 +0000879 """
showarda4ea5742009-02-17 20:56:23 +0000880 if valid_only:
881 manager = cls.get_valid_manager()
882 else:
883 manager = cls.objects
884
885 if isinstance(id_or_name, (int, long)):
886 return manager.get(pk=id_or_name)
jamesren3e9f6092010-03-11 21:32:10 +0000887 if isinstance(id_or_name, basestring) and hasattr(cls, 'name_field'):
showarda4ea5742009-02-17 20:56:23 +0000888 return manager.get(**{cls.name_field : id_or_name})
889 raise ValueError(
890 'Invalid positional argument: %s (%s)' % (id_or_name,
891 type(id_or_name)))
showard7c785282008-05-29 19:45:12 +0000892
893
showardbe3ec042008-11-12 18:16:07 +0000894 @classmethod
895 def smart_get_bulk(cls, id_or_name_list):
896 invalid_inputs = []
897 result_objects = []
898 for id_or_name in id_or_name_list:
899 try:
900 result_objects.append(cls.smart_get(id_or_name))
901 except cls.DoesNotExist:
902 invalid_inputs.append(id_or_name)
903 if invalid_inputs:
mbligh7a3ebe32008-12-01 17:10:33 +0000904 raise cls.DoesNotExist('The following %ss do not exist: %s'
905 % (cls.__name__.lower(),
906 ', '.join(invalid_inputs)))
showardbe3ec042008-11-12 18:16:07 +0000907 return result_objects
908
909
showard8bfb5cb2009-10-07 20:49:15 +0000910 def get_object_dict(self, extra_fields=None):
jadmanski0afbb632008-06-06 21:10:57 +0000911 """\
showard8bfb5cb2009-10-07 20:49:15 +0000912 Return a dictionary mapping fields to this object's values. @param
913 extra_fields: list of extra attribute names to include, in addition to
914 the fields defined on this object.
jadmanski0afbb632008-06-06 21:10:57 +0000915 """
showard8bfb5cb2009-10-07 20:49:15 +0000916 fields = self.get_field_dict().keys()
917 if extra_fields:
918 fields += extra_fields
jadmanski0afbb632008-06-06 21:10:57 +0000919 object_dict = dict((field_name, getattr(self, field_name))
showarde732ee72008-09-23 19:15:43 +0000920 for field_name in fields)
jadmanski0afbb632008-06-06 21:10:57 +0000921 self.clean_object_dicts([object_dict])
showardd3dc1992009-04-22 21:01:40 +0000922 self._postprocess_object_dict(object_dict)
jadmanski0afbb632008-06-06 21:10:57 +0000923 return object_dict
showard7c785282008-05-29 19:45:12 +0000924
925
showardd3dc1992009-04-22 21:01:40 +0000926 def _postprocess_object_dict(self, object_dict):
927 """For subclasses to override."""
928 pass
929
930
jadmanski0afbb632008-06-06 21:10:57 +0000931 @classmethod
932 def get_valid_manager(cls):
933 return cls.objects
showard7c785282008-05-29 19:45:12 +0000934
935
showard2bab8f42008-11-12 18:15:22 +0000936 def _record_attributes(self, attributes):
937 """
938 See on_attribute_changed.
939 """
940 assert not isinstance(attributes, basestring)
941 self._recorded_attributes = dict((attribute, getattr(self, attribute))
942 for attribute in attributes)
943
944
945 def _check_for_updated_attributes(self):
946 """
947 See on_attribute_changed.
948 """
949 for attribute, original_value in self._recorded_attributes.iteritems():
950 new_value = getattr(self, attribute)
951 if original_value != new_value:
952 self.on_attribute_changed(attribute, original_value)
953 self._record_attributes(self._recorded_attributes.keys())
954
955
956 def on_attribute_changed(self, attribute, old_value):
957 """
958 Called whenever an attribute is updated. To be overridden.
959
960 To use this method, you must:
961 * call _record_attributes() from __init__() (after making the super
962 call) with a list of attributes for which you want to be notified upon
963 change.
964 * call _check_for_updated_attributes() from save().
965 """
966 pass
967
968
Jakob Juelich116ff0f2014-09-17 18:25:16 -0700969 def serialize(self, include_dependencies=True):
Jakob Juelich3bb7c802014-09-02 16:31:11 -0700970 """Serializes the object with dependencies.
971
972 The variable SERIALIZATION_LINKS_TO_FOLLOW defines which dependencies
973 this function will serialize with the object.
974
Jakob Juelich116ff0f2014-09-17 18:25:16 -0700975 @param include_dependencies: Whether or not to follow relations to
976 objects this object depends on.
977 This parameter is used when uploading
978 jobs from a shard to the master, as the
979 master already has all the dependent
980 objects.
981
Jakob Juelich3bb7c802014-09-02 16:31:11 -0700982 @returns: Dictionary representation of the object.
983 """
984 serialized = {}
985 for field in self._meta.concrete_model._meta.local_fields:
986 if field.rel is None:
987 serialized[field.name] = field._get_val_from_obj(self)
988
Jakob Juelich116ff0f2014-09-17 18:25:16 -0700989 if include_dependencies:
990 for link in self.SERIALIZATION_LINKS_TO_FOLLOW:
991 serialized[link] = self._serialize_relation(link)
Jakob Juelich3bb7c802014-09-02 16:31:11 -0700992
993 return serialized
994
995
996 def _serialize_relation(self, link):
997 """Serializes dependent objects given the name of the relation.
998
999 @param link: Name of the relation to take objects from.
1000
1001 @returns For To-Many relationships a list of the serialized related
1002 objects, for To-One relationships the serialized related object.
1003 """
1004 try:
1005 attr = getattr(self, link)
1006 except AttributeError:
1007 # One-To-One relationships that point to None may raise this
1008 return None
1009
1010 if attr is None:
1011 return None
1012 if hasattr(attr, 'all'):
1013 return [obj.serialize() for obj in attr.all()]
1014 return attr.serialize()
1015
1016
Jakob Juelichf88fa932014-09-03 17:58:04 -07001017 @classmethod
Jakob Juelich116ff0f2014-09-17 18:25:16 -07001018 def _split_local_from_foreign_values(cls, data):
1019 """This splits local from foreign values in a serialized object.
1020
1021 @param data: The serialized object.
1022
1023 @returns A tuple of two lists, both containing tuples in the form
1024 (link_name, link_value). The first list contains all links
1025 for local fields, the second one contains those for foreign
1026 fields/objects.
1027 """
1028 links_to_local_values, links_to_related_values = [], []
1029 for link, value in data.iteritems():
1030 if link in cls.SERIALIZATION_LINKS_TO_FOLLOW:
1031 # It's a foreign key
1032 links_to_related_values.append((link, value))
1033 else:
1034 # It's a local attribute
1035 links_to_local_values.append((link, value))
1036 return links_to_local_values, links_to_related_values
1037
1038
Jakob Juelichf865d332014-09-29 10:47:49 -07001039 @classmethod
1040 def _filter_update_allowed_fields(cls, data):
1041 """Filters data and returns only files that updates are allowed on.
1042
1043 This is i.e. needed for syncing aborted bits from the master to shards.
1044
1045 Local links are only allowed to be updated, if they are in
1046 SERIALIZATION_LOCAL_LINKS_TO_UPDATE.
1047 Overwriting existing values is allowed in order to be able to sync i.e.
1048 the aborted bit from the master to a shard.
1049
1050 The whitelisting mechanism is in place to prevent overwriting local
1051 status: If all fields were overwritten, jobs would be completely be
1052 set back to their original (unstarted) state.
1053
1054 @param data: List with tuples of the form (link_name, link_value), as
1055 returned by _split_local_from_foreign_values.
1056
1057 @returns List of the same format as data, but only containing data for
1058 fields that updates are allowed on.
1059 """
1060 return [pair for pair in data
1061 if pair[0] in cls.SERIALIZATION_LOCAL_LINKS_TO_UPDATE]
1062
1063
Jakob Juelich116ff0f2014-09-17 18:25:16 -07001064 def _deserialize_local(self, data):
1065 """Set local attributes from a list of tuples.
1066
1067 @param data: List of tuples like returned by
1068 _split_local_from_foreign_values.
1069 """
1070 for link, value in data:
1071 setattr(self, link, value)
1072 # Overwridden save() methods are prone to errors, so don't execute them.
1073 # This is because:
1074 # - the overwritten methods depend on ACL groups that don't yet exist
1075 # and don't handle errors
1076 # - the overwritten methods think this object already exists in the db
1077 # because the id is already set
1078 super(type(self), self).save()
1079
1080
1081 def _deserialize_relations(self, data):
1082 """Set foreign attributes from a list of tuples.
1083
1084 This deserialized the related objects using their own deserialize()
1085 function and then sets the relation.
1086
1087 @param data: List of tuples like returned by
1088 _split_local_from_foreign_values.
1089 """
1090 for link, value in data:
1091 self._deserialize_relation(link, value)
1092 # See comment in _deserialize_local
1093 super(type(self), self).save()
1094
1095
1096 @classmethod
Jakob Juelichf88fa932014-09-03 17:58:04 -07001097 def deserialize(cls, data):
Jakob Juelich116ff0f2014-09-17 18:25:16 -07001098 """Recursively deserializes and saves an object with it's dependencies.
Jakob Juelichf88fa932014-09-03 17:58:04 -07001099
1100 This takes the result of the serialize method and creates objects
Jakob Juelich116ff0f2014-09-17 18:25:16 -07001101 in the database that are just like the original.
1102
1103 If an object of the same type with the same id already exists, it's
Jakob Juelichf865d332014-09-29 10:47:49 -07001104 local values will be left untouched, unless they are explicitly
1105 whitelisted in SERIALIZATION_LOCAL_LINKS_TO_UPDATE.
1106
1107 Deserialize will always recursively propagate to all related objects
Jakob Juelich116ff0f2014-09-17 18:25:16 -07001108 present in data though.
1109 I.e. this is necessary to add users to an already existing acl-group.
Jakob Juelichf88fa932014-09-03 17:58:04 -07001110
1111 @param data: Representation of an object and its dependencies, as
1112 returned by serialize.
1113
1114 @returns: The object represented by data if it didn't exist before,
1115 otherwise the object that existed before and has the same type
1116 and id as the one described by data.
1117 """
1118 if data is None:
1119 return None
1120
Jakob Juelich116ff0f2014-09-17 18:25:16 -07001121 local, related = cls._split_local_from_foreign_values(data)
1122
Jakob Juelichf88fa932014-09-03 17:58:04 -07001123 try:
Jakob Juelich116ff0f2014-09-17 18:25:16 -07001124 instance = cls.objects.get(id=data['id'])
Jakob Juelichf865d332014-09-29 10:47:49 -07001125 local = cls._filter_update_allowed_fields(local)
Jakob Juelichf88fa932014-09-03 17:58:04 -07001126 except cls.DoesNotExist:
Jakob Juelich116ff0f2014-09-17 18:25:16 -07001127 instance = cls()
Jakob Juelichf88fa932014-09-03 17:58:04 -07001128
Jakob Juelichf865d332014-09-29 10:47:49 -07001129 instance._deserialize_local(local)
Jakob Juelich116ff0f2014-09-17 18:25:16 -07001130 instance._deserialize_relations(related)
Jakob Juelichf88fa932014-09-03 17:58:04 -07001131
1132 return instance
1133
1134
Jakob Juelicha94efe62014-09-18 16:02:49 -07001135 def sanity_check_update_from_shard(self, shard, updated_serialized,
1136 *args, **kwargs):
1137 """Check if an update sent from a shard is legitimate.
1138
1139 @raises error.UnallowedRecordsSentToMaster if an update is not
1140 legitimate.
1141 """
1142 raise NotImplementedError(
1143 'sanity_check_update_from_shard must be implemented by subclass %s '
1144 'for type %s' % type(self))
1145
1146
Jakob Juelich116ff0f2014-09-17 18:25:16 -07001147 def update_from_serialized(self, serialized):
1148 """Updates local fields of an existing object from a serialized form.
1149
1150 This is different than the normal deserialize() in the way that it
1151 does update local values, which deserialize doesn't, but doesn't
1152 recursively propagate to related objects, which deserialize() does.
1153
1154 The use case of this function is to update job records on the master
1155 after the jobs have been executed on a slave, as the master is not
1156 interested in updates for users, labels, specialtasks, etc.
1157
1158 @param serialized: Representation of an object and its dependencies, as
1159 returned by serialize.
1160
1161 @raises ValueError: if serialized contains related objects, i.e. not
1162 only local fields.
1163 """
1164 local, related = (
1165 self._split_local_from_foreign_values(serialized))
1166 if related:
1167 raise ValueError('Serialized must not contain foreign '
1168 'objects: %s' % related)
1169
1170 self._deserialize_local(local)
1171
1172
Jakob Juelichf88fa932014-09-03 17:58:04 -07001173 def custom_deserialize_relation(self, link, data):
Jakob Juelich116ff0f2014-09-17 18:25:16 -07001174 """Allows overriding the deserialization behaviour by subclasses."""
Jakob Juelichf88fa932014-09-03 17:58:04 -07001175 raise NotImplementedError(
1176 'custom_deserialize_relation must be implemented by subclass %s '
1177 'for relation %s' % (type(self), link))
1178
1179
1180 def _deserialize_relation(self, link, data):
Jakob Juelich116ff0f2014-09-17 18:25:16 -07001181 """Deserializes related objects and sets references on this object.
1182
1183 Relations that point to a list of objects are handled automatically.
1184 For many-to-one or one-to-one relations custom_deserialize_relation
1185 must be overridden by the subclass.
1186
1187 Related objects are deserialized using their deserialize() method.
1188 Thereby they and their dependencies are created if they don't exist
1189 and saved to the database.
1190
1191 @param link: Name of the relation.
1192 @param data: Serialized representation of the related object(s).
1193 This means a list of dictionaries for to-many relations,
1194 just a dictionary for to-one relations.
1195 """
Jakob Juelichf88fa932014-09-03 17:58:04 -07001196 field = getattr(self, link)
1197
1198 if field and hasattr(field, 'all'):
1199 self._deserialize_2m_relation(link, data, field.model)
1200 else:
1201 self.custom_deserialize_relation(link, data)
1202
1203
1204 def _deserialize_2m_relation(self, link, data, related_class):
Jakob Juelich116ff0f2014-09-17 18:25:16 -07001205 """Deserialize related objects for one to-many relationship.
1206
1207 @param link: Name of the relation.
1208 @param data: Serialized representation of the related objects.
1209 This is a list with of dictionaries.
1210 """
Jakob Juelichf88fa932014-09-03 17:58:04 -07001211 relation_set = getattr(self, link)
1212 for serialized in data:
1213 relation_set.add(related_class.deserialize(serialized))
1214
1215
showard7c785282008-05-29 19:45:12 +00001216class ModelWithInvalid(ModelExtensions):
jadmanski0afbb632008-06-06 21:10:57 +00001217 """
1218 Overrides model methods save() and delete() to support invalidation in
1219 place of actual deletion. Subclasses must have a boolean "invalid"
1220 field.
1221 """
showard7c785282008-05-29 19:45:12 +00001222
showarda5288b42009-07-28 20:06:08 +00001223 def save(self, *args, **kwargs):
showardddb90992009-02-11 23:39:32 +00001224 first_time = (self.id is None)
1225 if first_time:
1226 # see if this object was previously added and invalidated
1227 my_name = getattr(self, self.name_field)
1228 filters = {self.name_field : my_name, 'invalid' : True}
1229 try:
1230 old_object = self.__class__.objects.get(**filters)
showardafd97de2009-10-01 18:45:09 +00001231 self.resurrect_object(old_object)
showardddb90992009-02-11 23:39:32 +00001232 except self.DoesNotExist:
1233 # no existing object
1234 pass
showard7c785282008-05-29 19:45:12 +00001235
showarda5288b42009-07-28 20:06:08 +00001236 super(ModelWithInvalid, self).save(*args, **kwargs)
showard7c785282008-05-29 19:45:12 +00001237
1238
showardafd97de2009-10-01 18:45:09 +00001239 def resurrect_object(self, old_object):
1240 """
1241 Called when self is about to be saved for the first time and is actually
1242 "undeleting" a previously deleted object. Can be overridden by
1243 subclasses to copy data as desired from the deleted entry (but this
1244 superclass implementation must normally be called).
1245 """
1246 self.id = old_object.id
1247
1248
jadmanski0afbb632008-06-06 21:10:57 +00001249 def clean_object(self):
1250 """
1251 This method is called when an object is marked invalid.
1252 Subclasses should override this to clean up relationships that
showardafd97de2009-10-01 18:45:09 +00001253 should no longer exist if the object were deleted.
1254 """
jadmanski0afbb632008-06-06 21:10:57 +00001255 pass
showard7c785282008-05-29 19:45:12 +00001256
1257
jadmanski0afbb632008-06-06 21:10:57 +00001258 def delete(self):
Dale Curtis74a314b2011-06-23 14:55:46 -07001259 self.invalid = self.invalid
jadmanski0afbb632008-06-06 21:10:57 +00001260 assert not self.invalid
1261 self.invalid = True
1262 self.save()
1263 self.clean_object()
showard7c785282008-05-29 19:45:12 +00001264
1265
jadmanski0afbb632008-06-06 21:10:57 +00001266 @classmethod
1267 def get_valid_manager(cls):
1268 return cls.valid_objects
showard7c785282008-05-29 19:45:12 +00001269
1270
jadmanski0afbb632008-06-06 21:10:57 +00001271 class Manipulator(object):
1272 """
1273 Force default manipulators to look only at valid objects -
1274 otherwise they will match against invalid objects when checking
1275 uniqueness.
1276 """
1277 @classmethod
1278 def _prepare(cls, model):
1279 super(ModelWithInvalid.Manipulator, cls)._prepare(model)
1280 cls.manager = model.valid_objects
showardf8b19042009-05-12 17:22:49 +00001281
1282
1283class ModelWithAttributes(object):
1284 """
1285 Mixin class for models that have an attribute model associated with them.
1286 The attribute model is assumed to have its value field named "value".
1287 """
1288
1289 def _get_attribute_model_and_args(self, attribute):
1290 """
1291 Subclasses should override this to return a tuple (attribute_model,
1292 keyword_args), where attribute_model is a model class and keyword_args
1293 is a dict of args to pass to attribute_model.objects.get() to get an
1294 instance of the given attribute on this object.
1295 """
Dale Curtis74a314b2011-06-23 14:55:46 -07001296 raise NotImplementedError
showardf8b19042009-05-12 17:22:49 +00001297
1298
1299 def set_attribute(self, attribute, value):
1300 attribute_model, get_args = self._get_attribute_model_and_args(
1301 attribute)
1302 attribute_object, _ = attribute_model.objects.get_or_create(**get_args)
1303 attribute_object.value = value
1304 attribute_object.save()
1305
1306
1307 def delete_attribute(self, attribute):
1308 attribute_model, get_args = self._get_attribute_model_and_args(
1309 attribute)
1310 try:
1311 attribute_model.objects.get(**get_args).delete()
showard16245422009-09-08 16:28:15 +00001312 except attribute_model.DoesNotExist:
showardf8b19042009-05-12 17:22:49 +00001313 pass
1314
1315
1316 def set_or_delete_attribute(self, attribute, value):
1317 if value is None:
1318 self.delete_attribute(attribute)
1319 else:
1320 self.set_attribute(attribute, value)
showard26b7ec72009-12-21 22:43:57 +00001321
1322
1323class ModelWithHashManager(dbmodels.Manager):
1324 """Manager for use with the ModelWithHash abstract model class"""
1325
1326 def create(self, **kwargs):
1327 raise Exception('ModelWithHash manager should use get_or_create() '
1328 'instead of create()')
1329
1330
1331 def get_or_create(self, **kwargs):
1332 kwargs['the_hash'] = self.model._compute_hash(**kwargs)
1333 return super(ModelWithHashManager, self).get_or_create(**kwargs)
1334
1335
1336class ModelWithHash(dbmodels.Model):
1337 """Superclass with methods for dealing with a hash column"""
1338
1339 the_hash = dbmodels.CharField(max_length=40, unique=True)
1340
1341 objects = ModelWithHashManager()
1342
1343 class Meta:
1344 abstract = True
1345
1346
1347 @classmethod
1348 def _compute_hash(cls, **kwargs):
1349 raise NotImplementedError('Subclasses must override _compute_hash()')
1350
1351
1352 def save(self, force_insert=False, **kwargs):
1353 """Prevents saving the model in most cases
1354
1355 We want these models to be immutable, so the generic save() operation
1356 will not work. These models should be instantiated through their the
1357 model.objects.get_or_create() method instead.
1358
1359 The exception is that save(force_insert=True) will be allowed, since
1360 that creates a new row. However, the preferred way to make instances of
1361 these models is through the get_or_create() method.
1362 """
1363 if not force_insert:
1364 # Allow a forced insert to happen; if it's a duplicate, the unique
1365 # constraint will catch it later anyways
1366 raise Exception('ModelWithHash is immutable')
1367 super(ModelWithHash, self).save(force_insert=force_insert, **kwargs)