CFI: Introduce -fsanitize=cfi-icall flag.

This flag causes the compiler to emit bit set entries for functions as well
as runtime bitset checks at indirect call sites. Depends on the new function
bitset mechanism.

Differential Revision: http://reviews.llvm.org/D11857

llvm-svn: 247238
diff --git a/clang/lib/CodeGen/CGClass.cpp b/clang/lib/CodeGen/CGClass.cpp
index dc65b14..56787e5 100644
--- a/clang/lib/CodeGen/CGClass.cpp
+++ b/clang/lib/CodeGen/CGClass.cpp
@@ -2473,12 +2473,9 @@
 
   SanitizerScope SanScope(this);
 
-  std::string OutName;
-  llvm::raw_string_ostream Out(OutName);
-  CGM.getCXXABI().getMangleContext().mangleCXXVTableBitSet(RD, Out);
-
   llvm::Value *BitSetName = llvm::MetadataAsValue::get(
-      getLLVMContext(), llvm::MDString::get(getLLVMContext(), Out.str()));
+      getLLVMContext(),
+      CGM.CreateMetadataIdentifierForType(QualType(RD->getTypeForDecl(), 0)));
 
   llvm::Value *CastedVTable = Builder.CreateBitCast(VTable, Int8PtrTy);
   llvm::Value *BitSetTest =
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 8339444..5e6c4de 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -3780,6 +3780,29 @@
     }
   }
 
+  // If we are checking indirect calls and this call is indirect, check that the
+  // function pointer is a member of the bit set for the function type.
+  if (SanOpts.has(SanitizerKind::CFIICall) &&
+      (!TargetDecl || !isa<FunctionDecl>(TargetDecl))) {
+    SanitizerScope SanScope(this);
+
+    llvm::Value *BitSetName = llvm::MetadataAsValue::get(
+        getLLVMContext(),
+        CGM.CreateMetadataIdentifierForType(QualType(FnType, 0)));
+
+    llvm::Value *CastedCallee = Builder.CreateBitCast(Callee, Int8PtrTy);
+    llvm::Value *BitSetTest =
+        Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::bitset_test),
+                           {CastedCallee, BitSetName});
+
+    llvm::Constant *StaticData[] = {
+      EmitCheckSourceLocation(E->getLocStart()),
+      EmitCheckTypeDescriptor(QualType(FnType, 0)),
+    };
+    EmitCheck(std::make_pair(BitSetTest, SanitizerKind::CFIICall),
+              "cfi_bad_icall", StaticData, CastedCallee);
+  }
+
   CallArgList Args;
   if (Chain)
     Args.add(RValue::get(Builder.CreateBitCast(Chain, CGM.VoidPtrTy)),
diff --git a/clang/lib/CodeGen/CGVTables.cpp b/clang/lib/CodeGen/CGVTables.cpp
index 6aa624e..4c3202c 100644
--- a/clang/lib/CodeGen/CGVTables.cpp
+++ b/clang/lib/CodeGen/CGVTables.cpp
@@ -893,41 +893,45 @@
   CharUnits PointerWidth =
       Context.toCharUnitsFromBits(Context.getTargetInfo().getPointerWidth(0));
 
-  std::vector<llvm::MDTuple *> BitsetEntries;
+  typedef std::pair<const CXXRecordDecl *, unsigned> BSEntry;
+  std::vector<BSEntry> BitsetEntries;
   // Create a bit set entry for each address point.
   for (auto &&AP : VTLayout.getAddressPoints()) {
     if (IsCFIBlacklistedRecord(AP.first.getBase()))
       continue;
 
-    BitsetEntries.push_back(CreateVTableBitSetEntry(
-        VTable, PointerWidth * AP.second, AP.first.getBase()));
+    BitsetEntries.push_back(std::make_pair(AP.first.getBase(), AP.second));
   }
 
   // Sort the bit set entries for determinism.
-  std::sort(BitsetEntries.begin(), BitsetEntries.end(), [](llvm::MDTuple *T1,
-                                                           llvm::MDTuple *T2) {
-    if (T1 == T2)
+  std::sort(BitsetEntries.begin(), BitsetEntries.end(),
+            [this](const BSEntry &E1, const BSEntry &E2) {
+    if (&E1 == &E2)
       return false;
 
-    StringRef S1 = cast<llvm::MDString>(T1->getOperand(0))->getString();
-    StringRef S2 = cast<llvm::MDString>(T2->getOperand(0))->getString();
+    std::string S1;
+    llvm::raw_string_ostream O1(S1);
+    getCXXABI().getMangleContext().mangleTypeName(
+        QualType(E1.first->getTypeForDecl(), 0), O1);
+    O1.flush();
+
+    std::string S2;
+    llvm::raw_string_ostream O2(S2);
+    getCXXABI().getMangleContext().mangleTypeName(
+        QualType(E2.first->getTypeForDecl(), 0), O2);
+    O2.flush();
+
     if (S1 < S2)
       return true;
     if (S1 != S2)
       return false;
 
-    uint64_t Offset1 = cast<llvm::ConstantInt>(
-                           cast<llvm::ConstantAsMetadata>(T1->getOperand(2))
-                               ->getValue())->getZExtValue();
-    uint64_t Offset2 = cast<llvm::ConstantInt>(
-                           cast<llvm::ConstantAsMetadata>(T2->getOperand(2))
-                               ->getValue())->getZExtValue();
-    assert(Offset1 != Offset2);
-    return Offset1 < Offset2;
+    return E1.second < E2.second;
   });
 
   llvm::NamedMDNode *BitsetsMD =
       getModule().getOrInsertNamedMetadata("llvm.bitsets");
   for (auto BitsetEntry : BitsetEntries)
-    BitsetsMD->addOperand(BitsetEntry);
+    BitsetsMD->addOperand(CreateVTableBitSetEntry(
+        VTable, PointerWidth * BitsetEntry.second, BitsetEntry.first));
 }
diff --git a/clang/lib/CodeGen/CodeGenModule.cpp b/clang/lib/CodeGen/CodeGenModule.cpp
index a2d554b..8f0259d 100644
--- a/clang/lib/CodeGen/CodeGenModule.cpp
+++ b/clang/lib/CodeGen/CodeGenModule.cpp
@@ -941,6 +941,20 @@
   if (FD->isReplaceableGlobalAllocationFunction())
     F->addAttribute(llvm::AttributeSet::FunctionIndex,
                     llvm::Attribute::NoBuiltin);
+
+  // If we are checking indirect calls and this is not a non-static member
+  // function, emit a bit set entry for the function type.
+  if (LangOpts.Sanitize.has(SanitizerKind::CFIICall) &&
+      !(isa<CXXMethodDecl>(FD) && !cast<CXXMethodDecl>(FD)->isStatic())) {
+    llvm::NamedMDNode *BitsetsMD =
+        getModule().getOrInsertNamedMetadata("llvm.bitsets");
+
+    llvm::Metadata *BitsetOps[] = {
+        CreateMetadataIdentifierForType(FD->getType()),
+        llvm::ConstantAsMetadata::get(F),
+        llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(Int64Ty, 0))};
+    BitsetsMD->addOperand(llvm::MDTuple::get(getLLVMContext(), BitsetOps));
+  }
 }
 
 void CodeGenModule::addUsedGlobal(llvm::GlobalValue *GV) {
@@ -3824,12 +3838,8 @@
 
 llvm::MDTuple *CodeGenModule::CreateVTableBitSetEntry(
     llvm::GlobalVariable *VTable, CharUnits Offset, const CXXRecordDecl *RD) {
-  std::string OutName;
-  llvm::raw_string_ostream Out(OutName);
-  getCXXABI().getMangleContext().mangleCXXVTableBitSet(RD, Out);
-
   llvm::Metadata *BitsetOps[] = {
-      llvm::MDString::get(getLLVMContext(), Out.str()),
+      CreateMetadataIdentifierForType(QualType(RD->getTypeForDecl(), 0)),
       llvm::ConstantAsMetadata::get(VTable),
       llvm::ConstantAsMetadata::get(
           llvm::ConstantInt::get(Int64Ty, Offset.getQuantity()))};