[NVPTX] Added support for half-precision floating point.
Only scalar half-precision operations are supported at the moment.
- Adds general support for 'half' type in NVPTX.
- fp16 math operations are supported on sm_53+ GPUs only
(can be disabled with --nvptx-no-f16-math).
- Type conversions to/from fp16 are supported on all GPU variants.
- On GPU variants that do not have full fp16 support (or if it's disabled),
fp16 operations are promoted to fp32 and results are converted back
to fp16 for storage.
Differential Revision: https://reviews.llvm.org/D28540
llvm-svn: 291956
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 4f3129c..6548dad 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -42,7 +42,6 @@
cl::desc("NVPTX Specific: Flush f32 subnormals to sign-preserving zero."),
cl::init(false));
-
/// createNVPTXISelDag - This pass converts a legalized DAG into a
/// NVPTX-specific DAG, ready for instruction scheduling.
FunctionPass *llvm::createNVPTXISelDag(NVPTXTargetMachine &TM,
@@ -520,6 +519,10 @@
case ISD::ADDRSPACECAST:
SelectAddrSpaceCast(N);
return;
+ case ISD::ConstantFP:
+ if (tryConstantFP16(N))
+ return;
+ break;
default:
break;
}
@@ -541,6 +544,19 @@
}
}
+// There's no way to specify FP16 immediates in .f16 ops, so we have to
+// load them into an .f16 register first.
+bool NVPTXDAGToDAGISel::tryConstantFP16(SDNode *N) {
+ if (N->getValueType(0) != MVT::f16)
+ return false;
+ SDValue Val = CurDAG->getTargetConstantFP(
+ cast<ConstantFPSDNode>(N)->getValueAPF(), SDLoc(N), MVT::f16);
+ SDNode *LoadConstF16 =
+ CurDAG->getMachineNode(NVPTX::LOAD_CONST_F16, SDLoc(N), MVT::f16, Val);
+ ReplaceNode(N, LoadConstF16);
+ return true;
+}
+
static unsigned int getCodeAddrSpace(MemSDNode *N) {
const Value *Src = N->getMemOperand()->getValue();
@@ -740,7 +756,9 @@
if ((LD->getExtensionType() == ISD::SEXTLOAD))
fromType = NVPTX::PTXLdStInstCode::Signed;
else if (ScalarVT.isFloatingPoint())
- fromType = NVPTX::PTXLdStInstCode::Float;
+ // f16 uses .b16 as its storage type.
+ fromType = ScalarVT.SimpleTy == MVT::f16 ? NVPTX::PTXLdStInstCode::Untyped
+ : NVPTX::PTXLdStInstCode::Float;
else
fromType = NVPTX::PTXLdStInstCode::Unsigned;
@@ -766,6 +784,9 @@
case MVT::i64:
Opcode = NVPTX::LD_i64_avar;
break;
+ case MVT::f16:
+ Opcode = NVPTX::LD_f16_avar;
+ break;
case MVT::f32:
Opcode = NVPTX::LD_f32_avar;
break;
@@ -794,6 +815,9 @@
case MVT::i64:
Opcode = NVPTX::LD_i64_asi;
break;
+ case MVT::f16:
+ Opcode = NVPTX::LD_f16_asi;
+ break;
case MVT::f32:
Opcode = NVPTX::LD_f32_asi;
break;
@@ -823,6 +847,9 @@
case MVT::i64:
Opcode = NVPTX::LD_i64_ari_64;
break;
+ case MVT::f16:
+ Opcode = NVPTX::LD_f16_ari_64;
+ break;
case MVT::f32:
Opcode = NVPTX::LD_f32_ari_64;
break;
@@ -846,6 +873,9 @@
case MVT::i64:
Opcode = NVPTX::LD_i64_ari;
break;
+ case MVT::f16:
+ Opcode = NVPTX::LD_f16_ari;
+ break;
case MVT::f32:
Opcode = NVPTX::LD_f32_ari;
break;
@@ -875,6 +905,9 @@
case MVT::i64:
Opcode = NVPTX::LD_i64_areg_64;
break;
+ case MVT::f16:
+ Opcode = NVPTX::LD_f16_areg_64;
+ break;
case MVT::f32:
Opcode = NVPTX::LD_f32_areg_64;
break;
@@ -898,6 +931,9 @@
case MVT::i64:
Opcode = NVPTX::LD_i64_areg;
break;
+ case MVT::f16:
+ Opcode = NVPTX::LD_f16_areg;
+ break;
case MVT::f32:
Opcode = NVPTX::LD_f32_areg;
break;
@@ -2173,7 +2209,9 @@
unsigned toTypeWidth = ScalarVT.getSizeInBits();
unsigned int toType;
if (ScalarVT.isFloatingPoint())
- toType = NVPTX::PTXLdStInstCode::Float;
+ // f16 uses .b16 as its storage type.
+ toType = ScalarVT.SimpleTy == MVT::f16 ? NVPTX::PTXLdStInstCode::Untyped
+ : NVPTX::PTXLdStInstCode::Float;
else
toType = NVPTX::PTXLdStInstCode::Unsigned;
@@ -2200,6 +2238,9 @@
case MVT::i64:
Opcode = NVPTX::ST_i64_avar;
break;
+ case MVT::f16:
+ Opcode = NVPTX::ST_f16_avar;
+ break;
case MVT::f32:
Opcode = NVPTX::ST_f32_avar;
break;
@@ -2229,6 +2270,9 @@
case MVT::i64:
Opcode = NVPTX::ST_i64_asi;
break;
+ case MVT::f16:
+ Opcode = NVPTX::ST_f16_asi;
+ break;
case MVT::f32:
Opcode = NVPTX::ST_f32_asi;
break;
@@ -2259,6 +2303,9 @@
case MVT::i64:
Opcode = NVPTX::ST_i64_ari_64;
break;
+ case MVT::f16:
+ Opcode = NVPTX::ST_f16_ari_64;
+ break;
case MVT::f32:
Opcode = NVPTX::ST_f32_ari_64;
break;
@@ -2282,6 +2329,9 @@
case MVT::i64:
Opcode = NVPTX::ST_i64_ari;
break;
+ case MVT::f16:
+ Opcode = NVPTX::ST_f16_ari;
+ break;
case MVT::f32:
Opcode = NVPTX::ST_f32_ari;
break;
@@ -2312,6 +2362,9 @@
case MVT::i64:
Opcode = NVPTX::ST_i64_areg_64;
break;
+ case MVT::f16:
+ Opcode = NVPTX::ST_f16_areg_64;
+ break;
case MVT::f32:
Opcode = NVPTX::ST_f32_areg_64;
break;
@@ -2335,6 +2388,9 @@
case MVT::i64:
Opcode = NVPTX::ST_i64_areg;
break;
+ case MVT::f16:
+ Opcode = NVPTX::ST_f16_areg;
+ break;
case MVT::f32:
Opcode = NVPTX::ST_f32_areg;
break;
@@ -2786,6 +2842,9 @@
case MVT::i64:
Opc = NVPTX::LoadParamMemI64;
break;
+ case MVT::f16:
+ Opc = NVPTX::LoadParamMemF16;
+ break;
case MVT::f32:
Opc = NVPTX::LoadParamMemF32;
break;
@@ -2921,6 +2980,9 @@
case MVT::i64:
Opcode = NVPTX::StoreRetvalI64;
break;
+ case MVT::f16:
+ Opcode = NVPTX::StoreRetvalF16;
+ break;
case MVT::f32:
Opcode = NVPTX::StoreRetvalF32;
break;
@@ -3054,6 +3116,9 @@
case MVT::i64:
Opcode = NVPTX::StoreParamI64;
break;
+ case MVT::f16:
+ Opcode = NVPTX::StoreParamF16;
+ break;
case MVT::f32:
Opcode = NVPTX::StoreParamF32;
break;