Add more AVX2 shift lowering support. Move AVX2 variable shift to use patterns instead of custom lowering code.
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@144457 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp
index e77b1df..f1c80a2 100644
--- a/lib/Target/X86/X86ISelLowering.cpp
+++ b/lib/Target/X86/X86ISelLowering.cpp
@@ -924,10 +924,6 @@
// FIXME: Do we need to handle scalar-to-vector here?
setOperationAction(ISD::MUL, MVT::v4i32, Legal);
- // Can turn SHL into an integer multiply.
- setOperationAction(ISD::SHL, MVT::v4i32, Custom);
- setOperationAction(ISD::SHL, MVT::v16i8, Custom);
-
setOperationAction(ISD::VSELECT, MVT::v2f64, Legal);
setOperationAction(ISD::VSELECT, MVT::v2i64, Legal);
setOperationAction(ISD::VSELECT, MVT::v16i8, Legal);
@@ -955,18 +951,32 @@
}
if (Subtarget->hasXMMInt()) {
- setOperationAction(ISD::SRL, MVT::v2i64, Custom);
- setOperationAction(ISD::SRL, MVT::v4i32, Custom);
- setOperationAction(ISD::SRL, MVT::v16i8, Custom);
setOperationAction(ISD::SRL, MVT::v8i16, Custom);
+ setOperationAction(ISD::SRL, MVT::v16i8, Custom);
- setOperationAction(ISD::SHL, MVT::v2i64, Custom);
- setOperationAction(ISD::SHL, MVT::v4i32, Custom);
setOperationAction(ISD::SHL, MVT::v8i16, Custom);
+ setOperationAction(ISD::SHL, MVT::v16i8, Custom);
- setOperationAction(ISD::SRA, MVT::v4i32, Custom);
setOperationAction(ISD::SRA, MVT::v8i16, Custom);
setOperationAction(ISD::SRA, MVT::v16i8, Custom);
+
+ if (Subtarget->hasAVX2()) {
+ setOperationAction(ISD::SRL, MVT::v2i64, Legal);
+ setOperationAction(ISD::SRL, MVT::v4i32, Legal);
+
+ setOperationAction(ISD::SHL, MVT::v2i64, Legal);
+ setOperationAction(ISD::SHL, MVT::v4i32, Legal);
+
+ setOperationAction(ISD::SRA, MVT::v4i32, Legal);
+ } else {
+ setOperationAction(ISD::SRL, MVT::v2i64, Custom);
+ setOperationAction(ISD::SRL, MVT::v4i32, Custom);
+
+ setOperationAction(ISD::SHL, MVT::v2i64, Custom);
+ setOperationAction(ISD::SHL, MVT::v4i32, Custom);
+
+ setOperationAction(ISD::SRA, MVT::v4i32, Custom);
+ }
}
if (Subtarget->hasSSE42() || Subtarget->hasAVX())
@@ -1009,18 +1019,14 @@
setOperationAction(ISD::CONCAT_VECTORS, MVT::v32i8, Custom);
setOperationAction(ISD::CONCAT_VECTORS, MVT::v16i16, Custom);
- setOperationAction(ISD::SRL, MVT::v4i64, Custom);
- setOperationAction(ISD::SRL, MVT::v8i32, Custom);
setOperationAction(ISD::SRL, MVT::v16i16, Custom);
setOperationAction(ISD::SRL, MVT::v32i8, Custom);
- setOperationAction(ISD::SHL, MVT::v4i64, Custom);
- setOperationAction(ISD::SHL, MVT::v8i32, Custom);
setOperationAction(ISD::SHL, MVT::v16i16, Custom);
setOperationAction(ISD::SHL, MVT::v32i8, Custom);
- setOperationAction(ISD::SRA, MVT::v8i32, Custom);
setOperationAction(ISD::SRA, MVT::v16i16, Custom);
+ setOperationAction(ISD::SRA, MVT::v32i8, Custom);
setOperationAction(ISD::SETCC, MVT::v32i8, Custom);
setOperationAction(ISD::SETCC, MVT::v16i16, Custom);
@@ -1053,6 +1059,14 @@
// Don't lower v32i8 because there is no 128-bit byte mul
setOperationAction(ISD::VSELECT, MVT::v32i8, Legal);
+
+ setOperationAction(ISD::SRL, MVT::v4i64, Legal);
+ setOperationAction(ISD::SRL, MVT::v8i32, Legal);
+
+ setOperationAction(ISD::SHL, MVT::v4i64, Legal);
+ setOperationAction(ISD::SHL, MVT::v8i32, Legal);
+
+ setOperationAction(ISD::SRA, MVT::v8i32, Legal);
} else {
setOperationAction(ISD::ADD, MVT::v4i64, Custom);
setOperationAction(ISD::ADD, MVT::v8i32, Custom);
@@ -1068,6 +1082,14 @@
setOperationAction(ISD::MUL, MVT::v8i32, Custom);
setOperationAction(ISD::MUL, MVT::v16i16, Custom);
// Don't lower v32i8 because there is no 128-bit byte mul
+
+ setOperationAction(ISD::SRL, MVT::v4i64, Custom);
+ setOperationAction(ISD::SRL, MVT::v8i32, Custom);
+
+ setOperationAction(ISD::SHL, MVT::v4i64, Custom);
+ setOperationAction(ISD::SHL, MVT::v8i32, Custom);
+
+ setOperationAction(ISD::SRA, MVT::v8i32, Custom);
}
// Custom lower several nodes for 256-bit types.
@@ -9510,6 +9532,14 @@
// Fix vector shift instructions where the last operand is a non-immediate
// i32 value.
+ case Intrinsic::x86_avx2_pslli_w:
+ case Intrinsic::x86_avx2_pslli_d:
+ case Intrinsic::x86_avx2_pslli_q:
+ case Intrinsic::x86_avx2_psrli_w:
+ case Intrinsic::x86_avx2_psrli_d:
+ case Intrinsic::x86_avx2_psrli_q:
+ case Intrinsic::x86_avx2_psrai_w:
+ case Intrinsic::x86_avx2_psrai_d:
case Intrinsic::x86_sse2_pslli_w:
case Intrinsic::x86_sse2_pslli_d:
case Intrinsic::x86_sse2_pslli_q:
@@ -9557,6 +9587,30 @@
case Intrinsic::x86_sse2_psrai_d:
NewIntNo = Intrinsic::x86_sse2_psra_d;
break;
+ case Intrinsic::x86_avx2_pslli_w:
+ NewIntNo = Intrinsic::x86_avx2_psll_w;
+ break;
+ case Intrinsic::x86_avx2_pslli_d:
+ NewIntNo = Intrinsic::x86_avx2_psll_d;
+ break;
+ case Intrinsic::x86_avx2_pslli_q:
+ NewIntNo = Intrinsic::x86_avx2_psll_q;
+ break;
+ case Intrinsic::x86_avx2_psrli_w:
+ NewIntNo = Intrinsic::x86_avx2_psrl_w;
+ break;
+ case Intrinsic::x86_avx2_psrli_d:
+ NewIntNo = Intrinsic::x86_avx2_psrl_d;
+ break;
+ case Intrinsic::x86_avx2_psrli_q:
+ NewIntNo = Intrinsic::x86_avx2_psrl_q;
+ break;
+ case Intrinsic::x86_avx2_psrai_w:
+ NewIntNo = Intrinsic::x86_avx2_psra_w;
+ break;
+ case Intrinsic::x86_avx2_psrai_d:
+ NewIntNo = Intrinsic::x86_avx2_psra_d;
+ break;
default: {
ShAmtVT = MVT::v2i32;
switch (IntNo) {
@@ -10251,52 +10305,6 @@
}
}
- // AVX2 variable shifts
- if (Subtarget->hasAVX2()) {
- if (VT == MVT::v4i32 && Op->getOpcode() == ISD::SHL)
- return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
- DAG.getConstant(Intrinsic::x86_avx2_psllv_d, MVT::i32),
- R, Amt);
- if (VT == MVT::v8i32 && Op->getOpcode() == ISD::SHL)
- return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
- DAG.getConstant(Intrinsic::x86_avx2_psllv_d_256, MVT::i32),
- R, Amt);
- if (VT == MVT::v2i64 && Op->getOpcode() == ISD::SHL)
- return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
- DAG.getConstant(Intrinsic::x86_avx2_psllv_q, MVT::i32),
- R, Amt);
- if (VT == MVT::v4i64 && Op->getOpcode() == ISD::SHL)
- return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
- DAG.getConstant(Intrinsic::x86_avx2_psllv_q_256, MVT::i32),
- R, Amt);
-
- if (VT == MVT::v4i32 && Op->getOpcode() == ISD::SRL)
- return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
- DAG.getConstant(Intrinsic::x86_avx2_psrlv_d, MVT::i32),
- R, Amt);
- if (VT == MVT::v8i32 && Op->getOpcode() == ISD::SRL)
- return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
- DAG.getConstant(Intrinsic::x86_avx2_psrlv_d_256, MVT::i32),
- R, Amt);
- if (VT == MVT::v2i64 && Op->getOpcode() == ISD::SRL)
- return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
- DAG.getConstant(Intrinsic::x86_avx2_psrlv_q, MVT::i32),
- R, Amt);
- if (VT == MVT::v4i64 && Op->getOpcode() == ISD::SRL)
- return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
- DAG.getConstant(Intrinsic::x86_avx2_psrlv_q_256, MVT::i32),
- R, Amt);
-
- if (VT == MVT::v4i32 && Op->getOpcode() == ISD::SRA)
- return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
- DAG.getConstant(Intrinsic::x86_avx2_psrav_d, MVT::i32),
- R, Amt);
- if (VT == MVT::v8i32 && Op->getOpcode() == ISD::SRA)
- return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
- DAG.getConstant(Intrinsic::x86_avx2_psrav_d_256, MVT::i32),
- R, Amt);
- }
-
// Lower SHL with variable shift amount.
if (VT == MVT::v4i32 && Op->getOpcode() == ISD::SHL) {
Op = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT,
@@ -13464,7 +13472,9 @@
if (!Subtarget->hasXMMInt())
return SDValue();
- if (VT != MVT::v2i64 && VT != MVT::v4i32 && VT != MVT::v8i16)
+ if (VT != MVT::v2i64 && VT != MVT::v4i32 && VT != MVT::v8i16 &&
+ (!Subtarget->hasAVX2() ||
+ (VT != MVT::v4i64 && VT != MVT::v8i32 && VT != MVT::v16i16)))
return SDValue();
SDValue ShAmtOp = N->getOperand(1);
@@ -13537,6 +13547,18 @@
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
DAG.getConstant(Intrinsic::x86_sse2_pslli_w, MVT::i32),
ValOp, BaseShAmt);
+ if (VT == MVT::v4i64)
+ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
+ DAG.getConstant(Intrinsic::x86_avx2_pslli_q, MVT::i32),
+ ValOp, BaseShAmt);
+ if (VT == MVT::v8i32)
+ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
+ DAG.getConstant(Intrinsic::x86_avx2_pslli_d, MVT::i32),
+ ValOp, BaseShAmt);
+ if (VT == MVT::v16i16)
+ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
+ DAG.getConstant(Intrinsic::x86_avx2_pslli_w, MVT::i32),
+ ValOp, BaseShAmt);
break;
case ISD::SRA:
if (VT == MVT::v4i32)
@@ -13547,6 +13569,14 @@
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
DAG.getConstant(Intrinsic::x86_sse2_psrai_w, MVT::i32),
ValOp, BaseShAmt);
+ if (VT == MVT::v8i32)
+ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
+ DAG.getConstant(Intrinsic::x86_avx2_psrai_d, MVT::i32),
+ ValOp, BaseShAmt);
+ if (VT == MVT::v16i16)
+ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
+ DAG.getConstant(Intrinsic::x86_avx2_psrai_w, MVT::i32),
+ ValOp, BaseShAmt);
break;
case ISD::SRL:
if (VT == MVT::v2i64)
@@ -13561,6 +13591,18 @@
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
DAG.getConstant(Intrinsic::x86_sse2_psrli_w, MVT::i32),
ValOp, BaseShAmt);
+ if (VT == MVT::v4i64)
+ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
+ DAG.getConstant(Intrinsic::x86_avx2_psrli_q, MVT::i32),
+ ValOp, BaseShAmt);
+ if (VT == MVT::v8i32)
+ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
+ DAG.getConstant(Intrinsic::x86_avx2_psrli_d, MVT::i32),
+ ValOp, BaseShAmt);
+ if (VT == MVT::v16i16)
+ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
+ DAG.getConstant(Intrinsic::x86_avx2_psrli_w, MVT::i32),
+ ValOp, BaseShAmt);
break;
}
return SDValue();