[X86] Constant folding of adds/subs intrinsics
Summary: This adds constant folding of signed add/sub with saturation intrinsics.
Reviewers: craig.topper, spatel, RKSimon, chandlerc, efriedma
Reviewed By: craig.topper
Subscribers: rnk, llvm-commits
Differential Revision: https://reviews.llvm.org/D50499
llvm-svn: 339659
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index f0c4ceb..38d6f29 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -243,6 +243,86 @@
return nullptr;
}
+static Value *simplifyX86AddsSubs(const IntrinsicInst &II,
+ InstCombiner::BuilderTy &Builder) {
+ bool IsAddition = false;
+ bool IsMasked = false;
+
+ switch (II.getIntrinsicID()) {
+ default: llvm_unreachable("Unexpected intrinsic!");
+ case Intrinsic::x86_sse2_padds_b:
+ case Intrinsic::x86_sse2_padds_w:
+ case Intrinsic::x86_avx2_padds_b:
+ case Intrinsic::x86_avx2_padds_w:
+ IsAddition = true; IsMasked = false;
+ break;
+ case Intrinsic::x86_sse2_psubs_b:
+ case Intrinsic::x86_sse2_psubs_w:
+ case Intrinsic::x86_avx2_psubs_b:
+ case Intrinsic::x86_avx2_psubs_w:
+ IsAddition = false; IsMasked = false;
+ break;
+ case Intrinsic::x86_avx512_mask_padds_b_512:
+ case Intrinsic::x86_avx512_mask_padds_w_512:
+ IsAddition = true; IsMasked = true;
+ break;
+ case Intrinsic::x86_avx512_mask_psubs_b_512:
+ case Intrinsic::x86_avx512_mask_psubs_w_512:
+ IsAddition = false; IsMasked = true;
+ break;
+ }
+
+ auto *Arg0 = dyn_cast<Constant>(II.getOperand(0));
+ auto *Arg1 = dyn_cast<Constant>(II.getOperand(1));
+ auto VT = cast<VectorType>(II.getType());
+ auto SVT = VT->getElementType();
+ unsigned NumElems = VT->getNumElements();
+
+ if (!Arg0 || !Arg1 || (IsMasked && !isa<Constant>(II.getOperand(2))))
+ return nullptr;
+
+ SmallVector<Constant *, 64> Result;
+
+ APInt MaxValue = APInt::getSignedMaxValue(SVT->getIntegerBitWidth());
+ APInt MinValue = APInt::getSignedMinValue(SVT->getIntegerBitWidth());
+ for (unsigned i = 0; i < NumElems; ++i) {
+ auto *Elt0 = Arg0->getAggregateElement(i);
+ auto *Elt1 = Arg1->getAggregateElement(i);
+ if (isa<UndefValue>(Elt0) || isa<UndefValue>(Elt1)) {
+ Result.push_back(UndefValue::get(SVT));
+ continue;
+ }
+
+ if (!isa<ConstantInt>(Elt0) || !isa<ConstantInt>(Elt1))
+ return nullptr;
+
+ const APInt &Val0 = cast<ConstantInt>(Elt0)->getValue();
+ const APInt &Val1 = cast<ConstantInt>(Elt1)->getValue();
+ bool Overflow = false;
+ APInt ResultElem = IsAddition ? Val0.sadd_ov(Val1, Overflow)
+ : Val0.ssub_ov(Val1, Overflow);
+ if (Overflow)
+ ResultElem = Val0.isNegative() ? MinValue : MaxValue;
+ Result.push_back(Constant::getIntegerValue(SVT, ResultElem));
+ }
+
+ Value *ResultVec = ConstantVector::get(Result);
+
+ if (II.getNumArgOperands() == 4) { // For masked intrinsics.
+ Value *Src = II.getOperand(2);
+ auto Mask = II.getOperand(3);
+ if (auto *C = dyn_cast<Constant>(Mask))
+ if (C->isAllOnesValue())
+ return ResultVec;
+ auto *MaskTy = VectorType::get(
+ Builder.getInt1Ty(), cast<IntegerType>(Mask->getType())->getBitWidth());
+ Mask = Builder.CreateBitCast(Mask, MaskTy);
+ ResultVec = Builder.CreateSelect(Mask, ResultVec, Src);
+ }
+
+ return ResultVec;
+}
+
static Value *simplifyX86immShift(const IntrinsicInst &II,
InstCombiner::BuilderTy &Builder) {
bool LogicalShift = false;
@@ -2525,6 +2605,24 @@
break;
}
+ // Constant fold add/sub with saturation intrinsics.
+ case Intrinsic::x86_sse2_padds_b:
+ case Intrinsic::x86_sse2_padds_w:
+ case Intrinsic::x86_sse2_psubs_b:
+ case Intrinsic::x86_sse2_psubs_w:
+ case Intrinsic::x86_avx2_padds_b:
+ case Intrinsic::x86_avx2_padds_w:
+ case Intrinsic::x86_avx2_psubs_b:
+ case Intrinsic::x86_avx2_psubs_w:
+ case Intrinsic::x86_avx512_mask_padds_b_512:
+ case Intrinsic::x86_avx512_mask_padds_w_512:
+ case Intrinsic::x86_avx512_mask_psubs_b_512:
+ case Intrinsic::x86_avx512_mask_psubs_w_512:
+ if (Value *V = simplifyX86AddsSubs(*II, Builder))
+ return replaceInstUsesWith(*II, V);
+ break;
+
+
// Constant fold ashr( <A x Bi>, Ci ).
// Constant fold lshr( <A x Bi>, Ci ).
// Constant fold shl( <A x Bi>, Ci ).