Stephen Hines | 37ed9c1 | 2014-12-01 14:51:49 -0800 | [diff] [blame] | 1 | //===-- ForwardControlFlowIntegrity.cpp: Forward-Edge CFI -----------------===// |
| 2 | // |
| 3 | // This file is distributed under the University of Illinois Open Source |
| 4 | // License. See LICENSE.TXT for details. |
| 5 | // |
| 6 | //===----------------------------------------------------------------------===// |
| 7 | /// |
| 8 | /// \file |
| 9 | /// \brief A pass that instruments code with fast checks for indirect calls and |
| 10 | /// hooks for a function to check violations. |
| 11 | /// |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #define DEBUG_TYPE "cfi" |
| 15 | |
| 16 | #include "llvm/ADT/SmallVector.h" |
| 17 | #include "llvm/ADT/Statistic.h" |
| 18 | #include "llvm/Analysis/JumpInstrTableInfo.h" |
| 19 | #include "llvm/CodeGen/ForwardControlFlowIntegrity.h" |
| 20 | #include "llvm/CodeGen/JumpInstrTables.h" |
| 21 | #include "llvm/CodeGen/Passes.h" |
| 22 | #include "llvm/IR/Attributes.h" |
| 23 | #include "llvm/IR/CallSite.h" |
| 24 | #include "llvm/IR/Constants.h" |
| 25 | #include "llvm/IR/DerivedTypes.h" |
| 26 | #include "llvm/IR/Function.h" |
| 27 | #include "llvm/IR/GlobalValue.h" |
Stephen Hines | 37ed9c1 | 2014-12-01 14:51:49 -0800 | [diff] [blame] | 28 | #include "llvm/IR/IRBuilder.h" |
Stephen Hines | ebe69fe | 2015-03-23 12:10:34 -0700 | [diff] [blame] | 29 | #include "llvm/IR/InlineAsm.h" |
| 30 | #include "llvm/IR/Instructions.h" |
Stephen Hines | 37ed9c1 | 2014-12-01 14:51:49 -0800 | [diff] [blame] | 31 | #include "llvm/IR/LLVMContext.h" |
| 32 | #include "llvm/IR/Module.h" |
| 33 | #include "llvm/IR/Operator.h" |
| 34 | #include "llvm/IR/Type.h" |
| 35 | #include "llvm/IR/Verifier.h" |
| 36 | #include "llvm/Pass.h" |
| 37 | #include "llvm/Support/CommandLine.h" |
| 38 | #include "llvm/Support/Debug.h" |
| 39 | #include "llvm/Support/raw_ostream.h" |
| 40 | |
| 41 | using namespace llvm; |
| 42 | |
| 43 | STATISTIC(NumCFIIndirectCalls, |
| 44 | "Number of indirect call sites rewritten by the CFI pass"); |
| 45 | |
| 46 | char ForwardControlFlowIntegrity::ID = 0; |
| 47 | INITIALIZE_PASS_BEGIN(ForwardControlFlowIntegrity, "forward-cfi", |
| 48 | "Control-Flow Integrity", true, true) |
| 49 | INITIALIZE_PASS_DEPENDENCY(JumpInstrTableInfo); |
| 50 | INITIALIZE_PASS_DEPENDENCY(JumpInstrTables); |
| 51 | INITIALIZE_PASS_END(ForwardControlFlowIntegrity, "forward-cfi", |
| 52 | "Control-Flow Integrity", true, true) |
| 53 | |
| 54 | ModulePass *llvm::createForwardControlFlowIntegrityPass() { |
| 55 | return new ForwardControlFlowIntegrity(); |
| 56 | } |
| 57 | |
| 58 | ModulePass *llvm::createForwardControlFlowIntegrityPass( |
| 59 | JumpTable::JumpTableType JTT, CFIntegrity CFIType, bool CFIEnforcing, |
| 60 | StringRef CFIFuncName) { |
| 61 | return new ForwardControlFlowIntegrity(JTT, CFIType, CFIEnforcing, |
| 62 | CFIFuncName); |
| 63 | } |
| 64 | |
| 65 | // Checks to see if a given CallSite is making an indirect call, including |
| 66 | // cases where the indirect call is made through a bitcast. |
| 67 | static bool isIndirectCall(CallSite &CS) { |
| 68 | if (CS.getCalledFunction()) |
| 69 | return false; |
| 70 | |
| 71 | // Check the value to see if it is merely a bitcast of a function. In |
| 72 | // this case, it will translate to a direct function call in the resulting |
| 73 | // assembly, so we won't treat it as an indirect call here. |
| 74 | const Value *V = CS.getCalledValue(); |
| 75 | if (const ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) { |
| 76 | return !(CE->isCast() && isa<Function>(CE->getOperand(0))); |
| 77 | } |
| 78 | |
| 79 | // Otherwise, since we know it's a call, it must be an indirect call |
| 80 | return true; |
| 81 | } |
| 82 | |
| 83 | static const char cfi_failure_func_name[] = "__llvm_cfi_pointer_warning"; |
| 84 | |
| 85 | ForwardControlFlowIntegrity::ForwardControlFlowIntegrity() |
| 86 | : ModulePass(ID), IndirectCalls(), JTType(JumpTable::Single), |
| 87 | CFIType(CFIntegrity::Sub), CFIEnforcing(false), CFIFuncName("") { |
| 88 | initializeForwardControlFlowIntegrityPass(*PassRegistry::getPassRegistry()); |
| 89 | } |
| 90 | |
| 91 | ForwardControlFlowIntegrity::ForwardControlFlowIntegrity( |
| 92 | JumpTable::JumpTableType JTT, CFIntegrity CFIType, bool CFIEnforcing, |
| 93 | std::string CFIFuncName) |
| 94 | : ModulePass(ID), IndirectCalls(), JTType(JTT), CFIType(CFIType), |
| 95 | CFIEnforcing(CFIEnforcing), CFIFuncName(CFIFuncName) { |
| 96 | initializeForwardControlFlowIntegrityPass(*PassRegistry::getPassRegistry()); |
| 97 | } |
| 98 | |
| 99 | ForwardControlFlowIntegrity::~ForwardControlFlowIntegrity() {} |
| 100 | |
| 101 | void ForwardControlFlowIntegrity::getAnalysisUsage(AnalysisUsage &AU) const { |
| 102 | AU.addRequired<JumpInstrTableInfo>(); |
| 103 | AU.addRequired<JumpInstrTables>(); |
| 104 | } |
| 105 | |
| 106 | void ForwardControlFlowIntegrity::getIndirectCalls(Module &M) { |
| 107 | // To get the indirect calls, we iterate over all functions and iterate over |
| 108 | // the list of basic blocks in each. We extract a total list of indirect calls |
| 109 | // before modifying any of them, since our modifications will modify the list |
| 110 | // of basic blocks. |
| 111 | for (Function &F : M) { |
| 112 | for (BasicBlock &BB : F) { |
| 113 | for (Instruction &I : BB) { |
| 114 | CallSite CS(&I); |
| 115 | if (!(CS && isIndirectCall(CS))) |
| 116 | continue; |
| 117 | |
| 118 | Value *CalledValue = CS.getCalledValue(); |
| 119 | |
| 120 | // Don't rewrite this instruction if the indirect call is actually just |
| 121 | // inline assembly, since our transformation will generate an invalid |
| 122 | // module in that case. |
| 123 | if (isa<InlineAsm>(CalledValue)) |
| 124 | continue; |
| 125 | |
| 126 | IndirectCalls.push_back(&I); |
| 127 | } |
| 128 | } |
| 129 | } |
| 130 | } |
| 131 | |
| 132 | void ForwardControlFlowIntegrity::updateIndirectCalls(Module &M, |
| 133 | CFITables &CFIT) { |
| 134 | Type *Int64Ty = Type::getInt64Ty(M.getContext()); |
| 135 | for (Instruction *I : IndirectCalls) { |
| 136 | CallSite CS(I); |
| 137 | Value *CalledValue = CS.getCalledValue(); |
| 138 | |
| 139 | // Get the function type for this call and look it up in the tables. |
| 140 | Type *VTy = CalledValue->getType(); |
| 141 | PointerType *PTy = dyn_cast<PointerType>(VTy); |
| 142 | Type *EltTy = PTy->getElementType(); |
| 143 | FunctionType *FunTy = dyn_cast<FunctionType>(EltTy); |
| 144 | FunctionType *TransformedTy = JumpInstrTables::transformType(JTType, FunTy); |
| 145 | ++NumCFIIndirectCalls; |
| 146 | Constant *JumpTableStart = nullptr; |
| 147 | Constant *JumpTableMask = nullptr; |
| 148 | Constant *JumpTableSize = nullptr; |
| 149 | |
| 150 | // Some call sites have function types that don't correspond to any |
| 151 | // address-taken function in the module. This happens when function pointers |
| 152 | // are passed in from external code. |
| 153 | auto it = CFIT.find(TransformedTy); |
| 154 | if (it == CFIT.end()) { |
| 155 | // In this case, make sure that the function pointer will change by |
| 156 | // setting the mask and the start to be 0 so that the transformed |
| 157 | // function is 0. |
| 158 | JumpTableStart = ConstantInt::get(Int64Ty, 0); |
| 159 | JumpTableMask = ConstantInt::get(Int64Ty, 0); |
| 160 | JumpTableSize = ConstantInt::get(Int64Ty, 0); |
| 161 | } else { |
| 162 | JumpTableStart = it->second.StartValue; |
| 163 | JumpTableMask = it->second.MaskValue; |
| 164 | JumpTableSize = it->second.Size; |
| 165 | } |
| 166 | |
| 167 | rewriteFunctionPointer(M, I, CalledValue, JumpTableStart, JumpTableMask, |
| 168 | JumpTableSize); |
| 169 | } |
| 170 | |
| 171 | return; |
| 172 | } |
| 173 | |
| 174 | bool ForwardControlFlowIntegrity::runOnModule(Module &M) { |
| 175 | JumpInstrTableInfo *JITI = &getAnalysis<JumpInstrTableInfo>(); |
| 176 | Type *Int64Ty = Type::getInt64Ty(M.getContext()); |
| 177 | Type *VoidPtrTy = Type::getInt8PtrTy(M.getContext()); |
| 178 | |
| 179 | // JumpInstrTableInfo stores information about the alignment of each entry. |
| 180 | // The alignment returned by JumpInstrTableInfo is alignment in bytes, not |
| 181 | // in the exponent. |
| 182 | ByteAlignment = JITI->entryByteAlignment(); |
| 183 | LogByteAlignment = llvm::Log2_64(ByteAlignment); |
| 184 | |
| 185 | // Set up tables for control-flow integrity based on information about the |
| 186 | // jump-instruction tables. |
| 187 | CFITables CFIT; |
| 188 | for (const auto &KV : JITI->getTables()) { |
| 189 | uint64_t Size = static_cast<uint64_t>(KV.second.size()); |
| 190 | uint64_t TableSize = NextPowerOf2(Size); |
| 191 | |
| 192 | int64_t MaskValue = ((TableSize << LogByteAlignment) - 1) & -ByteAlignment; |
| 193 | Constant *JumpTableMaskValue = ConstantInt::get(Int64Ty, MaskValue); |
| 194 | Constant *JumpTableSize = ConstantInt::get(Int64Ty, Size); |
| 195 | |
| 196 | // The base of the table is defined to be the first jumptable function in |
| 197 | // the table. |
| 198 | Function *First = KV.second.begin()->second; |
| 199 | Constant *JumpTableStartValue = ConstantExpr::getBitCast(First, VoidPtrTy); |
| 200 | CFIT[KV.first].StartValue = JumpTableStartValue; |
| 201 | CFIT[KV.first].MaskValue = JumpTableMaskValue; |
| 202 | CFIT[KV.first].Size = JumpTableSize; |
| 203 | } |
| 204 | |
| 205 | if (CFIT.empty()) |
| 206 | return false; |
| 207 | |
| 208 | getIndirectCalls(M); |
| 209 | |
| 210 | if (!CFIEnforcing) { |
| 211 | addWarningFunction(M); |
| 212 | } |
| 213 | |
| 214 | // Update the instructions with the check and the indirect jump through our |
| 215 | // table. |
| 216 | updateIndirectCalls(M, CFIT); |
| 217 | |
| 218 | return true; |
| 219 | } |
| 220 | |
| 221 | void ForwardControlFlowIntegrity::addWarningFunction(Module &M) { |
| 222 | PointerType *CharPtrTy = Type::getInt8PtrTy(M.getContext()); |
| 223 | |
| 224 | // Get the type of the Warning Function: void (i8*, i8*), |
| 225 | // where the first argument is the name of the function in which the violation |
| 226 | // occurs, and the second is the function pointer that violates CFI. |
| 227 | SmallVector<Type *, 2> WarningFunArgs; |
| 228 | WarningFunArgs.push_back(CharPtrTy); |
| 229 | WarningFunArgs.push_back(CharPtrTy); |
| 230 | FunctionType *WarningFunTy = |
| 231 | FunctionType::get(Type::getVoidTy(M.getContext()), WarningFunArgs, false); |
| 232 | |
| 233 | if (!CFIFuncName.empty()) { |
| 234 | Constant *FailureFun = M.getOrInsertFunction(CFIFuncName, WarningFunTy); |
| 235 | if (!FailureFun) |
| 236 | report_fatal_error("Could not get or insert the function specified by" |
| 237 | " -cfi-func-name"); |
| 238 | } else { |
| 239 | // The default warning function swallows the warning and lets the call |
| 240 | // continue, since there's no generic way for it to print out this |
| 241 | // information. |
| 242 | Function *WarningFun = M.getFunction(cfi_failure_func_name); |
| 243 | if (!WarningFun) { |
| 244 | WarningFun = |
| 245 | Function::Create(WarningFunTy, GlobalValue::LinkOnceAnyLinkage, |
| 246 | cfi_failure_func_name, &M); |
| 247 | } |
| 248 | |
| 249 | BasicBlock *Entry = |
| 250 | BasicBlock::Create(M.getContext(), "entry", WarningFun, 0); |
| 251 | ReturnInst::Create(M.getContext(), Entry); |
| 252 | } |
| 253 | } |
| 254 | |
| 255 | void ForwardControlFlowIntegrity::rewriteFunctionPointer( |
| 256 | Module &M, Instruction *I, Value *FunPtr, Constant *JumpTableStart, |
| 257 | Constant *JumpTableMask, Constant *JumpTableSize) { |
| 258 | IRBuilder<> TempBuilder(I); |
| 259 | |
| 260 | Type *OrigFunType = FunPtr->getType(); |
| 261 | |
| 262 | BasicBlock *CurBB = cast<BasicBlock>(I->getParent()); |
| 263 | Function *CurF = cast<Function>(CurBB->getParent()); |
| 264 | Type *Int64Ty = Type::getInt64Ty(M.getContext()); |
| 265 | |
| 266 | Value *TI = TempBuilder.CreatePtrToInt(FunPtr, Int64Ty); |
| 267 | Value *TStartInt = TempBuilder.CreatePtrToInt(JumpTableStart, Int64Ty); |
| 268 | |
| 269 | Value *NewFunPtr = nullptr; |
| 270 | Value *Check = nullptr; |
| 271 | switch (CFIType) { |
| 272 | case CFIntegrity::Sub: { |
| 273 | // This is the subtract, mask, and add version. |
| 274 | // Subtract from the base. |
| 275 | Value *Sub = TempBuilder.CreateSub(TI, TStartInt); |
| 276 | |
| 277 | // Mask the difference to force this to be a table offset. |
| 278 | Value *And = TempBuilder.CreateAnd(Sub, JumpTableMask); |
| 279 | |
| 280 | // Add it back to the base. |
| 281 | Value *Result = TempBuilder.CreateAdd(And, TStartInt); |
| 282 | |
| 283 | // Convert it back into a function pointer that we can call. |
| 284 | NewFunPtr = TempBuilder.CreateIntToPtr(Result, OrigFunType); |
| 285 | break; |
| 286 | } |
| 287 | case CFIntegrity::Ror: { |
| 288 | // This is the subtract and rotate version. |
| 289 | // Rotate right by the alignment value. The optimizer should recognize |
| 290 | // this sequence as a rotation. |
| 291 | |
| 292 | // This cast is safe, since unsigned is always a subset of uint64_t. |
| 293 | uint64_t LogByteAlignment64 = static_cast<uint64_t>(LogByteAlignment); |
| 294 | Constant *RightShift = ConstantInt::get(Int64Ty, LogByteAlignment64); |
| 295 | Constant *LeftShift = ConstantInt::get(Int64Ty, 64 - LogByteAlignment64); |
| 296 | |
| 297 | // Subtract from the base. |
| 298 | Value *Sub = TempBuilder.CreateSub(TI, TStartInt); |
| 299 | |
| 300 | // Create the equivalent of a rotate-right instruction. |
| 301 | Value *Shr = TempBuilder.CreateLShr(Sub, RightShift); |
| 302 | Value *Shl = TempBuilder.CreateShl(Sub, LeftShift); |
| 303 | Value *Or = TempBuilder.CreateOr(Shr, Shl); |
| 304 | |
| 305 | // Perform unsigned comparison to check for inclusion in the table. |
| 306 | Check = TempBuilder.CreateICmpULT(Or, JumpTableSize); |
| 307 | NewFunPtr = FunPtr; |
| 308 | break; |
| 309 | } |
| 310 | case CFIntegrity::Add: { |
| 311 | // This is the mask and add version. |
| 312 | // Mask the function pointer to turn it into an offset into the table. |
| 313 | Value *And = TempBuilder.CreateAnd(TI, JumpTableMask); |
| 314 | |
| 315 | // Then or this offset to the base and get the pointer value. |
| 316 | Value *Result = TempBuilder.CreateAdd(And, TStartInt); |
| 317 | |
| 318 | // Convert it back into a function pointer that we can call. |
| 319 | NewFunPtr = TempBuilder.CreateIntToPtr(Result, OrigFunType); |
| 320 | break; |
| 321 | } |
| 322 | } |
| 323 | |
| 324 | if (!CFIEnforcing) { |
| 325 | // If a check hasn't been added (in the rotation version), then check to see |
| 326 | // if it's the same as the original function. This check determines whether |
| 327 | // or not we call the CFI failure function. |
| 328 | if (!Check) |
| 329 | Check = TempBuilder.CreateICmpEQ(NewFunPtr, FunPtr); |
| 330 | BasicBlock *InvalidPtrBlock = |
| 331 | BasicBlock::Create(M.getContext(), "invalid.ptr", CurF, 0); |
| 332 | BasicBlock *ContinuationBB = CurBB->splitBasicBlock(I); |
| 333 | |
| 334 | // Remove the unconditional branch that connects the two blocks. |
| 335 | TerminatorInst *TermInst = CurBB->getTerminator(); |
| 336 | TermInst->eraseFromParent(); |
| 337 | |
| 338 | // Add a conditional branch that depends on the Check above. |
| 339 | BranchInst::Create(ContinuationBB, InvalidPtrBlock, Check, CurBB); |
| 340 | |
| 341 | // Call the warning function for this pointer, then continue. |
| 342 | Instruction *BI = BranchInst::Create(ContinuationBB, InvalidPtrBlock); |
| 343 | insertWarning(M, InvalidPtrBlock, BI, FunPtr); |
| 344 | } else { |
| 345 | // Modify the instruction to call this value. |
| 346 | CallSite CS(I); |
| 347 | CS.setCalledFunction(NewFunPtr); |
| 348 | } |
| 349 | } |
| 350 | |
| 351 | void ForwardControlFlowIntegrity::insertWarning(Module &M, BasicBlock *Block, |
| 352 | Instruction *I, Value *FunPtr) { |
| 353 | Function *ParentFun = cast<Function>(Block->getParent()); |
| 354 | |
| 355 | // Get the function to call right before the instruction. |
| 356 | Function *WarningFun = nullptr; |
| 357 | if (CFIFuncName.empty()) { |
| 358 | WarningFun = M.getFunction(cfi_failure_func_name); |
| 359 | } else { |
| 360 | WarningFun = M.getFunction(CFIFuncName); |
| 361 | } |
| 362 | |
| 363 | assert(WarningFun && "Could not find the CFI failure function"); |
| 364 | |
| 365 | Type *VoidPtrTy = Type::getInt8PtrTy(M.getContext()); |
| 366 | |
| 367 | IRBuilder<> WarningInserter(I); |
| 368 | // Create a mergeable GlobalVariable containing the name of the function. |
| 369 | Value *ParentNameGV = |
| 370 | WarningInserter.CreateGlobalString(ParentFun->getName()); |
| 371 | Value *ParentNamePtr = WarningInserter.CreateBitCast(ParentNameGV, VoidPtrTy); |
| 372 | Value *FunVoidPtr = WarningInserter.CreateBitCast(FunPtr, VoidPtrTy); |
| 373 | WarningInserter.CreateCall2(WarningFun, ParentNamePtr, FunVoidPtr); |
| 374 | } |