Extend initial support for primitive types in PTX backend

- Allow i16, i32, i64, float, and double types, using the native .u16,
  .u32, .u64, .f32, and .f64 PTX types.
- Allow loading/storing of all primitive types.
- Allow primitive types to be passed as parameters.
- Allow selection of PTX Version and Shader Model as sub-target attributes.
- Merge integer/floating-point test cases for load/store.
- Use .u32 instead of .s32 to conform to output from NVidia nvcc compiler.

Patch by Justin Holewinski



git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@126824 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Target/PTX/PTXAsmPrinter.cpp b/lib/Target/PTX/PTXAsmPrinter.cpp
index 25f26fa..35eeadc 100644
--- a/lib/Target/PTX/PTXAsmPrinter.cpp
+++ b/lib/Target/PTX/PTXAsmPrinter.cpp
@@ -24,6 +24,7 @@
 #include "llvm/ADT/Twine.h"
 #include "llvm/CodeGen/AsmPrinter.h"
 #include "llvm/CodeGen/MachineInstr.h"
+#include "llvm/CodeGen/MachineRegisterInfo.h"
 #include "llvm/MC/MCStreamer.h"
 #include "llvm/MC/MCSymbol.h"
 #include "llvm/Target/Mangler.h"
@@ -37,13 +38,6 @@
 
 using namespace llvm;
 
-static cl::opt<std::string>
-OptPTXVersion("ptx-version", cl::desc("Set PTX version"), cl::init("1.4"));
-
-static cl::opt<std::string>
-OptPTXTarget("ptx-target", cl::desc("Set GPU target (comma-separated list)"),
-             cl::init("sm_10"));
-
 namespace {
 class PTXAsmPrinter : public AsmPrinter {
 public:
@@ -82,11 +76,14 @@
 static const char PARAM_PREFIX[] = "__param_";
 
 static const char *getRegisterTypeName(unsigned RegNo) {
-#define TEST_REGCLS(cls, clsstr) \
+#define TEST_REGCLS(cls, clsstr)                \
   if (PTX::cls ## RegisterClass->contains(RegNo)) return # clsstr;
-  TEST_REGCLS(RRegf32, f32);
-  TEST_REGCLS(RRegs32, s32);
   TEST_REGCLS(Preds, pred);
+  TEST_REGCLS(RRegu16, u16);
+  TEST_REGCLS(RRegu32, u32);
+  TEST_REGCLS(RRegu64, u64);
+  TEST_REGCLS(RRegf32, f32);
+  TEST_REGCLS(RRegf64, f64);
 #undef TEST_REGCLS
 
   llvm_unreachable("Not in any register class!");
@@ -121,7 +118,14 @@
     switch (type->getTypeID()) {
       default: llvm_unreachable("Unknown type");
       case Type::FloatTyID: return ".f32";
-      case Type::IntegerTyID: return ".s32";    // TODO:  Handle 64-bit types.
+      case Type::DoubleTyID: return ".f64";
+      case Type::IntegerTyID:
+        switch (type->getPrimitiveSizeInBits()) {
+          default: llvm_unreachable("Unknown integer bit-width");
+          case 16: return ".u16";
+          case 32: return ".u32";
+          case 64: return ".u64";
+        }
       case Type::ArrayTyID:
       case Type::PointerTyID:
         type = dyn_cast<const SequentialType>(type)->getElementType();
@@ -162,8 +166,11 @@
 
 void PTXAsmPrinter::EmitStartOfAsmFile(Module &M)
 {
-  OutStreamer.EmitRawText(Twine("\t.version " + OptPTXVersion));
-  OutStreamer.EmitRawText(Twine("\t.target " + OptPTXTarget));
+  const PTXSubtarget& ST = TM.getSubtarget<PTXSubtarget>();
+
+  OutStreamer.EmitRawText(Twine("\t.version " + ST.getPTXVersionString()));
+  OutStreamer.EmitRawText(Twine("\t.target " + ST.getTargetString() +
+                                (ST.supportsDouble() ? "" : ", map_f64_to_f32")));
   OutStreamer.AddBlankLine();
 
   // declare global variables
@@ -236,11 +243,24 @@
       break;
     case MachineOperand::MO_FPImmediate:
       APInt constFP = MO.getFPImm()->getValueAPF().bitcastToAPInt();
-      if (constFP.getZExtValue() > 0) {
-        OS << "0F" << constFP.toString(16, false);
+      bool  isFloat = MO.getFPImm()->getType()->getTypeID() == Type::FloatTyID;
+      // Emit 0F for 32-bit floats and 0D for 64-bit doubles.
+      if (isFloat) {
+        OS << "0F";
       }
       else {
-        OS << "0F00000000";
+        OS << "0D";
+      }
+      // Emit the encoded floating-point value.
+      if (constFP.getZExtValue() > 0) {
+        OS << constFP.toString(16, false);
+      }
+      else {
+        OS << "00000000";
+        // If We have a double-precision zero, pad to 8-bytes.
+        if (!isFloat) {
+          OS << "00000000";
+        }
       }
       break;
   }
@@ -338,12 +358,18 @@
   if (!MFI->argRegEmpty()) {
     decl += " (";
     if (isKernel) {
-      for (int i = 0, e = MFI->getNumArg(); i != e; ++i) {
-        if (i != 0)
+      unsigned cnt = 0;
+      //for (int i = 0, e = MFI->getNumArg(); i != e; ++i) {
+      for(PTXMachineFunctionInfo::reg_iterator
+          i = MFI->argRegBegin(), e = MFI->argRegEnd(), b = i; i != e; ++i) {
+        reg = *i;
+        assert(reg != PTX::NoRegister && "Not a valid register!");
+        if (i != b)
           decl += ", ";
-        decl += ".param .s32 "; // TODO: add types
+        decl += ".param .u32";  // TODO: Parse type from register map
+        decl += " ";
         decl += PARAM_PREFIX;
-        decl += utostr(i + 1);
+        decl += utostr(++cnt);
       }
     } else {
       for (PTXMachineFunctionInfo::reg_iterator