PGO: preserve branch-weight metadata when simplifying SwitchOnSelect.



git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@164068 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Transforms/Utils/SimplifyCFG.cpp b/lib/Transforms/Utils/SimplifyCFG.cpp
index 5a8d2e4..67cb8f5 100644
--- a/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -2213,7 +2213,9 @@
 // Also makes sure not to introduce new successors by assuming that edges to
 // non-successor TrueBBs and FalseBBs aren't reachable.
 static bool SimplifyTerminatorOnSelect(TerminatorInst *OldTerm, Value *Cond,
-                                       BasicBlock *TrueBB, BasicBlock *FalseBB){
+                                       BasicBlock *TrueBB, BasicBlock *FalseBB,
+                                       uint32_t TrueWeight,
+                                       uint32_t FalseWeight){
   // Remove any superfluous successor edges from the CFG.
   // First, figure out which successors to preserve.
   // If TrueBB and FalseBB are equal, only try to preserve one copy of that
@@ -2242,10 +2244,15 @@
       // We were only looking for one successor, and it was present.
       // Create an unconditional branch to it.
       Builder.CreateBr(TrueBB);
-    else
+    else {
       // We found both of the successors we were looking for.
       // Create a conditional branch sharing the condition of the select.
-      Builder.CreateCondBr(Cond, TrueBB, FalseBB);
+      BranchInst *NewBI = Builder.CreateCondBr(Cond, TrueBB, FalseBB);
+      if (TrueWeight != FalseWeight)
+        NewBI->setMetadata(LLVMContext::MD_prof,
+                           MDBuilder(OldTerm->getContext()).
+                           createBranchWeights(TrueWeight, FalseWeight));
+    }
   } else if (KeepEdge1 && (KeepEdge2 || TrueBB == FalseBB)) {
     // Neither of the selected blocks were successors, so this
     // terminator must be unreachable.
@@ -2282,8 +2289,23 @@
   BasicBlock *TrueBB = SI->findCaseValue(TrueVal).getCaseSuccessor();
   BasicBlock *FalseBB = SI->findCaseValue(FalseVal).getCaseSuccessor();
 
+  // Get weight for TrueBB and FalseBB.
+  uint32_t TrueWeight = 0, FalseWeight = 0;
+  SmallVector<uint64_t, 8> Weights;
+  bool HasWeights = HasBranchWeights(SI);
+  if (HasWeights) {
+    GetBranchWeights(SI, Weights);
+    if (Weights.size() == 1 + SI->getNumCases()) {
+      TrueWeight = (uint32_t)Weights[SI->findCaseValue(TrueVal).
+                                     getSuccessorIndex()];
+      FalseWeight = (uint32_t)Weights[SI->findCaseValue(FalseVal).
+                                      getSuccessorIndex()];
+    }
+  }
+
   // Perform the actual simplification.
-  return SimplifyTerminatorOnSelect(SI, Condition, TrueBB, FalseBB);
+  return SimplifyTerminatorOnSelect(SI, Condition, TrueBB, FalseBB,
+                                    TrueWeight, FalseWeight);
 }
 
 // SimplifyIndirectBrOnSelect - Replaces
@@ -2303,7 +2325,8 @@
   BasicBlock *FalseBB = FBA->getBasicBlock();
 
   // Perform the actual simplification.
-  return SimplifyTerminatorOnSelect(IBI, SI->getCondition(), TrueBB, FalseBB);
+  return SimplifyTerminatorOnSelect(IBI, SI->getCondition(), TrueBB, FalseBB,
+                                    0, 0);
 }
 
 /// TryToSimplifyUncondBranchWithICmpInIt - This is called when we find an icmp