[JumpThreading] Fix exponential time algorithm computing known values.
ComputeValueKnownInPredecessors has a "visited" set to prevent infinite
loops, since a value can be visited more than once. However, the
implementation didn't prevent the algorithm from taking exponential
time. Instead of removing elements from the RecursionSet one at a time,
we should keep around the whole set until
ComputeValueKnownInPredecessors finishes, then discard it.
The testcase is synthetic because I was having trouble effectively
reducing the original. But it's basically the same idea.
Instead of failing, we could theoretically cache the result instead.
But I don't think it would help substantially in practice.
Differential Revision: https://reviews.llvm.org/D54239
llvm-svn: 346562
diff --git a/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
index 849ff71..7f2d769 100644
--- a/llvm/lib/Transforms/Scalar/JumpThreading.cpp
+++ b/llvm/lib/Transforms/Scalar/JumpThreading.cpp
@@ -574,9 +574,11 @@
/// BB in the result vector.
///
/// This returns true if there were any known values.
-bool JumpThreadingPass::ComputeValueKnownInPredecessors(
+bool JumpThreadingPass::ComputeValueKnownInPredecessorsImpl(
Value *V, BasicBlock *BB, PredValueInfo &Result,
- ConstantPreference Preference, Instruction *CxtI) {
+ ConstantPreference Preference,
+ DenseSet<std::pair<Value *, BasicBlock *>> &RecursionSet,
+ Instruction *CxtI) {
// This method walks up use-def chains recursively. Because of this, we could
// get into an infinite loop going around loops in the use-def chain. To
// prevent this, keep track of what (value, block) pairs we've already visited
@@ -584,10 +586,6 @@
if (!RecursionSet.insert(std::make_pair(V, BB)).second)
return false;
- // An RAII help to remove this pair from the recursion set once the recursion
- // stack pops back out again.
- RecursionSetRemover remover(RecursionSet, std::make_pair(V, BB));
-
// If V is a constant, then it is known in all predecessors.
if (Constant *KC = getKnownConstant(V, Preference)) {
for (BasicBlock *Pred : predecessors(BB))
@@ -657,7 +655,8 @@
Value *Source = CI->getOperand(0);
if (!isa<PHINode>(Source) && !isa<CmpInst>(Source))
return false;
- ComputeValueKnownInPredecessors(Source, BB, Result, Preference, CxtI);
+ ComputeValueKnownInPredecessorsImpl(Source, BB, Result, Preference,
+ RecursionSet, CxtI);
if (Result.empty())
return false;
@@ -677,10 +676,10 @@
I->getOpcode() == Instruction::And) {
PredValueInfoTy LHSVals, RHSVals;
- ComputeValueKnownInPredecessors(I->getOperand(0), BB, LHSVals,
- WantInteger, CxtI);
- ComputeValueKnownInPredecessors(I->getOperand(1), BB, RHSVals,
- WantInteger, CxtI);
+ ComputeValueKnownInPredecessorsImpl(I->getOperand(0), BB, LHSVals,
+ WantInteger, RecursionSet, CxtI);
+ ComputeValueKnownInPredecessorsImpl(I->getOperand(1), BB, RHSVals,
+ WantInteger, RecursionSet, CxtI);
if (LHSVals.empty() && RHSVals.empty())
return false;
@@ -715,8 +714,8 @@
if (I->getOpcode() == Instruction::Xor &&
isa<ConstantInt>(I->getOperand(1)) &&
cast<ConstantInt>(I->getOperand(1))->isOne()) {
- ComputeValueKnownInPredecessors(I->getOperand(0), BB, Result,
- WantInteger, CxtI);
+ ComputeValueKnownInPredecessorsImpl(I->getOperand(0), BB, Result,
+ WantInteger, RecursionSet, CxtI);
if (Result.empty())
return false;
@@ -733,8 +732,8 @@
&& "A binary operator creating a block address?");
if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->getOperand(1))) {
PredValueInfoTy LHSVals;
- ComputeValueKnownInPredecessors(BO->getOperand(0), BB, LHSVals,
- WantInteger, CxtI);
+ ComputeValueKnownInPredecessorsImpl(BO->getOperand(0), BB, LHSVals,
+ WantInteger, RecursionSet, CxtI);
// Try to use constant folding to simplify the binary operator.
for (const auto &LHSVal : LHSVals) {
@@ -879,8 +878,8 @@
// Try to find a constant value for the LHS of a comparison,
// and evaluate it statically if we can.
PredValueInfoTy LHSVals;
- ComputeValueKnownInPredecessors(I->getOperand(0), BB, LHSVals,
- WantInteger, CxtI);
+ ComputeValueKnownInPredecessorsImpl(I->getOperand(0), BB, LHSVals,
+ WantInteger, RecursionSet, CxtI);
for (const auto &LHSVal : LHSVals) {
Constant *V = LHSVal.first;
@@ -900,8 +899,8 @@
Constant *FalseVal = getKnownConstant(SI->getFalseValue(), Preference);
PredValueInfoTy Conds;
if ((TrueVal || FalseVal) &&
- ComputeValueKnownInPredecessors(SI->getCondition(), BB, Conds,
- WantInteger, CxtI)) {
+ ComputeValueKnownInPredecessorsImpl(SI->getCondition(), BB, Conds,
+ WantInteger, RecursionSet, CxtI)) {
for (auto &C : Conds) {
Constant *Cond = C.first;