PTX: Use .param space for device function return values on SM 2.0+, and attempt
to fix up parameter passing on SM < 2.0
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@140309 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Target/PTX/PTXISelLowering.cpp b/lib/Target/PTX/PTXISelLowering.cpp
index a05a55b..424c5a1 100644
--- a/lib/Target/PTX/PTXISelLowering.cpp
+++ b/lib/Target/PTX/PTXISelLowering.cpp
@@ -132,6 +132,10 @@
return "PTXISD::LOAD_PARAM";
case PTXISD::STORE_PARAM:
return "PTXISD::STORE_PARAM";
+ case PTXISD::READ_PARAM:
+ return "PTXISD::READ_PARAM";
+ case PTXISD::WRITE_PARAM:
+ return "PTXISD::WRITE_PARAM";
case PTXISD::EXIT:
return "PTXISD::EXIT";
case PTXISD::RET:
@@ -220,7 +224,6 @@
if (MFI->isKernel() || ST.useParamSpaceForDeviceArgs()) {
// We just need to emit the proper LOAD_PARAM ISDs
for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
-
assert((!MFI->isKernel() || Ins[i].VT != MVT::i1) &&
"Kernels cannot take pred operands");
@@ -231,26 +234,140 @@
// Instead of storing a physical register in our argument list, we just
// store the total size of the parameter, in bits. The ASM printer
// knows how to process this.
- MFI->addArgReg(Ins[i].VT.getStoreSizeInBits());
+ MFI->addArgParam(Ins[i].VT.getStoreSizeInBits());
}
}
else {
// For device functions, we use the PTX calling convention to do register
// assignments then create CopyFromReg ISDs for the allocated registers
- SmallVector<CCValAssign, 16> ArgLocs;
- CCState CCInfo(CallConv, isVarArg, MF, getTargetMachine(), ArgLocs,
- *DAG.getContext());
+ //SmallVector<CCValAssign, 16> ArgLocs;
+ //CCState CCInfo(CallConv, isVarArg, MF, getTargetMachine(), ArgLocs,
+ // *DAG.getContext());
- CCInfo.AnalyzeFormalArguments(Ins, CC_PTX);
+ //CCInfo.AnalyzeFormalArguments(Ins, CC_PTX);
- for (unsigned i = 0, e = ArgLocs.size(); i != e; ++i) {
+ //for (unsigned i = 0, e = ArgLocs.size(); i != e; ++i) {
+ for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
- CCValAssign& VA = ArgLocs[i];
- EVT RegVT = VA.getLocVT();
+ EVT RegVT = Ins[i].VT;
TargetRegisterClass* TRC = 0;
+ int OpCode;
- assert(VA.isRegLoc() && "CCValAssign must be RegLoc");
+ //assert(VA.isRegLoc() && "CCValAssign must be RegLoc");
+
+ // Determine which register class we need
+ if (RegVT == MVT::i1) {
+ TRC = PTX::RegPredRegisterClass;
+ OpCode = PTX::READPARAMPRED;
+ }
+ else if (RegVT == MVT::i16) {
+ TRC = PTX::RegI16RegisterClass;
+ OpCode = PTX::READPARAMI16;
+ }
+ else if (RegVT == MVT::i32) {
+ TRC = PTX::RegI32RegisterClass;
+ OpCode = PTX::READPARAMI32;
+ }
+ else if (RegVT == MVT::i64) {
+ TRC = PTX::RegI64RegisterClass;
+ OpCode = PTX::READPARAMI64;
+ }
+ else if (RegVT == MVT::f32) {
+ TRC = PTX::RegF32RegisterClass;
+ OpCode = PTX::READPARAMF32;
+ }
+ else if (RegVT == MVT::f64) {
+ TRC = PTX::RegF64RegisterClass;
+ OpCode = PTX::READPARAMF64;
+ }
+ else {
+ llvm_unreachable("Unknown parameter type");
+ }
+
+ // Use a unique index in the instruction to prevent instruction folding.
+ // Yes, this is a hack.
+ SDValue Index = DAG.getTargetConstant(i, MVT::i32);
+ unsigned Reg = MF.getRegInfo().createVirtualRegister(TRC);
+ SDValue ArgValue = DAG.getNode(PTXISD::READ_PARAM, dl, RegVT, Chain,
+ Index);
+
+ SDValue Flag = ArgValue.getValue(1);
+
+ SDValue Copy = DAG.getCopyFromReg(Chain, dl, Reg, RegVT);
+ SDValue RegValue = DAG.getRegister(Reg, RegVT);
+ InVals.push_back(ArgValue);
+
+ MFI->addArgReg(Reg);
+ }
+ }
+
+ return Chain;
+}
+
+SDValue PTXTargetLowering::
+ LowerReturn(SDValue Chain,
+ CallingConv::ID CallConv,
+ bool isVarArg,
+ const SmallVectorImpl<ISD::OutputArg> &Outs,
+ const SmallVectorImpl<SDValue> &OutVals,
+ DebugLoc dl,
+ SelectionDAG &DAG) const {
+ if (isVarArg) llvm_unreachable("PTX does not support varargs");
+
+ switch (CallConv) {
+ default:
+ llvm_unreachable("Unsupported calling convention.");
+ case CallingConv::PTX_Kernel:
+ assert(Outs.size() == 0 && "Kernel must return void.");
+ return DAG.getNode(PTXISD::EXIT, dl, MVT::Other, Chain);
+ case CallingConv::PTX_Device:
+ assert(Outs.size() <= 1 && "Can at most return one value.");
+ break;
+ }
+
+ MachineFunction& MF = DAG.getMachineFunction();
+ PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
+
+ SDValue Flag;
+
+ // Even though we could use the .param space for return arguments for
+ // device functions if SM >= 2.0 and the number of return arguments is
+ // only 1, we just always use registers since this makes the codegen
+ // easier.
+
+ const PTXSubtarget& ST = getTargetMachine().getSubtarget<PTXSubtarget>();
+
+ if (ST.useParamSpaceForDeviceArgs()) {
+ assert(Outs.size() < 2 && "Device functions can return at most one value");
+
+ if (Outs.size() == 1) {
+ unsigned Size = OutVals[0].getValueType().getSizeInBits();
+ SDValue Index = DAG.getTargetConstant(MFI->getNextParam(Size), MVT::i32);
+ Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain,
+ Index, OutVals[0]);
+
+ //Flag = Chain.getValue(1);
+ MFI->setRetParamSize(Outs[0].VT.getStoreSizeInBits());
+ }
+ } else {
+ //SmallVector<CCValAssign, 16> RVLocs;
+ //CCState CCInfo(CallConv, isVarArg, DAG.getMachineFunction(),
+ //getTargetMachine(), RVLocs, *DAG.getContext());
+
+ //CCInfo.AnalyzeReturn(Outs, RetCC_PTX);
+
+ //for (unsigned i = 0, e = RVLocs.size(); i != e; ++i) {
+ //CCValAssign& VA = RVLocs[i];
+
+ for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
+
+ //assert(VA.isRegLoc() && "CCValAssign must be RegLoc");
+
+ //unsigned Reg = VA.getLocReg();
+
+ EVT RegVT = Outs[i].VT;
+ TargetRegisterClass* TRC = 0;
// Determine which register class we need
if (RegVT == MVT::i1) {
@@ -276,72 +393,28 @@
}
unsigned Reg = MF.getRegInfo().createVirtualRegister(TRC);
- MF.getRegInfo().addLiveIn(VA.getLocReg(), Reg);
- SDValue ArgValue = DAG.getCopyFromReg(Chain, dl, Reg, RegVT);
- InVals.push_back(ArgValue);
+ //DAG.getMachineFunction().getRegInfo().addLiveOut(Reg);
- MFI->addArgReg(VA.getLocReg());
+ //Chain = DAG.getCopyToReg(Chain, dl, Reg, OutVals[i], Flag);
+ //SDValue Copy = DAG.getCopyToReg(Chain, dl, Reg, OutVals[i]/*, Flag*/);
+
+ // Guarantee that all emitted copies are stuck together,
+ // avoiding something bad
+ //Flag = Chain.getValue(1);
+
+ SDValue Copy = DAG.getCopyToReg(Chain, dl, Reg, OutVals[i]/*, Flag*/);
+ SDValue OutReg = DAG.getRegister(Reg, RegVT);
+
+ Chain = DAG.getNode(PTXISD::WRITE_PARAM, dl, MVT::Other, Copy, OutReg);
+ //Flag = Chain.getValue(1);
+
+ MFI->addRetReg(Reg);
+
+ //MFI->addRetReg(Reg);
}
}
- return Chain;
-}
-
-SDValue PTXTargetLowering::
- LowerReturn(SDValue Chain,
- CallingConv::ID CallConv,
- bool isVarArg,
- const SmallVectorImpl<ISD::OutputArg> &Outs,
- const SmallVectorImpl<SDValue> &OutVals,
- DebugLoc dl,
- SelectionDAG &DAG) const {
- if (isVarArg) llvm_unreachable("PTX does not support varargs");
-
- switch (CallConv) {
- default:
- llvm_unreachable("Unsupported calling convention.");
- case CallingConv::PTX_Kernel:
- assert(Outs.size() == 0 && "Kernel must return void.");
- return DAG.getNode(PTXISD::EXIT, dl, MVT::Other, Chain);
- case CallingConv::PTX_Device:
- //assert(Outs.size() <= 1 && "Can at most return one value.");
- break;
- }
-
- MachineFunction& MF = DAG.getMachineFunction();
- PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
-
- SDValue Flag;
-
- // Even though we could use the .param space for return arguments for
- // device functions if SM >= 2.0 and the number of return arguments is
- // only 1, we just always use registers since this makes the codegen
- // easier.
- SmallVector<CCValAssign, 16> RVLocs;
- CCState CCInfo(CallConv, isVarArg, DAG.getMachineFunction(),
- getTargetMachine(), RVLocs, *DAG.getContext());
-
- CCInfo.AnalyzeReturn(Outs, RetCC_PTX);
-
- for (unsigned i = 0, e = RVLocs.size(); i != e; ++i) {
- CCValAssign& VA = RVLocs[i];
-
- assert(VA.isRegLoc() && "CCValAssign must be RegLoc");
-
- unsigned Reg = VA.getLocReg();
-
- DAG.getMachineFunction().getRegInfo().addLiveOut(Reg);
-
- Chain = DAG.getCopyToReg(Chain, dl, Reg, OutVals[i], Flag);
-
- // Guarantee that all emitted copies are stuck together,
- // avoiding something bad
- Flag = Chain.getValue(1);
-
- MFI->addRetReg(Reg);
- }
-
if (Flag.getNode() == 0) {
return DAG.getNode(PTXISD::RET, dl, MVT::Other, Chain);
}