[SCEV] Cache results during GetMinTrailingZeros query
Differential Revision: https://reviews.llvm.org/D29759
llvm-svn: 295060
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 4374735..5d5b9b3 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -4437,8 +4437,7 @@
return getGEPExpr(GEP, IndexExprs);
}
-uint32_t
-ScalarEvolution::GetMinTrailingZeros(const SCEV *S) {
+uint32_t ScalarEvolution::GetMinTrailingZerosImpl(const SCEV *S) {
if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
return C->getAPInt().countTrailingZeros();
@@ -4448,14 +4447,16 @@
if (const SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S)) {
uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
- return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ?
- getTypeSizeInBits(E->getType()) : OpRes;
+ return OpRes == getTypeSizeInBits(E->getOperand()->getType())
+ ? getTypeSizeInBits(E->getType())
+ : OpRes;
}
if (const SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S)) {
uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
- return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ?
- getTypeSizeInBits(E->getType()) : OpRes;
+ return OpRes == getTypeSizeInBits(E->getOperand()->getType())
+ ? getTypeSizeInBits(E->getType())
+ : OpRes;
}
if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) {
@@ -4472,8 +4473,8 @@
uint32_t BitWidth = getTypeSizeInBits(M->getType());
for (unsigned i = 1, e = M->getNumOperands();
SumOpRes != BitWidth && i != e; ++i)
- SumOpRes = std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)),
- BitWidth);
+ SumOpRes =
+ std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)), BitWidth);
return SumOpRes;
}
@@ -4514,6 +4515,17 @@
return 0;
}
+uint32_t ScalarEvolution::GetMinTrailingZeros(const SCEV *S) {
+ auto I = MinTrailingZerosCache.find(S);
+ if (I != MinTrailingZerosCache.end())
+ return I->second;
+
+ uint32_t Result = GetMinTrailingZerosImpl(S);
+ auto InsertPair = MinTrailingZerosCache.insert({S, Result});
+ assert(InsertPair.second && "Should insert a new key");
+ return InsertPair.first->second;
+}
+
/// Helper method to assign a range to V from metadata present in the IR.
static Optional<ConstantRange> GetRangeFromMetadata(Value *V) {
if (Instruction *I = dyn_cast<Instruction>(V))
@@ -9510,6 +9522,7 @@
ValueExprMap(std::move(Arg.ValueExprMap)),
PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
WalkingBEDominatingConds(false), ProvingSplitPredicate(false),
+ MinTrailingZerosCache(std::move(Arg.MinTrailingZerosCache)),
BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
PredicatedBackedgeTakenCounts(
std::move(Arg.PredicatedBackedgeTakenCounts)),
@@ -9915,6 +9928,7 @@
SignedRanges.erase(S);
ExprValueMap.erase(S);
HasRecMap.erase(S);
+ MinTrailingZerosCache.erase(S);
auto RemoveSCEVFromBackedgeMap =
[S, this](DenseMap<const Loop *, BackedgeTakenInfo> &Map) {