blob: 99d5e82e9c37848151701f6722381e0c41cc6751 [file] [log] [blame]
Yang Ni6749f542016-11-07 20:20:49 -08001/*
2 * Copyright 2016, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "RSSPIRVWriter.h"
18
19#include "SPIRVModule.h"
20#include "bcinfo/MetadataExtractor.h"
21
22#include "llvm/ADT/StringMap.h"
23#include "llvm/ADT/Triple.h"
24#include "llvm/IR/LegacyPassManager.h"
25#include "llvm/IR/Module.h"
26#include "llvm/Support/CommandLine.h"
27#include "llvm/Support/Debug.h"
28#include "llvm/Support/SPIRV.h"
29#include "llvm/Support/raw_ostream.h"
30#include "llvm/Transforms/IPO.h"
31
32#include "GlobalMergePass.h"
Yang Niec9360e2016-11-21 15:50:27 -080033#include "InlinePreparationPass.h"
Yang Ni6749f542016-11-07 20:20:49 -080034#include "LinkerModule.h"
35#include "ReflectionPass.h"
Yang Ni6749f542016-11-07 20:20:49 -080036
37#include <fstream>
38#include <sstream>
39
40#define DEBUG_TYPE "rs2spirv-writer"
41
42using namespace llvm;
43using namespace SPIRV;
44
45namespace llvm {
46FunctionPass *createPromoteMemoryToRegisterPass();
47}
48
49namespace rs2spirv {
50
51static cl::opt<std::string> WrapperOutputFile("wo",
52 cl::desc("Wrapper output file"),
53 cl::value_desc("filename.spt"));
54
55static bool FixMain(LinkerModule &LM, MainFunBlock &MB, StringRef KernelName);
56static bool InlineFunctionCalls(LinkerModule &LM, MainFunBlock &MB);
57static bool FuseTypesAndConstants(LinkerModule &LM);
58static bool TranslateInBoundsPtrAccessToAccess(SPIRVLine &L);
59static bool FixVectorShuffles(MainFunBlock &MB);
60static void FixModuleStorageClass(LinkerModule &M);
61
62static void HandleTargetTriple(Module &M) {
63 Triple TT(M.getTargetTriple());
64 auto Arch = TT.getArch();
65
66 StringRef NewTriple;
67 switch (Arch) {
68 default:
69 llvm_unreachable("Unrecognized architecture");
70 break;
71 case Triple::arm:
72 NewTriple = "spir-unknown-unknown";
73 break;
74 case Triple::aarch64:
75 NewTriple = "spir64-unknown-unknown";
76 break;
77 case Triple::spir:
78 case Triple::spir64:
79 DEBUG(dbgs() << "!!! Already a spir triple !!!\n");
80 }
81
82 DEBUG(dbgs() << "New triple:\t" << NewTriple << "\n");
83 M.setTargetTriple(NewTriple);
84}
85
86void addPassesForRS2SPIRV(llvm::legacy::PassManager &PassMgr) {
87 PassMgr.add(createGlobalMergePass());
88 PassMgr.add(createPromoteMemoryToRegisterPass());
89 PassMgr.add(createTransOCLMD());
90 // TODO: investigate removal of OCLTypeToSPIRV pass.
91 PassMgr.add(createOCLTypeToSPIRV());
92 PassMgr.add(createSPIRVRegularizeLLVM());
93 PassMgr.add(createSPIRVLowerConstExpr());
94 PassMgr.add(createSPIRVLowerBool());
95 PassMgr.add(createAlwaysInlinerPass());
96}
97
98bool WriteSPIRV(Module *M, llvm::raw_ostream &OS, std::string &ErrMsg) {
99 std::unique_ptr<SPIRVModule> BM(SPIRVModule::createSPIRVModule());
100
101 HandleTargetTriple(*M);
102
103 bcinfo::MetadataExtractor ME(M);
104 if (!ME.extract()) {
105 errs() << "Could not extract metadata\n";
106 return false;
107 }
108 DEBUG(dbgs() << "Metadata extracted\n");
109
110 llvm::legacy::PassManager PassMgr;
111 PassMgr.add(createInlinePreparationPass(ME));
112 addPassesForRS2SPIRV(PassMgr);
113
114 std::ofstream WrapperF;
115 if (!WrapperOutputFile.empty()) {
116 WrapperF.open(WrapperOutputFile, std::ios::trunc);
117 if (!WrapperF.good()) {
118 errs() << "Could not create/open file:\t" << WrapperOutputFile << "\n";
119 return false;
120 }
121 DEBUG(dbgs() << "Wrapper output:\t" << WrapperOutputFile << "\n");
122 PassMgr.add(createReflectionPass(WrapperF, ME));
123 }
124
125 PassMgr.add(createLLVMToSPIRV(BM.get()));
126 PassMgr.run(*M);
127 DEBUG(M->dump());
128
129 if (BM->getError(ErrMsg) != SPIRVEC_Success)
130 return false;
131
132 OS << *BM;
133
134 return true;
135}
136
137bool Link(llvm::StringRef KernelFilename, llvm::StringRef WrapperFilename,
138 llvm::StringRef OutputFilename) {
139 DEBUG(dbgs() << "Linking...\n");
140
141 std::ifstream WrapperF(WrapperFilename);
142 if (!WrapperF.good()) {
143 errs() << "Cannot open file: " << WrapperFilename << "\n";
144 }
145 std::ifstream KernelF(KernelFilename);
146 if (!KernelF.good()) {
147 errs() << "Cannot open file: " << KernelFilename << "\n";
148 }
149
150 LinkerModule WrapperM(WrapperF);
151 LinkerModule KernelM(KernelF);
152
153 WrapperF.close();
154 KernelF.close();
155
156 DEBUG(dbgs() << "WrapperF:\n");
157 DEBUG(WrapperM.dump());
158 DEBUG(dbgs() << "\n~~~~~~~~~~~~~~~~~~~~~~\n\nKernelF:\n");
159 DEBUG(KernelM.dump());
160 DEBUG(dbgs() << "\n======================\n\n");
161
162 const char *const Prefix = "%rs_linker_";
163
164 for (auto *LPtr : KernelM.lines()) {
165 assert(LPtr);
166 auto &L = *LPtr;
167 size_t Pos = 0;
168 while ((Pos = L.str().find("%", Pos)) != std::string::npos) {
169 L.str().replace(Pos, 1, Prefix);
170 Pos += strlen(Prefix);
171 }
172 }
173
174 FixModuleStorageClass(KernelM);
175 DEBUG(KernelM.dump());
176
177 auto WBlocks = WrapperM.blocks();
178 auto WIt = WBlocks.begin();
179 const auto WEnd = WBlocks.end();
180
181 auto KBlocks = KernelM.blocks();
182 auto KIt = KBlocks.begin();
183 const auto KEnd = KBlocks.end();
184
185 LinkerModule OutM;
186
187 if (WIt == WEnd || KIt == KEnd)
188 return false;
189
190 const auto *HeaderB = dyn_cast<HeaderBlock>(WIt->get());
191 if (!HeaderB || !isa<HeaderBlock>(KIt->get()))
192 return false;
193
194 SmallVector<StringRef, 2> KernelNames;
195 const bool KernelsFound = HeaderB->getRSKernelNames(KernelNames);
196
197 if (!KernelsFound) {
198 errs() << "RS kernel names not found in wrapper\n";
199 return false;
200 }
201
202 // TODO: Support more than one kernel.
203 if (KernelNames.size() != 1) {
204 errs() << "Unsupported number of kernels: " << KernelNames.size() << '\n';
205 return false;
206 }
207
208 const std::string KernelName =
209 Prefix + KernelNames.front().drop_front().str();
210 DEBUG(dbgs() << "Kernel name: " << KernelName << '\n');
211
212 // Kernel's HeaderBlock is skipped - it has OpenCL-specific code that
213 // is replaced here with compute shader code.
214
215 OutM.addBlock<HeaderBlock>(*HeaderB);
216
217 if (++WIt == WEnd || ++KIt == KEnd)
218 return false;
219
220 const auto *DecorBW = dyn_cast<DecorBlock>(WIt->get());
221 if (!DecorBW || !isa<DecorBlock>(KIt->get()))
222 return false;
223
224 // Kernel's DecorBlock is skipped, because it contains OpenCL-specific code
225 // that is not needed (eg. linkage type information).
226
227 OutM.addBlock<DecorBlock>(*DecorBW);
228
229 if (++WIt == WEnd || ++KIt == KEnd)
230 return false;
231
232 const auto *TypeAndConstBW = dyn_cast<TypeAndConstBlock>(WIt->get());
233 auto *TypeAndConstBK = dyn_cast<TypeAndConstBlock>(KIt->get());
234 if (!TypeAndConstBW || !TypeAndConstBK)
235 return false;
236
237 OutM.addBlock<TypeAndConstBlock>(*TypeAndConstBW);
238 OutM.addBlock<TypeAndConstBlock>(*TypeAndConstBK);
239
240 if (++WIt == WEnd || ++KIt == KEnd)
241 return false;
242
243 const auto *VarBW = dyn_cast<VarBlock>(WIt->get());
244 auto *VarBK = dyn_cast<VarBlock>(KIt->get());
245 if (!VarBW)
246 return false;
247
248 OutM.addBlock<VarBlock>(*VarBW);
249
250 if (VarBK)
251 OutM.addBlock<VarBlock>(*VarBK);
252 else
253 --KIt;
254
255 MainFunBlock *MainB = nullptr;
256
257 while (++WIt != WEnd) {
258 auto *FunB = dyn_cast<FunctionBlock>(WIt->get());
259 if (!FunB)
260 return false;
261
262 if (auto *MB = dyn_cast<MainFunBlock>(WIt->get())) {
263 if (MainB) {
264 errs() << "More than one main function found in wrapper module\n";
265 return false;
266 }
267
268 MainB = &OutM.addBlock<MainFunBlock>(*MB);
269 } else {
270 OutM.addBlock<FunctionBlock>(*FunB);
271 }
272 }
273
274 if (!MainB) {
275 errs() << "Wrapper module has no main function\n";
276 return false;
277 }
278
279 while (++KIt != KEnd) {
280 // TODO: Check if FunDecl is a known runtime function.
281 if (isa<FunDeclBlock>(KIt->get()))
282 continue;
283
284 auto *FunB = dyn_cast<FunctionBlock>(KIt->get());
285 if (!FunB)
286 return false;
287
288 // TODO: Detect also indirect recurion.
289 if (FunB->isDirectlyRecursive()) {
290 errs() << "Function: " << FunB->getFunctionName().str()
291 << " is recursive\n";
292 return false;
293 }
294
295 OutM.addBlock<FunctionBlock>(*FunB);
296 }
297
298 OutM.fixBlockOrder();
299 if (!FixMain(OutM, *MainB, KernelName))
300 return false;
301
302 if (!FixVectorShuffles(*MainB))
303 return false;
304
305 OutM.removeUnusedFunctions();
306
307 DEBUG(dbgs() << ">>>>>>>>>>>> Output module after prelink:\n\n");
308 DEBUG(OutM.dump());
309
310 if (!FuseTypesAndConstants(OutM)) {
311 errs() << "Type fusion failed\n";
312 return false;
313 }
314
315 DEBUG(dbgs() << ">>>>>>>>>>>> Output module after value fusion:\n\n");
316 DEBUG(OutM.dump());
317
318 if (!OutM.saveToFile(OutputFilename)) {
319 errs() << "Could not save to file: " << OutputFilename << "\n";
320 return false;
321 }
322
323 return true;
324}
325
326bool FixMain(LinkerModule &LM, MainFunBlock &MainB, StringRef KernelName) {
327 MainB.replaceAllIds("%RS_SPIRV_DUMMY_", KernelName);
328
329 while (MainB.hasFunctionCalls())
330 if (!InlineFunctionCalls(LM, MainB)) {
331 errs() << "Could not inline function calls in main\n";
332 return false;
333 }
334
335 for (auto &L : MainB.lines()) {
336 if (!L.contains("OpInBoundsPtrAccessChain"))
337 continue;
338
339 if (!TranslateInBoundsPtrAccessToAccess(L))
340 return false;
341 }
342
343 return true;
344}
345
346struct FunctionCallInfo {
347 StringRef RetValName;
348 StringRef RetTy;
349 StringRef FName;
350 SmallVector<StringRef, 4> ArgNames;
351};
352
353static FunctionCallInfo GetFunctionCallInfo(const SPIRVLine &L) {
354 assert(L.contains("OpFunctionCall"));
355
356 const Optional<StringRef> Ret = L.getLHSIdentifier();
357 assert(Ret);
358
359 SmallVector<StringRef, 6> Ids;
360 L.getRHSIdentifiers(Ids);
361 assert(Ids.size() >= 2 && "No return type and function name");
362
363 const StringRef RetTy = Ids[0];
364 const StringRef FName = Ids[1];
365 SmallVector<StringRef, 4> Args(Ids.begin() + 2, Ids.end());
366
367 return {*Ret, RetTy, FName, std::move(Args)};
368}
369
370bool InlineFunctionCalls(LinkerModule &LM, MainFunBlock &MB) {
371 DEBUG(dbgs() << "InlineFunctionCalls\n");
372 MainFunBlock NewMB;
373
374 auto MLines = MB.lines();
375 auto MIt = MLines.begin();
376 const auto MEnd = MLines.end();
377 using iter_ty = decltype(MIt);
378
379 auto SkipToFunctionCall = [&MEnd, &NewMB](iter_ty &It) {
380 while (++It != MEnd && !It->contains("OpFunctionCall"))
381 NewMB.addLine(*It);
382
383 return It != MEnd;
384 };
385
386 NewMB.addLine(*MIt);
387
388 std::vector<std::pair<std::string, std::string>> NameMapping;
389
390 while (SkipToFunctionCall(MIt)) {
391 assert(MIt->contains("OpFunctionCall"));
392 const auto FInfo = GetFunctionCallInfo(*MIt);
393 DEBUG(dbgs() << "Found function call:\t" << MIt->str() << '\n');
394
395 SmallVector<Block *, 1> Callee;
396 LM.getBlocksIf(Callee, [&FInfo](Block &B) {
397 auto *FB = dyn_cast<FunctionBlock>(&B);
398 if (!FB)
399 return false;
400
401 return FB->getFunctionName() == FInfo.FName;
402 });
403
404 if (Callee.size() != 1) {
405 errs() << "Callee not found\n";
406 return false;
407 }
408
409 auto *FB = cast<FunctionBlock>(Callee.front());
410
411 if (FB->getArity() != FInfo.ArgNames.size()) {
412 errs() << "Arity mismatch (caller: " << FInfo.ArgNames.size()
413 << ", callee: " << FB->getArity() << ")\n";
414 return false;
415 }
416
417 Optional<StringRef> RetValName = FB->getRetValName();
418 if (!RetValName && !FB->isReturnTypeVoid()) {
419 errs() << "Return value not found for a function with non-void "
420 "return type.\n";
421 return false;
422 }
423
424 SmallVector<StringRef, 4> Params;
425 FB->getArgNames(Params);
426
427 if (Params.size() != FInfo.ArgNames.size()) {
428 errs() << "Params size mismatch\n";
429 return false;
430 }
431
432 for (size_t i = 0, e = FInfo.ArgNames.size(); i < e; ++i) {
433 DEBUG(dbgs() << "New param mapping: " << Params[i] << " -> "
434 << FInfo.ArgNames[i] << "\n");
435 NameMapping.emplace_back(Params[i].str(), FInfo.ArgNames[i].str());
436 }
437
438 if (RetValName) {
439 DEBUG(dbgs() << "New ret-val mapping: " << FInfo.RetValName << " -> "
440 << *RetValName << "\n");
441 NameMapping.emplace_back(FInfo.RetValName.str(), RetValName->str());
442 }
443
444 const auto Body = FB->body();
445 for (const auto &L : Body)
446 NewMB.addLine(L);
447 }
448
449 while (MIt != MEnd) {
450 NewMB.addLine(*MIt);
451 ++MIt;
452 }
453
454 std::reverse(NameMapping.begin(), NameMapping.end());
455 for (const auto &P : NameMapping) {
456 DEBUG(dbgs() << "Replace " << P.first << " with " << P.second << "\n");
457 NewMB.replaceAllIds(P.first, P.second);
458 }
459
460 MB = NewMB;
461
462 return true;
463}
464
465bool FuseTypesAndConstants(LinkerModule &LM) {
466 StringMap<std::string> TypesAndConstDefs;
467 StringMap<std::string> NameReps;
468
469 for (auto *LPtr : LM.lines()) {
470 assert(LPtr);
471 auto &L = *LPtr;
472 if (!L.contains("="))
473 continue;
474
475 SmallVector<StringRef, 4> IdsRefs;
476 L.getRHSIdentifiers(IdsRefs);
477
478 SmallVector<std::string, 4> Ids;
479 Ids.reserve(IdsRefs.size());
480 for (const auto &I : IdsRefs)
481 Ids.push_back(I.str());
482
483 for (auto &I : Ids)
484 if (NameReps.count(I) != 0) {
485 const bool Res = L.replaceId(I, NameReps[I]);
486 (void)Res;
487 assert(Res);
488 }
489
490 if (L.contains("OpType") || L.contains("OpConstant")) {
491 const auto LHS = L.getLHSIdentifier();
492 const auto RHS = L.getRHS();
493 assert(LHS);
494 assert(RHS);
495
496 if (TypesAndConstDefs.count(*RHS) != 0) {
497 NameReps.insert(
498 std::make_pair(LHS->str(), TypesAndConstDefs[RHS->str()]));
499 DEBUG(dbgs() << "New mapping: [" << LHS->str() << ", "
500 << TypesAndConstDefs[RHS->str()] << "]\n");
501 L.markAsEmpty();
502 } else {
503 TypesAndConstDefs.insert(std::make_pair(RHS->str(), LHS->str()));
504 DEBUG(dbgs() << "New val:\t" << RHS->str() << " : " << LHS->str()
505 << '\n');
506 }
507 };
508 }
509
510 LM.removeNonCode();
511
512 return true;
513}
514
515bool TranslateInBoundsPtrAccessToAccess(SPIRVLine &L) {
516 assert(L.contains(" OpInBoundsPtrAccessChain "));
517
518 SmallVector<StringRef, 4> Ids;
519 L.getRHSIdentifiers(Ids);
520
521 if (Ids.size() < 4) {
522 errs() << "OpInBoundsPtrAccessChain has not enough parameters:\n\t"
523 << L.str();
524 return false;
525 }
526
527 std::istringstream SS(L.str());
528 std::string LHS, Eq, Op;
529 SS >> LHS >> Eq >> Op;
530
531 if (LHS.empty() || Eq != "=" || Op != "OpInBoundsPtrAccessChain") {
532 errs() << "Could not decompose OpInBoundsPtrAccessChain:\n\t" << L.str();
533 return false;
534 }
535
536 constexpr size_t ElementArgPosition = 2;
537
538 std::ostringstream NewLine;
539 NewLine << LHS << " " << Eq << " OpAccessChain ";
540 for (size_t i = 0, e = Ids.size(); i != e; ++i)
541 if (i != ElementArgPosition)
542 NewLine << Ids[i].str() << " ";
543
544 L.str() = NewLine.str();
545 L.trim();
546
547 return true;
548}
549
550// Replaces UndefValues in VectorShuffles with zeros, which is always
551// safe, as the result for components marked as Undef is unused.
552// Ex. 1) OpVectorShuffle %v4uchar %a %b 0 1 2 4294967295 -->
553// OpVectorShuffle %v4uchar %a %b 0 1 2 0.
554//
555// Ex. 2) OpVectorShuffle %v4uchar %a %b 0 4294967295 3 4294967295 -->
556// OpVectorShuffle %v4uchar %a %b 0 0 3 0.
557//
558// Fix needed for the current Vulkan driver, which crashed during
559// backend compilation when case is not handled.
560bool FixVectorShuffles(MainFunBlock &MB) {
561 const StringRef UndefStr = " 4294967295 ";
562
563 for (auto &L : MB.lines()) {
564 if (!L.contains("OpVectorShuffle"))
565 continue;
566
567 L.str().push_back(' ');
568 while (L.contains(UndefStr))
569 L.replaceStr(UndefStr, " 0 ");
570
571 L.trim();
572 }
573
574 return true;
575}
576
577// This function changes all Function StorageClass use into Uniform.
578// It's needed, because llvm-spirv converter emits wrong StorageClass
579// for globals.
580// The transfromation, however, breaks legitimate uses of Function StorageClass
581// inside functions.
582//
583// Ex. 1. %ptr_Function_uint = OpTypePointer Function %uint
584// --> %ptr_Uniform_uint = OpTypePointer Uniform %uint
585//
586// Ex. 2. %gep = OpAccessChain %ptr_Function_uchar %G %uint_zero
587// --> %gep = OpAccessChain %ptr_Uniform_uchar %G %uint_zero
588//
589// TODO: Consider a better way of fixing this.
590void FixModuleStorageClass(LinkerModule &M) {
591 for (auto *LPtr : M.lines()) {
592 assert(LPtr);
593 auto &L = *LPtr;
594
595 while (L.contains(" Function"))
596 L.replaceStr(" Function", " Uniform");
597
598 while (L.contains("_Function_"))
599 L.replaceStr("_Function_", "_Uniform_");
600 }
601}
602
603} // namespace rs2spirv