AMDGPU: Form more FMAs if fusion is allowed

Extend the existing fadd/fsub->fmad combines to produce
FMA if allowed.

llvm-svn: 290311
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index 1572897..52cc042 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -3871,24 +3871,31 @@
   return SDValue();
 }
 
+unsigned SITargetLowering::getFusedOpcode(const SelectionDAG &DAG, EVT VT) const {
+  // Only do this if we are not trying to support denormals. v_mad_f32 does not
+  // support denormals ever.
+  if ((VT == MVT::f32 && !Subtarget->hasFP32Denormals()) ||
+      (VT == MVT::f16 && !Subtarget->hasFP16Denormals()))
+    return ISD::FMAD;
+
+  const TargetOptions &Options = DAG.getTarget().Options;
+  if ((Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath) &&
+      isFMAFasterThanFMulAndFAdd(VT)) {
+    return ISD::FMA;
+  }
+
+  return 0;
+}
+
 SDValue SITargetLowering::performFAddCombine(SDNode *N,
                                              DAGCombinerInfo &DCI) const {
   if (DCI.getDAGCombineLevel() < AfterLegalizeDAG)
     return SDValue();
 
-  EVT VT = N->getValueType(0);
-  if (VT == MVT::f64)
-    return SDValue();
-
-    assert(!VT.isVector());
-
-  // Only do this if we are not trying to support denormals. v_mad_f32 does
-  // not support denormals ever.
-  if ((VT == MVT::f32 && Subtarget->hasFP32Denormals()) ||
-      (VT == MVT::f16 && Subtarget->hasFP16Denormals()))
-    return SDValue();
-
   SelectionDAG &DAG = DCI.DAG;
+  EVT VT = N->getValueType(0);
+  assert(!VT.isVector());
+
   SDLoc SL(N);
   SDValue LHS = N->getOperand(0);
   SDValue RHS = N->getOperand(1);
@@ -3900,8 +3907,11 @@
   if (LHS.getOpcode() == ISD::FADD) {
     SDValue A = LHS.getOperand(0);
     if (A == LHS.getOperand(1)) {
-      const SDValue Two = DAG.getConstantFP(2.0, SL, VT);
-      return DAG.getNode(ISD::FMAD, SL, VT, Two, A, RHS);
+      unsigned FusedOp = getFusedOpcode(DAG, VT);
+      if (FusedOp != 0) {
+        const SDValue Two = DAG.getConstantFP(2.0, SL, VT);
+        return DAG.getNode(FusedOp, SL, VT, Two, A, RHS);
+      }
     }
   }
 
@@ -3909,8 +3919,11 @@
   if (RHS.getOpcode() == ISD::FADD) {
     SDValue A = RHS.getOperand(0);
     if (A == RHS.getOperand(1)) {
-      const SDValue Two = DAG.getConstantFP(2.0, SL, VT);
-      return DAG.getNode(ISD::FMAD, SL, VT, Two, A, LHS);
+      unsigned FusedOp = getFusedOpcode(DAG, VT);
+      if (FusedOp != 0) {
+        const SDValue Two = DAG.getConstantFP(2.0, SL, VT);
+        return DAG.getNode(FusedOp, SL, VT, Two, A, LHS);
+      }
     }
   }
 
@@ -3932,29 +3945,31 @@
   //
   // Only do this if we are not trying to support denormals. v_mad_f32 does
   // not support denormals ever.
-  if ((VT == MVT::f32 && !Subtarget->hasFP32Denormals()) ||
-      (VT == MVT::f16 && !Subtarget->hasFP16Denormals())) {
-    SDValue LHS = N->getOperand(0);
-    SDValue RHS = N->getOperand(1);
-    if (LHS.getOpcode() == ISD::FADD) {
-      // (fsub (fadd a, a), c) -> mad 2.0, a, (fneg c)
-
-      SDValue A = LHS.getOperand(0);
-      if (A == LHS.getOperand(1)) {
+  SDValue LHS = N->getOperand(0);
+  SDValue RHS = N->getOperand(1);
+  if (LHS.getOpcode() == ISD::FADD) {
+    // (fsub (fadd a, a), c) -> mad 2.0, a, (fneg c)
+    SDValue A = LHS.getOperand(0);
+    if (A == LHS.getOperand(1)) {
+      unsigned FusedOp = getFusedOpcode(DAG, VT);
+      if (FusedOp != 0){
         const SDValue Two = DAG.getConstantFP(2.0, SL, VT);
         SDValue NegRHS = DAG.getNode(ISD::FNEG, SL, VT, RHS);
 
-        return DAG.getNode(ISD::FMAD, SL, VT, Two, A, NegRHS);
+        return DAG.getNode(FusedOp, SL, VT, Two, A, NegRHS);
       }
     }
+  }
 
-    if (RHS.getOpcode() == ISD::FADD) {
-      // (fsub c, (fadd a, a)) -> mad -2.0, a, c
+  if (RHS.getOpcode() == ISD::FADD) {
+    // (fsub c, (fadd a, a)) -> mad -2.0, a, c
 
-      SDValue A = RHS.getOperand(0);
-      if (A == RHS.getOperand(1)) {
+    SDValue A = RHS.getOperand(0);
+    if (A == RHS.getOperand(1)) {
+      unsigned FusedOp = getFusedOpcode(DAG, VT);
+      if (FusedOp != 0){
         const SDValue NegTwo = DAG.getConstantFP(-2.0, SL, VT);
-        return DAG.getNode(ISD::FMAD, SL, VT, NegTwo, A, LHS);
+        return DAG.getNode(FusedOp, SL, VT, NegTwo, A, LHS);
       }
     }
   }