Generalize LSR's OptimizeMax to handle the new kinds of max expressions
that indvars may use, now that indvars is recognizing le and ge loops.


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@102235 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/lib/Transforms/Scalar/LoopStrengthReduce.cpp
index a09b3dc..9fdbf4a 100644
--- a/lib/Transforms/Scalar/LoopStrengthReduce.cpp
+++ b/lib/Transforms/Scalar/LoopStrengthReduce.cpp
@@ -1461,12 +1461,26 @@
 
   // Add one to the backedge-taken count to get the trip count.
   const SCEV *IterationCount = SE.getAddExpr(BackedgeTakenCount, One);
+  if (IterationCount != SE.getSCEV(Sel)) return Cond;
 
-  // Check for a max calculation that matches the pattern.
-  if (!isa<SCEVSMaxExpr>(IterationCount) && !isa<SCEVUMaxExpr>(IterationCount))
+  // Check for a max calculation that matches the pattern. There's no check
+  // for ICMP_ULE here because the comparison would be with zero, which
+  // isn't interesting.
+  CmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE;
+  const SCEVNAryExpr *Max = 0;
+  if (const SCEVSMaxExpr *S = dyn_cast<SCEVSMaxExpr>(BackedgeTakenCount)) {
+    Pred = ICmpInst::ICMP_SLE;
+    Max = S;
+  } else if (const SCEVSMaxExpr *S = dyn_cast<SCEVSMaxExpr>(IterationCount)) {
+    Pred = ICmpInst::ICMP_SLT;
+    Max = S;
+  } else if (const SCEVUMaxExpr *U = dyn_cast<SCEVUMaxExpr>(IterationCount)) {
+    Pred = ICmpInst::ICMP_ULT;
+    Max = U;
+  } else {
+    // No match; bail.
     return Cond;
-  const SCEVNAryExpr *Max = cast<SCEVNAryExpr>(IterationCount);
-  if (Max != SE.getSCEV(Sel)) return Cond;
+  }
 
   // To handle a max with more than two operands, this optimization would
   // require additional checking and setup.
@@ -1475,7 +1489,13 @@
 
   const SCEV *MaxLHS = Max->getOperand(0);
   const SCEV *MaxRHS = Max->getOperand(1);
-  if (!MaxLHS || MaxLHS != One) return Cond;
+
+  // ScalarEvolution canonicalizes constants to the left. For < and >, look
+  // for a comparison with 1. For <= and >=, a comparison with zero.
+  if (!MaxLHS ||
+      (ICmpInst::isTrueWhenEqual(Pred) ? !MaxLHS->isZero() : (MaxLHS != One)))
+    return Cond;
+
   // Check the relevant induction variable for conformance to
   // the pattern.
   const SCEV *IV = SE.getSCEV(Cond->getOperand(0));
@@ -1491,16 +1511,29 @@
   // Check the right operand of the select, and remember it, as it will
   // be used in the new comparison instruction.
   Value *NewRHS = 0;
-  if (SE.getSCEV(Sel->getOperand(1)) == MaxRHS)
+  if (ICmpInst::isTrueWhenEqual(Pred)) {
+    // Look for n+1, and grab n.
+    if (AddOperator *BO = dyn_cast<AddOperator>(Sel->getOperand(1)))
+      if (isa<ConstantInt>(BO->getOperand(1)) &&
+          cast<ConstantInt>(BO->getOperand(1))->isOne() &&
+          SE.getSCEV(BO->getOperand(0)) == MaxRHS)
+        NewRHS = BO->getOperand(0);
+    if (AddOperator *BO = dyn_cast<AddOperator>(Sel->getOperand(2)))
+      if (isa<ConstantInt>(BO->getOperand(1)) &&
+          cast<ConstantInt>(BO->getOperand(1))->isOne() &&
+          SE.getSCEV(BO->getOperand(0)) == MaxRHS)
+        NewRHS = BO->getOperand(0);
+    if (!NewRHS)
+      return Cond;
+  } else if (SE.getSCEV(Sel->getOperand(1)) == MaxRHS)
     NewRHS = Sel->getOperand(1);
   else if (SE.getSCEV(Sel->getOperand(2)) == MaxRHS)
     NewRHS = Sel->getOperand(2);
-  if (!NewRHS) return Cond;
+  else
+    llvm_unreachable("Max doesn't match expected pattern!");
 
   // Determine the new comparison opcode. It may be signed or unsigned,
   // and the original comparison may be either equality or inequality.
-  CmpInst::Predicate Pred =
-    isa<SCEVSMaxExpr>(Max) ? CmpInst::ICMP_SLT : CmpInst::ICMP_ULT;
   if (Cond->getPredicate() == CmpInst::ICMP_EQ)
     Pred = CmpInst::getInversePredicate(Pred);