[InstCombine] fold fmul/fdiv with fabs operands
fabs(X) * fabs(Y) --> fabs(X * Y)
fabs(X) / fabs(Y) --> fabs(X / Y)
If both operands of fmul/fdiv are positive, then the result must be positive.
There's a NAN corner-case that prevents removing the more specific fold just
above this one:
fabs(X) * fabs(X) -> X * X
That fold works even with NAN because the sign-bit result of the multiply is
not specified if X is NAN.
We can't remove that and use the more general fold that is proposed here
because once we convert to this:
fabs (X * X)
...it is not legal to simplify the 'fabs' out of that expression when X is NAN.
That's because fabs() guarantees that the sign-bit is always cleared - even
for NAN values.
So this patch has the potential to lose information, but it seems unlikely if
we do the more specific fold ahead of this one.
Differential Revision: https://reviews.llvm.org/D82277
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index cbbde9a..2d4fdd3 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -402,7 +402,7 @@
return Changed ? &I : nullptr;
}
-static Instruction *foldFPSignBitOps(BinaryOperator &I) {
+Instruction *InstCombiner::foldFPSignBitOps(BinaryOperator &I) {
BinaryOperator::BinaryOps Opcode = I.getOpcode();
assert((Opcode == Instruction::FMul || Opcode == Instruction::FDiv) &&
"Expected fmul or fdiv");
@@ -420,6 +420,19 @@
if (Op0 == Op1 && match(Op0, m_Intrinsic<Intrinsic::fabs>(m_Value(X))))
return BinaryOperator::CreateWithCopiedFlags(Opcode, X, X, &I);
+ // fabs(X) * fabs(Y) --> fabs(X * Y)
+ // fabs(X) / fabs(Y) --> fabs(X / Y)
+ if (match(Op0, m_Intrinsic<Intrinsic::fabs>(m_Value(X))) &&
+ match(Op1, m_Intrinsic<Intrinsic::fabs>(m_Value(Y))) &&
+ (Op0->hasOneUse() || Op1->hasOneUse())) {
+ IRBuilder<>::FastMathFlagGuard FMFGuard(Builder);
+ Builder.setFastMathFlags(I.getFastMathFlags());
+ Value *XY = Builder.CreateBinOp(Opcode, X, Y);
+ Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, XY);
+ Fabs->takeName(&I);
+ return replaceInstUsesWith(I, Fabs);
+ }
+
return nullptr;
}