blob: b533f316d8a989594b18f9c93f3d4e76dd133584 [file] [log] [blame]
Jingyue Wua2f60272015-06-04 21:28:26 +00001//===-- NVPTXLowerKernelArgs.cpp - Lower kernel arguments -----------------===//
2//
3// The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// Pointer arguments to kernel functions need to be lowered specially.
11//
12// 1. Copy byval struct args to local memory. This is a preparation for handling
13// cases like
14//
15// kernel void foo(struct A arg, ...)
16// {
17// struct A *p = &arg;
18// ...
19// ... = p->filed1 ... (this is no generic address for .param)
20// p->filed2 = ... (this is no write access to .param)
21// }
22//
23// 2. Convert non-byval pointer arguments of CUDA kernels to pointers in the
24// global address space. This allows later optimizations to emit
25// ld.global.*/st.global.* for accessing these pointer arguments. For
26// example,
27//
28// define void @foo(float* %input) {
29// %v = load float, float* %input, align 4
30// ...
31// }
32//
33// becomes
34//
35// define void @foo(float* %input) {
36// %input2 = addrspacecast float* %input to float addrspace(1)*
37// %input3 = addrspacecast float addrspace(1)* %input2 to float*
38// %v = load float, float* %input3, align 4
39// ...
40// }
41//
42// Later, NVPTXFavorNonGenericAddrSpaces will optimize it to
43//
44// define void @foo(float* %input) {
45// %input2 = addrspacecast float* %input to float addrspace(1)*
46// %v = load float, float addrspace(1)* %input2, align 4
47// ...
48// }
49//
50// TODO: merge this pass with NVPTXFavorNonGenericAddrSpace so that other passes
51// don't cancel the addrspacecast pair this pass emits.
52//===----------------------------------------------------------------------===//
53
54#include "NVPTX.h"
55#include "NVPTXUtilities.h"
56#include "NVPTXTargetMachine.h"
57#include "llvm/IR/Function.h"
58#include "llvm/IR/Instructions.h"
59#include "llvm/IR/Module.h"
60#include "llvm/IR/Type.h"
61#include "llvm/Pass.h"
62
63using namespace llvm;
64
65namespace llvm {
66void initializeNVPTXLowerKernelArgsPass(PassRegistry &);
67}
68
69namespace {
70class NVPTXLowerKernelArgs : public FunctionPass {
71 bool runOnFunction(Function &F) override;
72
73 // handle byval parameters
74 void handleByValParam(Argument *);
75 // handle non-byval pointer parameters
76 void handlePointerParam(Argument *);
77
78public:
79 static char ID; // Pass identification, replacement for typeid
80 NVPTXLowerKernelArgs(const NVPTXTargetMachine *TM = nullptr)
81 : FunctionPass(ID), TM(TM) {}
82 const char *getPassName() const override {
83 return "Lower pointer arguments of CUDA kernels";
84 }
85
86private:
87 const NVPTXTargetMachine *TM;
88};
89} // namespace
90
91char NVPTXLowerKernelArgs::ID = 1;
92
93INITIALIZE_PASS(NVPTXLowerKernelArgs, "nvptx-lower-kernel-args",
94 "Lower kernel arguments (NVPTX)", false, false)
95
96// =============================================================================
97// If the function had a byval struct ptr arg, say foo(%struct.x *byval %d),
98// then add the following instructions to the first basic block:
99//
100// %temp = alloca %struct.x, align 8
101// %tempd = addrspacecast %struct.x* %d to %struct.x addrspace(101)*
102// %tv = load %struct.x addrspace(101)* %tempd
103// store %struct.x %tv, %struct.x* %temp, align 8
104//
105// The above code allocates some space in the stack and copies the incoming
106// struct from param space to local space.
107// Then replace all occurences of %d by %temp.
108// =============================================================================
109void NVPTXLowerKernelArgs::handleByValParam(Argument *Arg) {
110 Function *Func = Arg->getParent();
111 Instruction *FirstInst = &(Func->getEntryBlock().front());
112 PointerType *PType = dyn_cast<PointerType>(Arg->getType());
113
114 assert(PType && "Expecting pointer type in handleByValParam");
115
116 Type *StructType = PType->getElementType();
117 AllocaInst *AllocA = new AllocaInst(StructType, Arg->getName(), FirstInst);
118 // Set the alignment to alignment of the byval parameter. This is because,
119 // later load/stores assume that alignment, and we are going to replace
120 // the use of the byval parameter with this alloca instruction.
121 AllocA->setAlignment(Func->getParamAlignment(Arg->getArgNo() + 1));
122 Arg->replaceAllUsesWith(AllocA);
123
124 Value *ArgInParam = new AddrSpaceCastInst(
125 Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(),
126 FirstInst);
127 LoadInst *LI = new LoadInst(ArgInParam, Arg->getName(), FirstInst);
128 new StoreInst(LI, AllocA, FirstInst);
129}
130
131void NVPTXLowerKernelArgs::handlePointerParam(Argument *Arg) {
132 assert(!Arg->hasByValAttr() &&
133 "byval params should be handled by handleByValParam");
134
Jingyue Wu32038182015-06-26 22:35:43 +0000135 // Do nothing if the argument already points to the global address space.
136 if (Arg->getType()->getPointerAddressSpace() == ADDRESS_SPACE_GLOBAL)
137 return;
138
Jingyue Wua2f60272015-06-04 21:28:26 +0000139 Instruction *FirstInst = Arg->getParent()->getEntryBlock().begin();
140 Instruction *ArgInGlobal = new AddrSpaceCastInst(
141 Arg, PointerType::get(Arg->getType()->getPointerElementType(),
142 ADDRESS_SPACE_GLOBAL),
143 Arg->getName(), FirstInst);
144 Value *ArgInGeneric = new AddrSpaceCastInst(ArgInGlobal, Arg->getType(),
145 Arg->getName(), FirstInst);
146 // Replace with ArgInGeneric all uses of Args except ArgInGlobal.
147 Arg->replaceAllUsesWith(ArgInGeneric);
148 ArgInGlobal->setOperand(0, Arg);
149}
150
151
152// =============================================================================
153// Main function for this pass.
154// =============================================================================
155bool NVPTXLowerKernelArgs::runOnFunction(Function &F) {
156 // Skip non-kernels. See the comments at the top of this file.
157 if (!isKernelFunction(F))
158 return false;
159
160 for (Argument &Arg : F.args()) {
161 if (Arg.getType()->isPointerTy()) {
162 if (Arg.hasByValAttr())
163 handleByValParam(&Arg);
164 else if (TM && TM->getDrvInterface() == NVPTX::CUDA)
165 handlePointerParam(&Arg);
166 }
167 }
168 return true;
169}
170
171FunctionPass *
172llvm::createNVPTXLowerKernelArgsPass(const NVPTXTargetMachine *TM) {
173 return new NVPTXLowerKernelArgs(TM);
174}