Fix the representation of wide strings in the AST and IR so that it uses the native representation of integers for the elements.  This fixes a bunch of nastiness involving
treating wide strings as a series of bytes.

Patch by Seth Cantrell.



git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@143417 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/AST/Expr.cpp b/lib/AST/Expr.cpp
index 96a8125..3239973 100644
--- a/lib/AST/Expr.cpp
+++ b/lib/AST/Expr.cpp
@@ -29,6 +29,7 @@
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/raw_ostream.h"
 #include <algorithm>
+#include <cstring>
 using namespace clang;
 
 /// isKnownToHaveBooleanValue - Return true if this is an integer expression
@@ -482,6 +483,29 @@
   return V.convertToDouble();
 }
 
+int StringLiteral::mapCharByteWidth(TargetInfo const &target,StringKind k) {
+  int CharByteWidth;
+  switch(k) {
+    case Ascii:
+    case UTF8:
+      CharByteWidth = target.getCharWidth();
+      break;
+    case Wide:
+      CharByteWidth = target.getWCharWidth();
+      break;
+    case UTF16:
+      CharByteWidth = target.getChar16Width();
+      break;
+    case UTF32:
+      CharByteWidth = target.getChar32Width();
+  }
+  assert((CharByteWidth & 7) == 0 && "Assumes character size is byte multiple");
+  CharByteWidth /= 8;
+  assert((CharByteWidth==1 || CharByteWidth==2 || CharByteWidth==4)
+         && "character byte widths supported are 1, 2, and 4 only");
+  return CharByteWidth;
+}
+
 StringLiteral *StringLiteral::Create(ASTContext &C, StringRef Str,
                                      StringKind Kind, bool Pascal, QualType Ty,
                                      const SourceLocation *Loc,
@@ -494,12 +518,8 @@
   StringLiteral *SL = new (Mem) StringLiteral(Ty);
 
   // OPTIMIZE: could allocate this appended to the StringLiteral.
-  char *AStrData = new (C, 1) char[Str.size()];
-  memcpy(AStrData, Str.data(), Str.size());
-  SL->StrData = AStrData;
-  SL->ByteLength = Str.size();
-  SL->Kind = Kind;
-  SL->IsPascal = Pascal;
+  SL->setString(C,Str,Kind,Pascal);
+
   SL->TokLocs[0] = Loc[0];
   SL->NumConcatenated = NumStrs;
 
@@ -513,17 +533,46 @@
                          sizeof(SourceLocation)*(NumStrs-1),
                          llvm::alignOf<StringLiteral>());
   StringLiteral *SL = new (Mem) StringLiteral(QualType());
-  SL->StrData = 0;
-  SL->ByteLength = 0;
+  SL->CharByteWidth = 0;
+  SL->Length = 0;
   SL->NumConcatenated = NumStrs;
   return SL;
 }
 
-void StringLiteral::setString(ASTContext &C, StringRef Str) {
-  char *AStrData = new (C, 1) char[Str.size()];
-  memcpy(AStrData, Str.data(), Str.size());
-  StrData = AStrData;
-  ByteLength = Str.size();
+void StringLiteral::setString(ASTContext &C, StringRef Str,
+                              StringKind Kind, bool IsPascal) {
+  //FIXME: we assume that the string data comes from a target that uses the same
+  // code unit size and endianess for the type of string.
+  this->Kind = Kind;
+  this->IsPascal = IsPascal;
+  
+  CharByteWidth = mapCharByteWidth(C.getTargetInfo(),Kind);
+  assert((Str.size()%CharByteWidth == 0)
+         && "size of data must be multiple of CharByteWidth");
+  Length = Str.size()/CharByteWidth;
+
+  switch(CharByteWidth) {
+    case 1: {
+      char *AStrData = new (C) char[Length];
+      std::memcpy(AStrData,Str.data(),Str.size());
+      StrData.asChar = AStrData;
+      break;
+    }
+    case 2: {
+      uint16_t *AStrData = new (C) uint16_t[Length];
+      std::memcpy(AStrData,Str.data(),Str.size());
+      StrData.asUInt16 = AStrData;
+      break;
+    }
+    case 4: {
+      uint32_t *AStrData = new (C) uint32_t[Length];
+      std::memcpy(AStrData,Str.data(),Str.size());
+      StrData.asUInt32 = AStrData;
+      break;
+    }
+    default:
+      assert(false && "unsupported CharByteWidth");
+  }
 }
 
 /// getLocationOfByte - Return a source location that points to the specified
diff --git a/lib/CodeGen/CGExprConstant.cpp b/lib/CodeGen/CGExprConstant.cpp
index 0622c10..889cdd8 100644
--- a/lib/CodeGen/CGExprConstant.cpp
+++ b/lib/CodeGen/CGExprConstant.cpp
@@ -817,13 +817,7 @@
   }
 
   llvm::Constant *VisitStringLiteral(StringLiteral *E) {
-    assert(!E->getType()->isPointerType() && "Strings are always arrays");
-
-    // This must be a string initializing an array in a static initializer.
-    // Don't emit it as the address of the string, emit the string data itself
-    // as an inline array.
-    return llvm::ConstantArray::get(VMContext,
-                                    CGM.GetStringForStringLiteral(E), false);
+    return CGM.GetConstantArrayFromStringLiteral(E);
   }
 
   llvm::Constant *VisitObjCEncodeExpr(ObjCEncodeExpr *E) {
diff --git a/lib/CodeGen/CodeGenModule.cpp b/lib/CodeGen/CodeGenModule.cpp
index c796e0d..0905c4b 100644
--- a/lib/CodeGen/CodeGenModule.cpp
+++ b/lib/CodeGen/CodeGenModule.cpp
@@ -2037,6 +2037,8 @@
 /// GetStringForStringLiteral - Return the appropriate bytes for a
 /// string literal, properly padded to match the literal type.
 std::string CodeGenModule::GetStringForStringLiteral(const StringLiteral *E) {
+  assert((E->isAscii() || E->isUTF8())
+         && "Use GetConstantArrayFromStringLiteral for wide strings");
   const ASTContext &Context = getContext();
   const ConstantArrayType *CAT =
     Context.getAsConstantArrayType(E->getType());
@@ -2045,27 +2047,44 @@
   // Resize the string to the right size.
   uint64_t RealLen = CAT->getSize().getZExtValue();
 
-  switch (E->getKind()) {
-  case StringLiteral::Ascii:
-  case StringLiteral::UTF8:
-    break;
-  case StringLiteral::Wide:
-    RealLen *= Context.getTargetInfo().getWCharWidth() / Context.getCharWidth();
-    break;
-  case StringLiteral::UTF16:
-    RealLen *= Context.getTargetInfo().getChar16Width() / Context.getCharWidth();
-    break;
-  case StringLiteral::UTF32:
-    RealLen *= Context.getTargetInfo().getChar32Width() / Context.getCharWidth();
-    break;
-  }
-
   std::string Str = E->getString().str();
   Str.resize(RealLen, '\0');
 
   return Str;
 }
 
+llvm::Constant *
+CodeGenModule::GetConstantArrayFromStringLiteral(const StringLiteral *E) {
+  assert(!E->getType()->isPointerType() && "Strings are always arrays");
+  
+  // Don't emit it as the address of the string, emit the string data itself
+  // as an inline array.
+  if (E->getCharByteWidth()==1) {
+    return llvm::ConstantArray::get(VMContext,
+                                    GetStringForStringLiteral(E), false);
+  } else {
+    llvm::ArrayType *AType =
+      cast<llvm::ArrayType>(getTypes().ConvertType(E->getType()));
+    llvm::Type *ElemTy = AType->getElementType();
+    unsigned NumElements = AType->getNumElements();
+    std::vector<llvm::Constant*> Elts;
+    Elts.reserve(NumElements);
+    
+    for(unsigned i=0;i<E->getLength();++i) {
+      unsigned value = E->getCodeUnit(i);
+      llvm::Constant *C = llvm::ConstantInt::get(ElemTy,value,false);
+      Elts.push_back(C);
+    }
+    for(unsigned i=E->getLength();i<NumElements;++i) {
+      llvm::Constant *C = llvm::ConstantInt::get(ElemTy,0,false);
+      Elts.push_back(C);
+    }
+    
+    return llvm::ConstantArray::get(AType, Elts);
+  }
+
+}
+
 /// GetAddrOfConstantStringFromLiteral - Return a pointer to a
 /// constant array for the given string literal.
 llvm::Constant *
@@ -2073,15 +2092,23 @@
   // FIXME: This can be more efficient.
   // FIXME: We shouldn't need to bitcast the constant in the wide string case.
   CharUnits Align = getContext().getTypeAlignInChars(S->getType());
-  llvm::Constant *C = GetAddrOfConstantString(GetStringForStringLiteral(S),
-                                              /* GlobalName */ 0,
-                                              Align.getQuantity());
-  if (S->isWide() || S->isUTF16() || S->isUTF32()) {
-    llvm::Type *DestTy =
-        llvm::PointerType::getUnqual(getTypes().ConvertType(S->getType()));
-    C = llvm::ConstantExpr::getBitCast(C, DestTy);
+  if (S->isAscii() || S->isUTF8()) {
+    return GetAddrOfConstantString(GetStringForStringLiteral(S),
+                                   /* GlobalName */ 0,
+                                   Align.getQuantity());
   }
-  return C;
+
+  // FIXME: the following does not memoize wide strings
+  llvm::Constant *C = GetConstantArrayFromStringLiteral(S);
+  llvm::GlobalVariable *GV =
+    new llvm::GlobalVariable(getModule(),C->getType(),
+                             !Features.WritableStrings,
+                             llvm::GlobalValue::PrivateLinkage,
+                             C,".str");
+  GV->setAlignment(Align.getQuantity());
+  GV->setUnnamedAddr(true);
+  
+  return GV;
 }
 
 /// GetAddrOfConstantStringFromObjCEncode - Return a pointer to a constant
diff --git a/lib/CodeGen/CodeGenModule.h b/lib/CodeGen/CodeGenModule.h
index ea2e177..0ce698a 100644
--- a/lib/CodeGen/CodeGenModule.h
+++ b/lib/CodeGen/CodeGenModule.h
@@ -565,6 +565,10 @@
   /// -fconstant-string-class=class_name option.
   llvm::Constant *GetAddrOfConstantString(const StringLiteral *Literal);
 
+  /// GetConstantArrayFromStringLiteral - Return a constant array for the given
+  /// string.
+  llvm::Constant *GetConstantArrayFromStringLiteral(const StringLiteral *E);
+
   /// GetAddrOfConstantStringFromLiteral - Return a pointer to a constant array
   /// for the given string literal.
   llvm::Constant *GetAddrOfConstantStringFromLiteral(const StringLiteral *S);
diff --git a/lib/Sema/SemaExpr.cpp b/lib/Sema/SemaExpr.cpp
index 21d0309..61766a8 100644
--- a/lib/Sema/SemaExpr.cpp
+++ b/lib/Sema/SemaExpr.cpp
@@ -1141,7 +1141,7 @@
     StrTy = Context.Char16Ty;
   else if (Literal.isUTF32())
     StrTy = Context.Char32Ty;
-  else if (Literal.Pascal)
+  else if (Literal.isPascal())
     StrTy = Context.UnsignedCharTy;
 
   StringLiteral::StringKind Kind = StringLiteral::Ascii;
diff --git a/lib/Serialization/ASTReaderStmt.cpp b/lib/Serialization/ASTReaderStmt.cpp
index 87912af..e57ab19 100644
--- a/lib/Serialization/ASTReaderStmt.cpp
+++ b/lib/Serialization/ASTReaderStmt.cpp
@@ -372,12 +372,13 @@
   assert(Record[Idx] == E->getNumConcatenated() &&
          "Wrong number of concatenated tokens!");
   ++Idx;
-  E->Kind = static_cast<StringLiteral::StringKind>(Record[Idx++]);
-  E->IsPascal = Record[Idx++];
+  StringLiteral::StringKind kind =
+        static_cast<StringLiteral::StringKind>(Record[Idx++]);
+  bool isPascal = Record[Idx++];
 
   // Read string data
   llvm::SmallString<16> Str(&Record[Idx], &Record[Idx] + Len);
-  E->setString(Reader.getContext(), Str.str());
+  E->setString(Reader.getContext(), Str.str(), kind, isPascal);
   Idx += Len;
 
   // Read source locations
diff --git a/lib/Serialization/ASTWriterStmt.cpp b/lib/Serialization/ASTWriterStmt.cpp
index 0721c29..61570a8 100644
--- a/lib/Serialization/ASTWriterStmt.cpp
+++ b/lib/Serialization/ASTWriterStmt.cpp
@@ -331,7 +331,7 @@
   // StringLiteral. However, we can't do so now because we have no
   // provision for coping with abbreviations when we're jumping around
   // the AST file during deserialization.
-  Record.append(E->getString().begin(), E->getString().end());
+  Record.append(E->getBytes().begin(), E->getBytes().end());
   for (unsigned I = 0, N = E->getNumConcatenated(); I != N; ++I)
     Writer.AddSourceLocation(E->getStrTokenLoc(I), Record);
   Code = serialization::EXPR_STRING_LITERAL;