blob: ab32c43163b930cd945488f4bc9036875d3efdad [file] [log] [blame]
Yang Nifb40ee22015-10-13 20:34:06 +00001/*
2 * Copyright 2015, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_foreach_lowering.h"
18
19#include "clang/AST/ASTContext.h"
20#include "llvm/Support/raw_ostream.h"
21#include "slang_rs_context.h"
22#include "slang_rs_export_foreach.h"
23
24namespace slang {
25
26namespace {
27
28const char KERNEL_LAUNCH_FUNCTION_NAME[] = "rsParallelFor";
29const char INTERNAL_LAUNCH_FUNCTION_NAME[] =
30 "_Z17rsForEachInternali13rs_allocationS_";
31
32} // anonymous namespace
33
34RSForEachLowering::RSForEachLowering(RSContext* ctxt)
35 : mCtxt(ctxt), mASTCtxt(ctxt->getASTContext()) {}
36
37// Check if the passed-in expr references a kernel function in the following
38// pattern in the AST.
39//
40// ImplicitCastExpr 'void *' <BitCast>
41// `-ImplicitCastExpr 'int (*)(int)' <FunctionToPointerDecay>
42// `-DeclRefExpr 'int (int)' Function 'foo' 'int (int)'
43const clang::FunctionDecl* RSForEachLowering::matchFunctionDesignator(
44 clang::Expr* expr) {
45 clang::ImplicitCastExpr* ToVoidPtr =
46 clang::dyn_cast<clang::ImplicitCastExpr>(expr);
47 if (ToVoidPtr == nullptr) {
48 return nullptr;
49 }
50
51 clang::ImplicitCastExpr* Decay =
52 clang::dyn_cast<clang::ImplicitCastExpr>(ToVoidPtr->getSubExpr());
53
54 if (Decay == nullptr) {
55 return nullptr;
56 }
57
58 clang::DeclRefExpr* DRE =
59 clang::dyn_cast<clang::DeclRefExpr>(Decay->getSubExpr());
60
61 if (DRE == nullptr) {
62 return nullptr;
63 }
64
65 const clang::FunctionDecl* FD =
66 clang::dyn_cast<clang::FunctionDecl>(DRE->getDecl());
67
68 if (FD == nullptr) {
69 return nullptr;
70 }
71
72 // TODO: Verify the launch has the expected number of input allocations
73
74 return FD;
75}
76
77// Checks if the call expression is a legal rsParallelFor call by looking for the
78// following pattern in the AST. On success, returns the first argument that is
79// a FunctionDecl of a kernel function.
80//
81// CallExpr 'void'
82// |
83// |-ImplicitCastExpr 'void (*)(void *, ...)' <FunctionToPointerDecay>
84// | `-DeclRefExpr 'void (void *, ...)' 'rsParallelFor' 'void (void *, ...)'
85// |
86// |-ImplicitCastExpr 'void *' <BitCast>
87// | `-ImplicitCastExpr 'int (*)(int)' <FunctionToPointerDecay>
88// | `-DeclRefExpr 'int (int)' Function 'foo' 'int (int)'
89// |
90// |-ImplicitCastExpr 'rs_allocation':'rs_allocation' <LValueToRValue>
91// | `-DeclRefExpr 'rs_allocation':'rs_allocation' lvalue ParmVar 'in' 'rs_allocation':'rs_allocation'
92// |
93// `-ImplicitCastExpr 'rs_allocation':'rs_allocation' <LValueToRValue>
94// `-DeclRefExpr 'rs_allocation':'rs_allocation' lvalue ParmVar 'out' 'rs_allocation':'rs_allocation'
95const clang::FunctionDecl* RSForEachLowering::matchKernelLaunchCall(
96 clang::CallExpr* CE) {
97 const clang::Decl* D = CE->getCalleeDecl();
98 const clang::FunctionDecl* FD = clang::dyn_cast<clang::FunctionDecl>(D);
99
100 if (FD == nullptr) {
101 return nullptr;
102 }
103
104 const clang::StringRef& funcName = FD->getName();
105
106 if (!funcName.equals(KERNEL_LAUNCH_FUNCTION_NAME)) {
107 return nullptr;
108 }
109
110 const clang::FunctionDecl* kernel = matchFunctionDesignator(CE->getArg(0));
111
112 if (kernel == nullptr ||
113 CE->getNumArgs() < 3) { // TODO: Make argument check more accurate
114 mCtxt->ReportError(CE->getExprLoc(), "Invalid kernel launch call.");
115 }
116
117 return kernel;
118}
119
120// Create an AST node for the declaration of rsForEachInternal
121clang::FunctionDecl* RSForEachLowering::CreateForEachInternalFunctionDecl() {
122 const clang::QualType& AllocTy = mCtxt->getAllocationType();
123 clang::DeclContext* DC = mASTCtxt.getTranslationUnitDecl();
124 clang::SourceLocation Loc;
125
126 llvm::StringRef SR(INTERNAL_LAUNCH_FUNCTION_NAME);
127 clang::IdentifierInfo& II = mASTCtxt.Idents.get(SR);
128 clang::DeclarationName N(&II);
129
130 clang::FunctionProtoType::ExtProtoInfo EPI;
131
132 clang::QualType T = mASTCtxt.getFunctionType(
133 mASTCtxt.VoidTy, // Return type
134 {mASTCtxt.IntTy, AllocTy, AllocTy}, // Argument types
135 EPI);
136
137 clang::FunctionDecl* FD = clang::FunctionDecl::Create(
138 mASTCtxt, DC, Loc, Loc, N, T, nullptr, clang::SC_Extern);
139 return FD;
140}
141
142// Create an expression like the following that references the rsForEachInternal to
143// replace the callee in the original call expression that references rsParallelFor.
144//
145// ImplicitCastExpr 'void (*)(int, rs_allocation, rs_allocation)' <FunctionToPointerDecay>
146// `-DeclRefExpr 'void' Function '_Z17rsForEachInternali13rs_allocationS_' 'void (int, rs_allocation, rs_allocation)'
147clang::Expr* RSForEachLowering::CreateCalleeExprForInternalForEach() {
148 clang::FunctionDecl* FDNew = CreateForEachInternalFunctionDecl();
149
150 clang::DeclRefExpr* refExpr = clang::DeclRefExpr::Create(
151 mASTCtxt, clang::NestedNameSpecifierLoc(), clang::SourceLocation(), FDNew,
152 false, clang::SourceLocation(), mASTCtxt.VoidTy, clang::VK_RValue);
153
154 const clang::QualType FDNewType = FDNew->getType();
155
156 clang::Expr* calleeNew = clang::ImplicitCastExpr::Create(
157 mASTCtxt, mASTCtxt.getPointerType(FDNewType),
158 clang::CK_FunctionToPointerDecay, refExpr, nullptr, clang::VK_RValue);
159
160 return calleeNew;
161}
162
163// This visit method checks (via pattern matching) if the call expression is to
164// rsParallelFor, and the arguments satisfy the restrictions on the
165// rsParallelFor API. If so, replace the call with a rsForEachInternal call
166// with the first argument replaced by the slot number of the kernel function
167// referenced in the original first argument.
168//
169// See comments to the helper methods defined above for details.
170void RSForEachLowering::VisitCallExpr(clang::CallExpr* CE) {
171 const clang::FunctionDecl* kernel = matchKernelLaunchCall(CE);
172 if (kernel == nullptr) {
173 return;
174 }
175
176 clang::Expr* calleeNew = CreateCalleeExprForInternalForEach();
177 CE->setCallee(calleeNew);
178
179 const int slot = mCtxt->getForEachSlotNumber(kernel);
180 const llvm::APInt APIntSlot(mASTCtxt.getTypeSize(mASTCtxt.IntTy), slot);
181 const clang::Expr* arg0 = CE->getArg(0);
182 const clang::SourceLocation Loc(arg0->getLocStart());
183 clang::Expr* IntSlotNum =
184 clang::IntegerLiteral::Create(mASTCtxt, APIntSlot, mASTCtxt.IntTy, Loc);
185 CE->setArg(0, IntSlotNum);
186}
187
188void RSForEachLowering::VisitStmt(clang::Stmt* S) {
189 for (clang::Stmt* Child : S->children()) {
190 if (Child) {
191 Visit(Child);
192 }
193 }
194}
195
196} // namespace slang