blob: 2566e24ff7317eda8432c1dba0cdb28f6bfded46 [file] [log] [blame]
showard7c785282008-05-29 19:45:12 +00001"""
2Extensions to Django's model logic.
3"""
4
showarda5288b42009-07-28 20:06:08 +00005import re
Michael Liang8864e862014-07-22 08:36:05 -07006import time
showarda5288b42009-07-28 20:06:08 +00007import django.core.exceptions
Jakob Juelich7bef8412014-10-14 19:11:54 -07008from django.db import models as dbmodels, backend, connection, connections
showarda5288b42009-07-28 20:06:08 +00009from django.db.models.sql import query
showard7e67b432010-01-20 01:13:04 +000010import django.db.models.sql.where
showard7c785282008-05-29 19:45:12 +000011from django.utils import datastructures
Prashanth B489b91d2014-03-15 12:17:16 -070012from autotest_lib.frontend.afe import rdb_model_extensions
showard7c785282008-05-29 19:45:12 +000013
Prashanth B489b91d2014-03-15 12:17:16 -070014
15class ValidationError(django.core.exceptions.ValidationError):
jadmanski0afbb632008-06-06 21:10:57 +000016 """\
showarda5288b42009-07-28 20:06:08 +000017 Data validation error in adding or updating an object. The associated
jadmanski0afbb632008-06-06 21:10:57 +000018 value is a dictionary mapping field names to error strings.
19 """
showard7c785282008-05-29 19:45:12 +000020
showarda5288b42009-07-28 20:06:08 +000021def _quote_name(name):
22 """Shorthand for connection.ops.quote_name()."""
23 return connection.ops.quote_name(name)
24
25
beepscc9fc702013-12-02 12:45:38 -080026class LeasedHostManager(dbmodels.Manager):
27 """Query manager for unleased, unlocked hosts.
28 """
29 def get_query_set(self):
30 return (super(LeasedHostManager, self).get_query_set().filter(
31 leased=0, locked=0))
32
33
showard7c785282008-05-29 19:45:12 +000034class ExtendedManager(dbmodels.Manager):
jadmanski0afbb632008-06-06 21:10:57 +000035 """\
36 Extended manager supporting subquery filtering.
37 """
showard7c785282008-05-29 19:45:12 +000038
showardf828c772010-01-25 21:49:42 +000039 class CustomQuery(query.Query):
showard7e67b432010-01-20 01:13:04 +000040 def __init__(self, *args, **kwargs):
showardf828c772010-01-25 21:49:42 +000041 super(ExtendedManager.CustomQuery, self).__init__(*args, **kwargs)
showard7e67b432010-01-20 01:13:04 +000042 self._custom_joins = []
43
44
showarda5288b42009-07-28 20:06:08 +000045 def clone(self, klass=None, **kwargs):
showardf828c772010-01-25 21:49:42 +000046 obj = super(ExtendedManager.CustomQuery, self).clone(klass)
showard7e67b432010-01-20 01:13:04 +000047 obj._custom_joins = list(self._custom_joins)
showarda5288b42009-07-28 20:06:08 +000048 return obj
showard08f981b2008-06-24 21:59:03 +000049
showard7e67b432010-01-20 01:13:04 +000050
51 def combine(self, rhs, connector):
showardf828c772010-01-25 21:49:42 +000052 super(ExtendedManager.CustomQuery, self).combine(rhs, connector)
showard7e67b432010-01-20 01:13:04 +000053 if hasattr(rhs, '_custom_joins'):
54 self._custom_joins.extend(rhs._custom_joins)
55
56
57 def add_custom_join(self, table, condition, join_type,
58 condition_values=(), alias=None):
59 if alias is None:
60 alias = table
61 join_dict = dict(table=table,
62 condition=condition,
63 condition_values=condition_values,
64 join_type=join_type,
65 alias=alias)
66 self._custom_joins.append(join_dict)
67
68
showard7e67b432010-01-20 01:13:04 +000069 @classmethod
70 def convert_query(self, query_set):
71 """
showardf828c772010-01-25 21:49:42 +000072 Convert the query set's "query" attribute to a CustomQuery.
showard7e67b432010-01-20 01:13:04 +000073 """
74 # Make a copy of the query set
75 query_set = query_set.all()
76 query_set.query = query_set.query.clone(
showardf828c772010-01-25 21:49:42 +000077 klass=ExtendedManager.CustomQuery,
showard7e67b432010-01-20 01:13:04 +000078 _custom_joins=[])
79 return query_set
showard43a3d262008-11-12 18:17:05 +000080
81
showard7e67b432010-01-20 01:13:04 +000082 class _WhereClause(object):
83 """Object allowing us to inject arbitrary SQL into Django queries.
showard43a3d262008-11-12 18:17:05 +000084
showard7e67b432010-01-20 01:13:04 +000085 By using this instead of extra(where=...), we can still freely combine
86 queries with & and |.
showarda5288b42009-07-28 20:06:08 +000087 """
showard7e67b432010-01-20 01:13:04 +000088 def __init__(self, clause, values=()):
89 self._clause = clause
90 self._values = values
showarda5288b42009-07-28 20:06:08 +000091
showard7e67b432010-01-20 01:13:04 +000092
Dale Curtis74a314b2011-06-23 14:55:46 -070093 def as_sql(self, qn=None, connection=None):
showard7e67b432010-01-20 01:13:04 +000094 return self._clause, self._values
95
96
97 def relabel_aliases(self, change_map):
98 return
showard43a3d262008-11-12 18:17:05 +000099
100
showard8b0ea222009-12-23 19:23:03 +0000101 def add_join(self, query_set, join_table, join_key, join_condition='',
showard7e67b432010-01-20 01:13:04 +0000102 join_condition_values=(), join_from_key=None, alias=None,
103 suffix='', exclude=False, force_left_join=False):
104 """Add a join to query_set.
105
106 Join looks like this:
107 (INNER|LEFT) JOIN <join_table> AS <alias>
108 ON (<this table>.<join_from_key> = <join_table>.<join_key>
109 and <join_condition>)
110
showard0957a842009-05-11 19:25:08 +0000111 @param join_table table to join to
112 @param join_key field referencing back to this model to use for the join
113 @param join_condition extra condition for the ON clause of the join
showard7e67b432010-01-20 01:13:04 +0000114 @param join_condition_values values to substitute into join_condition
115 @param join_from_key column on this model to join from.
showard8b0ea222009-12-23 19:23:03 +0000116 @param alias alias to use for for join
117 @param suffix suffix to add to join_table for the join alias, if no
118 alias is provided
showard0957a842009-05-11 19:25:08 +0000119 @param exclude if true, exclude rows that match this join (will use a
showarda5288b42009-07-28 20:06:08 +0000120 LEFT OUTER JOIN and an appropriate WHERE condition)
showardc4780402009-08-31 18:31:34 +0000121 @param force_left_join - if true, a LEFT OUTER JOIN will be used
122 instead of an INNER JOIN regardless of other options
showard0957a842009-05-11 19:25:08 +0000123 """
showard7e67b432010-01-20 01:13:04 +0000124 join_from_table = query_set.model._meta.db_table
125 if join_from_key is None:
126 join_from_key = self.model._meta.pk.name
127 if alias is None:
128 alias = join_table + suffix
129 full_join_key = _quote_name(alias) + '.' + _quote_name(join_key)
130 full_join_condition = '%s = %s.%s' % (full_join_key,
131 _quote_name(join_from_table),
132 _quote_name(join_from_key))
showard43a3d262008-11-12 18:17:05 +0000133 if join_condition:
134 full_join_condition += ' AND (' + join_condition + ')'
135 if exclude or force_left_join:
showarda5288b42009-07-28 20:06:08 +0000136 join_type = query_set.query.LOUTER
showard43a3d262008-11-12 18:17:05 +0000137 else:
showarda5288b42009-07-28 20:06:08 +0000138 join_type = query_set.query.INNER
showard43a3d262008-11-12 18:17:05 +0000139
showardf828c772010-01-25 21:49:42 +0000140 query_set = self.CustomQuery.convert_query(query_set)
showard7e67b432010-01-20 01:13:04 +0000141 query_set.query.add_custom_join(join_table,
142 full_join_condition,
143 join_type,
144 condition_values=join_condition_values,
145 alias=alias)
showard43a3d262008-11-12 18:17:05 +0000146
showard7e67b432010-01-20 01:13:04 +0000147 if exclude:
148 query_set = query_set.extra(where=[full_join_key + ' IS NULL'])
149
150 return query_set
151
152
153 def _info_for_many_to_one_join(self, field, join_to_query, alias):
154 """
155 @param field: the ForeignKey field on the related model
156 @param join_to_query: the query over the related model that we're
157 joining to
158 @param alias: alias of joined table
159 """
160 info = {}
161 rhs_table = join_to_query.model._meta.db_table
162 info['rhs_table'] = rhs_table
163 info['rhs_column'] = field.column
164 info['lhs_column'] = field.rel.get_related_field().column
165 rhs_where = join_to_query.query.where
166 rhs_where.relabel_aliases({rhs_table: alias})
Dale Curtis74a314b2011-06-23 14:55:46 -0700167 compiler = join_to_query.query.get_compiler(using=join_to_query.db)
168 initial_clause, values = compiler.as_sql()
169 all_clauses = (initial_clause,)
170 if hasattr(join_to_query.query, 'extra_where'):
171 all_clauses += join_to_query.query.extra_where
172 info['where_clause'] = (
173 ' AND '.join('(%s)' % clause for clause in all_clauses))
showard7e67b432010-01-20 01:13:04 +0000174 info['values'] = values
175 return info
176
177
178 def _info_for_many_to_many_join(self, m2m_field, join_to_query, alias,
179 m2m_is_on_this_model):
180 """
181 @param m2m_field: a Django field representing the M2M relationship.
182 It uses a pivot table with the following structure:
183 this model table <---> M2M pivot table <---> joined model table
184 @param join_to_query: the query over the related model that we're
185 joining to.
186 @param alias: alias of joined table
187 """
188 if m2m_is_on_this_model:
189 # referenced field on this model
190 lhs_id_field = self.model._meta.pk
191 # foreign key on the pivot table referencing lhs_id_field
192 m2m_lhs_column = m2m_field.m2m_column_name()
193 # foreign key on the pivot table referencing rhd_id_field
194 m2m_rhs_column = m2m_field.m2m_reverse_name()
195 # referenced field on related model
196 rhs_id_field = m2m_field.rel.get_related_field()
197 else:
198 lhs_id_field = m2m_field.rel.get_related_field()
199 m2m_lhs_column = m2m_field.m2m_reverse_name()
200 m2m_rhs_column = m2m_field.m2m_column_name()
201 rhs_id_field = join_to_query.model._meta.pk
202
203 info = {}
204 info['rhs_table'] = m2m_field.m2m_db_table()
205 info['rhs_column'] = m2m_lhs_column
206 info['lhs_column'] = lhs_id_field.column
207
208 # select the ID of related models relevant to this join. we can only do
209 # a single join, so we need to gather this information up front and
210 # include it in the join condition.
211 rhs_ids = join_to_query.values_list(rhs_id_field.attname, flat=True)
212 assert len(rhs_ids) == 1, ('Many-to-many custom field joins can only '
213 'match a single related object.')
214 rhs_id = rhs_ids[0]
215
216 info['where_clause'] = '%s.%s = %s' % (_quote_name(alias),
217 _quote_name(m2m_rhs_column),
218 rhs_id)
219 info['values'] = ()
220 return info
221
222
223 def join_custom_field(self, query_set, join_to_query, alias,
224 left_join=True):
225 """Join to a related model to create a custom field in the given query.
226
227 This method is used to construct a custom field on the given query based
228 on a many-valued relationsip. join_to_query should be a simple query
229 (no joins) on the related model which returns at most one related row
230 per instance of this model.
231
232 For many-to-one relationships, the joined table contains the matching
233 row from the related model it one is related, NULL otherwise.
234
235 For many-to-many relationships, the joined table contains the matching
236 row if it's related, NULL otherwise.
237 """
238 relationship_type, field = self.determine_relationship(
239 join_to_query.model)
240
241 if relationship_type == self.MANY_TO_ONE:
242 info = self._info_for_many_to_one_join(field, join_to_query, alias)
243 elif relationship_type == self.M2M_ON_RELATED_MODEL:
244 info = self._info_for_many_to_many_join(
245 m2m_field=field, join_to_query=join_to_query, alias=alias,
246 m2m_is_on_this_model=False)
247 elif relationship_type ==self.M2M_ON_THIS_MODEL:
248 info = self._info_for_many_to_many_join(
249 m2m_field=field, join_to_query=join_to_query, alias=alias,
250 m2m_is_on_this_model=True)
251
252 return self.add_join(query_set, info['rhs_table'], info['rhs_column'],
253 join_from_key=info['lhs_column'],
254 join_condition=info['where_clause'],
255 join_condition_values=info['values'],
256 alias=alias,
257 force_left_join=left_join)
258
259
showardf828c772010-01-25 21:49:42 +0000260 def key_on_joined_table(self, join_to_query):
261 """Get a non-null column on the table joined for the given query.
262
263 This analyzes the join that would be produced if join_to_query were
264 passed to join_custom_field.
265 """
266 relationship_type, field = self.determine_relationship(
267 join_to_query.model)
268 if relationship_type == self.MANY_TO_ONE:
269 return join_to_query.model._meta.pk.column
270 return field.m2m_column_name() # any column on the M2M table will do
271
272
showard7e67b432010-01-20 01:13:04 +0000273 def add_where(self, query_set, where, values=()):
274 query_set = query_set.all()
275 query_set.query.where.add(self._WhereClause(where, values),
276 django.db.models.sql.where.AND)
showardc4780402009-08-31 18:31:34 +0000277 return query_set
showard7c785282008-05-29 19:45:12 +0000278
279
showardeaccf8f2009-04-16 03:11:33 +0000280 def _get_quoted_field(self, table, field):
showarda5288b42009-07-28 20:06:08 +0000281 return _quote_name(table) + '.' + _quote_name(field)
showard5ef36e92008-07-02 16:37:09 +0000282
283
showard7c199df2008-10-03 10:17:15 +0000284 def get_key_on_this_table(self, key_field=None):
showard5ef36e92008-07-02 16:37:09 +0000285 if key_field is None:
286 # default to primary key
287 key_field = self.model._meta.pk.column
288 return self._get_quoted_field(self.model._meta.db_table, key_field)
289
290
showardeaccf8f2009-04-16 03:11:33 +0000291 def escape_user_sql(self, sql):
292 return sql.replace('%', '%%')
293
showard5ef36e92008-07-02 16:37:09 +0000294
showard0957a842009-05-11 19:25:08 +0000295 def _custom_select_query(self, query_set, selects):
Jakob Juelich7bef8412014-10-14 19:11:54 -0700296 """Execute a custom select query.
297
298 @param query_set: query set as returned by query_objects.
299 @param selects: Tables/Columns to select, e.g. tko_test_labels_list.id.
300
301 @returns: Result of the query as returned by cursor.fetchall().
302 """
Dale Curtis74a314b2011-06-23 14:55:46 -0700303 compiler = query_set.query.get_compiler(using=query_set.db)
304 sql, params = compiler.as_sql()
showarda5288b42009-07-28 20:06:08 +0000305 from_ = sql[sql.find(' FROM'):]
306
307 if query_set.query.distinct:
showard0957a842009-05-11 19:25:08 +0000308 distinct = 'DISTINCT '
309 else:
310 distinct = ''
showarda5288b42009-07-28 20:06:08 +0000311
312 sql_query = ('SELECT ' + distinct + ','.join(selects) + from_)
Jakob Juelich7bef8412014-10-14 19:11:54 -0700313 # Chose the connection that's responsible for this type of object
314 cursor = connections[query_set.db].cursor()
showard0957a842009-05-11 19:25:08 +0000315 cursor.execute(sql_query, params)
316 return cursor.fetchall()
317
318
showard68693f72009-05-20 00:31:53 +0000319 def _is_relation_to(self, field, model_class):
320 return field.rel and field.rel.to is model_class
showard0957a842009-05-11 19:25:08 +0000321
322
showard7e67b432010-01-20 01:13:04 +0000323 MANY_TO_ONE = object()
324 M2M_ON_RELATED_MODEL = object()
325 M2M_ON_THIS_MODEL = object()
326
327 def determine_relationship(self, related_model):
328 """
329 Determine the relationship between this model and related_model.
330
331 related_model must have some sort of many-valued relationship to this
332 manager's model.
333 @returns (relationship_type, field), where relationship_type is one of
334 MANY_TO_ONE, M2M_ON_RELATED_MODEL, M2M_ON_THIS_MODEL, and field
335 is the Django field object for the relationship.
336 """
337 # look for a foreign key field on related_model relating to this model
338 for field in related_model._meta.fields:
339 if self._is_relation_to(field, self.model):
340 return self.MANY_TO_ONE, field
341
342 # look for an M2M field on related_model relating to this model
343 for field in related_model._meta.many_to_many:
344 if self._is_relation_to(field, self.model):
345 return self.M2M_ON_RELATED_MODEL, field
346
347 # maybe this model has the many-to-many field
348 for field in self.model._meta.many_to_many:
349 if self._is_relation_to(field, related_model):
350 return self.M2M_ON_THIS_MODEL, field
351
352 raise ValueError('%s has no relation to %s' %
353 (related_model, self.model))
354
355
showard68693f72009-05-20 00:31:53 +0000356 def _get_pivot_iterator(self, base_objects_by_id, related_model):
showard0957a842009-05-11 19:25:08 +0000357 """
showard68693f72009-05-20 00:31:53 +0000358 Determine the relationship between this model and related_model, and
359 return a pivot iterator.
360 @param base_objects_by_id: dict of instances of this model indexed by
361 their IDs
362 @returns a pivot iterator, which yields a tuple (base_object,
363 related_object) for each relationship between a base object and a
364 related object. all base_object instances come from base_objects_by_id.
showard7e67b432010-01-20 01:13:04 +0000365 Note -- this depends on Django model internals.
showard0957a842009-05-11 19:25:08 +0000366 """
showard7e67b432010-01-20 01:13:04 +0000367 relationship_type, field = self.determine_relationship(related_model)
368 if relationship_type == self.MANY_TO_ONE:
369 return self._many_to_one_pivot(base_objects_by_id,
370 related_model, field)
371 elif relationship_type == self.M2M_ON_RELATED_MODEL:
372 return self._many_to_many_pivot(
showard68693f72009-05-20 00:31:53 +0000373 base_objects_by_id, related_model, field.m2m_db_table(),
374 field.m2m_reverse_name(), field.m2m_column_name())
showard7e67b432010-01-20 01:13:04 +0000375 else:
376 assert relationship_type == self.M2M_ON_THIS_MODEL
377 return self._many_to_many_pivot(
showard68693f72009-05-20 00:31:53 +0000378 base_objects_by_id, related_model, field.m2m_db_table(),
379 field.m2m_column_name(), field.m2m_reverse_name())
showard0957a842009-05-11 19:25:08 +0000380
showard0957a842009-05-11 19:25:08 +0000381
showard68693f72009-05-20 00:31:53 +0000382 def _many_to_one_pivot(self, base_objects_by_id, related_model,
383 foreign_key_field):
384 """
385 @returns a pivot iterator - see _get_pivot_iterator()
386 """
387 filter_data = {foreign_key_field.name + '__pk__in':
388 base_objects_by_id.keys()}
389 for related_object in related_model.objects.filter(**filter_data):
showarda5a72c92009-08-20 23:35:21 +0000390 # lookup base object in the dict, rather than grabbing it from the
391 # related object. we need to return instances from the dict, not
392 # fresh instances of the same models (and grabbing model instances
393 # from the related models incurs a DB query each time).
394 base_object_id = getattr(related_object, foreign_key_field.attname)
395 base_object = base_objects_by_id[base_object_id]
showard68693f72009-05-20 00:31:53 +0000396 yield base_object, related_object
397
398
399 def _query_pivot_table(self, base_objects_by_id, pivot_table,
Jakob Juelich7bef8412014-10-14 19:11:54 -0700400 pivot_from_field, pivot_to_field, related_model):
showard0957a842009-05-11 19:25:08 +0000401 """
402 @param id_list list of IDs of self.model objects to include
403 @param pivot_table the name of the pivot table
404 @param pivot_from_field a field name on pivot_table referencing
405 self.model
406 @param pivot_to_field a field name on pivot_table referencing the
407 related model.
Jakob Juelich7bef8412014-10-14 19:11:54 -0700408 @param related_model the related model
409
showard68693f72009-05-20 00:31:53 +0000410 @returns pivot list of IDs (base_id, related_id)
showard0957a842009-05-11 19:25:08 +0000411 """
412 query = """
413 SELECT %(from_field)s, %(to_field)s
414 FROM %(table)s
415 WHERE %(from_field)s IN (%(id_list)s)
416 """ % dict(from_field=pivot_from_field,
417 to_field=pivot_to_field,
418 table=pivot_table,
showard68693f72009-05-20 00:31:53 +0000419 id_list=','.join(str(id_) for id_
420 in base_objects_by_id.iterkeys()))
Jakob Juelich7bef8412014-10-14 19:11:54 -0700421
422 # Chose the connection that's responsible for this type of object
423 # The databases for related_model and the current model will always
424 # be the same, related_model is just easier to obtain here because
425 # self is only a ExtendedManager, not the object.
426 cursor = connections[related_model.objects.db].cursor()
showard0957a842009-05-11 19:25:08 +0000427 cursor.execute(query)
showard68693f72009-05-20 00:31:53 +0000428 return cursor.fetchall()
showard0957a842009-05-11 19:25:08 +0000429
430
showard68693f72009-05-20 00:31:53 +0000431 def _many_to_many_pivot(self, base_objects_by_id, related_model,
432 pivot_table, pivot_from_field, pivot_to_field):
433 """
434 @param pivot_table: see _query_pivot_table
435 @param pivot_from_field: see _query_pivot_table
436 @param pivot_to_field: see _query_pivot_table
437 @returns a pivot iterator - see _get_pivot_iterator()
438 """
439 id_pivot = self._query_pivot_table(base_objects_by_id, pivot_table,
Jakob Juelich7bef8412014-10-14 19:11:54 -0700440 pivot_from_field, pivot_to_field,
441 related_model)
showard68693f72009-05-20 00:31:53 +0000442
443 all_related_ids = list(set(related_id for base_id, related_id
444 in id_pivot))
445 related_objects_by_id = related_model.objects.in_bulk(all_related_ids)
446
447 for base_id, related_id in id_pivot:
448 yield base_objects_by_id[base_id], related_objects_by_id[related_id]
449
450
451 def populate_relationships(self, base_objects, related_model,
showard0957a842009-05-11 19:25:08 +0000452 related_list_name):
453 """
showard68693f72009-05-20 00:31:53 +0000454 For each instance of this model in base_objects, add a field named
455 related_list_name listing all the related objects of type related_model.
456 related_model must be in a many-to-one or many-to-many relationship with
457 this model.
458 @param base_objects - list of instances of this model
459 @param related_model - model class related to this model
460 @param related_list_name - attribute name in which to store the related
461 object list.
showard0957a842009-05-11 19:25:08 +0000462 """
showard68693f72009-05-20 00:31:53 +0000463 if not base_objects:
showard0957a842009-05-11 19:25:08 +0000464 # if we don't bail early, we'll get a SQL error later
465 return
showard0957a842009-05-11 19:25:08 +0000466
showard68693f72009-05-20 00:31:53 +0000467 base_objects_by_id = dict((base_object._get_pk_val(), base_object)
468 for base_object in base_objects)
469 pivot_iterator = self._get_pivot_iterator(base_objects_by_id,
470 related_model)
showard0957a842009-05-11 19:25:08 +0000471
showard68693f72009-05-20 00:31:53 +0000472 for base_object in base_objects:
473 setattr(base_object, related_list_name, [])
474
475 for base_object, related_object in pivot_iterator:
476 getattr(base_object, related_list_name).append(related_object)
showard0957a842009-05-11 19:25:08 +0000477
478
jamesrene3656232010-03-02 00:00:30 +0000479class ModelWithInvalidQuerySet(dbmodels.query.QuerySet):
480 """
481 QuerySet that handles delete() properly for models with an "invalid" bit
482 """
483 def delete(self):
484 for model in self:
485 model.delete()
486
487
488class ModelWithInvalidManager(ExtendedManager):
489 """
490 Manager for objects with an "invalid" bit
491 """
492 def get_query_set(self):
493 return ModelWithInvalidQuerySet(self.model)
494
495
496class ValidObjectsManager(ModelWithInvalidManager):
jadmanski0afbb632008-06-06 21:10:57 +0000497 """
498 Manager returning only objects with invalid=False.
499 """
500 def get_query_set(self):
501 queryset = super(ValidObjectsManager, self).get_query_set()
502 return queryset.filter(invalid=False)
showard7c785282008-05-29 19:45:12 +0000503
504
Prashanth B489b91d2014-03-15 12:17:16 -0700505class ModelExtensions(rdb_model_extensions.ModelValidators):
jadmanski0afbb632008-06-06 21:10:57 +0000506 """\
Prashanth B489b91d2014-03-15 12:17:16 -0700507 Mixin with convenience functions for models, built on top of
508 the model validators in rdb_model_extensions.
jadmanski0afbb632008-06-06 21:10:57 +0000509 """
510 # TODO: at least some of these functions really belong in a custom
511 # Manager class
showard7c785282008-05-29 19:45:12 +0000512
Jakob Juelich3bb7c802014-09-02 16:31:11 -0700513
514 SERIALIZATION_LINKS_TO_FOLLOW = set()
515 """
516 To be able to send jobs and hosts to shards, it's necessary to find their
517 dependencies.
518 The most generic approach for this would be to traverse all relationships
519 to other objects recursively. This would list all objects that are related
520 in any way.
521 But this approach finds too many objects: If a host should be transferred,
522 all it's relationships would be traversed. This would find an acl group.
523 If then the acl group's relationships are traversed, the relationship
524 would be followed backwards and many other hosts would be found.
525
526 This mapping tells that algorithm which relations to follow explicitly.
527 """
528
Jakob Juelichf865d332014-09-29 10:47:49 -0700529
530 SERIALIZATION_LOCAL_LINKS_TO_UPDATE = set()
531 """
532 On deserializion, if the object to persist already exists, local fields
533 will only be updated, if their name is in this set.
534 """
535
536
jadmanski0afbb632008-06-06 21:10:57 +0000537 @classmethod
538 def convert_human_readable_values(cls, data, to_human_readable=False):
539 """\
540 Performs conversions on user-supplied field data, to make it
541 easier for users to pass human-readable data.
showard7c785282008-05-29 19:45:12 +0000542
jadmanski0afbb632008-06-06 21:10:57 +0000543 For all fields that have choice sets, convert their values
544 from human-readable strings to enum values, if necessary. This
545 allows users to pass strings instead of the corresponding
546 integer values.
showard7c785282008-05-29 19:45:12 +0000547
jadmanski0afbb632008-06-06 21:10:57 +0000548 For all foreign key fields, call smart_get with the supplied
549 data. This allows the user to pass either an ID value or
550 the name of the object as a string.
showard7c785282008-05-29 19:45:12 +0000551
jadmanski0afbb632008-06-06 21:10:57 +0000552 If to_human_readable=True, perform the inverse - i.e. convert
553 numeric values to human readable values.
showard7c785282008-05-29 19:45:12 +0000554
jadmanski0afbb632008-06-06 21:10:57 +0000555 This method modifies data in-place.
556 """
557 field_dict = cls.get_field_dict()
558 for field_name in data:
showarde732ee72008-09-23 19:15:43 +0000559 if field_name not in field_dict or data[field_name] is None:
jadmanski0afbb632008-06-06 21:10:57 +0000560 continue
561 field_obj = field_dict[field_name]
562 # convert enum values
563 if field_obj.choices:
564 for choice_data in field_obj.choices:
565 # choice_data is (value, name)
566 if to_human_readable:
567 from_val, to_val = choice_data
568 else:
569 to_val, from_val = choice_data
570 if from_val == data[field_name]:
571 data[field_name] = to_val
572 break
573 # convert foreign key values
574 elif field_obj.rel:
showarda4ea5742009-02-17 20:56:23 +0000575 dest_obj = field_obj.rel.to.smart_get(data[field_name],
576 valid_only=False)
showardf8b19042009-05-12 17:22:49 +0000577 if to_human_readable:
Paul Pendlebury5a8c6ad2011-02-01 07:20:17 -0800578 # parameterized_jobs do not have a name_field
579 if (field_name != 'parameterized_job' and
580 dest_obj.name_field is not None):
showardf8b19042009-05-12 17:22:49 +0000581 data[field_name] = getattr(dest_obj,
582 dest_obj.name_field)
jadmanski0afbb632008-06-06 21:10:57 +0000583 else:
showardb0a73032009-03-27 18:35:41 +0000584 data[field_name] = dest_obj
showard7c785282008-05-29 19:45:12 +0000585
586
showard7c785282008-05-29 19:45:12 +0000587
588
Dale Curtis74a314b2011-06-23 14:55:46 -0700589 def _validate_unique(self):
jadmanski0afbb632008-06-06 21:10:57 +0000590 """\
591 Validate that unique fields are unique. Django manipulators do
592 this too, but they're a huge pain to use manually. Trust me.
593 """
594 errors = {}
595 cls = type(self)
596 field_dict = self.get_field_dict()
597 manager = cls.get_valid_manager()
598 for field_name, field_obj in field_dict.iteritems():
599 if not field_obj.unique:
600 continue
showard7c785282008-05-29 19:45:12 +0000601
jadmanski0afbb632008-06-06 21:10:57 +0000602 value = getattr(self, field_name)
showardbd18ab72009-09-18 21:20:27 +0000603 if value is None and field_obj.auto_created:
604 # don't bother checking autoincrement fields about to be
605 # generated
606 continue
607
jadmanski0afbb632008-06-06 21:10:57 +0000608 existing_objs = manager.filter(**{field_name : value})
609 num_existing = existing_objs.count()
showard7c785282008-05-29 19:45:12 +0000610
jadmanski0afbb632008-06-06 21:10:57 +0000611 if num_existing == 0:
612 continue
613 if num_existing == 1 and existing_objs[0].id == self.id:
614 continue
615 errors[field_name] = (
616 'This value must be unique (%s)' % (value))
617 return errors
showard7c785282008-05-29 19:45:12 +0000618
619
showarda5288b42009-07-28 20:06:08 +0000620 def _validate(self):
621 """
622 First coerces all fields on this instance to their proper Python types.
623 Then runs validation on every field. Returns a dictionary of
624 field_name -> error_list.
625
626 Based on validate() from django.db.models.Model in Django 0.96, which
627 was removed in Django 1.0. It should reappear in a later version. See:
628 http://code.djangoproject.com/ticket/6845
629 """
630 error_dict = {}
631 for f in self._meta.fields:
632 try:
633 python_value = f.to_python(
634 getattr(self, f.attname, f.get_default()))
635 except django.core.exceptions.ValidationError, e:
jamesren1e0a4ce2010-04-21 17:45:11 +0000636 error_dict[f.name] = str(e)
showarda5288b42009-07-28 20:06:08 +0000637 continue
638
639 if not f.blank and not python_value:
640 error_dict[f.name] = 'This field is required.'
641 continue
642
643 setattr(self, f.attname, python_value)
644
645 return error_dict
646
647
jadmanski0afbb632008-06-06 21:10:57 +0000648 def do_validate(self):
showarda5288b42009-07-28 20:06:08 +0000649 errors = self._validate()
Dale Curtis74a314b2011-06-23 14:55:46 -0700650 unique_errors = self._validate_unique()
jadmanski0afbb632008-06-06 21:10:57 +0000651 for field_name, error in unique_errors.iteritems():
652 errors.setdefault(field_name, error)
653 if errors:
654 raise ValidationError(errors)
showard7c785282008-05-29 19:45:12 +0000655
656
jadmanski0afbb632008-06-06 21:10:57 +0000657 # actually (externally) useful methods follow
showard7c785282008-05-29 19:45:12 +0000658
jadmanski0afbb632008-06-06 21:10:57 +0000659 @classmethod
660 def add_object(cls, data={}, **kwargs):
661 """\
662 Returns a new object created with the given data (a dictionary
663 mapping field names to values). Merges any extra keyword args
664 into data.
665 """
Prashanth B489b91d2014-03-15 12:17:16 -0700666 data = dict(data)
667 data.update(kwargs)
668 data = cls.prepare_data_args(data)
669 cls.convert_human_readable_values(data)
jadmanski0afbb632008-06-06 21:10:57 +0000670 data = cls.provide_default_values(data)
Prashanth B489b91d2014-03-15 12:17:16 -0700671
jadmanski0afbb632008-06-06 21:10:57 +0000672 obj = cls(**data)
673 obj.do_validate()
674 obj.save()
675 return obj
showard7c785282008-05-29 19:45:12 +0000676
677
jadmanski0afbb632008-06-06 21:10:57 +0000678 def update_object(self, data={}, **kwargs):
679 """\
680 Updates the object with the given data (a dictionary mapping
681 field names to values). Merges any extra keyword args into
682 data.
683 """
Prashanth B489b91d2014-03-15 12:17:16 -0700684 data = dict(data)
685 data.update(kwargs)
686 data = self.prepare_data_args(data)
687 self.convert_human_readable_values(data)
jadmanski0afbb632008-06-06 21:10:57 +0000688 for field_name, value in data.iteritems():
showardb0a73032009-03-27 18:35:41 +0000689 setattr(self, field_name, value)
jadmanski0afbb632008-06-06 21:10:57 +0000690 self.do_validate()
691 self.save()
showard7c785282008-05-29 19:45:12 +0000692
693
showard8bfb5cb2009-10-07 20:49:15 +0000694 # see query_objects()
695 _SPECIAL_FILTER_KEYS = ('query_start', 'query_limit', 'sort_by',
696 'extra_args', 'extra_where', 'no_distinct')
697
698
jadmanski0afbb632008-06-06 21:10:57 +0000699 @classmethod
showard8bfb5cb2009-10-07 20:49:15 +0000700 def _extract_special_params(cls, filter_data):
701 """
702 @returns a tuple of dicts (special_params, regular_filters), where
703 special_params contains the parameters we handle specially and
704 regular_filters is the remaining data to be handled by Django.
705 """
706 regular_filters = dict(filter_data)
707 special_params = {}
708 for key in cls._SPECIAL_FILTER_KEYS:
709 if key in regular_filters:
710 special_params[key] = regular_filters.pop(key)
711 return special_params, regular_filters
712
713
714 @classmethod
715 def apply_presentation(cls, query, filter_data):
716 """
717 Apply presentation parameters -- sorting and paging -- to the given
718 query.
719 @returns new query with presentation applied
720 """
721 special_params, _ = cls._extract_special_params(filter_data)
722 sort_by = special_params.get('sort_by', None)
723 if sort_by:
724 assert isinstance(sort_by, list) or isinstance(sort_by, tuple)
showard8b0ea222009-12-23 19:23:03 +0000725 query = query.extra(order_by=sort_by)
showard8bfb5cb2009-10-07 20:49:15 +0000726
727 query_start = special_params.get('query_start', None)
728 query_limit = special_params.get('query_limit', None)
729 if query_start is not None:
730 if query_limit is None:
731 raise ValueError('Cannot pass query_start without query_limit')
732 # query_limit is passed as a page size
showard7074b742009-10-12 20:30:04 +0000733 query_limit += query_start
734 return query[query_start:query_limit]
showard8bfb5cb2009-10-07 20:49:15 +0000735
736
737 @classmethod
738 def query_objects(cls, filter_data, valid_only=True, initial_query=None,
739 apply_presentation=True):
jadmanski0afbb632008-06-06 21:10:57 +0000740 """\
741 Returns a QuerySet object for querying the given model_class
742 with the given filter_data. Optional special arguments in
743 filter_data include:
744 -query_start: index of first return to return
745 -query_limit: maximum number of results to return
746 -sort_by: list of fields to sort on. prefixing a '-' onto a
747 field name changes the sort to descending order.
748 -extra_args: keyword args to pass to query.extra() (see Django
749 DB layer documentation)
showarda5288b42009-07-28 20:06:08 +0000750 -extra_where: extra WHERE clause to append
showard8bfb5cb2009-10-07 20:49:15 +0000751 -no_distinct: if True, a DISTINCT will not be added to the SELECT
jadmanski0afbb632008-06-06 21:10:57 +0000752 """
showard8bfb5cb2009-10-07 20:49:15 +0000753 special_params, regular_filters = cls._extract_special_params(
754 filter_data)
showard7c785282008-05-29 19:45:12 +0000755
showard7ac7b7a2008-07-21 20:24:29 +0000756 if initial_query is None:
757 if valid_only:
758 initial_query = cls.get_valid_manager()
759 else:
760 initial_query = cls.objects
showard8bfb5cb2009-10-07 20:49:15 +0000761
762 query = initial_query.filter(**regular_filters)
763
764 use_distinct = not special_params.get('no_distinct', False)
showard7ac7b7a2008-07-21 20:24:29 +0000765 if use_distinct:
766 query = query.distinct()
showard7c785282008-05-29 19:45:12 +0000767
showard8bfb5cb2009-10-07 20:49:15 +0000768 extra_args = special_params.get('extra_args', {})
769 extra_where = special_params.get('extra_where', None)
770 if extra_where:
771 # escape %'s
772 extra_where = cls.objects.escape_user_sql(extra_where)
773 extra_args.setdefault('where', []).append(extra_where)
jadmanski0afbb632008-06-06 21:10:57 +0000774 if extra_args:
775 query = query.extra(**extra_args)
Jakob Juelich7bef8412014-10-14 19:11:54 -0700776 # TODO: Use readonly connection for these queries.
777 # This has been disabled, because it's not used anyway, as the
778 # configured readonly user is the same as the real user anyway.
showard7c785282008-05-29 19:45:12 +0000779
showard8bfb5cb2009-10-07 20:49:15 +0000780 if apply_presentation:
781 query = cls.apply_presentation(query, filter_data)
782
783 return query
showard7c785282008-05-29 19:45:12 +0000784
785
jadmanski0afbb632008-06-06 21:10:57 +0000786 @classmethod
showard585c2ab2008-07-23 19:29:49 +0000787 def query_count(cls, filter_data, initial_query=None):
jadmanski0afbb632008-06-06 21:10:57 +0000788 """\
789 Like query_objects, but retreive only the count of results.
790 """
791 filter_data.pop('query_start', None)
792 filter_data.pop('query_limit', None)
showard585c2ab2008-07-23 19:29:49 +0000793 query = cls.query_objects(filter_data, initial_query=initial_query)
794 return query.count()
showard7c785282008-05-29 19:45:12 +0000795
796
jadmanski0afbb632008-06-06 21:10:57 +0000797 @classmethod
798 def clean_object_dicts(cls, field_dicts):
799 """\
800 Take a list of dicts corresponding to object (as returned by
801 query.values()) and clean the data to be more suitable for
802 returning to the user.
803 """
showarde732ee72008-09-23 19:15:43 +0000804 for field_dict in field_dicts:
805 cls.clean_foreign_keys(field_dict)
showard21baa452008-10-21 00:08:39 +0000806 cls._convert_booleans(field_dict)
showarde732ee72008-09-23 19:15:43 +0000807 cls.convert_human_readable_values(field_dict,
808 to_human_readable=True)
showard7c785282008-05-29 19:45:12 +0000809
810
jadmanski0afbb632008-06-06 21:10:57 +0000811 @classmethod
showard8bfb5cb2009-10-07 20:49:15 +0000812 def list_objects(cls, filter_data, initial_query=None):
jadmanski0afbb632008-06-06 21:10:57 +0000813 """\
814 Like query_objects, but return a list of dictionaries.
815 """
showard7ac7b7a2008-07-21 20:24:29 +0000816 query = cls.query_objects(filter_data, initial_query=initial_query)
showard8bfb5cb2009-10-07 20:49:15 +0000817 extra_fields = query.query.extra_select.keys()
818 field_dicts = [model_object.get_object_dict(extra_fields=extra_fields)
showarde732ee72008-09-23 19:15:43 +0000819 for model_object in query]
jadmanski0afbb632008-06-06 21:10:57 +0000820 return field_dicts
showard7c785282008-05-29 19:45:12 +0000821
822
jadmanski0afbb632008-06-06 21:10:57 +0000823 @classmethod
showarda4ea5742009-02-17 20:56:23 +0000824 def smart_get(cls, id_or_name, valid_only=True):
jadmanski0afbb632008-06-06 21:10:57 +0000825 """\
826 smart_get(integer) -> get object by ID
827 smart_get(string) -> get object by name_field
jadmanski0afbb632008-06-06 21:10:57 +0000828 """
showarda4ea5742009-02-17 20:56:23 +0000829 if valid_only:
830 manager = cls.get_valid_manager()
831 else:
832 manager = cls.objects
833
834 if isinstance(id_or_name, (int, long)):
835 return manager.get(pk=id_or_name)
jamesren3e9f6092010-03-11 21:32:10 +0000836 if isinstance(id_or_name, basestring) and hasattr(cls, 'name_field'):
showarda4ea5742009-02-17 20:56:23 +0000837 return manager.get(**{cls.name_field : id_or_name})
838 raise ValueError(
839 'Invalid positional argument: %s (%s)' % (id_or_name,
840 type(id_or_name)))
showard7c785282008-05-29 19:45:12 +0000841
842
showardbe3ec042008-11-12 18:16:07 +0000843 @classmethod
844 def smart_get_bulk(cls, id_or_name_list):
845 invalid_inputs = []
846 result_objects = []
847 for id_or_name in id_or_name_list:
848 try:
849 result_objects.append(cls.smart_get(id_or_name))
850 except cls.DoesNotExist:
851 invalid_inputs.append(id_or_name)
852 if invalid_inputs:
mbligh7a3ebe32008-12-01 17:10:33 +0000853 raise cls.DoesNotExist('The following %ss do not exist: %s'
854 % (cls.__name__.lower(),
855 ', '.join(invalid_inputs)))
showardbe3ec042008-11-12 18:16:07 +0000856 return result_objects
857
858
showard8bfb5cb2009-10-07 20:49:15 +0000859 def get_object_dict(self, extra_fields=None):
jadmanski0afbb632008-06-06 21:10:57 +0000860 """\
showard8bfb5cb2009-10-07 20:49:15 +0000861 Return a dictionary mapping fields to this object's values. @param
862 extra_fields: list of extra attribute names to include, in addition to
863 the fields defined on this object.
jadmanski0afbb632008-06-06 21:10:57 +0000864 """
showard8bfb5cb2009-10-07 20:49:15 +0000865 fields = self.get_field_dict().keys()
866 if extra_fields:
867 fields += extra_fields
jadmanski0afbb632008-06-06 21:10:57 +0000868 object_dict = dict((field_name, getattr(self, field_name))
showarde732ee72008-09-23 19:15:43 +0000869 for field_name in fields)
jadmanski0afbb632008-06-06 21:10:57 +0000870 self.clean_object_dicts([object_dict])
showardd3dc1992009-04-22 21:01:40 +0000871 self._postprocess_object_dict(object_dict)
jadmanski0afbb632008-06-06 21:10:57 +0000872 return object_dict
showard7c785282008-05-29 19:45:12 +0000873
874
showardd3dc1992009-04-22 21:01:40 +0000875 def _postprocess_object_dict(self, object_dict):
876 """For subclasses to override."""
877 pass
878
879
jadmanski0afbb632008-06-06 21:10:57 +0000880 @classmethod
881 def get_valid_manager(cls):
882 return cls.objects
showard7c785282008-05-29 19:45:12 +0000883
884
showard2bab8f42008-11-12 18:15:22 +0000885 def _record_attributes(self, attributes):
886 """
887 See on_attribute_changed.
888 """
889 assert not isinstance(attributes, basestring)
890 self._recorded_attributes = dict((attribute, getattr(self, attribute))
891 for attribute in attributes)
892
893
894 def _check_for_updated_attributes(self):
895 """
896 See on_attribute_changed.
897 """
898 for attribute, original_value in self._recorded_attributes.iteritems():
899 new_value = getattr(self, attribute)
900 if original_value != new_value:
901 self.on_attribute_changed(attribute, original_value)
902 self._record_attributes(self._recorded_attributes.keys())
903
904
905 def on_attribute_changed(self, attribute, old_value):
906 """
907 Called whenever an attribute is updated. To be overridden.
908
909 To use this method, you must:
910 * call _record_attributes() from __init__() (after making the super
911 call) with a list of attributes for which you want to be notified upon
912 change.
913 * call _check_for_updated_attributes() from save().
914 """
915 pass
916
917
Jakob Juelich116ff0f2014-09-17 18:25:16 -0700918 def serialize(self, include_dependencies=True):
Jakob Juelich3bb7c802014-09-02 16:31:11 -0700919 """Serializes the object with dependencies.
920
921 The variable SERIALIZATION_LINKS_TO_FOLLOW defines which dependencies
922 this function will serialize with the object.
923
Jakob Juelich116ff0f2014-09-17 18:25:16 -0700924 @param include_dependencies: Whether or not to follow relations to
925 objects this object depends on.
926 This parameter is used when uploading
927 jobs from a shard to the master, as the
928 master already has all the dependent
929 objects.
930
Jakob Juelich3bb7c802014-09-02 16:31:11 -0700931 @returns: Dictionary representation of the object.
932 """
933 serialized = {}
934 for field in self._meta.concrete_model._meta.local_fields:
935 if field.rel is None:
936 serialized[field.name] = field._get_val_from_obj(self)
937
Jakob Juelich116ff0f2014-09-17 18:25:16 -0700938 if include_dependencies:
939 for link in self.SERIALIZATION_LINKS_TO_FOLLOW:
940 serialized[link] = self._serialize_relation(link)
Jakob Juelich3bb7c802014-09-02 16:31:11 -0700941
942 return serialized
943
944
945 def _serialize_relation(self, link):
946 """Serializes dependent objects given the name of the relation.
947
948 @param link: Name of the relation to take objects from.
949
950 @returns For To-Many relationships a list of the serialized related
951 objects, for To-One relationships the serialized related object.
952 """
953 try:
954 attr = getattr(self, link)
955 except AttributeError:
956 # One-To-One relationships that point to None may raise this
957 return None
958
959 if attr is None:
960 return None
961 if hasattr(attr, 'all'):
962 return [obj.serialize() for obj in attr.all()]
963 return attr.serialize()
964
965
Jakob Juelichf88fa932014-09-03 17:58:04 -0700966 @classmethod
Jakob Juelich116ff0f2014-09-17 18:25:16 -0700967 def _split_local_from_foreign_values(cls, data):
968 """This splits local from foreign values in a serialized object.
969
970 @param data: The serialized object.
971
972 @returns A tuple of two lists, both containing tuples in the form
973 (link_name, link_value). The first list contains all links
974 for local fields, the second one contains those for foreign
975 fields/objects.
976 """
977 links_to_local_values, links_to_related_values = [], []
978 for link, value in data.iteritems():
979 if link in cls.SERIALIZATION_LINKS_TO_FOLLOW:
980 # It's a foreign key
981 links_to_related_values.append((link, value))
982 else:
983 # It's a local attribute
984 links_to_local_values.append((link, value))
985 return links_to_local_values, links_to_related_values
986
987
Jakob Juelichf865d332014-09-29 10:47:49 -0700988 @classmethod
989 def _filter_update_allowed_fields(cls, data):
990 """Filters data and returns only files that updates are allowed on.
991
992 This is i.e. needed for syncing aborted bits from the master to shards.
993
994 Local links are only allowed to be updated, if they are in
995 SERIALIZATION_LOCAL_LINKS_TO_UPDATE.
996 Overwriting existing values is allowed in order to be able to sync i.e.
997 the aborted bit from the master to a shard.
998
999 The whitelisting mechanism is in place to prevent overwriting local
1000 status: If all fields were overwritten, jobs would be completely be
1001 set back to their original (unstarted) state.
1002
1003 @param data: List with tuples of the form (link_name, link_value), as
1004 returned by _split_local_from_foreign_values.
1005
1006 @returns List of the same format as data, but only containing data for
1007 fields that updates are allowed on.
1008 """
1009 return [pair for pair in data
1010 if pair[0] in cls.SERIALIZATION_LOCAL_LINKS_TO_UPDATE]
1011
1012
Jakob Juelich116ff0f2014-09-17 18:25:16 -07001013 def _deserialize_local(self, data):
1014 """Set local attributes from a list of tuples.
1015
1016 @param data: List of tuples like returned by
1017 _split_local_from_foreign_values.
1018 """
1019 for link, value in data:
1020 setattr(self, link, value)
1021 # Overwridden save() methods are prone to errors, so don't execute them.
1022 # This is because:
1023 # - the overwritten methods depend on ACL groups that don't yet exist
1024 # and don't handle errors
1025 # - the overwritten methods think this object already exists in the db
1026 # because the id is already set
1027 super(type(self), self).save()
1028
1029
1030 def _deserialize_relations(self, data):
1031 """Set foreign attributes from a list of tuples.
1032
1033 This deserialized the related objects using their own deserialize()
1034 function and then sets the relation.
1035
1036 @param data: List of tuples like returned by
1037 _split_local_from_foreign_values.
1038 """
1039 for link, value in data:
1040 self._deserialize_relation(link, value)
1041 # See comment in _deserialize_local
1042 super(type(self), self).save()
1043
1044
1045 @classmethod
Jakob Juelichf88fa932014-09-03 17:58:04 -07001046 def deserialize(cls, data):
Jakob Juelich116ff0f2014-09-17 18:25:16 -07001047 """Recursively deserializes and saves an object with it's dependencies.
Jakob Juelichf88fa932014-09-03 17:58:04 -07001048
1049 This takes the result of the serialize method and creates objects
Jakob Juelich116ff0f2014-09-17 18:25:16 -07001050 in the database that are just like the original.
1051
1052 If an object of the same type with the same id already exists, it's
Jakob Juelichf865d332014-09-29 10:47:49 -07001053 local values will be left untouched, unless they are explicitly
1054 whitelisted in SERIALIZATION_LOCAL_LINKS_TO_UPDATE.
1055
1056 Deserialize will always recursively propagate to all related objects
Jakob Juelich116ff0f2014-09-17 18:25:16 -07001057 present in data though.
1058 I.e. this is necessary to add users to an already existing acl-group.
Jakob Juelichf88fa932014-09-03 17:58:04 -07001059
1060 @param data: Representation of an object and its dependencies, as
1061 returned by serialize.
1062
1063 @returns: The object represented by data if it didn't exist before,
1064 otherwise the object that existed before and has the same type
1065 and id as the one described by data.
1066 """
1067 if data is None:
1068 return None
1069
Jakob Juelich116ff0f2014-09-17 18:25:16 -07001070 local, related = cls._split_local_from_foreign_values(data)
1071
Jakob Juelichf88fa932014-09-03 17:58:04 -07001072 try:
Jakob Juelich116ff0f2014-09-17 18:25:16 -07001073 instance = cls.objects.get(id=data['id'])
Jakob Juelichf865d332014-09-29 10:47:49 -07001074 local = cls._filter_update_allowed_fields(local)
Jakob Juelichf88fa932014-09-03 17:58:04 -07001075 except cls.DoesNotExist:
Jakob Juelich116ff0f2014-09-17 18:25:16 -07001076 instance = cls()
Jakob Juelichf88fa932014-09-03 17:58:04 -07001077
Jakob Juelichf865d332014-09-29 10:47:49 -07001078 instance._deserialize_local(local)
Jakob Juelich116ff0f2014-09-17 18:25:16 -07001079 instance._deserialize_relations(related)
Jakob Juelichf88fa932014-09-03 17:58:04 -07001080
1081 return instance
1082
1083
Jakob Juelicha94efe62014-09-18 16:02:49 -07001084 def sanity_check_update_from_shard(self, shard, updated_serialized,
1085 *args, **kwargs):
1086 """Check if an update sent from a shard is legitimate.
1087
1088 @raises error.UnallowedRecordsSentToMaster if an update is not
1089 legitimate.
1090 """
1091 raise NotImplementedError(
1092 'sanity_check_update_from_shard must be implemented by subclass %s '
1093 'for type %s' % type(self))
1094
1095
Jakob Juelich116ff0f2014-09-17 18:25:16 -07001096 def update_from_serialized(self, serialized):
1097 """Updates local fields of an existing object from a serialized form.
1098
1099 This is different than the normal deserialize() in the way that it
1100 does update local values, which deserialize doesn't, but doesn't
1101 recursively propagate to related objects, which deserialize() does.
1102
1103 The use case of this function is to update job records on the master
1104 after the jobs have been executed on a slave, as the master is not
1105 interested in updates for users, labels, specialtasks, etc.
1106
1107 @param serialized: Representation of an object and its dependencies, as
1108 returned by serialize.
1109
1110 @raises ValueError: if serialized contains related objects, i.e. not
1111 only local fields.
1112 """
1113 local, related = (
1114 self._split_local_from_foreign_values(serialized))
1115 if related:
1116 raise ValueError('Serialized must not contain foreign '
1117 'objects: %s' % related)
1118
1119 self._deserialize_local(local)
1120
1121
Jakob Juelichf88fa932014-09-03 17:58:04 -07001122 def custom_deserialize_relation(self, link, data):
Jakob Juelich116ff0f2014-09-17 18:25:16 -07001123 """Allows overriding the deserialization behaviour by subclasses."""
Jakob Juelichf88fa932014-09-03 17:58:04 -07001124 raise NotImplementedError(
1125 'custom_deserialize_relation must be implemented by subclass %s '
1126 'for relation %s' % (type(self), link))
1127
1128
1129 def _deserialize_relation(self, link, data):
Jakob Juelich116ff0f2014-09-17 18:25:16 -07001130 """Deserializes related objects and sets references on this object.
1131
1132 Relations that point to a list of objects are handled automatically.
1133 For many-to-one or one-to-one relations custom_deserialize_relation
1134 must be overridden by the subclass.
1135
1136 Related objects are deserialized using their deserialize() method.
1137 Thereby they and their dependencies are created if they don't exist
1138 and saved to the database.
1139
1140 @param link: Name of the relation.
1141 @param data: Serialized representation of the related object(s).
1142 This means a list of dictionaries for to-many relations,
1143 just a dictionary for to-one relations.
1144 """
Jakob Juelichf88fa932014-09-03 17:58:04 -07001145 field = getattr(self, link)
1146
1147 if field and hasattr(field, 'all'):
1148 self._deserialize_2m_relation(link, data, field.model)
1149 else:
1150 self.custom_deserialize_relation(link, data)
1151
1152
1153 def _deserialize_2m_relation(self, link, data, related_class):
Jakob Juelich116ff0f2014-09-17 18:25:16 -07001154 """Deserialize related objects for one to-many relationship.
1155
1156 @param link: Name of the relation.
1157 @param data: Serialized representation of the related objects.
1158 This is a list with of dictionaries.
1159 """
Jakob Juelichf88fa932014-09-03 17:58:04 -07001160 relation_set = getattr(self, link)
1161 for serialized in data:
1162 relation_set.add(related_class.deserialize(serialized))
1163
1164
showard7c785282008-05-29 19:45:12 +00001165class ModelWithInvalid(ModelExtensions):
jadmanski0afbb632008-06-06 21:10:57 +00001166 """
1167 Overrides model methods save() and delete() to support invalidation in
1168 place of actual deletion. Subclasses must have a boolean "invalid"
1169 field.
1170 """
showard7c785282008-05-29 19:45:12 +00001171
showarda5288b42009-07-28 20:06:08 +00001172 def save(self, *args, **kwargs):
showardddb90992009-02-11 23:39:32 +00001173 first_time = (self.id is None)
1174 if first_time:
1175 # see if this object was previously added and invalidated
1176 my_name = getattr(self, self.name_field)
1177 filters = {self.name_field : my_name, 'invalid' : True}
1178 try:
1179 old_object = self.__class__.objects.get(**filters)
showardafd97de2009-10-01 18:45:09 +00001180 self.resurrect_object(old_object)
showardddb90992009-02-11 23:39:32 +00001181 except self.DoesNotExist:
1182 # no existing object
1183 pass
showard7c785282008-05-29 19:45:12 +00001184
showarda5288b42009-07-28 20:06:08 +00001185 super(ModelWithInvalid, self).save(*args, **kwargs)
showard7c785282008-05-29 19:45:12 +00001186
1187
showardafd97de2009-10-01 18:45:09 +00001188 def resurrect_object(self, old_object):
1189 """
1190 Called when self is about to be saved for the first time and is actually
1191 "undeleting" a previously deleted object. Can be overridden by
1192 subclasses to copy data as desired from the deleted entry (but this
1193 superclass implementation must normally be called).
1194 """
1195 self.id = old_object.id
1196
1197
jadmanski0afbb632008-06-06 21:10:57 +00001198 def clean_object(self):
1199 """
1200 This method is called when an object is marked invalid.
1201 Subclasses should override this to clean up relationships that
showardafd97de2009-10-01 18:45:09 +00001202 should no longer exist if the object were deleted.
1203 """
jadmanski0afbb632008-06-06 21:10:57 +00001204 pass
showard7c785282008-05-29 19:45:12 +00001205
1206
jadmanski0afbb632008-06-06 21:10:57 +00001207 def delete(self):
Dale Curtis74a314b2011-06-23 14:55:46 -07001208 self.invalid = self.invalid
jadmanski0afbb632008-06-06 21:10:57 +00001209 assert not self.invalid
1210 self.invalid = True
1211 self.save()
1212 self.clean_object()
showard7c785282008-05-29 19:45:12 +00001213
1214
jadmanski0afbb632008-06-06 21:10:57 +00001215 @classmethod
1216 def get_valid_manager(cls):
1217 return cls.valid_objects
showard7c785282008-05-29 19:45:12 +00001218
1219
jadmanski0afbb632008-06-06 21:10:57 +00001220 class Manipulator(object):
1221 """
1222 Force default manipulators to look only at valid objects -
1223 otherwise they will match against invalid objects when checking
1224 uniqueness.
1225 """
1226 @classmethod
1227 def _prepare(cls, model):
1228 super(ModelWithInvalid.Manipulator, cls)._prepare(model)
1229 cls.manager = model.valid_objects
showardf8b19042009-05-12 17:22:49 +00001230
1231
1232class ModelWithAttributes(object):
1233 """
1234 Mixin class for models that have an attribute model associated with them.
1235 The attribute model is assumed to have its value field named "value".
1236 """
1237
1238 def _get_attribute_model_and_args(self, attribute):
1239 """
1240 Subclasses should override this to return a tuple (attribute_model,
1241 keyword_args), where attribute_model is a model class and keyword_args
1242 is a dict of args to pass to attribute_model.objects.get() to get an
1243 instance of the given attribute on this object.
1244 """
Dale Curtis74a314b2011-06-23 14:55:46 -07001245 raise NotImplementedError
showardf8b19042009-05-12 17:22:49 +00001246
1247
1248 def set_attribute(self, attribute, value):
1249 attribute_model, get_args = self._get_attribute_model_and_args(
1250 attribute)
1251 attribute_object, _ = attribute_model.objects.get_or_create(**get_args)
1252 attribute_object.value = value
1253 attribute_object.save()
1254
1255
1256 def delete_attribute(self, attribute):
1257 attribute_model, get_args = self._get_attribute_model_and_args(
1258 attribute)
1259 try:
1260 attribute_model.objects.get(**get_args).delete()
showard16245422009-09-08 16:28:15 +00001261 except attribute_model.DoesNotExist:
showardf8b19042009-05-12 17:22:49 +00001262 pass
1263
1264
1265 def set_or_delete_attribute(self, attribute, value):
1266 if value is None:
1267 self.delete_attribute(attribute)
1268 else:
1269 self.set_attribute(attribute, value)
showard26b7ec72009-12-21 22:43:57 +00001270
1271
1272class ModelWithHashManager(dbmodels.Manager):
1273 """Manager for use with the ModelWithHash abstract model class"""
1274
1275 def create(self, **kwargs):
1276 raise Exception('ModelWithHash manager should use get_or_create() '
1277 'instead of create()')
1278
1279
1280 def get_or_create(self, **kwargs):
1281 kwargs['the_hash'] = self.model._compute_hash(**kwargs)
1282 return super(ModelWithHashManager, self).get_or_create(**kwargs)
1283
1284
1285class ModelWithHash(dbmodels.Model):
1286 """Superclass with methods for dealing with a hash column"""
1287
1288 the_hash = dbmodels.CharField(max_length=40, unique=True)
1289
1290 objects = ModelWithHashManager()
1291
1292 class Meta:
1293 abstract = True
1294
1295
1296 @classmethod
1297 def _compute_hash(cls, **kwargs):
1298 raise NotImplementedError('Subclasses must override _compute_hash()')
1299
1300
1301 def save(self, force_insert=False, **kwargs):
1302 """Prevents saving the model in most cases
1303
1304 We want these models to be immutable, so the generic save() operation
1305 will not work. These models should be instantiated through their the
1306 model.objects.get_or_create() method instead.
1307
1308 The exception is that save(force_insert=True) will be allowed, since
1309 that creates a new row. However, the preferred way to make instances of
1310 these models is through the get_or_create() method.
1311 """
1312 if not force_insert:
1313 # Allow a forced insert to happen; if it's a duplicate, the unique
1314 # constraint will catch it later anyways
1315 raise Exception('ModelWithHash is immutable')
1316 super(ModelWithHash, self).save(force_insert=force_insert, **kwargs)