AMDGPU: Select v_mad_u64_u32 and v_mad_i64_i32

llvm-svn: 317492
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index 70e21a2..d1120f5 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -5962,18 +5962,57 @@
   return 0;
 }
 
+static SDValue getMad64_32(SelectionDAG &DAG, const SDLoc &SL,
+                           EVT VT,
+                           SDValue N0, SDValue N1, SDValue N2,
+                           bool Signed) {
+  unsigned MadOpc = Signed ? AMDGPUISD::MAD_I64_I32 : AMDGPUISD::MAD_U64_U32;
+  SDVTList VTs = DAG.getVTList(MVT::i64, MVT::i1);
+  SDValue Mad = DAG.getNode(MadOpc, SL, VTs, N0, N1, N2);
+  return DAG.getNode(ISD::TRUNCATE, SL, VT, Mad);
+}
+
 SDValue SITargetLowering::performAddCombine(SDNode *N,
                                             DAGCombinerInfo &DCI) const {
   SelectionDAG &DAG = DCI.DAG;
   EVT VT = N->getValueType(0);
-
-  if (VT != MVT::i32)
-    return SDValue();
-
   SDLoc SL(N);
   SDValue LHS = N->getOperand(0);
   SDValue RHS = N->getOperand(1);
 
+  if ((LHS.getOpcode() == ISD::MUL || RHS.getOpcode() == ISD::MUL)
+      && Subtarget->hasMad64_32() &&
+      !VT.isVector() && VT.getScalarSizeInBits() > 32 &&
+      VT.getScalarSizeInBits() <= 64) {
+    if (LHS.getOpcode() != ISD::MUL)
+      std::swap(LHS, RHS);
+
+    SDValue MulLHS = LHS.getOperand(0);
+    SDValue MulRHS = LHS.getOperand(1);
+    SDValue AddRHS = RHS;
+
+    // TODO: Maybe restrict if SGPR inputs.
+    if (numBitsUnsigned(MulLHS, DAG) <= 32 &&
+        numBitsUnsigned(MulRHS, DAG) <= 32) {
+      MulLHS = DAG.getZExtOrTrunc(MulLHS, SL, MVT::i32);
+      MulRHS = DAG.getZExtOrTrunc(MulRHS, SL, MVT::i32);
+      AddRHS = DAG.getZExtOrTrunc(AddRHS, SL, MVT::i64);
+      return getMad64_32(DAG, SL, VT, MulLHS, MulRHS, AddRHS, false);
+    }
+
+    if (numBitsSigned(MulLHS, DAG) < 32 && numBitsSigned(MulRHS, DAG) < 32) {
+      MulLHS = DAG.getSExtOrTrunc(MulLHS, SL, MVT::i32);
+      MulRHS = DAG.getSExtOrTrunc(MulRHS, SL, MVT::i32);
+      AddRHS = DAG.getSExtOrTrunc(AddRHS, SL, MVT::i64);
+      return getMad64_32(DAG, SL, VT, MulLHS, MulRHS, AddRHS, true);
+    }
+
+    return SDValue();
+  }
+
+  if (VT != MVT::i32)
+    return SDValue();
+
   // add x, zext (setcc) => addcarry x, 0, setcc
   // add x, sext (setcc) => subcarry x, 0, setcc
   unsigned Opc = LHS.getOpcode();