[InstCombine] allow fmul-sqrt folds with less than full -ffast-math

Also, add a Builder method for intrinsics to reduce code duplication for clients.

llvm-svn: 325960
diff --git a/llvm/lib/IR/IRBuilder.cpp b/llvm/lib/IR/IRBuilder.cpp
index 0085b82..c9eef2d 100644
--- a/llvm/lib/IR/IRBuilder.cpp
+++ b/llvm/lib/IR/IRBuilder.cpp
@@ -59,8 +59,11 @@
 
 static CallInst *createCallHelper(Value *Callee, ArrayRef<Value *> Ops,
                                   IRBuilderBase *Builder,
-                                  const Twine& Name="") {
+                                  const Twine &Name = "",
+                                  Instruction *FMFSource = nullptr) {
   CallInst *CI = CallInst::Create(Callee, Ops, Name);
+  if (FMFSource)
+    CI->copyFastMathFlags(FMFSource);
   Builder->GetInsertBlock()->getInstList().insert(Builder->GetInsertPoint(),CI);
   Builder->SetInstDebugLocation(CI);
   return CI;  
@@ -646,7 +649,18 @@
 CallInst *IRBuilderBase::CreateBinaryIntrinsic(Intrinsic::ID ID,
                                                Value *LHS, Value *RHS,
                                                const Twine &Name) {
-  Module *M = BB->getParent()->getParent();
-  Function *Fn =  Intrinsic::getDeclaration(M, ID, { LHS->getType() });
+  Module *M = BB->getModule();
+  Function *Fn = Intrinsic::getDeclaration(M, ID, { LHS->getType() });
   return createCallHelper(Fn, { LHS, RHS }, this, Name);
 }
+
+CallInst *IRBuilderBase::CreateIntrinsic(Intrinsic::ID ID,
+                                         ArrayRef<Value *> Args,
+                                         Instruction *FMFSource,
+                                         const Twine &Name) {
+  assert(!Args.empty() && "Expected at least one argument to intrinsic");
+  Module *M = BB->getModule();
+  Function *Fn = Intrinsic::getDeclaration(M, ID, { Args.front()->getType() });
+  return createCallHelper(Fn, Args, this, Name, FMFSource);
+}
+
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index cb48b93..d5456cc 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -662,24 +662,17 @@
     }
   }
 
-  // sqrt(a) * sqrt(b) -> sqrt(a * b)
-  if (AllowReassociate && Op0->hasOneUse() && Op1->hasOneUse()) {
-    Value *Opnd0 = nullptr;
-    Value *Opnd1 = nullptr;
-    if (match(Op0, m_Intrinsic<Intrinsic::sqrt>(m_Value(Opnd0))) &&
-        match(Op1, m_Intrinsic<Intrinsic::sqrt>(m_Value(Opnd1)))) {
-      BuilderTy::FastMathFlagGuard Guard(Builder);
-      Builder.setFastMathFlags(I.getFastMathFlags());
-      Value *FMulVal = Builder.CreateFMul(Opnd0, Opnd1);
-      Value *Sqrt = Intrinsic::getDeclaration(I.getModule(), 
-                                              Intrinsic::sqrt, I.getType());
-      Value *SqrtCall = Builder.CreateCall(Sqrt, FMulVal);
-      return replaceInstUsesWith(I, SqrtCall);
-    }
+  // sqrt(X) * sqrt(Y) -> sqrt(X * Y)
+  Value *X, *Y;
+  if (I.hasAllowReassoc() &&
+      match(Op0, m_OneUse(m_Intrinsic<Intrinsic::sqrt>(m_Value(X)))) &&
+      match(Op1, m_OneUse(m_Intrinsic<Intrinsic::sqrt>(m_Value(Y))))) {
+    Value *XY = Builder.CreateFMulFMF(X, Y, &I);
+    Value *Sqrt = Builder.CreateIntrinsic(Intrinsic::sqrt, { XY }, &I);
+    return replaceInstUsesWith(I, Sqrt);
   }
 
   // -X * -Y --> X * Y
-  Value *X, *Y;
   if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_FNeg(m_Value(Y))))
     return BinaryOperator::CreateFMulFMF(X, Y, &I);