[RISCV] Lower the tail pseudoinstruction
This patch lowers the tail pseudoinstruction. This has been modeled after ARM's
tail call opt.
llvm-svn: 333137
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index d3d0d03..f572614 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -18,6 +18,7 @@
#include "RISCVRegisterInfo.h"
#include "RISCVSubtarget.h"
#include "RISCVTargetMachine.h"
+#include "llvm/ADT/Statistic.h"
#include "llvm/CodeGen/CallingConvLower.h"
#include "llvm/CodeGen/MachineFrameInfo.h"
#include "llvm/CodeGen/MachineFunction.h"
@@ -36,6 +37,8 @@
#define DEBUG_TYPE "riscv-lower"
+STATISTIC(NumTailCalls, "Number of tail calls");
+
RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
const RISCVSubtarget &STI)
: TargetLowering(TM), Subtarget(STI) {
@@ -1076,6 +1079,88 @@
return Chain;
}
+/// IsEligibleForTailCallOptimization - Check whether the call is eligible
+/// for tail call optimization.
+/// Note: This is modelled after ARM's IsEligibleForTailCallOptimization.
+bool RISCVTargetLowering::IsEligibleForTailCallOptimization(
+ CCState &CCInfo, CallLoweringInfo &CLI, MachineFunction &MF,
+ const SmallVector<CCValAssign, 16> &ArgLocs) const {
+
+ auto &Callee = CLI.Callee;
+ auto CalleeCC = CLI.CallConv;
+ auto IsVarArg = CLI.IsVarArg;
+ auto &Outs = CLI.Outs;
+ auto &Caller = MF.getFunction();
+ auto CallerCC = Caller.getCallingConv();
+
+ // Do not tail call opt functions with "disable-tail-calls" attribute.
+ if (Caller.getFnAttribute("disable-tail-calls").getValueAsString() == "true")
+ return false;
+
+ // Exception-handling functions need a special set of instructions to
+ // indicate a return to the hardware. Tail-calling another function would
+ // probably break this.
+ // TODO: The "interrupt" attribute isn't currently defined by RISC-V. This
+ // should be expanded as new function attributes are introduced.
+ if (Caller.hasFnAttribute("interrupt"))
+ return false;
+
+ // Do not tail call opt functions with varargs.
+ if (IsVarArg)
+ return false;
+
+ // Do not tail call opt if the stack is used to pass parameters.
+ if (CCInfo.getNextStackOffset() != 0)
+ return false;
+
+ // Do not tail call opt if any parameters need to be passed indirectly.
+ // Since long doubles (fp128) and i128 are larger than 2*XLEN, they are
+ // passed indirectly. So the address of the value will be passed in a
+ // register, or if not available, then the address is put on the stack. In
+ // order to pass indirectly, space on the stack often needs to be allocated
+ // in order to store the value. In this case the CCInfo.getNextStackOffset()
+ // != 0 check is not enough and we need to check if any CCValAssign ArgsLocs
+ // are passed CCValAssign::Indirect.
+ for (auto &VA : ArgLocs)
+ if (VA.getLocInfo() == CCValAssign::Indirect)
+ return false;
+
+ // Do not tail call opt if either caller or callee uses struct return
+ // semantics.
+ auto IsCallerStructRet = Caller.hasStructRetAttr();
+ auto IsCalleeStructRet = Outs.empty() ? false : Outs[0].Flags.isSRet();
+ if (IsCallerStructRet || IsCalleeStructRet)
+ return false;
+
+ // Externally-defined functions with weak linkage should not be
+ // tail-called. The behaviour of branch instructions in this situation (as
+ // used for tail calls) is implementation-defined, so we cannot rely on the
+ // linker replacing the tail call with a return.
+ if (GlobalAddressSDNode *G = dyn_cast<GlobalAddressSDNode>(Callee)) {
+ const GlobalValue *GV = G->getGlobal();
+ if (GV->hasExternalWeakLinkage())
+ return false;
+ }
+
+ // The callee has to preserve all registers the caller needs to preserve.
+ const RISCVRegisterInfo *TRI = Subtarget.getRegisterInfo();
+ const uint32_t *CallerPreserved = TRI->getCallPreservedMask(MF, CallerCC);
+ if (CalleeCC != CallerCC) {
+ const uint32_t *CalleePreserved = TRI->getCallPreservedMask(MF, CalleeCC);
+ if (!TRI->regmaskSubsetEqual(CallerPreserved, CalleePreserved))
+ return false;
+ }
+
+ // Byval parameters hand the function a pointer directly into the stack area
+ // we want to reuse during a tail call. Working around this *is* possible
+ // but less efficient and uglier in LowerCall.
+ for (auto &Arg : Outs)
+ if (Arg.Flags.isByVal())
+ return false;
+
+ return true;
+}
+
// Lower a call to a callseq_start + CALL + callseq_end chain, and add input
// and output parameter nodes.
SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI,
@@ -1087,7 +1172,7 @@
SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins;
SDValue Chain = CLI.Chain;
SDValue Callee = CLI.Callee;
- CLI.IsTailCall = false;
+ bool &IsTailCall = CLI.IsTailCall;
CallingConv::ID CallConv = CLI.CallConv;
bool IsVarArg = CLI.IsVarArg;
EVT PtrVT = getPointerTy(DAG.getDataLayout());
@@ -1100,6 +1185,17 @@
CCState ArgCCInfo(CallConv, IsVarArg, MF, ArgLocs, *DAG.getContext());
analyzeOutputArgs(MF, ArgCCInfo, Outs, /*IsRet=*/false, &CLI);
+ // Check if it's really possible to do a tail call.
+ if (IsTailCall)
+ IsTailCall = IsEligibleForTailCallOptimization(ArgCCInfo, CLI, MF,
+ ArgLocs);
+
+ if (IsTailCall)
+ ++NumTailCalls;
+ else if (CLI.CS && CLI.CS.isMustTailCall())
+ report_fatal_error("failed to perform tail call elimination on a call "
+ "site marked musttail");
+
// Get a count of how many bytes are to be pushed on the stack.
unsigned NumBytes = ArgCCInfo.getNextStackOffset();
@@ -1121,12 +1217,13 @@
Chain = DAG.getMemcpy(Chain, DL, FIPtr, Arg, SizeNode, Align,
/*IsVolatile=*/false,
/*AlwaysInline=*/false,
- /*isTailCall=*/false, MachinePointerInfo(),
+ IsTailCall, MachinePointerInfo(),
MachinePointerInfo());
ByValArgs.push_back(FIPtr);
}
- Chain = DAG.getCALLSEQ_START(Chain, NumBytes, 0, CLI.DL);
+ if (!IsTailCall)
+ Chain = DAG.getCALLSEQ_START(Chain, NumBytes, 0, CLI.DL);
// Copy argument values to their designated locations.
SmallVector<std::pair<unsigned, SDValue>, 8> RegsToPass;
@@ -1213,6 +1310,8 @@
RegsToPass.push_back(std::make_pair(VA.getLocReg(), ArgValue));
} else {
assert(VA.isMemLoc() && "Argument not register or memory");
+ assert(!IsTailCall && "Tail call not allowed if stack is used "
+ "for passing parameters");
// Work out the address of the stack slot.
if (!StackPtr.getNode())
@@ -1258,11 +1357,13 @@
for (auto &Reg : RegsToPass)
Ops.push_back(DAG.getRegister(Reg.first, Reg.second.getValueType()));
- // Add a register mask operand representing the call-preserved registers.
- const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo();
- const uint32_t *Mask = TRI->getCallPreservedMask(MF, CallConv);
- assert(Mask && "Missing call preserved mask for calling convention");
- Ops.push_back(DAG.getRegisterMask(Mask));
+ if (!IsTailCall) {
+ // Add a register mask operand representing the call-preserved registers.
+ const TargetRegisterInfo *TRI = Subtarget.getRegisterInfo();
+ const uint32_t *Mask = TRI->getCallPreservedMask(MF, CallConv);
+ assert(Mask && "Missing call preserved mask for calling convention");
+ Ops.push_back(DAG.getRegisterMask(Mask));
+ }
// Glue the call to the argument copies, if any.
if (Glue.getNode())
@@ -1270,6 +1371,12 @@
// Emit the call.
SDVTList NodeTys = DAG.getVTList(MVT::Other, MVT::Glue);
+
+ if (IsTailCall) {
+ MF.getFrameInfo().setHasTailCall();
+ return DAG.getNode(RISCVISD::TAIL, DL, NodeTys, Ops);
+ }
+
Chain = DAG.getNode(RISCVISD::CALL, DL, NodeTys, Ops);
Glue = Chain.getValue(1);
@@ -1425,6 +1532,8 @@
return "RISCVISD::BuildPairF64";
case RISCVISD::SplitF64:
return "RISCVISD::SplitF64";
+ case RISCVISD::TAIL:
+ return "RISCVISD::TAIL";
}
return nullptr;
}