SimplifyLibCalls: Optimize wcslen

Refactor the strlen optimization code to work for both strlen and wcslen.

This especially helps with programs in the wild where people pass
L"string"s to const std::wstring& function parameters and the wstring
constructor gets inlined.

This also fixes a lingerind API problem/bug in getConstantStringInfo()
where zeroinitializers would always give you an empty string (without a
length) back regardless of the actual length of the initializer which
did not work well in the TrimAtNul==false causing the PR mentioned
below.

Note that the fixed getConstantStringInfo() needed fixes to SelectionDAG
memcpy lowering and may lead to some cases for out-of-bounds
zeroinitializer accesses not getting optimized anymore. So some code
with UB may produce out of bound memory reads now instead of just
producing zeros.

The refactoring "accidentally" fixes http://llvm.org/PR32124

Differential Revision: https://reviews.llvm.org/D32839

llvm-svn: 303461
diff --git a/llvm/lib/Analysis/TargetLibraryInfo.cpp b/llvm/lib/Analysis/TargetLibraryInfo.cpp
index 3cf1bbc..7a07ab7 100644
--- a/llvm/lib/Analysis/TargetLibraryInfo.cpp
+++ b/llvm/lib/Analysis/TargetLibraryInfo.cpp
@@ -13,6 +13,7 @@
 
 #include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/ADT/Triple.h"
+#include "llvm/IR/Constants.h"
 #include "llvm/Support/CommandLine.h"
 using namespace llvm;
 
@@ -1518,6 +1519,17 @@
   return *Impl;
 }
 
+unsigned TargetLibraryInfoImpl::getTargetWCharSize(const Triple &T) {
+  // See also clang/lib/Basic/Targets.cpp.
+  return T.isPS4() || T.isOSWindows() || T.getArch() == Triple::xcore ? 2 : 4;
+}
+
+unsigned TargetLibraryInfoImpl::getWCharSize(const Module &M) const {
+  if (auto *ShortWChar = cast_or_null<ConstantAsMetadata>(
+      M.getModuleFlag("wchar_size")))
+    return cast<ConstantInt>(ShortWChar->getValue())->getZExtValue();
+  return getTargetWCharSize(Triple(M.getTargetTriple()));
+}
 
 TargetLibraryInfoWrapperPass::TargetLibraryInfoWrapperPass()
     : ImmutablePass(ID), TLIImpl(), TLI(TLIImpl) {
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index cba7363..8e6c109 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -26,6 +26,7 @@
 #include "llvm/IR/ConstantRange.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DataLayout.h"
+#include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/GetElementPtrTypeIterator.h"
 #include "llvm/IR/GlobalAlias.h"
@@ -2953,14 +2954,16 @@
   return Ptr;
 }
 
-bool llvm::isGEPBasedOnPointerToString(const GEPOperator *GEP) {
+bool llvm::isGEPBasedOnPointerToString(const GEPOperator *GEP,
+                                       unsigned CharSize) {
   // Make sure the GEP has exactly three arguments.
   if (GEP->getNumOperands() != 3)
     return false;
 
-  // Make sure the index-ee is a pointer to array of i8.
+  // Make sure the index-ee is a pointer to array of \p CharSize integers.
+  // CharSize.
   ArrayType *AT = dyn_cast<ArrayType>(GEP->getSourceElementType());
-  if (!AT || !AT->getElementType()->isIntegerTy(8))
+  if (!AT || !AT->getElementType()->isIntegerTy(CharSize))
     return false;
 
   // Check to make sure that the first operand of the GEP is an integer and
@@ -2972,11 +2975,9 @@
   return true;
 }
 
-/// This function computes the length of a null-terminated C string pointed to
-/// by V. If successful, it returns true and returns the string in Str.
-/// If unsuccessful, it returns false.
-bool llvm::getConstantStringInfo(const Value *V, StringRef &Str,
-                                 uint64_t Offset, bool TrimAtNul) {
+bool llvm::getConstantDataArrayInfo(const Value *V,
+                                    ConstantDataArraySlice &Slice,
+                                    unsigned ElementSize, uint64_t Offset) {
   assert(V);
 
   // Look through bitcast instructions and geps.
@@ -2987,7 +2988,7 @@
   if (const GEPOperator *GEP = dyn_cast<GEPOperator>(V)) {
     // The GEP operator should be based on a pointer to string constant, and is
     // indexing into the string constant.
-    if (!isGEPBasedOnPointerToString(GEP))
+    if (!isGEPBasedOnPointerToString(GEP, ElementSize))
       return false;
 
     // If the second index isn't a ConstantInt, then this is a variable index
@@ -2998,8 +2999,8 @@
       StartIdx = CI->getZExtValue();
     else
       return false;
-    return getConstantStringInfo(GEP->getOperand(0), Str, StartIdx + Offset,
-                                 TrimAtNul);
+    return getConstantDataArrayInfo(GEP->getOperand(0), Slice, ElementSize,
+                                    StartIdx + Offset);
   }
 
   // The GEP instruction, constant or instruction, must reference a global
@@ -3009,30 +3010,72 @@
   if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer())
     return false;
 
-  // Handle the all-zeros case.
+  const ConstantDataArray *Array;
+  ArrayType *ArrayTy;
   if (GV->getInitializer()->isNullValue()) {
-    // This is a degenerate case. The initializer is constant zero so the
-    // length of the string must be zero.
-    Str = "";
-    return true;
-  }
+    Type *GVTy = GV->getValueType();
+    if ( (ArrayTy = dyn_cast<ArrayType>(GVTy)) ) {
+      // A zeroinitializer for the array; There is no ConstantDataArray.
+      Array = nullptr;
+    } else {
+      const DataLayout &DL = GV->getParent()->getDataLayout();
+      uint64_t SizeInBytes = DL.getTypeStoreSize(GVTy);
+      uint64_t Length = SizeInBytes / (ElementSize / 8);
+      if (Length <= Offset)
+        return false;
 
-  // This must be a ConstantDataArray.
-  const auto *Array = dyn_cast<ConstantDataArray>(GV->getInitializer());
-  if (!Array || !Array->isString())
+      Slice.Array = nullptr;
+      Slice.Offset = 0;
+      Slice.Length = Length - Offset;
+      return true;
+    }
+  } else {
+    // This must be a ConstantDataArray.
+    Array = dyn_cast<ConstantDataArray>(GV->getInitializer());
+    if (!Array)
+      return false;
+    ArrayTy = Array->getType();
+  }
+  if (!ArrayTy->getElementType()->isIntegerTy(ElementSize))
     return false;
 
-  // Get the number of elements in the array.
-  uint64_t NumElts = Array->getType()->getArrayNumElements();
-
-  // Start out with the entire array in the StringRef.
-  Str = Array->getAsString();
-
+  uint64_t NumElts = ArrayTy->getArrayNumElements();
   if (Offset > NumElts)
     return false;
 
+  Slice.Array = Array;
+  Slice.Offset = Offset;
+  Slice.Length = NumElts - Offset;
+  return true;
+}
+
+/// This function computes the length of a null-terminated C string pointed to
+/// by V. If successful, it returns true and returns the string in Str.
+/// If unsuccessful, it returns false.
+bool llvm::getConstantStringInfo(const Value *V, StringRef &Str,
+                                 uint64_t Offset, bool TrimAtNul) {
+  ConstantDataArraySlice Slice;
+  if (!getConstantDataArrayInfo(V, Slice, 8, Offset))
+    return false;
+
+  if (Slice.Array == nullptr) {
+    if (TrimAtNul) {
+      Str = StringRef();
+      return true;
+    }
+    if (Slice.Length == 1) {
+      Str = StringRef("", 1);
+      return true;
+    }
+    // We cannot instantiate a StringRef as we do not have an apropriate string
+    // of 0s at hand.
+    return false;
+  }
+
+  // Start out with the entire array in the StringRef.
+  Str = Slice.Array->getAsString();
   // Skip over 'offset' bytes.
-  Str = Str.substr(Offset);
+  Str = Str.substr(Slice.Offset);
 
   if (TrimAtNul) {
     // Trim off the \0 and anything after it.  If the array is not nul
@@ -3050,7 +3093,8 @@
 /// If we can compute the length of the string pointed to by
 /// the specified pointer, return 'len+1'.  If we can't, return 0.
 static uint64_t GetStringLengthH(const Value *V,
-                                 SmallPtrSetImpl<const PHINode*> &PHIs) {
+                                 SmallPtrSetImpl<const PHINode*> &PHIs,
+                                 unsigned CharSize) {
   // Look through noop bitcast instructions.
   V = V->stripPointerCasts();
 
@@ -3063,7 +3107,7 @@
     // If it was new, see if all the input strings are the same length.
     uint64_t LenSoFar = ~0ULL;
     for (Value *IncValue : PN->incoming_values()) {
-      uint64_t Len = GetStringLengthH(IncValue, PHIs);
+      uint64_t Len = GetStringLengthH(IncValue, PHIs, CharSize);
       if (Len == 0) return 0; // Unknown length -> unknown.
 
       if (Len == ~0ULL) continue;
@@ -3079,9 +3123,9 @@
 
   // strlen(select(c,x,y)) -> strlen(x) ^ strlen(y)
   if (const SelectInst *SI = dyn_cast<SelectInst>(V)) {
-    uint64_t Len1 = GetStringLengthH(SI->getTrueValue(), PHIs);
+    uint64_t Len1 = GetStringLengthH(SI->getTrueValue(), PHIs, CharSize);
     if (Len1 == 0) return 0;
-    uint64_t Len2 = GetStringLengthH(SI->getFalseValue(), PHIs);
+    uint64_t Len2 = GetStringLengthH(SI->getFalseValue(), PHIs, CharSize);
     if (Len2 == 0) return 0;
     if (Len1 == ~0ULL) return Len2;
     if (Len2 == ~0ULL) return Len1;
@@ -3090,20 +3134,30 @@
   }
 
   // Otherwise, see if we can read the string.
-  StringRef StrData;
-  if (!getConstantStringInfo(V, StrData))
+  ConstantDataArraySlice Slice;
+  if (!getConstantDataArrayInfo(V, Slice, CharSize))
     return 0;
 
-  return StrData.size()+1;
+  if (Slice.Array == nullptr)
+    return 1;
+
+  // Search for nul characters
+  unsigned NullIndex = 0;
+  for (unsigned E = Slice.Length; NullIndex < E; ++NullIndex) {
+    if (Slice.Array->getElementAsInteger(Slice.Offset + NullIndex) == 0)
+      break;
+  }
+
+  return NullIndex + 1;
 }
 
 /// If we can compute the length of the string pointed to by
 /// the specified pointer, return 'len+1'.  If we can't, return 0.
-uint64_t llvm::GetStringLength(const Value *V) {
+uint64_t llvm::GetStringLength(const Value *V, unsigned CharSize) {
   if (!V->getType()->isPointerTy()) return 0;
 
   SmallPtrSet<const PHINode*, 32> PHIs;
-  uint64_t Len = GetStringLengthH(V, PHIs);
+  uint64_t Len = GetStringLengthH(V, PHIs, CharSize);
   // If Len is ~0ULL, we had an infinite phi cycle: this is dead code, so return
   // an empty string as a length.
   return Len == ~0ULL ? 1 : Len;