[ConstantFold] ExtractConstantBytes - handle shifts on large integer types
Use APInt instead of getZExtValue from the ConstantInt until we can confirm that the shift amount is in range.
Reduced from OSS-Fuzz #14169 - https://bugs.chromium.org/p/oss-fuzz/issues/detail?id=14169
llvm-svn: 358192
diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp
index 3784405..3c01a48 100644
--- a/llvm/lib/IR/ConstantFold.cpp
+++ b/llvm/lib/IR/ConstantFold.cpp
@@ -268,19 +268,20 @@
     ConstantInt *Amt = dyn_cast<ConstantInt>(CE->getOperand(1));
     if (!Amt)
       return nullptr;
-    unsigned ShAmt = Amt->getZExtValue();
+    APInt ShAmt = Amt->getValue();
     // Cannot analyze non-byte shifts.
     if ((ShAmt & 7) != 0)
       return nullptr;
-    ShAmt >>= 3;
+    ShAmt.lshrInPlace(3);
 
     // If the extract is known to be all zeros, return zero.
-    if (ByteStart >= CSize-ShAmt)
-      return Constant::getNullValue(IntegerType::get(CE->getContext(),
-                                                     ByteSize*8));
+    if (ShAmt.uge(CSize - ByteStart))
+      return Constant::getNullValue(
+          IntegerType::get(CE->getContext(), ByteSize * 8));
     // If the extract is known to be fully in the input, extract it.
-    if (ByteStart+ByteSize+ShAmt <= CSize)
-      return ExtractConstantBytes(CE->getOperand(0), ByteStart+ShAmt, ByteSize);
+    if (ShAmt.ule(CSize - (ByteStart + ByteSize)))
+      return ExtractConstantBytes(CE->getOperand(0),
+                                  ByteStart + ShAmt.getZExtValue(), ByteSize);
 
     // TODO: Handle the 'partially zero' case.
     return nullptr;
@@ -290,19 +291,20 @@
     ConstantInt *Amt = dyn_cast<ConstantInt>(CE->getOperand(1));
     if (!Amt)
       return nullptr;
-    unsigned ShAmt = Amt->getZExtValue();
+    APInt ShAmt = Amt->getValue();
     // Cannot analyze non-byte shifts.
     if ((ShAmt & 7) != 0)
       return nullptr;
-    ShAmt >>= 3;
+    ShAmt.lshrInPlace(3);
 
     // If the extract is known to be all zeros, return zero.
-    if (ByteStart+ByteSize <= ShAmt)
-      return Constant::getNullValue(IntegerType::get(CE->getContext(),
-                                                     ByteSize*8));
+    if (ShAmt.uge(ByteStart + ByteSize))
+      return Constant::getNullValue(
+          IntegerType::get(CE->getContext(), ByteSize * 8));
     // If the extract is known to be fully in the input, extract it.
-    if (ByteStart >= ShAmt)
-      return ExtractConstantBytes(CE->getOperand(0), ByteStart-ShAmt, ByteSize);
+    if (ShAmt.ule(ByteStart))
+      return ExtractConstantBytes(CE->getOperand(0),
+                                  ByteStart - ShAmt.getZExtValue(), ByteSize);
 
     // TODO: Handle the 'partially zero' case.
     return nullptr;