Use the [[flatten]] attribute only when a loop is present.

Flattening branch-heavy shaders that contained no loops caused regressions.
As a temporary workaround we only flatten ifs when there exists a loop.

BUG=395048

Change-Id: I95c40f0249643b98c62304a0f2a4563561d1fbbc
Reviewed-on: https://chromium-review.googlesource.com/228722
Reviewed-by: Shannon Woods <shannonwoods@chromium.org>
Tested-by: Corentin Wallez <cwallez@chromium.org>
diff --git a/src/compiler/translator/DetectDiscontinuity.cpp b/src/compiler/translator/DetectDiscontinuity.cpp
index 334eb0b..f98d32b 100644
--- a/src/compiler/translator/DetectDiscontinuity.cpp
+++ b/src/compiler/translator/DetectDiscontinuity.cpp
@@ -14,6 +14,9 @@
 
 namespace sh
 {
+
+// Detect Loop Discontinuity
+
 bool DetectLoopDiscontinuity::traverse(TIntermNode *node)
 {
     mLoopDepth = 0;
@@ -74,6 +77,55 @@
     return detectLoopDiscontinuity.traverse(node);
 }
 
+// Detect Any Loop
+
+bool DetectAnyLoop::traverse(TIntermNode *node)
+{
+    mHasLoop = false;
+    node->traverse(this);
+    return mHasLoop;
+}
+
+bool DetectAnyLoop::visitLoop(Visit visit, TIntermLoop *loop)
+{
+    mHasLoop = true;
+    return false;
+}
+
+// The following definitions stop all traversal when we have found a loop
+bool DetectAnyLoop::visitBinary(Visit, TIntermBinary *)
+{
+    return !mHasLoop;
+}
+
+bool DetectAnyLoop::visitUnary(Visit, TIntermUnary *)
+{
+    return !mHasLoop;
+}
+
+bool DetectAnyLoop::visitSelection(Visit, TIntermSelection *)
+{
+    return !mHasLoop;
+}
+
+bool DetectAnyLoop::visitAggregate(Visit, TIntermAggregate *)
+{
+    return !mHasLoop;
+}
+
+bool DetectAnyLoop::visitBranch(Visit, TIntermBranch *)
+{
+    return !mHasLoop;
+}
+
+bool containsAnyLoop(TIntermNode *node)
+{
+    DetectAnyLoop detectAnyLoop;
+    return detectAnyLoop.traverse(node);
+}
+
+// Detect Gradient Operation
+
 bool DetectGradientOperation::traverse(TIntermNode *node)
 {
     mGradientOperation = false;
diff --git a/src/compiler/translator/DetectDiscontinuity.h b/src/compiler/translator/DetectDiscontinuity.h
index 35d66cb..67e37be 100644
--- a/src/compiler/translator/DetectDiscontinuity.h
+++ b/src/compiler/translator/DetectDiscontinuity.h
@@ -32,6 +32,25 @@
 
 bool containsLoopDiscontinuity(TIntermNode *node);
 
+// Checks for the existence of any loop
+class DetectAnyLoop : public TIntermTraverser
+{
+public:
+    bool traverse(TIntermNode *node);
+
+protected:
+    bool visitBinary(Visit, TIntermBinary *);
+    bool visitUnary(Visit, TIntermUnary *);
+    bool visitSelection(Visit, TIntermSelection *);
+    bool visitAggregate(Visit, TIntermAggregate *);
+    bool visitLoop(Visit, TIntermLoop *);
+    bool visitBranch(Visit, TIntermBranch *);
+
+    bool mHasLoop;
+};
+
+bool containsAnyLoop(TIntermNode *node);
+
 // Checks for intrinsic functions which compute gradients
 class DetectGradientOperation : public TIntermTraverser
 {
diff --git a/src/compiler/translator/OutputHLSL.cpp b/src/compiler/translator/OutputHLSL.cpp
index 8766bc7..30bbbff 100644
--- a/src/compiler/translator/OutputHLSL.cpp
+++ b/src/compiler/translator/OutputHLSL.cpp
@@ -135,6 +135,7 @@
     mUniqueIndex = 0;
 
     mContainsLoopDiscontinuity = false;
+    mContainsAnyLoop = false;
     mOutputLod0Function = false;
     mInsideDiscontinuousLoop = false;
     mNestedLoopDepth = 0;
@@ -172,6 +173,7 @@
 void OutputHLSL::output()
 {
     mContainsLoopDiscontinuity = mContext.shaderType == GL_FRAGMENT_SHADER && containsLoopDiscontinuity(mContext.treeRoot);
+    mContainsAnyLoop = containsAnyLoop(mContext.treeRoot);
     const std::vector<TIntermTyped*> &flaggedStructs = FlagStd140ValueStructs(mContext.treeRoot);
     makeFlaggedStructMaps(flaggedStructs);
 
@@ -2301,7 +2303,16 @@
     {
         mUnfoldShortCircuit->traverse(node->getCondition());
 
-        out << "FLATTEN if (";
+        // D3D errors when there is a gradient operation in a loop in an unflattened if
+        // however flattening all the ifs in branch heavy shaders made D3D error too.
+        // As a temporary workaround we flatten the ifs only if there is at least a loop
+        // present somewhere in the shader.
+        if (mContext.shaderType == GL_FRAGMENT_SHADER && mContainsAnyLoop)
+        {
+            out << "FLATTEN ";
+        }
+
+        out << "if (";
 
         node->getCondition()->traverse(this);
 
diff --git a/src/compiler/translator/OutputHLSL.h b/src/compiler/translator/OutputHLSL.h
index bec0247..5525e6e 100644
--- a/src/compiler/translator/OutputHLSL.h
+++ b/src/compiler/translator/OutputHLSL.h
@@ -144,6 +144,7 @@
     int mUniqueIndex;   // For creating unique names
 
     bool mContainsLoopDiscontinuity;
+    bool mContainsAnyLoop;
     bool mOutputLod0Function;
     bool mInsideDiscontinuousLoop;
     int mNestedLoopDepth;