showard | 7c78528 | 2008-05-29 19:45:12 +0000 | [diff] [blame] | 1 | """ |
| 2 | Extensions to Django's model logic. |
| 3 | """ |
| 4 | |
| 5 | from django.db import models as dbmodels, backend, connection |
| 6 | from django.utils import datastructures |
| 7 | |
| 8 | |
| 9 | class ValidationError(Exception): |
| 10 | """\ |
| 11 | Data validation error in adding or updating an object. The associated |
| 12 | value is a dictionary mapping field names to error strings. |
| 13 | """ |
| 14 | |
| 15 | |
| 16 | class ExtendedManager(dbmodels.Manager): |
| 17 | """\ |
| 18 | Extended manager supporting subquery filtering. |
| 19 | """ |
| 20 | |
| 21 | class _RawSqlQ(dbmodels.Q): |
| 22 | """\ |
| 23 | A Django "Q" object constructed with a raw SQL query. |
| 24 | """ |
| 25 | def __init__(self, sql, params=[], joins={}): |
| 26 | """ |
| 27 | sql: the SQL to go into the WHERE clause |
| 28 | |
| 29 | params: substitution params for the WHERE SQL |
| 30 | |
| 31 | joins: a dict mapping alias to (table, join_type, |
| 32 | condition). This converts to the SQL: |
| 33 | "join_type table AS alias ON condition" |
| 34 | For example: |
| 35 | alias='host_hqe', |
| 36 | table='host_queue_entries', |
| 37 | join_type='INNER JOIN', |
| 38 | condition='host_hqe.host_id=hosts.id' |
| 39 | """ |
| 40 | self._sql = sql |
| 41 | self._params = params[:] |
| 42 | self._joins = datastructures.SortedDict(joins) |
| 43 | |
| 44 | |
| 45 | def get_sql(self, opts): |
| 46 | return (self._joins, |
| 47 | [self._sql], |
| 48 | self._params) |
| 49 | |
| 50 | |
| 51 | @staticmethod |
| 52 | def _get_quoted_field(table, field): |
| 53 | return (backend.quote_name(table) + '.' + |
| 54 | backend.quote_name(field)) |
| 55 | |
| 56 | |
| 57 | @classmethod |
| 58 | def _get_sql_string_for(cls, value): |
| 59 | """ |
| 60 | >>> ExtendedManager._get_sql_string_for((1L, 2L)) |
| 61 | '(1,2)' |
| 62 | >>> ExtendedManager._get_sql_string_for(['abc', 'def']) |
| 63 | 'abc,def' |
| 64 | """ |
| 65 | if isinstance(value, list): |
| 66 | return ','.join(cls._get_sql_string_for(item) |
| 67 | for item in value) |
| 68 | if isinstance(value, tuple): |
| 69 | return '(%s)' % cls._get_sql_string_for(list(value)) |
| 70 | if isinstance(value, long): |
| 71 | return str(int(value)) |
| 72 | return str(value) |
| 73 | |
| 74 | |
| 75 | @staticmethod |
| 76 | def _get_sql_query_for(query_object, select_field): |
| 77 | query_table = query_object.model._meta.db_table |
| 78 | quoted_field = ExtendedManager._get_quoted_field(query_table, |
| 79 | select_field) |
| 80 | _, where, params = query_object._get_sql_clause() |
| 81 | # where includes the FROM clause |
| 82 | return '(SELECT DISTINCT ' + quoted_field + where + ')', params |
| 83 | |
| 84 | |
| 85 | def _get_key_on_this_table(self, key_field=None): |
| 86 | if key_field is None: |
| 87 | # default to primary key |
| 88 | key_field = self.model._meta.pk.column |
| 89 | return self._get_quoted_field(self.model._meta.db_table, |
| 90 | key_field) |
| 91 | |
| 92 | |
| 93 | def _do_subquery_filter(self, subquery_key, subquery, subquery_alias, |
| 94 | this_table_key=None, not_in=False): |
| 95 | """ |
| 96 | This method constructs SQL queries to accomplish IN/NOT IN |
| 97 | subquery filtering using explicit joins. It does this by |
| 98 | LEFT JOINing onto the subquery and then checking to see if |
| 99 | the joined column is NULL or not. |
| 100 | |
| 101 | We use explicit joins instead of the SQL IN operator because |
| 102 | MySQL (at least some versions) considers all IN subqueries to be |
| 103 | dependent, so using explicit joins can be MUCH faster. |
| 104 | |
| 105 | The query we're going for is: |
| 106 | SELECT * FROM <this table> |
| 107 | LEFT JOIN (<subquery>) AS <subquery_alias> |
| 108 | ON <subquery_alias>.<subquery_key> = |
| 109 | <this table>.<this_table_key> |
| 110 | WHERE <subquery_alias>.<subquery_key> IS [NOT] NULL |
| 111 | """ |
| 112 | subselect, params = self._get_sql_query_for(subquery, |
| 113 | subquery_key) |
| 114 | |
| 115 | this_full_key = self._get_key_on_this_table(this_table_key) |
| 116 | alias_full_key = self._get_quoted_field(subquery_alias, |
| 117 | subquery_key) |
| 118 | join_condition = alias_full_key + ' = ' + this_full_key |
| 119 | joins = {subquery_alias : (subselect, # join table |
| 120 | 'LEFT JOIN', # join type |
| 121 | join_condition)} # join on |
| 122 | |
| 123 | if not_in: |
| 124 | where_sql = alias_full_key + ' IS NULL' |
| 125 | else: |
| 126 | where_sql = alias_full_key + ' IS NOT NULL' |
| 127 | filter_obj = self._RawSqlQ(where_sql, params, joins) |
| 128 | return self.complex_filter(filter_obj) |
| 129 | |
| 130 | |
| 131 | def filter_in_subquery(self, subquery_key, subquery, subquery_alias, |
| 132 | this_table_key=None): |
| 133 | """\ |
| 134 | Construct a filter to perform a subquery match, i.e. |
| 135 | WHERE id IN (SELECT host_id FROM ... WHERE ...) |
| 136 | -subquery_key - the field to select in the subquery (host_id |
| 137 | above) |
| 138 | -subquery - a query object for the subquery |
| 139 | -subquery_alias - a logical name for the query, to be used in |
| 140 | the SQL (i.e. 'valid_hosts') |
| 141 | -this_table_key - the field to match (id above). Defaults to |
| 142 | this table's primary key. |
| 143 | """ |
| 144 | return self._do_subquery_filter(subquery_key, subquery, |
| 145 | subquery_alias, this_table_key) |
| 146 | |
| 147 | |
| 148 | def filter_not_in_subquery(self, subquery_key, subquery, |
| 149 | subquery_alias, this_table_key=None): |
| 150 | 'Like filter_in_subquery, but use NOT IN rather than IN.' |
| 151 | return self._do_subquery_filter(subquery_key, subquery, |
| 152 | subquery_alias, this_table_key, |
| 153 | not_in=True) |
| 154 | |
| 155 | |
| 156 | def create_in_bulk(self, fields, values): |
| 157 | """ |
| 158 | Creates many objects with a single SQL query. |
| 159 | field - list of field names (model attributes, not actual DB |
| 160 | field names) for which values will be specified. |
| 161 | values - list of tuples containing values. Each tuple contains |
| 162 | the values for the specified fields for a single |
| 163 | object. |
| 164 | Example: Host.objects.create_in_bulk(['hostname', 'status'], |
| 165 | [('host1', 'Ready'), ('host2', 'Running')]) |
| 166 | """ |
| 167 | if not values: |
| 168 | return |
| 169 | field_dict = self.model.get_field_dict() |
| 170 | field_names = [field_dict[field].column for field in fields] |
| 171 | sql = 'INSERT INTO %s %s' % ( |
| 172 | self.model._meta.db_table, |
| 173 | self._get_sql_string_for(tuple(field_names))) |
| 174 | sql += ' VALUES ' + self._get_sql_string_for(list(values)) |
| 175 | cursor = connection.cursor() |
| 176 | cursor.execute(sql) |
| 177 | connection._commit() |
| 178 | |
| 179 | |
| 180 | def delete_in_bulk(self, ids): |
| 181 | """ |
| 182 | Deletes many objects with a single SQL query. ids should be a |
| 183 | list of object ids to delete. Nonexistent ids will be silently |
| 184 | ignored. |
| 185 | """ |
| 186 | if not ids: |
| 187 | return |
| 188 | sql = 'DELETE FROM %s WHERE id IN %s' % ( |
| 189 | self.model._meta.db_table, |
| 190 | self._get_sql_string_for(tuple(ids))) |
| 191 | cursor = connection.cursor() |
| 192 | cursor.execute(sql) |
| 193 | connection._commit() |
| 194 | |
| 195 | |
| 196 | class ValidObjectsManager(ExtendedManager): |
| 197 | """ |
| 198 | Manager returning only objects with invalid=False. |
| 199 | """ |
| 200 | def get_query_set(self): |
| 201 | queryset = super(ValidObjectsManager, self).get_query_set() |
| 202 | return queryset.filter(invalid=False) |
| 203 | |
| 204 | |
| 205 | class ModelExtensions(object): |
| 206 | """\ |
| 207 | Mixin with convenience functions for models, built on top of the |
| 208 | default Django model functions. |
| 209 | """ |
| 210 | # TODO: at least some of these functions really belong in a custom |
| 211 | # Manager class |
| 212 | |
| 213 | field_dict = None |
| 214 | # subclasses should override if they want to support smart_get() by name |
| 215 | name_field = None |
| 216 | |
| 217 | |
| 218 | @classmethod |
| 219 | def get_field_dict(cls): |
| 220 | if cls.field_dict is None: |
| 221 | cls.field_dict = {} |
| 222 | for field in cls._meta.fields: |
| 223 | cls.field_dict[field.name] = field |
| 224 | return cls.field_dict |
| 225 | |
| 226 | |
| 227 | @classmethod |
| 228 | def clean_foreign_keys(cls, data): |
| 229 | """\ |
| 230 | -Convert foreign key fields in data from <field>_id to just |
| 231 | <field>. |
| 232 | -replace foreign key objects with their IDs |
| 233 | This method modifies data in-place. |
| 234 | """ |
| 235 | for field in cls._meta.fields: |
| 236 | if not field.rel: |
| 237 | continue |
| 238 | if (field.attname != field.name and |
| 239 | field.attname in data): |
| 240 | data[field.name] = data[field.attname] |
| 241 | del data[field.attname] |
| 242 | value = data[field.name] |
| 243 | if isinstance(value, dbmodels.Model): |
| 244 | data[field.name] = value.id |
| 245 | |
| 246 | |
| 247 | # TODO(showard) - is there a way to not have to do this? |
| 248 | @classmethod |
| 249 | def provide_default_values(cls, data): |
| 250 | """\ |
| 251 | Provide default values for fields with default values which have |
| 252 | nothing passed in. |
| 253 | |
| 254 | For CharField and TextField fields with "blank=True", if nothing |
| 255 | is passed, we fill in an empty string value, even if there's no |
| 256 | default set. |
| 257 | """ |
| 258 | new_data = dict(data) |
| 259 | field_dict = cls.get_field_dict() |
| 260 | for name, obj in field_dict.iteritems(): |
| 261 | if data.get(name) is not None: |
| 262 | continue |
| 263 | if obj.default is not dbmodels.fields.NOT_PROVIDED: |
| 264 | new_data[name] = obj.default |
| 265 | elif (isinstance(obj, dbmodels.CharField) or |
| 266 | isinstance(obj, dbmodels.TextField)): |
| 267 | new_data[name] = '' |
| 268 | return new_data |
| 269 | |
| 270 | |
| 271 | @classmethod |
| 272 | def convert_human_readable_values(cls, data, to_human_readable=False): |
| 273 | """\ |
| 274 | Performs conversions on user-supplied field data, to make it |
| 275 | easier for users to pass human-readable data. |
| 276 | |
| 277 | For all fields that have choice sets, convert their values |
| 278 | from human-readable strings to enum values, if necessary. This |
| 279 | allows users to pass strings instead of the corresponding |
| 280 | integer values. |
| 281 | |
| 282 | For all foreign key fields, call smart_get with the supplied |
| 283 | data. This allows the user to pass either an ID value or |
| 284 | the name of the object as a string. |
| 285 | |
| 286 | If to_human_readable=True, perform the inverse - i.e. convert |
| 287 | numeric values to human readable values. |
| 288 | |
| 289 | This method modifies data in-place. |
| 290 | """ |
| 291 | field_dict = cls.get_field_dict() |
| 292 | for field_name in data: |
| 293 | if data[field_name] is None: |
| 294 | continue |
| 295 | field_obj = field_dict[field_name] |
| 296 | # convert enum values |
| 297 | if field_obj.choices: |
| 298 | for choice_data in field_obj.choices: |
| 299 | # choice_data is (value, name) |
| 300 | if to_human_readable: |
| 301 | from_val, to_val = choice_data |
| 302 | else: |
| 303 | to_val, from_val = choice_data |
| 304 | if from_val == data[field_name]: |
| 305 | data[field_name] = to_val |
| 306 | break |
| 307 | # convert foreign key values |
| 308 | elif field_obj.rel: |
| 309 | dest_obj = field_obj.rel.to.smart_get( |
| 310 | data[field_name]) |
| 311 | if (to_human_readable and |
| 312 | dest_obj.name_field is not None): |
| 313 | data[field_name] = ( |
| 314 | getattr(dest_obj, |
| 315 | dest_obj.name_field)) |
| 316 | else: |
| 317 | data[field_name] = dest_obj.id |
| 318 | |
| 319 | |
| 320 | @classmethod |
| 321 | def validate_field_names(cls, data): |
| 322 | 'Checks for extraneous fields in data.' |
| 323 | errors = {} |
| 324 | field_dict = cls.get_field_dict() |
| 325 | for field_name in data: |
| 326 | if field_name not in field_dict: |
| 327 | errors[field_name] = 'No field of this name' |
| 328 | return errors |
| 329 | |
| 330 | |
| 331 | @classmethod |
| 332 | def prepare_data_args(cls, data, kwargs): |
| 333 | 'Common preparation for add_object and update_object' |
| 334 | data = dict(data) # don't modify the default keyword arg |
| 335 | data.update(kwargs) |
| 336 | # must check for extraneous field names here, while we have the |
| 337 | # data in a dict |
| 338 | errors = cls.validate_field_names(data) |
| 339 | if errors: |
| 340 | raise ValidationError(errors) |
| 341 | cls.convert_human_readable_values(data) |
| 342 | return data |
| 343 | |
| 344 | |
| 345 | def validate_unique(self): |
| 346 | """\ |
| 347 | Validate that unique fields are unique. Django manipulators do |
| 348 | this too, but they're a huge pain to use manually. Trust me. |
| 349 | """ |
| 350 | errors = {} |
| 351 | cls = type(self) |
| 352 | field_dict = self.get_field_dict() |
| 353 | manager = cls.get_valid_manager() |
| 354 | for field_name, field_obj in field_dict.iteritems(): |
| 355 | if not field_obj.unique: |
| 356 | continue |
| 357 | |
| 358 | value = getattr(self, field_name) |
| 359 | existing_objs = manager.filter(**{field_name : value}) |
| 360 | num_existing = existing_objs.count() |
| 361 | |
| 362 | if num_existing == 0: |
| 363 | continue |
| 364 | if num_existing == 1 and existing_objs[0].id == self.id: |
| 365 | continue |
| 366 | errors[field_name] = ( |
| 367 | 'This value must be unique (%s)' % (value)) |
| 368 | return errors |
| 369 | |
| 370 | |
| 371 | def do_validate(self): |
| 372 | errors = self.validate() |
| 373 | unique_errors = self.validate_unique() |
| 374 | for field_name, error in unique_errors.iteritems(): |
| 375 | errors.setdefault(field_name, error) |
| 376 | if errors: |
| 377 | raise ValidationError(errors) |
| 378 | |
| 379 | |
| 380 | # actually (externally) useful methods follow |
| 381 | |
| 382 | @classmethod |
| 383 | def add_object(cls, data={}, **kwargs): |
| 384 | """\ |
| 385 | Returns a new object created with the given data (a dictionary |
| 386 | mapping field names to values). Merges any extra keyword args |
| 387 | into data. |
| 388 | """ |
| 389 | data = cls.prepare_data_args(data, kwargs) |
| 390 | data = cls.provide_default_values(data) |
| 391 | obj = cls(**data) |
| 392 | obj.do_validate() |
| 393 | obj.save() |
| 394 | return obj |
| 395 | |
| 396 | |
| 397 | def update_object(self, data={}, **kwargs): |
| 398 | """\ |
| 399 | Updates the object with the given data (a dictionary mapping |
| 400 | field names to values). Merges any extra keyword args into |
| 401 | data. |
| 402 | """ |
| 403 | data = self.prepare_data_args(data, kwargs) |
| 404 | for field_name, value in data.iteritems(): |
| 405 | if value is not None: |
| 406 | setattr(self, field_name, value) |
| 407 | self.do_validate() |
| 408 | self.save() |
| 409 | |
| 410 | |
| 411 | @classmethod |
| 412 | def query_objects(cls, filter_data, valid_only=True): |
| 413 | """\ |
| 414 | Returns a QuerySet object for querying the given model_class |
| 415 | with the given filter_data. Optional special arguments in |
| 416 | filter_data include: |
| 417 | -query_start: index of first return to return |
| 418 | -query_limit: maximum number of results to return |
| 419 | -sort_by: list of fields to sort on. prefixing a '-' onto a |
| 420 | field name changes the sort to descending order. |
| 421 | -extra_args: keyword args to pass to query.extra() (see Django |
| 422 | DB layer documentation) |
showard | acdbe35 | 2008-06-05 23:46:50 +0000 | [diff] [blame] | 423 | -extra_where: extra WHERE clause to append |
showard | 7c78528 | 2008-05-29 19:45:12 +0000 | [diff] [blame] | 424 | """ |
| 425 | query_start = filter_data.pop('query_start', None) |
| 426 | query_limit = filter_data.pop('query_limit', None) |
| 427 | if query_start and not query_limit: |
| 428 | raise ValueError('Cannot pass query_start without ' |
| 429 | 'query_limit') |
| 430 | sort_by = filter_data.pop('sort_by', []) |
showard | acdbe35 | 2008-06-05 23:46:50 +0000 | [diff] [blame] | 431 | extra_args = filter_data.pop('extra_args', {}) |
| 432 | extra_where = filter_data.pop('extra_where', None) |
| 433 | if extra_where: |
| 434 | extra_args.setdefault('where', []).append(extra_where) |
showard | 7c78528 | 2008-05-29 19:45:12 +0000 | [diff] [blame] | 435 | |
| 436 | # filters |
| 437 | query_dict = {} |
| 438 | for field, value in filter_data.iteritems(): |
| 439 | query_dict[field] = value |
| 440 | if valid_only: |
| 441 | manager = cls.get_valid_manager() |
| 442 | else: |
| 443 | manager = cls.objects |
| 444 | query = manager.filter(**query_dict).distinct() |
| 445 | |
| 446 | # other arguments |
| 447 | if extra_args: |
| 448 | query = query.extra(**extra_args) |
| 449 | |
| 450 | # sorting + paging |
| 451 | assert isinstance(sort_by, list) or isinstance(sort_by, tuple) |
| 452 | query = query.order_by(*sort_by) |
| 453 | if query_start is not None and query_limit is not None: |
| 454 | query_limit += query_start |
| 455 | return query[query_start:query_limit] |
| 456 | |
| 457 | |
| 458 | @classmethod |
| 459 | def query_count(cls, filter_data): |
| 460 | """\ |
| 461 | Like query_objects, but retreive only the count of results. |
| 462 | """ |
| 463 | filter_data.pop('query_start', None) |
| 464 | filter_data.pop('query_limit', None) |
| 465 | return cls.query_objects(filter_data).count() |
| 466 | |
| 467 | |
| 468 | @classmethod |
| 469 | def clean_object_dicts(cls, field_dicts): |
| 470 | """\ |
| 471 | Take a list of dicts corresponding to object (as returned by |
| 472 | query.values()) and clean the data to be more suitable for |
| 473 | returning to the user. |
| 474 | """ |
| 475 | for i in range(len(field_dicts)): |
| 476 | cls.clean_foreign_keys(field_dicts[i]) |
| 477 | cls.convert_human_readable_values( |
| 478 | field_dicts[i], to_human_readable=True) |
| 479 | |
| 480 | |
| 481 | @classmethod |
| 482 | def list_objects(cls, filter_data): |
| 483 | """\ |
| 484 | Like query_objects, but return a list of dictionaries. |
| 485 | """ |
| 486 | query = cls.query_objects(filter_data) |
| 487 | field_dicts = list(query.values()) |
| 488 | cls.clean_object_dicts(field_dicts) |
| 489 | return field_dicts |
| 490 | |
| 491 | |
| 492 | @classmethod |
| 493 | def smart_get(cls, *args, **kwargs): |
| 494 | """\ |
| 495 | smart_get(integer) -> get object by ID |
| 496 | smart_get(string) -> get object by name_field |
| 497 | smart_get(keyword args) -> normal ModelClass.objects.get() |
| 498 | """ |
| 499 | assert bool(args) ^ bool(kwargs) |
| 500 | if args: |
| 501 | assert len(args) == 1 |
| 502 | arg = args[0] |
| 503 | if isinstance(arg, int) or isinstance(arg, long): |
| 504 | return cls.objects.get(id=arg) |
| 505 | if isinstance(arg, str) or isinstance(arg, unicode): |
| 506 | return cls.objects.get( |
| 507 | **{cls.name_field : arg}) |
| 508 | raise ValueError( |
| 509 | 'Invalid positional argument: %s (%s)' % ( |
| 510 | str(arg), type(arg))) |
| 511 | return cls.objects.get(**kwargs) |
| 512 | |
| 513 | |
| 514 | def get_object_dict(self): |
| 515 | """\ |
| 516 | Return a dictionary mapping fields to this object's values. |
| 517 | """ |
| 518 | object_dict = dict((field_name, getattr(self, field_name)) |
| 519 | for field_name |
| 520 | in self.get_field_dict().iterkeys()) |
| 521 | self.clean_object_dicts([object_dict]) |
| 522 | return object_dict |
| 523 | |
| 524 | |
| 525 | @classmethod |
| 526 | def get_valid_manager(cls): |
| 527 | return cls.objects |
| 528 | |
| 529 | |
| 530 | class ModelWithInvalid(ModelExtensions): |
| 531 | """ |
| 532 | Overrides model methods save() and delete() to support invalidation in |
| 533 | place of actual deletion. Subclasses must have a boolean "invalid" |
| 534 | field. |
| 535 | """ |
| 536 | |
| 537 | def save(self): |
| 538 | # see if this object was previously added and invalidated |
| 539 | my_name = getattr(self, self.name_field) |
| 540 | filters = {self.name_field : my_name, 'invalid' : True} |
| 541 | try: |
| 542 | old_object = self.__class__.objects.get(**filters) |
| 543 | except self.DoesNotExist: |
| 544 | # no existing object |
| 545 | super(ModelWithInvalid, self).save() |
| 546 | return |
| 547 | |
| 548 | self.id = old_object.id |
| 549 | super(ModelWithInvalid, self).save() |
| 550 | |
| 551 | |
| 552 | def clean_object(self): |
| 553 | """ |
| 554 | This method is called when an object is marked invalid. |
| 555 | Subclasses should override this to clean up relationships that |
| 556 | should no longer exist if the object were deleted.""" |
| 557 | pass |
| 558 | |
| 559 | |
| 560 | def delete(self): |
| 561 | assert not self.invalid |
| 562 | self.invalid = True |
| 563 | self.save() |
| 564 | self.clean_object() |
| 565 | |
| 566 | |
| 567 | @classmethod |
| 568 | def get_valid_manager(cls): |
| 569 | return cls.valid_objects |
| 570 | |
| 571 | |
| 572 | class Manipulator(object): |
| 573 | """ |
| 574 | Force default manipulators to look only at valid objects - |
| 575 | otherwise they will match against invalid objects when checking |
| 576 | uniqueness. |
| 577 | """ |
| 578 | @classmethod |
| 579 | def _prepare(cls, model): |
| 580 | super(ModelWithInvalid.Manipulator, cls)._prepare(model) |
| 581 | cls.manager = model.valid_objects |