blob: 8c7f6e63e1f438d086c3735ab6fc2ca43c63975e [file] [log] [blame]
Artem Belevich7e9c9a62016-07-20 21:44:07 +00001//===-- NVPTXLowerArgs.cpp - Lower arguments ------------------------------===//
Jingyue Wua2f60272015-06-04 21:28:26 +00002//
Chandler Carruth2946cd72019-01-19 08:50:56 +00003// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
Jingyue Wua2f60272015-06-04 21:28:26 +00006//
7//===----------------------------------------------------------------------===//
8//
Jingyue Wua2f60272015-06-04 21:28:26 +00009//
Artem Belevichb2e76a52016-07-20 18:39:47 +000010// Arguments to kernel and device functions are passed via param space,
11// which imposes certain restrictions:
12// http://docs.nvidia.com/cuda/parallel-thread-execution/#state-spaces
Jingyue Wua2f60272015-06-04 21:28:26 +000013//
Artem Belevichb2e76a52016-07-20 18:39:47 +000014// Kernel parameters are read-only and accessible only via ld.param
15// instruction, directly or via a pointer. Pointers to kernel
16// arguments can't be converted to generic address space.
Jingyue Wua2f60272015-06-04 21:28:26 +000017//
Artem Belevichb2e76a52016-07-20 18:39:47 +000018// Device function parameters are directly accessible via
19// ld.param/st.param, but taking the address of one returns a pointer
20// to a copy created in local space which *can't* be used with
21// ld.param/st.param.
22//
23// Copying a byval struct into local memory in IR allows us to enforce
24// the param space restrictions, gives the rest of IR a pointer w/o
25// param space restrictions, and gives us an opportunity to eliminate
26// the copy.
27//
28// Pointer arguments to kernel functions need more work to be lowered:
29//
30// 1. Convert non-byval pointer arguments of CUDA kernels to pointers in the
Jingyue Wua2f60272015-06-04 21:28:26 +000031// global address space. This allows later optimizations to emit
32// ld.global.*/st.global.* for accessing these pointer arguments. For
33// example,
34//
35// define void @foo(float* %input) {
36// %v = load float, float* %input, align 4
37// ...
38// }
39//
40// becomes
41//
42// define void @foo(float* %input) {
43// %input2 = addrspacecast float* %input to float addrspace(1)*
44// %input3 = addrspacecast float addrspace(1)* %input2 to float*
45// %v = load float, float* %input3, align 4
46// ...
47// }
48//
Justin Lebared1e3122016-10-31 21:51:42 +000049// Later, NVPTXInferAddressSpaces will optimize it to
Jingyue Wua2f60272015-06-04 21:28:26 +000050//
51// define void @foo(float* %input) {
52// %input2 = addrspacecast float* %input to float addrspace(1)*
53// %v = load float, float addrspace(1)* %input2, align 4
54// ...
55// }
56//
Artem Belevichb2e76a52016-07-20 18:39:47 +000057// 2. Convert pointers in a byval kernel parameter to pointers in the global
Jingyue Wucf700532015-07-31 21:44:14 +000058// address space. As #2, it allows NVPTX to emit more ld/st.global. E.g.,
59//
60// struct S {
61// int *x;
62// int *y;
63// };
64// __global__ void foo(S s) {
65// int *b = s.y;
66// // use b
67// }
68//
69// "b" points to the global address space. In the IR level,
70//
71// define void @foo({i32*, i32*}* byval %input) {
72// %b_ptr = getelementptr {i32*, i32*}, {i32*, i32*}* %input, i64 0, i32 1
73// %b = load i32*, i32** %b_ptr
74// ; use %b
75// }
76//
77// becomes
78//
79// define void @foo({i32*, i32*}* byval %input) {
80// %b_ptr = getelementptr {i32*, i32*}, {i32*, i32*}* %input, i64 0, i32 1
81// %b = load i32*, i32** %b_ptr
82// %b_global = addrspacecast i32* %b to i32 addrspace(1)*
83// %b_generic = addrspacecast i32 addrspace(1)* %b_global to i32*
84// ; use %b_generic
85// }
86//
Justin Lebared1e3122016-10-31 21:51:42 +000087// TODO: merge this pass with NVPTXInferAddressSpaces so that other passes don't
88// cancel the addrspacecast pair this pass emits.
Jingyue Wua2f60272015-06-04 21:28:26 +000089//===----------------------------------------------------------------------===//
90
91#include "NVPTX.h"
Jingyue Wua2f60272015-06-04 21:28:26 +000092#include "NVPTXTargetMachine.h"
Chandler Carruth6bda14b2017-06-06 11:49:48 +000093#include "NVPTXUtilities.h"
Jingyue Wucf700532015-07-31 21:44:14 +000094#include "llvm/Analysis/ValueTracking.h"
Jingyue Wua2f60272015-06-04 21:28:26 +000095#include "llvm/IR/Function.h"
96#include "llvm/IR/Instructions.h"
97#include "llvm/IR/Module.h"
98#include "llvm/IR/Type.h"
99#include "llvm/Pass.h"
100
101using namespace llvm;
102
103namespace llvm {
Artem Belevich7e9c9a62016-07-20 21:44:07 +0000104void initializeNVPTXLowerArgsPass(PassRegistry &);
Jingyue Wua2f60272015-06-04 21:28:26 +0000105}
106
107namespace {
Artem Belevich7e9c9a62016-07-20 21:44:07 +0000108class NVPTXLowerArgs : public FunctionPass {
Jingyue Wua2f60272015-06-04 21:28:26 +0000109 bool runOnFunction(Function &F) override;
110
Artem Belevichb2e76a52016-07-20 18:39:47 +0000111 bool runOnKernelFunction(Function &F);
112 bool runOnDeviceFunction(Function &F);
113
Jingyue Wua2f60272015-06-04 21:28:26 +0000114 // handle byval parameters
Jingyue Wucf700532015-07-31 21:44:14 +0000115 void handleByValParam(Argument *Arg);
116 // Knowing Ptr must point to the global address space, this function
117 // addrspacecasts Ptr to global and then back to generic. This allows
Justin Lebared1e3122016-10-31 21:51:42 +0000118 // NVPTXInferAddressSpaces to fold the global-to-generic cast into
Jingyue Wucf700532015-07-31 21:44:14 +0000119 // loads/stores that appear later.
120 void markPointerAsGlobal(Value *Ptr);
Jingyue Wua2f60272015-06-04 21:28:26 +0000121
122public:
123 static char ID; // Pass identification, replacement for typeid
Artem Belevich7e9c9a62016-07-20 21:44:07 +0000124 NVPTXLowerArgs(const NVPTXTargetMachine *TM = nullptr)
Jingyue Wua2f60272015-06-04 21:28:26 +0000125 : FunctionPass(ID), TM(TM) {}
Mehdi Amini117296c2016-10-01 02:56:57 +0000126 StringRef getPassName() const override {
Jingyue Wua2f60272015-06-04 21:28:26 +0000127 return "Lower pointer arguments of CUDA kernels";
128 }
129
130private:
131 const NVPTXTargetMachine *TM;
132};
133} // namespace
134
Artem Belevich7e9c9a62016-07-20 21:44:07 +0000135char NVPTXLowerArgs::ID = 1;
Jingyue Wua2f60272015-06-04 21:28:26 +0000136
Artem Belevich7e9c9a62016-07-20 21:44:07 +0000137INITIALIZE_PASS(NVPTXLowerArgs, "nvptx-lower-args",
138 "Lower arguments (NVPTX)", false, false)
Jingyue Wua2f60272015-06-04 21:28:26 +0000139
140// =============================================================================
Manuel Jacob45cc9bb2016-01-23 05:47:34 +0000141// If the function had a byval struct ptr arg, say foo(%struct.x* byval %d),
Jingyue Wua2f60272015-06-04 21:28:26 +0000142// then add the following instructions to the first basic block:
143//
144// %temp = alloca %struct.x, align 8
145// %tempd = addrspacecast %struct.x* %d to %struct.x addrspace(101)*
146// %tv = load %struct.x addrspace(101)* %tempd
147// store %struct.x %tv, %struct.x* %temp, align 8
148//
149// The above code allocates some space in the stack and copies the incoming
150// struct from param space to local space.
Benjamin Kramerdf005cb2015-08-08 18:27:36 +0000151// Then replace all occurrences of %d by %temp.
Jingyue Wua2f60272015-06-04 21:28:26 +0000152// =============================================================================
Artem Belevich7e9c9a62016-07-20 21:44:07 +0000153void NVPTXLowerArgs::handleByValParam(Argument *Arg) {
Jingyue Wua2f60272015-06-04 21:28:26 +0000154 Function *Func = Arg->getParent();
155 Instruction *FirstInst = &(Func->getEntryBlock().front());
156 PointerType *PType = dyn_cast<PointerType>(Arg->getType());
157
158 assert(PType && "Expecting pointer type in handleByValParam");
159
160 Type *StructType = PType->getElementType();
Matt Arsenault3c1fc762017-04-10 22:27:50 +0000161 unsigned AS = Func->getParent()->getDataLayout().getAllocaAddrSpace();
162 AllocaInst *AllocA = new AllocaInst(StructType, AS, Arg->getName(), FirstInst);
Jingyue Wua2f60272015-06-04 21:28:26 +0000163 // Set the alignment to alignment of the byval parameter. This is because,
164 // later load/stores assume that alignment, and we are going to replace
165 // the use of the byval parameter with this alloca instruction.
Reid Kleckner859f8b52017-04-28 20:34:27 +0000166 AllocA->setAlignment(Func->getParamAlignment(Arg->getArgNo()));
Jingyue Wua2f60272015-06-04 21:28:26 +0000167 Arg->replaceAllUsesWith(AllocA);
168
169 Value *ArgInParam = new AddrSpaceCastInst(
170 Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(),
171 FirstInst);
James Y Knight14359ef2019-02-01 20:44:24 +0000172 LoadInst *LI =
173 new LoadInst(StructType, ArgInParam, Arg->getName(), FirstInst);
Jingyue Wua2f60272015-06-04 21:28:26 +0000174 new StoreInst(LI, AllocA, FirstInst);
175}
176
Artem Belevich7e9c9a62016-07-20 21:44:07 +0000177void NVPTXLowerArgs::markPointerAsGlobal(Value *Ptr) {
Jingyue Wucf700532015-07-31 21:44:14 +0000178 if (Ptr->getType()->getPointerAddressSpace() == ADDRESS_SPACE_GLOBAL)
Jingyue Wu32038182015-06-26 22:35:43 +0000179 return;
180
Jingyue Wucf700532015-07-31 21:44:14 +0000181 // Deciding where to emit the addrspacecast pair.
182 BasicBlock::iterator InsertPt;
183 if (Argument *Arg = dyn_cast<Argument>(Ptr)) {
184 // Insert at the functon entry if Ptr is an argument.
185 InsertPt = Arg->getParent()->getEntryBlock().begin();
186 } else {
187 // Insert right after Ptr if Ptr is an instruction.
Duncan P. N. Exon Smith61149b82015-10-20 00:54:09 +0000188 InsertPt = ++cast<Instruction>(Ptr)->getIterator();
Jingyue Wucf700532015-07-31 21:44:14 +0000189 assert(InsertPt != InsertPt->getParent()->end() &&
190 "We don't call this function with Ptr being a terminator.");
191 }
Jingyue Wua2f60272015-06-04 21:28:26 +0000192
Jingyue Wucf700532015-07-31 21:44:14 +0000193 Instruction *PtrInGlobal = new AddrSpaceCastInst(
194 Ptr, PointerType::get(Ptr->getType()->getPointerElementType(),
195 ADDRESS_SPACE_GLOBAL),
Duncan P. N. Exon Smith61149b82015-10-20 00:54:09 +0000196 Ptr->getName(), &*InsertPt);
Jingyue Wucf700532015-07-31 21:44:14 +0000197 Value *PtrInGeneric = new AddrSpaceCastInst(PtrInGlobal, Ptr->getType(),
Duncan P. N. Exon Smith61149b82015-10-20 00:54:09 +0000198 Ptr->getName(), &*InsertPt);
Jingyue Wucf700532015-07-31 21:44:14 +0000199 // Replace with PtrInGeneric all uses of Ptr except PtrInGlobal.
200 Ptr->replaceAllUsesWith(PtrInGeneric);
201 PtrInGlobal->setOperand(0, Ptr);
202}
Jingyue Wua2f60272015-06-04 21:28:26 +0000203
204// =============================================================================
205// Main function for this pass.
206// =============================================================================
Artem Belevich7e9c9a62016-07-20 21:44:07 +0000207bool NVPTXLowerArgs::runOnKernelFunction(Function &F) {
Jingyue Wucf700532015-07-31 21:44:14 +0000208 if (TM && TM->getDrvInterface() == NVPTX::CUDA) {
209 // Mark pointers in byval structs as global.
210 for (auto &B : F) {
211 for (auto &I : B) {
212 if (LoadInst *LI = dyn_cast<LoadInst>(&I)) {
213 if (LI->getType()->isPointerTy()) {
214 Value *UO = GetUnderlyingObject(LI->getPointerOperand(),
215 F.getParent()->getDataLayout());
216 if (Argument *Arg = dyn_cast<Argument>(UO)) {
217 if (Arg->hasByValAttr()) {
218 // LI is a load from a pointer within a byval kernel parameter.
219 markPointerAsGlobal(LI);
220 }
221 }
222 }
223 }
224 }
225 }
226 }
227
Jingyue Wua2f60272015-06-04 21:28:26 +0000228 for (Argument &Arg : F.args()) {
229 if (Arg.getType()->isPointerTy()) {
230 if (Arg.hasByValAttr())
231 handleByValParam(&Arg);
232 else if (TM && TM->getDrvInterface() == NVPTX::CUDA)
Jingyue Wucf700532015-07-31 21:44:14 +0000233 markPointerAsGlobal(&Arg);
Jingyue Wua2f60272015-06-04 21:28:26 +0000234 }
235 }
236 return true;
237}
238
Artem Belevichb2e76a52016-07-20 18:39:47 +0000239// Device functions only need to copy byval args into local memory.
Artem Belevich7e9c9a62016-07-20 21:44:07 +0000240bool NVPTXLowerArgs::runOnDeviceFunction(Function &F) {
Artem Belevichb2e76a52016-07-20 18:39:47 +0000241 for (Argument &Arg : F.args())
242 if (Arg.getType()->isPointerTy() && Arg.hasByValAttr())
243 handleByValParam(&Arg);
244 return true;
245}
246
Artem Belevich7e9c9a62016-07-20 21:44:07 +0000247bool NVPTXLowerArgs::runOnFunction(Function &F) {
Artem Belevichb2e76a52016-07-20 18:39:47 +0000248 return isKernelFunction(F) ? runOnKernelFunction(F) : runOnDeviceFunction(F);
249}
250
Jingyue Wua2f60272015-06-04 21:28:26 +0000251FunctionPass *
Artem Belevich7e9c9a62016-07-20 21:44:07 +0000252llvm::createNVPTXLowerArgsPass(const NVPTXTargetMachine *TM) {
253 return new NVPTXLowerArgs(TM);
Jingyue Wua2f60272015-06-04 21:28:26 +0000254}