[InstCombine][X86] Add MULDQ/MULUDQ constant folding support
llvm-svn: 292793
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 9f9bf40..e6e126b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -510,16 +510,53 @@
return Builder.CreateAShr(Vec, ShiftVec);
}
-static Value *simplifyX86muldq(const IntrinsicInst &II) {
+static Value *simplifyX86muldq(const IntrinsicInst &II,
+ InstCombiner::BuilderTy &Builder) {
Value *Arg0 = II.getArgOperand(0);
Value *Arg1 = II.getArgOperand(1);
Type *ResTy = II.getType();
+ assert(Arg0->getType()->getScalarSizeInBits() == 32 &&
+ Arg1->getType()->getScalarSizeInBits() == 32 &&
+ ResTy->getScalarSizeInBits() == 64 && "Unexpected muldq/muludq types");
// muldq/muludq(undef, undef) -> zero (matches generic mul behavior)
if (isa<UndefValue>(Arg0) && isa<UndefValue>(Arg1))
return ConstantAggregateZero::get(ResTy);
- return nullptr;
+ // Constant folding.
+ // PMULDQ = (mul(vXi64 sext(shuffle<0,2,..>(Arg0)),
+ // vXi64 sext(shuffle<0,2,..>(Arg1))))
+ // PMULUDQ = (mul(vXi64 zext(shuffle<0,2,..>(Arg0)),
+ // vXi64 zext(shuffle<0,2,..>(Arg1))))
+ if (!isa<Constant>(Arg0) || !isa<Constant>(Arg1))
+ return nullptr;
+
+ unsigned NumElts = ResTy->getVectorNumElements();
+ assert(Arg0->getType()->getVectorNumElements() == (2 * NumElts) &&
+ Arg1->getType()->getVectorNumElements() == (2 * NumElts) &&
+ "Unexpected muldq/muludq types");
+
+ unsigned IntrinsicID = II.getIntrinsicID();
+ bool IsSigned = (Intrinsic::x86_sse41_pmuldq == IntrinsicID ||
+ Intrinsic::x86_avx2_pmul_dq == IntrinsicID ||
+ Intrinsic::x86_avx512_pmul_dq_512 == IntrinsicID);
+
+ SmallVector<unsigned, 16> ShuffleMask;
+ for (unsigned i = 0; i != NumElts; ++i)
+ ShuffleMask.push_back(i * 2);
+
+ auto *LHS = Builder.CreateShuffleVector(Arg0, Arg0, ShuffleMask);
+ auto *RHS = Builder.CreateShuffleVector(Arg1, Arg1, ShuffleMask);
+
+ if (IsSigned) {
+ LHS = Builder.CreateSExt(LHS, ResTy);
+ RHS = Builder.CreateSExt(RHS, ResTy);
+ } else {
+ LHS = Builder.CreateZExt(LHS, ResTy);
+ RHS = Builder.CreateZExt(RHS, ResTy);
+ }
+
+ return Builder.CreateMul(LHS, RHS);
}
static Value *simplifyX86movmsk(const IntrinsicInst &II,
@@ -2154,7 +2191,7 @@
case Intrinsic::x86_avx2_pmulu_dq:
case Intrinsic::x86_avx512_pmul_dq_512:
case Intrinsic::x86_avx512_pmulu_dq_512: {
- if (Value *V = simplifyX86muldq(*II))
+ if (Value *V = simplifyX86muldq(*II, *Builder))
return replaceInstUsesWith(*II, V);
unsigned VWidth = II->getType()->getVectorNumElements();