PTX: Use .param space for device function return values on SM 2.0+, and attempt
to fix up parameter passing on SM < 2.0
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@140309 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Target/PTX/PTXAsmPrinter.cpp b/lib/Target/PTX/PTXAsmPrinter.cpp
index f936d4b..6337ee9 100644
--- a/lib/Target/PTX/PTXAsmPrinter.cpp
+++ b/lib/Target/PTX/PTXAsmPrinter.cpp
@@ -91,9 +91,13 @@
static const char PARAM_PREFIX[] = "__param_";
static const char RETURN_PREFIX[] = "__ret_";
-static const char *getRegisterTypeName(unsigned RegNo) {
-#define TEST_REGCLS(cls, clsstr) \
- if (PTX::cls ## RegisterClass->contains(RegNo)) return # clsstr;
+static const char *getRegisterTypeName(unsigned RegNo,
+ const MachineRegisterInfo& MRI) {
+ const TargetRegisterClass *TRC = MRI.getRegClass(RegNo);
+
+#define TEST_REGCLS(cls, clsstr) \
+ if (PTX::cls ## RegisterClass == TRC) return # clsstr;
+
TEST_REGCLS(RegPred, pred);
TEST_REGCLS(RegI16, b16);
TEST_REGCLS(RegI32, b32);
@@ -288,18 +292,18 @@
}
}
- unsigned Index = 1;
+ //unsigned Index = 1;
// Print parameter passing params
- for (PTXMachineFunctionInfo::param_iterator
- i = MFI->paramBegin(), e = MFI->paramEnd(); i != e; ++i) {
- std::string def = "\t.param .b";
- def += utostr(*i);
- def += " __ret_";
- def += utostr(Index);
- Index++;
- def += ";";
- OutStreamer.EmitRawText(Twine(def));
- }
+ //for (PTXMachineFunctionInfo::param_iterator
+ // i = MFI->paramBegin(), e = MFI->paramEnd(); i != e; ++i) {
+ // std::string def = "\t.param .b";
+ // def += utostr(*i);
+ // def += " __ret_";
+ // def += utostr(Index);
+ // Index++;
+ // def += ";";
+ // OutStreamer.EmitRawText(Twine(def));
+ //}
}
void PTXAsmPrinter::EmitInstruction(const MachineInstr *MI) {
@@ -436,7 +440,8 @@
void PTXAsmPrinter::printReturnOperand(const MachineInstr *MI, int opNum,
raw_ostream &OS, const char *Modifier) {
- OS << RETURN_PREFIX << (int) MI->getOperand(opNum).getImm() + 1;
+ //OS << RETURN_PREFIX << (int) MI->getOperand(opNum).getImm() + 1;
+ OS << "__ret";
}
void PTXAsmPrinter::EmitVariableDeclaration(const GlobalVariable *gv) {
@@ -559,6 +564,7 @@
const PTXMachineFunctionInfo *MFI = MF->getInfo<PTXMachineFunctionInfo>();
const bool isKernel = MFI->isKernel();
const PTXSubtarget& ST = TM.getSubtarget<PTXSubtarget>();
+ const MachineRegisterInfo& MRI = MF->getRegInfo();
std::string decl = isKernel ? ".entry" : ".func";
@@ -566,16 +572,22 @@
if (!isKernel) {
decl += " (";
- for (PTXMachineFunctionInfo::ret_iterator
- i = MFI->retRegBegin(), e = MFI->retRegEnd(), b = i;
- i != e; ++i) {
- if (i != b) {
- decl += ", ";
+ if (ST.useParamSpaceForDeviceArgs() && MFI->getRetParamSize() != 0) {
+ decl += ".param .b";
+ decl += utostr(MFI->getRetParamSize());
+ decl += " __ret";
+ } else {
+ for (PTXMachineFunctionInfo::ret_iterator
+ i = MFI->retRegBegin(), e = MFI->retRegEnd(), b = i;
+ i != e; ++i) {
+ if (i != b) {
+ decl += ", ";
+ }
+ decl += ".reg .";
+ decl += getRegisterTypeName(*i, MRI);
+ decl += " ";
+ decl += MFI->getRegisterName(*i);
}
- decl += ".reg .";
- decl += getRegisterTypeName(*i);
- decl += " ";
- decl += getRegisterName(*i);
}
decl += ")";
}
@@ -589,23 +601,32 @@
cnt = 0;
// Print parameters
- for (PTXMachineFunctionInfo::reg_iterator
- i = MFI->argRegBegin(), e = MFI->argRegEnd(), b = i;
- i != e; ++i) {
- if (i != b) {
- decl += ", ";
- }
- if (isKernel || ST.useParamSpaceForDeviceArgs()) {
+ if (isKernel || ST.useParamSpaceForDeviceArgs()) {
+ for (PTXMachineFunctionInfo::argparam_iterator
+ i = MFI->argParamBegin(), e = MFI->argParamEnd(), b = i;
+ i != e; ++i) {
+ if (i != b) {
+ decl += ", ";
+ }
+
decl += ".param .b";
decl += utostr(*i);
decl += " ";
decl += PARAM_PREFIX;
decl += utostr(++cnt);
- } else {
+ }
+ } else {
+ for (PTXMachineFunctionInfo::reg_iterator
+ i = MFI->argRegBegin(), e = MFI->argRegEnd(), b = i;
+ i != e; ++i) {
+ if (i != b) {
+ decl += ", ";
+ }
+
decl += ".reg .";
- decl += getRegisterTypeName(*i);
+ decl += getRegisterTypeName(*i, MRI);
decl += " ";
- decl += getRegisterName(*i);
+ decl += MFI->getRegisterName(*i);
}
}
decl += ")";