Implement bad cast checks using control flow integrity information.
This scheme checks that pointer and lvalue casts are made to an object of
the correct dynamic type; that is, the dynamic type of the object must be
a derived class of the pointee type of the cast. The checks are currently
only introduced where the class being casted to is a polymorphic class.
Differential Revision: http://reviews.llvm.org/D8312
llvm-svn: 232241
diff --git a/clang/lib/CodeGen/CGClass.cpp b/clang/lib/CodeGen/CGClass.cpp
index 5649708..84d6437 100644
--- a/clang/lib/CodeGen/CGClass.cpp
+++ b/clang/lib/CodeGen/CGClass.cpp
@@ -2093,7 +2093,96 @@
   if (!SanOpts.has(SanitizerKind::CFIVptr))
     return;
 
-  const CXXRecordDecl *RD = MD->getParent();
+  EmitVTablePtrCheck(MD->getParent(), VTable);
+}
+
+// If a class has a single non-virtual base and does not introduce or override
+// virtual member functions or fields, it will have the same layout as its base.
+// This function returns the least derived such class.
+//
+// Casting an instance of a base class to such a derived class is technically
+// undefined behavior, but it is a relatively common hack for introducing member
+// functions on class instances with specific properties (e.g. llvm::Operator)
+// that works under most compilers and should not have security implications, so
+// we allow it by default. It can be disabled with -fsanitize=cfi-cast-strict.
+static const CXXRecordDecl *
+LeastDerivedClassWithSameLayout(const CXXRecordDecl *RD) {
+  if (!RD->field_empty())
+    return RD;
+
+  if (RD->getNumVBases() != 0)
+    return RD;
+
+  if (RD->getNumBases() != 1)
+    return RD;
+
+  for (const CXXMethodDecl *MD : RD->methods()) {
+    if (MD->isVirtual()) {
+      // Virtual member functions are only ok if they are implicit destructors
+      // because the implicit destructor will have the same semantics as the
+      // base class's destructor if no fields are added.
+      if (isa<CXXDestructorDecl>(MD) && MD->isImplicit())
+        continue;
+      return RD;
+    }
+  }
+
+  return LeastDerivedClassWithSameLayout(
+      RD->bases_begin()->getType()->getAsCXXRecordDecl());
+}
+
+void CodeGenFunction::EmitVTablePtrCheckForCast(QualType T,
+                                                llvm::Value *Derived,
+                                                bool MayBeNull) {
+  if (!getLangOpts().CPlusPlus)
+    return;
+
+  auto *ClassTy = T->getAs<RecordType>();
+  if (!ClassTy)
+    return;
+
+  const CXXRecordDecl *ClassDecl = cast<CXXRecordDecl>(ClassTy->getDecl());
+
+  if (!ClassDecl->isCompleteDefinition() || !ClassDecl->isDynamicClass())
+    return;
+
+  SmallString<64> MangledName;
+  llvm::raw_svector_ostream Out(MangledName);
+  CGM.getCXXABI().getMangleContext().mangleCXXRTTI(T.getUnqualifiedType(),
+                                                   Out);
+
+  // Blacklist based on the mangled type.
+  if (CGM.getContext().getSanitizerBlacklist().isBlacklistedType(Out.str()))
+    return;
+
+  if (!SanOpts.has(SanitizerKind::CFICastStrict))
+    ClassDecl = LeastDerivedClassWithSameLayout(ClassDecl);
+
+  llvm::BasicBlock *ContBlock = 0;
+
+  if (MayBeNull) {
+    llvm::Value *DerivedNotNull =
+        Builder.CreateIsNotNull(Derived, "cast.nonnull");
+
+    llvm::BasicBlock *CheckBlock = createBasicBlock("cast.check");
+    ContBlock = createBasicBlock("cast.cont");
+
+    Builder.CreateCondBr(DerivedNotNull, CheckBlock, ContBlock);
+
+    EmitBlock(CheckBlock);
+  }
+
+  llvm::Value *VTable = GetVTablePtr(Derived, Int8PtrTy);
+  EmitVTablePtrCheck(ClassDecl, VTable);
+
+  if (MayBeNull) {
+    Builder.CreateBr(ContBlock);
+    EmitBlock(ContBlock);
+  }
+}
+
+void CodeGenFunction::EmitVTablePtrCheck(const CXXRecordDecl *RD,
+                                         llvm::Value *VTable) {
   // FIXME: Add blacklisting scheme.
   if (RD->isInStdNamespace())
     return;
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 4a38200..35275e5 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -3031,6 +3031,9 @@
       EmitTypeCheck(TCK_DowncastReference, E->getExprLoc(),
                     Derived, E->getType());
 
+    if (SanOpts.has(SanitizerKind::CFIDerivedCast))
+      EmitVTablePtrCheckForCast(E->getType(), Derived, /*MayBeNull=*/false);
+
     return MakeAddrLValue(Derived, E->getType());
   }
   case CK_LValueBitCast: {
@@ -3040,6 +3043,10 @@
     LValue LV = EmitLValue(E->getSubExpr());
     llvm::Value *V = Builder.CreateBitCast(LV.getAddress(),
                                            ConvertType(CE->getTypeAsWritten()));
+
+    if (SanOpts.has(SanitizerKind::CFIUnrelatedCast))
+      EmitVTablePtrCheckForCast(E->getType(), V, /*MayBeNull=*/false);
+
     return MakeAddrLValue(V, E->getType());
   }
   case CK_ObjCObjectLValueCast: {
diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index d969412..e5d612c 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -1355,6 +1355,13 @@
       llvm_unreachable("wrong cast for pointers in different address spaces"
                        "(must be an address space cast)!");
     }
+
+    if (CGF.SanOpts.has(SanitizerKind::CFIUnrelatedCast)) {
+      if (auto PT = DestTy->getAs<PointerType>())
+        CGF.EmitVTablePtrCheckForCast(PT->getPointeeType(), Src,
+                                      /*MayBeNull=*/true);
+    }
+
     return Builder.CreateBitCast(Src, DstTy);
   }
   case CK_AddressSpaceConversion: {
@@ -1384,6 +1391,10 @@
       CGF.EmitTypeCheck(CodeGenFunction::TCK_DowncastPointer, CE->getExprLoc(),
                         Derived, DestTy->getPointeeType());
 
+    if (CGF.SanOpts.has(SanitizerKind::CFIDerivedCast))
+      CGF.EmitVTablePtrCheckForCast(DestTy->getPointeeType(), Derived,
+                                    /*MayBeNull=*/true);
+
     return Derived;
   }
   case CK_UncheckedDerivedToBase:
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index 1795ba3..474668a 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -1288,10 +1288,20 @@
   /// to by This.
   llvm::Value *GetVTablePtr(llvm::Value *This, llvm::Type *Ty);
 
+  /// \brief Derived is the presumed address of an object of type T after a
+  /// cast. If T is a polymorphic class type, emit a check that the virtual
+  /// table for Derived belongs to a class derived from T.
+  void EmitVTablePtrCheckForCast(QualType T, llvm::Value *Derived,
+                                 bool MayBeNull);
+
   /// EmitVTablePtrCheckForCall - Virtual method MD is being called via VTable.
   /// If vptr CFI is enabled, emit a check that VTable is valid.
   void EmitVTablePtrCheckForCall(const CXXMethodDecl *MD, llvm::Value *VTable);
 
+  /// EmitVTablePtrCheck - Emit a check that VTable is a valid virtual table for
+  /// RD using llvm.bitset.test.
+  void EmitVTablePtrCheck(const CXXRecordDecl *RD, llvm::Value *VTable);
+
   /// CanDevirtualizeMemberFunctionCalls - Checks whether virtual calls on given
   /// expr can be devirtualized.
   bool CanDevirtualizeMemberFunctionCall(const Expr *Base,