Teach ConstantFolding about pointer address spaces

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@188831 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Analysis/ConstantFolding.cpp b/lib/Analysis/ConstantFolding.cpp
index c063d06..0319962 100644
--- a/lib/Analysis/ConstantFolding.cpp
+++ b/lib/Analysis/ConstantFolding.cpp
@@ -367,7 +367,7 @@
 
   if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C)) {
     if (CE->getOpcode() == Instruction::IntToPtr &&
-        CE->getOperand(0)->getType() == TD.getIntPtrType(CE->getContext())) {
+        CE->getOperand(0)->getType() == TD.getIntPtrType(CE->getType())) {
       return ReadDataFromGlobal(CE->getOperand(0), ByteOffset, CurPtr,
                                 BytesLeft, TD);
     }
@@ -379,26 +379,29 @@
 
 static Constant *FoldReinterpretLoadFromConstPtr(Constant *C,
                                                  const DataLayout &TD) {
-  Type *LoadTy = cast<PointerType>(C->getType())->getElementType();
+  PointerType *PTy = cast<PointerType>(C->getType());
+  Type *LoadTy = PTy->getElementType();
   IntegerType *IntType = dyn_cast<IntegerType>(LoadTy);
 
   // If this isn't an integer load we can't fold it directly.
   if (!IntType) {
+    unsigned AS = PTy->getAddressSpace();
+
     // If this is a float/double load, we can try folding it as an int32/64 load
     // and then bitcast the result.  This can be useful for union cases.  Note
     // that address spaces don't matter here since we're not going to result in
     // an actual new load.
     Type *MapTy;
     if (LoadTy->isHalfTy())
-      MapTy = Type::getInt16PtrTy(C->getContext());
+      MapTy = Type::getInt16PtrTy(C->getContext(), AS);
     else if (LoadTy->isFloatTy())
-      MapTy = Type::getInt32PtrTy(C->getContext());
+      MapTy = Type::getInt32PtrTy(C->getContext(), AS);
     else if (LoadTy->isDoubleTy())
-      MapTy = Type::getInt64PtrTy(C->getContext());
+      MapTy = Type::getInt64PtrTy(C->getContext(), AS);
     else if (LoadTy->isVectorTy()) {
-      MapTy = IntegerType::get(C->getContext(),
-                               TD.getTypeAllocSizeInBits(LoadTy));
-      MapTy = PointerType::getUnqual(MapTy);
+      MapTy = PointerType::getIntNPtrTy(C->getContext(),
+                                        TD.getTypeAllocSizeInBits(LoadTy),
+                                        AS);
     } else
       return 0;
 
@@ -413,7 +416,7 @@
     return 0;
 
   GlobalValue *GVal;
-  APInt Offset(TD.getPointerSizeInBits(), 0);
+  APInt Offset(TD.getPointerTypeSizeInBits(PTy), 0);
   if (!IsConstantOffsetFromGlobal(C, GVal, Offset, TD))
     return 0;
 
@@ -606,8 +609,10 @@
 static Constant *CastGEPIndices(ArrayRef<Constant *> Ops,
                                 Type *ResultTy, const DataLayout *TD,
                                 const TargetLibraryInfo *TLI) {
-  if (!TD) return 0;
-  Type *IntPtrTy = TD->getIntPtrType(ResultTy->getContext());
+  if (!TD)
+    return 0;
+
+  Type *IntPtrTy = TD->getIntPtrType(ResultTy);
 
   bool Any = false;
   SmallVector<Constant*, 32> NewIdxs;
@@ -665,7 +670,7 @@
       !Ptr->getType()->isPointerTy())
     return 0;
 
-  Type *IntPtrTy = TD->getIntPtrType(Ptr->getContext());
+  Type *IntPtrTy = TD->getIntPtrType(Ptr->getType());
   Type *ResultElementTy = ResultTy->getPointerElementType();
 
   // If this is a constant expr gep that is effectively computing an
@@ -741,7 +746,8 @@
   // Also, this helps GlobalOpt do SROA on GlobalVariables.
   Type *Ty = Ptr->getType();
   assert(Ty->isPointerTy() && "Forming regular GEP of non-pointer type");
-  SmallVector<Constant*, 32> NewIdxs;
+  SmallVector<Constant *, 32> NewIdxs;
+
   do {
     if (SequentialType *ATy = dyn_cast<SequentialType>(Ty)) {
       if (ATy->isPointerTy()) {
@@ -756,7 +762,6 @@
 
       // Determine which element of the array the offset points into.
       APInt ElemSize(BitWidth, TD->getTypeAllocSize(ATy->getElementType()));
-      IntegerType *IntPtrTy = TD->getIntPtrType(Ty->getContext());
       if (ElemSize == 0)
         // The element size is 0. This may be [0 x Ty]*, so just use a zero
         // index for this level and proceed to the next level to see if it can
@@ -968,7 +973,7 @@
       if (TD && CE->getOpcode() == Instruction::IntToPtr) {
         Constant *Input = CE->getOperand(0);
         unsigned InWidth = Input->getType()->getScalarSizeInBits();
-        if (TD->getPointerSizeInBits() < InWidth) {
+        if (TD->getPointerTypeSizeInBits(CE->getType()) < InWidth) {
           Constant *Mask =
             ConstantInt::get(CE->getContext(), APInt::getLowBitsSet(InWidth,
                                                   TD->getPointerSizeInBits()));
@@ -983,11 +988,19 @@
     // If the input is a ptrtoint, turn the pair into a ptr to ptr bitcast if
     // the int size is >= the ptr size.  This requires knowing the width of a
     // pointer, so it can't be done in ConstantExpr::getCast.
-    if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Ops[0]))
-      if (TD &&
-          TD->getPointerSizeInBits() <= CE->getType()->getScalarSizeInBits() &&
-          CE->getOpcode() == Instruction::PtrToInt)
-        return FoldBitCast(CE->getOperand(0), DestTy, *TD);
+    if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Ops[0])) {
+      if (TD && CE->getOpcode() == Instruction::PtrToInt) {
+        Constant *SrcPtr = CE->getOperand(0);
+        unsigned SrcPtrSize = TD->getPointerTypeSizeInBits(SrcPtr->getType());
+        unsigned MidIntSize = CE->getType()->getScalarSizeInBits();
+
+        if (MidIntSize >= SrcPtrSize) {
+          unsigned DestPtrSize = TD->getPointerTypeSizeInBits(DestTy);
+          if (SrcPtrSize == DestPtrSize)
+            return FoldBitCast(CE->getOperand(0), DestTy, *TD);
+        }
+      }
+    }
 
     return ConstantExpr::getCast(Opcode, Ops[0], DestTy);
   case Instruction::Trunc:
@@ -1039,8 +1052,8 @@
   // around to know if bit truncation is happening.
   if (ConstantExpr *CE0 = dyn_cast<ConstantExpr>(Ops0)) {
     if (TD && Ops1->isNullValue()) {
-      Type *IntPtrTy = TD->getIntPtrType(CE0->getContext());
       if (CE0->getOpcode() == Instruction::IntToPtr) {
+        Type *IntPtrTy = TD->getIntPtrType(CE0->getType());
         // Convert the integer value to the right size to ensure we get the
         // proper extension or truncation.
         Constant *C = ConstantExpr::getIntegerCast(CE0->getOperand(0),
@@ -1051,19 +1064,21 @@
 
       // Only do this transformation if the int is intptrty in size, otherwise
       // there is a truncation or extension that we aren't modeling.
-      if (CE0->getOpcode() == Instruction::PtrToInt &&
-          CE0->getType() == IntPtrTy) {
-        Constant *C = CE0->getOperand(0);
-        Constant *Null = Constant::getNullValue(C->getType());
-        return ConstantFoldCompareInstOperands(Predicate, C, Null, TD, TLI);
+      if (CE0->getOpcode() == Instruction::PtrToInt) {
+        Type *IntPtrTy = TD->getIntPtrType(CE0->getOperand(0)->getType());
+        if (CE0->getType() == IntPtrTy) {
+          Constant *C = CE0->getOperand(0);
+          Constant *Null = Constant::getNullValue(C->getType());
+          return ConstantFoldCompareInstOperands(Predicate, C, Null, TD, TLI);
+        }
       }
     }
 
     if (ConstantExpr *CE1 = dyn_cast<ConstantExpr>(Ops1)) {
       if (TD && CE0->getOpcode() == CE1->getOpcode()) {
-        Type *IntPtrTy = TD->getIntPtrType(CE0->getContext());
-
         if (CE0->getOpcode() == Instruction::IntToPtr) {
+          Type *IntPtrTy = TD->getIntPtrType(CE0->getType());
+
           // Convert the integer value to the right size to ensure we get the
           // proper extension or truncation.
           Constant *C0 = ConstantExpr::getIntegerCast(CE0->getOperand(0),
@@ -1075,11 +1090,17 @@
 
         // Only do this transformation if the int is intptrty in size, otherwise
         // there is a truncation or extension that we aren't modeling.
-        if ((CE0->getOpcode() == Instruction::PtrToInt &&
-             CE0->getType() == IntPtrTy &&
-             CE0->getOperand(0)->getType() == CE1->getOperand(0)->getType()))
-          return ConstantFoldCompareInstOperands(Predicate, CE0->getOperand(0),
-                                                 CE1->getOperand(0), TD, TLI);
+        if (CE0->getOpcode() == Instruction::PtrToInt) {
+          Type *IntPtrTy = TD->getIntPtrType(CE0->getOperand(0)->getType());
+          if (CE0->getType() == IntPtrTy &&
+              CE0->getOperand(0)->getType() == CE1->getOperand(0)->getType()) {
+            return ConstantFoldCompareInstOperands(Predicate,
+                                                   CE0->getOperand(0),
+                                                   CE1->getOperand(0),
+                                                   TD,
+                                                   TLI);
+          }
+        }
       }
     }