Extend ScalarEvolution's executesAtLeastOnce logic to be able to
continue past the first conditional branch when looking for a
relevant test. This helps it avoid using MAX expressions in
loop trip counts in more cases.
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@54697 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp
index 00a4475..786212e 100644
--- a/lib/Analysis/ScalarEvolution.cpp
+++ b/lib/Analysis/ScalarEvolution.cpp
@@ -2709,66 +2709,68 @@
SCEV *LHS, SCEV *RHS) {
BasicBlock *Preheader = L->getLoopPreheader();
BasicBlock *PreheaderDest = L->getHeader();
- if (Preheader == 0) return false;
- BranchInst *LoopEntryPredicate =
- dyn_cast<BranchInst>(Preheader->getTerminator());
- if (!LoopEntryPredicate) return false;
+ // Starting at the preheader, climb up the predecessor chain, as long as
+ // there are unique predecessors, looking for a conditional branch that
+ // protects the loop.
+ //
+ // This is a conservative apporoximation of a climb of the
+ // control-dependence predecessors.
- // This might be a critical edge broken out. If the loop preheader ends in
- // an unconditional branch to the loop, check to see if the preheader has a
- // single predecessor, and if so, look for its terminator.
- while (LoopEntryPredicate->isUnconditional()) {
- PreheaderDest = Preheader;
- Preheader = Preheader->getSinglePredecessor();
- if (!Preheader) return false; // Multiple preds.
-
- LoopEntryPredicate =
+ for (; Preheader; PreheaderDest = Preheader,
+ Preheader = Preheader->getSinglePredecessor()) {
+
+ BranchInst *LoopEntryPredicate =
dyn_cast<BranchInst>(Preheader->getTerminator());
- if (!LoopEntryPredicate) return false;
+ if (!LoopEntryPredicate ||
+ LoopEntryPredicate->isUnconditional())
+ continue;
+
+ ICmpInst *ICI = dyn_cast<ICmpInst>(LoopEntryPredicate->getCondition());
+ if (!ICI) continue;
+
+ // Now that we found a conditional branch that dominates the loop, check to
+ // see if it is the comparison we are looking for.
+ Value *PreCondLHS = ICI->getOperand(0);
+ Value *PreCondRHS = ICI->getOperand(1);
+ ICmpInst::Predicate Cond;
+ if (LoopEntryPredicate->getSuccessor(0) == PreheaderDest)
+ Cond = ICI->getPredicate();
+ else
+ Cond = ICI->getInversePredicate();
+
+ switch (Cond) {
+ case ICmpInst::ICMP_UGT:
+ if (isSigned) continue;
+ std::swap(PreCondLHS, PreCondRHS);
+ Cond = ICmpInst::ICMP_ULT;
+ break;
+ case ICmpInst::ICMP_SGT:
+ if (!isSigned) continue;
+ std::swap(PreCondLHS, PreCondRHS);
+ Cond = ICmpInst::ICMP_SLT;
+ break;
+ case ICmpInst::ICMP_ULT:
+ if (isSigned) continue;
+ break;
+ case ICmpInst::ICMP_SLT:
+ if (!isSigned) continue;
+ break;
+ default:
+ continue;
+ }
+
+ if (!PreCondLHS->getType()->isInteger()) continue;
+
+ SCEVHandle PreCondLHSSCEV = getSCEV(PreCondLHS);
+ SCEVHandle PreCondRHSSCEV = getSCEV(PreCondRHS);
+ if ((LHS == PreCondLHSSCEV && RHS == PreCondRHSSCEV) ||
+ (LHS == SE.getNotSCEV(PreCondRHSSCEV) &&
+ RHS == SE.getNotSCEV(PreCondLHSSCEV)))
+ return true;
}
- ICmpInst *ICI = dyn_cast<ICmpInst>(LoopEntryPredicate->getCondition());
- if (!ICI) return false;
-
- // Now that we found a conditional branch that dominates the loop, check to
- // see if it is the comparison we are looking for.
- Value *PreCondLHS = ICI->getOperand(0);
- Value *PreCondRHS = ICI->getOperand(1);
- ICmpInst::Predicate Cond;
- if (LoopEntryPredicate->getSuccessor(0) == PreheaderDest)
- Cond = ICI->getPredicate();
- else
- Cond = ICI->getInversePredicate();
-
- switch (Cond) {
- case ICmpInst::ICMP_UGT:
- if (isSigned) return false;
- std::swap(PreCondLHS, PreCondRHS);
- Cond = ICmpInst::ICMP_ULT;
- break;
- case ICmpInst::ICMP_SGT:
- if (!isSigned) return false;
- std::swap(PreCondLHS, PreCondRHS);
- Cond = ICmpInst::ICMP_SLT;
- break;
- case ICmpInst::ICMP_ULT:
- if (isSigned) return false;
- break;
- case ICmpInst::ICMP_SLT:
- if (!isSigned) return false;
- break;
- default:
- return false;
- }
-
- if (!PreCondLHS->getType()->isInteger()) return false;
-
- SCEVHandle PreCondLHSSCEV = getSCEV(PreCondLHS);
- SCEVHandle PreCondRHSSCEV = getSCEV(PreCondRHS);
- return (LHS == PreCondLHSSCEV && RHS == PreCondRHSSCEV) ||
- (LHS == SE.getNotSCEV(PreCondRHSSCEV) &&
- RHS == SE.getNotSCEV(PreCondLHSSCEV));
+ return false;
}
/// HowManyLessThans - Return the number of times a backedge containing the