CodeExtractor : Add ability to preserve profile data.

Added ability to estimate the entry count of the extracted function and
the branch probabilities of the exit branches.

Patch by River Riddle!

Differential Revision: https://reviews.llvm.org/D22744

llvm-svn: 277411
diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
index 8d0bc036..c514c9c 100644
--- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
@@ -17,6 +17,9 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/StringExtras.h"
+#include "llvm/Analysis/BlockFrequencyInfo.h"
+#include "llvm/Analysis/BlockFrequencyInfoImpl.h"
+#include "llvm/Analysis/BranchProbabilityInfo.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/RegionInfo.h"
 #include "llvm/Analysis/RegionIterator.h"
@@ -26,9 +29,11 @@
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Verifier.h"
 #include "llvm/Pass.h"
+#include "llvm/Support/BlockFrequency.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/ErrorHandling.h"
@@ -119,23 +124,30 @@
   return buildExtractionBlockSet(R.block_begin(), R.block_end());
 }
 
-CodeExtractor::CodeExtractor(BasicBlock *BB, bool AggregateArgs)
-  : DT(nullptr), AggregateArgs(AggregateArgs||AggregateArgsOpt),
-    Blocks(buildExtractionBlockSet(BB)), NumExitBlocks(~0U) {}
+CodeExtractor::CodeExtractor(BasicBlock *BB, bool AggregateArgs,
+                             BlockFrequencyInfo *BFI,
+                             BranchProbabilityInfo *BPI)
+    : DT(nullptr), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
+      BPI(BPI), Blocks(buildExtractionBlockSet(BB)), NumExitBlocks(~0U) {}
 
 CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
-                             bool AggregateArgs)
-  : DT(DT), AggregateArgs(AggregateArgs||AggregateArgsOpt),
-    Blocks(buildExtractionBlockSet(BBs)), NumExitBlocks(~0U) {}
+                             bool AggregateArgs, BlockFrequencyInfo *BFI,
+                             BranchProbabilityInfo *BPI)
+    : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
+      BPI(BPI), Blocks(buildExtractionBlockSet(BBs)), NumExitBlocks(~0U) {}
 
-CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs)
-  : DT(&DT), AggregateArgs(AggregateArgs||AggregateArgsOpt),
-    Blocks(buildExtractionBlockSet(L.getBlocks())), NumExitBlocks(~0U) {}
+CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs,
+                             BlockFrequencyInfo *BFI,
+                             BranchProbabilityInfo *BPI)
+    : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
+      BPI(BPI), Blocks(buildExtractionBlockSet(L.getBlocks())),
+      NumExitBlocks(~0U) {}
 
 CodeExtractor::CodeExtractor(DominatorTree &DT, const RegionNode &RN,
-                             bool AggregateArgs)
-  : DT(&DT), AggregateArgs(AggregateArgs||AggregateArgsOpt),
-    Blocks(buildExtractionBlockSet(RN)), NumExitBlocks(~0U) {}
+                             bool AggregateArgs, BlockFrequencyInfo *BFI,
+                             BranchProbabilityInfo *BPI)
+    : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
+      BPI(BPI), Blocks(buildExtractionBlockSet(RN)), NumExitBlocks(~0U) {}
 
 /// definedInRegion - Return true if the specified value is defined in the
 /// extracted region.
@@ -687,6 +699,51 @@
   }
 }
 
+void CodeExtractor::calculateNewCallTerminatorWeights(
+    BasicBlock *CodeReplacer,
+    DenseMap<BasicBlock *, BlockFrequency> &ExitWeights,
+    BranchProbabilityInfo *BPI) {
+  typedef BlockFrequencyInfoImplBase::Distribution Distribution;
+  typedef BlockFrequencyInfoImplBase::BlockNode BlockNode;
+
+  // Update the branch weights for the exit block.
+  TerminatorInst *TI = CodeReplacer->getTerminator();
+  SmallVector<unsigned, 8> BranchWeights(TI->getNumSuccessors(), 0);
+
+  // Block Frequency distribution with dummy node.
+  Distribution BranchDist;
+
+  // Add each of the frequencies of the successors.
+  for (unsigned i = 0, e = TI->getNumSuccessors(); i < e; ++i) {
+    BlockNode ExitNode(i);
+    uint64_t ExitFreq = ExitWeights[TI->getSuccessor(i)].getFrequency();
+    if (ExitFreq != 0)
+      BranchDist.addExit(ExitNode, ExitFreq);
+    else
+      BPI->setEdgeProbability(CodeReplacer, i, BranchProbability::getZero());
+  }
+
+  // Check for no total weight.
+  if (BranchDist.Total == 0)
+    return;
+
+  // Normalize the distribution so that they can fit in unsigned.
+  BranchDist.normalize();
+
+  // Create normalized branch weights and set the metadata.
+  for (unsigned I = 0, E = BranchDist.Weights.size(); I < E; ++I) {
+    const auto &Weight = BranchDist.Weights[I];
+
+    // Get the weight and update the current BFI.
+    BranchWeights[Weight.TargetNode.Index] = Weight.Amount;
+    BranchProbability BP(Weight.Amount, BranchDist.Total);
+    BPI->setEdgeProbability(CodeReplacer, Weight.TargetNode.Index, BP);
+  }
+  TI->setMetadata(
+      LLVMContext::MD_prof,
+      MDBuilder(TI->getContext()).createBranchWeights(BranchWeights));
+}
+
 Function *CodeExtractor::extractCodeRegion() {
   if (!isEligible())
     return nullptr;
@@ -697,6 +754,19 @@
   // block in the region.
   BasicBlock *header = *Blocks.begin();
 
+  // Calculate the entry frequency of the new function before we change the root
+  //   block.
+  BlockFrequency EntryFreq;
+  if (BFI) {
+    assert(BPI && "Both BPI and BFI are required to preserve profile info");
+    for (BasicBlock *Pred : predecessors(header)) {
+      if (Blocks.count(Pred))
+        continue;
+      EntryFreq +=
+          BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, header);
+    }
+  }
+
   // If we have to split PHI nodes or the entry block, do so now.
   severSplitPHINodes(header);
 
@@ -720,12 +790,23 @@
   // Find inputs to, outputs from the code region.
   findInputsOutputs(inputs, outputs);
 
+  // Calculate the exit blocks for the extracted region and the total exit
+  //  weights for each of those blocks.
+  DenseMap<BasicBlock *, BlockFrequency> ExitWeights;
   SmallPtrSet<BasicBlock *, 1> ExitBlocks;
-  for (BasicBlock *Block : Blocks)
+  for (BasicBlock *Block : Blocks) {
     for (succ_iterator SI = succ_begin(Block), SE = succ_end(Block); SI != SE;
-         ++SI)
-      if (!Blocks.count(*SI))
+         ++SI) {
+      if (!Blocks.count(*SI)) {
+        // Update the branch weight for this successor.
+        if (BFI) {
+          BlockFrequency &BF = ExitWeights[*SI];
+          BF += BFI->getBlockFreq(Block) * BPI->getEdgeProbability(Block, *SI);
+        }
         ExitBlocks.insert(*SI);
+      }
+    }
+  }
   NumExitBlocks = ExitBlocks.size();
 
   // Construct new function based on inputs/outputs & add allocas for all defs.
@@ -734,10 +815,23 @@
                                             codeReplacer, oldFunction,
                                             oldFunction->getParent());
 
+  // Update the entry count of the function.
+  if (BFI) {
+    Optional<uint64_t> EntryCount =
+        BFI->getProfileCountFromFreq(EntryFreq.getFrequency());
+    if (EntryCount.hasValue())
+      newFunction->setEntryCount(EntryCount.getValue());
+    BFI->setBlockFreq(codeReplacer, EntryFreq.getFrequency());
+  }
+
   emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs);
 
   moveCodeToFunction(newFunction);
 
+  // Update the branch weights for the exit block.
+  if (BFI && NumExitBlocks > 1)
+    calculateNewCallTerminatorWeights(codeReplacer, ExitWeights, BPI);
+
   // Loop over all of the PHI nodes in the header block, and change any
   // references to the old incoming edge to be the new incoming edge.
   for (BasicBlock::iterator I = header->begin(); isa<PHINode>(I); ++I) {