Break 256-bit vector int add/sub/mul into two 128-bit operations to avoid costly scalarization. Fixes PR10711.
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@138427 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp
index 88566c6..a1f0732 100644
--- a/lib/Target/X86/X86ISelLowering.cpp
+++ b/lib/Target/X86/X86ISelLowering.cpp
@@ -998,6 +998,21 @@
setOperationAction(ISD::SELECT, MVT::v4i64, Custom);
setOperationAction(ISD::SELECT, MVT::v8f32, Custom);
+ setOperationAction(ISD::ADD, MVT::v4i64, Custom);
+ setOperationAction(ISD::ADD, MVT::v8i32, Custom);
+ setOperationAction(ISD::ADD, MVT::v16i16, Custom);
+ setOperationAction(ISD::ADD, MVT::v32i8, Custom);
+
+ setOperationAction(ISD::SUB, MVT::v4i64, Custom);
+ setOperationAction(ISD::SUB, MVT::v8i32, Custom);
+ setOperationAction(ISD::SUB, MVT::v16i16, Custom);
+ setOperationAction(ISD::SUB, MVT::v32i8, Custom);
+
+ setOperationAction(ISD::MUL, MVT::v4i64, Custom);
+ setOperationAction(ISD::MUL, MVT::v8i32, Custom);
+ setOperationAction(ISD::MUL, MVT::v16i16, Custom);
+ // Don't lower v32i8 because there is no 128-bit byte mul
+
// Custom lower several nodes for 256-bit types.
for (unsigned i = (unsigned)MVT::FIRST_VECTOR_VALUETYPE;
i <= (unsigned)MVT::LAST_VECTOR_VALUETYPE; ++i) {
@@ -9422,8 +9437,58 @@
return Op;
}
-SDValue X86TargetLowering::LowerMUL_V2I64(SDValue Op, SelectionDAG &DAG) const {
+// Lower256IntArith - Break a 256-bit integer operation into two new 128-bit
+// ones, and then concatenate the result back.
+static SDValue Lower256IntArith(SDValue Op, SelectionDAG &DAG) {
EVT VT = Op.getValueType();
+
+ assert(VT.getSizeInBits() == 256 && VT.isInteger() &&
+ "Unsupported value type for operation");
+
+ int NumElems = VT.getVectorNumElements();
+ DebugLoc dl = Op.getDebugLoc();
+ SDValue Idx0 = DAG.getConstant(0, MVT::i32);
+ SDValue Idx1 = DAG.getConstant(NumElems/2, MVT::i32);
+
+ // Extract the LHS vectors
+ SDValue LHS = Op.getOperand(0);
+ SDValue LHS1 = Extract128BitVector(LHS, Idx0, DAG, dl);
+ SDValue LHS2 = Extract128BitVector(LHS, Idx1, DAG, dl);
+
+ // Extract the RHS vectors
+ SDValue RHS = Op.getOperand(1);
+ SDValue RHS1 = Extract128BitVector(RHS, Idx0, DAG, dl);
+ SDValue RHS2 = Extract128BitVector(RHS, Idx1, DAG, dl);
+
+ MVT EltVT = VT.getVectorElementType().getSimpleVT();
+ EVT NewVT = MVT::getVectorVT(EltVT, NumElems/2);
+
+ return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT,
+ DAG.getNode(Op.getOpcode(), dl, NewVT, LHS1, RHS1),
+ DAG.getNode(Op.getOpcode(), dl, NewVT, LHS2, RHS2));
+}
+
+SDValue X86TargetLowering::LowerADD(SDValue Op, SelectionDAG &DAG) const {
+ assert(Op.getValueType().getSizeInBits() == 256 &&
+ Op.getValueType().isInteger() &&
+ "Only handle AVX 256-bit vector integer operation");
+ return Lower256IntArith(Op, DAG);
+}
+
+SDValue X86TargetLowering::LowerSUB(SDValue Op, SelectionDAG &DAG) const {
+ assert(Op.getValueType().getSizeInBits() == 256 &&
+ Op.getValueType().isInteger() &&
+ "Only handle AVX 256-bit vector integer operation");
+ return Lower256IntArith(Op, DAG);
+}
+
+SDValue X86TargetLowering::LowerMUL(SDValue Op, SelectionDAG &DAG) const {
+ EVT VT = Op.getValueType();
+
+ // Decompose 256-bit ops into smaller 128-bit ops.
+ if (VT.getSizeInBits() == 256)
+ return Lower256IntArith(Op, DAG);
+
assert(VT == MVT::v2i64 && "Only know how to lower V2I64 multiply");
DebugLoc dl = Op.getDebugLoc();
@@ -10013,7 +10078,7 @@
case ISD::FLT_ROUNDS_: return LowerFLT_ROUNDS_(Op, DAG);
case ISD::CTLZ: return LowerCTLZ(Op, DAG);
case ISD::CTTZ: return LowerCTTZ(Op, DAG);
- case ISD::MUL: return LowerMUL_V2I64(Op, DAG);
+ case ISD::MUL: return LowerMUL(Op, DAG);
case ISD::SRA:
case ISD::SRL:
case ISD::SHL: return LowerShift(Op, DAG);
@@ -10029,6 +10094,8 @@
case ISD::ADDE:
case ISD::SUBC:
case ISD::SUBE: return LowerADDC_ADDE_SUBC_SUBE(Op, DAG);
+ case ISD::ADD: return LowerADD(Op, DAG);
+ case ISD::SUB: return LowerSUB(Op, DAG);
}
}