Add generic fmad DAG node.

This allows sharing of FMA forming combines to work
with instructions that have the same semantics as a separate
multiply and add.

This is expand by default, and only formed post legalization
so it shouldn't have much impact on targets that do not want it.

llvm-svn: 230070
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index e8d1acf..53867b5 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -6938,6 +6938,133 @@
   return DAG.getNode(ISD::BUILD_VECTOR, SDLoc(BV), VT, Ops);
 }
 
+// Attempt different variants of (fadd (fmul a, b), c) -> fma or fmad
+static SDValue performFaddFmulCombines(unsigned FusedOpcode,
+                                       bool Aggressive,
+                                       SDNode *N,
+                                       const TargetLowering &TLI,
+                                       SelectionDAG &DAG) {
+  SDValue N0 = N->getOperand(0);
+  SDValue N1 = N->getOperand(1);
+  EVT VT = N->getValueType(0);
+
+  // fold (fadd (fmul x, y), z) -> (fma x, y, z)
+  if (N0.getOpcode() == ISD::FMUL &&
+      (Aggressive || N0->hasOneUse())) {
+    return DAG.getNode(FusedOpcode, SDLoc(N), VT,
+                       N0.getOperand(0), N0.getOperand(1), N1);
+  }
+
+  // fold (fadd x, (fmul y, z)) -> (fma y, z, x)
+  // Note: Commutes FADD operands.
+  if (N1.getOpcode() == ISD::FMUL &&
+      (Aggressive || N1->hasOneUse())) {
+    return DAG.getNode(FusedOpcode, SDLoc(N), VT,
+                       N1.getOperand(0), N1.getOperand(1), N0);
+  }
+
+  // More folding opportunities when target permits.
+  if (Aggressive) {
+    // fold (fadd (fma x, y, (fmul u, v)), z) -> (fma x, y (fma u, v, z))
+    if (N0.getOpcode() == ISD::FMA &&
+        N0.getOperand(2).getOpcode() == ISD::FMUL) {
+      return DAG.getNode(FusedOpcode, SDLoc(N), VT,
+                         N0.getOperand(0), N0.getOperand(1),
+                         DAG.getNode(FusedOpcode, SDLoc(N), VT,
+                                     N0.getOperand(2).getOperand(0),
+                                     N0.getOperand(2).getOperand(1),
+                                     N1));
+    }
+
+    // fold (fadd x, (fma y, z, (fmul u, v)) -> (fma y, z (fma u, v, x))
+    if (N1->getOpcode() == ISD::FMA &&
+        N1.getOperand(2).getOpcode() == ISD::FMUL) {
+      return DAG.getNode(FusedOpcode, SDLoc(N), VT,
+                         N1.getOperand(0), N1.getOperand(1),
+                         DAG.getNode(FusedOpcode, SDLoc(N), VT,
+                                     N1.getOperand(2).getOperand(0),
+                                     N1.getOperand(2).getOperand(1),
+                                     N0));
+    }
+  }
+
+  return SDValue();
+}
+
+static SDValue performFsubFmulCombines(unsigned FusedOpcode,
+                                       bool Aggressive,
+                                       SDNode *N,
+                                       const TargetLowering &TLI,
+                                       SelectionDAG &DAG) {
+  SDValue N0 = N->getOperand(0);
+  SDValue N1 = N->getOperand(1);
+  EVT VT = N->getValueType(0);
+
+  SDLoc SL(N);
+
+  // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z))
+  if (N0.getOpcode() == ISD::FMUL &&
+      (Aggressive || N0->hasOneUse())) {
+    return DAG.getNode(FusedOpcode, SL, VT,
+                       N0.getOperand(0), N0.getOperand(1),
+                       DAG.getNode(ISD::FNEG, SL, VT, N1));
+  }
+
+  // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
+  // Note: Commutes FSUB operands.
+  if (N1.getOpcode() == ISD::FMUL &&
+      (Aggressive || N1->hasOneUse()))
+    return DAG.getNode(FusedOpcode, SL, VT,
+                       DAG.getNode(ISD::FNEG, SL, VT,
+                                   N1.getOperand(0)),
+                       N1.getOperand(1), N0);
+
+  // fold (fsub (fneg (fmul, x, y)), z) -> (fma (fneg x), y, (fneg z))
+  if (N0.getOpcode() == ISD::FNEG &&
+      N0.getOperand(0).getOpcode() == ISD::FMUL &&
+      (Aggressive || (N0->hasOneUse() && N0.getOperand(0).hasOneUse()))) {
+    SDValue N00 = N0.getOperand(0).getOperand(0);
+    SDValue N01 = N0.getOperand(0).getOperand(1);
+    return DAG.getNode(FusedOpcode, SL, VT,
+                       DAG.getNode(ISD::FNEG, SL, VT, N00), N01,
+                       DAG.getNode(ISD::FNEG, SL, VT, N1));
+  }
+
+  // More folding opportunities when target permits.
+  if (Aggressive) {
+    // fold (fsub (fma x, y, (fmul u, v)), z)
+    //   -> (fma x, y (fma u, v, (fneg z)))
+    if (N0.getOpcode() == FusedOpcode &&
+        N0.getOperand(2).getOpcode() == ISD::FMUL) {
+      return DAG.getNode(FusedOpcode, SDLoc(N), VT,
+                         N0.getOperand(0), N0.getOperand(1),
+                         DAG.getNode(FusedOpcode, SDLoc(N), VT,
+                                     N0.getOperand(2).getOperand(0),
+                                     N0.getOperand(2).getOperand(1),
+                                     DAG.getNode(ISD::FNEG, SDLoc(N), VT,
+                                                 N1)));
+    }
+
+    // fold (fsub x, (fma y, z, (fmul u, v)))
+    //   -> (fma (fneg y), z, (fma (fneg u), v, x))
+    if (N1.getOpcode() == FusedOpcode &&
+        N1.getOperand(2).getOpcode() == ISD::FMUL) {
+      SDValue N20 = N1.getOperand(2).getOperand(0);
+      SDValue N21 = N1.getOperand(2).getOperand(1);
+      return DAG.getNode(FusedOpcode, SDLoc(N), VT,
+                         DAG.getNode(ISD::FNEG, SDLoc(N), VT,
+                                     N1.getOperand(0)),
+                         N1.getOperand(1),
+                         DAG.getNode(FusedOpcode, SDLoc(N), VT,
+                                     DAG.getNode(ISD::FNEG, SDLoc(N),  VT,
+                                                 N20),
+                                     N21, N0));
+    }
+  }
+
+  return SDValue();
+}
+
 SDValue DAGCombiner::visitFADD(SDNode *N) {
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
@@ -7077,23 +7204,27 @@
     }
   } // enable-unsafe-fp-math
 
+  if (LegalOperations && TLI.isOperationLegal(ISD::FMAD, VT)) {
+    // Assume if there is an fmad instruction that it should be aggressively
+    // used.
+    if (SDValue Fused = performFaddFmulCombines(ISD::FMAD, true, N, TLI, DAG))
+      return Fused;
+  }
+
   // FADD -> FMA combines:
   if ((Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath) &&
       TLI.isFMAFasterThanFMulAndFAdd(VT) &&
       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT))) {
 
-    // fold (fadd (fmul x, y), z) -> (fma x, y, z)
-    if (N0.getOpcode() == ISD::FMUL &&
-        (N0->hasOneUse() || TLI.enableAggressiveFMAFusion(VT)))
-      return DAG.getNode(ISD::FMA, SDLoc(N), VT,
-                         N0.getOperand(0), N0.getOperand(1), N1);
-
-    // fold (fadd x, (fmul y, z)) -> (fma y, z, x)
-    // Note: Commutes FADD operands.
-    if (N1.getOpcode() == ISD::FMUL &&
-        (N1->hasOneUse() || TLI.enableAggressiveFMAFusion(VT)))
-      return DAG.getNode(ISD::FMA, SDLoc(N), VT,
-                         N1.getOperand(0), N1.getOperand(1), N0);
+    if (!TLI.isOperationLegal(ISD::FMAD, VT)) {
+      // Don't form FMA if we are preferring FMAD.
+      if (SDValue Fused
+          = performFaddFmulCombines(ISD::FMA,
+                                    TLI.enableAggressiveFMAFusion(VT),
+                                    N, TLI, DAG)) {
+        return Fused;
+      }
+    }
 
     // When FP_EXTEND nodes are free on the target, and there is an opportunity
     // to combine into FMA, arrange such nodes accordingly.
@@ -7122,30 +7253,6 @@
                                          N10.getOperand(1)), N0);
       }
     }
-
-    // More folding opportunities when target permits.
-    if (TLI.enableAggressiveFMAFusion(VT)) {
-
-      // fold (fadd (fma x, y, (fmul u, v)), z) -> (fma x, y (fma u, v, z))
-      if (N0.getOpcode() == ISD::FMA &&
-          N0.getOperand(2).getOpcode() == ISD::FMUL)
-        return DAG.getNode(ISD::FMA, SDLoc(N), VT,
-                           N0.getOperand(0), N0.getOperand(1),
-                           DAG.getNode(ISD::FMA, SDLoc(N), VT,
-                                       N0.getOperand(2).getOperand(0),
-                                       N0.getOperand(2).getOperand(1),
-                                       N1));
-
-      // fold (fadd x, (fma y, z, (fmul u, v)) -> (fma y, z (fma u, v, x))
-      if (N1->getOpcode() == ISD::FMA &&
-          N1.getOperand(2).getOpcode() == ISD::FMUL)
-        return DAG.getNode(ISD::FMA, SDLoc(N), VT,
-                           N1.getOperand(0), N1.getOperand(1),
-                           DAG.getNode(ISD::FMA, SDLoc(N), VT,
-                                       N1.getOperand(2).getOperand(0),
-                                       N1.getOperand(2).getOperand(1),
-                                       N0));
-    }
   }
 
   return SDValue();
@@ -7207,43 +7314,32 @@
     }
   }
 
+  if (LegalOperations && TLI.isOperationLegal(ISD::FMAD, VT)) {
+    // Assume if there is an fmad instruction that it should be aggressively
+    // used.
+    if (SDValue Fused = performFsubFmulCombines(ISD::FMAD, true, N, TLI, DAG))
+      return Fused;
+  }
+
   // FSUB -> FMA combines:
   if ((Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath) &&
       TLI.isFMAFasterThanFMulAndFAdd(VT) &&
       (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT))) {
 
-    // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z))
-    if (N0.getOpcode() == ISD::FMUL &&
-        (N0->hasOneUse() || TLI.enableAggressiveFMAFusion(VT)))
-      return DAG.getNode(ISD::FMA, dl, VT,
-                         N0.getOperand(0), N0.getOperand(1),
-                         DAG.getNode(ISD::FNEG, dl, VT, N1));
+    if (!TLI.isOperationLegal(ISD::FMAD, VT)) {
+      // Don't form FMA if we are preferring FMAD.
 
-    // fold (fsub x, (fmul y, z)) -> (fma (fneg y), z, x)
-    // Note: Commutes FSUB operands.
-    if (N1.getOpcode() == ISD::FMUL &&
-        (N1->hasOneUse() || TLI.enableAggressiveFMAFusion(VT)))
-      return DAG.getNode(ISD::FMA, dl, VT,
-                         DAG.getNode(ISD::FNEG, dl, VT,
-                         N1.getOperand(0)),
-                         N1.getOperand(1), N0);
-
-    // fold (fsub (fneg (fmul, x, y)), z) -> (fma (fneg x), y, (fneg z))
-    if (N0.getOpcode() == ISD::FNEG &&
-        N0.getOperand(0).getOpcode() == ISD::FMUL &&
-        ((N0->hasOneUse() && N0.getOperand(0).hasOneUse()) ||
-            TLI.enableAggressiveFMAFusion(VT))) {
-      SDValue N00 = N0.getOperand(0).getOperand(0);
-      SDValue N01 = N0.getOperand(0).getOperand(1);
-      return DAG.getNode(ISD::FMA, dl, VT,
-                         DAG.getNode(ISD::FNEG, dl, VT, N00), N01,
-                         DAG.getNode(ISD::FNEG, dl, VT, N1));
+      if (SDValue Fused
+          = performFsubFmulCombines(ISD::FMA,
+                                    TLI.enableAggressiveFMAFusion(VT),
+                                    N, TLI, DAG)) {
+        return Fused;
+      }
     }
 
     // When FP_EXTEND nodes are free on the target, and there is an opportunity
     // to combine into FMA, arrange such nodes accordingly.
     if (TLI.isFPExtFree(VT)) {
-
       // fold (fsub (fpext (fmul x, y)), z)
       //   -> (fma (fpext x), (fpext y), (fneg z))
       if (N0.getOpcode() == ISD::FP_EXTEND) {
@@ -7308,38 +7404,6 @@
         }
       }
     }
-
-    // More folding opportunities when target permits.
-    if (TLI.enableAggressiveFMAFusion(VT)) {
-
-      // fold (fsub (fma x, y, (fmul u, v)), z)
-      //   -> (fma x, y (fma u, v, (fneg z)))
-      if (N0.getOpcode() == ISD::FMA &&
-          N0.getOperand(2).getOpcode() == ISD::FMUL)
-        return DAG.getNode(ISD::FMA, SDLoc(N), VT,
-                           N0.getOperand(0), N0.getOperand(1),
-                           DAG.getNode(ISD::FMA, SDLoc(N), VT,
-                                       N0.getOperand(2).getOperand(0),
-                                       N0.getOperand(2).getOperand(1),
-                                       DAG.getNode(ISD::FNEG, SDLoc(N), VT,
-                                                   N1)));
-
-      // fold (fsub x, (fma y, z, (fmul u, v)))
-      //   -> (fma (fneg y), z, (fma (fneg u), v, x))
-      if (N1.getOpcode() == ISD::FMA &&
-          N1.getOperand(2).getOpcode() == ISD::FMUL) {
-        SDValue N20 = N1.getOperand(2).getOperand(0);
-        SDValue N21 = N1.getOperand(2).getOperand(1);
-        return DAG.getNode(ISD::FMA, SDLoc(N), VT,
-                           DAG.getNode(ISD::FNEG, SDLoc(N), VT,
-                                       N1.getOperand(0)),
-                           N1.getOperand(1),
-                           DAG.getNode(ISD::FMA, SDLoc(N), VT,
-                                       DAG.getNode(ISD::FNEG, SDLoc(N),  VT,
-                                                   N20),
-                                       N21, N0));
-      }
-    }
   }
 
   return SDValue();