Refactoring: Make creating temporary symbols in AST traversal reusable

Temporary symbols will also be needed to store temporary arrays when complex
array expressions are unfolded.

Also clear tree update related structures at the end of updateTree(), so that
the traverser can be reused for several rounds of replacement more easily, and
remove unnecessary InVisit step from UnfoldShortCircuitToIf.

BUG=angleproject:971
TEST=angle_end2end_tests, WebGL conformance tests

Change-Id: Iecdd3008d43f01b02fe344ccde8614f70e6c0c65
Reviewed-on: https://chromium-review.googlesource.com/272121
Reviewed-by: Zhenyao Mo <zmo@chromium.org>
Tested-by: Olli Etuaho <oetuaho@nvidia.com>
diff --git a/src/compiler/translator/IntermNode.cpp b/src/compiler/translator/IntermNode.cpp
index 2bff86c..72e3b42 100644
--- a/src/compiler/translator/IntermNode.cpp
+++ b/src/compiler/translator/IntermNode.cpp
@@ -1879,4 +1879,8 @@
         ASSERT(replaced);
         UNUSED_ASSERTION_VARIABLE(replaced);
     }
+
+    mInsertions.clear();
+    mReplacements.clear();
+    mMultiReplacements.clear();
 }
diff --git a/src/compiler/translator/IntermNode.h b/src/compiler/translator/IntermNode.h
index c5ddfd7..8b91f51 100644
--- a/src/compiler/translator/IntermNode.h
+++ b/src/compiler/translator/IntermNode.h
@@ -594,7 +594,10 @@
           postVisit(postVisit),
           rightToLeft(rightToLeft),
           mDepth(0),
-          mMaxDepth(0) {}
+          mMaxDepth(0),
+          mTemporaryIndex(nullptr)
+    {
+    }
     virtual ~TIntermTraverser() {}
 
     virtual void visitSymbol(TIntermSymbol *) {}
@@ -647,6 +650,9 @@
     // this function after traversal to perform them.
     void updateTree();
 
+    // Start creating temporary symbols from the given temporary symbol index + 1.
+    void useTemporaryIndex(unsigned int *temporaryIndex);
+
   protected:
     int mDepth;
     int mMaxDepth;
@@ -716,6 +722,15 @@
     // supported.
     void insertStatementsInParentBlock(const TIntermSequence &insertions);
 
+    // Helper to create a temporary symbol node.
+    TIntermSymbol *createTempSymbol(const TType &type);
+    // Create a node that initializes the current temporary symbol with initializer.
+    TIntermAggregate *createTempInitDeclaration(TIntermTyped *initializer);
+    // Create a node that assigns rightNode to the current temporary symbol.
+    TIntermBinary *createTempAssignment(TIntermTyped *rightNode);
+    // Increment temporary symbol index.
+    void nextTemporaryIndex();
+
   private:
     struct ParentBlock
     {
@@ -730,6 +745,8 @@
     };
     // All the code blocks from the root to the current node's parent during traversal.
     std::vector<ParentBlock> mParentBlockStack;
+
+    unsigned int *mTemporaryIndex;
 };
 
 //
diff --git a/src/compiler/translator/IntermTraverse.cpp b/src/compiler/translator/IntermTraverse.cpp
index 2f401da..9d5d870 100644
--- a/src/compiler/translator/IntermTraverse.cpp
+++ b/src/compiler/translator/IntermTraverse.cpp
@@ -5,6 +5,7 @@
 //
 
 #include "compiler/translator/IntermNode.h"
+#include "compiler/translator/InfoSink.h"
 
 void TIntermTraverser::pushParentBlock(TIntermAggregate *node)
 {
@@ -35,6 +36,56 @@
     mInsertions.push_back(insert);
 }
 
+TIntermSymbol *TIntermTraverser::createTempSymbol(const TType &type)
+{
+    // Each traversal uses at most one temporary variable, so the index stays the same within a single traversal.
+    TInfoSinkBase symbolNameOut;
+    ASSERT(mTemporaryIndex != nullptr);
+    symbolNameOut << "s" << (*mTemporaryIndex);
+    TString symbolName = symbolNameOut.c_str();
+
+    TIntermSymbol *node = new TIntermSymbol(0, symbolName, type);
+    node->setInternal(true);
+    node->getTypePointer()->setQualifier(EvqTemporary);
+    return node;
+}
+
+
+TIntermAggregate *TIntermTraverser::createTempInitDeclaration(TIntermTyped *initializer)
+{
+    ASSERT(initializer != nullptr);
+    TIntermSymbol *tempSymbol = createTempSymbol(initializer->getType());
+    TIntermAggregate *tempDeclaration = new TIntermAggregate(EOpDeclaration);
+    TIntermBinary *tempInit = new TIntermBinary(EOpInitialize);
+    tempInit->setLeft(tempSymbol);
+    tempInit->setRight(initializer);
+    tempInit->setType(tempSymbol->getType());
+    tempDeclaration->getSequence()->push_back(tempInit);
+    return tempDeclaration;
+}
+
+TIntermBinary *TIntermTraverser::createTempAssignment(TIntermTyped *rightNode)
+{
+    ASSERT(rightNode != nullptr);
+    TIntermSymbol *tempSymbol = createTempSymbol(rightNode->getType());
+    TIntermBinary *assignment = new TIntermBinary(EOpAssign);
+    assignment->setLeft(tempSymbol);
+    assignment->setRight(rightNode);
+    assignment->setType(tempSymbol->getType());
+    return assignment;
+}
+
+void TIntermTraverser::useTemporaryIndex(unsigned int *temporaryIndex)
+{
+    mTemporaryIndex = temporaryIndex;
+}
+
+void TIntermTraverser::nextTemporaryIndex()
+{
+    ASSERT(mTemporaryIndex != nullptr);
+    ++(*mTemporaryIndex);
+}
+
 //
 // Traverse the intermediate representation tree, and
 // call a node type specific function for each node.
diff --git a/src/compiler/translator/TranslatorHLSL.cpp b/src/compiler/translator/TranslatorHLSL.cpp
index 2c4224c..1fffa4c 100644
--- a/src/compiler/translator/TranslatorHLSL.cpp
+++ b/src/compiler/translator/TranslatorHLSL.cpp
@@ -25,8 +25,10 @@
 
     SeparateDeclarations(root);
 
+    unsigned int temporaryIndex = 0;
+
     // Note that SeparateDeclarations needs to be run before UnfoldShortCircuitToIf.
-    UnfoldShortCircuitToIf(root);
+    UnfoldShortCircuitToIf(root, &temporaryIndex);
 
     // Note that SeparateDeclarations needs to be run before SeparateArrayInitialization.
     SeparateArrayInitialization(root);
diff --git a/src/compiler/translator/UnfoldShortCircuitToIf.cpp b/src/compiler/translator/UnfoldShortCircuitToIf.cpp
index cd83210..dfaf388 100644
--- a/src/compiler/translator/UnfoldShortCircuitToIf.cpp
+++ b/src/compiler/translator/UnfoldShortCircuitToIf.cpp
@@ -10,7 +10,6 @@
 
 #include "compiler/translator/UnfoldShortCircuitToIf.h"
 
-#include "compiler/translator/InfoSink.h"
 #include "compiler/translator/IntermNode.h"
 
 namespace
@@ -31,60 +30,17 @@
     bool foundShortCircuit() const { return mFoundShortCircuit; }
 
   protected:
-    int mTemporaryIndex;
-
     // Marked to true once an operation that needs to be unfolded has been found.
     // After that, no more unfolding is performed on that traversal.
     bool mFoundShortCircuit;
-
-    TIntermSymbol *createTempSymbol(const TType &type);
-    TIntermAggregate *createTempInitDeclaration(const TType &type, TIntermTyped *initializer);
-    TIntermBinary *createTempAssignment(const TType &type, TIntermTyped *rightNode);
 };
 
 UnfoldShortCircuitTraverser::UnfoldShortCircuitTraverser()
-    : TIntermTraverser(true, true, true),
-      mTemporaryIndex(0),
+    : TIntermTraverser(true, false, true),
       mFoundShortCircuit(false)
 {
 }
 
-TIntermSymbol *UnfoldShortCircuitTraverser::createTempSymbol(const TType &type)
-{
-    // Each traversal uses at most one temporary variable, so the index stays the same within a single traversal.
-    TInfoSinkBase symbolNameOut;
-    symbolNameOut << "s" << mTemporaryIndex;
-    TString symbolName = symbolNameOut.c_str();
-
-    TIntermSymbol *node = new TIntermSymbol(0, symbolName, type);
-    node->setInternal(true);
-    return node;
-}
-
-TIntermAggregate *UnfoldShortCircuitTraverser::createTempInitDeclaration(const TType &type, TIntermTyped *initializer)
-{
-    ASSERT(initializer != nullptr);
-    TIntermSymbol *tempSymbol = createTempSymbol(type);
-    TIntermAggregate *tempDeclaration = new TIntermAggregate(EOpDeclaration);
-    TIntermBinary *tempInit = new TIntermBinary(EOpInitialize);
-    tempInit->setLeft(tempSymbol);
-    tempInit->setRight(initializer);
-    tempInit->setType(type);
-    tempDeclaration->getSequence()->push_back(tempInit);
-    return tempDeclaration;
-}
-
-TIntermBinary *UnfoldShortCircuitTraverser::createTempAssignment(const TType &type, TIntermTyped *rightNode)
-{
-    ASSERT(rightNode != nullptr);
-    TIntermSymbol *tempSymbol = createTempSymbol(type);
-    TIntermBinary *assignment = new TIntermBinary(EOpAssign);
-    assignment->setLeft(tempSymbol);
-    assignment->setRight(rightNode);
-    assignment->setType(type);
-    return assignment;
-}
-
 bool UnfoldShortCircuitTraverser::visitBinary(Visit visit, TIntermBinary *node)
 {
     if (mFoundShortCircuit)
@@ -108,10 +64,12 @@
             TIntermSequence insertions;
             TType boolType(EbtBool, EbpUndefined, EvqTemporary);
 
-            insertions.push_back(createTempInitDeclaration(boolType, node->getLeft()));
+            ASSERT(node->getLeft()->getType() == boolType);
+            insertions.push_back(createTempInitDeclaration(node->getLeft()));
 
             TIntermAggregate *assignRightBlock = new TIntermAggregate(EOpSequence);
-            assignRightBlock->getSequence()->push_back(createTempAssignment(boolType, node->getRight()));
+            ASSERT(node->getRight()->getType() == boolType);
+            assignRightBlock->getSequence()->push_back(createTempAssignment(node->getRight()));
 
             TIntermUnary *notTempSymbol = new TIntermUnary(EOpLogicalNot, boolType);
             notTempSymbol->setOperand(createTempSymbol(boolType));
@@ -133,10 +91,12 @@
             TIntermSequence insertions;
             TType boolType(EbtBool, EbpUndefined, EvqTemporary);
 
-            insertions.push_back(createTempInitDeclaration(boolType, node->getLeft()));
+            ASSERT(node->getLeft()->getType() == boolType);
+            insertions.push_back(createTempInitDeclaration(node->getLeft()));
 
             TIntermAggregate *assignRightBlock = new TIntermAggregate(EOpSequence);
-            assignRightBlock->getSequence()->push_back(createTempAssignment(boolType, node->getRight()));
+            ASSERT(node->getRight()->getType() == boolType);
+            assignRightBlock->getSequence()->push_back(createTempAssignment(node->getRight()));
 
             TIntermSelection *ifNode = new TIntermSelection(createTempSymbol(boolType), assignRightBlock, nullptr);
             insertions.push_back(ifNode);
@@ -169,11 +129,11 @@
         insertions.push_back(tempDeclaration);
 
         TIntermAggregate *trueBlock = new TIntermAggregate(EOpSequence);
-        TIntermBinary *trueAssignment = createTempAssignment(node->getType(), node->getTrueBlock()->getAsTyped());
+        TIntermBinary *trueAssignment = createTempAssignment(node->getTrueBlock()->getAsTyped());
         trueBlock->getSequence()->push_back(trueAssignment);
 
         TIntermAggregate *falseBlock = new TIntermAggregate(EOpSequence);
-        TIntermBinary *falseAssignment = createTempAssignment(node->getType(), node->getFalseBlock()->getAsTyped());
+        TIntermBinary *falseAssignment = createTempAssignment(node->getFalseBlock()->getAsTyped());
         falseBlock->getSequence()->push_back(falseAssignment);
 
         TIntermSelection *ifNode = new TIntermSelection(node->getCondition()->getAsTyped(), trueBlock, falseBlock);
@@ -235,17 +195,16 @@
 void UnfoldShortCircuitTraverser::nextIteration()
 {
     mFoundShortCircuit = false;
-    mTemporaryIndex++;
-    mReplacements.clear();
-    mMultiReplacements.clear();
-    mInsertions.clear();
+    nextTemporaryIndex();
 }
 
 } // namespace
 
-void UnfoldShortCircuitToIf(TIntermNode *root)
+void UnfoldShortCircuitToIf(TIntermNode *root, unsigned int *temporaryIndex)
 {
     UnfoldShortCircuitTraverser traverser;
+    ASSERT(temporaryIndex != nullptr);
+    traverser.useTemporaryIndex(temporaryIndex);
     // Unfold one operator at a time, and reset the traverser between iterations.
     do
     {
diff --git a/src/compiler/translator/UnfoldShortCircuitToIf.h b/src/compiler/translator/UnfoldShortCircuitToIf.h
index 1d67b00..0fe37b7 100644
--- a/src/compiler/translator/UnfoldShortCircuitToIf.h
+++ b/src/compiler/translator/UnfoldShortCircuitToIf.h
@@ -13,6 +13,6 @@
 
 class TIntermNode;
 
-void UnfoldShortCircuitToIf(TIntermNode *root);
+void UnfoldShortCircuitToIf(TIntermNode *root, unsigned int *temporaryIndex);
 
 #endif   // COMPILER_TRANSLATOR_UNFOLDSHORTCIRCUIT_H_