AVX512BW: Support llvm intrinsic masked vector load/store for i8/i16 element types on SKX

Differential Revision: http://reviews.llvm.org/D17913

llvm-svn: 262803
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 7baea19..4608df1 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -1687,6 +1687,12 @@
     if (Subtarget.hasVLX())
       setTruncStoreAction(MVT::v8i16,   MVT::v8i8,  Legal);
 
+    LegalizeAction Action = Subtarget.hasVLX() ? Legal : Custom;
+    for (auto VT : { MVT::v32i8, MVT::v16i8, MVT::v16i16, MVT::v8i16 }) {
+      setOperationAction(ISD::MLOAD,               VT, Action);
+      setOperationAction(ISD::MSTORE,              VT, Action);
+    }
+
     if (Subtarget.hasCDI()) {
       setOperationAction(ISD::CTLZ,            MVT::v32i16, Custom);
       setOperationAction(ISD::CTLZ,            MVT::v64i8,  Custom);
@@ -1700,6 +1706,8 @@
       setOperationAction(ISD::SRL,                 VT, Custom);
       setOperationAction(ISD::SHL,                 VT, Custom);
       setOperationAction(ISD::SRA,                 VT, Custom);
+      setOperationAction(ISD::MLOAD,               VT, Legal);
+      setOperationAction(ISD::MSTORE,              VT, Legal);
 
       setOperationAction(ISD::AND,    VT, Promote);
       AddPromotedToType (ISD::AND,    VT, MVT::v8i64);
@@ -20786,31 +20794,36 @@
 
   MaskedLoadSDNode *N = cast<MaskedLoadSDNode>(Op.getNode());
   MVT VT = Op.getSimpleValueType();
+  MVT ScalarVT = VT.getScalarType();
   SDValue Mask = N->getMask();
   SDLoc dl(Op);
 
-  if (Subtarget.hasAVX512() && !Subtarget.hasVLX() &&
-      !VT.is512BitVector() && Mask.getValueType() == MVT::v8i1) {
-    // This operation is legal for targets with VLX, but without
-    // VLX the vector should be widened to 512 bit
-    unsigned NumEltsInWideVec = 512/VT.getScalarSizeInBits();
-    MVT WideDataVT = MVT::getVectorVT(VT.getScalarType(), NumEltsInWideVec);
-    MVT WideMaskVT = MVT::getVectorVT(MVT::i1, NumEltsInWideVec);
-    SDValue Src0 = N->getSrc0();
-    Src0 = ExtendToType(Src0, WideDataVT, DAG);
-    Mask = ExtendToType(Mask, WideMaskVT, DAG, true);
-    SDValue NewLoad = DAG.getMaskedLoad(WideDataVT, dl, N->getChain(),
-                                        N->getBasePtr(), Mask, Src0,
-                                        N->getMemoryVT(), N->getMemOperand(),
-                                        N->getExtensionType());
+  assert(Subtarget.hasAVX512() && !Subtarget.hasVLX() && !VT.is512BitVector() &&
+         "Cannot lower masked load op.");
 
-    SDValue Exract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT,
-                                 NewLoad.getValue(0),
-                                 DAG.getIntPtrConstant(0, dl));
-    SDValue RetOps[] = {Exract, NewLoad.getValue(1)};
-    return DAG.getMergeValues(RetOps, dl);
-  }
-  return Op;
+  assert(((ScalarVT == MVT::i32 || ScalarVT == MVT::f32) ||
+          (Subtarget.hasBWI() &&
+              (ScalarVT == MVT::i8 || ScalarVT == MVT::i16))) &&
+         "Unsupported masked load op.");
+
+  // This operation is legal for targets with VLX, but without
+  // VLX the vector should be widened to 512 bit
+  unsigned NumEltsInWideVec = 512/VT.getScalarSizeInBits();
+  MVT WideDataVT = MVT::getVectorVT(ScalarVT, NumEltsInWideVec);
+  MVT WideMaskVT = MVT::getVectorVT(MVT::i1, NumEltsInWideVec);
+  SDValue Src0 = N->getSrc0();
+  Src0 = ExtendToType(Src0, WideDataVT, DAG);
+  Mask = ExtendToType(Mask, WideMaskVT, DAG, true);
+  SDValue NewLoad = DAG.getMaskedLoad(WideDataVT, dl, N->getChain(),
+                                      N->getBasePtr(), Mask, Src0,
+                                      N->getMemoryVT(), N->getMemOperand(),
+                                      N->getExtensionType());
+
+  SDValue Exract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT,
+                               NewLoad.getValue(0),
+                               DAG.getIntPtrConstant(0, dl));
+  SDValue RetOps[] = {Exract, NewLoad.getValue(1)};
+  return DAG.getMergeValues(RetOps, dl);
 }
 
 static SDValue LowerMSTORE(SDValue Op, const X86Subtarget &Subtarget,
@@ -20818,23 +20831,28 @@
   MaskedStoreSDNode *N = cast<MaskedStoreSDNode>(Op.getNode());
   SDValue DataToStore = N->getValue();
   MVT VT = DataToStore.getSimpleValueType();
+  MVT ScalarVT = VT.getScalarType();
   SDValue Mask = N->getMask();
   SDLoc dl(Op);
 
-  if (Subtarget.hasAVX512() && !Subtarget.hasVLX() &&
-      !VT.is512BitVector() && Mask.getValueType() == MVT::v8i1) {
-    // This operation is legal for targets with VLX, but without
-    // VLX the vector should be widened to 512 bit
-    unsigned NumEltsInWideVec = 512/VT.getScalarSizeInBits();
-    MVT WideDataVT = MVT::getVectorVT(VT.getScalarType(), NumEltsInWideVec);
-    MVT WideMaskVT = MVT::getVectorVT(MVT::i1, NumEltsInWideVec);
-    DataToStore = ExtendToType(DataToStore, WideDataVT, DAG);
-    Mask = ExtendToType(Mask, WideMaskVT, DAG, true);
-    return DAG.getMaskedStore(N->getChain(), dl, DataToStore, N->getBasePtr(),
-                              Mask, N->getMemoryVT(), N->getMemOperand(),
-                              N->isTruncatingStore());
-  }
-  return Op;
+  assert(Subtarget.hasAVX512() && !Subtarget.hasVLX() && !VT.is512BitVector() &&
+         "Cannot lower masked store op.");
+
+  assert(((ScalarVT == MVT::i32 || ScalarVT == MVT::f32) ||
+          (Subtarget.hasBWI() &&
+              (ScalarVT == MVT::i8 || ScalarVT == MVT::i16))) &&
+          "Unsupported masked store op.");
+
+  // This operation is legal for targets with VLX, but without
+  // VLX the vector should be widened to 512 bit
+  unsigned NumEltsInWideVec = 512/VT.getScalarSizeInBits();
+  MVT WideDataVT = MVT::getVectorVT(ScalarVT, NumEltsInWideVec);
+  MVT WideMaskVT = MVT::getVectorVT(MVT::i1, NumEltsInWideVec);
+  DataToStore = ExtendToType(DataToStore, WideDataVT, DAG);
+  Mask = ExtendToType(Mask, WideMaskVT, DAG, true);
+  return DAG.getMaskedStore(N->getChain(), dl, DataToStore, N->getBasePtr(),
+                            Mask, N->getMemoryVT(), N->getMemOperand(),
+                            N->isTruncatingStore());
 }
 
 static SDValue LowerMGATHER(SDValue Op, const X86Subtarget &Subtarget,
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index 397a0f2..efa7feb 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -1438,7 +1438,8 @@
   int DataWidth = isa<PointerType>(ScalarTy) ?
     DL.getPointerSizeInBits() : ScalarTy->getPrimitiveSizeInBits();
 
-  return (DataWidth >= 32 && ST->hasAVX());
+  return (DataWidth >= 32 && ST->hasAVX()) ||
+         (DataWidth >= 8 && ST->hasBWI());
 }
 
 bool X86TTIImpl::isLegalMaskedStore(Type *DataType) {