PTX: Add initial support for device function calls
- Calls are supported on SM 2.0+ for function with no return values
llvm-svn: 137125
diff --git a/llvm/lib/Target/PTX/PTXISelLowering.cpp b/llvm/lib/Target/PTX/PTXISelLowering.cpp
index 6fcf710..d52aa2a 100644
--- a/llvm/lib/Target/PTX/PTXISelLowering.cpp
+++ b/llvm/lib/Target/PTX/PTXISelLowering.cpp
@@ -22,6 +22,7 @@
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/SelectionDAG.h"
#include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
+#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
using namespace llvm;
@@ -134,6 +135,8 @@
return "PTXISD::EXIT";
case PTXISD::RET:
return "PTXISD::RET";
+ case PTXISD::CALL:
+ return "PTXISD::CALL";
}
}
@@ -345,3 +348,49 @@
return DAG.getNode(PTXISD::RET, dl, MVT::Other, Chain, Flag);
}
}
+
+SDValue
+PTXTargetLowering::LowerCall(SDValue Chain, SDValue Callee,
+ CallingConv::ID CallConv, bool isVarArg,
+ bool &isTailCall,
+ const SmallVectorImpl<ISD::OutputArg> &Outs,
+ const SmallVectorImpl<SDValue> &OutVals,
+ const SmallVectorImpl<ISD::InputArg> &Ins,
+ DebugLoc dl, SelectionDAG &DAG,
+ SmallVectorImpl<SDValue> &InVals) const {
+
+ MachineFunction& MF = DAG.getMachineFunction();
+ PTXMachineFunctionInfo *MFI = MF.getInfo<PTXMachineFunctionInfo>();
+ const PTXSubtarget& ST = getTargetMachine().getSubtarget<PTXSubtarget>();
+
+ assert(ST.callsAreHandled() && "Calls are not handled for the target device");
+
+ // Is there a more "LLVM"-way to create a variable-length array of values?
+ SDValue* ops = new SDValue[OutVals.size() + 2];
+
+ ops[0] = Chain;
+
+ if (GlobalAddressSDNode *G = dyn_cast<GlobalAddressSDNode>(Callee)) {
+ const GlobalValue *GV = G->getGlobal();
+ Callee = DAG.getTargetGlobalAddress(GV, dl, getPointerTy());
+ ops[1] = Callee;
+ } else {
+ assert(false && "Function must be a GlobalAddressSDNode");
+ }
+
+ for (unsigned i = 0; i != OutVals.size(); ++i) {
+ unsigned Size = OutVals[i].getValueType().getSizeInBits();
+ SDValue Index = DAG.getTargetConstant(MFI->getNextParam(Size), MVT::i32);
+ Chain = DAG.getNode(PTXISD::STORE_PARAM, dl, MVT::Other, Chain,
+ Index, OutVals[i]);
+ ops[i+2] = Index;
+ }
+
+ ops[0] = Chain;
+
+ Chain = DAG.getNode(PTXISD::CALL, dl, MVT::Other, ops, OutVals.size()+2);
+
+ delete [] ops;
+
+ return Chain;
+}