[InstCombine] Check for out of range shift values using APInt before calling getZExtValue
Reduced from oss-fuzz #4871 test case
llvm-svn: 321748
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 44bbb84..3d379ad 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -87,8 +87,7 @@
// Equal shift amounts in opposite directions become bitwise 'and':
// lshr (shl X, C), C --> and X, C'
// shl (lshr X, C), C --> and X, C'
- unsigned InnerShAmt = InnerShiftConst->getZExtValue();
- if (InnerShAmt == OuterShAmt)
+ if (*InnerShiftConst == OuterShAmt)
return true;
// If the 2nd shift is bigger than the 1st, we can fold:
@@ -98,7 +97,8 @@
// Also, check that the inner shift is valid (less than the type width) or
// we'll crash trying to produce the bit mask for the 'and'.
unsigned TypeWidth = InnerShift->getType()->getScalarSizeInBits();
- if (InnerShAmt > OuterShAmt && InnerShAmt < TypeWidth) {
+ if (InnerShiftConst->ugt(OuterShAmt) && InnerShiftConst->ult(TypeWidth)) {
+ unsigned InnerShAmt = InnerShiftConst->getZExtValue();
unsigned MaskShift =
IsInnerShl ? TypeWidth - InnerShAmt : InnerShAmt - OuterShAmt;
APInt Mask = APInt::getLowBitsSet(TypeWidth, OuterShAmt) << MaskShift;
@@ -135,7 +135,7 @@
ConstantInt *CI = nullptr;
if ((IsLeftShift && match(I, m_LShr(m_Value(), m_ConstantInt(CI)))) ||
(!IsLeftShift && match(I, m_Shl(m_Value(), m_ConstantInt(CI))))) {
- if (CI->getZExtValue() == NumBits) {
+ if (CI->getValue() == NumBits) {
// TODO: Check that the input bits are already zero with MaskedValueIsZero
#if 0
// If this is a truncate of a logical shr, we can truncate it to a smaller