PTX: Use .param space for parameters in device functions for SM >= 2.0

FIXME: DCE is eliminating the final st.param.x calls, figure out why
llvm-svn: 133732
diff --git a/llvm/lib/Target/PTX/PTXISelLowering.cpp b/llvm/lib/Target/PTX/PTXISelLowering.cpp
index c3cdaba..782d916 100644
--- a/llvm/lib/Target/PTX/PTXISelLowering.cpp
+++ b/llvm/lib/Target/PTX/PTXISelLowering.cpp
@@ -15,6 +15,7 @@
 #include "PTXISelLowering.h"
 #include "PTXMachineFunctionInfo.h"
 #include "PTXRegisterInfo.h"
+#include "PTXSubtarget.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/CodeGen/CallingConvLower.h"
 #include "llvm/CodeGen/MachineFunction.h"
@@ -106,6 +107,8 @@
       return "PTXISD::COPY_ADDRESS";
     case PTXISD::READ_PARAM:
       return "PTXISD::READ_PARAM";
+    case PTXISD::STORE_PARAM:
+      return "PTXISD::STORE_PARAM";
     case PTXISD::EXIT:
       return "PTXISD::EXIT";
     case PTXISD::RET:
@@ -192,6 +195,7 @@
   if (isVarArg) llvm_unreachable("PTX does not support varargs");
 
   MachineFunction &MF = DAG.getMachineFunction();
+  const PTXSubtarget& ST = getTargetMachine().getSubtarget<PTXSubtarget>();
   PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
 
   switch (CallConv) {
@@ -206,11 +210,16 @@
       break;
   }
 
-  if (MFI->isKernel()) {
-    // For kernel functions, we just need to emit the proper READ_PARAM ISDs
+  // We do one of two things here:
+  // IsKernel || SM >= 2.0  ->  Use param space for arguments
+  // SM < 2.0               ->  Use registers for arguments
+  
+  if (MFI->isKernel() || ST.getShaderModel() >= PTXSubtarget::PTX_SM_2_0) {
+    // We just need to emit the proper READ_PARAM ISDs
     for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
 
-      assert(Ins[i].VT != MVT::i1 && "Kernels cannot take pred operands");
+      assert((!MFI->isKernel() || Ins[i].VT != MVT::i1) &&
+             "Kernels cannot take pred operands");
 
       SDValue ArgValue = DAG.getNode(PTXISD::READ_PARAM, dl, Ins[i].VT, Chain,
                                      DAG.getTargetConstant(i, MVT::i32));
@@ -299,31 +308,49 @@
 
   MachineFunction& MF = DAG.getMachineFunction();
   PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
-  SmallVector<CCValAssign, 16> RVLocs;
-  CCState CCInfo(CallConv, isVarArg, DAG.getMachineFunction(),
-                 getTargetMachine(), RVLocs, *DAG.getContext());
+  const PTXSubtarget& ST = getTargetMachine().getSubtarget<PTXSubtarget>();
 
   SDValue Flag;
 
-  CCInfo.AnalyzeReturn(Outs, RetCC_PTX);
+  if (ST.getShaderModel() >= PTXSubtarget::PTX_SM_2_0) {
+    // For SM 2.0+, we return arguments in the param space
+    for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
+      SDVTList VTs = DAG.getVTList(MVT::Other, MVT::Glue);
+      SDValue ParamIndex = DAG.getTargetConstant(i, MVT::i32);
+      SDValue Ops[] = { Chain, ParamIndex, OutVals[i], Flag };
+      Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, VTs, Ops,
+                          Flag.getNode() ? 4 : 3);
+      Flag = Chain.getValue(1);
+      // Instead of storing a physical register in our argument list, we just
+      // store the total size of the parameter, in bits.  The ASM printer
+      // knows how to process this.
+      MFI->addRetReg(Outs[i].VT.getStoreSizeInBits());
+    }
+  } else {
+    // For SM < 2.0, we return arguments in registers
+    SmallVector<CCValAssign, 16> RVLocs;
+    CCState CCInfo(CallConv, isVarArg, DAG.getMachineFunction(),
+    getTargetMachine(), RVLocs, *DAG.getContext());
 
-  for (unsigned i = 0, e = RVLocs.size(); i != e; ++i) {
+    CCInfo.AnalyzeReturn(Outs, RetCC_PTX);
 
-    CCValAssign& VA  = RVLocs[i];
+    for (unsigned i = 0, e = RVLocs.size(); i != e; ++i) {
+      CCValAssign& VA  = RVLocs[i];
 
-    assert(VA.isRegLoc() && "CCValAssign must be RegLoc");
+      assert(VA.isRegLoc() && "CCValAssign must be RegLoc");
 
-    unsigned Reg = VA.getLocReg();
+      unsigned Reg = VA.getLocReg();
 
-    DAG.getMachineFunction().getRegInfo().addLiveOut(Reg);
+      DAG.getMachineFunction().getRegInfo().addLiveOut(Reg);
 
-    Chain = DAG.getCopyToReg(Chain, dl, Reg, OutVals[i], Flag);
+      Chain = DAG.getCopyToReg(Chain, dl, Reg, OutVals[i], Flag);
 
-    // Guarantee that all emitted copies are stuck together,
-    // avoiding something bad
-    Flag = Chain.getValue(1);
+      // Guarantee that all emitted copies are stuck together,
+      // avoiding something bad
+      Flag = Chain.getValue(1);
 
-    MFI->addRetReg(Reg);
+      MFI->addRetReg(Reg);
+    }
   }
 
   if (Flag.getNode() == 0) {