AMDGPU: Select v_mad_u64_u32 and v_mad_i64_i32

llvm-svn: 317492
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
index c313e4a..f04efd7 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp
@@ -204,6 +204,7 @@
   void SelectADD_SUB_I64(SDNode *N);
   void SelectUADDO_USUBO(SDNode *N);
   void SelectDIV_SCALE(SDNode *N);
+  void SelectMAD_64_32(SDNode *N);
   void SelectFMA_W_CHAIN(SDNode *N);
   void SelectFMUL_W_CHAIN(SDNode *N);
 
@@ -594,6 +595,11 @@
     SelectDIV_SCALE(N);
     return;
   }
+  case AMDGPUISD::MAD_I64_I32:
+  case AMDGPUISD::MAD_U64_U32: {
+    SelectMAD_64_32(N);
+    return;
+  }
   case ISD::CopyToReg: {
     const SITargetLowering& Lowering =
       *static_cast<const SITargetLowering*>(getTargetLowering());
@@ -814,6 +820,19 @@
   CurDAG->SelectNodeTo(N, Opc, N->getVTList(), Ops);
 }
 
+// We need to handle this here because tablegen doesn't support matching
+// instructions with multiple outputs.
+void AMDGPUDAGToDAGISel::SelectMAD_64_32(SDNode *N) {
+  SDLoc SL(N);
+  bool Signed = N->getOpcode() == AMDGPUISD::MAD_I64_I32;
+  unsigned Opc = Signed ? AMDGPU::V_MAD_I64_I32 : AMDGPU::V_MAD_U64_U32;
+
+  SDValue Clamp = CurDAG->getTargetConstant(0, SL, MVT::i1);
+  SDValue Ops[] = { N->getOperand(0), N->getOperand(1), N->getOperand(2),
+                    Clamp };
+  CurDAG->SelectNodeTo(N, Opc, N->getVTList(), Ops);
+}
+
 bool AMDGPUDAGToDAGISel::isDSOffsetLegal(const SDValue &Base, unsigned Offset,
                                          unsigned OffsetBits) const {
   if ((OffsetBits == 16 && !isUInt<16>(Offset)) ||
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
index fe2c933..af22d52 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
@@ -151,6 +151,22 @@
   return false;
 }
 
+unsigned AMDGPUTargetLowering::numBitsUnsigned(SDValue Op, SelectionDAG &DAG) {
+  KnownBits Known;
+  EVT VT = Op.getValueType();
+  DAG.computeKnownBits(Op, Known);
+
+  return VT.getSizeInBits() - Known.countMinLeadingZeros();
+}
+
+unsigned AMDGPUTargetLowering::numBitsSigned(SDValue Op, SelectionDAG &DAG) {
+  EVT VT = Op.getValueType();
+
+  // In order for this to be a signed 24-bit value, bit 23, must
+  // be a sign bit.
+  return VT.getSizeInBits() - DAG.ComputeNumSignBits(Op);
+}
+
 AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM,
                                            const AMDGPUSubtarget &STI)
     : TargetLowering(TM), Subtarget(&STI) {
@@ -2615,21 +2631,14 @@
 //===----------------------------------------------------------------------===//
 
 static bool isU24(SDValue Op, SelectionDAG &DAG) {
-  KnownBits Known;
-  EVT VT = Op.getValueType();
-  DAG.computeKnownBits(Op, Known);
-
-  return (VT.getSizeInBits() - Known.countMinLeadingZeros()) <= 24;
+  return AMDGPUTargetLowering::numBitsUnsigned(Op, DAG) <= 24;
 }
 
 static bool isI24(SDValue Op, SelectionDAG &DAG) {
   EVT VT = Op.getValueType();
-
-  // In order for this to be a signed 24-bit value, bit 23, must
-  // be a sign bit.
   return VT.getSizeInBits() >= 24 && // Types less than 24-bit should be treated
                                      // as unsigned 24-bit values.
-         (VT.getSizeInBits() - DAG.ComputeNumSignBits(Op)) < 24;
+    AMDGPUTargetLowering::numBitsSigned(Op, DAG) < 24;
 }
 
 static bool simplifyI24(SDNode *Node24, unsigned OpIdx,
@@ -3946,6 +3955,8 @@
   NODE_NAME_CASE(MUL_LOHI_I24)
   NODE_NAME_CASE(MAD_U24)
   NODE_NAME_CASE(MAD_I24)
+  NODE_NAME_CASE(MAD_I64_I32)
+  NODE_NAME_CASE(MAD_U64_U32)
   NODE_NAME_CASE(TEXTURE_FETCH)
   NODE_NAME_CASE(EXPORT)
   NODE_NAME_CASE(EXPORT_DONE)
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
index cdb1518..dd3cc0a 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
@@ -36,6 +36,8 @@
 
 public:
   static bool isOrEquivalentToAdd(SelectionDAG &DAG, SDValue Op);
+  static unsigned numBitsUnsigned(SDValue Op, SelectionDAG &DAG);
+  static unsigned numBitsSigned(SDValue Op, SelectionDAG &DAG);
 
 protected:
   const AMDGPUSubtarget *Subtarget;
@@ -379,6 +381,8 @@
   MULHI_I24,
   MAD_U24,
   MAD_I24,
+  MAD_U64_U32,
+  MAD_I64_I32,
   MUL_LOHI_I24,
   MUL_LOHI_U24,
   TEXTURE_FETCH,
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUSubtarget.h b/llvm/lib/Target/AMDGPU/AMDGPUSubtarget.h
index 56a5fa6..6ee529c 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUSubtarget.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPUSubtarget.h
@@ -462,6 +462,10 @@
     return isAmdHsaOS() || isMesaKernel(MF);
   }
 
+  bool hasMad64_32() const {
+    return getGeneration() >= SEA_ISLANDS;
+  }
+
   bool hasFminFmaxLegacy() const {
     return getGeneration() < AMDGPUSubtarget::VOLCANIC_ISLANDS;
   }
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index 70e21a2..d1120f5 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -5962,18 +5962,57 @@
   return 0;
 }
 
+static SDValue getMad64_32(SelectionDAG &DAG, const SDLoc &SL,
+                           EVT VT,
+                           SDValue N0, SDValue N1, SDValue N2,
+                           bool Signed) {
+  unsigned MadOpc = Signed ? AMDGPUISD::MAD_I64_I32 : AMDGPUISD::MAD_U64_U32;
+  SDVTList VTs = DAG.getVTList(MVT::i64, MVT::i1);
+  SDValue Mad = DAG.getNode(MadOpc, SL, VTs, N0, N1, N2);
+  return DAG.getNode(ISD::TRUNCATE, SL, VT, Mad);
+}
+
 SDValue SITargetLowering::performAddCombine(SDNode *N,
                                             DAGCombinerInfo &DCI) const {
   SelectionDAG &DAG = DCI.DAG;
   EVT VT = N->getValueType(0);
-
-  if (VT != MVT::i32)
-    return SDValue();
-
   SDLoc SL(N);
   SDValue LHS = N->getOperand(0);
   SDValue RHS = N->getOperand(1);
 
+  if ((LHS.getOpcode() == ISD::MUL || RHS.getOpcode() == ISD::MUL)
+      && Subtarget->hasMad64_32() &&
+      !VT.isVector() && VT.getScalarSizeInBits() > 32 &&
+      VT.getScalarSizeInBits() <= 64) {
+    if (LHS.getOpcode() != ISD::MUL)
+      std::swap(LHS, RHS);
+
+    SDValue MulLHS = LHS.getOperand(0);
+    SDValue MulRHS = LHS.getOperand(1);
+    SDValue AddRHS = RHS;
+
+    // TODO: Maybe restrict if SGPR inputs.
+    if (numBitsUnsigned(MulLHS, DAG) <= 32 &&
+        numBitsUnsigned(MulRHS, DAG) <= 32) {
+      MulLHS = DAG.getZExtOrTrunc(MulLHS, SL, MVT::i32);
+      MulRHS = DAG.getZExtOrTrunc(MulRHS, SL, MVT::i32);
+      AddRHS = DAG.getZExtOrTrunc(AddRHS, SL, MVT::i64);
+      return getMad64_32(DAG, SL, VT, MulLHS, MulRHS, AddRHS, false);
+    }
+
+    if (numBitsSigned(MulLHS, DAG) < 32 && numBitsSigned(MulRHS, DAG) < 32) {
+      MulLHS = DAG.getSExtOrTrunc(MulLHS, SL, MVT::i32);
+      MulRHS = DAG.getSExtOrTrunc(MulRHS, SL, MVT::i32);
+      AddRHS = DAG.getSExtOrTrunc(AddRHS, SL, MVT::i64);
+      return getMad64_32(DAG, SL, VT, MulLHS, MulRHS, AddRHS, true);
+    }
+
+    return SDValue();
+  }
+
+  if (VT != MVT::i32)
+    return SDValue();
+
   // add x, zext (setcc) => addcarry x, 0, setcc
   // add x, sext (setcc) => subcarry x, 0, setcc
   unsigned Opc = LHS.getOpcode();