AMDGPU: Look through a bitcast user of an out argument

This allows handling of a lot more of the interesting
cases in Blender. Most of the large functions unlikely
to be inlined have this pattern.

This is a special case for what clang emits for OpenCL 3
element vectors. Annoyingly, these are emitted as
<3 x elt>* pointers, but accessed as <4 x elt>* operations.
This also needs to handle cases where a struct containing
a single vector is used.

llvm-svn: 309419
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURewriteOutArguments.cpp b/llvm/lib/Target/AMDGPU/AMDGPURewriteOutArguments.cpp
index 4d03040..0983126 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURewriteOutArguments.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURewriteOutArguments.cpp
@@ -83,8 +83,10 @@
   const DataLayout *DL = nullptr;
   MemoryDependenceResults *MDA = nullptr;
 
+  bool checkArgumentUses(Value &Arg) const;
   bool isOutArgumentCandidate(Argument &Arg) const;
 
+  bool isVec3ToVec4Shuffle(Type *Ty0, Type* Ty1) const;
 public:
   static char ID;
 
@@ -110,27 +112,49 @@
 
 char AMDGPURewriteOutArguments::ID = 0;
 
-bool AMDGPURewriteOutArguments::isOutArgumentCandidate(Argument &Arg) const {
+bool AMDGPURewriteOutArguments::checkArgumentUses(Value &Arg) const {
   const int MaxUses = 10;
-  const unsigned MaxOutArgSizeBytes = 4 * MaxNumRetRegs;
   int UseCount = 0;
 
-  PointerType *ArgTy = dyn_cast<PointerType>(Arg.getType());
-
-  // TODO: It might be useful for any out arguments, not just privates.
-  if (!ArgTy || (ArgTy->getAddressSpace() != DL->getAllocaAddrSpace() &&
-                 !AnyAddressSpace) ||
-      Arg.hasByValAttr() || Arg.hasStructRetAttr() ||
-      DL->getTypeStoreSize(ArgTy->getPointerElementType()) > MaxOutArgSizeBytes) {
-    return false;
-  }
-
   for (Use &U : Arg.uses()) {
     StoreInst *SI = dyn_cast<StoreInst>(U.getUser());
     if (UseCount > MaxUses)
       return false;
 
-    if (!SI || !SI->isSimple() ||
+    if (!SI) {
+      auto *BCI = dyn_cast<BitCastInst>(U.getUser());
+      if (!BCI || !BCI->hasOneUse())
+        return false;
+
+      // We don't handle multiple stores currently, so stores to aggregate
+      // pointers aren't worth the trouble since they are canonically split up.
+      Type *DestEltTy = BCI->getType()->getPointerElementType();
+      if (DestEltTy->isAggregateType())
+        return false;
+
+      // We could handle these if we had a convenient way to bitcast between
+      // them.
+      Type *SrcEltTy = Arg.getType()->getPointerElementType();
+      if (SrcEltTy->isArrayTy())
+        return false;
+
+      // Special case handle structs with single members. It is useful to handle
+      // some casts between structs and non-structs, but we can't bitcast
+      // directly between them.  directly bitcast between them.  Blender uses
+      // some casts that look like { <3 x float> }* to <4 x float>*
+      if ((SrcEltTy->isStructTy() && (SrcEltTy->getNumContainedTypes() != 1)))
+        return false;
+
+      // Clang emits OpenCL 3-vector type accesses with a bitcast to the
+      // equivalent 4-element vector and accesses that, and we're looking for
+      // this pointer cast.
+      if (DL->getTypeAllocSize(SrcEltTy) != DL->getTypeAllocSize(DestEltTy))
+        return false;
+
+      return checkArgumentUses(*BCI);
+    }
+
+    if (!SI->isSimple() ||
         U.getOperandNo() != StoreInst::getPointerOperandIndex())
       return false;
 
@@ -141,11 +165,40 @@
   return UseCount > 0;
 }
 
+bool AMDGPURewriteOutArguments::isOutArgumentCandidate(Argument &Arg) const {
+  const unsigned MaxOutArgSizeBytes = 4 * MaxNumRetRegs;
+  PointerType *ArgTy = dyn_cast<PointerType>(Arg.getType());
+
+  // TODO: It might be useful for any out arguments, not just privates.
+  if (!ArgTy || (ArgTy->getAddressSpace() != DL->getAllocaAddrSpace() &&
+                 !AnyAddressSpace) ||
+      Arg.hasByValAttr() || Arg.hasStructRetAttr() ||
+      DL->getTypeStoreSize(ArgTy->getPointerElementType()) > MaxOutArgSizeBytes) {
+    return false;
+  }
+
+  return checkArgumentUses(Arg);
+}
+
 bool AMDGPURewriteOutArguments::doInitialization(Module &M) {
   DL = &M.getDataLayout();
   return false;
 }
 
+bool AMDGPURewriteOutArguments::isVec3ToVec4Shuffle(Type *Ty0, Type* Ty1) const {
+  VectorType *VT0 = dyn_cast<VectorType>(Ty0);
+  VectorType *VT1 = dyn_cast<VectorType>(Ty1);
+  if (!VT0 || !VT1)
+    return false;
+
+  if (VT0->getNumElements() != 3 ||
+      VT1->getNumElements() != 4)
+    return false;
+
+  return DL->getTypeSizeInBits(VT0->getElementType()) ==
+         DL->getTypeSizeInBits(VT1->getElementType());
+}
+
 bool AMDGPURewriteOutArguments::runOnFunction(Function &F) {
   if (skipFunction(F))
     return false;
@@ -316,8 +369,33 @@
     if (RetVal)
       NewRetVal = B.CreateInsertValue(NewRetVal, RetVal, RetIdx++);
 
+
     for (std::pair<Argument *, Value *> ReturnPoint : Replacement.second) {
-      NewRetVal = B.CreateInsertValue(NewRetVal, ReturnPoint.second, RetIdx++);
+      Argument *Arg = ReturnPoint.first;
+      Value *Val = ReturnPoint.second;
+      Type *EltTy = Arg->getType()->getPointerElementType();
+      if (Val->getType() != EltTy) {
+        Type *EffectiveEltTy = EltTy;
+        if (StructType *CT = dyn_cast<StructType>(EltTy)) {
+          assert(CT->getNumContainedTypes() == 1);
+          EffectiveEltTy = CT->getContainedType(0);
+        }
+
+        if (DL->getTypeSizeInBits(EffectiveEltTy) !=
+            DL->getTypeSizeInBits(Val->getType())) {
+          assert(isVec3ToVec4Shuffle(EffectiveEltTy, Val->getType()));
+          Val = B.CreateShuffleVector(Val, UndefValue::get(Val->getType()),
+                                      { 0, 1, 2 });
+        }
+
+        Val = B.CreateBitCast(Val, EffectiveEltTy);
+
+        // Re-create single element composite.
+        if (EltTy != EffectiveEltTy)
+          Val = B.CreateInsertValue(UndefValue::get(EltTy), Val, 0);
+      }
+
+      NewRetVal = B.CreateInsertValue(NewRetVal, Val, RetIdx++);
     }
 
     if (RetVal)
@@ -348,13 +426,20 @@
     if (!OutArgIndexes.count(Arg.getArgNo()))
       continue;
 
-    auto *EltTy = Arg.getType()->getPointerElementType();
+    PointerType *ArgType = cast<PointerType>(Arg.getType());
+
+    auto *EltTy = ArgType->getElementType();
     unsigned Align = Arg.getParamAlignment();
     if (Align == 0)
       Align = DL->getABITypeAlignment(EltTy);
 
     Value *Val = B.CreateExtractValue(StubCall, RetIdx++);
-    B.CreateAlignedStore(Val, &Arg, Align);
+    Type *PtrTy = Val->getType()->getPointerTo(ArgType->getAddressSpace());
+
+    // We can peek through bitcasts, so the type may not match.
+    Value *PtrVal = B.CreateBitCast(&Arg, PtrTy);
+
+    B.CreateAlignedStore(Val, PtrVal, Align);
   }
 
   if (!RetTy->isVoidTy()) {