[InstCombine] "X - (X / C) * C == 0" to "X & C-1 == 0"
Summary:
"X % C == 0" is optimized to "X & C-1 == 0" (where C is a power-of-two)
However, "X % Y" can also be represented as "X - (X / Y) * Y" so if I rewrite the initial expression:
"X - (X / C) * C == 0" it's not currently optimized to "X & C-1 == 0", see godbolt: https://godbolt.org/z/KzuXUj
This is my first contribution to LLVM so I hope I didn't mess things up
Reviewers: lebedev.ri, spatel
Reviewed By: lebedev.ri
Subscribers: hiraditya, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D79369
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 288d0d1..68367b7 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1314,6 +1314,17 @@
   // X % C0 + (( X / C0 ) % C1) * C0 => X % (C0 * C1)
   if (Value *V = SimplifyAddWithRemainder(I)) return replaceInstUsesWith(I, V);
 
+  // ((X s/ C1) << C2) + X => X s% -C1 where -C1 is 1 << C2
+  const APInt *C1, *C2;
+  if (match(LHS, m_Shl(m_SDiv(m_Specific(RHS), m_APInt(C1)), m_APInt(C2)))) {
+    APInt one(C2->getBitWidth(), 1);
+    APInt minusC1 = -(*C1);
+    if (minusC1 == (one << *C2)) {
+      Constant *NewRHS = ConstantInt::get(RHS->getType(), minusC1);
+      return BinaryOperator::CreateSRem(RHS, NewRHS);
+    }
+  }
+
   // A+B --> A|B iff A and B have no bits set in common.
   if (haveNoCommonBitsSet(LHS, RHS, DL, &AC, &I, &DT))
     return BinaryOperator::CreateOr(LHS, RHS);