Teach simplifycfg to recompute branch weights when merging some branches, and
to discard weights when appropriate. Still more to do (and a new TODO), but
it's a start!


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@147286 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Transforms/Utils/SimplifyCFG.cpp b/lib/Transforms/Utils/SimplifyCFG.cpp
index 9604f5e..6d6ad66 100644
--- a/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -18,6 +18,8 @@
 #include "llvm/GlobalVariable.h"
 #include "llvm/Instructions.h"
 #include "llvm/IntrinsicInst.h"
+#include "llvm/LLVMContext.h"
+#include "llvm/Metadata.h"
 #include "llvm/Type.h"
 #include "llvm/Analysis/InstructionSimplify.h"
 #include "llvm/Analysis/ValueTracking.h"
@@ -1462,6 +1464,26 @@
   return true;
 }
 
+/// ExtractBranchMetadata - Given a conditional BranchInstruction, retrieve the
+/// probabilities of the branch taking each edge. Fills in the two APInt
+/// parameters and return true, or returns false if no or invalid metadata was
+/// found.
+static bool ExtractBranchMetadata(BranchInst *BI,
+                                  APInt &ProbTrue, APInt &ProbFalse) {
+  assert(BI->isConditional() &&
+         "Looking for probabilities on unconditional branch?");
+  MDNode *ProfileData = BI->getMetadata(LLVMContext::MD_prof);
+  if (!ProfileData || ProfileData->getNumOperands() != 3) return 0;
+  ConstantInt *CITrue = dyn_cast<ConstantInt>(ProfileData->getOperand(1));
+  ConstantInt *CIFalse = dyn_cast<ConstantInt>(ProfileData->getOperand(2));
+  if (!CITrue || !CIFalse) return 0;
+  ProbTrue = CITrue->getValue();
+  ProbFalse = CIFalse->getValue();
+  assert(ProbTrue.getBitWidth() == 32 && ProbFalse.getBitWidth() == 32 &&
+         "Branch probability metadata must be 32-bit integers");
+  return true;
+}
+
 /// FoldBranchToCommonDest - If this basic block is simple enough, and if a
 /// predecessor branches to us and one of our successors, fold the block into
 /// the predecessor and use logical operations to pick the right destination.
@@ -1636,6 +1658,51 @@
       PBI->setSuccessor(1, FalseDest);
     }
 
+    // TODO: If BB is reachable from all paths through PredBlock, then we
+    // could replace PBI's branch probabilities with BI's.
+
+    // Merge probability data into PredBlock's branch.
+    APInt A, B, C, D;
+    if (ExtractBranchMetadata(PBI, C, D) && ExtractBranchMetadata(BI, A, B)) {
+      // bbA: br bbB (a% probability), bbC (b% prob.)
+      // bbB: br bbD (c% probability), bbC (d% prob.)
+      // --> bbA: br bbD ((a*c)% prob.), bbC ((b+a*d)% prob.)
+      //
+      // Probabilities aren't stored as ratios directly. Converting to
+      // probability-numerator form, we get:
+      // (a*c)% = A*C, (b+(a*d))% = A*D+B*C+B*D.
+
+      bool Overflow1 = false, Overflow2 = false, Overflow3 = false;
+      bool Overflow4 = false, Overflow5 = false, Overflow6 = false;
+      APInt ProbTrue = A.umul_ov(C, Overflow1);
+
+      APInt Tmp1 = A.umul_ov(D, Overflow2);
+      APInt Tmp2 = B.umul_ov(C, Overflow3);
+      APInt Tmp3 = B.umul_ov(D, Overflow4);
+      APInt Tmp4 = Tmp1.uadd_ov(Tmp2, Overflow5);
+      APInt ProbFalse = Tmp4.uadd_ov(Tmp3, Overflow6);
+
+      APInt GCD = APIntOps::GreatestCommonDivisor(ProbTrue, ProbFalse);
+      ProbTrue = ProbTrue.udiv(GCD);
+      ProbFalse = ProbFalse.udiv(GCD);
+
+      if (Overflow1 || Overflow2 || Overflow3 || Overflow4 || Overflow5 ||
+          Overflow6) {
+        DEBUG(dbgs() << "Overflow recomputing branch weight on: " << *PBI
+                     << "when merging with: " << *BI);
+        PBI->setMetadata(LLVMContext::MD_prof, NULL);
+      } else {
+        LLVMContext &Context = BI->getContext();
+        Value *Ops[3];
+        Ops[0] = BI->getMetadata(LLVMContext::MD_prof)->getOperand(0);
+        Ops[1] = ConstantInt::get(Context, ProbTrue);
+        Ops[2] = ConstantInt::get(Context, ProbFalse);
+        PBI->setMetadata(LLVMContext::MD_prof, MDNode::get(Context, Ops));
+      }
+    } else {
+      PBI->setMetadata(LLVMContext::MD_prof, NULL);
+    }
+
     // Copy any debug value intrinsics into the end of PredBlock.
     for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I)
       if (isa<DbgInfoIntrinsic>(*I))