[InstCombine] use m_APInt to allow lshr folds for vectors with splat constants

llvm-svn: 291972
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 5209dad..078c938 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -774,34 +774,31 @@
   if (Instruction *R = commonShiftTransforms(I))
     return R;
 
-  if (ConstantInt *Op1C = dyn_cast<ConstantInt>(Op1)) {
-    unsigned ShAmt = Op1C->getZExtValue();
-
-    if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op0)) {
-      unsigned BitWidth = Op0->getType()->getScalarSizeInBits();
+  const APInt *ShAmtAPInt;
+  if (match(Op1, m_APInt(ShAmtAPInt))) {
+    unsigned ShAmt = ShAmtAPInt->getZExtValue();
+    unsigned BitWidth = Op0->getType()->getScalarSizeInBits();
+    auto *II = dyn_cast<IntrinsicInst>(Op0);
+    if (II && isPowerOf2_32(BitWidth) && Log2_32(BitWidth) == ShAmt &&
+        (II->getIntrinsicID() == Intrinsic::ctlz ||
+         II->getIntrinsicID() == Intrinsic::cttz ||
+         II->getIntrinsicID() == Intrinsic::ctpop)) {
       // ctlz.i32(x)>>5  --> zext(x == 0)
       // cttz.i32(x)>>5  --> zext(x == 0)
       // ctpop.i32(x)>>5 --> zext(x == -1)
-      if ((II->getIntrinsicID() == Intrinsic::ctlz ||
-           II->getIntrinsicID() == Intrinsic::cttz ||
-           II->getIntrinsicID() == Intrinsic::ctpop) &&
-          isPowerOf2_32(BitWidth) && Log2_32(BitWidth) == ShAmt) {
-        bool isCtPop = II->getIntrinsicID() == Intrinsic::ctpop;
-        Constant *RHS = ConstantInt::getSigned(Op0->getType(), isCtPop ? -1:0);
-        Value *Cmp = Builder->CreateICmpEQ(II->getArgOperand(0), RHS);
-        return new ZExtInst(Cmp, II->getType());
-      }
+      bool IsPop = II->getIntrinsicID() == Intrinsic::ctpop;
+      Constant *RHS = ConstantInt::getSigned(Op0->getType(), IsPop ? -1 : 0);
+      Value *Cmp = Builder->CreateICmpEQ(II->getArgOperand(0), RHS);
+      return new ZExtInst(Cmp, II->getType());
     }
 
     // If the shifted-out value is known-zero, then this is an exact shift.
     if (!I.isExact() &&
-        MaskedValueIsZero(Op0, APInt::getLowBitsSet(Op1C->getBitWidth(), ShAmt),
-                          0, &I)){
+        MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmt), 0, &I)) {
       I.setIsExact();
       return &I;
     }
   }
-
   return nullptr;
 }