[SCEV] Cache results during GetMinTrailingZeros query

Differential Revision: https://reviews.llvm.org/D29759

llvm-svn: 295060
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 4374735..5d5b9b3 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -4437,8 +4437,7 @@
   return getGEPExpr(GEP, IndexExprs);
 }
 
-uint32_t
-ScalarEvolution::GetMinTrailingZeros(const SCEV *S) {
+uint32_t ScalarEvolution::GetMinTrailingZerosImpl(const SCEV *S) {
   if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
     return C->getAPInt().countTrailingZeros();
 
@@ -4448,14 +4447,16 @@
 
   if (const SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S)) {
     uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
-    return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ?
-             getTypeSizeInBits(E->getType()) : OpRes;
+    return OpRes == getTypeSizeInBits(E->getOperand()->getType())
+               ? getTypeSizeInBits(E->getType())
+               : OpRes;
   }
 
   if (const SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S)) {
     uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
-    return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ?
-             getTypeSizeInBits(E->getType()) : OpRes;
+    return OpRes == getTypeSizeInBits(E->getOperand()->getType())
+               ? getTypeSizeInBits(E->getType())
+               : OpRes;
   }
 
   if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) {
@@ -4472,8 +4473,8 @@
     uint32_t BitWidth = getTypeSizeInBits(M->getType());
     for (unsigned i = 1, e = M->getNumOperands();
          SumOpRes != BitWidth && i != e; ++i)
-      SumOpRes = std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)),
-                          BitWidth);
+      SumOpRes =
+          std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)), BitWidth);
     return SumOpRes;
   }
 
@@ -4514,6 +4515,17 @@
   return 0;
 }
 
+uint32_t ScalarEvolution::GetMinTrailingZeros(const SCEV *S) {
+  auto I = MinTrailingZerosCache.find(S);
+  if (I != MinTrailingZerosCache.end())
+    return I->second;
+
+  uint32_t Result = GetMinTrailingZerosImpl(S);
+  auto InsertPair = MinTrailingZerosCache.insert({S, Result});
+  assert(InsertPair.second && "Should insert a new key");
+  return InsertPair.first->second;
+}
+
 /// Helper method to assign a range to V from metadata present in the IR.
 static Optional<ConstantRange> GetRangeFromMetadata(Value *V) {
   if (Instruction *I = dyn_cast<Instruction>(V))
@@ -9510,6 +9522,7 @@
       ValueExprMap(std::move(Arg.ValueExprMap)),
       PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
       WalkingBEDominatingConds(false), ProvingSplitPredicate(false),
+      MinTrailingZerosCache(std::move(Arg.MinTrailingZerosCache)),
       BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
       PredicatedBackedgeTakenCounts(
           std::move(Arg.PredicatedBackedgeTakenCounts)),
@@ -9915,6 +9928,7 @@
   SignedRanges.erase(S);
   ExprValueMap.erase(S);
   HasRecMap.erase(S);
+  MinTrailingZerosCache.erase(S);
 
   auto RemoveSCEVFromBackedgeMap =
       [S, this](DenseMap<const Loop *, BackedgeTakenInfo> &Map) {