Implement discontinuous loops AST analysis
This will allow narrowing down which usages of
[[flatten]] and [[unroll]] are actually useful.
BUG=angleproject:937
BUG=395048
Change-Id: I091e647e3053d22edadd0cabb7c50bd5efa690b2
Reviewed-on: https://chromium-review.googlesource.com/263776
Reviewed-by: Jamie Madill <jmadill@chromium.org>
Tested-by: Corentin Wallez <cwallez@chromium.org>
diff --git a/src/compiler/translator/ASTMetadataHLSL.cpp b/src/compiler/translator/ASTMetadataHLSL.cpp
index 9580667..910cfdc 100644
--- a/src/compiler/translator/ASTMetadataHLSL.cpp
+++ b/src/compiler/translator/ASTMetadataHLSL.cpp
@@ -138,6 +138,201 @@
std::vector<TIntermNode*> mParents;
};
+// Traverses the AST of a function definition, assuming it has already been used to
+// traverse the callees of that function; computes the discontinuous loops and the if
+// statements that contain a discontinuous loop in their call graph.
+class PullComputeDiscontinuousLoops : public TIntermTraverser
+{
+ public:
+ PullComputeDiscontinuousLoops(MetadataList *metadataList, size_t index, const CallDAG &dag)
+ : TIntermTraverser(true, false, true),
+ mMetadataList(metadataList),
+ mMetadata(&(*metadataList)[index]),
+ mIndex(index),
+ mDag(dag)
+ {
+ }
+
+ void traverse(TIntermAggregate *node)
+ {
+ node->traverse(this);
+ ASSERT(mLoops.empty());
+ ASSERT(mIfs.empty());
+ }
+
+ // Called when a discontinuous loop or a call to a function with a discontinuous loop
+ // in its call graph is found.
+ void onDiscontinuousLoop()
+ {
+ mMetadata->mHasDiscontinuousLoopInCallGraph = true;
+ // Mark the latest if as using a discontinuous loop.
+ if (!mIfs.empty())
+ {
+ mMetadata->mIfsContainingDiscontinuousLoop.insert(mIfs.back());
+ }
+ }
+
+ bool visitLoop(Visit visit, TIntermLoop *loop)
+ {
+ if (visit == PreVisit)
+ {
+ mLoops.push_back(loop);
+ }
+ else if (visit == PostVisit)
+ {
+ ASSERT(mLoops.back() == loop);
+ mLoops.pop_back();
+ }
+
+ return true;
+ }
+
+ bool visitSelection(Visit visit, TIntermSelection *node)
+ {
+ if (visit == PreVisit)
+ {
+ mIfs.push_back(node);
+ }
+ else if (visit == PostVisit)
+ {
+ ASSERT(mIfs.back() == node);
+ mIfs.pop_back();
+ // An if using a discontinuous loop means its parents ifs are also discontinuous.
+ if (mMetadata->mIfsContainingDiscontinuousLoop.count(node) > 0 && !mIfs.empty())
+ {
+ mMetadata->mIfsContainingDiscontinuousLoop.insert(mIfs.back());
+ }
+ }
+
+ return true;
+ }
+
+ bool visitBranch(Visit visit, TIntermBranch *node)
+ {
+ if (visit == PreVisit)
+ {
+ switch (node->getFlowOp())
+ {
+ case EOpKill:
+ break;
+ case EOpBreak:
+ case EOpContinue:
+ ASSERT(!mLoops.empty());
+ mMetadata->mDiscontinuousLoops.insert(mLoops.back());
+ onDiscontinuousLoop();
+ break;
+ case EOpReturn:
+ // A return jumps out of all the enclosing loops
+ if (!mLoops.empty())
+ {
+ for (TIntermLoop* loop : mLoops)
+ {
+ mMetadata->mDiscontinuousLoops.insert(loop);
+ }
+ onDiscontinuousLoop();
+ }
+ break;
+ default:
+ UNREACHABLE();
+ }
+ }
+
+ return true;
+ }
+
+ bool visitAggregate(Visit visit, TIntermAggregate *node) override
+ {
+ if (visit == PreVisit && node->getOp() == EOpFunctionCall)
+ {
+ if (node->isUserDefined())
+ {
+ size_t calleeIndex = mDag.findIndex(node);
+ ASSERT(calleeIndex != CallDAG::InvalidIndex && calleeIndex < mIndex);
+
+ if ((*mMetadataList)[calleeIndex].mHasDiscontinuousLoopInCallGraph)
+ {
+ onDiscontinuousLoop();
+ }
+ }
+ }
+
+ return true;
+ }
+
+ private:
+ MetadataList *mMetadataList;
+ ASTMetadataHLSL *mMetadata;
+ size_t mIndex;
+ const CallDAG &mDag;
+
+ std::vector<TIntermLoop*> mLoops;
+ std::vector<TIntermSelection*> mIfs;
+};
+
+// Tags all the functions called in a discontinuous loop
+class PushDiscontinuousLoops : public TIntermTraverser
+{
+ public:
+ PushDiscontinuousLoops(MetadataList *metadataList, size_t index, const CallDAG &dag)
+ : TIntermTraverser(true, true, true),
+ mMetadataList(metadataList),
+ mMetadata(&(*metadataList)[index]),
+ mIndex(index),
+ mDag(dag),
+ mNestedDiscont(mMetadata->mCalledInDiscontinuousLoop ? 1 : 0)
+ {
+ }
+
+ void traverse(TIntermAggregate *node)
+ {
+ node->traverse(this);
+ ASSERT(mNestedDiscont == (mMetadata->mCalledInDiscontinuousLoop ? 1 : 0));
+ }
+
+ bool visitLoop(Visit visit, TIntermLoop *loop)
+ {
+ bool isDiscontinuous = mMetadata->mDiscontinuousLoops.count(loop) > 0;
+
+ if (visit == PreVisit && isDiscontinuous)
+ {
+ mNestedDiscont++;
+ }
+ else if (visit == PostVisit && isDiscontinuous)
+ {
+ mNestedDiscont--;
+ }
+
+ return true;
+ }
+
+ bool visitAggregate(Visit visit, TIntermAggregate *node) override
+ {
+ switch (node->getOp())
+ {
+ case EOpFunctionCall:
+ if (visit == PreVisit && node->isUserDefined() && mNestedDiscont > 0)
+ {
+ size_t calleeIndex = mDag.findIndex(node);
+ ASSERT(calleeIndex != CallDAG::InvalidIndex && calleeIndex < mIndex);
+
+ (*mMetadataList)[calleeIndex].mCalledInDiscontinuousLoop = true;
+ }
+ break;
+ default:
+ break;
+ }
+ return true;
+ }
+
+ private:
+ MetadataList *mMetadataList;
+ ASTMetadataHLSL *mMetadata;
+ size_t mIndex;
+ const CallDAG &mDag;
+
+ int mNestedDiscont;
+};
+
}
bool ASTMetadataHLSL::hasGradientInCallGraph(TIntermSelection *node)
@@ -150,6 +345,11 @@
return mControlFlowsContainingGradient.count(node) > 0;
}
+bool ASTMetadataHLSL::hasDiscontinuousLoop(TIntermSelection *node)
+{
+ return mIfsContainingDiscontinuousLoop.count(node) > 0;
+}
+
MetadataList CreateASTMetadataHLSL(TIntermNode *root, const CallDAG &callDag)
{
MetadataList metadataList(callDag.size());
@@ -177,5 +377,35 @@
pull.traverse(callDag.getRecordFromIndex(i).node);
}
+ // Compute which loops are discontinuous and which function are called in
+ // these loops. The same way computing gradient usage is a "pull" process,
+ // computing "bing used in a discont. loop" is a push process. However we also
+ // need to know what ifs have a discontinuous loop inside so we do the same type
+ // of callgraph analysis as for the gradient.
+
+ // First compute which loops are discontinuous (no specific order) and pull
+ // the ifs and functions using a discontinuous loop.
+ for (size_t i = 0; i < callDag.size(); i++)
+ {
+ PullComputeDiscontinuousLoops pull(&metadataList, i, callDag);
+ pull.traverse(callDag.getRecordFromIndex(i).node);
+ }
+
+ // Then push the information to callees, either from the a local discontinuous
+ // loop or from the caller being called in a discontinuous loop already
+ for (size_t i = callDag.size(); i-- > 0;)
+ {
+ PushDiscontinuousLoops push(&metadataList, i, callDag);
+ push.traverse(callDag.getRecordFromIndex(i).node);
+ }
+
+ // We create "Lod0" version of functions with the gradient operations replaced
+ // by non-gradient operations so that the D3D compiler is happier with discont
+ // loops.
+ for (auto &metadata : metadataList)
+ {
+ metadata.mNeedsLod0 = metadata.mCalledInDiscontinuousLoop && metadata.mUsesGradient;
+ }
+
return metadataList;
}