Invoke GetInlineCost for legality check before inline functions in SampleProfileLoader.

Summary: SampleProfileLoader inlines hot functions if it is inlined in the profiled binary. However, the inline needs to be guarded by legality check, otherwise it could lead to correctness issues.

Reviewers: eraman, davidxl

Reviewed By: eraman

Subscribers: vitalybuka, sanjoy, llvm-commits

Differential Revision: https://reviews.llvm.org/D37779

llvm-svn: 313277
diff --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp
index f13ddee..ac20d2a 100644
--- a/llvm/lib/Transforms/IPO/SampleProfile.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp
@@ -28,9 +28,11 @@
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/Analysis/InlineCost.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/OptimizationDiagnosticInfo.h"
 #include "llvm/Analysis/PostDominators.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DebugInfo.h"
 #include "llvm/IR/DiagnosticInfo.h"
@@ -148,10 +150,12 @@
 public:
   SampleProfileLoader(
       StringRef Name,
-      std::function<AssumptionCache &(Function &)> GetAssumptionCache)
+      std::function<AssumptionCache &(Function &)> GetAssumptionCache,
+      std::function<TargetTransformInfo &(Function &)> GetTargetTransformInfo)
       : DT(nullptr), PDT(nullptr), LI(nullptr), GetAC(GetAssumptionCache),
-        Reader(), Samples(nullptr), Filename(Name), ProfileIsValid(false),
-        TotalCollectedSamples(0), ORE(nullptr) {}
+        GetTTI(GetTargetTransformInfo), Reader(), Samples(nullptr),
+        Filename(Name), ProfileIsValid(false), TotalCollectedSamples(0),
+        ORE(nullptr) {}
 
   bool doInitialization(Module &M);
   bool runOnModule(Module &M, ModuleAnalysisManager *AM);
@@ -225,6 +229,7 @@
   std::unique_ptr<LoopInfo> LI;
 
   std::function<AssumptionCache &(Function &)> GetAC;
+  std::function<TargetTransformInfo &(Function &)> GetTTI;
 
   /// \brief Predecessors for each basic block in the CFG.
   BlockEdgeMap Predecessors;
@@ -265,8 +270,11 @@
       : ModulePass(ID), SampleLoader(Name,
                                      [&](Function &F) -> AssumptionCache & {
                                        return ACT->getAssumptionCache(F);
+                                     },
+                                     [&](Function &F) -> TargetTransformInfo & {
+                                       return TTIWP->getTTI(F);
                                      }),
-        ACT(nullptr) {
+        ACT(nullptr), TTIWP(nullptr) {
     initializeSampleProfileLoaderLegacyPassPass(
         *PassRegistry::getPassRegistry());
   }
@@ -281,11 +289,13 @@
 
   void getAnalysisUsage(AnalysisUsage &AU) const override {
     AU.addRequired<AssumptionCacheTracker>();
+    AU.addRequired<TargetTransformInfoWrapperPass>();
   }
 
 private:
   SampleProfileLoader SampleLoader;
   AssumptionCacheTracker *ACT;
+  TargetTransformInfoWrapperPass *TTIWP;
 };
 
 /// Return true if the given callsite is hot wrt to its caller.
@@ -747,9 +757,25 @@
             Samples->getTotalSamples() * SampleProfileHotThreshold / 100);
         continue;
       }
+      CallSite CS(DI);
       DebugLoc DLoc = I->getDebugLoc();
       BasicBlock *BB = I->getParent();
-      if (InlineFunction(CallSite(DI), IFI)) {
+      InlineParams Params = getInlineParams();
+      Params.ComputeFullInlineCost = true;
+      // Checks if there is anything in the reachable portion of the callee at
+      // this callsite that makes this inlining potentially illegal. Need to
+      // set ComputeFullInlineCost, otherwise getInlineCost may return early
+      // when cost exceeds threshold without checking all IRs in the callee.
+      // The acutal cost does not matter because we only checks isNever() to
+      // see if it is legal to inline the callsite.
+      InlineCost Cost = getInlineCost(CS, Params, GetTTI(*CalledFunction), GetAC,
+                                      None, nullptr, nullptr);
+      if (Cost.isNever()) {
+        ORE->emit(OptimizationRemark(DEBUG_TYPE, "Not inline", DLoc, BB)
+                  << "incompatible inlining");
+        continue;
+      }
+      if (InlineFunction(CS, IFI)) {
         LocalChanged = true;
         // The call to InlineFunction erases DI, so we can't pass it here.
         ORE->emit(OptimizationRemark(DEBUG_TYPE, "HotInline", DLoc, BB)
@@ -1418,6 +1444,7 @@
 INITIALIZE_PASS_BEGIN(SampleProfileLoaderLegacyPass, "sample-profile",
                       "Sample Profile loader", false, false)
 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
+INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
 INITIALIZE_PASS_END(SampleProfileLoaderLegacyPass, "sample-profile",
                     "Sample Profile loader", false, false)
 
@@ -1483,6 +1510,7 @@
 
 bool SampleProfileLoaderLegacyPass::runOnModule(Module &M) {
   ACT = &getAnalysis<AssumptionCacheTracker>();
+  TTIWP = &getAnalysis<TargetTransformInfoWrapperPass>();
   return SampleLoader.runOnModule(M, nullptr);
 }
 
@@ -1512,10 +1540,13 @@
   auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & {
     return FAM.getResult<AssumptionAnalysis>(F);
   };
+  auto GetTTI = [&](Function &F) -> TargetTransformInfo & {
+    return FAM.getResult<TargetIRAnalysis>(F);
+  };
 
   SampleProfileLoader SampleLoader(ProfileFileName.empty() ? SampleProfileFile
                                                            : ProfileFileName,
-                                   GetAssumptionCache);
+                                   GetAssumptionCache, GetTTI);
 
   SampleLoader.doInitialization(M);