Use branch funnels for virtual calls when retpoline mitigation is enabled.

The retpoline mitigation for variant 2 of CVE-2017-5715 inhibits the
branch predictor, and as a result it can lead to a measurable loss of
performance. We can reduce the performance impact of retpolined virtual
calls by replacing them with a special construct known as a branch
funnel, which is an instruction sequence that implements virtual calls
to a set of known targets using a binary tree of direct branches. This
allows the processor to speculately execute valid implementations of the
virtual function without allowing for speculative execution of of calls
to arbitrary addresses.

This patch extends the whole-program devirtualization pass to replace
certain virtual calls with calls to branch funnels, which are
represented using a new llvm.icall.jumptable intrinsic. It also extends
the LowerTypeTests pass to recognize the new intrinsic, generate code
for the branch funnels (x86_64 only for now) and lay out virtual tables
as required for each branch funnel.

The implementation supports full LTO as well as ThinLTO, and extends the
ThinLTO summary format used for whole-program devirtualization to
support branch funnels.

For more details see RFC:
http://lists.llvm.org/pipermail/llvm-dev/2018-January/120672.html

Differential Revision: https://reviews.llvm.org/D42453

llvm-svn: 327163
diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
index aa1755b..a3aa7c4 100644
--- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
+++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
@@ -316,12 +316,17 @@
   /// cases we are directly operating on the call sites at the IR level.
   std::vector<VirtualCallSite> CallSites;
 
+  /// Whether all call sites represented by this CallSiteInfo, including those
+  /// in summaries, have been devirtualized. This starts off as true because a
+  /// default constructed CallSiteInfo represents no call sites.
+  bool AllCallSitesDevirted = true;
+
   // These fields are used during the export phase of ThinLTO and reflect
   // information collected from function summaries.
 
   /// Whether any function summary contains an llvm.assume(llvm.type.test) for
   /// this slot.
-  bool SummaryHasTypeTestAssumeUsers;
+  bool SummaryHasTypeTestAssumeUsers = false;
 
   /// CFI-specific: a vector containing the list of function summaries that use
   /// the llvm.type.checked.load intrinsic and therefore will require
@@ -337,8 +342,22 @@
            !SummaryTypeCheckedLoadUsers.empty();
   }
 
-  /// As explained in the comment for SummaryTypeCheckedLoadUsers.
-  void markDevirt() { SummaryTypeCheckedLoadUsers.clear(); }
+  void markSummaryHasTypeTestAssumeUsers() {
+    SummaryHasTypeTestAssumeUsers = true;
+    AllCallSitesDevirted = false;
+  }
+
+  void addSummaryTypeCheckedLoadUser(FunctionSummary *FS) {
+    SummaryTypeCheckedLoadUsers.push_back(FS);
+    AllCallSitesDevirted = false;
+  }
+
+  void markDevirt() {
+    AllCallSitesDevirted = true;
+
+    // As explained in the comment for SummaryTypeCheckedLoadUsers.
+    SummaryTypeCheckedLoadUsers.clear();
+  }
 };
 
 // Call site information collected for a specific VTableSlot.
@@ -373,7 +392,9 @@
 
 void VTableSlotInfo::addCallSite(Value *VTable, CallSite CS,
                                  unsigned *NumUnsafeUses) {
-  findCallSiteInfo(CS).CallSites.push_back({VTable, CS, NumUnsafeUses});
+  auto &CSI = findCallSiteInfo(CS);
+  CSI.AllCallSitesDevirted = false;
+  CSI.CallSites.push_back({VTable, CS, NumUnsafeUses});
 }
 
 struct DevirtModule {
@@ -438,6 +459,12 @@
                            VTableSlotInfo &SlotInfo,
                            WholeProgramDevirtResolution *Res);
 
+  void applyICallBranchFunnel(VTableSlotInfo &SlotInfo, Constant *JT,
+                              bool &IsExported);
+  void tryICallBranchFunnel(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
+                            VTableSlotInfo &SlotInfo,
+                            WholeProgramDevirtResolution *Res, VTableSlot Slot);
+
   bool tryEvaluateFunctionsWithArgs(
       MutableArrayRef<VirtualCallTarget> TargetsForSlot,
       ArrayRef<uint64_t> Args);
@@ -471,6 +498,8 @@
                            StringRef Name, IntegerType *IntTy,
                            uint32_t Storage);
 
+  Constant *getMemberAddr(const TypeMemberInfo *M);
+
   void applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, bool IsOne,
                             Constant *UniqueMemberAddr);
   bool tryUniqueRetValOpt(unsigned BitWidth,
@@ -726,10 +755,9 @@
       if (VCallSite.NumUnsafeUses)
         --*VCallSite.NumUnsafeUses;
     }
-    if (CSInfo.isExported()) {
+    if (CSInfo.isExported())
       IsExported = true;
-      CSInfo.markDevirt();
-    }
+    CSInfo.markDevirt();
   };
   Apply(SlotInfo.CSInfo);
   for (auto &P : SlotInfo.ConstCSInfo)
@@ -785,6 +813,134 @@
   return true;
 }
 
+void DevirtModule::tryICallBranchFunnel(
+    MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo,
+    WholeProgramDevirtResolution *Res, VTableSlot Slot) {
+  Triple T(M.getTargetTriple());
+  if (T.getArch() != Triple::x86_64)
+    return;
+
+  const unsigned kBranchFunnelThreshold = 10;
+  if (TargetsForSlot.size() > kBranchFunnelThreshold)
+    return;
+
+  bool HasNonDevirt = !SlotInfo.CSInfo.AllCallSitesDevirted;
+  if (!HasNonDevirt)
+    for (auto &P : SlotInfo.ConstCSInfo)
+      if (!P.second.AllCallSitesDevirted) {
+        HasNonDevirt = true;
+        break;
+      }
+
+  if (!HasNonDevirt)
+    return;
+
+  FunctionType *FT =
+      FunctionType::get(Type::getVoidTy(M.getContext()), {Int8PtrTy}, true);
+  Function *JT;
+  if (isa<MDString>(Slot.TypeID)) {
+    JT = Function::Create(FT, Function::ExternalLinkage,
+                          getGlobalName(Slot, {}, "branch_funnel"), &M);
+    JT->setVisibility(GlobalValue::HiddenVisibility);
+  } else {
+    JT = Function::Create(FT, Function::InternalLinkage, "branch_funnel", &M);
+  }
+  JT->addAttribute(1, Attribute::Nest);
+
+  std::vector<Value *> JTArgs;
+  JTArgs.push_back(JT->arg_begin());
+  for (auto &T : TargetsForSlot) {
+    JTArgs.push_back(getMemberAddr(T.TM));
+    JTArgs.push_back(T.Fn);
+  }
+
+  BasicBlock *BB = BasicBlock::Create(M.getContext(), "", JT, nullptr);
+  Constant *Intr =
+      Intrinsic::getDeclaration(&M, llvm::Intrinsic::icall_branch_funnel, {});
+
+  auto *CI = CallInst::Create(Intr, JTArgs, "", BB);
+  CI->setTailCallKind(CallInst::TCK_MustTail);
+  ReturnInst::Create(M.getContext(), nullptr, BB);
+
+  bool IsExported = false;
+  applyICallBranchFunnel(SlotInfo, JT, IsExported);
+  if (IsExported)
+    Res->TheKind = WholeProgramDevirtResolution::BranchFunnel;
+}
+
+void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
+                                          Constant *JT, bool &IsExported) {
+  auto Apply = [&](CallSiteInfo &CSInfo) {
+    if (CSInfo.isExported())
+      IsExported = true;
+    if (CSInfo.AllCallSitesDevirted)
+      return;
+    for (auto &&VCallSite : CSInfo.CallSites) {
+      CallSite CS = VCallSite.CS;
+
+      // Jump tables are only profitable if the retpoline mitigation is enabled.
+      Attribute FSAttr = CS.getCaller()->getFnAttribute("target-features");
+      if (FSAttr.hasAttribute(Attribute::None) ||
+          !FSAttr.getValueAsString().contains("+retpoline"))
+        continue;
+
+      if (RemarksEnabled)
+        VCallSite.emitRemark("branch-funnel", JT->getName(), OREGetter);
+
+      // Pass the address of the vtable in the nest register, which is r10 on
+      // x86_64.
+      std::vector<Type *> NewArgs;
+      NewArgs.push_back(Int8PtrTy);
+      for (Type *T : CS.getFunctionType()->params())
+        NewArgs.push_back(T);
+      PointerType *NewFT = PointerType::getUnqual(
+          FunctionType::get(CS.getFunctionType()->getReturnType(), NewArgs,
+                            CS.getFunctionType()->isVarArg()));
+
+      IRBuilder<> IRB(CS.getInstruction());
+      std::vector<Value *> Args;
+      Args.push_back(IRB.CreateBitCast(VCallSite.VTable, Int8PtrTy));
+      for (unsigned I = 0; I != CS.getNumArgOperands(); ++I)
+        Args.push_back(CS.getArgOperand(I));
+
+      CallSite NewCS;
+      if (CS.isCall())
+        NewCS = IRB.CreateCall(IRB.CreateBitCast(JT, NewFT), Args);
+      else
+        NewCS = IRB.CreateInvoke(
+            IRB.CreateBitCast(JT, NewFT),
+            cast<InvokeInst>(CS.getInstruction())->getNormalDest(),
+            cast<InvokeInst>(CS.getInstruction())->getUnwindDest(), Args);
+      NewCS.setCallingConv(CS.getCallingConv());
+
+      AttributeList Attrs = CS.getAttributes();
+      std::vector<AttributeSet> NewArgAttrs;
+      NewArgAttrs.push_back(AttributeSet::get(
+          M.getContext(), ArrayRef<Attribute>{Attribute::get(
+                              M.getContext(), Attribute::Nest)}));
+      for (unsigned I = 0; I + 2 <  Attrs.getNumAttrSets(); ++I)
+        NewArgAttrs.push_back(Attrs.getParamAttributes(I));
+      NewCS.setAttributes(
+          AttributeList::get(M.getContext(), Attrs.getFnAttributes(),
+                             Attrs.getRetAttributes(), NewArgAttrs));
+
+      CS->replaceAllUsesWith(NewCS.getInstruction());
+      CS->eraseFromParent();
+
+      // This use is no longer unsafe.
+      if (VCallSite.NumUnsafeUses)
+        --*VCallSite.NumUnsafeUses;
+    }
+    // Don't mark as devirtualized because there may be callers compiled without
+    // retpoline mitigation, which would mean that they are lowered to
+    // llvm.type.test and therefore require an llvm.type.test resolution for the
+    // type identifier.
+  };
+  Apply(SlotInfo.CSInfo);
+  for (auto &P : SlotInfo.ConstCSInfo)
+    Apply(P.second);
+}
+
 bool DevirtModule::tryEvaluateFunctionsWithArgs(
     MutableArrayRef<VirtualCallTarget> TargetsForSlot,
     ArrayRef<uint64_t> Args) {
@@ -937,6 +1093,12 @@
   CSInfo.markDevirt();
 }
 
+Constant *DevirtModule::getMemberAddr(const TypeMemberInfo *M) {
+  Constant *C = ConstantExpr::getBitCast(M->Bits->GV, Int8PtrTy);
+  return ConstantExpr::getGetElementPtr(Int8Ty, C,
+                                        ConstantInt::get(Int64Ty, M->Offset));
+}
+
 bool DevirtModule::tryUniqueRetValOpt(
     unsigned BitWidth, MutableArrayRef<VirtualCallTarget> TargetsForSlot,
     CallSiteInfo &CSInfo, WholeProgramDevirtResolution::ByArg *Res,
@@ -956,12 +1118,7 @@
     // checked for a uniform return value in tryUniformRetValOpt.
     assert(UniqueMember);
 
-    Constant *UniqueMemberAddr =
-        ConstantExpr::getBitCast(UniqueMember->Bits->GV, Int8PtrTy);
-    UniqueMemberAddr = ConstantExpr::getGetElementPtr(
-        Int8Ty, UniqueMemberAddr,
-        ConstantInt::get(Int64Ty, UniqueMember->Offset));
-
+    Constant *UniqueMemberAddr = getMemberAddr(UniqueMember);
     if (CSInfo.isExported()) {
       Res->TheKind = WholeProgramDevirtResolution::ByArg::UniqueRetVal;
       Res->Info = IsOne;
@@ -1348,6 +1505,14 @@
       break;
     }
   }
+
+  if (Res.TheKind == WholeProgramDevirtResolution::BranchFunnel) {
+    auto *JT = M.getOrInsertFunction(getGlobalName(Slot, {}, "branch_funnel"),
+                                     Type::getVoidTy(M.getContext()));
+    bool IsExported = false;
+    applyICallBranchFunnel(SlotInfo, JT, IsExported);
+    assert(!IsExported);
+  }
 }
 
 void DevirtModule::removeRedundantTypeTests() {
@@ -1417,14 +1582,13 @@
         // FIXME: Only add live functions.
         for (FunctionSummary::VFuncId VF : FS->type_test_assume_vcalls()) {
           for (Metadata *MD : MetadataByGUID[VF.GUID]) {
-            CallSlots[{MD, VF.Offset}].CSInfo.SummaryHasTypeTestAssumeUsers =
-                true;
+            CallSlots[{MD, VF.Offset}]
+                .CSInfo.markSummaryHasTypeTestAssumeUsers();
           }
         }
         for (FunctionSummary::VFuncId VF : FS->type_checked_load_vcalls()) {
           for (Metadata *MD : MetadataByGUID[VF.GUID]) {
-            CallSlots[{MD, VF.Offset}]
-                .CSInfo.SummaryTypeCheckedLoadUsers.push_back(FS);
+            CallSlots[{MD, VF.Offset}].CSInfo.addSummaryTypeCheckedLoadUser(FS);
           }
         }
         for (const FunctionSummary::ConstVCall &VC :
@@ -1432,7 +1596,7 @@
           for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) {
             CallSlots[{MD, VC.VFunc.Offset}]
                 .ConstCSInfo[VC.Args]
-                .SummaryHasTypeTestAssumeUsers = true;
+                .markSummaryHasTypeTestAssumeUsers();
           }
         }
         for (const FunctionSummary::ConstVCall &VC :
@@ -1440,7 +1604,7 @@
           for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) {
             CallSlots[{MD, VC.VFunc.Offset}]
                 .ConstCSInfo[VC.Args]
-                .SummaryTypeCheckedLoadUsers.push_back(FS);
+                .addSummaryTypeCheckedLoadUser(FS);
           }
         }
       }
@@ -1464,9 +1628,12 @@
                        cast<MDString>(S.first.TypeID)->getString())
                    .WPDRes[S.first.ByteOffset];
 
-      if (!trySingleImplDevirt(TargetsForSlot, S.second, Res) &&
-          tryVirtualConstProp(TargetsForSlot, S.second, Res, S.first))
-        DidVirtualConstProp = true;
+      if (!trySingleImplDevirt(TargetsForSlot, S.second, Res)) {
+        DidVirtualConstProp |=
+            tryVirtualConstProp(TargetsForSlot, S.second, Res, S.first);
+
+        tryICallBranchFunnel(TargetsForSlot, S.second, Res, S.first);
+      }
 
       // Collect functions devirtualized at least for one call site for stats.
       if (RemarksEnabled)