Change the FLATTEN heuristic to "ifs with a loop with a gradient"

This heuristic makes more sense than the previous "ifs with a
discontinuous loop" as the reason we need to flatten is that we need
gradients to be in branchless code.

Change the UnrollFlatten test accordingly.

Tested with:
 - the WebGL CTS
 - dev.miaumiau.cat/rayTracer "Skull Demo"
 - THe turbulenz engine GPU particle demo
 - Lots of ShaderToy Samples (inc. Volcanic, Metropolis and Hierarchical
   Voronoi)
 - Google Maps Earth mode
 - Lots of Chrome experiments
 - madebyevan.com/webgl-water

BUG=524297

Change-Id: Iaa727036fffcfde3952716a1ef33b6ee0546b69d
Reviewed-on: https://chromium-review.googlesource.com/296442
Reviewed-by: Jamie Madill <jmadill@chromium.org>
Tested-by: Corentin Wallez <cwallez@chromium.org>
diff --git a/src/compiler/translator/ASTMetadataHLSL.cpp b/src/compiler/translator/ASTMetadataHLSL.cpp
index aa71645..cc21a00 100644
--- a/src/compiler/translator/ASTMetadataHLSL.cpp
+++ b/src/compiler/translator/ASTMetadataHLSL.cpp
@@ -139,13 +139,16 @@
     std::vector<TIntermNode*> mParents;
 };
 
-// Traverses the AST of a function definition, assuming it has already been used to
-// traverse the callees of that function; computes the discontinuous loops and the if
-// statements that contain a discontinuous loop in their call graph.
-class PullComputeDiscontinuousLoops : public TIntermTraverser
+// Traverses the AST of a function definition to compute the the discontinuous loops
+// and the if statements containing gradient loops. It assumes that the gradient loops
+// (loops that contain a gradient) have already been computed and that it has already
+// traversed the current function's callees.
+class PullComputeDiscontinuousAndGradientLoops : public TIntermTraverser
 {
   public:
-    PullComputeDiscontinuousLoops(MetadataList *metadataList, size_t index, const CallDAG &dag)
+    PullComputeDiscontinuousAndGradientLoops(MetadataList *metadataList,
+                                             size_t index,
+                                             const CallDAG &dag)
         : TIntermTraverser(true, false, true),
           mMetadataList(metadataList),
           mMetadata(&(*metadataList)[index]),
@@ -161,15 +164,15 @@
         ASSERT(mIfs.empty());
     }
 
-    // Called when a discontinuous loop or a call to a function with a discontinuous loop
-    // in its call graph is found.
-    void onDiscontinuousLoop()
+    // Called when traversing a gradient loop or a call to a function with a
+    // gradient loop in its call graph.
+    void onGradientLoop()
     {
-        mMetadata->mHasDiscontinuousLoopInCallGraph = true;
+        mMetadata->mHasGradientLoopInCallGraph = true;
         // Mark the latest if as using a discontinuous loop.
         if (!mIfs.empty())
         {
-            mMetadata->mIfsContainingDiscontinuousLoop.insert(mIfs.back());
+            mMetadata->mIfsContainingGradientLoop.insert(mIfs.back());
         }
     }
 
@@ -178,6 +181,11 @@
         if (visit == PreVisit)
         {
             mLoopsAndSwitches.push_back(loop);
+
+            if (mMetadata->hasGradientInCallGraph(loop))
+            {
+                onGradientLoop();
+            }
         }
         else if (visit == PostVisit)
         {
@@ -199,9 +207,9 @@
             ASSERT(mIfs.back() == node);
             mIfs.pop_back();
             // An if using a discontinuous loop means its parents ifs are also discontinuous.
-            if (mMetadata->mIfsContainingDiscontinuousLoop.count(node) > 0 && !mIfs.empty())
+            if (mMetadata->mIfsContainingGradientLoop.count(node) > 0 && !mIfs.empty())
             {
-                mMetadata->mIfsContainingDiscontinuousLoop.insert(mIfs.back());
+                mMetadata->mIfsContainingGradientLoop.insert(mIfs.back());
             }
         }
 
@@ -221,7 +229,6 @@
                     if (loop != nullptr)
                     {
                         mMetadata->mDiscontinuousLoops.insert(loop);
-                        onDiscontinuousLoop();
                     }
                 }
                 break;
@@ -237,7 +244,6 @@
                     }
                     ASSERT(loop != nullptr);
                     mMetadata->mDiscontinuousLoops.insert(loop);
-                    onDiscontinuousLoop();
                 }
                 break;
               case EOpKill:
@@ -253,7 +259,6 @@
                             mMetadata->mDiscontinuousLoops.insert(loop);
                         }
                     }
-                    onDiscontinuousLoop();
                 }
                 break;
               default:
@@ -274,9 +279,9 @@
                 ASSERT(calleeIndex != CallDAG::InvalidIndex && calleeIndex < mIndex);
                 UNUSED_ASSERTION_VARIABLE(mIndex);
 
-                if ((*mMetadataList)[calleeIndex].mHasDiscontinuousLoopInCallGraph)
+                if ((*mMetadataList)[calleeIndex].mHasGradientLoopInCallGraph)
                 {
-                    onDiscontinuousLoop();
+                    onGradientLoop();
                 }
             }
         }
@@ -375,19 +380,14 @@
 
 }
 
-bool ASTMetadataHLSL::hasGradientInCallGraph(TIntermSelection *node)
-{
-    return mControlFlowsContainingGradient.count(node) > 0;
-}
-
 bool ASTMetadataHLSL::hasGradientInCallGraph(TIntermLoop *node)
 {
     return mControlFlowsContainingGradient.count(node) > 0;
 }
 
-bool ASTMetadataHLSL::hasDiscontinuousLoop(TIntermSelection *node)
+bool ASTMetadataHLSL::hasGradientLoop(TIntermSelection *node)
 {
-    return mIfsContainingDiscontinuousLoop.count(node) > 0;
+    return mIfsContainingGradientLoop.count(node) > 0;
 }
 
 MetadataList CreateASTMetadataHLSL(TIntermNode *root, const CallDAG &callDag)
@@ -424,10 +424,10 @@
     // of callgraph analysis as for the gradient.
 
     // First compute which loops are discontinuous (no specific order) and pull
-    // the ifs and functions using a discontinuous loop.
+    // the ifs and functions using a gradient loop.
     for (size_t i = 0; i < callDag.size(); i++)
     {
-        PullComputeDiscontinuousLoops pull(&metadataList, i, callDag);
+        PullComputeDiscontinuousAndGradientLoops pull(&metadataList, i, callDag);
         pull.traverse(callDag.getRecordFromIndex(i).node);
     }