AMDGPU: Check fast math flags in fadd/fsub combines

llvm-svn: 290312
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index 52cc042..9a0002d 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -3871,7 +3871,11 @@
   return SDValue();
 }
 
-unsigned SITargetLowering::getFusedOpcode(const SelectionDAG &DAG, EVT VT) const {
+unsigned SITargetLowering::getFusedOpcode(const SelectionDAG &DAG,
+                                          const SDNode *N0,
+                                          const SDNode *N1) const {
+  EVT VT = N0->getValueType(0);
+
   // Only do this if we are not trying to support denormals. v_mad_f32 does not
   // support denormals ever.
   if ((VT == MVT::f32 && !Subtarget->hasFP32Denormals()) ||
@@ -3879,7 +3883,10 @@
     return ISD::FMAD;
 
   const TargetOptions &Options = DAG.getTarget().Options;
-  if ((Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath) &&
+  if ((Options.AllowFPOpFusion == FPOpFusion::Fast ||
+       Options.UnsafeFPMath ||
+       (cast<BinaryWithFlagsSDNode>(N0)->Flags.hasUnsafeAlgebra() &&
+        cast<BinaryWithFlagsSDNode>(N1)->Flags.hasUnsafeAlgebra())) &&
       isFMAFasterThanFMulAndFAdd(VT)) {
     return ISD::FMA;
   }
@@ -3907,7 +3914,7 @@
   if (LHS.getOpcode() == ISD::FADD) {
     SDValue A = LHS.getOperand(0);
     if (A == LHS.getOperand(1)) {
-      unsigned FusedOp = getFusedOpcode(DAG, VT);
+      unsigned FusedOp = getFusedOpcode(DAG, N, LHS.getNode());
       if (FusedOp != 0) {
         const SDValue Two = DAG.getConstantFP(2.0, SL, VT);
         return DAG.getNode(FusedOp, SL, VT, Two, A, RHS);
@@ -3919,7 +3926,7 @@
   if (RHS.getOpcode() == ISD::FADD) {
     SDValue A = RHS.getOperand(0);
     if (A == RHS.getOperand(1)) {
-      unsigned FusedOp = getFusedOpcode(DAG, VT);
+      unsigned FusedOp = getFusedOpcode(DAG, N, RHS.getNode());
       if (FusedOp != 0) {
         const SDValue Two = DAG.getConstantFP(2.0, SL, VT);
         return DAG.getNode(FusedOp, SL, VT, Two, A, LHS);
@@ -3951,7 +3958,7 @@
     // (fsub (fadd a, a), c) -> mad 2.0, a, (fneg c)
     SDValue A = LHS.getOperand(0);
     if (A == LHS.getOperand(1)) {
-      unsigned FusedOp = getFusedOpcode(DAG, VT);
+      unsigned FusedOp = getFusedOpcode(DAG, N, LHS.getNode());
       if (FusedOp != 0){
         const SDValue Two = DAG.getConstantFP(2.0, SL, VT);
         SDValue NegRHS = DAG.getNode(ISD::FNEG, SL, VT, RHS);
@@ -3966,7 +3973,7 @@
 
     SDValue A = RHS.getOperand(0);
     if (A == RHS.getOperand(1)) {
-      unsigned FusedOp = getFusedOpcode(DAG, VT);
+      unsigned FusedOp = getFusedOpcode(DAG, N, RHS.getNode());
       if (FusedOp != 0){
         const SDValue NegTwo = DAG.getConstantFP(-2.0, SL, VT);
         return DAG.getNode(FusedOp, SL, VT, NegTwo, A, LHS);