[X86] Don't make 512-bit vectors legal when preferred vector width is 256 bits and 512 bits aren't required

This patch adds a new function attribute "required-vector-width" that can be set by the frontend to indicate the maximum vector width present in the original source code. The idea is that this would be set based on ABI requirements, intrinsics or explicit vector types being used, maybe simd pragmas, etc. The backend will then use this information to determine if its save to make 512-bit vectors illegal when the preference is for 256-bit vectors.

For code that has no vectors in it originally and only get vectors through the loop and slp vectorizers this allows us to generate code largely similar to our AVX2 only output while still enabling AVX512 features like mask registers and gather/scatter. The loop vectorizer doesn't always obey TTI and will create oversized vectors with the expectation the backend will legalize it. In order to avoid changing the vectorizer and potentially harm our AVX2 codegen this patch tries to make the legalizer behavior similar.

This is restricted to CPUs that support AVX512F and AVX512VL so that we have good fallback options to use 128 and 256-bit vectors and still get masking.

I've qualified every place I could find in X86ISelLowering.cpp and added tests cases for many of them with 2 different values for the attribute to see the codegen differences.

We still need to do frontend work for the attribute and teach the inliner how to merge it, etc. But this gets the codegen layer ready for it.

Differential Revision: https://reviews.llvm.org/D42724

llvm-svn: 324834
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 88411c1..ac3d7b5 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -1144,12 +1144,10 @@
     }
   }
 
+  // This block controls legalization of the mask vector sizes that are
+  // available with AVX512. 512-bit vectors are in a separate block controlled
+  // by useAVX512Regs.
   if (!Subtarget.useSoftFloat() && Subtarget.hasAVX512()) {
-    addRegisterClass(MVT::v16i32, &X86::VR512RegClass);
-    addRegisterClass(MVT::v16f32, &X86::VR512RegClass);
-    addRegisterClass(MVT::v8i64,  &X86::VR512RegClass);
-    addRegisterClass(MVT::v8f64,  &X86::VR512RegClass);
-
     addRegisterClass(MVT::v1i1,   &X86::VK1RegClass);
     addRegisterClass(MVT::v2i1,   &X86::VK2RegClass);
     addRegisterClass(MVT::v4i1,   &X86::VK4RegClass);
@@ -1160,8 +1158,6 @@
     setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v1i1, Custom);
     setOperationAction(ISD::BUILD_VECTOR,       MVT::v1i1, Custom);
 
-    setOperationPromotedToType(ISD::FP_TO_SINT, MVT::v16i1, MVT::v16i32);
-    setOperationPromotedToType(ISD::FP_TO_UINT, MVT::v16i1, MVT::v16i32);
     setOperationPromotedToType(ISD::FP_TO_SINT, MVT::v8i1,  MVT::v8i32);
     setOperationPromotedToType(ISD::FP_TO_UINT, MVT::v8i1,  MVT::v8i32);
     setOperationPromotedToType(ISD::FP_TO_SINT, MVT::v4i1,  MVT::v4i32);
@@ -1200,6 +1196,16 @@
     setOperationAction(ISD::INSERT_SUBVECTOR,   MVT::v16i1, Custom);
     for (auto VT : { MVT::v1i1, MVT::v2i1, MVT::v4i1, MVT::v8i1 })
       setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom);
+  }
+
+  // This block controls legalization for 512-bit operations with 32/64 bit
+  // elements. 512-bits can be disabled based on prefer-vector-width and
+  // required-vector-width function attributes.
+  if (!Subtarget.useSoftFloat() && Subtarget.useAVX512Regs()) {
+    addRegisterClass(MVT::v16i32, &X86::VR512RegClass);
+    addRegisterClass(MVT::v16f32, &X86::VR512RegClass);
+    addRegisterClass(MVT::v8i64,  &X86::VR512RegClass);
+    addRegisterClass(MVT::v8f64,  &X86::VR512RegClass);
 
     for (MVT VT : MVT::fp_vector_valuetypes())
       setLoadExtAction(ISD::EXTLOAD, VT, MVT::v8f32, Legal);
@@ -1222,7 +1228,9 @@
     setOperationAction(ISD::FP_TO_SINT,         MVT::v16i32, Legal);
     setOperationPromotedToType(ISD::FP_TO_SINT, MVT::v16i16, MVT::v16i32);
     setOperationPromotedToType(ISD::FP_TO_SINT, MVT::v16i8, MVT::v16i32);
+    setOperationPromotedToType(ISD::FP_TO_SINT, MVT::v16i1, MVT::v16i32);
     setOperationAction(ISD::FP_TO_UINT,         MVT::v16i32, Legal);
+    setOperationPromotedToType(ISD::FP_TO_UINT, MVT::v16i1, MVT::v16i32);
     setOperationPromotedToType(ISD::FP_TO_UINT, MVT::v16i8, MVT::v16i32);
     setOperationPromotedToType(ISD::FP_TO_UINT, MVT::v16i16, MVT::v16i32);
     setOperationAction(ISD::SINT_TO_FP,         MVT::v16i32, Legal);
@@ -1352,6 +1360,9 @@
     }
   }// has  AVX-512
 
+  // This block controls legalization for operations that don't have
+  // pre-AVX512 equivalents. Without VLX we use 512-bit operations for
+  // narrower widths.
   if (!Subtarget.useSoftFloat() && Subtarget.hasAVX512()) {
     // These operations are handled on non-VLX by artificially widening in
     // isel patterns.
@@ -1406,10 +1417,10 @@
     }
   }
 
+  // This block control legalization of v32i1/v64i1 which are available with
+  // AVX512BW. 512-bit v32i16 and v64i8 vector legalization is controlled with
+  // useBWIRegs.
   if (!Subtarget.useSoftFloat() && Subtarget.hasBWI()) {
-    addRegisterClass(MVT::v32i16, &X86::VR512RegClass);
-    addRegisterClass(MVT::v64i8,  &X86::VR512RegClass);
-
     addRegisterClass(MVT::v32i1,  &X86::VK32RegClass);
     addRegisterClass(MVT::v64i1,  &X86::VK64RegClass);
 
@@ -1439,6 +1450,15 @@
     setOperationAction(ISD::SIGN_EXTEND,        MVT::v32i8, Custom);
     setOperationAction(ISD::ZERO_EXTEND,        MVT::v32i8, Custom);
     setOperationAction(ISD::ANY_EXTEND,         MVT::v32i8, Custom);
+  }
+
+  // This block controls legalization for v32i16 and v64i8. 512-bits can be
+  // disabled based on prefer-vector-width and required-vector-width function
+  // attributes.
+  if (!Subtarget.useSoftFloat() && Subtarget.useBWIRegs()) {
+    addRegisterClass(MVT::v32i16, &X86::VR512RegClass);
+    addRegisterClass(MVT::v64i8,  &X86::VR512RegClass);
+
     // Extends from v64i1 masks to 512-bit vectors.
     setOperationAction(ISD::SIGN_EXTEND,        MVT::v64i8, Custom);
     setOperationAction(ISD::ZERO_EXTEND,        MVT::v64i8, Custom);
@@ -30049,7 +30069,7 @@
   EVT VT = N->getValueType(0);
   if ((!Subtarget.hasSSE3() || (VT != MVT::v4f32 && VT != MVT::v2f64)) &&
       (!Subtarget.hasAVX() || (VT != MVT::v8f32 && VT != MVT::v4f64)) &&
-      (!Subtarget.hasAVX512() || (VT != MVT::v16f32 && VT != MVT::v8f64)))
+      (!Subtarget.useAVX512Regs() || (VT != MVT::v16f32 && VT != MVT::v8f64)))
     return false;
 
   // We only handle target-independent shuffles.
@@ -31086,7 +31106,7 @@
     return SDValue();
 
   unsigned RegSize = 128;
-  if (Subtarget.hasBWI())
+  if (Subtarget.useBWIRegs())
     RegSize = 512;
   else if (Subtarget.hasAVX2())
     RegSize = 256;
@@ -32664,7 +32684,7 @@
   if (Subtarget.getProcFamily() != X86Subtarget::IntelKNL &&
       ((VT == MVT::v4i32 && Subtarget.hasSSE2()) ||
        (VT == MVT::v8i32 && Subtarget.hasAVX2()) ||
-       (VT == MVT::v16i32 && Subtarget.hasBWI()))) {
+       (VT == MVT::v16i32 && Subtarget.useBWIRegs()))) {
     SDValue N0 = N->getOperand(0);
     SDValue N1 = N->getOperand(1);
     APInt Mask17 = APInt::getHighBitsSet(32, 17);
@@ -34190,7 +34210,7 @@
                                SDValue Op1, F Builder) {
   assert(Subtarget.hasSSE2() && "Target assumed to support at least SSE2");
   unsigned NumSubs = 1;
-  if (Subtarget.hasBWI()) {
+  if (Subtarget.useBWIRegs()) {
     if (VT.getSizeInBits() > 512) {
       NumSubs = VT.getSizeInBits() / 512;
       assert((VT.getSizeInBits() % 512) == 0 && "Illegal vector size");
@@ -36181,7 +36201,7 @@
   // Also use this if we don't have SSE41 to allow the legalizer do its job.
   if (!Subtarget.hasSSE41() || VT.is128BitVector() ||
       (VT.is256BitVector() && Subtarget.hasInt256()) ||
-      (VT.is512BitVector() && Subtarget.hasAVX512())) {
+      (VT.is512BitVector() && Subtarget.useAVX512Regs())) {
     SDValue ExOp = ExtendVecSize(DL, N0, VT.getSizeInBits());
     return Opcode == ISD::SIGN_EXTEND
                ? DAG.getSignExtendVectorInReg(ExOp, DL, VT)
@@ -36214,7 +36234,7 @@
 
   // On pre-AVX512 targets, split into 256-bit nodes of
   // ISD::*_EXTEND_VECTOR_INREG.
-  if (!Subtarget.hasAVX512() && !(VT.getSizeInBits() % 256))
+  if (!Subtarget.useAVX512Regs() && !(VT.getSizeInBits() % 256))
     return SplitAndExtendInReg(256);
 
   return SDValue();
@@ -37169,7 +37189,7 @@
   EVT VT = N->getValueType(0);
 
   unsigned RegSize = 128;
-  if (Subtarget.hasBWI())
+  if (Subtarget.useBWIRegs())
     RegSize = 512;
   else if (Subtarget.hasAVX2())
     RegSize = 256;
@@ -37214,7 +37234,7 @@
     return SDValue();
 
   unsigned RegSize = 128;
-  if (Subtarget.hasBWI())
+  if (Subtarget.useBWIRegs())
     RegSize = 512;
   else if (Subtarget.hasAVX2())
     RegSize = 256;
@@ -37442,8 +37462,8 @@
   if (!(Subtarget.hasSSE2() && (VT == MVT::v16i8 || VT == MVT::v8i16)) &&
       !(Subtarget.hasSSE41() && (VT == MVT::v8i32)) &&
       !(Subtarget.hasAVX2() && (VT == MVT::v32i8 || VT == MVT::v16i16)) &&
-      !(Subtarget.hasBWI() && (VT == MVT::v64i8 || VT == MVT::v32i16 ||
-                               VT == MVT::v16i32 || VT == MVT::v8i64)))
+      !(Subtarget.useBWIRegs() && (VT == MVT::v64i8 || VT == MVT::v32i16 ||
+                                   VT == MVT::v16i32 || VT == MVT::v8i64)))
     return SDValue();
 
   SDValue SubusLHS, SubusRHS;
diff --git a/llvm/lib/Target/X86/X86Subtarget.cpp b/llvm/lib/Target/X86/X86Subtarget.cpp
index 7c4af0a..0a6eda7 100644
--- a/llvm/lib/Target/X86/X86Subtarget.cpp
+++ b/llvm/lib/Target/X86/X86Subtarget.cpp
@@ -373,11 +373,13 @@
 X86Subtarget::X86Subtarget(const Triple &TT, StringRef CPU, StringRef FS,
                            const X86TargetMachine &TM,
                            unsigned StackAlignOverride,
-                           unsigned PreferVectorWidthOverride)
+                           unsigned PreferVectorWidthOverride,
+                           unsigned RequiredVectorWidth)
     : X86GenSubtargetInfo(TT, CPU, FS), X86ProcFamily(Others),
       PICStyle(PICStyles::None), TM(TM), TargetTriple(TT),
       StackAlignOverride(StackAlignOverride),
       PreferVectorWidthOverride(PreferVectorWidthOverride),
+      RequiredVectorWidth(RequiredVectorWidth),
       In64BitMode(TargetTriple.getArch() == Triple::x86_64),
       In32BitMode(TargetTriple.getArch() == Triple::x86 &&
                   TargetTriple.getEnvironment() != Triple::CODE16),
diff --git a/llvm/lib/Target/X86/X86Subtarget.h b/llvm/lib/Target/X86/X86Subtarget.h
index b8adb63..4a23f4b 100644
--- a/llvm/lib/Target/X86/X86Subtarget.h
+++ b/llvm/lib/Target/X86/X86Subtarget.h
@@ -407,6 +407,9 @@
   /// features.
   unsigned PreferVectorWidth;
 
+  /// Required vector width from function attribute.
+  unsigned RequiredVectorWidth;
+
   /// True if compiling for 64-bit, false for 16-bit or 32-bit.
   bool In64BitMode;
 
@@ -433,7 +436,8 @@
   ///
   X86Subtarget(const Triple &TT, StringRef CPU, StringRef FS,
                const X86TargetMachine &TM, unsigned StackAlignOverride,
-               unsigned PreferVectorWidthOverride);
+               unsigned PreferVectorWidthOverride,
+               unsigned RequiredVectorWidth);
 
   const X86TargetLowering *getTargetLowering() const override {
     return &TLInfo;
@@ -622,6 +626,7 @@
   bool useRetpolineExternalThunk() const { return UseRetpolineExternalThunk; }
 
   unsigned getPreferVectorWidth() const { return PreferVectorWidth; }
+  unsigned getRequiredVectorWidth() const { return RequiredVectorWidth; }
 
   // Helper functions to determine when we should allow widening to 512-bit
   // during codegen.
@@ -634,6 +639,16 @@
     return hasBWI() && canExtendTo512DQ();
   }
 
+  // If there are no 512-bit vectors and we prefer not to use 512-bit registers,
+  // disable them in the legalizer.
+  bool useAVX512Regs() const {
+    return hasAVX512() && (canExtendTo512DQ() || RequiredVectorWidth > 256);
+  }
+
+  bool useBWIRegs() const {
+    return hasBWI() && useAVX512Regs();
+  }
+
   bool isXRaySupported() const override { return is64Bit(); }
 
   X86ProcFamilyEnum getProcFamily() const { return X86ProcFamily; }
diff --git a/llvm/lib/Target/X86/X86TargetMachine.cpp b/llvm/lib/Target/X86/X86TargetMachine.cpp
index 5f67949..c79f1b4 100644
--- a/llvm/lib/Target/X86/X86TargetMachine.cpp
+++ b/llvm/lib/Target/X86/X86TargetMachine.cpp
@@ -259,8 +259,7 @@
   // the feature string out later.
   unsigned CPUFSWidth = Key.size();
 
-  // Translate vector width function attribute into subtarget features. This
-  // overrides any CPU specific turning parameter
+  // Extract prefer-vector-width attribute.
   unsigned PreferVectorWidthOverride = 0;
   if (F.hasFnAttribute("prefer-vector-width")) {
     StringRef Val = F.getFnAttribute("prefer-vector-width").getValueAsString();
@@ -272,6 +271,21 @@
     }
   }
 
+  // Extract required-vector-width attribute.
+  unsigned RequiredVectorWidth = UINT32_MAX;
+  if (F.hasFnAttribute("required-vector-width")) {
+    StringRef Val = F.getFnAttribute("required-vector-width").getValueAsString();
+    unsigned Width;
+    if (!Val.getAsInteger(0, Width)) {
+      Key += ",required-vector-width=";
+      Key += Val;
+      RequiredVectorWidth = Width;
+    }
+  }
+
+  // Extracted here so that we make sure there is backing for the StringRef. If
+  // we assigned earlier, its possible the SmallString reallocated leaving a
+  // dangling StringRef.
   FS = Key.slice(CPU.size(), CPUFSWidth);
 
   auto &I = SubtargetMap[Key];
@@ -282,7 +296,8 @@
     resetTargetOptions(F);
     I = llvm::make_unique<X86Subtarget>(TargetTriple, CPU, FS, *this,
                                         Options.StackAlignmentOverride,
-                                        PreferVectorWidthOverride);
+                                        PreferVectorWidthOverride,
+                                        RequiredVectorWidth);
   }
   return I.get();
 }