blob: a0923baae18b456768efebce29de6339aec21185 [file] [log] [blame]
showard7c785282008-05-29 19:45:12 +00001"""
2Extensions to Django's model logic.
3"""
4
showarda5288b42009-07-28 20:06:08 +00005import django.core.exceptions
Prashanth Balasubramanian75be1d32014-11-25 18:03:09 -08006from django.db import backend
7from django.db import connection
8from django.db import connections
9from django.db import models as dbmodels
10from django.db import transaction
showarda5288b42009-07-28 20:06:08 +000011from django.db.models.sql import query
showard7e67b432010-01-20 01:13:04 +000012import django.db.models.sql.where
showard7c785282008-05-29 19:45:12 +000013from django.utils import datastructures
Prashanth B489b91d2014-03-15 12:17:16 -070014from autotest_lib.frontend.afe import rdb_model_extensions
showard7c785282008-05-29 19:45:12 +000015
Prashanth B489b91d2014-03-15 12:17:16 -070016
17class ValidationError(django.core.exceptions.ValidationError):
jadmanski0afbb632008-06-06 21:10:57 +000018 """\
showarda5288b42009-07-28 20:06:08 +000019 Data validation error in adding or updating an object. The associated
jadmanski0afbb632008-06-06 21:10:57 +000020 value is a dictionary mapping field names to error strings.
21 """
showard7c785282008-05-29 19:45:12 +000022
showarda5288b42009-07-28 20:06:08 +000023def _quote_name(name):
24 """Shorthand for connection.ops.quote_name()."""
25 return connection.ops.quote_name(name)
26
27
beepscc9fc702013-12-02 12:45:38 -080028class LeasedHostManager(dbmodels.Manager):
29 """Query manager for unleased, unlocked hosts.
30 """
31 def get_query_set(self):
32 return (super(LeasedHostManager, self).get_query_set().filter(
33 leased=0, locked=0))
34
35
showard7c785282008-05-29 19:45:12 +000036class ExtendedManager(dbmodels.Manager):
jadmanski0afbb632008-06-06 21:10:57 +000037 """\
38 Extended manager supporting subquery filtering.
39 """
showard7c785282008-05-29 19:45:12 +000040
showardf828c772010-01-25 21:49:42 +000041 class CustomQuery(query.Query):
showard7e67b432010-01-20 01:13:04 +000042 def __init__(self, *args, **kwargs):
showardf828c772010-01-25 21:49:42 +000043 super(ExtendedManager.CustomQuery, self).__init__(*args, **kwargs)
showard7e67b432010-01-20 01:13:04 +000044 self._custom_joins = []
45
46
showarda5288b42009-07-28 20:06:08 +000047 def clone(self, klass=None, **kwargs):
showardf828c772010-01-25 21:49:42 +000048 obj = super(ExtendedManager.CustomQuery, self).clone(klass)
showard7e67b432010-01-20 01:13:04 +000049 obj._custom_joins = list(self._custom_joins)
showarda5288b42009-07-28 20:06:08 +000050 return obj
showard08f981b2008-06-24 21:59:03 +000051
showard7e67b432010-01-20 01:13:04 +000052
53 def combine(self, rhs, connector):
showardf828c772010-01-25 21:49:42 +000054 super(ExtendedManager.CustomQuery, self).combine(rhs, connector)
showard7e67b432010-01-20 01:13:04 +000055 if hasattr(rhs, '_custom_joins'):
56 self._custom_joins.extend(rhs._custom_joins)
57
58
59 def add_custom_join(self, table, condition, join_type,
60 condition_values=(), alias=None):
61 if alias is None:
62 alias = table
63 join_dict = dict(table=table,
64 condition=condition,
65 condition_values=condition_values,
66 join_type=join_type,
67 alias=alias)
68 self._custom_joins.append(join_dict)
69
70
showard7e67b432010-01-20 01:13:04 +000071 @classmethod
72 def convert_query(self, query_set):
73 """
showardf828c772010-01-25 21:49:42 +000074 Convert the query set's "query" attribute to a CustomQuery.
showard7e67b432010-01-20 01:13:04 +000075 """
76 # Make a copy of the query set
77 query_set = query_set.all()
78 query_set.query = query_set.query.clone(
showardf828c772010-01-25 21:49:42 +000079 klass=ExtendedManager.CustomQuery,
showard7e67b432010-01-20 01:13:04 +000080 _custom_joins=[])
81 return query_set
showard43a3d262008-11-12 18:17:05 +000082
83
showard7e67b432010-01-20 01:13:04 +000084 class _WhereClause(object):
85 """Object allowing us to inject arbitrary SQL into Django queries.
showard43a3d262008-11-12 18:17:05 +000086
showard7e67b432010-01-20 01:13:04 +000087 By using this instead of extra(where=...), we can still freely combine
88 queries with & and |.
showarda5288b42009-07-28 20:06:08 +000089 """
showard7e67b432010-01-20 01:13:04 +000090 def __init__(self, clause, values=()):
91 self._clause = clause
92 self._values = values
showarda5288b42009-07-28 20:06:08 +000093
showard7e67b432010-01-20 01:13:04 +000094
Dale Curtis74a314b2011-06-23 14:55:46 -070095 def as_sql(self, qn=None, connection=None):
showard7e67b432010-01-20 01:13:04 +000096 return self._clause, self._values
97
98
99 def relabel_aliases(self, change_map):
100 return
showard43a3d262008-11-12 18:17:05 +0000101
102
showard8b0ea222009-12-23 19:23:03 +0000103 def add_join(self, query_set, join_table, join_key, join_condition='',
showard7e67b432010-01-20 01:13:04 +0000104 join_condition_values=(), join_from_key=None, alias=None,
105 suffix='', exclude=False, force_left_join=False):
106 """Add a join to query_set.
107
108 Join looks like this:
109 (INNER|LEFT) JOIN <join_table> AS <alias>
110 ON (<this table>.<join_from_key> = <join_table>.<join_key>
111 and <join_condition>)
112
showard0957a842009-05-11 19:25:08 +0000113 @param join_table table to join to
114 @param join_key field referencing back to this model to use for the join
115 @param join_condition extra condition for the ON clause of the join
showard7e67b432010-01-20 01:13:04 +0000116 @param join_condition_values values to substitute into join_condition
117 @param join_from_key column on this model to join from.
showard8b0ea222009-12-23 19:23:03 +0000118 @param alias alias to use for for join
119 @param suffix suffix to add to join_table for the join alias, if no
120 alias is provided
showard0957a842009-05-11 19:25:08 +0000121 @param exclude if true, exclude rows that match this join (will use a
showarda5288b42009-07-28 20:06:08 +0000122 LEFT OUTER JOIN and an appropriate WHERE condition)
showardc4780402009-08-31 18:31:34 +0000123 @param force_left_join - if true, a LEFT OUTER JOIN will be used
124 instead of an INNER JOIN regardless of other options
showard0957a842009-05-11 19:25:08 +0000125 """
showard7e67b432010-01-20 01:13:04 +0000126 join_from_table = query_set.model._meta.db_table
127 if join_from_key is None:
128 join_from_key = self.model._meta.pk.name
129 if alias is None:
130 alias = join_table + suffix
131 full_join_key = _quote_name(alias) + '.' + _quote_name(join_key)
132 full_join_condition = '%s = %s.%s' % (full_join_key,
133 _quote_name(join_from_table),
134 _quote_name(join_from_key))
showard43a3d262008-11-12 18:17:05 +0000135 if join_condition:
136 full_join_condition += ' AND (' + join_condition + ')'
137 if exclude or force_left_join:
showarda5288b42009-07-28 20:06:08 +0000138 join_type = query_set.query.LOUTER
showard43a3d262008-11-12 18:17:05 +0000139 else:
showarda5288b42009-07-28 20:06:08 +0000140 join_type = query_set.query.INNER
showard43a3d262008-11-12 18:17:05 +0000141
showardf828c772010-01-25 21:49:42 +0000142 query_set = self.CustomQuery.convert_query(query_set)
showard7e67b432010-01-20 01:13:04 +0000143 query_set.query.add_custom_join(join_table,
144 full_join_condition,
145 join_type,
146 condition_values=join_condition_values,
147 alias=alias)
showard43a3d262008-11-12 18:17:05 +0000148
showard7e67b432010-01-20 01:13:04 +0000149 if exclude:
150 query_set = query_set.extra(where=[full_join_key + ' IS NULL'])
151
152 return query_set
153
154
155 def _info_for_many_to_one_join(self, field, join_to_query, alias):
156 """
157 @param field: the ForeignKey field on the related model
158 @param join_to_query: the query over the related model that we're
159 joining to
160 @param alias: alias of joined table
161 """
162 info = {}
163 rhs_table = join_to_query.model._meta.db_table
164 info['rhs_table'] = rhs_table
165 info['rhs_column'] = field.column
166 info['lhs_column'] = field.rel.get_related_field().column
167 rhs_where = join_to_query.query.where
168 rhs_where.relabel_aliases({rhs_table: alias})
Dale Curtis74a314b2011-06-23 14:55:46 -0700169 compiler = join_to_query.query.get_compiler(using=join_to_query.db)
170 initial_clause, values = compiler.as_sql()
171 all_clauses = (initial_clause,)
172 if hasattr(join_to_query.query, 'extra_where'):
173 all_clauses += join_to_query.query.extra_where
174 info['where_clause'] = (
175 ' AND '.join('(%s)' % clause for clause in all_clauses))
showard7e67b432010-01-20 01:13:04 +0000176 info['values'] = values
177 return info
178
179
180 def _info_for_many_to_many_join(self, m2m_field, join_to_query, alias,
181 m2m_is_on_this_model):
182 """
183 @param m2m_field: a Django field representing the M2M relationship.
184 It uses a pivot table with the following structure:
185 this model table <---> M2M pivot table <---> joined model table
186 @param join_to_query: the query over the related model that we're
187 joining to.
188 @param alias: alias of joined table
189 """
190 if m2m_is_on_this_model:
191 # referenced field on this model
192 lhs_id_field = self.model._meta.pk
193 # foreign key on the pivot table referencing lhs_id_field
194 m2m_lhs_column = m2m_field.m2m_column_name()
195 # foreign key on the pivot table referencing rhd_id_field
196 m2m_rhs_column = m2m_field.m2m_reverse_name()
197 # referenced field on related model
198 rhs_id_field = m2m_field.rel.get_related_field()
199 else:
200 lhs_id_field = m2m_field.rel.get_related_field()
201 m2m_lhs_column = m2m_field.m2m_reverse_name()
202 m2m_rhs_column = m2m_field.m2m_column_name()
203 rhs_id_field = join_to_query.model._meta.pk
204
205 info = {}
206 info['rhs_table'] = m2m_field.m2m_db_table()
207 info['rhs_column'] = m2m_lhs_column
208 info['lhs_column'] = lhs_id_field.column
209
210 # select the ID of related models relevant to this join. we can only do
211 # a single join, so we need to gather this information up front and
212 # include it in the join condition.
213 rhs_ids = join_to_query.values_list(rhs_id_field.attname, flat=True)
214 assert len(rhs_ids) == 1, ('Many-to-many custom field joins can only '
215 'match a single related object.')
216 rhs_id = rhs_ids[0]
217
218 info['where_clause'] = '%s.%s = %s' % (_quote_name(alias),
219 _quote_name(m2m_rhs_column),
220 rhs_id)
221 info['values'] = ()
222 return info
223
224
225 def join_custom_field(self, query_set, join_to_query, alias,
226 left_join=True):
227 """Join to a related model to create a custom field in the given query.
228
229 This method is used to construct a custom field on the given query based
230 on a many-valued relationsip. join_to_query should be a simple query
231 (no joins) on the related model which returns at most one related row
232 per instance of this model.
233
234 For many-to-one relationships, the joined table contains the matching
235 row from the related model it one is related, NULL otherwise.
236
237 For many-to-many relationships, the joined table contains the matching
238 row if it's related, NULL otherwise.
239 """
240 relationship_type, field = self.determine_relationship(
241 join_to_query.model)
242
243 if relationship_type == self.MANY_TO_ONE:
244 info = self._info_for_many_to_one_join(field, join_to_query, alias)
245 elif relationship_type == self.M2M_ON_RELATED_MODEL:
246 info = self._info_for_many_to_many_join(
247 m2m_field=field, join_to_query=join_to_query, alias=alias,
248 m2m_is_on_this_model=False)
249 elif relationship_type ==self.M2M_ON_THIS_MODEL:
250 info = self._info_for_many_to_many_join(
251 m2m_field=field, join_to_query=join_to_query, alias=alias,
252 m2m_is_on_this_model=True)
253
254 return self.add_join(query_set, info['rhs_table'], info['rhs_column'],
255 join_from_key=info['lhs_column'],
256 join_condition=info['where_clause'],
257 join_condition_values=info['values'],
258 alias=alias,
259 force_left_join=left_join)
260
261
showardf828c772010-01-25 21:49:42 +0000262 def key_on_joined_table(self, join_to_query):
263 """Get a non-null column on the table joined for the given query.
264
265 This analyzes the join that would be produced if join_to_query were
266 passed to join_custom_field.
267 """
268 relationship_type, field = self.determine_relationship(
269 join_to_query.model)
270 if relationship_type == self.MANY_TO_ONE:
271 return join_to_query.model._meta.pk.column
272 return field.m2m_column_name() # any column on the M2M table will do
273
274
showard7e67b432010-01-20 01:13:04 +0000275 def add_where(self, query_set, where, values=()):
276 query_set = query_set.all()
277 query_set.query.where.add(self._WhereClause(where, values),
278 django.db.models.sql.where.AND)
showardc4780402009-08-31 18:31:34 +0000279 return query_set
showard7c785282008-05-29 19:45:12 +0000280
281
showardeaccf8f2009-04-16 03:11:33 +0000282 def _get_quoted_field(self, table, field):
showarda5288b42009-07-28 20:06:08 +0000283 return _quote_name(table) + '.' + _quote_name(field)
showard5ef36e92008-07-02 16:37:09 +0000284
285
showard7c199df2008-10-03 10:17:15 +0000286 def get_key_on_this_table(self, key_field=None):
showard5ef36e92008-07-02 16:37:09 +0000287 if key_field is None:
288 # default to primary key
289 key_field = self.model._meta.pk.column
290 return self._get_quoted_field(self.model._meta.db_table, key_field)
291
292
showardeaccf8f2009-04-16 03:11:33 +0000293 def escape_user_sql(self, sql):
294 return sql.replace('%', '%%')
295
showard5ef36e92008-07-02 16:37:09 +0000296
showard0957a842009-05-11 19:25:08 +0000297 def _custom_select_query(self, query_set, selects):
Jakob Juelich7bef8412014-10-14 19:11:54 -0700298 """Execute a custom select query.
299
300 @param query_set: query set as returned by query_objects.
301 @param selects: Tables/Columns to select, e.g. tko_test_labels_list.id.
302
303 @returns: Result of the query as returned by cursor.fetchall().
304 """
Dale Curtis74a314b2011-06-23 14:55:46 -0700305 compiler = query_set.query.get_compiler(using=query_set.db)
306 sql, params = compiler.as_sql()
showarda5288b42009-07-28 20:06:08 +0000307 from_ = sql[sql.find(' FROM'):]
308
309 if query_set.query.distinct:
showard0957a842009-05-11 19:25:08 +0000310 distinct = 'DISTINCT '
311 else:
312 distinct = ''
showarda5288b42009-07-28 20:06:08 +0000313
314 sql_query = ('SELECT ' + distinct + ','.join(selects) + from_)
Jakob Juelich7bef8412014-10-14 19:11:54 -0700315 # Chose the connection that's responsible for this type of object
316 cursor = connections[query_set.db].cursor()
showard0957a842009-05-11 19:25:08 +0000317 cursor.execute(sql_query, params)
318 return cursor.fetchall()
319
320
showard68693f72009-05-20 00:31:53 +0000321 def _is_relation_to(self, field, model_class):
322 return field.rel and field.rel.to is model_class
showard0957a842009-05-11 19:25:08 +0000323
324
showard7e67b432010-01-20 01:13:04 +0000325 MANY_TO_ONE = object()
326 M2M_ON_RELATED_MODEL = object()
327 M2M_ON_THIS_MODEL = object()
328
329 def determine_relationship(self, related_model):
330 """
331 Determine the relationship between this model and related_model.
332
333 related_model must have some sort of many-valued relationship to this
334 manager's model.
335 @returns (relationship_type, field), where relationship_type is one of
336 MANY_TO_ONE, M2M_ON_RELATED_MODEL, M2M_ON_THIS_MODEL, and field
337 is the Django field object for the relationship.
338 """
339 # look for a foreign key field on related_model relating to this model
340 for field in related_model._meta.fields:
341 if self._is_relation_to(field, self.model):
342 return self.MANY_TO_ONE, field
343
344 # look for an M2M field on related_model relating to this model
345 for field in related_model._meta.many_to_many:
346 if self._is_relation_to(field, self.model):
347 return self.M2M_ON_RELATED_MODEL, field
348
349 # maybe this model has the many-to-many field
350 for field in self.model._meta.many_to_many:
351 if self._is_relation_to(field, related_model):
352 return self.M2M_ON_THIS_MODEL, field
353
354 raise ValueError('%s has no relation to %s' %
355 (related_model, self.model))
356
357
showard68693f72009-05-20 00:31:53 +0000358 def _get_pivot_iterator(self, base_objects_by_id, related_model):
showard0957a842009-05-11 19:25:08 +0000359 """
showard68693f72009-05-20 00:31:53 +0000360 Determine the relationship between this model and related_model, and
361 return a pivot iterator.
362 @param base_objects_by_id: dict of instances of this model indexed by
363 their IDs
364 @returns a pivot iterator, which yields a tuple (base_object,
365 related_object) for each relationship between a base object and a
366 related object. all base_object instances come from base_objects_by_id.
showard7e67b432010-01-20 01:13:04 +0000367 Note -- this depends on Django model internals.
showard0957a842009-05-11 19:25:08 +0000368 """
showard7e67b432010-01-20 01:13:04 +0000369 relationship_type, field = self.determine_relationship(related_model)
370 if relationship_type == self.MANY_TO_ONE:
371 return self._many_to_one_pivot(base_objects_by_id,
372 related_model, field)
373 elif relationship_type == self.M2M_ON_RELATED_MODEL:
374 return self._many_to_many_pivot(
showard68693f72009-05-20 00:31:53 +0000375 base_objects_by_id, related_model, field.m2m_db_table(),
376 field.m2m_reverse_name(), field.m2m_column_name())
showard7e67b432010-01-20 01:13:04 +0000377 else:
378 assert relationship_type == self.M2M_ON_THIS_MODEL
379 return self._many_to_many_pivot(
showard68693f72009-05-20 00:31:53 +0000380 base_objects_by_id, related_model, field.m2m_db_table(),
381 field.m2m_column_name(), field.m2m_reverse_name())
showard0957a842009-05-11 19:25:08 +0000382
showard0957a842009-05-11 19:25:08 +0000383
showard68693f72009-05-20 00:31:53 +0000384 def _many_to_one_pivot(self, base_objects_by_id, related_model,
385 foreign_key_field):
386 """
387 @returns a pivot iterator - see _get_pivot_iterator()
388 """
389 filter_data = {foreign_key_field.name + '__pk__in':
390 base_objects_by_id.keys()}
391 for related_object in related_model.objects.filter(**filter_data):
showarda5a72c92009-08-20 23:35:21 +0000392 # lookup base object in the dict, rather than grabbing it from the
393 # related object. we need to return instances from the dict, not
394 # fresh instances of the same models (and grabbing model instances
395 # from the related models incurs a DB query each time).
396 base_object_id = getattr(related_object, foreign_key_field.attname)
397 base_object = base_objects_by_id[base_object_id]
showard68693f72009-05-20 00:31:53 +0000398 yield base_object, related_object
399
400
401 def _query_pivot_table(self, base_objects_by_id, pivot_table,
Jakob Juelich7bef8412014-10-14 19:11:54 -0700402 pivot_from_field, pivot_to_field, related_model):
showard0957a842009-05-11 19:25:08 +0000403 """
404 @param id_list list of IDs of self.model objects to include
405 @param pivot_table the name of the pivot table
406 @param pivot_from_field a field name on pivot_table referencing
407 self.model
408 @param pivot_to_field a field name on pivot_table referencing the
409 related model.
Jakob Juelich7bef8412014-10-14 19:11:54 -0700410 @param related_model the related model
411
showard68693f72009-05-20 00:31:53 +0000412 @returns pivot list of IDs (base_id, related_id)
showard0957a842009-05-11 19:25:08 +0000413 """
414 query = """
415 SELECT %(from_field)s, %(to_field)s
416 FROM %(table)s
417 WHERE %(from_field)s IN (%(id_list)s)
418 """ % dict(from_field=pivot_from_field,
419 to_field=pivot_to_field,
420 table=pivot_table,
showard68693f72009-05-20 00:31:53 +0000421 id_list=','.join(str(id_) for id_
422 in base_objects_by_id.iterkeys()))
Jakob Juelich7bef8412014-10-14 19:11:54 -0700423
424 # Chose the connection that's responsible for this type of object
425 # The databases for related_model and the current model will always
426 # be the same, related_model is just easier to obtain here because
427 # self is only a ExtendedManager, not the object.
428 cursor = connections[related_model.objects.db].cursor()
showard0957a842009-05-11 19:25:08 +0000429 cursor.execute(query)
showard68693f72009-05-20 00:31:53 +0000430 return cursor.fetchall()
showard0957a842009-05-11 19:25:08 +0000431
432
showard68693f72009-05-20 00:31:53 +0000433 def _many_to_many_pivot(self, base_objects_by_id, related_model,
434 pivot_table, pivot_from_field, pivot_to_field):
435 """
436 @param pivot_table: see _query_pivot_table
437 @param pivot_from_field: see _query_pivot_table
438 @param pivot_to_field: see _query_pivot_table
439 @returns a pivot iterator - see _get_pivot_iterator()
440 """
441 id_pivot = self._query_pivot_table(base_objects_by_id, pivot_table,
Jakob Juelich7bef8412014-10-14 19:11:54 -0700442 pivot_from_field, pivot_to_field,
443 related_model)
showard68693f72009-05-20 00:31:53 +0000444
445 all_related_ids = list(set(related_id for base_id, related_id
446 in id_pivot))
447 related_objects_by_id = related_model.objects.in_bulk(all_related_ids)
448
449 for base_id, related_id in id_pivot:
450 yield base_objects_by_id[base_id], related_objects_by_id[related_id]
451
452
453 def populate_relationships(self, base_objects, related_model,
showard0957a842009-05-11 19:25:08 +0000454 related_list_name):
455 """
showard68693f72009-05-20 00:31:53 +0000456 For each instance of this model in base_objects, add a field named
457 related_list_name listing all the related objects of type related_model.
458 related_model must be in a many-to-one or many-to-many relationship with
459 this model.
460 @param base_objects - list of instances of this model
461 @param related_model - model class related to this model
462 @param related_list_name - attribute name in which to store the related
463 object list.
showard0957a842009-05-11 19:25:08 +0000464 """
showard68693f72009-05-20 00:31:53 +0000465 if not base_objects:
showard0957a842009-05-11 19:25:08 +0000466 # if we don't bail early, we'll get a SQL error later
467 return
showard0957a842009-05-11 19:25:08 +0000468
showard68693f72009-05-20 00:31:53 +0000469 base_objects_by_id = dict((base_object._get_pk_val(), base_object)
470 for base_object in base_objects)
471 pivot_iterator = self._get_pivot_iterator(base_objects_by_id,
472 related_model)
showard0957a842009-05-11 19:25:08 +0000473
showard68693f72009-05-20 00:31:53 +0000474 for base_object in base_objects:
475 setattr(base_object, related_list_name, [])
476
477 for base_object, related_object in pivot_iterator:
478 getattr(base_object, related_list_name).append(related_object)
showard0957a842009-05-11 19:25:08 +0000479
480
jamesrene3656232010-03-02 00:00:30 +0000481class ModelWithInvalidQuerySet(dbmodels.query.QuerySet):
482 """
483 QuerySet that handles delete() properly for models with an "invalid" bit
484 """
485 def delete(self):
486 for model in self:
487 model.delete()
488
489
490class ModelWithInvalidManager(ExtendedManager):
491 """
492 Manager for objects with an "invalid" bit
493 """
494 def get_query_set(self):
495 return ModelWithInvalidQuerySet(self.model)
496
497
498class ValidObjectsManager(ModelWithInvalidManager):
jadmanski0afbb632008-06-06 21:10:57 +0000499 """
500 Manager returning only objects with invalid=False.
501 """
502 def get_query_set(self):
503 queryset = super(ValidObjectsManager, self).get_query_set()
504 return queryset.filter(invalid=False)
showard7c785282008-05-29 19:45:12 +0000505
506
Prashanth B489b91d2014-03-15 12:17:16 -0700507class ModelExtensions(rdb_model_extensions.ModelValidators):
jadmanski0afbb632008-06-06 21:10:57 +0000508 """\
Prashanth B489b91d2014-03-15 12:17:16 -0700509 Mixin with convenience functions for models, built on top of
510 the model validators in rdb_model_extensions.
jadmanski0afbb632008-06-06 21:10:57 +0000511 """
512 # TODO: at least some of these functions really belong in a custom
513 # Manager class
showard7c785282008-05-29 19:45:12 +0000514
Jakob Juelich3bb7c802014-09-02 16:31:11 -0700515
516 SERIALIZATION_LINKS_TO_FOLLOW = set()
517 """
518 To be able to send jobs and hosts to shards, it's necessary to find their
519 dependencies.
520 The most generic approach for this would be to traverse all relationships
521 to other objects recursively. This would list all objects that are related
522 in any way.
523 But this approach finds too many objects: If a host should be transferred,
524 all it's relationships would be traversed. This would find an acl group.
525 If then the acl group's relationships are traversed, the relationship
526 would be followed backwards and many other hosts would be found.
527
528 This mapping tells that algorithm which relations to follow explicitly.
529 """
530
Jakob Juelichf865d332014-09-29 10:47:49 -0700531
Fang Deng86248502014-12-18 16:38:00 -0800532 SERIALIZATION_LINKS_TO_KEEP = set()
533 """This set stores foreign keys which we don't want to follow, but
534 still want to include in the serialized dictionary. For
535 example, we follow the relationship `Host.hostattribute_set`,
536 but we do not want to follow `HostAttributes.host_id` back to
537 to Host, which would otherwise lead to a circle. However, we still
538 like to serialize HostAttribute.`host_id`."""
539
Jakob Juelichf865d332014-09-29 10:47:49 -0700540 SERIALIZATION_LOCAL_LINKS_TO_UPDATE = set()
541 """
542 On deserializion, if the object to persist already exists, local fields
543 will only be updated, if their name is in this set.
544 """
545
546
jadmanski0afbb632008-06-06 21:10:57 +0000547 @classmethod
548 def convert_human_readable_values(cls, data, to_human_readable=False):
549 """\
550 Performs conversions on user-supplied field data, to make it
551 easier for users to pass human-readable data.
showard7c785282008-05-29 19:45:12 +0000552
jadmanski0afbb632008-06-06 21:10:57 +0000553 For all fields that have choice sets, convert their values
554 from human-readable strings to enum values, if necessary. This
555 allows users to pass strings instead of the corresponding
556 integer values.
showard7c785282008-05-29 19:45:12 +0000557
jadmanski0afbb632008-06-06 21:10:57 +0000558 For all foreign key fields, call smart_get with the supplied
559 data. This allows the user to pass either an ID value or
560 the name of the object as a string.
showard7c785282008-05-29 19:45:12 +0000561
jadmanski0afbb632008-06-06 21:10:57 +0000562 If to_human_readable=True, perform the inverse - i.e. convert
563 numeric values to human readable values.
showard7c785282008-05-29 19:45:12 +0000564
jadmanski0afbb632008-06-06 21:10:57 +0000565 This method modifies data in-place.
566 """
567 field_dict = cls.get_field_dict()
568 for field_name in data:
showarde732ee72008-09-23 19:15:43 +0000569 if field_name not in field_dict or data[field_name] is None:
jadmanski0afbb632008-06-06 21:10:57 +0000570 continue
571 field_obj = field_dict[field_name]
572 # convert enum values
573 if field_obj.choices:
574 for choice_data in field_obj.choices:
575 # choice_data is (value, name)
576 if to_human_readable:
577 from_val, to_val = choice_data
578 else:
579 to_val, from_val = choice_data
580 if from_val == data[field_name]:
581 data[field_name] = to_val
582 break
583 # convert foreign key values
584 elif field_obj.rel:
showarda4ea5742009-02-17 20:56:23 +0000585 dest_obj = field_obj.rel.to.smart_get(data[field_name],
586 valid_only=False)
showardf8b19042009-05-12 17:22:49 +0000587 if to_human_readable:
Paul Pendlebury5a8c6ad2011-02-01 07:20:17 -0800588 # parameterized_jobs do not have a name_field
589 if (field_name != 'parameterized_job' and
590 dest_obj.name_field is not None):
showardf8b19042009-05-12 17:22:49 +0000591 data[field_name] = getattr(dest_obj,
592 dest_obj.name_field)
jadmanski0afbb632008-06-06 21:10:57 +0000593 else:
showardb0a73032009-03-27 18:35:41 +0000594 data[field_name] = dest_obj
showard7c785282008-05-29 19:45:12 +0000595
596
showard7c785282008-05-29 19:45:12 +0000597
598
Dale Curtis74a314b2011-06-23 14:55:46 -0700599 def _validate_unique(self):
jadmanski0afbb632008-06-06 21:10:57 +0000600 """\
601 Validate that unique fields are unique. Django manipulators do
602 this too, but they're a huge pain to use manually. Trust me.
603 """
604 errors = {}
605 cls = type(self)
606 field_dict = self.get_field_dict()
607 manager = cls.get_valid_manager()
608 for field_name, field_obj in field_dict.iteritems():
609 if not field_obj.unique:
610 continue
showard7c785282008-05-29 19:45:12 +0000611
jadmanski0afbb632008-06-06 21:10:57 +0000612 value = getattr(self, field_name)
showardbd18ab72009-09-18 21:20:27 +0000613 if value is None and field_obj.auto_created:
614 # don't bother checking autoincrement fields about to be
615 # generated
616 continue
617
jadmanski0afbb632008-06-06 21:10:57 +0000618 existing_objs = manager.filter(**{field_name : value})
619 num_existing = existing_objs.count()
showard7c785282008-05-29 19:45:12 +0000620
jadmanski0afbb632008-06-06 21:10:57 +0000621 if num_existing == 0:
622 continue
623 if num_existing == 1 and existing_objs[0].id == self.id:
624 continue
625 errors[field_name] = (
626 'This value must be unique (%s)' % (value))
627 return errors
showard7c785282008-05-29 19:45:12 +0000628
629
showarda5288b42009-07-28 20:06:08 +0000630 def _validate(self):
631 """
632 First coerces all fields on this instance to their proper Python types.
633 Then runs validation on every field. Returns a dictionary of
634 field_name -> error_list.
635
636 Based on validate() from django.db.models.Model in Django 0.96, which
637 was removed in Django 1.0. It should reappear in a later version. See:
638 http://code.djangoproject.com/ticket/6845
639 """
640 error_dict = {}
641 for f in self._meta.fields:
642 try:
643 python_value = f.to_python(
644 getattr(self, f.attname, f.get_default()))
645 except django.core.exceptions.ValidationError, e:
jamesren1e0a4ce2010-04-21 17:45:11 +0000646 error_dict[f.name] = str(e)
showarda5288b42009-07-28 20:06:08 +0000647 continue
648
649 if not f.blank and not python_value:
650 error_dict[f.name] = 'This field is required.'
651 continue
652
653 setattr(self, f.attname, python_value)
654
655 return error_dict
656
657
jadmanski0afbb632008-06-06 21:10:57 +0000658 def do_validate(self):
showarda5288b42009-07-28 20:06:08 +0000659 errors = self._validate()
Dale Curtis74a314b2011-06-23 14:55:46 -0700660 unique_errors = self._validate_unique()
jadmanski0afbb632008-06-06 21:10:57 +0000661 for field_name, error in unique_errors.iteritems():
662 errors.setdefault(field_name, error)
663 if errors:
664 raise ValidationError(errors)
showard7c785282008-05-29 19:45:12 +0000665
666
jadmanski0afbb632008-06-06 21:10:57 +0000667 # actually (externally) useful methods follow
showard7c785282008-05-29 19:45:12 +0000668
jadmanski0afbb632008-06-06 21:10:57 +0000669 @classmethod
670 def add_object(cls, data={}, **kwargs):
671 """\
672 Returns a new object created with the given data (a dictionary
673 mapping field names to values). Merges any extra keyword args
674 into data.
675 """
Prashanth B489b91d2014-03-15 12:17:16 -0700676 data = dict(data)
677 data.update(kwargs)
678 data = cls.prepare_data_args(data)
679 cls.convert_human_readable_values(data)
jadmanski0afbb632008-06-06 21:10:57 +0000680 data = cls.provide_default_values(data)
Prashanth B489b91d2014-03-15 12:17:16 -0700681
jadmanski0afbb632008-06-06 21:10:57 +0000682 obj = cls(**data)
683 obj.do_validate()
684 obj.save()
685 return obj
showard7c785282008-05-29 19:45:12 +0000686
687
jadmanski0afbb632008-06-06 21:10:57 +0000688 def update_object(self, data={}, **kwargs):
689 """\
690 Updates the object with the given data (a dictionary mapping
691 field names to values). Merges any extra keyword args into
692 data.
693 """
Prashanth B489b91d2014-03-15 12:17:16 -0700694 data = dict(data)
695 data.update(kwargs)
696 data = self.prepare_data_args(data)
697 self.convert_human_readable_values(data)
jadmanski0afbb632008-06-06 21:10:57 +0000698 for field_name, value in data.iteritems():
showardb0a73032009-03-27 18:35:41 +0000699 setattr(self, field_name, value)
jadmanski0afbb632008-06-06 21:10:57 +0000700 self.do_validate()
701 self.save()
showard7c785282008-05-29 19:45:12 +0000702
703
showard8bfb5cb2009-10-07 20:49:15 +0000704 # see query_objects()
705 _SPECIAL_FILTER_KEYS = ('query_start', 'query_limit', 'sort_by',
706 'extra_args', 'extra_where', 'no_distinct')
707
708
jadmanski0afbb632008-06-06 21:10:57 +0000709 @classmethod
showard8bfb5cb2009-10-07 20:49:15 +0000710 def _extract_special_params(cls, filter_data):
711 """
712 @returns a tuple of dicts (special_params, regular_filters), where
713 special_params contains the parameters we handle specially and
714 regular_filters is the remaining data to be handled by Django.
715 """
716 regular_filters = dict(filter_data)
717 special_params = {}
718 for key in cls._SPECIAL_FILTER_KEYS:
719 if key in regular_filters:
720 special_params[key] = regular_filters.pop(key)
721 return special_params, regular_filters
722
723
724 @classmethod
725 def apply_presentation(cls, query, filter_data):
726 """
727 Apply presentation parameters -- sorting and paging -- to the given
728 query.
729 @returns new query with presentation applied
730 """
731 special_params, _ = cls._extract_special_params(filter_data)
732 sort_by = special_params.get('sort_by', None)
733 if sort_by:
734 assert isinstance(sort_by, list) or isinstance(sort_by, tuple)
showard8b0ea222009-12-23 19:23:03 +0000735 query = query.extra(order_by=sort_by)
showard8bfb5cb2009-10-07 20:49:15 +0000736
737 query_start = special_params.get('query_start', None)
738 query_limit = special_params.get('query_limit', None)
739 if query_start is not None:
740 if query_limit is None:
741 raise ValueError('Cannot pass query_start without query_limit')
742 # query_limit is passed as a page size
showard7074b742009-10-12 20:30:04 +0000743 query_limit += query_start
744 return query[query_start:query_limit]
showard8bfb5cb2009-10-07 20:49:15 +0000745
746
747 @classmethod
748 def query_objects(cls, filter_data, valid_only=True, initial_query=None,
749 apply_presentation=True):
jadmanski0afbb632008-06-06 21:10:57 +0000750 """\
751 Returns a QuerySet object for querying the given model_class
752 with the given filter_data. Optional special arguments in
753 filter_data include:
754 -query_start: index of first return to return
755 -query_limit: maximum number of results to return
756 -sort_by: list of fields to sort on. prefixing a '-' onto a
757 field name changes the sort to descending order.
758 -extra_args: keyword args to pass to query.extra() (see Django
759 DB layer documentation)
showarda5288b42009-07-28 20:06:08 +0000760 -extra_where: extra WHERE clause to append
showard8bfb5cb2009-10-07 20:49:15 +0000761 -no_distinct: if True, a DISTINCT will not be added to the SELECT
jadmanski0afbb632008-06-06 21:10:57 +0000762 """
showard8bfb5cb2009-10-07 20:49:15 +0000763 special_params, regular_filters = cls._extract_special_params(
764 filter_data)
showard7c785282008-05-29 19:45:12 +0000765
showard7ac7b7a2008-07-21 20:24:29 +0000766 if initial_query is None:
767 if valid_only:
768 initial_query = cls.get_valid_manager()
769 else:
770 initial_query = cls.objects
showard8bfb5cb2009-10-07 20:49:15 +0000771
772 query = initial_query.filter(**regular_filters)
773
774 use_distinct = not special_params.get('no_distinct', False)
showard7ac7b7a2008-07-21 20:24:29 +0000775 if use_distinct:
776 query = query.distinct()
showard7c785282008-05-29 19:45:12 +0000777
showard8bfb5cb2009-10-07 20:49:15 +0000778 extra_args = special_params.get('extra_args', {})
779 extra_where = special_params.get('extra_where', None)
780 if extra_where:
781 # escape %'s
782 extra_where = cls.objects.escape_user_sql(extra_where)
783 extra_args.setdefault('where', []).append(extra_where)
jadmanski0afbb632008-06-06 21:10:57 +0000784 if extra_args:
785 query = query.extra(**extra_args)
Jakob Juelich7bef8412014-10-14 19:11:54 -0700786 # TODO: Use readonly connection for these queries.
787 # This has been disabled, because it's not used anyway, as the
788 # configured readonly user is the same as the real user anyway.
showard7c785282008-05-29 19:45:12 +0000789
showard8bfb5cb2009-10-07 20:49:15 +0000790 if apply_presentation:
791 query = cls.apply_presentation(query, filter_data)
792
793 return query
showard7c785282008-05-29 19:45:12 +0000794
795
jadmanski0afbb632008-06-06 21:10:57 +0000796 @classmethod
showard585c2ab2008-07-23 19:29:49 +0000797 def query_count(cls, filter_data, initial_query=None):
jadmanski0afbb632008-06-06 21:10:57 +0000798 """\
799 Like query_objects, but retreive only the count of results.
800 """
801 filter_data.pop('query_start', None)
802 filter_data.pop('query_limit', None)
showard585c2ab2008-07-23 19:29:49 +0000803 query = cls.query_objects(filter_data, initial_query=initial_query)
804 return query.count()
showard7c785282008-05-29 19:45:12 +0000805
806
jadmanski0afbb632008-06-06 21:10:57 +0000807 @classmethod
808 def clean_object_dicts(cls, field_dicts):
809 """\
810 Take a list of dicts corresponding to object (as returned by
811 query.values()) and clean the data to be more suitable for
812 returning to the user.
813 """
showarde732ee72008-09-23 19:15:43 +0000814 for field_dict in field_dicts:
815 cls.clean_foreign_keys(field_dict)
showard21baa452008-10-21 00:08:39 +0000816 cls._convert_booleans(field_dict)
showarde732ee72008-09-23 19:15:43 +0000817 cls.convert_human_readable_values(field_dict,
818 to_human_readable=True)
showard7c785282008-05-29 19:45:12 +0000819
820
jadmanski0afbb632008-06-06 21:10:57 +0000821 @classmethod
showard8bfb5cb2009-10-07 20:49:15 +0000822 def list_objects(cls, filter_data, initial_query=None):
jadmanski0afbb632008-06-06 21:10:57 +0000823 """\
824 Like query_objects, but return a list of dictionaries.
825 """
showard7ac7b7a2008-07-21 20:24:29 +0000826 query = cls.query_objects(filter_data, initial_query=initial_query)
showard8bfb5cb2009-10-07 20:49:15 +0000827 extra_fields = query.query.extra_select.keys()
828 field_dicts = [model_object.get_object_dict(extra_fields=extra_fields)
showarde732ee72008-09-23 19:15:43 +0000829 for model_object in query]
jadmanski0afbb632008-06-06 21:10:57 +0000830 return field_dicts
showard7c785282008-05-29 19:45:12 +0000831
832
jadmanski0afbb632008-06-06 21:10:57 +0000833 @classmethod
showarda4ea5742009-02-17 20:56:23 +0000834 def smart_get(cls, id_or_name, valid_only=True):
jadmanski0afbb632008-06-06 21:10:57 +0000835 """\
836 smart_get(integer) -> get object by ID
837 smart_get(string) -> get object by name_field
jadmanski0afbb632008-06-06 21:10:57 +0000838 """
showarda4ea5742009-02-17 20:56:23 +0000839 if valid_only:
840 manager = cls.get_valid_manager()
841 else:
842 manager = cls.objects
843
844 if isinstance(id_or_name, (int, long)):
845 return manager.get(pk=id_or_name)
jamesren3e9f6092010-03-11 21:32:10 +0000846 if isinstance(id_or_name, basestring) and hasattr(cls, 'name_field'):
showarda4ea5742009-02-17 20:56:23 +0000847 return manager.get(**{cls.name_field : id_or_name})
848 raise ValueError(
849 'Invalid positional argument: %s (%s)' % (id_or_name,
850 type(id_or_name)))
showard7c785282008-05-29 19:45:12 +0000851
852
showardbe3ec042008-11-12 18:16:07 +0000853 @classmethod
854 def smart_get_bulk(cls, id_or_name_list):
855 invalid_inputs = []
856 result_objects = []
857 for id_or_name in id_or_name_list:
858 try:
859 result_objects.append(cls.smart_get(id_or_name))
860 except cls.DoesNotExist:
861 invalid_inputs.append(id_or_name)
862 if invalid_inputs:
mbligh7a3ebe32008-12-01 17:10:33 +0000863 raise cls.DoesNotExist('The following %ss do not exist: %s'
864 % (cls.__name__.lower(),
865 ', '.join(invalid_inputs)))
showardbe3ec042008-11-12 18:16:07 +0000866 return result_objects
867
868
showard8bfb5cb2009-10-07 20:49:15 +0000869 def get_object_dict(self, extra_fields=None):
jadmanski0afbb632008-06-06 21:10:57 +0000870 """\
showard8bfb5cb2009-10-07 20:49:15 +0000871 Return a dictionary mapping fields to this object's values. @param
872 extra_fields: list of extra attribute names to include, in addition to
873 the fields defined on this object.
jadmanski0afbb632008-06-06 21:10:57 +0000874 """
showard8bfb5cb2009-10-07 20:49:15 +0000875 fields = self.get_field_dict().keys()
876 if extra_fields:
877 fields += extra_fields
jadmanski0afbb632008-06-06 21:10:57 +0000878 object_dict = dict((field_name, getattr(self, field_name))
showarde732ee72008-09-23 19:15:43 +0000879 for field_name in fields)
jadmanski0afbb632008-06-06 21:10:57 +0000880 self.clean_object_dicts([object_dict])
showardd3dc1992009-04-22 21:01:40 +0000881 self._postprocess_object_dict(object_dict)
jadmanski0afbb632008-06-06 21:10:57 +0000882 return object_dict
showard7c785282008-05-29 19:45:12 +0000883
884
showardd3dc1992009-04-22 21:01:40 +0000885 def _postprocess_object_dict(self, object_dict):
886 """For subclasses to override."""
887 pass
888
889
jadmanski0afbb632008-06-06 21:10:57 +0000890 @classmethod
891 def get_valid_manager(cls):
892 return cls.objects
showard7c785282008-05-29 19:45:12 +0000893
894
showard2bab8f42008-11-12 18:15:22 +0000895 def _record_attributes(self, attributes):
896 """
897 See on_attribute_changed.
898 """
899 assert not isinstance(attributes, basestring)
900 self._recorded_attributes = dict((attribute, getattr(self, attribute))
901 for attribute in attributes)
902
903
904 def _check_for_updated_attributes(self):
905 """
906 See on_attribute_changed.
907 """
908 for attribute, original_value in self._recorded_attributes.iteritems():
909 new_value = getattr(self, attribute)
910 if original_value != new_value:
911 self.on_attribute_changed(attribute, original_value)
912 self._record_attributes(self._recorded_attributes.keys())
913
914
915 def on_attribute_changed(self, attribute, old_value):
916 """
917 Called whenever an attribute is updated. To be overridden.
918
919 To use this method, you must:
920 * call _record_attributes() from __init__() (after making the super
921 call) with a list of attributes for which you want to be notified upon
922 change.
923 * call _check_for_updated_attributes() from save().
924 """
925 pass
926
927
Jakob Juelich116ff0f2014-09-17 18:25:16 -0700928 def serialize(self, include_dependencies=True):
Jakob Juelich3bb7c802014-09-02 16:31:11 -0700929 """Serializes the object with dependencies.
930
931 The variable SERIALIZATION_LINKS_TO_FOLLOW defines which dependencies
932 this function will serialize with the object.
933
Jakob Juelich116ff0f2014-09-17 18:25:16 -0700934 @param include_dependencies: Whether or not to follow relations to
935 objects this object depends on.
936 This parameter is used when uploading
937 jobs from a shard to the master, as the
938 master already has all the dependent
939 objects.
940
Jakob Juelich3bb7c802014-09-02 16:31:11 -0700941 @returns: Dictionary representation of the object.
942 """
943 serialized = {}
944 for field in self._meta.concrete_model._meta.local_fields:
945 if field.rel is None:
946 serialized[field.name] = field._get_val_from_obj(self)
Fang Deng86248502014-12-18 16:38:00 -0800947 elif (include_dependencies and
948 field.name in self.SERIALIZATION_LINKS_TO_KEEP):
949 # attname will contain "_id" suffix for foreign keys,
950 # e.g. HostAttribute.host will be serialized as 'host_id'.
951 # Use it for easy deserialization.
952 serialized[field.attname] = field._get_val_from_obj(self)
Jakob Juelich3bb7c802014-09-02 16:31:11 -0700953
Jakob Juelich116ff0f2014-09-17 18:25:16 -0700954 if include_dependencies:
955 for link in self.SERIALIZATION_LINKS_TO_FOLLOW:
956 serialized[link] = self._serialize_relation(link)
Jakob Juelich3bb7c802014-09-02 16:31:11 -0700957
958 return serialized
959
960
961 def _serialize_relation(self, link):
962 """Serializes dependent objects given the name of the relation.
963
964 @param link: Name of the relation to take objects from.
965
966 @returns For To-Many relationships a list of the serialized related
967 objects, for To-One relationships the serialized related object.
968 """
969 try:
970 attr = getattr(self, link)
971 except AttributeError:
972 # One-To-One relationships that point to None may raise this
973 return None
974
975 if attr is None:
976 return None
977 if hasattr(attr, 'all'):
978 return [obj.serialize() for obj in attr.all()]
979 return attr.serialize()
980
981
Jakob Juelichf88fa932014-09-03 17:58:04 -0700982 @classmethod
Jakob Juelich116ff0f2014-09-17 18:25:16 -0700983 def _split_local_from_foreign_values(cls, data):
984 """This splits local from foreign values in a serialized object.
985
986 @param data: The serialized object.
987
988 @returns A tuple of two lists, both containing tuples in the form
989 (link_name, link_value). The first list contains all links
990 for local fields, the second one contains those for foreign
991 fields/objects.
992 """
993 links_to_local_values, links_to_related_values = [], []
994 for link, value in data.iteritems():
995 if link in cls.SERIALIZATION_LINKS_TO_FOLLOW:
996 # It's a foreign key
997 links_to_related_values.append((link, value))
998 else:
Fang Deng86248502014-12-18 16:38:00 -0800999 # It's a local attribute or a foreign key
1000 # we don't want to follow.
Jakob Juelich116ff0f2014-09-17 18:25:16 -07001001 links_to_local_values.append((link, value))
1002 return links_to_local_values, links_to_related_values
1003
1004
Jakob Juelichf865d332014-09-29 10:47:49 -07001005 @classmethod
1006 def _filter_update_allowed_fields(cls, data):
1007 """Filters data and returns only files that updates are allowed on.
1008
1009 This is i.e. needed for syncing aborted bits from the master to shards.
1010
1011 Local links are only allowed to be updated, if they are in
1012 SERIALIZATION_LOCAL_LINKS_TO_UPDATE.
1013 Overwriting existing values is allowed in order to be able to sync i.e.
1014 the aborted bit from the master to a shard.
1015
1016 The whitelisting mechanism is in place to prevent overwriting local
1017 status: If all fields were overwritten, jobs would be completely be
1018 set back to their original (unstarted) state.
1019
1020 @param data: List with tuples of the form (link_name, link_value), as
1021 returned by _split_local_from_foreign_values.
1022
1023 @returns List of the same format as data, but only containing data for
1024 fields that updates are allowed on.
1025 """
1026 return [pair for pair in data
1027 if pair[0] in cls.SERIALIZATION_LOCAL_LINKS_TO_UPDATE]
1028
1029
Prashanth Balasubramanianaf516642014-12-12 18:16:32 -08001030 @classmethod
1031 def delete_matching_record(cls, **filter_args):
1032 """Delete records matching the filter.
1033
1034 @param filter_args: Arguments for the django filter
1035 used to locate the record to delete.
1036 """
1037 try:
1038 existing_record = cls.objects.get(**filter_args)
1039 except cls.DoesNotExist:
1040 return
1041 existing_record.delete()
1042
1043
Jakob Juelich116ff0f2014-09-17 18:25:16 -07001044 def _deserialize_local(self, data):
1045 """Set local attributes from a list of tuples.
1046
1047 @param data: List of tuples like returned by
1048 _split_local_from_foreign_values.
1049 """
Prashanth Balasubramanianaf516642014-12-12 18:16:32 -08001050 if not data:
1051 return
1052
Jakob Juelich116ff0f2014-09-17 18:25:16 -07001053 for link, value in data:
1054 setattr(self, link, value)
1055 # Overwridden save() methods are prone to errors, so don't execute them.
1056 # This is because:
1057 # - the overwritten methods depend on ACL groups that don't yet exist
1058 # and don't handle errors
1059 # - the overwritten methods think this object already exists in the db
1060 # because the id is already set
1061 super(type(self), self).save()
1062
1063
1064 def _deserialize_relations(self, data):
1065 """Set foreign attributes from a list of tuples.
1066
1067 This deserialized the related objects using their own deserialize()
1068 function and then sets the relation.
1069
1070 @param data: List of tuples like returned by
1071 _split_local_from_foreign_values.
1072 """
1073 for link, value in data:
1074 self._deserialize_relation(link, value)
1075 # See comment in _deserialize_local
1076 super(type(self), self).save()
1077
1078
1079 @classmethod
Prashanth Balasubramanianaf516642014-12-12 18:16:32 -08001080 def get_record(cls, data):
1081 """Retrieve a record with the data in the given input arg.
1082
1083 @param data: A dictionary containing the information to use in a query
1084 for data. If child models have different constraints of
1085 uniqueness they should override this model.
1086
1087 @return: An object with matching data.
1088
1089 @raises DoesNotExist: If a record with the given data doesn't exist.
1090 """
1091 return cls.objects.get(id=data['id'])
1092
1093
1094 @classmethod
Jakob Juelichf88fa932014-09-03 17:58:04 -07001095 def deserialize(cls, data):
Jakob Juelich116ff0f2014-09-17 18:25:16 -07001096 """Recursively deserializes and saves an object with it's dependencies.
Jakob Juelichf88fa932014-09-03 17:58:04 -07001097
1098 This takes the result of the serialize method and creates objects
Jakob Juelich116ff0f2014-09-17 18:25:16 -07001099 in the database that are just like the original.
1100
1101 If an object of the same type with the same id already exists, it's
Jakob Juelichf865d332014-09-29 10:47:49 -07001102 local values will be left untouched, unless they are explicitly
1103 whitelisted in SERIALIZATION_LOCAL_LINKS_TO_UPDATE.
1104
1105 Deserialize will always recursively propagate to all related objects
Jakob Juelich116ff0f2014-09-17 18:25:16 -07001106 present in data though.
1107 I.e. this is necessary to add users to an already existing acl-group.
Jakob Juelichf88fa932014-09-03 17:58:04 -07001108
1109 @param data: Representation of an object and its dependencies, as
1110 returned by serialize.
1111
1112 @returns: The object represented by data if it didn't exist before,
1113 otherwise the object that existed before and has the same type
1114 and id as the one described by data.
1115 """
1116 if data is None:
1117 return None
1118
Jakob Juelich116ff0f2014-09-17 18:25:16 -07001119 local, related = cls._split_local_from_foreign_values(data)
Jakob Juelichf88fa932014-09-03 17:58:04 -07001120 try:
Prashanth Balasubramanianaf516642014-12-12 18:16:32 -08001121 instance = cls.get_record(data)
Jakob Juelichf865d332014-09-29 10:47:49 -07001122 local = cls._filter_update_allowed_fields(local)
Jakob Juelichf88fa932014-09-03 17:58:04 -07001123 except cls.DoesNotExist:
Jakob Juelich116ff0f2014-09-17 18:25:16 -07001124 instance = cls()
Jakob Juelichf88fa932014-09-03 17:58:04 -07001125
Jakob Juelichf865d332014-09-29 10:47:49 -07001126 instance._deserialize_local(local)
Jakob Juelich116ff0f2014-09-17 18:25:16 -07001127 instance._deserialize_relations(related)
Jakob Juelichf88fa932014-09-03 17:58:04 -07001128
1129 return instance
1130
1131
Jakob Juelicha94efe62014-09-18 16:02:49 -07001132 def sanity_check_update_from_shard(self, shard, updated_serialized,
1133 *args, **kwargs):
1134 """Check if an update sent from a shard is legitimate.
1135
1136 @raises error.UnallowedRecordsSentToMaster if an update is not
1137 legitimate.
1138 """
1139 raise NotImplementedError(
1140 'sanity_check_update_from_shard must be implemented by subclass %s '
1141 'for type %s' % type(self))
1142
1143
Prashanth Balasubramanian75be1d32014-11-25 18:03:09 -08001144 @transaction.commit_on_success
Jakob Juelich116ff0f2014-09-17 18:25:16 -07001145 def update_from_serialized(self, serialized):
1146 """Updates local fields of an existing object from a serialized form.
1147
1148 This is different than the normal deserialize() in the way that it
1149 does update local values, which deserialize doesn't, but doesn't
1150 recursively propagate to related objects, which deserialize() does.
1151
1152 The use case of this function is to update job records on the master
1153 after the jobs have been executed on a slave, as the master is not
1154 interested in updates for users, labels, specialtasks, etc.
1155
1156 @param serialized: Representation of an object and its dependencies, as
1157 returned by serialize.
1158
1159 @raises ValueError: if serialized contains related objects, i.e. not
1160 only local fields.
1161 """
1162 local, related = (
1163 self._split_local_from_foreign_values(serialized))
1164 if related:
1165 raise ValueError('Serialized must not contain foreign '
1166 'objects: %s' % related)
1167
1168 self._deserialize_local(local)
1169
1170
Jakob Juelichf88fa932014-09-03 17:58:04 -07001171 def custom_deserialize_relation(self, link, data):
Jakob Juelich116ff0f2014-09-17 18:25:16 -07001172 """Allows overriding the deserialization behaviour by subclasses."""
Jakob Juelichf88fa932014-09-03 17:58:04 -07001173 raise NotImplementedError(
1174 'custom_deserialize_relation must be implemented by subclass %s '
1175 'for relation %s' % (type(self), link))
1176
1177
1178 def _deserialize_relation(self, link, data):
Jakob Juelich116ff0f2014-09-17 18:25:16 -07001179 """Deserializes related objects and sets references on this object.
1180
1181 Relations that point to a list of objects are handled automatically.
1182 For many-to-one or one-to-one relations custom_deserialize_relation
1183 must be overridden by the subclass.
1184
1185 Related objects are deserialized using their deserialize() method.
1186 Thereby they and their dependencies are created if they don't exist
1187 and saved to the database.
1188
1189 @param link: Name of the relation.
1190 @param data: Serialized representation of the related object(s).
1191 This means a list of dictionaries for to-many relations,
1192 just a dictionary for to-one relations.
1193 """
Jakob Juelichf88fa932014-09-03 17:58:04 -07001194 field = getattr(self, link)
1195
1196 if field and hasattr(field, 'all'):
1197 self._deserialize_2m_relation(link, data, field.model)
1198 else:
1199 self.custom_deserialize_relation(link, data)
1200
1201
1202 def _deserialize_2m_relation(self, link, data, related_class):
Jakob Juelich116ff0f2014-09-17 18:25:16 -07001203 """Deserialize related objects for one to-many relationship.
1204
1205 @param link: Name of the relation.
1206 @param data: Serialized representation of the related objects.
1207 This is a list with of dictionaries.
1208 """
Jakob Juelichf88fa932014-09-03 17:58:04 -07001209 relation_set = getattr(self, link)
1210 for serialized in data:
1211 relation_set.add(related_class.deserialize(serialized))
1212
1213
showard7c785282008-05-29 19:45:12 +00001214class ModelWithInvalid(ModelExtensions):
jadmanski0afbb632008-06-06 21:10:57 +00001215 """
1216 Overrides model methods save() and delete() to support invalidation in
1217 place of actual deletion. Subclasses must have a boolean "invalid"
1218 field.
1219 """
showard7c785282008-05-29 19:45:12 +00001220
showarda5288b42009-07-28 20:06:08 +00001221 def save(self, *args, **kwargs):
showardddb90992009-02-11 23:39:32 +00001222 first_time = (self.id is None)
1223 if first_time:
1224 # see if this object was previously added and invalidated
1225 my_name = getattr(self, self.name_field)
1226 filters = {self.name_field : my_name, 'invalid' : True}
1227 try:
1228 old_object = self.__class__.objects.get(**filters)
showardafd97de2009-10-01 18:45:09 +00001229 self.resurrect_object(old_object)
showardddb90992009-02-11 23:39:32 +00001230 except self.DoesNotExist:
1231 # no existing object
1232 pass
showard7c785282008-05-29 19:45:12 +00001233
showarda5288b42009-07-28 20:06:08 +00001234 super(ModelWithInvalid, self).save(*args, **kwargs)
showard7c785282008-05-29 19:45:12 +00001235
1236
showardafd97de2009-10-01 18:45:09 +00001237 def resurrect_object(self, old_object):
1238 """
1239 Called when self is about to be saved for the first time and is actually
1240 "undeleting" a previously deleted object. Can be overridden by
1241 subclasses to copy data as desired from the deleted entry (but this
1242 superclass implementation must normally be called).
1243 """
1244 self.id = old_object.id
1245
1246
jadmanski0afbb632008-06-06 21:10:57 +00001247 def clean_object(self):
1248 """
1249 This method is called when an object is marked invalid.
1250 Subclasses should override this to clean up relationships that
showardafd97de2009-10-01 18:45:09 +00001251 should no longer exist if the object were deleted.
1252 """
jadmanski0afbb632008-06-06 21:10:57 +00001253 pass
showard7c785282008-05-29 19:45:12 +00001254
1255
jadmanski0afbb632008-06-06 21:10:57 +00001256 def delete(self):
Dale Curtis74a314b2011-06-23 14:55:46 -07001257 self.invalid = self.invalid
jadmanski0afbb632008-06-06 21:10:57 +00001258 assert not self.invalid
1259 self.invalid = True
1260 self.save()
1261 self.clean_object()
showard7c785282008-05-29 19:45:12 +00001262
1263
jadmanski0afbb632008-06-06 21:10:57 +00001264 @classmethod
1265 def get_valid_manager(cls):
1266 return cls.valid_objects
showard7c785282008-05-29 19:45:12 +00001267
1268
jadmanski0afbb632008-06-06 21:10:57 +00001269 class Manipulator(object):
1270 """
1271 Force default manipulators to look only at valid objects -
1272 otherwise they will match against invalid objects when checking
1273 uniqueness.
1274 """
1275 @classmethod
1276 def _prepare(cls, model):
1277 super(ModelWithInvalid.Manipulator, cls)._prepare(model)
1278 cls.manager = model.valid_objects
showardf8b19042009-05-12 17:22:49 +00001279
1280
1281class ModelWithAttributes(object):
1282 """
1283 Mixin class for models that have an attribute model associated with them.
1284 The attribute model is assumed to have its value field named "value".
1285 """
1286
1287 def _get_attribute_model_and_args(self, attribute):
1288 """
1289 Subclasses should override this to return a tuple (attribute_model,
1290 keyword_args), where attribute_model is a model class and keyword_args
1291 is a dict of args to pass to attribute_model.objects.get() to get an
1292 instance of the given attribute on this object.
1293 """
Dale Curtis74a314b2011-06-23 14:55:46 -07001294 raise NotImplementedError
showardf8b19042009-05-12 17:22:49 +00001295
1296
1297 def set_attribute(self, attribute, value):
1298 attribute_model, get_args = self._get_attribute_model_and_args(
1299 attribute)
1300 attribute_object, _ = attribute_model.objects.get_or_create(**get_args)
1301 attribute_object.value = value
1302 attribute_object.save()
1303
1304
1305 def delete_attribute(self, attribute):
1306 attribute_model, get_args = self._get_attribute_model_and_args(
1307 attribute)
1308 try:
1309 attribute_model.objects.get(**get_args).delete()
showard16245422009-09-08 16:28:15 +00001310 except attribute_model.DoesNotExist:
showardf8b19042009-05-12 17:22:49 +00001311 pass
1312
1313
1314 def set_or_delete_attribute(self, attribute, value):
1315 if value is None:
1316 self.delete_attribute(attribute)
1317 else:
1318 self.set_attribute(attribute, value)
showard26b7ec72009-12-21 22:43:57 +00001319
1320
1321class ModelWithHashManager(dbmodels.Manager):
1322 """Manager for use with the ModelWithHash abstract model class"""
1323
1324 def create(self, **kwargs):
1325 raise Exception('ModelWithHash manager should use get_or_create() '
1326 'instead of create()')
1327
1328
1329 def get_or_create(self, **kwargs):
1330 kwargs['the_hash'] = self.model._compute_hash(**kwargs)
1331 return super(ModelWithHashManager, self).get_or_create(**kwargs)
1332
1333
1334class ModelWithHash(dbmodels.Model):
1335 """Superclass with methods for dealing with a hash column"""
1336
1337 the_hash = dbmodels.CharField(max_length=40, unique=True)
1338
1339 objects = ModelWithHashManager()
1340
1341 class Meta:
1342 abstract = True
1343
1344
1345 @classmethod
1346 def _compute_hash(cls, **kwargs):
1347 raise NotImplementedError('Subclasses must override _compute_hash()')
1348
1349
1350 def save(self, force_insert=False, **kwargs):
1351 """Prevents saving the model in most cases
1352
1353 We want these models to be immutable, so the generic save() operation
1354 will not work. These models should be instantiated through their the
1355 model.objects.get_or_create() method instead.
1356
1357 The exception is that save(force_insert=True) will be allowed, since
1358 that creates a new row. However, the preferred way to make instances of
1359 these models is through the get_or_create() method.
1360 """
1361 if not force_insert:
1362 # Allow a forced insert to happen; if it's a duplicate, the unique
1363 # constraint will catch it later anyways
1364 raise Exception('ModelWithHash is immutable')
1365 super(ModelWithHash, self).save(force_insert=force_insert, **kwargs)