Teach DAG combine a number of tricks to simplify FMA expressions in fast-math mode.


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@163051 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index cd528c5..ee6c2a3 100644
--- a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -5988,6 +5988,7 @@
   ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(N0);
   ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(N1);
   EVT VT = N->getValueType(0);
+  DebugLoc dl = N->getDebugLoc();
 
   if (N0CFP && N0CFP->isExactlyValue(1.0))
     return DAG.getNode(ISD::FADD, N->getDebugLoc(), VT, N1, N2);
@@ -5998,6 +5999,58 @@
   if (N0CFP && !N1CFP)
     return DAG.getNode(ISD::FMA, N->getDebugLoc(), VT, N1, N0, N2);
 
+  // (fma x, c1, (fmul x, c2)) -> (fmul x, c1+c2)
+  if (DAG.getTarget().Options.UnsafeFPMath && N1CFP &&
+      N2.getOpcode() == ISD::FMUL &&
+      N0 == N2.getOperand(0) &&
+      N2.getOperand(1).getOpcode() == ISD::ConstantFP) {
+    return DAG.getNode(ISD::FMUL, dl, VT, N0,
+                       DAG.getNode(ISD::FADD, dl, VT, N1, N2.getOperand(1)));
+  }
+
+
+  // (fma (fmul x, c1), c2, y) -> (fma x, c1*c2, y)
+  if (DAG.getTarget().Options.UnsafeFPMath &&
+      N0.getOpcode() == ISD::FMUL && N1CFP &&
+      N0.getOperand(1).getOpcode() == ISD::ConstantFP) {
+    return DAG.getNode(ISD::FMA, dl, VT,
+                       N0.getOperand(0),
+                       DAG.getNode(ISD::FMUL, dl, VT, N1, N0.getOperand(1)),
+                       N2);
+  }
+
+  // (fma x, 1, y) -> (fadd x, y)
+  // (fma x, -1, y) -> (fadd (fneg x), y)
+  if (N1CFP) {
+    if (N1CFP->isExactlyValue(1.0))
+      return DAG.getNode(ISD::FADD, dl, VT, N0, N2);
+
+    if (N1CFP->isExactlyValue(-1.0) &&
+        (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))) {
+      SDValue RHSNeg = DAG.getNode(ISD::FNEG, dl, VT, N0);
+      AddToWorkList(RHSNeg.getNode());
+      return DAG.getNode(ISD::FADD, dl, VT, N2, RHSNeg);
+    }
+  }
+
+  // (fma x, c, x) -> (fmul x, (c+1))
+  if (DAG.getTarget().Options.UnsafeFPMath && N1CFP && N0 == N2) {
+    return DAG.getNode(ISD::FMUL, dl, VT,
+                       N0,
+                       DAG.getNode(ISD::FADD, dl, VT,
+                                   N1, DAG.getConstantFP(1.0, VT)));
+  }
+
+  // (fma x, c, (fneg x)) -> (fmul x, (c-1))
+  if (DAG.getTarget().Options.UnsafeFPMath && N1CFP &&
+      N2.getOpcode() == ISD::FNEG && N2.getOperand(0) == N0) {
+    return DAG.getNode(ISD::FMUL, dl, VT,
+                       N0,
+                       DAG.getNode(ISD::FADD, dl, VT,
+                                   N1, DAG.getConstantFP(-1.0, VT)));
+  }
+
+
   return SDValue();
 }
 
@@ -6367,6 +6420,17 @@
     }
   }
 
+  // (fneg (fmul c, x)) -> (fmul -c, x)
+  if (N0.getOpcode() == ISD::FMUL) {
+    ConstantFPSDNode *CFP1 = dyn_cast<ConstantFPSDNode>(N0.getOperand(1));
+    if (CFP1) {
+      return DAG.getNode(ISD::FMUL, N->getDebugLoc(), VT,
+                         N0.getOperand(0),
+                         DAG.getNode(ISD::FNEG, N->getDebugLoc(), VT,
+                                     N0.getOperand(1)));
+    }
+  }
+
   return SDValue();
 }