New code for performing explicit joins with custom join conditions.
* added ExtendedManager.join_custom_field(), which uses the introspection magic from populate_relationships (now factored out) to infer the type of relationship between two models and construct the correct join. join_custom_field() presents a much simpler, more Django-y interface for doing this sort of thing -- compare with add_join() above it.
* changed TKO custom fields code to use join_custom_field()
* added some cases to AFE rpc_interface_unittest to ensure populate_relationships() usage didn't break
* simplified _CustomQuery and got rid of _CustomSqlQ. _CustomQuery can do the work itself and its cleaner this way.
* added add_where(), an alternative to extra(where=...) that fits more into Django's normal representation of WHERE clauses, and therefore supports & and | operators later
Signed-off-by: Steve Howard <showard@google.com>
git-svn-id: http://test.kernel.org/svn/autotest/trunk@4155 592f7852-d20e-0410-864c-8624ca9c26a4
diff --git a/frontend/afe/model_logic.py b/frontend/afe/model_logic.py
index 3b5c70e..7fbdb76 100644
--- a/frontend/afe/model_logic.py
+++ b/frontend/afe/model_logic.py
@@ -6,6 +6,7 @@
import django.core.exceptions
from django.db import models as dbmodels, backend, connection
from django.db.models.sql import query
+import django.db.models.sql.where
from django.utils import datastructures
from autotest_lib.frontend.afe import readonly_connection
@@ -94,77 +95,97 @@
"""
class _CustomQuery(query.Query):
+ def __init__(self, *args, **kwargs):
+ super(ExtendedManager._CustomQuery, self).__init__(*args, **kwargs)
+ self._custom_joins = []
+
+
def clone(self, klass=None, **kwargs):
- obj = super(ExtendedManager._CustomQuery, self).clone(
- klass, _customSqlQ=self._customSqlQ)
-
- customQ = kwargs.get('_customSqlQ', None)
- if customQ is not None:
- obj._customSqlQ._joins.update(customQ._joins)
- obj._customSqlQ._where.extend(customQ._where)
- obj._customSqlQ._params.extend(customQ._params)
-
+ obj = super(ExtendedManager._CustomQuery, self).clone(klass)
+ obj._custom_joins = list(self._custom_joins)
return obj
+
+ def combine(self, rhs, connector):
+ super(ExtendedManager._CustomQuery, self).combine(rhs, connector)
+ if hasattr(rhs, '_custom_joins'):
+ self._custom_joins.extend(rhs._custom_joins)
+
+
+ def add_custom_join(self, table, condition, join_type,
+ condition_values=(), alias=None):
+ if alias is None:
+ alias = table
+ join_dict = dict(table=table,
+ condition=condition,
+ condition_values=condition_values,
+ join_type=join_type,
+ alias=alias)
+ self._custom_joins.append(join_dict)
+
+
def get_from_clause(self):
- from_, params = super(
- ExtendedManager._CustomQuery, self).get_from_clause()
+ from_, params = (super(ExtendedManager._CustomQuery, self)
+ .get_from_clause())
- join_clause = ''
- for join_alias, join in self._customSqlQ._joins.iteritems():
- join_table, join_type, condition = join
- join_clause += ' %s %s AS %s ON (%s)' % (
- join_type, _quote_name(join_table),
- _quote_name(join_alias), condition)
-
- if join_clause:
- from_.append(join_clause)
+ for join_dict in self._custom_joins:
+ from_.append('%s %s AS %s ON (%s)'
+ % (join_dict['join_type'],
+ _quote_name(join_dict['table']),
+ _quote_name(join_dict['alias']),
+ join_dict['condition']))
+ params.extend(join_dict['condition_values'])
return from_, params
- class _CustomSqlQ(dbmodels.Q):
- def __init__(self):
- self._joins = datastructures.SortedDict()
- self._where, self._params = [], []
+ @classmethod
+ def convert_query(self, query_set):
+ """
+ Convert the query set's "query" attribute to a _CustomQuery.
+ """
+ # Make a copy of the query set
+ query_set = query_set.all()
+ query_set.query = query_set.query.clone(
+ klass=ExtendedManager._CustomQuery,
+ _custom_joins=[])
+ return query_set
- def add_join(self, table, condition, join_type, alias=None):
- if alias is None:
- alias = table
- self._joins[alias] = (table, join_type, condition)
+ class _WhereClause(object):
+ """Object allowing us to inject arbitrary SQL into Django queries.
-
- def add_where(self, where, params=[]):
- self._where.append(where)
- self._params.extend(params)
-
-
- def add_to_query(self, query, aliases):
- if self._where:
- where = ' AND '.join(self._where)
- query.add_extra(None, None, (where,), self._params, None, None)
-
-
- def _add_customSqlQ(self, query_set, filter_object):
- """\
- Add a _CustomSqlQ to the query set.
+ By using this instead of extra(where=...), we can still freely combine
+ queries with & and |.
"""
- # Make a copy of the query set
- query_set = query_set.all()
+ def __init__(self, clause, values=()):
+ self._clause = clause
+ self._values = values
- query_set.query = query_set.query.clone(
- ExtendedManager._CustomQuery, _customSqlQ=filter_object)
- return query_set.filter(filter_object)
+
+ def as_sql(self, qn=None):
+ return self._clause, self._values
+
+
+ def relabel_aliases(self, change_map):
+ return
def add_join(self, query_set, join_table, join_key, join_condition='',
- alias=None, suffix='', exclude=False, force_left_join=False):
- """
- Add a join to query_set.
+ join_condition_values=(), join_from_key=None, alias=None,
+ suffix='', exclude=False, force_left_join=False):
+ """Add a join to query_set.
+
+ Join looks like this:
+ (INNER|LEFT) JOIN <join_table> AS <alias>
+ ON (<this table>.<join_from_key> = <join_table>.<join_key>
+ and <join_condition>)
+
@param join_table table to join to
@param join_key field referencing back to this model to use for the join
@param join_condition extra condition for the ON clause of the join
+ @param join_condition_values values to substitute into join_condition
+ @param join_from_key column on this model to join from.
@param alias alias to use for for join
@param suffix suffix to add to join_table for the join alias, if no
alias is provided
@@ -173,15 +194,15 @@
@param force_left_join - if true, a LEFT OUTER JOIN will be used
instead of an INNER JOIN regardless of other options
"""
- join_from_table = _quote_name(self.model._meta.db_table)
- join_from_key = _quote_name(self.model._meta.pk.name)
- if alias:
- join_alias = alias
- else:
- join_alias = join_table + suffix
- full_join_key = _quote_name(join_alias) + '.' + _quote_name(join_key)
- full_join_condition = '%s = %s.%s' % (full_join_key, join_from_table,
- join_from_key)
+ join_from_table = query_set.model._meta.db_table
+ if join_from_key is None:
+ join_from_key = self.model._meta.pk.name
+ if alias is None:
+ alias = join_table + suffix
+ full_join_key = _quote_name(alias) + '.' + _quote_name(join_key)
+ full_join_condition = '%s = %s.%s' % (full_join_key,
+ _quote_name(join_from_table),
+ _quote_name(join_from_key))
if join_condition:
full_join_condition += ' AND (' + join_condition + ')'
if exclude or force_left_join:
@@ -189,15 +210,128 @@
else:
join_type = query_set.query.INNER
- filter_object = self._CustomSqlQ()
- filter_object.add_join(join_table,
- full_join_condition,
- join_type,
- alias=join_alias)
- if exclude:
- filter_object.add_where(full_join_key + ' IS NULL')
+ query_set = self._CustomQuery.convert_query(query_set)
+ query_set.query.add_custom_join(join_table,
+ full_join_condition,
+ join_type,
+ condition_values=join_condition_values,
+ alias=alias)
- query_set = self._add_customSqlQ(query_set, filter_object)
+ if exclude:
+ query_set = query_set.extra(where=[full_join_key + ' IS NULL'])
+
+ return query_set
+
+
+ def _info_for_many_to_one_join(self, field, join_to_query, alias):
+ """
+ @param field: the ForeignKey field on the related model
+ @param join_to_query: the query over the related model that we're
+ joining to
+ @param alias: alias of joined table
+ """
+ info = {}
+ rhs_table = join_to_query.model._meta.db_table
+ info['rhs_table'] = rhs_table
+ info['rhs_column'] = field.column
+ info['lhs_column'] = field.rel.get_related_field().column
+ rhs_where = join_to_query.query.where
+ rhs_where.relabel_aliases({rhs_table: alias})
+ initial_clause, values = rhs_where.as_sql()
+ all_clauses = (initial_clause,) + join_to_query.query.extra_where
+ info['where_clause'] = ' AND '.join('(%s)' % clause
+ for clause in all_clauses)
+ values += join_to_query.query.extra_params
+ info['values'] = values
+ return info
+
+
+ def _info_for_many_to_many_join(self, m2m_field, join_to_query, alias,
+ m2m_is_on_this_model):
+ """
+ @param m2m_field: a Django field representing the M2M relationship.
+ It uses a pivot table with the following structure:
+ this model table <---> M2M pivot table <---> joined model table
+ @param join_to_query: the query over the related model that we're
+ joining to.
+ @param alias: alias of joined table
+ """
+ if m2m_is_on_this_model:
+ # referenced field on this model
+ lhs_id_field = self.model._meta.pk
+ # foreign key on the pivot table referencing lhs_id_field
+ m2m_lhs_column = m2m_field.m2m_column_name()
+ # foreign key on the pivot table referencing rhd_id_field
+ m2m_rhs_column = m2m_field.m2m_reverse_name()
+ # referenced field on related model
+ rhs_id_field = m2m_field.rel.get_related_field()
+ else:
+ lhs_id_field = m2m_field.rel.get_related_field()
+ m2m_lhs_column = m2m_field.m2m_reverse_name()
+ m2m_rhs_column = m2m_field.m2m_column_name()
+ rhs_id_field = join_to_query.model._meta.pk
+
+ info = {}
+ info['rhs_table'] = m2m_field.m2m_db_table()
+ info['rhs_column'] = m2m_lhs_column
+ info['lhs_column'] = lhs_id_field.column
+
+ # select the ID of related models relevant to this join. we can only do
+ # a single join, so we need to gather this information up front and
+ # include it in the join condition.
+ rhs_ids = join_to_query.values_list(rhs_id_field.attname, flat=True)
+ assert len(rhs_ids) == 1, ('Many-to-many custom field joins can only '
+ 'match a single related object.')
+ rhs_id = rhs_ids[0]
+
+ info['where_clause'] = '%s.%s = %s' % (_quote_name(alias),
+ _quote_name(m2m_rhs_column),
+ rhs_id)
+ info['values'] = ()
+ return info
+
+
+ def join_custom_field(self, query_set, join_to_query, alias,
+ left_join=True):
+ """Join to a related model to create a custom field in the given query.
+
+ This method is used to construct a custom field on the given query based
+ on a many-valued relationsip. join_to_query should be a simple query
+ (no joins) on the related model which returns at most one related row
+ per instance of this model.
+
+ For many-to-one relationships, the joined table contains the matching
+ row from the related model it one is related, NULL otherwise.
+
+ For many-to-many relationships, the joined table contains the matching
+ row if it's related, NULL otherwise.
+ """
+ relationship_type, field = self.determine_relationship(
+ join_to_query.model)
+
+ if relationship_type == self.MANY_TO_ONE:
+ info = self._info_for_many_to_one_join(field, join_to_query, alias)
+ elif relationship_type == self.M2M_ON_RELATED_MODEL:
+ info = self._info_for_many_to_many_join(
+ m2m_field=field, join_to_query=join_to_query, alias=alias,
+ m2m_is_on_this_model=False)
+ elif relationship_type ==self.M2M_ON_THIS_MODEL:
+ info = self._info_for_many_to_many_join(
+ m2m_field=field, join_to_query=join_to_query, alias=alias,
+ m2m_is_on_this_model=True)
+
+ return self.add_join(query_set, info['rhs_table'], info['rhs_column'],
+ join_from_key=info['lhs_column'],
+ join_condition=info['where_clause'],
+ join_condition_values=info['values'],
+ alias=alias,
+ force_left_join=left_join)
+
+
+ def add_where(self, query_set, where, values=()):
+ query_set = query_set.all()
+ query_set.query.where.add(self._WhereClause(where, values),
+ django.db.models.sql.where.AND)
return query_set
@@ -235,6 +369,39 @@
return field.rel and field.rel.to is model_class
+ MANY_TO_ONE = object()
+ M2M_ON_RELATED_MODEL = object()
+ M2M_ON_THIS_MODEL = object()
+
+ def determine_relationship(self, related_model):
+ """
+ Determine the relationship between this model and related_model.
+
+ related_model must have some sort of many-valued relationship to this
+ manager's model.
+ @returns (relationship_type, field), where relationship_type is one of
+ MANY_TO_ONE, M2M_ON_RELATED_MODEL, M2M_ON_THIS_MODEL, and field
+ is the Django field object for the relationship.
+ """
+ # look for a foreign key field on related_model relating to this model
+ for field in related_model._meta.fields:
+ if self._is_relation_to(field, self.model):
+ return self.MANY_TO_ONE, field
+
+ # look for an M2M field on related_model relating to this model
+ for field in related_model._meta.many_to_many:
+ if self._is_relation_to(field, self.model):
+ return self.M2M_ON_RELATED_MODEL, field
+
+ # maybe this model has the many-to-many field
+ for field in self.model._meta.many_to_many:
+ if self._is_relation_to(field, related_model):
+ return self.M2M_ON_THIS_MODEL, field
+
+ raise ValueError('%s has no relation to %s' %
+ (related_model, self.model))
+
+
def _get_pivot_iterator(self, base_objects_by_id, related_model):
"""
Determine the relationship between this model and related_model, and
@@ -244,33 +411,22 @@
@returns a pivot iterator, which yields a tuple (base_object,
related_object) for each relationship between a base object and a
related object. all base_object instances come from base_objects_by_id.
- Note -- this depends on Django model internals and will likely need to
- be updated when we move to Django 1.x.
+ Note -- this depends on Django model internals.
"""
- # look for a field on related_model relating to this model
- for field in related_model._meta.fields:
- if self._is_relation_to(field, self.model):
- # many-to-one
- return self._many_to_one_pivot(base_objects_by_id,
- related_model, field)
-
- for field in related_model._meta.many_to_many:
- if self._is_relation_to(field, self.model):
- # many-to-many
- return self._many_to_many_pivot(
+ relationship_type, field = self.determine_relationship(related_model)
+ if relationship_type == self.MANY_TO_ONE:
+ return self._many_to_one_pivot(base_objects_by_id,
+ related_model, field)
+ elif relationship_type == self.M2M_ON_RELATED_MODEL:
+ return self._many_to_many_pivot(
base_objects_by_id, related_model, field.m2m_db_table(),
field.m2m_reverse_name(), field.m2m_column_name())
-
- # maybe this model has the many-to-many field
- for field in self.model._meta.many_to_many:
- if self._is_relation_to(field, related_model):
- return self._many_to_many_pivot(
+ else:
+ assert relationship_type == self.M2M_ON_THIS_MODEL
+ return self._many_to_many_pivot(
base_objects_by_id, related_model, field.m2m_db_table(),
field.m2m_column_name(), field.m2m_reverse_name())
- raise ValueError('%s has no relation to %s' %
- (related_model, self.model))
-
def _many_to_one_pivot(self, base_objects_by_id, related_model,
foreign_key_field):