Recommit r354930 "[PGO] Context sensitive PGO (part 1)"

Fixed UBSan failures.

llvm-svn: 355005
diff --git a/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp b/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp
index e77427a..bf91e6d 100644
--- a/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp
+++ b/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp
@@ -18,6 +18,8 @@
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/Triple.h"
 #include "llvm/ADT/Twine.h"
+#include "llvm/Analysis/BlockFrequencyInfo.h"
+#include "llvm/Analysis/BranchProbabilityInfo.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/IR/Attributes.h"
@@ -147,8 +149,8 @@
   static char ID;
 
   InstrProfilingLegacyPass() : ModulePass(ID) {}
-  InstrProfilingLegacyPass(const InstrProfOptions &Options)
-      : ModulePass(ID), InstrProf(Options) {}
+  InstrProfilingLegacyPass(const InstrProfOptions &Options, bool IsCS = false)
+      : ModulePass(ID), InstrProf(Options, IsCS) {}
 
   StringRef getPassName() const override {
     return "Frontend instrumentation-based coverage lowering";
@@ -232,9 +234,9 @@
 public:
   PGOCounterPromoter(
       DenseMap<Loop *, SmallVector<LoadStorePair, 8>> &LoopToCands,
-      Loop &CurLoop, LoopInfo &LI)
+      Loop &CurLoop, LoopInfo &LI, BlockFrequencyInfo *BFI)
       : LoopToCandidates(LoopToCands), ExitBlocks(), InsertPts(), L(CurLoop),
-        LI(LI) {
+        LI(LI), BFI(BFI) {
 
     SmallVector<BasicBlock *, 8> LoopExitBlocks;
     SmallPtrSet<BasicBlock *, 8> BlockSet;
@@ -263,6 +265,20 @@
       SSAUpdater SSA(&NewPHIs);
       Value *InitVal = ConstantInt::get(Cand.first->getType(), 0);
 
+      // If BFI is set, we will use it to guide the promotions.
+      if (BFI) {
+        auto *BB = Cand.first->getParent();
+        auto InstrCount = BFI->getBlockProfileCount(BB);
+        if (!InstrCount)
+          continue;
+        auto PreheaderCount = BFI->getBlockProfileCount(L.getLoopPreheader());
+        // If the average loop trip count is not greater than 1.5, we skip
+        // promotion.
+        if (PreheaderCount &&
+            (PreheaderCount.getValue() * 3) >= (InstrCount.getValue() * 2))
+          continue;
+      }
+
       PGOCounterPromoterHelper Promoter(Cand.first, Cand.second, SSA, InitVal,
                                         L.getLoopPreheader(), ExitBlocks,
                                         InsertPts, LoopToCandidates, LI);
@@ -312,6 +328,11 @@
 
     SmallVector<BasicBlock *, 8> ExitingBlocks;
     LP->getExitingBlocks(ExitingBlocks);
+
+    // If BFI is set, we do more aggressive promotions based on BFI.
+    if (BFI)
+      return (unsigned)-1;
+
     // Not considierered speculative.
     if (ExitingBlocks.size() == 1)
       return MaxNumOfPromotionsPerLoop;
@@ -343,6 +364,7 @@
   SmallVector<Instruction *, 8> InsertPts;
   Loop &L;
   LoopInfo &LI;
+  BlockFrequencyInfo *BFI;
 };
 
 } // end anonymous namespace
@@ -365,8 +387,9 @@
     "Frontend instrumentation-based coverage lowering.", false, false)
 
 ModulePass *
-llvm::createInstrProfilingLegacyPass(const InstrProfOptions &Options) {
-  return new InstrProfilingLegacyPass(Options);
+llvm::createInstrProfilingLegacyPass(const InstrProfOptions &Options,
+                                     bool IsCS) {
+  return new InstrProfilingLegacyPass(Options, IsCS);
 }
 
 static InstrProfIncrementInst *castToIncrementInst(Instruction *Instr) {
@@ -415,6 +438,13 @@
   LoopInfo LI(DT);
   DenseMap<Loop *, SmallVector<LoadStorePair, 8>> LoopPromotionCandidates;
 
+  std::unique_ptr<BlockFrequencyInfo> BFI;
+  if (Options.UseBFIInPromotion) {
+    std::unique_ptr<BranchProbabilityInfo> BPI;
+    BPI.reset(new BranchProbabilityInfo(*F, LI, TLI));
+    BFI.reset(new BlockFrequencyInfo(*F, *BPI, LI));
+  }
+
   for (const auto &LoadStore : PromotionCandidates) {
     auto *CounterLoad = LoadStore.first;
     auto *CounterStore = LoadStore.second;
@@ -430,7 +460,7 @@
   // Do a post-order traversal of the loops so that counter updates can be
   // iteratively hoisted outside the loop nest.
   for (auto *Loop : llvm::reverse(Loops)) {
-    PGOCounterPromoter Promoter(LoopPromotionCandidates, *Loop, LI);
+    PGOCounterPromoter Promoter(LoopPromotionCandidates, *Loop, LI, BFI.get());
     Promoter.run(&TotalCountersPromoted);
   }
 }
@@ -681,7 +711,6 @@
   // Don't do this for Darwin.  compiler-rt uses linker magic.
   if (TT.isOSDarwin())
     return false;
-
   // Use linker script magic to get data/cnts/name start/end.
   if (TT.isOSLinux() || TT.isOSFreeBSD() || TT.isOSNetBSD() ||
       TT.isOSFuchsia() || TT.isPS4CPU() || TT.isOSWindows())
@@ -985,8 +1014,12 @@
 }
 
 void InstrProfiling::emitInitialization() {
-  // Create variable for profile name.
-  createProfileFileNameVar(*M, Options.InstrProfileOutput);
+  // Create ProfileFileName variable. Don't don't this for the
+  // context-sensitive instrumentation lowering: This lowering is after
+  // LTO/ThinLTO linking. Pass PGOInstrumentationGenCreateVar should
+  // have already create the variable before LTO/ThinLTO linking.
+  if (!IsCS)
+    createProfileFileNameVar(*M, Options.InstrProfileOutput);
   Function *RegisterF = M->getFunction(getInstrProfRegFuncsName());
   if (!RegisterF)
     return;
diff --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
index bb2e335..d44c2ad 100644
--- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
+++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
@@ -65,6 +65,7 @@
 #include "llvm/Analysis/IndirectCallVisitor.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
+#include "llvm/Analysis/ProfileSummaryInfo.h"
 #include "llvm/IR/Attributes.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/CFG.h"
@@ -132,6 +133,19 @@
 STATISTIC(NumOfPGOMismatch, "Number of functions having mismatch profile.");
 STATISTIC(NumOfPGOMissing, "Number of functions without profile.");
 STATISTIC(NumOfPGOICall, "Number of indirect call value instrumentations.");
+STATISTIC(NumOfCSPGOInstrument, "Number of edges instrumented in CSPGO.");
+STATISTIC(NumOfCSPGOSelectInsts,
+          "Number of select instruction instrumented in CSPGO.");
+STATISTIC(NumOfCSPGOMemIntrinsics,
+          "Number of mem intrinsics instrumented in CSPGO.");
+STATISTIC(NumOfCSPGOEdge, "Number of edges in CSPGO.");
+STATISTIC(NumOfCSPGOBB, "Number of basic-blocks in CSPGO.");
+STATISTIC(NumOfCSPGOSplit, "Number of critical edge splits in CSPGO.");
+STATISTIC(NumOfCSPGOFunc,
+          "Number of functions having valid profile counts in CSPGO.");
+STATISTIC(NumOfCSPGOMismatch,
+          "Number of functions having mismatch profile in CSPGO.");
+STATISTIC(NumOfCSPGOMissing, "Number of functions without profile in CSPGO.");
 
 // Command line option to specify the file to read profile from. This is
 // mainly used for testing.
@@ -383,7 +397,8 @@
 public:
   static char ID;
 
-  PGOInstrumentationGenLegacyPass() : ModulePass(ID) {
+  PGOInstrumentationGenLegacyPass(bool IsCS = false)
+      : ModulePass(ID), IsCS(IsCS) {
     initializePGOInstrumentationGenLegacyPassPass(
         *PassRegistry::getPassRegistry());
   }
@@ -391,6 +406,8 @@
   StringRef getPassName() const override { return "PGOInstrumentationGenPass"; }
 
 private:
+  // Is this is context-sensitive instrumentation.
+  bool IsCS;
   bool runOnModule(Module &M) override;
 
   void getAnalysisUsage(AnalysisUsage &AU) const override {
@@ -403,8 +420,8 @@
   static char ID;
 
   // Provide the profile filename as the parameter.
-  PGOInstrumentationUseLegacyPass(std::string Filename = "")
-      : ModulePass(ID), ProfileFileName(std::move(Filename)) {
+  PGOInstrumentationUseLegacyPass(std::string Filename = "", bool IsCS = false)
+      : ModulePass(ID), ProfileFileName(std::move(Filename)), IsCS(IsCS) {
     if (!PGOTestProfileFile.empty())
       ProfileFileName = PGOTestProfileFile;
     initializePGOInstrumentationUseLegacyPassPass(
@@ -415,14 +432,38 @@
 
 private:
   std::string ProfileFileName;
+  // Is this is context-sensitive instrumentation use.
+  bool IsCS;
 
   bool runOnModule(Module &M) override;
 
   void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.addRequired<ProfileSummaryInfoWrapperPass>();
     AU.addRequired<BlockFrequencyInfoWrapperPass>();
   }
 };
 
+class PGOInstrumentationGenCreateVarLegacyPass : public ModulePass {
+public:
+  static char ID;
+  StringRef getPassName() const override {
+    return "PGOInstrumentationGenCreateVarPass";
+  }
+  PGOInstrumentationGenCreateVarLegacyPass(std::string CSInstrName = "")
+      : ModulePass(ID), InstrProfileOutput(CSInstrName) {
+    initializePGOInstrumentationGenCreateVarLegacyPassPass(
+        *PassRegistry::getPassRegistry());
+  }
+
+private:
+  bool runOnModule(Module &M) override {
+    createProfileFileNameVar(M, InstrProfileOutput);
+    createIRLevelProfileFlagVar(M, true);
+    return false;
+  }
+  std::string InstrProfileOutput;
+};
+
 } // end anonymous namespace
 
 char PGOInstrumentationGenLegacyPass::ID = 0;
@@ -434,8 +475,8 @@
 INITIALIZE_PASS_END(PGOInstrumentationGenLegacyPass, "pgo-instr-gen",
                     "PGO instrumentation.", false, false)
 
-ModulePass *llvm::createPGOInstrumentationGenLegacyPass() {
-  return new PGOInstrumentationGenLegacyPass();
+ModulePass *llvm::createPGOInstrumentationGenLegacyPass(bool IsCS) {
+  return new PGOInstrumentationGenLegacyPass(IsCS);
 }
 
 char PGOInstrumentationUseLegacyPass::ID = 0;
@@ -444,11 +485,25 @@
                       "Read PGO instrumentation profile.", false, false)
 INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass)
 INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass)
 INITIALIZE_PASS_END(PGOInstrumentationUseLegacyPass, "pgo-instr-use",
                     "Read PGO instrumentation profile.", false, false)
 
-ModulePass *llvm::createPGOInstrumentationUseLegacyPass(StringRef Filename) {
-  return new PGOInstrumentationUseLegacyPass(Filename.str());
+ModulePass *llvm::createPGOInstrumentationUseLegacyPass(StringRef Filename,
+                                                        bool IsCS) {
+  return new PGOInstrumentationUseLegacyPass(Filename.str(), IsCS);
+}
+
+char PGOInstrumentationGenCreateVarLegacyPass::ID = 0;
+
+INITIALIZE_PASS(PGOInstrumentationGenCreateVarLegacyPass,
+                "pgo-instr-gen-create-var",
+                "Create PGO instrumentation version variable for CSPGO.", false,
+                false)
+
+ModulePass *
+llvm::createPGOInstrumentationGenCreateVarLegacyPass(StringRef CSInstrName) {
+  return new PGOInstrumentationGenCreateVarLegacyPass(CSInstrName);
 }
 
 namespace {
@@ -496,6 +551,9 @@
 private:
   Function &F;
 
+  // Is this is context-sensitive instrumentation.
+  bool IsCS;
+
   // A map that stores the Comdat group in function F.
   std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers;
 
@@ -535,15 +593,23 @@
       Function &Func,
       std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers,
       bool CreateGlobalVar = false, BranchProbabilityInfo *BPI = nullptr,
-      BlockFrequencyInfo *BFI = nullptr)
-      : F(Func), ComdatMembers(ComdatMembers), ValueSites(IPVK_Last + 1),
-        SIVisitor(Func), MIVisitor(Func), MST(F, BPI, BFI) {
+      BlockFrequencyInfo *BFI = nullptr, bool IsCS = false)
+      : F(Func), IsCS(IsCS), ComdatMembers(ComdatMembers),
+        ValueSites(IPVK_Last + 1), SIVisitor(Func), MIVisitor(Func),
+        MST(F, BPI, BFI) {
     // This should be done before CFG hash computation.
     SIVisitor.countSelects(Func);
     MIVisitor.countMemIntrinsics(Func);
-    NumOfPGOSelectInsts += SIVisitor.getNumOfSelectInsts();
-    NumOfPGOMemIntrinsics += MIVisitor.getNumOfMemIntrinsics();
-    ValueSites[IPVK_IndirectCallTarget] = findIndirectCalls(Func);
+    if (!IsCS) {
+      NumOfPGOSelectInsts += SIVisitor.getNumOfSelectInsts();
+      NumOfPGOMemIntrinsics += MIVisitor.getNumOfMemIntrinsics();
+      NumOfPGOBB += MST.BBInfos.size();
+      ValueSites[IPVK_IndirectCallTarget] = findIndirectCalls(Func);
+    } else {
+      NumOfCSPGOSelectInsts += SIVisitor.getNumOfSelectInsts();
+      NumOfCSPGOMemIntrinsics += MIVisitor.getNumOfMemIntrinsics();
+      NumOfCSPGOBB += MST.BBInfos.size();
+    }
     ValueSites[IPVK_MemOPSize] = MIVisitor.findMemIntrinsics(Func);
 
     FuncName = getPGOFuncName(F);
@@ -552,13 +618,12 @@
       renameComdatFunction();
     LLVM_DEBUG(dumpInfo("after CFGMST"));
 
-    NumOfPGOBB += MST.BBInfos.size();
     for (auto &E : MST.AllEdges) {
       if (E->Removed)
         continue;
-      NumOfPGOEdge++;
+      IsCS ? NumOfCSPGOEdge++ : NumOfPGOEdge++;
       if (!E->InMST)
-        NumOfPGOInstrument++;
+        IsCS ? NumOfCSPGOInstrument++ : NumOfPGOInstrument++;
     }
 
     if (CreateGlobalVar)
@@ -597,9 +662,17 @@
     }
   }
   JC.update(Indexes);
+
+  // Hash format for context sensitive profile. Reserve 4 bits for other
+  // information.
   FunctionHash = (uint64_t)SIVisitor.getNumOfSelectInsts() << 56 |
                  (uint64_t)ValueSites[IPVK_IndirectCallTarget].size() << 48 |
+                 //(uint64_t)ValueSites[IPVK_MemOPSize].size() << 40 |
                  (uint64_t)MST.AllEdges.size() << 32 | JC.getCRC();
+  // Reserve bit 60-63 for other information purpose.
+  FunctionHash &= 0x0FFFFFFFFFFFFFFF;
+  if (IsCS)
+    NamedInstrProfRecord::setCSFlagInHash(FunctionHash);
   LLVM_DEBUG(dbgs() << "Function Hash Computation for " << F.getName() << ":\n"
                     << " CRC = " << JC.getCRC()
                     << ", Selects = " << SIVisitor.getNumOfSelectInsts()
@@ -705,7 +778,7 @@
 
   // For a critical edge, we have to split. Instrument the newly
   // created BB.
-  NumOfPGOSplit++;
+  IsCS ? NumOfCSPGOSplit++ : NumOfPGOSplit++;
   LLVM_DEBUG(dbgs() << "Split critical edge: " << getBBInfo(SrcBB).Index
                     << " --> " << getBBInfo(DestBB).Index << "\n");
   unsigned SuccNum = GetSuccessorNumber(SrcBB, DestBB);
@@ -720,12 +793,14 @@
 // Critical edges will be split.
 static void instrumentOneFunc(
     Function &F, Module *M, BranchProbabilityInfo *BPI, BlockFrequencyInfo *BFI,
-    std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers) {
+    std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers,
+    bool IsCS) {
   // Split indirectbr critical edges here before computing the MST rather than
   // later in getInstrBB() to avoid invalidating it.
   SplitIndirectBrCriticalEdges(F, BPI, BFI);
+
   FuncPGOInstrumentation<PGOEdge, BBInfo> FuncInfo(F, ComdatMembers, true, BPI,
-                                                   BFI);
+                                                   BFI, IsCS);
   unsigned NumCounters = FuncInfo.getNumCounters();
 
   uint32_t I = 0;
@@ -852,10 +927,10 @@
   PGOUseFunc(Function &Func, Module *Modu,
              std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers,
              BranchProbabilityInfo *BPI = nullptr,
-             BlockFrequencyInfo *BFIin = nullptr)
+             BlockFrequencyInfo *BFIin = nullptr, bool IsCS = false)
       : F(Func), M(Modu), BFI(BFIin),
-        FuncInfo(Func, ComdatMembers, false, BPI, BFIin),
-        FreqAttr(FFA_Normal) {}
+        FuncInfo(Func, ComdatMembers, false, BPI, BFIin, IsCS),
+        FreqAttr(FFA_Normal), IsCS(IsCS) {}
 
   // Read counts for the instrumented BB from profile.
   bool readCounters(IndexedInstrProfReader *PGOReader, bool &AllZeros);
@@ -928,6 +1003,9 @@
   // Function hotness info derived from profile.
   FuncFreqAttr FreqAttr;
 
+  // Is to use the context sensitive profile.
+  bool IsCS;
+
   // Find the Instrumented BB and set the value.
   void setInstrumentedCounts(const std::vector<uint64_t> &CountFromProfile);
 
@@ -1021,23 +1099,31 @@
     handleAllErrors(std::move(E), [&](const InstrProfError &IPE) {
       auto Err = IPE.get();
       bool SkipWarning = false;
+      LLVM_DEBUG(dbgs() << "Error in reading profile for Func "
+                        << FuncInfo.FuncName << ": ");
       if (Err == instrprof_error::unknown_function) {
-        NumOfPGOMissing++;
+        IsCS ? NumOfCSPGOMissing++ : NumOfPGOMissing++;
         SkipWarning = !PGOWarnMissing;
+        LLVM_DEBUG(dbgs() << "unknown function");
       } else if (Err == instrprof_error::hash_mismatch ||
                  Err == instrprof_error::malformed) {
-        NumOfPGOMismatch++;
+        IsCS ? NumOfCSPGOMismatch++ : NumOfPGOMismatch++;
         SkipWarning =
             NoPGOWarnMismatch ||
             (NoPGOWarnMismatchComdat &&
              (F.hasComdat() ||
               F.getLinkage() == GlobalValue::AvailableExternallyLinkage));
+        LLVM_DEBUG(dbgs() << "hash mismatch (skip=" << SkipWarning << ")");
       }
 
+      LLVM_DEBUG(dbgs() << " IsCS=" << IsCS << "\n");
       if (SkipWarning)
         return;
 
-      std::string Msg = IPE.message() + std::string(" ") + F.getName().str();
+      std::string Msg = IPE.message() + std::string(" ") + F.getName().str() +
+                        std::string(" Hash = ") +
+                        std::to_string(FuncInfo.FunctionHash);
+
       Ctx.diagnose(
           DiagnosticInfoPGOProfile(M->getName().data(), Msg, DS_Warning));
     });
@@ -1046,7 +1132,7 @@
   ProfileRecord = std::move(Result.get());
   std::vector<uint64_t> &CountFromProfile = ProfileRecord.Counts;
 
-  NumOfPGOFunc++;
+  IsCS ? NumOfCSPGOFunc++ : NumOfPGOFunc++;
   LLVM_DEBUG(dbgs() << CountFromProfile.size() << " counts\n");
   uint64_t ValueSum = 0;
   for (unsigned I = 0, S = CountFromProfile.size(); I < S; I++) {
@@ -1166,7 +1252,8 @@
 // Assign the scaled count values to the BB with multiple out edges.
 void PGOUseFunc::setBranchWeights() {
   // Generate MD_prof metadata for every branch instruction.
-  LLVM_DEBUG(dbgs() << "\nSetting branch weights.\n");
+  LLVM_DEBUG(dbgs() << "\nSetting branch weights for func " << F.getName()
+                    << " IsCS=" << IsCS << "\n");
   for (auto &BB : F) {
     Instruction *TI = BB.getTerminator();
     if (TI->getNumSuccessors() < 2)
@@ -1174,6 +1261,7 @@
     if (!(isa<BranchInst>(TI) || isa<SwitchInst>(TI) ||
           isa<IndirectBrInst>(TI)))
       continue;
+
     if (getBBInfo(&BB).CountValue == 0)
       continue;
 
@@ -1351,24 +1439,6 @@
   }
 }
 
-// Create a COMDAT variable INSTR_PROF_RAW_VERSION_VAR to make the runtime
-// aware this is an ir_level profile so it can set the version flag.
-static void createIRLevelProfileFlagVariable(Module &M) {
-  Type *IntTy64 = Type::getInt64Ty(M.getContext());
-  uint64_t ProfileVersion = (INSTR_PROF_RAW_VERSION | VARIANT_MASK_IR_PROF);
-  auto IRLevelVersionVariable = new GlobalVariable(
-      M, IntTy64, true, GlobalVariable::ExternalLinkage,
-      Constant::getIntegerValue(IntTy64, APInt(64, ProfileVersion)),
-      INSTR_PROF_QUOTE(INSTR_PROF_RAW_VERSION_VAR));
-  IRLevelVersionVariable->setVisibility(GlobalValue::DefaultVisibility);
-  Triple TT(M.getTargetTriple());
-  if (!TT.supportsCOMDAT())
-    IRLevelVersionVariable->setLinkage(GlobalValue::WeakAnyLinkage);
-  else
-    IRLevelVersionVariable->setComdat(M.getOrInsertComdat(
-        StringRef(INSTR_PROF_QUOTE(INSTR_PROF_RAW_VERSION_VAR))));
-}
-
 // Collect the set of members for each Comdat in module M and store
 // in ComdatMembers.
 static void collectComdatMembers(
@@ -1389,8 +1459,11 @@
 
 static bool InstrumentAllFunctions(
     Module &M, function_ref<BranchProbabilityInfo *(Function &)> LookupBPI,
-    function_ref<BlockFrequencyInfo *(Function &)> LookupBFI) {
-  createIRLevelProfileFlagVariable(M);
+    function_ref<BlockFrequencyInfo *(Function &)> LookupBFI, bool IsCS) {
+  // For the context-sensitve instrumentation, we should have a separated pass
+  // (before LTO/ThinLTO linking) to create these variables.
+  if (!IsCS)
+    createIRLevelProfileFlagVar(M, /* IsCS */ false);
   std::unordered_multimap<Comdat *, GlobalValue *> ComdatMembers;
   collectComdatMembers(M, ComdatMembers);
 
@@ -1399,7 +1472,7 @@
       continue;
     auto *BPI = LookupBPI(F);
     auto *BFI = LookupBFI(F);
-    instrumentOneFunc(F, &M, BPI, BFI, ComdatMembers);
+    instrumentOneFunc(F, &M, BPI, BFI, ComdatMembers, IsCS);
   }
   return true;
 }
@@ -1414,7 +1487,7 @@
   auto LookupBFI = [this](Function &F) {
     return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI();
   };
-  return InstrumentAllFunctions(M, LookupBPI, LookupBFI);
+  return InstrumentAllFunctions(M, LookupBPI, LookupBFI, IsCS);
 }
 
 PreservedAnalyses PGOInstrumentationGen::run(Module &M,
@@ -1428,7 +1501,7 @@
     return &FAM.getResult<BlockFrequencyAnalysis>(F);
   };
 
-  if (!InstrumentAllFunctions(M, LookupBPI, LookupBFI))
+  if (!InstrumentAllFunctions(M, LookupBPI, LookupBFI, IsCS))
     return PreservedAnalyses::all();
 
   return PreservedAnalyses::none();
@@ -1437,7 +1510,7 @@
 static bool annotateAllFunctions(
     Module &M, StringRef ProfileFileName, StringRef ProfileRemappingFileName,
     function_ref<BranchProbabilityInfo *(Function &)> LookupBPI,
-    function_ref<BlockFrequencyInfo *(Function &)> LookupBFI) {
+    function_ref<BlockFrequencyInfo *(Function &)> LookupBFI, bool IsCS) {
   LLVM_DEBUG(dbgs() << "Read in profile counters: ");
   auto &Ctx = M.getContext();
   // Read the counter array from file.
@@ -1458,6 +1531,7 @@
                                           StringRef("Cannot get PGOReader")));
     return false;
   }
+
   // TODO: might need to change the warning once the clang option is finalized.
   if (!PGOReader->isIRLevelProfile()) {
     Ctx.diagnose(DiagnosticInfoPGOProfile(
@@ -1477,7 +1551,7 @@
     // Split indirectbr critical edges here before computing the MST rather than
     // later in getInstrBB() to avoid invalidating it.
     SplitIndirectBrCriticalEdges(F, BPI, BFI);
-    PGOUseFunc Func(F, &M, ComdatMembers, BPI, BFI);
+    PGOUseFunc Func(F, &M, ComdatMembers, BPI, BFI, IsCS);
     bool AllZeros = false;
     if (!Func.readCounters(PGOReader.get(), AllZeros))
       continue;
@@ -1526,6 +1600,7 @@
     }
   }
   M.setProfileSummary(PGOReader->getSummary().getMD(M.getContext()));
+
   // Set function hotness attribute from the profile.
   // We have to apply these attributes at the end because their presence
   // can affect the BranchProbabilityInfo of any callers, resulting in an
@@ -1544,9 +1619,10 @@
 }
 
 PGOInstrumentationUse::PGOInstrumentationUse(std::string Filename,
-                                             std::string RemappingFilename)
+                                             std::string RemappingFilename,
+                                             bool IsCS)
     : ProfileFileName(std::move(Filename)),
-      ProfileRemappingFileName(std::move(RemappingFilename)) {
+      ProfileRemappingFileName(std::move(RemappingFilename)), IsCS(IsCS) {
   if (!PGOTestProfileFile.empty())
     ProfileFileName = PGOTestProfileFile;
   if (!PGOTestProfileRemappingFile.empty())
@@ -1566,7 +1642,7 @@
   };
 
   if (!annotateAllFunctions(M, ProfileFileName, ProfileRemappingFileName,
-                            LookupBPI, LookupBFI))
+                            LookupBPI, LookupBFI, IsCS))
     return PreservedAnalyses::all();
 
   return PreservedAnalyses::none();
@@ -1583,7 +1659,8 @@
     return &this->getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI();
   };
 
-  return annotateAllFunctions(M, ProfileFileName, "", LookupBPI, LookupBFI);
+  return annotateAllFunctions(M, ProfileFileName, "", LookupBPI, LookupBFI,
+                              IsCS);
 }
 
 static std::string getSimpleNodeName(const BasicBlock *Node) {