InstCombine: fabs(x) * fabs(x) -> x * x

llvm-svn: 259295
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 160792b..161d3eb 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -612,12 +612,23 @@
     }
   }
 
-  // sqrt(X) * sqrt(X) -> X
-  if (AllowReassociate && (Op0 == Op1))
-    if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op0))
-      if (II->getIntrinsicID() == Intrinsic::sqrt)
+  if (Op0 == Op1) {
+    if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op0)) {
+      // sqrt(X) * sqrt(X) -> X
+      if (AllowReassociate && II->getIntrinsicID() == Intrinsic::sqrt)
         return ReplaceInstUsesWith(I, II->getOperand(0));
 
+      // fabs(X) * fabs(X) -> X * X
+      if (II->getIntrinsicID() == Intrinsic::fabs) {
+        Instruction *FMulVal = BinaryOperator::CreateFMul(II->getOperand(0),
+                                                          II->getOperand(0),
+                                                          I.getName());
+        FMulVal->copyFastMathFlags(&I);
+        return FMulVal;
+      }
+    }
+  }
+
   // Under unsafe algebra do:
   // X * log2(0.5*Y) = X*log2(Y) - X
   if (AllowReassociate) {