[InstCombine] Move negation handling into freelyNegateValue()

Followup to D72978. This moves existing negation handling in
InstCombine into freelyNegateValue(), which make it composable.
In particular, root negations of div/zext/sext/ashr/lshr/sub can
now always be performed through a shl/trunc as well.

Differential Revision: https://reviews.llvm.org/D73288
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 3f842f9..2f40485 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1734,22 +1734,18 @@
   }
 
   if (Constant *C = dyn_cast<Constant>(Op0)) {
-    bool IsNegate = match(C, m_ZeroInt());
+    // -f(x) -> f(-x) if possible.
+    if (match(C, m_Zero()))
+      if (Value *Neg = freelyNegateValue(Op1))
+        return replaceInstUsesWith(I, Neg);
+
     Value *X;
-    if (match(Op1, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) {
-      // 0 - (zext bool) --> sext bool
+    if (match(Op1, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1))
       // C - (zext bool) --> bool ? C - 1 : C
-      if (IsNegate)
-        return CastInst::CreateSExtOrBitCast(X, I.getType());
       return SelectInst::Create(X, SubOne(C), C);
-    }
-    if (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) {
-      // 0 - (sext bool) --> zext bool
+    if (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1))
       // C - (sext bool) --> bool ? C + 1 : C
-      if (IsNegate)
-        return CastInst::CreateZExtOrBitCast(X, I.getType());
       return SelectInst::Create(X, AddOne(C), C);
-    }
 
     // C - ~X == X + (1+C)
     if (match(Op1, m_Not(m_Value(X))))
@@ -1778,51 +1774,15 @@
 
   const APInt *Op0C;
   if (match(Op0, m_APInt(Op0C))) {
-
-    if (Op0C->isNullValue()) {
-      Value *Op1Wide;
-      match(Op1, m_TruncOrSelf(m_Value(Op1Wide)));
-      bool HadTrunc = Op1Wide != Op1;
-      bool NoTruncOrTruncIsOneUse = !HadTrunc || Op1->hasOneUse();
-      unsigned BitWidth = Op1Wide->getType()->getScalarSizeInBits();
-
-      Value *X;
-      const APInt *ShAmt;
-      // -(X >>u 31) -> (X >>s 31)
-      if (NoTruncOrTruncIsOneUse &&
-          match(Op1Wide, m_LShr(m_Value(X), m_APInt(ShAmt))) &&
-          *ShAmt == BitWidth - 1) {
-        Value *ShAmtOp = cast<Instruction>(Op1Wide)->getOperand(1);
-        Instruction *NewShift = BinaryOperator::CreateAShr(X, ShAmtOp);
-        NewShift->copyIRFlags(Op1Wide);
-        if (!HadTrunc)
-          return NewShift;
-        Builder.Insert(NewShift);
-        return TruncInst::CreateTruncOrBitCast(NewShift, Op1->getType());
-      }
-      // -(X >>s 31) -> (X >>u 31)
-      if (NoTruncOrTruncIsOneUse &&
-          match(Op1Wide, m_AShr(m_Value(X), m_APInt(ShAmt))) &&
-          *ShAmt == BitWidth - 1) {
-        Value *ShAmtOp = cast<Instruction>(Op1Wide)->getOperand(1);
-        Instruction *NewShift = BinaryOperator::CreateLShr(X, ShAmtOp);
-        NewShift->copyIRFlags(Op1Wide);
-        if (!HadTrunc)
-          return NewShift;
-        Builder.Insert(NewShift);
-        return TruncInst::CreateTruncOrBitCast(NewShift, Op1->getType());
-      }
-
-      if (!HadTrunc && Op1->hasOneUse()) {
-        Value *LHS, *RHS;
-        SelectPatternFlavor SPF = matchSelectPattern(Op1, LHS, RHS).Flavor;
-        if (SPF == SPF_ABS || SPF == SPF_NABS) {
-          // This is a negate of an ABS/NABS pattern. Just swap the operands
-          // of the select.
-          cast<SelectInst>(Op1)->swapValues();
-          // Don't swap prof metadata, we didn't change the branch behavior.
-          return replaceInstUsesWith(I, Op1);
-        }
+    if (Op0C->isNullValue() && Op1->hasOneUse()) {
+      Value *LHS, *RHS;
+      SelectPatternFlavor SPF = matchSelectPattern(Op1, LHS, RHS).Flavor;
+      if (SPF == SPF_ABS || SPF == SPF_NABS) {
+        // This is a negate of an ABS/NABS pattern. Just swap the operands
+        // of the select.
+        cast<SelectInst>(Op1)->swapValues();
+        // Don't swap prof metadata, we didn't change the branch behavior.
+        return replaceInstUsesWith(I, Op1);
       }
     }
 
@@ -1957,7 +1917,7 @@
   }
 
   if (Op1->hasOneUse()) {
-    Value *X = nullptr, *Y = nullptr, *Z = nullptr;
+    Value *Y = nullptr, *Z = nullptr;
     Constant *C = nullptr;
 
     // (X - (Y - Z))  -->  (X + (Z - Y)).
@@ -1970,24 +1930,6 @@
       return BinaryOperator::CreateAnd(Op0,
                                   Builder.CreateNot(Y, Y->getName() + ".not"));
 
-    // 0 - (X sdiv C)  -> (X sdiv -C)  provided the negation doesn't overflow.
-    if (match(Op0, m_Zero())) {
-      Constant *Op11C;
-      if (match(Op1, m_SDiv(m_Value(X), m_Constant(Op11C))) &&
-          !Op11C->containsUndefElement() && Op11C->isNotMinSignedValue() &&
-          Op11C->isNotOneValue()) {
-        Instruction *BO =
-            BinaryOperator::CreateSDiv(X, ConstantExpr::getNeg(Op11C));
-        BO->setIsExact(cast<BinaryOperator>(Op1)->isExact());
-        return BO;
-      }
-    }
-
-    // 0 - (X << Y)  -> (-X << Y)   when X is freely negatable.
-    if (match(Op1, m_Shl(m_Value(X), m_Value(Y))) && match(Op0, m_Zero()))
-      if (Value *XNeg = freelyNegateValue(X))
-        return BinaryOperator::CreateShl(XNeg, Y);
-
     // Subtracting -1/0 is the same as adding 1/0:
     // sub [nsw] Op0, sext(bool Y) -> add [nsw] Op0, zext(bool Y)
     // 'nuw' is dropped in favor of the canonical form.
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 8dbad1a..16ea6a2 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -862,15 +862,82 @@
   if (Value *NegV = dyn_castNegVal(V))
     return NegV;
 
-  if (!V->hasOneUse())
+  Instruction *I = dyn_cast<Instruction>(V);
+  if (!I)
     return nullptr;
 
-  Value *A, *B;
-  // 0-(A-B)  =>  B-A
-  if (match(V, m_Sub(m_Value(A), m_Value(B))))
-    return Builder.CreateSub(B, A);
+  unsigned BitWidth = I->getType()->getScalarSizeInBits();
+  switch (I->getOpcode()) {
+  // 0-(zext i1 A)  =>  sext i1 A
+  case Instruction::ZExt:
+    if (I->getOperand(0)->getType()->isIntOrIntVectorTy(1))
+      return Builder.CreateSExtOrBitCast(
+          I->getOperand(0), I->getType(), I->getName() + ".neg");
+    return nullptr;
 
-  return nullptr;
+  // 0-(sext i1 A)  =>  zext i1 A
+  case Instruction::SExt:
+    if (I->getOperand(0)->getType()->isIntOrIntVectorTy(1))
+      return Builder.CreateZExtOrBitCast(
+          I->getOperand(0), I->getType(), I->getName() + ".neg");
+    return nullptr;
+
+  // 0-(A lshr (BW-1))  =>  A ashr (BW-1)
+  case Instruction::LShr:
+    if (match(I->getOperand(1), m_SpecificInt(BitWidth - 1)))
+      return Builder.CreateAShr(
+          I->getOperand(0), I->getOperand(1),
+          I->getName() + ".neg", cast<BinaryOperator>(I)->isExact());
+    return nullptr;
+
+  // 0-(A ashr (BW-1))  =>  A lshr (BW-1)
+  case Instruction::AShr:
+    if (match(I->getOperand(1), m_SpecificInt(BitWidth - 1)))
+      return Builder.CreateLShr(
+          I->getOperand(0), I->getOperand(1),
+          I->getName() + ".neg", cast<BinaryOperator>(I)->isExact());
+    return nullptr;
+
+  default:
+    break;
+  }
+
+  // TODO: The "sub" pattern below could also be applied without the one-use
+  // restriction. Not allowing it for now in line with existing behavior.
+  if (!I->hasOneUse())
+    return nullptr;
+
+  switch (I->getOpcode()) {
+  // 0-(A-B)  =>  B-A
+  case Instruction::Sub:
+    return Builder.CreateSub(
+        I->getOperand(1), I->getOperand(0), I->getName() + ".neg");
+
+  // 0-(A sdiv C)  =>  A sdiv (0-C)  provided the negation doesn't overflow.
+  case Instruction::SDiv: {
+    Constant *C = dyn_cast<Constant>(I->getOperand(1));
+    if (C && !C->containsUndefElement() && C->isNotMinSignedValue() &&
+        C->isNotOneValue())
+      return Builder.CreateSDiv(I->getOperand(0), ConstantExpr::getNeg(C),
+          I->getName() + ".neg", cast<BinaryOperator>(I)->isExact());
+    return nullptr;
+  }
+
+  // 0-(A<<B)  =>  (0-A)<<B
+  case Instruction::Shl:
+    if (Value *NegA = freelyNegateValue(I->getOperand(0)))
+      return Builder.CreateShl(NegA, I->getOperand(1), I->getName() + ".neg");
+    return nullptr;
+
+  // 0-(trunc A)  =>  trunc (0-A)
+  case Instruction::Trunc:
+    if (Value *NegA = freelyNegateValue(I->getOperand(0)))
+      return Builder.CreateTrunc(NegA, I->getType(), I->getName() + ".neg");
+    return nullptr;
+
+  default:
+    return nullptr;
+  }
 }
 
 static Value *foldOperationIntoSelectOperand(Instruction &I, Value *SO,