AVX512BW: Support llvm intrinsic masked vector load/store for i8/i16 element types on SKX
Differential Revision: http://reviews.llvm.org/D17913
llvm-svn: 262803
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 7baea19..4608df1 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -1687,6 +1687,12 @@
if (Subtarget.hasVLX())
setTruncStoreAction(MVT::v8i16, MVT::v8i8, Legal);
+ LegalizeAction Action = Subtarget.hasVLX() ? Legal : Custom;
+ for (auto VT : { MVT::v32i8, MVT::v16i8, MVT::v16i16, MVT::v8i16 }) {
+ setOperationAction(ISD::MLOAD, VT, Action);
+ setOperationAction(ISD::MSTORE, VT, Action);
+ }
+
if (Subtarget.hasCDI()) {
setOperationAction(ISD::CTLZ, MVT::v32i16, Custom);
setOperationAction(ISD::CTLZ, MVT::v64i8, Custom);
@@ -1700,6 +1706,8 @@
setOperationAction(ISD::SRL, VT, Custom);
setOperationAction(ISD::SHL, VT, Custom);
setOperationAction(ISD::SRA, VT, Custom);
+ setOperationAction(ISD::MLOAD, VT, Legal);
+ setOperationAction(ISD::MSTORE, VT, Legal);
setOperationAction(ISD::AND, VT, Promote);
AddPromotedToType (ISD::AND, VT, MVT::v8i64);
@@ -20786,31 +20794,36 @@
MaskedLoadSDNode *N = cast<MaskedLoadSDNode>(Op.getNode());
MVT VT = Op.getSimpleValueType();
+ MVT ScalarVT = VT.getScalarType();
SDValue Mask = N->getMask();
SDLoc dl(Op);
- if (Subtarget.hasAVX512() && !Subtarget.hasVLX() &&
- !VT.is512BitVector() && Mask.getValueType() == MVT::v8i1) {
- // This operation is legal for targets with VLX, but without
- // VLX the vector should be widened to 512 bit
- unsigned NumEltsInWideVec = 512/VT.getScalarSizeInBits();
- MVT WideDataVT = MVT::getVectorVT(VT.getScalarType(), NumEltsInWideVec);
- MVT WideMaskVT = MVT::getVectorVT(MVT::i1, NumEltsInWideVec);
- SDValue Src0 = N->getSrc0();
- Src0 = ExtendToType(Src0, WideDataVT, DAG);
- Mask = ExtendToType(Mask, WideMaskVT, DAG, true);
- SDValue NewLoad = DAG.getMaskedLoad(WideDataVT, dl, N->getChain(),
- N->getBasePtr(), Mask, Src0,
- N->getMemoryVT(), N->getMemOperand(),
- N->getExtensionType());
+ assert(Subtarget.hasAVX512() && !Subtarget.hasVLX() && !VT.is512BitVector() &&
+ "Cannot lower masked load op.");
- SDValue Exract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT,
- NewLoad.getValue(0),
- DAG.getIntPtrConstant(0, dl));
- SDValue RetOps[] = {Exract, NewLoad.getValue(1)};
- return DAG.getMergeValues(RetOps, dl);
- }
- return Op;
+ assert(((ScalarVT == MVT::i32 || ScalarVT == MVT::f32) ||
+ (Subtarget.hasBWI() &&
+ (ScalarVT == MVT::i8 || ScalarVT == MVT::i16))) &&
+ "Unsupported masked load op.");
+
+ // This operation is legal for targets with VLX, but without
+ // VLX the vector should be widened to 512 bit
+ unsigned NumEltsInWideVec = 512/VT.getScalarSizeInBits();
+ MVT WideDataVT = MVT::getVectorVT(ScalarVT, NumEltsInWideVec);
+ MVT WideMaskVT = MVT::getVectorVT(MVT::i1, NumEltsInWideVec);
+ SDValue Src0 = N->getSrc0();
+ Src0 = ExtendToType(Src0, WideDataVT, DAG);
+ Mask = ExtendToType(Mask, WideMaskVT, DAG, true);
+ SDValue NewLoad = DAG.getMaskedLoad(WideDataVT, dl, N->getChain(),
+ N->getBasePtr(), Mask, Src0,
+ N->getMemoryVT(), N->getMemOperand(),
+ N->getExtensionType());
+
+ SDValue Exract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT,
+ NewLoad.getValue(0),
+ DAG.getIntPtrConstant(0, dl));
+ SDValue RetOps[] = {Exract, NewLoad.getValue(1)};
+ return DAG.getMergeValues(RetOps, dl);
}
static SDValue LowerMSTORE(SDValue Op, const X86Subtarget &Subtarget,
@@ -20818,23 +20831,28 @@
MaskedStoreSDNode *N = cast<MaskedStoreSDNode>(Op.getNode());
SDValue DataToStore = N->getValue();
MVT VT = DataToStore.getSimpleValueType();
+ MVT ScalarVT = VT.getScalarType();
SDValue Mask = N->getMask();
SDLoc dl(Op);
- if (Subtarget.hasAVX512() && !Subtarget.hasVLX() &&
- !VT.is512BitVector() && Mask.getValueType() == MVT::v8i1) {
- // This operation is legal for targets with VLX, but without
- // VLX the vector should be widened to 512 bit
- unsigned NumEltsInWideVec = 512/VT.getScalarSizeInBits();
- MVT WideDataVT = MVT::getVectorVT(VT.getScalarType(), NumEltsInWideVec);
- MVT WideMaskVT = MVT::getVectorVT(MVT::i1, NumEltsInWideVec);
- DataToStore = ExtendToType(DataToStore, WideDataVT, DAG);
- Mask = ExtendToType(Mask, WideMaskVT, DAG, true);
- return DAG.getMaskedStore(N->getChain(), dl, DataToStore, N->getBasePtr(),
- Mask, N->getMemoryVT(), N->getMemOperand(),
- N->isTruncatingStore());
- }
- return Op;
+ assert(Subtarget.hasAVX512() && !Subtarget.hasVLX() && !VT.is512BitVector() &&
+ "Cannot lower masked store op.");
+
+ assert(((ScalarVT == MVT::i32 || ScalarVT == MVT::f32) ||
+ (Subtarget.hasBWI() &&
+ (ScalarVT == MVT::i8 || ScalarVT == MVT::i16))) &&
+ "Unsupported masked store op.");
+
+ // This operation is legal for targets with VLX, but without
+ // VLX the vector should be widened to 512 bit
+ unsigned NumEltsInWideVec = 512/VT.getScalarSizeInBits();
+ MVT WideDataVT = MVT::getVectorVT(ScalarVT, NumEltsInWideVec);
+ MVT WideMaskVT = MVT::getVectorVT(MVT::i1, NumEltsInWideVec);
+ DataToStore = ExtendToType(DataToStore, WideDataVT, DAG);
+ Mask = ExtendToType(Mask, WideMaskVT, DAG, true);
+ return DAG.getMaskedStore(N->getChain(), dl, DataToStore, N->getBasePtr(),
+ Mask, N->getMemoryVT(), N->getMemOperand(),
+ N->isTruncatingStore());
}
static SDValue LowerMGATHER(SDValue Op, const X86Subtarget &Subtarget,
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index 397a0f2..efa7feb 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -1438,7 +1438,8 @@
int DataWidth = isa<PointerType>(ScalarTy) ?
DL.getPointerSizeInBits() : ScalarTy->getPrimitiveSizeInBits();
- return (DataWidth >= 32 && ST->hasAVX());
+ return (DataWidth >= 32 && ST->hasAVX()) ||
+ (DataWidth >= 8 && ST->hasBWI());
}
bool X86TTIImpl::isLegalMaskedStore(Type *DataType) {