[SimplifyLibCalls] Add dereferenceable bytes from known callsites

Summary:
int mm(char *a, char *b) {
    return memcmp(a,b,16);
}

Currently:
define dso_local i32 @mm(i8* nocapture readonly %a, i8* nocapture readonly %b) local_unnamed_addr #1 {
entry:
  %call = tail call i32 @memcmp(i8* %a, i8* %b, i64 16)
  ret i32 %call
}

After patch:
define dso_local i32 @mm(i8* nocapture readonly %a, i8* nocapture readonly %b) local_unnamed_addr #1 {
entry:
  %call = tail call i32 @memcmp(i8* dereferenceable(16)  %a, i8* dereferenceable(16)  %b, i64 16)
  ret i32 %call
}




Reviewers: jdoerfert, efriedma

Reviewed By: jdoerfert

Subscribers: javed.absar, spatel, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D66079

llvm-svn: 368657
diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
index 32b845c..396690e 100644
--- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
@@ -186,6 +186,20 @@
   return true;
 }
 
+static void annotateDereferenceableBytes(CallInst *CI,
+                                         ArrayRef<unsigned> ArgNos,
+                                         uint64_t DerefBytes) {
+  for (unsigned ArgNo : ArgNos) {
+    if (CI->getDereferenceableBytes(ArgNo + AttributeList::FirstArgIndex) <
+        DerefBytes) {
+      CI->removeParamAttr(ArgNo, Attribute::Dereferenceable);
+      CI->removeParamAttr(ArgNo, Attribute::DereferenceableOrNull);
+      CI->addParamAttr(ArgNo, Attribute::getWithDereferenceableBytes(
+                                  CI->getContext(), DerefBytes));
+    }
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // String and Memory Library Call Optimizations
 //===----------------------------------------------------------------------===//
@@ -765,9 +779,11 @@
   ConstantInt *LenC = dyn_cast<ConstantInt>(CI->getArgOperand(2));
 
   // memchr(x, y, 0) -> null
-  if (LenC && LenC->isZero())
-    return Constant::getNullValue(CI->getType());
-
+  if (LenC) {
+    if (LenC->isZero())
+      return Constant::getNullValue(CI->getType());
+    annotateDereferenceableBytes(CI, {0}, LenC->getZExtValue());
+  }
   // From now on we need at least constant length and string.
   StringRef Str;
   if (!LenC || !getConstantStringInfo(SrcStr, Str, 0, /*TrimAtNul=*/false))
@@ -926,10 +942,12 @@
     return Constant::getNullValue(CI->getType());
 
   // Handle constant lengths.
-  if (ConstantInt *LenC = dyn_cast<ConstantInt>(Size))
+  if (ConstantInt *LenC = dyn_cast<ConstantInt>(Size)) {
     if (Value *Res = optimizeMemCmpConstantSize(CI, LHS, RHS,
                                                 LenC->getZExtValue(), B, DL))
       return Res;
+    annotateDereferenceableBytes(CI, {0, 1}, LenC->getZExtValue());
+  }
 
   return nullptr;
 }
@@ -955,18 +973,31 @@
   return optimizeMemCmpBCmpCommon(CI, B);
 }
 
-Value *LibCallSimplifier::optimizeMemCpy(CallInst *CI, IRBuilder<> &B) {
+Value *LibCallSimplifier::optimizeMemCpy(CallInst *CI, IRBuilder<> &B,
+                                         bool isIntrinsic) {
+  Value *Size = CI->getArgOperand(2);
+  if (ConstantInt *LenC = dyn_cast<ConstantInt>(Size))
+    annotateDereferenceableBytes(CI, {0, 1}, LenC->getZExtValue());
+
+  if (isIntrinsic)
+    return nullptr;
+
   // memcpy(x, y, n) -> llvm.memcpy(align 1 x, align 1 y, n)
-  B.CreateMemCpy(CI->getArgOperand(0), 1, CI->getArgOperand(1), 1,
-                 CI->getArgOperand(2));
+  B.CreateMemCpy(CI->getArgOperand(0), 1, CI->getArgOperand(1), 1, Size);
   return CI->getArgOperand(0);
 }
 
-Value *LibCallSimplifier::optimizeMemMove(CallInst *CI, IRBuilder<> &B) {
+Value *LibCallSimplifier::optimizeMemMove(CallInst *CI, IRBuilder<> &B, bool isIntrinsic) {
+  Value *Size = CI->getArgOperand(2);
+  if (ConstantInt *LenC = dyn_cast<ConstantInt>(Size))
+    annotateDereferenceableBytes(CI, {0, 1}, LenC->getZExtValue());
+
+  if (isIntrinsic)
+    return nullptr;
+
   // memmove(x, y, n) -> llvm.memmove(align 1 x, align 1 y, n)
-  B.CreateMemMove(CI->getArgOperand(0), 1, CI->getArgOperand(1), 1,
-                  CI->getArgOperand(2));
-  return CI->getArgOperand(0);
+  B.CreateMemMove( CI->getArgOperand(0), 1, CI->getArgOperand(1), 1, Size);
+  return  CI->getArgOperand(0);
 }
 
 /// Fold memset[_chk](malloc(n), 0, n) --> calloc(1, n).
@@ -1015,13 +1046,21 @@
   return Calloc;
 }
 
-Value *LibCallSimplifier::optimizeMemSet(CallInst *CI, IRBuilder<> &B) {
+Value *LibCallSimplifier::optimizeMemSet(CallInst *CI, IRBuilder<> &B,
+                                         bool isIntrinsic) {
+  Value *Size = CI->getArgOperand(2);
+  if (ConstantInt *LenC = dyn_cast<ConstantInt>(Size))
+    annotateDereferenceableBytes(CI, {0}, LenC->getZExtValue());
+
+  if (isIntrinsic)
+    return nullptr;
+
   if (auto *Calloc = foldMallocMemset(CI, B))
     return Calloc;
 
   // memset(p, v, n) -> llvm.memset(align 1 p, v, n)
   Value *Val = B.CreateIntCast(CI->getArgOperand(1), B.getInt8Ty(), false);
-  B.CreateMemSet(CI->getArgOperand(0), Val, CI->getArgOperand(2), 1);
+  B.CreateMemSet(CI->getArgOperand(0), Val, Size, 1);
   return CI->getArgOperand(0);
 }
 
@@ -2710,6 +2749,12 @@
     case Intrinsic::sqrt:
       return optimizeSqrt(CI, Builder);
     // TODO: Use foldMallocMemset() with memset intrinsic.
+    case Intrinsic::memset:
+      return optimizeMemSet(CI, Builder, true);
+    case Intrinsic::memcpy:
+      return optimizeMemCpy(CI, Builder, true);
+    case Intrinsic::memmove:
+      return optimizeMemMove(CI, Builder, true);
     default:
       return nullptr;
     }