blob: 997ced7927504890ff95569a1ca5453ba35efd42 [file] [log] [blame]
Pirama Arumuga Nainar8c24f8d2015-03-17 13:11:25 -07001/*
2 * Copyright 2015, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
Jean-Luc Brouilleta2dd52f2017-02-16 20:57:26 -080017#include "Assert.h"
18#include "Log.h"
19#include "RSUtils.h"
Pirama Arumuga Nainar8c24f8d2015-03-17 13:11:25 -070020
21#include <algorithm>
22#include <vector>
23
24#include <llvm/IR/CallSite.h>
25#include <llvm/IR/Type.h>
26#include <llvm/IR/Instructions.h>
27#include <llvm/IR/Module.h>
28#include <llvm/IR/Function.h>
29#include <llvm/Pass.h>
30
31namespace { // anonymous namespace
32
33static const bool kDebug = false;
34
35/* RSX86_64CallConvPass: This pass scans for calls to Renderscript functions in
36 * the CPU reference driver. For such calls, it identifies the
37 * pass-by-reference large-object pointer arguments introduced by the frontend
38 * to conform to the AArch64 calling convention (AAPCS). These pointer
39 * arguments are converted to pass-by-value to match the calling convention of
40 * the CPU reference driver.
41 */
42class RSX86_64CallConvPass: public llvm::ModulePass {
43private:
Pirama Arumuga Nainar8c24f8d2015-03-17 13:11:25 -070044 bool IsRSFunctionOfInterest(llvm::Function &F) {
45 // Only Renderscript functions that are not defined locally be considered
46 if (!F.empty()) // defined locally
47 return false;
48
49 // llvm intrinsic or internal function
50 llvm::StringRef FName = F.getName();
51 if (FName.startswith("llvm."))
52 return false;
53
54 // All other functions need to be checked for large-object parameters.
55 // Disallowed (non-Renderscript) functions are detected by a different pass.
56 return true;
57 }
58
Pirama Arumuga Nainar8c24f8d2015-03-17 13:11:25 -070059 // Test if this argument needs to be converted to pass-by-value.
60 bool IsDerefNeeded(llvm::Function *F, llvm::Argument &Arg) {
61 unsigned ArgNo = Arg.getArgNo();
62 llvm::Type *ArgTy = Arg.getType();
63
64 // Do not consider arguments with 'sret' attribute. Parameters with this
65 // attribute are actually pointers to structure return values.
66 if (Arg.hasStructRetAttr())
67 return false;
68
69 // Dereference needed only if type is a pointer to a struct
70 if (!ArgTy->isPointerTy() || !ArgTy->getPointerElementType()->isStructTy())
71 return false;
72
73 // Dereference needed only for certain RS struct objects.
74 llvm::Type *StructTy = ArgTy->getPointerElementType();
Stephen Hinesabfa7852015-05-22 19:43:10 -070075 if (!isRsObjectType(StructTy))
Pirama Arumuga Nainar8c24f8d2015-03-17 13:11:25 -070076 return false;
77
78 // TODO Find a better way to encode exceptions
79 llvm::StringRef FName = F->getName();
80 // rsSetObject's first parameter is a pointer
81 if (FName.find("rsSetObject") != std::string::npos && ArgNo == 0)
82 return false;
83 // rsClearObject's first parameter is a pointer
84 if (FName.find("rsClearObject") != std::string::npos && ArgNo == 0)
85 return false;
Yang Ni02c61f62015-11-18 16:01:10 -080086 // rsForEachInternal's fifth parameter is a pointer
87 if (FName.find("rsForEachInternal") != std::string::npos && ArgNo == 4)
88 return false;
Pirama Arumuga Nainar8c24f8d2015-03-17 13:11:25 -070089
90 return true;
91 }
92
93 // Compute which arguments to this function need be converted to pass-by-value
94 bool FillArgsToDeref(llvm::Function *F, std::vector<unsigned> &ArgNums) {
95 bccAssert(ArgNums.size() == 0);
96
97 for (auto &Arg: F->getArgumentList()) {
98 if (IsDerefNeeded(F, Arg)) {
99 ArgNums.push_back(Arg.getArgNo());
100
101 if (kDebug) {
102 ALOGV("Lowering argument %u for function %s\n", Arg.getArgNo(),
103 F->getName().str().c_str());
104 }
105 }
106 }
107 return ArgNums.size() > 0;
108 }
109
110 llvm::Function *RedefineFn(llvm::Function *OrigFn,
111 std::vector<unsigned> &ArgsToDeref) {
112
113 llvm::FunctionType *FTy = OrigFn->getFunctionType();
114 std::vector<llvm::Type *> Params(FTy->param_begin(), FTy->param_end());
115
116 llvm::FunctionType *NewTy = llvm::FunctionType::get(FTy->getReturnType(),
117 Params,
118 FTy->isVarArg());
119 llvm::Function *NewFn = llvm::Function::Create(NewTy,
120 OrigFn->getLinkage(),
121 OrigFn->getName(),
122 OrigFn->getParent());
123
124 // Add the ByVal attribute to the attribute list corresponding to this
125 // argument. The list at index (i+1) corresponds to the i-th argument. The
126 // list at index 0 corresponds to the return value's attribute.
127 for (auto i: ArgsToDeref) {
128 NewFn->addAttribute(i+1, llvm::Attribute::ByVal);
129 }
130
131 NewFn->copyAttributesFrom(OrigFn);
132 NewFn->takeName(OrigFn);
133
134 for (auto AI=OrigFn->arg_begin(), AE=OrigFn->arg_end(),
135 NAI=NewFn->arg_begin();
136 AI != AE; ++ AI, ++NAI) {
Pirama Arumuga Nainarf229c402016-03-06 23:05:45 -0800137 NAI->takeName(&*AI);
Pirama Arumuga Nainar8c24f8d2015-03-17 13:11:25 -0700138 }
139
140 return NewFn;
141 }
142
143 void ReplaceCallInsn(llvm::CallSite &CS,
144 llvm::Function *NewFn,
145 std::vector<unsigned> &ArgsToDeref) {
146
147 llvm::CallInst *CI = llvm::cast<llvm::CallInst>(CS.getInstruction());
148 std::vector<llvm::Value *> Args(CS.arg_begin(), CS.arg_end());
149 auto NewCI = llvm::CallInst::Create(NewFn, Args, "", CI);
150
151 // Add the ByVal attribute to the attribute list corresponding to this
152 // argument. The list at index (i+1) corresponds to the i-th argument. The
153 // list at index 0 corresponds to the return value's attribute.
154 for (auto i: ArgsToDeref) {
155 NewCI->addAttribute(i+1, llvm::Attribute::ByVal);
156 }
157 if (CI->isTailCall())
158 NewCI->setTailCall();
159
160 if (!CI->getType()->isVoidTy())
161 CI->replaceAllUsesWith(NewCI);
162
163 CI->eraseFromParent();
164 }
165
166public:
167 static char ID;
168
169 RSX86_64CallConvPass()
170 : ModulePass (ID) {
171 }
172
173 virtual void getAnalysisUsage(llvm::AnalysisUsage &AU) const override {
174 // This pass does not use any other analysis passes, but it does
175 // modify the existing functions in the module (thus altering the CFG).
176 }
177
178 bool runOnModule(llvm::Module &M) override {
179 // Avoid adding Functions and altering FunctionList while iterating over it
180 // by collecting functions and processing them later.
181 std::vector<llvm::Function *> FunctionsToHandle;
182
183 auto &FunctionList = M.getFunctionList();
184 for (auto &OrigFn: FunctionList) {
185 if (!IsRSFunctionOfInterest(OrigFn))
186 continue;
187 FunctionsToHandle.push_back(&OrigFn);
188 }
189
190 for (auto OrigFn: FunctionsToHandle) {
191 std::vector<unsigned> ArgsToDeref;
192 if (!FillArgsToDeref(OrigFn, ArgsToDeref))
193 continue;
194
195 // Replace all calls to OrigFn and erase it from parent.
196 llvm::Function *NewFn = RedefineFn(OrigFn, ArgsToDeref);
197 while (!OrigFn->use_empty()) {
198 llvm::CallSite CS(OrigFn->user_back());
199 ReplaceCallInsn(CS, NewFn, ArgsToDeref);
200 }
201 OrigFn->eraseFromParent();
202 }
203
204 return FunctionsToHandle.size() > 0;
205 }
206
207};
208
209}
210
211char RSX86_64CallConvPass::ID = 0;
212
213static llvm::RegisterPass<RSX86_64CallConvPass> X("X86-64-calling-conv",
214 "remove AArch64 assumptions from calls in X86-64");
215
216namespace bcc {
217
218llvm::ModulePass *
219createRSX86_64CallConvPass() {
220 return new RSX86_64CallConvPass();
221}
222
223}