[X86] Support of no_caller_saved_registers attribute

This patch implements the LLVM part for no_caller_saved_registers attribute as appears here: https://gcc.gnu.org/git/?p=gcc.git;a=commit;h=5ed3cc7b66af4758f7849ed6f65f4365be8223be.
In order to implement the attribute, we use the dynamic CSR mechanism to remove returned/passed arguments from the function regmask/CSR list.

Differential Revision: https://reviews.llvm.org/D31876

llvm-svn: 302020
diff --git a/llvm/lib/Target/X86/X86FastISel.cpp b/llvm/lib/Target/X86/X86FastISel.cpp
index fd11b67..ebd179e 100644
--- a/llvm/lib/Target/X86/X86FastISel.cpp
+++ b/llvm/lib/Target/X86/X86FastISel.cpp
@@ -3181,6 +3181,15 @@
   bool Is64Bit        = Subtarget->is64Bit();
   bool IsWin64        = Subtarget->isCallingConvWin64(CC);
 
+  const CallInst *CI =
+      CLI.CS ? dyn_cast<CallInst>(CLI.CS->getInstruction()) : nullptr;
+  const Function *CalledFn = CI ? CI->getCalledFunction() : nullptr;
+
+  // Functions with no_caller_saved_registers that need special handling.
+  if ((CI && CI->hasFnAttr("no_caller_saved_registers")) ||
+      (CalledFn && CalledFn->hasFnAttribute("no_caller_saved_registers")))
+    return false;
+
   // Handle only C, fastcc, and webkit_js calling conventions for now.
   switch (CC) {
   default: return false;
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 25dd1b6..19857c2 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -2180,6 +2180,12 @@
   MachineFunction &MF = DAG.getMachineFunction();
   X86MachineFunctionInfo *FuncInfo = MF.getInfo<X86MachineFunctionInfo>();
 
+  // In some cases we need to disable registers from the default CSR list.
+  // For example, when they are used for argument passing.
+  bool ShouldDisableCalleeSavedRegister =
+      CallConv == CallingConv::X86_RegCall ||
+      MF.getFunction()->hasFnAttribute("no_caller_saved_registers");
+
   if (CallConv == CallingConv::X86_INTR && !Outs.empty())
     report_fatal_error("X86 interrupts may not return any value");
 
@@ -2201,7 +2207,7 @@
     assert(VA.isRegLoc() && "Can only return in registers!");
 
     // Add the register to the CalleeSaveDisableRegs list.
-    if (CallConv == CallingConv::X86_RegCall)
+    if (ShouldDisableCalleeSavedRegister)
       MF.getRegInfo().disableCalleeSavedRegister(VA.getLocReg());
 
     SDValue ValToCopy = OutVals[OutsIndex];
@@ -2280,7 +2286,7 @@
              "Expecting two registers after Pass64BitArgInRegs");
 
       // Add the second register to the CalleeSaveDisableRegs list.
-      if (CallConv == CallingConv::X86_RegCall)
+      if (ShouldDisableCalleeSavedRegister)
         MF.getRegInfo().disableCalleeSavedRegister(RVLocs[I].getLocReg());
     } else {
       RegsToPass.push_back(std::make_pair(VA.getLocReg(), ValToCopy));
@@ -2340,7 +2346,7 @@
         DAG.getRegister(RetValReg, getPointerTy(DAG.getDataLayout())));
 
     // Add the returned register to the CalleeSaveDisableRegs list.
-    if (CallConv == CallingConv::X86_RegCall)
+    if (ShouldDisableCalleeSavedRegister)
       MF.getRegInfo().disableCalleeSavedRegister(RetValReg);
   }
 
@@ -2540,7 +2546,7 @@
 
     // In some calling conventions we need to remove the used registers
     // from the register mask.
-    if (RegMask && CallConv == CallingConv::X86_RegCall) {
+    if (RegMask) {
       for (MCSubRegIterator SubRegs(VA.getLocReg(), TRI, /*IncludeSelf=*/true);
            SubRegs.isValid(); ++SubRegs)
         RegMask[*SubRegs / 32] &= ~(1u << (*SubRegs % 32));
@@ -3237,7 +3243,8 @@
     }
   }
 
-  if (CallConv == CallingConv::X86_RegCall) {
+  if (CallConv == CallingConv::X86_RegCall ||
+      Fn->hasFnAttribute("no_caller_saved_registers")) {
     const MachineRegisterInfo &MRI = MF.getRegInfo();
     for (const auto &Pair : make_range(MRI.livein_begin(), MRI.livein_end()))
       MF.getRegInfo().disableCalleeSavedRegister(Pair.first);
@@ -3329,6 +3336,11 @@
   bool IsSibcall      = false;
   X86MachineFunctionInfo *X86Info = MF.getInfo<X86MachineFunctionInfo>();
   auto Attr = MF.getFunction()->getFnAttribute("disable-tail-calls");
+  const CallInst *CI =
+      CLI.CS ? dyn_cast<CallInst>(CLI.CS->getInstruction()) : nullptr;
+  const Function *Fn = CI ? CI->getCalledFunction() : nullptr;
+  bool HasNCSR = (CI && CI->hasFnAttr("no_caller_saved_registers")) ||
+                 (Fn && Fn->hasFnAttribute("no_caller_saved_registers"));
 
   if (CallConv == CallingConv::X86_INTR)
     report_fatal_error("X86 interrupts may not be called directly");
@@ -3741,7 +3753,11 @@
                                   RegsToPass[i].second.getValueType()));
 
   // Add a register mask operand representing the call-preserved registers.
-  const uint32_t *Mask = RegInfo->getCallPreservedMask(MF, CallConv);
+  // If HasNCSR is asserted (attribute NoCallerSavedRegisters exists) then we
+  // set X86_INTR calling convention because it has the same CSR mask
+  // (same preserved registers).
+  const uint32_t *Mask = RegInfo->getCallPreservedMask(
+      MF, HasNCSR ? CallingConv::X86_INTR : CallConv);
   assert(Mask && "Missing call preserved mask for calling convention");
 
   // If this is an invoke in a 32-bit function using a funclet-based
@@ -3764,7 +3780,7 @@
 
   // In some calling conventions we need to remove the used physical registers
   // from the reg mask.
-  if (CallConv == CallingConv::X86_RegCall) {
+  if (CallConv == CallingConv::X86_RegCall || HasNCSR) {
     const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo();
 
     // Allocate a new Reg Mask and copy Mask.
diff --git a/llvm/lib/Target/X86/X86RegisterInfo.cpp b/llvm/lib/Target/X86/X86RegisterInfo.cpp
index 1f16f3c..cf2ceef 100644
--- a/llvm/lib/Target/X86/X86RegisterInfo.cpp
+++ b/llvm/lib/Target/X86/X86RegisterInfo.cpp
@@ -276,7 +276,14 @@
   bool HasAVX512 = Subtarget.hasAVX512();
   bool CallsEHReturn = MF->callsEHReturn();
 
-  switch (MF->getFunction()->getCallingConv()) {
+  CallingConv::ID CC = MF->getFunction()->getCallingConv();
+
+  // If attribute NoCallerSavedRegisters exists then we set X86_INTR calling
+  // convention because it has the CSR list.
+  if (MF->getFunction()->hasFnAttribute("no_caller_saved_registers"))
+    CC = CallingConv::X86_INTR;
+
+  switch (CC) {
   case CallingConv::GHC:
   case CallingConv::HiPE:
     return CSR_NoRegs_SaveList;