[NVPTX] Added support for half-precision floating point.

Only scalar half-precision operations are supported at the moment.

- Adds general support for 'half' type in NVPTX.
- fp16 math operations are supported on sm_53+ GPUs only
  (can be disabled with --nvptx-no-f16-math).
- Type conversions to/from fp16 are supported on all GPU variants.
- On GPU variants that do not have full fp16 support (or if it's disabled),
  fp16 operations are promoted to fp32 and results are converted back
  to fp16 for storage.

Differential Revision: https://reviews.llvm.org/D28540

llvm-svn: 291956
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 4f3129c..6548dad 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -42,7 +42,6 @@
            cl::desc("NVPTX Specific: Flush f32 subnormals to sign-preserving zero."),
            cl::init(false));
 
-
 /// createNVPTXISelDag - This pass converts a legalized DAG into a
 /// NVPTX-specific DAG, ready for instruction scheduling.
 FunctionPass *llvm::createNVPTXISelDag(NVPTXTargetMachine &TM,
@@ -520,6 +519,10 @@
   case ISD::ADDRSPACECAST:
     SelectAddrSpaceCast(N);
     return;
+  case ISD::ConstantFP:
+    if (tryConstantFP16(N))
+      return;
+    break;
   default:
     break;
   }
@@ -541,6 +544,19 @@
   }
 }
 
+// There's no way to specify FP16 immediates in .f16 ops, so we have to
+// load them into an .f16 register first.
+bool NVPTXDAGToDAGISel::tryConstantFP16(SDNode *N) {
+  if (N->getValueType(0) != MVT::f16)
+    return false;
+  SDValue Val = CurDAG->getTargetConstantFP(
+      cast<ConstantFPSDNode>(N)->getValueAPF(), SDLoc(N), MVT::f16);
+  SDNode *LoadConstF16 =
+      CurDAG->getMachineNode(NVPTX::LOAD_CONST_F16, SDLoc(N), MVT::f16, Val);
+  ReplaceNode(N, LoadConstF16);
+  return true;
+}
+
 static unsigned int getCodeAddrSpace(MemSDNode *N) {
   const Value *Src = N->getMemOperand()->getValue();
 
@@ -740,7 +756,9 @@
   if ((LD->getExtensionType() == ISD::SEXTLOAD))
     fromType = NVPTX::PTXLdStInstCode::Signed;
   else if (ScalarVT.isFloatingPoint())
-    fromType = NVPTX::PTXLdStInstCode::Float;
+    // f16 uses .b16 as its storage type.
+    fromType = ScalarVT.SimpleTy == MVT::f16 ? NVPTX::PTXLdStInstCode::Untyped
+                                             : NVPTX::PTXLdStInstCode::Float;
   else
     fromType = NVPTX::PTXLdStInstCode::Unsigned;
 
@@ -766,6 +784,9 @@
     case MVT::i64:
       Opcode = NVPTX::LD_i64_avar;
       break;
+    case MVT::f16:
+      Opcode = NVPTX::LD_f16_avar;
+      break;
     case MVT::f32:
       Opcode = NVPTX::LD_f32_avar;
       break;
@@ -794,6 +815,9 @@
     case MVT::i64:
       Opcode = NVPTX::LD_i64_asi;
       break;
+    case MVT::f16:
+      Opcode = NVPTX::LD_f16_asi;
+      break;
     case MVT::f32:
       Opcode = NVPTX::LD_f32_asi;
       break;
@@ -823,6 +847,9 @@
       case MVT::i64:
         Opcode = NVPTX::LD_i64_ari_64;
         break;
+      case MVT::f16:
+        Opcode = NVPTX::LD_f16_ari_64;
+        break;
       case MVT::f32:
         Opcode = NVPTX::LD_f32_ari_64;
         break;
@@ -846,6 +873,9 @@
       case MVT::i64:
         Opcode = NVPTX::LD_i64_ari;
         break;
+      case MVT::f16:
+        Opcode = NVPTX::LD_f16_ari;
+        break;
       case MVT::f32:
         Opcode = NVPTX::LD_f32_ari;
         break;
@@ -875,6 +905,9 @@
       case MVT::i64:
         Opcode = NVPTX::LD_i64_areg_64;
         break;
+      case MVT::f16:
+        Opcode = NVPTX::LD_f16_areg_64;
+        break;
       case MVT::f32:
         Opcode = NVPTX::LD_f32_areg_64;
         break;
@@ -898,6 +931,9 @@
       case MVT::i64:
         Opcode = NVPTX::LD_i64_areg;
         break;
+      case MVT::f16:
+        Opcode = NVPTX::LD_f16_areg;
+        break;
       case MVT::f32:
         Opcode = NVPTX::LD_f32_areg;
         break;
@@ -2173,7 +2209,9 @@
   unsigned toTypeWidth = ScalarVT.getSizeInBits();
   unsigned int toType;
   if (ScalarVT.isFloatingPoint())
-    toType = NVPTX::PTXLdStInstCode::Float;
+    // f16 uses .b16 as its storage type.
+    toType = ScalarVT.SimpleTy == MVT::f16 ? NVPTX::PTXLdStInstCode::Untyped
+                                           : NVPTX::PTXLdStInstCode::Float;
   else
     toType = NVPTX::PTXLdStInstCode::Unsigned;
 
@@ -2200,6 +2238,9 @@
     case MVT::i64:
       Opcode = NVPTX::ST_i64_avar;
       break;
+    case MVT::f16:
+      Opcode = NVPTX::ST_f16_avar;
+      break;
     case MVT::f32:
       Opcode = NVPTX::ST_f32_avar;
       break;
@@ -2229,6 +2270,9 @@
     case MVT::i64:
       Opcode = NVPTX::ST_i64_asi;
       break;
+    case MVT::f16:
+      Opcode = NVPTX::ST_f16_asi;
+      break;
     case MVT::f32:
       Opcode = NVPTX::ST_f32_asi;
       break;
@@ -2259,6 +2303,9 @@
       case MVT::i64:
         Opcode = NVPTX::ST_i64_ari_64;
         break;
+      case MVT::f16:
+        Opcode = NVPTX::ST_f16_ari_64;
+        break;
       case MVT::f32:
         Opcode = NVPTX::ST_f32_ari_64;
         break;
@@ -2282,6 +2329,9 @@
       case MVT::i64:
         Opcode = NVPTX::ST_i64_ari;
         break;
+      case MVT::f16:
+        Opcode = NVPTX::ST_f16_ari;
+        break;
       case MVT::f32:
         Opcode = NVPTX::ST_f32_ari;
         break;
@@ -2312,6 +2362,9 @@
       case MVT::i64:
         Opcode = NVPTX::ST_i64_areg_64;
         break;
+      case MVT::f16:
+        Opcode = NVPTX::ST_f16_areg_64;
+        break;
       case MVT::f32:
         Opcode = NVPTX::ST_f32_areg_64;
         break;
@@ -2335,6 +2388,9 @@
       case MVT::i64:
         Opcode = NVPTX::ST_i64_areg;
         break;
+      case MVT::f16:
+        Opcode = NVPTX::ST_f16_areg;
+        break;
       case MVT::f32:
         Opcode = NVPTX::ST_f32_areg;
         break;
@@ -2786,6 +2842,9 @@
     case MVT::i64:
       Opc = NVPTX::LoadParamMemI64;
       break;
+    case MVT::f16:
+      Opc = NVPTX::LoadParamMemF16;
+      break;
     case MVT::f32:
       Opc = NVPTX::LoadParamMemF32;
       break;
@@ -2921,6 +2980,9 @@
     case MVT::i64:
       Opcode = NVPTX::StoreRetvalI64;
       break;
+    case MVT::f16:
+      Opcode = NVPTX::StoreRetvalF16;
+      break;
     case MVT::f32:
       Opcode = NVPTX::StoreRetvalF32;
       break;
@@ -3054,6 +3116,9 @@
       case MVT::i64:
         Opcode = NVPTX::StoreParamI64;
         break;
+      case MVT::f16:
+        Opcode = NVPTX::StoreParamF16;
+        break;
       case MVT::f32:
         Opcode = NVPTX::StoreParamF32;
         break;