[SimpleLoopUnswitch] Implement handling of prof branch_weights metadata for SwitchInst
Differential Revision: https://reviews.llvm.org/D60606
llvm-svn: 364734
diff --git a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp
index 9fba159..cb78240 100644
--- a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp
+++ b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp
@@ -594,11 +594,13 @@
ExitCaseIndices.push_back(Case.getCaseIndex());
}
BasicBlock *DefaultExitBB = nullptr;
+ SwitchInstProfUpdateWrapper::CaseWeightOpt DefaultCaseWeight =
+ SwitchInstProfUpdateWrapper::getSuccessorWeight(SI, 0);
if (!L.contains(SI.getDefaultDest()) &&
areLoopExitPHIsLoopInvariant(L, *ParentBB, *SI.getDefaultDest()) &&
- !isa<UnreachableInst>(SI.getDefaultDest()->getTerminator()))
+ !isa<UnreachableInst>(SI.getDefaultDest()->getTerminator())) {
DefaultExitBB = SI.getDefaultDest();
- else if (ExitCaseIndices.empty())
+ } else if (ExitCaseIndices.empty())
return false;
LLVM_DEBUG(dbgs() << " unswitching trivial switch...\n");
@@ -622,8 +624,11 @@
// Store the exit cases into a separate data structure and remove them from
// the switch.
- SmallVector<std::pair<ConstantInt *, BasicBlock *>, 4> ExitCases;
+ SmallVector<std::tuple<ConstantInt *, BasicBlock *,
+ SwitchInstProfUpdateWrapper::CaseWeightOpt>,
+ 4> ExitCases;
ExitCases.reserve(ExitCaseIndices.size());
+ SwitchInstProfUpdateWrapper SIW(SI);
// We walk the case indices backwards so that we remove the last case first
// and don't disrupt the earlier indices.
for (unsigned Index : reverse(ExitCaseIndices)) {
@@ -633,9 +638,10 @@
if (!ExitL || ExitL->contains(OuterL))
OuterL = ExitL;
// Save the value of this case.
- ExitCases.push_back({CaseI->getCaseValue(), CaseI->getCaseSuccessor()});
+ auto W = SIW.getSuccessorWeight(CaseI->getSuccessorIndex());
+ ExitCases.emplace_back(CaseI->getCaseValue(), CaseI->getCaseSuccessor(), W);
// Delete the unswitched cases.
- SI.removeCase(CaseI);
+ SIW.removeCase(CaseI);
}
if (SE) {
@@ -673,6 +679,7 @@
// Now add the unswitched switch.
auto *NewSI = SwitchInst::Create(LoopCond, NewPH, ExitCases.size(), OldPH);
+ SwitchInstProfUpdateWrapper NewSIW(*NewSI);
// Rewrite the IR for the unswitched basic blocks. This requires two steps.
// First, we split any exit blocks with remaining in-loop predecessors. Then
@@ -700,9 +707,9 @@
}
// Note that we must use a reference in the for loop so that we update the
// container.
- for (auto &CasePair : reverse(ExitCases)) {
+ for (auto &ExitCase : reverse(ExitCases)) {
// Grab a reference to the exit block in the pair so that we can update it.
- BasicBlock *ExitBB = CasePair.second;
+ BasicBlock *ExitBB = std::get<1>(ExitCase);
// If this case is the last edge into the exit block, we can simply reuse it
// as it will no longer be a loop exit. No mapping necessary.
@@ -724,27 +731,39 @@
/*FullUnswitch*/ true);
}
// Update the case pair to point to the split block.
- CasePair.second = SplitExitBB;
+ std::get<1>(ExitCase) = SplitExitBB;
}
// Now add the unswitched cases. We do this in reverse order as we built them
// in reverse order.
- for (auto CasePair : reverse(ExitCases)) {
- ConstantInt *CaseVal = CasePair.first;
- BasicBlock *UnswitchedBB = CasePair.second;
+ for (auto &ExitCase : reverse(ExitCases)) {
+ ConstantInt *CaseVal = std::get<0>(ExitCase);
+ BasicBlock *UnswitchedBB = std::get<1>(ExitCase);
- NewSI->addCase(CaseVal, UnswitchedBB);
+ NewSIW.addCase(CaseVal, UnswitchedBB, std::get<2>(ExitCase));
}
// If the default was unswitched, re-point it and add explicit cases for
// entering the loop.
if (DefaultExitBB) {
- NewSI->setDefaultDest(DefaultExitBB);
+ NewSIW->setDefaultDest(DefaultExitBB);
+ NewSIW.setSuccessorWeight(0, DefaultCaseWeight);
// We removed all the exit cases, so we just copy the cases to the
// unswitched switch.
- for (auto Case : SI.cases())
- NewSI->addCase(Case.getCaseValue(), NewPH);
+ for (const auto &Case : SI.cases())
+ NewSIW.addCase(Case.getCaseValue(), NewPH,
+ SIW.getSuccessorWeight(Case.getSuccessorIndex()));
+ } else if (DefaultCaseWeight) {
+ // We have to set branch weight of the default case.
+ uint64_t SW = *DefaultCaseWeight;
+ for (const auto &Case : SI.cases()) {
+ auto W = SIW.getSuccessorWeight(Case.getSuccessorIndex());
+ assert(W &&
+ "case weight must be defined as default case weight is defined");
+ SW += *W;
+ }
+ NewSIW.setSuccessorWeight(0, SW);
}
// If we ended up with a common successor for every path through the switch
@@ -769,7 +788,7 @@
/*KeepOneInputPHIs*/ true);
}
// Now nuke the switch and replace it with a direct branch.
- SI.eraseFromParent();
+ SIW.eraseFromParent();
BranchInst::Create(CommonSuccBB, BB);
} else if (DefaultExitBB) {
assert(SI.getNumCases() > 0 &&
@@ -779,8 +798,11 @@
// being simple and keeping the number of edges from this switch to
// successors the same, and avoiding any PHI update complexity.
auto LastCaseI = std::prev(SI.case_end());
+
SI.setDefaultDest(LastCaseI->getCaseSuccessor());
- SI.removeCase(LastCaseI);
+ SIW.setSuccessorWeight(
+ 0, SIW.getSuccessorWeight(LastCaseI->getSuccessorIndex()));
+ SIW.removeCase(LastCaseI);
}
// Walk the unswitched exit blocks and the unswitched split blocks and update