[RISCV] Add support for _interrupt attribute

- Save/restore only registers that are used.
This includes Callee saved registers and Caller saved registers
(arguments and temporaries) for integer and FP registers.
- If there is a call in the interrupt handler, save/restore all
Caller saved registers (arguments and temporaries) and all FP registers.
- Emit special return instructions depending on "interrupt"
attribute type.
Based on initial patch by Zhaoshi Zheng.

Reviewers: asb

Reviewed By: asb

Subscribers: rkruppe, the_o, MartinMosbeck, brucehoult, rbar, johnrusso, simoncook, sabuasal, niosHD, kito-cheng, shiva0217, zzheng, edward-jones, mgrang, rogfer01, llvm-commits

Differential Revision: https://reviews.llvm.org/D48411

llvm-svn: 338047
diff --git a/llvm/lib/Target/RISCV/RISCVCallingConv.td b/llvm/lib/Target/RISCV/RISCVCallingConv.td
index d2b17c6..ef14625 100644
--- a/llvm/lib/Target/RISCV/RISCVCallingConv.td
+++ b/llvm/lib/Target/RISCV/RISCVCallingConv.td
@@ -18,3 +18,40 @@
 
 // Needed for implementation of RISCVRegisterInfo::getNoPreservedMask()
 def CSR_NoRegs : CalleeSavedRegs<(add)>;
+
+// Interrupt handler needs to save/restore all registers that are used,
+// both Caller and Callee saved registers.
+def CSR_Interrupt : CalleeSavedRegs<(add X1,
+    (sequence "X%u", 3, 9),
+    (sequence "X%u", 10, 11),
+    (sequence "X%u", 12, 17),
+    (sequence "X%u", 18, 27),
+    (sequence "X%u", 28, 31))>;
+
+// Same as CSR_Interrupt, but including all 32-bit FP registers.
+def CSR_XLEN_F32_Interrupt: CalleeSavedRegs<(add X1,
+    (sequence "X%u", 3, 9),
+    (sequence "X%u", 10, 11),
+    (sequence "X%u", 12, 17),
+    (sequence "X%u", 18, 27),
+    (sequence "X%u", 28, 31),
+    (sequence "F%u_32", 0, 7),
+    (sequence "F%u_32", 10, 11),
+    (sequence "F%u_32", 12, 17),
+    (sequence "F%u_32", 28, 31),
+    (sequence "F%u_32", 8, 9),
+    (sequence "F%u_32", 18, 27))>;
+
+// Same as CSR_Interrupt, but including all 64-bit FP registers.
+def CSR_XLEN_F64_Interrupt: CalleeSavedRegs<(add X1,
+    (sequence "X%u", 3, 9),
+    (sequence "X%u", 10, 11),
+    (sequence "X%u", 12, 17),
+    (sequence "X%u", 18, 27),
+    (sequence "X%u", 28, 31),
+    (sequence "F%u_64", 0, 7),
+    (sequence "F%u_64", 10, 11),
+    (sequence "F%u_64", 12, 17),
+    (sequence "F%u_64", 28, 31),
+    (sequence "F%u_64", 8, 9),
+    (sequence "F%u_64", 18, 27))>;
diff --git a/llvm/lib/Target/RISCV/RISCVFrameLowering.cpp b/llvm/lib/Target/RISCV/RISCVFrameLowering.cpp
index 3dc1e3d..a816028 100644
--- a/llvm/lib/Target/RISCV/RISCVFrameLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVFrameLowering.cpp
@@ -212,6 +212,36 @@
     SavedRegs.set(RISCV::X1);
     SavedRegs.set(RISCV::X8);
   }
+
+  // If interrupt is enabled and there are calls in the handler,
+  // unconditionally save all Caller-saved registers and
+  // all FP registers, regardless whether they are used.
+  MachineFrameInfo &MFI = MF.getFrameInfo();
+
+  if (MF.getFunction().hasFnAttribute("interrupt") && MFI.hasCalls()) {
+
+    static const MCPhysReg CSRegs[] = { RISCV::X1,      /* ra */
+      RISCV::X5, RISCV::X6, RISCV::X7,                  /* t0-t2 */
+      RISCV::X10, RISCV::X11,                           /* a0-a1, a2-a7 */
+      RISCV::X12, RISCV::X13, RISCV::X14, RISCV::X15, RISCV::X16, RISCV::X17,
+      RISCV::X28, RISCV::X29, RISCV::X30, RISCV::X31, 0 /* t3-t6 */
+    };
+
+    for (unsigned i = 0; CSRegs[i]; ++i)
+      SavedRegs.set(CSRegs[i]);
+
+    if (MF.getSubtarget<RISCVSubtarget>().hasStdExtD() ||
+        MF.getSubtarget<RISCVSubtarget>().hasStdExtF()) {
+
+      // If interrupt is enabled, this list contains all FP registers.
+      const MCPhysReg * Regs = MF.getRegInfo().getCalleeSavedRegs();
+
+      for (unsigned i = 0; Regs[i]; ++i)
+        if (RISCV::FPR32RegClass.contains(Regs[i]) ||
+            RISCV::FPR64RegClass.contains(Regs[i]))
+          SavedRegs.set(Regs[i]);
+    }
+  }
 }
 
 void RISCVFrameLowering::processFunctionBeforeFrameFinalized(
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index dda26a7..87796e5 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -967,6 +967,21 @@
   }
 
   MachineFunction &MF = DAG.getMachineFunction();
+
+  const Function &Func = MF.getFunction();
+  if (Func.hasFnAttribute("interrupt")) {
+    if (!Func.arg_empty())
+      report_fatal_error(
+        "Functions with the interrupt attribute cannot have arguments!");
+
+    StringRef Kind =
+      MF.getFunction().getFnAttribute("interrupt").getValueAsString();
+
+    if (!(Kind == "user" || Kind == "supervisor" || Kind == "machine"))
+      report_fatal_error(
+        "Function interrupt attribute argument not supported!");
+  }
+
   EVT PtrVT = getPointerTy(DAG.getDataLayout());
   MVT XLenVT = Subtarget.getXLenVT();
   unsigned XLenInBytes = Subtarget.getXLen() / 8;
@@ -1515,6 +1530,28 @@
     RetOps.push_back(Glue);
   }
 
+  // Interrupt service routines use different return instructions.
+  const Function &Func = DAG.getMachineFunction().getFunction();
+  if (Func.hasFnAttribute("interrupt")) {
+    if (!Func.getReturnType()->isVoidTy())
+      report_fatal_error(
+          "Functions with the interrupt attribute must have void return type!");
+
+    MachineFunction &MF = DAG.getMachineFunction();
+    StringRef Kind =
+      MF.getFunction().getFnAttribute("interrupt").getValueAsString();
+
+    unsigned RetOpc;
+    if (Kind == "user")
+      RetOpc = RISCVISD::URET_FLAG;
+    else if (Kind == "supervisor")
+      RetOpc = RISCVISD::SRET_FLAG;
+    else
+      RetOpc = RISCVISD::MRET_FLAG;
+
+    return DAG.getNode(RetOpc, DL, MVT::Other, RetOps);
+  }
+
   return DAG.getNode(RISCVISD::RET_FLAG, DL, MVT::Other, RetOps);
 }
 
@@ -1524,6 +1561,12 @@
     break;
   case RISCVISD::RET_FLAG:
     return "RISCVISD::RET_FLAG";
+  case RISCVISD::URET_FLAG:
+    return "RISCVISD::URET_FLAG";
+  case RISCVISD::SRET_FLAG:
+    return "RISCVISD::SRET_FLAG";
+  case RISCVISD::MRET_FLAG:
+    return "RISCVISD::MRET_FLAG";
   case RISCVISD::CALL:
     return "RISCVISD::CALL";
   case RISCVISD::SELECT_CC:
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 41b9d27..280adb2 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -25,6 +25,9 @@
 enum NodeType : unsigned {
   FIRST_NUMBER = ISD::BUILTIN_OP_END,
   RET_FLAG,
+  URET_FLAG,
+  SRET_FLAG,
+  MRET_FLAG,
   CALL,
   SELECT_CC,
   BuildPairF64,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index e62308e..327e4a7 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -118,7 +118,8 @@
   unsigned Opcode;
 
   if (RISCV::GPRRegClass.hasSubClassEq(RC))
-    Opcode = RISCV::SW;
+    Opcode = TRI->getRegSizeInBits(RISCV::GPRRegClass) == 32 ?
+             RISCV::SW : RISCV::SD;
   else if (RISCV::FPR32RegClass.hasSubClassEq(RC))
     Opcode = RISCV::FSW;
   else if (RISCV::FPR64RegClass.hasSubClassEq(RC))
@@ -144,7 +145,8 @@
   unsigned Opcode;
 
   if (RISCV::GPRRegClass.hasSubClassEq(RC))
-    Opcode = RISCV::LW;
+    Opcode = TRI->getRegSizeInBits(RISCV::GPRRegClass) == 32 ?
+             RISCV::LW : RISCV::LD;
   else if (RISCV::FPR32RegClass.hasSubClassEq(RC))
     Opcode = RISCV::FLW;
   else if (RISCV::FPR64RegClass.hasSubClassEq(RC))
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.td b/llvm/lib/Target/RISCV/RISCVInstrInfo.td
index e1c3a06..b51e4e7 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.td
@@ -36,6 +36,12 @@
                           [SDNPHasChain, SDNPOptInGlue, SDNPOutGlue]>;
 def RetFlag      : SDNode<"RISCVISD::RET_FLAG", SDTNone,
                           [SDNPHasChain, SDNPOptInGlue, SDNPVariadic]>;
+def URetFlag     : SDNode<"RISCVISD::URET_FLAG", SDTNone,
+                          [SDNPHasChain, SDNPOptInGlue]>;
+def SRetFlag     : SDNode<"RISCVISD::SRET_FLAG", SDTNone,
+                          [SDNPHasChain, SDNPOptInGlue]>;
+def MRetFlag     : SDNode<"RISCVISD::MRET_FLAG", SDTNone,
+                          [SDNPHasChain, SDNPOptInGlue]>;
 def SelectCC     : SDNode<"RISCVISD::SELECT_CC", SDT_RISCVSelectCC,
                           [SDNPInGlue]>;
 def Tail         : SDNode<"RISCVISD::TAIL", SDT_RISCVCall,
@@ -684,6 +690,10 @@
 
 def : Pat<(Call texternalsym:$func), (PseudoCALL texternalsym:$func)>;
 
+def : Pat<(URetFlag), (URET X0, X0)>;
+def : Pat<(SRetFlag), (SRET X0, X0)>;
+def : Pat<(MRetFlag), (MRET X0, X0)>;
+
 let isCall = 1, Defs = [X1] in
 def PseudoCALLIndirect : Pseudo<(outs), (ins GPR:$rs1), [(Call GPR:$rs1)]>,
                          PseudoInstExpansion<(JALR X1, GPR:$rs1, 0)>;
diff --git a/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp b/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp
index f72730f..3ed1dec 100644
--- a/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp
@@ -33,6 +33,13 @@
 
 const MCPhysReg *
 RISCVRegisterInfo::getCalleeSavedRegs(const MachineFunction *MF) const {
+  if (MF->getFunction().hasFnAttribute("interrupt")) {
+    if (MF->getSubtarget<RISCVSubtarget>().hasStdExtD())
+      return CSR_XLEN_F64_Interrupt_SaveList;
+    if (MF->getSubtarget<RISCVSubtarget>().hasStdExtF())
+      return CSR_XLEN_F32_Interrupt_SaveList;
+    return CSR_Interrupt_SaveList;
+  }
   return CSR_SaveList;
 }
 
@@ -108,7 +115,14 @@
 }
 
 const uint32_t *
-RISCVRegisterInfo::getCallPreservedMask(const MachineFunction & /*MF*/,
+RISCVRegisterInfo::getCallPreservedMask(const MachineFunction & MF,
                                         CallingConv::ID /*CC*/) const {
+  if (MF.getFunction().hasFnAttribute("interrupt")) {
+    if (MF.getSubtarget<RISCVSubtarget>().hasStdExtD())
+      return CSR_XLEN_F64_Interrupt_RegMask;
+    if (MF.getSubtarget<RISCVSubtarget>().hasStdExtF())
+      return CSR_XLEN_F32_Interrupt_RegMask;
+    return CSR_Interrupt_RegMask;
+  }
   return CSR_RegMask;
 }