blob: f1197103de3f9e666c085c24054fb8c0c98df672 [file] [log] [blame]
showard7c785282008-05-29 19:45:12 +00001"""
2Extensions to Django's model logic.
3"""
4
5from django.db import models as dbmodels, backend, connection
6from django.utils import datastructures
7
8
9class ValidationError(Exception):
10 """\
11 Data validation error in adding or updating an object. The associated
12 value is a dictionary mapping field names to error strings.
13 """
14
15
16class ExtendedManager(dbmodels.Manager):
17 """\
18 Extended manager supporting subquery filtering.
19 """
20
21 class _RawSqlQ(dbmodels.Q):
22 """\
23 A Django "Q" object constructed with a raw SQL query.
24 """
25 def __init__(self, sql, params=[], joins={}):
26 """
27 sql: the SQL to go into the WHERE clause
28
29 params: substitution params for the WHERE SQL
30
31 joins: a dict mapping alias to (table, join_type,
32 condition). This converts to the SQL:
33 "join_type table AS alias ON condition"
34 For example:
35 alias='host_hqe',
36 table='host_queue_entries',
37 join_type='INNER JOIN',
38 condition='host_hqe.host_id=hosts.id'
39 """
40 self._sql = sql
41 self._params = params[:]
42 self._joins = datastructures.SortedDict(joins)
43
44
45 def get_sql(self, opts):
46 return (self._joins,
47 [self._sql],
48 self._params)
49
50
51 @staticmethod
52 def _get_quoted_field(table, field):
53 return (backend.quote_name(table) + '.' +
54 backend.quote_name(field))
55
56
57 @classmethod
58 def _get_sql_string_for(cls, value):
59 """
60 >>> ExtendedManager._get_sql_string_for((1L, 2L))
61 '(1,2)'
62 >>> ExtendedManager._get_sql_string_for(['abc', 'def'])
63 'abc,def'
64 """
65 if isinstance(value, list):
66 return ','.join(cls._get_sql_string_for(item)
67 for item in value)
68 if isinstance(value, tuple):
69 return '(%s)' % cls._get_sql_string_for(list(value))
70 if isinstance(value, long):
71 return str(int(value))
72 return str(value)
73
74
75 @staticmethod
76 def _get_sql_query_for(query_object, select_field):
77 query_table = query_object.model._meta.db_table
78 quoted_field = ExtendedManager._get_quoted_field(query_table,
79 select_field)
80 _, where, params = query_object._get_sql_clause()
81 # where includes the FROM clause
82 return '(SELECT DISTINCT ' + quoted_field + where + ')', params
83
84
85 def _get_key_on_this_table(self, key_field=None):
86 if key_field is None:
87 # default to primary key
88 key_field = self.model._meta.pk.column
89 return self._get_quoted_field(self.model._meta.db_table,
90 key_field)
91
92
93 def _do_subquery_filter(self, subquery_key, subquery, subquery_alias,
94 this_table_key=None, not_in=False):
95 """
96 This method constructs SQL queries to accomplish IN/NOT IN
97 subquery filtering using explicit joins. It does this by
98 LEFT JOINing onto the subquery and then checking to see if
99 the joined column is NULL or not.
100
101 We use explicit joins instead of the SQL IN operator because
102 MySQL (at least some versions) considers all IN subqueries to be
103 dependent, so using explicit joins can be MUCH faster.
104
105 The query we're going for is:
106 SELECT * FROM <this table>
107 LEFT JOIN (<subquery>) AS <subquery_alias>
108 ON <subquery_alias>.<subquery_key> =
109 <this table>.<this_table_key>
110 WHERE <subquery_alias>.<subquery_key> IS [NOT] NULL
111 """
112 subselect, params = self._get_sql_query_for(subquery,
113 subquery_key)
114
115 this_full_key = self._get_key_on_this_table(this_table_key)
116 alias_full_key = self._get_quoted_field(subquery_alias,
117 subquery_key)
118 join_condition = alias_full_key + ' = ' + this_full_key
119 joins = {subquery_alias : (subselect, # join table
120 'LEFT JOIN', # join type
121 join_condition)} # join on
122
123 if not_in:
124 where_sql = alias_full_key + ' IS NULL'
125 else:
126 where_sql = alias_full_key + ' IS NOT NULL'
127 filter_obj = self._RawSqlQ(where_sql, params, joins)
128 return self.complex_filter(filter_obj)
129
130
131 def filter_in_subquery(self, subquery_key, subquery, subquery_alias,
132 this_table_key=None):
133 """\
134 Construct a filter to perform a subquery match, i.e.
135 WHERE id IN (SELECT host_id FROM ... WHERE ...)
136 -subquery_key - the field to select in the subquery (host_id
137 above)
138 -subquery - a query object for the subquery
139 -subquery_alias - a logical name for the query, to be used in
140 the SQL (i.e. 'valid_hosts')
141 -this_table_key - the field to match (id above). Defaults to
142 this table's primary key.
143 """
144 return self._do_subquery_filter(subquery_key, subquery,
145 subquery_alias, this_table_key)
146
147
148 def filter_not_in_subquery(self, subquery_key, subquery,
149 subquery_alias, this_table_key=None):
150 'Like filter_in_subquery, but use NOT IN rather than IN.'
151 return self._do_subquery_filter(subquery_key, subquery,
152 subquery_alias, this_table_key,
153 not_in=True)
154
155
156 def create_in_bulk(self, fields, values):
157 """
158 Creates many objects with a single SQL query.
159 field - list of field names (model attributes, not actual DB
160 field names) for which values will be specified.
161 values - list of tuples containing values. Each tuple contains
162 the values for the specified fields for a single
163 object.
164 Example: Host.objects.create_in_bulk(['hostname', 'status'],
165 [('host1', 'Ready'), ('host2', 'Running')])
166 """
167 if not values:
168 return
169 field_dict = self.model.get_field_dict()
170 field_names = [field_dict[field].column for field in fields]
171 sql = 'INSERT INTO %s %s' % (
172 self.model._meta.db_table,
173 self._get_sql_string_for(tuple(field_names)))
174 sql += ' VALUES ' + self._get_sql_string_for(list(values))
175 cursor = connection.cursor()
176 cursor.execute(sql)
177 connection._commit()
178
179
180 def delete_in_bulk(self, ids):
181 """
182 Deletes many objects with a single SQL query. ids should be a
183 list of object ids to delete. Nonexistent ids will be silently
184 ignored.
185 """
186 if not ids:
187 return
188 sql = 'DELETE FROM %s WHERE id IN %s' % (
189 self.model._meta.db_table,
190 self._get_sql_string_for(tuple(ids)))
191 cursor = connection.cursor()
192 cursor.execute(sql)
193 connection._commit()
194
195
196class ValidObjectsManager(ExtendedManager):
197 """
198 Manager returning only objects with invalid=False.
199 """
200 def get_query_set(self):
201 queryset = super(ValidObjectsManager, self).get_query_set()
202 return queryset.filter(invalid=False)
203
204
205class ModelExtensions(object):
206 """\
207 Mixin with convenience functions for models, built on top of the
208 default Django model functions.
209 """
210 # TODO: at least some of these functions really belong in a custom
211 # Manager class
212
213 field_dict = None
214 # subclasses should override if they want to support smart_get() by name
215 name_field = None
216
217
218 @classmethod
219 def get_field_dict(cls):
220 if cls.field_dict is None:
221 cls.field_dict = {}
222 for field in cls._meta.fields:
223 cls.field_dict[field.name] = field
224 return cls.field_dict
225
226
227 @classmethod
228 def clean_foreign_keys(cls, data):
229 """\
230 -Convert foreign key fields in data from <field>_id to just
231 <field>.
232 -replace foreign key objects with their IDs
233 This method modifies data in-place.
234 """
235 for field in cls._meta.fields:
236 if not field.rel:
237 continue
238 if (field.attname != field.name and
239 field.attname in data):
240 data[field.name] = data[field.attname]
241 del data[field.attname]
242 value = data[field.name]
243 if isinstance(value, dbmodels.Model):
244 data[field.name] = value.id
245
246
247 # TODO(showard) - is there a way to not have to do this?
248 @classmethod
249 def provide_default_values(cls, data):
250 """\
251 Provide default values for fields with default values which have
252 nothing passed in.
253
254 For CharField and TextField fields with "blank=True", if nothing
255 is passed, we fill in an empty string value, even if there's no
256 default set.
257 """
258 new_data = dict(data)
259 field_dict = cls.get_field_dict()
260 for name, obj in field_dict.iteritems():
261 if data.get(name) is not None:
262 continue
263 if obj.default is not dbmodels.fields.NOT_PROVIDED:
264 new_data[name] = obj.default
265 elif (isinstance(obj, dbmodels.CharField) or
266 isinstance(obj, dbmodels.TextField)):
267 new_data[name] = ''
268 return new_data
269
270
271 @classmethod
272 def convert_human_readable_values(cls, data, to_human_readable=False):
273 """\
274 Performs conversions on user-supplied field data, to make it
275 easier for users to pass human-readable data.
276
277 For all fields that have choice sets, convert their values
278 from human-readable strings to enum values, if necessary. This
279 allows users to pass strings instead of the corresponding
280 integer values.
281
282 For all foreign key fields, call smart_get with the supplied
283 data. This allows the user to pass either an ID value or
284 the name of the object as a string.
285
286 If to_human_readable=True, perform the inverse - i.e. convert
287 numeric values to human readable values.
288
289 This method modifies data in-place.
290 """
291 field_dict = cls.get_field_dict()
292 for field_name in data:
293 if data[field_name] is None:
294 continue
295 field_obj = field_dict[field_name]
296 # convert enum values
297 if field_obj.choices:
298 for choice_data in field_obj.choices:
299 # choice_data is (value, name)
300 if to_human_readable:
301 from_val, to_val = choice_data
302 else:
303 to_val, from_val = choice_data
304 if from_val == data[field_name]:
305 data[field_name] = to_val
306 break
307 # convert foreign key values
308 elif field_obj.rel:
309 dest_obj = field_obj.rel.to.smart_get(
310 data[field_name])
311 if (to_human_readable and
312 dest_obj.name_field is not None):
313 data[field_name] = (
314 getattr(dest_obj,
315 dest_obj.name_field))
316 else:
317 data[field_name] = dest_obj.id
318
319
320 @classmethod
321 def validate_field_names(cls, data):
322 'Checks for extraneous fields in data.'
323 errors = {}
324 field_dict = cls.get_field_dict()
325 for field_name in data:
326 if field_name not in field_dict:
327 errors[field_name] = 'No field of this name'
328 return errors
329
330
331 @classmethod
332 def prepare_data_args(cls, data, kwargs):
333 'Common preparation for add_object and update_object'
334 data = dict(data) # don't modify the default keyword arg
335 data.update(kwargs)
336 # must check for extraneous field names here, while we have the
337 # data in a dict
338 errors = cls.validate_field_names(data)
339 if errors:
340 raise ValidationError(errors)
341 cls.convert_human_readable_values(data)
342 return data
343
344
345 def validate_unique(self):
346 """\
347 Validate that unique fields are unique. Django manipulators do
348 this too, but they're a huge pain to use manually. Trust me.
349 """
350 errors = {}
351 cls = type(self)
352 field_dict = self.get_field_dict()
353 manager = cls.get_valid_manager()
354 for field_name, field_obj in field_dict.iteritems():
355 if not field_obj.unique:
356 continue
357
358 value = getattr(self, field_name)
359 existing_objs = manager.filter(**{field_name : value})
360 num_existing = existing_objs.count()
361
362 if num_existing == 0:
363 continue
364 if num_existing == 1 and existing_objs[0].id == self.id:
365 continue
366 errors[field_name] = (
367 'This value must be unique (%s)' % (value))
368 return errors
369
370
371 def do_validate(self):
372 errors = self.validate()
373 unique_errors = self.validate_unique()
374 for field_name, error in unique_errors.iteritems():
375 errors.setdefault(field_name, error)
376 if errors:
377 raise ValidationError(errors)
378
379
380 # actually (externally) useful methods follow
381
382 @classmethod
383 def add_object(cls, data={}, **kwargs):
384 """\
385 Returns a new object created with the given data (a dictionary
386 mapping field names to values). Merges any extra keyword args
387 into data.
388 """
389 data = cls.prepare_data_args(data, kwargs)
390 data = cls.provide_default_values(data)
391 obj = cls(**data)
392 obj.do_validate()
393 obj.save()
394 return obj
395
396
397 def update_object(self, data={}, **kwargs):
398 """\
399 Updates the object with the given data (a dictionary mapping
400 field names to values). Merges any extra keyword args into
401 data.
402 """
403 data = self.prepare_data_args(data, kwargs)
404 for field_name, value in data.iteritems():
405 if value is not None:
406 setattr(self, field_name, value)
407 self.do_validate()
408 self.save()
409
410
411 @classmethod
412 def query_objects(cls, filter_data, valid_only=True):
413 """\
414 Returns a QuerySet object for querying the given model_class
415 with the given filter_data. Optional special arguments in
416 filter_data include:
417 -query_start: index of first return to return
418 -query_limit: maximum number of results to return
419 -sort_by: list of fields to sort on. prefixing a '-' onto a
420 field name changes the sort to descending order.
421 -extra_args: keyword args to pass to query.extra() (see Django
422 DB layer documentation)
showardacdbe352008-06-05 23:46:50 +0000423 -extra_where: extra WHERE clause to append
showard7c785282008-05-29 19:45:12 +0000424 """
425 query_start = filter_data.pop('query_start', None)
426 query_limit = filter_data.pop('query_limit', None)
427 if query_start and not query_limit:
428 raise ValueError('Cannot pass query_start without '
429 'query_limit')
430 sort_by = filter_data.pop('sort_by', [])
showardacdbe352008-06-05 23:46:50 +0000431 extra_args = filter_data.pop('extra_args', {})
432 extra_where = filter_data.pop('extra_where', None)
433 if extra_where:
434 extra_args.setdefault('where', []).append(extra_where)
showard7c785282008-05-29 19:45:12 +0000435
436 # filters
437 query_dict = {}
438 for field, value in filter_data.iteritems():
439 query_dict[field] = value
440 if valid_only:
441 manager = cls.get_valid_manager()
442 else:
443 manager = cls.objects
444 query = manager.filter(**query_dict).distinct()
445
446 # other arguments
447 if extra_args:
448 query = query.extra(**extra_args)
449
450 # sorting + paging
451 assert isinstance(sort_by, list) or isinstance(sort_by, tuple)
452 query = query.order_by(*sort_by)
453 if query_start is not None and query_limit is not None:
454 query_limit += query_start
455 return query[query_start:query_limit]
456
457
458 @classmethod
459 def query_count(cls, filter_data):
460 """\
461 Like query_objects, but retreive only the count of results.
462 """
463 filter_data.pop('query_start', None)
464 filter_data.pop('query_limit', None)
465 return cls.query_objects(filter_data).count()
466
467
468 @classmethod
469 def clean_object_dicts(cls, field_dicts):
470 """\
471 Take a list of dicts corresponding to object (as returned by
472 query.values()) and clean the data to be more suitable for
473 returning to the user.
474 """
475 for i in range(len(field_dicts)):
476 cls.clean_foreign_keys(field_dicts[i])
477 cls.convert_human_readable_values(
478 field_dicts[i], to_human_readable=True)
479
480
481 @classmethod
482 def list_objects(cls, filter_data):
483 """\
484 Like query_objects, but return a list of dictionaries.
485 """
486 query = cls.query_objects(filter_data)
487 field_dicts = list(query.values())
488 cls.clean_object_dicts(field_dicts)
489 return field_dicts
490
491
492 @classmethod
493 def smart_get(cls, *args, **kwargs):
494 """\
495 smart_get(integer) -> get object by ID
496 smart_get(string) -> get object by name_field
497 smart_get(keyword args) -> normal ModelClass.objects.get()
498 """
499 assert bool(args) ^ bool(kwargs)
500 if args:
501 assert len(args) == 1
502 arg = args[0]
503 if isinstance(arg, int) or isinstance(arg, long):
504 return cls.objects.get(id=arg)
505 if isinstance(arg, str) or isinstance(arg, unicode):
506 return cls.objects.get(
507 **{cls.name_field : arg})
508 raise ValueError(
509 'Invalid positional argument: %s (%s)' % (
510 str(arg), type(arg)))
511 return cls.objects.get(**kwargs)
512
513
514 def get_object_dict(self):
515 """\
516 Return a dictionary mapping fields to this object's values.
517 """
518 object_dict = dict((field_name, getattr(self, field_name))
519 for field_name
520 in self.get_field_dict().iterkeys())
521 self.clean_object_dicts([object_dict])
522 return object_dict
523
524
525 @classmethod
526 def get_valid_manager(cls):
527 return cls.objects
528
529
530class ModelWithInvalid(ModelExtensions):
531 """
532 Overrides model methods save() and delete() to support invalidation in
533 place of actual deletion. Subclasses must have a boolean "invalid"
534 field.
535 """
536
537 def save(self):
538 # see if this object was previously added and invalidated
539 my_name = getattr(self, self.name_field)
540 filters = {self.name_field : my_name, 'invalid' : True}
541 try:
542 old_object = self.__class__.objects.get(**filters)
543 except self.DoesNotExist:
544 # no existing object
545 super(ModelWithInvalid, self).save()
546 return
547
548 self.id = old_object.id
549 super(ModelWithInvalid, self).save()
550
551
552 def clean_object(self):
553 """
554 This method is called when an object is marked invalid.
555 Subclasses should override this to clean up relationships that
556 should no longer exist if the object were deleted."""
557 pass
558
559
560 def delete(self):
561 assert not self.invalid
562 self.invalid = True
563 self.save()
564 self.clean_object()
565
566
567 @classmethod
568 def get_valid_manager(cls):
569 return cls.valid_objects
570
571
572 class Manipulator(object):
573 """
574 Force default manipulators to look only at valid objects -
575 otherwise they will match against invalid objects when checking
576 uniqueness.
577 """
578 @classmethod
579 def _prepare(cls, model):
580 super(ModelWithInvalid.Manipulator, cls)._prepare(model)
581 cls.manager = model.valid_objects