Add lowering for AVX2 shift instructions.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@144380 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp
index 93f7de8..e77b1df 100644
--- a/lib/Target/X86/X86ISelLowering.cpp
+++ b/lib/Target/X86/X86ISelLowering.cpp
@@ -1050,21 +1050,9 @@
       setOperationAction(ISD::MUL,             MVT::v4i64, Custom);
       setOperationAction(ISD::MUL,             MVT::v8i32, Legal);
       setOperationAction(ISD::MUL,             MVT::v16i16, Legal);
+      // Don't lower v32i8 because there is no 128-bit byte mul
 
       setOperationAction(ISD::VSELECT,         MVT::v32i8, Legal);
-
-      setOperationAction(ISD::SHL,         MVT::v4i32, Legal);
-      setOperationAction(ISD::SHL,         MVT::v2i64, Legal);
-      setOperationAction(ISD::SRL,         MVT::v4i32, Legal);
-      setOperationAction(ISD::SRL,         MVT::v2i64, Legal);
-      setOperationAction(ISD::SRA,         MVT::v4i32, Legal);
-
-      setOperationAction(ISD::SHL,         MVT::v8i32, Legal);
-      setOperationAction(ISD::SHL,         MVT::v4i64, Legal);
-      setOperationAction(ISD::SRL,         MVT::v8i32, Legal);
-      setOperationAction(ISD::SRL,         MVT::v4i64, Legal);
-      setOperationAction(ISD::SRA,         MVT::v8i32, Legal);
-      // Don't lower v32i8 because there is no 128-bit byte mul
     } else {
       setOperationAction(ISD::ADD,             MVT::v4i64, Custom);
       setOperationAction(ISD::ADD,             MVT::v8i32, Custom);
@@ -10130,47 +10118,6 @@
   if (!Subtarget->hasXMMInt())
     return SDValue();
 
-  // Decompose 256-bit shifts into smaller 128-bit shifts.
-  if (VT.getSizeInBits() == 256) {
-    int NumElems = VT.getVectorNumElements();
-    MVT EltVT = VT.getVectorElementType().getSimpleVT();
-    EVT NewVT = MVT::getVectorVT(EltVT, NumElems/2);
-
-    // Extract the two vectors
-    SDValue V1 = Extract128BitVector(R, DAG.getConstant(0, MVT::i32), DAG, dl);
-    SDValue V2 = Extract128BitVector(R, DAG.getConstant(NumElems/2, MVT::i32),
-                                     DAG, dl);
-
-    // Recreate the shift amount vectors
-    SDValue Amt1, Amt2;
-    if (Amt.getOpcode() == ISD::BUILD_VECTOR) {
-      // Constant shift amount
-      SmallVector<SDValue, 4> Amt1Csts;
-      SmallVector<SDValue, 4> Amt2Csts;
-      for (int i = 0; i < NumElems/2; ++i)
-        Amt1Csts.push_back(Amt->getOperand(i));
-      for (int i = NumElems/2; i < NumElems; ++i)
-        Amt2Csts.push_back(Amt->getOperand(i));
-
-      Amt1 = DAG.getNode(ISD::BUILD_VECTOR, dl, NewVT,
-                                 &Amt1Csts[0], NumElems/2);
-      Amt2 = DAG.getNode(ISD::BUILD_VECTOR, dl, NewVT,
-                                 &Amt2Csts[0], NumElems/2);
-    } else {
-      // Variable shift amount
-      Amt1 = Extract128BitVector(Amt, DAG.getConstant(0, MVT::i32), DAG, dl);
-      Amt2 = Extract128BitVector(Amt, DAG.getConstant(NumElems/2, MVT::i32),
-                                 DAG, dl);
-    }
-
-    // Issue new vector shifts for the smaller types
-    V1 = DAG.getNode(Op.getOpcode(), dl, NewVT, V1, Amt1);
-    V2 = DAG.getNode(Op.getOpcode(), dl, NewVT, V2, Amt2);
-
-    // Concatenate the result back
-    return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, V1, V2);
-  }
-
   // Optimize shl/srl/sra with constant shift amount.
   if (isSplatVector(Amt.getNode())) {
     SDValue SclrAmt = Amt->getOperand(0);
@@ -10259,9 +10206,97 @@
         Res = DAG.getNode(ISD::SUB, dl, VT, Res, Mask);
         return Res;
       }
+
+      if (Subtarget->hasAVX2()) {
+        if (VT == MVT::v4i64 && Op.getOpcode() == ISD::SHL)
+         return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
+                       DAG.getConstant(Intrinsic::x86_avx2_pslli_q, MVT::i32),
+                       R, DAG.getConstant(ShiftAmt, MVT::i32));
+
+        if (VT == MVT::v8i32 && Op.getOpcode() == ISD::SHL)
+         return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
+                       DAG.getConstant(Intrinsic::x86_avx2_pslli_d, MVT::i32),
+                       R, DAG.getConstant(ShiftAmt, MVT::i32));
+
+        if (VT == MVT::v16i16 && Op.getOpcode() == ISD::SHL)
+         return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
+                       DAG.getConstant(Intrinsic::x86_avx2_pslli_w, MVT::i32),
+                       R, DAG.getConstant(ShiftAmt, MVT::i32));
+
+        if (VT == MVT::v4i64 && Op.getOpcode() == ISD::SRL)
+         return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
+                       DAG.getConstant(Intrinsic::x86_avx2_psrli_q, MVT::i32),
+                       R, DAG.getConstant(ShiftAmt, MVT::i32));
+
+        if (VT == MVT::v8i32 && Op.getOpcode() == ISD::SRL)
+         return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
+                       DAG.getConstant(Intrinsic::x86_avx2_psrli_d, MVT::i32),
+                       R, DAG.getConstant(ShiftAmt, MVT::i32));
+
+        if (VT == MVT::v16i16 && Op.getOpcode() == ISD::SRL)
+         return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
+                       DAG.getConstant(Intrinsic::x86_avx2_psrli_w, MVT::i32),
+                       R, DAG.getConstant(ShiftAmt, MVT::i32));
+
+        if (VT == MVT::v8i32 && Op.getOpcode() == ISD::SRA)
+         return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
+                       DAG.getConstant(Intrinsic::x86_avx2_psrai_d, MVT::i32),
+                       R, DAG.getConstant(ShiftAmt, MVT::i32));
+
+        if (VT == MVT::v16i16 && Op.getOpcode() == ISD::SRA)
+         return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
+                       DAG.getConstant(Intrinsic::x86_avx2_psrai_w, MVT::i32),
+                       R, DAG.getConstant(ShiftAmt, MVT::i32));
+        }
     }
   }
 
+  // AVX2 variable shifts
+  if (Subtarget->hasAVX2()) {
+    if (VT == MVT::v4i32 && Op->getOpcode() == ISD::SHL)
+       return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
+                     DAG.getConstant(Intrinsic::x86_avx2_psllv_d, MVT::i32),
+                     R, Amt);
+    if (VT == MVT::v8i32 && Op->getOpcode() == ISD::SHL)
+       return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
+                     DAG.getConstant(Intrinsic::x86_avx2_psllv_d_256, MVT::i32),
+                     R, Amt);
+    if (VT == MVT::v2i64 && Op->getOpcode() == ISD::SHL)
+       return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
+                     DAG.getConstant(Intrinsic::x86_avx2_psllv_q, MVT::i32),
+                     R, Amt);
+    if (VT == MVT::v4i64 && Op->getOpcode() == ISD::SHL)
+       return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
+                     DAG.getConstant(Intrinsic::x86_avx2_psllv_q_256, MVT::i32),
+                    R, Amt);
+
+    if (VT == MVT::v4i32 && Op->getOpcode() == ISD::SRL)
+       return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
+                     DAG.getConstant(Intrinsic::x86_avx2_psrlv_d, MVT::i32),
+                     R, Amt);
+    if (VT == MVT::v8i32 && Op->getOpcode() == ISD::SRL)
+       return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
+                     DAG.getConstant(Intrinsic::x86_avx2_psrlv_d_256, MVT::i32),
+                     R, Amt);
+    if (VT == MVT::v2i64 && Op->getOpcode() == ISD::SRL)
+       return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
+                     DAG.getConstant(Intrinsic::x86_avx2_psrlv_q, MVT::i32),
+                     R, Amt);
+    if (VT == MVT::v4i64 && Op->getOpcode() == ISD::SRL)
+       return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
+                     DAG.getConstant(Intrinsic::x86_avx2_psrlv_q_256, MVT::i32),
+                     R, Amt);
+
+    if (VT == MVT::v4i32 && Op->getOpcode() == ISD::SRA)
+       return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
+                     DAG.getConstant(Intrinsic::x86_avx2_psrav_d, MVT::i32),
+                     R, Amt);
+    if (VT == MVT::v8i32 && Op->getOpcode() == ISD::SRA)
+       return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
+                     DAG.getConstant(Intrinsic::x86_avx2_psrav_d_256, MVT::i32),
+                     R, Amt);
+  }
+
   // Lower SHL with variable shift amount.
   if (VT == MVT::v4i32 && Op->getOpcode() == ISD::SHL) {
     Op = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
@@ -10328,6 +10363,48 @@
                     R, DAG.getNode(ISD::ADD, dl, VT, R, R));
     return R;
   }
+
+  // Decompose 256-bit shifts into smaller 128-bit shifts.
+  if (VT.getSizeInBits() == 256) {
+    int NumElems = VT.getVectorNumElements();
+    MVT EltVT = VT.getVectorElementType().getSimpleVT();
+    EVT NewVT = MVT::getVectorVT(EltVT, NumElems/2);
+
+    // Extract the two vectors
+    SDValue V1 = Extract128BitVector(R, DAG.getConstant(0, MVT::i32), DAG, dl);
+    SDValue V2 = Extract128BitVector(R, DAG.getConstant(NumElems/2, MVT::i32),
+                                     DAG, dl);
+
+    // Recreate the shift amount vectors
+    SDValue Amt1, Amt2;
+    if (Amt.getOpcode() == ISD::BUILD_VECTOR) {
+      // Constant shift amount
+      SmallVector<SDValue, 4> Amt1Csts;
+      SmallVector<SDValue, 4> Amt2Csts;
+      for (int i = 0; i < NumElems/2; ++i)
+        Amt1Csts.push_back(Amt->getOperand(i));
+      for (int i = NumElems/2; i < NumElems; ++i)
+        Amt2Csts.push_back(Amt->getOperand(i));
+
+      Amt1 = DAG.getNode(ISD::BUILD_VECTOR, dl, NewVT,
+                                 &Amt1Csts[0], NumElems/2);
+      Amt2 = DAG.getNode(ISD::BUILD_VECTOR, dl, NewVT,
+                                 &Amt2Csts[0], NumElems/2);
+    } else {
+      // Variable shift amount
+      Amt1 = Extract128BitVector(Amt, DAG.getConstant(0, MVT::i32), DAG, dl);
+      Amt2 = Extract128BitVector(Amt, DAG.getConstant(NumElems/2, MVT::i32),
+                                 DAG, dl);
+    }
+
+    // Issue new vector shifts for the smaller types
+    V1 = DAG.getNode(Op.getOpcode(), dl, NewVT, V1, Amt1);
+    V2 = DAG.getNode(Op.getOpcode(), dl, NewVT, V2, Amt2);
+
+    // Concatenate the result back
+    return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, V1, V2);
+  }
+
   return SDValue();
 }
 
diff --git a/lib/Target/X86/X86InstrSSE.td b/lib/Target/X86/X86InstrSSE.td
index 91c84dd..10f527c 100644
--- a/lib/Target/X86/X86InstrSSE.td
+++ b/lib/Target/X86/X86InstrSSE.td
@@ -7655,7 +7655,6 @@
 // Variable Bit Shifts
 //
 multiclass avx2_var_shift<bits<8> opc, string OpcodeStr,
-                          PatFrag pf128, PatFrag pf256,
                           Intrinsic Int128, Intrinsic Int256> {
   def rr  : AVX28I<opc, MRMSrcReg, (outs VR128:$dst),
              (ins VR128:$src1, VR128:$src2),
@@ -7664,7 +7663,8 @@
   def rm  : AVX28I<opc, MRMSrcMem, (outs VR128:$dst),
              (ins VR128:$src1, i128mem:$src2),
              !strconcat(OpcodeStr, "\t{$src2, $src1, $dst|$dst, $src1, $src2}"),
-             [(set VR128:$dst, (Int128 VR128:$src1, (pf128 addr:$src2)))]>,
+             [(set VR128:$dst,
+              (Int128 VR128:$src1, (bitconvert (memopv2i64 addr:$src2))))]>,
              VEX_4V;
   def Yrr : AVX28I<opc, MRMSrcReg, (outs VR256:$dst),
              (ins VR256:$src1, VR256:$src2),
@@ -7673,70 +7673,43 @@
   def Yrm : AVX28I<opc, MRMSrcMem, (outs VR256:$dst),
              (ins VR256:$src1, i256mem:$src2),
              !strconcat(OpcodeStr, "\t{$src2, $src1, $dst|$dst, $src1, $src2}"),
-             [(set VR256:$dst, (Int256 VR256:$src1, (pf256 addr:$src2)))]>,
+             [(set VR256:$dst,
+               (Int256 VR256:$src1, (bitconvert (memopv4i64 addr:$src2))))]>,
              VEX_4V;
 }
 
-defm VPSLLVD : avx2_var_shift<0x47, "vpsllvd", memopv4i32, memopv8i32,
-                              int_x86_avx2_psllv_d, int_x86_avx2_psllv_d_256>;
-defm VPSLLVQ : avx2_var_shift<0x47, "vpsllvq", memopv2i64, memopv4i64,
-                              int_x86_avx2_psllv_q, int_x86_avx2_psllv_q_256>,
-                              VEX_W;
-defm VPSRLVD : avx2_var_shift<0x45, "vpsrlvd", memopv4i32, memopv8i32,
-                              int_x86_avx2_psrlv_d, int_x86_avx2_psrlv_d_256>;
-defm VPSRLVQ : avx2_var_shift<0x45, "vpsrlvq", memopv2i64, memopv4i64,
-                              int_x86_avx2_psrlv_q, int_x86_avx2_psrlv_q_256>,
-                              VEX_W;
-defm VPSRAVD : avx2_var_shift<0x46, "vpsravd", memopv4i32, memopv8i32,
-                              int_x86_avx2_psrav_d, int_x86_avx2_psrav_d_256>;
-
-
-let Predicates = [HasAVX2] in {
-
-  def : Pat<(v4i32 (shl (v4i32 VR128:$src1), (v4i32 VR128:$src2))),
-            (VPSLLVDrr VR128:$src1, VR128:$src2)>;
-  def : Pat<(v2i64 (shl (v2i64 VR128:$src1), (v2i64 VR128:$src2))),
-            (VPSLLVQrr VR128:$src1, VR128:$src2)>;
-  def : Pat<(v4i32 (srl (v4i32 VR128:$src1), (v4i32 VR128:$src2))),
-            (VPSRLVDrr VR128:$src1, VR128:$src2)>;
-  def : Pat<(v2i64 (srl (v2i64 VR128:$src1), (v2i64 VR128:$src2))),
-            (VPSRLVQrr VR128:$src1, VR128:$src2)>;
-  def : Pat<(v4i32 (sra (v4i32 VR128:$src1), (v4i32 VR128:$src2))),
-            (VPSRAVDrr VR128:$src1, VR128:$src2)>;
-  def : Pat<(v8i32 (shl (v8i32 VR256:$src1), (v8i32 VR256:$src2))),
-            (VPSLLVDYrr VR256:$src1, VR256:$src2)>;
-  def : Pat<(v4i64 (shl (v4i64 VR256:$src1), (v4i64 VR256:$src2))),
-            (VPSLLVQYrr VR256:$src1, VR256:$src2)>;
-  def : Pat<(v8i32 (srl (v8i32 VR256:$src1), (v8i32 VR256:$src2))),
-            (VPSRLVDYrr VR256:$src1, VR256:$src2)>;
-  def : Pat<(v4i64 (srl (v4i64 VR256:$src1), (v4i64 VR256:$src2))),
-            (VPSRLVQYrr VR256:$src1, VR256:$src2)>;
-  def : Pat<(v8i32 (sra (v8i32 VR256:$src1), (v8i32 VR256:$src2))),
-            (VPSRAVDYrr VR256:$src1, VR256:$src2)>;
-
-  def : Pat<(v4i32 (shl (v4i32 VR128:$src1),(loadv4i32 addr:$src2))),
-            (VPSLLVDrm VR128:$src1,  addr:$src2)>;
-  def : Pat<(v4i32 (shl (v4i32 VR128:$src1),(loadv2i64 addr:$src2))),
-            (VPSLLVDrm VR128:$src1,  addr:$src2)>;
-  def : Pat<(v2i64 (shl (v2i64 VR128:$src1),(loadv2i64 addr:$src2))),
-            (VPSLLVQrm VR128:$src1,  addr:$src2)>;
-  def : Pat<(v4i32 (srl (v4i32 VR128:$src1),(loadv4i32 addr:$src2))),
-            (VPSRLVDrm VR128:$src1,  addr:$src2)>;
-  def : Pat<(v2i64 (srl (v2i64 VR128:$src1),(loadv2i64 addr:$src2))),
-            (VPSRLVQrm VR128:$src1,  addr:$src2)>;
-  def : Pat<(v4i32 (sra (v4i32 VR128:$src1),(loadv4i32 addr:$src2))),
-            (VPSRAVDrm VR128:$src1,  addr:$src2)>;
-  def : Pat<(v8i32 (shl (v8i32 VR256:$src1),(loadv8i32 addr:$src2))),
-            (VPSLLVDYrm VR256:$src1, addr:$src2)>;
-  def : Pat<(v4i64 (shl (v4i64 VR256:$src1),(loadv4i64 addr:$src2))),
-            (VPSLLVQYrm VR256:$src1, addr:$src2)>;
-  def : Pat<(v8i32 (srl (v8i32 VR256:$src1),(loadv8i32 addr:$src2))),
-            (VPSRLVDYrm VR256:$src1, addr:$src2)>;
-  def : Pat<(v4i64 (srl (v4i64 VR256:$src1),(loadv4i64 addr:$src2))),
-            (VPSRLVQYrm VR256:$src1, addr:$src2)>;
-  def : Pat<(v8i32 (sra (v8i32 VR256:$src1),(loadv8i32 addr:$src2))),
-            (VPSRAVDYrm VR256:$src1, addr:$src2)>;
+multiclass avx2_var_shift_i64<bits<8> opc, string OpcodeStr,
+                              Intrinsic Int128, Intrinsic Int256> {
+  def rr  : AVX28I<opc, MRMSrcReg, (outs VR128:$dst),
+             (ins VR128:$src1, VR128:$src2),
+             !strconcat(OpcodeStr, "\t{$src2, $src1, $dst|$dst, $src1, $src2}"),
+             [(set VR128:$dst, (Int128 VR128:$src1, VR128:$src2))]>, VEX_4V;
+  def rm  : AVX28I<opc, MRMSrcMem, (outs VR128:$dst),
+             (ins VR128:$src1, i128mem:$src2),
+             !strconcat(OpcodeStr, "\t{$src2, $src1, $dst|$dst, $src1, $src2}"),
+             [(set VR128:$dst,
+              (Int128 VR128:$src1, (memopv2i64 addr:$src2)))]>,
+             VEX_4V;
+  def Yrr : AVX28I<opc, MRMSrcReg, (outs VR256:$dst),
+             (ins VR256:$src1, VR256:$src2),
+             !strconcat(OpcodeStr, "\t{$src2, $src1, $dst|$dst, $src1, $src2}"),
+             [(set VR256:$dst, (Int256 VR256:$src1, VR256:$src2))]>, VEX_4V;
+  def Yrm : AVX28I<opc, MRMSrcMem, (outs VR256:$dst),
+             (ins VR256:$src1, i256mem:$src2),
+             !strconcat(OpcodeStr, "\t{$src2, $src1, $dst|$dst, $src1, $src2}"),
+             [(set VR256:$dst,
+               (Int256 VR256:$src1, (memopv4i64 addr:$src2)))]>,
+             VEX_4V;
 }
 
-
+defm VPSLLVD : avx2_var_shift<0x47, "vpsllvd", int_x86_avx2_psllv_d,
+                              int_x86_avx2_psllv_d_256>;
+defm VPSLLVQ : avx2_var_shift_i64<0x47, "vpsllvq", int_x86_avx2_psllv_q,
+                                  int_x86_avx2_psllv_q_256>, VEX_W;
+defm VPSRLVD : avx2_var_shift<0x45, "vpsrlvd", int_x86_avx2_psrlv_d,
+                              int_x86_avx2_psrlv_d_256>;
+defm VPSRLVQ : avx2_var_shift_i64<0x45, "vpsrlvq", int_x86_avx2_psrlv_q,
+                                  int_x86_avx2_psrlv_q_256>, VEX_W;
+defm VPSRAVD : avx2_var_shift<0x46, "vpsravd", int_x86_avx2_psrav_d,
+                              int_x86_avx2_psrav_d_256>;