PGO: preserve branch-weight metadata when simplifying Switch to a sub, an icmp
and a conditional branch; also when removing dead cases from a switch.



git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@164084 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Transforms/Utils/SimplifyCFG.cpp b/lib/Transforms/Utils/SimplifyCFG.cpp
index 1316b85..3365c2f 100644
--- a/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -2853,9 +2853,28 @@
   if (!Offset->isNullValue())
     Sub = Builder.CreateAdd(Sub, Offset, Sub->getName()+".off");
   Value *Cmp = Builder.CreateICmpULT(Sub, NumCases, "switch");
-  Builder.CreateCondBr(
+  BranchInst *NewBI = Builder.CreateCondBr(
       Cmp, SI->case_begin().getCaseSuccessor(), SI->getDefaultDest());
 
+  // Update weight for the newly-created conditional branch.
+  SmallVector<uint64_t, 8> Weights;
+  bool HasWeights = HasBranchWeights(SI);
+  if (HasWeights) {
+    GetBranchWeights(SI, Weights);
+    if (Weights.size() == 1 + SI->getNumCases()) {
+      // Combine all weights for the cases to be the true weight of NewBI.
+      // We assume that the sum of all weights for a Terminator can fit into 32
+      // bits.
+      uint32_t NewTrueWeight = 0;
+      for (unsigned I = 1, E = Weights.size(); I != E; ++I)
+        NewTrueWeight += (uint32_t)Weights[I];
+      NewBI->setMetadata(LLVMContext::MD_prof,
+                         MDBuilder(SI->getContext()).
+                         createBranchWeights(NewTrueWeight,
+                                             (uint32_t)Weights[0]));
+    }
+  }
+
   // Prune obsolete incoming values off the successor's PHI nodes.
   for (BasicBlock::iterator BBI = SI->case_begin().getCaseSuccessor()->begin();
        isa<PHINode>(BBI); ++BBI) {
@@ -2886,15 +2905,33 @@
     }
   }
 
+  SmallVector<uint64_t, 8> Weights;
+  bool HasWeight = HasBranchWeights(SI);
+  if (HasWeight) {
+    GetBranchWeights(SI, Weights);
+    HasWeight = (Weights.size() == 1 + SI->getNumCases());
+  }
+
   // Remove dead cases from the switch.
   for (unsigned I = 0, E = DeadCases.size(); I != E; ++I) {
     SwitchInst::CaseIt Case = SI->findCaseValue(DeadCases[I]);
     assert(Case != SI->case_default() &&
            "Case was not found. Probably mistake in DeadCases forming.");
+    if (HasWeight) {
+      std::swap(Weights[Case.getCaseIndex()+1], Weights.back());
+      Weights.pop_back();
+    }
+
     // Prune unused values from PHI nodes.
     Case.getCaseSuccessor()->removePredecessor(SI->getParent());
     SI->removeCase(Case);
   }
+  if (HasWeight) {
+    SmallVector<uint32_t, 8> MDWeights(Weights.begin(), Weights.end());
+    SI->setMetadata(LLVMContext::MD_prof,
+                    MDBuilder(SI->getParent()->getContext()).
+                    createBranchWeights(MDWeights));
+  }
 
   return !DeadCases.empty();
 }