Check multiplication validity in ParseContext

This improves separation of responsibilities in the code: ParseContext
should handle operand type validation, while TIntermBinary::promote
should ideally only determine the type of the node based on the
operation and operands.

BUG=angleproject:952
TEST=angle_unittests

Change-Id: I9a8d8ede21cdf35de631623a62194c0da5c604d2
Reviewed-on: https://chromium-review.googlesource.com/372622
Commit-Queue: Olli Etuaho <oetuaho@nvidia.com>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
diff --git a/src/compiler/translator/IntermNode.cpp b/src/compiler/translator/IntermNode.cpp
index 83f8e00..2cd5dee 100644
--- a/src/compiler/translator/IntermNode.cpp
+++ b/src/compiler/translator/IntermNode.cpp
@@ -34,41 +34,6 @@
     return left > right ? left : right;
 }
 
-bool ValidateMultiplication(TOperator op, const TType &left, const TType &right)
-{
-    switch (op)
-    {
-      case EOpMul:
-      case EOpMulAssign:
-        return left.getNominalSize() == right.getNominalSize() &&
-               left.getSecondarySize() == right.getSecondarySize();
-      case EOpVectorTimesScalar:
-      case EOpVectorTimesScalarAssign:
-        return true;
-      case EOpVectorTimesMatrix:
-        return left.getNominalSize() == right.getRows();
-      case EOpVectorTimesMatrixAssign:
-        return left.getNominalSize() == right.getRows() &&
-               left.getNominalSize() == right.getCols();
-      case EOpMatrixTimesVector:
-        return left.getCols() == right.getNominalSize();
-      case EOpMatrixTimesScalar:
-      case EOpMatrixTimesScalarAssign:
-        return true;
-      case EOpMatrixTimesMatrix:
-        return left.getCols() == right.getRows();
-      case EOpMatrixTimesMatrixAssign:
-          // We need to check two things:
-          // 1. The matrix multiplication step is valid.
-          // 2. The result will have the same number of columns as the lvalue.
-          return left.getCols() == right.getRows() && left.getCols() == right.getCols();
-
-      default:
-        UNREACHABLE();
-        return false;
-    }
-}
-
 TConstantUnion *Vectorize(const TConstantUnion &constant, size_t size)
 {
     TConstantUnion *constUnion = new TConstantUnion[size];
@@ -513,6 +478,94 @@
     }
 }
 
+TOperator TIntermBinary::GetMulOpBasedOnOperands(const TType &left, const TType &right)
+{
+    if (left.isMatrix())
+    {
+        if (right.isMatrix())
+        {
+            return EOpMatrixTimesMatrix;
+        }
+        else
+        {
+            if (right.isVector())
+            {
+                return EOpMatrixTimesVector;
+            }
+            else
+            {
+                return EOpMatrixTimesScalar;
+            }
+        }
+    }
+    else
+    {
+        if (right.isMatrix())
+        {
+            if (left.isVector())
+            {
+                return EOpVectorTimesMatrix;
+            }
+            else
+            {
+                return EOpMatrixTimesScalar;
+            }
+        }
+        else
+        {
+            // Neither operand is a matrix.
+            if (left.isVector() == right.isVector())
+            {
+                // Leave as component product.
+                return EOpMul;
+            }
+            else
+            {
+                return EOpVectorTimesScalar;
+            }
+        }
+    }
+}
+
+TOperator TIntermBinary::GetMulAssignOpBasedOnOperands(const TType &left, const TType &right)
+{
+    if (left.isMatrix())
+    {
+        if (right.isMatrix())
+        {
+            return EOpMatrixTimesMatrixAssign;
+        }
+        else
+        {
+            // right should be scalar, but this may not be validated yet.
+            return EOpMatrixTimesScalarAssign;
+        }
+    }
+    else
+    {
+        if (right.isMatrix())
+        {
+            // Left should be a vector, but this may not be validated yet.
+            return EOpVectorTimesMatrixAssign;
+        }
+        else
+        {
+            // Neither operand is a matrix.
+            if (left.isVector() == right.isVector())
+            {
+                // Leave as component product.
+                return EOpMulAssign;
+            }
+            else
+            {
+                // left should be vector and right should be scalar, but this may not be validated
+                // yet.
+                return EOpVectorTimesScalarAssign;
+            }
+        }
+    }
+}
+
 //
 // Make sure the type of a unary operator is appropriate for its
 // combination of operation and operand type.
@@ -570,6 +623,9 @@
 {
     ASSERT(mLeft->isArray() == mRight->isArray());
 
+    ASSERT(!isMultiplication() ||
+           mOp == GetMulOpBasedOnOperands(mLeft->getType(), mRight->getType()));
+
     //
     // Base assumption:  just make the type the same as the left
     // operand.  Then only deviations from this need be coded.
@@ -633,204 +689,118 @@
     // Can these two operands be combined?
     //
     TBasicType basicType = mLeft->getBasicType();
+
     switch (mOp)
     {
-      case EOpMul:
-        if (!mLeft->isMatrix() && mRight->isMatrix())
-        {
-            if (mLeft->isVector())
+        case EOpMul:
+            break;
+        case EOpMatrixTimesScalar:
+            if (mRight->isMatrix())
             {
-                mOp = EOpVectorTimesMatrix;
-                setType(TType(basicType, higherPrecision, resultQualifier,
-                              static_cast<unsigned char>(mRight->getCols()), 1));
-            }
-            else
-            {
-                mOp = EOpMatrixTimesScalar;
                 setType(TType(basicType, higherPrecision, resultQualifier,
                               static_cast<unsigned char>(mRight->getCols()),
                               static_cast<unsigned char>(mRight->getRows())));
             }
-        }
-        else if (mLeft->isMatrix() && !mRight->isMatrix())
-        {
-            if (mRight->isVector())
-            {
-                mOp = EOpMatrixTimesVector;
-                setType(TType(basicType, higherPrecision, resultQualifier,
-                              static_cast<unsigned char>(mLeft->getRows()), 1));
-            }
-            else
-            {
-                mOp = EOpMatrixTimesScalar;
-            }
-        }
-        else if (mLeft->isMatrix() && mRight->isMatrix())
-        {
-            mOp = EOpMatrixTimesMatrix;
+            break;
+        case EOpMatrixTimesVector:
+            setType(TType(basicType, higherPrecision, resultQualifier,
+                          static_cast<unsigned char>(mLeft->getRows()), 1));
+            break;
+        case EOpMatrixTimesMatrix:
             setType(TType(basicType, higherPrecision, resultQualifier,
                           static_cast<unsigned char>(mRight->getCols()),
                           static_cast<unsigned char>(mLeft->getRows())));
-        }
-        else if (!mLeft->isMatrix() && !mRight->isMatrix())
-        {
-            if (mLeft->isVector() && mRight->isVector())
-            {
-                // leave as component product
-            }
-            else if (mLeft->isVector() || mRight->isVector())
-            {
-                mOp = EOpVectorTimesScalar;
-                setType(TType(basicType, higherPrecision, resultQualifier,
-                              static_cast<unsigned char>(nominalSize), 1));
-            }
-        }
-        else
-        {
-            UNREACHABLE();
-            return false;
-        }
-
-        if (!ValidateMultiplication(mOp, mLeft->getType(), mRight->getType()))
-        {
-            return false;
-        }
-        break;
-
-      case EOpMulAssign:
-        if (!mLeft->isMatrix() && mRight->isMatrix())
-        {
-            if (mLeft->isVector())
-            {
-                mOp = EOpVectorTimesMatrixAssign;
-            }
-            else
-            {
-                return false;
-            }
-        }
-        else if (mLeft->isMatrix() && !mRight->isMatrix())
-        {
-            if (mRight->isVector())
-            {
-                return false;
-            }
-            else
-            {
-                mOp = EOpMatrixTimesScalarAssign;
-            }
-        }
-        else if (mLeft->isMatrix() && mRight->isMatrix())
-        {
-            mOp = EOpMatrixTimesMatrixAssign;
+            break;
+        case EOpVectorTimesScalar:
             setType(TType(basicType, higherPrecision, resultQualifier,
-                          static_cast<unsigned char>(mRight->getCols()),
-                          static_cast<unsigned char>(mLeft->getRows())));
-        }
-        else if (!mLeft->isMatrix() && !mRight->isMatrix())
-        {
-            if (mLeft->isVector() && mRight->isVector())
+                          static_cast<unsigned char>(nominalSize), 1));
+            break;
+        case EOpVectorTimesMatrix:
+            setType(TType(basicType, higherPrecision, resultQualifier,
+                          static_cast<unsigned char>(mRight->getCols()), 1));
+            break;
+        case EOpMulAssign:
+        case EOpVectorTimesScalarAssign:
+        case EOpVectorTimesMatrixAssign:
+        case EOpMatrixTimesScalarAssign:
+        case EOpMatrixTimesMatrixAssign:
+            ASSERT(mOp == GetMulAssignOpBasedOnOperands(mLeft->getType(), mRight->getType()));
+            break;
+        case EOpAssign:
+        case EOpInitialize:
+            // No more additional checks are needed.
+            ASSERT((mLeft->getNominalSize() == mRight->getNominalSize()) &&
+                   (mLeft->getSecondarySize() == mRight->getSecondarySize()));
+            break;
+        case EOpAdd:
+        case EOpSub:
+        case EOpDiv:
+        case EOpIMod:
+        case EOpBitShiftLeft:
+        case EOpBitShiftRight:
+        case EOpBitwiseAnd:
+        case EOpBitwiseXor:
+        case EOpBitwiseOr:
+        case EOpAddAssign:
+        case EOpSubAssign:
+        case EOpDivAssign:
+        case EOpIModAssign:
+        case EOpBitShiftLeftAssign:
+        case EOpBitShiftRightAssign:
+        case EOpBitwiseAndAssign:
+        case EOpBitwiseXorAssign:
+        case EOpBitwiseOrAssign:
+            if ((mLeft->isMatrix() && mRight->isVector()) ||
+                (mLeft->isVector() && mRight->isMatrix()))
             {
-                // leave as component product
+                return false;
             }
-            else if (mLeft->isVector() || mRight->isVector())
+
+            // Are the sizes compatible?
+            if (mLeft->getNominalSize() != mRight->getNominalSize() ||
+                mLeft->getSecondarySize() != mRight->getSecondarySize())
             {
-                if (!mLeft->isVector())
+                // If the nominal sizes of operands do not match:
+                // One of them must be a scalar.
+                if (!mLeft->isScalar() && !mRight->isScalar())
                     return false;
-                mOp = EOpVectorTimesScalarAssign;
-                setType(TType(basicType, higherPrecision, resultQualifier,
-                              static_cast<unsigned char>(mLeft->getNominalSize()), 1));
+
+                // In the case of compound assignment other than multiply-assign,
+                // the right side needs to be a scalar. Otherwise a vector/matrix
+                // would be assigned to a scalar. A scalar can't be shifted by a
+                // vector either.
+                if (!mRight->isScalar() &&
+                    (isAssignment() || mOp == EOpBitShiftLeft || mOp == EOpBitShiftRight))
+                    return false;
             }
-        }
-        else
-        {
-            UNREACHABLE();
-            return false;
-        }
 
-        if (!ValidateMultiplication(mOp, mLeft->getType(), mRight->getType()))
-        {
-            return false;
-        }
-        break;
-
-      case EOpAssign:
-      case EOpInitialize:
-        // No more additional checks are needed.
-        ASSERT((mLeft->getNominalSize() == mRight->getNominalSize()) &&
-            (mLeft->getSecondarySize() == mRight->getSecondarySize()));
-        break;
-      case EOpAdd:
-      case EOpSub:
-      case EOpDiv:
-      case EOpIMod:
-      case EOpBitShiftLeft:
-      case EOpBitShiftRight:
-      case EOpBitwiseAnd:
-      case EOpBitwiseXor:
-      case EOpBitwiseOr:
-      case EOpAddAssign:
-      case EOpSubAssign:
-      case EOpDivAssign:
-      case EOpIModAssign:
-      case EOpBitShiftLeftAssign:
-      case EOpBitShiftRightAssign:
-      case EOpBitwiseAndAssign:
-      case EOpBitwiseXorAssign:
-      case EOpBitwiseOrAssign:
-        if ((mLeft->isMatrix() && mRight->isVector()) ||
-            (mLeft->isVector() && mRight->isMatrix()))
-        {
-            return false;
-        }
-
-        // Are the sizes compatible?
-        if (mLeft->getNominalSize() != mRight->getNominalSize() ||
-            mLeft->getSecondarySize() != mRight->getSecondarySize())
-        {
-            // If the nominal sizes of operands do not match:
-            // One of them must be a scalar.
-            if (!mLeft->isScalar() && !mRight->isScalar())
-                return false;
-
-            // In the case of compound assignment other than multiply-assign,
-            // the right side needs to be a scalar. Otherwise a vector/matrix
-            // would be assigned to a scalar. A scalar can't be shifted by a
-            // vector either.
-            if (!mRight->isScalar() &&
-                (isAssignment() ||
-                mOp == EOpBitShiftLeft ||
-                mOp == EOpBitShiftRight))
-                return false;
-        }
-
-        {
-            const int secondarySize = std::max(
-                mLeft->getSecondarySize(), mRight->getSecondarySize());
-            setType(TType(basicType, higherPrecision, resultQualifier,
-                          static_cast<unsigned char>(nominalSize),
-                          static_cast<unsigned char>(secondarySize)));
-            if (mLeft->isArray())
             {
-                ASSERT(mLeft->getArraySize() == mRight->getArraySize());
-                mType.setArraySize(mLeft->getArraySize());
+                const int secondarySize =
+                    std::max(mLeft->getSecondarySize(), mRight->getSecondarySize());
+                setType(TType(basicType, higherPrecision, resultQualifier,
+                              static_cast<unsigned char>(nominalSize),
+                              static_cast<unsigned char>(secondarySize)));
+                if (mLeft->isArray())
+                {
+                    ASSERT(mLeft->getArraySize() == mRight->getArraySize());
+                    mType.setArraySize(mLeft->getArraySize());
+                }
             }
-        }
-        break;
+            break;
 
-      case EOpEqual:
-      case EOpNotEqual:
-      case EOpLessThan:
-      case EOpGreaterThan:
-      case EOpLessThanEqual:
-      case EOpGreaterThanEqual:
-        ASSERT((mLeft->getNominalSize() == mRight->getNominalSize()) &&
-            (mLeft->getSecondarySize() == mRight->getSecondarySize()));
-        setType(TType(EbtBool, EbpUndefined));
-        break;
+        case EOpEqual:
+        case EOpNotEqual:
+        case EOpLessThan:
+        case EOpGreaterThan:
+        case EOpLessThanEqual:
+        case EOpGreaterThanEqual:
+            ASSERT((mLeft->getNominalSize() == mRight->getNominalSize()) &&
+                   (mLeft->getSecondarySize() == mRight->getSecondarySize()));
+            setType(TType(EbtBool, EbpUndefined));
+            break;
 
-      default:
-        return false;
+        default:
+            return false;
     }
     return true;
 }
diff --git a/src/compiler/translator/IntermNode.h b/src/compiler/translator/IntermNode.h
index 7aae484..e9cfb79 100644
--- a/src/compiler/translator/IntermNode.h
+++ b/src/compiler/translator/IntermNode.h
@@ -417,6 +417,9 @@
 
     TIntermTyped *deepCopy() const override { return new TIntermBinary(*this); }
 
+    static TOperator GetMulOpBasedOnOperands(const TType &left, const TType &right);
+    static TOperator GetMulAssignOpBasedOnOperands(const TType &left, const TType &right);
+
     TIntermBinary *getAsBinaryNode() override { return this; };
     void traverse(TIntermTraverser *it) override;
     bool replaceChildNode(TIntermNode *original, TIntermNode *replacement) override;
diff --git a/src/compiler/translator/ParseContext.cpp b/src/compiler/translator/ParseContext.cpp
index eac0702..abe3712 100644
--- a/src/compiler/translator/ParseContext.cpp
+++ b/src/compiler/translator/ParseContext.cpp
@@ -3574,6 +3574,49 @@
     return true;
 }
 
+bool TParseContext::isMultiplicationTypeCombinationValid(TOperator op,
+                                                         const TType &left,
+                                                         const TType &right)
+{
+    switch (op)
+    {
+        case EOpMul:
+        case EOpMulAssign:
+            return left.getNominalSize() == right.getNominalSize() &&
+                   left.getSecondarySize() == right.getSecondarySize();
+        case EOpVectorTimesScalar:
+            return true;
+        case EOpVectorTimesScalarAssign:
+            ASSERT(!left.isMatrix() && !right.isMatrix());
+            return left.isVector() && !right.isVector();
+        case EOpVectorTimesMatrix:
+            return left.getNominalSize() == right.getRows();
+        case EOpVectorTimesMatrixAssign:
+            ASSERT(!left.isMatrix() && right.isMatrix());
+            return left.isVector() && left.getNominalSize() == right.getRows() &&
+                   left.getNominalSize() == right.getCols();
+        case EOpMatrixTimesVector:
+            return left.getCols() == right.getNominalSize();
+        case EOpMatrixTimesScalar:
+            return true;
+        case EOpMatrixTimesScalarAssign:
+            ASSERT(left.isMatrix() && !right.isMatrix());
+            return !right.isVector();
+        case EOpMatrixTimesMatrix:
+            return left.getCols() == right.getRows();
+        case EOpMatrixTimesMatrixAssign:
+            ASSERT(left.isMatrix() && right.isMatrix());
+            // We need to check two things:
+            // 1. The matrix multiplication step is valid.
+            // 2. The result will have the same number of columns as the lvalue.
+            return left.getCols() == right.getRows() && left.getCols() == right.getCols();
+
+        default:
+            UNREACHABLE();
+            return false;
+    }
+}
+
 TIntermTyped *TParseContext::addBinaryMathInternal(TOperator op,
                                                    TIntermTyped *left,
                                                    TIntermTyped *right,
@@ -3634,6 +3677,15 @@
             break;
     }
 
+    if (op == EOpMul)
+    {
+        op = TIntermBinary::GetMulOpBasedOnOperands(left->getType(), right->getType());
+        if (!isMultiplicationTypeCombinationValid(op, left->getType(), right->getType()))
+        {
+            return nullptr;
+        }
+    }
+
     TIntermBinary *node = new TIntermBinary(op, left, right);
     node->setLine(loc);
 
@@ -3688,6 +3740,14 @@
 {
     if (binaryOpCommonCheck(op, left, right, loc))
     {
+        if (op == EOpMulAssign)
+        {
+            op = TIntermBinary::GetMulAssignOpBasedOnOperands(left->getType(), right->getType());
+            if (!isMultiplicationTypeCombinationValid(op, left->getType(), right->getType()))
+            {
+                return nullptr;
+            }
+        }
         TIntermBinary *node = new TIntermBinary(op, left, right);
         node->setLine(loc);
 
diff --git a/src/compiler/translator/ParseContext.h b/src/compiler/translator/ParseContext.h
index 05ce2e7..b5425d9 100644
--- a/src/compiler/translator/ParseContext.h
+++ b/src/compiler/translator/ParseContext.h
@@ -389,6 +389,9 @@
     bool checkIsValidTypeAndQualifierForArray(const TSourceLoc &indexLocation,
                                               const TPublicType &elementType);
 
+    // Assumes that multiplication op has already been set based on the types.
+    bool isMultiplicationTypeCombinationValid(TOperator op, const TType &left, const TType &right);
+
     TIntermTyped *addBinaryMathInternal(
         TOperator op, TIntermTyped *left, TIntermTyped *right, const TSourceLoc &loc);
     TIntermTyped *createAssign(
diff --git a/src/compiler/translator/RemovePow.cpp b/src/compiler/translator/RemovePow.cpp
index 48fc1f9..db79ca5 100644
--- a/src/compiler/translator/RemovePow.cpp
+++ b/src/compiler/translator/RemovePow.cpp
@@ -60,7 +60,8 @@
         log->setLine(node->getLine());
         log->setType(x->getType());
 
-        TIntermBinary *mul = new TIntermBinary(EOpMul, y, log);
+        TOperator op       = TIntermBinary::GetMulOpBasedOnOperands(y->getType(), log->getType());
+        TIntermBinary *mul = new TIntermBinary(op, y, log);
         mul->setLine(node->getLine());
         bool valid = mul->promote();
         UNUSED_ASSERTION_VARIABLE(valid);