[SimplifyIndvar] Replace the srem used by IV if we can prove both of its operands are non-negative

Since now SCEV can handle 'urem', an 'urem' is a better canonical form than an 'srem' because it has well-defined behavior

This is a follow up of D34598

Differential Revision: https://reviews.llvm.org/D38072

llvm-svn: 314125
diff --git a/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp b/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp
index 6d90e6b..deaddcb 100644
--- a/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyIndVar.cpp
@@ -39,6 +39,9 @@
 STATISTIC(
     NumSimplifiedSDiv,
     "Number of IV signed division operations converted to unsigned division");
+STATISTIC(
+    NumSimplifiedSRem,
+    "Number of IV signed remainder operations converted to unsigned remainder");
 STATISTIC(NumElimCmp     , "Number of IV comparisons eliminated");
 
 namespace {
@@ -77,8 +80,11 @@
     bool eliminateOverflowIntrinsic(CallInst *CI);
     bool eliminateIVUser(Instruction *UseInst, Instruction *IVOperand);
     void eliminateIVComparison(ICmpInst *ICmp, Value *IVOperand);
-    void eliminateIVRemainder(BinaryOperator *Rem, Value *IVOperand,
-                              bool IsSigned);
+    void simplifyIVRemainder(BinaryOperator *Rem, Value *IVOperand,
+                             bool IsSigned);
+    void replaceRemWithNumerator(BinaryOperator *Rem);
+    void replaceRemWithNumeratorOrZero(BinaryOperator *Rem);
+    void replaceSRemWithURem(BinaryOperator *Rem);
     bool eliminateSDiv(BinaryOperator *SDiv);
     bool strengthenOverflowingOperation(BinaryOperator *OBO, Value *IVOperand);
     bool strengthenRightShift(BinaryOperator *BO, Value *IVOperand);
@@ -309,56 +315,92 @@
   return false;
 }
 
-/// SimplifyIVUsers helper for eliminating useless
-/// remainder operations operating on an induction variable.
-void SimplifyIndvar::eliminateIVRemainder(BinaryOperator *Rem,
-                                      Value *IVOperand,
-                                      bool IsSigned) {
-  // We're only interested in the case where we know something about
-  // the numerator.
-  if (IVOperand != Rem->getOperand(0))
-    return;
+// i %s n -> i %u n if i >= 0 and n >= 0
+void SimplifyIndvar::replaceSRemWithURem(BinaryOperator *Rem) {
+  auto *N = Rem->getOperand(0), *D = Rem->getOperand(1);
+  auto *URem = BinaryOperator::Create(BinaryOperator::URem, N, D,
+                                      Rem->getName() + ".urem", Rem);
+  Rem->replaceAllUsesWith(URem);
+  DEBUG(dbgs() << "INDVARS: Simplified srem: " << *Rem << '\n');
+  ++NumSimplifiedSRem;
+  DeadInsts.emplace_back(Rem);
+}
 
-  // Get the SCEVs for the ICmp operands.
-  const SCEV *S = SE->getSCEV(Rem->getOperand(0));
-  const SCEV *X = SE->getSCEV(Rem->getOperand(1));
-
-  // Simplify unnecessary loops away.
-  const Loop *ICmpLoop = LI->getLoopFor(Rem->getParent());
-  S = SE->getSCEVAtScope(S, ICmpLoop);
-  X = SE->getSCEVAtScope(X, ICmpLoop);
-
-  // i % n  -->  i  if i is in [0,n).
-  if ((!IsSigned || SE->isKnownNonNegative(S)) &&
-      SE->isKnownPredicate(IsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT,
-                           S, X))
-    Rem->replaceAllUsesWith(Rem->getOperand(0));
-  else {
-    // (i+1) % n  -->  (i+1)==n?0:(i+1)  if i is in [0,n).
-    const SCEV *LessOne = SE->getMinusSCEV(S, SE->getOne(S->getType()));
-    if (IsSigned && !SE->isKnownNonNegative(LessOne))
-      return;
-
-    if (!SE->isKnownPredicate(IsSigned ?
-                              ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT,
-                              LessOne, X))
-      return;
-
-    ICmpInst *ICmp = new ICmpInst(Rem, ICmpInst::ICMP_EQ,
-                                  Rem->getOperand(0), Rem->getOperand(1));
-    SelectInst *Sel =
-      SelectInst::Create(ICmp,
-                         ConstantInt::get(Rem->getType(), 0),
-                         Rem->getOperand(0), "tmp", Rem);
-    Rem->replaceAllUsesWith(Sel);
-  }
-
+// i % n  -->  i  if i is in [0,n).
+void SimplifyIndvar::replaceRemWithNumerator(BinaryOperator *Rem) {
+  Rem->replaceAllUsesWith(Rem->getOperand(0));
   DEBUG(dbgs() << "INDVARS: Simplified rem: " << *Rem << '\n');
   ++NumElimRem;
   Changed = true;
   DeadInsts.emplace_back(Rem);
 }
 
+// (i+1) % n  -->  (i+1)==n?0:(i+1)  if i is in [0,n).
+void SimplifyIndvar::replaceRemWithNumeratorOrZero(BinaryOperator *Rem) {
+  auto *T = Rem->getType();
+  auto *N = Rem->getOperand(0), *D = Rem->getOperand(1);
+  ICmpInst *ICmp = new ICmpInst(Rem, ICmpInst::ICMP_EQ, N, D);
+  SelectInst *Sel =
+      SelectInst::Create(ICmp, ConstantInt::get(T, 0), N, "iv.rem", Rem);
+  Rem->replaceAllUsesWith(Sel);
+  DEBUG(dbgs() << "INDVARS: Simplified rem: " << *Rem << '\n');
+  ++NumElimRem;
+  Changed = true;
+  DeadInsts.emplace_back(Rem);
+}
+
+/// SimplifyIVUsers helper for eliminating useless remainder operations
+/// operating on an induction variable or replacing srem by urem.
+void SimplifyIndvar::simplifyIVRemainder(BinaryOperator *Rem, Value *IVOperand,
+                                         bool IsSigned) {
+  auto *NValue = Rem->getOperand(0);
+  auto *DValue = Rem->getOperand(1);
+  // We're only interested in the case where we know something about
+  // the numerator, unless it is a srem, because we want to replace srem by urem
+  // in general.
+  bool UsedAsNumerator = IVOperand == NValue;
+  if (!UsedAsNumerator && !IsSigned)
+    return;
+
+  const SCEV *N = SE->getSCEV(NValue);
+
+  // Simplify unnecessary loops away.
+  const Loop *ICmpLoop = LI->getLoopFor(Rem->getParent());
+  N = SE->getSCEVAtScope(N, ICmpLoop);
+
+  bool IsNumeratorNonNegative = !IsSigned || SE->isKnownNonNegative(N);
+
+  // Do not proceed if the Numerator may be negative
+  if (!IsNumeratorNonNegative)
+    return;
+
+  const SCEV *D = SE->getSCEV(DValue);
+  D = SE->getSCEVAtScope(D, ICmpLoop);
+
+  if (UsedAsNumerator) {
+    auto LT = IsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT;
+    if (SE->isKnownPredicate(LT, N, D)) {
+      replaceRemWithNumerator(Rem);
+      return;
+    }
+
+    auto *T = Rem->getType();
+    const auto *NLessOne = SE->getMinusSCEV(N, SE->getOne(T));
+    if (SE->isKnownPredicate(LT, NLessOne, D)) {
+      replaceRemWithNumeratorOrZero(Rem);
+      return;
+    }
+  }
+
+  // Try to replace SRem with URem, if both N and D are known non-negative.
+  // Since we had already check N, we only need to check D now
+  if (!IsSigned || !SE->isKnownNonNegative(D))
+    return;
+
+  replaceSRemWithURem(Rem);
+  Changed = true;
+}
+
 bool SimplifyIndvar::eliminateOverflowIntrinsic(CallInst *CI) {
   auto *F = CI->getCalledFunction();
   if (!F)
@@ -474,7 +516,7 @@
   if (BinaryOperator *Bin = dyn_cast<BinaryOperator>(UseInst)) {
     bool IsSRem = Bin->getOpcode() == Instruction::SRem;
     if (IsSRem || Bin->getOpcode() == Instruction::URem) {
-      eliminateIVRemainder(Bin, IVOperand, IsSRem);
+      simplifyIVRemainder(Bin, IVOperand, IsSRem);
       return true;
     }