[InstCombine][X86] Add MULDQ/MULUDQ constant folding support

llvm-svn: 292793
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 9f9bf40..e6e126b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -510,16 +510,53 @@
   return Builder.CreateAShr(Vec, ShiftVec);
 }
 
-static Value *simplifyX86muldq(const IntrinsicInst &II) {
+static Value *simplifyX86muldq(const IntrinsicInst &II,
+                               InstCombiner::BuilderTy &Builder) {
   Value *Arg0 = II.getArgOperand(0);
   Value *Arg1 = II.getArgOperand(1);
   Type *ResTy = II.getType();
+  assert(Arg0->getType()->getScalarSizeInBits() == 32 &&
+         Arg1->getType()->getScalarSizeInBits() == 32 &&
+         ResTy->getScalarSizeInBits() == 64 && "Unexpected muldq/muludq types");
 
   // muldq/muludq(undef, undef) -> zero (matches generic mul behavior)
   if (isa<UndefValue>(Arg0) && isa<UndefValue>(Arg1))
     return ConstantAggregateZero::get(ResTy);
 
-  return nullptr;
+  // Constant folding.
+  // PMULDQ  = (mul(vXi64 sext(shuffle<0,2,..>(Arg0)),
+  //                vXi64 sext(shuffle<0,2,..>(Arg1))))
+  // PMULUDQ = (mul(vXi64 zext(shuffle<0,2,..>(Arg0)),
+  //                vXi64 zext(shuffle<0,2,..>(Arg1))))
+  if (!isa<Constant>(Arg0) || !isa<Constant>(Arg1))
+    return nullptr;
+
+  unsigned NumElts = ResTy->getVectorNumElements();
+  assert(Arg0->getType()->getVectorNumElements() == (2 * NumElts) &&
+         Arg1->getType()->getVectorNumElements() == (2 * NumElts) &&
+         "Unexpected muldq/muludq types");
+
+  unsigned IntrinsicID = II.getIntrinsicID();
+  bool IsSigned = (Intrinsic::x86_sse41_pmuldq == IntrinsicID ||
+                   Intrinsic::x86_avx2_pmul_dq == IntrinsicID ||
+                   Intrinsic::x86_avx512_pmul_dq_512 == IntrinsicID);
+
+  SmallVector<unsigned, 16> ShuffleMask;
+  for (unsigned i = 0; i != NumElts; ++i)
+    ShuffleMask.push_back(i * 2);
+
+  auto *LHS = Builder.CreateShuffleVector(Arg0, Arg0, ShuffleMask);
+  auto *RHS = Builder.CreateShuffleVector(Arg1, Arg1, ShuffleMask);
+
+  if (IsSigned) {
+    LHS = Builder.CreateSExt(LHS, ResTy);
+    RHS = Builder.CreateSExt(RHS, ResTy);
+  } else {
+    LHS = Builder.CreateZExt(LHS, ResTy);
+    RHS = Builder.CreateZExt(RHS, ResTy);
+  }
+
+  return Builder.CreateMul(LHS, RHS);
 }
 
 static Value *simplifyX86movmsk(const IntrinsicInst &II,
@@ -2154,7 +2191,7 @@
   case Intrinsic::x86_avx2_pmulu_dq:
   case Intrinsic::x86_avx512_pmul_dq_512:
   case Intrinsic::x86_avx512_pmulu_dq_512: {
-    if (Value *V = simplifyX86muldq(*II))
+    if (Value *V = simplifyX86muldq(*II, *Builder))
       return replaceInstUsesWith(*II, V);
 
     unsigned VWidth = II->getType()->getVectorNumElements();