[X86] Reimplement r321437 using custom lowering instead of as a DAG combine.

My original implementation ran as a DAG combine post type legalization, but it turns out we don't run that DAG combine step if type legalization didn't change anything. Attempts to make the combine run before type legalization as well hit other issues.

So just do it in LowerMUL where we can catch more cases.

llvm-svn: 321496
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 7d2bfd4..ad2d4b5 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -1310,8 +1310,6 @@
       setOperationAction(ISD::UINT_TO_FP, MVT::v8i64, Legal);
       setOperationAction(ISD::FP_TO_SINT, MVT::v8i64, Legal);
       setOperationAction(ISD::FP_TO_UINT, MVT::v8i64, Legal);
-
-      setOperationAction(ISD::MUL,        MVT::v8i64, Legal);
     }
 
     if (Subtarget.hasCDI()) {
@@ -1388,8 +1386,6 @@
         setOperationAction(ISD::UINT_TO_FP,     VT, Legal);
         setOperationAction(ISD::FP_TO_SINT,     VT, Legal);
         setOperationAction(ISD::FP_TO_UINT,     VT, Legal);
-
-        setOperationAction(ISD::MUL,            VT, Legal);
       }
     }
 
@@ -22140,6 +22136,11 @@
   bool AHiIsZero = DAG.MaskedValueIsZero(A, UpperBitsMask);
   bool BHiIsZero = DAG.MaskedValueIsZero(B, UpperBitsMask);
 
+  // If DQI is supported we can use MULLQ, but MULUDQ is still better if the
+  // the high bits are known to be zero.
+  if (Subtarget.hasDQI() && (!AHiIsZero || !BHiIsZero))
+    return Op;
+
   // Bit cast to 32-bit vectors for MULUDQ.
   SDValue Alo = DAG.getBitcast(MulVT, A);
   SDValue Blo = DAG.getBitcast(MulVT, B);
@@ -32423,41 +32424,6 @@
   return SDValue();
 }
 
-static SDValue combineVMUL(SDNode *N, SelectionDAG &DAG,
-                           const X86Subtarget &Subtarget) {
-  EVT VT = N->getValueType(0);
-  SDLoc dl(N);
-
-  if (VT.getScalarType() != MVT::i64)
-    return SDValue();
-
-  // Don't try to lower 256 bit integer vectors on AVX1 targets.
-  if (!Subtarget.hasAVX2() && VT.getVectorNumElements() > 2)
-    return SDValue();
-
-  MVT MulVT = MVT::getVectorVT(MVT::i32, VT.getVectorNumElements() * 2);
-
-  SDValue LHS = N->getOperand(0);
-  SDValue RHS = N->getOperand(1);
-
-  // MULDQ returns the 64-bit result of the signed multiplication of the lower
-  // 32-bits. We can lower with this if the sign bits stretch that far.
-  if (Subtarget.hasSSE41() && DAG.ComputeNumSignBits(LHS) > 32 &&
-      DAG.ComputeNumSignBits(RHS) > 32) {
-    return DAG.getNode(X86ISD::PMULDQ, dl, VT, DAG.getBitcast(MulVT, LHS),
-                       DAG.getBitcast(MulVT, RHS));
-  }
-
-  // If the upper bits are zero we can use a single pmuludq.
-  APInt Mask = APInt::getHighBitsSet(64, 32);
-  if (DAG.MaskedValueIsZero(LHS, Mask) && DAG.MaskedValueIsZero(RHS, Mask)) {
-    return DAG.getNode(X86ISD::PMULUDQ, dl, VT, DAG.getBitcast(MulVT, LHS),
-                       DAG.getBitcast(MulVT, RHS));
-  }
-
-  return SDValue();
-}
-
 /// Optimize a single multiply with constant into two operations in order to
 /// implement it with two cheaper instructions, e.g. LEA + SHL, LEA + LEA.
 static SDValue combineMul(SDNode *N, SelectionDAG &DAG,
@@ -32467,9 +32433,6 @@
   if (DCI.isBeforeLegalize() && VT.isVector())
     return reduceVMULWidth(N, DAG, Subtarget);
 
-  if (!DCI.isBeforeLegalize() && VT.isVector())
-    return combineVMUL(N, DAG, Subtarget);
-
   if (!MulConstantOptimization)
     return SDValue();
   // An imul is usually smaller than the alternative sequence.
@@ -34911,7 +34874,7 @@
     // X86 is rubbish at scalar and vector i64 multiplies (until AVX512DQ) - its
     // better to truncate if we have the chance.
     if (SrcVT.getScalarType() == MVT::i64 && TLI.isOperationLegal(Opcode, VT) &&
-        !TLI.isOperationLegal(Opcode, SrcVT))
+        !Subtarget.hasDQI())
       return TruncateArithmetic(Src.getOperand(0), Src.getOperand(1));
     LLVM_FALLTHROUGH;
   case ISD::ADD: {
diff --git a/llvm/test/CodeGen/X86/combine-pmuldq.ll b/llvm/test/CodeGen/X86/combine-pmuldq.ll
index 0c7b8d6..ebfe0d5 100644
--- a/llvm/test/CodeGen/X86/combine-pmuldq.ll
+++ b/llvm/test/CodeGen/X86/combine-pmuldq.ll
@@ -90,7 +90,7 @@
 ; AVX512DQVL-NEXT:    vpxor %xmm2, %xmm2, %xmm2
 ; AVX512DQVL-NEXT:    vpblendd {{.*#+}} xmm0 = xmm0[0],xmm2[1],xmm0[2],xmm2[3]
 ; AVX512DQVL-NEXT:    vpblendd {{.*#+}} xmm1 = xmm1[0],xmm2[1],xmm1[2],xmm2[3]
-; AVX512DQVL-NEXT:    vpmullq %xmm1, %xmm0, %xmm0
+; AVX512DQVL-NEXT:    vpmuludq %xmm1, %xmm0, %xmm0
 ; AVX512DQVL-NEXT:    retq
   %1 = shufflevector <4 x i32> %a0, <4 x i32> zeroinitializer, <4 x i32> <i32 0, i32 5, i32 2, i32 7>
   %2 = shufflevector <4 x i32> %a1, <4 x i32> zeroinitializer, <4 x i32> <i32 0, i32 5, i32 2, i32 7>
@@ -134,7 +134,7 @@
 ; AVX512DQVL-NEXT:    vpxor %xmm2, %xmm2, %xmm2
 ; AVX512DQVL-NEXT:    vpblendd {{.*#+}} ymm0 = ymm0[0],ymm2[1],ymm0[2],ymm2[3],ymm0[4],ymm2[5],ymm0[6],ymm2[7]
 ; AVX512DQVL-NEXT:    vpblendd {{.*#+}} ymm1 = ymm1[0],ymm2[1],ymm1[2],ymm2[3],ymm1[4],ymm2[5],ymm1[6],ymm2[7]
-; AVX512DQVL-NEXT:    vpmullq %ymm1, %ymm0, %ymm0
+; AVX512DQVL-NEXT:    vpmuludq %ymm1, %ymm0, %ymm0
 ; AVX512DQVL-NEXT:    retq
   %1 = shufflevector <8 x i32> %a0, <8 x i32> zeroinitializer, <8 x i32> <i32 0, i32 9, i32 2, i32 11, i32 4, i32 13, i32 6, i32 15>
   %2 = shufflevector <8 x i32> %a1, <8 x i32> zeroinitializer, <8 x i32> <i32 0, i32 9, i32 2, i32 11, i32 4, i32 13, i32 6, i32 15>