allow non-device function calls in PTX when natively handling device-side printf

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@144388 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Target/PTX/PTXISelLowering.cpp b/lib/Target/PTX/PTXISelLowering.cpp
index 3307d91..7f55871 100644
--- a/lib/Target/PTX/PTXISelLowering.cpp
+++ b/lib/Target/PTX/PTXISelLowering.cpp
@@ -20,6 +20,7 @@
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/CodeGen/CallingConvLower.h"
 #include "llvm/CodeGen/MachineFunction.h"
+#include "llvm/CodeGen/MachineFrameInfo.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
 #include "llvm/CodeGen/SelectionDAG.h"
 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
@@ -352,40 +353,101 @@
                              SmallVectorImpl<SDValue> &InVals) const {
 
   MachineFunction& MF = DAG.getMachineFunction();
-  PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
-  PTXParamManager &PM = MFI->getParamManager();
-
+  PTXMachineFunctionInfo *PTXMFI = MF.getInfo<PTXMachineFunctionInfo>();
+  PTXParamManager &PM = PTXMFI->getParamManager();
+  MachineFrameInfo *MFI = MF.getFrameInfo();
+  
   assert(getTargetMachine().getSubtarget<PTXSubtarget>().callsAreHandled() &&
          "Calls are not handled for the target device");
 
+  // Identify the callee function
+  const GlobalValue *GV = cast<GlobalAddressSDNode>(Callee)->getGlobal();
+  const Function *function = cast<Function>(GV);
+  
+  // allow non-device calls only for printf
+  bool isPrintf = function->getName() == "printf" || function->getName() == "puts";	
+  
+  assert((isPrintf || function->getCallingConv() == CallingConv::PTX_Device) &&
+			 "PTX function calls must be to PTX device functions");
+  
+  unsigned outSize = isPrintf ? 2 : Outs.size();
+  
   std::vector<SDValue> Ops;
   // The layout of the ops will be [Chain, #Ins, Ins, Callee, #Outs, Outs]
-  Ops.resize(Outs.size() + Ins.size() + 4);
+  Ops.resize(outSize + Ins.size() + 4);
 
   Ops[0] = Chain;
 
   // Identify the callee function
-  const GlobalValue *GV = cast<GlobalAddressSDNode>(Callee)->getGlobal();
-  assert(cast<Function>(GV)->getCallingConv() == CallingConv::PTX_Device &&
-         "PTX function calls must be to PTX device functions");
   Callee = DAG.getTargetGlobalAddress(GV, dl, getPointerTy());
   Ops[Ins.size()+2] = Callee;
 
-  // Generate STORE_PARAM nodes for each function argument.  In PTX, function
-  // arguments are explicitly stored into .param variables and passed as
-  // arguments. There is no register/stack-based calling convention in PTX.
-  Ops[Ins.size()+3] = DAG.getTargetConstant(OutVals.size(), MVT::i32);
-  for (unsigned i = 0; i != OutVals.size(); ++i) {
-    unsigned Size = OutVals[i].getValueType().getSizeInBits();
-    unsigned Param = PM.addLocalParam(Size);
-    const std::string &ParamName = PM.getParamName(Param);
-    SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(),
-                                                     MVT::Other);
+  // #Outs
+  Ops[Ins.size()+3] = DAG.getTargetConstant(outSize, MVT::i32);
+  
+  if (isPrintf) {
+    // first argument is the address of the global string variable in memory
+    unsigned Param0 = PM.addLocalParam(getPointerTy().getSizeInBits());
+    SDValue ParamValue0 = DAG.getTargetExternalSymbol(PM.getParamName(Param0).c_str(),
+                                                      MVT::Other);
     Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain,
-                        ParamValue, OutVals[i]);
-    Ops[i+Ins.size()+4] = ParamValue;
-  }
+                        ParamValue0, OutVals[0]);
+    Ops[Ins.size()+4] = ParamValue0;
+      
+    // alignment is the maximum size of all the arguments
+    unsigned alignment = 0;
+    for (unsigned i = 1; i < OutVals.size(); ++i) {
+      alignment = std::max(alignment, 
+    		               OutVals[i].getValueType().getSizeInBits());
+    }
 
+    // size is the alignment multiplied by the number of arguments
+    unsigned size = alignment * (OutVals.size() - 1);
+  
+    // second argument is the address of the stack object (unless no arguments)
+    unsigned Param1 = PM.addLocalParam(getPointerTy().getSizeInBits());
+    SDValue ParamValue1 = DAG.getTargetExternalSymbol(PM.getParamName(Param1).c_str(),
+                                                      MVT::Other);
+    Ops[Ins.size()+5] = ParamValue1;
+    
+    if (size > 0)
+    {
+      // create a local stack object to store the arguments
+      unsigned StackObject = MFI->CreateStackObject(size / 8, alignment / 8, false);
+      SDValue FrameIndex = DAG.getFrameIndex(StackObject, getPointerTy());
+	  
+      // store each of the arguments to the stack in turn
+      for (unsigned int i = 1; i != OutVals.size(); i++) {
+        SDValue FrameAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), FrameIndex, DAG.getTargetConstant((i - 1) * 8, getPointerTy()));
+        Chain = DAG.getStore(Chain, dl, OutVals[i], FrameAddr,
+                             MachinePointerInfo(),
+                             false, false, 0);
+      }
+
+      // copy the address of the local frame index to get the address in non-local space
+      SDValue genericAddr = DAG.getNode(PTXISD::COPY_ADDRESS, dl, getPointerTy(), FrameIndex);
+
+      // store this address in the second argument
+      Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain, ParamValue1, genericAddr);
+    }
+  }
+  else
+  {
+	  // Generate STORE_PARAM nodes for each function argument.  In PTX, function
+	  // arguments are explicitly stored into .param variables and passed as
+	  // arguments. There is no register/stack-based calling convention in PTX.
+	  for (unsigned i = 0; i != OutVals.size(); ++i) {
+		unsigned Size = OutVals[i].getValueType().getSizeInBits();
+		unsigned Param = PM.addLocalParam(Size);
+		const std::string &ParamName = PM.getParamName(Param);
+		SDValue ParamValue = DAG.getTargetExternalSymbol(ParamName.c_str(),
+														 MVT::Other);
+		Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain,
+							ParamValue, OutVals[i]);
+		Ops[i+Ins.size()+4] = ParamValue;
+	  }
+  }
+  
   std::vector<SDValue> InParams;
 
   // Generate list of .param variables to hold the return value(s).