SimplifyLibCalls: Optimize wcslen

Refactor the strlen optimization code to work for both strlen and wcslen.

This especially helps with programs in the wild where people pass
L"string"s to const std::wstring& function parameters and the wstring
constructor gets inlined.

This also fixes a lingerind API problem/bug in getConstantStringInfo()
where zeroinitializers would always give you an empty string (without a
length) back regardless of the actual length of the initializer which
did not work well in the TrimAtNul==false causing the PR mentioned
below.

Note that the fixed getConstantStringInfo() needed fixes to SelectionDAG
memcpy lowering and may lead to some cases for out-of-bounds
zeroinitializer accesses not getting optimized anymore. So some code
with UB may produce out of bound memory reads now instead of just
producing zeros.

The refactoring "accidentally" fixes http://llvm.org/PR32124

Differential Revision: https://reviews.llvm.org/D32839

llvm-svn: 303461
diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
index 1de579e..85c9464 100644
--- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
@@ -426,57 +426,70 @@
   return Dst;
 }
 
-Value *LibCallSimplifier::optimizeStrLen(CallInst *CI, IRBuilder<> &B) {
+Value *LibCallSimplifier::optimizeStringLength(CallInst *CI, IRBuilder<> &B,
+                                               unsigned CharSize) {
   Value *Src = CI->getArgOperand(0);
 
   // Constant folding: strlen("xyz") -> 3
-  if (uint64_t Len = GetStringLength(Src))
+  if (uint64_t Len = GetStringLength(Src, CharSize))
     return ConstantInt::get(CI->getType(), Len - 1);
 
   // If s is a constant pointer pointing to a string literal, we can fold
-  // strlen(s + x) to strlen(s) - x, when x is known to be in the range 
+  // strlen(s + x) to strlen(s) - x, when x is known to be in the range
   // [0, strlen(s)] or the string has a single null terminator '\0' at the end.
-  // We only try to simplify strlen when the pointer s points to an array 
+  // We only try to simplify strlen when the pointer s points to an array
   // of i8. Otherwise, we would need to scale the offset x before doing the
-  // subtraction. This will make the optimization more complex, and it's not 
-  // very useful because calling strlen for a pointer of other types is 
+  // subtraction. This will make the optimization more complex, and it's not
+  // very useful because calling strlen for a pointer of other types is
   // very uncommon.
   if (GEPOperator *GEP = dyn_cast<GEPOperator>(Src)) {
-    if (!isGEPBasedOnPointerToString(GEP))
+    if (!isGEPBasedOnPointerToString(GEP, CharSize))
       return nullptr;
 
-    StringRef Str;
-    if (getConstantStringInfo(GEP->getOperand(0), Str, 0, false)) {
-      size_t NullTermIdx = Str.find('\0');
-      
-      // If the string does not have '\0', leave it to strlen to compute
-      // its length.
-      if (NullTermIdx == StringRef::npos)
-        return nullptr;
-     
+    ConstantDataArraySlice Slice;
+    if (getConstantDataArrayInfo(GEP->getOperand(0), Slice, CharSize)) {
+      uint64_t NullTermIdx;
+      if (Slice.Array == nullptr) {
+        NullTermIdx = 0;
+      } else {
+        NullTermIdx = ~((uint64_t)0);
+        for (uint64_t I = 0, E = Slice.Length; I < E; ++I) {
+          if (Slice.Array->getElementAsInteger(I + Slice.Offset) == 0) {
+            NullTermIdx = I;
+            break;
+          }
+        }
+        // If the string does not have '\0', leave it to strlen to compute
+        // its length.
+        if (NullTermIdx == ~((uint64_t)0))
+          return nullptr;
+      }
+
       Value *Offset = GEP->getOperand(2);
       unsigned BitWidth = Offset->getType()->getIntegerBitWidth();
       KnownBits Known(BitWidth);
       computeKnownBits(Offset, Known, DL, 0, nullptr, CI, nullptr);
       Known.Zero.flipAllBits();
-      size_t ArrSize = 
+      uint64_t ArrSize =
              cast<ArrayType>(GEP->getSourceElementType())->getNumElements();
 
-      // KnownZero's bits are flipped, so zeros in KnownZero now represent 
-      // bits known to be zeros in Offset, and ones in KnowZero represent 
+      // KnownZero's bits are flipped, so zeros in KnownZero now represent
+      // bits known to be zeros in Offset, and ones in KnowZero represent
       // bits unknown in Offset. Therefore, Offset is known to be in range
-      // [0, NullTermIdx] when the flipped KnownZero is non-negative and 
+      // [0, NullTermIdx] when the flipped KnownZero is non-negative and
       // unsigned-less-than NullTermIdx.
       //
-      // If Offset is not provably in the range [0, NullTermIdx], we can still 
-      // optimize if we can prove that the program has undefined behavior when 
-      // Offset is outside that range. That is the case when GEP->getOperand(0) 
+      // If Offset is not provably in the range [0, NullTermIdx], we can still
+      // optimize if we can prove that the program has undefined behavior when
+      // Offset is outside that range. That is the case when GEP->getOperand(0)
       // is a pointer to an object whose memory extent is NullTermIdx+1.
-      if ((Known.Zero.isNonNegative() && Known.Zero.ule(NullTermIdx)) || 
+      if ((Known.Zero.isNonNegative() && Known.Zero.ule(NullTermIdx)) ||
           (GEP->isInBounds() && isa<GlobalVariable>(GEP->getOperand(0)) &&
-           NullTermIdx == ArrSize - 1))
-        return B.CreateSub(ConstantInt::get(CI->getType(), NullTermIdx), 
+           NullTermIdx == ArrSize - 1)) {
+        Offset = B.CreateSExtOrTrunc(Offset, CI->getType());
+        return B.CreateSub(ConstantInt::get(CI->getType(), NullTermIdx),
                            Offset);
+      }
     }
 
     return nullptr;
@@ -484,8 +497,8 @@
 
   // strlen(x?"foo":"bars") --> x ? 3 : 4
   if (SelectInst *SI = dyn_cast<SelectInst>(Src)) {
-    uint64_t LenTrue = GetStringLength(SI->getTrueValue());
-    uint64_t LenFalse = GetStringLength(SI->getFalseValue());
+    uint64_t LenTrue = GetStringLength(SI->getTrueValue(), CharSize);
+    uint64_t LenFalse = GetStringLength(SI->getFalseValue(), CharSize);
     if (LenTrue && LenFalse) {
       Function *Caller = CI->getParent()->getParent();
       emitOptimizationRemark(CI->getContext(), "simplify-libcalls", *Caller,
@@ -505,6 +518,17 @@
   return nullptr;
 }
 
+Value *LibCallSimplifier::optimizeStrLen(CallInst *CI, IRBuilder<> &B) {
+  return optimizeStringLength(CI, B, 8);
+}
+
+Value *LibCallSimplifier::optimizeWcslen(CallInst *CI, IRBuilder<> &B) {
+  Module &M = *CI->getParent()->getParent()->getParent();
+  unsigned WCharSize = TLI->getWCharSize(M) * 8;
+
+  return optimizeStringLength(CI, B, WCharSize);
+}
+
 Value *LibCallSimplifier::optimizeStrPBrk(CallInst *CI, IRBuilder<> &B) {
   StringRef S1, S2;
   bool HasS1 = getConstantStringInfo(CI->getArgOperand(0), S1);
@@ -2026,6 +2050,8 @@
       return optimizeMemMove(CI, Builder);
     case LibFunc_memset:
       return optimizeMemSet(CI, Builder);
+    case LibFunc_wcslen:
+      return optimizeWcslen(CI, Builder);
     default:
       break;
     }