blob: 4120c4aa8c92902a4dea65ecb5c2bdd21a10b291 [file] [log] [blame]
Ethan Nicholas26a9aad2018-03-27 14:10:52 -04001/*
2 * Copyright 2018 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8#ifndef SKSL_STANDALONE
9
10#ifdef SK_LLVM_AVAILABLE
11
12#include "SkSLJIT.h"
13
14#include "SkCpu.h"
15#include "SkRasterPipeline.h"
16#include "../jumper/SkJumper.h"
Ethan Nicholas00543112018-07-31 09:44:36 -040017#include "ir/SkSLAppendStage.h"
Ethan Nicholas26a9aad2018-03-27 14:10:52 -040018#include "ir/SkSLExpressionStatement.h"
19#include "ir/SkSLFunctionCall.h"
20#include "ir/SkSLFunctionReference.h"
21#include "ir/SkSLIndexExpression.h"
22#include "ir/SkSLProgram.h"
Ethan Nicholas00543112018-07-31 09:44:36 -040023#include "ir/SkSLUnresolvedFunction.h"
Ethan Nicholas26a9aad2018-03-27 14:10:52 -040024#include "llvm/ExecutionEngine/RTDyldMemoryManager.h"
25
26static constexpr int MAX_VECTOR_COUNT = 16;
27
28extern "C" void sksl_pipeline_append(SkRasterPipeline* p, int stage, void* ctx) {
29 p->append((SkRasterPipeline::StockStage) stage, ctx);
30}
31
32#define PTR_SIZE sizeof(void*)
33
34extern "C" void sksl_pipeline_append_callback(SkRasterPipeline* p, void* fn) {
35 p->append(fn, nullptr);
36}
37
38extern "C" void sksl_debug_print(float f) {
39 printf("Debug: %f\n", f);
40}
41
Ethan Nicholas00543112018-07-31 09:44:36 -040042extern "C" float sksl_clamp1(float f, float min, float max) {
43 return SkTPin(f, min, max);
44}
45
46using float2 = __attribute__((vector_size(8))) float;
47using float3 = __attribute__((vector_size(16))) float;
48using float4 = __attribute__((vector_size(16))) float;
49
50extern "C" float2 sksl_clamp2(float2 f, float min, float max) {
51 return float2 { SkTPin(f[0], min, max), SkTPin(f[1], min, max) };
52}
53
54extern "C" float3 sksl_clamp3(float3 f, float min, float max) {
55 return float3 { SkTPin(f[0], min, max), SkTPin(f[1], min, max), SkTPin(f[2], min, max) };
56}
57
58extern "C" float4 sksl_clamp4(float4 f, float min, float max) {
59 return float4 { SkTPin(f[0], min, max), SkTPin(f[1], min, max), SkTPin(f[2], min, max),
60 SkTPin(f[3], min, max) };
61}
62
Ethan Nicholas26a9aad2018-03-27 14:10:52 -040063namespace SkSL {
64
65static constexpr int STAGE_PARAM_COUNT = 12;
66
67static bool ends_with_branch(const Statement& stmt) {
68 switch (stmt.fKind) {
69 case Statement::kBlock_Kind: {
70 const Block& b = (const Block&) stmt;
71 if (b.fStatements.size()) {
72 return ends_with_branch(*b.fStatements.back());
73 }
74 return false;
75 }
76 case Statement::kBreak_Kind: // fall through
77 case Statement::kContinue_Kind: // fall through
78 case Statement::kReturn_Kind: // fall through
79 return true;
80 default:
81 return false;
82 }
83}
84
85JIT::JIT(Compiler* compiler)
86: fCompiler(*compiler) {
87 LLVMInitializeNativeTarget();
88 LLVMInitializeNativeAsmPrinter();
89 LLVMLinkInMCJIT();
Ethan Nicholasd9d33c32018-06-12 11:05:59 -040090 SkASSERT(!SkCpu::Supports(SkCpu::SKX)); // not yet supported
Ethan Nicholas26a9aad2018-03-27 14:10:52 -040091 if (SkCpu::Supports(SkCpu::HSW)) {
92 fVectorCount = 8;
93 fCPU = "haswell";
94 } else if (SkCpu::Supports(SkCpu::AVX)) {
95 fVectorCount = 8;
96 fCPU = "ivybridge";
97 } else {
98 fVectorCount = 4;
99 fCPU = nullptr;
100 }
101 fContext = LLVMContextCreate();
102 fVoidType = LLVMVoidTypeInContext(fContext);
103 fInt1Type = LLVMInt1TypeInContext(fContext);
Ethan Nicholas00543112018-07-31 09:44:36 -0400104 fInt1VectorType = LLVMVectorType(fInt1Type, fVectorCount);
105 fInt1Vector2Type = LLVMVectorType(fInt1Type, 2);
106 fInt1Vector3Type = LLVMVectorType(fInt1Type, 3);
107 fInt1Vector4Type = LLVMVectorType(fInt1Type, 4);
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400108 fInt8Type = LLVMInt8TypeInContext(fContext);
109 fInt8PtrType = LLVMPointerType(fInt8Type, 0);
110 fInt32Type = LLVMInt32TypeInContext(fContext);
111 fInt64Type = LLVMInt64TypeInContext(fContext);
112 fSizeTType = LLVMInt64TypeInContext(fContext);
113 fInt32VectorType = LLVMVectorType(fInt32Type, fVectorCount);
114 fInt32Vector2Type = LLVMVectorType(fInt32Type, 2);
115 fInt32Vector3Type = LLVMVectorType(fInt32Type, 3);
116 fInt32Vector4Type = LLVMVectorType(fInt32Type, 4);
117 fFloat32Type = LLVMFloatTypeInContext(fContext);
118 fFloat32VectorType = LLVMVectorType(fFloat32Type, fVectorCount);
119 fFloat32Vector2Type = LLVMVectorType(fFloat32Type, 2);
120 fFloat32Vector3Type = LLVMVectorType(fFloat32Type, 3);
121 fFloat32Vector4Type = LLVMVectorType(fFloat32Type, 4);
122}
123
124JIT::~JIT() {
125 LLVMOrcDisposeInstance(fJITStack);
126 LLVMContextDispose(fContext);
127}
128
129void JIT::addBuiltinFunction(const char* ourName, const char* realName, LLVMTypeRef returnType,
130 std::vector<LLVMTypeRef> parameters) {
Ethan Nicholas00543112018-07-31 09:44:36 -0400131 bool found = false;
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400132 for (const auto& pair : *fProgram->fSymbols) {
133 if (Symbol::kFunctionDeclaration_Kind == pair.second->fKind) {
134 const FunctionDeclaration& f = (const FunctionDeclaration&) *pair.second;
135 if (pair.first != ourName || returnType != this->getType(f.fReturnType) ||
136 parameters.size() != f.fParameters.size()) {
137 continue;
138 }
139 for (size_t i = 0; i < parameters.size(); ++i) {
140 if (parameters[i] != this->getType(f.fParameters[i]->fType)) {
141 goto next;
142 }
143 }
144 fFunctions[&f] = LLVMAddFunction(fModule, realName, LLVMFunctionType(returnType,
145 parameters.data(),
146 parameters.size(),
147 false));
Ethan Nicholas00543112018-07-31 09:44:36 -0400148 found = true;
149 }
150 if (Symbol::kUnresolvedFunction_Kind == pair.second->fKind) {
151 // FIXME consolidate this with the code above
152 for (const auto& f : ((const UnresolvedFunction&) *pair.second).fFunctions) {
153 if (pair.first != ourName || returnType != this->getType(f->fReturnType) ||
154 parameters.size() != f->fParameters.size()) {
155 continue;
156 }
157 for (size_t i = 0; i < parameters.size(); ++i) {
158 if (parameters[i] != this->getType(f->fParameters[i]->fType)) {
159 goto next;
160 }
161 }
162 fFunctions[f] = LLVMAddFunction(fModule, realName, LLVMFunctionType(
163 returnType,
164 parameters.data(),
165 parameters.size(),
166 false));
167 found = true;
168 }
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400169 }
170 next:;
171 }
Ethan Nicholas00543112018-07-31 09:44:36 -0400172 SkASSERT(found);
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400173}
174
175void JIT::loadBuiltinFunctions() {
176 this->addBuiltinFunction("abs", "fabs", fFloat32Type, { fFloat32Type });
177 this->addBuiltinFunction("sin", "sinf", fFloat32Type, { fFloat32Type });
178 this->addBuiltinFunction("cos", "cosf", fFloat32Type, { fFloat32Type });
179 this->addBuiltinFunction("tan", "tanf", fFloat32Type, { fFloat32Type });
180 this->addBuiltinFunction("sqrt", "sqrtf", fFloat32Type, { fFloat32Type });
Ethan Nicholas00543112018-07-31 09:44:36 -0400181 this->addBuiltinFunction("clamp", "sksl_clamp1", fFloat32Type, { fFloat32Type,
182 fFloat32Type,
183 fFloat32Type });
184 this->addBuiltinFunction("clamp", "sksl_clamp2", fFloat32Vector2Type, { fFloat32Vector2Type,
185 fFloat32Type,
186 fFloat32Type });
187 this->addBuiltinFunction("clamp", "sksl_clamp3", fFloat32Vector3Type, { fFloat32Vector3Type,
188 fFloat32Type,
189 fFloat32Type });
190 this->addBuiltinFunction("clamp", "sksl_clamp4", fFloat32Vector4Type, { fFloat32Vector4Type,
191 fFloat32Type,
192 fFloat32Type });
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400193 this->addBuiltinFunction("print", "sksl_debug_print", fVoidType, { fFloat32Type });
194}
195
196uint64_t JIT::resolveSymbol(const char* name, JIT* jit) {
197 LLVMOrcTargetAddress result;
198 if (!LLVMOrcGetSymbolAddress(jit->fJITStack, &result, name)) {
199 if (!strcmp(name, "_sksl_pipeline_append")) {
200 result = (uint64_t) &sksl_pipeline_append;
201 } else if (!strcmp(name, "_sksl_pipeline_append_callback")) {
202 result = (uint64_t) &sksl_pipeline_append_callback;
Ethan Nicholas00543112018-07-31 09:44:36 -0400203 } else if (!strcmp(name, "_sksl_clamp1")) {
204 result = (uint64_t) &sksl_clamp1;
205 } else if (!strcmp(name, "_sksl_clamp2")) {
206 result = (uint64_t) &sksl_clamp2;
207 } else if (!strcmp(name, "_sksl_clamp3")) {
208 result = (uint64_t) &sksl_clamp3;
209 } else if (!strcmp(name, "_sksl_clamp4")) {
210 result = (uint64_t) &sksl_clamp4;
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400211 } else if (!strcmp(name, "_sksl_debug_print")) {
212 result = (uint64_t) &sksl_debug_print;
213 } else {
214 result = llvm::RTDyldMemoryManager::getSymbolAddressInProcess(name);
215 }
216 }
Ethan Nicholasd9d33c32018-06-12 11:05:59 -0400217 SkASSERT(result);
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400218 return result;
219}
220
221LLVMValueRef JIT::compileFunctionCall(LLVMBuilderRef builder, const FunctionCall& fc) {
222 LLVMValueRef func = fFunctions[&fc.fFunction];
Ethan Nicholasd9d33c32018-06-12 11:05:59 -0400223 SkASSERT(func);
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400224 std::vector<LLVMValueRef> parameters;
225 for (const auto& a : fc.fArguments) {
226 parameters.push_back(this->compileExpression(builder, *a));
227 }
228 return LLVMBuildCall(builder, func, parameters.data(), parameters.size(), "");
229}
230
231LLVMTypeRef JIT::getType(const Type& type) {
232 switch (type.kind()) {
233 case Type::kOther_Kind:
234 if (type.name() == "void") {
235 return fVoidType;
236 }
Ethan Nicholasd9d33c32018-06-12 11:05:59 -0400237 SkASSERT(type.name() == "SkRasterPipeline");
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400238 return fInt8PtrType;
239 case Type::kScalar_Kind:
240 if (type.isSigned() || type.isUnsigned()) {
241 return fInt32Type;
242 }
243 if (type.isUnsigned()) {
244 return fInt32Type;
245 }
246 if (type.isFloat()) {
247 return fFloat32Type;
248 }
Ethan Nicholasd9d33c32018-06-12 11:05:59 -0400249 SkASSERT(type.name() == "bool");
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400250 return fInt1Type;
251 case Type::kArray_Kind:
252 return LLVMPointerType(this->getType(type.componentType()), 0);
253 case Type::kVector_Kind:
254 if (type.name() == "float2" || type.name() == "half2") {
255 return fFloat32Vector2Type;
256 }
257 if (type.name() == "float3" || type.name() == "half3") {
258 return fFloat32Vector3Type;
259 }
260 if (type.name() == "float4" || type.name() == "half4") {
261 return fFloat32Vector4Type;
262 }
Ruiqi Maob609e6d2018-07-17 10:19:38 -0400263 if (type.name() == "int2" || type.name() == "short2" || type.name == "byte2") {
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400264 return fInt32Vector2Type;
265 }
Ruiqi Maob609e6d2018-07-17 10:19:38 -0400266 if (type.name() == "int3" || type.name() == "short3" || type.name == "byte3") {
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400267 return fInt32Vector3Type;
268 }
Ruiqi Maob609e6d2018-07-17 10:19:38 -0400269 if (type.name() == "int4" || type.name() == "short4" || type.name == "byte3") {
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400270 return fInt32Vector4Type;
271 }
272 // fall through
273 default:
274 ABORT("unsupported type");
275 }
276}
277
278void JIT::setBlock(LLVMBuilderRef builder, LLVMBasicBlockRef block) {
279 fCurrentBlock = block;
280 LLVMPositionBuilderAtEnd(builder, block);
281}
282
283std::unique_ptr<JIT::LValue> JIT::getLValue(LLVMBuilderRef builder, const Expression& expr) {
284 switch (expr.fKind) {
285 case Expression::kVariableReference_Kind: {
286 class PointerLValue : public LValue {
287 public:
288 PointerLValue(LLVMValueRef ptr)
289 : fPointer(ptr) {}
290
291 LLVMValueRef load(LLVMBuilderRef builder) override {
292 return LLVMBuildLoad(builder, fPointer, "lvalue load");
293 }
294
295 void store(LLVMBuilderRef builder, LLVMValueRef value) override {
296 LLVMBuildStore(builder, value, fPointer);
297 }
298
299 private:
300 LLVMValueRef fPointer;
301 };
302 const Variable* var = &((VariableReference&) expr).fVariable;
303 if (var->fStorage == Variable::kParameter_Storage &&
304 !(var->fModifiers.fFlags & Modifiers::kOut_Flag) &&
305 fPromotedParameters.find(var) == fPromotedParameters.end()) {
306 // promote parameter to variable
307 fPromotedParameters.insert(var);
308 LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
309 LLVMValueRef alloca = LLVMBuildAlloca(builder, this->getType(var->fType),
310 String(var->fName).c_str());
311 LLVMBuildStore(builder, fVariables[var], alloca);
312 LLVMPositionBuilderAtEnd(builder, fCurrentBlock);
313 fVariables[var] = alloca;
314 }
315 LLVMValueRef ptr = fVariables[var];
316 return std::unique_ptr<LValue>(new PointerLValue(ptr));
317 }
318 case Expression::kTernary_Kind: {
319 class TernaryLValue : public LValue {
320 public:
321 TernaryLValue(JIT* jit, LLVMValueRef test, std::unique_ptr<LValue> ifTrue,
322 std::unique_ptr<LValue> ifFalse)
323 : fJIT(*jit)
324 , fTest(test)
325 , fIfTrue(std::move(ifTrue))
326 , fIfFalse(std::move(ifFalse)) {}
327
328 LLVMValueRef load(LLVMBuilderRef builder) override {
329 LLVMBasicBlockRef trueBlock = LLVMAppendBasicBlockInContext(
330 fJIT.fContext,
331 fJIT.fCurrentFunction,
332 "true ? ...");
333 LLVMBasicBlockRef falseBlock = LLVMAppendBasicBlockInContext(
334 fJIT.fContext,
335 fJIT.fCurrentFunction,
336 "false ? ...");
337 LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fJIT.fContext,
338 fJIT.fCurrentFunction,
339 "ternary merge");
340 LLVMBuildCondBr(builder, fTest, trueBlock, falseBlock);
341 fJIT.setBlock(builder, trueBlock);
342 LLVMValueRef ifTrue = fIfTrue->load(builder);
343 LLVMBuildBr(builder, merge);
344 fJIT.setBlock(builder, falseBlock);
345 LLVMValueRef ifFalse = fIfTrue->load(builder);
346 LLVMBuildBr(builder, merge);
347 fJIT.setBlock(builder, merge);
348 LLVMTypeRef type = LLVMPointerType(LLVMTypeOf(ifTrue), 0);
349 LLVMValueRef phi = LLVMBuildPhi(builder, type, "?");
350 LLVMValueRef incomingValues[2] = { ifTrue, ifFalse };
351 LLVMBasicBlockRef incomingBlocks[2] = { trueBlock, falseBlock };
352 LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2);
353 return phi;
354 }
355
356 void store(LLVMBuilderRef builder, LLVMValueRef value) override {
357 LLVMBasicBlockRef trueBlock = LLVMAppendBasicBlockInContext(
358 fJIT.fContext,
359 fJIT.fCurrentFunction,
360 "true ? ...");
361 LLVMBasicBlockRef falseBlock = LLVMAppendBasicBlockInContext(
362 fJIT.fContext,
363 fJIT.fCurrentFunction,
364 "false ? ...");
365 LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fJIT.fContext,
366 fJIT.fCurrentFunction,
367 "ternary merge");
368 LLVMBuildCondBr(builder, fTest, trueBlock, falseBlock);
369 fJIT.setBlock(builder, trueBlock);
370 fIfTrue->store(builder, value);
371 LLVMBuildBr(builder, merge);
372 fJIT.setBlock(builder, falseBlock);
373 fIfTrue->store(builder, value);
374 LLVMBuildBr(builder, merge);
375 fJIT.setBlock(builder, merge);
376 }
377
378 private:
379 JIT& fJIT;
380 LLVMValueRef fTest;
381 std::unique_ptr<LValue> fIfTrue;
382 std::unique_ptr<LValue> fIfFalse;
383 };
384 const TernaryExpression& t = (const TernaryExpression&) expr;
385 LLVMValueRef test = this->compileExpression(builder, *t.fTest);
386 return std::unique_ptr<LValue>(new TernaryLValue(this,
387 test,
388 this->getLValue(builder,
389 *t.fIfTrue),
390 this->getLValue(builder,
391 *t.fIfFalse)));
392 }
393 case Expression::kSwizzle_Kind: {
394 class SwizzleLValue : public LValue {
395 public:
396 SwizzleLValue(JIT* jit, LLVMTypeRef type, std::unique_ptr<LValue> base,
397 std::vector<int> components)
398 : fJIT(*jit)
399 , fType(type)
400 , fBase(std::move(base))
401 , fComponents(components) {}
402
403 LLVMValueRef load(LLVMBuilderRef builder) override {
404 LLVMValueRef base = fBase->load(builder);
405 if (fComponents.size() > 1) {
406 LLVMValueRef result = LLVMGetUndef(fType);
407 for (size_t i = 0; i < fComponents.size(); ++i) {
408 LLVMValueRef element = LLVMBuildExtractElement(
409 builder,
410 base,
411 LLVMConstInt(fJIT.fInt32Type,
412 fComponents[i],
413 false),
414 "swizzle extract");
415 result = LLVMBuildInsertElement(builder, result, element,
416 LLVMConstInt(fJIT.fInt32Type, i, false),
417 "swizzle insert");
418 }
419 return result;
420 }
Ethan Nicholasd9d33c32018-06-12 11:05:59 -0400421 SkASSERT(fComponents.size() == 1);
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400422 return LLVMBuildExtractElement(builder, base,
423 LLVMConstInt(fJIT.fInt32Type,
424 fComponents[0],
425 false),
426 "swizzle extract");
427 }
428
429 void store(LLVMBuilderRef builder, LLVMValueRef value) override {
430 LLVMValueRef result = fBase->load(builder);
431 if (fComponents.size() > 1) {
432 for (size_t i = 0; i < fComponents.size(); ++i) {
433 LLVMValueRef element = LLVMBuildExtractElement(builder, value,
434 LLVMConstInt(
435 fJIT.fInt32Type,
436 i,
437 false),
438 "swizzle extract");
439 result = LLVMBuildInsertElement(builder, result, element,
440 LLVMConstInt(fJIT.fInt32Type,
441 fComponents[i],
442 false),
443 "swizzle insert");
444 }
445 } else {
446 result = LLVMBuildInsertElement(builder, result, value,
447 LLVMConstInt(fJIT.fInt32Type,
448 fComponents[0],
449 false),
450 "swizzle insert");
451 }
452 fBase->store(builder, result);
453 }
454
455 private:
456 JIT& fJIT;
457 LLVMTypeRef fType;
458 std::unique_ptr<LValue> fBase;
459 std::vector<int> fComponents;
460 };
461 const Swizzle& s = (const Swizzle&) expr;
462 return std::unique_ptr<LValue>(new SwizzleLValue(this, this->getType(s.fType),
463 this->getLValue(builder, *s.fBase),
464 s.fComponents));
465 }
466 default:
467 ABORT("unsupported lvalue");
468 }
469}
470
471JIT::TypeKind JIT::typeKind(const Type& type) {
472 if (type.kind() == Type::kVector_Kind) {
473 return this->typeKind(type.componentType());
474 }
Ruiqi Maob609e6d2018-07-17 10:19:38 -0400475 if (type.fName == "int" || type.fName == "short" || type.fName == "byte") {
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400476 return JIT::kInt_TypeKind;
Ruiqi Maob609e6d2018-07-17 10:19:38 -0400477 } else if (type.fName == "uint" || type.fName == "ushort" || type.fName == "ubyte") {
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400478 return JIT::kUInt_TypeKind;
Ethan Nicholas00543112018-07-31 09:44:36 -0400479 } else if (type.fName == "float" || type.fName == "double" || type.fName == "half") {
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400480 return JIT::kFloat_TypeKind;
481 }
482 ABORT("unsupported type: %s\n", type.description().c_str());
483}
484
485void JIT::vectorize(LLVMBuilderRef builder, LLVMValueRef* value, int columns) {
486 LLVMValueRef result = LLVMGetUndef(LLVMVectorType(LLVMTypeOf(*value), columns));
487 for (int i = 0; i < columns; ++i) {
488 result = LLVMBuildInsertElement(builder,
489 result,
490 *value,
491 LLVMConstInt(fInt32Type, i, false),
492 "vectorize");
493 }
494 *value = result;
495}
496
497void JIT::vectorize(LLVMBuilderRef builder, const BinaryExpression& b, LLVMValueRef* left,
498 LLVMValueRef* right) {
499 if (b.fLeft->fType.kind() == Type::kScalar_Kind &&
500 b.fRight->fType.kind() == Type::kVector_Kind) {
501 this->vectorize(builder, left, b.fRight->fType.columns());
502 } else if (b.fLeft->fType.kind() == Type::kVector_Kind &&
503 b.fRight->fType.kind() == Type::kScalar_Kind) {
504 this->vectorize(builder, right, b.fLeft->fType.columns());
505 }
506}
507
508
509LLVMValueRef JIT::compileBinary(LLVMBuilderRef builder, const BinaryExpression& b) {
510 #define BINARY(SFunc, UFunc, FFunc) { \
511 LLVMValueRef left = this->compileExpression(builder, *b.fLeft); \
512 LLVMValueRef right = this->compileExpression(builder, *b.fRight); \
513 this->vectorize(builder, b, &left, &right); \
Ethan Nicholas00543112018-07-31 09:44:36 -0400514 switch (this->typeKind(b.fLeft->fType)) { \
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400515 case kInt_TypeKind: \
516 return SFunc(builder, left, right, "binary"); \
517 case kUInt_TypeKind: \
518 return UFunc(builder, left, right, "binary"); \
519 case kFloat_TypeKind: \
520 return FFunc(builder, left, right, "binary"); \
521 default: \
Ethan Nicholas00543112018-07-31 09:44:36 -0400522 ABORT("unsupported typeKind"); \
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400523 } \
524 }
525 #define COMPOUND(SFunc, UFunc, FFunc) { \
526 std::unique_ptr<LValue> lvalue = this->getLValue(builder, *b.fLeft); \
527 LLVMValueRef left = lvalue->load(builder); \
528 LLVMValueRef right = this->compileExpression(builder, *b.fRight); \
529 this->vectorize(builder, b, &left, &right); \
530 LLVMValueRef result; \
Ethan Nicholas00543112018-07-31 09:44:36 -0400531 switch (this->typeKind(b.fLeft->fType)) { \
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400532 case kInt_TypeKind: \
533 result = SFunc(builder, left, right, "binary"); \
534 break; \
535 case kUInt_TypeKind: \
536 result = UFunc(builder, left, right, "binary"); \
537 break; \
538 case kFloat_TypeKind: \
539 result = FFunc(builder, left, right, "binary"); \
540 break; \
541 default: \
Ethan Nicholas00543112018-07-31 09:44:36 -0400542 ABORT("unsupported typeKind"); \
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400543 } \
544 lvalue->store(builder, result); \
545 return result; \
546 }
547 #define COMPARE(SFunc, SOp, UFunc, UOp, FFunc, FOp) { \
548 LLVMValueRef left = this->compileExpression(builder, *b.fLeft); \
549 LLVMValueRef right = this->compileExpression(builder, *b.fRight); \
550 this->vectorize(builder, b, &left, &right); \
551 switch (this->typeKind(b.fLeft->fType)) { \
552 case kInt_TypeKind: \
553 return SFunc(builder, SOp, left, right, "binary"); \
554 case kUInt_TypeKind: \
555 return UFunc(builder, UOp, left, right, "binary"); \
556 case kFloat_TypeKind: \
557 return FFunc(builder, FOp, left, right, "binary"); \
558 default: \
559 ABORT("unsupported typeKind"); \
560 } \
561 }
562 switch (b.fOperator) {
563 case Token::EQ: {
564 std::unique_ptr<LValue> lvalue = this->getLValue(builder, *b.fLeft);
565 LLVMValueRef result = this->compileExpression(builder, *b.fRight);
566 lvalue->store(builder, result);
567 return result;
568 }
569 case Token::PLUS:
570 BINARY(LLVMBuildAdd, LLVMBuildAdd, LLVMBuildFAdd);
571 case Token::MINUS:
572 BINARY(LLVMBuildSub, LLVMBuildSub, LLVMBuildFSub);
573 case Token::STAR:
574 BINARY(LLVMBuildMul, LLVMBuildMul, LLVMBuildFMul);
575 case Token::SLASH:
576 BINARY(LLVMBuildSDiv, LLVMBuildUDiv, LLVMBuildFDiv);
577 case Token::PERCENT:
578 BINARY(LLVMBuildSRem, LLVMBuildURem, LLVMBuildSRem);
579 case Token::BITWISEAND:
580 BINARY(LLVMBuildAnd, LLVMBuildAnd, LLVMBuildAnd);
581 case Token::BITWISEOR:
582 BINARY(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr);
Ethan Nicholas00543112018-07-31 09:44:36 -0400583 case Token::SHL:
584 BINARY(LLVMBuildShl, LLVMBuildShl, LLVMBuildShl);
585 case Token::SHR:
586 BINARY(LLVMBuildAShr, LLVMBuildLShr, LLVMBuildAShr);
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400587 case Token::PLUSEQ:
588 COMPOUND(LLVMBuildAdd, LLVMBuildAdd, LLVMBuildFAdd);
589 case Token::MINUSEQ:
590 COMPOUND(LLVMBuildSub, LLVMBuildSub, LLVMBuildFSub);
591 case Token::STAREQ:
592 COMPOUND(LLVMBuildMul, LLVMBuildMul, LLVMBuildFMul);
593 case Token::SLASHEQ:
594 COMPOUND(LLVMBuildSDiv, LLVMBuildUDiv, LLVMBuildFDiv);
595 case Token::BITWISEANDEQ:
596 COMPOUND(LLVMBuildAnd, LLVMBuildAnd, LLVMBuildAnd);
597 case Token::BITWISEOREQ:
598 COMPOUND(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr);
599 case Token::EQEQ:
Ethan Nicholas00543112018-07-31 09:44:36 -0400600 switch (b.fLeft->fType.kind()) {
601 case Type::kScalar_Kind:
602 COMPARE(LLVMBuildICmp, LLVMIntEQ,
603 LLVMBuildICmp, LLVMIntEQ,
604 LLVMBuildFCmp, LLVMRealOEQ);
605 case Type::kVector_Kind: {
606 LLVMValueRef left = this->compileExpression(builder, *b.fLeft);
607 LLVMValueRef right = this->compileExpression(builder, *b.fRight);
608 this->vectorize(builder, b, &left, &right);
609 LLVMValueRef value;
610 switch (this->typeKind(b.fLeft->fType)) {
611 case kInt_TypeKind:
612 value = LLVMBuildICmp(builder, LLVMIntEQ, left, right, "binary");
613 break;
614 case kUInt_TypeKind:
615 value = LLVMBuildICmp(builder, LLVMIntEQ, left, right, "binary");
616 break;
617 case kFloat_TypeKind:
618 value = LLVMBuildFCmp(builder, LLVMRealOEQ, left, right, "binary");
619 break;
620 default:
621 ABORT("unsupported typeKind");
622 }
623 LLVMValueRef args[1] = { value };
624 LLVMValueRef func;
625 switch (b.fLeft->fType.columns()) {
626 case 2: func = fFoldAnd2Func; break;
627 case 3: func = fFoldAnd3Func; break;
628 case 4: func = fFoldAnd4Func; break;
629 default:
630 SkASSERT(false);
631 func = fFoldAnd2Func;
632 }
633 return LLVMBuildCall(builder, func, args, 1, "all");
634 }
635 default:
636 SkASSERT(false);
637 }
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400638 case Token::NEQ:
Ethan Nicholas00543112018-07-31 09:44:36 -0400639 switch (b.fLeft->fType.kind()) {
640 case Type::kScalar_Kind:
641 COMPARE(LLVMBuildICmp, LLVMIntNE,
642 LLVMBuildICmp, LLVMIntNE,
643 LLVMBuildFCmp, LLVMRealONE);
644 case Type::kVector_Kind: {
645 LLVMValueRef left = this->compileExpression(builder, *b.fLeft);
646 LLVMValueRef right = this->compileExpression(builder, *b.fRight);
647 this->vectorize(builder, b, &left, &right);
648 LLVMValueRef value;
649 switch (this->typeKind(b.fLeft->fType)) {
650 case kInt_TypeKind:
651 value = LLVMBuildICmp(builder, LLVMIntNE, left, right, "binary");
652 break;
653 case kUInt_TypeKind:
654 value = LLVMBuildICmp(builder, LLVMIntNE, left, right, "binary");
655 break;
656 case kFloat_TypeKind:
657 value = LLVMBuildFCmp(builder, LLVMRealONE, left, right, "binary");
658 break;
659 default:
660 ABORT("unsupported typeKind");
661 }
662 LLVMValueRef args[1] = { value };
663 LLVMValueRef func;
664 switch (b.fLeft->fType.columns()) {
665 case 2: func = fFoldOr2Func; break;
666 case 3: func = fFoldOr3Func; break;
667 case 4: func = fFoldOr4Func; break;
668 default:
669 SkASSERT(false);
670 func = fFoldOr2Func;
671 }
672 return LLVMBuildCall(builder, func, args, 1, "all");
673 }
674 default:
675 SkASSERT(false);
676 }
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400677 case Token::LT:
678 COMPARE(LLVMBuildICmp, LLVMIntSLT,
679 LLVMBuildICmp, LLVMIntULT,
680 LLVMBuildFCmp, LLVMRealOLT);
681 case Token::LTEQ:
682 COMPARE(LLVMBuildICmp, LLVMIntSLE,
683 LLVMBuildICmp, LLVMIntULE,
684 LLVMBuildFCmp, LLVMRealOLE);
685 case Token::GT:
686 COMPARE(LLVMBuildICmp, LLVMIntSGT,
687 LLVMBuildICmp, LLVMIntUGT,
688 LLVMBuildFCmp, LLVMRealOGT);
689 case Token::GTEQ:
690 COMPARE(LLVMBuildICmp, LLVMIntSGE,
691 LLVMBuildICmp, LLVMIntUGE,
692 LLVMBuildFCmp, LLVMRealOGE);
693 case Token::LOGICALAND: {
694 LLVMValueRef left = this->compileExpression(builder, *b.fLeft);
695 LLVMBasicBlockRef ifFalse = fCurrentBlock;
696 LLVMBasicBlockRef ifTrue = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
697 "true && ...");
698 LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
699 "&& merge");
700 LLVMBuildCondBr(builder, left, ifTrue, merge);
701 this->setBlock(builder, ifTrue);
702 LLVMValueRef right = this->compileExpression(builder, *b.fRight);
703 LLVMBuildBr(builder, merge);
704 this->setBlock(builder, merge);
705 LLVMValueRef phi = LLVMBuildPhi(builder, fInt1Type, "&&");
706 LLVMValueRef incomingValues[2] = { right, LLVMConstInt(fInt1Type, 0, false) };
707 LLVMBasicBlockRef incomingBlocks[2] = { ifTrue, ifFalse };
708 LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2);
709 return phi;
710 }
711 case Token::LOGICALOR: {
712 LLVMValueRef left = this->compileExpression(builder, *b.fLeft);
713 LLVMBasicBlockRef ifTrue = fCurrentBlock;
714 LLVMBasicBlockRef ifFalse = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
715 "false || ...");
716 LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
717 "|| merge");
718 LLVMBuildCondBr(builder, left, merge, ifFalse);
719 this->setBlock(builder, ifFalse);
720 LLVMValueRef right = this->compileExpression(builder, *b.fRight);
721 LLVMBuildBr(builder, merge);
722 this->setBlock(builder, merge);
723 LLVMValueRef phi = LLVMBuildPhi(builder, fInt1Type, "||");
724 LLVMValueRef incomingValues[2] = { right, LLVMConstInt(fInt1Type, 1, false) };
725 LLVMBasicBlockRef incomingBlocks[2] = { ifFalse, ifTrue };
726 LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2);
727 return phi;
728 }
729 default:
Ethan Nicholas00543112018-07-31 09:44:36 -0400730 printf("%s\n", b.description().c_str());
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400731 ABORT("unsupported binary operator");
732 }
733}
734
735LLVMValueRef JIT::compileIndex(LLVMBuilderRef builder, const IndexExpression& idx) {
736 LLVMValueRef base = this->compileExpression(builder, *idx.fBase);
737 LLVMValueRef index = this->compileExpression(builder, *idx.fIndex);
738 LLVMValueRef ptr = LLVMBuildGEP(builder, base, &index, 1, "index ptr");
739 return LLVMBuildLoad(builder, ptr, "index load");
740}
741
742LLVMValueRef JIT::compilePostfix(LLVMBuilderRef builder, const PostfixExpression& p) {
743 std::unique_ptr<LValue> lvalue = this->getLValue(builder, *p.fOperand);
744 LLVMValueRef result = lvalue->load(builder);
745 LLVMValueRef mod;
746 LLVMValueRef one = LLVMConstInt(this->getType(p.fType), 1, false);
747 switch (p.fOperator) {
748 case Token::PLUSPLUS:
749 switch (this->typeKind(p.fType)) {
750 case kInt_TypeKind: // fall through
751 case kUInt_TypeKind:
752 mod = LLVMBuildAdd(builder, result, one, "++");
753 break;
754 case kFloat_TypeKind:
755 mod = LLVMBuildFAdd(builder, result, one, "++");
756 break;
757 default:
758 ABORT("unsupported typeKind");
759 }
760 break;
761 case Token::MINUSMINUS:
762 switch (this->typeKind(p.fType)) {
763 case kInt_TypeKind: // fall through
764 case kUInt_TypeKind:
765 mod = LLVMBuildSub(builder, result, one, "--");
766 break;
767 case kFloat_TypeKind:
768 mod = LLVMBuildFSub(builder, result, one, "--");
769 break;
770 default:
771 ABORT("unsupported typeKind");
772 }
773 break;
774 default:
775 ABORT("unsupported postfix op");
776 }
777 lvalue->store(builder, mod);
778 return result;
779}
780
781LLVMValueRef JIT::compilePrefix(LLVMBuilderRef builder, const PrefixExpression& p) {
782 LLVMValueRef one = LLVMConstInt(this->getType(p.fType), 1, false);
783 if (Token::LOGICALNOT == p.fOperator) {
784 LLVMValueRef base = this->compileExpression(builder, *p.fOperand);
785 return LLVMBuildXor(builder, base, one, "!");
786 }
787 if (Token::MINUS == p.fOperator) {
788 LLVMValueRef base = this->compileExpression(builder, *p.fOperand);
789 return LLVMBuildSub(builder, LLVMConstInt(this->getType(p.fType), 0, false), base, "-");
790 }
791 std::unique_ptr<LValue> lvalue = this->getLValue(builder, *p.fOperand);
792 LLVMValueRef raw = lvalue->load(builder);
793 LLVMValueRef result;
794 switch (p.fOperator) {
795 case Token::PLUSPLUS:
796 switch (this->typeKind(p.fType)) {
797 case kInt_TypeKind: // fall through
798 case kUInt_TypeKind:
799 result = LLVMBuildAdd(builder, raw, one, "++");
800 break;
801 case kFloat_TypeKind:
802 result = LLVMBuildFAdd(builder, raw, one, "++");
803 break;
804 default:
805 ABORT("unsupported typeKind");
806 }
807 break;
808 case Token::MINUSMINUS:
809 switch (this->typeKind(p.fType)) {
810 case kInt_TypeKind: // fall through
811 case kUInt_TypeKind:
812 result = LLVMBuildSub(builder, raw, one, "--");
813 break;
814 case kFloat_TypeKind:
815 result = LLVMBuildFSub(builder, raw, one, "--");
816 break;
817 default:
818 ABORT("unsupported typeKind");
819 }
820 break;
821 default:
822 ABORT("unsupported prefix op");
823 }
824 lvalue->store(builder, result);
825 return result;
826}
827
828LLVMValueRef JIT::compileVariableReference(LLVMBuilderRef builder, const VariableReference& v) {
829 const Variable& var = v.fVariable;
830 if (Variable::kParameter_Storage == var.fStorage &&
831 !(var.fModifiers.fFlags & Modifiers::kOut_Flag) &&
832 fPromotedParameters.find(&var) == fPromotedParameters.end()) {
833 return fVariables[&var];
834 }
835 return LLVMBuildLoad(builder, fVariables[&var], String(var.fName).c_str());
836}
837
838void JIT::appendStage(LLVMBuilderRef builder, const AppendStage& a) {
Ethan Nicholasd9d33c32018-06-12 11:05:59 -0400839 SkASSERT(a.fArguments.size() >= 1);
840 SkASSERT(a.fArguments[0]->fType == *fCompiler.context().fSkRasterPipeline_Type);
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400841 LLVMValueRef pipeline = this->compileExpression(builder, *a.fArguments[0]);
842 LLVMValueRef stage = LLVMConstInt(fInt32Type, a.fStage, 0);
843 switch (a.fStage) {
844 case SkRasterPipeline::callback: {
Ethan Nicholasd9d33c32018-06-12 11:05:59 -0400845 SkASSERT(a.fArguments.size() == 2);
846 SkASSERT(a.fArguments[1]->fKind == Expression::kFunctionReference_Kind);
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400847 const FunctionDeclaration& functionDecl =
848 *((FunctionReference&) *a.fArguments[1]).fFunctions[0];
849 bool found = false;
Ethan Nicholas00543112018-07-31 09:44:36 -0400850 for (const auto& pe : *fProgram) {
851 if (ProgramElement::kFunction_Kind == pe.fKind) {
852 const FunctionDefinition& def = (const FunctionDefinition&) pe;
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400853 if (&def.fDeclaration == &functionDecl) {
854 LLVMValueRef fn = this->compileStageFunction(def);
855 LLVMValueRef args[2] = {
856 pipeline,
857 LLVMBuildBitCast(builder, fn, fInt8PtrType, "callback cast")
858 };
859 LLVMBuildCall(builder, fAppendCallbackFunc, args, 2, "");
860 found = true;
861 break;
862 }
863 }
864 }
Ethan Nicholasd9d33c32018-06-12 11:05:59 -0400865 SkASSERT(found);
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400866 break;
867 }
868 default: {
869 LLVMValueRef ctx;
870 if (a.fArguments.size() == 2) {
871 ctx = this->compileExpression(builder, *a.fArguments[1]);
872 ctx = LLVMBuildBitCast(builder, ctx, fInt8PtrType, "context cast");
873 } else {
Ethan Nicholasd9d33c32018-06-12 11:05:59 -0400874 SkASSERT(a.fArguments.size() == 1);
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400875 ctx = LLVMConstNull(fInt8PtrType);
876 }
877 LLVMValueRef args[3] = {
878 pipeline,
879 stage,
880 ctx
881 };
882 LLVMBuildCall(builder, fAppendFunc, args, 3, "");
883 break;
884 }
885 }
886}
887
888LLVMValueRef JIT::compileConstructor(LLVMBuilderRef builder, const Constructor& c) {
889 switch (c.fType.kind()) {
890 case Type::kScalar_Kind: {
Ethan Nicholasd9d33c32018-06-12 11:05:59 -0400891 SkASSERT(c.fArguments.size() == 1);
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400892 TypeKind from = this->typeKind(c.fArguments[0]->fType);
893 TypeKind to = this->typeKind(c.fType);
894 LLVMValueRef base = this->compileExpression(builder, *c.fArguments[0]);
Ethan Nicholas00543112018-07-31 09:44:36 -0400895 switch (to) {
896 case kFloat_TypeKind:
897 switch (from) {
898 case kInt_TypeKind:
899 return LLVMBuildSIToFP(builder, base, this->getType(c.fType), "cast");
900 case kUInt_TypeKind:
901 return LLVMBuildUIToFP(builder, base, this->getType(c.fType), "cast");
902 case kFloat_TypeKind:
903 return base;
904 case kBool_TypeKind:
905 SkASSERT(false);
906 }
907 case kInt_TypeKind:
908 switch (from) {
909 case kInt_TypeKind:
910 return base;
911 case kUInt_TypeKind:
912 return base;
913 case kFloat_TypeKind:
914 return LLVMBuildFPToSI(builder, base, this->getType(c.fType), "cast");
915 case kBool_TypeKind:
916 SkASSERT(false);
917 }
918 case kUInt_TypeKind:
919 switch (from) {
920 case kInt_TypeKind:
921 return base;
922 case kUInt_TypeKind:
923 return base;
924 case kFloat_TypeKind:
925 return LLVMBuildFPToUI(builder, base, this->getType(c.fType), "cast");
926 case kBool_TypeKind:
927 SkASSERT(false);
928 }
929 case kBool_TypeKind:
930 SkASSERT(false);
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400931 }
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400932 }
933 case Type::kVector_Kind: {
934 LLVMValueRef vec = LLVMGetUndef(this->getType(c.fType));
Ethan Nicholas00543112018-07-31 09:44:36 -0400935 if (c.fArguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) {
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400936 LLVMValueRef value = this->compileExpression(builder, *c.fArguments[0]);
937 for (int i = 0; i < c.fType.columns(); ++i) {
938 vec = LLVMBuildInsertElement(builder, vec, value,
939 LLVMConstInt(fInt32Type, i, false),
Ethan Nicholas00543112018-07-31 09:44:36 -0400940 "vec build 1");
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400941 }
942 } else {
Ethan Nicholas00543112018-07-31 09:44:36 -0400943 int index = 0;
944 for (const auto& arg : c.fArguments) {
945 LLVMValueRef value = this->compileExpression(builder, *arg);
946 if (arg->fType.kind() == Type::kVector_Kind) {
947 for (int i = 0; i < arg->fType.columns(); ++i) {
948 LLVMValueRef column = LLVMBuildExtractElement(builder,
949 vec,
950 LLVMConstInt(fInt32Type,
951 i,
952 false),
953 "construct extract");
954 vec = LLVMBuildInsertElement(builder, vec, column,
955 LLVMConstInt(fInt32Type, index++, false),
956 "vec build 2");
957 }
958 } else {
959 vec = LLVMBuildInsertElement(builder, vec, value,
960 LLVMConstInt(fInt32Type, index++, false),
961 "vec build 3");
962 }
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400963 }
964 }
965 return vec;
966 }
967 default:
968 break;
969 }
970 ABORT("unsupported constructor");
971}
972
973LLVMValueRef JIT::compileSwizzle(LLVMBuilderRef builder, const Swizzle& s) {
974 LLVMValueRef base = this->compileExpression(builder, *s.fBase);
975 if (s.fComponents.size() > 1) {
976 LLVMValueRef result = LLVMGetUndef(this->getType(s.fType));
977 for (size_t i = 0; i < s.fComponents.size(); ++i) {
978 LLVMValueRef element = LLVMBuildExtractElement(
979 builder,
980 base,
981 LLVMConstInt(fInt32Type,
982 s.fComponents[i],
983 false),
984 "swizzle extract");
985 result = LLVMBuildInsertElement(builder, result, element,
986 LLVMConstInt(fInt32Type, i, false),
987 "swizzle insert");
988 }
989 return result;
990 }
Ethan Nicholasd9d33c32018-06-12 11:05:59 -0400991 SkASSERT(s.fComponents.size() == 1);
Ethan Nicholas26a9aad2018-03-27 14:10:52 -0400992 return LLVMBuildExtractElement(builder, base,
993 LLVMConstInt(fInt32Type,
994 s.fComponents[0],
995 false),
996 "swizzle extract");
997}
998
999LLVMValueRef JIT::compileTernary(LLVMBuilderRef builder, const TernaryExpression& t) {
1000 LLVMValueRef test = this->compileExpression(builder, *t.fTest);
1001 LLVMBasicBlockRef trueBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
1002 "if true");
1003 LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
1004 "if merge");
1005 LLVMBasicBlockRef falseBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
1006 "if false");
1007 LLVMBuildCondBr(builder, test, trueBlock, falseBlock);
1008 this->setBlock(builder, trueBlock);
1009 LLVMValueRef ifTrue = this->compileExpression(builder, *t.fIfTrue);
1010 trueBlock = fCurrentBlock;
1011 LLVMBuildBr(builder, merge);
1012 this->setBlock(builder, falseBlock);
1013 LLVMValueRef ifFalse = this->compileExpression(builder, *t.fIfFalse);
1014 falseBlock = fCurrentBlock;
1015 LLVMBuildBr(builder, merge);
1016 this->setBlock(builder, merge);
1017 LLVMValueRef phi = LLVMBuildPhi(builder, this->getType(t.fType), "?");
1018 LLVMValueRef incomingValues[2] = { ifTrue, ifFalse };
1019 LLVMBasicBlockRef incomingBlocks[2] = { trueBlock, falseBlock };
1020 LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2);
1021 return phi;
1022}
1023
1024LLVMValueRef JIT::compileExpression(LLVMBuilderRef builder, const Expression& expr) {
1025 switch (expr.fKind) {
1026 case Expression::kAppendStage_Kind: {
1027 this->appendStage(builder, (const AppendStage&) expr);
1028 return LLVMValueRef();
1029 }
1030 case Expression::kBinary_Kind:
1031 return this->compileBinary(builder, (BinaryExpression&) expr);
1032 case Expression::kBoolLiteral_Kind:
1033 return LLVMConstInt(fInt1Type, ((BoolLiteral&) expr).fValue, false);
1034 case Expression::kConstructor_Kind:
1035 return this->compileConstructor(builder, (Constructor&) expr);
1036 case Expression::kIntLiteral_Kind:
1037 return LLVMConstInt(this->getType(expr.fType), ((IntLiteral&) expr).fValue, true);
1038 case Expression::kFieldAccess_Kind:
1039 abort();
1040 case Expression::kFloatLiteral_Kind:
1041 return LLVMConstReal(this->getType(expr.fType), ((FloatLiteral&) expr).fValue);
1042 case Expression::kFunctionCall_Kind:
1043 return this->compileFunctionCall(builder, (FunctionCall&) expr);
1044 case Expression::kIndex_Kind:
1045 return this->compileIndex(builder, (IndexExpression&) expr);
1046 case Expression::kPrefix_Kind:
1047 return this->compilePrefix(builder, (PrefixExpression&) expr);
1048 case Expression::kPostfix_Kind:
1049 return this->compilePostfix(builder, (PostfixExpression&) expr);
1050 case Expression::kSetting_Kind:
1051 abort();
1052 case Expression::kSwizzle_Kind:
1053 return this->compileSwizzle(builder, (Swizzle&) expr);
1054 case Expression::kVariableReference_Kind:
1055 return this->compileVariableReference(builder, (VariableReference&) expr);
1056 case Expression::kTernary_Kind:
1057 return this->compileTernary(builder, (TernaryExpression&) expr);
1058 case Expression::kTypeReference_Kind:
1059 abort();
1060 default:
1061 abort();
1062 }
1063 ABORT("unsupported expression: %s\n", expr.description().c_str());
1064}
1065
1066void JIT::compileBlock(LLVMBuilderRef builder, const Block& block) {
1067 for (const auto& stmt : block.fStatements) {
1068 this->compileStatement(builder, *stmt);
1069 }
1070}
1071
1072void JIT::compileVarDeclarations(LLVMBuilderRef builder, const VarDeclarationsStatement& decls) {
1073 for (const auto& declStatement : decls.fDeclaration->fVars) {
1074 const VarDeclaration& decl = (VarDeclaration&) *declStatement;
1075 LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
1076 LLVMValueRef alloca = LLVMBuildAlloca(builder, this->getType(decl.fVar->fType),
1077 String(decl.fVar->fName).c_str());
1078 fVariables[decl.fVar] = alloca;
1079 LLVMPositionBuilderAtEnd(builder, fCurrentBlock);
1080 if (decl.fValue) {
1081 LLVMValueRef result = this->compileExpression(builder, *decl.fValue);
1082 LLVMBuildStore(builder, result, alloca);
1083 }
1084 }
1085}
1086
1087void JIT::compileIf(LLVMBuilderRef builder, const IfStatement& i) {
1088 LLVMValueRef test = this->compileExpression(builder, *i.fTest);
1089 LLVMBasicBlockRef ifTrue = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "if true");
1090 LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
1091 "if merge");
1092 LLVMBasicBlockRef ifFalse;
1093 if (i.fIfFalse) {
1094 ifFalse = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "if false");
1095 } else {
1096 ifFalse = merge;
1097 }
1098 LLVMBuildCondBr(builder, test, ifTrue, ifFalse);
1099 this->setBlock(builder, ifTrue);
1100 this->compileStatement(builder, *i.fIfTrue);
1101 if (!ends_with_branch(*i.fIfTrue)) {
1102 LLVMBuildBr(builder, merge);
1103 }
1104 if (i.fIfFalse) {
1105 this->setBlock(builder, ifFalse);
1106 this->compileStatement(builder, *i.fIfFalse);
1107 if (!ends_with_branch(*i.fIfFalse)) {
1108 LLVMBuildBr(builder, merge);
1109 }
1110 }
1111 this->setBlock(builder, merge);
1112}
1113
1114void JIT::compileFor(LLVMBuilderRef builder, const ForStatement& f) {
1115 if (f.fInitializer) {
1116 this->compileStatement(builder, *f.fInitializer);
1117 }
1118 LLVMBasicBlockRef start;
1119 LLVMBasicBlockRef body = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for body");
1120 LLVMBasicBlockRef next = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for next");
1121 LLVMBasicBlockRef end = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for end");
1122 if (f.fTest) {
1123 start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for test");
1124 LLVMBuildBr(builder, start);
1125 this->setBlock(builder, start);
1126 LLVMValueRef test = this->compileExpression(builder, *f.fTest);
1127 LLVMBuildCondBr(builder, test, body, end);
1128 } else {
1129 start = body;
1130 LLVMBuildBr(builder, body);
1131 }
1132 this->setBlock(builder, body);
1133 fBreakTarget.push_back(end);
1134 fContinueTarget.push_back(next);
1135 this->compileStatement(builder, *f.fStatement);
1136 fBreakTarget.pop_back();
1137 fContinueTarget.pop_back();
1138 if (!ends_with_branch(*f.fStatement)) {
1139 LLVMBuildBr(builder, next);
1140 }
1141 this->setBlock(builder, next);
1142 if (f.fNext) {
1143 this->compileExpression(builder, *f.fNext);
1144 }
1145 LLVMBuildBr(builder, start);
1146 this->setBlock(builder, end);
1147}
1148
1149void JIT::compileDo(LLVMBuilderRef builder, const DoStatement& d) {
1150 LLVMBasicBlockRef testBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
1151 "do test");
1152 LLVMBasicBlockRef body = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
1153 "do body");
1154 LLVMBasicBlockRef end = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
1155 "do end");
1156 LLVMBuildBr(builder, body);
1157 this->setBlock(builder, testBlock);
1158 LLVMValueRef test = this->compileExpression(builder, *d.fTest);
1159 LLVMBuildCondBr(builder, test, body, end);
1160 this->setBlock(builder, body);
1161 fBreakTarget.push_back(end);
1162 fContinueTarget.push_back(body);
1163 this->compileStatement(builder, *d.fStatement);
1164 fBreakTarget.pop_back();
1165 fContinueTarget.pop_back();
1166 if (!ends_with_branch(*d.fStatement)) {
1167 LLVMBuildBr(builder, testBlock);
1168 }
1169 this->setBlock(builder, end);
1170}
1171
1172void JIT::compileWhile(LLVMBuilderRef builder, const WhileStatement& w) {
1173 LLVMBasicBlockRef testBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
1174 "while test");
1175 LLVMBasicBlockRef body = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
1176 "while body");
1177 LLVMBasicBlockRef end = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
1178 "while end");
1179 LLVMBuildBr(builder, testBlock);
1180 this->setBlock(builder, testBlock);
1181 LLVMValueRef test = this->compileExpression(builder, *w.fTest);
1182 LLVMBuildCondBr(builder, test, body, end);
1183 this->setBlock(builder, body);
1184 fBreakTarget.push_back(end);
1185 fContinueTarget.push_back(testBlock);
1186 this->compileStatement(builder, *w.fStatement);
1187 fBreakTarget.pop_back();
1188 fContinueTarget.pop_back();
1189 if (!ends_with_branch(*w.fStatement)) {
1190 LLVMBuildBr(builder, testBlock);
1191 }
1192 this->setBlock(builder, end);
1193}
1194
1195void JIT::compileBreak(LLVMBuilderRef builder, const BreakStatement& b) {
1196 LLVMBuildBr(builder, fBreakTarget.back());
1197}
1198
1199void JIT::compileContinue(LLVMBuilderRef builder, const ContinueStatement& b) {
1200 LLVMBuildBr(builder, fContinueTarget.back());
1201}
1202
1203void JIT::compileReturn(LLVMBuilderRef builder, const ReturnStatement& r) {
1204 if (r.fExpression) {
1205 LLVMBuildRet(builder, this->compileExpression(builder, *r.fExpression));
1206 } else {
1207 LLVMBuildRetVoid(builder);
1208 }
1209}
1210
1211void JIT::compileStatement(LLVMBuilderRef builder, const Statement& stmt) {
1212 switch (stmt.fKind) {
1213 case Statement::kBlock_Kind:
1214 this->compileBlock(builder, (Block&) stmt);
1215 break;
1216 case Statement::kBreak_Kind:
1217 this->compileBreak(builder, (BreakStatement&) stmt);
1218 break;
1219 case Statement::kContinue_Kind:
1220 this->compileContinue(builder, (ContinueStatement&) stmt);
1221 break;
1222 case Statement::kDiscard_Kind:
1223 abort();
1224 case Statement::kDo_Kind:
1225 this->compileDo(builder, (DoStatement&) stmt);
1226 break;
1227 case Statement::kExpression_Kind:
1228 this->compileExpression(builder, *((ExpressionStatement&) stmt).fExpression);
1229 break;
1230 case Statement::kFor_Kind:
1231 this->compileFor(builder, (ForStatement&) stmt);
1232 break;
1233 case Statement::kGroup_Kind:
1234 abort();
1235 case Statement::kIf_Kind:
1236 this->compileIf(builder, (IfStatement&) stmt);
1237 break;
1238 case Statement::kNop_Kind:
1239 break;
1240 case Statement::kReturn_Kind:
1241 this->compileReturn(builder, (ReturnStatement&) stmt);
1242 break;
1243 case Statement::kSwitch_Kind:
1244 abort();
1245 case Statement::kVarDeclarations_Kind:
1246 this->compileVarDeclarations(builder, (VarDeclarationsStatement&) stmt);
1247 break;
1248 case Statement::kWhile_Kind:
1249 this->compileWhile(builder, (WhileStatement&) stmt);
1250 break;
1251 default:
1252 abort();
1253 }
1254}
1255
1256void JIT::compileStageFunctionLoop(const FunctionDefinition& f, LLVMValueRef newFunc) {
1257 // loop over fVectorCount pixels, running the body of the stage function for each of them
1258 LLVMValueRef oldFunction = fCurrentFunction;
1259 fCurrentFunction = newFunc;
1260 std::unique_ptr<LLVMValueRef[]> params(new LLVMValueRef[STAGE_PARAM_COUNT]);
1261 LLVMGetParams(fCurrentFunction, params.get());
1262 LLVMValueRef programParam = params.get()[1];
1263 LLVMBuilderRef builder = LLVMCreateBuilderInContext(fContext);
1264 LLVMBasicBlockRef oldAllocaBlock = fAllocaBlock;
1265 LLVMBasicBlockRef oldCurrentBlock = fCurrentBlock;
1266 fAllocaBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "alloca");
1267 this->setBlock(builder, fAllocaBlock);
1268 // temporaries to store the color channel vectors
1269 LLVMValueRef rVec = LLVMBuildAlloca(builder, fFloat32VectorType, "rVec");
1270 LLVMBuildStore(builder, params.get()[4], rVec);
1271 LLVMValueRef gVec = LLVMBuildAlloca(builder, fFloat32VectorType, "gVec");
1272 LLVMBuildStore(builder, params.get()[5], gVec);
1273 LLVMValueRef bVec = LLVMBuildAlloca(builder, fFloat32VectorType, "bVec");
1274 LLVMBuildStore(builder, params.get()[6], bVec);
1275 LLVMValueRef aVec = LLVMBuildAlloca(builder, fFloat32VectorType, "aVec");
1276 LLVMBuildStore(builder, params.get()[7], aVec);
1277 LLVMValueRef color = LLVMBuildAlloca(builder, fFloat32Vector4Type, "color");
1278 fVariables[f.fDeclaration.fParameters[1]] = LLVMBuildTrunc(builder, params.get()[3], fInt32Type,
1279 "y->Int32");
1280 fVariables[f.fDeclaration.fParameters[2]] = color;
1281 LLVMValueRef ivar = LLVMBuildAlloca(builder, fInt32Type, "i");
1282 LLVMBuildStore(builder, LLVMConstInt(fInt32Type, 0, false), ivar);
1283 LLVMBasicBlockRef start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "start");
1284 this->setBlock(builder, start);
1285 LLVMValueRef iload = LLVMBuildLoad(builder, ivar, "load i");
1286 fVariables[f.fDeclaration.fParameters[0]] = LLVMBuildAdd(builder,
1287 LLVMBuildTrunc(builder,
1288 params.get()[2],
1289 fInt32Type,
1290 "x->Int32"),
1291 iload,
1292 "x");
1293 LLVMValueRef vectorSize = LLVMConstInt(fInt32Type, fVectorCount, false);
1294 LLVMValueRef test = LLVMBuildICmp(builder, LLVMIntSLT, iload, vectorSize, "i < vectorSize");
1295 LLVMBasicBlockRef loopBody = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "body");
1296 LLVMBasicBlockRef loopEnd = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "end");
1297 LLVMBuildCondBr(builder, test, loopBody, loopEnd);
1298 this->setBlock(builder, loopBody);
1299 LLVMValueRef vec = LLVMGetUndef(fFloat32Vector4Type);
1300 // extract the r, g, b, and a values from the color channel vectors and store them into "color"
1301 for (int i = 0; i < 4; ++i) {
1302 vec = LLVMBuildInsertElement(builder, vec,
1303 LLVMBuildExtractElement(builder,
1304 params.get()[4 + i],
1305 iload, "initial"),
1306 LLVMConstInt(fInt32Type, i, false),
1307 "vec build");
1308 }
1309 LLVMBuildStore(builder, vec, color);
1310 // write actual loop body
1311 this->compileStatement(builder, *f.fBody);
1312 // extract the r, g, b, and a values from "color" and stick them back into the color channel
1313 // vectors
1314 LLVMValueRef colorLoad = LLVMBuildLoad(builder, color, "color load");
1315 LLVMBuildStore(builder,
1316 LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, rVec, "rVec"),
1317 LLVMBuildExtractElement(builder, colorLoad,
1318 LLVMConstInt(fInt32Type, 0,
1319 false),
1320 "rExtract"),
1321 iload, "rInsert"),
1322 rVec);
1323 LLVMBuildStore(builder,
1324 LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, gVec, "gVec"),
1325 LLVMBuildExtractElement(builder, colorLoad,
1326 LLVMConstInt(fInt32Type, 1,
1327 false),
1328 "gExtract"),
1329 iload, "gInsert"),
1330 gVec);
1331 LLVMBuildStore(builder,
1332 LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, bVec, "bVec"),
1333 LLVMBuildExtractElement(builder, colorLoad,
1334 LLVMConstInt(fInt32Type, 2,
1335 false),
1336 "bExtract"),
1337 iload, "bInsert"),
1338 bVec);
1339 LLVMBuildStore(builder,
1340 LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, aVec, "aVec"),
1341 LLVMBuildExtractElement(builder, colorLoad,
1342 LLVMConstInt(fInt32Type, 3,
1343 false),
1344 "aExtract"),
1345 iload, "aInsert"),
1346 aVec);
1347 LLVMValueRef inc = LLVMBuildAdd(builder, iload, LLVMConstInt(fInt32Type, 1, false), "inc i");
1348 LLVMBuildStore(builder, inc, ivar);
1349 LLVMBuildBr(builder, start);
1350 this->setBlock(builder, loopEnd);
1351 // increment program pointer, call the next stage
1352 LLVMValueRef rawNextPtr = LLVMBuildLoad(builder, programParam, "next load");
1353 LLVMTypeRef stageFuncType = LLVMTypeOf(newFunc);
1354 LLVMValueRef nextPtr = LLVMBuildBitCast(builder, rawNextPtr, stageFuncType, "cast next->func");
1355 LLVMValueRef nextInc = LLVMBuildIntToPtr(builder,
1356 LLVMBuildAdd(builder,
1357 LLVMBuildPtrToInt(builder,
1358 programParam,
1359 fInt64Type,
1360 "cast 1"),
1361 LLVMConstInt(fInt64Type, PTR_SIZE, false),
1362 "add"),
1363 LLVMPointerType(fInt8PtrType, 0), "cast 2");
1364 LLVMValueRef args[STAGE_PARAM_COUNT] = {
1365 params.get()[0],
1366 nextInc,
1367 params.get()[2],
1368 params.get()[3],
1369 LLVMBuildLoad(builder, rVec, "rVec"),
1370 LLVMBuildLoad(builder, gVec, "gVec"),
1371 LLVMBuildLoad(builder, bVec, "bVec"),
1372 LLVMBuildLoad(builder, aVec, "aVec"),
1373 params.get()[8],
1374 params.get()[9],
1375 params.get()[10],
1376 params.get()[11]
1377 };
1378 LLVMBuildCall(builder, nextPtr, args, STAGE_PARAM_COUNT, "");
1379 LLVMBuildRetVoid(builder);
1380 // finish
1381 LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
1382 LLVMBuildBr(builder, start);
1383 LLVMDisposeBuilder(builder);
1384 if (LLVMVerifyFunction(fCurrentFunction, LLVMPrintMessageAction)) {
1385 ABORT("verify failed\n");
1386 }
1387 fAllocaBlock = oldAllocaBlock;
1388 fCurrentBlock = oldCurrentBlock;
1389 fCurrentFunction = oldFunction;
1390}
1391
1392// FIXME maybe pluggable code generators? Need to do something to separate all
1393// of the normal codegen from the vector codegen and break this up into multiple
1394// classes.
1395
1396bool JIT::getVectorLValue(LLVMBuilderRef builder, const Expression& e,
1397 LLVMValueRef out[CHANNELS]) {
1398 switch (e.fKind) {
1399 case Expression::kVariableReference_Kind:
1400 if (fColorParam == &((VariableReference&) e).fVariable) {
1401 memcpy(out, fChannels, sizeof(fChannels));
1402 return true;
1403 }
1404 return false;
1405 case Expression::kSwizzle_Kind: {
1406 const Swizzle& s = (const Swizzle&) e;
1407 LLVMValueRef base[CHANNELS];
1408 if (!this->getVectorLValue(builder, *s.fBase, base)) {
1409 return false;
1410 }
1411 for (size_t i = 0; i < s.fComponents.size(); ++i) {
1412 out[i] = base[s.fComponents[i]];
1413 }
1414 return true;
1415 }
1416 default:
1417 return false;
1418 }
1419}
1420
1421bool JIT::getVectorBinaryOperands(LLVMBuilderRef builder, const Expression& left,
1422 LLVMValueRef outLeft[CHANNELS], const Expression& right,
1423 LLVMValueRef outRight[CHANNELS]) {
1424 if (!this->compileVectorExpression(builder, left, outLeft)) {
1425 return false;
1426 }
1427 int leftColumns = left.fType.columns();
1428 int rightColumns = right.fType.columns();
1429 if (leftColumns == 1 && rightColumns > 1) {
1430 for (int i = 1; i < rightColumns; ++i) {
1431 outLeft[i] = outLeft[0];
1432 }
1433 }
1434 if (!this->compileVectorExpression(builder, right, outRight)) {
1435 return false;
1436 }
1437 if (rightColumns == 1 && leftColumns > 1) {
1438 for (int i = 1; i < leftColumns; ++i) {
1439 outRight[i] = outRight[0];
1440 }
1441 }
1442 return true;
1443}
1444
1445bool JIT::compileVectorBinary(LLVMBuilderRef builder, const BinaryExpression& b,
1446 LLVMValueRef out[CHANNELS]) {
1447 LLVMValueRef left[CHANNELS];
1448 LLVMValueRef right[CHANNELS];
1449 #define VECTOR_BINARY(signedOp, unsignedOp, floatOp) { \
1450 if (!this->getVectorBinaryOperands(builder, *b.fLeft, left, *b.fRight, right)) { \
1451 return false; \
1452 } \
1453 for (int i = 0; i < b.fLeft->fType.columns(); ++i) { \
1454 switch (this->typeKind(b.fLeft->fType)) { \
1455 case kInt_TypeKind: \
1456 out[i] = signedOp(builder, left[i], right[i], "binary"); \
1457 break; \
1458 case kUInt_TypeKind: \
1459 out[i] = unsignedOp(builder, left[i], right[i], "binary"); \
1460 break; \
1461 case kFloat_TypeKind: \
1462 out[i] = floatOp(builder, left[i], right[i], "binary"); \
1463 break; \
1464 case kBool_TypeKind: \
Ethan Nicholasd9d33c32018-06-12 11:05:59 -04001465 SkASSERT(false); \
Ethan Nicholas26a9aad2018-03-27 14:10:52 -04001466 break; \
1467 } \
1468 } \
1469 return true; \
1470 }
1471 switch (b.fOperator) {
1472 case Token::EQ: {
1473 if (!this->getVectorLValue(builder, *b.fLeft, left)) {
1474 return false;
1475 }
1476 if (!this->compileVectorExpression(builder, *b.fRight, right)) {
1477 return false;
1478 }
1479 int columns = b.fRight->fType.columns();
1480 for (int i = 0; i < columns; ++i) {
1481 LLVMBuildStore(builder, right[i], left[i]);
1482 }
1483 return true;
1484 }
1485 case Token::PLUS:
1486 VECTOR_BINARY(LLVMBuildAdd, LLVMBuildAdd, LLVMBuildFAdd);
1487 case Token::MINUS:
1488 VECTOR_BINARY(LLVMBuildSub, LLVMBuildSub, LLVMBuildFSub);
1489 case Token::STAR:
1490 VECTOR_BINARY(LLVMBuildMul, LLVMBuildMul, LLVMBuildFMul);
1491 case Token::SLASH:
1492 VECTOR_BINARY(LLVMBuildSDiv, LLVMBuildUDiv, LLVMBuildFDiv);
1493 case Token::PERCENT:
1494 VECTOR_BINARY(LLVMBuildSRem, LLVMBuildURem, LLVMBuildSRem);
1495 case Token::BITWISEAND:
1496 VECTOR_BINARY(LLVMBuildAnd, LLVMBuildAnd, LLVMBuildAnd);
1497 case Token::BITWISEOR:
1498 VECTOR_BINARY(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr);
1499 default:
1500 printf("unsupported operator: %s\n", b.description().c_str());
1501 return false;
1502 }
1503}
1504
1505bool JIT::compileVectorConstructor(LLVMBuilderRef builder, const Constructor& c,
1506 LLVMValueRef out[CHANNELS]) {
1507 switch (c.fType.kind()) {
1508 case Type::kScalar_Kind: {
Ethan Nicholasd9d33c32018-06-12 11:05:59 -04001509 SkASSERT(c.fArguments.size() == 1);
Ethan Nicholas26a9aad2018-03-27 14:10:52 -04001510 TypeKind from = this->typeKind(c.fArguments[0]->fType);
1511 TypeKind to = this->typeKind(c.fType);
1512 LLVMValueRef base[CHANNELS];
1513 if (!this->compileVectorExpression(builder, *c.fArguments[0], base)) {
1514 return false;
1515 }
1516 #define CONSTRUCT(fn) \
1517 out[0] = LLVMGetUndef(LLVMVectorType(this->getType(c.fType), fVectorCount)); \
1518 for (int i = 0; i < fVectorCount; ++i) { \
1519 LLVMValueRef index = LLVMConstInt(fInt32Type, i, false); \
1520 LLVMValueRef baseVal = LLVMBuildExtractElement(builder, base[0], index, \
1521 "construct extract"); \
1522 out[0] = LLVMBuildInsertElement(builder, out[0], \
1523 fn(builder, baseVal, this->getType(c.fType), \
1524 "cast"), \
1525 index, "construct insert"); \
1526 } \
1527 return true;
1528 if (kFloat_TypeKind == to) {
1529 if (kInt_TypeKind == from) {
1530 CONSTRUCT(LLVMBuildSIToFP);
1531 }
1532 if (kUInt_TypeKind == from) {
1533 CONSTRUCT(LLVMBuildUIToFP);
1534 }
1535 }
1536 if (kInt_TypeKind == to) {
1537 if (kFloat_TypeKind == from) {
1538 CONSTRUCT(LLVMBuildFPToSI);
1539 }
1540 if (kUInt_TypeKind == from) {
1541 return true;
1542 }
1543 }
1544 if (kUInt_TypeKind == to) {
1545 if (kFloat_TypeKind == from) {
1546 CONSTRUCT(LLVMBuildFPToUI);
1547 }
1548 if (kInt_TypeKind == from) {
1549 return base;
1550 }
1551 }
1552 printf("%s\n", c.description().c_str());
1553 ABORT("unsupported constructor");
1554 }
1555 case Type::kVector_Kind: {
1556 if (c.fArguments.size() == 1) {
1557 LLVMValueRef base[CHANNELS];
1558 if (!this->compileVectorExpression(builder, *c.fArguments[0], base)) {
1559 return false;
1560 }
1561 for (int i = 0; i < c.fType.columns(); ++i) {
1562 out[i] = base[0];
1563 }
1564 } else {
Ethan Nicholasd9d33c32018-06-12 11:05:59 -04001565 SkASSERT(c.fArguments.size() == (size_t) c.fType.columns());
Ethan Nicholas26a9aad2018-03-27 14:10:52 -04001566 for (int i = 0; i < c.fType.columns(); ++i) {
1567 LLVMValueRef base[CHANNELS];
1568 if (!this->compileVectorExpression(builder, *c.fArguments[i], base)) {
1569 return false;
1570 }
1571 out[i] = base[0];
1572 }
1573 }
1574 return true;
1575 }
1576 default:
1577 break;
1578 }
1579 ABORT("unsupported constructor");
1580}
1581
1582bool JIT::compileVectorFloatLiteral(LLVMBuilderRef builder,
1583 const FloatLiteral& f,
1584 LLVMValueRef out[CHANNELS]) {
1585 LLVMValueRef value = LLVMConstReal(this->getType(f.fType), f.fValue);
1586 LLVMValueRef values[MAX_VECTOR_COUNT];
1587 for (int i = 0; i < fVectorCount; ++i) {
1588 values[i] = value;
1589 }
1590 out[0] = LLVMConstVector(values, fVectorCount);
1591 return true;
1592}
1593
1594
1595bool JIT::compileVectorSwizzle(LLVMBuilderRef builder, const Swizzle& s,
1596 LLVMValueRef out[CHANNELS]) {
1597 LLVMValueRef all[CHANNELS];
1598 if (!this->compileVectorExpression(builder, *s.fBase, all)) {
1599 return false;
1600 }
1601 for (size_t i = 0; i < s.fComponents.size(); ++i) {
1602 out[i] = all[s.fComponents[i]];
1603 }
1604 return true;
1605}
1606
1607bool JIT::compileVectorVariableReference(LLVMBuilderRef builder, const VariableReference& v,
1608 LLVMValueRef out[CHANNELS]) {
1609 if (&v.fVariable == fColorParam) {
1610 for (int i = 0; i < CHANNELS; ++i) {
1611 out[i] = LLVMBuildLoad(builder, fChannels[i], "variable reference");
1612 }
1613 return true;
1614 }
1615 return false;
1616}
1617
1618bool JIT::compileVectorExpression(LLVMBuilderRef builder, const Expression& expr,
1619 LLVMValueRef out[CHANNELS]) {
1620 switch (expr.fKind) {
1621 case Expression::kBinary_Kind:
1622 return this->compileVectorBinary(builder, (const BinaryExpression&) expr, out);
1623 case Expression::kConstructor_Kind:
1624 return this->compileVectorConstructor(builder, (const Constructor&) expr, out);
1625 case Expression::kFloatLiteral_Kind:
1626 return this->compileVectorFloatLiteral(builder, (const FloatLiteral&) expr, out);
1627 case Expression::kSwizzle_Kind:
1628 return this->compileVectorSwizzle(builder, (const Swizzle&) expr, out);
1629 case Expression::kVariableReference_Kind:
1630 return this->compileVectorVariableReference(builder, (const VariableReference&) expr,
1631 out);
1632 default:
Ethan Nicholas26a9aad2018-03-27 14:10:52 -04001633 return false;
1634 }
1635}
1636
1637bool JIT::compileVectorStatement(LLVMBuilderRef builder, const Statement& stmt) {
1638 switch (stmt.fKind) {
1639 case Statement::kBlock_Kind:
1640 for (const auto& s : ((const Block&) stmt).fStatements) {
1641 if (!this->compileVectorStatement(builder, *s)) {
1642 return false;
1643 }
1644 }
1645 return true;
1646 case Statement::kExpression_Kind:
1647 LLVMValueRef result;
1648 return this->compileVectorExpression(builder,
1649 *((const ExpressionStatement&) stmt).fExpression,
1650 &result);
1651 default:
Ethan Nicholas26a9aad2018-03-27 14:10:52 -04001652 return false;
1653 }
1654}
1655
1656bool JIT::compileStageFunctionVector(const FunctionDefinition& f, LLVMValueRef newFunc) {
1657 LLVMValueRef oldFunction = fCurrentFunction;
1658 fCurrentFunction = newFunc;
1659 std::unique_ptr<LLVMValueRef[]> params(new LLVMValueRef[STAGE_PARAM_COUNT]);
1660 LLVMGetParams(fCurrentFunction, params.get());
1661 LLVMValueRef programParam = params.get()[1];
1662 LLVMBuilderRef builder = LLVMCreateBuilderInContext(fContext);
1663 LLVMBasicBlockRef oldAllocaBlock = fAllocaBlock;
1664 LLVMBasicBlockRef oldCurrentBlock = fCurrentBlock;
1665 fAllocaBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "alloca");
1666 this->setBlock(builder, fAllocaBlock);
1667 fChannels[0] = LLVMBuildAlloca(builder, fFloat32VectorType, "rVec");
1668 LLVMBuildStore(builder, params.get()[4], fChannels[0]);
1669 fChannels[1] = LLVMBuildAlloca(builder, fFloat32VectorType, "gVec");
1670 LLVMBuildStore(builder, params.get()[5], fChannels[1]);
1671 fChannels[2] = LLVMBuildAlloca(builder, fFloat32VectorType, "bVec");
1672 LLVMBuildStore(builder, params.get()[6], fChannels[2]);
1673 fChannels[3] = LLVMBuildAlloca(builder, fFloat32VectorType, "aVec");
1674 LLVMBuildStore(builder, params.get()[7], fChannels[3]);
1675 LLVMBasicBlockRef start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "start");
1676 this->setBlock(builder, start);
1677 bool success = this->compileVectorStatement(builder, *f.fBody);
1678 if (success) {
1679 // increment program pointer, call next
1680 LLVMValueRef rawNextPtr = LLVMBuildLoad(builder, programParam, "next load");
1681 LLVMTypeRef stageFuncType = LLVMTypeOf(newFunc);
1682 LLVMValueRef nextPtr = LLVMBuildBitCast(builder, rawNextPtr, stageFuncType,
1683 "cast next->func");
1684 LLVMValueRef nextInc = LLVMBuildIntToPtr(builder,
1685 LLVMBuildAdd(builder,
1686 LLVMBuildPtrToInt(builder,
1687 programParam,
1688 fInt64Type,
1689 "cast 1"),
1690 LLVMConstInt(fInt64Type, PTR_SIZE,
1691 false),
1692 "add"),
1693 LLVMPointerType(fInt8PtrType, 0), "cast 2");
1694 LLVMValueRef args[STAGE_PARAM_COUNT] = {
1695 params.get()[0],
1696 nextInc,
1697 params.get()[2],
1698 params.get()[3],
1699 LLVMBuildLoad(builder, fChannels[0], "rVec"),
1700 LLVMBuildLoad(builder, fChannels[1], "gVec"),
1701 LLVMBuildLoad(builder, fChannels[2], "bVec"),
1702 LLVMBuildLoad(builder, fChannels[3], "aVec"),
1703 params.get()[8],
1704 params.get()[9],
1705 params.get()[10],
1706 params.get()[11]
1707 };
1708 LLVMBuildCall(builder, nextPtr, args, STAGE_PARAM_COUNT, "");
1709 LLVMBuildRetVoid(builder);
1710 // finish
1711 LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
1712 LLVMBuildBr(builder, start);
1713 LLVMDisposeBuilder(builder);
1714 if (LLVMVerifyFunction(fCurrentFunction, LLVMPrintMessageAction)) {
1715 ABORT("verify failed\n");
1716 }
1717 } else {
1718 LLVMDeleteBasicBlock(fAllocaBlock);
1719 LLVMDeleteBasicBlock(start);
1720 }
1721
1722 fAllocaBlock = oldAllocaBlock;
1723 fCurrentBlock = oldCurrentBlock;
1724 fCurrentFunction = oldFunction;
1725 return success;
1726}
1727
1728LLVMValueRef JIT::compileStageFunction(const FunctionDefinition& f) {
1729 LLVMTypeRef returnType = fVoidType;
1730 LLVMTypeRef parameterTypes[12] = { fSizeTType, LLVMPointerType(fInt8PtrType, 0), fSizeTType,
1731 fSizeTType, fFloat32VectorType, fFloat32VectorType,
1732 fFloat32VectorType, fFloat32VectorType, fFloat32VectorType,
1733 fFloat32VectorType, fFloat32VectorType, fFloat32VectorType };
1734 LLVMTypeRef stageFuncType = LLVMFunctionType(returnType, parameterTypes, 12, false);
1735 LLVMValueRef result = LLVMAddFunction(fModule,
1736 (String(f.fDeclaration.fName) + "$stage").c_str(),
1737 stageFuncType);
1738 fColorParam = f.fDeclaration.fParameters[2];
1739 if (!this->compileStageFunctionVector(f, result)) {
1740 // vectorization failed, fall back to looping over the pixels
1741 this->compileStageFunctionLoop(f, result);
1742 }
1743 return result;
1744}
1745
1746bool JIT::hasStageSignature(const FunctionDeclaration& f) {
1747 return f.fReturnType == *fProgram->fContext->fVoid_Type &&
1748 f.fParameters.size() == 3 &&
1749 f.fParameters[0]->fType == *fProgram->fContext->fInt_Type &&
1750 f.fParameters[0]->fModifiers.fFlags == 0 &&
1751 f.fParameters[1]->fType == *fProgram->fContext->fInt_Type &&
1752 f.fParameters[1]->fModifiers.fFlags == 0 &&
Ethan Nicholas00543112018-07-31 09:44:36 -04001753 f.fParameters[2]->fType == *fProgram->fContext->fHalf4_Type &&
Ethan Nicholas26a9aad2018-03-27 14:10:52 -04001754 f.fParameters[2]->fModifiers.fFlags == (Modifiers::kIn_Flag | Modifiers::kOut_Flag);
1755}
1756
1757LLVMValueRef JIT::compileFunction(const FunctionDefinition& f) {
1758 if (this->hasStageSignature(f.fDeclaration)) {
1759 this->compileStageFunction(f);
1760 // we compile foo$stage *in addition* to compiling foo, as we can't be sure that the intent
1761 // was to produce an SkJumper stage just because the signature matched or that the function
1762 // is not otherwise called. May need a better way to handle this.
1763 }
1764 LLVMTypeRef returnType = this->getType(f.fDeclaration.fReturnType);
1765 std::vector<LLVMTypeRef> parameterTypes;
1766 for (const auto& p : f.fDeclaration.fParameters) {
1767 LLVMTypeRef type = this->getType(p->fType);
1768 if (p->fModifiers.fFlags & Modifiers::kOut_Flag) {
1769 type = LLVMPointerType(type, 0);
1770 }
1771 parameterTypes.push_back(type);
1772 }
1773 fCurrentFunction = LLVMAddFunction(fModule,
1774 String(f.fDeclaration.fName).c_str(),
1775 LLVMFunctionType(returnType, parameterTypes.data(),
1776 parameterTypes.size(), false));
1777 fFunctions[&f.fDeclaration] = fCurrentFunction;
1778
1779 std::unique_ptr<LLVMValueRef[]> params(new LLVMValueRef[parameterTypes.size()]);
1780 LLVMGetParams(fCurrentFunction, params.get());
1781 for (size_t i = 0; i < f.fDeclaration.fParameters.size(); ++i) {
1782 fVariables[f.fDeclaration.fParameters[i]] = params.get()[i];
1783 }
1784 LLVMBuilderRef builder = LLVMCreateBuilderInContext(fContext);
1785 fAllocaBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "alloca");
1786 LLVMBasicBlockRef start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "start");
1787 fCurrentBlock = start;
1788 LLVMPositionBuilderAtEnd(builder, fCurrentBlock);
1789 this->compileStatement(builder, *f.fBody);
1790 if (!ends_with_branch(*f.fBody)) {
1791 if (f.fDeclaration.fReturnType == *fProgram->fContext->fVoid_Type) {
1792 LLVMBuildRetVoid(builder);
1793 } else {
1794 LLVMBuildUnreachable(builder);
1795 }
1796 }
1797 LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
1798 LLVMBuildBr(builder, start);
1799 LLVMDisposeBuilder(builder);
1800 if (LLVMVerifyFunction(fCurrentFunction, LLVMPrintMessageAction)) {
1801 ABORT("verify failed\n");
1802 }
1803 return fCurrentFunction;
1804}
1805
1806void JIT::createModule() {
1807 fPromotedParameters.clear();
1808 fModule = LLVMModuleCreateWithNameInContext("skslmodule", fContext);
1809 this->loadBuiltinFunctions();
Ethan Nicholas00543112018-07-31 09:44:36 -04001810 LLVMTypeRef fold2Params[1] = { fInt1Vector2Type };
1811 fFoldAnd2Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.and.i1.v2i1",
1812 LLVMFunctionType(fInt1Type, fold2Params, 1, false));
1813 fFoldOr2Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.or.i1.v2i1",
1814 LLVMFunctionType(fInt1Type, fold2Params, 1, false));
1815 LLVMTypeRef fold3Params[1] = { fInt1Vector3Type };
1816 fFoldAnd3Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.and.i1.v3i1",
1817 LLVMFunctionType(fInt1Type, fold3Params, 1, false));
1818 fFoldOr3Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.or.i1.v3i1",
1819 LLVMFunctionType(fInt1Type, fold3Params, 1, false));
1820 LLVMTypeRef fold4Params[1] = { fInt1Vector4Type };
1821 fFoldAnd4Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.and.i1.v4i1",
1822 LLVMFunctionType(fInt1Type, fold4Params, 1, false));
1823 fFoldOr4Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.or.i1.v4i1",
1824 LLVMFunctionType(fInt1Type, fold4Params, 1, false));
Ethan Nicholas26a9aad2018-03-27 14:10:52 -04001825 // LLVM doesn't do void*, have to declare it as int8*
1826 LLVMTypeRef appendParams[3] = { fInt8PtrType, fInt32Type, fInt8PtrType };
1827 fAppendFunc = LLVMAddFunction(fModule, "sksl_pipeline_append", LLVMFunctionType(fVoidType,
1828 appendParams,
1829 3,
1830 false));
1831 LLVMTypeRef appendCallbackParams[2] = { fInt8PtrType, fInt8PtrType };
1832 fAppendCallbackFunc = LLVMAddFunction(fModule, "sksl_pipeline_append_callback",
1833 LLVMFunctionType(fVoidType, appendCallbackParams, 2,
1834 false));
1835
1836 LLVMTypeRef debugParams[3] = { fFloat32Type };
1837 fDebugFunc = LLVMAddFunction(fModule, "sksl_debug_print", LLVMFunctionType(fVoidType,
1838 debugParams,
1839 1,
1840 false));
1841
Ethan Nicholas00543112018-07-31 09:44:36 -04001842 for (const auto& e : *fProgram) {
1843 if (e.fKind == ProgramElement::kFunction_Kind) {
1844 this->compileFunction((FunctionDefinition&) e);
1845 }
Ethan Nicholas26a9aad2018-03-27 14:10:52 -04001846 }
1847}
1848
1849std::unique_ptr<JIT::Module> JIT::compile(std::unique_ptr<Program> program) {
Ethan Nicholas00543112018-07-31 09:44:36 -04001850 fCompiler.optimize(*program);
Ethan Nicholas26a9aad2018-03-27 14:10:52 -04001851 fProgram = std::move(program);
1852 this->createModule();
1853 this->optimize();
1854 return std::unique_ptr<Module>(new Module(std::move(fProgram), fSharedModule, fJITStack));
1855}
1856
1857void JIT::optimize() {
1858 LLVMPassManagerBuilderRef pmb = LLVMPassManagerBuilderCreate();
1859 LLVMPassManagerBuilderSetOptLevel(pmb, 3);
1860 LLVMPassManagerRef functionPM = LLVMCreateFunctionPassManagerForModule(fModule);
1861 LLVMPassManagerBuilderPopulateFunctionPassManager(pmb, functionPM);
1862 LLVMPassManagerRef modulePM = LLVMCreatePassManager();
1863 LLVMPassManagerBuilderPopulateModulePassManager(pmb, modulePM);
1864 LLVMInitializeFunctionPassManager(functionPM);
1865
1866 LLVMValueRef func = LLVMGetFirstFunction(fModule);
1867 for (;;) {
1868 if (!func) {
1869 break;
1870 }
1871 LLVMRunFunctionPassManager(functionPM, func);
1872 func = LLVMGetNextFunction(func);
1873 }
1874 LLVMRunPassManager(modulePM, fModule);
1875 LLVMDisposePassManager(functionPM);
1876 LLVMDisposePassManager(modulePM);
1877 LLVMPassManagerBuilderDispose(pmb);
1878
1879 std::string error_string;
1880 if (LLVMLoadLibraryPermanently(nullptr)) {
1881 ABORT("LLVMLoadLibraryPermanently failed");
1882 }
1883 char* defaultTriple = LLVMGetDefaultTargetTriple();
1884 char* error;
1885 LLVMTargetRef target;
1886 if (LLVMGetTargetFromTriple(defaultTriple, &target, &error)) {
1887 ABORT("LLVMGetTargetFromTriple failed");
1888 }
1889
1890 if (!LLVMTargetHasJIT(target)) {
1891 ABORT("!LLVMTargetHasJIT");
1892 }
1893 LLVMTargetMachineRef targetMachine = LLVMCreateTargetMachine(target,
1894 defaultTriple,
1895 fCPU,
1896 nullptr,
1897 LLVMCodeGenLevelDefault,
1898 LLVMRelocDefault,
1899 LLVMCodeModelJITDefault);
1900 LLVMDisposeMessage(defaultTriple);
1901 LLVMTargetDataRef dataLayout = LLVMCreateTargetDataLayout(targetMachine);
1902 LLVMSetModuleDataLayout(fModule, dataLayout);
1903 LLVMDisposeTargetData(dataLayout);
1904
1905 fJITStack = LLVMOrcCreateInstance(targetMachine);
1906 fSharedModule = LLVMOrcMakeSharedModule(fModule);
1907 LLVMOrcModuleHandle orcModule;
1908 LLVMOrcAddEagerlyCompiledIR(fJITStack, &orcModule, fSharedModule,
1909 (LLVMOrcSymbolResolverFn) resolveSymbol, this);
1910 LLVMDisposeTargetMachine(targetMachine);
1911}
1912
1913void* JIT::Module::getSymbol(const char* name) {
1914 LLVMOrcTargetAddress result;
1915 if (LLVMOrcGetSymbolAddress(fJITStack, &result, name)) {
1916 ABORT("GetSymbolAddress error");
1917 }
1918 if (!result) {
1919 ABORT("symbol not found");
1920 }
1921 return (void*) result;
1922}
1923
1924void* JIT::Module::getJumperStage(const char* name) {
1925 return this->getSymbol((String(name) + "$stage").c_str());
1926}
1927
1928} // namespace
1929
1930#endif // SK_LLVM_AVAILABLE
1931
1932#endif // SKSL_STANDALONE