Make AST path always include the current node being traversed

AST traversers tend to sometimes call traverse() functions manually
during PreVisit. Change TIntermTraverser so that even if this happens,
all the nodes are automatically added to the traversal path, instead
of having to add them manually in each individual AST traverser.

This also makes calling getParentNode() return the correct node during
InVisit.

This does cause the same node being added to the traversal path twice
in some cases, where nodes are repeatedly traversed, like in
OutputHLSL, but this should not have adverse side effects. The more
common case is that the traverse() function is called on the children
of the node being currently traversed.

This fixes a bug in OVR_multiview validation, which did not previously
call incrementDepth and decrementDepth when it should have.

BUG=angleproject:1725
TEST=angle_unittests, angle_end2end_tests

Change-Id: I6ae762eef760509ebe853eefa37dac28c16e7a9b
Reviewed-on: https://chromium-review.googlesource.com/430732
Commit-Queue: Olli Etuaho <oetuaho@nvidia.com>
Reviewed-by: Jamie Madill <jmadill@chromium.org>
diff --git a/src/compiler/translator/IntermNode.h b/src/compiler/translator/IntermNode.h
index 3f935b6..454a4a3 100644
--- a/src/compiler/translator/IntermNode.h
+++ b/src/compiler/translator/IntermNode.h
@@ -962,6 +962,7 @@
     void useTemporaryIndex(unsigned int *temporaryIndex);
 
   protected:
+    // Should only be called from traverse*() functions
     void incrementDepth(TIntermNode *current)
     {
         mDepth++;
@@ -969,20 +970,35 @@
         mPath.push_back(current);
     }
 
+    // Should only be called from traverse*() functions
     void decrementDepth()
     {
         mDepth--;
         mPath.pop_back();
     }
 
-    TIntermNode *getParentNode() { return mPath.size() == 0 ? NULL : mPath.back(); }
+    // RAII helper for incrementDepth/decrementDepth
+    class ScopedNodeInTraversalPath
+    {
+      public:
+        ScopedNodeInTraversalPath(TIntermTraverser *traverser, TIntermNode *current)
+            : mTraverser(traverser)
+        {
+            mTraverser->incrementDepth(current);
+        }
+        ~ScopedNodeInTraversalPath() { mTraverser->decrementDepth(); }
+      private:
+        TIntermTraverser *mTraverser;
+    };
+
+    TIntermNode *getParentNode() { return mPath.size() <= 1 ? nullptr : mPath[mPath.size() - 2u]; }
 
     // Return the nth ancestor of the node being traversed. getAncestorNode(0) == getParentNode()
     TIntermNode *getAncestorNode(unsigned int n)
     {
-        if (mPath.size() > n)
+        if (mPath.size() > n + 1u)
         {
-            return mPath[mPath.size() - n - 1u];
+            return mPath[mPath.size() - n - 2u];
         }
         return nullptr;
     }
@@ -991,11 +1007,6 @@
     void incrementParentBlockPos();
     void popParentBlock();
 
-    bool parentNodeIsBlock()
-    {
-        return !mParentBlockStack.empty() && getParentNode() == mParentBlockStack.back().node;
-    }
-
     // To replace a single node with multiple nodes on the parent aggregate node
     struct NodeReplaceWithMultipleEntry
     {
@@ -1086,9 +1097,6 @@
     int mDepth;
     int mMaxDepth;
 
-    // All the nodes from root to the current node's parent during traversing.
-    TVector<TIntermNode *> mPath;
-
     bool mInGlobalScope;
 
     // During traversing, save all the changes that need to happen into
@@ -1131,6 +1139,9 @@
 
     std::vector<NodeUpdateEntry> mReplacements;
 
+    // All the nodes from root to the current node during traversing.
+    TVector<TIntermNode *> mPath;
+
     // All the code blocks from the root to the current node's parent during traversal.
     std::vector<ParentBlock> mParentBlockStack;
 
diff --git a/src/compiler/translator/IntermTraverse.cpp b/src/compiler/translator/IntermTraverse.cpp
index 36f8f5a..b5d42cc 100644
--- a/src/compiler/translator/IntermTraverse.cpp
+++ b/src/compiler/translator/IntermTraverse.cpp
@@ -105,7 +105,7 @@
     : preVisit(preVisit),
       inVisit(inVisit),
       postVisit(postVisit),
-      mDepth(0),
+      mDepth(-1),
       mMaxDepth(0),
       mInGlobalScope(true),
       mTemporaryIndex(nullptr)
@@ -142,8 +142,16 @@
                                                      const TIntermSequence &insertionsAfter)
 {
     ASSERT(!mParentBlockStack.empty());
-    NodeInsertMultipleEntry insert(mParentBlockStack.back().node, mParentBlockStack.back().pos,
-                                   insertionsBefore, insertionsAfter);
+    ParentBlock &parentBlock = mParentBlockStack.back();
+    if (mPath.back() == parentBlock.node)
+    {
+        ASSERT(mParentBlockStack.size() >= 2u);
+        // The current node is a block node, so the parent block is not the topmost one in the block
+        // stack, but the one below that.
+        parentBlock = mParentBlockStack.at(mParentBlockStack.size() - 2u);
+    }
+    NodeInsertMultipleEntry insert(parentBlock.node, parentBlock.pos, insertionsBefore,
+                                   insertionsAfter);
     mInsertions.push_back(insert);
 }
 
@@ -264,16 +272,20 @@
 //
 void TIntermTraverser::traverseSymbol(TIntermSymbol *node)
 {
+    ScopedNodeInTraversalPath addToPath(this, node);
     visitSymbol(node);
 }
 
 void TIntermTraverser::traverseConstantUnion(TIntermConstantUnion *node)
 {
+    ScopedNodeInTraversalPath addToPath(this, node);
     visitConstantUnion(node);
 }
 
 void TIntermTraverser::traverseSwizzle(TIntermSwizzle *node)
 {
+    ScopedNodeInTraversalPath addToPath(this, node);
+
     bool visit = true;
 
     if (preVisit)
@@ -281,11 +293,7 @@
 
     if (visit)
     {
-        incrementDepth(node);
-
         node->getOperand()->traverse(this);
-
-        decrementDepth();
     }
 
     if (visit && postVisit)
@@ -297,6 +305,8 @@
 //
 void TIntermTraverser::traverseBinary(TIntermBinary *node)
 {
+    ScopedNodeInTraversalPath addToPath(this, node);
+
     bool visit = true;
 
     //
@@ -310,8 +320,6 @@
     //
     if (visit)
     {
-        incrementDepth(node);
-
         if (node->getLeft())
             node->getLeft()->traverse(this);
 
@@ -320,8 +328,6 @@
 
         if (visit && node->getRight())
             node->getRight()->traverse(this);
-
-        decrementDepth();
     }
 
     //
@@ -334,6 +340,8 @@
 
 void TLValueTrackingTraverser::traverseBinary(TIntermBinary *node)
 {
+    ScopedNodeInTraversalPath addToPath(this, node);
+
     bool visit = true;
 
     //
@@ -347,8 +355,6 @@
     //
     if (visit)
     {
-        incrementDepth(node);
-
         // Some binary operations like indexing can be inside an expression which must be an
         // l-value.
         bool parentOperatorRequiresLValue     = operatorRequiresLValue();
@@ -383,8 +389,6 @@
 
         setOperatorRequiresLValue(parentOperatorRequiresLValue);
         setInFunctionCallOutParameter(parentInFunctionCallOutParameter);
-
-        decrementDepth();
     }
 
     //
@@ -400,6 +404,8 @@
 //
 void TIntermTraverser::traverseUnary(TIntermUnary *node)
 {
+    ScopedNodeInTraversalPath addToPath(this, node);
+
     bool visit = true;
 
     if (preVisit)
@@ -407,11 +413,7 @@
 
     if (visit)
     {
-        incrementDepth(node);
-
         node->getOperand()->traverse(this);
-
-        decrementDepth();
     }
 
     if (visit && postVisit)
@@ -420,6 +422,8 @@
 
 void TLValueTrackingTraverser::traverseUnary(TIntermUnary *node)
 {
+    ScopedNodeInTraversalPath addToPath(this, node);
+
     bool visit = true;
 
     if (preVisit)
@@ -427,8 +431,6 @@
 
     if (visit)
     {
-        incrementDepth(node);
-
         ASSERT(!operatorRequiresLValue());
         switch (node->getOp())
         {
@@ -445,8 +447,6 @@
         node->getOperand()->traverse(this);
 
         setOperatorRequiresLValue(false);
-
-        decrementDepth();
     }
 
     if (visit && postVisit)
@@ -456,6 +456,8 @@
 // Traverse a function definition node.
 void TIntermTraverser::traverseFunctionDefinition(TIntermFunctionDefinition *node)
 {
+    ScopedNodeInTraversalPath addToPath(this, node);
+
     bool visit = true;
 
     if (preVisit)
@@ -463,7 +465,6 @@
 
     if (visit)
     {
-        incrementDepth(node);
         mInGlobalScope = false;
 
         node->getFunctionPrototype()->traverse(this);
@@ -472,7 +473,6 @@
         node->getBody()->traverse(this);
 
         mInGlobalScope = true;
-        decrementDepth();
     }
 
     if (visit && postVisit)
@@ -482,6 +482,9 @@
 // Traverse a block node.
 void TIntermTraverser::traverseBlock(TIntermBlock *node)
 {
+    ScopedNodeInTraversalPath addToPath(this, node);
+    pushParentBlock(node);
+
     bool visit = true;
 
     TIntermSequence *sequence = node->getSequence();
@@ -491,9 +494,6 @@
 
     if (visit)
     {
-        incrementDepth(node);
-        pushParentBlock(node);
-
         for (auto *child : *sequence)
         {
             child->traverse(this);
@@ -505,17 +505,18 @@
 
             incrementParentBlockPos();
         }
-
-        popParentBlock();
-        decrementDepth();
     }
 
     if (visit && postVisit)
         visitBlock(PostVisit, node);
+
+    popParentBlock();
 }
 
 void TIntermTraverser::traverseInvariantDeclaration(TIntermInvariantDeclaration *node)
 {
+    ScopedNodeInTraversalPath addToPath(this, node);
+
     bool visit = true;
 
     if (preVisit)
@@ -536,6 +537,8 @@
 // Traverse a declaration node.
 void TIntermTraverser::traverseDeclaration(TIntermDeclaration *node)
 {
+    ScopedNodeInTraversalPath addToPath(this, node);
+
     bool visit = true;
 
     TIntermSequence *sequence = node->getSequence();
@@ -545,8 +548,6 @@
 
     if (visit)
     {
-        incrementDepth(node);
-
         for (auto *child : *sequence)
         {
             child->traverse(this);
@@ -556,8 +557,6 @@
                     visit = visitDeclaration(InVisit, node);
             }
         }
-
-        decrementDepth();
     }
 
     if (visit && postVisit)
@@ -566,6 +565,8 @@
 
 void TIntermTraverser::traverseFunctionPrototype(TIntermFunctionPrototype *node)
 {
+    ScopedNodeInTraversalPath addToPath(this, node);
+
     bool visit = true;
 
     TIntermSequence *sequence = node->getSequence();
@@ -575,8 +576,6 @@
 
     if (visit)
     {
-        incrementDepth(node);
-
         for (auto *child : *sequence)
         {
             child->traverse(this);
@@ -586,8 +585,6 @@
                     visit = visitFunctionPrototype(InVisit, node);
             }
         }
-
-        decrementDepth();
     }
 
     if (visit && postVisit)
@@ -597,6 +594,8 @@
 // Traverse an aggregate node.  Same comments in binary node apply here.
 void TIntermTraverser::traverseAggregate(TIntermAggregate *node)
 {
+    ScopedNodeInTraversalPath addToPath(this, node);
+
     bool visit = true;
 
     TIntermSequence *sequence = node->getSequence();
@@ -606,8 +605,6 @@
 
     if (visit)
     {
-        incrementDepth(node);
-
         for (auto *child : *sequence)
         {
             child->traverse(this);
@@ -617,8 +614,6 @@
                     visit = visitAggregate(InVisit, node);
             }
         }
-
-        decrementDepth();
     }
 
     if (visit && postVisit)
@@ -635,6 +630,8 @@
 
 void TLValueTrackingTraverser::traverseAggregate(TIntermAggregate *node)
 {
+    ScopedNodeInTraversalPath addToPath(this, node);
+
     bool visit = true;
 
     TIntermSequence *sequence = node->getSequence();
@@ -656,8 +653,6 @@
             }
         }
 
-        incrementDepth(node);
-
         if (inFunctionMap)
         {
             TIntermSequence *params             = getFunctionParameters(node);
@@ -728,8 +723,6 @@
 
             setInFunctionCallOutParameter(false);
         }
-
-        decrementDepth();
     }
 
     if (visit && postVisit)
@@ -741,6 +734,8 @@
 //
 void TIntermTraverser::traverseTernary(TIntermTernary *node)
 {
+    ScopedNodeInTraversalPath addToPath(this, node);
+
     bool visit = true;
 
     if (preVisit)
@@ -748,13 +743,11 @@
 
     if (visit)
     {
-        incrementDepth(node);
         node->getCondition()->traverse(this);
         if (node->getTrueExpression())
             node->getTrueExpression()->traverse(this);
         if (node->getFalseExpression())
             node->getFalseExpression()->traverse(this);
-        decrementDepth();
     }
 
     if (visit && postVisit)
@@ -764,6 +757,8 @@
 // Traverse an if-else node.  Same comments in binary node apply here.
 void TIntermTraverser::traverseIfElse(TIntermIfElse *node)
 {
+    ScopedNodeInTraversalPath addToPath(this, node);
+
     bool visit = true;
 
     if (preVisit)
@@ -771,13 +766,11 @@
 
     if (visit)
     {
-        incrementDepth(node);
         node->getCondition()->traverse(this);
         if (node->getTrueBlock())
             node->getTrueBlock()->traverse(this);
         if (node->getFalseBlock())
             node->getFalseBlock()->traverse(this);
-        decrementDepth();
     }
 
     if (visit && postVisit)
@@ -789,6 +782,8 @@
 //
 void TIntermTraverser::traverseSwitch(TIntermSwitch *node)
 {
+    ScopedNodeInTraversalPath addToPath(this, node);
+
     bool visit = true;
 
     if (preVisit)
@@ -796,13 +791,11 @@
 
     if (visit)
     {
-        incrementDepth(node);
         node->getInit()->traverse(this);
         if (inVisit)
             visit = visitSwitch(InVisit, node);
         if (visit && node->getStatementList())
             node->getStatementList()->traverse(this);
-        decrementDepth();
     }
 
     if (visit && postVisit)
@@ -814,6 +807,8 @@
 //
 void TIntermTraverser::traverseCase(TIntermCase *node)
 {
+    ScopedNodeInTraversalPath addToPath(this, node);
+
     bool visit = true;
 
     if (preVisit)
@@ -821,9 +816,7 @@
 
     if (visit && node->getCondition())
     {
-        incrementDepth(node);
         node->getCondition()->traverse(this);
-        decrementDepth();
     }
 
     if (visit && postVisit)
@@ -835,6 +828,8 @@
 //
 void TIntermTraverser::traverseLoop(TIntermLoop *node)
 {
+    ScopedNodeInTraversalPath addToPath(this, node);
+
     bool visit = true;
 
     if (preVisit)
@@ -842,8 +837,6 @@
 
     if (visit)
     {
-        incrementDepth(node);
-
         if (node->getInit())
             node->getInit()->traverse(this);
 
@@ -855,8 +848,6 @@
 
         if (node->getExpression())
             node->getExpression()->traverse(this);
-
-        decrementDepth();
     }
 
     if (visit && postVisit)
@@ -868,6 +859,8 @@
 //
 void TIntermTraverser::traverseBranch(TIntermBranch *node)
 {
+    ScopedNodeInTraversalPath addToPath(this, node);
+
     bool visit = true;
 
     if (preVisit)
@@ -875,9 +868,7 @@
 
     if (visit && node->getExpression())
     {
-        incrementDepth(node);
         node->getExpression()->traverse(this);
-        decrementDepth();
     }
 
     if (visit && postVisit)
@@ -886,6 +877,7 @@
 
 void TIntermTraverser::traverseRaw(TIntermRaw *node)
 {
+    ScopedNodeInTraversalPath addToPath(this, node);
     visitRaw(node);
 }
 
diff --git a/src/compiler/translator/OutputGLSLBase.cpp b/src/compiler/translator/OutputGLSLBase.cpp
index d71c857..a2cef1c 100644
--- a/src/compiler/translator/OutputGLSLBase.cpp
+++ b/src/compiler/translator/OutputGLSLBase.cpp
@@ -761,7 +761,6 @@
     node->getCondition()->traverse(this);
     out << ")\n";
 
-    incrementDepth(node);
     visitCodeBlock(node->getTrueBlock());
 
     if (node->getFalseBlock())
@@ -769,7 +768,6 @@
         out << "else\n";
         visitCodeBlock(node->getFalseBlock());
     }
-    decrementDepth();
     return false;
 }
 
@@ -812,7 +810,6 @@
         out << "{\n";
     }
 
-    incrementDepth(node);
     for (TIntermSequence::const_iterator iter = node->getSequence()->begin();
          iter != node->getSequence()->end(); ++iter)
     {
@@ -823,7 +820,6 @@
         if (isSingleStatement(curNode))
             out << ";\n";
     }
-    decrementDepth();
 
     // Scope the blocks except when at the global scope.
     if (mDepth > 0)
@@ -835,11 +831,9 @@
 
 bool TOutputGLSLBase::visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node)
 {
-    incrementDepth(node);
     TIntermFunctionPrototype *prototype = node->getFunctionPrototype();
     prototype->traverse(this);
     visitCodeBlock(node->getBody());
-    decrementDepth();
 
     // Fully processed; no need to visit children.
     return false;
@@ -986,8 +980,6 @@
 {
     TInfoSinkBase &out = objSink();
 
-    incrementDepth(node);
-
     TLoopType loopType = node->getType();
 
     if (loopType == ELoopFor)  // for loop
@@ -1029,8 +1021,6 @@
         out << ");\n";
     }
 
-    decrementDepth();
-
     // No need to visit children. They have been already processed in
     // this function.
     return false;
diff --git a/src/compiler/translator/OutputHLSL.cpp b/src/compiler/translator/OutputHLSL.cpp
index 25d46f8..1ce25cf 100644
--- a/src/compiler/translator/OutputHLSL.cpp
+++ b/src/compiler/translator/OutputHLSL.cpp
@@ -919,11 +919,9 @@
     }
 }
 
-bool OutputHLSL::ancestorEvaluatesToSamplerInStruct(Visit visit)
+bool OutputHLSL::ancestorEvaluatesToSamplerInStruct()
 {
-    // Inside InVisit the current node is already in the path.
-    const unsigned int initialN = visit == InVisit ? 1u : 0u;
-    for (unsigned int n = initialN; getAncestorNode(n) != nullptr; ++n)
+    for (unsigned int n = 0u; getAncestorNode(n) != nullptr; ++n)
     {
         TIntermNode *ancestor               = getAncestorNode(n);
         const TIntermBinary *ancestorBinary = ancestor->getAsBinaryNode();
@@ -1126,7 +1124,7 @@
                     return false;
                 }
             }
-            else if (ancestorEvaluatesToSamplerInStruct(visit))
+            else if (ancestorEvaluatesToSamplerInStruct())
             {
                 // All parts of an expression that access a sampler in a struct need to use _ as
                 // separator to access the sampler variable that has been moved out of the struct.
@@ -1163,7 +1161,7 @@
             {
                 // All parts of an expression that access a sampler in a struct need to use _ as
                 // separator to access the sampler variable that has been moved out of the struct.
-                indexingReturnsSampler = ancestorEvaluatesToSamplerInStruct(visit);
+                indexingReturnsSampler = ancestorEvaluatesToSamplerInStruct();
             }
             if (visit == InVisit)
             {
diff --git a/src/compiler/translator/OutputHLSL.h b/src/compiler/translator/OutputHLSL.h
index 46a721f..b1ccbc5 100644
--- a/src/compiler/translator/OutputHLSL.h
+++ b/src/compiler/translator/OutputHLSL.h
@@ -239,7 +239,7 @@
 
   private:
     TString samplerNamePrefixFromStruct(TIntermTyped *node);
-    bool ancestorEvaluatesToSamplerInStruct(Visit visit);
+    bool ancestorEvaluatesToSamplerInStruct();
 };
 }
 
diff --git a/src/compiler/translator/SimplifyLoopConditions.cpp b/src/compiler/translator/SimplifyLoopConditions.cpp
index 68fdb03..a2ed7f3 100644
--- a/src/compiler/translator/SimplifyLoopConditions.cpp
+++ b/src/compiler/translator/SimplifyLoopConditions.cpp
@@ -135,9 +135,7 @@
 
     // Mark that we're inside a loop condition or expression, and transform the loop if needed.
 
-    incrementDepth(node);
-
-    // Note: No need to traverse the loop init node.
+    ScopedNodeInTraversalPath addToPath(this, node);
 
     mInsideLoopInitConditionOrExpression = true;
     TLoopType loopType                   = node->getType();
@@ -274,8 +272,7 @@
                 ELoopWhile, nullptr, createTempSymbol(conditionInitializer->getType()), nullptr,
                 whileLoopBody);
             loopScope->getSequence()->push_back(whileLoop);
-            queueReplacementWithParent(getAncestorNode(1), node, loopScope,
-                                       OriginalNode::IS_DROPPED);
+            queueReplacement(node, loopScope, OriginalNode::IS_DROPPED);
         }
     }
 
@@ -283,8 +280,6 @@
 
     if (!mFoundLoopToChange && node->getBody())
         node->getBody()->traverse(this);
-
-    decrementDepth();
 }
 
 }  // namespace