Add more AVX2 shift lowering support. Move AVX2 variable shift to use patterns instead of custom lowering code.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@144457 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp
index e77b1df..f1c80a2 100644
--- a/lib/Target/X86/X86ISelLowering.cpp
+++ b/lib/Target/X86/X86ISelLowering.cpp
@@ -924,10 +924,6 @@
     // FIXME: Do we need to handle scalar-to-vector here?
     setOperationAction(ISD::MUL,                MVT::v4i32, Legal);
 
-    // Can turn SHL into an integer multiply.
-    setOperationAction(ISD::SHL,                MVT::v4i32, Custom);
-    setOperationAction(ISD::SHL,                MVT::v16i8, Custom);
-
     setOperationAction(ISD::VSELECT,            MVT::v2f64, Legal);
     setOperationAction(ISD::VSELECT,            MVT::v2i64, Legal);
     setOperationAction(ISD::VSELECT,            MVT::v16i8, Legal);
@@ -955,18 +951,32 @@
   }
 
   if (Subtarget->hasXMMInt()) {
-    setOperationAction(ISD::SRL,               MVT::v2i64, Custom);
-    setOperationAction(ISD::SRL,               MVT::v4i32, Custom);
-    setOperationAction(ISD::SRL,               MVT::v16i8, Custom);
     setOperationAction(ISD::SRL,               MVT::v8i16, Custom);
+    setOperationAction(ISD::SRL,               MVT::v16i8, Custom);
 
-    setOperationAction(ISD::SHL,               MVT::v2i64, Custom);
-    setOperationAction(ISD::SHL,               MVT::v4i32, Custom);
     setOperationAction(ISD::SHL,               MVT::v8i16, Custom);
+    setOperationAction(ISD::SHL,               MVT::v16i8, Custom);
 
-    setOperationAction(ISD::SRA,               MVT::v4i32, Custom);
     setOperationAction(ISD::SRA,               MVT::v8i16, Custom);
     setOperationAction(ISD::SRA,               MVT::v16i8, Custom);
+
+    if (Subtarget->hasAVX2()) {
+      setOperationAction(ISD::SRL,             MVT::v2i64, Legal);
+      setOperationAction(ISD::SRL,             MVT::v4i32, Legal);
+
+      setOperationAction(ISD::SHL,             MVT::v2i64, Legal);
+      setOperationAction(ISD::SHL,             MVT::v4i32, Legal);
+
+      setOperationAction(ISD::SRA,             MVT::v4i32, Legal);
+    } else {
+      setOperationAction(ISD::SRL,             MVT::v2i64, Custom);
+      setOperationAction(ISD::SRL,             MVT::v4i32, Custom);
+
+      setOperationAction(ISD::SHL,             MVT::v2i64, Custom);
+      setOperationAction(ISD::SHL,             MVT::v4i32, Custom);
+
+      setOperationAction(ISD::SRA,             MVT::v4i32, Custom);
+    }
   }
 
   if (Subtarget->hasSSE42() || Subtarget->hasAVX())
@@ -1009,18 +1019,14 @@
     setOperationAction(ISD::CONCAT_VECTORS,     MVT::v32i8,  Custom);
     setOperationAction(ISD::CONCAT_VECTORS,     MVT::v16i16, Custom);
 
-    setOperationAction(ISD::SRL,               MVT::v4i64, Custom);
-    setOperationAction(ISD::SRL,               MVT::v8i32, Custom);
     setOperationAction(ISD::SRL,               MVT::v16i16, Custom);
     setOperationAction(ISD::SRL,               MVT::v32i8, Custom);
 
-    setOperationAction(ISD::SHL,               MVT::v4i64, Custom);
-    setOperationAction(ISD::SHL,               MVT::v8i32, Custom);
     setOperationAction(ISD::SHL,               MVT::v16i16, Custom);
     setOperationAction(ISD::SHL,               MVT::v32i8, Custom);
 
-    setOperationAction(ISD::SRA,               MVT::v8i32, Custom);
     setOperationAction(ISD::SRA,               MVT::v16i16, Custom);
+    setOperationAction(ISD::SRA,               MVT::v32i8, Custom);
 
     setOperationAction(ISD::SETCC,             MVT::v32i8, Custom);
     setOperationAction(ISD::SETCC,             MVT::v16i16, Custom);
@@ -1053,6 +1059,14 @@
       // Don't lower v32i8 because there is no 128-bit byte mul
 
       setOperationAction(ISD::VSELECT,         MVT::v32i8, Legal);
+
+      setOperationAction(ISD::SRL,             MVT::v4i64, Legal);
+      setOperationAction(ISD::SRL,             MVT::v8i32, Legal);
+
+      setOperationAction(ISD::SHL,             MVT::v4i64, Legal);
+      setOperationAction(ISD::SHL,             MVT::v8i32, Legal);
+
+      setOperationAction(ISD::SRA,             MVT::v8i32, Legal);
     } else {
       setOperationAction(ISD::ADD,             MVT::v4i64, Custom);
       setOperationAction(ISD::ADD,             MVT::v8i32, Custom);
@@ -1068,6 +1082,14 @@
       setOperationAction(ISD::MUL,             MVT::v8i32, Custom);
       setOperationAction(ISD::MUL,             MVT::v16i16, Custom);
       // Don't lower v32i8 because there is no 128-bit byte mul
+
+      setOperationAction(ISD::SRL,             MVT::v4i64, Custom);
+      setOperationAction(ISD::SRL,             MVT::v8i32, Custom);
+
+      setOperationAction(ISD::SHL,             MVT::v4i64, Custom);
+      setOperationAction(ISD::SHL,             MVT::v8i32, Custom);
+
+      setOperationAction(ISD::SRA,             MVT::v8i32, Custom);
     }
 
     // Custom lower several nodes for 256-bit types.
@@ -9510,6 +9532,14 @@
 
   // Fix vector shift instructions where the last operand is a non-immediate
   // i32 value.
+  case Intrinsic::x86_avx2_pslli_w:
+  case Intrinsic::x86_avx2_pslli_d:
+  case Intrinsic::x86_avx2_pslli_q:
+  case Intrinsic::x86_avx2_psrli_w:
+  case Intrinsic::x86_avx2_psrli_d:
+  case Intrinsic::x86_avx2_psrli_q:
+  case Intrinsic::x86_avx2_psrai_w:
+  case Intrinsic::x86_avx2_psrai_d:
   case Intrinsic::x86_sse2_pslli_w:
   case Intrinsic::x86_sse2_pslli_d:
   case Intrinsic::x86_sse2_pslli_q:
@@ -9557,6 +9587,30 @@
     case Intrinsic::x86_sse2_psrai_d:
       NewIntNo = Intrinsic::x86_sse2_psra_d;
       break;
+    case Intrinsic::x86_avx2_pslli_w:
+      NewIntNo = Intrinsic::x86_avx2_psll_w;
+      break;
+    case Intrinsic::x86_avx2_pslli_d:
+      NewIntNo = Intrinsic::x86_avx2_psll_d;
+      break;
+    case Intrinsic::x86_avx2_pslli_q:
+      NewIntNo = Intrinsic::x86_avx2_psll_q;
+      break;
+    case Intrinsic::x86_avx2_psrli_w:
+      NewIntNo = Intrinsic::x86_avx2_psrl_w;
+      break;
+    case Intrinsic::x86_avx2_psrli_d:
+      NewIntNo = Intrinsic::x86_avx2_psrl_d;
+      break;
+    case Intrinsic::x86_avx2_psrli_q:
+      NewIntNo = Intrinsic::x86_avx2_psrl_q;
+      break;
+    case Intrinsic::x86_avx2_psrai_w:
+      NewIntNo = Intrinsic::x86_avx2_psra_w;
+      break;
+    case Intrinsic::x86_avx2_psrai_d:
+      NewIntNo = Intrinsic::x86_avx2_psra_d;
+      break;
     default: {
       ShAmtVT = MVT::v2i32;
       switch (IntNo) {
@@ -10251,52 +10305,6 @@
     }
   }
 
-  // AVX2 variable shifts
-  if (Subtarget->hasAVX2()) {
-    if (VT == MVT::v4i32 && Op->getOpcode() == ISD::SHL)
-       return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
-                     DAG.getConstant(Intrinsic::x86_avx2_psllv_d, MVT::i32),
-                     R, Amt);
-    if (VT == MVT::v8i32 && Op->getOpcode() == ISD::SHL)
-       return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
-                     DAG.getConstant(Intrinsic::x86_avx2_psllv_d_256, MVT::i32),
-                     R, Amt);
-    if (VT == MVT::v2i64 && Op->getOpcode() == ISD::SHL)
-       return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
-                     DAG.getConstant(Intrinsic::x86_avx2_psllv_q, MVT::i32),
-                     R, Amt);
-    if (VT == MVT::v4i64 && Op->getOpcode() == ISD::SHL)
-       return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
-                     DAG.getConstant(Intrinsic::x86_avx2_psllv_q_256, MVT::i32),
-                    R, Amt);
-
-    if (VT == MVT::v4i32 && Op->getOpcode() == ISD::SRL)
-       return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
-                     DAG.getConstant(Intrinsic::x86_avx2_psrlv_d, MVT::i32),
-                     R, Amt);
-    if (VT == MVT::v8i32 && Op->getOpcode() == ISD::SRL)
-       return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
-                     DAG.getConstant(Intrinsic::x86_avx2_psrlv_d_256, MVT::i32),
-                     R, Amt);
-    if (VT == MVT::v2i64 && Op->getOpcode() == ISD::SRL)
-       return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
-                     DAG.getConstant(Intrinsic::x86_avx2_psrlv_q, MVT::i32),
-                     R, Amt);
-    if (VT == MVT::v4i64 && Op->getOpcode() == ISD::SRL)
-       return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
-                     DAG.getConstant(Intrinsic::x86_avx2_psrlv_q_256, MVT::i32),
-                     R, Amt);
-
-    if (VT == MVT::v4i32 && Op->getOpcode() == ISD::SRA)
-       return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
-                     DAG.getConstant(Intrinsic::x86_avx2_psrav_d, MVT::i32),
-                     R, Amt);
-    if (VT == MVT::v8i32 && Op->getOpcode() == ISD::SRA)
-       return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
-                     DAG.getConstant(Intrinsic::x86_avx2_psrav_d_256, MVT::i32),
-                     R, Amt);
-  }
-
   // Lower SHL with variable shift amount.
   if (VT == MVT::v4i32 && Op->getOpcode() == ISD::SHL) {
     Op = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
@@ -13464,7 +13472,9 @@
   if (!Subtarget->hasXMMInt())
     return SDValue();
 
-  if (VT != MVT::v2i64 && VT != MVT::v4i32 && VT != MVT::v8i16)
+  if (VT != MVT::v2i64 && VT != MVT::v4i32 && VT != MVT::v8i16 &&
+      (!Subtarget->hasAVX2() ||
+       (VT != MVT::v4i64 && VT != MVT::v8i32 && VT != MVT::v16i16)))
     return SDValue();
 
   SDValue ShAmtOp = N->getOperand(1);
@@ -13537,6 +13547,18 @@
       return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
                          DAG.getConstant(Intrinsic::x86_sse2_pslli_w, MVT::i32),
                          ValOp, BaseShAmt);
+    if (VT == MVT::v4i64)
+      return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
+                         DAG.getConstant(Intrinsic::x86_avx2_pslli_q, MVT::i32),
+                         ValOp, BaseShAmt);
+    if (VT == MVT::v8i32)
+      return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
+                         DAG.getConstant(Intrinsic::x86_avx2_pslli_d, MVT::i32),
+                         ValOp, BaseShAmt);
+    if (VT == MVT::v16i16)
+      return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
+                         DAG.getConstant(Intrinsic::x86_avx2_pslli_w, MVT::i32),
+                         ValOp, BaseShAmt);
     break;
   case ISD::SRA:
     if (VT == MVT::v4i32)
@@ -13547,6 +13569,14 @@
       return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
                          DAG.getConstant(Intrinsic::x86_sse2_psrai_w, MVT::i32),
                          ValOp, BaseShAmt);
+    if (VT == MVT::v8i32)
+      return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
+                         DAG.getConstant(Intrinsic::x86_avx2_psrai_d, MVT::i32),
+                         ValOp, BaseShAmt);
+    if (VT == MVT::v16i16)
+      return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
+                         DAG.getConstant(Intrinsic::x86_avx2_psrai_w, MVT::i32),
+                         ValOp, BaseShAmt);
     break;
   case ISD::SRL:
     if (VT == MVT::v2i64)
@@ -13561,6 +13591,18 @@
       return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
                          DAG.getConstant(Intrinsic::x86_sse2_psrli_w, MVT::i32),
                          ValOp, BaseShAmt);
+    if (VT == MVT::v4i64)
+      return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
+                         DAG.getConstant(Intrinsic::x86_avx2_psrli_q, MVT::i32),
+                         ValOp, BaseShAmt);
+    if (VT == MVT::v8i32)
+      return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
+                         DAG.getConstant(Intrinsic::x86_avx2_psrli_d, MVT::i32),
+                         ValOp, BaseShAmt);
+    if (VT ==  MVT::v16i16)
+      return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
+                         DAG.getConstant(Intrinsic::x86_avx2_psrli_w, MVT::i32),
+                         ValOp, BaseShAmt);
     break;
   }
   return SDValue();
diff --git a/lib/Target/X86/X86InstrSSE.td b/lib/Target/X86/X86InstrSSE.td
index 10f527c..735a30f 100644
--- a/lib/Target/X86/X86InstrSSE.td
+++ b/lib/Target/X86/X86InstrSSE.td
@@ -7713,3 +7713,52 @@
 defm VPSRAVD : avx2_var_shift<0x46, "vpsravd", int_x86_avx2_psrav_d,
                               int_x86_avx2_psrav_d_256>;
 
+let Predicates = [HasAVX2] in {
+  def : Pat<(v4i32 (shl (v4i32 VR128:$src1), (v4i32 VR128:$src2))),
+            (VPSLLVDrr VR128:$src1, VR128:$src2)>;
+  def : Pat<(v2i64 (shl (v2i64 VR128:$src1), (v2i64 VR128:$src2))),
+            (VPSLLVQrr VR128:$src1, VR128:$src2)>;
+  def : Pat<(v4i32 (srl (v4i32 VR128:$src1), (v4i32 VR128:$src2))),
+            (VPSRLVDrr VR128:$src1, VR128:$src2)>;
+  def : Pat<(v2i64 (srl (v2i64 VR128:$src1), (v2i64 VR128:$src2))),
+            (VPSRLVQrr VR128:$src1, VR128:$src2)>;
+  def : Pat<(v4i32 (sra (v4i32 VR128:$src1), (v4i32 VR128:$src2))),
+            (VPSRAVDrr VR128:$src1, VR128:$src2)>;
+  def : Pat<(v8i32 (shl (v8i32 VR256:$src1), (v8i32 VR256:$src2))),
+            (VPSLLVDYrr VR256:$src1, VR256:$src2)>;
+  def : Pat<(v4i64 (shl (v4i64 VR256:$src1), (v4i64 VR256:$src2))),
+            (VPSLLVQYrr VR256:$src1, VR256:$src2)>;
+  def : Pat<(v8i32 (srl (v8i32 VR256:$src1), (v8i32 VR256:$src2))),
+            (VPSRLVDYrr VR256:$src1, VR256:$src2)>;
+  def : Pat<(v4i64 (srl (v4i64 VR256:$src1), (v4i64 VR256:$src2))),
+            (VPSRLVQYrr VR256:$src1, VR256:$src2)>;
+  def : Pat<(v8i32 (sra (v8i32 VR256:$src1), (v8i32 VR256:$src2))),
+            (VPSRAVDYrr VR256:$src1, VR256:$src2)>;
+
+  def : Pat<(v4i32 (shl (v4i32 VR128:$src1),
+                    (v4i32 (bitconvert (memopv2i64 addr:$src2))))),
+            (VPSLLVDrm VR128:$src1,  addr:$src2)>;
+  def : Pat<(v2i64 (shl (v2i64 VR128:$src1), (memopv2i64 addr:$src2))),
+            (VPSLLVQrm VR128:$src1,  addr:$src2)>;
+  def : Pat<(v4i32 (srl (v4i32 VR128:$src1),
+                    (v4i32 (bitconvert (memopv2i64 addr:$src2))))),
+            (VPSRLVDrm VR128:$src1,  addr:$src2)>;
+  def : Pat<(v2i64 (srl (v2i64 VR128:$src1), (memopv2i64 addr:$src2))),
+            (VPSRLVQrm VR128:$src1,  addr:$src2)>;
+  def : Pat<(v4i32 (sra (v4i32 VR128:$src1),
+                    (v4i32 (bitconvert (memopv2i64 addr:$src2))))),
+            (VPSRAVDrm VR128:$src1,  addr:$src2)>;
+  def : Pat<(v8i32 (shl (v8i32 VR256:$src1),
+                    (v8i32 (bitconvert (memopv4i64 addr:$src2))))),
+            (VPSLLVDYrm VR256:$src1, addr:$src2)>;
+  def : Pat<(v4i64 (shl (v4i64 VR256:$src1), (memopv4i64 addr:$src2))),
+            (VPSLLVQYrm VR256:$src1, addr:$src2)>;
+  def : Pat<(v8i32 (srl (v8i32 VR256:$src1),
+                    (v8i32 (bitconvert (memopv4i64 addr:$src2))))),
+            (VPSRLVDYrm VR256:$src1, addr:$src2)>;
+  def : Pat<(v4i64 (srl (v4i64 VR256:$src1), (memopv4i64 addr:$src2))),
+            (VPSRLVQYrm VR256:$src1, addr:$src2)>;
+  def : Pat<(v8i32 (sra (v8i32 VR256:$src1),
+                    (v8i32 (bitconvert (memopv4i64 addr:$src2))))),
+            (VPSRAVDYrm VR256:$src1, addr:$src2)>;
+}