AMDGPU: Look through a bitcast user of an out argument
This allows handling of a lot more of the interesting
cases in Blender. Most of the large functions unlikely
to be inlined have this pattern.
This is a special case for what clang emits for OpenCL 3
element vectors. Annoyingly, these are emitted as
<3 x elt>* pointers, but accessed as <4 x elt>* operations.
This also needs to handle cases where a struct containing
a single vector is used.
llvm-svn: 309419
diff --git a/llvm/lib/Target/AMDGPU/AMDGPURewriteOutArguments.cpp b/llvm/lib/Target/AMDGPU/AMDGPURewriteOutArguments.cpp
index 4d03040..0983126 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPURewriteOutArguments.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPURewriteOutArguments.cpp
@@ -83,8 +83,10 @@
const DataLayout *DL = nullptr;
MemoryDependenceResults *MDA = nullptr;
+ bool checkArgumentUses(Value &Arg) const;
bool isOutArgumentCandidate(Argument &Arg) const;
+ bool isVec3ToVec4Shuffle(Type *Ty0, Type* Ty1) const;
public:
static char ID;
@@ -110,27 +112,49 @@
char AMDGPURewriteOutArguments::ID = 0;
-bool AMDGPURewriteOutArguments::isOutArgumentCandidate(Argument &Arg) const {
+bool AMDGPURewriteOutArguments::checkArgumentUses(Value &Arg) const {
const int MaxUses = 10;
- const unsigned MaxOutArgSizeBytes = 4 * MaxNumRetRegs;
int UseCount = 0;
- PointerType *ArgTy = dyn_cast<PointerType>(Arg.getType());
-
- // TODO: It might be useful for any out arguments, not just privates.
- if (!ArgTy || (ArgTy->getAddressSpace() != DL->getAllocaAddrSpace() &&
- !AnyAddressSpace) ||
- Arg.hasByValAttr() || Arg.hasStructRetAttr() ||
- DL->getTypeStoreSize(ArgTy->getPointerElementType()) > MaxOutArgSizeBytes) {
- return false;
- }
-
for (Use &U : Arg.uses()) {
StoreInst *SI = dyn_cast<StoreInst>(U.getUser());
if (UseCount > MaxUses)
return false;
- if (!SI || !SI->isSimple() ||
+ if (!SI) {
+ auto *BCI = dyn_cast<BitCastInst>(U.getUser());
+ if (!BCI || !BCI->hasOneUse())
+ return false;
+
+ // We don't handle multiple stores currently, so stores to aggregate
+ // pointers aren't worth the trouble since they are canonically split up.
+ Type *DestEltTy = BCI->getType()->getPointerElementType();
+ if (DestEltTy->isAggregateType())
+ return false;
+
+ // We could handle these if we had a convenient way to bitcast between
+ // them.
+ Type *SrcEltTy = Arg.getType()->getPointerElementType();
+ if (SrcEltTy->isArrayTy())
+ return false;
+
+ // Special case handle structs with single members. It is useful to handle
+ // some casts between structs and non-structs, but we can't bitcast
+ // directly between them. directly bitcast between them. Blender uses
+ // some casts that look like { <3 x float> }* to <4 x float>*
+ if ((SrcEltTy->isStructTy() && (SrcEltTy->getNumContainedTypes() != 1)))
+ return false;
+
+ // Clang emits OpenCL 3-vector type accesses with a bitcast to the
+ // equivalent 4-element vector and accesses that, and we're looking for
+ // this pointer cast.
+ if (DL->getTypeAllocSize(SrcEltTy) != DL->getTypeAllocSize(DestEltTy))
+ return false;
+
+ return checkArgumentUses(*BCI);
+ }
+
+ if (!SI->isSimple() ||
U.getOperandNo() != StoreInst::getPointerOperandIndex())
return false;
@@ -141,11 +165,40 @@
return UseCount > 0;
}
+bool AMDGPURewriteOutArguments::isOutArgumentCandidate(Argument &Arg) const {
+ const unsigned MaxOutArgSizeBytes = 4 * MaxNumRetRegs;
+ PointerType *ArgTy = dyn_cast<PointerType>(Arg.getType());
+
+ // TODO: It might be useful for any out arguments, not just privates.
+ if (!ArgTy || (ArgTy->getAddressSpace() != DL->getAllocaAddrSpace() &&
+ !AnyAddressSpace) ||
+ Arg.hasByValAttr() || Arg.hasStructRetAttr() ||
+ DL->getTypeStoreSize(ArgTy->getPointerElementType()) > MaxOutArgSizeBytes) {
+ return false;
+ }
+
+ return checkArgumentUses(Arg);
+}
+
bool AMDGPURewriteOutArguments::doInitialization(Module &M) {
DL = &M.getDataLayout();
return false;
}
+bool AMDGPURewriteOutArguments::isVec3ToVec4Shuffle(Type *Ty0, Type* Ty1) const {
+ VectorType *VT0 = dyn_cast<VectorType>(Ty0);
+ VectorType *VT1 = dyn_cast<VectorType>(Ty1);
+ if (!VT0 || !VT1)
+ return false;
+
+ if (VT0->getNumElements() != 3 ||
+ VT1->getNumElements() != 4)
+ return false;
+
+ return DL->getTypeSizeInBits(VT0->getElementType()) ==
+ DL->getTypeSizeInBits(VT1->getElementType());
+}
+
bool AMDGPURewriteOutArguments::runOnFunction(Function &F) {
if (skipFunction(F))
return false;
@@ -316,8 +369,33 @@
if (RetVal)
NewRetVal = B.CreateInsertValue(NewRetVal, RetVal, RetIdx++);
+
for (std::pair<Argument *, Value *> ReturnPoint : Replacement.second) {
- NewRetVal = B.CreateInsertValue(NewRetVal, ReturnPoint.second, RetIdx++);
+ Argument *Arg = ReturnPoint.first;
+ Value *Val = ReturnPoint.second;
+ Type *EltTy = Arg->getType()->getPointerElementType();
+ if (Val->getType() != EltTy) {
+ Type *EffectiveEltTy = EltTy;
+ if (StructType *CT = dyn_cast<StructType>(EltTy)) {
+ assert(CT->getNumContainedTypes() == 1);
+ EffectiveEltTy = CT->getContainedType(0);
+ }
+
+ if (DL->getTypeSizeInBits(EffectiveEltTy) !=
+ DL->getTypeSizeInBits(Val->getType())) {
+ assert(isVec3ToVec4Shuffle(EffectiveEltTy, Val->getType()));
+ Val = B.CreateShuffleVector(Val, UndefValue::get(Val->getType()),
+ { 0, 1, 2 });
+ }
+
+ Val = B.CreateBitCast(Val, EffectiveEltTy);
+
+ // Re-create single element composite.
+ if (EltTy != EffectiveEltTy)
+ Val = B.CreateInsertValue(UndefValue::get(EltTy), Val, 0);
+ }
+
+ NewRetVal = B.CreateInsertValue(NewRetVal, Val, RetIdx++);
}
if (RetVal)
@@ -348,13 +426,20 @@
if (!OutArgIndexes.count(Arg.getArgNo()))
continue;
- auto *EltTy = Arg.getType()->getPointerElementType();
+ PointerType *ArgType = cast<PointerType>(Arg.getType());
+
+ auto *EltTy = ArgType->getElementType();
unsigned Align = Arg.getParamAlignment();
if (Align == 0)
Align = DL->getABITypeAlignment(EltTy);
Value *Val = B.CreateExtractValue(StubCall, RetIdx++);
- B.CreateAlignedStore(Val, &Arg, Align);
+ Type *PtrTy = Val->getType()->getPointerTo(ArgType->getAddressSpace());
+
+ // We can peek through bitcasts, so the type may not match.
+ Value *PtrVal = B.CreateBitCast(&Arg, PtrTy);
+
+ B.CreateAlignedStore(Val, PtrVal, Align);
}
if (!RetTy->isVoidTy()) {