[InstCombine] use m_APInt to allow (X << C) >>u C --> X & (-1 >>u C) with splat vectors

llvm-svn: 293208
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 21a9758..a4ff5f7 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -353,29 +353,37 @@
   // Combinations of right and left shifts will still be optimized in
   // DAGCombine where scalar evolution no longer applies.
 
+  Value *X = ShiftOp->getOperand(0);
+  unsigned ShiftAmt1 = ShAmt1->getLimitedValue();
+  unsigned ShiftAmt2 = COp1->getLimitedValue();
+  assert(ShiftAmt2 != 0 && "Should have been simplified earlier");
+  if (ShiftAmt1 == 0)
+    return nullptr; // Will be simplified in the future.
+
+  if (ShiftAmt1 == ShiftAmt2) {
+    // FIXME: This repeats a fold that exists in foldShiftedShift(), but we're
+    // not handling the related fold here:
+    // (X >>u C) << C --> X & (-1 << C).
+    // foldShiftedShift() is always called before this, but it is restricted to
+    // only handle cases where the ShiftOp has one use. We don't have that
+    // restriction here.
+    if (I.getOpcode() != Instruction::LShr ||
+        ShiftOp->getOpcode() != Instruction::Shl)
+      return nullptr;
+
+    // (X << C) >>u C --> X & (-1 >>u C).
+    APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt1));
+    return BinaryOperator::CreateAnd(X, ConstantInt::get(I.getType(), Mask));
+  }
+
   // FIXME: Everything under here should be extended to work with vector types.
 
   auto *ShiftAmt1C = dyn_cast<ConstantInt>(ShiftOp->getOperand(1));
   if (!ShiftAmt1C)
     return nullptr;
 
-  uint32_t ShiftAmt1 = ShiftAmt1C->getLimitedValue(TypeBits);
-  uint32_t ShiftAmt2 = COp1->getLimitedValue(TypeBits);
-  assert(ShiftAmt2 != 0 && "Should have been simplified earlier");
-  if (ShiftAmt1 == 0)
-    return nullptr; // Will be simplified in the future.
-
-  Value *X = ShiftOp->getOperand(0);
   IntegerType *Ty = cast<IntegerType>(I.getType());
-  if (ShiftAmt1 == ShiftAmt2) {
-    // If we have ((X << C) >>u C), turn this into X & (-1 >>u C).
-    if (I.getOpcode() == Instruction::LShr &&
-        ShiftOp->getOpcode() == Instruction::Shl) {
-      APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt1));
-      return BinaryOperator::CreateAnd(X,
-                                       ConstantInt::get(I.getContext(), Mask));
-    }
-  } else if (ShiftAmt1 < ShiftAmt2) {
+  if (ShiftAmt1 < ShiftAmt2) {
     uint32_t ShiftDiff = ShiftAmt2 - ShiftAmt1;
 
     // (X >>?,exact C1) << C2 --> X << (C2-C1)