blob: 68dfbb716139010c910a0727ffc71577985109a2 [file] [log] [blame]
Justin Holewinski3d140fc2014-11-05 18:19:30 +00001//===-- NVPTXLowerStructArgs.cpp - Copy struct args to local memory =====--===//
2//
3// The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// Copy struct args to local memory. This is needed for kernel functions only.
11// This is a preparation for handling cases like
12//
13// kernel void foo(struct A arg, ...)
14// {
15// struct A *p = &arg;
16// ...
17// ... = p->filed1 ... (this is no generic address for .param)
18// p->filed2 = ... (this is no write access to .param)
19// }
20//
21//===----------------------------------------------------------------------===//
22
23#include "NVPTX.h"
24#include "NVPTXUtilities.h"
25#include "llvm/IR/Function.h"
26#include "llvm/IR/Instructions.h"
27#include "llvm/IR/IntrinsicInst.h"
28#include "llvm/IR/Module.h"
29#include "llvm/IR/Type.h"
30#include "llvm/Pass.h"
31
32using namespace llvm;
33
34namespace llvm {
35void initializeNVPTXLowerStructArgsPass(PassRegistry &);
36}
37
Benjamin Kramerd58792f2015-03-10 20:07:44 +000038namespace {
39class NVPTXLowerStructArgs : public FunctionPass {
Justin Holewinski3d140fc2014-11-05 18:19:30 +000040 bool runOnFunction(Function &F) override;
41
42 void handleStructPtrArgs(Function &);
43 void handleParam(Argument *);
44
45public:
46 static char ID; // Pass identification, replacement for typeid
47 NVPTXLowerStructArgs() : FunctionPass(ID) {}
48 const char *getPassName() const override {
49 return "Copy structure (byval *) arguments to stack";
50 }
51};
Benjamin Kramerd58792f2015-03-10 20:07:44 +000052} // namespace
Justin Holewinski3d140fc2014-11-05 18:19:30 +000053
54char NVPTXLowerStructArgs::ID = 1;
55
56INITIALIZE_PASS(NVPTXLowerStructArgs, "nvptx-lower-struct-args",
57 "Lower structure arguments (NVPTX)", false, false)
58
59void NVPTXLowerStructArgs::handleParam(Argument *Arg) {
60 Function *Func = Arg->getParent();
61 Instruction *FirstInst = &(Func->getEntryBlock().front());
Eli Bendersky799c5642014-11-06 17:05:49 +000062 PointerType *PType = dyn_cast<PointerType>(Arg->getType());
Justin Holewinski3d140fc2014-11-05 18:19:30 +000063
64 assert(PType && "Expecting pointer type in handleParam");
65
Eli Bendersky799c5642014-11-06 17:05:49 +000066 Type *StructType = PType->getElementType();
67 AllocaInst *AllocA = new AllocaInst(StructType, Arg->getName(), FirstInst);
Justin Holewinski3d140fc2014-11-05 18:19:30 +000068
69 /* Set the alignment to alignment of the byval parameter. This is because,
70 * later load/stores assume that alignment, and we are going to replace
71 * the use of the byval parameter with this alloca instruction.
72 */
73 AllocA->setAlignment(Func->getParamAlignment(Arg->getArgNo() + 1));
74
75 Arg->replaceAllUsesWith(AllocA);
76
77 // Get the cvt.gen.to.param intrinsic
Eli Bendersky799c5642014-11-06 17:05:49 +000078 Type *CvtTypes[] = {
79 Type::getInt8PtrTy(Func->getParent()->getContext(), ADDRESS_SPACE_PARAM),
80 Type::getInt8PtrTy(Func->getParent()->getContext(),
81 ADDRESS_SPACE_GENERIC)};
Aaron Ballmane77ffe32014-11-06 14:32:30 +000082 Function *CvtFunc = Intrinsic::getDeclaration(
Eli Bendersky799c5642014-11-06 17:05:49 +000083 Func->getParent(), Intrinsic::nvvm_ptr_gen_to_param, CvtTypes);
84
85 Value *BitcastArgs[] = {
Justin Holewinski3d140fc2014-11-05 18:19:30 +000086 new BitCastInst(Arg, Type::getInt8PtrTy(Func->getParent()->getContext(),
87 ADDRESS_SPACE_GENERIC),
Eli Bendersky799c5642014-11-06 17:05:49 +000088 Arg->getName(), FirstInst)};
89 CallInst *CallCVT =
90 CallInst::Create(CvtFunc, BitcastArgs, "cvt_to_param", FirstInst);
Justin Holewinski3d140fc2014-11-05 18:19:30 +000091
Eli Bendersky799c5642014-11-06 17:05:49 +000092 BitCastInst *BitCast = new BitCastInst(
93 CallCVT, PointerType::get(StructType, ADDRESS_SPACE_PARAM),
94 Arg->getName(), FirstInst);
Justin Holewinski3d140fc2014-11-05 18:19:30 +000095 LoadInst *LI = new LoadInst(BitCast, Arg->getName(), FirstInst);
96 new StoreInst(LI, AllocA, FirstInst);
97}
98
Eli Bendersky799c5642014-11-06 17:05:49 +000099// =============================================================================
100// If the function had a struct ptr arg, say foo(%struct.x *byval %d), then
101// add the following instructions to the first basic block :
102//
103// %temp = alloca %struct.x, align 8
104// %tt1 = bitcast %struct.x * %d to i8 *
105// %tt2 = llvm.nvvm.cvt.gen.to.param %tt2
106// %tempd = bitcast i8 addrspace(101) * to %struct.x addrspace(101) *
107// %tv = load %struct.x addrspace(101) * %tempd
108// store %struct.x %tv, %struct.x * %temp, align 8
109//
110// The above code allocates some space in the stack and copies the incoming
111// struct from param space to local space.
112// Then replace all occurences of %d by %temp.
113// =============================================================================
Justin Holewinski3d140fc2014-11-05 18:19:30 +0000114void NVPTXLowerStructArgs::handleStructPtrArgs(Function &F) {
Justin Holewinski3d140fc2014-11-05 18:19:30 +0000115 for (Argument &Arg : F.args()) {
Eli Bendersky799c5642014-11-06 17:05:49 +0000116 if (Arg.getType()->isPointerTy() && Arg.hasByValAttr()) {
117 handleParam(&Arg);
Justin Holewinski3d140fc2014-11-05 18:19:30 +0000118 }
Justin Holewinski3d140fc2014-11-05 18:19:30 +0000119 }
120}
121
Eli Bendersky799c5642014-11-06 17:05:49 +0000122// =============================================================================
123// Main function for this pass.
124// =============================================================================
Justin Holewinski3d140fc2014-11-05 18:19:30 +0000125bool NVPTXLowerStructArgs::runOnFunction(Function &F) {
126 // Skip non-kernels. See the comments at the top of this file.
127 if (!isKernelFunction(F))
128 return false;
129
130 handleStructPtrArgs(F);
Justin Holewinski3d140fc2014-11-05 18:19:30 +0000131 return true;
132}
133
134FunctionPass *llvm::createNVPTXLowerStructArgsPass() {
135 return new NVPTXLowerStructArgs();
136}