[CodeExtractor] Update function's assumption cache after extracting blocks from it
Summary: Assumption cache's self-updating mechanism does not correctly handle the case when blocks are extracted from the function by the CodeExtractor. As a result function's assumption cache may have stale references to the llvm.assume calls that were moved to the outlined function. This patch fixes this problem by removing extracted llvm.assume calls from the function’s assumption cache.
Reviewers: hfinkel, vsk, fhahn, davidxl, sanjoy
Reviewed By: hfinkel, vsk
Subscribers: llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D57215
llvm-svn: 353500
diff --git a/llvm/lib/Transforms/IPO/HotColdSplitting.cpp b/llvm/lib/Transforms/IPO/HotColdSplitting.cpp
index 648e2ae..b8def7a 100644
--- a/llvm/lib/Transforms/IPO/HotColdSplitting.cpp
+++ b/llvm/lib/Transforms/IPO/HotColdSplitting.cpp
@@ -173,8 +173,9 @@
   HotColdSplitting(ProfileSummaryInfo *ProfSI,
                    function_ref<BlockFrequencyInfo *(Function &)> GBFI,
                    function_ref<TargetTransformInfo &(Function &)> GTTI,
-                   std::function<OptimizationRemarkEmitter &(Function &)> *GORE)
-      : PSI(ProfSI), GetBFI(GBFI), GetTTI(GTTI), GetORE(GORE) {}
+                   std::function<OptimizationRemarkEmitter &(Function &)> *GORE,
+                   function_ref<AssumptionCache *(Function &)> LAC)
+      : PSI(ProfSI), GetBFI(GBFI), GetTTI(GTTI), GetORE(GORE), LookupAC(LAC) {}
   bool run(Module &M);
 
 private:
@@ -183,11 +184,13 @@
   bool outlineColdRegions(Function &F, bool HasProfileSummary);
   Function *extractColdRegion(const BlockSequence &Region, DominatorTree &DT,
                               BlockFrequencyInfo *BFI, TargetTransformInfo &TTI,
-                              OptimizationRemarkEmitter &ORE, unsigned Count);
+                              OptimizationRemarkEmitter &ORE,
+                              AssumptionCache *AC, unsigned Count);
   ProfileSummaryInfo *PSI;
   function_ref<BlockFrequencyInfo *(Function &)> GetBFI;
   function_ref<TargetTransformInfo &(Function &)> GetTTI;
   std::function<OptimizationRemarkEmitter &(Function &)> *GetORE;
+  function_ref<AssumptionCache *(Function &)> LookupAC;
 };
 
 class HotColdSplittingLegacyPass : public ModulePass {
@@ -198,10 +201,10 @@
   }
 
   void getAnalysisUsage(AnalysisUsage &AU) const override {
-    AU.addRequired<AssumptionCacheTracker>();
     AU.addRequired<BlockFrequencyInfoWrapperPass>();
     AU.addRequired<ProfileSummaryInfoWrapperPass>();
     AU.addRequired<TargetTransformInfoWrapperPass>();
+    AU.addUsedIfAvailable<AssumptionCacheTracker>();
   }
 
   bool runOnModule(Module &M) override;
@@ -316,12 +319,13 @@
                                               BlockFrequencyInfo *BFI,
                                               TargetTransformInfo &TTI,
                                               OptimizationRemarkEmitter &ORE,
+                                              AssumptionCache *AC,
                                               unsigned Count) {
   assert(!Region.empty());
 
   // TODO: Pass BFI and BPI to update profile information.
   CodeExtractor CE(Region, &DT, /* AggregateArgs */ false, /* BFI */ nullptr,
-                   /* BPI */ nullptr, /* AllowVarArgs */ false,
+                   /* BPI */ nullptr, AC, /* AllowVarArgs */ false,
                    /* AllowAlloca */ false,
                    /* Suffix */ "cold." + std::to_string(Count));
 
@@ -577,6 +581,7 @@
 
   TargetTransformInfo &TTI = GetTTI(F);
   OptimizationRemarkEmitter &ORE = (*GetORE)(F);
+  AssumptionCache *AC = LookupAC(F);
 
   // Find all cold regions.
   for (BasicBlock *BB : RPOT) {
@@ -638,8 +643,8 @@
           BB->dump();
       });
 
-      Function *Outlined =
-          extractColdRegion(SubRegion, *DT, BFI, TTI, ORE, OutlinedFunctionID);
+      Function *Outlined = extractColdRegion(SubRegion, *DT, BFI, TTI, ORE, AC,
+                                             OutlinedFunctionID);
       if (Outlined) {
         ++OutlinedFunctionID;
         Changed = true;
@@ -698,17 +703,21 @@
     ORE.reset(new OptimizationRemarkEmitter(&F));
     return *ORE.get();
   };
+  auto LookupAC = [this](Function &F) -> AssumptionCache * {
+    if (auto *ACT = getAnalysisIfAvailable<AssumptionCacheTracker>())
+      return ACT->lookupAssumptionCache(F);
+    return nullptr;
+  };
 
-  return HotColdSplitting(PSI, GBFI, GTTI, &GetORE).run(M);
+  return HotColdSplitting(PSI, GBFI, GTTI, &GetORE, LookupAC).run(M);
 }
 
 PreservedAnalyses
 HotColdSplittingPass::run(Module &M, ModuleAnalysisManager &AM) {
   auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
 
-  std::function<AssumptionCache &(Function &)> GetAssumptionCache =
-      [&FAM](Function &F) -> AssumptionCache & {
-    return FAM.getResult<AssumptionAnalysis>(F);
+  auto LookupAC = [&FAM](Function &F) -> AssumptionCache * {
+    return FAM.getCachedResult<AssumptionAnalysis>(F);
   };
 
   auto GBFI = [&FAM](Function &F) {
@@ -729,7 +738,7 @@
 
   ProfileSummaryInfo *PSI = &AM.getResult<ProfileSummaryAnalysis>(M);
 
-  if (HotColdSplitting(PSI, GBFI, GTTI, &GetORE).run(M))
+  if (HotColdSplitting(PSI, GBFI, GTTI, &GetORE, LookupAC).run(M))
     return PreservedAnalyses::none();
   return PreservedAnalyses::all();
 }
diff --git a/llvm/lib/Transforms/IPO/LoopExtractor.cpp b/llvm/lib/Transforms/IPO/LoopExtractor.cpp
index 6e7e59a..91c7b5f 100644
--- a/llvm/lib/Transforms/IPO/LoopExtractor.cpp
+++ b/llvm/lib/Transforms/IPO/LoopExtractor.cpp
@@ -14,6 +14,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/Analysis/LoopPass.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/Instructions.h"
@@ -50,6 +51,7 @@
       AU.addRequiredID(LoopSimplifyID);
       AU.addRequired<DominatorTreeWrapperPass>();
       AU.addRequired<LoopInfoWrapperPass>();
+      AU.addUsedIfAvailable<AssumptionCacheTracker>();
     }
   };
 }
@@ -138,7 +140,10 @@
   if (ShouldExtractLoop) {
     if (NumLoops == 0) return Changed;
     --NumLoops;
-    CodeExtractor Extractor(DT, *L);
+    AssumptionCache *AC = nullptr;
+    if (auto *ACT = getAnalysisIfAvailable<AssumptionCacheTracker>())
+      AC = ACT->lookupAssumptionCache(*L->getHeader()->getParent());
+    CodeExtractor Extractor(DT, *L, false, nullptr, nullptr, AC);
     if (Extractor.extractCodeRegion() != nullptr) {
       Changed = true;
       // After extraction, the loop is replaced by a function call, so
diff --git a/llvm/lib/Transforms/IPO/PartialInlining.cpp b/llvm/lib/Transforms/IPO/PartialInlining.cpp
index f971cee..8339eb4 100644
--- a/llvm/lib/Transforms/IPO/PartialInlining.cpp
+++ b/llvm/lib/Transforms/IPO/PartialInlining.cpp
@@ -199,10 +199,12 @@
 
   PartialInlinerImpl(
       std::function<AssumptionCache &(Function &)> *GetAC,
+      function_ref<AssumptionCache *(Function &)> LookupAC,
       std::function<TargetTransformInfo &(Function &)> *GTTI,
       Optional<function_ref<BlockFrequencyInfo &(Function &)>> GBFI,
       ProfileSummaryInfo *ProfSI)
-      : GetAssumptionCache(GetAC), GetTTI(GTTI), GetBFI(GBFI), PSI(ProfSI) {}
+      : GetAssumptionCache(GetAC), LookupAssumptionCache(LookupAC),
+        GetTTI(GTTI), GetBFI(GBFI), PSI(ProfSI) {}
 
   bool run(Module &M);
   // Main part of the transformation that calls helper functions to find
@@ -222,9 +224,11 @@
     // Two constructors, one for single region outlining, the other for
     // multi-region outlining.
     FunctionCloner(Function *F, FunctionOutliningInfo *OI,
-                   OptimizationRemarkEmitter &ORE);
+                   OptimizationRemarkEmitter &ORE,
+                   function_ref<AssumptionCache *(Function &)> LookupAC);
     FunctionCloner(Function *F, FunctionOutliningMultiRegionInfo *OMRI,
-                   OptimizationRemarkEmitter &ORE);
+                   OptimizationRemarkEmitter &ORE,
+                   function_ref<AssumptionCache *(Function &)> LookupAC);
     ~FunctionCloner();
 
     // Prepare for function outlining: making sure there is only
@@ -260,11 +264,13 @@
     std::unique_ptr<FunctionOutliningMultiRegionInfo> ClonedOMRI = nullptr;
     std::unique_ptr<BlockFrequencyInfo> ClonedFuncBFI = nullptr;
     OptimizationRemarkEmitter &ORE;
+    function_ref<AssumptionCache *(Function &)> LookupAC;
   };
 
 private:
   int NumPartialInlining = 0;
   std::function<AssumptionCache &(Function &)> *GetAssumptionCache;
+  function_ref<AssumptionCache *(Function &)> LookupAssumptionCache;
   std::function<TargetTransformInfo &(Function &)> *GetTTI;
   Optional<function_ref<BlockFrequencyInfo &(Function &)>> GetBFI;
   ProfileSummaryInfo *PSI;
@@ -365,12 +371,17 @@
       return ACT->getAssumptionCache(F);
     };
 
+    auto LookupAssumptionCache = [ACT](Function &F) -> AssumptionCache * {
+      return ACT->lookupAssumptionCache(F);
+    };
+
     std::function<TargetTransformInfo &(Function &)> GetTTI =
         [&TTIWP](Function &F) -> TargetTransformInfo & {
       return TTIWP->getTTI(F);
     };
 
-    return PartialInlinerImpl(&GetAssumptionCache, &GetTTI, NoneType::None, PSI)
+    return PartialInlinerImpl(&GetAssumptionCache, LookupAssumptionCache,
+                              &GetTTI, NoneType::None, PSI)
         .run(M);
   }
 };
@@ -948,8 +959,9 @@
 }
 
 PartialInlinerImpl::FunctionCloner::FunctionCloner(
-    Function *F, FunctionOutliningInfo *OI, OptimizationRemarkEmitter &ORE)
-    : OrigFunc(F), ORE(ORE) {
+    Function *F, FunctionOutliningInfo *OI, OptimizationRemarkEmitter &ORE,
+    function_ref<AssumptionCache *(Function &)> LookupAC)
+    : OrigFunc(F), ORE(ORE), LookupAC(LookupAC) {
   ClonedOI = llvm::make_unique<FunctionOutliningInfo>();
 
   // Clone the function, so that we can hack away on it.
@@ -972,8 +984,9 @@
 
 PartialInlinerImpl::FunctionCloner::FunctionCloner(
     Function *F, FunctionOutliningMultiRegionInfo *OI,
-    OptimizationRemarkEmitter &ORE)
-    : OrigFunc(F), ORE(ORE) {
+    OptimizationRemarkEmitter &ORE,
+    function_ref<AssumptionCache *(Function &)> LookupAC)
+    : OrigFunc(F), ORE(ORE), LookupAC(LookupAC) {
   ClonedOMRI = llvm::make_unique<FunctionOutliningMultiRegionInfo>();
 
   // Clone the function, so that we can hack away on it.
@@ -1111,7 +1124,9 @@
     int CurrentOutlinedRegionCost = ComputeRegionCost(RegionInfo.Region);
 
     CodeExtractor CE(RegionInfo.Region, &DT, /*AggregateArgs*/ false,
-                     ClonedFuncBFI.get(), &BPI, /* AllowVarargs */ false);
+                     ClonedFuncBFI.get(), &BPI,
+                     LookupAC(*RegionInfo.EntryBlock->getParent()),
+                     /* AllowVarargs */ false);
 
     CE.findInputsOutputs(Inputs, Outputs, Sinks);
 
@@ -1193,7 +1208,7 @@
   // Extract the body of the if.
   Function *OutlinedFunc =
       CodeExtractor(ToExtract, &DT, /*AggregateArgs*/ false,
-                    ClonedFuncBFI.get(), &BPI,
+                    ClonedFuncBFI.get(), &BPI, LookupAC(*ClonedFunc),
                     /* AllowVarargs */ true)
           .extractCodeRegion();
 
@@ -1257,7 +1272,7 @@
     std::unique_ptr<FunctionOutliningMultiRegionInfo> OMRI =
         computeOutliningColdRegionsInfo(F, ORE);
     if (OMRI) {
-      FunctionCloner Cloner(F, OMRI.get(), ORE);
+      FunctionCloner Cloner(F, OMRI.get(), ORE, LookupAssumptionCache);
 
 #ifndef NDEBUG
       if (TracePartialInlining) {
@@ -1290,7 +1305,7 @@
   if (!OI)
     return {false, nullptr};
 
-  FunctionCloner Cloner(F, OI.get(), ORE);
+  FunctionCloner Cloner(F, OI.get(), ORE, LookupAssumptionCache);
   Cloner.NormalizeReturnBlock();
 
   Function *OutlinedFunction = Cloner.doSingleRegionFunctionOutlining();
@@ -1484,6 +1499,10 @@
     return FAM.getResult<AssumptionAnalysis>(F);
   };
 
+  auto LookupAssumptionCache = [&FAM](Function &F) -> AssumptionCache * {
+    return FAM.getCachedResult<AssumptionAnalysis>(F);
+  };
+
   std::function<BlockFrequencyInfo &(Function &)> GetBFI =
       [&FAM](Function &F) -> BlockFrequencyInfo & {
     return FAM.getResult<BlockFrequencyAnalysis>(F);
@@ -1496,7 +1515,8 @@
 
   ProfileSummaryInfo *PSI = &AM.getResult<ProfileSummaryAnalysis>(M);
 
-  if (PartialInlinerImpl(&GetAssumptionCache, &GetTTI, {GetBFI}, PSI)
+  if (PartialInlinerImpl(&GetAssumptionCache, LookupAssumptionCache, &GetTTI,
+                         {GetBFI}, PSI)
           .run(M))
     return PreservedAnalyses::none();
   return PreservedAnalyses::all();