Lower consecutive select instructions correctly.

Summary: If consecutive select instructions are lowered separately in CGP, it will introduce redundant condition check and branches that cannot be removed by later optimization phases. This patch lowers all consecutive select instructions at the same to to avoid inefficent code as demonstrated in https://llvm.org/bugs/show_bug.cgi?id=29095

Reviewers: davidxl

Subscribers: vsk, llvm-commits

Differential Revision: https://reviews.llvm.org/D24147

llvm-svn: 281252
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index 3bdf60c..fc27f0e 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -4578,10 +4578,45 @@
   return false;
 }
 
+/// If \p isTrue is true, return the true value of \p SI, otherwise return
+/// false value of \p SI. If the true/false value of \p SI is defined by any
+/// select instructions in \p Selects, look through the defining select
+/// instruction until the true/false value is not defined in \p Selects.
+static Value *getTrueOrFalseValue(
+    SelectInst *SI, bool isTrue,
+    const SmallPtrSet<const Instruction *, 2> &Selects) {
+  Value *V;
+
+  for (SelectInst *DefSI = SI; DefSI != nullptr && Selects.count(DefSI);
+       DefSI = dyn_cast<SelectInst>(V)) {
+    assert(DefSI.getCondition() == SI->getCondition() &&
+           "The condition of DefSI does not match with SI");
+    V = (isTrue ? DefSI->getTrueValue() : DefSI->getFalseValue());
+  }
+  return V;
+}
 
 /// If we have a SelectInst that will likely profit from branch prediction,
 /// turn it into a branch.
 bool CodeGenPrepare::optimizeSelectInst(SelectInst *SI) {
+  // Find all consecutive select instructions that share the same condition.
+  SmallVector<SelectInst *, 2> ASI;
+  ASI.push_back(SI);
+  for (BasicBlock::iterator It = ++BasicBlock::iterator(SI);
+       It != SI->getParent()->end(); ++It) {
+    SelectInst *I = dyn_cast<SelectInst>(&*It);
+    if (I && SI->getCondition() == I->getCondition()) {
+      ASI.push_back(I);
+    } else {
+      break;
+    }
+  }
+
+  SelectInst *LastSI = ASI.back();
+  // Increment the current iterator to skip all the rest of select instructions
+  // because they will be either "not lowered" or "all lowered" to branch.
+  CurInstIterator = std::next(LastSI->getIterator());
+
   bool VectorCond = !SI->getCondition()->getType()->isIntegerTy(1);
 
   // Can we convert the 'select' to CF ?
@@ -4628,7 +4663,7 @@
 
   // First, we split the block containing the select into 2 blocks.
   BasicBlock *StartBlock = SI->getParent();
-  BasicBlock::iterator SplitPt = ++(BasicBlock::iterator(SI));
+  BasicBlock::iterator SplitPt = ++(BasicBlock::iterator(LastSI));
   BasicBlock *EndBlock = StartBlock->splitBasicBlock(SplitPt, "select.end");
 
   // Delete the unconditional branch that was just created by the split.
@@ -4638,22 +4673,30 @@
   // At least one will become an actual new basic block.
   BasicBlock *TrueBlock = nullptr;
   BasicBlock *FalseBlock = nullptr;
+  BranchInst *TrueBranch = nullptr;
+  BranchInst *FalseBranch = nullptr;
 
   // Sink expensive instructions into the conditional blocks to avoid executing
   // them speculatively.
-  if (sinkSelectOperand(TTI, SI->getTrueValue())) {
-    TrueBlock = BasicBlock::Create(SI->getContext(), "select.true.sink",
-                                   EndBlock->getParent(), EndBlock);
-    auto *TrueBranch = BranchInst::Create(EndBlock, TrueBlock);
-    auto *TrueInst = cast<Instruction>(SI->getTrueValue());
-    TrueInst->moveBefore(TrueBranch);
-  }
-  if (sinkSelectOperand(TTI, SI->getFalseValue())) {
-    FalseBlock = BasicBlock::Create(SI->getContext(), "select.false.sink",
-                                    EndBlock->getParent(), EndBlock);
-    auto *FalseBranch = BranchInst::Create(EndBlock, FalseBlock);
-    auto *FalseInst = cast<Instruction>(SI->getFalseValue());
-    FalseInst->moveBefore(FalseBranch);
+  for (SelectInst *SI : ASI) {
+    if (sinkSelectOperand(TTI, SI->getTrueValue())) {
+      if (TrueBlock == nullptr) {
+        TrueBlock = BasicBlock::Create(SI->getContext(), "select.true.sink",
+                                       EndBlock->getParent(), EndBlock);
+        TrueBranch = BranchInst::Create(EndBlock, TrueBlock);
+      }
+      auto *TrueInst = cast<Instruction>(SI->getTrueValue());
+      TrueInst->moveBefore(TrueBranch);
+    }
+    if (sinkSelectOperand(TTI, SI->getFalseValue())) {
+      if (FalseBlock == nullptr) {
+        FalseBlock = BasicBlock::Create(SI->getContext(), "select.false.sink",
+                                        EndBlock->getParent(), EndBlock);
+        FalseBranch = BranchInst::Create(EndBlock, FalseBlock);
+      }
+      auto *FalseInst = cast<Instruction>(SI->getFalseValue());
+      FalseInst->moveBefore(FalseBranch);
+    }
   }
 
   // If there was nothing to sink, then arbitrarily choose the 'false' side
@@ -4687,18 +4730,27 @@
   }
   IRBuilder<>(SI).CreateCondBr(SI->getCondition(), TT, FT, SI);
 
-  // The select itself is replaced with a PHI Node.
-  PHINode *PN = PHINode::Create(SI->getType(), 2, "", &EndBlock->front());
-  PN->takeName(SI);
-  PN->addIncoming(SI->getTrueValue(), TrueBlock);
-  PN->addIncoming(SI->getFalseValue(), FalseBlock);
+  SmallPtrSet<const Instruction *, 2> INS;
+  INS.insert(ASI.begin(), ASI.end());
+  // Use reverse iterator because later select may use the value of the
+  // earlier select, and we need to propagate value through earlier select
+  // to get the PHI operand.
+  for (auto It = ASI.rbegin(); It != ASI.rend(); ++It) {
+    SelectInst *SI = *It;
+    // The select itself is replaced with a PHI Node.
+    PHINode *PN = PHINode::Create(SI->getType(), 2, "", &EndBlock->front());
+    PN->takeName(SI);
+    PN->addIncoming(getTrueOrFalseValue(SI, true, INS), TrueBlock);
+    PN->addIncoming(getTrueOrFalseValue(SI, false, INS), FalseBlock);
 
-  SI->replaceAllUsesWith(PN);
-  SI->eraseFromParent();
+    SI->replaceAllUsesWith(PN);
+    SI->eraseFromParent();
+    INS.erase(SI);
+    ++NumSelectsExpanded;
+  }
 
   // Instruct OptimizeBlock to skip to the next block.
   CurInstIterator = StartBlock->end();
-  ++NumSelectsExpanded;
   return true;
 }