[InstCombine] Invert `add A, sext(B) --> sub A, zext(B)` canonicalization (to `sub A, zext B -> add A, sext B`)

Summary:
D68408 proposes to greatly improve our negation sinking abilities.
But in current canonicalization, we produce `sub A, zext(B)`,
which we will consider non-canonical and try to sink that negation,
undoing the existing canonicalization.
So unless we explicitly stop producing previous canonicalization,
we will have two conflicting folds, and will end up endlessly looping.

This inverts canonicalization, and adds back the obvious fold
that we'd miss:
* `sub [nsw] Op0, sext/zext (bool Y) -> add [nsw] Op0, zext/sext (bool Y)`
  https://rise4fun.com/Alive/xx4
* `sext(bool) + C -> bool ? C - 1 : C`
  https://rise4fun.com/Alive/fBl

It is obvious that `@ossfuzz_9880()` / `@lshr_out_of_range()`/`@ashr_out_of_range()`
(oss-fuzz 4871) are no longer folded as much, though those aren't really worrying.

Reviewers: spatel, efriedma, t.p.northover, hfinkel

Reviewed By: spatel

Subscribers: hiraditya, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D71064
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 5b71f9d..ad42f3d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -890,6 +890,10 @@
   if (match(Op0, m_ZExt(m_Value(X))) &&
       X->getType()->getScalarSizeInBits() == 1)
     return SelectInst::Create(X, AddOne(Op1C), Op1);
+  // sext(bool) + C -> bool ? C - 1 : C
+  if (match(Op0, m_SExt(m_Value(X))) &&
+      X->getType()->getScalarSizeInBits() == 1)
+    return SelectInst::Create(X, SubOne(Op1C), Op1);
 
   // ~X + C --> (C-1) - X
   if (match(Op0, m_Not(m_Value(X))))
@@ -1288,12 +1292,6 @@
     return BinaryOperator::CreateSub(RHS, A);
   }
 
-  // Canonicalize sext to zext for better value tracking potential.
-  // add A, sext(B) --> sub A, zext(B)
-  if (match(&I, m_c_Add(m_Value(A), m_OneUse(m_SExt(m_Value(B))))) &&
-      B->getType()->isIntOrIntVectorTy(1))
-    return BinaryOperator::CreateSub(A, Builder.CreateZExt(B, Ty));
-
   // A + -B  -->  A - B
   if (match(RHS, m_Neg(m_Value(B))))
     return BinaryOperator::CreateSub(LHS, B);
@@ -1923,6 +1921,14 @@
       Add->setHasNoSignedWrap(I.hasNoSignedWrap());
       return Add;
     }
+    // sub [nsw] X, zext(bool Y) -> add [nsw] X, sext(bool Y)
+    // 'nuw' is dropped in favor of the canonical form.
+    if (match(Op1, m_ZExt(m_Value(Y))) && Y->getType()->isIntOrIntVectorTy(1)) {
+      Value *Sext = Builder.CreateSExt(Y, I.getType());
+      BinaryOperator *Add = BinaryOperator::CreateAdd(Op0, Sext);
+      Add->setHasNoSignedWrap(I.hasNoSignedWrap());
+      return Add;
+    }
 
     // X - A*-B -> X + A*B
     // X - -A*B -> X + A*B