[HWASAN] Add support for memory intrinsics

Differential revision: https://reviews.llvm.org/D55117

llvm-svn: 349728
diff --git a/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp
index 042e8ea..22c9163 100644
--- a/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp
+++ b/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp
@@ -152,6 +152,10 @@
                               cl::desc("create static frame descriptions"),
                               cl::Hidden, cl::init(true));
 
+static cl::opt<bool>
+    ClInstrumentMemIntrinsics("hwasan-instrument-mem-intrinsics",
+                              cl::desc("instrument memory intrinsics"),
+                              cl::Hidden, cl::init(false));
 namespace {
 
 /// An instrumentation pass implementing detection of addressability bugs
@@ -182,6 +186,7 @@
   void instrumentMemAccessInline(Value *PtrLong, bool IsWrite,
                                  unsigned AccessSizeIndex,
                                  Instruction *InsertBefore);
+  void instrumentMemIntrinsic(MemIntrinsic *MI);
   bool instrumentMemAccess(Instruction *I);
   Value *isInterestingMemoryAccess(Instruction *I, bool *IsWrite,
                                    uint64_t *TypeSize, unsigned *Alignment,
@@ -206,6 +211,7 @@
   LLVMContext *C;
   std::string CurModuleUniqueId;
   Triple TargetTriple;
+  Function *HWAsanMemmove, *HWAsanMemcpy, *HWAsanMemset;
 
   // Frame description is a way to pass names/sizes of local variables
   // to the run-time w/o adding extra executable code in every function.
@@ -373,6 +379,18 @@
   if (Mapping.InGlobal)
     ShadowGlobal = M.getOrInsertGlobal("__hwasan_shadow",
                                        ArrayType::get(IRB.getInt8Ty(), 0));
+
+  const std::string MemIntrinCallbackPrefix =
+      CompileKernel ? std::string("") : ClMemoryAccessCallbackPrefix;
+  HWAsanMemmove = checkSanitizerInterfaceFunction(M.getOrInsertFunction(
+      MemIntrinCallbackPrefix + "memmove", IRB.getInt8PtrTy(),
+      IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy));
+  HWAsanMemcpy = checkSanitizerInterfaceFunction(M.getOrInsertFunction(
+      MemIntrinCallbackPrefix + "memcpy", IRB.getInt8PtrTy(),
+      IRB.getInt8PtrTy(), IRB.getInt8PtrTy(), IntptrTy));
+  HWAsanMemset = checkSanitizerInterfaceFunction(M.getOrInsertFunction(
+      MemIntrinCallbackPrefix + "memset", IRB.getInt8PtrTy(),
+      IRB.getInt8PtrTy(), IRB.getInt32Ty(), IntptrTy));
 }
 
 Value *HWAddressSanitizer::getDynamicShadowNonTls(IRBuilder<> &IRB) {
@@ -548,12 +566,36 @@
   IRB.CreateCall(Asm, PtrLong);
 }
 
+void HWAddressSanitizer::instrumentMemIntrinsic(MemIntrinsic *MI) {
+  IRBuilder<> IRB(MI);
+  if (isa<MemTransferInst>(MI)) {
+    IRB.CreateCall(
+        isa<MemMoveInst>(MI) ? HWAsanMemmove : HWAsanMemcpy,
+        {IRB.CreatePointerCast(MI->getOperand(0), IRB.getInt8PtrTy()),
+         IRB.CreatePointerCast(MI->getOperand(1), IRB.getInt8PtrTy()),
+         IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)});
+  } else if (isa<MemSetInst>(MI)) {
+    IRB.CreateCall(
+        HWAsanMemset,
+        {IRB.CreatePointerCast(MI->getOperand(0), IRB.getInt8PtrTy()),
+         IRB.CreateIntCast(MI->getOperand(1), IRB.getInt32Ty(), false),
+         IRB.CreateIntCast(MI->getOperand(2), IntptrTy, false)});
+  }
+  MI->eraseFromParent();
+}
+
 bool HWAddressSanitizer::instrumentMemAccess(Instruction *I) {
   LLVM_DEBUG(dbgs() << "Instrumenting: " << *I << "\n");
   bool IsWrite = false;
   unsigned Alignment = 0;
   uint64_t TypeSize = 0;
   Value *MaybeMask = nullptr;
+
+  if (ClInstrumentMemIntrinsics && isa<MemIntrinsic>(I)) {
+    instrumentMemIntrinsic(cast<MemIntrinsic>(I));
+    return true;
+  }
+
   Value *Addr =
       isInterestingMemoryAccess(I, &IsWrite, &TypeSize, &Alignment, &MaybeMask);