AMDGPU: Make v4i16/v4f16 legal

Some image loads return these, and it's awkward working
around them not being legal.

llvm-svn: 334835
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index df396e6..1b91d74 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -143,6 +143,8 @@
     // Unless there are also VOP3P operations, not operations are really legal.
     addRegisterClass(MVT::v2i16, &AMDGPU::SReg_32_XM0RegClass);
     addRegisterClass(MVT::v2f16, &AMDGPU::SReg_32_XM0RegClass);
+    addRegisterClass(MVT::v4i16, &AMDGPU::SReg_64RegClass);
+    addRegisterClass(MVT::v4f16, &AMDGPU::SReg_64RegClass);
   }
 
   computeRegisterProperties(STI.getRegisterInfo());
@@ -237,7 +239,7 @@
   // We only support LOAD/STORE and vector manipulation ops for vectors
   // with > 4 elements.
   for (MVT VT : {MVT::v8i32, MVT::v8f32, MVT::v16i32, MVT::v16f32,
-        MVT::v2i64, MVT::v2f64}) {
+        MVT::v2i64, MVT::v2f64, MVT::v4i16, MVT::v4f16 }) {
     for (unsigned Op = 0; Op < ISD::BUILTIN_OP_END; ++Op) {
       switch (Op) {
       case ISD::LOAD:
@@ -260,6 +262,8 @@
     }
   }
 
+  setOperationAction(ISD::FP_EXTEND, MVT::v4f32, Expand);
+
   // TODO: For dynamic 64-bit vector inserts/extracts, should emit a pseudo that
   // is expanded to avoid having two separate loops in case the index is a VGPR.
 
@@ -426,7 +430,7 @@
     if (!Subtarget->hasFP16Denormals())
       setOperationAction(ISD::FMAD, MVT::f16, Legal);
 
-    for (MVT VT : {MVT::v2i16, MVT::v2f16}) {
+    for (MVT VT : {MVT::v2i16, MVT::v2f16, MVT::v4i16, MVT::v4f16}) {
       for (unsigned Op = 0; Op < ISD::BUILTIN_OP_END; ++Op) {
         switch (Op) {
         case ISD::LOAD:
@@ -488,6 +492,10 @@
     setOperationAction(ISD::SIGN_EXTEND, MVT::v2i32, Expand);
     setOperationAction(ISD::FP_EXTEND, MVT::v2f32, Expand);
 
+    setOperationAction(ISD::ANY_EXTEND, MVT::v4i32, Expand);
+    setOperationAction(ISD::ZERO_EXTEND, MVT::v4i32, Expand);
+    setOperationAction(ISD::SIGN_EXTEND, MVT::v4i32, Expand);
+
     if (!Subtarget->hasVOP3PInsts()) {
       setOperationAction(ISD::BUILD_VECTOR, MVT::v2i16, Custom);
       setOperationAction(ISD::BUILD_VECTOR, MVT::v2f16, Custom);
@@ -520,8 +528,31 @@
 
     setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v2i16, Custom);
     setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v2f16, Custom);
+
+    setOperationAction(ISD::SHL, MVT::v4i16, Custom);
+    setOperationAction(ISD::SRA, MVT::v4i16, Custom);
+    setOperationAction(ISD::SRL, MVT::v4i16, Custom);
+    setOperationAction(ISD::ADD, MVT::v4i16, Custom);
+    setOperationAction(ISD::SUB, MVT::v4i16, Custom);
+    setOperationAction(ISD::MUL, MVT::v4i16, Custom);
+
+    setOperationAction(ISD::SMIN, MVT::v4i16, Custom);
+    setOperationAction(ISD::SMAX, MVT::v4i16, Custom);
+    setOperationAction(ISD::UMIN, MVT::v4i16, Custom);
+    setOperationAction(ISD::UMAX, MVT::v4i16, Custom);
+
+    setOperationAction(ISD::FADD, MVT::v4f16, Custom);
+    setOperationAction(ISD::FMUL, MVT::v4f16, Custom);
+    setOperationAction(ISD::FMINNUM, MVT::v4f16, Custom);
+    setOperationAction(ISD::FMAXNUM, MVT::v4f16, Custom);
+
+    setOperationAction(ISD::SELECT, MVT::v4i16, Custom);
+    setOperationAction(ISD::SELECT, MVT::v4f16, Custom);
   }
 
+  setOperationAction(ISD::FNEG, MVT::v4f16, Custom);
+  setOperationAction(ISD::FABS, MVT::v4f16, Custom);
+
   if (Subtarget->has16BitInsts()) {
     setOperationAction(ISD::SELECT, MVT::v2i16, Promote);
     AddPromotedToType(ISD::SELECT, MVT::v2i16, MVT::i32);
@@ -3383,6 +3414,49 @@
 // Custom DAG Lowering Operations
 //===----------------------------------------------------------------------===//
 
+// Work around LegalizeDAG doing the wrong thing and fully scalarizing if the
+// wider vector type is legal.
+SDValue SITargetLowering::splitUnaryVectorOp(SDValue Op,
+                                             SelectionDAG &DAG) const {
+  unsigned Opc = Op.getOpcode();
+  EVT VT = Op.getValueType();
+  assert(VT == MVT::v4f16);
+
+  SDValue Lo, Hi;
+  std::tie(Lo, Hi) = DAG.SplitVectorOperand(Op.getNode(), 0);
+
+  SDLoc SL(Op);
+  SDValue OpLo = DAG.getNode(Opc, SL, Lo.getValueType(), Lo,
+                             Op->getFlags());
+  SDValue OpHi = DAG.getNode(Opc, SL, Hi.getValueType(), Hi,
+                             Op->getFlags());
+
+  return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(Op), VT, OpLo, OpHi);
+}
+
+// Work around LegalizeDAG doing the wrong thing and fully scalarizing if the
+// wider vector type is legal.
+SDValue SITargetLowering::splitBinaryVectorOp(SDValue Op,
+                                              SelectionDAG &DAG) const {
+  unsigned Opc = Op.getOpcode();
+  EVT VT = Op.getValueType();
+  assert(VT == MVT::v4i16 || VT == MVT::v4f16);
+
+  SDValue Lo0, Hi0;
+  std::tie(Lo0, Hi0) = DAG.SplitVectorOperand(Op.getNode(), 0);
+  SDValue Lo1, Hi1;
+  std::tie(Lo1, Hi1) = DAG.SplitVectorOperand(Op.getNode(), 1);
+
+  SDLoc SL(Op);
+
+  SDValue OpLo = DAG.getNode(Opc, SL, Lo0.getValueType(), Lo0, Lo1,
+                             Op->getFlags());
+  SDValue OpHi = DAG.getNode(Opc, SL, Hi0.getValueType(), Hi0, Hi1,
+                             Op->getFlags());
+
+  return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(Op), VT, OpLo, OpHi);
+}
+
 SDValue SITargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
   switch (Op.getOpcode()) {
   default: return AMDGPUTargetLowering::LowerOperation(Op, DAG);
@@ -3423,6 +3497,24 @@
     return lowerTRAP(Op, DAG);
   case ISD::DEBUGTRAP:
     return lowerDEBUGTRAP(Op, DAG);
+  case ISD::FABS:
+  case ISD::FNEG:
+    return splitUnaryVectorOp(Op, DAG);
+  case ISD::SHL:
+  case ISD::SRA:
+  case ISD::SRL:
+  case ISD::ADD:
+  case ISD::SUB:
+  case ISD::MUL:
+  case ISD::SMIN:
+  case ISD::SMAX:
+  case ISD::UMIN:
+  case ISD::UMAX:
+  case ISD::FMINNUM:
+  case ISD::FMAXNUM:
+  case ISD::FADD:
+  case ISD::FMUL:
+    return splitBinaryVectorOp(Op, DAG);
   }
   return SDValue();
 }
@@ -3630,21 +3722,23 @@
   bool Unpacked = Subtarget->hasUnpackedD16VMem();
   EVT LoadVT = M->getValueType(0);
 
-  EVT UnpackedLoadVT = LoadVT.isVector() ?
-    EVT::getVectorVT(*DAG.getContext(), MVT::i32,
-                     LoadVT.getVectorNumElements()) : LoadVT;
   EVT EquivLoadVT = LoadVT;
-  if (LoadVT.isVector()) {
-    EquivLoadVT = Unpacked ? UnpackedLoadVT :
-                  getEquivalentMemType(*DAG.getContext(), LoadVT);
+  if (Unpacked && LoadVT.isVector()) {
+    EquivLoadVT = LoadVT.isVector() ?
+      EVT::getVectorVT(*DAG.getContext(), MVT::i32,
+                       LoadVT.getVectorNumElements()) : LoadVT;
   }
 
   // Change from v4f16/v2f16 to EquivLoadVT.
   SDVTList VTList = DAG.getVTList(EquivLoadVT, MVT::Other);
 
-  SDValue Load = DAG.getMemIntrinsicNode(
-                    IsIntrinsic ? (unsigned)ISD::INTRINSIC_W_CHAIN : Opcode,
-                    DL, VTList, Ops, M->getMemoryVT(), M->getMemOperand());
+  SDValue Load
+    = DAG.getMemIntrinsicNode(
+      IsIntrinsic ? (unsigned)ISD::INTRINSIC_W_CHAIN : Opcode, DL,
+      VTList, Ops, M->getMemoryVT(),
+      M->getMemOperand());
+  if (!Unpacked) // Just adjusted the opcode.
+    return Load;
 
   SDValue Adjusted = adjustLoadValueTypeImpl(Load, LoadVT, DL, DAG, Unpacked);
 
@@ -3734,8 +3828,10 @@
     return;
   }
   case ISD::FNEG: {
+    if (N->getValueType(0) != MVT::v2f16)
+      break;
+
     SDLoc SL(N);
-    assert(N->getValueType(0) == MVT::v2f16);
     SDValue BC = DAG.getNode(ISD::BITCAST, SL, MVT::i32, N->getOperand(0));
 
     SDValue Op = DAG.getNode(ISD::XOR, SL, MVT::i32,
@@ -3745,8 +3841,10 @@
     return;
   }
   case ISD::FABS: {
+    if (N->getValueType(0) != MVT::v2f16)
+      break;
+
     SDLoc SL(N);
-    assert(N->getValueType(0) == MVT::v2f16);
     SDValue BC = DAG.getNode(ISD::BITCAST, SL, MVT::i32, N->getOperand(0));
 
     SDValue Op = DAG.getNode(ISD::AND, SL, MVT::i32,
@@ -4247,6 +4345,23 @@
   SDLoc SL(Op);
   EVT VT = Op.getValueType();
 
+  if (VT == MVT::v4i16 || VT == MVT::v4f16) {
+    EVT HalfVT = MVT::getVectorVT(VT.getVectorElementType().getSimpleVT(), 2);
+
+    // Turn into pair of packed build_vectors.
+    // TODO: Special case for constants that can be materialized with s_mov_b64.
+    SDValue Lo = DAG.getBuildVector(HalfVT, SL,
+                                    { Op.getOperand(0), Op.getOperand(1) });
+    SDValue Hi = DAG.getBuildVector(HalfVT, SL,
+                                    { Op.getOperand(2), Op.getOperand(3) });
+
+    SDValue CastLo = DAG.getNode(ISD::BITCAST, SL, MVT::i32, Lo);
+    SDValue CastHi = DAG.getNode(ISD::BITCAST, SL, MVT::i32, Hi);
+
+    SDValue Blend = DAG.getBuildVector(MVT::v2i32, SL, { CastLo, CastHi });
+    return DAG.getNode(ISD::BITCAST, SL, VT, Blend);
+  }
+
   assert(VT == MVT::v2f16 || VT == MVT::v2i16);
 
   SDValue Lo = Op.getOperand(0);
@@ -4913,11 +5028,11 @@
 
   case Intrinsic::amdgcn_image_load:
   case Intrinsic::amdgcn_image_load_mip: {
-    EVT LoadVT = Op.getValueType();
-    if ((Subtarget->hasUnpackedD16VMem() && LoadVT == MVT::v2f16) ||
-        LoadVT == MVT::v4f16) {
-      MemSDNode *M = cast<MemSDNode>(Op);
-      return adjustLoadValueType(getImageOpcode(IntrID), M, DAG);
+    EVT VT = Op.getValueType();
+    if (Subtarget->hasUnpackedD16VMem() &&
+        VT.isVector() && VT.getScalarSizeInBits() == 16) {
+      return adjustLoadValueType(getImageOpcode(IntrID), cast<MemSDNode>(Op),
+                                 DAG);
     }
 
     return SDValue();
@@ -5009,8 +5124,9 @@
       return DAG.getMergeValues({ Undef, Op.getOperand(0) }, SDLoc(Op));
     }
 
-    if ((Subtarget->hasUnpackedD16VMem() && Op.getValueType() == MVT::v2f16) ||
-        Op.getValueType() == MVT::v4f16) {
+    if (Subtarget->hasUnpackedD16VMem() &&
+        Op.getValueType().isVector() &&
+        Op.getValueType().getScalarSizeInBits() == 16) {
       return adjustLoadValueType(getImageOpcode(IntrID), cast<MemSDNode>(Op),
                                  DAG);
     }
@@ -5018,21 +5134,14 @@
     return SDValue();
   }
   default:
-    EVT LoadVT = Op.getValueType();
-    if (LoadVT.getScalarSizeInBits() != 16)
-      return SDValue();
-
-    const AMDGPU::D16ImageDimIntrinsic *D16ImageDimIntr =
-      AMDGPU::lookupD16ImageDimIntrinsicByIntr(IntrID);
-    if (D16ImageDimIntr) {
-      bool Unpacked = Subtarget->hasUnpackedD16VMem();
-      MemSDNode *M = cast<MemSDNode>(Op);
-
-      if (isTypeLegal(LoadVT) && (!Unpacked || LoadVT == MVT::f16))
-        return SDValue();
-
-      return adjustLoadValueType(D16ImageDimIntr->D16HelperIntr,
-                                 M, DAG, true);
+    if (Subtarget->hasUnpackedD16VMem() &&
+        Op.getValueType().isVector() &&
+        Op.getValueType().getScalarSizeInBits() == 16) {
+      if (const AMDGPU::D16ImageDimIntrinsic *D16ImageDimIntr =
+            AMDGPU::lookupD16ImageDimIntrinsicByIntr(IntrID)) {
+        return adjustLoadValueType(D16ImageDimIntr->D16HelperIntr,
+                                   cast<MemSDNode>(Op), DAG, true);
+      }
     }
 
     return SDValue();
@@ -5061,13 +5170,8 @@
     return DAG.UnrollVectorOp(ZExt.getNode());
   }
 
-  if (isTypeLegal(StoreVT))
-    return VData;
-
-  // If target supports packed vmem, we just need to workaround
-  // the illegal type by casting to an equivalent one.
-  EVT EquivStoreVT = getEquivalentMemType(*DAG.getContext(), StoreVT);
-  return DAG.getNode(ISD::BITCAST, DL, EquivStoreVT, VData);
+  assert(isTypeLegal(StoreVT));
+  return VData;
 }
 
 SDValue SITargetLowering::LowerINTRINSIC_VOID(SDValue Op,
@@ -5261,9 +5365,9 @@
   case Intrinsic::amdgcn_image_store:
   case Intrinsic::amdgcn_image_store_mip: {
     SDValue VData = Op.getOperand(2);
-    if ((Subtarget->hasUnpackedD16VMem() &&
-         VData.getValueType() == MVT::v2f16) ||
-        VData.getValueType() == MVT::v4f16) {
+    EVT VT = VData.getValueType();
+    if (Subtarget->hasUnpackedD16VMem() &&
+        VT.isVector() && VT.getScalarSizeInBits() == 16) {
       SDValue Chain = Op.getOperand(0);
 
       VData = handleD16VData(VData, DAG);
@@ -5293,9 +5397,9 @@
     if (D16ImageDimIntr) {
       SDValue VData = Op.getOperand(2);
       EVT StoreVT = VData.getValueType();
-      if (((StoreVT == MVT::v2f16 || StoreVT == MVT::v4f16) &&
-           Subtarget->hasUnpackedD16VMem()) ||
-          !isTypeLegal(StoreVT)) {
+      if (Subtarget->hasUnpackedD16VMem() &&
+          StoreVT.isVector() &&
+          StoreVT.getScalarSizeInBits() == 16) {
         SmallVector<SDValue, 12> Ops(Op.getNode()->op_values());
 
         Ops[1] = DAG.getConstant(D16ImageDimIntr->D16HelperIntr, DL, MVT::i32);
@@ -5521,8 +5625,8 @@
 }
 
 SDValue SITargetLowering::LowerSELECT(SDValue Op, SelectionDAG &DAG) const {
-  if (Op.getValueType() != MVT::i64)
-    return SDValue();
+  EVT VT = Op.getValueType();
+  assert(VT.getSizeInBits() == 64);
 
   SDLoc DL(Op);
   SDValue Cond = Op.getOperand(0);
@@ -5544,7 +5648,7 @@
   SDValue Hi = DAG.getSelect(DL, MVT::i32, Cond, Hi0, Hi1);
 
   SDValue Res = DAG.getBuildVector(MVT::v2i32, DL, {Lo, Hi});
-  return DAG.getNode(ISD::BITCAST, DL, MVT::i64, Res);
+  return DAG.getNode(ISD::BITCAST, DL, VT, Res);
 }
 
 // Catch division cases where we can use shortcuts with rcp and rsq