Check multiplication validity in ParseContext
This improves separation of responsibilities in the code: ParseContext
should handle operand type validation, while TIntermBinary::promote
should ideally only determine the type of the node based on the
operation and operands.
BUG=angleproject:952
TEST=angle_unittests
Change-Id: I9a8d8ede21cdf35de631623a62194c0da5c604d2
Reviewed-on: https://chromium-review.googlesource.com/372622
Commit-Queue: Olli Etuaho <oetuaho@nvidia.com>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
diff --git a/src/compiler/translator/IntermNode.cpp b/src/compiler/translator/IntermNode.cpp
index 83f8e00..2cd5dee 100644
--- a/src/compiler/translator/IntermNode.cpp
+++ b/src/compiler/translator/IntermNode.cpp
@@ -34,41 +34,6 @@
return left > right ? left : right;
}
-bool ValidateMultiplication(TOperator op, const TType &left, const TType &right)
-{
- switch (op)
- {
- case EOpMul:
- case EOpMulAssign:
- return left.getNominalSize() == right.getNominalSize() &&
- left.getSecondarySize() == right.getSecondarySize();
- case EOpVectorTimesScalar:
- case EOpVectorTimesScalarAssign:
- return true;
- case EOpVectorTimesMatrix:
- return left.getNominalSize() == right.getRows();
- case EOpVectorTimesMatrixAssign:
- return left.getNominalSize() == right.getRows() &&
- left.getNominalSize() == right.getCols();
- case EOpMatrixTimesVector:
- return left.getCols() == right.getNominalSize();
- case EOpMatrixTimesScalar:
- case EOpMatrixTimesScalarAssign:
- return true;
- case EOpMatrixTimesMatrix:
- return left.getCols() == right.getRows();
- case EOpMatrixTimesMatrixAssign:
- // We need to check two things:
- // 1. The matrix multiplication step is valid.
- // 2. The result will have the same number of columns as the lvalue.
- return left.getCols() == right.getRows() && left.getCols() == right.getCols();
-
- default:
- UNREACHABLE();
- return false;
- }
-}
-
TConstantUnion *Vectorize(const TConstantUnion &constant, size_t size)
{
TConstantUnion *constUnion = new TConstantUnion[size];
@@ -513,6 +478,94 @@
}
}
+TOperator TIntermBinary::GetMulOpBasedOnOperands(const TType &left, const TType &right)
+{
+ if (left.isMatrix())
+ {
+ if (right.isMatrix())
+ {
+ return EOpMatrixTimesMatrix;
+ }
+ else
+ {
+ if (right.isVector())
+ {
+ return EOpMatrixTimesVector;
+ }
+ else
+ {
+ return EOpMatrixTimesScalar;
+ }
+ }
+ }
+ else
+ {
+ if (right.isMatrix())
+ {
+ if (left.isVector())
+ {
+ return EOpVectorTimesMatrix;
+ }
+ else
+ {
+ return EOpMatrixTimesScalar;
+ }
+ }
+ else
+ {
+ // Neither operand is a matrix.
+ if (left.isVector() == right.isVector())
+ {
+ // Leave as component product.
+ return EOpMul;
+ }
+ else
+ {
+ return EOpVectorTimesScalar;
+ }
+ }
+ }
+}
+
+TOperator TIntermBinary::GetMulAssignOpBasedOnOperands(const TType &left, const TType &right)
+{
+ if (left.isMatrix())
+ {
+ if (right.isMatrix())
+ {
+ return EOpMatrixTimesMatrixAssign;
+ }
+ else
+ {
+ // right should be scalar, but this may not be validated yet.
+ return EOpMatrixTimesScalarAssign;
+ }
+ }
+ else
+ {
+ if (right.isMatrix())
+ {
+ // Left should be a vector, but this may not be validated yet.
+ return EOpVectorTimesMatrixAssign;
+ }
+ else
+ {
+ // Neither operand is a matrix.
+ if (left.isVector() == right.isVector())
+ {
+ // Leave as component product.
+ return EOpMulAssign;
+ }
+ else
+ {
+ // left should be vector and right should be scalar, but this may not be validated
+ // yet.
+ return EOpVectorTimesScalarAssign;
+ }
+ }
+ }
+}
+
//
// Make sure the type of a unary operator is appropriate for its
// combination of operation and operand type.
@@ -570,6 +623,9 @@
{
ASSERT(mLeft->isArray() == mRight->isArray());
+ ASSERT(!isMultiplication() ||
+ mOp == GetMulOpBasedOnOperands(mLeft->getType(), mRight->getType()));
+
//
// Base assumption: just make the type the same as the left
// operand. Then only deviations from this need be coded.
@@ -633,204 +689,118 @@
// Can these two operands be combined?
//
TBasicType basicType = mLeft->getBasicType();
+
switch (mOp)
{
- case EOpMul:
- if (!mLeft->isMatrix() && mRight->isMatrix())
- {
- if (mLeft->isVector())
+ case EOpMul:
+ break;
+ case EOpMatrixTimesScalar:
+ if (mRight->isMatrix())
{
- mOp = EOpVectorTimesMatrix;
- setType(TType(basicType, higherPrecision, resultQualifier,
- static_cast<unsigned char>(mRight->getCols()), 1));
- }
- else
- {
- mOp = EOpMatrixTimesScalar;
setType(TType(basicType, higherPrecision, resultQualifier,
static_cast<unsigned char>(mRight->getCols()),
static_cast<unsigned char>(mRight->getRows())));
}
- }
- else if (mLeft->isMatrix() && !mRight->isMatrix())
- {
- if (mRight->isVector())
- {
- mOp = EOpMatrixTimesVector;
- setType(TType(basicType, higherPrecision, resultQualifier,
- static_cast<unsigned char>(mLeft->getRows()), 1));
- }
- else
- {
- mOp = EOpMatrixTimesScalar;
- }
- }
- else if (mLeft->isMatrix() && mRight->isMatrix())
- {
- mOp = EOpMatrixTimesMatrix;
+ break;
+ case EOpMatrixTimesVector:
+ setType(TType(basicType, higherPrecision, resultQualifier,
+ static_cast<unsigned char>(mLeft->getRows()), 1));
+ break;
+ case EOpMatrixTimesMatrix:
setType(TType(basicType, higherPrecision, resultQualifier,
static_cast<unsigned char>(mRight->getCols()),
static_cast<unsigned char>(mLeft->getRows())));
- }
- else if (!mLeft->isMatrix() && !mRight->isMatrix())
- {
- if (mLeft->isVector() && mRight->isVector())
- {
- // leave as component product
- }
- else if (mLeft->isVector() || mRight->isVector())
- {
- mOp = EOpVectorTimesScalar;
- setType(TType(basicType, higherPrecision, resultQualifier,
- static_cast<unsigned char>(nominalSize), 1));
- }
- }
- else
- {
- UNREACHABLE();
- return false;
- }
-
- if (!ValidateMultiplication(mOp, mLeft->getType(), mRight->getType()))
- {
- return false;
- }
- break;
-
- case EOpMulAssign:
- if (!mLeft->isMatrix() && mRight->isMatrix())
- {
- if (mLeft->isVector())
- {
- mOp = EOpVectorTimesMatrixAssign;
- }
- else
- {
- return false;
- }
- }
- else if (mLeft->isMatrix() && !mRight->isMatrix())
- {
- if (mRight->isVector())
- {
- return false;
- }
- else
- {
- mOp = EOpMatrixTimesScalarAssign;
- }
- }
- else if (mLeft->isMatrix() && mRight->isMatrix())
- {
- mOp = EOpMatrixTimesMatrixAssign;
+ break;
+ case EOpVectorTimesScalar:
setType(TType(basicType, higherPrecision, resultQualifier,
- static_cast<unsigned char>(mRight->getCols()),
- static_cast<unsigned char>(mLeft->getRows())));
- }
- else if (!mLeft->isMatrix() && !mRight->isMatrix())
- {
- if (mLeft->isVector() && mRight->isVector())
+ static_cast<unsigned char>(nominalSize), 1));
+ break;
+ case EOpVectorTimesMatrix:
+ setType(TType(basicType, higherPrecision, resultQualifier,
+ static_cast<unsigned char>(mRight->getCols()), 1));
+ break;
+ case EOpMulAssign:
+ case EOpVectorTimesScalarAssign:
+ case EOpVectorTimesMatrixAssign:
+ case EOpMatrixTimesScalarAssign:
+ case EOpMatrixTimesMatrixAssign:
+ ASSERT(mOp == GetMulAssignOpBasedOnOperands(mLeft->getType(), mRight->getType()));
+ break;
+ case EOpAssign:
+ case EOpInitialize:
+ // No more additional checks are needed.
+ ASSERT((mLeft->getNominalSize() == mRight->getNominalSize()) &&
+ (mLeft->getSecondarySize() == mRight->getSecondarySize()));
+ break;
+ case EOpAdd:
+ case EOpSub:
+ case EOpDiv:
+ case EOpIMod:
+ case EOpBitShiftLeft:
+ case EOpBitShiftRight:
+ case EOpBitwiseAnd:
+ case EOpBitwiseXor:
+ case EOpBitwiseOr:
+ case EOpAddAssign:
+ case EOpSubAssign:
+ case EOpDivAssign:
+ case EOpIModAssign:
+ case EOpBitShiftLeftAssign:
+ case EOpBitShiftRightAssign:
+ case EOpBitwiseAndAssign:
+ case EOpBitwiseXorAssign:
+ case EOpBitwiseOrAssign:
+ if ((mLeft->isMatrix() && mRight->isVector()) ||
+ (mLeft->isVector() && mRight->isMatrix()))
{
- // leave as component product
+ return false;
}
- else if (mLeft->isVector() || mRight->isVector())
+
+ // Are the sizes compatible?
+ if (mLeft->getNominalSize() != mRight->getNominalSize() ||
+ mLeft->getSecondarySize() != mRight->getSecondarySize())
{
- if (!mLeft->isVector())
+ // If the nominal sizes of operands do not match:
+ // One of them must be a scalar.
+ if (!mLeft->isScalar() && !mRight->isScalar())
return false;
- mOp = EOpVectorTimesScalarAssign;
- setType(TType(basicType, higherPrecision, resultQualifier,
- static_cast<unsigned char>(mLeft->getNominalSize()), 1));
+
+ // In the case of compound assignment other than multiply-assign,
+ // the right side needs to be a scalar. Otherwise a vector/matrix
+ // would be assigned to a scalar. A scalar can't be shifted by a
+ // vector either.
+ if (!mRight->isScalar() &&
+ (isAssignment() || mOp == EOpBitShiftLeft || mOp == EOpBitShiftRight))
+ return false;
}
- }
- else
- {
- UNREACHABLE();
- return false;
- }
- if (!ValidateMultiplication(mOp, mLeft->getType(), mRight->getType()))
- {
- return false;
- }
- break;
-
- case EOpAssign:
- case EOpInitialize:
- // No more additional checks are needed.
- ASSERT((mLeft->getNominalSize() == mRight->getNominalSize()) &&
- (mLeft->getSecondarySize() == mRight->getSecondarySize()));
- break;
- case EOpAdd:
- case EOpSub:
- case EOpDiv:
- case EOpIMod:
- case EOpBitShiftLeft:
- case EOpBitShiftRight:
- case EOpBitwiseAnd:
- case EOpBitwiseXor:
- case EOpBitwiseOr:
- case EOpAddAssign:
- case EOpSubAssign:
- case EOpDivAssign:
- case EOpIModAssign:
- case EOpBitShiftLeftAssign:
- case EOpBitShiftRightAssign:
- case EOpBitwiseAndAssign:
- case EOpBitwiseXorAssign:
- case EOpBitwiseOrAssign:
- if ((mLeft->isMatrix() && mRight->isVector()) ||
- (mLeft->isVector() && mRight->isMatrix()))
- {
- return false;
- }
-
- // Are the sizes compatible?
- if (mLeft->getNominalSize() != mRight->getNominalSize() ||
- mLeft->getSecondarySize() != mRight->getSecondarySize())
- {
- // If the nominal sizes of operands do not match:
- // One of them must be a scalar.
- if (!mLeft->isScalar() && !mRight->isScalar())
- return false;
-
- // In the case of compound assignment other than multiply-assign,
- // the right side needs to be a scalar. Otherwise a vector/matrix
- // would be assigned to a scalar. A scalar can't be shifted by a
- // vector either.
- if (!mRight->isScalar() &&
- (isAssignment() ||
- mOp == EOpBitShiftLeft ||
- mOp == EOpBitShiftRight))
- return false;
- }
-
- {
- const int secondarySize = std::max(
- mLeft->getSecondarySize(), mRight->getSecondarySize());
- setType(TType(basicType, higherPrecision, resultQualifier,
- static_cast<unsigned char>(nominalSize),
- static_cast<unsigned char>(secondarySize)));
- if (mLeft->isArray())
{
- ASSERT(mLeft->getArraySize() == mRight->getArraySize());
- mType.setArraySize(mLeft->getArraySize());
+ const int secondarySize =
+ std::max(mLeft->getSecondarySize(), mRight->getSecondarySize());
+ setType(TType(basicType, higherPrecision, resultQualifier,
+ static_cast<unsigned char>(nominalSize),
+ static_cast<unsigned char>(secondarySize)));
+ if (mLeft->isArray())
+ {
+ ASSERT(mLeft->getArraySize() == mRight->getArraySize());
+ mType.setArraySize(mLeft->getArraySize());
+ }
}
- }
- break;
+ break;
- case EOpEqual:
- case EOpNotEqual:
- case EOpLessThan:
- case EOpGreaterThan:
- case EOpLessThanEqual:
- case EOpGreaterThanEqual:
- ASSERT((mLeft->getNominalSize() == mRight->getNominalSize()) &&
- (mLeft->getSecondarySize() == mRight->getSecondarySize()));
- setType(TType(EbtBool, EbpUndefined));
- break;
+ case EOpEqual:
+ case EOpNotEqual:
+ case EOpLessThan:
+ case EOpGreaterThan:
+ case EOpLessThanEqual:
+ case EOpGreaterThanEqual:
+ ASSERT((mLeft->getNominalSize() == mRight->getNominalSize()) &&
+ (mLeft->getSecondarySize() == mRight->getSecondarySize()));
+ setType(TType(EbtBool, EbpUndefined));
+ break;
- default:
- return false;
+ default:
+ return false;
}
return true;
}
diff --git a/src/compiler/translator/IntermNode.h b/src/compiler/translator/IntermNode.h
index 7aae484..e9cfb79 100644
--- a/src/compiler/translator/IntermNode.h
+++ b/src/compiler/translator/IntermNode.h
@@ -417,6 +417,9 @@
TIntermTyped *deepCopy() const override { return new TIntermBinary(*this); }
+ static TOperator GetMulOpBasedOnOperands(const TType &left, const TType &right);
+ static TOperator GetMulAssignOpBasedOnOperands(const TType &left, const TType &right);
+
TIntermBinary *getAsBinaryNode() override { return this; };
void traverse(TIntermTraverser *it) override;
bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) override;
diff --git a/src/compiler/translator/ParseContext.cpp b/src/compiler/translator/ParseContext.cpp
index eac0702..abe3712 100644
--- a/src/compiler/translator/ParseContext.cpp
+++ b/src/compiler/translator/ParseContext.cpp
@@ -3574,6 +3574,49 @@
return true;
}
+bool TParseContext::isMultiplicationTypeCombinationValid(TOperator op,
+ const TType &left,
+ const TType &right)
+{
+ switch (op)
+ {
+ case EOpMul:
+ case EOpMulAssign:
+ return left.getNominalSize() == right.getNominalSize() &&
+ left.getSecondarySize() == right.getSecondarySize();
+ case EOpVectorTimesScalar:
+ return true;
+ case EOpVectorTimesScalarAssign:
+ ASSERT(!left.isMatrix() && !right.isMatrix());
+ return left.isVector() && !right.isVector();
+ case EOpVectorTimesMatrix:
+ return left.getNominalSize() == right.getRows();
+ case EOpVectorTimesMatrixAssign:
+ ASSERT(!left.isMatrix() && right.isMatrix());
+ return left.isVector() && left.getNominalSize() == right.getRows() &&
+ left.getNominalSize() == right.getCols();
+ case EOpMatrixTimesVector:
+ return left.getCols() == right.getNominalSize();
+ case EOpMatrixTimesScalar:
+ return true;
+ case EOpMatrixTimesScalarAssign:
+ ASSERT(left.isMatrix() && !right.isMatrix());
+ return !right.isVector();
+ case EOpMatrixTimesMatrix:
+ return left.getCols() == right.getRows();
+ case EOpMatrixTimesMatrixAssign:
+ ASSERT(left.isMatrix() && right.isMatrix());
+ // We need to check two things:
+ // 1. The matrix multiplication step is valid.
+ // 2. The result will have the same number of columns as the lvalue.
+ return left.getCols() == right.getRows() && left.getCols() == right.getCols();
+
+ default:
+ UNREACHABLE();
+ return false;
+ }
+}
+
TIntermTyped *TParseContext::addBinaryMathInternal(TOperator op,
TIntermTyped *left,
TIntermTyped *right,
@@ -3634,6 +3677,15 @@
break;
}
+ if (op == EOpMul)
+ {
+ op = TIntermBinary::GetMulOpBasedOnOperands(left->getType(), right->getType());
+ if (!isMultiplicationTypeCombinationValid(op, left->getType(), right->getType()))
+ {
+ return nullptr;
+ }
+ }
+
TIntermBinary *node = new TIntermBinary(op, left, right);
node->setLine(loc);
@@ -3688,6 +3740,14 @@
{
if (binaryOpCommonCheck(op, left, right, loc))
{
+ if (op == EOpMulAssign)
+ {
+ op = TIntermBinary::GetMulAssignOpBasedOnOperands(left->getType(), right->getType());
+ if (!isMultiplicationTypeCombinationValid(op, left->getType(), right->getType()))
+ {
+ return nullptr;
+ }
+ }
TIntermBinary *node = new TIntermBinary(op, left, right);
node->setLine(loc);
diff --git a/src/compiler/translator/ParseContext.h b/src/compiler/translator/ParseContext.h
index 05ce2e7..b5425d9 100644
--- a/src/compiler/translator/ParseContext.h
+++ b/src/compiler/translator/ParseContext.h
@@ -389,6 +389,9 @@
bool checkIsValidTypeAndQualifierForArray(const TSourceLoc &indexLocation,
const TPublicType &elementType);
+ // Assumes that multiplication op has already been set based on the types.
+ bool isMultiplicationTypeCombinationValid(TOperator op, const TType &left, const TType &right);
+
TIntermTyped *addBinaryMathInternal(
TOperator op, TIntermTyped *left, TIntermTyped *right, const TSourceLoc &loc);
TIntermTyped *createAssign(
diff --git a/src/compiler/translator/RemovePow.cpp b/src/compiler/translator/RemovePow.cpp
index 48fc1f9..db79ca5 100644
--- a/src/compiler/translator/RemovePow.cpp
+++ b/src/compiler/translator/RemovePow.cpp
@@ -60,7 +60,8 @@
log->setLine(node->getLine());
log->setType(x->getType());
- TIntermBinary *mul = new TIntermBinary(EOpMul, y, log);
+ TOperator op = TIntermBinary::GetMulOpBasedOnOperands(y->getType(), log->getType());
+ TIntermBinary *mul = new TIntermBinary(op, y, log);
mul->setLine(node->getLine());
bool valid = mul->promote();
UNUSED_ASSERTION_VARIABLE(valid);