Cross-DSO control flow integrity (Clang part).

Clang-side cross-DSO CFI.

* Adds a command line flag -f[no-]sanitize-cfi-cross-dso.
* Links a runtime library when enabled.
* Emits __cfi_slowpath calls is bitset test fails.
* Emits extra hash-based bitsets for external CFI checks.
* Sets a module flag to enable __cfi_check generation during LTO.

This mode does not yet support diagnostics.

llvm-svn: 255694
diff --git a/clang/lib/CodeGen/CGClass.cpp b/clang/lib/CodeGen/CGClass.cpp
index 77ec7f2..2e566de 100644
--- a/clang/lib/CodeGen/CGClass.cpp
+++ b/clang/lib/CodeGen/CGClass.cpp
@@ -2552,15 +2552,22 @@
 
   SanitizerScope SanScope(this);
 
-  llvm::Value *BitSetName = llvm::MetadataAsValue::get(
-      getLLVMContext(),
-      CGM.CreateMetadataIdentifierForType(QualType(RD->getTypeForDecl(), 0)));
+  llvm::Metadata *MD =
+      CGM.CreateMetadataIdentifierForType(QualType(RD->getTypeForDecl(), 0));
+  llvm::Value *BitSetName = llvm::MetadataAsValue::get(getLLVMContext(), MD);
 
   llvm::Value *CastedVTable = Builder.CreateBitCast(VTable, Int8PtrTy);
   llvm::Value *BitSetTest =
       Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::bitset_test),
                          {CastedVTable, BitSetName});
 
+  if (CGM.getCodeGenOpts().SanitizeCfiCrossDso) {
+    if (auto TypeId = CGM.CreateCfiIdForTypeMetadata(MD)) {
+      EmitCfiSlowPathCheck(BitSetTest, TypeId, CastedVTable);
+      return;
+    }
+  }
+
   SanitizerMask M;
   switch (TCK) {
   case CFITCK_VCall:
@@ -2578,9 +2585,9 @@
   }
 
   llvm::Constant *StaticData[] = {
-    EmitCheckSourceLocation(Loc),
-    EmitCheckTypeDescriptor(QualType(RD->getTypeForDecl(), 0)),
-    llvm::ConstantInt::get(Int8Ty, TCK),
+      EmitCheckSourceLocation(Loc),
+      EmitCheckTypeDescriptor(QualType(RD->getTypeForDecl(), 0)),
+      llvm::ConstantInt::get(Int8Ty, TCK),
   };
   EmitCheck(std::make_pair(BitSetTest, M), "cfi_bad_type", StaticData,
             CastedVTable);
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index c2db728..dabd2b1 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -2532,6 +2532,34 @@
   EmitBlock(Cont);
 }
 
+void CodeGenFunction::EmitCfiSlowPathCheck(llvm::Value *Cond,
+                                           llvm::ConstantInt *TypeId,
+                                           llvm::Value *Ptr) {
+  auto &Ctx = getLLVMContext();
+  llvm::BasicBlock *Cont = createBasicBlock("cfi.cont");
+
+  llvm::BasicBlock *CheckBB = createBasicBlock("cfi.slowpath");
+  llvm::BranchInst *BI = Builder.CreateCondBr(Cond, Cont, CheckBB);
+
+  llvm::MDBuilder MDHelper(getLLVMContext());
+  llvm::MDNode *Node = MDHelper.createBranchWeights((1U << 20) - 1, 1);
+  BI->setMetadata(llvm::LLVMContext::MD_prof, Node);
+
+  EmitBlock(CheckBB);
+
+  llvm::Constant *SlowPathFn = CGM.getModule().getOrInsertFunction(
+      "__cfi_slowpath",
+      llvm::FunctionType::get(
+          llvm::Type::getVoidTy(Ctx),
+          {llvm::Type::getInt64Ty(Ctx),
+           llvm::PointerType::getUnqual(llvm::Type::getInt8Ty(Ctx))},
+          false));
+  llvm::CallInst *CheckCall = Builder.CreateCall(SlowPathFn, {TypeId, Ptr});
+  CheckCall->setDoesNotThrow();
+
+  EmitBlock(Cont);
+}
+
 void CodeGenFunction::EmitTrapCheck(llvm::Value *Checked) {
   llvm::BasicBlock *Cont = createBasicBlock("cont");
 
@@ -3823,21 +3851,25 @@
       (!TargetDecl || !isa<FunctionDecl>(TargetDecl))) {
     SanitizerScope SanScope(this);
 
-    llvm::Value *BitSetName = llvm::MetadataAsValue::get(
-        getLLVMContext(),
-        CGM.CreateMetadataIdentifierForType(QualType(FnType, 0)));
+    llvm::Metadata *MD = CGM.CreateMetadataIdentifierForType(QualType(FnType, 0));
+    llvm::Value *BitSetName = llvm::MetadataAsValue::get(getLLVMContext(), MD);
 
     llvm::Value *CastedCallee = Builder.CreateBitCast(Callee, Int8PtrTy);
     llvm::Value *BitSetTest =
         Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::bitset_test),
                            {CastedCallee, BitSetName});
 
-    llvm::Constant *StaticData[] = {
-      EmitCheckSourceLocation(E->getLocStart()),
-      EmitCheckTypeDescriptor(QualType(FnType, 0)),
-    };
-    EmitCheck(std::make_pair(BitSetTest, SanitizerKind::CFIICall),
-              "cfi_bad_icall", StaticData, CastedCallee);
+    auto TypeId = CGM.CreateCfiIdForTypeMetadata(MD);
+    if (CGM.getCodeGenOpts().SanitizeCfiCrossDso && TypeId) {
+      EmitCfiSlowPathCheck(BitSetTest, TypeId, CastedCallee);
+    } else {
+      llvm::Constant *StaticData[] = {
+          EmitCheckSourceLocation(E->getLocStart()),
+          EmitCheckTypeDescriptor(QualType(FnType, 0)),
+      };
+      EmitCheck(std::make_pair(BitSetTest, SanitizerKind::CFIICall),
+                "cfi_bad_icall", StaticData, CastedCallee);
+    }
   }
 
   CallArgList Args;
diff --git a/clang/lib/CodeGen/CGVTables.cpp b/clang/lib/CodeGen/CGVTables.cpp
index 3d8c41c..797c408 100644
--- a/clang/lib/CodeGen/CGVTables.cpp
+++ b/clang/lib/CodeGen/CGVTables.cpp
@@ -934,6 +934,7 @@
   llvm::NamedMDNode *BitsetsMD =
       getModule().getOrInsertNamedMetadata("llvm.bitsets");
   for (auto BitsetEntry : BitsetEntries)
-    BitsetsMD->addOperand(CreateVTableBitSetEntry(
-        VTable, PointerWidth * BitsetEntry.second, BitsetEntry.first));
+    CreateVTableBitSetEntry(BitsetsMD, VTable,
+                            PointerWidth * BitsetEntry.second,
+                            BitsetEntry.first);
 }
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index fa98b2b..b4a9186 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -3009,6 +3009,11 @@
                  StringRef CheckName, ArrayRef<llvm::Constant *> StaticArgs,
                  ArrayRef<llvm::Value *> DynamicArgs);
 
+  /// \brief Emit a slow path cross-DSO CFI check which calls __cfi_slowpath
+  /// if Cond if false.
+  void EmitCfiSlowPathCheck(llvm::Value *Cond, llvm::ConstantInt *TypeId,
+                            llvm::Value *Ptr);
+
   /// \brief Create a basic block that will call the trap intrinsic, and emit a
   /// conditional branch to it, for the -ftrapv checks.
   void EmitTrapCheck(llvm::Value *Checked);
diff --git a/clang/lib/CodeGen/CodeGenModule.cpp b/clang/lib/CodeGen/CodeGenModule.cpp
index b207bac..3b3fc87 100644
--- a/clang/lib/CodeGen/CodeGenModule.cpp
+++ b/clang/lib/CodeGen/CodeGenModule.cpp
@@ -53,6 +53,7 @@
 #include "llvm/ProfileData/InstrProfReader.h"
 #include "llvm/Support/ConvertUTF.h"
 #include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/MD5.h"
 
 using namespace clang;
 using namespace CodeGen;
@@ -439,6 +440,11 @@
     getModule().addModuleFlag(llvm::Module::Error, "min_enum_size", EnumWidth);
   }
 
+  if (CodeGenOpts.SanitizeCfiCrossDso) {
+    // Indicate that we want cross-DSO control flow integrity checks.
+    getModule().addModuleFlag(llvm::Module::Override, "Cross-DSO CFI", 1);
+  }
+
   if (uint32_t PLevel = Context.getLangOpts().PICLevel) {
     llvm::PICLevel::Level PL = llvm::PICLevel::Default;
     switch (PLevel) {
@@ -736,6 +742,21 @@
     F->setDLLStorageClass(llvm::GlobalVariable::DefaultStorageClass);
 }
 
+llvm::ConstantInt *
+CodeGenModule::CreateCfiIdForTypeMetadata(llvm::Metadata *MD) {
+  llvm::MDString *MDS = dyn_cast<llvm::MDString>(MD);
+  if (!MDS) return nullptr;
+
+  llvm::MD5 md5;
+  llvm::MD5::MD5Result result;
+  md5.update(MDS->getString());
+  md5.final(result);
+  uint64_t id = 0;
+  for (int i = 0; i < 8; ++i)
+    id |= static_cast<uint64_t>(result[i]) << (i * 8);
+  return llvm::ConstantInt::get(Int64Ty, id);
+}
+
 void CodeGenModule::setFunctionDefinitionAttributes(const FunctionDecl *D,
                                                     llvm::Function *F) {
   setNonAliasAttributes(D, F);
@@ -928,6 +949,49 @@
   }
 }
 
+void CodeGenModule::CreateFunctionBitSetEntry(const FunctionDecl *FD,
+                                              llvm::Function *F) {
+  // Only if we are checking indirect calls.
+  if (!LangOpts.Sanitize.has(SanitizerKind::CFIICall))
+    return;
+
+  // Non-static class methods are handled via vtable pointer checks elsewhere.
+  if (isa<CXXMethodDecl>(FD) && !cast<CXXMethodDecl>(FD)->isStatic())
+    return;
+
+  // Additionally, if building with cross-DSO support...
+  if (CodeGenOpts.SanitizeCfiCrossDso) {
+    // Don't emit entries for function declarations. In cross-DSO mode these are
+    // handled with better precision at run time.
+    if (!FD->hasBody())
+      return;
+    // Skip available_externally functions. They won't be codegen'ed in the
+    // current module anyway.
+    if (getContext().GetGVALinkageForFunction(FD) == GVA_AvailableExternally)
+      return;
+  }
+
+  llvm::NamedMDNode *BitsetsMD =
+      getModule().getOrInsertNamedMetadata("llvm.bitsets");
+
+  llvm::Metadata *MD = CreateMetadataIdentifierForType(FD->getType());
+  llvm::Metadata *BitsetOps[] = {
+      MD, llvm::ConstantAsMetadata::get(F),
+      llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(Int64Ty, 0))};
+  BitsetsMD->addOperand(llvm::MDTuple::get(getLLVMContext(), BitsetOps));
+
+  // Emit a hash-based bit set entry for cross-DSO calls.
+  if (CodeGenOpts.SanitizeCfiCrossDso) {
+    if (auto TypeId = CreateCfiIdForTypeMetadata(MD)) {
+      llvm::Metadata *BitsetOps2[] = {
+          llvm::ConstantAsMetadata::get(TypeId),
+          llvm::ConstantAsMetadata::get(F),
+          llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(Int64Ty, 0))};
+      BitsetsMD->addOperand(llvm::MDTuple::get(getLLVMContext(), BitsetOps2));
+    }
+  }
+}
+
 void CodeGenModule::SetFunctionAttributes(GlobalDecl GD, llvm::Function *F,
                                           bool IsIncompleteFunction,
                                           bool IsThunk) {
@@ -970,19 +1034,7 @@
     F->addAttribute(llvm::AttributeSet::FunctionIndex,
                     llvm::Attribute::NoBuiltin);
 
-  // If we are checking indirect calls and this is not a non-static member
-  // function, emit a bit set entry for the function type.
-  if (LangOpts.Sanitize.has(SanitizerKind::CFIICall) &&
-      !(isa<CXXMethodDecl>(FD) && !cast<CXXMethodDecl>(FD)->isStatic())) {
-    llvm::NamedMDNode *BitsetsMD =
-        getModule().getOrInsertNamedMetadata("llvm.bitsets");
-
-    llvm::Metadata *BitsetOps[] = {
-        CreateMetadataIdentifierForType(FD->getType()),
-        llvm::ConstantAsMetadata::get(F),
-        llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(Int64Ty, 0))};
-    BitsetsMD->addOperand(llvm::MDTuple::get(getLLVMContext(), BitsetOps));
-  }
+  CreateFunctionBitSetEntry(FD, F);
 }
 
 void CodeGenModule::addUsedGlobal(llvm::GlobalValue *GV) {
@@ -3874,14 +3926,28 @@
   return InternalId;
 }
 
-llvm::MDTuple *CodeGenModule::CreateVTableBitSetEntry(
-    llvm::GlobalVariable *VTable, CharUnits Offset, const CXXRecordDecl *RD) {
+void CodeGenModule::CreateVTableBitSetEntry(llvm::NamedMDNode *BitsetsMD,
+                                            llvm::GlobalVariable *VTable,
+                                            CharUnits Offset,
+                                            const CXXRecordDecl *RD) {
+  llvm::Metadata *MD =
+      CreateMetadataIdentifierForType(QualType(RD->getTypeForDecl(), 0));
   llvm::Metadata *BitsetOps[] = {
-      CreateMetadataIdentifierForType(QualType(RD->getTypeForDecl(), 0)),
-      llvm::ConstantAsMetadata::get(VTable),
+      MD, llvm::ConstantAsMetadata::get(VTable),
       llvm::ConstantAsMetadata::get(
           llvm::ConstantInt::get(Int64Ty, Offset.getQuantity()))};
-  return llvm::MDTuple::get(getLLVMContext(), BitsetOps);
+  BitsetsMD->addOperand(llvm::MDTuple::get(getLLVMContext(), BitsetOps));
+
+  if (CodeGenOpts.SanitizeCfiCrossDso) {
+    if (auto TypeId = CreateCfiIdForTypeMetadata(MD)) {
+      llvm::Metadata *BitsetOps2[] = {
+          llvm::ConstantAsMetadata::get(TypeId),
+          llvm::ConstantAsMetadata::get(VTable),
+          llvm::ConstantAsMetadata::get(
+              llvm::ConstantInt::get(Int64Ty, Offset.getQuantity()))};
+      BitsetsMD->addOperand(llvm::MDTuple::get(getLLVMContext(), BitsetOps2));
+    }
+  }
 }
 
 // Fills in the supplied string map with the set of target features for the
diff --git a/clang/lib/CodeGen/CodeGenModule.h b/clang/lib/CodeGen/CodeGenModule.h
index cede404..3311383 100644
--- a/clang/lib/CodeGen/CodeGenModule.h
+++ b/clang/lib/CodeGen/CodeGenModule.h
@@ -1106,15 +1106,21 @@
   void EmitVTableBitSetEntries(llvm::GlobalVariable *VTable,
                                const VTableLayout &VTLayout);
 
+  /// Generate a cross-DSO type identifier for type.
+  llvm::ConstantInt *CreateCfiIdForTypeMetadata(llvm::Metadata *MD);
+
   /// Create a metadata identifier for the given type. This may either be an
   /// MDString (for external identifiers) or a distinct unnamed MDNode (for
   /// internal identifiers).
   llvm::Metadata *CreateMetadataIdentifierForType(QualType T);
 
-  /// Create a bitset entry for the given vtable.
-  llvm::MDTuple *CreateVTableBitSetEntry(llvm::GlobalVariable *VTable,
-                                         CharUnits Offset,
-                                         const CXXRecordDecl *RD);
+  /// Create a bitset entry for the given function and add it to BitsetsMD.
+  void CreateFunctionBitSetEntry(const FunctionDecl *FD, llvm::Function *F);
+
+  /// Create a bitset entry for the given vtable and add it to BitsetsMD.
+  void CreateVTableBitSetEntry(llvm::NamedMDNode *BitsetsMD,
+                               llvm::GlobalVariable *VTable, CharUnits Offset,
+                               const CXXRecordDecl *RD);
 
   /// \breif Get the declaration of std::terminate for the platform.
   llvm::Constant *getTerminateFn();
diff --git a/clang/lib/CodeGen/MicrosoftCXXABI.cpp b/clang/lib/CodeGen/MicrosoftCXXABI.cpp
index f22c5b8..93210d5 100644
--- a/clang/lib/CodeGen/MicrosoftCXXABI.cpp
+++ b/clang/lib/CodeGen/MicrosoftCXXABI.cpp
@@ -1523,15 +1523,14 @@
 
   if (Info->PathToBaseWithVPtr.empty()) {
     if (!CGM.IsCFIBlacklistedRecord(RD))
-      BitsetsMD->addOperand(
-          CGM.CreateVTableBitSetEntry(VTable, AddressPoint, RD));
+      CGM.CreateVTableBitSetEntry(BitsetsMD, VTable, AddressPoint, RD);
     return;
   }
 
   // Add a bitset entry for the least derived base belonging to this vftable.
   if (!CGM.IsCFIBlacklistedRecord(Info->PathToBaseWithVPtr.back()))
-    BitsetsMD->addOperand(CGM.CreateVTableBitSetEntry(
-        VTable, AddressPoint, Info->PathToBaseWithVPtr.back()));
+    CGM.CreateVTableBitSetEntry(BitsetsMD, VTable, AddressPoint,
+                                Info->PathToBaseWithVPtr.back());
 
   // Add a bitset entry for each derived class that is laid out at the same
   // offset as the least derived base.
@@ -1550,14 +1549,12 @@
     if (!Offset.isZero())
       return;
     if (!CGM.IsCFIBlacklistedRecord(DerivedRD))
-      BitsetsMD->addOperand(
-          CGM.CreateVTableBitSetEntry(VTable, AddressPoint, DerivedRD));
+      CGM.CreateVTableBitSetEntry(BitsetsMD, VTable, AddressPoint, DerivedRD);
   }
 
   // Finally do the same for the most derived class.
   if (Info->FullOffsetInMDC.isZero() && !CGM.IsCFIBlacklistedRecord(RD))
-    BitsetsMD->addOperand(
-        CGM.CreateVTableBitSetEntry(VTable, AddressPoint, RD));
+    CGM.CreateVTableBitSetEntry(BitsetsMD, VTable, AddressPoint, RD);
 }
 
 void MicrosoftCXXABI::emitVTableDefinitions(CodeGenVTables &CGVT,