fixes for strcpy optimizer
llvm-svn: 35709
diff --git a/llvm/lib/Transforms/IPO/SimplifyLibCalls.cpp b/llvm/lib/Transforms/IPO/SimplifyLibCalls.cpp
index 0f19740..155fafd 100644
--- a/llvm/lib/Transforms/IPO/SimplifyLibCalls.cpp
+++ b/llvm/lib/Transforms/IPO/SimplifyLibCalls.cpp
@@ -705,7 +705,7 @@
     Value *Str1P = CI->getOperand(1);
     Value *Str2P = CI->getOperand(2);
     if (Str1P == Str2P) {
-      // strcmp(x,x)  -> 0
+      // strncmp(x,x)  -> 0
       CI->replaceAllUsesWith(ConstantInt::get(CI->getType(), 0));
       CI->eraseFromParent();
       return true;
@@ -713,8 +713,13 @@
     
     // Check the length argument, if it is Constant zero then the strings are
     // considered equal.
-    ConstantInt *LengthArg = dyn_cast<ConstantInt>(CI->getOperand(3));
-    if (LengthArg && LengthArg->isZero()) {
+    uint64_t Length;
+    if (ConstantInt *LengthArg = dyn_cast<ConstantInt>(CI->getOperand(3)))
+      Length = LengthArg->getZExtValue();
+    else
+      return false;
+    
+    if (Length == 0) {
       // strncmp(x,y,0)   -> 0
       CI->replaceAllUsesWith(ConstantInt::get(CI->getType(), 0));
       CI->eraseFromParent();
@@ -725,7 +730,7 @@
     ConstantArray *A1;
     bool Str1IsCst = GetConstantStringInfo(Str1P, A1, Str1Len, Str1StartIdx);
     if (Str1IsCst && Str1Len == 0) {
-      // strcmp("", x) -> *x
+      // strncmp("", x) -> *x
       Value *V = new LoadInst(Str2P, CI->getName()+".load", CI);
       V = new ZExtInst(V, CI->getType(), CI->getName()+".int", CI);
       CI->replaceAllUsesWith(V);
@@ -737,7 +742,7 @@
     ConstantArray* A2;
     bool Str2IsCst = GetConstantStringInfo(Str2P, A2, Str2Len, Str2StartIdx);
     if (Str2IsCst && Str2Len == 0) {
-      // strcmp(x,"") -> *x
+      // strncmp(x,"") -> *x
       Value *V = new LoadInst(Str1P, CI->getName()+".load", CI);
       V = new ZExtInst(V, CI->getType(), CI->getName()+".int", CI);
       CI->replaceAllUsesWith(V);
@@ -745,13 +750,12 @@
       return true;
     }
     
-    if (LengthArg && Str1IsCst && Str2IsCst && A1->isCString() &&
+    if (Str1IsCst && Str2IsCst && A1->isCString() &&
         A2->isCString()) {
-      // strcmp(x, y)  -> cnst  (if both x and y are constant strings)
+      // strncmp(x, y)  -> cnst  (if both x and y are constant strings)
       std::string S1 = A1->getAsString();
       std::string S2 = A2->getAsString();
-      int R = strncmp(S1.c_str()+Str1StartIdx, S2.c_str()+Str2StartIdx,
-                      LengthArg->getZExtValue());
+      int R = strncmp(S1.c_str()+Str1StartIdx, S2.c_str()+Str2StartIdx, Length);
       CI->replaceAllUsesWith(ConstantInt::get(CI->getType(), R));
       CI->eraseFromParent();
       return true;
@@ -771,71 +775,57 @@
       "Number of 'strcpy' calls simplified") {}
 
   /// @brief Make sure that the "strcpy" function has the right prototype
-  virtual bool ValidateCalledFunction(const Function* f, SimplifyLibCalls& SLC){
-    if (f->getReturnType() == PointerType::get(Type::Int8Ty))
-      if (f->arg_size() == 2) {
-        Function::const_arg_iterator AI = f->arg_begin();
-        if (AI++->getType() == PointerType::get(Type::Int8Ty))
-          if (AI->getType() == PointerType::get(Type::Int8Ty)) {
-            // Indicate this is a suitable call type.
-            return true;
-          }
-      }
-    return false;
+  virtual bool ValidateCalledFunction(const Function *F, SimplifyLibCalls &SLC){
+    const FunctionType *FT = F->getFunctionType();
+    return FT->getNumParams() == 2 &&
+           FT->getParamType(0) == FT->getParamType(1) &&
+           FT->getReturnType() == FT->getParamType(0) &&
+           FT->getParamType(0) == PointerType::get(Type::Int8Ty);
   }
 
   /// @brief Perform the strcpy optimization
-  virtual bool OptimizeCall(CallInst* ci, SimplifyLibCalls& SLC) {
+  virtual bool OptimizeCall(CallInst *CI, SimplifyLibCalls &SLC) {
     // First, check to see if src and destination are the same. If they are,
     // then the optimization is to replace the CallInst with the destination
     // because the call is a no-op. Note that this corresponds to the
     // degenerate strcpy(X,X) case which should have "undefined" results
     // according to the C specification. However, it occurs sometimes and
     // we optimize it as a no-op.
-    Value* dest = ci->getOperand(1);
-    Value* src = ci->getOperand(2);
-    if (dest == src) {
-      ci->replaceAllUsesWith(dest);
-      ci->eraseFromParent();
+    Value *Dst = CI->getOperand(1);
+    Value *Src = CI->getOperand(2);
+    if (Dst == Src) {
+      // strcpy(x, x) -> x
+      CI->replaceAllUsesWith(Dst);
+      CI->eraseFromParent();
       return true;
     }
-
-    // Get the length of the constant string referenced by the second operand,
-    // the "src" parameter. Fail the optimization if we can't get the length
-    // (note that GetConstantStringInfo does lots of checks to make sure this
-    // is valid).
-    uint64_t len, StartIdx;
-    ConstantArray *A;
-    if (!GetConstantStringInfo(ci->getOperand(2), A, len, StartIdx))
+    
+    // Get the length of the constant string referenced by the Src operand.
+    uint64_t SrcLen, SrcStartIdx;
+    ConstantArray *SrcArr;
+    if (!GetConstantStringInfo(Src, SrcArr, SrcLen, SrcStartIdx))
       return false;
 
     // If the constant string's length is zero we can optimize this by just
     // doing a store of 0 at the first byte of the destination
-    if (len == 0) {
-      new StoreInst(ConstantInt::get(Type::Int8Ty,0),ci->getOperand(1),ci);
-      ci->replaceAllUsesWith(dest);
-      ci->eraseFromParent();
+    if (SrcLen == 0) {
+      new StoreInst(ConstantInt::get(Type::Int8Ty, 0), Dst, CI);
+      CI->replaceAllUsesWith(Dst);
+      CI->eraseFromParent();
       return true;
     }
 
-    // Increment the length because we actually want to memcpy the null
-    // terminator as well.
-    len++;
-
     // We have enough information to now generate the memcpy call to
     // do the concatenation for us.
-    Value *vals[4] = {
-      dest, src,
-      ConstantInt::get(SLC.getIntPtrType(),len), // length
+    Value *MemcpyOps[] = {
+      Dst, Src,
+      ConstantInt::get(SLC.getIntPtrType(), SrcLen), // length including nul.
       ConstantInt::get(Type::Int32Ty, 1) // alignment
     };
-    new CallInst(SLC.get_memcpy(), vals, 4, "", ci);
+    new CallInst(SLC.get_memcpy(), MemcpyOps, 4, "", CI);
 
-    // Finally, substitute the first operand of the strcat call for the
-    // strcat call itself since strcat returns its first operand; and,
-    // kill the strcat CallInst.
-    ci->replaceAllUsesWith(dest);
-    ci->eraseFromParent();
+    CI->replaceAllUsesWith(Dst);
+    CI->eraseFromParent();
     return true;
   }
 } StrCpyOptimizer;
@@ -849,8 +839,7 @@
       "Number of 'strlen' calls simplified") {}
 
   /// @brief Make sure that the "strlen" function has the right prototype
-  virtual bool ValidateCalledFunction(const Function* f, SimplifyLibCalls& SLC)
-  {
+  virtual bool ValidateCalledFunction(const Function *F, SimplifyLibCalls &SLC){
     if (f->getReturnType() == SLC.getTargetData()->getIntPtrType())
       if (f->arg_size() == 1)
         if (Function::const_arg_iterator AI = f->arg_begin())