Extend initial support for primitive types in PTX backend
- Allow i16, i32, i64, float, and double types, using the native .u16,
.u32, .u64, .f32, and .f64 PTX types.
- Allow loading/storing of all primitive types.
- Allow primitive types to be passed as parameters.
- Allow selection of PTX Version and Shader Model as sub-target attributes.
- Merge integer/floating-point test cases for load/store.
- Use .u32 instead of .s32 to conform to output from NVidia nvcc compiler.
Patch by Justin Holewinski
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@126824 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Target/PTX/PTXAsmPrinter.cpp b/lib/Target/PTX/PTXAsmPrinter.cpp
index 25f26fa..35eeadc 100644
--- a/lib/Target/PTX/PTXAsmPrinter.cpp
+++ b/lib/Target/PTX/PTXAsmPrinter.cpp
@@ -24,6 +24,7 @@
#include "llvm/ADT/Twine.h"
#include "llvm/CodeGen/AsmPrinter.h"
#include "llvm/CodeGen/MachineInstr.h"
+#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/MC/MCStreamer.h"
#include "llvm/MC/MCSymbol.h"
#include "llvm/Target/Mangler.h"
@@ -37,13 +38,6 @@
using namespace llvm;
-static cl::opt<std::string>
-OptPTXVersion("ptx-version", cl::desc("Set PTX version"), cl::init("1.4"));
-
-static cl::opt<std::string>
-OptPTXTarget("ptx-target", cl::desc("Set GPU target (comma-separated list)"),
- cl::init("sm_10"));
-
namespace {
class PTXAsmPrinter : public AsmPrinter {
public:
@@ -82,11 +76,14 @@
static const char PARAM_PREFIX[] = "__param_";
static const char *getRegisterTypeName(unsigned RegNo) {
-#define TEST_REGCLS(cls, clsstr) \
+#define TEST_REGCLS(cls, clsstr) \
if (PTX::cls ## RegisterClass->contains(RegNo)) return # clsstr;
- TEST_REGCLS(RRegf32, f32);
- TEST_REGCLS(RRegs32, s32);
TEST_REGCLS(Preds, pred);
+ TEST_REGCLS(RRegu16, u16);
+ TEST_REGCLS(RRegu32, u32);
+ TEST_REGCLS(RRegu64, u64);
+ TEST_REGCLS(RRegf32, f32);
+ TEST_REGCLS(RRegf64, f64);
#undef TEST_REGCLS
llvm_unreachable("Not in any register class!");
@@ -121,7 +118,14 @@
switch (type->getTypeID()) {
default: llvm_unreachable("Unknown type");
case Type::FloatTyID: return ".f32";
- case Type::IntegerTyID: return ".s32"; // TODO: Handle 64-bit types.
+ case Type::DoubleTyID: return ".f64";
+ case Type::IntegerTyID:
+ switch (type->getPrimitiveSizeInBits()) {
+ default: llvm_unreachable("Unknown integer bit-width");
+ case 16: return ".u16";
+ case 32: return ".u32";
+ case 64: return ".u64";
+ }
case Type::ArrayTyID:
case Type::PointerTyID:
type = dyn_cast<const SequentialType>(type)->getElementType();
@@ -162,8 +166,11 @@
void PTXAsmPrinter::EmitStartOfAsmFile(Module &M)
{
- OutStreamer.EmitRawText(Twine("\t.version " + OptPTXVersion));
- OutStreamer.EmitRawText(Twine("\t.target " + OptPTXTarget));
+ const PTXSubtarget& ST = TM.getSubtarget<PTXSubtarget>();
+
+ OutStreamer.EmitRawText(Twine("\t.version " + ST.getPTXVersionString()));
+ OutStreamer.EmitRawText(Twine("\t.target " + ST.getTargetString() +
+ (ST.supportsDouble() ? "" : ", map_f64_to_f32")));
OutStreamer.AddBlankLine();
// declare global variables
@@ -236,11 +243,24 @@
break;
case MachineOperand::MO_FPImmediate:
APInt constFP = MO.getFPImm()->getValueAPF().bitcastToAPInt();
- if (constFP.getZExtValue() > 0) {
- OS << "0F" << constFP.toString(16, false);
+ bool isFloat = MO.getFPImm()->getType()->getTypeID() == Type::FloatTyID;
+ // Emit 0F for 32-bit floats and 0D for 64-bit doubles.
+ if (isFloat) {
+ OS << "0F";
}
else {
- OS << "0F00000000";
+ OS << "0D";
+ }
+ // Emit the encoded floating-point value.
+ if (constFP.getZExtValue() > 0) {
+ OS << constFP.toString(16, false);
+ }
+ else {
+ OS << "00000000";
+ // If We have a double-precision zero, pad to 8-bytes.
+ if (!isFloat) {
+ OS << "00000000";
+ }
}
break;
}
@@ -338,12 +358,18 @@
if (!MFI->argRegEmpty()) {
decl += " (";
if (isKernel) {
- for (int i = 0, e = MFI->getNumArg(); i != e; ++i) {
- if (i != 0)
+ unsigned cnt = 0;
+ //for (int i = 0, e = MFI->getNumArg(); i != e; ++i) {
+ for(PTXMachineFunctionInfo::reg_iterator
+ i = MFI->argRegBegin(), e = MFI->argRegEnd(), b = i; i != e; ++i) {
+ reg = *i;
+ assert(reg != PTX::NoRegister && "Not a valid register!");
+ if (i != b)
decl += ", ";
- decl += ".param .s32 "; // TODO: add types
+ decl += ".param .u32"; // TODO: Parse type from register map
+ decl += " ";
decl += PARAM_PREFIX;
- decl += utostr(i + 1);
+ decl += utostr(++cnt);
}
} else {
for (PTXMachineFunctionInfo::reg_iterator