blob: b7fc65401fc48594c3830bb863d7d790b0094a3a [file] [log] [blame]
Dan Gohman1b637452017-01-07 00:34:54 +00001//===-- WebAssemblyFixFunctionBitcasts.cpp - Fix function bitcasts --------===//
2//
Chandler Carruth2946cd72019-01-19 08:50:56 +00003// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
Dan Gohman1b637452017-01-07 00:34:54 +00006//
7//===----------------------------------------------------------------------===//
8///
9/// \file
Adrian Prantl5f8f34e42018-05-01 15:54:18 +000010/// Fix bitcasted functions.
Dan Gohman1b637452017-01-07 00:34:54 +000011///
12/// WebAssembly requires caller and callee signatures to match, however in LLVM,
13/// some amount of slop is vaguely permitted. Detect mismatch by looking for
14/// bitcasts of functions and rewrite them to use wrapper functions instead.
15///
16/// This doesn't catch all cases, such as when a function's address is taken in
17/// one place and casted in another, but it works for many common cases.
18///
19/// Note that LLVM already optimizes away function bitcasts in common cases by
20/// dropping arguments as needed, so this pass only ends up getting used in less
21/// common cases.
22///
23//===----------------------------------------------------------------------===//
24
25#include "WebAssembly.h"
Jacob Gravelle37af00e2017-10-10 16:20:18 +000026#include "llvm/IR/CallSite.h"
Dan Gohman1b637452017-01-07 00:34:54 +000027#include "llvm/IR/Constants.h"
28#include "llvm/IR/Instructions.h"
29#include "llvm/IR/Module.h"
30#include "llvm/IR/Operator.h"
31#include "llvm/Pass.h"
32#include "llvm/Support/Debug.h"
33#include "llvm/Support/raw_ostream.h"
34using namespace llvm;
35
36#define DEBUG_TYPE "wasm-fix-function-bitcasts"
37
38namespace {
39class FixFunctionBitcasts final : public ModulePass {
40 StringRef getPassName() const override {
41 return "WebAssembly Fix Function Bitcasts";
42 }
43
44 void getAnalysisUsage(AnalysisUsage &AU) const override {
45 AU.setPreservesCFG();
46 ModulePass::getAnalysisUsage(AU);
47 }
48
49 bool runOnModule(Module &M) override;
50
Dan Gohman1b637452017-01-07 00:34:54 +000051public:
52 static char ID;
53 FixFunctionBitcasts() : ModulePass(ID) {}
54};
55} // End anonymous namespace
56
57char FixFunctionBitcasts::ID = 0;
Jacob Gravelle40926452018-03-30 20:36:58 +000058INITIALIZE_PASS(FixFunctionBitcasts, DEBUG_TYPE,
59 "Fix mismatching bitcasts for WebAssembly", false, false)
60
Dan Gohman1b637452017-01-07 00:34:54 +000061ModulePass *llvm::createWebAssemblyFixFunctionBitcasts() {
62 return new FixFunctionBitcasts();
63}
64
65// Recursively descend the def-use lists from V to find non-bitcast users of
66// bitcasts of V.
Heejin Ahn18c56a02019-02-04 19:13:39 +000067static void findUses(Value *V, Function &F,
Derek Schuff7acb42a2017-01-10 21:59:53 +000068 SmallVectorImpl<std::pair<Use *, Function *>> &Uses,
69 SmallPtrSetImpl<Constant *> &ConstantBCs) {
Dan Gohman1b637452017-01-07 00:34:54 +000070 for (Use &U : V->uses()) {
Heejin Ahn18c56a02019-02-04 19:13:39 +000071 if (auto *BC = dyn_cast<BitCastOperator>(U.getUser()))
72 findUses(BC, F, Uses, ConstantBCs);
Derek Schuff7acb42a2017-01-10 21:59:53 +000073 else if (U.get()->getType() != F.getType()) {
Jacob Gravelle37af00e2017-10-10 16:20:18 +000074 CallSite CS(U.getUser());
75 if (!CS)
76 // Skip uses that aren't immediately called
77 continue;
78 Value *Callee = CS.getCalledValue();
79 if (Callee != V)
80 // Skip calls where the function isn't the callee
81 continue;
Derek Schuff7acb42a2017-01-10 21:59:53 +000082 if (isa<Constant>(U.get())) {
83 // Only add constant bitcasts to the list once; they get RAUW'd
Heejin Ahn18c56a02019-02-04 19:13:39 +000084 auto C = ConstantBCs.insert(cast<Constant>(U.get()));
85 if (!C.second)
Jacob Gravelle37af00e2017-10-10 16:20:18 +000086 continue;
Derek Schuff7acb42a2017-01-10 21:59:53 +000087 }
Dan Gohman1b637452017-01-07 00:34:54 +000088 Uses.push_back(std::make_pair(&U, &F));
Derek Schuff7acb42a2017-01-10 21:59:53 +000089 }
Dan Gohman1b637452017-01-07 00:34:54 +000090 }
91}
92
93// Create a wrapper function with type Ty that calls F (which may have a
94// different type). Attempt to support common bitcasted function idioms:
95// - Call with more arguments than needed: arguments are dropped
96// - Call with fewer arguments than needed: arguments are filled in with undef
97// - Return value is not needed: drop it
98// - Return value needed but not present: supply an undef
Dan Gohmand37dc2f2017-02-27 22:41:39 +000099//
Sam Clegg41d70472018-08-02 17:38:06 +0000100// If the all the argument types of trivially castable to one another (i.e.
101// I32 vs pointer type) then we don't create a wrapper at all (return nullptr
102// instead).
103//
Sam Clegg88599bf2018-08-30 01:01:30 +0000104// If there is a type mismatch that we know would result in an invalid wasm
Heejin Ahnf208f632018-09-05 01:27:38 +0000105// module then generate wrapper that contains unreachable (i.e. abort at
Sam Clegg88599bf2018-08-30 01:01:30 +0000106// runtime). Such programs are deep into undefined behaviour territory,
Sam Clegg41d70472018-08-02 17:38:06 +0000107// but we choose to fail at runtime rather than generate and invalid module
108// or fail at compiler time. The reason we delay the error is that we want
109// to support the CMake which expects to be able to compile and link programs
110// that refer to functions with entirely incorrect signatures (this is how
111// CMake detects the existence of a function in a toolchain).
Sam Clegg88599bf2018-08-30 01:01:30 +0000112//
113// For bitcasts that involve struct types we don't know at this stage if they
Heejin Ahnf208f632018-09-05 01:27:38 +0000114// would be equivalent at the wasm level and so we can't know if we need to
Sam Clegg88599bf2018-08-30 01:01:30 +0000115// generate a wrapper.
Heejin Ahn18c56a02019-02-04 19:13:39 +0000116static Function *createWrapper(Function *F, FunctionType *Ty) {
Dan Gohman1b637452017-01-07 00:34:54 +0000117 Module *M = F->getParent();
118
Sam Clegg41d70472018-08-02 17:38:06 +0000119 Function *Wrapper = Function::Create(Ty, Function::PrivateLinkage,
120 F->getName() + "_bitcast", M);
Dan Gohman1b637452017-01-07 00:34:54 +0000121 BasicBlock *BB = BasicBlock::Create(M->getContext(), "body", Wrapper);
Sam Clegg41d70472018-08-02 17:38:06 +0000122 const DataLayout &DL = BB->getModule()->getDataLayout();
Dan Gohman1b637452017-01-07 00:34:54 +0000123
124 // Determine what arguments to pass.
125 SmallVector<Value *, 4> Args;
126 Function::arg_iterator AI = Wrapper->arg_begin();
Dan Gohman2803bfa2017-11-28 17:15:03 +0000127 Function::arg_iterator AE = Wrapper->arg_end();
Dan Gohman1b637452017-01-07 00:34:54 +0000128 FunctionType::param_iterator PI = F->getFunctionType()->param_begin();
129 FunctionType::param_iterator PE = F->getFunctionType()->param_end();
Sam Clegg41d70472018-08-02 17:38:06 +0000130 bool TypeMismatch = false;
131 bool WrapperNeeded = false;
132
Sam Clegg88599bf2018-08-30 01:01:30 +0000133 Type *ExpectedRtnType = F->getFunctionType()->getReturnType();
134 Type *RtnType = Ty->getReturnType();
135
Sam Clegg41d70472018-08-02 17:38:06 +0000136 if ((F->getFunctionType()->getNumParams() != Ty->getNumParams()) ||
Sam Clegg88599bf2018-08-30 01:01:30 +0000137 (F->getFunctionType()->isVarArg() != Ty->isVarArg()) ||
138 (ExpectedRtnType != RtnType))
Sam Clegg41d70472018-08-02 17:38:06 +0000139 WrapperNeeded = true;
140
Dan Gohman2803bfa2017-11-28 17:15:03 +0000141 for (; AI != AE && PI != PE; ++AI, ++PI) {
Sam Clegg41d70472018-08-02 17:38:06 +0000142 Type *ArgType = AI->getType();
143 Type *ParamType = *PI;
144
145 if (ArgType == ParamType) {
Dan Gohman2803bfa2017-11-28 17:15:03 +0000146 Args.push_back(&*AI);
Sam Clegg41d70472018-08-02 17:38:06 +0000147 } else {
148 if (CastInst::isBitOrNoopPointerCastable(ArgType, ParamType, DL)) {
149 Instruction *PtrCast =
150 CastInst::CreateBitOrPointerCast(AI, ParamType, "cast");
151 BB->getInstList().push_back(PtrCast);
152 Args.push_back(PtrCast);
Sam Clegg88599bf2018-08-30 01:01:30 +0000153 } else if (ArgType->isStructTy() || ParamType->isStructTy()) {
Heejin Ahn18c56a02019-02-04 19:13:39 +0000154 LLVM_DEBUG(dbgs() << "createWrapper: struct param type in bitcast: "
Sam Clegg88599bf2018-08-30 01:01:30 +0000155 << F->getName() << "\n");
156 WrapperNeeded = false;
Sam Clegg41d70472018-08-02 17:38:06 +0000157 } else {
Heejin Ahn18c56a02019-02-04 19:13:39 +0000158 LLVM_DEBUG(dbgs() << "createWrapper: arg type mismatch calling: "
Sam Clegg41d70472018-08-02 17:38:06 +0000159 << F->getName() << "\n");
160 LLVM_DEBUG(dbgs() << "Arg[" << Args.size() << "] Expected: "
161 << *ParamType << " Got: " << *ArgType << "\n");
162 TypeMismatch = true;
163 break;
164 }
165 }
166 }
Dan Gohman1b637452017-01-07 00:34:54 +0000167
Sam Clegg88599bf2018-08-30 01:01:30 +0000168 if (WrapperNeeded && !TypeMismatch) {
Sam Clegg41d70472018-08-02 17:38:06 +0000169 for (; PI != PE; ++PI)
170 Args.push_back(UndefValue::get(*PI));
171 if (F->isVarArg())
172 for (; AI != AE; ++AI)
173 Args.push_back(&*AI);
Dan Gohman1b637452017-01-07 00:34:54 +0000174
Sam Clegg41d70472018-08-02 17:38:06 +0000175 CallInst *Call = CallInst::Create(F, Args, "", BB);
176
177 Type *ExpectedRtnType = F->getFunctionType()->getReturnType();
178 Type *RtnType = Ty->getReturnType();
179 // Determine what value to return.
180 if (RtnType->isVoidTy()) {
181 ReturnInst::Create(M->getContext(), BB);
Sam Clegg41d70472018-08-02 17:38:06 +0000182 } else if (ExpectedRtnType->isVoidTy()) {
Sam Clegg88599bf2018-08-30 01:01:30 +0000183 LLVM_DEBUG(dbgs() << "Creating dummy return: " << *RtnType << "\n");
Sam Clegg41d70472018-08-02 17:38:06 +0000184 ReturnInst::Create(M->getContext(), UndefValue::get(RtnType), BB);
Sam Clegg41d70472018-08-02 17:38:06 +0000185 } else if (RtnType == ExpectedRtnType) {
186 ReturnInst::Create(M->getContext(), Call, BB);
187 } else if (CastInst::isBitOrNoopPointerCastable(ExpectedRtnType, RtnType,
188 DL)) {
189 Instruction *Cast =
190 CastInst::CreateBitOrPointerCast(Call, RtnType, "cast");
191 BB->getInstList().push_back(Cast);
192 ReturnInst::Create(M->getContext(), Cast, BB);
Sam Clegg88599bf2018-08-30 01:01:30 +0000193 } else if (RtnType->isStructTy() || ExpectedRtnType->isStructTy()) {
Heejin Ahn18c56a02019-02-04 19:13:39 +0000194 LLVM_DEBUG(dbgs() << "createWrapper: struct return type in bitcast: "
Sam Clegg88599bf2018-08-30 01:01:30 +0000195 << F->getName() << "\n");
196 WrapperNeeded = false;
Sam Clegg41d70472018-08-02 17:38:06 +0000197 } else {
Heejin Ahn18c56a02019-02-04 19:13:39 +0000198 LLVM_DEBUG(dbgs() << "createWrapper: return type mismatch calling: "
Sam Clegg41d70472018-08-02 17:38:06 +0000199 << F->getName() << "\n");
200 LLVM_DEBUG(dbgs() << "Expected: " << *ExpectedRtnType
201 << " Got: " << *RtnType << "\n");
202 TypeMismatch = true;
203 }
204 }
205
206 if (TypeMismatch) {
Sam Clegg88599bf2018-08-30 01:01:30 +0000207 // Create a new wrapper that simply contains `unreachable`.
208 Wrapper->eraseFromParent();
Heejin Ahnf208f632018-09-05 01:27:38 +0000209 Wrapper = Function::Create(Ty, Function::PrivateLinkage,
210 F->getName() + "_bitcast_invalid", M);
Sam Clegg88599bf2018-08-30 01:01:30 +0000211 BasicBlock *BB = BasicBlock::Create(M->getContext(), "body", Wrapper);
Sam Clegg41d70472018-08-02 17:38:06 +0000212 new UnreachableInst(M->getContext(), BB);
213 Wrapper->setName(F->getName() + "_bitcast_invalid");
214 } else if (!WrapperNeeded) {
Heejin Ahn18c56a02019-02-04 19:13:39 +0000215 LLVM_DEBUG(dbgs() << "createWrapper: no wrapper needed: " << F->getName()
Sam Clegg41d70472018-08-02 17:38:06 +0000216 << "\n");
Dan Gohman0e2ceb82017-01-07 01:50:01 +0000217 Wrapper->eraseFromParent();
218 return nullptr;
219 }
Heejin Ahn18c56a02019-02-04 19:13:39 +0000220 LLVM_DEBUG(dbgs() << "createWrapper: " << F->getName() << "\n");
Dan Gohman1b637452017-01-07 00:34:54 +0000221 return Wrapper;
222}
223
Dan Gohman4684f822019-01-29 10:53:42 +0000224// Test whether a main function with type FuncTy should be rewritten to have
225// type MainTy.
Benjamin Kramer711950c2019-02-11 15:16:21 +0000226static bool shouldFixMainFunction(FunctionType *FuncTy, FunctionType *MainTy) {
Dan Gohman4684f822019-01-29 10:53:42 +0000227 // Only fix the main function if it's the standard zero-arg form. That way,
228 // the standard cases will work as expected, and users will see signature
229 // mismatches from the linker for non-standard cases.
230 return FuncTy->getReturnType() == MainTy->getReturnType() &&
231 FuncTy->getNumParams() == 0 &&
232 !FuncTy->isVarArg();
233}
234
Dan Gohman1b637452017-01-07 00:34:54 +0000235bool FixFunctionBitcasts::runOnModule(Module &M) {
Heejin Ahn569f0902019-01-09 23:05:21 +0000236 LLVM_DEBUG(dbgs() << "********** Fix Function Bitcasts **********\n");
237
Dan Gohman6736f592017-12-08 21:18:21 +0000238 Function *Main = nullptr;
239 CallInst *CallMain = nullptr;
Dan Gohmand5eda352017-01-07 01:31:18 +0000240 SmallVector<std::pair<Use *, Function *>, 0> Uses;
Derek Schuff7acb42a2017-01-10 21:59:53 +0000241 SmallPtrSet<Constant *, 2> ConstantBCs;
Dan Gohmand5eda352017-01-07 01:31:18 +0000242
Dan Gohman1b637452017-01-07 00:34:54 +0000243 // Collect all the places that need wrappers.
Dan Gohman6736f592017-12-08 21:18:21 +0000244 for (Function &F : M) {
Heejin Ahn18c56a02019-02-04 19:13:39 +0000245 findUses(&F, F, Uses, ConstantBCs);
Dan Gohman6736f592017-12-08 21:18:21 +0000246
247 // If we have a "main" function, and its type isn't
248 // "int main(int argc, char *argv[])", create an artificial call with it
249 // bitcasted to that type so that we generate a wrapper for it, so that
250 // the C runtime can call it.
Dan Gohman4684f822019-01-29 10:53:42 +0000251 if (F.getName() == "main") {
Dan Gohman6736f592017-12-08 21:18:21 +0000252 Main = &F;
253 LLVMContext &C = M.getContext();
Sam Clegg79c054f2018-09-13 17:13:10 +0000254 Type *MainArgTys[] = {Type::getInt32Ty(C),
255 PointerType::get(Type::getInt8PtrTy(C), 0)};
Dan Gohman6736f592017-12-08 21:18:21 +0000256 FunctionType *MainTy = FunctionType::get(Type::getInt32Ty(C), MainArgTys,
257 /*isVarArg=*/false);
Heejin Ahn18c56a02019-02-04 19:13:39 +0000258 if (shouldFixMainFunction(F.getFunctionType(), MainTy)) {
Sam Clegg79c054f2018-09-13 17:13:10 +0000259 LLVM_DEBUG(dbgs() << "Found `main` function with incorrect type: "
260 << *F.getFunctionType() << "\n");
Heejin Ahnf208f632018-09-05 01:27:38 +0000261 Value *Args[] = {UndefValue::get(MainArgTys[0]),
262 UndefValue::get(MainArgTys[1])};
263 Value *Casted =
264 ConstantExpr::getBitCast(Main, PointerType::get(MainTy, 0));
James Y Knight7976eb52019-02-01 20:43:25 +0000265 CallMain = CallInst::Create(MainTy, Casted, Args, "call_main");
Dan Gohman6736f592017-12-08 21:18:21 +0000266 Use *UseMain = &CallMain->getOperandUse(2);
267 Uses.push_back(std::make_pair(UseMain, &F));
268 }
269 }
270 }
Dan Gohman1b637452017-01-07 00:34:54 +0000271
272 DenseMap<std::pair<Function *, FunctionType *>, Function *> Wrappers;
273
274 for (auto &UseFunc : Uses) {
275 Use *U = UseFunc.first;
276 Function *F = UseFunc.second;
Heejin Ahn18c56a02019-02-04 19:13:39 +0000277 auto *PTy = cast<PointerType>(U->get()->getType());
278 auto *Ty = dyn_cast<FunctionType>(PTy->getElementType());
Dan Gohman1b637452017-01-07 00:34:54 +0000279
280 // If the function is casted to something like i8* as a "generic pointer"
281 // to be later casted to something else, we can't generate a wrapper for it.
282 // Just ignore such casts for now.
283 if (!Ty)
284 continue;
285
286 auto Pair = Wrappers.insert(std::make_pair(std::make_pair(F, Ty), nullptr));
287 if (Pair.second)
Heejin Ahn18c56a02019-02-04 19:13:39 +0000288 Pair.first->second = createWrapper(F, Ty);
Dan Gohman1b637452017-01-07 00:34:54 +0000289
Dan Gohman0e2ceb82017-01-07 01:50:01 +0000290 Function *Wrapper = Pair.first->second;
291 if (!Wrapper)
292 continue;
293
Dan Gohman1b637452017-01-07 00:34:54 +0000294 if (isa<Constant>(U->get()))
Dan Gohman0e2ceb82017-01-07 01:50:01 +0000295 U->get()->replaceAllUsesWith(Wrapper);
Dan Gohman1b637452017-01-07 00:34:54 +0000296 else
Dan Gohman0e2ceb82017-01-07 01:50:01 +0000297 U->set(Wrapper);
Dan Gohman1b637452017-01-07 00:34:54 +0000298 }
299
Dan Gohman6736f592017-12-08 21:18:21 +0000300 // If we created a wrapper for main, rename the wrapper so that it's the
301 // one that gets called from startup.
302 if (CallMain) {
303 Main->setName("__original_main");
Heejin Ahn18c56a02019-02-04 19:13:39 +0000304 auto *MainWrapper =
Dan Gohman6736f592017-12-08 21:18:21 +0000305 cast<Function>(CallMain->getCalledValue()->stripPointerCasts());
Dan Gohman6736f592017-12-08 21:18:21 +0000306 delete CallMain;
Dan Gohman4684f822019-01-29 10:53:42 +0000307 if (Main->isDeclaration()) {
308 // The wrapper is not needed in this case as we don't need to export
309 // it to anyone else.
310 MainWrapper->eraseFromParent();
311 } else {
312 // Otherwise give the wrapper the same linkage as the original main
313 // function, so that it can be called from the same places.
314 MainWrapper->setName("main");
315 MainWrapper->setLinkage(Main->getLinkage());
316 MainWrapper->setVisibility(Main->getVisibility());
317 }
Dan Gohman6736f592017-12-08 21:18:21 +0000318 }
319
Dan Gohman1b637452017-01-07 00:34:54 +0000320 return true;
321}