Fix scalarizing vec and mat constructor args

Scalarizing vec and mat constructor args can generate new statements
in the parent block of the constructor. To preserve the correct
execution order of expressions, scalarized vector and matrix
constructors need to be first moved out from inside loop conditions
and sequence operators. This is done whenever the compiler flag to
scalarize args is on.

BUG=chromium:772653
TEST=angle_unittests

Change-Id: Id40f8d848a9d087e186ef2e680c8e4cd440221d9
Reviewed-on: https://chromium-review.googlesource.com/790412
Commit-Queue: Olli Etuaho <oetuaho@nvidia.com>
Reviewed-by: Jamie Madill <jmadill@chromium.org>
diff --git a/src/compiler/translator/Compiler.cpp b/src/compiler/translator/Compiler.cpp
index 2f411cb..e0fd0cb 100644
--- a/src/compiler/translator/Compiler.cpp
+++ b/src/compiler/translator/Compiler.cpp
@@ -516,20 +516,24 @@
                                     &symbolTable, shaderVersion);
     }
 
+    int simplifyScalarized = (compileOptions & SH_SCALARIZE_VEC_AND_MAT_CONSTRUCTOR_ARGS)
+                                 ? IntermNodePatternMatcher::kScalarizedVecOrMatConstructor
+                                 : 0;
+
     // Split multi declarations and remove calls to array length().
     // Note that SimplifyLoopConditions needs to be run before any other AST transformations
     // that may need to generate new statements from loop conditions or loop expressions.
-    SimplifyLoopConditions(
-        root,
-        IntermNodePatternMatcher::kMultiDeclaration | IntermNodePatternMatcher::kArrayLengthMethod,
-        &getSymbolTable(), getShaderVersion());
+    SimplifyLoopConditions(root,
+                           IntermNodePatternMatcher::kMultiDeclaration |
+                               IntermNodePatternMatcher::kArrayLengthMethod | simplifyScalarized,
+                           &getSymbolTable(), getShaderVersion());
 
     // Note that separate declarations need to be run before other AST transformations that
     // generate new statements from expressions.
     SeparateDeclarations(root);
 
-    SplitSequenceOperator(root, IntermNodePatternMatcher::kArrayLengthMethod, &getSymbolTable(),
-                          getShaderVersion());
+    SplitSequenceOperator(root, IntermNodePatternMatcher::kArrayLengthMethod | simplifyScalarized,
+                          &getSymbolTable(), getShaderVersion());
 
     RemoveArrayLengthMethod(root);
 
diff --git a/src/compiler/translator/IntermNodePatternMatcher.cpp b/src/compiler/translator/IntermNodePatternMatcher.cpp
index 567e8f7..ce5605b 100644
--- a/src/compiler/translator/IntermNodePatternMatcher.cpp
+++ b/src/compiler/translator/IntermNodePatternMatcher.cpp
@@ -15,6 +15,33 @@
 namespace sh
 {
 
+namespace
+{
+
+bool ContainsMatrixNode(const TIntermSequence &sequence)
+{
+    for (size_t ii = 0; ii < sequence.size(); ++ii)
+    {
+        TIntermTyped *node = sequence[ii]->getAsTyped();
+        if (node && node->isMatrix())
+            return true;
+    }
+    return false;
+}
+
+bool ContainsVectorNode(const TIntermSequence &sequence)
+{
+    for (size_t ii = 0; ii < sequence.size(); ++ii)
+    {
+        TIntermTyped *node = sequence[ii]->getAsTyped();
+        if (node && node->isVector())
+            return true;
+    }
+    return false;
+}
+
+}  // anonymous namespace
+
 IntermNodePatternMatcher::IntermNodePatternMatcher(const unsigned int mask) : mMask(mask)
 {
 }
@@ -105,6 +132,20 @@
             }
         }
     }
+    if ((mMask & kScalarizedVecOrMatConstructor) != 0)
+    {
+        if (node->getOp() == EOpConstruct)
+        {
+            if (node->getType().isVector() && ContainsMatrixNode(*(node->getSequence())))
+            {
+                return true;
+            }
+            else if (node->getType().isMatrix() && ContainsVectorNode(*(node->getSequence())))
+            {
+                return true;
+            }
+        }
+    }
     return false;
 }
 
diff --git a/src/compiler/translator/IntermNodePatternMatcher.h b/src/compiler/translator/IntermNodePatternMatcher.h
index 997fc2e..b83ec67 100644
--- a/src/compiler/translator/IntermNodePatternMatcher.h
+++ b/src/compiler/translator/IntermNodePatternMatcher.h
@@ -48,7 +48,11 @@
         kNamelessStructDeclaration = 0x0001 << 5,
 
         // Matches array length() method.
-        kArrayLengthMethod = 0x0001 << 6
+        kArrayLengthMethod = 0x0001 << 6,
+
+        // Matches a vector or matrix constructor whose arguments are scalarized by the
+        // SH_SCALARIZE_VEC_OR_MAT_CONSTRUCTOR_ARGUMENTS workaround.
+        kScalarizedVecOrMatConstructor = 0x0001 << 7
     };
     IntermNodePatternMatcher(const unsigned int mask);
 
diff --git a/src/compiler/translator/ScalarizeVecAndMatConstructorArgs.cpp b/src/compiler/translator/ScalarizeVecAndMatConstructorArgs.cpp
index 746c16e..b79c56b 100644
--- a/src/compiler/translator/ScalarizeVecAndMatConstructorArgs.cpp
+++ b/src/compiler/translator/ScalarizeVecAndMatConstructorArgs.cpp
@@ -15,6 +15,7 @@
 
 #include "angle_gl.h"
 #include "common/angleutils.h"
+#include "compiler/translator/IntermNodePatternMatcher.h"
 #include "compiler/translator/IntermNode_util.h"
 #include "compiler/translator/IntermTraverse.h"
 
@@ -24,28 +25,6 @@
 namespace
 {
 
-bool ContainsMatrixNode(const TIntermSequence &sequence)
-{
-    for (size_t ii = 0; ii < sequence.size(); ++ii)
-    {
-        TIntermTyped *node = sequence[ii]->getAsTyped();
-        if (node && node->isMatrix())
-            return true;
-    }
-    return false;
-}
-
-bool ContainsVectorNode(const TIntermSequence &sequence)
-{
-    for (size_t ii = 0; ii < sequence.size(); ++ii)
-    {
-        TIntermTyped *node = sequence[ii]->getAsTyped();
-        if (node && node->isVector())
-            return true;
-    }
-    return false;
-}
-
 TIntermBinary *ConstructVectorIndexBinaryNode(TIntermSymbol *symbolNode, int index)
 {
     return new TIntermBinary(EOpIndexDirect, symbolNode, CreateIndexNode(index));
@@ -66,7 +45,8 @@
                            TSymbolTable *symbolTable)
         : TIntermTraverser(true, false, false, symbolTable),
           mShaderType(shaderType),
-          mFragmentPrecisionHigh(fragmentPrecisionHigh)
+          mFragmentPrecisionHigh(fragmentPrecisionHigh),
+          mNodesToScalarize(IntermNodePatternMatcher::kScalarizedVecOrMatConstructor)
     {
     }
 
@@ -92,16 +72,24 @@
 
     sh::GLenum mShaderType;
     bool mFragmentPrecisionHigh;
+
+    IntermNodePatternMatcher mNodesToScalarize;
 };
 
 bool ScalarizeArgsTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
 {
-    if (visit == PreVisit && node->getOp() == EOpConstruct)
+    ASSERT(visit == PreVisit);
+    if (mNodesToScalarize.match(node, getParentNode()))
     {
-        if (node->getType().isVector() && ContainsMatrixNode(*(node->getSequence())))
+        if (node->getType().isVector())
+        {
             scalarizeArgs(node, false, true);
-        else if (node->getType().isMatrix() && ContainsVectorNode(*(node->getSequence())))
+        }
+        else
+        {
+            ASSERT(node->getType().isMatrix());
             scalarizeArgs(node, true, false);
+        }
     }
     return true;
 }
@@ -134,55 +122,55 @@
     ASSERT(!aggregate->isArray());
     int size                  = static_cast<int>(aggregate->getType().getObjectSize());
     TIntermSequence *sequence = aggregate->getSequence();
-    TIntermSequence original(*sequence);
+    TIntermSequence originalArgs(*sequence);
     sequence->clear();
-    for (size_t ii = 0; ii < original.size(); ++ii)
+    for (TIntermNode *originalArgNode : originalArgs)
     {
         ASSERT(size > 0);
-        TIntermTyped *node = original[ii]->getAsTyped();
-        ASSERT(node);
-        createTempVariable(node);
-        if (node->isScalar())
+        TIntermTyped *originalArg = originalArgNode->getAsTyped();
+        ASSERT(originalArg);
+        createTempVariable(originalArg);
+        if (originalArg->isScalar())
         {
-            sequence->push_back(createTempSymbol(node->getType()));
+            sequence->push_back(createTempSymbol(originalArg->getType()));
             size--;
         }
-        else if (node->isVector())
+        else if (originalArg->isVector())
         {
             if (scalarizeVector)
             {
-                int repeat = std::min(size, node->getNominalSize());
+                int repeat = std::min(size, originalArg->getNominalSize());
                 size -= repeat;
                 for (int index = 0; index < repeat; ++index)
                 {
-                    TIntermSymbol *symbolNode = createTempSymbol(node->getType());
+                    TIntermSymbol *symbolNode = createTempSymbol(originalArg->getType());
                     TIntermBinary *newNode    = ConstructVectorIndexBinaryNode(symbolNode, index);
                     sequence->push_back(newNode);
                 }
             }
             else
             {
-                TIntermSymbol *symbolNode = createTempSymbol(node->getType());
+                TIntermSymbol *symbolNode = createTempSymbol(originalArg->getType());
                 sequence->push_back(symbolNode);
-                size -= node->getNominalSize();
+                size -= originalArg->getNominalSize();
             }
         }
         else
         {
-            ASSERT(node->isMatrix());
+            ASSERT(originalArg->isMatrix());
             if (scalarizeMatrix)
             {
                 int colIndex = 0, rowIndex = 0;
-                int repeat = std::min(size, node->getCols() * node->getRows());
+                int repeat = std::min(size, originalArg->getCols() * originalArg->getRows());
                 size -= repeat;
                 while (repeat > 0)
                 {
-                    TIntermSymbol *symbolNode = createTempSymbol(node->getType());
+                    TIntermSymbol *symbolNode = createTempSymbol(originalArg->getType());
                     TIntermBinary *newNode =
                         ConstructMatrixIndexBinaryNode(symbolNode, colIndex, rowIndex);
                     sequence->push_back(newNode);
                     rowIndex++;
-                    if (rowIndex >= node->getRows())
+                    if (rowIndex >= originalArg->getRows())
                     {
                         rowIndex = 0;
                         colIndex++;
@@ -192,9 +180,9 @@
             }
             else
             {
-                TIntermSymbol *symbolNode = createTempSymbol(node->getType());
+                TIntermSymbol *symbolNode = createTempSymbol(originalArg->getType());
                 sequence->push_back(symbolNode);
-                size -= node->getCols() * node->getRows();
+                size -= originalArg->getCols() * originalArg->getRows();
             }
         }
     }