AMDGPU: Custom lower v4i16/v4f16 vector operations

Avoids stack access.

Also handle extract hi elt pattern from truncate + shift
to avoid a couple test regressions.

llvm-svn: 332453
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
index d00727b..9885546 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
@@ -3144,6 +3144,28 @@
     }
   }
 
+  // Equivalent of above for accessing the high element of a vector as an
+  // integer operation.
+  // trunc (srl (bitcast (build_vector x, y))), 16 -> trunc (bitcast y)
+  if (Src.getOpcode() == ISD::SRL) {
+    if (auto K = isConstOrConstSplat(Src.getOperand(1))) {
+      if (2 * K->getZExtValue() == Src.getValueType().getScalarSizeInBits()) {
+        SDValue BV = stripBitcast(Src.getOperand(0));
+        if (BV.getOpcode() == ISD::BUILD_VECTOR &&
+            BV.getValueType().getVectorNumElements() == 2) {
+          SDValue SrcElt = BV.getOperand(1);
+          EVT SrcEltVT = SrcElt.getValueType();
+          if (SrcEltVT.isFloatingPoint()) {
+            SrcElt = DAG.getNode(ISD::BITCAST, SL,
+                                 SrcEltVT.changeTypeToInteger(), SrcElt);
+          }
+
+          return DAG.getNode(ISD::TRUNCATE, SL, VT, SrcElt);
+        }
+      }
+    }
+  }
+
   // Partially shrink 64-bit shifts to 32-bit if reduced to 16-bit.
   //
   // i16 (trunc (srl i64:x, K)), K <= 16 ->
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
index 6db8339..e8052de 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
@@ -137,6 +137,10 @@
     return false;
   }
 
+  static inline SDValue stripBitcast(SDValue Val) {
+    return Val.getOpcode() == ISD::BITCAST ? Val.getOperand(0) : Val;
+  }
+
   static bool allUsesHaveSourceMods(const SDNode *N,
                                     unsigned CostThreshold = 4);
   bool isFAbsFree(EVT VT) const override;
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index d34249e..f21dd62 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -288,13 +288,24 @@
   setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v16i32, Expand);
   setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v16f32, Expand);
 
+  setOperationAction(ISD::BUILD_VECTOR, MVT::v4f16, Custom);
+  setOperationAction(ISD::BUILD_VECTOR, MVT::v4i16, Custom);
+
   // Avoid stack access for these.
   // TODO: Generalize to more vector types.
   setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v2i16, Custom);
   setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v2f16, Custom);
+  setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v4i16, Custom);
+  setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v4f16, Custom);
+
   setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v2i16, Custom);
   setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v2f16, Custom);
 
+  setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v4i16, Custom);
+  setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v4f16, Custom);
+  setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v4i16, Custom);
+  setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v4f16, Custom);
+
   // BUFFER/FLAT_ATOMIC_CMP_SWAP on GCN GPUs needs input marshalling,
   // and output demarshalling
   setOperationAction(ISD::ATOMIC_CMP_SWAP, MVT::i32, Custom);
@@ -3333,6 +3344,8 @@
     return lowerINSERT_VECTOR_ELT(Op, DAG);
   case ISD::EXTRACT_VECTOR_ELT:
     return lowerEXTRACT_VECTOR_ELT(Op, DAG);
+  case ISD::BUILD_VECTOR:
+    return lowerBUILD_VECTOR(Op, DAG);
   case ISD::FP_ROUND:
     return lowerFP_ROUND(Op, DAG);
   case ISD::TRAP:
@@ -4157,34 +4170,72 @@
 
 SDValue SITargetLowering::lowerINSERT_VECTOR_ELT(SDValue Op,
                                                  SelectionDAG &DAG) const {
+  SDValue Vec = Op.getOperand(0);
+  SDValue InsVal = Op.getOperand(1);
   SDValue Idx = Op.getOperand(2);
+  EVT VecVT = Vec.getValueType();
+
+  assert(VecVT.getScalarSizeInBits() == 16);
+
+  unsigned NumElts = VecVT.getVectorNumElements();
+  SDLoc SL(Op);
+  auto KIdx = dyn_cast<ConstantSDNode>(Idx);
+
+  if (NumElts == 4 && KIdx) {
+    SDValue BCVec = DAG.getNode(ISD::BITCAST, SL, MVT::v2i32, Vec);
+
+    SDValue LoHalf = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, MVT::i32, BCVec,
+                                 DAG.getConstant(0, SL, MVT::i32));
+    SDValue HiHalf = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, MVT::i32, BCVec,
+                                 DAG.getConstant(1, SL, MVT::i32));
+
+    SDValue LoVec = DAG.getNode(ISD::BITCAST, SL, MVT::v2i16, LoHalf);
+    SDValue HiVec = DAG.getNode(ISD::BITCAST, SL, MVT::v2i16, HiHalf);
+
+    unsigned Idx = KIdx->getZExtValue();
+    bool InsertLo = Idx < 2;
+    SDValue InsHalf = DAG.getNode(ISD::INSERT_VECTOR_ELT, SL, MVT::v2i16,
+      InsertLo ? LoVec : HiVec,
+      DAG.getNode(ISD::BITCAST, SL, MVT::i16, InsVal),
+      DAG.getConstant(InsertLo ? Idx : (Idx - 2), SL, MVT::i32));
+
+    InsHalf = DAG.getNode(ISD::BITCAST, SL, MVT::i32, InsHalf);
+
+    SDValue Concat = InsertLo ?
+      DAG.getBuildVector(MVT::v2i32, SL, { InsHalf, HiHalf }) :
+      DAG.getBuildVector(MVT::v2i32, SL, { LoHalf, InsHalf });
+
+    return DAG.getNode(ISD::BITCAST, SL, VecVT, Concat);
+  }
+
+  assert(NumElts == 2 || NumElts == 4);
+
   if (isa<ConstantSDNode>(Idx))
     return SDValue();
 
+  EVT IntVT = NumElts == 2 ? MVT::i32 : MVT::i64;
+
   // Avoid stack access for dynamic indexing.
-  SDLoc SL(Op);
-  SDValue Vec = Op.getOperand(0);
-  SDValue Val = DAG.getNode(ISD::BITCAST, SL, MVT::i16, Op.getOperand(1));
+  SDValue Val = DAG.getNode(ISD::BITCAST, SL, MVT::i16, InsVal);
 
   // v_bfi_b32 (v_bfm_b32 16, (shl idx, 16)), val, vec
-  SDValue ExtVal = DAG.getNode(ISD::ZERO_EXTEND, SL, MVT::i32, Val);
+  SDValue ExtVal = DAG.getNode(ISD::ZERO_EXTEND, SL, IntVT, Val);
 
   // Convert vector index to bit-index.
   SDValue ScaledIdx = DAG.getNode(ISD::SHL, SL, MVT::i32, Idx,
                                   DAG.getConstant(4, SL, MVT::i32));
 
-  SDValue BCVec = DAG.getNode(ISD::BITCAST, SL, MVT::i32, Vec);
-
-  SDValue BFM = DAG.getNode(ISD::SHL, SL, MVT::i32,
-                            DAG.getConstant(0xffff, SL, MVT::i32),
+  SDValue BCVec = DAG.getNode(ISD::BITCAST, SL, IntVT, Vec);
+  SDValue BFM = DAG.getNode(ISD::SHL, SL, IntVT,
+                            DAG.getConstant(0xffff, SL, IntVT),
                             ScaledIdx);
 
-  SDValue LHS = DAG.getNode(ISD::AND, SL, MVT::i32, BFM, ExtVal);
-  SDValue RHS = DAG.getNode(ISD::AND, SL, MVT::i32,
-                            DAG.getNOT(SL, BFM, MVT::i32), BCVec);
+  SDValue LHS = DAG.getNode(ISD::AND, SL, IntVT, BFM, ExtVal);
+  SDValue RHS = DAG.getNode(ISD::AND, SL, IntVT,
+                            DAG.getNOT(SL, BFM, IntVT), BCVec);
 
-  SDValue BFI = DAG.getNode(ISD::OR, SL, MVT::i32, LHS, RHS);
-  return DAG.getNode(ISD::BITCAST, SL, Op.getValueType(), BFI);
+  SDValue BFI = DAG.getNode(ISD::OR, SL, IntVT, LHS, RHS);
+  return DAG.getNode(ISD::BITCAST, SL, VecVT, BFI);
 }
 
 SDValue SITargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op,
@@ -4194,6 +4245,9 @@
   EVT ResultVT = Op.getValueType();
   SDValue Vec = Op.getOperand(0);
   SDValue Idx = Op.getOperand(1);
+  EVT VecVT = Vec.getValueType();
+  unsigned NumElts = VecVT.getVectorNumElements();
+  assert(VecVT.getScalarSizeInBits() == 16 && (NumElts == 2 || NumElts == 4));
 
   DAGCombinerInfo DCI(DAG, AfterLegalizeVectorOps, true, nullptr);
 
@@ -4204,19 +4258,43 @@
   if (SDValue Combined = performExtractVectorEltCombine(Op.getNode(), DCI))
     return Combined;
 
+  EVT IntVT = NumElts == 2 ? MVT::i32 : MVT::i64;
   SDValue Four = DAG.getConstant(4, SL, MVT::i32);
 
   // Convert vector index to bit-index (* 16)
   SDValue ScaledIdx = DAG.getNode(ISD::SHL, SL, MVT::i32, Idx, Four);
 
-  SDValue BC = DAG.getNode(ISD::BITCAST, SL, MVT::i32, Vec);
-  SDValue Elt = DAG.getNode(ISD::SRL, SL, MVT::i32, BC, ScaledIdx);
+  SDValue BC = DAG.getNode(ISD::BITCAST, SL, IntVT, Vec);
+  SDValue Elt = DAG.getNode(ISD::SRL, SL, IntVT, BC, ScaledIdx);
 
-  SDValue Result = Elt;
-  if (ResultVT.bitsLT(MVT::i32))
-    Result = DAG.getNode(ISD::TRUNCATE, SL, MVT::i16, Result);
+  if (ResultVT == MVT::f16) {
+    SDValue Result = DAG.getNode(ISD::TRUNCATE, SL, MVT::i16, Elt);
+    return DAG.getNode(ISD::BITCAST, SL, ResultVT, Result);
+  }
 
-  return DAG.getNode(ISD::BITCAST, SL, ResultVT, Result);
+  return DAG.getAnyExtOrTrunc(Elt, SL, ResultVT);
+}
+
+SDValue SITargetLowering::lowerBUILD_VECTOR(SDValue Op,
+                                            SelectionDAG &DAG) const {
+  SDLoc SL(Op);
+  EVT VT = Op.getValueType();
+  assert(VT == MVT::v4i16 || VT == MVT::v4f16);
+
+  EVT HalfVT = MVT::getVectorVT(VT.getVectorElementType().getSimpleVT(), 2);
+
+  // Turn into pair of packed build_vectors.
+  // TODO: Special case for constants that can be materialized with s_mov_b64.
+  SDValue Lo = DAG.getBuildVector(HalfVT, SL,
+                                  { Op.getOperand(0), Op.getOperand(1) });
+  SDValue Hi = DAG.getBuildVector(HalfVT, SL,
+                                  { Op.getOperand(2), Op.getOperand(3) });
+
+  SDValue CastLo = DAG.getNode(ISD::BITCAST, SL, MVT::i32, Lo);
+  SDValue CastHi = DAG.getNode(ISD::BITCAST, SL, MVT::i32, Hi);
+
+  SDValue Blend = DAG.getBuildVector(MVT::v2i32, SL, { CastLo, CastHi });
+  return DAG.getNode(ISD::BITCAST, SL, VT, Blend);
 }
 
 bool
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.h b/llvm/lib/Target/AMDGPU/SIISelLowering.h
index fba383d..3a99994 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.h
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.h
@@ -84,6 +84,7 @@
   SDValue lowerADDRSPACECAST(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
+  SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerTRAP(SDValue Op, SelectionDAG &DAG) const;
 
   SDNode *adjustWritemask(MachineSDNode *&N, SelectionDAG &DAG) const;