Use the [[flatten]] attribute only when a loop is present.
Flattening branch-heavy shaders that contained no loops caused regressions.
As a temporary workaround we only flatten ifs when there exists a loop.
BUG=395048
Change-Id: I95c40f0249643b98c62304a0f2a4563561d1fbbc
Reviewed-on: https://chromium-review.googlesource.com/228722
Reviewed-by: Shannon Woods <shannonwoods@chromium.org>
Tested-by: Corentin Wallez <cwallez@chromium.org>
diff --git a/src/compiler/translator/DetectDiscontinuity.cpp b/src/compiler/translator/DetectDiscontinuity.cpp
index 334eb0b..f98d32b 100644
--- a/src/compiler/translator/DetectDiscontinuity.cpp
+++ b/src/compiler/translator/DetectDiscontinuity.cpp
@@ -14,6 +14,9 @@
namespace sh
{
+
+// Detect Loop Discontinuity
+
bool DetectLoopDiscontinuity::traverse(TIntermNode *node)
{
mLoopDepth = 0;
@@ -74,6 +77,55 @@
return detectLoopDiscontinuity.traverse(node);
}
+// Detect Any Loop
+
+bool DetectAnyLoop::traverse(TIntermNode *node)
+{
+ mHasLoop = false;
+ node->traverse(this);
+ return mHasLoop;
+}
+
+bool DetectAnyLoop::visitLoop(Visit visit, TIntermLoop *loop)
+{
+ mHasLoop = true;
+ return false;
+}
+
+// The following definitions stop all traversal when we have found a loop
+bool DetectAnyLoop::visitBinary(Visit, TIntermBinary *)
+{
+ return !mHasLoop;
+}
+
+bool DetectAnyLoop::visitUnary(Visit, TIntermUnary *)
+{
+ return !mHasLoop;
+}
+
+bool DetectAnyLoop::visitSelection(Visit, TIntermSelection *)
+{
+ return !mHasLoop;
+}
+
+bool DetectAnyLoop::visitAggregate(Visit, TIntermAggregate *)
+{
+ return !mHasLoop;
+}
+
+bool DetectAnyLoop::visitBranch(Visit, TIntermBranch *)
+{
+ return !mHasLoop;
+}
+
+bool containsAnyLoop(TIntermNode *node)
+{
+ DetectAnyLoop detectAnyLoop;
+ return detectAnyLoop.traverse(node);
+}
+
+// Detect Gradient Operation
+
bool DetectGradientOperation::traverse(TIntermNode *node)
{
mGradientOperation = false;
diff --git a/src/compiler/translator/DetectDiscontinuity.h b/src/compiler/translator/DetectDiscontinuity.h
index 35d66cb..67e37be 100644
--- a/src/compiler/translator/DetectDiscontinuity.h
+++ b/src/compiler/translator/DetectDiscontinuity.h
@@ -32,6 +32,25 @@
bool containsLoopDiscontinuity(TIntermNode *node);
+// Checks for the existence of any loop
+class DetectAnyLoop : public TIntermTraverser
+{
+public:
+ bool traverse(TIntermNode *node);
+
+protected:
+ bool visitBinary(Visit, TIntermBinary *);
+ bool visitUnary(Visit, TIntermUnary *);
+ bool visitSelection(Visit, TIntermSelection *);
+ bool visitAggregate(Visit, TIntermAggregate *);
+ bool visitLoop(Visit, TIntermLoop *);
+ bool visitBranch(Visit, TIntermBranch *);
+
+ bool mHasLoop;
+};
+
+bool containsAnyLoop(TIntermNode *node);
+
// Checks for intrinsic functions which compute gradients
class DetectGradientOperation : public TIntermTraverser
{
diff --git a/src/compiler/translator/OutputHLSL.cpp b/src/compiler/translator/OutputHLSL.cpp
index 8766bc7..30bbbff 100644
--- a/src/compiler/translator/OutputHLSL.cpp
+++ b/src/compiler/translator/OutputHLSL.cpp
@@ -135,6 +135,7 @@
mUniqueIndex = 0;
mContainsLoopDiscontinuity = false;
+ mContainsAnyLoop = false;
mOutputLod0Function = false;
mInsideDiscontinuousLoop = false;
mNestedLoopDepth = 0;
@@ -172,6 +173,7 @@
void OutputHLSL::output()
{
mContainsLoopDiscontinuity = mContext.shaderType == GL_FRAGMENT_SHADER && containsLoopDiscontinuity(mContext.treeRoot);
+ mContainsAnyLoop = containsAnyLoop(mContext.treeRoot);
const std::vector<TIntermTyped*> &flaggedStructs = FlagStd140ValueStructs(mContext.treeRoot);
makeFlaggedStructMaps(flaggedStructs);
@@ -2301,7 +2303,16 @@
{
mUnfoldShortCircuit->traverse(node->getCondition());
- out << "FLATTEN if (";
+ // D3D errors when there is a gradient operation in a loop in an unflattened if
+ // however flattening all the ifs in branch heavy shaders made D3D error too.
+ // As a temporary workaround we flatten the ifs only if there is at least a loop
+ // present somewhere in the shader.
+ if (mContext.shaderType == GL_FRAGMENT_SHADER && mContainsAnyLoop)
+ {
+ out << "FLATTEN ";
+ }
+
+ out << "if (";
node->getCondition()->traverse(this);
diff --git a/src/compiler/translator/OutputHLSL.h b/src/compiler/translator/OutputHLSL.h
index bec0247..5525e6e 100644
--- a/src/compiler/translator/OutputHLSL.h
+++ b/src/compiler/translator/OutputHLSL.h
@@ -144,6 +144,7 @@
int mUniqueIndex; // For creating unique names
bool mContainsLoopDiscontinuity;
+ bool mContainsAnyLoop;
bool mOutputLod0Function;
bool mInsideDiscontinuousLoop;
int mNestedLoopDepth;