blob: 57286b53fb29a78443063cb71df057d5762f4d05 [file] [log] [blame]
Ethan Nicholas26a9aad2018-03-27 14:10:52 -04001/*
2 * Copyright 2018 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8#ifndef SKSL_STANDALONE
9
10#ifdef SK_LLVM_AVAILABLE
11
12#include "SkSLJIT.h"
13
14#include "SkCpu.h"
15#include "SkRasterPipeline.h"
16#include "../jumper/SkJumper.h"
17#include "ir/SkSLExpressionStatement.h"
18#include "ir/SkSLFunctionCall.h"
19#include "ir/SkSLFunctionReference.h"
20#include "ir/SkSLIndexExpression.h"
21#include "ir/SkSLProgram.h"
22#include "llvm/ExecutionEngine/RTDyldMemoryManager.h"
23
24static constexpr int MAX_VECTOR_COUNT = 16;
25
26extern "C" void sksl_pipeline_append(SkRasterPipeline* p, int stage, void* ctx) {
27 p->append((SkRasterPipeline::StockStage) stage, ctx);
28}
29
30#define PTR_SIZE sizeof(void*)
31
32extern "C" void sksl_pipeline_append_callback(SkRasterPipeline* p, void* fn) {
33 p->append(fn, nullptr);
34}
35
36extern "C" void sksl_debug_print(float f) {
37 printf("Debug: %f\n", f);
38}
39
40namespace SkSL {
41
42static constexpr int STAGE_PARAM_COUNT = 12;
43
44static bool ends_with_branch(const Statement& stmt) {
45 switch (stmt.fKind) {
46 case Statement::kBlock_Kind: {
47 const Block& b = (const Block&) stmt;
48 if (b.fStatements.size()) {
49 return ends_with_branch(*b.fStatements.back());
50 }
51 return false;
52 }
53 case Statement::kBreak_Kind: // fall through
54 case Statement::kContinue_Kind: // fall through
55 case Statement::kReturn_Kind: // fall through
56 return true;
57 default:
58 return false;
59 }
60}
61
62JIT::JIT(Compiler* compiler)
63: fCompiler(*compiler) {
64 LLVMInitializeNativeTarget();
65 LLVMInitializeNativeAsmPrinter();
66 LLVMLinkInMCJIT();
Ethan Nicholasd9d33c32018-06-12 11:05:59 -040067 SkASSERT(!SkCpu::Supports(SkCpu::SKX)); // not yet supported
Ethan Nicholas26a9aad2018-03-27 14:10:52 -040068 if (SkCpu::Supports(SkCpu::HSW)) {
69 fVectorCount = 8;
70 fCPU = "haswell";
71 } else if (SkCpu::Supports(SkCpu::AVX)) {
72 fVectorCount = 8;
73 fCPU = "ivybridge";
74 } else {
75 fVectorCount = 4;
76 fCPU = nullptr;
77 }
78 fContext = LLVMContextCreate();
79 fVoidType = LLVMVoidTypeInContext(fContext);
80 fInt1Type = LLVMInt1TypeInContext(fContext);
81 fInt8Type = LLVMInt8TypeInContext(fContext);
82 fInt8PtrType = LLVMPointerType(fInt8Type, 0);
83 fInt32Type = LLVMInt32TypeInContext(fContext);
84 fInt64Type = LLVMInt64TypeInContext(fContext);
85 fSizeTType = LLVMInt64TypeInContext(fContext);
86 fInt32VectorType = LLVMVectorType(fInt32Type, fVectorCount);
87 fInt32Vector2Type = LLVMVectorType(fInt32Type, 2);
88 fInt32Vector3Type = LLVMVectorType(fInt32Type, 3);
89 fInt32Vector4Type = LLVMVectorType(fInt32Type, 4);
90 fFloat32Type = LLVMFloatTypeInContext(fContext);
91 fFloat32VectorType = LLVMVectorType(fFloat32Type, fVectorCount);
92 fFloat32Vector2Type = LLVMVectorType(fFloat32Type, 2);
93 fFloat32Vector3Type = LLVMVectorType(fFloat32Type, 3);
94 fFloat32Vector4Type = LLVMVectorType(fFloat32Type, 4);
95}
96
97JIT::~JIT() {
98 LLVMOrcDisposeInstance(fJITStack);
99 LLVMContextDispose(fContext);
100}
101
102void JIT::addBuiltinFunction(const char* ourName, const char* realName, LLVMTypeRef returnType,
103 std::vector<LLVMTypeRef> parameters) {
104 for (const auto& pair : *fProgram->fSymbols) {
105 if (Symbol::kFunctionDeclaration_Kind == pair.second->fKind) {
106 const FunctionDeclaration& f = (const FunctionDeclaration&) *pair.second;
107 if (pair.first != ourName || returnType != this->getType(f.fReturnType) ||
108 parameters.size() != f.fParameters.size()) {
109 continue;
110 }
111 for (size_t i = 0; i < parameters.size(); ++i) {
112 if (parameters[i] != this->getType(f.fParameters[i]->fType)) {
113 goto next;
114 }
115 }
116 fFunctions[&f] = LLVMAddFunction(fModule, realName, LLVMFunctionType(returnType,
117 parameters.data(),
118 parameters.size(),
119 false));
120 }
121 next:;
122 }
123}
124
125void JIT::loadBuiltinFunctions() {
126 this->addBuiltinFunction("abs", "fabs", fFloat32Type, { fFloat32Type });
127 this->addBuiltinFunction("sin", "sinf", fFloat32Type, { fFloat32Type });
128 this->addBuiltinFunction("cos", "cosf", fFloat32Type, { fFloat32Type });
129 this->addBuiltinFunction("tan", "tanf", fFloat32Type, { fFloat32Type });
130 this->addBuiltinFunction("sqrt", "sqrtf", fFloat32Type, { fFloat32Type });
131 this->addBuiltinFunction("print", "sksl_debug_print", fVoidType, { fFloat32Type });
132}
133
134uint64_t JIT::resolveSymbol(const char* name, JIT* jit) {
135 LLVMOrcTargetAddress result;
136 if (!LLVMOrcGetSymbolAddress(jit->fJITStack, &result, name)) {
137 if (!strcmp(name, "_sksl_pipeline_append")) {
138 result = (uint64_t) &sksl_pipeline_append;
139 } else if (!strcmp(name, "_sksl_pipeline_append_callback")) {
140 result = (uint64_t) &sksl_pipeline_append_callback;
141 } else if (!strcmp(name, "_sksl_debug_print")) {
142 result = (uint64_t) &sksl_debug_print;
143 } else {
144 result = llvm::RTDyldMemoryManager::getSymbolAddressInProcess(name);
145 }
146 }
Ethan Nicholasd9d33c32018-06-12 11:05:59 -0400147 SkASSERT(result);
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400148 return result;
149}
150
151LLVMValueRef JIT::compileFunctionCall(LLVMBuilderRef builder, const FunctionCall& fc) {
152 LLVMValueRef func = fFunctions[&fc.fFunction];
Ethan Nicholasd9d33c32018-06-12 11:05:59 -0400153 SkASSERT(func);
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400154 std::vector<LLVMValueRef> parameters;
155 for (const auto& a : fc.fArguments) {
156 parameters.push_back(this->compileExpression(builder, *a));
157 }
158 return LLVMBuildCall(builder, func, parameters.data(), parameters.size(), "");
159}
160
161LLVMTypeRef JIT::getType(const Type& type) {
162 switch (type.kind()) {
163 case Type::kOther_Kind:
164 if (type.name() == "void") {
165 return fVoidType;
166 }
Ethan Nicholasd9d33c32018-06-12 11:05:59 -0400167 SkASSERT(type.name() == "SkRasterPipeline");
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400168 return fInt8PtrType;
169 case Type::kScalar_Kind:
170 if (type.isSigned() || type.isUnsigned()) {
171 return fInt32Type;
172 }
173 if (type.isUnsigned()) {
174 return fInt32Type;
175 }
176 if (type.isFloat()) {
177 return fFloat32Type;
178 }
Ethan Nicholasd9d33c32018-06-12 11:05:59 -0400179 SkASSERT(type.name() == "bool");
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400180 return fInt1Type;
181 case Type::kArray_Kind:
182 return LLVMPointerType(this->getType(type.componentType()), 0);
183 case Type::kVector_Kind:
184 if (type.name() == "float2" || type.name() == "half2") {
185 return fFloat32Vector2Type;
186 }
187 if (type.name() == "float3" || type.name() == "half3") {
188 return fFloat32Vector3Type;
189 }
190 if (type.name() == "float4" || type.name() == "half4") {
191 return fFloat32Vector4Type;
192 }
193 if (type.name() == "int2" || type.name() == "short2") {
194 return fInt32Vector2Type;
195 }
196 if (type.name() == "int3" || type.name() == "short3") {
197 return fInt32Vector3Type;
198 }
199 if (type.name() == "int4" || type.name() == "short4") {
200 return fInt32Vector4Type;
201 }
202 // fall through
203 default:
204 ABORT("unsupported type");
205 }
206}
207
208void JIT::setBlock(LLVMBuilderRef builder, LLVMBasicBlockRef block) {
209 fCurrentBlock = block;
210 LLVMPositionBuilderAtEnd(builder, block);
211}
212
213std::unique_ptr<JIT::LValue> JIT::getLValue(LLVMBuilderRef builder, const Expression& expr) {
214 switch (expr.fKind) {
215 case Expression::kVariableReference_Kind: {
216 class PointerLValue : public LValue {
217 public:
218 PointerLValue(LLVMValueRef ptr)
219 : fPointer(ptr) {}
220
221 LLVMValueRef load(LLVMBuilderRef builder) override {
222 return LLVMBuildLoad(builder, fPointer, "lvalue load");
223 }
224
225 void store(LLVMBuilderRef builder, LLVMValueRef value) override {
226 LLVMBuildStore(builder, value, fPointer);
227 }
228
229 private:
230 LLVMValueRef fPointer;
231 };
232 const Variable* var = &((VariableReference&) expr).fVariable;
233 if (var->fStorage == Variable::kParameter_Storage &&
234 !(var->fModifiers.fFlags & Modifiers::kOut_Flag) &&
235 fPromotedParameters.find(var) == fPromotedParameters.end()) {
236 // promote parameter to variable
237 fPromotedParameters.insert(var);
238 LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
239 LLVMValueRef alloca = LLVMBuildAlloca(builder, this->getType(var->fType),
240 String(var->fName).c_str());
241 LLVMBuildStore(builder, fVariables[var], alloca);
242 LLVMPositionBuilderAtEnd(builder, fCurrentBlock);
243 fVariables[var] = alloca;
244 }
245 LLVMValueRef ptr = fVariables[var];
246 return std::unique_ptr<LValue>(new PointerLValue(ptr));
247 }
248 case Expression::kTernary_Kind: {
249 class TernaryLValue : public LValue {
250 public:
251 TernaryLValue(JIT* jit, LLVMValueRef test, std::unique_ptr<LValue> ifTrue,
252 std::unique_ptr<LValue> ifFalse)
253 : fJIT(*jit)
254 , fTest(test)
255 , fIfTrue(std::move(ifTrue))
256 , fIfFalse(std::move(ifFalse)) {}
257
258 LLVMValueRef load(LLVMBuilderRef builder) override {
259 LLVMBasicBlockRef trueBlock = LLVMAppendBasicBlockInContext(
260 fJIT.fContext,
261 fJIT.fCurrentFunction,
262 "true ? ...");
263 LLVMBasicBlockRef falseBlock = LLVMAppendBasicBlockInContext(
264 fJIT.fContext,
265 fJIT.fCurrentFunction,
266 "false ? ...");
267 LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fJIT.fContext,
268 fJIT.fCurrentFunction,
269 "ternary merge");
270 LLVMBuildCondBr(builder, fTest, trueBlock, falseBlock);
271 fJIT.setBlock(builder, trueBlock);
272 LLVMValueRef ifTrue = fIfTrue->load(builder);
273 LLVMBuildBr(builder, merge);
274 fJIT.setBlock(builder, falseBlock);
275 LLVMValueRef ifFalse = fIfTrue->load(builder);
276 LLVMBuildBr(builder, merge);
277 fJIT.setBlock(builder, merge);
278 LLVMTypeRef type = LLVMPointerType(LLVMTypeOf(ifTrue), 0);
279 LLVMValueRef phi = LLVMBuildPhi(builder, type, "?");
280 LLVMValueRef incomingValues[2] = { ifTrue, ifFalse };
281 LLVMBasicBlockRef incomingBlocks[2] = { trueBlock, falseBlock };
282 LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2);
283 return phi;
284 }
285
286 void store(LLVMBuilderRef builder, LLVMValueRef value) override {
287 LLVMBasicBlockRef trueBlock = LLVMAppendBasicBlockInContext(
288 fJIT.fContext,
289 fJIT.fCurrentFunction,
290 "true ? ...");
291 LLVMBasicBlockRef falseBlock = LLVMAppendBasicBlockInContext(
292 fJIT.fContext,
293 fJIT.fCurrentFunction,
294 "false ? ...");
295 LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fJIT.fContext,
296 fJIT.fCurrentFunction,
297 "ternary merge");
298 LLVMBuildCondBr(builder, fTest, trueBlock, falseBlock);
299 fJIT.setBlock(builder, trueBlock);
300 fIfTrue->store(builder, value);
301 LLVMBuildBr(builder, merge);
302 fJIT.setBlock(builder, falseBlock);
303 fIfTrue->store(builder, value);
304 LLVMBuildBr(builder, merge);
305 fJIT.setBlock(builder, merge);
306 }
307
308 private:
309 JIT& fJIT;
310 LLVMValueRef fTest;
311 std::unique_ptr<LValue> fIfTrue;
312 std::unique_ptr<LValue> fIfFalse;
313 };
314 const TernaryExpression& t = (const TernaryExpression&) expr;
315 LLVMValueRef test = this->compileExpression(builder, *t.fTest);
316 return std::unique_ptr<LValue>(new TernaryLValue(this,
317 test,
318 this->getLValue(builder,
319 *t.fIfTrue),
320 this->getLValue(builder,
321 *t.fIfFalse)));
322 }
323 case Expression::kSwizzle_Kind: {
324 class SwizzleLValue : public LValue {
325 public:
326 SwizzleLValue(JIT* jit, LLVMTypeRef type, std::unique_ptr<LValue> base,
327 std::vector<int> components)
328 : fJIT(*jit)
329 , fType(type)
330 , fBase(std::move(base))
331 , fComponents(components) {}
332
333 LLVMValueRef load(LLVMBuilderRef builder) override {
334 LLVMValueRef base = fBase->load(builder);
335 if (fComponents.size() > 1) {
336 LLVMValueRef result = LLVMGetUndef(fType);
337 for (size_t i = 0; i < fComponents.size(); ++i) {
338 LLVMValueRef element = LLVMBuildExtractElement(
339 builder,
340 base,
341 LLVMConstInt(fJIT.fInt32Type,
342 fComponents[i],
343 false),
344 "swizzle extract");
345 result = LLVMBuildInsertElement(builder, result, element,
346 LLVMConstInt(fJIT.fInt32Type, i, false),
347 "swizzle insert");
348 }
349 return result;
350 }
Ethan Nicholasd9d33c32018-06-12 11:05:59 -0400351 SkASSERT(fComponents.size() == 1);
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400352 return LLVMBuildExtractElement(builder, base,
353 LLVMConstInt(fJIT.fInt32Type,
354 fComponents[0],
355 false),
356 "swizzle extract");
357 }
358
359 void store(LLVMBuilderRef builder, LLVMValueRef value) override {
360 LLVMValueRef result = fBase->load(builder);
361 if (fComponents.size() > 1) {
362 for (size_t i = 0; i < fComponents.size(); ++i) {
363 LLVMValueRef element = LLVMBuildExtractElement(builder, value,
364 LLVMConstInt(
365 fJIT.fInt32Type,
366 i,
367 false),
368 "swizzle extract");
369 result = LLVMBuildInsertElement(builder, result, element,
370 LLVMConstInt(fJIT.fInt32Type,
371 fComponents[i],
372 false),
373 "swizzle insert");
374 }
375 } else {
376 result = LLVMBuildInsertElement(builder, result, value,
377 LLVMConstInt(fJIT.fInt32Type,
378 fComponents[0],
379 false),
380 "swizzle insert");
381 }
382 fBase->store(builder, result);
383 }
384
385 private:
386 JIT& fJIT;
387 LLVMTypeRef fType;
388 std::unique_ptr<LValue> fBase;
389 std::vector<int> fComponents;
390 };
391 const Swizzle& s = (const Swizzle&) expr;
392 return std::unique_ptr<LValue>(new SwizzleLValue(this, this->getType(s.fType),
393 this->getLValue(builder, *s.fBase),
394 s.fComponents));
395 }
396 default:
397 ABORT("unsupported lvalue");
398 }
399}
400
401JIT::TypeKind JIT::typeKind(const Type& type) {
402 if (type.kind() == Type::kVector_Kind) {
403 return this->typeKind(type.componentType());
404 }
405 if (type.fName == "int" || type.fName == "short") {
406 return JIT::kInt_TypeKind;
407 } else if (type.fName == "uint" || type.fName == "ushort") {
408 return JIT::kUInt_TypeKind;
409 } else if (type.fName == "float" || type.fName == "double") {
410 return JIT::kFloat_TypeKind;
411 }
412 ABORT("unsupported type: %s\n", type.description().c_str());
413}
414
415void JIT::vectorize(LLVMBuilderRef builder, LLVMValueRef* value, int columns) {
416 LLVMValueRef result = LLVMGetUndef(LLVMVectorType(LLVMTypeOf(*value), columns));
417 for (int i = 0; i < columns; ++i) {
418 result = LLVMBuildInsertElement(builder,
419 result,
420 *value,
421 LLVMConstInt(fInt32Type, i, false),
422 "vectorize");
423 }
424 *value = result;
425}
426
427void JIT::vectorize(LLVMBuilderRef builder, const BinaryExpression& b, LLVMValueRef* left,
428 LLVMValueRef* right) {
429 if (b.fLeft->fType.kind() == Type::kScalar_Kind &&
430 b.fRight->fType.kind() == Type::kVector_Kind) {
431 this->vectorize(builder, left, b.fRight->fType.columns());
432 } else if (b.fLeft->fType.kind() == Type::kVector_Kind &&
433 b.fRight->fType.kind() == Type::kScalar_Kind) {
434 this->vectorize(builder, right, b.fLeft->fType.columns());
435 }
436}
437
438
439LLVMValueRef JIT::compileBinary(LLVMBuilderRef builder, const BinaryExpression& b) {
440 #define BINARY(SFunc, UFunc, FFunc) { \
441 LLVMValueRef left = this->compileExpression(builder, *b.fLeft); \
442 LLVMValueRef right = this->compileExpression(builder, *b.fRight); \
443 this->vectorize(builder, b, &left, &right); \
444 switch (this->typeKind(b.fLeft->fType)) { \
445 case kInt_TypeKind: \
446 return SFunc(builder, left, right, "binary"); \
447 case kUInt_TypeKind: \
448 return UFunc(builder, left, right, "binary"); \
449 case kFloat_TypeKind: \
450 return FFunc(builder, left, right, "binary"); \
451 default: \
452 ABORT("unsupported typeKind"); \
453 } \
454 }
455 #define COMPOUND(SFunc, UFunc, FFunc) { \
456 std::unique_ptr<LValue> lvalue = this->getLValue(builder, *b.fLeft); \
457 LLVMValueRef left = lvalue->load(builder); \
458 LLVMValueRef right = this->compileExpression(builder, *b.fRight); \
459 this->vectorize(builder, b, &left, &right); \
460 LLVMValueRef result; \
461 switch (this->typeKind(b.fLeft->fType)) { \
462 case kInt_TypeKind: \
463 result = SFunc(builder, left, right, "binary"); \
464 break; \
465 case kUInt_TypeKind: \
466 result = UFunc(builder, left, right, "binary"); \
467 break; \
468 case kFloat_TypeKind: \
469 result = FFunc(builder, left, right, "binary"); \
470 break; \
471 default: \
472 ABORT("unsupported typeKind"); \
473 } \
474 lvalue->store(builder, result); \
475 return result; \
476 }
477 #define COMPARE(SFunc, SOp, UFunc, UOp, FFunc, FOp) { \
478 LLVMValueRef left = this->compileExpression(builder, *b.fLeft); \
479 LLVMValueRef right = this->compileExpression(builder, *b.fRight); \
480 this->vectorize(builder, b, &left, &right); \
481 switch (this->typeKind(b.fLeft->fType)) { \
482 case kInt_TypeKind: \
483 return SFunc(builder, SOp, left, right, "binary"); \
484 case kUInt_TypeKind: \
485 return UFunc(builder, UOp, left, right, "binary"); \
486 case kFloat_TypeKind: \
487 return FFunc(builder, FOp, left, right, "binary"); \
488 default: \
489 ABORT("unsupported typeKind"); \
490 } \
491 }
492 switch (b.fOperator) {
493 case Token::EQ: {
494 std::unique_ptr<LValue> lvalue = this->getLValue(builder, *b.fLeft);
495 LLVMValueRef result = this->compileExpression(builder, *b.fRight);
496 lvalue->store(builder, result);
497 return result;
498 }
499 case Token::PLUS:
500 BINARY(LLVMBuildAdd, LLVMBuildAdd, LLVMBuildFAdd);
501 case Token::MINUS:
502 BINARY(LLVMBuildSub, LLVMBuildSub, LLVMBuildFSub);
503 case Token::STAR:
504 BINARY(LLVMBuildMul, LLVMBuildMul, LLVMBuildFMul);
505 case Token::SLASH:
506 BINARY(LLVMBuildSDiv, LLVMBuildUDiv, LLVMBuildFDiv);
507 case Token::PERCENT:
508 BINARY(LLVMBuildSRem, LLVMBuildURem, LLVMBuildSRem);
509 case Token::BITWISEAND:
510 BINARY(LLVMBuildAnd, LLVMBuildAnd, LLVMBuildAnd);
511 case Token::BITWISEOR:
512 BINARY(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr);
513 case Token::PLUSEQ:
514 COMPOUND(LLVMBuildAdd, LLVMBuildAdd, LLVMBuildFAdd);
515 case Token::MINUSEQ:
516 COMPOUND(LLVMBuildSub, LLVMBuildSub, LLVMBuildFSub);
517 case Token::STAREQ:
518 COMPOUND(LLVMBuildMul, LLVMBuildMul, LLVMBuildFMul);
519 case Token::SLASHEQ:
520 COMPOUND(LLVMBuildSDiv, LLVMBuildUDiv, LLVMBuildFDiv);
521 case Token::BITWISEANDEQ:
522 COMPOUND(LLVMBuildAnd, LLVMBuildAnd, LLVMBuildAnd);
523 case Token::BITWISEOREQ:
524 COMPOUND(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr);
525 case Token::EQEQ:
526 COMPARE(LLVMBuildICmp, LLVMIntEQ,
527 LLVMBuildICmp, LLVMIntEQ,
528 LLVMBuildFCmp, LLVMRealOEQ);
529 case Token::NEQ:
530 COMPARE(LLVMBuildICmp, LLVMIntNE,
531 LLVMBuildICmp, LLVMIntNE,
532 LLVMBuildFCmp, LLVMRealONE);
533 case Token::LT:
534 COMPARE(LLVMBuildICmp, LLVMIntSLT,
535 LLVMBuildICmp, LLVMIntULT,
536 LLVMBuildFCmp, LLVMRealOLT);
537 case Token::LTEQ:
538 COMPARE(LLVMBuildICmp, LLVMIntSLE,
539 LLVMBuildICmp, LLVMIntULE,
540 LLVMBuildFCmp, LLVMRealOLE);
541 case Token::GT:
542 COMPARE(LLVMBuildICmp, LLVMIntSGT,
543 LLVMBuildICmp, LLVMIntUGT,
544 LLVMBuildFCmp, LLVMRealOGT);
545 case Token::GTEQ:
546 COMPARE(LLVMBuildICmp, LLVMIntSGE,
547 LLVMBuildICmp, LLVMIntUGE,
548 LLVMBuildFCmp, LLVMRealOGE);
549 case Token::LOGICALAND: {
550 LLVMValueRef left = this->compileExpression(builder, *b.fLeft);
551 LLVMBasicBlockRef ifFalse = fCurrentBlock;
552 LLVMBasicBlockRef ifTrue = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
553 "true && ...");
554 LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
555 "&& merge");
556 LLVMBuildCondBr(builder, left, ifTrue, merge);
557 this->setBlock(builder, ifTrue);
558 LLVMValueRef right = this->compileExpression(builder, *b.fRight);
559 LLVMBuildBr(builder, merge);
560 this->setBlock(builder, merge);
561 LLVMValueRef phi = LLVMBuildPhi(builder, fInt1Type, "&&");
562 LLVMValueRef incomingValues[2] = { right, LLVMConstInt(fInt1Type, 0, false) };
563 LLVMBasicBlockRef incomingBlocks[2] = { ifTrue, ifFalse };
564 LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2);
565 return phi;
566 }
567 case Token::LOGICALOR: {
568 LLVMValueRef left = this->compileExpression(builder, *b.fLeft);
569 LLVMBasicBlockRef ifTrue = fCurrentBlock;
570 LLVMBasicBlockRef ifFalse = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
571 "false || ...");
572 LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
573 "|| merge");
574 LLVMBuildCondBr(builder, left, merge, ifFalse);
575 this->setBlock(builder, ifFalse);
576 LLVMValueRef right = this->compileExpression(builder, *b.fRight);
577 LLVMBuildBr(builder, merge);
578 this->setBlock(builder, merge);
579 LLVMValueRef phi = LLVMBuildPhi(builder, fInt1Type, "||");
580 LLVMValueRef incomingValues[2] = { right, LLVMConstInt(fInt1Type, 1, false) };
581 LLVMBasicBlockRef incomingBlocks[2] = { ifFalse, ifTrue };
582 LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2);
583 return phi;
584 }
585 default:
586 ABORT("unsupported binary operator");
587 }
588}
589
590LLVMValueRef JIT::compileIndex(LLVMBuilderRef builder, const IndexExpression& idx) {
591 LLVMValueRef base = this->compileExpression(builder, *idx.fBase);
592 LLVMValueRef index = this->compileExpression(builder, *idx.fIndex);
593 LLVMValueRef ptr = LLVMBuildGEP(builder, base, &index, 1, "index ptr");
594 return LLVMBuildLoad(builder, ptr, "index load");
595}
596
597LLVMValueRef JIT::compilePostfix(LLVMBuilderRef builder, const PostfixExpression& p) {
598 std::unique_ptr<LValue> lvalue = this->getLValue(builder, *p.fOperand);
599 LLVMValueRef result = lvalue->load(builder);
600 LLVMValueRef mod;
601 LLVMValueRef one = LLVMConstInt(this->getType(p.fType), 1, false);
602 switch (p.fOperator) {
603 case Token::PLUSPLUS:
604 switch (this->typeKind(p.fType)) {
605 case kInt_TypeKind: // fall through
606 case kUInt_TypeKind:
607 mod = LLVMBuildAdd(builder, result, one, "++");
608 break;
609 case kFloat_TypeKind:
610 mod = LLVMBuildFAdd(builder, result, one, "++");
611 break;
612 default:
613 ABORT("unsupported typeKind");
614 }
615 break;
616 case Token::MINUSMINUS:
617 switch (this->typeKind(p.fType)) {
618 case kInt_TypeKind: // fall through
619 case kUInt_TypeKind:
620 mod = LLVMBuildSub(builder, result, one, "--");
621 break;
622 case kFloat_TypeKind:
623 mod = LLVMBuildFSub(builder, result, one, "--");
624 break;
625 default:
626 ABORT("unsupported typeKind");
627 }
628 break;
629 default:
630 ABORT("unsupported postfix op");
631 }
632 lvalue->store(builder, mod);
633 return result;
634}
635
636LLVMValueRef JIT::compilePrefix(LLVMBuilderRef builder, const PrefixExpression& p) {
637 LLVMValueRef one = LLVMConstInt(this->getType(p.fType), 1, false);
638 if (Token::LOGICALNOT == p.fOperator) {
639 LLVMValueRef base = this->compileExpression(builder, *p.fOperand);
640 return LLVMBuildXor(builder, base, one, "!");
641 }
642 if (Token::MINUS == p.fOperator) {
643 LLVMValueRef base = this->compileExpression(builder, *p.fOperand);
644 return LLVMBuildSub(builder, LLVMConstInt(this->getType(p.fType), 0, false), base, "-");
645 }
646 std::unique_ptr<LValue> lvalue = this->getLValue(builder, *p.fOperand);
647 LLVMValueRef raw = lvalue->load(builder);
648 LLVMValueRef result;
649 switch (p.fOperator) {
650 case Token::PLUSPLUS:
651 switch (this->typeKind(p.fType)) {
652 case kInt_TypeKind: // fall through
653 case kUInt_TypeKind:
654 result = LLVMBuildAdd(builder, raw, one, "++");
655 break;
656 case kFloat_TypeKind:
657 result = LLVMBuildFAdd(builder, raw, one, "++");
658 break;
659 default:
660 ABORT("unsupported typeKind");
661 }
662 break;
663 case Token::MINUSMINUS:
664 switch (this->typeKind(p.fType)) {
665 case kInt_TypeKind: // fall through
666 case kUInt_TypeKind:
667 result = LLVMBuildSub(builder, raw, one, "--");
668 break;
669 case kFloat_TypeKind:
670 result = LLVMBuildFSub(builder, raw, one, "--");
671 break;
672 default:
673 ABORT("unsupported typeKind");
674 }
675 break;
676 default:
677 ABORT("unsupported prefix op");
678 }
679 lvalue->store(builder, result);
680 return result;
681}
682
683LLVMValueRef JIT::compileVariableReference(LLVMBuilderRef builder, const VariableReference& v) {
684 const Variable& var = v.fVariable;
685 if (Variable::kParameter_Storage == var.fStorage &&
686 !(var.fModifiers.fFlags & Modifiers::kOut_Flag) &&
687 fPromotedParameters.find(&var) == fPromotedParameters.end()) {
688 return fVariables[&var];
689 }
690 return LLVMBuildLoad(builder, fVariables[&var], String(var.fName).c_str());
691}
692
693void JIT::appendStage(LLVMBuilderRef builder, const AppendStage& a) {
Ethan Nicholasd9d33c32018-06-12 11:05:59 -0400694 SkASSERT(a.fArguments.size() >= 1);
695 SkASSERT(a.fArguments[0]->fType == *fCompiler.context().fSkRasterPipeline_Type);
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400696 LLVMValueRef pipeline = this->compileExpression(builder, *a.fArguments[0]);
697 LLVMValueRef stage = LLVMConstInt(fInt32Type, a.fStage, 0);
698 switch (a.fStage) {
699 case SkRasterPipeline::callback: {
Ethan Nicholasd9d33c32018-06-12 11:05:59 -0400700 SkASSERT(a.fArguments.size() == 2);
701 SkASSERT(a.fArguments[1]->fKind == Expression::kFunctionReference_Kind);
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400702 const FunctionDeclaration& functionDecl =
703 *((FunctionReference&) *a.fArguments[1]).fFunctions[0];
704 bool found = false;
705 for (const auto& pe : fProgram->fElements) {
706 if (ProgramElement::kFunction_Kind == pe->fKind) {
707 const FunctionDefinition& def = (const FunctionDefinition&) *pe;
708 if (&def.fDeclaration == &functionDecl) {
709 LLVMValueRef fn = this->compileStageFunction(def);
710 LLVMValueRef args[2] = {
711 pipeline,
712 LLVMBuildBitCast(builder, fn, fInt8PtrType, "callback cast")
713 };
714 LLVMBuildCall(builder, fAppendCallbackFunc, args, 2, "");
715 found = true;
716 break;
717 }
718 }
719 }
Ethan Nicholasd9d33c32018-06-12 11:05:59 -0400720 SkASSERT(found);
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400721 break;
722 }
723 default: {
724 LLVMValueRef ctx;
725 if (a.fArguments.size() == 2) {
726 ctx = this->compileExpression(builder, *a.fArguments[1]);
727 ctx = LLVMBuildBitCast(builder, ctx, fInt8PtrType, "context cast");
728 } else {
Ethan Nicholasd9d33c32018-06-12 11:05:59 -0400729 SkASSERT(a.fArguments.size() == 1);
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400730 ctx = LLVMConstNull(fInt8PtrType);
731 }
732 LLVMValueRef args[3] = {
733 pipeline,
734 stage,
735 ctx
736 };
737 LLVMBuildCall(builder, fAppendFunc, args, 3, "");
738 break;
739 }
740 }
741}
742
743LLVMValueRef JIT::compileConstructor(LLVMBuilderRef builder, const Constructor& c) {
744 switch (c.fType.kind()) {
745 case Type::kScalar_Kind: {
Ethan Nicholasd9d33c32018-06-12 11:05:59 -0400746 SkASSERT(c.fArguments.size() == 1);
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400747 TypeKind from = this->typeKind(c.fArguments[0]->fType);
748 TypeKind to = this->typeKind(c.fType);
749 LLVMValueRef base = this->compileExpression(builder, *c.fArguments[0]);
750 if (kFloat_TypeKind == to) {
751 if (kInt_TypeKind == from) {
752 return LLVMBuildSIToFP(builder, base, this->getType(c.fType), "cast");
753 }
754 if (kUInt_TypeKind == from) {
755 return LLVMBuildUIToFP(builder, base, this->getType(c.fType), "cast");
756 }
757 }
758 if (kInt_TypeKind == to) {
759 if (kFloat_TypeKind == from) {
760 return LLVMBuildFPToSI(builder, base, this->getType(c.fType), "cast");
761 }
762 if (kUInt_TypeKind == from) {
763 return base;
764 }
765 }
766 if (kUInt_TypeKind == to) {
767 if (kFloat_TypeKind == from) {
768 return LLVMBuildFPToUI(builder, base, this->getType(c.fType), "cast");
769 }
770 if (kInt_TypeKind == from) {
771 return base;
772 }
773 }
774 ABORT("unsupported constructor");
775 }
776 case Type::kVector_Kind: {
777 LLVMValueRef vec = LLVMGetUndef(this->getType(c.fType));
778 if (c.fArguments.size() == 1) {
779 LLVMValueRef value = this->compileExpression(builder, *c.fArguments[0]);
780 for (int i = 0; i < c.fType.columns(); ++i) {
781 vec = LLVMBuildInsertElement(builder, vec, value,
782 LLVMConstInt(fInt32Type, i, false),
783 "vec build");
784 }
785 } else {
Ethan Nicholasd9d33c32018-06-12 11:05:59 -0400786 SkASSERT(c.fArguments.size() == (size_t) c.fType.columns());
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400787 for (int i = 0; i < c.fType.columns(); ++i) {
788 vec = LLVMBuildInsertElement(builder, vec,
789 this->compileExpression(builder,
790 *c.fArguments[i]),
791 LLVMConstInt(fInt32Type, i, false),
792 "vec build");
793 }
794 }
795 return vec;
796 }
797 default:
798 break;
799 }
800 ABORT("unsupported constructor");
801}
802
803LLVMValueRef JIT::compileSwizzle(LLVMBuilderRef builder, const Swizzle& s) {
804 LLVMValueRef base = this->compileExpression(builder, *s.fBase);
805 if (s.fComponents.size() > 1) {
806 LLVMValueRef result = LLVMGetUndef(this->getType(s.fType));
807 for (size_t i = 0; i < s.fComponents.size(); ++i) {
808 LLVMValueRef element = LLVMBuildExtractElement(
809 builder,
810 base,
811 LLVMConstInt(fInt32Type,
812 s.fComponents[i],
813 false),
814 "swizzle extract");
815 result = LLVMBuildInsertElement(builder, result, element,
816 LLVMConstInt(fInt32Type, i, false),
817 "swizzle insert");
818 }
819 return result;
820 }
Ethan Nicholasd9d33c32018-06-12 11:05:59 -0400821 SkASSERT(s.fComponents.size() == 1);
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400822 return LLVMBuildExtractElement(builder, base,
823 LLVMConstInt(fInt32Type,
824 s.fComponents[0],
825 false),
826 "swizzle extract");
827}
828
829LLVMValueRef JIT::compileTernary(LLVMBuilderRef builder, const TernaryExpression& t) {
830 LLVMValueRef test = this->compileExpression(builder, *t.fTest);
831 LLVMBasicBlockRef trueBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
832 "if true");
833 LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
834 "if merge");
835 LLVMBasicBlockRef falseBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
836 "if false");
837 LLVMBuildCondBr(builder, test, trueBlock, falseBlock);
838 this->setBlock(builder, trueBlock);
839 LLVMValueRef ifTrue = this->compileExpression(builder, *t.fIfTrue);
840 trueBlock = fCurrentBlock;
841 LLVMBuildBr(builder, merge);
842 this->setBlock(builder, falseBlock);
843 LLVMValueRef ifFalse = this->compileExpression(builder, *t.fIfFalse);
844 falseBlock = fCurrentBlock;
845 LLVMBuildBr(builder, merge);
846 this->setBlock(builder, merge);
847 LLVMValueRef phi = LLVMBuildPhi(builder, this->getType(t.fType), "?");
848 LLVMValueRef incomingValues[2] = { ifTrue, ifFalse };
849 LLVMBasicBlockRef incomingBlocks[2] = { trueBlock, falseBlock };
850 LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2);
851 return phi;
852}
853
854LLVMValueRef JIT::compileExpression(LLVMBuilderRef builder, const Expression& expr) {
855 switch (expr.fKind) {
856 case Expression::kAppendStage_Kind: {
857 this->appendStage(builder, (const AppendStage&) expr);
858 return LLVMValueRef();
859 }
860 case Expression::kBinary_Kind:
861 return this->compileBinary(builder, (BinaryExpression&) expr);
862 case Expression::kBoolLiteral_Kind:
863 return LLVMConstInt(fInt1Type, ((BoolLiteral&) expr).fValue, false);
864 case Expression::kConstructor_Kind:
865 return this->compileConstructor(builder, (Constructor&) expr);
866 case Expression::kIntLiteral_Kind:
867 return LLVMConstInt(this->getType(expr.fType), ((IntLiteral&) expr).fValue, true);
868 case Expression::kFieldAccess_Kind:
869 abort();
870 case Expression::kFloatLiteral_Kind:
871 return LLVMConstReal(this->getType(expr.fType), ((FloatLiteral&) expr).fValue);
872 case Expression::kFunctionCall_Kind:
873 return this->compileFunctionCall(builder, (FunctionCall&) expr);
874 case Expression::kIndex_Kind:
875 return this->compileIndex(builder, (IndexExpression&) expr);
876 case Expression::kPrefix_Kind:
877 return this->compilePrefix(builder, (PrefixExpression&) expr);
878 case Expression::kPostfix_Kind:
879 return this->compilePostfix(builder, (PostfixExpression&) expr);
880 case Expression::kSetting_Kind:
881 abort();
882 case Expression::kSwizzle_Kind:
883 return this->compileSwizzle(builder, (Swizzle&) expr);
884 case Expression::kVariableReference_Kind:
885 return this->compileVariableReference(builder, (VariableReference&) expr);
886 case Expression::kTernary_Kind:
887 return this->compileTernary(builder, (TernaryExpression&) expr);
888 case Expression::kTypeReference_Kind:
889 abort();
890 default:
891 abort();
892 }
893 ABORT("unsupported expression: %s\n", expr.description().c_str());
894}
895
896void JIT::compileBlock(LLVMBuilderRef builder, const Block& block) {
897 for (const auto& stmt : block.fStatements) {
898 this->compileStatement(builder, *stmt);
899 }
900}
901
902void JIT::compileVarDeclarations(LLVMBuilderRef builder, const VarDeclarationsStatement& decls) {
903 for (const auto& declStatement : decls.fDeclaration->fVars) {
904 const VarDeclaration& decl = (VarDeclaration&) *declStatement;
905 LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
906 LLVMValueRef alloca = LLVMBuildAlloca(builder, this->getType(decl.fVar->fType),
907 String(decl.fVar->fName).c_str());
908 fVariables[decl.fVar] = alloca;
909 LLVMPositionBuilderAtEnd(builder, fCurrentBlock);
910 if (decl.fValue) {
911 LLVMValueRef result = this->compileExpression(builder, *decl.fValue);
912 LLVMBuildStore(builder, result, alloca);
913 }
914 }
915}
916
917void JIT::compileIf(LLVMBuilderRef builder, const IfStatement& i) {
918 LLVMValueRef test = this->compileExpression(builder, *i.fTest);
919 LLVMBasicBlockRef ifTrue = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "if true");
920 LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
921 "if merge");
922 LLVMBasicBlockRef ifFalse;
923 if (i.fIfFalse) {
924 ifFalse = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "if false");
925 } else {
926 ifFalse = merge;
927 }
928 LLVMBuildCondBr(builder, test, ifTrue, ifFalse);
929 this->setBlock(builder, ifTrue);
930 this->compileStatement(builder, *i.fIfTrue);
931 if (!ends_with_branch(*i.fIfTrue)) {
932 LLVMBuildBr(builder, merge);
933 }
934 if (i.fIfFalse) {
935 this->setBlock(builder, ifFalse);
936 this->compileStatement(builder, *i.fIfFalse);
937 if (!ends_with_branch(*i.fIfFalse)) {
938 LLVMBuildBr(builder, merge);
939 }
940 }
941 this->setBlock(builder, merge);
942}
943
944void JIT::compileFor(LLVMBuilderRef builder, const ForStatement& f) {
945 if (f.fInitializer) {
946 this->compileStatement(builder, *f.fInitializer);
947 }
948 LLVMBasicBlockRef start;
949 LLVMBasicBlockRef body = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for body");
950 LLVMBasicBlockRef next = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for next");
951 LLVMBasicBlockRef end = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for end");
952 if (f.fTest) {
953 start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for test");
954 LLVMBuildBr(builder, start);
955 this->setBlock(builder, start);
956 LLVMValueRef test = this->compileExpression(builder, *f.fTest);
957 LLVMBuildCondBr(builder, test, body, end);
958 } else {
959 start = body;
960 LLVMBuildBr(builder, body);
961 }
962 this->setBlock(builder, body);
963 fBreakTarget.push_back(end);
964 fContinueTarget.push_back(next);
965 this->compileStatement(builder, *f.fStatement);
966 fBreakTarget.pop_back();
967 fContinueTarget.pop_back();
968 if (!ends_with_branch(*f.fStatement)) {
969 LLVMBuildBr(builder, next);
970 }
971 this->setBlock(builder, next);
972 if (f.fNext) {
973 this->compileExpression(builder, *f.fNext);
974 }
975 LLVMBuildBr(builder, start);
976 this->setBlock(builder, end);
977}
978
979void JIT::compileDo(LLVMBuilderRef builder, const DoStatement& d) {
980 LLVMBasicBlockRef testBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
981 "do test");
982 LLVMBasicBlockRef body = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
983 "do body");
984 LLVMBasicBlockRef end = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
985 "do end");
986 LLVMBuildBr(builder, body);
987 this->setBlock(builder, testBlock);
988 LLVMValueRef test = this->compileExpression(builder, *d.fTest);
989 LLVMBuildCondBr(builder, test, body, end);
990 this->setBlock(builder, body);
991 fBreakTarget.push_back(end);
992 fContinueTarget.push_back(body);
993 this->compileStatement(builder, *d.fStatement);
994 fBreakTarget.pop_back();
995 fContinueTarget.pop_back();
996 if (!ends_with_branch(*d.fStatement)) {
997 LLVMBuildBr(builder, testBlock);
998 }
999 this->setBlock(builder, end);
1000}
1001
1002void JIT::compileWhile(LLVMBuilderRef builder, const WhileStatement& w) {
1003 LLVMBasicBlockRef testBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
1004 "while test");
1005 LLVMBasicBlockRef body = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
1006 "while body");
1007 LLVMBasicBlockRef end = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
1008 "while end");
1009 LLVMBuildBr(builder, testBlock);
1010 this->setBlock(builder, testBlock);
1011 LLVMValueRef test = this->compileExpression(builder, *w.fTest);
1012 LLVMBuildCondBr(builder, test, body, end);
1013 this->setBlock(builder, body);
1014 fBreakTarget.push_back(end);
1015 fContinueTarget.push_back(testBlock);
1016 this->compileStatement(builder, *w.fStatement);
1017 fBreakTarget.pop_back();
1018 fContinueTarget.pop_back();
1019 if (!ends_with_branch(*w.fStatement)) {
1020 LLVMBuildBr(builder, testBlock);
1021 }
1022 this->setBlock(builder, end);
1023}
1024
1025void JIT::compileBreak(LLVMBuilderRef builder, const BreakStatement& b) {
1026 LLVMBuildBr(builder, fBreakTarget.back());
1027}
1028
1029void JIT::compileContinue(LLVMBuilderRef builder, const ContinueStatement& b) {
1030 LLVMBuildBr(builder, fContinueTarget.back());
1031}
1032
1033void JIT::compileReturn(LLVMBuilderRef builder, const ReturnStatement& r) {
1034 if (r.fExpression) {
1035 LLVMBuildRet(builder, this->compileExpression(builder, *r.fExpression));
1036 } else {
1037 LLVMBuildRetVoid(builder);
1038 }
1039}
1040
1041void JIT::compileStatement(LLVMBuilderRef builder, const Statement& stmt) {
1042 switch (stmt.fKind) {
1043 case Statement::kBlock_Kind:
1044 this->compileBlock(builder, (Block&) stmt);
1045 break;
1046 case Statement::kBreak_Kind:
1047 this->compileBreak(builder, (BreakStatement&) stmt);
1048 break;
1049 case Statement::kContinue_Kind:
1050 this->compileContinue(builder, (ContinueStatement&) stmt);
1051 break;
1052 case Statement::kDiscard_Kind:
1053 abort();
1054 case Statement::kDo_Kind:
1055 this->compileDo(builder, (DoStatement&) stmt);
1056 break;
1057 case Statement::kExpression_Kind:
1058 this->compileExpression(builder, *((ExpressionStatement&) stmt).fExpression);
1059 break;
1060 case Statement::kFor_Kind:
1061 this->compileFor(builder, (ForStatement&) stmt);
1062 break;
1063 case Statement::kGroup_Kind:
1064 abort();
1065 case Statement::kIf_Kind:
1066 this->compileIf(builder, (IfStatement&) stmt);
1067 break;
1068 case Statement::kNop_Kind:
1069 break;
1070 case Statement::kReturn_Kind:
1071 this->compileReturn(builder, (ReturnStatement&) stmt);
1072 break;
1073 case Statement::kSwitch_Kind:
1074 abort();
1075 case Statement::kVarDeclarations_Kind:
1076 this->compileVarDeclarations(builder, (VarDeclarationsStatement&) stmt);
1077 break;
1078 case Statement::kWhile_Kind:
1079 this->compileWhile(builder, (WhileStatement&) stmt);
1080 break;
1081 default:
1082 abort();
1083 }
1084}
1085
1086void JIT::compileStageFunctionLoop(const FunctionDefinition& f, LLVMValueRef newFunc) {
1087 // loop over fVectorCount pixels, running the body of the stage function for each of them
1088 LLVMValueRef oldFunction = fCurrentFunction;
1089 fCurrentFunction = newFunc;
1090 std::unique_ptr<LLVMValueRef[]> params(new LLVMValueRef[STAGE_PARAM_COUNT]);
1091 LLVMGetParams(fCurrentFunction, params.get());
1092 LLVMValueRef programParam = params.get()[1];
1093 LLVMBuilderRef builder = LLVMCreateBuilderInContext(fContext);
1094 LLVMBasicBlockRef oldAllocaBlock = fAllocaBlock;
1095 LLVMBasicBlockRef oldCurrentBlock = fCurrentBlock;
1096 fAllocaBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "alloca");
1097 this->setBlock(builder, fAllocaBlock);
1098 // temporaries to store the color channel vectors
1099 LLVMValueRef rVec = LLVMBuildAlloca(builder, fFloat32VectorType, "rVec");
1100 LLVMBuildStore(builder, params.get()[4], rVec);
1101 LLVMValueRef gVec = LLVMBuildAlloca(builder, fFloat32VectorType, "gVec");
1102 LLVMBuildStore(builder, params.get()[5], gVec);
1103 LLVMValueRef bVec = LLVMBuildAlloca(builder, fFloat32VectorType, "bVec");
1104 LLVMBuildStore(builder, params.get()[6], bVec);
1105 LLVMValueRef aVec = LLVMBuildAlloca(builder, fFloat32VectorType, "aVec");
1106 LLVMBuildStore(builder, params.get()[7], aVec);
1107 LLVMValueRef color = LLVMBuildAlloca(builder, fFloat32Vector4Type, "color");
1108 fVariables[f.fDeclaration.fParameters[1]] = LLVMBuildTrunc(builder, params.get()[3], fInt32Type,
1109 "y->Int32");
1110 fVariables[f.fDeclaration.fParameters[2]] = color;
1111 LLVMValueRef ivar = LLVMBuildAlloca(builder, fInt32Type, "i");
1112 LLVMBuildStore(builder, LLVMConstInt(fInt32Type, 0, false), ivar);
1113 LLVMBasicBlockRef start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "start");
1114 this->setBlock(builder, start);
1115 LLVMValueRef iload = LLVMBuildLoad(builder, ivar, "load i");
1116 fVariables[f.fDeclaration.fParameters[0]] = LLVMBuildAdd(builder,
1117 LLVMBuildTrunc(builder,
1118 params.get()[2],
1119 fInt32Type,
1120 "x->Int32"),
1121 iload,
1122 "x");
1123 LLVMValueRef vectorSize = LLVMConstInt(fInt32Type, fVectorCount, false);
1124 LLVMValueRef test = LLVMBuildICmp(builder, LLVMIntSLT, iload, vectorSize, "i < vectorSize");
1125 LLVMBasicBlockRef loopBody = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "body");
1126 LLVMBasicBlockRef loopEnd = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "end");
1127 LLVMBuildCondBr(builder, test, loopBody, loopEnd);
1128 this->setBlock(builder, loopBody);
1129 LLVMValueRef vec = LLVMGetUndef(fFloat32Vector4Type);
1130 // extract the r, g, b, and a values from the color channel vectors and store them into "color"
1131 for (int i = 0; i < 4; ++i) {
1132 vec = LLVMBuildInsertElement(builder, vec,
1133 LLVMBuildExtractElement(builder,
1134 params.get()[4 + i],
1135 iload, "initial"),
1136 LLVMConstInt(fInt32Type, i, false),
1137 "vec build");
1138 }
1139 LLVMBuildStore(builder, vec, color);
1140 // write actual loop body
1141 this->compileStatement(builder, *f.fBody);
1142 // extract the r, g, b, and a values from "color" and stick them back into the color channel
1143 // vectors
1144 LLVMValueRef colorLoad = LLVMBuildLoad(builder, color, "color load");
1145 LLVMBuildStore(builder,
1146 LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, rVec, "rVec"),
1147 LLVMBuildExtractElement(builder, colorLoad,
1148 LLVMConstInt(fInt32Type, 0,
1149 false),
1150 "rExtract"),
1151 iload, "rInsert"),
1152 rVec);
1153 LLVMBuildStore(builder,
1154 LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, gVec, "gVec"),
1155 LLVMBuildExtractElement(builder, colorLoad,
1156 LLVMConstInt(fInt32Type, 1,
1157 false),
1158 "gExtract"),
1159 iload, "gInsert"),
1160 gVec);
1161 LLVMBuildStore(builder,
1162 LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, bVec, "bVec"),
1163 LLVMBuildExtractElement(builder, colorLoad,
1164 LLVMConstInt(fInt32Type, 2,
1165 false),
1166 "bExtract"),
1167 iload, "bInsert"),
1168 bVec);
1169 LLVMBuildStore(builder,
1170 LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, aVec, "aVec"),
1171 LLVMBuildExtractElement(builder, colorLoad,
1172 LLVMConstInt(fInt32Type, 3,
1173 false),
1174 "aExtract"),
1175 iload, "aInsert"),
1176 aVec);
1177 LLVMValueRef inc = LLVMBuildAdd(builder, iload, LLVMConstInt(fInt32Type, 1, false), "inc i");
1178 LLVMBuildStore(builder, inc, ivar);
1179 LLVMBuildBr(builder, start);
1180 this->setBlock(builder, loopEnd);
1181 // increment program pointer, call the next stage
1182 LLVMValueRef rawNextPtr = LLVMBuildLoad(builder, programParam, "next load");
1183 LLVMTypeRef stageFuncType = LLVMTypeOf(newFunc);
1184 LLVMValueRef nextPtr = LLVMBuildBitCast(builder, rawNextPtr, stageFuncType, "cast next->func");
1185 LLVMValueRef nextInc = LLVMBuildIntToPtr(builder,
1186 LLVMBuildAdd(builder,
1187 LLVMBuildPtrToInt(builder,
1188 programParam,
1189 fInt64Type,
1190 "cast 1"),
1191 LLVMConstInt(fInt64Type, PTR_SIZE, false),
1192 "add"),
1193 LLVMPointerType(fInt8PtrType, 0), "cast 2");
1194 LLVMValueRef args[STAGE_PARAM_COUNT] = {
1195 params.get()[0],
1196 nextInc,
1197 params.get()[2],
1198 params.get()[3],
1199 LLVMBuildLoad(builder, rVec, "rVec"),
1200 LLVMBuildLoad(builder, gVec, "gVec"),
1201 LLVMBuildLoad(builder, bVec, "bVec"),
1202 LLVMBuildLoad(builder, aVec, "aVec"),
1203 params.get()[8],
1204 params.get()[9],
1205 params.get()[10],
1206 params.get()[11]
1207 };
1208 LLVMBuildCall(builder, nextPtr, args, STAGE_PARAM_COUNT, "");
1209 LLVMBuildRetVoid(builder);
1210 // finish
1211 LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
1212 LLVMBuildBr(builder, start);
1213 LLVMDisposeBuilder(builder);
1214 if (LLVMVerifyFunction(fCurrentFunction, LLVMPrintMessageAction)) {
1215 ABORT("verify failed\n");
1216 }
1217 fAllocaBlock = oldAllocaBlock;
1218 fCurrentBlock = oldCurrentBlock;
1219 fCurrentFunction = oldFunction;
1220}
1221
1222// FIXME maybe pluggable code generators? Need to do something to separate all
1223// of the normal codegen from the vector codegen and break this up into multiple
1224// classes.
1225
1226bool JIT::getVectorLValue(LLVMBuilderRef builder, const Expression& e,
1227 LLVMValueRef out[CHANNELS]) {
1228 switch (e.fKind) {
1229 case Expression::kVariableReference_Kind:
1230 if (fColorParam == &((VariableReference&) e).fVariable) {
1231 memcpy(out, fChannels, sizeof(fChannels));
1232 return true;
1233 }
1234 return false;
1235 case Expression::kSwizzle_Kind: {
1236 const Swizzle& s = (const Swizzle&) e;
1237 LLVMValueRef base[CHANNELS];
1238 if (!this->getVectorLValue(builder, *s.fBase, base)) {
1239 return false;
1240 }
1241 for (size_t i = 0; i < s.fComponents.size(); ++i) {
1242 out[i] = base[s.fComponents[i]];
1243 }
1244 return true;
1245 }
1246 default:
1247 return false;
1248 }
1249}
1250
1251bool JIT::getVectorBinaryOperands(LLVMBuilderRef builder, const Expression& left,
1252 LLVMValueRef outLeft[CHANNELS], const Expression& right,
1253 LLVMValueRef outRight[CHANNELS]) {
1254 if (!this->compileVectorExpression(builder, left, outLeft)) {
1255 return false;
1256 }
1257 int leftColumns = left.fType.columns();
1258 int rightColumns = right.fType.columns();
1259 if (leftColumns == 1 && rightColumns > 1) {
1260 for (int i = 1; i < rightColumns; ++i) {
1261 outLeft[i] = outLeft[0];
1262 }
1263 }
1264 if (!this->compileVectorExpression(builder, right, outRight)) {
1265 return false;
1266 }
1267 if (rightColumns == 1 && leftColumns > 1) {
1268 for (int i = 1; i < leftColumns; ++i) {
1269 outRight[i] = outRight[0];
1270 }
1271 }
1272 return true;
1273}
1274
1275bool JIT::compileVectorBinary(LLVMBuilderRef builder, const BinaryExpression& b,
1276 LLVMValueRef out[CHANNELS]) {
1277 LLVMValueRef left[CHANNELS];
1278 LLVMValueRef right[CHANNELS];
1279 #define VECTOR_BINARY(signedOp, unsignedOp, floatOp) { \
1280 if (!this->getVectorBinaryOperands(builder, *b.fLeft, left, *b.fRight, right)) { \
1281 return false; \
1282 } \
1283 for (int i = 0; i < b.fLeft->fType.columns(); ++i) { \
1284 switch (this->typeKind(b.fLeft->fType)) { \
1285 case kInt_TypeKind: \
1286 out[i] = signedOp(builder, left[i], right[i], "binary"); \
1287 break; \
1288 case kUInt_TypeKind: \
1289 out[i] = unsignedOp(builder, left[i], right[i], "binary"); \
1290 break; \
1291 case kFloat_TypeKind: \
1292 out[i] = floatOp(builder, left[i], right[i], "binary"); \
1293 break; \
1294 case kBool_TypeKind: \
Ethan Nicholasd9d33c32018-06-12 11:05:59 -04001295 SkASSERT(false); \
Ethan Nicholas26a9aad2018-03-27 14:10:52 -04001296 break; \
1297 } \
1298 } \
1299 return true; \
1300 }
1301 switch (b.fOperator) {
1302 case Token::EQ: {
1303 if (!this->getVectorLValue(builder, *b.fLeft, left)) {
1304 return false;
1305 }
1306 if (!this->compileVectorExpression(builder, *b.fRight, right)) {
1307 return false;
1308 }
1309 int columns = b.fRight->fType.columns();
1310 for (int i = 0; i < columns; ++i) {
1311 LLVMBuildStore(builder, right[i], left[i]);
1312 }
1313 return true;
1314 }
1315 case Token::PLUS:
1316 VECTOR_BINARY(LLVMBuildAdd, LLVMBuildAdd, LLVMBuildFAdd);
1317 case Token::MINUS:
1318 VECTOR_BINARY(LLVMBuildSub, LLVMBuildSub, LLVMBuildFSub);
1319 case Token::STAR:
1320 VECTOR_BINARY(LLVMBuildMul, LLVMBuildMul, LLVMBuildFMul);
1321 case Token::SLASH:
1322 VECTOR_BINARY(LLVMBuildSDiv, LLVMBuildUDiv, LLVMBuildFDiv);
1323 case Token::PERCENT:
1324 VECTOR_BINARY(LLVMBuildSRem, LLVMBuildURem, LLVMBuildSRem);
1325 case Token::BITWISEAND:
1326 VECTOR_BINARY(LLVMBuildAnd, LLVMBuildAnd, LLVMBuildAnd);
1327 case Token::BITWISEOR:
1328 VECTOR_BINARY(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr);
1329 default:
1330 printf("unsupported operator: %s\n", b.description().c_str());
1331 return false;
1332 }
1333}
1334
1335bool JIT::compileVectorConstructor(LLVMBuilderRef builder, const Constructor& c,
1336 LLVMValueRef out[CHANNELS]) {
1337 switch (c.fType.kind()) {
1338 case Type::kScalar_Kind: {
Ethan Nicholasd9d33c32018-06-12 11:05:59 -04001339 SkASSERT(c.fArguments.size() == 1);
Ethan Nicholas26a9aad2018-03-27 14:10:52 -04001340 TypeKind from = this->typeKind(c.fArguments[0]->fType);
1341 TypeKind to = this->typeKind(c.fType);
1342 LLVMValueRef base[CHANNELS];
1343 if (!this->compileVectorExpression(builder, *c.fArguments[0], base)) {
1344 return false;
1345 }
1346 #define CONSTRUCT(fn) \
1347 out[0] = LLVMGetUndef(LLVMVectorType(this->getType(c.fType), fVectorCount)); \
1348 for (int i = 0; i < fVectorCount; ++i) { \
1349 LLVMValueRef index = LLVMConstInt(fInt32Type, i, false); \
1350 LLVMValueRef baseVal = LLVMBuildExtractElement(builder, base[0], index, \
1351 "construct extract"); \
1352 out[0] = LLVMBuildInsertElement(builder, out[0], \
1353 fn(builder, baseVal, this->getType(c.fType), \
1354 "cast"), \
1355 index, "construct insert"); \
1356 } \
1357 return true;
1358 if (kFloat_TypeKind == to) {
1359 if (kInt_TypeKind == from) {
1360 CONSTRUCT(LLVMBuildSIToFP);
1361 }
1362 if (kUInt_TypeKind == from) {
1363 CONSTRUCT(LLVMBuildUIToFP);
1364 }
1365 }
1366 if (kInt_TypeKind == to) {
1367 if (kFloat_TypeKind == from) {
1368 CONSTRUCT(LLVMBuildFPToSI);
1369 }
1370 if (kUInt_TypeKind == from) {
1371 return true;
1372 }
1373 }
1374 if (kUInt_TypeKind == to) {
1375 if (kFloat_TypeKind == from) {
1376 CONSTRUCT(LLVMBuildFPToUI);
1377 }
1378 if (kInt_TypeKind == from) {
1379 return base;
1380 }
1381 }
1382 printf("%s\n", c.description().c_str());
1383 ABORT("unsupported constructor");
1384 }
1385 case Type::kVector_Kind: {
1386 if (c.fArguments.size() == 1) {
1387 LLVMValueRef base[CHANNELS];
1388 if (!this->compileVectorExpression(builder, *c.fArguments[0], base)) {
1389 return false;
1390 }
1391 for (int i = 0; i < c.fType.columns(); ++i) {
1392 out[i] = base[0];
1393 }
1394 } else {
Ethan Nicholasd9d33c32018-06-12 11:05:59 -04001395 SkASSERT(c.fArguments.size() == (size_t) c.fType.columns());
Ethan Nicholas26a9aad2018-03-27 14:10:52 -04001396 for (int i = 0; i < c.fType.columns(); ++i) {
1397 LLVMValueRef base[CHANNELS];
1398 if (!this->compileVectorExpression(builder, *c.fArguments[i], base)) {
1399 return false;
1400 }
1401 out[i] = base[0];
1402 }
1403 }
1404 return true;
1405 }
1406 default:
1407 break;
1408 }
1409 ABORT("unsupported constructor");
1410}
1411
1412bool JIT::compileVectorFloatLiteral(LLVMBuilderRef builder,
1413 const FloatLiteral& f,
1414 LLVMValueRef out[CHANNELS]) {
1415 LLVMValueRef value = LLVMConstReal(this->getType(f.fType), f.fValue);
1416 LLVMValueRef values[MAX_VECTOR_COUNT];
1417 for (int i = 0; i < fVectorCount; ++i) {
1418 values[i] = value;
1419 }
1420 out[0] = LLVMConstVector(values, fVectorCount);
1421 return true;
1422}
1423
1424
1425bool JIT::compileVectorSwizzle(LLVMBuilderRef builder, const Swizzle& s,
1426 LLVMValueRef out[CHANNELS]) {
1427 LLVMValueRef all[CHANNELS];
1428 if (!this->compileVectorExpression(builder, *s.fBase, all)) {
1429 return false;
1430 }
1431 for (size_t i = 0; i < s.fComponents.size(); ++i) {
1432 out[i] = all[s.fComponents[i]];
1433 }
1434 return true;
1435}
1436
1437bool JIT::compileVectorVariableReference(LLVMBuilderRef builder, const VariableReference& v,
1438 LLVMValueRef out[CHANNELS]) {
1439 if (&v.fVariable == fColorParam) {
1440 for (int i = 0; i < CHANNELS; ++i) {
1441 out[i] = LLVMBuildLoad(builder, fChannels[i], "variable reference");
1442 }
1443 return true;
1444 }
1445 return false;
1446}
1447
1448bool JIT::compileVectorExpression(LLVMBuilderRef builder, const Expression& expr,
1449 LLVMValueRef out[CHANNELS]) {
1450 switch (expr.fKind) {
1451 case Expression::kBinary_Kind:
1452 return this->compileVectorBinary(builder, (const BinaryExpression&) expr, out);
1453 case Expression::kConstructor_Kind:
1454 return this->compileVectorConstructor(builder, (const Constructor&) expr, out);
1455 case Expression::kFloatLiteral_Kind:
1456 return this->compileVectorFloatLiteral(builder, (const FloatLiteral&) expr, out);
1457 case Expression::kSwizzle_Kind:
1458 return this->compileVectorSwizzle(builder, (const Swizzle&) expr, out);
1459 case Expression::kVariableReference_Kind:
1460 return this->compileVectorVariableReference(builder, (const VariableReference&) expr,
1461 out);
1462 default:
1463 printf("failed expression: %s\n", expr.description().c_str());
1464 return false;
1465 }
1466}
1467
1468bool JIT::compileVectorStatement(LLVMBuilderRef builder, const Statement& stmt) {
1469 switch (stmt.fKind) {
1470 case Statement::kBlock_Kind:
1471 for (const auto& s : ((const Block&) stmt).fStatements) {
1472 if (!this->compileVectorStatement(builder, *s)) {
1473 return false;
1474 }
1475 }
1476 return true;
1477 case Statement::kExpression_Kind:
1478 LLVMValueRef result;
1479 return this->compileVectorExpression(builder,
1480 *((const ExpressionStatement&) stmt).fExpression,
1481 &result);
1482 default:
1483 printf("failed statement: %s\n", stmt.description().c_str());
1484 return false;
1485 }
1486}
1487
1488bool JIT::compileStageFunctionVector(const FunctionDefinition& f, LLVMValueRef newFunc) {
1489 LLVMValueRef oldFunction = fCurrentFunction;
1490 fCurrentFunction = newFunc;
1491 std::unique_ptr<LLVMValueRef[]> params(new LLVMValueRef[STAGE_PARAM_COUNT]);
1492 LLVMGetParams(fCurrentFunction, params.get());
1493 LLVMValueRef programParam = params.get()[1];
1494 LLVMBuilderRef builder = LLVMCreateBuilderInContext(fContext);
1495 LLVMBasicBlockRef oldAllocaBlock = fAllocaBlock;
1496 LLVMBasicBlockRef oldCurrentBlock = fCurrentBlock;
1497 fAllocaBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "alloca");
1498 this->setBlock(builder, fAllocaBlock);
1499 fChannels[0] = LLVMBuildAlloca(builder, fFloat32VectorType, "rVec");
1500 LLVMBuildStore(builder, params.get()[4], fChannels[0]);
1501 fChannels[1] = LLVMBuildAlloca(builder, fFloat32VectorType, "gVec");
1502 LLVMBuildStore(builder, params.get()[5], fChannels[1]);
1503 fChannels[2] = LLVMBuildAlloca(builder, fFloat32VectorType, "bVec");
1504 LLVMBuildStore(builder, params.get()[6], fChannels[2]);
1505 fChannels[3] = LLVMBuildAlloca(builder, fFloat32VectorType, "aVec");
1506 LLVMBuildStore(builder, params.get()[7], fChannels[3]);
1507 LLVMBasicBlockRef start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "start");
1508 this->setBlock(builder, start);
1509 bool success = this->compileVectorStatement(builder, *f.fBody);
1510 if (success) {
1511 // increment program pointer, call next
1512 LLVMValueRef rawNextPtr = LLVMBuildLoad(builder, programParam, "next load");
1513 LLVMTypeRef stageFuncType = LLVMTypeOf(newFunc);
1514 LLVMValueRef nextPtr = LLVMBuildBitCast(builder, rawNextPtr, stageFuncType,
1515 "cast next->func");
1516 LLVMValueRef nextInc = LLVMBuildIntToPtr(builder,
1517 LLVMBuildAdd(builder,
1518 LLVMBuildPtrToInt(builder,
1519 programParam,
1520 fInt64Type,
1521 "cast 1"),
1522 LLVMConstInt(fInt64Type, PTR_SIZE,
1523 false),
1524 "add"),
1525 LLVMPointerType(fInt8PtrType, 0), "cast 2");
1526 LLVMValueRef args[STAGE_PARAM_COUNT] = {
1527 params.get()[0],
1528 nextInc,
1529 params.get()[2],
1530 params.get()[3],
1531 LLVMBuildLoad(builder, fChannels[0], "rVec"),
1532 LLVMBuildLoad(builder, fChannels[1], "gVec"),
1533 LLVMBuildLoad(builder, fChannels[2], "bVec"),
1534 LLVMBuildLoad(builder, fChannels[3], "aVec"),
1535 params.get()[8],
1536 params.get()[9],
1537 params.get()[10],
1538 params.get()[11]
1539 };
1540 LLVMBuildCall(builder, nextPtr, args, STAGE_PARAM_COUNT, "");
1541 LLVMBuildRetVoid(builder);
1542 // finish
1543 LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
1544 LLVMBuildBr(builder, start);
1545 LLVMDisposeBuilder(builder);
1546 if (LLVMVerifyFunction(fCurrentFunction, LLVMPrintMessageAction)) {
1547 ABORT("verify failed\n");
1548 }
1549 } else {
1550 LLVMDeleteBasicBlock(fAllocaBlock);
1551 LLVMDeleteBasicBlock(start);
1552 }
1553
1554 fAllocaBlock = oldAllocaBlock;
1555 fCurrentBlock = oldCurrentBlock;
1556 fCurrentFunction = oldFunction;
1557 return success;
1558}
1559
1560LLVMValueRef JIT::compileStageFunction(const FunctionDefinition& f) {
1561 LLVMTypeRef returnType = fVoidType;
1562 LLVMTypeRef parameterTypes[12] = { fSizeTType, LLVMPointerType(fInt8PtrType, 0), fSizeTType,
1563 fSizeTType, fFloat32VectorType, fFloat32VectorType,
1564 fFloat32VectorType, fFloat32VectorType, fFloat32VectorType,
1565 fFloat32VectorType, fFloat32VectorType, fFloat32VectorType };
1566 LLVMTypeRef stageFuncType = LLVMFunctionType(returnType, parameterTypes, 12, false);
1567 LLVMValueRef result = LLVMAddFunction(fModule,
1568 (String(f.fDeclaration.fName) + "$stage").c_str(),
1569 stageFuncType);
1570 fColorParam = f.fDeclaration.fParameters[2];
1571 if (!this->compileStageFunctionVector(f, result)) {
1572 // vectorization failed, fall back to looping over the pixels
1573 this->compileStageFunctionLoop(f, result);
1574 }
1575 return result;
1576}
1577
1578bool JIT::hasStageSignature(const FunctionDeclaration& f) {
1579 return f.fReturnType == *fProgram->fContext->fVoid_Type &&
1580 f.fParameters.size() == 3 &&
1581 f.fParameters[0]->fType == *fProgram->fContext->fInt_Type &&
1582 f.fParameters[0]->fModifiers.fFlags == 0 &&
1583 f.fParameters[1]->fType == *fProgram->fContext->fInt_Type &&
1584 f.fParameters[1]->fModifiers.fFlags == 0 &&
1585 f.fParameters[2]->fType == *fProgram->fContext->fFloat4_Type &&
1586 f.fParameters[2]->fModifiers.fFlags == (Modifiers::kIn_Flag | Modifiers::kOut_Flag);
1587}
1588
1589LLVMValueRef JIT::compileFunction(const FunctionDefinition& f) {
1590 if (this->hasStageSignature(f.fDeclaration)) {
1591 this->compileStageFunction(f);
1592 // we compile foo$stage *in addition* to compiling foo, as we can't be sure that the intent
1593 // was to produce an SkJumper stage just because the signature matched or that the function
1594 // is not otherwise called. May need a better way to handle this.
1595 }
1596 LLVMTypeRef returnType = this->getType(f.fDeclaration.fReturnType);
1597 std::vector<LLVMTypeRef> parameterTypes;
1598 for (const auto& p : f.fDeclaration.fParameters) {
1599 LLVMTypeRef type = this->getType(p->fType);
1600 if (p->fModifiers.fFlags & Modifiers::kOut_Flag) {
1601 type = LLVMPointerType(type, 0);
1602 }
1603 parameterTypes.push_back(type);
1604 }
1605 fCurrentFunction = LLVMAddFunction(fModule,
1606 String(f.fDeclaration.fName).c_str(),
1607 LLVMFunctionType(returnType, parameterTypes.data(),
1608 parameterTypes.size(), false));
1609 fFunctions[&f.fDeclaration] = fCurrentFunction;
1610
1611 std::unique_ptr<LLVMValueRef[]> params(new LLVMValueRef[parameterTypes.size()]);
1612 LLVMGetParams(fCurrentFunction, params.get());
1613 for (size_t i = 0; i < f.fDeclaration.fParameters.size(); ++i) {
1614 fVariables[f.fDeclaration.fParameters[i]] = params.get()[i];
1615 }
1616 LLVMBuilderRef builder = LLVMCreateBuilderInContext(fContext);
1617 fAllocaBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "alloca");
1618 LLVMBasicBlockRef start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "start");
1619 fCurrentBlock = start;
1620 LLVMPositionBuilderAtEnd(builder, fCurrentBlock);
1621 this->compileStatement(builder, *f.fBody);
1622 if (!ends_with_branch(*f.fBody)) {
1623 if (f.fDeclaration.fReturnType == *fProgram->fContext->fVoid_Type) {
1624 LLVMBuildRetVoid(builder);
1625 } else {
1626 LLVMBuildUnreachable(builder);
1627 }
1628 }
1629 LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
1630 LLVMBuildBr(builder, start);
1631 LLVMDisposeBuilder(builder);
1632 if (LLVMVerifyFunction(fCurrentFunction, LLVMPrintMessageAction)) {
1633 ABORT("verify failed\n");
1634 }
1635 return fCurrentFunction;
1636}
1637
1638void JIT::createModule() {
1639 fPromotedParameters.clear();
1640 fModule = LLVMModuleCreateWithNameInContext("skslmodule", fContext);
1641 this->loadBuiltinFunctions();
1642 // LLVM doesn't do void*, have to declare it as int8*
1643 LLVMTypeRef appendParams[3] = { fInt8PtrType, fInt32Type, fInt8PtrType };
1644 fAppendFunc = LLVMAddFunction(fModule, "sksl_pipeline_append", LLVMFunctionType(fVoidType,
1645 appendParams,
1646 3,
1647 false));
1648 LLVMTypeRef appendCallbackParams[2] = { fInt8PtrType, fInt8PtrType };
1649 fAppendCallbackFunc = LLVMAddFunction(fModule, "sksl_pipeline_append_callback",
1650 LLVMFunctionType(fVoidType, appendCallbackParams, 2,
1651 false));
1652
1653 LLVMTypeRef debugParams[3] = { fFloat32Type };
1654 fDebugFunc = LLVMAddFunction(fModule, "sksl_debug_print", LLVMFunctionType(fVoidType,
1655 debugParams,
1656 1,
1657 false));
1658
1659 for (const auto& e : fProgram->fElements) {
Ethan Nicholasd9d33c32018-06-12 11:05:59 -04001660 SkASSERT(e->fKind == ProgramElement::kFunction_Kind);
Ethan Nicholas26a9aad2018-03-27 14:10:52 -04001661 this->compileFunction((FunctionDefinition&) *e);
1662 }
1663}
1664
1665std::unique_ptr<JIT::Module> JIT::compile(std::unique_ptr<Program> program) {
1666 fProgram = std::move(program);
1667 this->createModule();
1668 this->optimize();
1669 return std::unique_ptr<Module>(new Module(std::move(fProgram), fSharedModule, fJITStack));
1670}
1671
1672void JIT::optimize() {
1673 LLVMPassManagerBuilderRef pmb = LLVMPassManagerBuilderCreate();
1674 LLVMPassManagerBuilderSetOptLevel(pmb, 3);
1675 LLVMPassManagerRef functionPM = LLVMCreateFunctionPassManagerForModule(fModule);
1676 LLVMPassManagerBuilderPopulateFunctionPassManager(pmb, functionPM);
1677 LLVMPassManagerRef modulePM = LLVMCreatePassManager();
1678 LLVMPassManagerBuilderPopulateModulePassManager(pmb, modulePM);
1679 LLVMInitializeFunctionPassManager(functionPM);
1680
1681 LLVMValueRef func = LLVMGetFirstFunction(fModule);
1682 for (;;) {
1683 if (!func) {
1684 break;
1685 }
1686 LLVMRunFunctionPassManager(functionPM, func);
1687 func = LLVMGetNextFunction(func);
1688 }
1689 LLVMRunPassManager(modulePM, fModule);
1690 LLVMDisposePassManager(functionPM);
1691 LLVMDisposePassManager(modulePM);
1692 LLVMPassManagerBuilderDispose(pmb);
1693
1694 std::string error_string;
1695 if (LLVMLoadLibraryPermanently(nullptr)) {
1696 ABORT("LLVMLoadLibraryPermanently failed");
1697 }
1698 char* defaultTriple = LLVMGetDefaultTargetTriple();
1699 char* error;
1700 LLVMTargetRef target;
1701 if (LLVMGetTargetFromTriple(defaultTriple, &target, &error)) {
1702 ABORT("LLVMGetTargetFromTriple failed");
1703 }
1704
1705 if (!LLVMTargetHasJIT(target)) {
1706 ABORT("!LLVMTargetHasJIT");
1707 }
1708 LLVMTargetMachineRef targetMachine = LLVMCreateTargetMachine(target,
1709 defaultTriple,
1710 fCPU,
1711 nullptr,
1712 LLVMCodeGenLevelDefault,
1713 LLVMRelocDefault,
1714 LLVMCodeModelJITDefault);
1715 LLVMDisposeMessage(defaultTriple);
1716 LLVMTargetDataRef dataLayout = LLVMCreateTargetDataLayout(targetMachine);
1717 LLVMSetModuleDataLayout(fModule, dataLayout);
1718 LLVMDisposeTargetData(dataLayout);
1719
1720 fJITStack = LLVMOrcCreateInstance(targetMachine);
1721 fSharedModule = LLVMOrcMakeSharedModule(fModule);
1722 LLVMOrcModuleHandle orcModule;
1723 LLVMOrcAddEagerlyCompiledIR(fJITStack, &orcModule, fSharedModule,
1724 (LLVMOrcSymbolResolverFn) resolveSymbol, this);
1725 LLVMDisposeTargetMachine(targetMachine);
1726}
1727
1728void* JIT::Module::getSymbol(const char* name) {
1729 LLVMOrcTargetAddress result;
1730 if (LLVMOrcGetSymbolAddress(fJITStack, &result, name)) {
1731 ABORT("GetSymbolAddress error");
1732 }
1733 if (!result) {
1734 ABORT("symbol not found");
1735 }
1736 return (void*) result;
1737}
1738
1739void* JIT::Module::getJumperStage(const char* name) {
1740 return this->getSymbol((String(name) + "$stage").c_str());
1741}
1742
1743} // namespace
1744
1745#endif // SK_LLVM_AVAILABLE
1746
1747#endif // SKSL_STANDALONE