[InstCombine] Replacing X86-specific rounding intrinsics with generic floor-ceil
This patch replaces calls to X86-specific intrinsics with floor-ceil semantics
with calls to target-independent @llvm.floor.* and @llvm.ceil.* intrinsics. This
doesn't affect the resulting machine code, as those intrinsics are lowered to
the same instructions, but exposes these specific rounding cases to generic
optimizations.
Differential Revision: https://reviews.llvm.org/D48067
llvm-svn: 335039
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 9eb8d5d..9e046c9 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -576,6 +576,105 @@
return ConstantVector::get(Vals);
}
+// Replace X86-specific intrinsics with generic floor-ceil where applicable.
+static Value *simplifyX86round(IntrinsicInst &II,
+ InstCombiner::BuilderTy &Builder) {
+ ConstantInt *Arg = nullptr;
+ Intrinsic::ID IntrinsicID = II.getIntrinsicID();
+
+ if (IntrinsicID == Intrinsic::x86_sse41_round_ss ||
+ IntrinsicID == Intrinsic::x86_sse41_round_sd)
+ Arg = dyn_cast<ConstantInt>(II.getArgOperand(2));
+ else if (IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ss ||
+ IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_sd)
+ Arg = dyn_cast<ConstantInt>(II.getArgOperand(4));
+ else
+ Arg = dyn_cast<ConstantInt>(II.getArgOperand(1));
+ if (!Arg)
+ return nullptr;
+ unsigned RoundControl = Arg->getZExtValue();
+
+ Arg = nullptr;
+ unsigned SAE = 0;
+ if (IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ps_512 ||
+ IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_pd_512)
+ Arg = dyn_cast<ConstantInt>(II.getArgOperand(4));
+ else if (IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ss ||
+ IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_sd)
+ Arg = dyn_cast<ConstantInt>(II.getArgOperand(5));
+ else
+ SAE = 4;
+ if (!SAE) {
+ if (!Arg)
+ return nullptr;
+ SAE = Arg->getZExtValue();
+ }
+
+ if (SAE != 4 || (RoundControl != 2 /*ceil*/ && RoundControl != 1 /*floor*/))
+ return nullptr;
+
+ Value *Src, *Dst, *Mask;
+ bool IsScalar = false;
+ if (IntrinsicID == Intrinsic::x86_sse41_round_ss ||
+ IntrinsicID == Intrinsic::x86_sse41_round_sd ||
+ IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ss ||
+ IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_sd) {
+ IsScalar = true;
+ if (IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ss ||
+ IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_sd) {
+ Mask = II.getArgOperand(3);
+ Value *Zero = Constant::getNullValue(Mask->getType());
+ Mask = Builder.CreateAnd(Mask, 1);
+ Mask = Builder.CreateICmp(ICmpInst::ICMP_NE, Mask, Zero);
+ Dst = II.getArgOperand(2);
+ } else
+ Dst = II.getArgOperand(0);
+ Src = Builder.CreateExtractElement(II.getArgOperand(1), (uint64_t)0);
+ } else {
+ Src = II.getArgOperand(0);
+ if (IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ps_128 ||
+ IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ps_256 ||
+ IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ps_512 ||
+ IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_pd_128 ||
+ IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_pd_256 ||
+ IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_pd_512) {
+ Dst = II.getArgOperand(2);
+ Mask = II.getArgOperand(3);
+ } else {
+ Dst = Src;
+ Mask = ConstantInt::getAllOnesValue(
+ Builder.getIntNTy(Src->getType()->getVectorNumElements()));
+ }
+ }
+
+ Intrinsic::ID ID = (RoundControl == 2) ? Intrinsic::ceil : Intrinsic::floor;
+ Value *Res = Builder.CreateIntrinsic(ID, {Src}, &II);
+ if (!IsScalar) {
+ if (auto *C = dyn_cast<Constant>(Mask))
+ if (C->isAllOnesValue())
+ return Res;
+ auto *MaskTy = VectorType::get(
+ Builder.getInt1Ty(), cast<IntegerType>(Mask->getType())->getBitWidth());
+ Mask = Builder.CreateBitCast(Mask, MaskTy);
+ unsigned Width = Src->getType()->getVectorNumElements();
+ if (MaskTy->getVectorNumElements() > Width) {
+ uint32_t Indices[4];
+ for (unsigned i = 0; i != Width; ++i)
+ Indices[i] = i;
+ Mask = Builder.CreateShuffleVector(Mask, Mask,
+ makeArrayRef(Indices, Width));
+ }
+ return Builder.CreateSelect(Mask, Res, Dst);
+ }
+ if (IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ss ||
+ IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_sd) {
+ Dst = Builder.CreateExtractElement(Dst, (uint64_t)0);
+ Res = Builder.CreateSelect(Mask, Res, Dst);
+ Dst = II.getArgOperand(0);
+ }
+ return Builder.CreateInsertElement(Dst, Res, (uint64_t)0);
+}
+
static Value *simplifyX86movmsk(const IntrinsicInst &II) {
Value *Arg = II.getArgOperand(0);
Type *ResTy = II.getType();
@@ -2222,6 +2321,22 @@
break;
}
+ case Intrinsic::x86_sse41_round_ps:
+ case Intrinsic::x86_sse41_round_pd:
+ case Intrinsic::x86_avx_round_ps_256:
+ case Intrinsic::x86_avx_round_pd_256:
+ case Intrinsic::x86_avx512_mask_rndscale_ps_128:
+ case Intrinsic::x86_avx512_mask_rndscale_ps_256:
+ case Intrinsic::x86_avx512_mask_rndscale_ps_512:
+ case Intrinsic::x86_avx512_mask_rndscale_pd_128:
+ case Intrinsic::x86_avx512_mask_rndscale_pd_256:
+ case Intrinsic::x86_avx512_mask_rndscale_pd_512:
+ case Intrinsic::x86_avx512_mask_rndscale_ss:
+ case Intrinsic::x86_avx512_mask_rndscale_sd:
+ if (Value *V = simplifyX86round(*II, Builder))
+ return replaceInstUsesWith(*II, V);
+ break;
+
case Intrinsic::x86_mmx_pmovmskb:
case Intrinsic::x86_sse_movmsk_ps:
case Intrinsic::x86_sse2_movmsk_pd:
@@ -2438,8 +2553,6 @@
case Intrinsic::x86_sse2_cmp_sd:
case Intrinsic::x86_sse2_min_sd:
case Intrinsic::x86_sse2_max_sd:
- case Intrinsic::x86_sse41_round_ss:
- case Intrinsic::x86_sse41_round_sd:
case Intrinsic::x86_xop_vfrcz_ss:
case Intrinsic::x86_xop_vfrcz_sd: {
unsigned VWidth = II->getType()->getVectorNumElements();
@@ -2452,6 +2565,19 @@
}
break;
}
+ case Intrinsic::x86_sse41_round_ss:
+ case Intrinsic::x86_sse41_round_sd: {
+ unsigned VWidth = II->getType()->getVectorNumElements();
+ APInt UndefElts(VWidth, 0);
+ APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth));
+ if (Value *V = SimplifyDemandedVectorElts(II, AllOnesEltMask, UndefElts)) {
+ if (V != II)
+ return replaceInstUsesWith(*II, V);
+ return II;
+ } else if (Value *V = simplifyX86round(*II, Builder))
+ return replaceInstUsesWith(*II, V);
+ break;
+ }
// Constant fold ashr( <A x Bi>, Ci ).
// Constant fold lshr( <A x Bi>, Ci ).