[InstCombine] Limit FMul constant folding for fma simplifications.

As @reames pointed out post-commit, rL371518 adds additional rounding
in some cases, when doing constant folding of the multiplication.
This breaks a guarantee llvm.fma makes and must be avoided.

This patch reapplies rL371518, but splits off the simplifications not
requiring rounding from SimplifFMulInst as SimplifyFMAFMul.

Reviewers: spatel, lebedev.ri, reames, scanon

Reviewed By: reames

Differential Revision: https://reviews.llvm.org/D67434

llvm-svn: 372899
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 67b06ea..4ae052e 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -4576,15 +4576,8 @@
   return nullptr;
 }
 
-/// Given the operands for an FMul, see if we can fold the result
-static Value *SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF,
-                               const SimplifyQuery &Q, unsigned MaxRecurse) {
-  if (Constant *C = foldOrCommuteConstant(Instruction::FMul, Op0, Op1, Q))
-    return C;
-
-  if (Constant *C = simplifyFPBinop(Op0, Op1))
-    return C;
-
+static Value *SimplifyFMAFMul(Value *Op0, Value *Op1, FastMathFlags FMF,
+                              const SimplifyQuery &Q, unsigned MaxRecurse) {
   // fmul X, 1.0 ==> X
   if (match(Op1, m_FPOne()))
     return Op0;
@@ -4605,6 +4598,19 @@
   return nullptr;
 }
 
+/// Given the operands for an FMul, see if we can fold the result
+static Value *SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF,
+                               const SimplifyQuery &Q, unsigned MaxRecurse) {
+  if (Constant *C = foldOrCommuteConstant(Instruction::FMul, Op0, Op1, Q))
+    return C;
+
+  if (Constant *C = simplifyFPBinop(Op0, Op1))
+    return C;
+
+  // Now apply simplifications that do not require rounding.
+  return SimplifyFMAFMul(Op0, Op1, FMF, Q, MaxRecurse);
+}
+
 Value *llvm::SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF,
                               const SimplifyQuery &Q) {
   return ::SimplifyFAddInst(Op0, Op1, FMF, Q, RecursionLimit);
@@ -4621,6 +4627,11 @@
   return ::SimplifyFMulInst(Op0, Op1, FMF, Q, RecursionLimit);
 }
 
+Value *llvm::SimplifyFMAFMul(Value *Op0, Value *Op1, FastMathFlags FMF,
+                             const SimplifyQuery &Q) {
+  return ::SimplifyFMAFMul(Op0, Op1, FMF, Q, RecursionLimit);
+}
+
 static Value *SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF,
                                const SimplifyQuery &Q, unsigned) {
   if (Constant *C = foldOrCommuteConstant(Instruction::FDiv, Op0, Op1, Q))
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 5c0e648..51a21e3 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -2234,6 +2234,15 @@
       return replaceInstUsesWith(*II, Add);
     }
 
+    // Try to simplify the underlying FMul.
+    if (Value *V = SimplifyFMulInst(II->getArgOperand(0), II->getArgOperand(1),
+                                    II->getFastMathFlags(),
+                                    SQ.getWithInstruction(II))) {
+      auto *FAdd = BinaryOperator::CreateFAdd(V, II->getArgOperand(2));
+      FAdd->copyFastMathFlags(II);
+      return FAdd;
+    }
+
     LLVM_FALLTHROUGH;
   }
   case Intrinsic::fma: {
@@ -2258,9 +2267,12 @@
       return II;
     }
 
-    // fma x, 1, z -> fadd x, z
-    if (match(Src1, m_FPOne())) {
-      auto *FAdd = BinaryOperator::CreateFAdd(Src0, II->getArgOperand(2));
+    // Try to simplify the underlying FMul. We can only apply simplifications
+    // that do not require rounding.
+    if (Value *V = SimplifyFMAFMul(II->getArgOperand(0), II->getArgOperand(1),
+                                   II->getFastMathFlags(),
+                                   SQ.getWithInstruction(II))) {
+      auto *FAdd = BinaryOperator::CreateFAdd(V, II->getArgOperand(2));
       FAdd->copyFastMathFlags(II);
       return FAdd;
     }