blob: 6c1cc80a9fd8fa1d77d894945cdd6f2549a11125 [file] [log] [blame]
Justin Holewinski3d140fc2014-11-05 18:19:30 +00001//===-- NVPTXLowerStructArgs.cpp - Copy struct args to local memory =====--===//
2//
3// The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// Copy struct args to local memory. This is needed for kernel functions only.
11// This is a preparation for handling cases like
12//
13// kernel void foo(struct A arg, ...)
14// {
15// struct A *p = &arg;
16// ...
17// ... = p->filed1 ... (this is no generic address for .param)
18// p->filed2 = ... (this is no write access to .param)
19// }
20//
21//===----------------------------------------------------------------------===//
22
23#include "NVPTX.h"
24#include "NVPTXUtilities.h"
25#include "llvm/IR/Function.h"
26#include "llvm/IR/Instructions.h"
27#include "llvm/IR/IntrinsicInst.h"
28#include "llvm/IR/Module.h"
29#include "llvm/IR/Type.h"
30#include "llvm/Pass.h"
31
32using namespace llvm;
33
34namespace llvm {
35void initializeNVPTXLowerStructArgsPass(PassRegistry &);
36}
37
38class LLVM_LIBRARY_VISIBILITY NVPTXLowerStructArgs : public FunctionPass {
39 bool runOnFunction(Function &F) override;
40
41 void handleStructPtrArgs(Function &);
42 void handleParam(Argument *);
43
44public:
45 static char ID; // Pass identification, replacement for typeid
46 NVPTXLowerStructArgs() : FunctionPass(ID) {}
47 const char *getPassName() const override {
48 return "Copy structure (byval *) arguments to stack";
49 }
50};
51
52char NVPTXLowerStructArgs::ID = 1;
53
54INITIALIZE_PASS(NVPTXLowerStructArgs, "nvptx-lower-struct-args",
55 "Lower structure arguments (NVPTX)", false, false)
56
57void NVPTXLowerStructArgs::handleParam(Argument *Arg) {
58 Function *Func = Arg->getParent();
59 Instruction *FirstInst = &(Func->getEntryBlock().front());
60 const PointerType *PType = dyn_cast<PointerType>(Arg->getType());
61
62 assert(PType && "Expecting pointer type in handleParam");
63
64 const Type *StructType = PType->getElementType();
65
66 AllocaInst *AllocA =
67 new AllocaInst((Type *)StructType, Arg->getName(), FirstInst);
68
69 /* Set the alignment to alignment of the byval parameter. This is because,
70 * later load/stores assume that alignment, and we are going to replace
71 * the use of the byval parameter with this alloca instruction.
72 */
73 AllocA->setAlignment(Func->getParamAlignment(Arg->getArgNo() + 1));
74
75 Arg->replaceAllUsesWith(AllocA);
76
77 // Get the cvt.gen.to.param intrinsic
78 const Type *CvtTypes[2] = {
79 Type::getInt8PtrTy(Func->getParent()->getContext(), ADDRESS_SPACE_PARAM),
80 Type::getInt8PtrTy(Func->getParent()->getContext(), ADDRESS_SPACE_GENERIC)
81 };
82 Function *CvtFunc = (Function *)Intrinsic::getDeclaration(
83 Func->getParent(), Intrinsic::nvvm_ptr_gen_to_param,
84 ArrayRef<Type *>((Type **)CvtTypes, 2));
85 std::vector<Value *> BC1;
86 BC1.push_back(
87 new BitCastInst(Arg, Type::getInt8PtrTy(Func->getParent()->getContext(),
88 ADDRESS_SPACE_GENERIC),
89 Arg->getName(), FirstInst));
90 CallInst *CallCVT = CallInst::Create(CvtFunc, ArrayRef<Value *>(BC1),
91 "cvt_to_param", FirstInst);
92
93 BitCastInst *BitCast = new BitCastInst(
94 CallCVT, PointerType::get((Type *)StructType, ADDRESS_SPACE_PARAM),
95 Arg->getName(), FirstInst);
96 LoadInst *LI = new LoadInst(BitCast, Arg->getName(), FirstInst);
97 new StoreInst(LI, AllocA, FirstInst);
98}
99
100/// =============================================================================
101/// If the function had a struct ptr arg, say foo(%struct.x *byval %d), then
102/// add the following instructions to the first basic block :
103///
104/// %temp = alloca %struct.x, align 8
105/// %tt1 = bitcast %struct.x * %d to i8 *
106/// %tt2 = llvm.nvvm.cvt.gen.to.param %tt2
107/// %tempd = bitcast i8 addrspace(101) * to %struct.x addrspace(101) *
108/// %tv = load %struct.x addrspace(101) * %tempd
109/// store %struct.x %tv, %struct.x * %temp, align 8
110///
111/// The above code allocates some space in the stack and copies the incoming
112/// struct from param space to local space.
113/// Then replace all occurences of %d by %temp.
114/// =============================================================================
115void NVPTXLowerStructArgs::handleStructPtrArgs(Function &F) {
116 const AttributeSet &PAL = F.getAttributes();
117
118 unsigned Idx = 1;
119
120 for (Argument &Arg : F.args()) {
121 const Type *Ty = Arg.getType();
122
123 const PointerType *PTy = dyn_cast<PointerType>(Ty);
124
125 if (PTy) {
126 if (PAL.hasAttribute(Idx, Attribute::ByVal)) {
127 // cout << "Has struct ptr args" << std::endl;
128 handleParam(&Arg);
129 }
130 }
131 Idx++;
132 }
133}
134
135/// =============================================================================
136/// Main function for this pass.
137/// =============================================================================
138bool NVPTXLowerStructArgs::runOnFunction(Function &F) {
139 // Skip non-kernels. See the comments at the top of this file.
140 if (!isKernelFunction(F))
141 return false;
142
143 handleStructPtrArgs(F);
144
145 return true;
146}
147
148FunctionPass *llvm::createNVPTXLowerStructArgsPass() {
149 return new NVPTXLowerStructArgs();
150}