[CallSite removal] move InlineCost to CallBase usage

Converting InlineCost interface and its internals into CallBase usage.
Inliners themselves are still not converted.

Reviewed By: reames
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D60636

llvm-svn: 358982
diff --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp
index d8b87f9..a4161dd 100644
--- a/llvm/lib/Analysis/InlineCost.cpp
+++ b/llvm/lib/Analysis/InlineCost.cpp
@@ -27,7 +27,6 @@
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/Config/llvm-config.h"
-#include "llvm/IR/CallSite.h"
 #include "llvm/IR/CallingConv.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/Dominators.h"
@@ -121,7 +120,7 @@
   /// The candidate callsite being analyzed. Please do not use this to do
   /// analysis in the caller function; we want the inline cost query to be
   /// easily cacheable. Instead, use the cover function paramHasAttr.
-  CallSite CandidateCS;
+  CallBase &CandidateCall;
 
   /// Tunable parameters that control the analysis.
   const InlineParams &Params;
@@ -195,7 +194,7 @@
   bool isGEPFree(GetElementPtrInst &GEP);
   bool canFoldInboundsGEP(GetElementPtrInst &I);
   bool accumulateGEPOffset(GEPOperator &GEP, APInt &Offset);
-  bool simplifyCallSite(Function *F, CallSite CS);
+  bool simplifyCallSite(Function *F, CallBase &Call);
   template <typename Callable>
   bool simplifyInstruction(Instruction &I, Callable Evaluate);
   ConstantInt *stripAndComputeInBoundsConstantOffsets(Value *&V);
@@ -215,16 +214,16 @@
   /// attributes and callee hotness for PGO builds. The Callee is explicitly
   /// passed to support analyzing indirect calls whose target is inferred by
   /// analysis.
-  void updateThreshold(CallSite CS, Function &Callee);
+  void updateThreshold(CallBase &Call, Function &Callee);
 
-  /// Return true if size growth is allowed when inlining the callee at CS.
-  bool allowSizeGrowth(CallSite CS);
+  /// Return true if size growth is allowed when inlining the callee at \p Call.
+  bool allowSizeGrowth(CallBase &Call);
 
-  /// Return true if \p CS is a cold callsite.
-  bool isColdCallSite(CallSite CS, BlockFrequencyInfo *CallerBFI);
+  /// Return true if \p Call is a cold callsite.
+  bool isColdCallSite(CallBase &Call, BlockFrequencyInfo *CallerBFI);
 
-  /// Return a higher threshold if \p CS is a hot callsite.
-  Optional<int> getHotCallSiteThreshold(CallSite CS,
+  /// Return a higher threshold if \p Call is a hot callsite.
+  Optional<int> getHotCallSiteThreshold(CallBase &Call,
                                         BlockFrequencyInfo *CallerBFI);
 
   // Custom analysis routines.
@@ -259,7 +258,7 @@
   bool visitStore(StoreInst &I);
   bool visitExtractValue(ExtractValueInst &I);
   bool visitInsertValue(InsertValueInst &I);
-  bool visitCallSite(CallSite CS);
+  bool visitCallBase(CallBase &Call);
   bool visitReturnInst(ReturnInst &RI);
   bool visitBranchInst(BranchInst &BI);
   bool visitSelectInst(SelectInst &SI);
@@ -275,10 +274,10 @@
                std::function<AssumptionCache &(Function &)> &GetAssumptionCache,
                Optional<function_ref<BlockFrequencyInfo &(Function &)>> &GetBFI,
                ProfileSummaryInfo *PSI, OptimizationRemarkEmitter *ORE,
-               Function &Callee, CallSite CSArg, const InlineParams &Params)
+               Function &Callee, CallBase &Call, const InlineParams &Params)
       : TTI(TTI), GetAssumptionCache(GetAssumptionCache), GetBFI(GetBFI),
         PSI(PSI), F(Callee), DL(F.getParent()->getDataLayout()), ORE(ORE),
-        CandidateCS(CSArg), Params(Params), Threshold(Params.DefaultThreshold),
+        CandidateCall(Call), Params(Params), Threshold(Params.DefaultThreshold),
         Cost(0), ComputeFullInlineCost(OptComputeFullInlineCost ||
                                        Params.ComputeFullInlineCost || ORE),
         IsCallerRecursive(false), IsRecursiveCall(false),
@@ -292,7 +291,7 @@
         NumInstructionsSimplified(0), SROACostSavings(0),
         SROACostSavingsLost(0) {}
 
-  InlineResult analyzeCall(CallSite CS);
+  InlineResult analyzeCall(CallBase &Call);
 
   int getThreshold() { return Threshold; }
   int getCost() { return Cost; }
@@ -743,7 +742,7 @@
 }
 
 bool CallAnalyzer::paramHasAttr(Argument *A, Attribute::AttrKind Attr) {
-  return CandidateCS.paramHasAttr(A->getArgNo(), Attr);
+  return CandidateCall.paramHasAttr(A->getArgNo(), Attr);
 }
 
 bool CallAnalyzer::isKnownNonNullInCallee(Value *V) {
@@ -768,7 +767,7 @@
   return false;
 }
 
-bool CallAnalyzer::allowSizeGrowth(CallSite CS) {
+bool CallAnalyzer::allowSizeGrowth(CallBase &Call) {
   // If the normal destination of the invoke or the parent block of the call
   // site is unreachable-terminated, there is little point in inlining this
   // unless there is literally zero cost.
@@ -784,21 +783,21 @@
   // For now, we are not handling this corner case here as it is rare in real
   // code. In future, we should elaborate this based on BPI and BFI in more
   // general threshold adjusting heuristics in updateThreshold().
-  Instruction *Instr = CS.getInstruction();
-  if (InvokeInst *II = dyn_cast<InvokeInst>(Instr)) {
+  if (InvokeInst *II = dyn_cast<InvokeInst>(&Call)) {
     if (isa<UnreachableInst>(II->getNormalDest()->getTerminator()))
       return false;
-  } else if (isa<UnreachableInst>(Instr->getParent()->getTerminator()))
+  } else if (isa<UnreachableInst>(Call.getParent()->getTerminator()))
     return false;
 
   return true;
 }
 
-bool CallAnalyzer::isColdCallSite(CallSite CS, BlockFrequencyInfo *CallerBFI) {
+bool CallAnalyzer::isColdCallSite(CallBase &Call,
+                                  BlockFrequencyInfo *CallerBFI) {
   // If global profile summary is available, then callsite's coldness is
   // determined based on that.
   if (PSI && PSI->hasProfileSummary())
-    return PSI->isColdCallSite(CS, CallerBFI);
+    return PSI->isColdCallSite(CallSite(&Call), CallerBFI);
 
   // Otherwise we need BFI to be available.
   if (!CallerBFI)
@@ -809,20 +808,21 @@
   // complexity is not worth it unless this scaling shows up high in the
   // profiles.
   const BranchProbability ColdProb(ColdCallSiteRelFreq, 100);
-  auto CallSiteBB = CS.getInstruction()->getParent();
+  auto CallSiteBB = Call.getParent();
   auto CallSiteFreq = CallerBFI->getBlockFreq(CallSiteBB);
   auto CallerEntryFreq =
-      CallerBFI->getBlockFreq(&(CS.getCaller()->getEntryBlock()));
+      CallerBFI->getBlockFreq(&(Call.getCaller()->getEntryBlock()));
   return CallSiteFreq < CallerEntryFreq * ColdProb;
 }
 
 Optional<int>
-CallAnalyzer::getHotCallSiteThreshold(CallSite CS,
+CallAnalyzer::getHotCallSiteThreshold(CallBase &Call,
                                       BlockFrequencyInfo *CallerBFI) {
 
   // If global profile summary is available, then callsite's hotness is
   // determined based on that.
-  if (PSI && PSI->hasProfileSummary() && PSI->isHotCallSite(CS, CallerBFI))
+  if (PSI && PSI->hasProfileSummary() &&
+      PSI->isHotCallSite(CallSite(&Call), CallerBFI))
     return Params.HotCallSiteThreshold;
 
   // Otherwise we need BFI to be available and to have a locally hot callsite
@@ -834,7 +834,7 @@
   // potentially cache the computation of scaled entry frequency, but the added
   // complexity is not worth it unless this scaling shows up high in the
   // profiles.
-  auto CallSiteBB = CS.getInstruction()->getParent();
+  auto CallSiteBB = Call.getParent();
   auto CallSiteFreq = CallerBFI->getBlockFreq(CallSiteBB).getFrequency();
   auto CallerEntryFreq = CallerBFI->getEntryFreq();
   if (CallSiteFreq >= CallerEntryFreq * HotCallSiteRelFreq)
@@ -844,14 +844,14 @@
   return None;
 }
 
-void CallAnalyzer::updateThreshold(CallSite CS, Function &Callee) {
+void CallAnalyzer::updateThreshold(CallBase &Call, Function &Callee) {
   // If no size growth is allowed for this inlining, set Threshold to 0.
-  if (!allowSizeGrowth(CS)) {
+  if (!allowSizeGrowth(Call)) {
     Threshold = 0;
     return;
   }
 
-  Function *Caller = CS.getCaller();
+  Function *Caller = Call.getCaller();
 
   // return min(A, B) if B is valid.
   auto MinIfValid = [](int A, Optional<int> B) {
@@ -922,7 +922,7 @@
     // used (which adds hotness metadata to calls) or if caller's
     // BlockFrequencyInfo is available.
     BlockFrequencyInfo *CallerBFI = GetBFI ? &((*GetBFI)(*Caller)) : nullptr;
-    auto HotCallSiteThreshold = getHotCallSiteThreshold(CS, CallerBFI);
+    auto HotCallSiteThreshold = getHotCallSiteThreshold(Call, CallerBFI);
     if (!Caller->hasOptSize() && HotCallSiteThreshold) {
       LLVM_DEBUG(dbgs() << "Hot callsite.\n");
       // FIXME: This should update the threshold only if it exceeds the
@@ -930,7 +930,7 @@
       // behavior to prevent inlining of hot callsites during ThinLTO
       // compile phase.
       Threshold = HotCallSiteThreshold.getValue();
-    } else if (isColdCallSite(CS, CallerBFI)) {
+    } else if (isColdCallSite(Call, CallerBFI)) {
       LLVM_DEBUG(dbgs() << "Cold callsite.\n");
       // Do not apply bonuses for a cold callsite including the
       // LastCallToStatic bonus. While this bonus might result in code size
@@ -967,7 +967,7 @@
   VectorBonus = Threshold * VectorBonusPercent / 100;
 
   bool OnlyOneCallAndLocalLinkage =
-      F.hasLocalLinkage() && F.hasOneUse() && &F == CS.getCalledFunction();
+      F.hasLocalLinkage() && F.hasOneUse() && &F == Call.getCalledFunction();
   // If there is only one call of the function, and it has internal linkage,
   // the cost of inlining it drops dramatically. It may seem odd to update
   // Cost in updateThreshold, but the bonus depends on the logic in this method.
@@ -1172,59 +1172,57 @@
 /// analyzing the arguments and call itself with instsimplify. Returns true if
 /// it has simplified the callsite to some other entity (a constant), making it
 /// free.
-bool CallAnalyzer::simplifyCallSite(Function *F, CallSite CS) {
+bool CallAnalyzer::simplifyCallSite(Function *F, CallBase &Call) {
   // FIXME: Using the instsimplify logic directly for this is inefficient
   // because we have to continually rebuild the argument list even when no
   // simplifications can be performed. Until that is fixed with remapping
   // inside of instsimplify, directly constant fold calls here.
-  if (!canConstantFoldCallTo(cast<CallBase>(CS.getInstruction()), F))
+  if (!canConstantFoldCallTo(&Call, F))
     return false;
 
   // Try to re-map the arguments to constants.
   SmallVector<Constant *, 4> ConstantArgs;
-  ConstantArgs.reserve(CS.arg_size());
-  for (CallSite::arg_iterator I = CS.arg_begin(), E = CS.arg_end(); I != E;
-       ++I) {
-    Constant *C = dyn_cast<Constant>(*I);
+  ConstantArgs.reserve(Call.arg_size());
+  for (Value *I : Call.args()) {
+    Constant *C = dyn_cast<Constant>(I);
     if (!C)
-      C = dyn_cast_or_null<Constant>(SimplifiedValues.lookup(*I));
+      C = dyn_cast_or_null<Constant>(SimplifiedValues.lookup(I));
     if (!C)
       return false; // This argument doesn't map to a constant.
 
     ConstantArgs.push_back(C);
   }
-  if (Constant *C = ConstantFoldCall(cast<CallBase>(CS.getInstruction()), F,
-                                     ConstantArgs)) {
-    SimplifiedValues[CS.getInstruction()] = C;
+  if (Constant *C = ConstantFoldCall(&Call, F, ConstantArgs)) {
+    SimplifiedValues[&Call] = C;
     return true;
   }
 
   return false;
 }
 
-bool CallAnalyzer::visitCallSite(CallSite CS) {
-  if (CS.hasFnAttr(Attribute::ReturnsTwice) &&
+bool CallAnalyzer::visitCallBase(CallBase &Call) {
+  if (Call.hasFnAttr(Attribute::ReturnsTwice) &&
       !F.hasFnAttribute(Attribute::ReturnsTwice)) {
     // This aborts the entire analysis.
     ExposesReturnsTwice = true;
     return false;
   }
-  if (CS.isCall() && cast<CallInst>(CS.getInstruction())->cannotDuplicate())
+  if (isa<CallInst>(Call) && cast<CallInst>(Call).cannotDuplicate())
     ContainsNoDuplicateCall = true;
 
-  if (Function *F = CS.getCalledFunction()) {
+  if (Function *F = Call.getCalledFunction()) {
     // When we have a concrete function, first try to simplify it directly.
-    if (simplifyCallSite(F, CS))
+    if (simplifyCallSite(F, Call))
       return true;
 
     // Next check if it is an intrinsic we know about.
     // FIXME: Lift this into part of the InstVisitor.
-    if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CS.getInstruction())) {
+    if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Call)) {
       switch (II->getIntrinsicID()) {
       default:
-        if (!CS.onlyReadsMemory() && !isAssumeLikeIntrinsic(II))
+        if (!Call.onlyReadsMemory() && !isAssumeLikeIntrinsic(II))
           disableLoadElimination();
-        return Base::visitCallSite(CS);
+        return Base::visitCallBase(Call);
 
       case Intrinsic::load_relative:
         // This is normally lowered to 4 LLVM instructions.
@@ -1247,7 +1245,7 @@
       }
     }
 
-    if (F == CS.getInstruction()->getFunction()) {
+    if (F == Call.getFunction()) {
       // This flag will fully abort the analysis, so don't bother with anything
       // else.
       IsRecursiveCall = true;
@@ -1257,34 +1255,34 @@
     if (TTI.isLoweredToCall(F)) {
       // We account for the average 1 instruction per call argument setup
       // here.
-      Cost += CS.arg_size() * InlineConstants::InstrCost;
+      Cost += Call.arg_size() * InlineConstants::InstrCost;
 
       // Everything other than inline ASM will also have a significant cost
       // merely from making the call.
-      if (!isa<InlineAsm>(CS.getCalledValue()))
+      if (!isa<InlineAsm>(Call.getCalledValue()))
         Cost += InlineConstants::CallPenalty;
     }
 
-    if (!CS.onlyReadsMemory())
+    if (!Call.onlyReadsMemory())
       disableLoadElimination();
-    return Base::visitCallSite(CS);
+    return Base::visitCallBase(Call);
   }
 
   // Otherwise we're in a very special case -- an indirect function call. See
   // if we can be particularly clever about this.
-  Value *Callee = CS.getCalledValue();
+  Value *Callee = Call.getCalledValue();
 
   // First, pay the price of the argument setup. We account for the average
   // 1 instruction per call argument setup here.
-  Cost += CS.arg_size() * InlineConstants::InstrCost;
+  Cost += Call.arg_size() * InlineConstants::InstrCost;
 
   // Next, check if this happens to be an indirect function call to a known
   // function in this inline context. If not, we've done all we can.
   Function *F = dyn_cast_or_null<Function>(SimplifiedValues.lookup(Callee));
   if (!F) {
-    if (!CS.onlyReadsMemory())
+    if (!Call.onlyReadsMemory())
       disableLoadElimination();
-    return Base::visitCallSite(CS);
+    return Base::visitCallBase(Call);
   }
 
   // If we have a constant that we are calling as a function, we can peer
@@ -1294,9 +1292,9 @@
   // out. Pretend to inline the function, with a custom threshold.
   auto IndirectCallParams = Params;
   IndirectCallParams.DefaultThreshold = InlineConstants::IndirectCallThreshold;
-  CallAnalyzer CA(TTI, GetAssumptionCache, GetBFI, PSI, ORE, *F, CS,
+  CallAnalyzer CA(TTI, GetAssumptionCache, GetBFI, PSI, ORE, *F, Call,
                   IndirectCallParams);
-  if (CA.analyzeCall(CS)) {
+  if (CA.analyzeCall(Call)) {
     // We were able to inline the indirect call! Subtract the cost from the
     // threshold to get the bonus we want to apply, but don't go below zero.
     Cost -= std::max(0, CA.getThreshold() - CA.getCost());
@@ -1304,7 +1302,7 @@
 
   if (!F->onlyReadsMemory())
     disableLoadElimination();
-  return Base::visitCallSite(CS);
+  return Base::visitCallBase(Call);
 }
 
 bool CallAnalyzer::visitReturnInst(ReturnInst &RI) {
@@ -1595,7 +1593,7 @@
       if (ORE)
         ORE->emit([&]() {
           return OptimizationRemarkMissed(DEBUG_TYPE, "NeverInline",
-                                          CandidateCS.getInstruction())
+                                          &CandidateCall)
                  << NV("Callee", &F) << " has uninlinable pattern ("
                  << NV("InlineResult", IR.message)
                  << ") and cost is not fully computed";
@@ -1612,7 +1610,7 @@
       if (ORE)
         ORE->emit([&]() {
           return OptimizationRemarkMissed(DEBUG_TYPE, "NeverInline",
-                                          CandidateCS.getInstruction())
+                                          &CandidateCall)
                  << NV("Callee", &F) << " is " << NV("InlineResult", IR.message)
                  << ". Cost is not fully computed";
         });
@@ -1712,7 +1710,7 @@
 /// factors and heuristics. If this method returns false but the computed cost
 /// is below the computed threshold, then inlining was forcibly disabled by
 /// some artifact of the routine.
-InlineResult CallAnalyzer::analyzeCall(CallSite CS) {
+InlineResult CallAnalyzer::analyzeCall(CallBase &Call) {
   ++NumCallsAnalyzed;
 
   // Perform some tweaks to the cost and threshold based on the direct
@@ -1729,7 +1727,7 @@
   assert(NumVectorInstructions == 0);
 
   // Update the threshold based on callsite properties
-  updateThreshold(CS, F);
+  updateThreshold(Call, F);
 
   // While Threshold depends on commandline options that can take negative
   // values, we want to enforce the invariant that the computed threshold and
@@ -1745,7 +1743,7 @@
 
   // Give out bonuses for the callsite, as the instructions setting them up
   // will be gone after inlining.
-  Cost -= getCallsiteCost(CS, DL);
+  Cost -= getCallsiteCost(Call, DL);
 
   // If this function uses the coldcc calling convention, prefer not to inline
   // it.
@@ -1759,14 +1757,11 @@
   if (F.empty())
     return true;
 
-  Function *Caller = CS.getInstruction()->getFunction();
+  Function *Caller = Call.getFunction();
   // Check if the caller function is recursive itself.
   for (User *U : Caller->users()) {
-    CallSite Site(U);
-    if (!Site)
-      continue;
-    Instruction *I = Site.getInstruction();
-    if (I->getFunction() == Caller) {
+    CallBase *Call = dyn_cast<CallBase>(U);
+    if (Call && Call->getFunction() == Caller) {
       IsCallerRecursive = true;
       break;
     }
@@ -1774,10 +1769,10 @@
 
   // Populate our simplified values by mapping from function arguments to call
   // arguments with known important simplifications.
-  CallSite::arg_iterator CAI = CS.arg_begin();
+  auto CAI = Call.arg_begin();
   for (Function::arg_iterator FAI = F.arg_begin(), FAE = F.arg_end();
        FAI != FAE; ++FAI, ++CAI) {
-    assert(CAI != CS.arg_end());
+    assert(CAI != Call.arg_end());
     if (Constant *C = dyn_cast<Constant>(CAI))
       SimplifiedValues[&*FAI] = C;
 
@@ -1887,7 +1882,7 @@
   }
 
   bool OnlyOneCallAndLocalLinkage =
-      F.hasLocalLinkage() && F.hasOneUse() && &F == CS.getCalledFunction();
+      F.hasLocalLinkage() && F.hasOneUse() && &F == Call.getCalledFunction();
   // If this is a noduplicate call, we can still inline as long as
   // inlining this would cause the removal of the caller (so the instruction
   // is not actually duplicated, just moved).
@@ -1953,13 +1948,13 @@
          AttributeFuncs::areInlineCompatible(*Caller, *Callee);
 }
 
-int llvm::getCallsiteCost(CallSite CS, const DataLayout &DL) {
+int llvm::getCallsiteCost(CallBase &Call, const DataLayout &DL) {
   int Cost = 0;
-  for (unsigned I = 0, E = CS.arg_size(); I != E; ++I) {
-    if (CS.isByValArgument(I)) {
+  for (unsigned I = 0, E = Call.arg_size(); I != E; ++I) {
+    if (Call.isByValArgument(I)) {
       // We approximate the number of loads and stores needed by dividing the
       // size of the byval type by the target's pointer size.
-      PointerType *PTy = cast<PointerType>(CS.getArgument(I)->getType());
+      PointerType *PTy = cast<PointerType>(Call.getArgOperand(I)->getType());
       unsigned TypeSize = DL.getTypeSizeInBits(PTy->getElementType());
       unsigned AS = PTy->getAddressSpace();
       unsigned PointerSize = DL.getPointerSizeInBits(AS);
@@ -1987,16 +1982,16 @@
 }
 
 InlineCost llvm::getInlineCost(
-    CallSite CS, const InlineParams &Params, TargetTransformInfo &CalleeTTI,
+    CallBase &Call, const InlineParams &Params, TargetTransformInfo &CalleeTTI,
     std::function<AssumptionCache &(Function &)> &GetAssumptionCache,
     Optional<function_ref<BlockFrequencyInfo &(Function &)>> GetBFI,
     ProfileSummaryInfo *PSI, OptimizationRemarkEmitter *ORE) {
-  return getInlineCost(CS, CS.getCalledFunction(), Params, CalleeTTI,
+  return getInlineCost(Call, Call.getCalledFunction(), Params, CalleeTTI,
                        GetAssumptionCache, GetBFI, PSI, ORE);
 }
 
 InlineCost llvm::getInlineCost(
-    CallSite CS, Function *Callee, const InlineParams &Params,
+    CallBase &Call, Function *Callee, const InlineParams &Params,
     TargetTransformInfo &CalleeTTI,
     std::function<AssumptionCache &(Function &)> &GetAssumptionCache,
     Optional<function_ref<BlockFrequencyInfo &(Function &)>> GetBFI,
@@ -2012,9 +2007,9 @@
   // argument is in the alloca address space (so it is a little bit complicated
   // to solve).
   unsigned AllocaAS = Callee->getParent()->getDataLayout().getAllocaAddrSpace();
-  for (unsigned I = 0, E = CS.arg_size(); I != E; ++I)
-    if (CS.isByValArgument(I)) {
-      PointerType *PTy = cast<PointerType>(CS.getArgument(I)->getType());
+  for (unsigned I = 0, E = Call.arg_size(); I != E; ++I)
+    if (Call.isByValArgument(I)) {
+      PointerType *PTy = cast<PointerType>(Call.getArgOperand(I)->getType());
       if (PTy->getAddressSpace() != AllocaAS)
         return llvm::InlineCost::getNever("byval arguments without alloca"
                                           " address space");
@@ -2022,7 +2017,7 @@
 
   // Calls to functions with always-inline attributes should be inlined
   // whenever possible.
-  if (CS.hasFnAttr(Attribute::AlwaysInline)) {
+  if (Call.hasFnAttr(Attribute::AlwaysInline)) {
     auto IsViable = isInlineViable(*Callee);
     if (IsViable)
       return llvm::InlineCost::getAlways("always inline attribute");
@@ -2031,7 +2026,7 @@
 
   // Never inline functions with conflicting attributes (unless callee has
   // always-inline attribute).
-  Function *Caller = CS.getCaller();
+  Function *Caller = Call.getCaller();
   if (!functionsHaveCompatibleAttributes(Caller, Callee, CalleeTTI))
     return llvm::InlineCost::getNever("conflicting attributes");
 
@@ -2053,15 +2048,15 @@
     return llvm::InlineCost::getNever("noinline function attribute");
 
   // Don't inline call sites marked noinline.
-  if (CS.isNoInline())
+  if (Call.isNoInline())
     return llvm::InlineCost::getNever("noinline call site attribute");
 
   LLVM_DEBUG(llvm::dbgs() << "      Analyzing call of " << Callee->getName()
                           << "... (caller:" << Caller->getName() << ")\n");
 
-  CallAnalyzer CA(CalleeTTI, GetAssumptionCache, GetBFI, PSI, ORE, *Callee, CS,
-                  Params);
-  InlineResult ShouldInline = CA.analyzeCall(CS);
+  CallAnalyzer CA(CalleeTTI, GetAssumptionCache, GetBFI, PSI, ORE, *Callee,
+                  Call, Params);
+  InlineResult ShouldInline = CA.analyzeCall(Call);
 
   LLVM_DEBUG(CA.dump());
 
@@ -2086,22 +2081,22 @@
       return "uses block address";
 
     for (auto &II : *BI) {
-      CallSite CS(&II);
-      if (!CS)
+      CallBase *Call = dyn_cast<CallBase>(&II);
+      if (!Call)
         continue;
 
       // Disallow recursive calls.
-      if (&F == CS.getCalledFunction())
+      if (&F == Call->getCalledFunction())
         return "recursive call";
 
       // Disallow calls which expose returns-twice to a function not previously
       // attributed as such.
-      if (!ReturnsTwice && CS.isCall() &&
-          cast<CallInst>(CS.getInstruction())->canReturnTwice())
+      if (!ReturnsTwice && isa<CallInst>(Call) &&
+          cast<CallInst>(Call)->canReturnTwice())
         return "exposes returns-twice attribute";
 
-      if (CS.getCalledFunction())
-        switch (CS.getCalledFunction()->getIntrinsicID()) {
+      if (Call->getCalledFunction())
+        switch (Call->getCalledFunction()->getIntrinsicID()) {
         default:
           break;
         // Disallow inlining of @llvm.icall.branch.funnel because current
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInline.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInline.cpp
index 0ad78e0..8169421 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInline.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInline.cpp
@@ -206,6 +206,7 @@
     return ACT->getAssumptionCache(F);
   };
 
-  return llvm::getInlineCost(CS, Callee, LocalParams, TTI, GetAssumptionCache,
-                             None, PSI, RemarksEnabled ? &ORE : nullptr);
+  return llvm::getInlineCost(cast<CallBase>(*CS.getInstruction()), Callee,
+                             LocalParams, TTI, GetAssumptionCache, None, PSI,
+                             RemarksEnabled ? &ORE : nullptr);
 }
diff --git a/llvm/lib/Transforms/IPO/InlineSimple.cpp b/llvm/lib/Transforms/IPO/InlineSimple.cpp
index c9ce14f..efb71b7 100644
--- a/llvm/lib/Transforms/IPO/InlineSimple.cpp
+++ b/llvm/lib/Transforms/IPO/InlineSimple.cpp
@@ -68,9 +68,9 @@
         [&](Function &F) -> AssumptionCache & {
       return ACT->getAssumptionCache(F);
     };
-    return llvm::getInlineCost(CS, Params, TTI, GetAssumptionCache,
-                               /*GetBFI=*/None, PSI,
-                               RemarksEnabled ? &ORE : nullptr);
+    return llvm::getInlineCost(
+        cast<CallBase>(*CS.getInstruction()), Params, TTI, GetAssumptionCache,
+        /*GetBFI=*/None, PSI, RemarksEnabled ? &ORE : nullptr);
   }
 
   bool runOnSCC(CallGraphSCC &SCC) override;
diff --git a/llvm/lib/Transforms/IPO/Inliner.cpp b/llvm/lib/Transforms/IPO/Inliner.cpp
index f6e8b53..945f8af 100644
--- a/llvm/lib/Transforms/IPO/Inliner.cpp
+++ b/llvm/lib/Transforms/IPO/Inliner.cpp
@@ -1008,8 +1008,9 @@
       bool RemarksEnabled =
           Callee.getContext().getDiagHandlerPtr()->isMissedOptRemarkEnabled(
               DEBUG_TYPE);
-      return getInlineCost(CS, Params, CalleeTTI, GetAssumptionCache, {GetBFI},
-                           PSI, RemarksEnabled ? &ORE : nullptr);
+      return getInlineCost(cast<CallBase>(*CS.getInstruction()), Params,
+                           CalleeTTI, GetAssumptionCache, {GetBFI}, PSI,
+                           RemarksEnabled ? &ORE : nullptr);
     };
 
     // Now process as many calls as we have within this caller in the sequnece.
diff --git a/llvm/lib/Transforms/IPO/PartialInlining.cpp b/llvm/lib/Transforms/IPO/PartialInlining.cpp
index a172aad..733782e 100644
--- a/llvm/lib/Transforms/IPO/PartialInlining.cpp
+++ b/llvm/lib/Transforms/IPO/PartialInlining.cpp
@@ -775,9 +775,10 @@
   bool RemarksEnabled =
       Callee->getContext().getDiagHandlerPtr()->isMissedOptRemarkEnabled(
           DEBUG_TYPE);
-  InlineCost IC =
-      getInlineCost(CS, getInlineParams(), CalleeTTI, *GetAssumptionCache,
-                    GetBFI, PSI, RemarksEnabled ? &ORE : nullptr);
+  assert(Call && "invalid callsite for partial inline");
+  InlineCost IC = getInlineCost(cast<CallBase>(*Call), getInlineParams(),
+                                CalleeTTI, *GetAssumptionCache, GetBFI, PSI,
+                                RemarksEnabled ? &ORE : nullptr);
 
   if (IC.isAlways()) {
     ORE.emit([&]() {
@@ -811,7 +812,7 @@
   const DataLayout &DL = Caller->getParent()->getDataLayout();
 
   // The savings of eliminating the call:
-  int NonWeightedSavings = getCallsiteCost(CS, DL);
+  int NonWeightedSavings = getCallsiteCost(cast<CallBase>(*Call), DL);
   BlockFrequency NormWeightedSavings(NonWeightedSavings);
 
   // Weighted saving is smaller than weighted cost, return false
@@ -868,12 +869,12 @@
       continue;
 
     if (CallInst *CI = dyn_cast<CallInst>(&I)) {
-      InlineCost += getCallsiteCost(CallSite(CI), DL);
+      InlineCost += getCallsiteCost(*CI, DL);
       continue;
     }
 
     if (InvokeInst *II = dyn_cast<InvokeInst>(&I)) {
-      InlineCost += getCallsiteCost(CallSite(II), DL);
+      InlineCost += getCallsiteCost(*II, DL);
       continue;
     }
 
diff --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp
index 2955c33..877d20e 100644
--- a/llvm/lib/Transforms/IPO/SampleProfile.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp
@@ -752,8 +752,9 @@
   // when cost exceeds threshold without checking all IRs in the callee.
   // The acutal cost does not matter because we only checks isNever() to
   // see if it is legal to inline the callsite.
-  InlineCost Cost = getInlineCost(CS, Params, GetTTI(*CalledFunction), GetAC,
-                                  None, nullptr, nullptr);
+  InlineCost Cost =
+      getInlineCost(cast<CallBase>(*I), Params, GetTTI(*CalledFunction), GetAC,
+                    None, nullptr, nullptr);
   if (Cost.isNever()) {
     ORE->emit(OptimizationRemark(DEBUG_TYPE, "Not inline", DLoc, BB)
               << "incompatible inlining");