[X86][AVX512BW] Vectorize v64i8 vector shifts

Differential Revision: https://reviews.llvm.org/D28447

llvm-svn: 291665
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 84c5118..31978ff 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -21752,14 +21752,26 @@
   }
 
   if (VT == MVT::v16i8 ||
-      (VT == MVT::v32i8 && Subtarget.hasInt256() && !Subtarget.hasXOP())) {
+      (VT == MVT::v32i8 && Subtarget.hasInt256() && !Subtarget.hasXOP()) ||
+      (VT == MVT::v64i8 && Subtarget.hasBWI())) {
     MVT ExtVT = MVT::getVectorVT(MVT::i16, VT.getVectorNumElements() / 2);
     unsigned ShiftOpcode = Op->getOpcode();
 
     auto SignBitSelect = [&](MVT SelVT, SDValue Sel, SDValue V0, SDValue V1) {
-      // On SSE41 targets we make use of the fact that VSELECT lowers
-      // to PBLENDVB which selects bytes based just on the sign bit.
-      if (Subtarget.hasSSE41()) {
+      if (VT.is512BitVector()) {
+        // On AVX512BW targets we make use of the fact that VSELECT lowers
+        // to a masked blend which selects bytes based just on the sign bit
+        // extracted to a mask.
+        MVT MaskVT = MVT::getVectorVT(MVT::i1, VT.getVectorNumElements());
+        V0 = DAG.getBitcast(VT, V0);
+        V1 = DAG.getBitcast(VT, V1);
+        Sel = DAG.getBitcast(VT, Sel);
+        Sel = DAG.getNode(X86ISD::CVT2MASK, dl, MaskVT, Sel);
+        return DAG.getBitcast(SelVT,
+                              DAG.getNode(ISD::VSELECT, dl, VT, Sel, V0, V1));
+      } else if (Subtarget.hasSSE41()) {
+        // On SSE41 targets we make use of the fact that VSELECT lowers
+        // to PBLENDVB which selects bytes based just on the sign bit.
         V0 = DAG.getBitcast(VT, V0);
         V1 = DAG.getBitcast(VT, V1);
         Sel = DAG.getBitcast(VT, Sel);
diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
index 586786d..5715d82 100644
--- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
+++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp
@@ -323,6 +323,10 @@
     { ISD::SRL,   MVT::v32i16,     1 }, // vpsrlvw
     { ISD::SRA,   MVT::v32i16,     1 }, // vpsravw
 
+    { ISD::SHL,   MVT::v64i8,     11 }, // vpblendvb sequence.
+    { ISD::SRL,   MVT::v64i8,     11 }, // vpblendvb sequence.
+    { ISD::SRA,   MVT::v64i8,     24 }, // vpblendvb sequence.
+
     { ISD::MUL,   MVT::v64i8,     11 }, // extend/pmullw/trunc sequence.
     { ISD::MUL,   MVT::v32i8,      4 }, // extend/pmullw/trunc sequence.
     { ISD::MUL,   MVT::v16i8,      4 }, // extend/pmullw/trunc sequence.