Add lowering for AVX2 shift instructions.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@144380 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp
index 93f7de8..e77b1df 100644
--- a/lib/Target/X86/X86ISelLowering.cpp
+++ b/lib/Target/X86/X86ISelLowering.cpp
@@ -1050,21 +1050,9 @@
       setOperationAction(ISD::MUL,             MVT::v4i64, Custom);
       setOperationAction(ISD::MUL,             MVT::v8i32, Legal);
       setOperationAction(ISD::MUL,             MVT::v16i16, Legal);
+      // Don't lower v32i8 because there is no 128-bit byte mul
 
       setOperationAction(ISD::VSELECT,         MVT::v32i8, Legal);
-
-      setOperationAction(ISD::SHL,         MVT::v4i32, Legal);
-      setOperationAction(ISD::SHL,         MVT::v2i64, Legal);
-      setOperationAction(ISD::SRL,         MVT::v4i32, Legal);
-      setOperationAction(ISD::SRL,         MVT::v2i64, Legal);
-      setOperationAction(ISD::SRA,         MVT::v4i32, Legal);
-
-      setOperationAction(ISD::SHL,         MVT::v8i32, Legal);
-      setOperationAction(ISD::SHL,         MVT::v4i64, Legal);
-      setOperationAction(ISD::SRL,         MVT::v8i32, Legal);
-      setOperationAction(ISD::SRL,         MVT::v4i64, Legal);
-      setOperationAction(ISD::SRA,         MVT::v8i32, Legal);
-      // Don't lower v32i8 because there is no 128-bit byte mul
     } else {
       setOperationAction(ISD::ADD,             MVT::v4i64, Custom);
       setOperationAction(ISD::ADD,             MVT::v8i32, Custom);
@@ -10130,47 +10118,6 @@
   if (!Subtarget->hasXMMInt())
     return SDValue();
 
-  // Decompose 256-bit shifts into smaller 128-bit shifts.
-  if (VT.getSizeInBits() == 256) {
-    int NumElems = VT.getVectorNumElements();
-    MVT EltVT = VT.getVectorElementType().getSimpleVT();
-    EVT NewVT = MVT::getVectorVT(EltVT, NumElems/2);
-
-    // Extract the two vectors
-    SDValue V1 = Extract128BitVector(R, DAG.getConstant(0, MVT::i32), DAG, dl);
-    SDValue V2 = Extract128BitVector(R, DAG.getConstant(NumElems/2, MVT::i32),
-                                     DAG, dl);
-
-    // Recreate the shift amount vectors
-    SDValue Amt1, Amt2;
-    if (Amt.getOpcode() == ISD::BUILD_VECTOR) {
-      // Constant shift amount
-      SmallVector<SDValue, 4> Amt1Csts;
-      SmallVector<SDValue, 4> Amt2Csts;
-      for (int i = 0; i < NumElems/2; ++i)
-        Amt1Csts.push_back(Amt->getOperand(i));
-      for (int i = NumElems/2; i < NumElems; ++i)
-        Amt2Csts.push_back(Amt->getOperand(i));
-
-      Amt1 = DAG.getNode(ISD::BUILD_VECTOR, dl, NewVT,
-                                 &Amt1Csts[0], NumElems/2);
-      Amt2 = DAG.getNode(ISD::BUILD_VECTOR, dl, NewVT,
-                                 &Amt2Csts[0], NumElems/2);
-    } else {
-      // Variable shift amount
-      Amt1 = Extract128BitVector(Amt, DAG.getConstant(0, MVT::i32), DAG, dl);
-      Amt2 = Extract128BitVector(Amt, DAG.getConstant(NumElems/2, MVT::i32),
-                                 DAG, dl);
-    }
-
-    // Issue new vector shifts for the smaller types
-    V1 = DAG.getNode(Op.getOpcode(), dl, NewVT, V1, Amt1);
-    V2 = DAG.getNode(Op.getOpcode(), dl, NewVT, V2, Amt2);
-
-    // Concatenate the result back
-    return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, V1, V2);
-  }
-
   // Optimize shl/srl/sra with constant shift amount.
   if (isSplatVector(Amt.getNode())) {
     SDValue SclrAmt = Amt->getOperand(0);
@@ -10259,9 +10206,97 @@
         Res = DAG.getNode(ISD::SUB, dl, VT, Res, Mask);
         return Res;
       }
+
+      if (Subtarget->hasAVX2()) {
+        if (VT == MVT::v4i64 && Op.getOpcode() == ISD::SHL)
+         return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
+                       DAG.getConstant(Intrinsic::x86_avx2_pslli_q, MVT::i32),
+                       R, DAG.getConstant(ShiftAmt, MVT::i32));
+
+        if (VT == MVT::v8i32 && Op.getOpcode() == ISD::SHL)
+         return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
+                       DAG.getConstant(Intrinsic::x86_avx2_pslli_d, MVT::i32),
+                       R, DAG.getConstant(ShiftAmt, MVT::i32));
+
+        if (VT == MVT::v16i16 && Op.getOpcode() == ISD::SHL)
+         return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
+                       DAG.getConstant(Intrinsic::x86_avx2_pslli_w, MVT::i32),
+                       R, DAG.getConstant(ShiftAmt, MVT::i32));
+
+        if (VT == MVT::v4i64 && Op.getOpcode() == ISD::SRL)
+         return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
+                       DAG.getConstant(Intrinsic::x86_avx2_psrli_q, MVT::i32),
+                       R, DAG.getConstant(ShiftAmt, MVT::i32));
+
+        if (VT == MVT::v8i32 && Op.getOpcode() == ISD::SRL)
+         return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
+                       DAG.getConstant(Intrinsic::x86_avx2_psrli_d, MVT::i32),
+                       R, DAG.getConstant(ShiftAmt, MVT::i32));
+
+        if (VT == MVT::v16i16 && Op.getOpcode() == ISD::SRL)
+         return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
+                       DAG.getConstant(Intrinsic::x86_avx2_psrli_w, MVT::i32),
+                       R, DAG.getConstant(ShiftAmt, MVT::i32));
+
+        if (VT == MVT::v8i32 && Op.getOpcode() == ISD::SRA)
+         return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
+                       DAG.getConstant(Intrinsic::x86_avx2_psrai_d, MVT::i32),
+                       R, DAG.getConstant(ShiftAmt, MVT::i32));
+
+        if (VT == MVT::v16i16 && Op.getOpcode() == ISD::SRA)
+         return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
+                       DAG.getConstant(Intrinsic::x86_avx2_psrai_w, MVT::i32),
+                       R, DAG.getConstant(ShiftAmt, MVT::i32));
+        }
     }
   }
 
+  // AVX2 variable shifts
+  if (Subtarget->hasAVX2()) {
+    if (VT == MVT::v4i32 && Op->getOpcode() == ISD::SHL)
+       return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
+                     DAG.getConstant(Intrinsic::x86_avx2_psllv_d, MVT::i32),
+                     R, Amt);
+    if (VT == MVT::v8i32 && Op->getOpcode() == ISD::SHL)
+       return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
+                     DAG.getConstant(Intrinsic::x86_avx2_psllv_d_256, MVT::i32),
+                     R, Amt);
+    if (VT == MVT::v2i64 && Op->getOpcode() == ISD::SHL)
+       return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
+                     DAG.getConstant(Intrinsic::x86_avx2_psllv_q, MVT::i32),
+                     R, Amt);
+    if (VT == MVT::v4i64 && Op->getOpcode() == ISD::SHL)
+       return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
+                     DAG.getConstant(Intrinsic::x86_avx2_psllv_q_256, MVT::i32),
+                    R, Amt);
+
+    if (VT == MVT::v4i32 && Op->getOpcode() == ISD::SRL)
+       return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
+                     DAG.getConstant(Intrinsic::x86_avx2_psrlv_d, MVT::i32),
+                     R, Amt);
+    if (VT == MVT::v8i32 && Op->getOpcode() == ISD::SRL)
+       return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
+                     DAG.getConstant(Intrinsic::x86_avx2_psrlv_d_256, MVT::i32),
+                     R, Amt);
+    if (VT == MVT::v2i64 && Op->getOpcode() == ISD::SRL)
+       return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
+                     DAG.getConstant(Intrinsic::x86_avx2_psrlv_q, MVT::i32),
+                     R, Amt);
+    if (VT == MVT::v4i64 && Op->getOpcode() == ISD::SRL)
+       return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
+                     DAG.getConstant(Intrinsic::x86_avx2_psrlv_q_256, MVT::i32),
+                     R, Amt);
+
+    if (VT == MVT::v4i32 && Op->getOpcode() == ISD::SRA)
+       return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
+                     DAG.getConstant(Intrinsic::x86_avx2_psrav_d, MVT::i32),
+                     R, Amt);
+    if (VT == MVT::v8i32 && Op->getOpcode() == ISD::SRA)
+       return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
+                     DAG.getConstant(Intrinsic::x86_avx2_psrav_d_256, MVT::i32),
+                     R, Amt);
+  }
+
   // Lower SHL with variable shift amount.
   if (VT == MVT::v4i32 && Op->getOpcode() == ISD::SHL) {
     Op = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
@@ -10328,6 +10363,48 @@
                     R, DAG.getNode(ISD::ADD, dl, VT, R, R));
     return R;
   }
+
+  // Decompose 256-bit shifts into smaller 128-bit shifts.
+  if (VT.getSizeInBits() == 256) {
+    int NumElems = VT.getVectorNumElements();
+    MVT EltVT = VT.getVectorElementType().getSimpleVT();
+    EVT NewVT = MVT::getVectorVT(EltVT, NumElems/2);
+
+    // Extract the two vectors
+    SDValue V1 = Extract128BitVector(R, DAG.getConstant(0, MVT::i32), DAG, dl);
+    SDValue V2 = Extract128BitVector(R, DAG.getConstant(NumElems/2, MVT::i32),
+                                     DAG, dl);
+
+    // Recreate the shift amount vectors
+    SDValue Amt1, Amt2;
+    if (Amt.getOpcode() == ISD::BUILD_VECTOR) {
+      // Constant shift amount
+      SmallVector<SDValue, 4> Amt1Csts;
+      SmallVector<SDValue, 4> Amt2Csts;
+      for (int i = 0; i < NumElems/2; ++i)
+        Amt1Csts.push_back(Amt->getOperand(i));
+      for (int i = NumElems/2; i < NumElems; ++i)
+        Amt2Csts.push_back(Amt->getOperand(i));
+
+      Amt1 = DAG.getNode(ISD::BUILD_VECTOR, dl, NewVT,
+                                 &Amt1Csts[0], NumElems/2);
+      Amt2 = DAG.getNode(ISD::BUILD_VECTOR, dl, NewVT,
+                                 &Amt2Csts[0], NumElems/2);
+    } else {
+      // Variable shift amount
+      Amt1 = Extract128BitVector(Amt, DAG.getConstant(0, MVT::i32), DAG, dl);
+      Amt2 = Extract128BitVector(Amt, DAG.getConstant(NumElems/2, MVT::i32),
+                                 DAG, dl);
+    }
+
+    // Issue new vector shifts for the smaller types
+    V1 = DAG.getNode(Op.getOpcode(), dl, NewVT, V1, Amt1);
+    V2 = DAG.getNode(Op.getOpcode(), dl, NewVT, V2, Amt2);
+
+    // Concatenate the result back
+    return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, V1, V2);
+  }
+
   return SDValue();
 }