[mips][msa] Implemented build_vector using ldi, fill, and custom SelectionDAG nodes (VSPLAT and VSPLATD)
Note: There's a later patch on my branch that re-implements this to select
build_vector without the custom SelectionDAG nodes. The future patch avoids
the constant-folding problems stemming from the custom node (i.e. it doesn't
need to re-implement all the DAG combines related to BUILD_VECTOR).
Changes to MIPS specific SelectionDAG nodes:
* Added VSPLAT
This is a special case of BUILD_VECTOR that covers the case the
BUILD_VECTOR is a splat operation.
* Added VSPLATD
This is a special case of VSPLAT that handles the cases when v2i64 is legal
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@191191 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Target/Mips/MipsISelLowering.cpp b/lib/Target/Mips/MipsISelLowering.cpp
index 220955e..21c5edb 100644
--- a/lib/Target/Mips/MipsISelLowering.cpp
+++ b/lib/Target/Mips/MipsISelLowering.cpp
@@ -212,6 +212,8 @@
case MipsISD::VANY_ZERO: return "MipsISD::VANY_ZERO";
case MipsISD::VALL_NONZERO: return "MipsISD::VALL_NONZERO";
case MipsISD::VANY_NONZERO: return "MipsISD::VANY_NONZERO";
+ case MipsISD::VSPLAT: return "MipsISD::VSPLAT";
+ case MipsISD::VSPLATD: return "MipsISD::VSPLATD";
default: return NULL;
}
}
diff --git a/lib/Target/Mips/MipsISelLowering.h b/lib/Target/Mips/MipsISelLowering.h
index 85aa162..57b5603 100644
--- a/lib/Target/Mips/MipsISelLowering.h
+++ b/lib/Target/Mips/MipsISelLowering.h
@@ -152,12 +152,18 @@
SETCC_DSP,
SELECT_CC_DSP,
- // Vector comparisons
+ // Vector comparisons.
VALL_ZERO,
VANY_ZERO,
VALL_NONZERO,
VANY_NONZERO,
+ // Special case of BUILD_VECTOR where all elements are the same.
+ VSPLAT,
+ // Special case of VSPLAT where the result is v2i64, the operand is
+ // constant, and the operand fits in a signed 10-bits value.
+ VSPLATD,
+
// Load/Store Left/Right nodes.
LWL = ISD::FIRST_TARGET_MEMORY_OPCODE,
LWR,
diff --git a/lib/Target/Mips/MipsMSAInstrInfo.td b/lib/Target/Mips/MipsMSAInstrInfo.td
index d4dcbd1..68b835e 100644
--- a/lib/Target/Mips/MipsMSAInstrInfo.td
+++ b/lib/Target/Mips/MipsMSAInstrInfo.td
@@ -11,12 +11,20 @@
//
//===----------------------------------------------------------------------===//
+def SDT_MipsSplat : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisInt<1>]>;
def SDT_MipsVecCond : SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisVec<1>]>;
def MipsVAllNonZero : SDNode<"MipsISD::VALL_NONZERO", SDT_MipsVecCond>;
def MipsVAnyNonZero : SDNode<"MipsISD::VANY_NONZERO", SDT_MipsVecCond>;
def MipsVAllZero : SDNode<"MipsISD::VALL_ZERO", SDT_MipsVecCond>;
def MipsVAnyZero : SDNode<"MipsISD::VANY_ZERO", SDT_MipsVecCond>;
+def MipsVSplat : SDNode<"MipsISD::VSPLAT", SDT_MipsSplat>;
+def MipsVSplatD : SDNode<"MipsISD::VSPLATD", SDT_MipsSplat>;
+
+def vsplati8 : PatFrag<(ops node:$in), (v16i8 (MipsVSplat (i32 node:$in)))>;
+def vsplati16 : PatFrag<(ops node:$in), (v8i16 (MipsVSplat (i32 node:$in)))>;
+def vsplati32 : PatFrag<(ops node:$in), (v4i32 (MipsVSplat (i32 node:$in)))>;
+def vsplati64 : PatFrag<(ops node:$in), (v2i64 (MipsVSplatD (i32 node:$in)))>;
// Immediates
def immSExt5 : ImmLeaf<i32, [{return isInt<5>(Imm);}]>;
@@ -1383,12 +1391,9 @@
class FFQR_D_DESC : MSA_2RF_DESC_BASE<"ffqr.d", int_mips_ffqr_d,
MSA128D, MSA128W>;
-class FILL_B_DESC : MSA_2R_DESC_BASE<"fill.b", int_mips_fill_b,
- MSA128B, GPR32>;
-class FILL_H_DESC : MSA_2R_DESC_BASE<"fill.h", int_mips_fill_h,
- MSA128H, GPR32>;
-class FILL_W_DESC : MSA_2R_DESC_BASE<"fill.w", int_mips_fill_w,
- MSA128W, GPR32>;
+class FILL_B_DESC : MSA_2R_DESC_BASE<"fill.b", vsplati8, MSA128B, GPR32>;
+class FILL_H_DESC : MSA_2R_DESC_BASE<"fill.h", vsplati16, MSA128H, GPR32>;
+class FILL_W_DESC : MSA_2R_DESC_BASE<"fill.w", vsplati32, MSA128W, GPR32>;
class FLOG2_W_DESC : MSA_2RF_DESC_BASE<"flog2.w", flog2, MSA128W>;
class FLOG2_D_DESC : MSA_2RF_DESC_BASE<"flog2.d", flog2, MSA128D>;
@@ -1573,10 +1578,10 @@
class LD_W_DESC : LD_DESC_BASE<"ld.w", load, v4i32, MSA128W>;
class LD_D_DESC : LD_DESC_BASE<"ld.d", load, v2i64, MSA128D>;
-class LDI_B_DESC : MSA_I10_DESC_BASE<"ldi.b", int_mips_ldi_b, MSA128B>;
-class LDI_H_DESC : MSA_I10_DESC_BASE<"ldi.h", int_mips_ldi_h, MSA128H>;
-class LDI_W_DESC : MSA_I10_DESC_BASE<"ldi.w", int_mips_ldi_w, MSA128W>;
-class LDI_D_DESC : MSA_I10_DESC_BASE<"ldi.d", int_mips_ldi_d, MSA128D>;
+class LDI_B_DESC : MSA_I10_DESC_BASE<"ldi.b", vsplati8, MSA128B>;
+class LDI_H_DESC : MSA_I10_DESC_BASE<"ldi.h", vsplati16, MSA128H>;
+class LDI_W_DESC : MSA_I10_DESC_BASE<"ldi.w", vsplati32, MSA128W>;
+class LDI_D_DESC : MSA_I10_DESC_BASE<"ldi.d", vsplati64, MSA128D>;
class LDX_DESC_BASE<string instr_asm, SDPatternOperator OpNode,
ValueType TyNode, RegisterClass RCWD,
@@ -2356,6 +2361,7 @@
def LDI_B : LDI_B_ENC, LDI_B_DESC;
def LDI_H : LDI_H_ENC, LDI_H_DESC;
def LDI_W : LDI_W_ENC, LDI_W_DESC;
+def LDI_D : LDI_D_ENC, LDI_D_DESC;
def LDX_B: LDX_B_ENC, LDX_B_DESC;
def LDX_H: LDX_H_ENC, LDX_H_DESC;
diff --git a/lib/Target/Mips/MipsSEISelLowering.cpp b/lib/Target/Mips/MipsSEISelLowering.cpp
index 879df6d..3b446c5 100644
--- a/lib/Target/Mips/MipsSEISelLowering.cpp
+++ b/lib/Target/Mips/MipsSEISelLowering.cpp
@@ -147,6 +147,7 @@
return new MipsSETargetLowering(TM);
}
+// Enable MSA support for the given integer type and Register class.
void MipsSETargetLowering::
addMSAIntType(MVT::SimpleValueType Ty, const TargetRegisterClass *RC) {
addRegisterClass(Ty, RC);
@@ -158,6 +159,7 @@
setOperationAction(ISD::BITCAST, Ty, Legal);
setOperationAction(ISD::LOAD, Ty, Legal);
setOperationAction(ISD::STORE, Ty, Legal);
+ setOperationAction(ISD::BUILD_VECTOR, Ty, Custom);
setOperationAction(ISD::ADD, Ty, Legal);
setOperationAction(ISD::CTLZ, Ty, Legal);
@@ -170,6 +172,7 @@
setOperationAction(ISD::UDIV, Ty, Legal);
}
+// Enable MSA support for the given floating-point type and Register class.
void MipsSETargetLowering::
addMSAFloatType(MVT::SimpleValueType Ty, const TargetRegisterClass *RC) {
addRegisterClass(Ty, RC);
@@ -224,6 +227,7 @@
case ISD::INTRINSIC_WO_CHAIN: return lowerINTRINSIC_WO_CHAIN(Op, DAG);
case ISD::INTRINSIC_W_CHAIN: return lowerINTRINSIC_W_CHAIN(Op, DAG);
case ISD::INTRINSIC_VOID: return lowerINTRINSIC_VOID(Op, DAG);
+ case ISD::BUILD_VECTOR: return lowerBUILD_VECTOR(Op, DAG);
}
return MipsTargetLowering::LowerOperation(Op, DAG);
@@ -921,6 +925,10 @@
case Intrinsic::mips_fdiv_w:
case Intrinsic::mips_fdiv_d:
return lowerMSABinaryIntr(Op, DAG, ISD::FDIV);
+ case Intrinsic::mips_fill_b:
+ case Intrinsic::mips_fill_h:
+ case Intrinsic::mips_fill_w:
+ return lowerMSAUnaryIntr(Op, DAG, MipsISD::VSPLAT);
case Intrinsic::mips_flog2_w:
case Intrinsic::mips_flog2_d:
return lowerMSAUnaryIntr(Op, DAG, ISD::FLOG2);
@@ -936,6 +944,11 @@
case Intrinsic::mips_fsub_w:
case Intrinsic::mips_fsub_d:
return lowerMSABinaryIntr(Op, DAG, ISD::FSUB);
+ case Intrinsic::mips_ldi_b:
+ case Intrinsic::mips_ldi_h:
+ case Intrinsic::mips_ldi_w:
+ case Intrinsic::mips_ldi_d:
+ return lowerMSAUnaryIntr(Op, DAG, MipsISD::VSPLAT);
case Intrinsic::mips_mulv_b:
case Intrinsic::mips_mulv_h:
case Intrinsic::mips_mulv_w:
@@ -1073,6 +1086,102 @@
}
}
+/// \brief Check if the given BuildVectorSDNode is a splat.
+/// This method currently relies on DAG nodes being reused when equivalent,
+/// so it's possible for this to return false even when isConstantSplat returns
+/// true.
+static bool isSplatVector(const BuildVectorSDNode *N) {
+ EVT VT = N->getValueType(0);
+ assert(VT.isVector() && "Expected a vector type");
+
+ unsigned int nOps = N->getNumOperands();
+ assert(nOps > 1 && "isSplat has 0 or 1 sized build vector");
+
+ SDValue Operand0 = N->getOperand(0);
+
+ for (unsigned int i = 1; i < nOps; ++i) {
+ if (N->getOperand(i) != Operand0)
+ return false;
+ }
+
+ return true;
+}
+
+// Lowers ISD::BUILD_VECTOR into appropriate SelectionDAG nodes for the
+// backend.
+//
+// Lowers according to the following rules:
+// - Vectors of 128-bits may be legal subject to the other rules. Other sizes
+// are not legal.
+// - Non-constant splats are legal and are lowered to MipsISD::VSPLAT.
+// - Constant splats with an element size of 32-bits or less are legal and are
+// lowered to MipsISD::VSPLAT.
+// - Constant splats with an element size of 64-bits but whose value would fit
+// within a 10 bit immediate are legal and are lowered to MipsISD::VSPLATD.
+// - All other ISD::BUILD_VECTORS are not legal
+SDValue MipsSETargetLowering::lowerBUILD_VECTOR(SDValue Op,
+ SelectionDAG &DAG) const {
+ BuildVectorSDNode *Node = cast<BuildVectorSDNode>(Op);
+ EVT ResTy = Op->getValueType(0);
+ SDLoc DL(Op);
+ APInt SplatValue, SplatUndef;
+ unsigned SplatBitSize;
+ bool HasAnyUndefs;
+
+ if (!Subtarget->hasMSA() || !ResTy.is128BitVector())
+ return SDValue();
+
+ if (Node->isConstantSplat(SplatValue, SplatUndef, SplatBitSize,
+ HasAnyUndefs, 8,
+ !Subtarget->isLittle())) {
+ SDValue Result;
+ EVT TmpVecTy;
+ EVT ConstTy = MVT::i32;
+ unsigned SplatOp = MipsISD::VSPLAT;
+
+ switch (SplatBitSize) {
+ default:
+ return SDValue();
+ case 64:
+ TmpVecTy = MVT::v2i64;
+
+ // i64 is an illegal type on Mips32, but if it the constant fits into a
+ // signed 10-bit value then we can still handle it using VSPLATD and an
+ // i32 constant
+ if (HasMips64)
+ ConstTy = MVT::i64;
+ else if (isInt<10>(SplatValue.getSExtValue())) {
+ SplatValue = SplatValue.trunc(32);
+ SplatOp = MipsISD::VSPLATD;
+ } else
+ return SDValue();
+ break;
+ case 32:
+ TmpVecTy = MVT::v4i32;
+ break;
+ case 16:
+ TmpVecTy = MVT::v8i16;
+ SplatValue = SplatValue.sext(32);
+ break;
+ case 8:
+ TmpVecTy = MVT::v16i8;
+ SplatValue = SplatValue.sext(32);
+ break;
+ }
+
+ Result = DAG.getNode(SplatOp, DL, TmpVecTy,
+ DAG.getConstant(SplatValue, ConstTy));
+ if (ResTy != Result.getValueType())
+ Result = DAG.getNode(ISD::BITCAST, DL, ResTy, Result);
+
+ return Result;
+ }
+ else if (isSplatVector(Node))
+ return DAG.getNode(MipsISD::VSPLAT, DL, ResTy, Op->getOperand(0));
+
+ return SDValue();
+}
+
MachineBasicBlock * MipsSETargetLowering::
emitBPOSGE32(MachineInstr *MI, MachineBasicBlock *BB) const{
// $bb:
diff --git a/lib/Target/Mips/MipsSEISelLowering.h b/lib/Target/Mips/MipsSEISelLowering.h
index 016d4ad..909ab7d 100644
--- a/lib/Target/Mips/MipsSEISelLowering.h
+++ b/lib/Target/Mips/MipsSEISelLowering.h
@@ -22,7 +22,11 @@
public:
explicit MipsSETargetLowering(MipsTargetMachine &TM);
+ /// \brief Enable MSA support for the given integer type and Register
+ /// class.
void addMSAIntType(MVT::SimpleValueType Ty, const TargetRegisterClass *RC);
+ /// \brief Enable MSA support for the given floating-point type and
+ /// Register class.
void addMSAFloatType(MVT::SimpleValueType Ty,
const TargetRegisterClass *RC);
@@ -69,6 +73,7 @@
SDValue lowerINTRINSIC_WO_CHAIN(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerINTRINSIC_W_CHAIN(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerINTRINSIC_VOID(SDValue Op, SelectionDAG &DAG) const;
+ SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
MachineBasicBlock *emitBPOSGE32(MachineInstr *MI,
MachineBasicBlock *BB) const;