SROA: Allow eliminating addrspacecasted allocas

There is a circular dependency between SROA and InferAddressSpaces
today that requires running both multiple times in order to be able to
eliminate all simple allocas and addrspacecasts. InferAddressSpaces
can't remove addrspacecasts when written to memory, and SROA helps
move pointers out of memory.

This should avoid inserting new commuting addrspacecasts with GEPs,
since there are unresolved questions about pointer wrapping between
different address spaces.

For now, don't replace volatile operations that don't match the alloca
addrspace, as it would change the address space of the access. It may
be still OK to insert an addrspacecast from the new alloca, but be
more conservative for now.

llvm-svn: 363462
diff --git a/llvm/lib/Analysis/PtrUseVisitor.cpp b/llvm/lib/Analysis/PtrUseVisitor.cpp
index f78446c..9a834ba 100644
--- a/llvm/lib/Analysis/PtrUseVisitor.cpp
+++ b/llvm/lib/Analysis/PtrUseVisitor.cpp
@@ -34,5 +34,11 @@
   if (!IsOffsetKnown)
     return false;
 
-  return GEPI.accumulateConstantOffset(DL, Offset);
+  APInt TmpOffset(DL.getIndexTypeSizeInBits(GEPI.getType()), 0);
+  if (GEPI.accumulateConstantOffset(DL, TmpOffset)) {
+    Offset += TmpOffset.sextOrTrunc(Offset.getBitWidth());
+    return true;
+  }
+
+  return false;
 }
diff --git a/llvm/lib/Transforms/Scalar/SROA.cpp b/llvm/lib/Transforms/Scalar/SROA.cpp
index 790a16d..6454f1e 100644
--- a/llvm/lib/Transforms/Scalar/SROA.cpp
+++ b/llvm/lib/Transforms/Scalar/SROA.cpp
@@ -713,6 +713,13 @@
     return Base::visitBitCastInst(BC);
   }
 
+  void visitAddrSpaceCastInst(AddrSpaceCastInst &ASC) {
+    if (ASC.use_empty())
+      return markAsDead(ASC);
+
+    return Base::visitAddrSpaceCastInst(ASC);
+  }
+
   void visitGetElementPtrInst(GetElementPtrInst &GEPI) {
     if (GEPI.use_empty())
       return markAsDead(GEPI);
@@ -776,7 +783,10 @@
     if (!IsOffsetKnown)
       return PI.setAborted(&LI);
 
-    const DataLayout &DL = LI.getModule()->getDataLayout();
+    if (LI.isVolatile() &&
+        LI.getPointerAddressSpace() != DL.getAllocaAddrSpace())
+      return PI.setAborted(&LI);
+
     uint64_t Size = DL.getTypeStoreSize(LI.getType());
     return handleLoadOrStore(LI.getType(), LI, Offset, Size, LI.isVolatile());
   }
@@ -788,7 +798,10 @@
     if (!IsOffsetKnown)
       return PI.setAborted(&SI);
 
-    const DataLayout &DL = SI.getModule()->getDataLayout();
+    if (SI.isVolatile() &&
+        SI.getPointerAddressSpace() != DL.getAllocaAddrSpace())
+      return PI.setAborted(&SI);
+
     uint64_t Size = DL.getTypeStoreSize(ValOp->getType());
 
     // If this memory access can be shown to *statically* extend outside the
@@ -823,6 +836,11 @@
     if (!IsOffsetKnown)
       return PI.setAborted(&II);
 
+    // Don't replace this with a store with a different address space.  TODO:
+    // Use a store with the casted new alloca?
+    if (II.isVolatile() && II.getDestAddressSpace() != DL.getAllocaAddrSpace())
+      return PI.setAborted(&II);
+
     insertUse(II, Offset, Length ? Length->getLimitedValue()
                                  : AllocSize - Offset.getLimitedValue(),
               (bool)Length);
@@ -842,6 +860,13 @@
     if (!IsOffsetKnown)
       return PI.setAborted(&II);
 
+    // Don't replace this with a load/store with a different address space.
+    // TODO: Use a store with the casted new alloca?
+    if (II.isVolatile() &&
+        (II.getDestAddressSpace() != DL.getAllocaAddrSpace() ||
+         II.getSourceAddressSpace() != DL.getAllocaAddrSpace()))
+      return PI.setAborted(&II);
+
     // This side of the transfer is completely out-of-bounds, and so we can
     // nuke the entire transfer. However, we also need to nuke the other side
     // if already added to our partitions.
@@ -949,7 +974,7 @@
         if (!GEP->hasAllZeroIndices())
           return GEP;
       } else if (!isa<BitCastInst>(I) && !isa<PHINode>(I) &&
-                 !isa<SelectInst>(I)) {
+                 !isa<SelectInst>(I) && !isa<AddrSpaceCastInst>(I)) {
         return I;
       }
 
@@ -1561,7 +1586,8 @@
   Value *Int8Ptr = nullptr;
   APInt Int8PtrOffset(Offset.getBitWidth(), 0);
 
-  Type *TargetTy = PointerTy->getPointerElementType();
+  PointerType *TargetPtrTy = cast<PointerType>(PointerTy);
+  Type *TargetTy = TargetPtrTy->getElementType();
 
   do {
     // First fold any existing GEPs into the offset.
@@ -1630,8 +1656,11 @@
   Ptr = OffsetPtr;
 
   // On the off chance we were targeting i8*, guard the bitcast here.
-  if (Ptr->getType() != PointerTy)
-    Ptr = IRB.CreateBitCast(Ptr, PointerTy, NamePrefix + "sroa_cast");
+  if (cast<PointerType>(Ptr->getType()) != TargetPtrTy) {
+    Ptr = IRB.CreatePointerBitCastOrAddrSpaceCast(Ptr,
+                                                  TargetPtrTy,
+                                                  NamePrefix + "sroa_cast");
+  }
 
   return Ptr;
 }
@@ -3082,8 +3111,9 @@
         continue;
       }
 
-      assert(isa<BitCastInst>(I) || isa<PHINode>(I) ||
-             isa<SelectInst>(I) || isa<GetElementPtrInst>(I));
+      assert(isa<BitCastInst>(I) || isa<AddrSpaceCastInst>(I) ||
+             isa<PHINode>(I) || isa<SelectInst>(I) ||
+             isa<GetElementPtrInst>(I));
       for (User *U : I->users())
         if (Visited.insert(cast<Instruction>(U)).second)
           Uses.push_back(cast<Instruction>(U));
@@ -3384,6 +3414,11 @@
     return false;
   }
 
+  bool visitAddrSpaceCastInst(AddrSpaceCastInst &ASC) {
+    enqueueUsers(ASC);
+    return false;
+  }
+
   bool visitGetElementPtrInst(GetElementPtrInst &GEPI) {
     enqueueUsers(GEPI);
     return false;