[InstCombine] Simplify Add with remainder expressions as operands.

Summary:
Simplify integer add expression X % C0 + (( X / C0 ) % C1) * C0 to
X % (C0 * C1).  This is a common pattern seen in code generated by the XLA
GPU backend.

Add test cases for this new optimization.

Patch by Bixia Zheng!

Reviewers: sanjoy

Reviewed By: sanjoy

Subscribers: efriedma, craig.topper, lebedev.ri, llvm-commits, jlebar

Differential Revision: https://reviews.llvm.org/D45976

llvm-svn: 330992
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index f5b0d5b..b7e115e 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1032,6 +1032,112 @@
   return nullptr;
 }
 
+// Matches multiplication expression Op * C where C is a constant. Returns the
+// constant value in C and the other operand in Op. Returns true if such a
+// match is found.
+static bool MatchMul(Value *E, Value *&Op, APInt &C) {
+  const APInt *AI;
+  if (match(E, m_Mul(m_Value(Op), m_APInt(AI)))) {
+    C = *AI;
+    return true;
+  }
+  if (match(E, m_Shl(m_Value(Op), m_APInt(AI)))) {
+    C = APInt(AI->getBitWidth(), 1);
+    C <<= *AI;
+    return true;
+  }
+  return false;
+}
+
+// Matches remainder expression Op % C where C is a constant. Returns the
+// constant value in C and the other operand in Op. Returns the signedness of
+// the remainder operation in IsSigned. Returns true if such a match is
+// found.
+static bool MatchRem(Value *E, Value *&Op, APInt &C, bool &IsSigned) {
+  const APInt *AI;
+  IsSigned = false;
+  if (match(E, m_SRem(m_Value(Op), m_APInt(AI)))) {
+    IsSigned = true;
+    C = *AI;
+    return true;
+  }
+  if (match(E, m_URem(m_Value(Op), m_APInt(AI)))) {
+    C = *AI;
+    return true;
+  }
+  if (match(E, m_And(m_Value(Op), m_APInt(AI))) && (*AI + 1).isPowerOf2()) {
+    C = *AI + 1;
+    return true;
+  }
+  return false;
+}
+
+// Matches division expression Op / C with the given signedness as indicated
+// by IsSigned, where C is a constant. Returns the constant value in C and the
+// other operand in Op. Returns true if such a match is found.
+static bool MatchDiv(Value *E, Value *&Op, APInt &C, bool IsSigned) {
+  const APInt *AI;
+  if (IsSigned && match(E, m_SDiv(m_Value(Op), m_APInt(AI)))) {
+    C = *AI;
+    return true;
+  }
+  if (!IsSigned) {
+    if (match(E, m_UDiv(m_Value(Op), m_APInt(AI)))) {
+      C = *AI;
+      return true;
+    }
+    if (match(E, m_LShr(m_Value(Op), m_APInt(AI)))) {
+      C = APInt(AI->getBitWidth(), 1);
+      C <<= *AI;
+      return true;
+    }
+  }
+  return false;
+}
+
+// Returns whether C0 * C1 with the given signedness overflows.
+static bool MulWillOverflow(APInt &C0, APInt &C1, bool IsSigned) {
+  bool overflow;
+  if (IsSigned)
+    (void)C0.smul_ov(C1, overflow);
+  else
+    (void)C0.umul_ov(C1, overflow);
+  return overflow;
+}
+
+// Simplifies X % C0 + (( X / C0 ) % C1) * C0 to X % (C0 * C1), where (C0 * C1)
+// does not overflow.
+Value *InstCombiner::SimplifyAddWithRemainder(BinaryOperator &I) {
+  Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
+  Value *X, *MulOpV;
+  APInt C0, MulOpC;
+  bool IsSigned;
+  // Match I = X % C0 + MulOpV * C0
+  if (((MatchRem(LHS, X, C0, IsSigned) && MatchMul(RHS, MulOpV, MulOpC)) ||
+       (MatchRem(RHS, X, C0, IsSigned) && MatchMul(LHS, MulOpV, MulOpC))) &&
+      C0 == MulOpC) {
+    Value *RemOpV;
+    APInt C1;
+    bool Rem2IsSigned;
+    // Match MulOpC = RemOpV % C1
+    if (MatchRem(MulOpV, RemOpV, C1, Rem2IsSigned) &&
+        IsSigned == Rem2IsSigned) {
+      Value *DivOpV;
+      APInt DivOpC;
+      // Match RemOpV = X / C0
+      if (MatchDiv(RemOpV, DivOpV, DivOpC, IsSigned) && X == DivOpV &&
+          C0 == DivOpC && !MulWillOverflow(C0, C1, IsSigned)) {
+        Value *NewDivisor =
+            ConstantInt::get(X->getType()->getContext(), C0 * C1);
+        return IsSigned ? Builder.CreateSRem(X, NewDivisor, "srem")
+                        : Builder.CreateURem(X, NewDivisor, "urem");
+      }
+    }
+  }
+
+  return nullptr;
+}
+
 Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
   bool Changed = SimplifyAssociativeOrCommutative(I);
   if (Value *V = SimplifyVectorOp(I))
@@ -1124,6 +1230,9 @@
   if (Value *V = checkForNegativeOperand(I, Builder))
     return replaceInstUsesWith(I, V);
 
+  // X % C0 + (( X / C0 ) % C1) * C0 => X % (C0 * C1)
+  if (Value *V = SimplifyAddWithRemainder(I)) return replaceInstUsesWith(I, V);
+
   // A+B --> A|B iff A and B have no bits set in common.
   if (haveNoCommonBitsSet(LHS, RHS, DL, &AC, &I, &DT))
     return BinaryOperator::CreateOr(LHS, RHS);