Teach simplify libcall to transform __strcpy_chk to __memcpy_chk to enable optimizations down stream.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@99282 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Transforms/Scalar/SimplifyLibCalls.cpp b/lib/Transforms/Scalar/SimplifyLibCalls.cpp
index 22f3628..058cd3c 100644
--- a/lib/Transforms/Scalar/SimplifyLibCalls.cpp
+++ b/lib/Transforms/Scalar/SimplifyLibCalls.cpp
@@ -49,10 +49,13 @@
   Function *Caller;
   const TargetData *TD;
   LLVMContext* Context;
+  bool OptChkCall;  // True if it's optimizing a *_chk libcall.
 public:
-  LibCallOptimization() { }
+  LibCallOptimization() : OptChkCall(false) { }
   virtual ~LibCallOptimization() {}
 
+  void setOptChkCall(bool c) { OptChkCall = c; }
+
   /// CallOptimizer - This pure virtual method is implemented by base classes to
   /// do various optimizations.  If this returns null then no transformation was
   /// performed.  If it returns CI, then it transformed the call and CI is to be
@@ -352,8 +355,10 @@
 struct StrCpyOpt : public LibCallOptimization {
   virtual Value *CallOptimizer(Function *Callee, CallInst *CI, IRBuilder<> &B) {
     // Verify the "strcpy" function prototype.
+    unsigned NumParams = OptChkCall ? 3 : 2;
     const FunctionType *FT = Callee->getFunctionType();
-    if (FT->getNumParams() != 2 || FT->getReturnType() != FT->getParamType(0) ||
+    if (FT->getNumParams() != NumParams ||
+        FT->getReturnType() != FT->getParamType(0) ||
         FT->getParamType(0) != FT->getParamType(1) ||
         FT->getParamType(0) != Type::getInt8PtrTy(*Context))
       return 0;
@@ -371,8 +376,13 @@
 
     // We have enough information to now generate the memcpy call to do the
     // concatenation for us.  Make a memcpy to copy the nul byte with align = 1.
-    EmitMemCpy(Dst, Src,
-               ConstantInt::get(TD->getIntPtrType(*Context), Len), 1, B, TD);
+    if (OptChkCall)
+      EmitMemCpyChk(Dst, Src,
+                    ConstantInt::get(TD->getIntPtrType(*Context), Len),
+                    CI->getOperand(3), B, TD);
+    else
+      EmitMemCpy(Dst, Src,
+                 ConstantInt::get(TD->getIntPtrType(*Context), Len), 1, B, TD);
     return Dst;
   }
 };
@@ -1162,7 +1172,8 @@
     StringMap<LibCallOptimization*> Optimizations;
     // String and Memory LibCall Optimizations
     StrCatOpt StrCat; StrNCatOpt StrNCat; StrChrOpt StrChr; StrCmpOpt StrCmp;
-    StrNCmpOpt StrNCmp; StrCpyOpt StrCpy; StrNCpyOpt StrNCpy; StrLenOpt StrLen;
+    StrNCmpOpt StrNCmp; StrCpyOpt StrCpy; StrCpyOpt StrCpyChk;
+    StrNCpyOpt StrNCpy; StrLenOpt StrLen;
     StrToOpt StrTo; StrStrOpt StrStr;
     MemCmpOpt MemCmp; MemCpyOpt MemCpy; MemMoveOpt MemMove; MemSetOpt MemSet;
     // Math Library Optimizations
@@ -1228,6 +1239,10 @@
   Optimizations["memmove"] = &MemMove;
   Optimizations["memset"] = &MemSet;
 
+  // _chk variants of String and Memory LibCall Optimizations.
+  StrCpyChk.setOptChkCall(true);
+  Optimizations["__strcpy_chk"] = &StrCpyChk;
+
   // Math Library Optimizations
   Optimizations["powf"] = &Pow;
   Optimizations["pow"] = &Pow;
diff --git a/lib/Transforms/Utils/BuildLibCalls.cpp b/lib/Transforms/Utils/BuildLibCalls.cpp
index 2e6edfa..0afccf4 100644
--- a/lib/Transforms/Utils/BuildLibCalls.cpp
+++ b/lib/Transforms/Utils/BuildLibCalls.cpp
@@ -108,7 +108,7 @@
 
 
 /// EmitMemCpy - Emit a call to the memcpy function to the builder.  This always
-/// expects that the size has type 'intptr_t' and Dst/Src are pointers.
+/// expects that Len has type 'intptr_t' and Dst/Src are pointers.
 Value *llvm::EmitMemCpy(Value *Dst, Value *Src, Value *Len,
                         unsigned Align, IRBuilder<> &B, const TargetData *TD) {
   Module *M = B.GetInsertBlock()->getParent()->getParent();
@@ -120,6 +120,30 @@
                        ConstantInt::get(B.getInt32Ty(), Align));
 }
 
+/// EmitMemCpyChk - Emit a call to the __memcpy_chk function to the builder.
+/// This expects that the Len and ObjSize have type 'intptr_t' and Dst/Src
+/// are pointers.
+Value *llvm::EmitMemCpyChk(Value *Dst, Value *Src, Value *Len, Value *ObjSize,
+                           IRBuilder<> &B, const TargetData *TD) {
+  Module *M = B.GetInsertBlock()->getParent()->getParent();
+  AttributeWithIndex AWI;
+  AWI = AttributeWithIndex::get(~0u, Attribute::NoUnwind);
+  LLVMContext &Context = B.GetInsertBlock()->getContext();
+  Value *MemCpy = M->getOrInsertFunction("__memcpy_chk",
+                                         AttrListPtr::get(&AWI, 1),
+                                         B.getInt8PtrTy(),
+                                         B.getInt8PtrTy(),
+                                         B.getInt8PtrTy(),
+                                         TD->getIntPtrType(Context),
+                                         TD->getIntPtrType(Context), NULL);
+  Dst = CastToCStr(Dst, B);
+  Src = CastToCStr(Src, B);
+  CallInst *CI = B.CreateCall4(MemCpy, Dst, Src, Len, ObjSize);
+  if (const Function *F = dyn_cast<Function>(MemCpy->stripPointerCasts()))
+    CI->setCallingConv(F->getCallingConv());
+  return CI;
+}
+
 /// EmitMemMove - Emit a call to the memmove function to the builder.  This
 /// always expects that the size has type 'intptr_t' and Dst/Src are pointers.
 Value *llvm::EmitMemMove(Value *Dst, Value *Src, Value *Len,