[InstCombine] Support folding a subtract with a constant LHS into a phi node

We currently only support folding a subtract into a select but not a PHI. This fixes that.

I had to fix an assumption in FoldOpIntoPhi that assumed the PHI node was always in operand 0. Now we pass it in like we do for FoldOpIntoSelect. But we still require some dancing to find the Constant when we create the BinOp or ConstantExpr. This is based code is similar to what we do for selects.

Since I touched all call sites, this also renames FoldOpIntoPhi to foldOpIntoPhi to match coding standards.

Differential Revision: https://reviews.llvm.org/D31686

llvm-svn: 300363
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 1077121..174ec80 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1572,6 +1572,11 @@
       if (Instruction *R = FoldOpIntoSelect(I, SI))
         return R;
 
+    // Try to fold constant sub into PHI values.
+    if (PHINode *PN = dyn_cast<PHINode>(Op1))
+      if (Instruction *R = foldOpIntoPhi(I, PN))
+        return R;
+
     // C-(X+C2) --> (C-C2)-X
     Constant *C2;
     if (match(Op1, m_Add(m_Value(X), m_Constant(C2))))
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index e08c301..2568313 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -274,12 +274,12 @@
       return NV;
 
   // If we are casting a PHI, then fold the cast into the PHI.
-  if (isa<PHINode>(Src)) {
+  if (auto *PN = dyn_cast<PHINode>(Src)) {
     // Don't do this if it would create a PHI node with an illegal type from a
     // legal type.
     if (!Src->getType()->isIntegerTy() || !CI.getType()->isIntegerTy() ||
         shouldChangeType(CI.getType(), Src->getType()))
-      if (Instruction *NV = FoldOpIntoPhi(CI))
+      if (Instruction *NV = foldOpIntoPhi(CI, PN))
         return NV;
   }
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 5a292ef..bbafa9e 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -2709,7 +2709,7 @@
     // block.  If in the same block, we're encouraging jump threading.  If
     // not, we are just pessimizing the code by making an i1 phi.
     if (LHSI->getParent() == I.getParent())
-      if (Instruction *NV = FoldOpIntoPhi(I))
+      if (Instruction *NV = foldOpIntoPhi(I, cast<PHINode>(LHSI)))
         return NV;
     break;
   case Instruction::Select: {
@@ -4873,7 +4873,7 @@
         // block.  If in the same block, we're encouraging jump threading.  If
         // not, we are just pessimizing the code by making an i1 phi.
         if (LHSI->getParent() == I.getParent())
-          if (Instruction *NV = FoldOpIntoPhi(I))
+          if (Instruction *NV = foldOpIntoPhi(I, cast<PHINode>(LHSI)))
             return NV;
         break;
       case Instruction::SIToFP:
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index f163f57..7100006 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -569,7 +569,7 @@
   /// Given a binary operator, cast instruction, or select which has a PHI node
   /// as operand #0, see if we can fold the instruction into the PHI (which is
   /// only possible if all operands to the PHI are constants).
-  Instruction *FoldOpIntoPhi(Instruction &I);
+  Instruction *foldOpIntoPhi(Instruction &I, PHINode *PN);
 
   /// Given an instruction with a select as one operand and a constant as the
   /// other operand, try to fold the binary operator into the select arguments.
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index a238f3f..f1ac820 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -1455,16 +1455,16 @@
       if (SelectInst *SI = dyn_cast<SelectInst>(Op0I)) {
         if (Instruction *R = FoldOpIntoSelect(I, SI))
           return R;
-      } else if (isa<PHINode>(Op0I)) {
+      } else if (auto *PN = dyn_cast<PHINode>(Op0I)) {
         using namespace llvm::PatternMatch;
         const APInt *Op1Int;
         if (match(Op1, m_APInt(Op1Int)) && !Op1Int->isMinValue() &&
             (I.getOpcode() == Instruction::URem ||
              !Op1Int->isMinSignedValue())) {
-          // FoldOpIntoPhi will speculate instructions to the end of the PHI's
+          // foldOpIntoPhi will speculate instructions to the end of the PHI's
           // predecessor blocks, so do this only if we know the srem or urem
           // will not fault.
-          if (Instruction *NV = FoldOpIntoPhi(I))
+          if (Instruction *NV = foldOpIntoPhi(I, PN))
             return NV;
         }
       }
diff --git a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
index d857417..85e5b6b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
@@ -457,8 +457,8 @@
   }
 
   // The more common cases of a phi with no constant operands or just one
-  // variable operand are handled by FoldPHIArgOpIntoPHI() and FoldOpIntoPhi()
-  // respectively. FoldOpIntoPhi() wants to do the opposite transform that is
+  // variable operand are handled by FoldPHIArgOpIntoPHI() and foldOpIntoPhi()
+  // respectively. foldOpIntoPhi() wants to do the opposite transform that is
   // performed here. It tries to replicate a cast in the phi operand's basic
   // block to expose other folding opportunities. Thus, InstCombine will
   // infinite loop without this check.
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index b7099be..693b6c95 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -1384,11 +1384,11 @@
   }
 
   // See if we can fold the select into a phi node if the condition is a select.
-  if (isa<PHINode>(SI.getCondition()))
+  if (auto *PN = dyn_cast<PHINode>(SI.getCondition()))
     // The true/false values have to be live in the PHI predecessor's blocks.
     if (canSelectOperandBeMappingIntoPredBlock(TrueVal, SI) &&
         canSelectOperandBeMappingIntoPredBlock(FalseVal, SI))
-      if (Instruction *NV = FoldOpIntoPhi(SI))
+      if (Instruction *NV = foldOpIntoPhi(SI, PN))
         return NV;
 
   if (SelectInst *TrueSI = dyn_cast<SelectInst>(TrueVal)) {
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 223fa2a..88ef17b 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -839,8 +839,29 @@
   return SelectInst::Create(SI->getCondition(), NewTV, NewFV, "", nullptr, SI);
 }
 
-Instruction *InstCombiner::FoldOpIntoPhi(Instruction &I) {
-  PHINode *PN = cast<PHINode>(I.getOperand(0));
+static Value *foldOperationIntoPhiValue(BinaryOperator *I, Value *InV,
+                                        InstCombiner *IC) {
+  bool ConstIsRHS = isa<Constant>(I->getOperand(1));
+  Constant *C = cast<Constant>(I->getOperand(ConstIsRHS));
+
+  if (auto *InC = dyn_cast<Constant>(InV)) {
+    if (ConstIsRHS)
+      return ConstantExpr::get(I->getOpcode(), InC, C);
+    return ConstantExpr::get(I->getOpcode(), C, InC);
+  }
+
+  Value *Op0 = InV, *Op1 = C;
+  if (!ConstIsRHS)
+    std::swap(Op0, Op1);
+
+  Value *RI = IC->Builder->CreateBinOp(I->getOpcode(), Op0, Op1, "phitmp");
+  auto *FPInst = dyn_cast<Instruction>(RI);
+  if (FPInst && isa<FPMathOperator>(FPInst))
+    FPInst->copyFastMathFlags(I);
+  return RI;
+}
+
+Instruction *InstCombiner::foldOpIntoPhi(Instruction &I, PHINode *PN) {
   unsigned NumPHIValues = PN->getNumIncomingValues();
   if (NumPHIValues == 0)
     return nullptr;
@@ -946,19 +967,9 @@
                                   C, "phitmp");
       NewPN->addIncoming(InV, PN->getIncomingBlock(i));
     }
-  } else if (I.getNumOperands() == 2) {
-    Constant *C = cast<Constant>(I.getOperand(1));
+  } else if (auto *BO = dyn_cast<BinaryOperator>(&I)) {
     for (unsigned i = 0; i != NumPHIValues; ++i) {
-      Value *InV = nullptr;
-      if (Constant *InC = dyn_cast<Constant>(PN->getIncomingValue(i))) {
-        InV = ConstantExpr::get(I.getOpcode(), InC, C);
-      } else {
-        InV = Builder->CreateBinOp(cast<BinaryOperator>(I).getOpcode(),
-                                   PN->getIncomingValue(i), C, "phitmp");
-        auto *FPInst = dyn_cast<Instruction>(InV);
-        if (FPInst && isa<FPMathOperator>(FPInst))
-          FPInst->copyFastMathFlags(&I);
-      }
+      Value *InV = foldOperationIntoPhiValue(BO, PN->getIncomingValue(i), this);
       NewPN->addIncoming(InV, PN->getIncomingBlock(i));
     }
   } else {
@@ -990,8 +1001,8 @@
   if (auto *Sel = dyn_cast<SelectInst>(I.getOperand(0))) {
     if (Instruction *NewSel = FoldOpIntoSelect(I, Sel))
       return NewSel;
-  } else if (isa<PHINode>(I.getOperand(0))) {
-    if (Instruction *NewPhi = FoldOpIntoPhi(I))
+  } else if (auto *PN = dyn_cast<PHINode>(I.getOperand(0))) {
+    if (Instruction *NewPhi = foldOpIntoPhi(I, PN))
       return NewPhi;
   }
   return nullptr;