Only unswitch loops with uniform conditions

Loop unswitching can be extremely harmful for a SIMT target. In case
if hoisted condition is not uniform a SIMT machine will execute both
clones of a loop sequentially. Therefor LoopUnswitch checks if the
condition is non-divergent.

Since DivergenceAnalysis adds an expensive PostDominatorTree analysis
not needed for non-SIMT targets a new option is added to avoid unneded
analysis initialization. The method getAnalysisUsage is called when
TargetTransformInfo is not yet available and we cannot use it here.
For that reason a new field DivergentTarget is added to PassManagerBuilder
to control the behavior and set this field from a target.

Differential Revision: https://reviews.llvm.org/D30796

llvm-svn: 298104
diff --git a/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp b/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp
index c3acd9f..842f395 100644
--- a/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp
+++ b/llvm/lib/Transforms/IPO/PassManagerBuilder.cpp
@@ -168,6 +168,7 @@
     PGOInstrUse = RunPGOInstrUse;
     PrepareForThinLTO = EnablePrepareForThinLTO;
     PerformThinLTO = false;
+    DivergentTarget = false;
 }
 
 PassManagerBuilder::~PassManagerBuilder() {
@@ -307,7 +308,7 @@
   // Rotate Loop - disable header duplication at -Oz
   MPM.add(createLoopRotatePass(SizeLevel == 2 ? 0 : -1));
   MPM.add(createLICMPass());                  // Hoist loop invariants
-  MPM.add(createLoopUnswitchPass(SizeLevel || OptLevel < 3));
+  MPM.add(createLoopUnswitchPass(SizeLevel || OptLevel < 3, DivergentTarget));
   MPM.add(createCFGSimplificationPass());
   addInstructionCombiningPass(MPM);
   MPM.add(createIndVarSimplifyPass());        // Canonicalize indvars
@@ -588,7 +589,7 @@
     MPM.add(createCorrelatedValuePropagationPass());
     addInstructionCombiningPass(MPM);
     MPM.add(createLICMPass());
-    MPM.add(createLoopUnswitchPass(SizeLevel || OptLevel < 3));
+    MPM.add(createLoopUnswitchPass(SizeLevel || OptLevel < 3, DivergentTarget));
     MPM.add(createCFGSimplificationPass());
     addInstructionCombiningPass(MPM);
   }
diff --git a/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp b/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp
index a670c6e..fe4161a 100644
--- a/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp
@@ -33,6 +33,7 @@
 #include "llvm/Analysis/GlobalsModRef.h"
 #include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/Analysis/CodeMetrics.h"
+#include "llvm/Analysis/DivergenceAnalysis.h"
 #include "llvm/Analysis/InstructionSimplify.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/LoopPass.h"
@@ -180,12 +181,14 @@
     // NewBlocks contained cloned copy of basic blocks from LoopBlocks.
     std::vector<BasicBlock*> NewBlocks;
 
+    bool hasBranchDivergence;
+
   public:
     static char ID; // Pass ID, replacement for typeid
-    explicit LoopUnswitch(bool Os = false) :
+    explicit LoopUnswitch(bool Os = false, bool hasBranchDivergence = false) :
       LoopPass(ID), OptimizeForSize(Os), redoLoop(false),
       currentLoop(nullptr), DT(nullptr), loopHeader(nullptr),
-      loopPreheader(nullptr) {
+      loopPreheader(nullptr), hasBranchDivergence(hasBranchDivergence) {
         initializeLoopUnswitchPass(*PassRegistry::getPassRegistry());
       }
 
@@ -198,6 +201,8 @@
     void getAnalysisUsage(AnalysisUsage &AU) const override {
       AU.addRequired<AssumptionCacheTracker>();
       AU.addRequired<TargetTransformInfoWrapperPass>();
+      if (hasBranchDivergence)
+        AU.addRequired<DivergenceAnalysis>();
       getLoopAnalysisUsage(AU);
     }
 
@@ -367,11 +372,12 @@
 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
 INITIALIZE_PASS_DEPENDENCY(LoopPass)
 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(DivergenceAnalysis)
 INITIALIZE_PASS_END(LoopUnswitch, "loop-unswitch", "Unswitch loops",
                       false, false)
 
-Pass *llvm::createLoopUnswitchPass(bool Os) {
-  return new LoopUnswitch(Os);
+Pass *llvm::createLoopUnswitchPass(bool Os, bool hasBranchDivergence) {
+  return new LoopUnswitch(Os, hasBranchDivergence);
 }
 
 /// Operator chain lattice.
@@ -808,6 +814,15 @@
                  << ". Cost too high.\n");
     return false;
   }
+  if (hasBranchDivergence &&
+      getAnalysis<DivergenceAnalysis>().isDivergent(LoopCond)) {
+    DEBUG(dbgs() << "NOT unswitching loop %"
+                 << currentLoop->getHeader()->getName()
+                 << " at non-trivial condition '" << *Val
+                 << "' == " << *LoopCond << "\n"
+                 << ". Condition is divergent.\n");
+    return false;
+  }
 
   UnswitchNontrivialCondition(LoopCond, Val, currentLoop, TI);
   return true;