blob: 539c1e4215690c0ca5ed5d2929b682bd7904737b [file] [log] [blame]
Yang Ni6749f542016-11-07 20:20:49 -08001/*
2 * Copyright 2016, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "ReflectionPass.h"
18
I-Jui (Ray) Sung3f1f6a12016-12-07 18:04:08 -080019#include "KernelSignature.h"
20
Yang Ni6749f542016-11-07 20:20:49 -080021#include "RSAllocationUtils.h"
22#include "bcinfo/MetadataExtractor.h"
23#include "llvm/ADT/StringSwitch.h"
Yang Ni6749f542016-11-07 20:20:49 -080024#include "llvm/IR/Instructions.h"
Yang Niec9360e2016-11-21 15:50:27 -080025#include "llvm/IR/Module.h"
Yang Ni6749f542016-11-07 20:20:49 -080026#include "llvm/IR/PassManager.h"
27#include "llvm/Pass.h"
28#include "llvm/Support/Debug.h"
29#include "llvm/Support/SPIRV.h"
30
31#include <map>
32#include <sstream>
33#include <string>
34#include <type_traits>
35#include <unordered_map>
36#include <unordered_set>
37
38#define DEBUG_TYPE "rs2spirv-reflection"
39
40using namespace llvm;
41
Yang Ni6749f542016-11-07 20:20:49 -080042namespace {
43
Yang Ni6749f542016-11-07 20:20:49 -080044static const StringRef CoordsNames[] = {"x", "y", "z"};
45
I-Jui (Ray) Sung3f1f6a12016-12-07 18:04:08 -080046}
Yang Ni6749f542016-11-07 20:20:49 -080047
I-Jui (Ray) Sung3f1f6a12016-12-07 18:04:08 -080048namespace rs2spirv {
Yang Ni6749f542016-11-07 20:20:49 -080049
I-Jui (Ray) Sung3f1f6a12016-12-07 18:04:08 -080050void KernelSignature::dump() const {
51 dbgs() << returnType << ' ' << name << '(' << argumentType;
52 const auto CoordsNum = size_t(coordsKind);
53 for (size_t i = 0; i != CoordsNum; ++i)
54 dbgs() << ", " << CoordsNames[i];
55
56 dbgs() << ")\n";
57}
58
59const std::string KernelSignature::wrapperPrefix = "%__rsov_entry_";
60
61namespace {
Yang Ni6749f542016-11-07 20:20:49 -080062
63std::string TypeToString(const Type *Ty) {
64 assert(Ty);
65 if (Ty->isVoidTy())
66 return "void";
67
68 if (auto *IT = dyn_cast<IntegerType>(Ty)) {
69 if (IT->getBitWidth() == 32)
70 return "int";
71 else if (IT->getBitWidth() == 8)
72 return "uchar";
Yang Ni6749f542016-11-07 20:20:49 -080073 }
I-Jui (Ray) Sung333659c2016-12-02 15:34:13 -080074
Yang Ni6749f542016-11-07 20:20:49 -080075 if (Ty->isFloatTy())
76 return "float";
77
78 if (auto *VT = dyn_cast<VectorType>(Ty)) {
79 auto *ET = VT->getElementType();
80 if (auto *IT = dyn_cast<IntegerType>(ET)) {
81 if (IT->getBitWidth() == 32)
82 return "int4";
83 else if (IT->getBitWidth() == 8)
84 return "uchar4";
Yang Ni6749f542016-11-07 20:20:49 -080085 }
86 if (ET->isFloatTy())
87 return "float4";
Yang Ni6749f542016-11-07 20:20:49 -080088 }
89
David Gross5802e622016-11-28 11:50:17 -080090 std::string badNameString;
91 raw_string_ostream badNameStream(badNameString);
92 badNameStream << '[';
93 Ty->print(badNameStream);
94 badNameStream << ']';
95 return badNameStream.str();
Yang Ni6749f542016-11-07 20:20:49 -080096}
97
98enum class RSType {
David Gross5802e622016-11-28 11:50:17 -080099 rs_bad = -1,
Yang Ni6749f542016-11-07 20:20:49 -0800100 rs_void,
101 rs_uchar,
102 rs_int,
103 rs_float,
104 rs_uchar4,
105 rs_int4,
106 rs_float4
107};
108
109RSType StrToRsTy(StringRef S) {
110 RSType Ty = StringSwitch<RSType>(S)
111 .Case("void", RSType::rs_void)
112 .Case("uchar", RSType::rs_uchar)
113 .Case("int", RSType::rs_int)
114 .Case("float", RSType::rs_float)
115 .Case("uchar4", RSType::rs_uchar4)
116 .Case("int4", RSType::rs_int4)
David Gross5802e622016-11-28 11:50:17 -0800117 .Case("float4", RSType::rs_float4)
118 .Default(RSType::rs_bad);
Yang Ni6749f542016-11-07 20:20:49 -0800119 return Ty;
120}
121
122struct TypeMapping {
123 RSType RSTy;
124 bool isVectorTy;
125 // Scalar types are accessed (loaded/stored) using wider (vector) types.
126 // 'vecLen' corresponds to width of such vector type.
127 // As for vector types, 'vectorWidth' is just width of such type.
128 size_t vectorWidth;
129 std::string SPIRVTy;
130 std::string SPIRVScalarTy;
131 std::string SPIRVImageFormat;
132 // TODO: Handle different image formats for read and write.
133 std::string SPIRVImageReadType;
134
135 TypeMapping(RSType RSTy, bool IsVectorTy, size_t VectorLen,
136 StringRef SPIRVScalarTy, StringRef SPIRVImageFormat)
137 : RSTy(RSTy), isVectorTy(IsVectorTy), vectorWidth(VectorLen),
138 SPIRVScalarTy(SPIRVScalarTy), SPIRVImageFormat(SPIRVImageFormat) {
139 assert(vectorWidth != 0);
140
141 if (isVectorTy) {
142 std::ostringstream OSS;
143 OSS << "%v" << vectorWidth << SPIRVScalarTy.drop_front().str();
144 SPIRVTy = OSS.str();
145 SPIRVImageReadType = SPIRVTy;
146 return;
147 }
148
149 SPIRVTy = SPIRVScalarTy;
150 std::ostringstream OSS;
151 OSS << "%v" << vectorWidth << SPIRVScalarTy.drop_front().str();
152 SPIRVImageReadType = OSS.str();
153 }
154};
155
156class ReflectionPass : public ModulePass {
157 std::ostream &OS;
158 bcinfo::MetadataExtractor &ME;
159
160 static const std::map<RSType, TypeMapping> TypeMappings;
161
162 static const TypeMapping *getMapping(RSType RsTy) {
163 auto it = TypeMappings.find(RsTy);
164 if (it != TypeMappings.end())
165 return &it->second;
166
167 return nullptr;
168 };
169
170 static const TypeMapping *getMapping(StringRef Str) {
171 auto Ty = StrToRsTy(Str);
172 return getMapping(Ty);
173 }
174
175 static const TypeMapping *getMappingOrPrintError(StringRef Str) {
176 const auto *TM = ReflectionPass::getMapping(Str);
177 if (!TM)
178 errs() << "LLVM to SPIRV type mapping for type:\t" << Str
179 << " not found\n";
180
181 return TM;
182 }
183
I-Jui (Ray) Sung3f1f6a12016-12-07 18:04:08 -0800184 bool emitHeader(const Module &M, const KernelSignature &Kernel);
Yang Ni6749f542016-11-07 20:20:49 -0800185 bool emitDecorations(const Module &M,
186 const SmallVectorImpl<RSAllocationInfo> &RSAllocs);
187 void emitCommonTypes();
188 bool extractKernelSignatures(const Module &M,
189 SmallVectorImpl<KernelSignature> &Out);
190 bool emitKernelTypes(const KernelSignature &Kernel);
191 bool emitInputImage(const KernelSignature &Kernel);
192 void emitGLGlobalInput();
193 bool emitOutputImage(const KernelSignature &Kernel);
194 bool emitRSAllocImages(const SmallVectorImpl<RSAllocationInfo> &RSAllocs);
195 bool emitConstants(const KernelSignature &Kernel);
196 void emitRTFunctions();
197 bool emitRSAllocFunctions(
198 Module &M, const SmallVectorImpl<RSAllocationInfo> &RSAllocs,
199 const SmallVectorImpl<RSAllocationCallInfo> &RSAllocAccesses);
200 bool emitMain(const KernelSignature &Kernel,
201 const SmallVectorImpl<RSAllocationInfo> &RSAllocs);
202
203public:
204 static char ID;
205 explicit ReflectionPass(std::ostream &OS, bcinfo::MetadataExtractor &ME)
206 : ModulePass(ID), OS(OS), ME(ME) {}
207
208 const char *getPassName() const override { return "ReflectionPass"; }
209
210 bool runOnModule(Module &M) override {
211 DEBUG(dbgs() << "ReflectionPass\n");
212
I-Jui (Ray) Sung3f1f6a12016-12-07 18:04:08 -0800213 SmallVector<KernelSignature, 4> Kernels;
214 if (!extractKernelSignatures(M, Kernels)) {
215 errs() << "Extraction of kernels failed\n";
216 return false;
217 }
218
219 if (Kernels.size() != 1) {
220 errs() << "Non single-kernel modules are not supported\n";
221 return false;
222 }
223 const auto &Kernel = Kernels.front();
224
225 if (!emitHeader(M, Kernel)) {
Yang Ni6749f542016-11-07 20:20:49 -0800226 errs() << "Emiting header failed\n";
227 return false;
228 }
229
230 SmallVector<RSAllocationInfo, 2> RSAllocs;
231 if (!getRSAllocationInfo(M, RSAllocs)) {
232 errs() << "Extracting rs_allocation info failed\n";
233 return false;
234 }
235
236 SmallVector<RSAllocationCallInfo, 4> RSAllocAccesses;
237 if (!getRSAllocAccesses(RSAllocs, RSAllocAccesses)) {
238 errs() << "Extracting rsGEA/rsSEA info failed\n";
239 return false;
240 }
241
242 if (!emitDecorations(M, RSAllocs)) {
243 errs() << "Emiting decorations failed\n";
244 return false;
245 }
246
247 emitCommonTypes();
248
Yang Ni6749f542016-11-07 20:20:49 -0800249
250 if (!emitKernelTypes(Kernel)) {
251 errs() << "Emitting kernel types failed\n";
252 return false;
253 }
254
255 if (!emitInputImage(Kernel)) {
256 errs() << "Emitting input image failed\n";
257 return false;
258 }
259
260 emitGLGlobalInput();
261
262 if (!emitOutputImage(Kernel)) {
263 errs() << "Emitting output image failed\n";
264 return false;
265 }
266
267 if (!emitRSAllocImages(RSAllocs)) {
268 errs() << "Emitting rs_allocation images failed\n";
269 return false;
270 }
271
272 if (!emitConstants(Kernel)) {
273 errs() << "Emitting constants failed\n";
274 return false;
275 }
276
277 emitRTFunctions();
278
279 if (!emitRSAllocFunctions(M, RSAllocs, RSAllocAccesses)) {
280 errs() << "Emitting rs_allocation runtime functions failed\n";
281 return false;
282 }
283
284 if (!emitMain(Kernel, RSAllocs)) {
285 errs() << "Emitting main failed\n";
286 return false;
287 }
288
289 // Return false, as the module is not modified.
290 return false;
291 }
292};
293
294// TODO: Add other types: bool, double, char, uchar, long, ulong
295// and their vector counterparts.
296// TODO: Support vector types of width different than 4. eg. float3.
297const std::map<RSType, TypeMapping> ReflectionPass::TypeMappings = {
298 {RSType::rs_void, {RSType::rs_void, false, 1, "%void", ""}},
299 {RSType::rs_uchar, {RSType::rs_uchar, false, 4, "%uchar", "R8ui"}},
300 {RSType::rs_int, {RSType::rs_void, false, 4, "%int", "R32i"}},
301 {RSType::rs_float, {RSType::rs_float, false, 4, "%float", "R32f"}},
302 {RSType::rs_uchar4, {RSType::rs_uchar4, true, 4, "%uchar", "Rgba8ui"}},
303 {RSType::rs_int4, {RSType::rs_int4, true, 4, "%int", "Rgba32i"}},
304 {RSType::rs_float4, {RSType::rs_float4, true, 4, "%float", "Rgba32f"}}};
305};
306
307char ReflectionPass::ID = 0;
308
309ModulePass *createReflectionPass(std::ostream &OS,
310 bcinfo::MetadataExtractor &ME) {
311 return new ReflectionPass(OS, ME);
312}
313
I-Jui (Ray) Sung3f1f6a12016-12-07 18:04:08 -0800314bool ReflectionPass::emitHeader(const Module &M,
315 const KernelSignature &Kernel) {
Yang Ni6749f542016-11-07 20:20:49 -0800316 DEBUG(dbgs() << "emitHeader\n");
317
318 OS << "; SPIR-V\n"
319 "; Version: 1.0\n"
320 "; Generator: rs2spirv;\n"
321 "; Bound: 1024\n"
322 "; Schema: 0\n"
323 " OpCapability Shader\n"
324 " OpCapability StorageImageWriteWithoutFormat\n"
325 " OpCapability Addresses\n"
326 " %glsl_ext_ins = OpExtInstImport \"GLSL.std.450\"\n"
I-Jui (Ray) Sung3f1f6a12016-12-07 18:04:08 -0800327 " OpMemoryModel Physical32 GLSL450\n";
328 OS << " OpEntryPoint GLCompute " << Kernel.getWrapperName() << " ";
329 OS << "\"main\" %global_invocation_id\n"
330 " OpExecutionMode ";
331 OS << Kernel.getWrapperName() << " LocalSize 1 1 1\n"
Yang Ni6749f542016-11-07 20:20:49 -0800332 " OpSource GLSL 450\n"
333 " OpSourceExtension \"GL_ARB_separate_shader_objects\"\n"
334 " OpSourceExtension \"GL_ARB_shading_language_420pack\"\n"
335 " OpSourceExtension \"GL_GOOGLE_cpp_style_line_directive\"\n"
336 " OpSourceExtension \"GL_GOOGLE_include_directive\"\n";
337
338 const size_t RSKernelNum = ME.getExportForEachSignatureCount();
339
340 if (RSKernelNum == 0)
341 return false;
342
343 const char **RSKernelNames = ME.getExportForEachNameList();
344
345 OS << " %RS_KERNELS = OpString \"";
346
347 for (size_t i = 0; i < RSKernelNum; ++i)
348 if (RSKernelNames[i] != StringRef("root"))
349 OS << '%' << RSKernelNames[i] << " ";
350
351 OS << "\"\n";
352
353 return true;
354}
355
356bool ReflectionPass::emitDecorations(
357 const Module &M, const SmallVectorImpl<RSAllocationInfo> &RSAllocs) {
358 DEBUG(dbgs() << "emitDecorations\n");
359
360 OS << "\n"
361 " OpDecorate %global_invocation_id BuiltIn GlobalInvocationId\n"
362 " OpDecorate %input_image DescriptorSet 0\n"
363 " OpDecorate %input_image Binding 0\n"
364 " OpDecorate %input_image NonWritable\n"
365 " OpDecorate %output_image DescriptorSet 0\n"
366 " OpDecorate %output_image Binding 1\n"
367 " OpDecorate %output_image NonReadable\n";
368
369 const auto GlobalsB = M.globals().begin();
370 const auto GlobalsE = M.globals().end();
371 const auto Found =
372 std::find_if(GlobalsB, GlobalsE, [](const GlobalVariable &GV) {
373 return GV.getName() == "__GPUBlock";
374 });
375
376 if (Found == GlobalsE)
377 return true; // GPUBlock not found - not an error by itself.
378
379 const GlobalVariable &G = *Found;
380
381 DEBUG(dbgs() << "Found GPUBlock:\t");
382 DEBUG(G.dump());
383
384 bool IsCorrectTy = false;
385 if (const auto *PtrTy = dyn_cast<PointerType>(G.getType())) {
386 if (auto *StructTy = dyn_cast<StructType>(PtrTy->getElementType())) {
387 IsCorrectTy = true;
388
389 const auto &DLayout = M.getDataLayout();
390 const auto *SLayout = DLayout.getStructLayout(StructTy);
391 assert(SLayout);
392
393 for (size_t i = 0, e = StructTy->getNumElements(); i != e; ++i)
394 OS << " OpMemberDecorate %rs_linker_struct___GPUBuffer " << i
395 << " Offset " << SLayout->getElementOffset(i) << '\n';
396 }
397 }
398
399 if (!IsCorrectTy) {
400 errs() << "GPUBlock is not of expected type:\t";
401 G.print(errs());
402 G.getType()->print(errs());
403 return false;
404 }
405
406 OS << " OpDecorate %rs_linker_struct___GPUBuffer BufferBlock\n";
407 OS << " OpDecorate %rs_linker___GPUBlock DescriptorSet 0\n";
408 OS << " OpDecorate %rs_linker___GPUBlock Binding 2\n";
409
410 size_t BindingNum = 3;
411
412 for (const auto &A : RSAllocs) {
413 OS << " OpDecorate " << A.VarName << "_var DescriptorSet 0\n";
414 OS << " OpDecorate " << A.VarName << "_var Binding " << BindingNum
415 << '\n';
416 ++BindingNum;
417 }
418
419 return true;
420}
421
422void ReflectionPass::emitCommonTypes() {
423 DEBUG(dbgs() << "emitCommonTypes\n");
424
425 OS << "\n\n"
426 "%void = OpTypeVoid\n"
427 "%fun_void = OpTypeFunction %void\n"
428 "%float = OpTypeFloat 32\n"
429 "%v2float = OpTypeVector %float 2\n"
430 "%v3float = OpTypeVector %float 3\n"
431 "%v4float = OpTypeVector %float 4\n"
432 "%int = OpTypeInt 32 1\n"
433 "%v2int = OpTypeVector %int 2\n"
434 "%v4int = OpTypeVector %int 4\n"
435 "%uchar = OpTypeInt 8 0\n"
436 "%v2uchar = OpTypeVector %uchar 2\n"
437 "%v3uchar = OpTypeVector %uchar 3\n"
438 "%v4uchar = OpTypeVector %uchar 4\n"
439 "%uint = OpTypeInt 32 0\n"
440 "%v2uint = OpTypeVector %uint 2\n"
441 "%v3uint = OpTypeVector %uint 3\n"
442 "%v4uint = OpTypeVector %uint 4\n"
443 "%fun_f3_uc3 = OpTypeFunction %v3float %v3uchar\n"
444 "%fun_f3_u3 = OpTypeFunction %v3float %v3uint\n"
445 "%fun_f4_uc4 = OpTypeFunction %v4float %v4uchar\n"
446 "%fun_uc3_f3 = OpTypeFunction %v3uchar %v3float\n"
447 "%fun_u3_f3 = OpTypeFunction %v3uint %v3float\n"
448 "%fun_uc4_f4 = OpTypeFunction %v4uchar %v4float\n"
449 "%fun_uc4_u4 = OpTypeFunction %v4uchar %v4uint\n"
450 "%fun_u4_uc4 = OpTypeFunction %v4uint %v4uchar\n"
451 "%fun_f_f = OpTypeFunction %float %float\n"
452 "%fun_f_ff = OpTypeFunction %float %float %float\n"
453 "%fun_f_fff = OpTypeFunction %float %float %float %float\n"
454 "%fun_f_f2f2 = OpTypeFunction %float %v2float %v2float\n"
455 "%fun_f_f3f3 = OpTypeFunction %float %v3float %v3float\n"
456 "%fun_f3_f3ff = OpTypeFunction %v3float %v3float %float %float\n"
457 "%fun_i_iii = OpTypeFunction %int %int %int %int\n"
458 "%fun_uc_uu = OpTypeFunction %uchar %uint %uint\n"
459 "%fun_u_uu = OpTypeFunction %uint %uint %uint\n"
460 "%fun_u_uuu = OpTypeFunction %uint %uint %uint %uint\n"
461 "%fun_u3_u3uu = OpTypeFunction %v3uint %v3uint %uint %uint\n";
462}
463
464static Coords GetCoordsKind(const Function &F) {
465 if (F.arg_size() <= 1)
466 return Coords::None;
467
468 DEBUG(F.getFunctionType()->dump());
469
470 SmallVector<const Argument *, 4> Args;
471 Args.reserve(F.arg_size());
472 for (const auto &Arg : F.args())
473 Args.push_back(&Arg);
474
475 auto IsInt32 = [](const Argument *Arg) {
476 assert(Arg);
477 auto *Ty = Arg->getType();
478 auto IntTy = dyn_cast<IntegerType>(Ty);
479 if (!IntTy)
480 return false;
481
482 return IntTy->getBitWidth() == 32;
483 };
484
485 size_t LastInt32Num = 0;
486 size_t XPos = -1; // npos - not found.
487 auto RIt = Args.rbegin();
488 const auto REnd = Args.rend();
489 while (RIt != REnd && IsInt32(*RIt)) {
490 if ((*RIt)->getName() == "x")
491 XPos = Args.size() - 1 - LastInt32Num;
492
493 ++LastInt32Num;
494 ++RIt;
495 }
496
497 DEBUG(dbgs() << "Original number of last i32's: " << LastInt32Num << '\n');
498 DEBUG(dbgs() << "X found at position: " << XPos << '\n');
499 if (XPos == size_t(-1) || Args.size() - XPos > size_t(Coords::Last))
500 return Coords::None;
501
502 // Check remaining coordinate names.
503 for (size_t i = 1, c = XPos + 1, e = Args.size(); c != e; ++i, ++c)
504 if (Args[c]->getName() != CoordsNames[i])
505 return Coords::None;
506
507 DEBUG(dbgs() << "Coords: not none!\n");
508
509 return Coords(Args.size() - XPos);
510}
511
512bool ReflectionPass::extractKernelSignatures(
513 const Module &M, SmallVectorImpl<KernelSignature> &Out) {
514 DEBUG(dbgs() << "extractKernelSignatures\n");
515
516 for (const auto &F : M.functions()) {
517 if (F.isDeclaration())
518 continue;
519
520 const auto CoordsKind = GetCoordsKind(F);
521 const auto CoordsNum = unsigned(CoordsKind);
522 if (F.arg_size() != CoordsNum + 1) {
523 // TODO: Handle different arrities (and lack of return value).
524 errs() << "Unsupported kernel signature.\n";
525 return false;
526 }
527
528 const auto *FT = F.getFunctionType();
529 const auto *RT = FT->getReturnType();
530 const auto *ArgT = FT->params()[0];
531 Out.push_back(
532 {TypeToString(RT), F.getName(), TypeToString(ArgT), GetCoordsKind(F)});
533 DEBUG(Out.back().dump());
534 }
535
536 if (Out.size() != 1) {
537 // TODO: recognize non-kernel functions and don't bail out here.
538 errs() << "Unsupported number of kernels\n";
539 return false;
540 }
541
542 return true;
543}
544
545bool ReflectionPass::emitKernelTypes(const KernelSignature &Kernel) {
546 DEBUG(dbgs() << "emitKernelTypes\n");
547
548 const auto *RTMapping = getMappingOrPrintError(Kernel.returnType);
549 const auto *ArgTMapping = getMappingOrPrintError(Kernel.argumentType);
550
551 if (!RTMapping || !ArgTMapping)
552 return false;
553
Yang Niec9360e2016-11-21 15:50:27 -0800554 OS << '\n'
555 << "%kernel_function_ty = OpTypeFunction " << RTMapping->SPIRVTy << ' '
556 << ArgTMapping->SPIRVTy;
Yang Ni6749f542016-11-07 20:20:49 -0800557
558 const auto CoordsNum = unsigned(Kernel.coordsKind);
559 for (size_t i = 0; i != CoordsNum; ++i)
560 OS << " %uint";
561
562 OS << '\n';
563
564 OS << "%ptr_function_ty = OpTypePointer Function " << RTMapping->SPIRVTy
565 << "\n";
566 OS << "%ptr_function_access_ty = OpTypePointer Function "
567 << RTMapping->SPIRVImageReadType << "\n\n";
568
569 return true;
570}
571
572bool ReflectionPass::emitInputImage(const KernelSignature &Kernel) {
573 DEBUG(dbgs() << "emitInputImage\n");
574
575 const auto *ArgTMapping = getMappingOrPrintError(Kernel.argumentType);
576 if (!ArgTMapping)
577 return false;
578
579 OS << "%input_image_ty = OpTypeImage " << ArgTMapping->SPIRVScalarTy
580 << " 2D 0 0 0 2 " << ArgTMapping->SPIRVImageFormat << '\n';
581
582 OS << "%input_image_ptr_ty = OpTypePointer UniformConstant "
583 << "%input_image_ty\n";
584
585 OS << "%input_image = OpVariable %input_image_ptr_ty UniformConstant\n";
586
587 return true;
588}
589
590void ReflectionPass::emitGLGlobalInput() {
591 DEBUG(dbgs() << "emitGLGlobalInput\n");
592
Yang Niec9360e2016-11-21 15:50:27 -0800593 OS << '\n'
594 << "%global_input_ptr_ty = OpTypePointer Input %v3uint\n"
Yang Ni6749f542016-11-07 20:20:49 -0800595 << "%global_invocation_id = OpVariable %global_input_ptr_ty Input\n";
596}
597
598bool ReflectionPass::emitOutputImage(const KernelSignature &Kernel) {
599 DEBUG(dbgs() << "emitOutputImage\n");
600
601 const auto *RTMapping = getMappingOrPrintError(Kernel.returnType);
602 if (!RTMapping)
603 return false;
604
605 OS << '\n';
606 OS << "%output_image_ty = OpTypeImage " << RTMapping->SPIRVScalarTy
607 << " 2D 0 0 0 2 " << RTMapping->SPIRVImageFormat << '\n'
608 << "%output_image_ptr_ty = OpTypePointer UniformConstant "
609 << "%output_image_ty\n";
610
611 OS << "%output_image = OpVariable %output_image_ptr_ty Image\n";
612
613 return true;
614}
615
616bool ReflectionPass::emitRSAllocImages(
617 const SmallVectorImpl<RSAllocationInfo> &RSAllocs) {
618 DEBUG(dbgs() << "emitRSAllocImages\n");
619
620 for (const auto &A : RSAllocs) {
621 if (!A.RSElementType) {
622 errs() << "Type of variable " << A.VarName << " not infered.\n";
623 return false;
624 }
625
626 const auto *AMapping = getMappingOrPrintError(*A.RSElementType);
627 if (!AMapping)
628 return false;
629
Yang Niec9360e2016-11-21 15:50:27 -0800630 OS << '\n'
631 << A.VarName << "_image_ty"
Yang Ni6749f542016-11-07 20:20:49 -0800632 << " = OpTypeImage " << AMapping->SPIRVScalarTy << " 2D 0 0 0 2 "
Yang Niec9360e2016-11-21 15:50:27 -0800633 << AMapping->SPIRVImageFormat << '\n'
634 << A.VarName << "_image_ptr_ty"
Yang Ni6749f542016-11-07 20:20:49 -0800635 << " = OpTypePointer UniformConstant " << A.VarName << "_image_ty\n";
636
637 OS << A.VarName << "_var = OpVariable " << A.VarName
638 << "_image_ptr_ty Image\n";
639 }
640
641 return true;
642}
643
644bool ReflectionPass::emitConstants(const KernelSignature &Kernel) {
645 DEBUG(dbgs() << "emitConstants\n");
646
647 OS << "\n"
648 "%uint_zero = OpConstant %uint 0\n"
649 "%float_zero = OpConstant %float 0\n";
650
651 return true;
652}
653
654static std::string GenerateConversionFun(const char *Name, const char *FType,
655 const char *From, const char *To,
656 const char *ConversionOp) {
657 std::ostringstream OS;
658
659 OS << "\n"
660 << "%rs_linker_" << Name << " = OpFunction " << To << " Pure " << FType
661 << "\n"
662 << "%param" << Name << " = OpFunctionParameter " << From << "\n"
663 << "%label" << Name << " = OpLabel\n"
664 << "%res" << Name << " = " << ConversionOp << " " << To << " %param"
665 << Name << "\n"
666 << " OpReturnValue %res" << Name << "\n"
667 << " OpFunctionEnd\n";
668
669 return OS.str();
670}
671
672static std::string GenerateEISFun(const char *Name, const char *FType,
673 const char *RType,
674 const SmallVector<const char *, 4> &ArgTypes,
675 const char *InstName) {
676 std::ostringstream OS;
677
Yang Niec9360e2016-11-21 15:50:27 -0800678 OS << '\n'
679 << "%rs_linker_" << Name << " = OpFunction " << RType << " Pure " << FType
680 << '\n';
Yang Ni6749f542016-11-07 20:20:49 -0800681
682 for (size_t i = 0, e = ArgTypes.size(); i < e; ++i)
683 OS << "%param" << Name << i << " = OpFunctionParameter " << ArgTypes[i]
684 << "\n";
685
686 OS << "%label" << Name << " = OpLabel\n"
687 << "%res" << Name << " = "
688 << "OpExtInst " << RType << " %glsl_ext_ins " << InstName;
689
690 for (size_t i = 0, e = ArgTypes.size(); i < e; ++i)
691 OS << " %param" << Name << i;
692
Yang Niec9360e2016-11-21 15:50:27 -0800693 OS << '\n'
694 << " OpReturnValue %res" << Name << "\n"
Yang Ni6749f542016-11-07 20:20:49 -0800695 << " OpFunctionEnd\n";
696
697 return OS.str();
698}
699
700// This SPIRV function generator relies heavily on future inlining.
701// Currently, the inliner doesn't perform any type checking - it blindly
702// maps function parameters to supplied parameters at call site.
703// It's non-trivial to generate correct SPIRV function signature based only
704// on the LLVM one, and the current design doesn't allow lazy type generation.
705//
706// TODO: Consider less horrible generator design that doesn't rely on lack of
707// type checking in the inliner.
708static std::string GenerateRSGEA(const char *Name, const char *RType,
709 StringRef LoadName, Coords CoordsKind) {
710 assert(CoordsKind != Coords::None);
711 std::ostringstream OS;
712
713 OS << "\n"
714 << "%rs_linker_" << Name << " = OpFunction " << RType
715 << " None %rs_inliner_placeholder_ty\n";
716
717 // Since the inliner doesn't perform type checking, function and parameter
718 // types can be anything. %rs_inliner_placeholder_ty is just a placeholder
719 // name that will disappear after inlining.
720
721 OS << "%rs_drop_param_" << Name << " = OpFunctionParameter "
722 << "%rs_inliner_placeholder_ty\n";
723
724 for (size_t i = 0, e = size_t(CoordsKind); i != e; ++i)
725 OS << "%param" << Name << '_' << CoordsNames[i].str()
726 << " = OpFunctionParameter %uint\n";
727
728 OS << "%label" << Name << " = OpLabel\n";
729 OS << "%arg" << Name << " = OpCompositeConstruct %v" << size_t(CoordsKind)
730 << "uint ";
731
732 for (size_t i = 0, e = size_t(CoordsKind); i != e; ++i)
733 OS << "%param" << Name << '_' << CoordsNames[i].str() << ' ';
734
735 OS << '\n';
736
737 OS << "%read" << Name << " = OpImageRead " << RType << ' ' << LoadName.str()
738 << " %arg" << Name << '\n';
739 OS << " OpReturnValue %read" << Name << '\n';
740 OS << " OpFunctionEnd\n";
741
742 return OS.str();
743}
744
745// The same remarks as to GenerateRSGEA apply to SEA function generator.
746static std::string GenerateRSSEA(const char *Name, StringRef LoadName,
747 Coords CoordsKind) {
748 assert(CoordsKind != Coords::None);
749 std::ostringstream OS;
750
751 // %rs_inliner_placeholder_ty will disappear after inlining.
752 OS << "\n"
753 << "%rs_linker_" << Name << " = OpFunction %void None "
754 << "%rs_inliner_placeholder_ty\n";
755
756 OS << "%rs_placeholder_param_" << Name << " = OpFunctionParameter "
757 << "%rs_inliner_placeholder_ty\n";
758 OS << "%param" << Name << "_new_val = OpFunctionParameter "
759 << "%rs_inliner_placeholder_ty\n";
760
761 for (size_t i = 0, e = size_t(CoordsKind); i != e; ++i)
762 OS << "%param" << Name << '_' << CoordsNames[i].str()
763 << " = OpFunctionParameter %uint\n";
764
765 OS << "%label" << Name << " = OpLabel\n";
766 OS << "%arg" << Name << " = OpCompositeConstruct %v" << size_t(CoordsKind)
767 << "uint ";
768
769 for (size_t i = 0, e = size_t(CoordsKind); i != e; ++i)
770 OS << "%param" << Name << '_' << CoordsNames[i].str() << ' ';
771
772 OS << '\n';
773
774 OS << "OpImageWrite " << LoadName.str() << " %arg" << Name << " %param"
775 << Name << "_new_val\n";
776 OS << " OpReturn\n";
777 OS << " OpFunctionEnd\n";
778
779 return OS.str();
780}
781
782void ReflectionPass::emitRTFunctions() {
783 DEBUG(dbgs() << "emitRTFunctions\n");
784
785 // TODO: Emit other runtime functions.
786 // TODO: Generate libary file instead of generating functions below
787 // every compilation.
788
789 // Use uints as Khronos' SPIRV converter turns LLVM's i32s into uints.
790
791 OS << GenerateConversionFun("_Z14convert_float4Dv4_h", "%fun_f4_uc4",
792 "%v4uchar", "%v4float", "OpConvertUToF");
793
794 OS << GenerateConversionFun("_Z14convert_uchar4Dv4_f", "%fun_uc4_f4",
795 "%v4float", "%v4uchar", "OpConvertFToU");
796
797 OS << GenerateConversionFun("_Z14convert_float3Dv3_h", "%fun_f3_uc3",
798 "%v3uchar", "%v3float", "OpConvertUToF");
799
800 OS << GenerateConversionFun("_Z14convert_uchar3Dv3_f", "%fun_uc3_f3",
801 "%v3float", "%v3uchar", "OpConvertFToU");
802
803 OS << GenerateConversionFun("_Z12convert_int3Dv3_f", "%fun_u3_f3", "%v3float",
804 "%v3uint", "OpConvertFToU");
805
806 OS << GenerateConversionFun("_Z14convert_uchar3Dv3_i", "%fun_uc3_u3",
807 "%v3uint", "%v3uchar", "OpUConvert");
808
809 OS << GenerateConversionFun("_Z14convert_uchar4Dv4_j", "%fun_uc4_u4",
810 "%v4uint", "%v4uchar", "OpUConvert");
811
812 OS << GenerateConversionFun("_Z13convert_uint4Dv4_h", "%fun_u4_uc4",
813 "%v4uchar", "%v4uint", "OpUConvert");
814
815 OS << GenerateEISFun("_Z3sinf", "%fun_f_f", "%float", {"%float"}, "Sin");
816 OS << GenerateEISFun("_Z4sqrtf", "%fun_f_f", "%float", {"%float"}, "Sqrt");
817 OS << GenerateEISFun("_Z10native_expf", "%fun_f_f", "%float", {"%float"},
818 "Exp");
819 OS << GenerateEISFun("_Z3maxii", "%fun_u_uu", "%uint", {"%uint", "%uint"},
820 "SMax");
821 OS << GenerateEISFun("_Z3minii", "%fun_u_uu", "%uint", {"%uint", "%uint"},
822 "SMin");
823 OS << GenerateEISFun("_Z3maxff", "%fun_f_ff", "%float", {"%float", "%float"},
824 "FMax");
825 OS << GenerateEISFun("_Z3minff", "%fun_f_ff", "%float", {"%float", "%float"},
826 "FMin");
827 OS << GenerateEISFun("_Z5clampfff", "%fun_f_fff", "%float",
828 {"%float", "%float", "%float"}, "FClamp");
829 OS << GenerateEISFun("_Z5clampiii", "%fun_u_uuu", "%uint",
830 {"%uint", "%uint", "%uint"}, "SClamp");
831
832 OS << R"(
833%rs_linker__Z3dotDv2_fS_ = OpFunction %float Pure %fun_f_f2f2
834%param_Z3dotDv2_fS_0 = OpFunctionParameter %v2float
835%param_Z3dotDv2_fS_1 = OpFunctionParameter %v2float
836%label_Z3dotDv2_fS = OpLabel
837%res_Z3dotDv2_fS = OpDot %float %param_Z3dotDv2_fS_0 %param_Z3dotDv2_fS_1
838 OpReturnValue %res_Z3dotDv2_fS
839 OpFunctionEnd
840)";
841
842 OS << R"(
843%rs_linker__Z3dotDv3_fS_ = OpFunction %float Pure %fun_f_f3f3
844%param_Z3dotDv3_fS_0 = OpFunctionParameter %v3float
845%param_Z3dotDv3_fS_1 = OpFunctionParameter %v3float
846%label_Z3dotDv3_fS = OpLabel
847%res_Z3dotDv3_fS = OpDot %float %param_Z3dotDv3_fS_0 %param_Z3dotDv3_fS_1
848 OpReturnValue %res_Z3dotDv3_fS
849 OpFunctionEnd
850)";
851
852 OS << R"(
853%rs_linker_rsUnpackColor8888 = OpFunction %v4float Pure %fun_f4_uc4
854%paramrsUnpackColor88880 = OpFunctionParameter %v4uchar
855%labelrsUnpackColor8888 = OpLabel
856%castedUnpackColor8888 = OpBitcast %uint %paramrsUnpackColor88880
857%resrsUnpackColor8888 = OpExtInst %v4float %glsl_ext_ins UnpackUnorm4x8 %castedUnpackColor8888
858 OpReturnValue %resrsUnpackColor8888
859 OpFunctionEnd
860)";
861
862 OS << R"(
863%rs_linker__Z17rsPackColorTo8888Dv4_f = OpFunction %v4uchar Pure %fun_uc4_f4
864%param_Z17rsPackColorTo8888Dv4_f0 = OpFunctionParameter %v4float
865%label_Z17rsPackColorTo8888Dv4_f = OpLabel
866%res_Z17rsPackColorTo8888Dv4_f = OpExtInst %uint %glsl_ext_ins PackUnorm4x8 %param_Z17rsPackColorTo8888Dv4_f0
867%casted_Z17rsPackColorTo8888Dv4_f = OpBitcast %v4uchar %res_Z17rsPackColorTo8888Dv4_f
868 OpReturnValue %casted_Z17rsPackColorTo8888Dv4_f
869 OpFunctionEnd
870)";
871
872 OS << R"(
873%rs_linker__Z5clampDv3_fff = OpFunction %v3float Pure %fun_f3_f3ff
874%param_Z5clampDv3_fff0 = OpFunctionParameter %v3float
875%param_Z5clampDv3_fff1 = OpFunctionParameter %float
876%param_Z5clampDv3_fff2 = OpFunctionParameter %float
877%label_Z5clampDv3_fff = OpLabel
878%arg1_Z5clampDv3_fff = OpCompositeConstruct %v3float %param_Z5clampDv3_fff1 %param_Z5clampDv3_fff1 %param_Z5clampDv3_fff1
879%arg2_Z5clampDv3_fff = OpCompositeConstruct %v3float %param_Z5clampDv3_fff2 %param_Z5clampDv3_fff2 %param_Z5clampDv3_fff2
880%res_Z5clampDv3_fff = OpExtInst %v3float %glsl_ext_ins FClamp %param_Z5clampDv3_fff0 %arg1_Z5clampDv3_fff %arg2_Z5clampDv3_fff
881 OpReturnValue %res_Z5clampDv3_fff
882 OpFunctionEnd
883)";
884
885 OS << R"(
886%rs_linker__Z5clampDv3_iii = OpFunction %v3uint Pure %fun_u3_u3uu
887%param_Z5clampDv3_iii0 = OpFunctionParameter %v3uint
888%param_Z5clampDv3_iii1 = OpFunctionParameter %uint
889%param_Z5clampDv3_iii2 = OpFunctionParameter %uint
890%label_Z5clampDv3_iii = OpLabel
891%arg1_Z5clampDv3_iii = OpCompositeConstruct %v3uint %param_Z5clampDv3_iii1 %param_Z5clampDv3_iii1 %param_Z5clampDv3_iii1
892%arg2_Z5clampDv3_iii = OpCompositeConstruct %v3uint %param_Z5clampDv3_iii2 %param_Z5clampDv3_iii2 %param_Z5clampDv3_iii2
893%res_Z5clampDv3_iii = OpExtInst %v3uint %glsl_ext_ins UClamp %param_Z5clampDv3_iii0 %arg1_Z5clampDv3_iii %arg2_Z5clampDv3_iii
894 OpReturnValue %res_Z5clampDv3_iii
895 OpFunctionEnd
896)";
897}
898
899bool ReflectionPass::emitRSAllocFunctions(
900 Module &M, const SmallVectorImpl<RSAllocationInfo> &RSAllocs,
901 const SmallVectorImpl<RSAllocationCallInfo> &RSAllocAccesses) {
902 DEBUG(dbgs() << "emitRSAllocFunctions\n");
903
904 for (const auto &Access : RSAllocAccesses) {
905 solidifyRSAllocAccess(M, Access);
906
907 auto *Fun = Access.FCall->getCalledFunction();
908 if (!Fun)
909 return false;
910
911 const auto FName = Fun->getName();
912 auto *ETMapping = getMappingOrPrintError(Access.RSElementTy);
913 if (!ETMapping)
914 return false;
915
916 const auto ElementTy = ETMapping->SPIRVTy;
917 const std::string LoadName = Access.RSAlloc.VarName + "_load";
918
919 if (Access.Kind == RSAllocAccessKind::GEA)
920 OS << GenerateRSGEA(FName.str().c_str(), ElementTy.c_str(),
921 LoadName.c_str(), Coords::XY);
922 else
923 OS << GenerateRSSEA(FName.str().c_str(), LoadName.c_str(), Coords::XY);
924 }
925
926 return true;
927}
928
929bool ReflectionPass::emitMain(
930 const KernelSignature &Kernel,
931 const SmallVectorImpl<RSAllocationInfo> &RSAllocs) {
932 DEBUG(dbgs() << "emitMain\n");
933
934 const auto *RTMapping = getMappingOrPrintError(Kernel.returnType);
935 const auto *ArgTMapping = getMappingOrPrintError(Kernel.argumentType);
936
937 if (!RTMapping || !ArgTMapping)
938 return false;
939
940 OS << '\n';
I-Jui (Ray) Sung3f1f6a12016-12-07 18:04:08 -0800941 OS << Kernel.getWrapperName();
942 OS << " = OpFunction %void None %fun_void\n";
943 OS << "%label_main = OpLabel\n"
Yang Ni6749f542016-11-07 20:20:49 -0800944 "%input_pixel = OpVariable %ptr_function_access_ty Function\n"
945 " %res = OpVariable %ptr_function_ty Function\n"
946 " %image_load = OpLoad %input_image_ty %input_image\n"
947 "%coords_load = OpLoad %v3uint %global_invocation_id\n"
948 " %coords_x = OpCompositeExtract %uint %coords_load 0\n"
949 " %coords_y = OpCompositeExtract %uint %coords_load 1\n"
950 " %coords_z = OpCompositeExtract %uint %coords_load 2\n"
951 " %shuffled = OpVectorShuffle %v2uint %coords_load %coords_load 0 1\n"
952 " %bitcasted = OpBitcast %v2int %shuffled\n";
953
954 OS << " %image_read = OpImageRead " << ArgTMapping->SPIRVImageReadType
955 << " %image_load %bitcasted\n"
956 " OpStore %input_pixel %image_read\n";
957
958 // TODO: Handle vector types of width different than 4.
959 if (RTMapping->isVectorTy) {
960 OS << " %input_load = OpLoad " << ArgTMapping->SPIRVTy << " %input_pixel\n";
961 } else {
962 OS << "%input_access_chain = OpAccessChain %ptr_function_ty "
963 "%input_pixel %uint_zero\n"
964 << " %input_load = OpLoad " << ArgTMapping->SPIRVTy
965 << " %input_access_chain\n";
966 }
967
968 for (const auto &A : RSAllocs)
969 OS << A.VarName << "_load = OpLoad " << A.VarName << "_image_ty "
970 << A.VarName << "_var\n";
971
972 OS << "%kernel_call = OpFunctionCall " << ArgTMapping->SPIRVTy
973 << " %RS_SPIRV_DUMMY_ %input_load";
974
975 const auto CoordsNum = size_t(Kernel.coordsKind);
976 for (size_t i = 0; i != CoordsNum; ++i)
977 OS << " %coords_" << CoordsNames[i].str();
978
979 OS << '\n';
980
981 OS << " OpStore %res %kernel_call\n"
982 "%output_load = OpLoad %output_image_ty %output_image\n";
983 OS << " %res_load = OpLoad " << RTMapping->SPIRVTy << " %res\n";
984
985 if (!RTMapping->isVectorTy) {
986 OS << "%composite_constructed = OpCompositeConstruct "
987 << RTMapping->SPIRVImageReadType;
988 for (size_t i = 0; i < RTMapping->vectorWidth; ++i)
989 OS << " %res_load";
990
991 OS << "\n"
992 " OpImageWrite %output_load %bitcasted "
993 "%composite_constructed\n";
994
995 } else {
996 OS << " OpImageWrite %output_load %bitcasted %res_load\n";
997 }
998
999 OS << " OpReturn\n"
1000 " OpFunctionEnd\n";
1001
1002 OS << "%RS_SPIRV_DUMMY_ = OpFunction " << RTMapping->SPIRVTy
1003 << " None %kernel_function_ty\n";
1004
1005 OS << " %p = OpFunctionParameter " << ArgTMapping->SPIRVTy << '\n';
1006
1007 for (size_t i = 0; i != CoordsNum; ++i)
1008 OS << " %coords_param_" << CoordsNames[i].str()
1009 << " = OpFunctionParameter %uint\n";
1010
1011 OS << " %11 = OpLabel\n"
1012 " OpReturnValue %p\n"
1013 " OpFunctionEnd\n";
1014
1015 return true;
1016}
1017
1018} // namespace rs2spirv