[X86] Teach X86 codegen to use vector width preference to avoid promoting to 512-bit types when VLX is enabled and the preference is for a smaller size.

This change applies to places where we would turn 128/256-bit code into 512-bit in order to get a wider element type through sext/zext. Any 512-bit types that already existed in the IR/DAG will be left that way.

The width preference has no effect on codegen behavior when the target does not have AVX512 enabled. So AVX/AVX2 codegen cannot be limited via this mechanism yet.

If the preference is lower than 256 we may still use a 256 bit type to do the operation. Constraining to 128 bits makes it much more difficult to support some operations. For many of these cases we need to change element width while keeping element count constant which is easiest done by switching between 256 and 128 bit.

The preference is only obeyed when AVX512 and VLX are available. This means the preference is not obeyed for KNL, but is obeyed for SKX, Cannonlake, and Icelake. For KNL, the only way to do masked operation is on 512-bit registers so we would have to completely disable masking to obey the preference. We would also lose support for gather, scatter, ctlz, vXi64 multiplies, etc. This may change in the future, but this simplifies the initial implementation.

Differential Revision: https://reviews.llvm.org/D41895

llvm-svn: 323016
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 074374b7..f4049bf 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -14323,10 +14323,15 @@
     ExtVT = Subtarget.hasVLX() ? MVT::v8i32 : MVT::v8i64;
     break;
   case MVT::v16i1:
-    ExtVT = MVT::v16i32;
+    // Take 512-bit type, unless we are avoiding 512-bit types and have the
+    // 256-bit operation available.
+    ExtVT = Subtarget.canExtendTo512DQ() ? MVT::v16i32 : MVT::v16i16;
     break;
   case MVT::v32i1:
-    ExtVT = MVT::v32i16;
+    // Take 512-bit type, unless we are avoiding 512-bit types and have the
+    // 256-bit operation available.
+    assert(Subtarget.hasBWI() && "Expected AVX512BW support");
+    ExtVT = Subtarget.canExtendTo512BW() ? MVT::v32i16 : MVT::v32i8;
     break;
   case MVT::v64i1:
     ExtVT = MVT::v64i8;
@@ -16361,6 +16366,20 @@
   return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, OpLo, OpHi);
 }
 
+// Helper to split and extend a v16i1 mask to v16i8 or v16i16.
+static SDValue SplitAndExtendv16i1(unsigned ExtOpc, MVT VT, SDValue In,
+                                   const SDLoc &dl, SelectionDAG &DAG) {
+  assert((VT == MVT::v16i8 || VT == MVT::v16i16) && "Unexpected VT.");
+  SDValue Lo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v8i1, In,
+                           DAG.getIntPtrConstant(0, dl));
+  SDValue Hi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v8i1, In,
+                           DAG.getIntPtrConstant(8, dl));
+  Lo = DAG.getNode(ExtOpc, dl, MVT::v8i16, Lo);
+  Hi = DAG.getNode(ExtOpc, dl, MVT::v8i16, Hi);
+  SDValue Res = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v16i16, Lo, Hi);
+  return DAG.getNode(ISD::TRUNCATE, dl, VT, Res);
+}
+
 static  SDValue LowerZERO_EXTEND_Mask(SDValue Op,
                                       const X86Subtarget &Subtarget,
                                       SelectionDAG &DAG) {
@@ -16374,8 +16393,13 @@
   // Extend VT if the scalar type is v8/v16 and BWI is not supported.
   MVT ExtVT = VT;
   if (!Subtarget.hasBWI() &&
-      (VT.getVectorElementType().getSizeInBits() <= 16))
+      (VT.getVectorElementType().getSizeInBits() <= 16)) {
+    // If v16i32 is to be avoided, we'll need to split and concatenate.
+    if (NumElts == 16 && !Subtarget.canExtendTo512DQ())
+      return SplitAndExtendv16i1(ISD::ZERO_EXTEND, VT, In, DL, DAG);
+
     ExtVT = MVT::getVectorVT(MVT::i32, NumElts);
+  }
 
   // Widen to 512-bits if VLX is not supported.
   MVT WideVT = ExtVT;
@@ -16549,6 +16573,33 @@
     assert((InVT.is256BitVector() || InVT.is128BitVector()) &&
            "Unexpected vector type.");
     unsigned NumElts = InVT.getVectorNumElements();
+    assert((NumElts == 8 || NumElts == 16) && "Unexpected number of elements");
+    // We need to change to a wider element type that we have support for.
+    // For 8 element vectors this is easy, we either extend to v8i32 or v8i64.
+    // For 16 element vectors we extend to v16i32 unless we are explicitly
+    // trying to avoid 512-bit vectors. If we are avoiding 512-bit vectors
+    // we need to split into two 8 element vectors which we can extend to v8i32,
+    // truncate and concat the results. There's an additional complication if
+    // the original type is v16i8. In that case we can't split the v16i8 so
+    // first we pre-extend it to v16i16 which we can split to v8i16, then extend
+    // to v8i32, truncate that to v8i1 and concat the two halves.
+    if (NumElts == 16 && !Subtarget.canExtendTo512DQ()) {
+      if (InVT == MVT::v16i8) {
+        // First we need to sign extend up to 256-bits so we can split that.
+        InVT = MVT::v16i16;
+        In = DAG.getNode(ISD::SIGN_EXTEND, DL, InVT, In);
+      }
+      SDValue Lo = extract128BitVector(In, 0, DAG, DL);
+      SDValue Hi = extract128BitVector(In, 8, DAG, DL);
+      // We're split now, just emit two truncates and a concat. The two
+      // truncates will trigger legalization to come back to this function.
+      Lo = DAG.getNode(ISD::TRUNCATE, DL, MVT::v8i1, Lo);
+      Hi = DAG.getNode(ISD::TRUNCATE, DL, MVT::v8i1, Hi);
+      return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Lo, Hi);
+    }
+    // We either have 8 elements or we're allowed to use 512-bit vectors.
+    // If we have VLX, we want to use the narrowest vector that can get the
+    // job done so we use vXi32.
     MVT EltVT = Subtarget.hasVLX() ? MVT::i32 : MVT::getIntegerVT(512/NumElts);
     MVT ExtVT = MVT::getVectorVT(EltVT, NumElts);
     In = DAG.getNode(ISD::SIGN_EXTEND, DL, ExtVT, In);
@@ -16580,10 +16631,15 @@
   // vpmovqb/w/d, vpmovdb/w, vpmovwb
   if (Subtarget.hasAVX512()) {
     // word to byte only under BWI
-    if (InVT == MVT::v16i16 && !Subtarget.hasBWI()) // v16i16 -> v16i8
-      return DAG.getNode(ISD::TRUNCATE, DL, VT,
-                         getExtendInVec(X86ISD::VSEXT, DL, MVT::v16i32, In, DAG));
-    return Op;
+    if (InVT == MVT::v16i16 && !Subtarget.hasBWI()) { // v16i16 -> v16i8
+      // Make sure we're allowed to promote 512-bits.
+      if (Subtarget.canExtendTo512DQ())
+        return DAG.getNode(ISD::TRUNCATE, DL, VT,
+                           getExtendInVec(X86ISD::VSEXT, DL, MVT::v16i32, In,
+                                          DAG));
+    } else {
+      return DAG.getNode(ISD::TRUNCATE, DL, VT, In);
+    }
   }
 
   // Truncate with PACKSS if we are truncating a vector with sign-bits that
@@ -18561,8 +18617,13 @@
 
   // Extend VT if the scalar type is v8/v16 and BWI is not supported.
   MVT ExtVT = VT;
-  if (!Subtarget.hasBWI() && VTElt.getSizeInBits() <= 16)
+  if (!Subtarget.hasBWI() && VTElt.getSizeInBits() <= 16) {
+    // If v16i32 is to be avoided, we'll need to split and concatenate.
+    if (NumElts == 16 && !Subtarget.canExtendTo512DQ())
+      return SplitAndExtendv16i1(ISD::SIGN_EXTEND, VT, In, dl, DAG);
+
     ExtVT = MVT::getVectorVT(MVT::i32, NumElts);
+  }
 
   // Widen to 512-bits if VLX is not supported.
   MVT WideVT = ExtVT;
@@ -21845,7 +21906,8 @@
 // ( sub(trunc(lzcnt(zext32(x)))) ). In case zext32(x) is illegal,
 // split the vector, perform operation on it's Lo a Hi part and
 // concatenate the results.
-static SDValue LowerVectorCTLZ_AVX512CDI(SDValue Op, SelectionDAG &DAG) {
+static SDValue LowerVectorCTLZ_AVX512CDI(SDValue Op, SelectionDAG &DAG,
+                                         const X86Subtarget &Subtarget) {
   assert(Op.getOpcode() == ISD::CTLZ);
   SDLoc dl(Op);
   MVT VT = Op.getSimpleValueType();
@@ -21856,7 +21918,8 @@
           "Unsupported element type");
 
   // Split vector, it's Lo and Hi parts will be handled in next iteration.
-  if (16 < NumElems)
+  if (NumElems > 16 ||
+      (NumElems == 16 && !Subtarget.canExtendTo512DQ()))
     return LowerVectorIntUnary(Op, DAG);
 
   MVT NewVT = MVT::getVectorVT(MVT::i32, NumElems);
@@ -21961,8 +22024,10 @@
                                SelectionDAG &DAG) {
   MVT VT = Op.getSimpleValueType();
 
-  if (Subtarget.hasCDI())
-    return LowerVectorCTLZ_AVX512CDI(Op, DAG);
+  if (Subtarget.hasCDI() &&
+      // vXi8 vectors need to be promoted to 512-bits for vXi32.
+      (Subtarget.canExtendTo512DQ() || VT.getVectorElementType() != MVT::i8))
+    return LowerVectorCTLZ_AVX512CDI(Op, DAG, Subtarget);
 
   // Decompose 256-bit ops into smaller 128-bit ops.
   if (VT.is256BitVector() && !Subtarget.hasInt256())
@@ -22378,7 +22443,7 @@
     SDValue Hi = DAG.getIntPtrConstant(NumElems / 2, dl);
 
     if (VT == MVT::v32i8) {
-      if (Subtarget.hasBWI()) {
+      if (Subtarget.canExtendTo512BW()) {
         SDValue ExA = DAG.getNode(ExAVX, dl, MVT::v32i16, A);
         SDValue ExB = DAG.getNode(ExAVX, dl, MVT::v32i16, B);
         SDValue Mul = DAG.getNode(ISD::MUL, dl, MVT::v32i16, ExA, ExB);
@@ -23167,10 +23232,12 @@
   // It's worth extending once and using the vXi16/vXi32 shifts for smaller
   // types, but without AVX512 the extra overheads to get from vXi8 to vXi32
   // make the existing SSE solution better.
+  // NOTE: We honor prefered vector width before promoting to 512-bits.
   if ((Subtarget.hasInt256() && VT == MVT::v8i16) ||
-      (Subtarget.hasAVX512() && VT == MVT::v16i16) ||
-      (Subtarget.hasAVX512() && VT == MVT::v16i8) ||
-      (Subtarget.hasBWI() && VT == MVT::v32i8)) {
+      (Subtarget.canExtendTo512DQ() && VT == MVT::v16i16) ||
+      (Subtarget.canExtendTo512DQ() && VT == MVT::v16i8) ||
+      (Subtarget.canExtendTo512BW() && VT == MVT::v32i8) ||
+      (Subtarget.hasBWI() && Subtarget.hasVLX() && VT == MVT::v16i8)) {
     assert((!Subtarget.hasBWI() || VT == MVT::v32i8 || VT == MVT::v16i8) &&
            "Unexpected vector type");
     MVT EvtSVT = Subtarget.hasBWI() ? MVT::i16 : MVT::i32;
@@ -23995,7 +24062,7 @@
     unsigned NumElems = VT.getVectorNumElements();
     assert((VT.getVectorElementType() == MVT::i8 ||
             VT.getVectorElementType() == MVT::i16) && "Unexpected type");
-    if (NumElems <= 16) {
+    if (NumElems < 16 || (NumElems == 16 && Subtarget.canExtendTo512DQ())) {
       MVT NewVT = MVT::getVectorVT(MVT::i32, NumElems);
       Op = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, Op0);
       Op = DAG.getNode(ISD::CTPOP, DL, NewVT, Op);
diff --git a/llvm/lib/Target/X86/X86Subtarget.h b/llvm/lib/Target/X86/X86Subtarget.h
index 08cc28e..a5db980 100644
--- a/llvm/lib/Target/X86/X86Subtarget.h
+++ b/llvm/lib/Target/X86/X86Subtarget.h
@@ -597,6 +597,17 @@
 
   unsigned getPreferVectorWidth() const { return PreferVectorWidth; }
 
+  // Helper functions to determine when we should allow widening to 512-bit
+  // during codegen.
+  // TODO: Currently we're always allowing widening on CPUs without VLX,
+  // because for many cases we don't have a better option.
+  bool canExtendTo512DQ() const {
+    return hasAVX512() && (!hasVLX() || getPreferVectorWidth() >= 512);
+  }
+  bool canExtendTo512BW() const  {
+    return hasBWI() && canExtendTo512DQ();
+  }
+
   bool isXRaySupported() const override { return is64Bit(); }
 
   X86ProcFamilyEnum getProcFamily() const { return X86ProcFamily; }