[InstSimplify] allow icmp with constant folds for splat vectors, part 2
Completes the m_APInt changes for simplifyICmpWithConstant().
Other commits in this series:
https://reviews.llvm.org/rL279492
https://reviews.llvm.org/rL279530
https://reviews.llvm.org/rL279534
https://reviews.llvm.org/rL279538
llvm-svn: 279543
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index c048f31..9be4613 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -2165,119 +2165,113 @@
if (RHS_CR.isFullSet())
return ConstantInt::getTrue(GetCompareTy(RHS));
- // FIXME: Use m_APInt below here to allow splat vector folds.
- ConstantInt *CI = dyn_cast<ConstantInt>(RHS);
- if (!CI)
- return nullptr;
-
// Many binary operators with constant RHS have easy to compute constant
// range. Use them to check whether the comparison is a tautology.
- unsigned Width = CI->getBitWidth();
+ unsigned Width = C->getBitWidth();
APInt Lower = APInt(Width, 0);
APInt Upper = APInt(Width, 0);
- ConstantInt *CI2;
- if (match(LHS, m_URem(m_Value(), m_ConstantInt(CI2)))) {
- // 'urem x, CI2' produces [0, CI2).
- Upper = CI2->getValue();
- } else if (match(LHS, m_SRem(m_Value(), m_ConstantInt(CI2)))) {
- // 'srem x, CI2' produces (-|CI2|, |CI2|).
- Upper = CI2->getValue().abs();
+ const APInt *C2;
+ if (match(LHS, m_URem(m_Value(), m_APInt(C2)))) {
+ // 'urem x, C2' produces [0, C2).
+ Upper = *C2;
+ } else if (match(LHS, m_SRem(m_Value(), m_APInt(C2)))) {
+ // 'srem x, C2' produces (-|C2|, |C2|).
+ Upper = C2->abs();
Lower = (-Upper) + 1;
- } else if (match(LHS, m_UDiv(m_ConstantInt(CI2), m_Value()))) {
- // 'udiv CI2, x' produces [0, CI2].
- Upper = CI2->getValue() + 1;
- } else if (match(LHS, m_UDiv(m_Value(), m_ConstantInt(CI2)))) {
- // 'udiv x, CI2' produces [0, UINT_MAX / CI2].
+ } else if (match(LHS, m_UDiv(m_APInt(C2), m_Value()))) {
+ // 'udiv C2, x' produces [0, C2].
+ Upper = *C2 + 1;
+ } else if (match(LHS, m_UDiv(m_Value(), m_APInt(C2)))) {
+ // 'udiv x, C2' produces [0, UINT_MAX / C2].
APInt NegOne = APInt::getAllOnesValue(Width);
- if (!CI2->isZero())
- Upper = NegOne.udiv(CI2->getValue()) + 1;
- } else if (match(LHS, m_SDiv(m_ConstantInt(CI2), m_Value()))) {
- if (CI2->isMinSignedValue()) {
+ if (*C2 != 0)
+ Upper = NegOne.udiv(*C2) + 1;
+ } else if (match(LHS, m_SDiv(m_APInt(C2), m_Value()))) {
+ if (C2->isMinSignedValue()) {
// 'sdiv INT_MIN, x' produces [INT_MIN, INT_MIN / -2].
- Lower = CI2->getValue();
+ Lower = *C2;
Upper = Lower.lshr(1) + 1;
} else {
- // 'sdiv CI2, x' produces [-|CI2|, |CI2|].
- Upper = CI2->getValue().abs() + 1;
+ // 'sdiv C2, x' produces [-|C2|, |C2|].
+ Upper = C2->abs() + 1;
Lower = (-Upper) + 1;
}
- } else if (match(LHS, m_SDiv(m_Value(), m_ConstantInt(CI2)))) {
+ } else if (match(LHS, m_SDiv(m_Value(), m_APInt(C2)))) {
APInt IntMin = APInt::getSignedMinValue(Width);
APInt IntMax = APInt::getSignedMaxValue(Width);
- const APInt &Val = CI2->getValue();
- if (Val.isAllOnesValue()) {
+ if (C2->isAllOnesValue()) {
// 'sdiv x, -1' produces [INT_MIN + 1, INT_MAX]
- // where CI2 != -1 and CI2 != 0 and CI2 != 1
+ // where C2 != -1 and C2 != 0 and C2 != 1
Lower = IntMin + 1;
Upper = IntMax + 1;
- } else if (Val.countLeadingZeros() < Width - 1) {
- // 'sdiv x, CI2' produces [INT_MIN / CI2, INT_MAX / CI2]
- // where CI2 != -1 and CI2 != 0 and CI2 != 1
- Lower = IntMin.sdiv(Val);
- Upper = IntMax.sdiv(Val);
+ } else if (C2->countLeadingZeros() < Width - 1) {
+ // 'sdiv x, C2' produces [INT_MIN / C2, INT_MAX / C2]
+ // where C2 != -1 and C2 != 0 and C2 != 1
+ Lower = IntMin.sdiv(*C2);
+ Upper = IntMax.sdiv(*C2);
if (Lower.sgt(Upper))
std::swap(Lower, Upper);
Upper = Upper + 1;
assert(Upper != Lower && "Upper part of range has wrapped!");
}
- } else if (match(LHS, m_NUWShl(m_ConstantInt(CI2), m_Value()))) {
- // 'shl nuw CI2, x' produces [CI2, CI2 << CLZ(CI2)]
- Lower = CI2->getValue();
+ } else if (match(LHS, m_NUWShl(m_APInt(C2), m_Value()))) {
+ // 'shl nuw C2, x' produces [C2, C2 << CLZ(C2)]
+ Lower = *C2;
Upper = Lower.shl(Lower.countLeadingZeros()) + 1;
- } else if (match(LHS, m_NSWShl(m_ConstantInt(CI2), m_Value()))) {
- if (CI2->isNegative()) {
- // 'shl nsw CI2, x' produces [CI2 << CLO(CI2)-1, CI2]
- unsigned ShiftAmount = CI2->getValue().countLeadingOnes() - 1;
- Lower = CI2->getValue().shl(ShiftAmount);
- Upper = CI2->getValue() + 1;
+ } else if (match(LHS, m_NSWShl(m_APInt(C2), m_Value()))) {
+ if (C2->isNegative()) {
+ // 'shl nsw C2, x' produces [C2 << CLO(C2)-1, C2]
+ unsigned ShiftAmount = C2->countLeadingOnes() - 1;
+ Lower = C2->shl(ShiftAmount);
+ Upper = *C2 + 1;
} else {
- // 'shl nsw CI2, x' produces [CI2, CI2 << CLZ(CI2)-1]
- unsigned ShiftAmount = CI2->getValue().countLeadingZeros() - 1;
- Lower = CI2->getValue();
- Upper = CI2->getValue().shl(ShiftAmount) + 1;
+ // 'shl nsw C2, x' produces [C2, C2 << CLZ(C2)-1]
+ unsigned ShiftAmount = C2->countLeadingZeros() - 1;
+ Lower = *C2;
+ Upper = C2->shl(ShiftAmount) + 1;
}
- } else if (match(LHS, m_LShr(m_Value(), m_ConstantInt(CI2)))) {
- // 'lshr x, CI2' produces [0, UINT_MAX >> CI2].
+ } else if (match(LHS, m_LShr(m_Value(), m_APInt(C2)))) {
+ // 'lshr x, C2' produces [0, UINT_MAX >> C2].
APInt NegOne = APInt::getAllOnesValue(Width);
- if (CI2->getValue().ult(Width))
- Upper = NegOne.lshr(CI2->getValue()) + 1;
- } else if (match(LHS, m_LShr(m_ConstantInt(CI2), m_Value()))) {
- // 'lshr CI2, x' produces [CI2 >> (Width-1), CI2].
+ if (C2->ult(Width))
+ Upper = NegOne.lshr(*C2) + 1;
+ } else if (match(LHS, m_LShr(m_APInt(C2), m_Value()))) {
+ // 'lshr C2, x' produces [C2 >> (Width-1), C2].
unsigned ShiftAmount = Width - 1;
- if (!CI2->isZero() && cast<BinaryOperator>(LHS)->isExact())
- ShiftAmount = CI2->getValue().countTrailingZeros();
- Lower = CI2->getValue().lshr(ShiftAmount);
- Upper = CI2->getValue() + 1;
- } else if (match(LHS, m_AShr(m_Value(), m_ConstantInt(CI2)))) {
- // 'ashr x, CI2' produces [INT_MIN >> CI2, INT_MAX >> CI2].
+ if (*C2 != 0 && cast<BinaryOperator>(LHS)->isExact())
+ ShiftAmount = C2->countTrailingZeros();
+ Lower = C2->lshr(ShiftAmount);
+ Upper = *C2 + 1;
+ } else if (match(LHS, m_AShr(m_Value(), m_APInt(C2)))) {
+ // 'ashr x, C2' produces [INT_MIN >> C2, INT_MAX >> C2].
APInt IntMin = APInt::getSignedMinValue(Width);
APInt IntMax = APInt::getSignedMaxValue(Width);
- if (CI2->getValue().ult(Width)) {
- Lower = IntMin.ashr(CI2->getValue());
- Upper = IntMax.ashr(CI2->getValue()) + 1;
+ if (C2->ult(Width)) {
+ Lower = IntMin.ashr(*C2);
+ Upper = IntMax.ashr(*C2) + 1;
}
- } else if (match(LHS, m_AShr(m_ConstantInt(CI2), m_Value()))) {
+ } else if (match(LHS, m_AShr(m_APInt(C2), m_Value()))) {
unsigned ShiftAmount = Width - 1;
- if (!CI2->isZero() && cast<BinaryOperator>(LHS)->isExact())
- ShiftAmount = CI2->getValue().countTrailingZeros();
- if (CI2->isNegative()) {
- // 'ashr CI2, x' produces [CI2, CI2 >> (Width-1)]
- Lower = CI2->getValue();
- Upper = CI2->getValue().ashr(ShiftAmount) + 1;
+ if (*C2 != 0 && cast<BinaryOperator>(LHS)->isExact())
+ ShiftAmount = C2->countTrailingZeros();
+ if (C2->isNegative()) {
+ // 'ashr C2, x' produces [C2, C2 >> (Width-1)]
+ Lower = *C2;
+ Upper = C2->ashr(ShiftAmount) + 1;
} else {
- // 'ashr CI2, x' produces [CI2 >> (Width-1), CI2]
- Lower = CI2->getValue().ashr(ShiftAmount);
- Upper = CI2->getValue() + 1;
+ // 'ashr C2, x' produces [C2 >> (Width-1), C2]
+ Lower = C2->ashr(ShiftAmount);
+ Upper = *C2 + 1;
}
- } else if (match(LHS, m_Or(m_Value(), m_ConstantInt(CI2)))) {
- // 'or x, CI2' produces [CI2, UINT_MAX].
- Lower = CI2->getValue();
- } else if (match(LHS, m_And(m_Value(), m_ConstantInt(CI2)))) {
- // 'and x, CI2' produces [0, CI2].
- Upper = CI2->getValue() + 1;
- } else if (match(LHS, m_NUWAdd(m_Value(), m_ConstantInt(CI2)))) {
- // 'add nuw x, CI2' produces [CI2, UINT_MAX].
- Lower = CI2->getValue();
+ } else if (match(LHS, m_Or(m_Value(), m_APInt(C2)))) {
+ // 'or x, C2' produces [C2, UINT_MAX].
+ Lower = *C2;
+ } else if (match(LHS, m_And(m_Value(), m_APInt(C2)))) {
+ // 'and x, C2' produces [0, C2].
+ Upper = *C2 + 1;
+ } else if (match(LHS, m_NUWAdd(m_Value(), m_APInt(C2)))) {
+ // 'add nuw x, C2' produces [C2, UINT_MAX].
+ Lower = *C2;
}
ConstantRange LHS_CR =
@@ -2289,9 +2283,9 @@
if (!LHS_CR.isFullSet()) {
if (RHS_CR.contains(LHS_CR))
- return ConstantInt::getTrue(RHS->getContext());
+ return ConstantInt::getTrue(GetCompareTy(RHS));
if (RHS_CR.inverse().contains(LHS_CR))
- return ConstantInt::getFalse(RHS->getContext());
+ return ConstantInt::getFalse(GetCompareTy(RHS));
}
return nullptr;