InstCombine: Fold bitcast of vector to FP scalar
llvm-svn: 288978
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index 04d36ff..2d1edfe 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -88,7 +88,7 @@
return nullptr;
}
-// Constant fold bitcast, symbolically evaluating it with DataLayout.
+/// Constant fold bitcast, symbolically evaluating it with DataLayout.
/// This always returns a non-null constant, but it may be a
/// ConstantExpr if unfoldable.
Constant *FoldBitCast(Constant *C, Type *DestTy, const DataLayout &DL) {
@@ -99,31 +99,33 @@
!DestTy->isPtrOrPtrVectorTy()) // Don't get ones for ptr types!
return Constant::getAllOnesValue(DestTy);
- // Handle a vector->integer cast.
- if (auto *IT = dyn_cast<IntegerType>(DestTy)) {
- auto *VTy = dyn_cast<VectorType>(C->getType());
- if (!VTy)
- return ConstantExpr::getBitCast(C, DestTy);
+ if (auto *VTy = dyn_cast<VectorType>(C->getType())) {
+ // Handle a vector->scalar integer/fp cast.
+ if (isa<IntegerType>(DestTy) || DestTy->isFloatingPointTy()) {
+ unsigned NumSrcElts = VTy->getNumElements();
+ Type *SrcEltTy = VTy->getElementType();
- unsigned NumSrcElts = VTy->getNumElements();
- Type *SrcEltTy = VTy->getElementType();
+ // If the vector is a vector of floating point, convert it to vector of int
+ // to simplify things.
+ if (SrcEltTy->isFloatingPointTy()) {
+ unsigned FPWidth = SrcEltTy->getPrimitiveSizeInBits();
+ Type *SrcIVTy =
+ VectorType::get(IntegerType::get(C->getContext(), FPWidth), NumSrcElts);
+ // Ask IR to do the conversion now that #elts line up.
+ C = ConstantExpr::getBitCast(C, SrcIVTy);
+ }
- // If the vector is a vector of floating point, convert it to vector of int
- // to simplify things.
- if (SrcEltTy->isFloatingPointTy()) {
- unsigned FPWidth = SrcEltTy->getPrimitiveSizeInBits();
- Type *SrcIVTy =
- VectorType::get(IntegerType::get(C->getContext(), FPWidth), NumSrcElts);
- // Ask IR to do the conversion now that #elts line up.
- C = ConstantExpr::getBitCast(C, SrcIVTy);
+ APInt Result(DL.getTypeSizeInBits(DestTy), 0);
+ if (Constant *CE = foldConstVectorToAPInt(Result, DestTy, C,
+ SrcEltTy, NumSrcElts, DL))
+ return CE;
+
+ if (isa<IntegerType>(DestTy))
+ return ConstantInt::get(DestTy, Result);
+
+ APFloat FP(DestTy->getFltSemantics(), Result);
+ return ConstantFP::get(DestTy->getContext(), FP);
}
-
- APInt Result(IT->getBitWidth(), 0);
- if (Constant *CE = foldConstVectorToAPInt(Result, DestTy, C,
- SrcEltTy, NumSrcElts, DL))
- return CE;
-
- return ConstantInt::get(IT, Result);
}
// The code below only handles casts to vectors currently.