[InstCombine] Check for out of range ashr values using APInt before calling getZExtValue
Reduced from oss-fuzz #5032 test case
llvm-svn: 322078
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 3d379ad..a04a3ce 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -818,7 +818,7 @@
Type *Ty = I.getType();
unsigned BitWidth = Ty->getScalarSizeInBits();
const APInt *ShAmtAPInt;
- if (match(Op1, m_APInt(ShAmtAPInt))) {
+ if (match(Op1, m_APInt(ShAmtAPInt)) && ShAmtAPInt->ult(BitWidth)) {
unsigned ShAmt = ShAmtAPInt->getZExtValue();
// If the shift amount equals the difference in width of the destination
@@ -832,7 +832,8 @@
// We can't handle (X << C1) >>s C2. It shifts arbitrary bits in. However,
// we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits.
const APInt *ShOp1;
- if (match(Op0, m_NSWShl(m_Value(X), m_APInt(ShOp1)))) {
+ if (match(Op0, m_NSWShl(m_Value(X), m_APInt(ShOp1))) &&
+ ShOp1->ult(BitWidth)) {
unsigned ShlAmt = ShOp1->getZExtValue();
if (ShlAmt < ShAmt) {
// (X <<nsw C1) >>s C2 --> X >>s (C2 - C1)
@@ -850,7 +851,8 @@
}
}
- if (match(Op0, m_AShr(m_Value(X), m_APInt(ShOp1)))) {
+ if (match(Op0, m_AShr(m_Value(X), m_APInt(ShOp1))) &&
+ ShOp1->ult(BitWidth)) {
unsigned AmtSum = ShAmt + ShOp1->getZExtValue();
// Oversized arithmetic shifts replicate the sign bit.
AmtSum = std::min(AmtSum, BitWidth - 1);