HotColdSplit: add back propagation to extend cold regions

Also fix a problem in forward propagation:
  const TerminatorInst *TI = It->getTerminator();
was set outside the while loop that iterates over It.

llvm-svn: 342275
diff --git a/llvm/lib/Transforms/IPO/HotColdSplitting.cpp b/llvm/lib/Transforms/IPO/HotColdSplitting.cpp
index 8542f56..a6ea04e 100644
--- a/llvm/lib/Transforms/IPO/HotColdSplitting.cpp
+++ b/llvm/lib/Transforms/IPO/HotColdSplitting.cpp
@@ -74,6 +74,7 @@
 };
 
 typedef DenseSet<const BasicBlock *> DenseSetBB;
+typedef DenseMap<const BasicBlock *, uint64_t> DenseMapBBInt;
 
 // From: https://reviews.llvm.org/D22558
 // Exit is not part of the region.
@@ -132,35 +133,80 @@
 }
 
 static DenseSetBB getHotBlocks(Function &F) {
-  // First mark all function basic blocks as hot or cold.
-  DenseSet<const BasicBlock *> ColdBlocks;
+
+  // Mark all cold basic blocks.
+  DenseSetBB ColdBlocks;
   for (BasicBlock &BB : F)
     if (unlikelyExecuted(BB))
-      ColdBlocks.insert(&BB);
-  // Forward propagation.
-  DenseSetBB AllColdBlocks;
+      ColdBlocks.insert((const BasicBlock *)&BB);
+
+  // Forward propagation: basic blocks are hot when they are reachable from the
+  // beginning of the function through a path that does not contain cold blocks.
   SmallVector<const BasicBlock *, 8> WL;
-  DenseSetBB Visited; // Track hot blocks.
+  DenseSetBB HotBlocks;
 
   const BasicBlock *It = &F.front();
-  const TerminatorInst *TI = It->getTerminator();
   if (!ColdBlocks.count(It)) {
-    Visited.insert(It);
-    // Breadth First Search to mark edges not reachable from cold.
+    HotBlocks.insert(It);
+    // Breadth First Search to mark edges reachable from hot.
     WL.push_back(It);
     while (WL.size() > 0) {
       It = WL.pop_back_val();
-      for (const BasicBlock *Succ : successors(TI)) {
+
+      for (const BasicBlock *Succ : successors(It)) {
         // Do not visit blocks that are cold.
-        if (!ColdBlocks.count(Succ) && !Visited.count(Succ)) {
-          Visited.insert(Succ);
+        if (!ColdBlocks.count(Succ) && !HotBlocks.count(Succ)) {
+          HotBlocks.insert(Succ);
           WL.push_back(Succ);
         }
       }
     }
   }
 
-  return Visited;
+  assert(WL.empty() && "work list should be empty");
+
+  DenseMapBBInt NumHotSuccessors;
+  // Back propagation: when all successors of a basic block are cold, the
+  // basic block is cold as well.
+  for (BasicBlock &BBRef : F) {
+    const BasicBlock *BB = &BBRef;
+    if (HotBlocks.count(BB)) {
+      // Keep a count of hot successors for every hot block.
+      NumHotSuccessors[BB] = 0;
+      for (const BasicBlock *Succ : successors(BB))
+        if (!ColdBlocks.count(Succ))
+          NumHotSuccessors[BB] += 1;
+
+      // Add to work list the blocks with all successors cold. Those are the
+      // root nodes in the next loop, where we will move those blocks from
+      // HotBlocks to ColdBlocks and iterate over their predecessors.
+      if (NumHotSuccessors[BB] == 0)
+        WL.push_back(BB);
+    }
+  }
+
+  while (WL.size() > 0) {
+    It = WL.pop_back_val();
+    if (ColdBlocks.count(It))
+      continue;
+
+    // Move the block from HotBlocks to ColdBlocks.
+    HotBlocks.erase(It);
+    ColdBlocks.insert(It);
+
+    // Iterate over the predecessors.
+    for (const BasicBlock *Pred : predecessors(It)) {
+      if (HotBlocks.count(Pred)) {
+        NumHotSuccessors[Pred] -= 1;
+
+        // If Pred has no more hot successors, add it to the work list.
+        if (NumHotSuccessors[Pred] == 0)
+          WL.push_back(Pred);
+      }
+    }
+  }
+
+  return HotBlocks;
 }
 
 class HotColdSplitting {
@@ -175,7 +221,7 @@
 private:
   bool shouldOutlineFrom(const Function &F) const;
   Function *outlineColdBlocks(Function &F,
-                              const DenseSet<const BasicBlock *> &ColdBlock,
+                              const DenseSetBB &ColdBlock,
                               DominatorTree *DT, PostDomTree *PDT);
   Function *extractColdRegion(const SmallVectorImpl<BasicBlock *> &Region,
                               DominatorTree *DT, BlockFrequencyInfo *BFI,
@@ -250,9 +296,9 @@
                                     OptimizationRemarkEmitter &ORE) {
   LLVM_DEBUG(for (auto *BB : Region)
           llvm::dbgs() << "\nExtracting: " << *BB;);
+
   // TODO: Pass BFI and BPI to update profile information.
-  CodeExtractor CE(Region, DT, /*AggregateArgs*/ false, nullptr, nullptr,
-                   /* AllowVarargs */ false);
+  CodeExtractor CE(Region, DT);
 
   SetVector<Value *> Inputs, Outputs, Sinks;
   CE.findInputsOutputs(Inputs, Outputs, Sinks);
@@ -286,7 +332,7 @@
 
 // Return the function created after outlining, nullptr otherwise.
 Function *HotColdSplitting::outlineColdBlocks(Function &F,
-                                              const  DenseSetBB &HotBlock,
+                                              const DenseSetBB &HotBlocks,
                                               DominatorTree *DT,
                                               PostDomTree *PDT) {
   auto BFI = GetBFI(F);
@@ -296,7 +342,7 @@
   BasicBlock *Begin = DT->getRootNode()->getBlock();
   for (auto I = df_begin(Begin), E = df_end(Begin); I != E; ++I) {
     BasicBlock *BB = *I;
-    if (PSI->isColdBB(BB, BFI) || !HotBlock.count(BB)) {
+    if (PSI->isColdBB(BB, BFI) || !HotBlocks.count(BB)) {
       SmallVector<BasicBlock *, 4> ValidColdRegion, Region;
       auto *BBNode = (*PDT)[BB];
       auto Exit = BBNode->getIDom()->getBlock();