PTX: Generalize handling of .param types

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@140375 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Target/PTX/PTXISelLowering.cpp b/lib/Target/PTX/PTXISelLowering.cpp
index 2d7756e..7996728 100644
--- a/lib/Target/PTX/PTXISelLowering.cpp
+++ b/lib/Target/PTX/PTXISelLowering.cpp
@@ -199,6 +199,7 @@
   MachineFunction &MF = DAG.getMachineFunction();
   const PTXSubtarget& ST = getTargetMachine().getSubtarget<PTXSubtarget>();
   PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
+  PTXParamManager &PM = MFI->getParamManager();
 
   switch (CallConv) {
     default:
@@ -221,8 +222,10 @@
       assert((!MFI->isKernel() || Ins[i].VT != MVT::i1) &&
              "Kernels cannot take pred operands");
 
+      unsigned ParamSize = Ins[i].VT.getStoreSizeInBits();
+      unsigned Param = PM.addArgumentParam(ParamSize);
       SDValue ArgValue = DAG.getNode(PTXISD::LOAD_PARAM, dl, Ins[i].VT, Chain,
-                                     DAG.getTargetConstant(i, MVT::i32));
+                                     DAG.getTargetConstant(Param, MVT::i32));
       InVals.push_back(ArgValue);
 
       // Instead of storing a physical register in our argument list, we just
@@ -322,6 +325,7 @@
 
   MachineFunction& MF = DAG.getMachineFunction();
   PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
+  PTXParamManager &PM = MFI->getParamManager();
 
   SDValue Flag;
 
@@ -336,13 +340,15 @@
     assert(Outs.size() < 2 && "Device functions can return at most one value");
 
     if (Outs.size() == 1) {
-      unsigned Size = OutVals[0].getValueType().getSizeInBits();
-      SDValue Index = DAG.getTargetConstant(MFI->getNextParam(Size), MVT::i32);
+      unsigned ParamSize = OutVals[0].getValueType().getSizeInBits();
+      unsigned Param = PM.addReturnParam(ParamSize);
+      SDValue ParamIndex = DAG.getTargetConstant(Param, MVT::i32);
       Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain,
-                          Index, OutVals[0]);
+                          ParamIndex, OutVals[0]);
+
 
       //Flag = Chain.getValue(1);
-      MFI->setRetParamSize(Outs[0].VT.getStoreSizeInBits());
+      //MFI->setRetParamSize(Outs[0].VT.getStoreSizeInBits());
     }
   } else {
     //SmallVector<CCValAssign, 16> RVLocs;