[NVPTX] Added support for .f16x2 instructions.

This patch enables support for .f16x2 operations.

Added new register type Float16x2.
Added support for .f16x2 instructions.
Added handling of vectorized loads/stores of v2f16 values.

Differential Revision: https://reviews.llvm.org/D30057
Differential Revision: https://reviews.llvm.org/D30310

llvm-svn: 296032
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 2aef67b..7da621c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -84,6 +84,14 @@
     if (tryStore(N))
       return;
     break;
+  case ISD::EXTRACT_VECTOR_ELT:
+    if (tryEXTRACT_VECTOR_ELEMENT(N))
+      return;
+    break;
+  case NVPTXISD::SETP_F16X2:
+    SelectSETP_F16X2(N);
+    return;
+
   case NVPTXISD::LoadV2:
   case NVPTXISD::LoadV4:
     if (tryLoadVector(N))
@@ -516,6 +524,127 @@
   return true;
 }
 
+// Map ISD:CONDCODE value to appropriate CmpMode expected by
+// NVPTXInstPrinter::printCmpMode()
+static unsigned getPTXCmpMode(const CondCodeSDNode &CondCode, bool FTZ) {
+  using NVPTX::PTXCmpMode::CmpMode;
+  unsigned PTXCmpMode = [](ISD::CondCode CC) {
+    switch (CC) {
+    default:
+      llvm_unreachable("Unexpected condition code.");
+    case ISD::SETOEQ:
+      return CmpMode::EQ;
+    case ISD::SETOGT:
+      return CmpMode::GT;
+    case ISD::SETOGE:
+      return CmpMode::GE;
+    case ISD::SETOLT:
+      return CmpMode::LT;
+    case ISD::SETOLE:
+      return CmpMode::LE;
+    case ISD::SETONE:
+      return CmpMode::NE;
+    case ISD::SETO:
+      return CmpMode::NUM;
+    case ISD::SETUO:
+      return CmpMode::NotANumber;
+    case ISD::SETUEQ:
+      return CmpMode::EQU;
+    case ISD::SETUGT:
+      return CmpMode::GTU;
+    case ISD::SETUGE:
+      return CmpMode::GEU;
+    case ISD::SETULT:
+      return CmpMode::LTU;
+    case ISD::SETULE:
+      return CmpMode::LEU;
+    case ISD::SETUNE:
+      return CmpMode::NEU;
+    case ISD::SETEQ:
+      return CmpMode::EQ;
+    case ISD::SETGT:
+      return CmpMode::GT;
+    case ISD::SETGE:
+      return CmpMode::GE;
+    case ISD::SETLT:
+      return CmpMode::LT;
+    case ISD::SETLE:
+      return CmpMode::LE;
+    case ISD::SETNE:
+      return CmpMode::NE;
+    }
+  }(CondCode.get());
+
+  if (FTZ)
+    PTXCmpMode |= NVPTX::PTXCmpMode::FTZ_FLAG;
+
+  return PTXCmpMode;
+}
+
+bool NVPTXDAGToDAGISel::SelectSETP_F16X2(SDNode *N) {
+  unsigned PTXCmpMode =
+      getPTXCmpMode(*cast<CondCodeSDNode>(N->getOperand(2)), useF32FTZ());
+  SDLoc DL(N);
+  SDNode *SetP = CurDAG->getMachineNode(
+      NVPTX::SETP_f16x2rr, DL, MVT::i1, MVT::i1, N->getOperand(0),
+      N->getOperand(1), CurDAG->getTargetConstant(PTXCmpMode, DL, MVT::i32));
+  ReplaceNode(N, SetP);
+  return true;
+}
+
+// Find all instances of extract_vector_elt that use this v2f16 vector
+// and coalesce them into a scattering move instruction.
+bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) {
+  SDValue Vector = N->getOperand(0);
+
+  // We only care about f16x2 as it's the only real vector type we
+  // need to deal with.
+  if (Vector.getSimpleValueType() != MVT::v2f16)
+    return false;
+
+  // Find and record all uses of this vector that extract element 0 or 1.
+  SmallVector<SDNode *, 4> E0, E1;
+  for (const auto &U : Vector.getNode()->uses()) {
+    if (U->getOpcode() != ISD::EXTRACT_VECTOR_ELT)
+      continue;
+    if (U->getOperand(0) != Vector)
+      continue;
+    if (const ConstantSDNode *IdxConst =
+            dyn_cast<ConstantSDNode>(U->getOperand(1))) {
+      if (IdxConst->getZExtValue() == 0)
+        E0.push_back(U);
+      else if (IdxConst->getZExtValue() == 1)
+        E1.push_back(U);
+      else
+        llvm_unreachable("Invalid vector index.");
+    }
+  }
+
+  // There's no point scattering f16x2 if we only ever access one
+  // element of it.
+  if (E0.empty() || E1.empty())
+    return false;
+
+  unsigned Op = NVPTX::SplitF16x2;
+  // If the vector has been BITCAST'ed from i32, we can use original
+  // value directly and avoid register-to-register move.
+  SDValue Source = Vector;
+  if (Vector->getOpcode() == ISD::BITCAST) {
+    Op = NVPTX::SplitI32toF16x2;
+    Source = Vector->getOperand(0);
+  }
+  // Merge (f16 extractelt(V, 0), f16 extractelt(V,1))
+  // into f16,f16 SplitF16x2(V)
+  SDNode *ScatterOp =
+      CurDAG->getMachineNode(Op, SDLoc(N), MVT::f16, MVT::f16, Source);
+  for (auto *Node : E0)
+    ReplaceUses(SDValue(Node, 0), SDValue(ScatterOp, 0));
+  for (auto *Node : E1)
+    ReplaceUses(SDValue(Node, 0), SDValue(ScatterOp, 1));
+
+  return true;
+}
+
 static unsigned int getCodeAddrSpace(MemSDNode *N) {
   const Value *Src = N->getMemOperand()->getValue();
 
@@ -689,29 +818,26 @@
       codeAddrSpace != NVPTX::PTXLdStInstCode::GENERIC)
     isVolatile = false;
 
-  // Vector Setting
-  MVT SimpleVT = LoadedVT.getSimpleVT();
-  unsigned vecType = NVPTX::PTXLdStInstCode::Scalar;
-  if (SimpleVT.isVector()) {
-    unsigned num = SimpleVT.getVectorNumElements();
-    if (num == 2)
-      vecType = NVPTX::PTXLdStInstCode::V2;
-    else if (num == 4)
-      vecType = NVPTX::PTXLdStInstCode::V4;
-    else
-      return false;
-  }
-
   // Type Setting: fromType + fromTypeWidth
   //
   // Sign   : ISD::SEXTLOAD
   // Unsign : ISD::ZEXTLOAD, ISD::NON_EXTLOAD or ISD::EXTLOAD and the
   //          type is integer
   // Float  : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float
+  MVT SimpleVT = LoadedVT.getSimpleVT();
   MVT ScalarVT = SimpleVT.getScalarType();
   // Read at least 8 bits (predicates are stored as 8-bit values)
   unsigned fromTypeWidth = std::max(8U, ScalarVT.getSizeInBits());
   unsigned int fromType;
+
+  // Vector Setting
+  unsigned vecType = NVPTX::PTXLdStInstCode::Scalar;
+  if (SimpleVT.isVector()) {
+    assert(LoadedVT == MVT::v2f16 && "Unexpected vector type");
+    // v2f16 is loaded using ld.b32
+    fromTypeWidth = 32;
+  }
+
   if ((LD->getExtensionType() == ISD::SEXTLOAD))
     fromType = NVPTX::PTXLdStInstCode::Signed;
   else if (ScalarVT.isFloatingPoint())
@@ -746,6 +872,9 @@
     case MVT::f16:
       Opcode = NVPTX::LD_f16_avar;
       break;
+    case MVT::v2f16:
+      Opcode = NVPTX::LD_f16x2_avar;
+      break;
     case MVT::f32:
       Opcode = NVPTX::LD_f32_avar;
       break;
@@ -777,6 +906,9 @@
     case MVT::f16:
       Opcode = NVPTX::LD_f16_asi;
       break;
+    case MVT::v2f16:
+      Opcode = NVPTX::LD_f16x2_asi;
+      break;
     case MVT::f32:
       Opcode = NVPTX::LD_f32_asi;
       break;
@@ -809,6 +941,9 @@
       case MVT::f16:
         Opcode = NVPTX::LD_f16_ari_64;
         break;
+      case MVT::v2f16:
+        Opcode = NVPTX::LD_f16x2_ari_64;
+        break;
       case MVT::f32:
         Opcode = NVPTX::LD_f32_ari_64;
         break;
@@ -835,6 +970,9 @@
       case MVT::f16:
         Opcode = NVPTX::LD_f16_ari;
         break;
+      case MVT::v2f16:
+        Opcode = NVPTX::LD_f16x2_ari;
+        break;
       case MVT::f32:
         Opcode = NVPTX::LD_f32_ari;
         break;
@@ -867,6 +1005,9 @@
       case MVT::f16:
         Opcode = NVPTX::LD_f16_areg_64;
         break;
+      case MVT::v2f16:
+        Opcode = NVPTX::LD_f16x2_areg_64;
+        break;
       case MVT::f32:
         Opcode = NVPTX::LD_f32_areg_64;
         break;
@@ -893,6 +1034,9 @@
       case MVT::f16:
         Opcode = NVPTX::LD_f16_areg;
         break;
+      case MVT::v2f16:
+        Opcode = NVPTX::LD_f16x2_areg;
+        break;
       case MVT::f32:
         Opcode = NVPTX::LD_f32_areg;
         break;
@@ -968,7 +1112,8 @@
   if (ExtensionType == ISD::SEXTLOAD)
     FromType = NVPTX::PTXLdStInstCode::Signed;
   else if (ScalarVT.isFloatingPoint())
-    FromType = NVPTX::PTXLdStInstCode::Float;
+    FromType = ScalarVT.SimpleTy == MVT::f16 ? NVPTX::PTXLdStInstCode::Untyped
+                                             : NVPTX::PTXLdStInstCode::Float;
   else
     FromType = NVPTX::PTXLdStInstCode::Unsigned;
 
@@ -987,6 +1132,16 @@
 
   EVT EltVT = N->getValueType(0);
 
+  // v8f16 is a special case. PTX doesn't have ld.v8.f16
+  // instruction. Instead, we split the vector into v2f16 chunks and
+  // load them with ld.v4.b32.
+  if (EltVT == MVT::v2f16) {
+    assert(N->getOpcode() == NVPTXISD::LoadV4 && "Unexpected load opcode.");
+    EltVT = MVT::i32;
+    FromType = NVPTX::PTXLdStInstCode::Untyped;
+    FromTypeWidth = 32;
+  }
+
   if (SelectDirectAddr(Op1, Addr)) {
     switch (N->getOpcode()) {
     default:
@@ -1007,6 +1162,9 @@
       case MVT::i64:
         Opcode = NVPTX::LDV_i64_v2_avar;
         break;
+      case MVT::f16:
+        Opcode = NVPTX::LDV_f16_v2_avar;
+        break;
       case MVT::f32:
         Opcode = NVPTX::LDV_f32_v2_avar;
         break;
@@ -1028,6 +1186,9 @@
       case MVT::i32:
         Opcode = NVPTX::LDV_i32_v4_avar;
         break;
+      case MVT::f16:
+        Opcode = NVPTX::LDV_f16_v4_avar;
+        break;
       case MVT::f32:
         Opcode = NVPTX::LDV_f32_v4_avar;
         break;
@@ -1060,6 +1221,9 @@
       case MVT::i64:
         Opcode = NVPTX::LDV_i64_v2_asi;
         break;
+      case MVT::f16:
+        Opcode = NVPTX::LDV_f16_v2_asi;
+        break;
       case MVT::f32:
         Opcode = NVPTX::LDV_f32_v2_asi;
         break;
@@ -1081,6 +1245,9 @@
       case MVT::i32:
         Opcode = NVPTX::LDV_i32_v4_asi;
         break;
+      case MVT::f16:
+        Opcode = NVPTX::LDV_f16_v4_asi;
+        break;
       case MVT::f32:
         Opcode = NVPTX::LDV_f32_v4_asi;
         break;
@@ -1114,6 +1281,9 @@
         case MVT::i64:
           Opcode = NVPTX::LDV_i64_v2_ari_64;
           break;
+        case MVT::f16:
+          Opcode = NVPTX::LDV_f16_v2_ari_64;
+          break;
         case MVT::f32:
           Opcode = NVPTX::LDV_f32_v2_ari_64;
           break;
@@ -1135,6 +1305,9 @@
         case MVT::i32:
           Opcode = NVPTX::LDV_i32_v4_ari_64;
           break;
+        case MVT::f16:
+          Opcode = NVPTX::LDV_f16_v4_ari_64;
+          break;
         case MVT::f32:
           Opcode = NVPTX::LDV_f32_v4_ari_64;
           break;
@@ -1161,6 +1334,9 @@
         case MVT::i64:
           Opcode = NVPTX::LDV_i64_v2_ari;
           break;
+        case MVT::f16:
+          Opcode = NVPTX::LDV_f16_v2_ari;
+          break;
         case MVT::f32:
           Opcode = NVPTX::LDV_f32_v2_ari;
           break;
@@ -1182,6 +1358,9 @@
         case MVT::i32:
           Opcode = NVPTX::LDV_i32_v4_ari;
           break;
+        case MVT::f16:
+          Opcode = NVPTX::LDV_f16_v4_ari;
+          break;
         case MVT::f32:
           Opcode = NVPTX::LDV_f32_v4_ari;
           break;
@@ -1216,6 +1395,9 @@
         case MVT::i64:
           Opcode = NVPTX::LDV_i64_v2_areg_64;
           break;
+        case MVT::f16:
+          Opcode = NVPTX::LDV_f16_v2_areg_64;
+          break;
         case MVT::f32:
           Opcode = NVPTX::LDV_f32_v2_areg_64;
           break;
@@ -1237,6 +1419,9 @@
         case MVT::i32:
           Opcode = NVPTX::LDV_i32_v4_areg_64;
           break;
+        case MVT::f16:
+          Opcode = NVPTX::LDV_f16_v4_areg_64;
+          break;
         case MVT::f32:
           Opcode = NVPTX::LDV_f32_v4_areg_64;
           break;
@@ -1263,6 +1448,9 @@
         case MVT::i64:
           Opcode = NVPTX::LDV_i64_v2_areg;
           break;
+        case MVT::f16:
+          Opcode = NVPTX::LDV_f16_v2_areg;
+          break;
         case MVT::f32:
           Opcode = NVPTX::LDV_f32_v2_areg;
           break;
@@ -1284,6 +1472,9 @@
         case MVT::i32:
           Opcode = NVPTX::LDV_i32_v4_areg;
           break;
+        case MVT::f16:
+          Opcode = NVPTX::LDV_f16_v4_areg;
+          break;
         case MVT::f32:
           Opcode = NVPTX::LDV_f32_v4_areg;
           break;
@@ -2151,21 +2342,18 @@
   // Vector Setting
   MVT SimpleVT = StoreVT.getSimpleVT();
   unsigned vecType = NVPTX::PTXLdStInstCode::Scalar;
-  if (SimpleVT.isVector()) {
-    unsigned num = SimpleVT.getVectorNumElements();
-    if (num == 2)
-      vecType = NVPTX::PTXLdStInstCode::V2;
-    else if (num == 4)
-      vecType = NVPTX::PTXLdStInstCode::V4;
-    else
-      return false;
-  }
 
   // Type Setting: toType + toTypeWidth
   // - for integer type, always use 'u'
   //
   MVT ScalarVT = SimpleVT.getScalarType();
   unsigned toTypeWidth = ScalarVT.getSizeInBits();
+  if (SimpleVT.isVector()) {
+    assert(StoreVT == MVT::v2f16 && "Unexpected vector type");
+    // v2f16 is stored using st.b32
+    toTypeWidth = 32;
+  }
+
   unsigned int toType;
   if (ScalarVT.isFloatingPoint())
     // f16 uses .b16 as its storage type.
@@ -2200,6 +2388,9 @@
     case MVT::f16:
       Opcode = NVPTX::ST_f16_avar;
       break;
+    case MVT::v2f16:
+      Opcode = NVPTX::ST_f16x2_avar;
+      break;
     case MVT::f32:
       Opcode = NVPTX::ST_f32_avar;
       break;
@@ -2232,6 +2423,9 @@
     case MVT::f16:
       Opcode = NVPTX::ST_f16_asi;
       break;
+    case MVT::v2f16:
+      Opcode = NVPTX::ST_f16x2_asi;
+      break;
     case MVT::f32:
       Opcode = NVPTX::ST_f32_asi;
       break;
@@ -2265,6 +2459,9 @@
       case MVT::f16:
         Opcode = NVPTX::ST_f16_ari_64;
         break;
+      case MVT::v2f16:
+        Opcode = NVPTX::ST_f16x2_ari_64;
+        break;
       case MVT::f32:
         Opcode = NVPTX::ST_f32_ari_64;
         break;
@@ -2291,6 +2488,9 @@
       case MVT::f16:
         Opcode = NVPTX::ST_f16_ari;
         break;
+      case MVT::v2f16:
+        Opcode = NVPTX::ST_f16x2_ari;
+        break;
       case MVT::f32:
         Opcode = NVPTX::ST_f32_ari;
         break;
@@ -2324,6 +2524,9 @@
       case MVT::f16:
         Opcode = NVPTX::ST_f16_areg_64;
         break;
+      case MVT::v2f16:
+        Opcode = NVPTX::ST_f16x2_areg_64;
+        break;
       case MVT::f32:
         Opcode = NVPTX::ST_f32_areg_64;
         break;
@@ -2350,6 +2553,9 @@
       case MVT::f16:
         Opcode = NVPTX::ST_f16_areg;
         break;
+      case MVT::v2f16:
+        Opcode = NVPTX::ST_f16x2_areg;
+        break;
       case MVT::f32:
         Opcode = NVPTX::ST_f32_areg;
         break;
@@ -2411,7 +2617,8 @@
   unsigned ToTypeWidth = ScalarVT.getSizeInBits();
   unsigned ToType;
   if (ScalarVT.isFloatingPoint())
-    ToType = NVPTX::PTXLdStInstCode::Float;
+    ToType = ScalarVT.SimpleTy == MVT::f16 ? NVPTX::PTXLdStInstCode::Untyped
+                                           : NVPTX::PTXLdStInstCode::Float;
   else
     ToType = NVPTX::PTXLdStInstCode::Unsigned;
 
@@ -2438,6 +2645,16 @@
     return false;
   }
 
+  // v8f16 is a special case. PTX doesn't have st.v8.f16
+  // instruction. Instead, we split the vector into v2f16 chunks and
+  // store them with st.v4.b32.
+  if (EltVT == MVT::v2f16) {
+    assert(N->getOpcode() == NVPTXISD::StoreV4 && "Unexpected load opcode.");
+    EltVT = MVT::i32;
+    ToType = NVPTX::PTXLdStInstCode::Untyped;
+    ToTypeWidth = 32;
+  }
+
   StOps.push_back(getI32Imm(IsVolatile, DL));
   StOps.push_back(getI32Imm(CodeAddrSpace, DL));
   StOps.push_back(getI32Imm(VecType, DL));
@@ -2464,6 +2681,9 @@
       case MVT::i64:
         Opcode = NVPTX::STV_i64_v2_avar;
         break;
+      case MVT::f16:
+        Opcode = NVPTX::STV_f16_v2_avar;
+        break;
       case MVT::f32:
         Opcode = NVPTX::STV_f32_v2_avar;
         break;
@@ -2513,6 +2733,9 @@
       case MVT::i64:
         Opcode = NVPTX::STV_i64_v2_asi;
         break;
+      case MVT::f16:
+        Opcode = NVPTX::STV_f16_v2_asi;
+        break;
       case MVT::f32:
         Opcode = NVPTX::STV_f32_v2_asi;
         break;
@@ -2534,6 +2757,9 @@
       case MVT::i32:
         Opcode = NVPTX::STV_i32_v4_asi;
         break;
+      case MVT::f16:
+        Opcode = NVPTX::STV_f16_v4_asi;
+        break;
       case MVT::f32:
         Opcode = NVPTX::STV_f32_v4_asi;
         break;
@@ -2564,6 +2790,9 @@
         case MVT::i64:
           Opcode = NVPTX::STV_i64_v2_ari_64;
           break;
+        case MVT::f16:
+          Opcode = NVPTX::STV_f16_v2_ari_64;
+          break;
         case MVT::f32:
           Opcode = NVPTX::STV_f32_v2_ari_64;
           break;
@@ -2585,6 +2814,9 @@
         case MVT::i32:
           Opcode = NVPTX::STV_i32_v4_ari_64;
           break;
+        case MVT::f16:
+          Opcode = NVPTX::STV_f16_v4_ari_64;
+          break;
         case MVT::f32:
           Opcode = NVPTX::STV_f32_v4_ari_64;
           break;
@@ -2611,6 +2843,9 @@
         case MVT::i64:
           Opcode = NVPTX::STV_i64_v2_ari;
           break;
+        case MVT::f16:
+          Opcode = NVPTX::STV_f16_v2_ari;
+          break;
         case MVT::f32:
           Opcode = NVPTX::STV_f32_v2_ari;
           break;
@@ -2632,6 +2867,9 @@
         case MVT::i32:
           Opcode = NVPTX::STV_i32_v4_ari;
           break;
+        case MVT::f16:
+          Opcode = NVPTX::STV_f16_v4_ari;
+          break;
         case MVT::f32:
           Opcode = NVPTX::STV_f32_v4_ari;
           break;
@@ -2662,6 +2900,9 @@
         case MVT::i64:
           Opcode = NVPTX::STV_i64_v2_areg_64;
           break;
+        case MVT::f16:
+          Opcode = NVPTX::STV_f16_v2_areg_64;
+          break;
         case MVT::f32:
           Opcode = NVPTX::STV_f32_v2_areg_64;
           break;
@@ -2683,6 +2924,9 @@
         case MVT::i32:
           Opcode = NVPTX::STV_i32_v4_areg_64;
           break;
+        case MVT::f16:
+          Opcode = NVPTX::STV_f16_v4_areg_64;
+          break;
         case MVT::f32:
           Opcode = NVPTX::STV_f32_v4_areg_64;
           break;
@@ -2709,6 +2953,9 @@
         case MVT::i64:
           Opcode = NVPTX::STV_i64_v2_areg;
           break;
+        case MVT::f16:
+          Opcode = NVPTX::STV_f16_v2_areg;
+          break;
         case MVT::f32:
           Opcode = NVPTX::STV_f32_v2_areg;
           break;
@@ -2730,6 +2977,9 @@
         case MVT::i32:
           Opcode = NVPTX::STV_i32_v4_areg;
           break;
+        case MVT::f16:
+          Opcode = NVPTX::STV_f16_v4_areg;
+          break;
         case MVT::f32:
           Opcode = NVPTX::STV_f32_v4_areg;
           break;
@@ -2804,6 +3054,9 @@
     case MVT::f16:
       Opc = NVPTX::LoadParamMemF16;
       break;
+    case MVT::v2f16:
+      Opc = NVPTX::LoadParamMemF16x2;
+      break;
     case MVT::f32:
       Opc = NVPTX::LoadParamMemF32;
       break;
@@ -2831,6 +3084,12 @@
     case MVT::i64:
       Opc = NVPTX::LoadParamMemV2I64;
       break;
+    case MVT::f16:
+      Opc = NVPTX::LoadParamMemV2F16;
+      break;
+    case MVT::v2f16:
+      Opc = NVPTX::LoadParamMemV2F16x2;
+      break;
     case MVT::f32:
       Opc = NVPTX::LoadParamMemV2F32;
       break;
@@ -2855,6 +3114,12 @@
     case MVT::i32:
       Opc = NVPTX::LoadParamMemV4I32;
       break;
+    case MVT::f16:
+      Opc = NVPTX::LoadParamMemV4F16;
+      break;
+    case MVT::v2f16:
+      Opc = NVPTX::LoadParamMemV4F16x2;
+      break;
     case MVT::f32:
       Opc = NVPTX::LoadParamMemV4F32;
       break;
@@ -2942,6 +3207,9 @@
     case MVT::f16:
       Opcode = NVPTX::StoreRetvalF16;
       break;
+    case MVT::v2f16:
+      Opcode = NVPTX::StoreRetvalF16x2;
+      break;
     case MVT::f32:
       Opcode = NVPTX::StoreRetvalF32;
       break;
@@ -2969,6 +3237,12 @@
     case MVT::i64:
       Opcode = NVPTX::StoreRetvalV2I64;
       break;
+    case MVT::f16:
+      Opcode = NVPTX::StoreRetvalV2F16;
+      break;
+    case MVT::v2f16:
+      Opcode = NVPTX::StoreRetvalV2F16x2;
+      break;
     case MVT::f32:
       Opcode = NVPTX::StoreRetvalV2F32;
       break;
@@ -2993,6 +3267,12 @@
     case MVT::i32:
       Opcode = NVPTX::StoreRetvalV4I32;
       break;
+    case MVT::f16:
+      Opcode = NVPTX::StoreRetvalV4F16;
+      break;
+    case MVT::v2f16:
+      Opcode = NVPTX::StoreRetvalV4F16x2;
+      break;
     case MVT::f32:
       Opcode = NVPTX::StoreRetvalV4F32;
       break;
@@ -3000,8 +3280,7 @@
     break;
   }
 
-  SDNode *Ret =
-      CurDAG->getMachineNode(Opcode, DL, MVT::Other, Ops);
+  SDNode *Ret = CurDAG->getMachineNode(Opcode, DL, MVT::Other, Ops);
   MachineSDNode::mmo_iterator MemRefs0 = MF->allocateMemRefsArray(1);
   MemRefs0[0] = cast<MemSDNode>(N)->getMemOperand();
   cast<MachineSDNode>(Ret)->setMemRefs(MemRefs0, MemRefs0 + 1);
@@ -3078,6 +3357,9 @@
       case MVT::f16:
         Opcode = NVPTX::StoreParamF16;
         break;
+      case MVT::v2f16:
+        Opcode = NVPTX::StoreParamF16x2;
+        break;
       case MVT::f32:
         Opcode = NVPTX::StoreParamF32;
         break;
@@ -3105,6 +3387,12 @@
       case MVT::i64:
         Opcode = NVPTX::StoreParamV2I64;
         break;
+      case MVT::f16:
+        Opcode = NVPTX::StoreParamV2F16;
+        break;
+      case MVT::v2f16:
+        Opcode = NVPTX::StoreParamV2F16x2;
+        break;
       case MVT::f32:
         Opcode = NVPTX::StoreParamV2F32;
         break;
@@ -3129,6 +3417,12 @@
       case MVT::i32:
         Opcode = NVPTX::StoreParamV4I32;
         break;
+      case MVT::f16:
+        Opcode = NVPTX::StoreParamV4F16;
+        break;
+      case MVT::v2f16:
+        Opcode = NVPTX::StoreParamV4F16x2;
+        break;
       case MVT::f32:
         Opcode = NVPTX::StoreParamV4F32;
         break;