Add lower argument and return of device function


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@116805 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Target/PTX/PTXISelLowering.cpp b/lib/Target/PTX/PTXISelLowering.cpp
index d38abf1..6e68c37 100644
--- a/lib/Target/PTX/PTXISelLowering.cpp
+++ b/lib/Target/PTX/PTXISelLowering.cpp
@@ -11,9 +11,12 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "PTX.h"
 #include "PTXISelLowering.h"
 #include "PTXRegisterInfo.h"
 #include "llvm/Support/ErrorHandling.h"
+#include "llvm/CodeGen/MachineFunction.h"
+#include "llvm/CodeGen/MachineRegisterInfo.h"
 #include "llvm/CodeGen/SelectionDAG.h"
 #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
 
@@ -22,7 +25,8 @@
 PTXTargetLowering::PTXTargetLowering(TargetMachine &TM)
   : TargetLowering(TM, new TargetLoweringObjectFileELF()) {
   // Set up the register classes.
-  addRegisterClass(MVT::i1, PTX::PredsRegisterClass);
+  addRegisterClass(MVT::i1,  PTX::PredsRegisterClass);
+  addRegisterClass(MVT::i32, PTX::RRegs32RegisterClass);
 
   // Compute derived properties from the register classes
   computeRegisterProperties();
@@ -40,6 +44,57 @@
 //                      Calling Convention Implementation
 //===----------------------------------------------------------------------===//
 
+static struct argmap_entry {
+  MVT::SimpleValueType VT;
+  TargetRegisterClass *RC;
+  TargetRegisterClass::iterator loc;
+
+  argmap_entry(MVT::SimpleValueType _VT, TargetRegisterClass *_RC)
+    : VT(_VT), RC(_RC), loc(_RC->begin()) {}
+
+  void reset(void) { loc = RC->begin(); }
+  bool operator==(MVT::SimpleValueType _VT) { return VT == _VT; }
+} argmap[] = {
+  argmap_entry(MVT::i1,  PTX::PredsRegisterClass),
+  argmap_entry(MVT::i32, PTX::RRegs32RegisterClass)
+};
+
+static SDValue lower_kernel_argument(int i,
+                                     SDValue Chain,
+                                     DebugLoc dl,
+                                     MVT::SimpleValueType VT,
+                                     argmap_entry *entry,
+                                     SelectionDAG &DAG,
+                                     unsigned *argreg) {
+  // TODO
+  llvm_unreachable("Not implemented yet");
+}
+
+static SDValue lower_device_argument(int i,
+                                     SDValue Chain,
+                                     DebugLoc dl,
+                                     MVT::SimpleValueType VT,
+                                     argmap_entry *entry,
+                                     SelectionDAG &DAG,
+                                     unsigned *argreg) {
+  MachineRegisterInfo &RegInfo = DAG.getMachineFunction().getRegInfo();
+
+  unsigned preg = *++(entry->loc); // allocate start from register 1
+  unsigned vreg = RegInfo.createVirtualRegister(entry->RC);
+  RegInfo.addLiveIn(preg, vreg);
+
+  *argreg = preg;
+  return DAG.getCopyFromReg(Chain, dl, vreg, VT);
+}
+
+typedef SDValue (*lower_argument_func)(int i,
+                                       SDValue Chain,
+                                       DebugLoc dl,
+                                       MVT::SimpleValueType VT,
+                                       argmap_entry *entry,
+                                       SelectionDAG &DAG,
+                                       unsigned *argreg);
+
 SDValue PTXTargetLowering::
   LowerFormalArguments(SDValue Chain,
                        CallingConv::ID CallConv,
@@ -48,6 +103,40 @@
                        DebugLoc dl,
                        SelectionDAG &DAG,
                        SmallVectorImpl<SDValue> &InVals) const {
+  if (isVarArg) llvm_unreachable("PTX does not support varargs");
+
+  lower_argument_func lower_argument;
+
+  switch (CallConv) {
+    default:
+      llvm_unreachable("Unsupported calling convention");
+      break;
+    case CallingConv::PTX_Kernel:
+      lower_argument = lower_kernel_argument;
+      break;
+    case CallingConv::PTX_Device:
+      lower_argument = lower_device_argument;
+      break;
+  }
+
+  // Reset argmap before allocation
+  for (struct argmap_entry *i = argmap, *e = argmap + array_lengthof(argmap);
+       i != e; ++ i)
+    i->reset();
+
+  for (int i = 0, e = Ins.size(); i != e; ++ i) {
+    MVT::SimpleValueType VT = Ins[i].VT.getSimpleVT().SimpleTy;
+
+    struct argmap_entry *entry = std::find(argmap,
+                                           argmap + array_lengthof(argmap), VT);
+    if (entry == argmap + array_lengthof(argmap))
+      llvm_unreachable("Type of argument is not supported");
+
+    unsigned reg;
+    SDValue arg = lower_argument(i, Chain, dl, VT, entry, DAG, &reg);
+    InVals.push_back(arg);
+  }
+
   return Chain;
 }
 
@@ -59,7 +148,7 @@
               const SmallVectorImpl<SDValue> &OutVals,
               DebugLoc dl,
               SelectionDAG &DAG) const {
-  assert(!isVarArg && "PTX does not support var args.");
+  if (isVarArg) llvm_unreachable("PTX does not support varargs");
 
   switch (CallConv) {
     default:
@@ -74,10 +163,26 @@
 
   // PTX_Device
 
+  // return void
   if (Outs.size() == 0)
     return DAG.getNode(PTXISD::RET, dl, MVT::Other, Chain);
 
-  // TODO: allocate return register
+  assert(Outs[0].VT == MVT::i32 && "Can return only basic types");
+
   SDValue Flag;
+  unsigned reg = PTX::R0;
+
+  // If this is the first return lowered for this function, add the regs to the
+  // liveout set for the function
+  if (DAG.getMachineFunction().getRegInfo().liveout_empty())
+    DAG.getMachineFunction().getRegInfo().addLiveOut(reg);
+
+  // Copy the result values into the output registers
+  Chain = DAG.getCopyToReg(Chain, dl, reg, OutVals[0], Flag);
+
+  // Guarantee that all emitted copies are stuck together,
+  // avoiding something bad
+  Flag = Chain.getValue(1);
+
   return DAG.getNode(PTXISD::RET, dl, MVT::Other, Chain, Flag);
 }