PGO: preserve branch-weight metadata when simplifying a switch with a single
case to a conditional branch and when removing dead cases.


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@163942 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Transforms/Utils/Local.cpp b/lib/Transforms/Utils/Local.cpp
index dc6b506..2d89516 100644
--- a/lib/Transforms/Utils/Local.cpp
+++ b/lib/Transforms/Utils/Local.cpp
@@ -200,8 +200,20 @@
             "cond");
 
         // Insert the new branch.
-        Builder.CreateCondBr(Cond, FirstCase.getCaseSuccessor(),
-                             SI->getDefaultDest());
+        BranchInst *NewBr = Builder.CreateCondBr(Cond,
+                                FirstCase.getCaseSuccessor(),
+                                SI->getDefaultDest());
+        MDNode* MD = SI->getMetadata(LLVMContext::MD_prof);
+        if (MD && MD->getNumOperands() == 3) {
+          ConstantInt *SICase = dyn_cast<ConstantInt>(MD->getOperand(2));
+          ConstantInt *SIDef = dyn_cast<ConstantInt>(MD->getOperand(1));
+          assert(SICase && SIDef);
+          // The TrueWeight should be the weight for the single case of SI.
+          NewBr->setMetadata(LLVMContext::MD_prof,
+                 MDBuilder(BB->getContext()).
+                 createBranchWeights(SICase->getValue().getZExtValue(),
+                                     SIDef->getValue().getZExtValue()));
+        }
 
         // Delete the old switch.
         SI->eraseFromParent();
diff --git a/lib/Transforms/Utils/SimplifyCFG.cpp b/lib/Transforms/Utils/SimplifyCFG.cpp
index 551df00..a9d74cd 100644
--- a/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -667,13 +667,32 @@
     DEBUG(dbgs() << "Threading pred instr: " << *Pred->getTerminator()
                  << "Through successor TI: " << *TI);
 
+    // Collect branch weights into a vector.
+    SmallVector<uint32_t, 8> Weights;
+    MDNode* MD = SI->getMetadata(LLVMContext::MD_prof);
+    bool HasWeight = MD && (MD->getNumOperands() == 2 + SI->getNumCases());
+    if (HasWeight)
+      for (unsigned MD_i = 1, MD_e = MD->getNumOperands(); MD_i < MD_e;
+           ++MD_i) {
+        ConstantInt* CI = dyn_cast<ConstantInt>(MD->getOperand(MD_i));
+        assert(CI);
+        Weights.push_back(CI->getValue().getZExtValue());
+      }
     for (SwitchInst::CaseIt i = SI->case_end(), e = SI->case_begin(); i != e;) {
       --i;
       if (DeadCases.count(i.getCaseValue())) {
+        if (HasWeight) {
+          std::swap(Weights[i.getCaseIndex()+1], Weights.back());
+          Weights.pop_back();
+        }
         i.getCaseSuccessor()->removePredecessor(TI->getParent());
         SI->removeCase(i);
       }
     }
+    if (HasWeight)
+      SI->setMetadata(LLVMContext::MD_prof,
+                      MDBuilder(SI->getParent()->getContext()).
+                      createBranchWeights(Weights));
 
     DEBUG(dbgs() << "Leaving: " << *TI << "\n");
     return true;