[InstCombine] Support (sub (sext x), (sext y)) --> (sext (sub x, y)) and (sub (zext x), (zext y)) --> (zext (sub x, y))

Summary:
If the sub doesn't overflow in the original type we can move it above the sext/zext.

This is similar to what we do for add. The overflow checking for sub is currently weaker than add, so the test cases are constructed for what is supported.

Reviewers: spatel

Reviewed By: spatel

Subscribers: llvm-commits

Differential Revision: https://reviews.llvm.org/D52075

llvm-svn: 342335
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index acb62b6..910ec83 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1697,6 +1697,9 @@
     return SelectInst::Create(Cmp, Neg, A);
   }
 
+  if (Instruction *Ext = narrowMathIfNoOverflow(I))
+    return Ext;
+
   bool Changed = false;
   if (!I.hasNoSignedWrap() && willNotOverflowSignedSub(Op0, Op1, I)) {
     Changed = true;
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index d29cf93..5d5a9b2 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -1451,29 +1451,40 @@
 /// sure the narrow op does not overflow.
 Instruction *InstCombiner::narrowMathIfNoOverflow(BinaryOperator &BO) {
   // We need at least one extended operand.
-  Value *LHS = BO.getOperand(0), *RHS = BO.getOperand(1);
+  Value *Op0 = BO.getOperand(0), *Op1 = BO.getOperand(1);
+
+  // If this is a sub, we swap the operands since we always want an extension
+  // on the RHS. The LHS can be an extension or a constant.
+  if (BO.getOpcode() == Instruction::Sub)
+    std::swap(Op0, Op1);
+
   Value *X;
-  bool IsSext = match(LHS, m_SExt(m_Value(X)));
-  if (!IsSext && !match(LHS, m_ZExt(m_Value(X))))
+  bool IsSext = match(Op0, m_SExt(m_Value(X)));
+  if (!IsSext && !match(Op0, m_ZExt(m_Value(X))))
     return nullptr;
 
   // If both operands are the same extension from the same source type and we
   // can eliminate at least one (hasOneUse), this might work.
   CastInst::CastOps CastOpc = IsSext ? Instruction::SExt : Instruction::ZExt;
   Value *Y;
-  if (!(match(RHS, m_ZExtOrSExt(m_Value(Y))) && X->getType() == Y->getType() &&
-        cast<Operator>(RHS)->getOpcode() == CastOpc &&
-        (LHS->hasOneUse() || RHS->hasOneUse()))) {
+  if (!(match(Op1, m_ZExtOrSExt(m_Value(Y))) && X->getType() == Y->getType() &&
+        cast<Operator>(Op1)->getOpcode() == CastOpc &&
+        (Op0->hasOneUse() || Op1->hasOneUse()))) {
     // If that did not match, see if we have a suitable constant operand.
     // Truncating and extending must produce the same constant.
     Constant *WideC;
-    if (!LHS->hasOneUse() || !match(RHS, m_Constant(WideC)))
+    if (!Op0->hasOneUse() || !match(Op1, m_Constant(WideC)))
       return nullptr;
     Constant *NarrowC = ConstantExpr::getTrunc(WideC, X->getType());
     if (ConstantExpr::getCast(CastOpc, NarrowC, BO.getType()) != WideC)
       return nullptr;
     Y = NarrowC;
   }
+
+  // Swap back now that we found our operands.
+  if (BO.getOpcode() == Instruction::Sub)
+    std::swap(X, Y);
+
   // Both operands have narrow versions. Last step: the math must not overflow
   // in the narrow width.
   if (!willNotOverflow(BO.getOpcode(), X, Y, BO, IsSext))