[InstCombine] Negator - sink sinkable negations
Summary:
As we have discussed previously (e.g. in D63992 / D64090 / [[ https://bugs.llvm.org/show_bug.cgi?id=42457 | PR42457 ]]), `sub` instruction
can almost be considered non-canonical. While we do convert `sub %x, C` -> `add %x, -C`,
we sparsely do that for non-constants. But we should.
Here, i propose to interpret `sub %x, %y` as `add (sub 0, %y), %x` IFF the negation can be sinked into the `%y`
This has some potential to cause endless combine loops (either around PHI's, or if there are some opposite transforms).
For former there's `-instcombine-negator-max-depth` option to mitigate it, should this expose any such issues
For latter, if there are still any such opposing folds, we'd need to remove the colliding fold.
In any case, reproducers welcomed!
Reviewers: spatel, nikic, efriedma, xbolva00
Reviewed By: spatel
Subscribers: xbolva00, mgorny, hiraditya, reames, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D68408
diff --git a/llvm/lib/Transforms/InstCombine/CMakeLists.txt b/llvm/lib/Transforms/InstCombine/CMakeLists.txt
index 2f19882..1a34f22 100644
--- a/llvm/lib/Transforms/InstCombine/CMakeLists.txt
+++ b/llvm/lib/Transforms/InstCombine/CMakeLists.txt
@@ -12,6 +12,7 @@
InstCombineCompares.cpp
InstCombineLoadStoreAlloca.cpp
InstCombineMulDivRem.cpp
+ InstCombineNegator.cpp
InstCombinePHI.cpp
InstCombineSelect.cpp
InstCombineShifts.cpp
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 7ca287f..16666fe 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1682,12 +1682,10 @@
if (Instruction *X = foldVectorBinop(I))
return X;
- // (A*B)-(A*C) -> A*(B-C) etc
- if (Value *V = SimplifyUsingDistributiveLaws(I))
- return replaceInstUsesWith(I, V);
+ Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
// If this is a 'B = x-(-A)', change to B = x+A.
- Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+ // We deal with this without involving Negator to preserve NSW flag.
if (Value *V = dyn_castNegVal(Op1)) {
BinaryOperator *Res = BinaryOperator::CreateAdd(Op0, V);
@@ -1704,6 +1702,45 @@
return Res;
}
+ auto TryToNarrowDeduceFlags = [this, &I, &Op0, &Op1]() -> Instruction * {
+ if (Instruction *Ext = narrowMathIfNoOverflow(I))
+ return Ext;
+
+ bool Changed = false;
+ if (!I.hasNoSignedWrap() && willNotOverflowSignedSub(Op0, Op1, I)) {
+ Changed = true;
+ I.setHasNoSignedWrap(true);
+ }
+ if (!I.hasNoUnsignedWrap() && willNotOverflowUnsignedSub(Op0, Op1, I)) {
+ Changed = true;
+ I.setHasNoUnsignedWrap(true);
+ }
+
+ return Changed ? &I : nullptr;
+ };
+
+ // First, let's try to interpret `sub a, b` as `add a, (sub 0, b)`,
+ // and let's try to sink `(sub 0, b)` into `b` itself. But only if this isn't
+ // a pure negation used by a select that looks like abs/nabs.
+ bool IsNegation = match(Op0, m_ZeroInt());
+ if (!IsNegation || none_of(I.users(), [&I, Op1](const User *U) {
+ const Instruction *UI = dyn_cast<Instruction>(U);
+ if (!UI)
+ return false;
+ return match(UI,
+ m_Select(m_Value(), m_Specific(Op1), m_Specific(&I))) ||
+ match(UI, m_Select(m_Value(), m_Specific(&I), m_Specific(Op1)));
+ })) {
+ if (Value *NegOp1 = Negator::Negate(IsNegation, Op1, *this))
+ return BinaryOperator::CreateAdd(NegOp1, Op0);
+ }
+ if (IsNegation)
+ return TryToNarrowDeduceFlags(); // Should have been handled in Negator!
+
+ // (A*B)-(A*C) -> A*(B-C) etc
+ if (Value *V = SimplifyUsingDistributiveLaws(I))
+ return replaceInstUsesWith(I, V);
+
if (I.getType()->isIntOrIntVectorTy(1))
return BinaryOperator::CreateXor(Op0, Op1);
@@ -1720,22 +1757,7 @@
if (match(Op0, m_OneUse(m_Add(m_Value(X), m_AllOnes()))))
return BinaryOperator::CreateAdd(Builder.CreateNot(Op1), X);
- // Y - (X + 1) --> ~X + Y
- if (match(Op1, m_OneUse(m_Add(m_Value(X), m_One()))))
- return BinaryOperator::CreateAdd(Builder.CreateNot(X), Op0);
-
- // Y - ~X --> (X + 1) + Y
- if (match(Op1, m_OneUse(m_Not(m_Value(X))))) {
- return BinaryOperator::CreateAdd(
- Builder.CreateAdd(Op0, ConstantInt::get(I.getType(), 1)), X);
- }
-
if (Constant *C = dyn_cast<Constant>(Op0)) {
- // -f(x) -> f(-x) if possible.
- if (match(C, m_Zero()))
- if (Value *Neg = freelyNegateValue(Op1))
- return replaceInstUsesWith(I, Neg);
-
Value *X;
if (match(Op1, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1))
// C - (zext bool) --> bool ? C - 1 : C
@@ -1770,26 +1792,12 @@
}
const APInt *Op0C;
- if (match(Op0, m_APInt(Op0C))) {
- if (Op0C->isNullValue() && Op1->hasOneUse()) {
- Value *LHS, *RHS;
- SelectPatternFlavor SPF = matchSelectPattern(Op1, LHS, RHS).Flavor;
- if (SPF == SPF_ABS || SPF == SPF_NABS) {
- // This is a negate of an ABS/NABS pattern. Just swap the operands
- // of the select.
- cast<SelectInst>(Op1)->swapValues();
- // Don't swap prof metadata, we didn't change the branch behavior.
- return replaceInstUsesWith(I, Op1);
- }
- }
-
+ if (match(Op0, m_APInt(Op0C)) && Op0C->isMask()) {
// Turn this into a xor if LHS is 2^n-1 and the remaining bits are known
// zero.
- if (Op0C->isMask()) {
- KnownBits RHSKnown = computeKnownBits(Op1, 0, &I);
- if ((*Op0C | RHSKnown.Zero).isAllOnesValue())
- return BinaryOperator::CreateXor(Op1, Op0);
- }
+ KnownBits RHSKnown = computeKnownBits(Op1, 0, &I);
+ if ((*Op0C | RHSKnown.Zero).isAllOnesValue())
+ return BinaryOperator::CreateXor(Op1, Op0);
}
{
@@ -1919,49 +1927,6 @@
return BinaryOperator::CreateAnd(
Op0, Builder.CreateNot(Y, Y->getName() + ".not"));
- if (Op1->hasOneUse()) {
- Value *Y = nullptr, *Z = nullptr;
- Constant *C = nullptr;
-
- // (X - (Y - Z)) --> (X + (Z - Y)).
- if (match(Op1, m_Sub(m_Value(Y), m_Value(Z))))
- return BinaryOperator::CreateAdd(Op0,
- Builder.CreateSub(Z, Y, Op1->getName()));
-
- // Subtracting -1/0 is the same as adding 1/0:
- // sub [nsw] Op0, sext(bool Y) -> add [nsw] Op0, zext(bool Y)
- // 'nuw' is dropped in favor of the canonical form.
- if (match(Op1, m_SExt(m_Value(Y))) &&
- Y->getType()->getScalarSizeInBits() == 1) {
- Value *Zext = Builder.CreateZExt(Y, I.getType());
- BinaryOperator *Add = BinaryOperator::CreateAdd(Op0, Zext);
- Add->setHasNoSignedWrap(I.hasNoSignedWrap());
- return Add;
- }
- // sub [nsw] X, zext(bool Y) -> add [nsw] X, sext(bool Y)
- // 'nuw' is dropped in favor of the canonical form.
- if (match(Op1, m_ZExt(m_Value(Y))) && Y->getType()->isIntOrIntVectorTy(1)) {
- Value *Sext = Builder.CreateSExt(Y, I.getType());
- BinaryOperator *Add = BinaryOperator::CreateAdd(Op0, Sext);
- Add->setHasNoSignedWrap(I.hasNoSignedWrap());
- return Add;
- }
-
- // X - A*-B -> X + A*B
- // X - -A*B -> X + A*B
- Value *A, *B;
- if (match(Op1, m_c_Mul(m_Value(A), m_Neg(m_Value(B)))))
- return BinaryOperator::CreateAdd(Op0, Builder.CreateMul(A, B));
-
- // X - A*C -> X + A*-C
- // No need to handle commuted multiply because multiply handling will
- // ensure constant will be move to the right hand side.
- if (match(Op1, m_Mul(m_Value(A), m_Constant(C))) && !isa<ConstantExpr>(C)) {
- Value *NewMul = Builder.CreateMul(A, ConstantExpr::getNeg(C));
- return BinaryOperator::CreateAdd(Op0, NewMul);
- }
- }
-
{
// ~A - Min/Max(~A, O) -> Max/Min(A, ~O) - A
// ~A - Min/Max(O, ~A) -> Max/Min(A, ~O) - A
@@ -2036,20 +2001,7 @@
canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(I))
return V;
- if (Instruction *Ext = narrowMathIfNoOverflow(I))
- return Ext;
-
- bool Changed = false;
- if (!I.hasNoSignedWrap() && willNotOverflowSignedSub(Op0, Op1, I)) {
- Changed = true;
- I.setHasNoSignedWrap(true);
- }
- if (!I.hasNoUnsignedWrap() && willNotOverflowUnsignedSub(Op0, Op1, I)) {
- Changed = true;
- I.setHasNoUnsignedWrap(true);
- }
-
- return Changed ? &I : nullptr;
+ return TryToNarrowDeduceFlags();
}
/// This eliminates floating-point negation in either 'fneg(X)' or
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 6544fd4..a908349 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -16,6 +16,7 @@
#define LLVM_LIB_TRANSFORMS_INSTCOMBINE_INSTCOMBINEINTERNAL_H
#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/TargetFolder.h"
@@ -471,7 +472,6 @@
bool shouldChangeType(unsigned FromBitWidth, unsigned ToBitWidth) const;
bool shouldChangeType(Type *From, Type *To) const;
Value *dyn_castNegVal(Value *V) const;
- Value *freelyNegateValue(Value *V);
Type *FindElementAtOffset(PointerType *PtrTy, int64_t Offset,
SmallVectorImpl<Value *> &NewIndices);
@@ -513,7 +513,7 @@
Instruction *simplifyMaskedStore(IntrinsicInst &II);
Instruction *simplifyMaskedGather(IntrinsicInst &II);
Instruction *simplifyMaskedScatter(IntrinsicInst &II);
-
+
/// Transform (zext icmp) to bitwise / integer operations in order to
/// eliminate it.
///
@@ -1014,6 +1014,55 @@
Value *Descale(Value *Val, APInt Scale, bool &NoSignedWrap);
};
+namespace {
+
+// As a default, let's assume that we want to be aggressive,
+// and attempt to traverse with no limits in attempt to sink negation.
+static constexpr unsigned NegatorDefaultMaxDepth = ~0U;
+
+// Let's guesstimate that most often we will end up visiting/producing
+// fairly small number of new instructions.
+static constexpr unsigned NegatorMaxNodesSSO = 16;
+
+} // namespace
+
+class Negator final {
+ /// Top-to-bottom, def-to-use negated instruction tree we produced.
+ SmallVector<Instruction *, NegatorMaxNodesSSO> NewInstructions;
+
+ using BuilderTy = IRBuilder<TargetFolder, IRBuilderCallbackInserter>;
+ BuilderTy Builder;
+
+ const bool IsTrulyNegation;
+
+ Negator(LLVMContext &C, const DataLayout &DL, bool IsTrulyNegation);
+
+#if LLVM_ENABLE_STATS
+ unsigned NumValuesVisitedInThisNegator = 0;
+ ~Negator();
+#endif
+
+ using Result = std::pair<ArrayRef<Instruction *> /*NewInstructions*/,
+ Value * /*NegatedRoot*/>;
+
+ LLVM_NODISCARD Value *visit(Value *V, unsigned Depth);
+
+ /// Recurse depth-first and attempt to sink the negation.
+ /// FIXME: use worklist?
+ LLVM_NODISCARD Optional<Result> run(Value *Root);
+
+ Negator(const Negator &) = delete;
+ Negator(Negator &&) = delete;
+ Negator &operator=(const Negator &) = delete;
+ Negator &operator=(Negator &&) = delete;
+
+public:
+ /// Attempt to negate \p Root. Retuns nullptr if negation can't be performed,
+ /// otherwise returns negated value.
+ LLVM_NODISCARD static Value *Negate(bool LHSIsZero, Value *Root,
+ InstCombiner &IC);
+};
+
} // end namespace llvm
#undef DEBUG_TYPE
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp b/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp
new file mode 100644
index 0000000..2655ef3
--- /dev/null
+++ b/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp
@@ -0,0 +1,377 @@
+//===- InstCombineNegator.cpp -----------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements sinking of negation into expression trees,
+// as long as that can be done without increasing instruction count.
+//
+//===----------------------------------------------------------------------===//
+
+#include "InstCombineInternal.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/None.h"
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/ADT/iterator_range.h"
+#include "llvm/Analysis/TargetFolder.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/IR/Constant.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DebugLoc.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/Use.h"
+#include "llvm/IR/User.h"
+#include "llvm/IR/Value.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Compiler.h"
+#include "llvm/Support/DebugCounter.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/raw_ostream.h"
+#include <functional>
+#include <tuple>
+#include <utility>
+
+using namespace llvm;
+
+#define DEBUG_TYPE "instcombine"
+
+STATISTIC(NegatorTotalNegationsAttempted,
+ "Negator: Number of negations attempted to be sinked");
+STATISTIC(NegatorNumTreesNegated,
+ "Negator: Number of negations successfully sinked");
+STATISTIC(NegatorMaxDepthVisited, "Negator: Maximal traversal depth ever "
+ "reached while attempting to sink negation");
+STATISTIC(NegatorTimesDepthLimitReached,
+ "Negator: How many times did the traversal depth limit was reached "
+ "during sinking");
+STATISTIC(
+ NegatorNumValuesVisited,
+ "Negator: Total number of values visited during attempts to sink negation");
+STATISTIC(NegatorMaxTotalValuesVisited,
+ "Negator: Maximal number of values ever visited while attempting to "
+ "sink negation");
+STATISTIC(NegatorNumInstructionsCreatedTotal,
+ "Negator: Number of new negated instructions created, total");
+STATISTIC(NegatorMaxInstructionsCreated,
+ "Negator: Maximal number of new instructions created during negation "
+ "attempt");
+STATISTIC(NegatorNumInstructionsNegatedSuccess,
+ "Negator: Number of new negated instructions created in successful "
+ "negation sinking attempts");
+
+DEBUG_COUNTER(NegatorCounter, "instcombine-negator",
+ "Controls Negator transformations in InstCombine pass");
+
+static cl::opt<bool>
+ NegatorEnabled("instcombine-negator-enabled", cl::init(true),
+ cl::desc("Should we attempt to sink negations?"));
+
+static cl::opt<unsigned>
+ NegatorMaxDepth("instcombine-negator-max-depth",
+ cl::init(NegatorDefaultMaxDepth),
+ cl::desc("What is the maximal lookup depth when trying to "
+ "check for viability of negation sinking."));
+
+Negator::Negator(LLVMContext &C, const DataLayout &DL, bool IsTrulyNegation_)
+ : Builder(C, TargetFolder(DL),
+ IRBuilderCallbackInserter([&](Instruction *I) {
+ ++NegatorNumInstructionsCreatedTotal;
+ NewInstructions.push_back(I);
+ })),
+ IsTrulyNegation(IsTrulyNegation_) {}
+
+#if LLVM_ENABLE_STATS
+Negator::~Negator() {
+ NegatorMaxTotalValuesVisited.updateMax(NumValuesVisitedInThisNegator);
+}
+#endif
+
+// FIXME: can this be reworked into a worklist-based algorithm while preserving
+// the depth-first, early bailout traversal?
+LLVM_NODISCARD Value *Negator::visit(Value *V, unsigned Depth) {
+ NegatorMaxDepthVisited.updateMax(Depth);
+ ++NegatorNumValuesVisited;
+
+#if LLVM_ENABLE_STATS
+ ++NumValuesVisitedInThisNegator;
+#endif
+
+ // In i1, negation can simply be ignored.
+ if (V->getType()->isIntOrIntVectorTy(1))
+ return V;
+
+ Value *X;
+
+ // -(-(X)) -> X.
+ if (match(V, m_Neg(m_Value(X))))
+ return X;
+
+ // Integral constants can be freely negated.
+ if (match(V, m_AnyIntegralConstant()))
+ return ConstantExpr::getNeg(cast<Constant>(V), /*HasNUW=*/false,
+ /*HasNSW=*/false);
+
+ // If we have a non-instruction, then give up.
+ if (!isa<Instruction>(V))
+ return nullptr;
+
+ // If we have started with a true negation (i.e. `sub 0, %y`), then if we've
+ // got instruction that does not require recursive reasoning, we can still
+ // negate it even if it has other uses, without increasing instruction count.
+ if (!V->hasOneUse() && !IsTrulyNegation)
+ return nullptr;
+
+ auto *I = cast<Instruction>(V);
+ unsigned BitWidth = I->getType()->getScalarSizeInBits();
+
+ // We must preserve the insertion point and debug info that is set in the
+ // builder at the time this function is called.
+ InstCombiner::BuilderTy::InsertPointGuard Guard(Builder);
+ // And since we are trying to negate instruction I, that tells us about the
+ // insertion point and the debug info that we need to keep.
+ Builder.SetInsertPoint(I);
+
+ // In some cases we can give the answer without further recursion.
+ switch (I->getOpcode()) {
+ case Instruction::Sub:
+ // `sub` is always negatible.
+ return Builder.CreateSub(I->getOperand(1), I->getOperand(0),
+ I->getName() + ".neg");
+ case Instruction::Add:
+ // `inc` is always negatible.
+ if (match(I->getOperand(1), m_One()))
+ return Builder.CreateNot(I->getOperand(0), I->getName() + ".neg");
+ break;
+ case Instruction::Xor:
+ // `not` is always negatible.
+ if (match(I, m_Not(m_Value(X))))
+ return Builder.CreateAdd(X, ConstantInt::get(X->getType(), 1),
+ I->getName() + ".neg");
+ break;
+ case Instruction::AShr:
+ case Instruction::LShr: {
+ // Right-shift sign bit smear is negatible.
+ const APInt *Op1Val;
+ if (match(I->getOperand(1), m_APInt(Op1Val)) && *Op1Val == BitWidth - 1) {
+ Value *BO = I->getOpcode() == Instruction::AShr
+ ? Builder.CreateLShr(I->getOperand(0), I->getOperand(1))
+ : Builder.CreateAShr(I->getOperand(0), I->getOperand(1));
+ if (auto *NewInstr = dyn_cast<Instruction>(BO)) {
+ NewInstr->copyIRFlags(I);
+ NewInstr->setName(I->getName() + ".neg");
+ }
+ return BO;
+ }
+ break;
+ }
+ case Instruction::SDiv:
+ // `sdiv` is negatible if divisor is not undef/INT_MIN/1.
+ // While this is normally not behind a use-check,
+ // let's consider division to be special since it's costly.
+ if (!I->hasOneUse())
+ break;
+ if (auto *Op1C = dyn_cast<Constant>(I->getOperand(1))) {
+ if (!Op1C->containsUndefElement() && Op1C->isNotMinSignedValue() &&
+ Op1C->isNotOneValue()) {
+ Value *BO =
+ Builder.CreateSDiv(I->getOperand(0), ConstantExpr::getNeg(Op1C),
+ I->getName() + ".neg");
+ if (auto *NewInstr = dyn_cast<Instruction>(BO))
+ NewInstr->setIsExact(I->isExact());
+ return BO;
+ }
+ }
+ break;
+ case Instruction::SExt:
+ case Instruction::ZExt:
+ // `*ext` of i1 is always negatible
+ if (I->getOperand(0)->getType()->isIntOrIntVectorTy(1))
+ return I->getOpcode() == Instruction::SExt
+ ? Builder.CreateZExt(I->getOperand(0), I->getType(),
+ I->getName() + ".neg")
+ : Builder.CreateSExt(I->getOperand(0), I->getType(),
+ I->getName() + ".neg");
+ break;
+ default:
+ break; // Other instructions require recursive reasoning.
+ }
+
+ // Rest of the logic is recursive, and if either the current instruction
+ // has other uses or if it's time to give up then it's time.
+ if (!V->hasOneUse())
+ return nullptr;
+ if (Depth > NegatorMaxDepth) {
+ LLVM_DEBUG(dbgs() << "Negator: reached maximal allowed traversal depth in "
+ << *V << ". Giving up.\n");
+ ++NegatorTimesDepthLimitReached;
+ return nullptr;
+ }
+
+ switch (I->getOpcode()) {
+ case Instruction::PHI: {
+ // `phi` is negatible if all the incoming values are negatible.
+ PHINode *PHI = cast<PHINode>(I);
+ SmallVector<Value *, 4> NegatedIncomingValues(PHI->getNumOperands());
+ for (auto I : zip(PHI->incoming_values(), NegatedIncomingValues)) {
+ if (!(std::get<1>(I) = visit(std::get<0>(I), Depth + 1))) // Early return.
+ return nullptr;
+ }
+ // All incoming values are indeed negatible. Create negated PHI node.
+ PHINode *NegatedPHI = Builder.CreatePHI(
+ PHI->getType(), PHI->getNumOperands(), PHI->getName() + ".neg");
+ for (auto I : zip(NegatedIncomingValues, PHI->blocks()))
+ NegatedPHI->addIncoming(std::get<0>(I), std::get<1>(I));
+ return NegatedPHI;
+ }
+ case Instruction::Select: {
+ {
+ // `abs`/`nabs` is always negatible.
+ Value *LHS, *RHS;
+ SelectPatternFlavor SPF =
+ matchSelectPattern(I, LHS, RHS, /*CastOp=*/nullptr, Depth).Flavor;
+ if (SPF == SPF_ABS || SPF == SPF_NABS) {
+ auto *NewSelect = cast<SelectInst>(I->clone());
+ // Just swap the operands of the select.
+ NewSelect->swapValues();
+ // Don't swap prof metadata, we didn't change the branch behavior.
+ NewSelect->setName(I->getName() + ".neg");
+ Builder.Insert(NewSelect);
+ return NewSelect;
+ }
+ }
+ // `select` is negatible if both hands of `select` are negatible.
+ Value *NegOp1 = visit(I->getOperand(1), Depth + 1);
+ if (!NegOp1) // Early return.
+ return nullptr;
+ Value *NegOp2 = visit(I->getOperand(2), Depth + 1);
+ if (!NegOp2)
+ return nullptr;
+ // Do preserve the metadata!
+ return Builder.CreateSelect(I->getOperand(0), NegOp1, NegOp2,
+ I->getName() + ".neg", /*MDFrom=*/I);
+ }
+ case Instruction::Trunc: {
+ // `trunc` is negatible if its operand is negatible.
+ Value *NegOp = visit(I->getOperand(0), Depth + 1);
+ if (!NegOp) // Early return.
+ return nullptr;
+ return Builder.CreateTrunc(NegOp, I->getType(), I->getName() + ".neg");
+ }
+ case Instruction::Shl: {
+ // `shl` is negatible if the first operand is negatible.
+ Value *NegOp0 = visit(I->getOperand(0), Depth + 1);
+ if (!NegOp0) // Early return.
+ return nullptr;
+ return Builder.CreateShl(NegOp0, I->getOperand(1), I->getName() + ".neg");
+ }
+ case Instruction::Add: {
+ // `add` is negatible if both of its operands are negatible.
+ Value *NegOp0 = visit(I->getOperand(0), Depth + 1);
+ if (!NegOp0) // Early return.
+ return nullptr;
+ Value *NegOp1 = visit(I->getOperand(1), Depth + 1);
+ if (!NegOp1)
+ return nullptr;
+ return Builder.CreateAdd(NegOp0, NegOp1, I->getName() + ".neg");
+ }
+ case Instruction::Xor:
+ // `xor` is negatible if one of its operands is invertible.
+ // FIXME: InstCombineInverter? But how to connect Inverter and Negator?
+ if (auto *C = dyn_cast<Constant>(I->getOperand(1))) {
+ Value *Xor = Builder.CreateXor(I->getOperand(0), ConstantExpr::getNot(C));
+ return Builder.CreateAdd(Xor, ConstantInt::get(Xor->getType(), 1),
+ I->getName() + ".neg");
+ }
+ return nullptr;
+ case Instruction::Mul: {
+ // `mul` is negatible if one of its operands is negatible.
+ Value *NegatedOp, *OtherOp;
+ // First try the second operand, in case it's a constant it will be best to
+ // just invert it instead of sinking the `neg` deeper.
+ if (Value *NegOp1 = visit(I->getOperand(1), Depth + 1)) {
+ NegatedOp = NegOp1;
+ OtherOp = I->getOperand(0);
+ } else if (Value *NegOp0 = visit(I->getOperand(0), Depth + 1)) {
+ NegatedOp = NegOp0;
+ OtherOp = I->getOperand(1);
+ } else
+ // Can't negate either of them.
+ return nullptr;
+ return Builder.CreateMul(NegatedOp, OtherOp, I->getName() + ".neg");
+ }
+ default:
+ return nullptr; // Don't know, likely not negatible for free.
+ }
+
+ llvm_unreachable("Can't get here. We always return from switch.");
+};
+
+LLVM_NODISCARD Optional<Negator::Result> Negator::run(Value *Root) {
+ Value *Negated = visit(Root, /*Depth=*/0);
+ if (!Negated) {
+ // We must cleanup newly-inserted instructions, to avoid any potential
+ // endless combine looping.
+ llvm::for_each(llvm::reverse(NewInstructions),
+ [&](Instruction *I) { I->eraseFromParent(); });
+ return llvm::None;
+ }
+ return std::make_pair(ArrayRef<Instruction *>(NewInstructions), Negated);
+};
+
+LLVM_NODISCARD Value *Negator::Negate(bool LHSIsZero, Value *Root,
+ InstCombiner &IC) {
+ ++NegatorTotalNegationsAttempted;
+ LLVM_DEBUG(dbgs() << "Negator: attempting to sink negation into " << *Root
+ << "\n");
+
+ if (!NegatorEnabled || !DebugCounter::shouldExecute(NegatorCounter))
+ return nullptr;
+
+ Negator N(Root->getContext(), IC.getDataLayout(), LHSIsZero);
+ Optional<Result> Res = N.run(Root);
+ if (!Res) { // Negation failed.
+ LLVM_DEBUG(dbgs() << "Negator: failed to sink negation into " << *Root
+ << "\n");
+ return nullptr;
+ }
+
+ LLVM_DEBUG(dbgs() << "Negator: successfully sunk negation into " << *Root
+ << "\n NEW: " << *Res->second << "\n");
+ ++NegatorNumTreesNegated;
+
+ // We must temporarily unset the 'current' insertion point and DebugLoc of the
+ // InstCombine's IRBuilder so that it won't interfere with the ones we have
+ // already specified when producing negated instructions.
+ InstCombiner::BuilderTy::InsertPointGuard Guard(IC.Builder);
+ IC.Builder.ClearInsertionPoint();
+ IC.Builder.SetCurrentDebugLocation(DebugLoc());
+
+ // And finally, we must add newly-created instructions into the InstCombine's
+ // worklist (in a proper order!) so it can attempt to combine them.
+ LLVM_DEBUG(dbgs() << "Negator: Propagating " << Res->first.size()
+ << " instrs to InstCombine\n");
+ NegatorMaxInstructionsCreated.updateMax(Res->first.size());
+ NegatorNumInstructionsNegatedSuccess += Res->first.size();
+
+ // They are in def-use order, so nothing fancy, just insert them in order.
+ llvm::for_each(Res->first,
+ [&](Instruction *I) { IC.Builder.Insert(I, I->getName()); });
+
+ // And return the new root.
+ return Res->second;
+};
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 4d286f1..88cb8d5 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -853,120 +853,6 @@
return nullptr;
}
-/// Get negated V (that is 0-V) without increasing instruction count,
-/// assuming that the original V will become unused.
-Value *InstCombiner::freelyNegateValue(Value *V) {
- if (Value *NegV = dyn_castNegVal(V))
- return NegV;
-
- Instruction *I = dyn_cast<Instruction>(V);
- if (!I)
- return nullptr;
-
- unsigned BitWidth = I->getType()->getScalarSizeInBits();
- switch (I->getOpcode()) {
- // 0-(zext i1 A) => sext i1 A
- case Instruction::ZExt:
- if (I->getOperand(0)->getType()->isIntOrIntVectorTy(1))
- return Builder.CreateSExtOrBitCast(
- I->getOperand(0), I->getType(), I->getName() + ".neg");
- return nullptr;
-
- // 0-(sext i1 A) => zext i1 A
- case Instruction::SExt:
- if (I->getOperand(0)->getType()->isIntOrIntVectorTy(1))
- return Builder.CreateZExtOrBitCast(
- I->getOperand(0), I->getType(), I->getName() + ".neg");
- return nullptr;
-
- // 0-(A lshr (BW-1)) => A ashr (BW-1)
- case Instruction::LShr:
- if (match(I->getOperand(1), m_SpecificInt(BitWidth - 1)))
- return Builder.CreateAShr(
- I->getOperand(0), I->getOperand(1),
- I->getName() + ".neg", cast<BinaryOperator>(I)->isExact());
- return nullptr;
-
- // 0-(A ashr (BW-1)) => A lshr (BW-1)
- case Instruction::AShr:
- if (match(I->getOperand(1), m_SpecificInt(BitWidth - 1)))
- return Builder.CreateLShr(
- I->getOperand(0), I->getOperand(1),
- I->getName() + ".neg", cast<BinaryOperator>(I)->isExact());
- return nullptr;
-
- // Negation is equivalent to bitwise-not + 1.
- case Instruction::Xor: {
- // Special case for negate of 'not' - replace with increment:
- // 0 - (~A) => ((A ^ -1) ^ -1) + 1 => A + 1
- Value *A;
- if (match(I, m_Not(m_Value(A))))
- return Builder.CreateAdd(A, ConstantInt::get(A->getType(), 1),
- I->getName() + ".neg");
-
- // General case xor (not a 'not') requires creating a new xor, so this has a
- // one-use limitation:
- // 0 - (A ^ C) => ((A ^ C) ^ -1) + 1 => A ^ ~C + 1
- Constant *C;
- if (match(I, m_OneUse(m_Xor(m_Value(A), m_Constant(C))))) {
- Value *Xor = Builder.CreateXor(A, ConstantExpr::getNot(C));
- return Builder.CreateAdd(Xor, ConstantInt::get(Xor->getType(), 1),
- I->getName() + ".neg");
- }
- return nullptr;
- }
-
- default:
- break;
- }
-
- // TODO: The "sub" pattern below could also be applied without the one-use
- // restriction. Not allowing it for now in line with existing behavior.
- if (!I->hasOneUse())
- return nullptr;
-
- switch (I->getOpcode()) {
- // 0-(A-B) => B-A
- case Instruction::Sub:
- return Builder.CreateSub(
- I->getOperand(1), I->getOperand(0), I->getName() + ".neg");
-
- // 0-(A sdiv C) => A sdiv (0-C) provided the negation doesn't overflow.
- case Instruction::SDiv: {
- Constant *C = dyn_cast<Constant>(I->getOperand(1));
- if (C && !C->containsUndefElement() && C->isNotMinSignedValue() &&
- C->isNotOneValue())
- return Builder.CreateSDiv(I->getOperand(0), ConstantExpr::getNeg(C),
- I->getName() + ".neg", cast<BinaryOperator>(I)->isExact());
- return nullptr;
- }
-
- // 0-(A<<B) => (0-A)<<B
- case Instruction::Shl:
- if (Value *NegA = freelyNegateValue(I->getOperand(0)))
- return Builder.CreateShl(NegA, I->getOperand(1), I->getName() + ".neg");
- return nullptr;
-
- // 0-(trunc A) => trunc (0-A)
- case Instruction::Trunc:
- if (Value *NegA = freelyNegateValue(I->getOperand(0)))
- return Builder.CreateTrunc(NegA, I->getType(), I->getName() + ".neg");
- return nullptr;
-
- // 0-(A*B) => (0-A)*B
- // 0-(A*B) => A*(0-B)
- case Instruction::Mul:
- if (Value *NegA = freelyNegateValue(I->getOperand(0)))
- return Builder.CreateMul(NegA, I->getOperand(1), V->getName() + ".neg");
- if (Value *NegB = freelyNegateValue(I->getOperand(1)))
- return Builder.CreateMul(I->getOperand(0), NegB, V->getName() + ".neg");
- return nullptr;
-
- default:
- return nullptr;
- }
-}
-
static Value *foldOperationIntoSelectOperand(Instruction &I, Value *SO,
InstCombiner::BuilderTy &Builder) {
if (auto *Cast = dyn_cast<CastInst>(&I))