[InstSimplify] allow icmp with constant folds for splat vectors, part 2

Completes the m_APInt changes for simplifyICmpWithConstant().

Other commits in this series:
https://reviews.llvm.org/rL279492
https://reviews.llvm.org/rL279530
https://reviews.llvm.org/rL279534
https://reviews.llvm.org/rL279538

llvm-svn: 279543
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index c048f31..9be4613 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -2165,119 +2165,113 @@
   if (RHS_CR.isFullSet())
     return ConstantInt::getTrue(GetCompareTy(RHS));
 
-  // FIXME: Use m_APInt below here to allow splat vector folds.
-  ConstantInt *CI = dyn_cast<ConstantInt>(RHS);
-  if (!CI)
-    return nullptr;
-
   // Many binary operators with constant RHS have easy to compute constant
   // range.  Use them to check whether the comparison is a tautology.
-  unsigned Width = CI->getBitWidth();
+  unsigned Width = C->getBitWidth();
   APInt Lower = APInt(Width, 0);
   APInt Upper = APInt(Width, 0);
-  ConstantInt *CI2;
-  if (match(LHS, m_URem(m_Value(), m_ConstantInt(CI2)))) {
-    // 'urem x, CI2' produces [0, CI2).
-    Upper = CI2->getValue();
-  } else if (match(LHS, m_SRem(m_Value(), m_ConstantInt(CI2)))) {
-    // 'srem x, CI2' produces (-|CI2|, |CI2|).
-    Upper = CI2->getValue().abs();
+  const APInt *C2;
+  if (match(LHS, m_URem(m_Value(), m_APInt(C2)))) {
+    // 'urem x, C2' produces [0, C2).
+    Upper = *C2;
+  } else if (match(LHS, m_SRem(m_Value(), m_APInt(C2)))) {
+    // 'srem x, C2' produces (-|C2|, |C2|).
+    Upper = C2->abs();
     Lower = (-Upper) + 1;
-  } else if (match(LHS, m_UDiv(m_ConstantInt(CI2), m_Value()))) {
-    // 'udiv CI2, x' produces [0, CI2].
-    Upper = CI2->getValue() + 1;
-  } else if (match(LHS, m_UDiv(m_Value(), m_ConstantInt(CI2)))) {
-    // 'udiv x, CI2' produces [0, UINT_MAX / CI2].
+  } else if (match(LHS, m_UDiv(m_APInt(C2), m_Value()))) {
+    // 'udiv C2, x' produces [0, C2].
+    Upper = *C2 + 1;
+  } else if (match(LHS, m_UDiv(m_Value(), m_APInt(C2)))) {
+    // 'udiv x, C2' produces [0, UINT_MAX / C2].
     APInt NegOne = APInt::getAllOnesValue(Width);
-    if (!CI2->isZero())
-      Upper = NegOne.udiv(CI2->getValue()) + 1;
-  } else if (match(LHS, m_SDiv(m_ConstantInt(CI2), m_Value()))) {
-    if (CI2->isMinSignedValue()) {
+    if (*C2 != 0)
+      Upper = NegOne.udiv(*C2) + 1;
+  } else if (match(LHS, m_SDiv(m_APInt(C2), m_Value()))) {
+    if (C2->isMinSignedValue()) {
       // 'sdiv INT_MIN, x' produces [INT_MIN, INT_MIN / -2].
-      Lower = CI2->getValue();
+      Lower = *C2;
       Upper = Lower.lshr(1) + 1;
     } else {
-      // 'sdiv CI2, x' produces [-|CI2|, |CI2|].
-      Upper = CI2->getValue().abs() + 1;
+      // 'sdiv C2, x' produces [-|C2|, |C2|].
+      Upper = C2->abs() + 1;
       Lower = (-Upper) + 1;
     }
-  } else if (match(LHS, m_SDiv(m_Value(), m_ConstantInt(CI2)))) {
+  } else if (match(LHS, m_SDiv(m_Value(), m_APInt(C2)))) {
     APInt IntMin = APInt::getSignedMinValue(Width);
     APInt IntMax = APInt::getSignedMaxValue(Width);
-    const APInt &Val = CI2->getValue();
-    if (Val.isAllOnesValue()) {
+    if (C2->isAllOnesValue()) {
       // 'sdiv x, -1' produces [INT_MIN + 1, INT_MAX]
-      //    where CI2 != -1 and CI2 != 0 and CI2 != 1
+      //    where C2 != -1 and C2 != 0 and C2 != 1
       Lower = IntMin + 1;
       Upper = IntMax + 1;
-    } else if (Val.countLeadingZeros() < Width - 1) {
-      // 'sdiv x, CI2' produces [INT_MIN / CI2, INT_MAX / CI2]
-      //    where CI2 != -1 and CI2 != 0 and CI2 != 1
-      Lower = IntMin.sdiv(Val);
-      Upper = IntMax.sdiv(Val);
+    } else if (C2->countLeadingZeros() < Width - 1) {
+      // 'sdiv x, C2' produces [INT_MIN / C2, INT_MAX / C2]
+      //    where C2 != -1 and C2 != 0 and C2 != 1
+      Lower = IntMin.sdiv(*C2);
+      Upper = IntMax.sdiv(*C2);
       if (Lower.sgt(Upper))
         std::swap(Lower, Upper);
       Upper = Upper + 1;
       assert(Upper != Lower && "Upper part of range has wrapped!");
     }
-  } else if (match(LHS, m_NUWShl(m_ConstantInt(CI2), m_Value()))) {
-    // 'shl nuw CI2, x' produces [CI2, CI2 << CLZ(CI2)]
-    Lower = CI2->getValue();
+  } else if (match(LHS, m_NUWShl(m_APInt(C2), m_Value()))) {
+    // 'shl nuw C2, x' produces [C2, C2 << CLZ(C2)]
+    Lower = *C2;
     Upper = Lower.shl(Lower.countLeadingZeros()) + 1;
-  } else if (match(LHS, m_NSWShl(m_ConstantInt(CI2), m_Value()))) {
-    if (CI2->isNegative()) {
-      // 'shl nsw CI2, x' produces [CI2 << CLO(CI2)-1, CI2]
-      unsigned ShiftAmount = CI2->getValue().countLeadingOnes() - 1;
-      Lower = CI2->getValue().shl(ShiftAmount);
-      Upper = CI2->getValue() + 1;
+  } else if (match(LHS, m_NSWShl(m_APInt(C2), m_Value()))) {
+    if (C2->isNegative()) {
+      // 'shl nsw C2, x' produces [C2 << CLO(C2)-1, C2]
+      unsigned ShiftAmount = C2->countLeadingOnes() - 1;
+      Lower = C2->shl(ShiftAmount);
+      Upper = *C2 + 1;
     } else {
-      // 'shl nsw CI2, x' produces [CI2, CI2 << CLZ(CI2)-1]
-      unsigned ShiftAmount = CI2->getValue().countLeadingZeros() - 1;
-      Lower = CI2->getValue();
-      Upper = CI2->getValue().shl(ShiftAmount) + 1;
+      // 'shl nsw C2, x' produces [C2, C2 << CLZ(C2)-1]
+      unsigned ShiftAmount = C2->countLeadingZeros() - 1;
+      Lower = *C2;
+      Upper = C2->shl(ShiftAmount) + 1;
     }
-  } else if (match(LHS, m_LShr(m_Value(), m_ConstantInt(CI2)))) {
-    // 'lshr x, CI2' produces [0, UINT_MAX >> CI2].
+  } else if (match(LHS, m_LShr(m_Value(), m_APInt(C2)))) {
+    // 'lshr x, C2' produces [0, UINT_MAX >> C2].
     APInt NegOne = APInt::getAllOnesValue(Width);
-    if (CI2->getValue().ult(Width))
-      Upper = NegOne.lshr(CI2->getValue()) + 1;
-  } else if (match(LHS, m_LShr(m_ConstantInt(CI2), m_Value()))) {
-    // 'lshr CI2, x' produces [CI2 >> (Width-1), CI2].
+    if (C2->ult(Width))
+      Upper = NegOne.lshr(*C2) + 1;
+  } else if (match(LHS, m_LShr(m_APInt(C2), m_Value()))) {
+    // 'lshr C2, x' produces [C2 >> (Width-1), C2].
     unsigned ShiftAmount = Width - 1;
-    if (!CI2->isZero() && cast<BinaryOperator>(LHS)->isExact())
-      ShiftAmount = CI2->getValue().countTrailingZeros();
-    Lower = CI2->getValue().lshr(ShiftAmount);
-    Upper = CI2->getValue() + 1;
-  } else if (match(LHS, m_AShr(m_Value(), m_ConstantInt(CI2)))) {
-    // 'ashr x, CI2' produces [INT_MIN >> CI2, INT_MAX >> CI2].
+    if (*C2 != 0 && cast<BinaryOperator>(LHS)->isExact())
+      ShiftAmount = C2->countTrailingZeros();
+    Lower = C2->lshr(ShiftAmount);
+    Upper = *C2 + 1;
+  } else if (match(LHS, m_AShr(m_Value(), m_APInt(C2)))) {
+    // 'ashr x, C2' produces [INT_MIN >> C2, INT_MAX >> C2].
     APInt IntMin = APInt::getSignedMinValue(Width);
     APInt IntMax = APInt::getSignedMaxValue(Width);
-    if (CI2->getValue().ult(Width)) {
-      Lower = IntMin.ashr(CI2->getValue());
-      Upper = IntMax.ashr(CI2->getValue()) + 1;
+    if (C2->ult(Width)) {
+      Lower = IntMin.ashr(*C2);
+      Upper = IntMax.ashr(*C2) + 1;
     }
-  } else if (match(LHS, m_AShr(m_ConstantInt(CI2), m_Value()))) {
+  } else if (match(LHS, m_AShr(m_APInt(C2), m_Value()))) {
     unsigned ShiftAmount = Width - 1;
-    if (!CI2->isZero() && cast<BinaryOperator>(LHS)->isExact())
-      ShiftAmount = CI2->getValue().countTrailingZeros();
-    if (CI2->isNegative()) {
-      // 'ashr CI2, x' produces [CI2, CI2 >> (Width-1)]
-      Lower = CI2->getValue();
-      Upper = CI2->getValue().ashr(ShiftAmount) + 1;
+    if (*C2 != 0 && cast<BinaryOperator>(LHS)->isExact())
+      ShiftAmount = C2->countTrailingZeros();
+    if (C2->isNegative()) {
+      // 'ashr C2, x' produces [C2, C2 >> (Width-1)]
+      Lower = *C2;
+      Upper = C2->ashr(ShiftAmount) + 1;
     } else {
-      // 'ashr CI2, x' produces [CI2 >> (Width-1), CI2]
-      Lower = CI2->getValue().ashr(ShiftAmount);
-      Upper = CI2->getValue() + 1;
+      // 'ashr C2, x' produces [C2 >> (Width-1), C2]
+      Lower = C2->ashr(ShiftAmount);
+      Upper = *C2 + 1;
     }
-  } else if (match(LHS, m_Or(m_Value(), m_ConstantInt(CI2)))) {
-    // 'or x, CI2' produces [CI2, UINT_MAX].
-    Lower = CI2->getValue();
-  } else if (match(LHS, m_And(m_Value(), m_ConstantInt(CI2)))) {
-    // 'and x, CI2' produces [0, CI2].
-    Upper = CI2->getValue() + 1;
-  } else if (match(LHS, m_NUWAdd(m_Value(), m_ConstantInt(CI2)))) {
-    // 'add nuw x, CI2' produces [CI2, UINT_MAX].
-    Lower = CI2->getValue();
+  } else if (match(LHS, m_Or(m_Value(), m_APInt(C2)))) {
+    // 'or x, C2' produces [C2, UINT_MAX].
+    Lower = *C2;
+  } else if (match(LHS, m_And(m_Value(), m_APInt(C2)))) {
+    // 'and x, C2' produces [0, C2].
+    Upper = *C2 + 1;
+  } else if (match(LHS, m_NUWAdd(m_Value(), m_APInt(C2)))) {
+    // 'add nuw x, C2' produces [C2, UINT_MAX].
+    Lower = *C2;
   }
 
   ConstantRange LHS_CR =
@@ -2289,9 +2283,9 @@
 
   if (!LHS_CR.isFullSet()) {
     if (RHS_CR.contains(LHS_CR))
-      return ConstantInt::getTrue(RHS->getContext());
+      return ConstantInt::getTrue(GetCompareTy(RHS));
     if (RHS_CR.inverse().contains(LHS_CR))
-      return ConstantInt::getFalse(RHS->getContext());
+      return ConstantInt::getFalse(GetCompareTy(RHS));
   }
 
   return nullptr;