Implement discontinuous loops AST analysis

This will allow narrowing down which usages of
[[flatten]] and [[unroll]] are actually useful.

BUG=angleproject:937
BUG=395048

Change-Id: I091e647e3053d22edadd0cabb7c50bd5efa690b2
Reviewed-on: https://chromium-review.googlesource.com/263776
Reviewed-by: Jamie Madill <jmadill@chromium.org>
Tested-by: Corentin Wallez <cwallez@chromium.org>
diff --git a/src/compiler/translator/ASTMetadataHLSL.cpp b/src/compiler/translator/ASTMetadataHLSL.cpp
index 9580667..910cfdc 100644
--- a/src/compiler/translator/ASTMetadataHLSL.cpp
+++ b/src/compiler/translator/ASTMetadataHLSL.cpp
@@ -138,6 +138,201 @@
     std::vector<TIntermNode*> mParents;
 };
 
+// Traverses the AST of a function definition, assuming it has already been used to
+// traverse the callees of that function; computes the discontinuous loops and the if
+// statements that contain a discontinuous loop in their call graph.
+class PullComputeDiscontinuousLoops : public TIntermTraverser
+{
+  public:
+    PullComputeDiscontinuousLoops(MetadataList *metadataList, size_t index, const CallDAG &dag)
+        : TIntermTraverser(true, false, true),
+          mMetadataList(metadataList),
+          mMetadata(&(*metadataList)[index]),
+          mIndex(index),
+          mDag(dag)
+    {
+    }
+
+    void traverse(TIntermAggregate *node)
+    {
+        node->traverse(this);
+        ASSERT(mLoops.empty());
+        ASSERT(mIfs.empty());
+    }
+
+    // Called when a discontinuous loop or a call to a function with a discontinuous loop
+    // in its call graph is found.
+    void onDiscontinuousLoop()
+    {
+        mMetadata->mHasDiscontinuousLoopInCallGraph = true;
+        // Mark the latest if as using a discontinuous loop.
+        if (!mIfs.empty())
+        {
+            mMetadata->mIfsContainingDiscontinuousLoop.insert(mIfs.back());
+        }
+    }
+
+    bool visitLoop(Visit visit, TIntermLoop *loop)
+    {
+        if (visit == PreVisit)
+        {
+            mLoops.push_back(loop);
+        }
+        else if (visit == PostVisit)
+        {
+            ASSERT(mLoops.back() == loop);
+            mLoops.pop_back();
+        }
+
+        return true;
+    }
+
+    bool visitSelection(Visit visit, TIntermSelection *node)
+    {
+        if (visit == PreVisit)
+        {
+            mIfs.push_back(node);
+        }
+        else if (visit == PostVisit)
+        {
+            ASSERT(mIfs.back() == node);
+            mIfs.pop_back();
+            // An if using a discontinuous loop means its parents ifs are also discontinuous.
+            if (mMetadata->mIfsContainingDiscontinuousLoop.count(node) > 0 && !mIfs.empty())
+            {
+                mMetadata->mIfsContainingDiscontinuousLoop.insert(mIfs.back());
+            }
+        }
+
+        return true;
+    }
+
+    bool visitBranch(Visit visit, TIntermBranch *node)
+    {
+        if (visit == PreVisit)
+        {
+            switch (node->getFlowOp())
+            {
+              case EOpKill:
+                break;
+              case EOpBreak:
+              case EOpContinue:
+                ASSERT(!mLoops.empty());
+                mMetadata->mDiscontinuousLoops.insert(mLoops.back());
+                onDiscontinuousLoop();
+                break;
+              case EOpReturn:
+                // A return jumps out of all the enclosing loops
+                if (!mLoops.empty())
+                {
+                    for (TIntermLoop* loop : mLoops)
+                    {
+                        mMetadata->mDiscontinuousLoops.insert(loop);
+                    }
+                    onDiscontinuousLoop();
+                }
+                break;
+              default:
+                UNREACHABLE();
+            }
+        }
+
+        return true;
+    }
+
+    bool visitAggregate(Visit visit, TIntermAggregate *node) override
+    {
+        if (visit == PreVisit && node->getOp() == EOpFunctionCall)
+        {
+            if (node->isUserDefined())
+            {
+                size_t calleeIndex = mDag.findIndex(node);
+                ASSERT(calleeIndex != CallDAG::InvalidIndex && calleeIndex < mIndex);
+
+                if ((*mMetadataList)[calleeIndex].mHasDiscontinuousLoopInCallGraph)
+                {
+                    onDiscontinuousLoop();
+                }
+            }
+        }
+
+        return true;
+    }
+
+  private:
+    MetadataList *mMetadataList;
+    ASTMetadataHLSL *mMetadata;
+    size_t mIndex;
+    const CallDAG &mDag;
+
+    std::vector<TIntermLoop*> mLoops;
+    std::vector<TIntermSelection*> mIfs;
+};
+
+// Tags all the functions called in a discontinuous loop
+class PushDiscontinuousLoops : public TIntermTraverser
+{
+  public:
+    PushDiscontinuousLoops(MetadataList *metadataList, size_t index, const CallDAG &dag)
+        : TIntermTraverser(true, true, true),
+          mMetadataList(metadataList),
+          mMetadata(&(*metadataList)[index]),
+          mIndex(index),
+          mDag(dag),
+          mNestedDiscont(mMetadata->mCalledInDiscontinuousLoop ? 1 : 0)
+    {
+    }
+
+    void traverse(TIntermAggregate *node)
+    {
+        node->traverse(this);
+        ASSERT(mNestedDiscont == (mMetadata->mCalledInDiscontinuousLoop ? 1 : 0));
+    }
+
+    bool visitLoop(Visit visit, TIntermLoop *loop)
+    {
+        bool isDiscontinuous = mMetadata->mDiscontinuousLoops.count(loop) > 0;
+
+        if (visit == PreVisit && isDiscontinuous)
+        {
+            mNestedDiscont++;
+        }
+        else if (visit == PostVisit && isDiscontinuous)
+        {
+            mNestedDiscont--;
+        }
+
+        return true;
+    }
+
+    bool visitAggregate(Visit visit, TIntermAggregate *node) override
+    {
+        switch (node->getOp())
+        {
+          case EOpFunctionCall:
+            if (visit == PreVisit && node->isUserDefined() && mNestedDiscont > 0)
+            {
+                size_t calleeIndex = mDag.findIndex(node);
+                ASSERT(calleeIndex != CallDAG::InvalidIndex && calleeIndex < mIndex);
+
+                (*mMetadataList)[calleeIndex].mCalledInDiscontinuousLoop = true;
+            }
+            break;
+          default:
+            break;
+        }
+        return true;
+    }
+
+  private:
+    MetadataList *mMetadataList;
+    ASTMetadataHLSL *mMetadata;
+    size_t mIndex;
+    const CallDAG &mDag;
+
+    int mNestedDiscont;
+};
+
 }
 
 bool ASTMetadataHLSL::hasGradientInCallGraph(TIntermSelection *node)
@@ -150,6 +345,11 @@
     return mControlFlowsContainingGradient.count(node) > 0;
 }
 
+bool ASTMetadataHLSL::hasDiscontinuousLoop(TIntermSelection *node)
+{
+    return mIfsContainingDiscontinuousLoop.count(node) > 0;
+}
+
 MetadataList CreateASTMetadataHLSL(TIntermNode *root, const CallDAG &callDag)
 {
     MetadataList metadataList(callDag.size());
@@ -177,5 +377,35 @@
         pull.traverse(callDag.getRecordFromIndex(i).node);
     }
 
+    // Compute which loops are discontinuous and which function are called in
+    // these loops. The same way computing gradient usage is a "pull" process,
+    // computing "bing used in a discont. loop" is a push process. However we also
+    // need to know what ifs have a discontinuous loop inside so we do the same type
+    // of callgraph analysis as for the gradient.
+
+    // First compute which loops are discontinuous (no specific order) and pull
+    // the ifs and functions using a discontinuous loop.
+    for (size_t i = 0; i < callDag.size(); i++)
+    {
+        PullComputeDiscontinuousLoops pull(&metadataList, i, callDag);
+        pull.traverse(callDag.getRecordFromIndex(i).node);
+    }
+
+    // Then push the information to callees, either from the a local discontinuous
+    // loop or from the caller being called in a discontinuous loop already
+    for (size_t i = callDag.size(); i-- > 0;)
+    {
+        PushDiscontinuousLoops push(&metadataList, i, callDag);
+        push.traverse(callDag.getRecordFromIndex(i).node);
+    }
+
+    // We create "Lod0" version of functions with the gradient operations replaced
+    // by non-gradient operations so that the D3D compiler is happier with discont
+    // loops.
+    for (auto &metadata : metadataList)
+    {
+        metadata.mNeedsLod0 = metadata.mCalledInDiscontinuousLoop && metadata.mUsesGradient;
+    }
+
     return metadataList;
 }