[InstCombine] Missed optimization in math expression: squashing sqrt functions

Summary: This patch enables folding under -ffast-math flag sqrt(a) * sqrt(b) -> sqrt(a*b)

Reviewers: hfinkel, spatel, davide

Reviewed By: spatel, davide

Subscribers: davide, llvm-commits

Differential Revision: https://reviews.llvm.org/D41322

llvm-svn: 321637
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 541dde6..3860483 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -728,6 +728,23 @@
     }
   }
 
+  // sqrt(a) * sqrt(b) -> sqrt(a * b)
+  if (AllowReassociate &&
+      Op0->hasOneUse() && Op1->hasOneUse()) {
+    Value *Opnd0 = nullptr;
+    Value *Opnd1 = nullptr;
+    if (match(Op0, m_Intrinsic<Intrinsic::sqrt>(m_Value(Opnd0))) &&
+        match(Op1, m_Intrinsic<Intrinsic::sqrt>(m_Value(Opnd1)))) {
+      BuilderTy::FastMathFlagGuard Guard(Builder);
+      Builder.setFastMathFlags(I.getFastMathFlags());
+      Value *FMulVal = Builder.CreateFMul(Opnd0, Opnd1);
+      Value *Sqrt = Intrinsic::getDeclaration(I.getModule(), 
+                                              Intrinsic::sqrt, I.getType());
+      Value *SqrtCall = Builder.CreateCall(Sqrt, FMulVal);
+      return replaceInstUsesWith(I, SqrtCall);
+    }
+  }
+
   // Handle symmetric situation in a 2-iteration loop
   Value *Opnd0 = Op0;
   Value *Opnd1 = Op1;