Revert "Reland "Reland "Revert "Initial land of SkSL DSL.""""
This reverts commit 6b07e0eb497c3aa86d1ab6c238d9fa27d01b435c.
Change-Id: Ic01f31edf55b2d1a7533e0e8ed33b39b4846d937
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/343106
Reviewed-by: John Stiles <johnstiles@google.com>
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
Auto-Submit: Ethan Nicholas <ethannicholas@google.com>
diff --git a/gn/sksl.gni b/gn/sksl.gni
index c376851..6c78c80 100644
--- a/gn/sksl.gni
+++ b/gn/sksl.gni
@@ -54,6 +54,13 @@
"$_src/sksl/SkSLStringStream.h",
"$_src/sksl/SkSLUtil.cpp",
"$_src/sksl/SkSLUtil.h",
+ "$_src/sksl/dsl/DSLBlock.cpp",
+ "$_src/sksl/dsl/DSLExpression.cpp",
+ "$_src/sksl/dsl/DSLStatement.cpp",
+ "$_src/sksl/dsl/DSLType.cpp",
+ "$_src/sksl/dsl/DSLVar.cpp",
+ "$_src/sksl/dsl/DSL_core.cpp",
+ "$_src/sksl/dsl/priv/DSLWriter.cpp",
"$_src/sksl/ir/SkSLBinaryExpression.h",
"$_src/sksl/ir/SkSLBlock.h",
"$_src/sksl/ir/SkSLBoolLiteral.h",
diff --git a/gn/tests.gni b/gn/tests.gni
index 46b1866..d87f2ba 100644
--- a/gn/tests.gni
+++ b/gn/tests.gni
@@ -264,6 +264,7 @@
"$_tests/SkResourceCacheTest.cpp",
"$_tests/SkRuntimeEffectTest.cpp",
"$_tests/SkSLCross.cpp",
+ "$_tests/SkSLDSLTest.cpp",
"$_tests/SkSLFPTestbed.cpp",
"$_tests/SkSLGLSLTestbed.cpp",
"$_tests/SkSLInterpreterTest.cpp",
diff --git a/src/gpu/glsl/GrGLSLShaderBuilder.cpp b/src/gpu/glsl/GrGLSLShaderBuilder.cpp
index 7517b3c..ede30a1 100644
--- a/src/gpu/glsl/GrGLSLShaderBuilder.cpp
+++ b/src/gpu/glsl/GrGLSLShaderBuilder.cpp
@@ -13,6 +13,8 @@
#include "src/gpu/glsl/GrGLSLBlend.h"
#include "src/gpu/glsl/GrGLSLColorSpaceXformHelper.h"
#include "src/gpu/glsl/GrGLSLProgramBuilder.h"
+#include "src/sksl/dsl/DSL.h"
+#include "src/sksl/dsl/priv/DSLWriter.h"
GrGLSLShaderBuilder::GrGLSLShaderBuilder(GrGLSLProgramBuilder* program)
: fProgramBuilder(program)
@@ -22,6 +24,7 @@
, fCodeIndex(kCode)
, fFinalized(false)
, fTmpVariableCounter(0) {
+ SkSL::dsl::DSLWriter::Reset();
// We push back some dummy pointers which will later become our header
for (int i = 0; i <= kCode; i++) {
fShaderStrings.push_back();
@@ -81,6 +84,13 @@
this->functions().append(";\n");
}
+void GrGLSLShaderBuilder::codeAppend(SkSL::dsl::Statement stmt) {
+ std::unique_ptr<SkSL::Statement> skslStmt = stmt.release();
+ if (skslStmt) {
+ this->codeAppend(skslStmt->description().c_str());
+ }
+}
+
static inline void append_texture_swizzle(SkString* out, GrSwizzle swizzle) {
if (swizzle != GrSwizzle::RGBA()) {
out->appendf(".%s", swizzle.asString().c_str());
diff --git a/src/gpu/glsl/GrGLSLShaderBuilder.h b/src/gpu/glsl/GrGLSLShaderBuilder.h
index f95c03a..3bd431f 100644
--- a/src/gpu/glsl/GrGLSLShaderBuilder.h
+++ b/src/gpu/glsl/GrGLSLShaderBuilder.h
@@ -17,6 +17,14 @@
#include <stdarg.h>
+namespace SkSL {
+
+namespace dsl {
+ class DSLStatement;
+} // namespace dsl
+
+} // namespace SkSL
+
class GrGLSLColorSpaceXformHelper;
/**
@@ -109,6 +117,8 @@
void codeAppend(const char* str, size_t length) { this->code().append(str, length); }
+ void codeAppend(SkSL::dsl::DSLStatement stmt);
+
void codePrependf(const char format[], ...) SK_PRINTF_LIKE(2, 3) {
va_list args;
va_start(args, format);
diff --git a/src/sksl/SkSLCFGGenerator.cpp b/src/sksl/SkSLCFGGenerator.cpp
index 46c60e6..f3877b5 100644
--- a/src/sksl/SkSLCFGGenerator.cpp
+++ b/src/sksl/SkSLCFGGenerator.cpp
@@ -367,6 +367,9 @@
}
break;
}
+ case Expression::Kind::kCodeString:
+ SkDEBUGFAIL("shouldn't be able to receive kCodeString here");
+ break;
case Expression::Kind::kConstructor: {
Constructor& c = e->get()->as<Constructor>();
for (auto& arg : c.arguments()) {
diff --git a/src/sksl/SkSLCompiler.cpp b/src/sksl/SkSLCompiler.cpp
index 2726614..d610b3b 100644
--- a/src/sksl/SkSLCompiler.cpp
+++ b/src/sksl/SkSLCompiler.cpp
@@ -1813,6 +1813,7 @@
bool Compiler::optimize(LoadedModule& module) {
SkASSERT(!fErrorCount);
+ const Program::Settings* oldSettings = fIRGenerator->fSettings;
Program::Settings settings;
fIRGenerator->fKind = module.fKind;
fIRGenerator->fSettings = &settings;
@@ -1837,6 +1838,7 @@
break;
}
}
+ fIRGenerator->fSettings = oldSettings;
return fErrorCount == 0;
}
@@ -2159,10 +2161,13 @@
fErrorText += "error: " + (pos.fLine >= 1 ? to_string(pos.fLine) + ": " : "") + msg + "\n";
}
-String Compiler::errorText() {
- this->writeErrorCount();
+String Compiler::errorText(bool showCount) {
+ if (showCount) {
+ this->writeErrorCount();
+ }
fErrorCount = 0;
String result = fErrorText;
+ fErrorText = "";
return result;
}
diff --git a/src/sksl/SkSLCompiler.h b/src/sksl/SkSLCompiler.h
index 9f3be47..4143adc 100644
--- a/src/sksl/SkSLCompiler.h
+++ b/src/sksl/SkSLCompiler.h
@@ -46,6 +46,11 @@
namespace SkSL {
+namespace dsl {
+ class DSL;
+ class DSLWriter;
+} // namespace dsl
+
class ByteCode;
class ExternalValue;
class IRGenerator;
@@ -179,7 +184,7 @@
void error(int offset, String msg) override;
- String errorText();
+ String errorText(bool showCount = true);
void writeErrorCount();
@@ -306,6 +311,8 @@
friend class AutoSource;
friend class ::SkSLCompileBench;
+ friend class dsl::DSL;
+ friend class dsl::DSLWriter;
};
#if !defined(SKSL_STANDALONE) && SK_SUPPORT_GPU
diff --git a/src/sksl/SkSLDefines.h b/src/sksl/SkSLDefines.h
index c1376c6..3075c5c 100644
--- a/src/sksl/SkSLDefines.h
+++ b/src/sksl/SkSLDefines.h
@@ -28,8 +28,11 @@
#define NORETURN __attribute__((__noreturn__))
#endif
-#if defined(SK_BUILD_FOR_IOS) && \
- (!defined(__IPHONE_9_0) || __IPHONE_OS_VERSION_MIN_REQUIRED < __IPHONE_9_0)
+#if defined(SK_BUILD_FOR_IOS) && \
+ (!defined(__IPHONE_9_0) || __IPHONE_OS_VERSION_MIN_REQUIRED < __IPHONE_9_0) || \
+ /* thread_local actually works for modern Android, but Android doesn't provide a */ \
+ /* convenient version macro */ \
+ defined(__ANDROID__)
#define SKSL_USE_THREAD_LOCAL 0
#else
#define SKSL_USE_THREAD_LOCAL 1
diff --git a/src/sksl/SkSLDehydrator.cpp b/src/sksl/SkSLDehydrator.cpp
index 4bbb16e..e8279b3 100644
--- a/src/sksl/SkSLDehydrator.cpp
+++ b/src/sksl/SkSLDehydrator.cpp
@@ -280,6 +280,9 @@
this->writeU8(b.value());
break;
}
+ case Expression::Kind::kCodeString:
+ SkDEBUGFAIL("shouldn't be able to receive kCodeString here");
+ break;
case Expression::Kind::kConstructor: {
const Constructor& c = e->as<Constructor>();
this->writeCommand(Rehydrator::kConstructor_Command);
diff --git a/src/sksl/SkSLIRGenerator.cpp b/src/sksl/SkSLIRGenerator.cpp
index 2e7ae56..23e8f33 100644
--- a/src/sksl/SkSLIRGenerator.cpp
+++ b/src/sksl/SkSLIRGenerator.cpp
@@ -1362,11 +1362,10 @@
}
}
-std::unique_ptr<Expression> IRGenerator::convertIdentifier(const ASTNode& identifier) {
- SkASSERT(identifier.fKind == ASTNode::Kind::kIdentifier);
- const Symbol* result = (*fSymbolTable)[identifier.getString()];
+std::unique_ptr<Expression> IRGenerator::convertIdentifier(int offset, StringFragment name) {
+ const Symbol* result = (*fSymbolTable)[name];
if (!result) {
- fErrors.error(identifier.fOffset, "unknown identifier '" + identifier.getString() + "'");
+ fErrors.error(offset, "unknown identifier '" + name + "'");
return nullptr;
}
switch (result->kind()) {
@@ -1374,12 +1373,11 @@
std::vector<const FunctionDeclaration*> f = {
&result->as<FunctionDeclaration>()
};
- return std::make_unique<FunctionReference>(fContext, identifier.fOffset, f);
+ return std::make_unique<FunctionReference>(fContext, offset, f);
}
case Symbol::Kind::kUnresolvedFunction: {
const UnresolvedFunction* f = &result->as<UnresolvedFunction>();
- return std::make_unique<FunctionReference>(fContext, identifier.fOffset,
- f->functions());
+ return std::make_unique<FunctionReference>(fContext, offset, f->functions());
}
case Symbol::Kind::kVariable: {
const Variable* var = &result->as<Variable>();
@@ -1418,19 +1416,18 @@
}
}
if (!valid) {
- fErrors.error(identifier.fOffset, "'in' variable must be either 'uniform' or "
- "'layout(key)', or there must be a custom "
- "@setData function");
+ fErrors.error(offset, "'in' variable must be either 'uniform' or 'layout(key)',"
+ " or there must be a custom @setData function");
}
}
// default to kRead_RefKind; this will be corrected later if the variable is written to
- return std::make_unique<VariableReference>(identifier.fOffset,
+ return std::make_unique<VariableReference>(offset,
var,
VariableReference::RefKind::kRead);
}
case Symbol::Kind::kField: {
const Field* field = &result->as<Field>();
- auto base = std::make_unique<VariableReference>(identifier.fOffset, &field->owner(),
+ auto base = std::make_unique<VariableReference>(offset, &field->owner(),
VariableReference::RefKind::kRead);
return std::make_unique<FieldAccess>(std::move(base),
field->fieldIndex(),
@@ -1438,17 +1435,21 @@
}
case Symbol::Kind::kType: {
const Type* t = &result->as<Type>();
- return std::make_unique<TypeReference>(fContext, identifier.fOffset, t);
+ return std::make_unique<TypeReference>(fContext, offset, t);
}
case Symbol::Kind::kExternal: {
const ExternalValue* r = &result->as<ExternalValue>();
- return std::make_unique<ExternalValueReference>(identifier.fOffset, r);
+ return std::make_unique<ExternalValueReference>(offset, r);
}
default:
ABORT("unsupported symbol type %d\n", (int) result->kind());
}
}
+std::unique_ptr<Expression> IRGenerator::convertIdentifier(const ASTNode& identifier) {
+ return this->convertIdentifier(identifier.fOffset, identifier.getString());
+}
+
std::unique_ptr<Section> IRGenerator::convertSection(const ASTNode& s) {
if (fKind != Program::kFragmentProcessor_Kind) {
fErrors.error(s.fOffset, "syntax error");
@@ -1970,6 +1971,22 @@
if (!right) {
return nullptr;
}
+ return this->convertBinaryExpression(std::move(left), op, std::move(right));
+}
+
+std::unique_ptr<Expression> IRGenerator::convertBinaryExpression(
+ std::unique_ptr<Expression> left,
+ Token::Kind op,
+ std::unique_ptr<Expression> right) {
+ if (op == Token::Kind::TK_LOGICALAND || op == Token::Kind::TK_LOGICALOR ||
+ op == Token::Kind::TK_LOGICALXOR) {
+ left = this->coerce(std::move(left), *fContext.fBool_Type);
+ right = this->coerce(std::move(right), *fContext.fBool_Type);
+ }
+ if (!left || !right) {
+ return nullptr;
+ }
+ int offset = left->fOffset;
const Type* leftType;
const Type* rightType;
const Type* resultType;
@@ -1987,10 +2004,10 @@
}
if (!determine_binary_type(fContext, fSettings->fAllowNarrowingConversions, op,
*rawLeftType, *rawRightType, &leftType, &rightType, &resultType)) {
- fErrors.error(expression.fOffset, String("type mismatch: '") +
- Compiler::OperatorName(expression.getToken().fKind) +
- "' cannot operate on '" + left->type().displayName() +
- "', '" + right->type().displayName() + "'");
+ fErrors.error(offset, String("type mismatch: '") +
+ Compiler::OperatorName(op) + "' cannot operate on '" +
+ left->type().displayName() + "', '" +
+ right->type().displayName() + "'");
return nullptr;
}
if (Compiler::IsAssignment(op)) {
@@ -2007,28 +2024,21 @@
}
std::unique_ptr<Expression> result = this->constantFold(*left, op, *right);
if (!result) {
- result = std::make_unique<BinaryExpression>(expression.fOffset, std::move(left), op,
- std::move(right), resultType);
+ result = std::make_unique<BinaryExpression>(offset, std::move(left), op, std::move(right),
+ resultType);
}
return result;
}
-std::unique_ptr<Expression> IRGenerator::convertTernaryExpression(const ASTNode& node) {
- SkASSERT(node.fKind == ASTNode::Kind::kTernary);
- auto iter = node.begin();
- std::unique_ptr<Expression> test = this->coerce(this->convertExpression(*(iter++)),
- *fContext.fBool_Type);
- if (!test) {
+std::unique_ptr<Expression> IRGenerator::convertTernaryExpression(
+ std::unique_ptr<Expression> test,
+ std::unique_ptr<Expression> ifTrue,
+ std::unique_ptr<Expression> ifFalse) {
+ test = this->coerce(std::move(test), *fContext.fBool_Type);
+ if (!test || !ifTrue || !ifFalse) {
return nullptr;
}
- std::unique_ptr<Expression> ifTrue = this->convertExpression(*(iter++));
- if (!ifTrue) {
- return nullptr;
- }
- std::unique_ptr<Expression> ifFalse = this->convertExpression(*(iter++));
- if (!ifFalse) {
- return nullptr;
- }
+ int offset = test->fOffset;
const Type* trueType;
const Type* falseType;
const Type* resultType;
@@ -2036,13 +2046,13 @@
Token::Kind::TK_EQEQ, ifTrue->type(), ifFalse->type(),
&trueType, &falseType, &resultType) ||
trueType != falseType) {
- fErrors.error(node.fOffset, "ternary operator result mismatch: '" +
- ifTrue->type().displayName() + "', '" +
- ifFalse->type().displayName() + "'");
+ fErrors.error(offset, "ternary operator result mismatch: '" +
+ ifTrue->type().displayName() + "', '" +
+ ifFalse->type().displayName() + "'");
return nullptr;
}
if (trueType->nonnullable() == *fContext.fFragmentProcessor_Type) {
- fErrors.error(node.fOffset,
+ fErrors.error(offset,
"ternary expression of type '" + trueType->displayName() + "' not allowed");
return nullptr;
}
@@ -2062,12 +2072,30 @@
return ifFalse;
}
}
- return std::make_unique<TernaryExpression>(node.fOffset,
+ return std::make_unique<TernaryExpression>(offset,
std::move(test),
std::move(ifTrue),
std::move(ifFalse));
}
+std::unique_ptr<Expression> IRGenerator::convertTernaryExpression(const ASTNode& node) {
+ SkASSERT(node.fKind == ASTNode::Kind::kTernary);
+ auto iter = node.begin();
+ std::unique_ptr<Expression> test = this->convertExpression(*(iter++));
+ if (!test) {
+ return nullptr;
+ }
+ std::unique_ptr<Expression> ifTrue = this->convertExpression(*(iter++));
+ if (!ifTrue) {
+ return nullptr;
+ }
+ std::unique_ptr<Expression> ifFalse = this->convertExpression(*(iter++));
+ if (!ifFalse) {
+ return nullptr;
+ }
+ return this->convertTernaryExpression(std::move(test), std::move(ifTrue), std::move(ifFalse));
+}
+
void IRGenerator::copyIntrinsicIfNeeded(const FunctionDeclaration& function) {
if (const ProgramElement* found = fIntrinsics->findAndInclude(function.description())) {
const FunctionDefinition& original = found->as<FunctionDefinition>();
@@ -2379,12 +2407,17 @@
if (!base) {
return nullptr;
}
+ return this->convertPrefixExpression(expression.getToken().fKind, std::move(base));
+}
+
+std::unique_ptr<Expression> IRGenerator::convertPrefixExpression(Token::Kind op,
+ std::unique_ptr<Expression> base) {
const Type& baseType = base->type();
- switch (expression.getToken().fKind) {
+ switch (op) {
case Token::Kind::TK_PLUS:
if (!baseType.isNumber() && !baseType.isVector() &&
baseType != *fContext.fFloatLiteral_Type) {
- fErrors.error(expression.fOffset,
+ fErrors.error(base->fOffset,
"'+' cannot operate on '" + baseType.displayName() + "'");
return nullptr;
}
@@ -2401,7 +2434,7 @@
}
if (!baseType.isNumber() &&
!(baseType.isVector() && baseType.componentType().isNumber())) {
- fErrors.error(expression.fOffset,
+ fErrors.error(base->fOffset,
"'-' cannot operate on '" + baseType.displayName() + "'");
return nullptr;
}
@@ -2409,9 +2442,9 @@
case Token::Kind::TK_PLUSPLUS:
if (!baseType.isNumber()) {
- fErrors.error(expression.fOffset,
- String("'") + Compiler::OperatorName(expression.getToken().fKind) +
- "' cannot operate on '" + baseType.displayName() + "'");
+ fErrors.error(base->fOffset,
+ String("'") + Compiler::OperatorName(op) + "' cannot operate on '" +
+ baseType.displayName() + "'");
return nullptr;
}
if (!this->setRefKind(*base, VariableReference::RefKind::kReadWrite)) {
@@ -2420,9 +2453,9 @@
break;
case Token::Kind::TK_MINUSMINUS:
if (!baseType.isNumber()) {
- fErrors.error(expression.fOffset,
- String("'") + Compiler::OperatorName(expression.getToken().fKind) +
- "' cannot operate on '" + baseType.displayName() + "'");
+ fErrors.error(base->fOffset,
+ String("'") + Compiler::OperatorName(op) + "' cannot operate on '" +
+ baseType.displayName() + "'");
return nullptr;
}
if (!this->setRefKind(*base, VariableReference::RefKind::kReadWrite)) {
@@ -2431,9 +2464,9 @@
break;
case Token::Kind::TK_LOGICALNOT:
if (!baseType.isBoolean()) {
- fErrors.error(expression.fOffset,
- String("'") + Compiler::OperatorName(expression.getToken().fKind) +
- "' cannot operate on '" + baseType.displayName() + "'");
+ fErrors.error(base->fOffset,
+ String("'") + Compiler::OperatorName(op) + "' cannot operate on '" +
+ baseType.displayName() + "'");
return nullptr;
}
if (base->kind() == Expression::Kind::kBoolLiteral) {
@@ -2443,16 +2476,16 @@
break;
case Token::Kind::TK_BITWISENOT:
if (baseType != *fContext.fInt_Type && baseType != *fContext.fUInt_Type) {
- fErrors.error(expression.fOffset,
- String("'") + Compiler::OperatorName(expression.getToken().fKind) +
- "' cannot operate on '" + baseType.displayName() + "'");
+ fErrors.error(base->fOffset,
+ String("'") + Compiler::OperatorName(op) + "' cannot operate on '" +
+ baseType.displayName() + "'");
return nullptr;
}
break;
default:
ABORT("unsupported prefix operator\n");
}
- return std::make_unique<PrefixExpression>(expression.getToken().fKind, std::move(base));
+ return std::make_unique<PrefixExpression>(op, std::move(base));
}
std::unique_ptr<Expression> IRGenerator::convertField(std::unique_ptr<Expression> base,
@@ -2485,7 +2518,7 @@
// secondary swizzle to put them back into the right order, so in this case we end up with
// 'float4(base.xw, 1, 0).xzyw'.
std::unique_ptr<Expression> IRGenerator::convertSwizzle(std::unique_ptr<Expression> base,
- StringFragment fields) {
+ String fields) {
const int offset = base->fOffset;
const Type& baseType = base->type();
if (!baseType.isVector() && !baseType.isNumber()) {
@@ -2493,13 +2526,13 @@
return nullptr;
}
- if (fields.fLength > 4) {
+ if (fields.length() > 4) {
fErrors.error(offset, "too many components in swizzle mask '" + fields + "'");
return nullptr;
}
ComponentArray maskComponents;
- for (size_t i = 0; i < fields.fLength; i++) {
+ for (size_t i = 0; i < fields.length(); i++) {
switch (fields[i]) {
case '0':
case '1':
@@ -2565,7 +2598,7 @@
}
// If we have processed the entire swizzle, we're done.
- if (maskComponents.size() == fields.fLength) {
+ if (maskComponents.size() == fields.length()) {
return expr;
}
@@ -2593,7 +2626,7 @@
int constantFieldIdx = maskComponents.size();
int constantZeroIdx = -1, constantOneIdx = -1;
- for (size_t i = 0; i < fields.fLength; i++) {
+ for (size_t i = 0; i < fields.length(); i++) {
switch (fields[i]) {
case '0':
if (constantZeroIdx == -1) {
@@ -2841,21 +2874,27 @@
}
std::unique_ptr<Expression> IRGenerator::convertPostfixExpression(const ASTNode& expression) {
+ SkASSERT(expression.fKind == ASTNode::Kind::kPostfix);
std::unique_ptr<Expression> base = this->convertExpression(*expression.begin());
if (!base) {
return nullptr;
}
+ return this->convertPostfixExpression(std::move(base), expression.getToken().fKind);
+}
+
+std::unique_ptr<Expression> IRGenerator::convertPostfixExpression(std::unique_ptr<Expression> base,
+ Token::Kind op) {
const Type& baseType = base->type();
if (!baseType.isNumber()) {
- fErrors.error(expression.fOffset,
- "'" + String(Compiler::OperatorName(expression.getToken().fKind)) +
- "' cannot operate on '" + baseType.displayName() + "'");
+ fErrors.error(base->fOffset,
+ "'" + String(Compiler::OperatorName(op)) + "' cannot operate on '" +
+ baseType.displayName() + "'");
return nullptr;
}
if (!this->setRefKind(*base, VariableReference::RefKind::kReadWrite)) {
return nullptr;
}
- return std::make_unique<PostfixExpression>(std::move(base), expression.getToken().fKind);
+ return std::make_unique<PostfixExpression>(std::move(base), op);
}
void IRGenerator::checkValid(const Expression& expr) {
diff --git a/src/sksl/SkSLIRGenerator.h b/src/sksl/SkSLIRGenerator.h
index 46834b2..a483b72 100644
--- a/src/sksl/SkSLIRGenerator.h
+++ b/src/sksl/SkSLIRGenerator.h
@@ -38,6 +38,10 @@
struct ParsedModule;
struct Swizzle;
+namespace dsl {
+ class DSLWriter;
+} // namespace dsl
+
/**
* Intrinsics are passed between the Compiler and the IRGenerator using IRIntrinsicMaps.
*/
@@ -146,6 +150,30 @@
void pushSymbolTable();
void popSymbolTable();
+ std::unique_ptr<Expression> call(int offset,
+ std::unique_ptr<Expression> function,
+ ExpressionArray arguments);
+
+ std::unique_ptr<Expression> coerce(std::unique_ptr<Expression> expr, const Type& type);
+
+ std::unique_ptr<Expression> convertBinaryExpression(std::unique_ptr<Expression> left,
+ Token::Kind op,
+ std::unique_ptr<Expression> right);
+
+ std::unique_ptr<Expression> convertIdentifier(int offset, StringFragment identifier);
+
+ std::unique_ptr<Expression> convertPostfixExpression(std::unique_ptr<Expression> base,
+ Token::Kind op);
+
+ std::unique_ptr<Expression> convertPrefixExpression(Token::Kind op,
+ std::unique_ptr<Expression> base);
+
+ std::unique_ptr<Expression> convertSwizzle(std::unique_ptr<Expression> base, String fields);
+
+ std::unique_ptr<Expression> convertTernaryExpression(std::unique_ptr<Expression> test,
+ std::unique_ptr<Expression> ifTrue,
+ std::unique_ptr<Expression> ifFalse);
+
const Context& fContext;
private:
@@ -168,11 +196,7 @@
ExpressionArray arguments);
CoercionCost callCost(const FunctionDeclaration& function,
const ExpressionArray& arguments);
- std::unique_ptr<Expression> call(int offset,
- std::unique_ptr<Expression> function,
- ExpressionArray arguments);
CoercionCost coercionCost(const Expression& expr, const Type& type);
- std::unique_ptr<Expression> coerce(std::unique_ptr<Expression> expr, const Type& type);
template <typename T>
std::unique_ptr<Expression> constantFoldVector(const Expression& left,
Token::Kind op,
@@ -216,8 +240,6 @@
StringFragment field);
std::unique_ptr<Expression> convertField(std::unique_ptr<Expression> base,
StringFragment field);
- std::unique_ptr<Expression> convertSwizzle(std::unique_ptr<Expression> base,
- StringFragment fields);
std::unique_ptr<Expression> convertTernaryExpression(const ASTNode& expression);
std::unique_ptr<Statement> convertVarDeclarationStatement(const ASTNode& s);
std::unique_ptr<Statement> convertWhile(const ASTNode& w);
@@ -268,6 +290,7 @@
friend class AutoSwitchLevel;
friend class AutoDisableInline;
friend class Compiler;
+ friend class dsl::DSLWriter;
};
} // namespace SkSL
diff --git a/src/sksl/SkSLModifiersPool.h b/src/sksl/SkSLModifiersPool.h
index 57fecbe..1d0ae5d 100644
--- a/src/sksl/SkSLModifiersPool.h
+++ b/src/sksl/SkSLModifiersPool.h
@@ -8,12 +8,12 @@
#ifndef SKSL_MODIFIERSPOOL
#define SKSL_MODIFIERSPOOL
+#include "src/sksl/ir/SkSLModifiers.h"
+
#include <unordered_set>
namespace SkSL {
-struct Modifiers;
-
/**
* Deduplicates Modifiers objects and stores them in a shared pool. Modifiers are fairly heavy, and
* tend to be reused a lot, so deduplication can be a significant win.
diff --git a/src/sksl/dsl/DSL.h b/src/sksl/dsl/DSL.h
new file mode 100644
index 0000000..98cd391
--- /dev/null
+++ b/src/sksl/dsl/DSL.h
@@ -0,0 +1,27 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#ifndef SKSL_DSL
+#define SKSL_DSL
+
+#include "src/sksl/dsl/DSL_core.h"
+
+namespace SkSL {
+
+namespace dsl {
+
+using Block = DSLBlock;
+using Expression = DSLExpression;
+using Statement = DSLStatement;
+using Type = DSLType;
+using Var = DSLVar;
+
+} // namespace dsl
+
+} // namespace SkSL
+
+#endif
diff --git a/src/sksl/dsl/DSLBlock.cpp b/src/sksl/dsl/DSLBlock.cpp
new file mode 100644
index 0000000..e99be70
--- /dev/null
+++ b/src/sksl/dsl/DSLBlock.cpp
@@ -0,0 +1,27 @@
+/*
+ * Copyright 2020 Google LLC.
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#include "src/sksl/dsl/DSLBlock.h"
+
+#include "src/sksl/dsl/DSLStatement.h"
+#include "src/sksl/ir/SkSLBlock.h"
+
+namespace SkSL {
+
+namespace dsl {
+
+std::unique_ptr<SkSL::Statement> DSLBlock::release() {
+ return std::make_unique<SkSL::Block>(/*offset=*/-1, std::move(fStatements));
+}
+
+void DSLBlock::append(DSLStatement stmt) {
+ fStatements.push_back(stmt.release());
+}
+
+} // namespace dsl
+
+} // namespace SkSL
diff --git a/src/sksl/dsl/DSLBlock.h b/src/sksl/dsl/DSLBlock.h
new file mode 100644
index 0000000..d5c8414
--- /dev/null
+++ b/src/sksl/dsl/DSLBlock.h
@@ -0,0 +1,50 @@
+/*
+ * Copyright 2020 Google LLC.
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#ifndef SKSL_DSL_BLOCK
+#define SKSL_DSL_BLOCK
+
+#include "include/private/SkTArray.h"
+#include "src/sksl/dsl/DSLExpression.h"
+#include "src/sksl/dsl/DSLStatement.h"
+#include "src/sksl/ir/SkSLIRNode.h"
+
+#include <memory>
+
+namespace SkSL {
+
+class Statement;
+
+namespace dsl {
+
+class DSLBlock {
+public:
+ template<class... Statements>
+ DSLBlock(Statements... statements) {
+ fStatements.reserve_back(sizeof...(statements));
+ (fStatements.push_back(DSLStatement(std::move(statements)).release()), ...);
+ }
+
+ DSLBlock(SkSL::StatementArray statements)
+ : fStatements(std::move(statements)) {}
+
+ void append(DSLStatement stmt);
+
+private:
+ std::unique_ptr<SkSL::Statement> release();
+
+ SkSL::StatementArray fStatements;
+
+ friend class DSLStatement;
+ friend class DSLFunction;
+};
+
+} // namespace dsl
+
+} // namespace SkSL
+
+#endif
diff --git a/src/sksl/dsl/DSLExpression.cpp b/src/sksl/dsl/DSLExpression.cpp
new file mode 100644
index 0000000..ea89591
--- /dev/null
+++ b/src/sksl/dsl/DSLExpression.cpp
@@ -0,0 +1,213 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#include "src/sksl/dsl/DSLExpression.h"
+
+#include "src/sksl/SkSLCompiler.h"
+#include "src/sksl/SkSLIRGenerator.h"
+#include "src/sksl/dsl/DSLVar.h"
+#include "src/sksl/dsl/priv/DSLWriter.h"
+#include "src/sksl/ir/SkSLBinaryExpression.h"
+#include "src/sksl/ir/SkSLBoolLiteral.h"
+#include "src/sksl/ir/SkSLFloatLiteral.h"
+#include "src/sksl/ir/SkSLIntLiteral.h"
+
+namespace SkSL {
+
+namespace dsl {
+
+static std::unique_ptr<SkSL::Expression> check(std::unique_ptr<SkSL::Expression> expr) {
+ if (expr == nullptr) {
+ if (DSLWriter::Compiler().errorCount()) {
+ DSLWriter::ReportError(DSLWriter::Compiler().errorText(/*showCount=*/false).c_str());
+ }
+ }
+ return expr;
+}
+
+DSLExpression::DSLExpression() {}
+
+DSLExpression::DSLExpression(std::unique_ptr<SkSL::Expression> expression)
+ : fExpression(check(std::move(expression))) {}
+
+DSLExpression::DSLExpression(float value)
+ : fExpression(std::make_unique<SkSL::FloatLiteral>(DSLWriter::Context(),
+ /*offset=*/-1,
+ value)) {}
+
+DSLExpression::DSLExpression(int value)
+ : fExpression(std::make_unique<SkSL::IntLiteral>(DSLWriter::Context(),
+ /*offset=*/-1,
+ value)) {}
+
+DSLExpression::DSLExpression(bool value)
+ : fExpression(std::make_unique<SkSL::BoolLiteral>(DSLWriter::Context(),
+ /*offset=*/-1,
+ value)) {}
+
+DSLExpression::DSLExpression(const DSLVar& var)
+ : fExpression(std::make_unique<SkSL::VariableReference>(
+ /*offset=*/-1,
+ var.var(),
+ SkSL::VariableReference::RefKind::kRead)) {}
+
+DSLExpression::~DSLExpression() {
+ SkASSERTF(fExpression == nullptr,
+ "Expression destroyed without being incorporated into output tree");
+}
+
+std::unique_ptr<SkSL::Expression> DSLExpression::release() {
+ return std::move(fExpression);
+}
+
+DSLExpression DSLExpression::operator=(DSLExpression&& right) {
+ SkSL::IRGenerator& ir = DSLWriter::IRGenerator();
+ return DSLExpression(check(ir.convertBinaryExpression(this->release(), SkSL::Token::Kind::TK_EQ,
+ right.release())));
+}
+
+#define OP(op, token) \
+DSLExpression operator op(DSLExpression left, DSLExpression right) { \
+ SkSL::IRGenerator& ir = DSLWriter::IRGenerator(); \
+ return DSLExpression(check(ir.convertBinaryExpression(left.release(), SkSL::Token::Kind::token,\
+ right.release()))); \
+}
+
+#define RWOP(op, token) \
+OP(op, token) \
+DSLExpression operator op(DSLVar& left, DSLExpression right) { \
+ SkSL::IRGenerator& ir = DSLWriter::IRGenerator(); \
+ return DSLExpression(check(ir.convertBinaryExpression( \
+ std::make_unique<SkSL::VariableReference>(/*offset=*/-1, \
+ left.var(), \
+ SkSL::VariableReference::RefKind::kReadWrite), \
+ SkSL::Token::Kind::token, right.release()))); \
+}
+
+#define PREFIXOP(op, token) \
+DSLExpression operator op(DSLExpression expr) { \
+ SkSL::IRGenerator& ir = DSLWriter::IRGenerator(); \
+ return DSLExpression(check(ir.convertPrefixExpression(SkSL::Token::Kind::token, \
+ expr.release()))); \
+}
+
+#define POSTFIXOP(op, token) \
+DSLExpression operator op(DSLExpression expr, int) { \
+ SkSL::IRGenerator& ir = DSLWriter::IRGenerator(); \
+ return DSLExpression(check(ir.convertPostfixExpression(expr.release(), \
+ SkSL::Token::Kind::token))); \
+}
+
+OP(+, TK_PLUS)
+RWOP(+=, TK_PLUSEQ)
+OP(-, TK_MINUS)
+RWOP(-=, TK_MINUSEQ)
+OP(*, TK_STAR)
+RWOP(*=, TK_STAREQ)
+OP(/, TK_SLASH)
+RWOP(/=, TK_SLASHEQ)
+OP(%, TK_PERCENT)
+RWOP(%=, TK_PERCENTEQ)
+OP(<<, TK_SHL)
+RWOP(<<=, TK_SHLEQ)
+OP(>>, TK_SHR)
+RWOP(>>=, TK_SHREQ)
+OP(&&, TK_LOGICALAND)
+OP(||, TK_LOGICALOR)
+OP(&, TK_BITWISEAND)
+RWOP(&=, TK_BITWISEANDEQ)
+OP(|, TK_BITWISEOR)
+RWOP(|=, TK_BITWISEOREQ)
+OP(^, TK_BITWISEXOR)
+RWOP(^=, TK_BITWISEXOREQ)
+OP(==, TK_EQEQ)
+OP(!=, TK_NEQ)
+OP(>, TK_GT)
+OP(<, TK_LT)
+OP(>=, TK_GTEQ)
+OP(<=, TK_LTEQ)
+
+PREFIXOP(!, TK_LOGICALNOT)
+PREFIXOP(~, TK_BITWISENOT)
+PREFIXOP(++, TK_PLUSPLUS)
+POSTFIXOP(++, TK_PLUSPLUS)
+PREFIXOP(--, TK_MINUSMINUS)
+POSTFIXOP(--, TK_MINUSMINUS)
+
+DSLExpression operator,(DSLExpression left, DSLExpression right) {
+ SkSL::IRGenerator& ir = DSLWriter::IRGenerator();
+ return DSLExpression(check(ir.convertBinaryExpression(left.release(),
+ SkSL::Token::Kind::TK_COMMA,
+ right.release())));
+}
+
+std::unique_ptr<SkSL::Expression> DSLExpression::coerceAndRelease(const SkSL::Type& type) {
+ // tripping this assert means we had an error occur somewhere else in DSL construction that
+ // wasn't caught where it should have been
+ SkASSERTF(!DSLWriter::Compiler().errorCount(), "Unexpected SkSL DSL error: %s",
+ DSLWriter::Compiler().errorText().c_str());
+ return check(DSLWriter::IRGenerator().coerce(this->release(), type));
+}
+
+static SkSL::String swizzle_component(SwizzleComponent c) {
+ switch (c) {
+ case R:
+ return "r";
+ case G:
+ return "g";
+ case B:
+ return "b";
+ case A:
+ return "a";
+ case X:
+ return "x";
+ case Y:
+ return "y";
+ case Z:
+ return "z";
+ case W:
+ return "w";
+ case ZERO:
+ return "0";
+ case ONE:
+ return "1";
+ default:
+ SkUNREACHABLE;
+ }
+}
+
+DSLExpression Swizzle(DSLExpression base, SwizzleComponent a) {
+ return DSLExpression(check(DSLWriter::IRGenerator().convertSwizzle(base.release(),
+ swizzle_component(a))));
+}
+
+DSLExpression Swizzle(DSLExpression base, SwizzleComponent a, SwizzleComponent b) {
+ return DSLExpression(check(DSLWriter::IRGenerator().convertSwizzle(base.release(),
+ swizzle_component(a) +
+ swizzle_component(b))));
+}
+
+DSLExpression Swizzle(DSLExpression base, SwizzleComponent a, SwizzleComponent b,
+ SwizzleComponent c) {
+ return DSLExpression(check(DSLWriter::IRGenerator().convertSwizzle(base.release(),
+ swizzle_component(a) +
+ swizzle_component(b) +
+ swizzle_component(c))));
+}
+
+DSLExpression Swizzle(DSLExpression base, SwizzleComponent a, SwizzleComponent b,
+ SwizzleComponent c, SwizzleComponent d) {
+ return DSLExpression(check(DSLWriter::IRGenerator().convertSwizzle(base.release(),
+ swizzle_component(a) +
+ swizzle_component(b) +
+ swizzle_component(c) +
+ swizzle_component(d))));
+}
+
+} // namespace dsl
+
+} // namespace SkSL
diff --git a/src/sksl/dsl/DSLExpression.h b/src/sksl/dsl/DSLExpression.h
new file mode 100644
index 0000000..fd13aad
--- /dev/null
+++ b/src/sksl/dsl/DSLExpression.h
@@ -0,0 +1,238 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#ifndef SKSL_DSL_EXPRESSION
+#define SKSL_DSL_EXPRESSION
+
+#include "include/core/SkTypes.h"
+#include "src/sksl/ir/SkSLIRNode.h"
+
+#include <cstdint>
+#include <memory>
+
+namespace SkSL {
+
+class Expression;
+class Statement;
+class Type;
+
+namespace dsl {
+
+class DSLExpression;
+class DSLStatement;
+class DSLType;
+class DSLVar;
+
+enum SwizzleComponent : int8_t {
+ R,
+ G,
+ B,
+ A,
+ X,
+ Y,
+ Z,
+ W,
+ ZERO,
+ ONE
+};
+
+DSLExpression Swizzle(DSLExpression base, SwizzleComponent a);
+
+DSLExpression Swizzle(DSLExpression base, SwizzleComponent a, SwizzleComponent b);
+
+DSLExpression Swizzle(DSLExpression base, SwizzleComponent a, SwizzleComponent b,
+ SwizzleComponent c);
+
+DSLExpression Swizzle(DSLExpression base, SwizzleComponent a, SwizzleComponent b,
+ SwizzleComponent c, SwizzleComponent d);
+
+/**
+ * Represents an expression such as 'cos(x)' or 'a + b'.
+ */
+class DSLExpression {
+public:
+ DSLExpression(const DSLExpression&) = delete;
+
+ DSLExpression(DSLExpression&&) = default;
+
+ DSLExpression();
+
+ /**
+ * Creates an expression representing a literal float.
+ */
+ DSLExpression(float value);
+
+ /**
+ * Creates an expression representing a literal float.
+ */
+ DSLExpression(double value)
+ : DSLExpression((float) value) {}
+
+ /**
+ * Creates an expression representing a literal int.
+ */
+ DSLExpression(int value);
+
+ /**
+ * Creates an expression representing a literal bool.
+ */
+ DSLExpression(bool value);
+
+ /**
+ * Creates an expression representing a variable reference.
+ */
+ DSLExpression(const DSLVar& var);
+
+ ~DSLExpression();
+
+ /**
+ * Overloads the '=' operator to create an SkSL assignment statement.
+ */
+ DSLExpression operator=(DSLExpression&& other);
+
+ /**
+ * Creates an SkSL array index expression.
+ */
+ DSLExpression operator[](DSLExpression&& index);
+
+ /**
+ * Invalidates this object and returns the SkSL expression it represents.
+ */
+ std::unique_ptr<SkSL::Expression> release();
+
+private:
+ DSLExpression(std::unique_ptr<SkSL::Expression> expression);
+
+ /**
+ * Invalidates this object and returns the SkSL expression it represents coerced to the
+ * specified type. If the expression cannot be coerced, reports an error and returns null.
+ */
+ std::unique_ptr<SkSL::Expression> coerceAndRelease(const SkSL::Type& type);
+
+ std::unique_ptr<SkSL::Expression> fExpression;
+
+ template <typename... Args>
+ friend DSLExpression dsl_function(const char* name, Args... args);
+
+ friend DSLExpression dsl_construct(const SkSL::Type& type, std::vector<DSLExpression> rawArgs);
+ friend DSLStatement Declare(DSLVar& var, DSLExpression initialValue);
+ friend DSLStatement Do(DSLStatement stmt, DSLExpression test);
+ friend DSLStatement For(DSLStatement initializer, DSLExpression test, DSLExpression next,
+ DSLStatement stmt);
+ friend DSLStatement If(DSLExpression test, DSLStatement ifTrue, DSLStatement ifFalse);
+ friend DSLExpression Swizzle(DSLExpression base, SwizzleComponent a);
+ friend DSLExpression Swizzle(DSLExpression base, SwizzleComponent a, SwizzleComponent b);
+ friend DSLExpression Swizzle(DSLExpression base, SwizzleComponent a, SwizzleComponent b,
+ SwizzleComponent c);
+ friend DSLExpression Swizzle(DSLExpression base, SwizzleComponent a, SwizzleComponent b,
+ SwizzleComponent c, SwizzleComponent d);
+ friend DSLExpression Ternary(DSLExpression test, DSLExpression ifTrue, DSLExpression ifFalse);
+ friend DSLStatement While(DSLExpression test, DSLStatement stmt);
+ friend DSLExpression sampleChild(int index, DSLExpression coordinates);
+
+ friend class DSLBlock;
+ friend class DSLStatement;
+ friend class DSLVar;
+
+ friend DSLExpression operator+(DSLExpression left, DSLExpression right);
+ friend DSLExpression operator+=(DSLExpression left, DSLExpression right);
+ friend DSLExpression operator+=(DSLVar& left, DSLExpression right);
+ friend DSLExpression operator-(DSLExpression left, DSLExpression right);
+ friend DSLExpression operator-=(DSLExpression left, DSLExpression right);
+ friend DSLExpression operator-=(DSLVar& left, DSLExpression right);
+ friend DSLExpression operator*(DSLExpression left, DSLExpression right);
+ friend DSLExpression operator*=(DSLExpression left, DSLExpression right);
+ friend DSLExpression operator*=(DSLVar& left, DSLExpression right);
+ friend DSLExpression operator/(DSLExpression left, DSLExpression right);
+ friend DSLExpression operator/=(DSLExpression left, DSLExpression right);
+ friend DSLExpression operator/=(DSLVar& left, DSLExpression right);
+ friend DSLExpression operator%(DSLExpression left, DSLExpression right);
+ friend DSLExpression operator%=(DSLExpression left, DSLExpression right);
+ friend DSLExpression operator%=(DSLVar& left, DSLExpression right);
+ friend DSLExpression operator<<(DSLExpression left, DSLExpression right);
+ friend DSLExpression operator<<=(DSLExpression left, DSLExpression right);
+ friend DSLExpression operator<<=(DSLVar& left, DSLExpression right);
+ friend DSLExpression operator>>(DSLExpression left, DSLExpression right);
+ friend DSLExpression operator>>=(DSLExpression left, DSLExpression right);
+ friend DSLExpression operator>>=(DSLVar& left, DSLExpression right);
+ friend DSLExpression operator&&(DSLExpression left, DSLExpression right);
+ friend DSLExpression operator||(DSLExpression left, DSLExpression right);
+ friend DSLExpression operator&(DSLExpression left, DSLExpression right);
+ friend DSLExpression operator&=(DSLExpression left, DSLExpression right);
+ friend DSLExpression operator&=(DSLVar& left, DSLExpression right);
+ friend DSLExpression operator|(DSLExpression left, DSLExpression right);
+ friend DSLExpression operator|=(DSLExpression left, DSLExpression right);
+ friend DSLExpression operator|=(DSLVar& left, DSLExpression right);
+ friend DSLExpression operator^(DSLExpression left, DSLExpression right);
+ friend DSLExpression operator^=(DSLExpression left, DSLExpression right);
+ friend DSLExpression operator^=(DSLVar& left, DSLExpression right);
+ friend DSLExpression operator,(DSLExpression left, DSLExpression right);
+ friend DSLExpression operator==(DSLExpression left, DSLExpression right);
+ friend DSLExpression operator!=(DSLExpression left, DSLExpression right);
+ friend DSLExpression operator>(DSLExpression left, DSLExpression right);
+ friend DSLExpression operator<(DSLExpression left, DSLExpression right);
+ friend DSLExpression operator>=(DSLExpression left, DSLExpression right);
+ friend DSLExpression operator<=(DSLExpression left, DSLExpression right);
+ friend DSLExpression operator!(DSLExpression expr);
+ friend DSLExpression operator~(DSLExpression expr);
+ friend DSLExpression operator++(DSLExpression expr);
+ friend DSLExpression operator++(DSLExpression expr, int);
+ friend DSLExpression operator--(DSLExpression expr);
+ friend DSLExpression operator--(DSLExpression expr, int);
+};
+
+DSLExpression operator+(DSLExpression left, DSLExpression right);
+DSLExpression operator+=(DSLExpression left, DSLExpression right);
+DSLExpression operator+=(DSLVar& left, DSLExpression right);
+DSLExpression operator-(DSLExpression left, DSLExpression right);
+DSLExpression operator-=(DSLExpression left, DSLExpression right);
+DSLExpression operator-=(DSLVar& left, DSLExpression right);
+DSLExpression operator*(DSLExpression left, DSLExpression right);
+DSLExpression operator*=(DSLExpression left, DSLExpression right);
+DSLExpression operator*=(DSLVar& left, DSLExpression right);
+DSLExpression operator/(DSLExpression left, DSLExpression right);
+DSLExpression operator/=(DSLExpression left, DSLExpression right);
+DSLExpression operator/=(DSLVar& left, DSLExpression right);
+DSLExpression operator%(DSLExpression left, DSLExpression right);
+DSLExpression operator%=(DSLExpression left, DSLExpression right);
+DSLExpression operator%=(DSLVar& left, DSLExpression right);
+DSLExpression operator<<(DSLExpression left, DSLExpression right);
+DSLExpression operator<<=(DSLExpression left, DSLExpression right);
+DSLExpression operator<<=(DSLVar& left, DSLExpression right);
+DSLExpression operator>>(DSLExpression left, DSLExpression right);
+DSLExpression operator>>=(DSLExpression left, DSLExpression right);
+DSLExpression operator>>=(DSLVar& left, DSLExpression right);
+DSLExpression operator&&(DSLExpression left, DSLExpression right);
+DSLExpression operator||(DSLExpression left, DSLExpression right);
+DSLExpression operator&(DSLExpression left, DSLExpression right);
+DSLExpression operator&=(DSLExpression left, DSLExpression right);
+DSLExpression operator&=(DSLVar& left, DSLExpression right);
+DSLExpression operator|(DSLExpression left, DSLExpression right);
+DSLExpression operator|=(DSLExpression left, DSLExpression right);
+DSLExpression operator|=(DSLVar& left, DSLExpression right);
+DSLExpression operator^(DSLExpression left, DSLExpression right);
+DSLExpression operator^=(DSLExpression left, DSLExpression right);
+DSLExpression operator^=(DSLVar& left, DSLExpression right);
+DSLExpression operator,(DSLExpression left, DSLExpression right);
+DSLExpression operator==(DSLExpression left, DSLExpression right);
+DSLExpression operator!=(DSLExpression left, DSLExpression right);
+DSLExpression operator>(DSLExpression left, DSLExpression right);
+DSLExpression operator<(DSLExpression left, DSLExpression right);
+DSLExpression operator>=(DSLExpression left, DSLExpression right);
+DSLExpression operator<=(DSLExpression left, DSLExpression right);
+DSLExpression operator!(DSLExpression expr);
+DSLExpression operator~(DSLExpression expr);
+DSLExpression operator++(DSLExpression expr);
+DSLExpression operator++(DSLExpression expr, int);
+DSLExpression operator--(DSLExpression expr);
+DSLExpression operator--(DSLExpression expr, int);
+
+} // namespace dsl
+
+} // namespace SkSL
+
+#endif
diff --git a/src/sksl/dsl/DSLFunction.h b/src/sksl/dsl/DSLFunction.h
new file mode 100644
index 0000000..2adb5c4
--- /dev/null
+++ b/src/sksl/dsl/DSLFunction.h
@@ -0,0 +1,68 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#ifndef SKSL_DSL_FUNCTION
+#define SKSL_DSL_FUNCTION
+
+#include "src/sksl/SkSLString.h"
+#include "src/sksl/dsl/DSLBlock.h"
+#include "src/sksl/dsl/DSLType.h"
+#include "src/sksl/dsl/priv/DSLWriter.h"
+#include "src/sksl/ir/SkSLBlock.h"
+#include "src/sksl/ir/SkSLFunctionDefinition.h"
+
+namespace SkSL {
+
+class Block;
+class Variable;
+
+namespace dsl {
+
+class DSLType;
+
+class DSLFunction {
+public:
+ template<class... Parameters>
+ DSLFunction(const DSLType& returnType, const char* name, Parameters&... parameters)
+ : fReturnType(returnType.skslType()) {
+ std::vector<const Variable*> parameterArray;
+ parameterArray.reserve(sizeof...(parameters));
+ (parameterArray.push_back(parameters.var()), ...);
+ SkSL::SymbolTable& symbols = *DSLWriter::SymbolTable();
+ fDecl = symbols.add(std::make_unique<SkSL::FunctionDeclaration>(
+ /*offset=*/-1,
+ DSLWriter::Modifiers(SkSL::Modifiers()),
+ DSLWriter::Name(name),
+ std::move(parameterArray), &fReturnType,
+ /*builtin=*/false));
+ }
+
+ virtual ~DSLFunction() = default;
+
+ template<class... Stmt>
+ void define(Stmt... stmts) {
+ DSLBlock block = DSLBlock(DSLStatement(std::move(stmts))...);
+ this->define(std::move(block));
+ }
+
+ void define(DSLBlock block) {
+ DSLWriter::ProgramElements().emplace_back(new SkSL::FunctionDefinition(/*offset=*/-1,
+ fDecl,
+ /*builtin=*/false,
+ block.release()));
+ }
+
+protected:
+ const SkSL::Type& fReturnType;
+ const SkSL::FunctionDeclaration* fDecl;
+};
+
+} // namespace dsl
+
+} // namespace SkSL
+
+#endif
diff --git a/src/sksl/dsl/DSLModifiers.h b/src/sksl/dsl/DSLModifiers.h
new file mode 100644
index 0000000..70ebd39
--- /dev/null
+++ b/src/sksl/dsl/DSLModifiers.h
@@ -0,0 +1,43 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#ifndef SKSL_DSL_MODIFIERS
+#define SKSL_DSL_MODIFIERS
+
+#include "src/sksl/ir/SkSLModifiers.h"
+namespace SkSL {
+
+namespace dsl {
+
+class DSLModifiers {
+public:
+ enum Flag {
+ kNo_Flag = 0,
+ kConst_Flag = 1 << 0,
+ kIn_Flag = 1 << 1,
+ kOut_Flag = 1 << 2,
+ kUniform_Flag = 1 << 3,
+ kFlat_Flag = 1 << 4,
+ kNoPerspective_Flag = 1 << 5,
+ };
+
+ DSLModifiers() {}
+
+ DSLModifiers(Flag flags)
+ : fModifiers(SkSL::Layout(), flags) {}
+
+private:
+ SkSL::Modifiers fModifiers;
+
+ friend class DSLVar;
+};
+
+} // namespace dsl
+
+} // namespace SkSL
+
+#endif
diff --git a/src/sksl/dsl/DSLStatement.cpp b/src/sksl/dsl/DSLStatement.cpp
new file mode 100644
index 0000000..32d10e7
--- /dev/null
+++ b/src/sksl/dsl/DSLStatement.cpp
@@ -0,0 +1,33 @@
+/*
+ * Copyright 2020 Google LLC.
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#include "src/sksl/dsl/DSLStatement.h"
+
+#include "src/sksl/dsl/DSLBlock.h"
+#include "src/sksl/dsl/DSLExpression.h"
+#include "src/sksl/ir/SkSLExpressionStatement.h"
+
+namespace SkSL {
+
+namespace dsl {
+
+DSLStatement::DSLStatement(DSLBlock block)
+ : fStatement(block.release()) {}
+
+DSLStatement::DSLStatement(DSLExpression expr) {
+ std::unique_ptr<SkSL::Expression> skslExpr = expr.release();
+ if (skslExpr) {
+ fStatement = std::make_unique<SkSL::ExpressionStatement>(std::move(skslExpr));
+ }
+}
+
+DSLStatement::DSLStatement(std::unique_ptr<SkSL::Expression> expr)
+ : fStatement(std::make_unique<SkSL::ExpressionStatement>(std::move(expr))) {}
+
+} // namespace dsl
+
+} // namespace SkSL
diff --git a/src/sksl/dsl/DSLStatement.h b/src/sksl/dsl/DSLStatement.h
new file mode 100644
index 0000000..7674b0e
--- /dev/null
+++ b/src/sksl/dsl/DSLStatement.h
@@ -0,0 +1,72 @@
+/*
+ * Copyright 2020 Google LLC.
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#ifndef SKSL_DSL_STATEMENT
+#define SKSL_DSL_STATEMENT
+
+#include "include/core/SkString.h"
+#include "include/core/SkTypes.h"
+#include "src/sksl/ir/SkSLIRNode.h"
+
+#include <memory>
+
+class GrGLSLShaderBuilder;
+
+namespace SkSL {
+
+class Statement;
+
+namespace dsl {
+
+class DSLBlock;
+class DSLExpression;
+class DSLVar;
+
+class DSLStatement {
+public:
+ DSLStatement() {}
+
+ DSLStatement(DSLExpression expr);
+
+ DSLStatement(DSLBlock block);
+
+ DSLStatement(DSLStatement&&) = default;
+
+ ~DSLStatement() {
+ SkASSERTF(!fStatement, "Statement destroyed without being incorporated into output tree");
+ }
+
+ std::unique_ptr<SkSL::Statement> release() {
+ return std::move(fStatement);
+ }
+
+private:
+ DSLStatement(std::unique_ptr<SkSL::Statement> stmt)
+ : fStatement(std::move(stmt)) {
+ SkASSERT(fStatement);
+ }
+
+ DSLStatement(std::unique_ptr<SkSL::Expression> expr);
+
+ std::unique_ptr<SkSL::Statement> fStatement;
+
+ friend DSLStatement Declare(DSLVar& var, DSLExpression initialValue);
+ friend DSLStatement Do(DSLStatement stmt, DSLExpression test);
+ friend DSLStatement For(DSLStatement initializer, DSLExpression test, DSLExpression next,
+ DSLStatement stmt);
+ friend DSLStatement If(DSLExpression test, DSLStatement ifTrue, DSLStatement ifFalse);
+ friend DSLStatement While(DSLExpression test, DSLStatement stmt);
+
+ friend class DSLBlock;
+ friend class ::GrGLSLShaderBuilder;
+};
+
+} // namespace dsl
+
+} // namespace SkSL
+
+#endif
diff --git a/src/sksl/dsl/DSLType.cpp b/src/sksl/dsl/DSLType.cpp
new file mode 100644
index 0000000..fc4326c
--- /dev/null
+++ b/src/sksl/dsl/DSLType.cpp
@@ -0,0 +1,172 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#include "src/sksl/dsl/DSLType.h"
+
+#include "src/sksl/dsl/priv/DSLWriter.h"
+#include "src/sksl/ir/SkSLConstructor.h"
+
+namespace SkSL {
+
+namespace dsl {
+
+const SkSL::Type& DSLType::skslType() const {
+ if (fSkSLType) {
+ return *fSkSLType;
+ }
+ const SkSL::Context& context = DSLWriter::Context();
+ switch (fTypeConstant) {
+ case kBool:
+ return *context.fBool_Type;
+ case kBool2:
+ return *context.fBool2_Type;
+ case kBool3:
+ return *context.fBool3_Type;
+ case kBool4:
+ return *context.fBool4_Type;
+ case kHalf:
+ return *context.fHalf_Type;
+ case kHalf2:
+ return *context.fHalf2_Type;
+ case kHalf3:
+ return *context.fHalf3_Type;
+ case kHalf4:
+ return *context.fHalf4_Type;
+ case kFloat:
+ return *context.fFloat_Type;
+ case kFloat2:
+ return *context.fFloat2_Type;
+ case kFloat3:
+ return *context.fFloat3_Type;
+ case kFloat4:
+ return *context.fFloat4_Type;
+ case kInt:
+ return *context.fInt_Type;
+ case kInt2:
+ return *context.fInt2_Type;
+ case kInt3:
+ return *context.fInt3_Type;
+ case kInt4:
+ return *context.fInt4_Type;
+ case kShort:
+ return *context.fShort_Type;
+ case kShort2:
+ return *context.fShort2_Type;
+ case kShort3:
+ return *context.fShort3_Type;
+ case kShort4:
+ return *context.fShort4_Type;
+ case kVoid:
+ return *context.fVoid_Type;
+ default:
+ SkUNREACHABLE;
+ }
+}
+
+DSLExpression dsl_construct(const SkSL::Type& type, std::vector<DSLExpression> rawArgs) {
+ SkSL::ExpressionArray args;
+ for (DSLExpression& arg : rawArgs) {
+ args.push_back(arg.release());
+ }
+ return DSLExpression(DSLWriter::IRGenerator().call(
+ /*offset=*/-1,
+ std::make_unique<SkSL::TypeReference>(DSLWriter::Context(),
+ /*offset=*/-1,
+ &type),
+ std::move(args)));
+}
+
+static DSLExpression construct1(const SkSL::Type& type, DSLExpression a) {
+ std::vector<DSLExpression> args;
+ args.push_back(std::move(a));
+ return dsl_construct(type, std::move(args));
+}
+
+static DSLExpression construct2(const SkSL::Type& type, DSLExpression a,
+ DSLExpression b) {
+ std::vector<DSLExpression> args;
+ args.push_back(std::move(a));
+ args.push_back(std::move(b));
+ return dsl_construct(type, std::move(args));
+}
+
+static DSLExpression construct3(const SkSL::Type& type, DSLExpression a,
+ DSLExpression b,
+ DSLExpression c) {
+ std::vector<DSLExpression> args;
+ args.push_back(std::move(a));
+ args.push_back(std::move(b));
+ args.push_back(std::move(c));
+ return dsl_construct(type, std::move(args));
+}
+
+static DSLExpression construct4(const SkSL::Type& type, DSLExpression a, DSLExpression b,
+ DSLExpression c, DSLExpression d) {
+ std::vector<DSLExpression> args;
+ args.push_back(std::move(a));
+ args.push_back(std::move(b));
+ args.push_back(std::move(c));
+ args.push_back(std::move(d));
+ return dsl_construct(type, std::move(args));
+}
+
+#define TYPE(T) \
+DSLExpression T(DSLExpression a) { \
+ return construct1(*DSLWriter::Context().f ## T ## _Type, std::move(a)); \
+} \
+DSLExpression T ## 2(DSLExpression a) { \
+ return construct1(*DSLWriter::Context().f ## T ## 2_Type, std::move(a)); \
+} \
+DSLExpression T ## 2(DSLExpression a, DSLExpression b) { \
+ return construct2(*DSLWriter::Context().f ## T ## 2_Type, std::move(a), \
+ std::move(b)); \
+} \
+DSLExpression T ## 3(DSLExpression a) { \
+ return construct1(*DSLWriter::Context().f ## T ## 3_Type, std::move(a)); \
+} \
+DSLExpression T ## 3(DSLExpression a, DSLExpression b) { \
+ return construct2(*DSLWriter::Context().f ## T ## 3_Type, std::move(a), \
+ std::move(b)); \
+} \
+DSLExpression T ## 3(DSLExpression a, DSLExpression b, DSLExpression c) { \
+ return construct3(*DSLWriter::Context().f ## T ## 3_Type, std::move(a), \
+ std::move(b), std::move(c)); \
+} \
+DSLExpression T ## 4(DSLExpression a) { \
+ return construct1(*DSLWriter::Context().f ## T ## 4_Type, std::move(a)); \
+} \
+DSLExpression T ## 4(DSLExpression a, DSLExpression b) { \
+ return construct2(*DSLWriter::Context().f ## T ## 4_Type, std::move(a), \
+ std::move(b)); \
+} \
+DSLExpression T ## 4(DSLExpression a, DSLExpression b, DSLExpression c) { \
+ return construct3(*DSLWriter::Context().f ## T ## 4_Type, std::move(a), std::move(b), \
+ std::move(c)); \
+} \
+DSLExpression T ## 4(DSLExpression a, DSLExpression b, DSLExpression c, DSLExpression d) { \
+ return construct4(*DSLWriter::Context().f ## T ## 4_Type, std::move(a), std::move(b), \
+ std::move(c), std::move(d)); \
+}
+
+TYPE(Bool)
+TYPE(Float)
+TYPE(Half)
+TYPE(Int)
+TYPE(Short)
+
+#undef TYPE
+
+DSLType Array(const DSLType& base, int count) {
+ SkSL::String name = base.skslType().name() + "[" + SkSL::to_string(count) + "]";
+ return DSLType(DSLWriter::SymbolTable()->takeOwnershipOfSymbol(
+ std::make_unique<SkSL::Type>(name, SkSL::Type::TypeKind::kArray,
+ base.skslType(), count)));
+}
+
+} // namespace dsl
+
+} // namespace SkSL
diff --git a/src/sksl/dsl/DSLType.h b/src/sksl/dsl/DSLType.h
new file mode 100644
index 0000000..7e0e225
--- /dev/null
+++ b/src/sksl/dsl/DSLType.h
@@ -0,0 +1,99 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#ifndef SKSL_DSL_TYPE
+#define SKSL_DSL_TYPE
+
+#include "src/sksl/ir/SkSLIRNode.h"
+
+#include <cstdint>
+#include <memory>
+
+namespace SkSL {
+
+class Statement;
+class Type;
+
+namespace dsl {
+
+class DSLExpression;
+class DSLStatement;
+
+enum TypeConstant : uint8_t {
+ kBool,
+ kBool2,
+ kBool3,
+ kBool4,
+ kHalf,
+ kHalf2,
+ kHalf3,
+ kHalf4,
+ kFloat,
+ kFloat2,
+ kFloat3,
+ kFloat4,
+ kInt,
+ kInt2,
+ kInt3,
+ kInt4,
+ kShort,
+ kShort2,
+ kShort3,
+ kShort4,
+ kVoid,
+};
+
+class DSLType {
+public:
+ DSLType(TypeConstant tc)
+ : fTypeConstant(tc) {}
+
+ DSLType(const SkSL::Type* type)
+ : fSkSLType(type) {}
+
+private:
+ const SkSL::Type& skslType() const;
+
+ const SkSL::Type* fSkSLType = nullptr;
+
+ TypeConstant fTypeConstant;
+
+ friend DSLExpression dsl_construct(const SkSL::Type& type, std::vector<DSLExpression> rawArgs);
+ friend DSLType Array(const DSLType& base, int count);
+
+ friend class DSLFunction;
+ friend class DSLVar;
+};
+
+#define TYPE(T) \
+ DSLExpression T(DSLExpression expr); \
+ DSLExpression T##2(DSLExpression expr); \
+ DSLExpression T##2(DSLExpression x, DSLExpression y); \
+ DSLExpression T##3(DSLExpression expr); \
+ DSLExpression T##3(DSLExpression x, DSLExpression y); \
+ DSLExpression T##3(DSLExpression x, DSLExpression y, DSLExpression z); \
+ DSLExpression T##4(DSLExpression expr); \
+ DSLExpression T##4(DSLExpression x, DSLExpression y); \
+ DSLExpression T##4(DSLExpression x, DSLExpression y, DSLExpression z); \
+ DSLExpression T##4(DSLExpression x, DSLExpression y, DSLExpression z, DSLExpression w);
+
+TYPE(Bool)
+TYPE(Float)
+TYPE(Half)
+TYPE(Int)
+TYPE(Short)
+
+#undef TYPE
+#undef TYPE_FRIEND
+
+DSLType Array(const DSLType& base, int count);
+
+} // namespace dsl
+
+} // namespace SkSL
+
+#endif
diff --git a/src/sksl/dsl/DSLVar.cpp b/src/sksl/dsl/DSLVar.cpp
new file mode 100644
index 0000000..0e453d6
--- /dev/null
+++ b/src/sksl/dsl/DSLVar.cpp
@@ -0,0 +1,113 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#include "src/sksl/dsl/DSLVar.h"
+
+#include "src/sksl/SkSLUtil.h"
+#include "src/sksl/dsl/DSLModifiers.h"
+#include "src/sksl/dsl/DSLType.h"
+#include "src/sksl/dsl/priv/DSLWriter.h"
+#include "src/sksl/ir/SkSLBinaryExpression.h"
+#include "src/sksl/ir/SkSLSymbolTable.h"
+#include "src/sksl/ir/SkSLVariable.h"
+#include "src/sksl/ir/SkSLVariableReference.h"
+
+namespace SkSL {
+
+namespace dsl {
+
+DSLVar::DSLVar(const char* name)
+ : fName(name) {}
+
+DSLVar::DSLVar(DSLType type, const char* name)
+ : DSLVar(DSLModifiers(), type, name) {}
+
+DSLVar::DSLVar(DSLModifiers modifiers, DSLType type, const char* name)
+ : fName(DSLWriter::Name(name)) {
+#if SK_SUPPORT_GPU && !defined(SKSL_STANDALONE)
+ if (modifiers.fModifiers.fFlags & Modifiers::kUniform_Flag) {
+ const SkSL::Type& skslType = type.skslType();
+ GrSLType grslType;
+ int count;
+ if (skslType.typeKind() == SkSL::Type::TypeKind::kArray) {
+ SkAssertResult(SkSL::type_to_grsltype(DSLWriter::Context(),
+ skslType.componentType(),
+ &grslType));
+ count = skslType.columns();
+ SkASSERT(count > 0);
+ } else {
+ SkAssertResult(SkSL::type_to_grsltype(DSLWriter::Context(), skslType,
+ &grslType));
+ count = 0;
+ }
+ const char* name;
+ SkASSERT(DSLWriter::CurrentEmitArgs());
+ fUniformHandle = DSLWriter::CurrentEmitArgs()->fUniformHandler->addUniformArray(
+ &DSLWriter::CurrentEmitArgs()->fFp,
+ kFragment_GrShaderFlag,
+ grslType,
+ this->name().c_str(),
+ count,
+ &name);
+ fName = name;
+ }
+#endif // SK_SUPPORT_GPU && !defined(SKSL_STANDALONE)
+ fOwnedVar = std::make_unique<SkSL::Variable>(/*offset=*/-1,
+ DSLWriter::Modifiers(modifiers.fModifiers),
+ SkSL::StringFragment(fName.c_str()),
+ &type.skslType(),
+ /*builtin=*/false,
+ SkSL::Variable::Storage::kLocal);
+ fVar = fOwnedVar.get();
+}
+
+const SkSL::Variable* DSLVar::var() const {
+ if (fVar) {
+ return fVar;
+ }
+ SkSL::StringFragment name(fName.c_str());
+#if !defined(SKSL_STANDALONE) && SK_SUPPORT_GPU
+ if (name == "sk_SampleCoord") {
+ name = DSLWriter::CurrentEmitArgs()->fSampleCoord;
+ } else if (name == "sk_InColor") {
+ name = DSLWriter::CurrentEmitArgs()->fInputColor;
+ } else if (name == "sk_OutColor") {
+ name = DSLWriter::CurrentEmitArgs()->fOutputColor;
+ }
+#endif
+ const SkSL::Symbol* result = (*DSLWriter::SymbolTable())[name];
+ SkASSERTF(result, "could not find '%s' in symbol table", fName.c_str());
+ return &result->as<SkSL::Variable>();
+}
+
+GrGLSLUniformHandler::UniformHandle DSLVar::uniformHandle() {
+ SkASSERT(fVar->modifiers().fFlags & SkSL::Modifiers::kUniform_Flag);
+ return fUniformHandle;
+}
+
+DSLExpression DSLVar::operator[](DSLExpression&& index) {
+ return DSLExpression(std::make_unique<SkSL::IndexExpression>(
+ DSLWriter::Context(),
+ DSLExpression(*this).release(),
+ index.coerceAndRelease(*DSLWriter::Context().fInt_Type)));
+}
+
+DSLExpression DSLVar::operator=(DSLExpression&& expr) {
+ const SkSL::Variable* var = this->var();
+ return DSLExpression(std::make_unique<SkSL::BinaryExpression>(
+ /*offset=*/-1,
+ std::make_unique<SkSL::VariableReference>(/*offset=*/-1,
+ var,
+ SkSL::VariableReference::RefKind::kWrite),
+ SkSL::Token::Kind::TK_EQ,
+ expr.coerceAndRelease(var->type()),
+ &var->type()));
+}
+
+} // namespace dsl
+
+} // namespace SkSL
diff --git a/src/sksl/dsl/DSLVar.h b/src/sksl/dsl/DSLVar.h
new file mode 100644
index 0000000..a24cea7
--- /dev/null
+++ b/src/sksl/dsl/DSLVar.h
@@ -0,0 +1,98 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#ifndef SKSL_DSL_VAR
+#define SKSL_DSL_VAR
+
+#include "src/gpu/glsl/GrGLSLUniformHandler.h"
+#include "src/sksl/SkSLString.h"
+#include "src/sksl/dsl/DSLExpression.h"
+
+namespace SkSL {
+
+class Statement;
+class Variable;
+
+namespace dsl {
+
+class DSLExpression;
+class DSLModifiers;
+class DSLType;
+
+class DSLVar {
+public:
+ DSLVar(const char* name);
+
+ DSLVar(DSLType type, const char* name = "var");
+
+ DSLVar(DSLModifiers modifiers, DSLType type, const char* name = "var");
+
+ DSLVar(DSLVar&&) = delete;
+
+ DSLExpression operator=(const DSLVar& var) {
+ return this->operator=(DSLExpression(var));
+ }
+
+ DSLExpression operator=(DSLExpression&& expr);
+
+ DSLExpression operator=(int expr) {
+ return this->operator=(DSLExpression(expr));
+ }
+
+ DSLExpression operator=(float expr) {
+ return this->operator=(DSLExpression(expr));
+ }
+
+ DSLExpression operator[](DSLExpression&& index);
+
+ DSLExpression operator++() {
+ return ++DSLExpression(*this);
+ }
+
+ DSLExpression operator++(int) {
+ return DSLExpression(*this)++;
+ }
+
+private:
+ const SkSL::Variable* var() const;
+
+ const SkSL::String& name() const {
+ return fName;
+ }
+
+ GrGLSLUniformHandler::UniformHandle uniformHandle();
+
+ // this object owns the var until it is added to a symboltable
+ std::unique_ptr<SkSL::Variable> fOwnedVar;
+ const SkSL::Variable* fVar = nullptr;
+ SkSL::String fName;
+ DSLExpression fInitialValue;
+ GrGLSLUniformHandler::UniformHandle fUniformHandle;
+
+ friend class DSLExpression;
+ friend class DSLFunction;
+ friend class DSLWriter;
+ friend DSLStatement Declare(DSLVar& var, DSLExpression initialValue);
+
+ friend DSLExpression operator+=(DSLVar& left, DSLExpression right);
+ friend DSLExpression operator-=(DSLVar& left, DSLExpression right);
+ friend DSLExpression operator*=(DSLVar& left, DSLExpression right);
+ friend DSLExpression operator/=(DSLVar& left, DSLExpression right);
+ friend DSLExpression operator%=(DSLVar& left, DSLExpression right);
+ friend DSLExpression operator<<=(DSLVar& left, DSLExpression right);
+ friend DSLExpression operator>>=(DSLVar& left, DSLExpression right);
+ friend DSLExpression operator&=(DSLVar& left, DSLExpression right);
+ friend DSLExpression operator|=(DSLVar& left, DSLExpression right);
+ friend DSLExpression operator^=(DSLVar& left, DSLExpression right);
+};
+
+} // namespace dsl
+
+} // namespace SkSL
+
+
+#endif
diff --git a/src/sksl/dsl/DSL_core.cpp b/src/sksl/dsl/DSL_core.cpp
new file mode 100644
index 0000000..f02de6c
--- /dev/null
+++ b/src/sksl/dsl/DSL_core.cpp
@@ -0,0 +1,169 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#include "src/sksl/dsl/DSL_core.h"
+
+#include "src/sksl/SkSLCompiler.h"
+#include "src/sksl/SkSLIRGenerator.h"
+#include "src/sksl/ir/SkSLCodeStringExpression.h"
+#include "src/sksl/ir/SkSLDoStatement.h"
+#include "src/sksl/ir/SkSLForStatement.h"
+#include "src/sksl/ir/SkSLIfStatement.h"
+#include "src/sksl/ir/SkSLWhileStatement.h"
+
+namespace SkSL {
+
+namespace dsl {
+
+DSLVar sk_FragColor() { return DSLVar("sk_FragColor"); }
+
+DSLVar sk_FragCoord() { return DSLVar("sk_FragCoord"); }
+
+DSLVar sk_SampleCoord() { return DSLVar("sk_SampleCoord"); }
+
+DSLVar sk_OutColor() { return DSLVar("sk_OutColor"); }
+
+DSLVar sk_InColor() { return DSLVar("sk_InColor"); }
+
+void SetErrorHandler(ErrorHandler* errorHandler) {
+ DSLWriter::SetErrorHandler(errorHandler);
+}
+
+// normally we would use std::make_unique to create the nodes below, but explicitly creating
+// std::unique_ptr<SkSL::Statement> avoids issues with ambiguous constructor invocations
+
+DSLStatement Declare(DSLVar& var, DSLExpression initialValue) {
+ DSLWriter::SymbolTable()->add(std::move(var.fOwnedVar));
+ return std::unique_ptr<SkSL::Statement>(new SkSL::VarDeclaration(
+ var.var(),
+ &var.var()->type(),
+ /*arraySize=*/0,
+ initialValue.coerceAndRelease(var.var()->type())));
+}
+
+DSLStatement Do(DSLStatement stmt, DSLExpression test) {
+ const SkSL::Type& boolType = *DSLWriter::Context().fBool_Type;
+ return std::unique_ptr<SkSL::Statement>(new SkSL::DoStatement(/*offset=*/-1,
+ stmt.release(),
+ test.coerceAndRelease(boolType)));
+}
+
+DSLStatement For(DSLStatement initializer, DSLExpression test, DSLExpression next,
+ DSLStatement stmt) {
+ const SkSL::Type& boolType = *DSLWriter::Context().fBool_Type;
+ return std::unique_ptr<SkSL::Statement>(new SkSL::ForStatement(/*offset=*/-1,
+ initializer.release(),
+ test.coerceAndRelease(boolType),
+ next.release(),
+ stmt.release(),
+ nullptr));
+}
+
+DSLStatement If(DSLExpression test, DSLStatement ifTrue, DSLStatement ifFalse) {
+ const SkSL::Type& boolType = *DSLWriter::Context().fBool_Type;
+ return std::unique_ptr<SkSL::Statement>(new SkSL::IfStatement(/*offset=*/-1,
+ /*isStatic=*/false,
+ test.coerceAndRelease(boolType),
+ ifTrue.release(),
+ ifFalse.release()));
+}
+
+DSLExpression Ternary(DSLExpression test, DSLExpression ifTrue, DSLExpression ifFalse) {
+ return DSLExpression(DSLWriter::IRGenerator().convertTernaryExpression(test.release(),
+ ifTrue.release(),
+ ifFalse.release()));
+}
+
+DSLStatement While(DSLExpression test, DSLStatement stmt) {
+ const SkSL::Type& boolType = *DSLWriter::Context().fBool_Type;
+ return std::unique_ptr<SkSL::Statement>(new SkSL::WhileStatement(
+ /*offset=*/-1,
+ test.coerceAndRelease(boolType),
+ stmt.release()));
+}
+
+static void ignore(std::unique_ptr<SkSL::Expression>&) {}
+
+template <typename... Args>
+DSLExpression dsl_function(const char* name, Args... args) {
+ SkSL::IRGenerator& ir = DSLWriter::IRGenerator();
+ SkSL::ExpressionArray argArray;
+ argArray.reserve_back(sizeof...(args));
+
+ // in C++17, we could just do:
+ // (argArray.push_back(args.release()), ...);
+ int unused[] = {0, (ignore(argArray.push_back(args.release())), 0)...};
+ static_cast<void>(unused);
+
+ return ir.call(/*offset=*/-1, ir.convertIdentifier(-1, name), std::move(argArray));
+}
+DSLExpression ceil(DSLExpression x) {
+ return dsl_function("ceil", std::move(x));
+}
+
+DSLExpression clamp(DSLExpression x, DSLExpression min, DSLExpression max) {
+ return dsl_function("clamp", std::move(x), std::move(min), std::move(max));
+}
+
+DSLExpression dot(DSLExpression x, DSLExpression y) {
+ return dsl_function("dot", std::move(x), std::move(y));
+}
+
+DSLExpression floor(DSLExpression x) {
+ return dsl_function("floor", std::move(x));
+}
+
+DSLExpression saturate(DSLExpression x) {
+ return dsl_function("saturate", std::move(x));
+}
+
+DSLExpression unpremul(DSLExpression x) {
+ return dsl_function("unpremul", std::move(x));
+}
+
+#if SK_SUPPORT_GPU && !defined(SKSL_STANDALONE)
+DSLExpression sampleChild(int index, DSLExpression coords) {
+ std::unique_ptr<SkSL::Expression> coordsExpr = coords.release();
+ SkString code = DSLWriter::CurrentProcessor()->invokeChild(index, *DSLWriter::CurrentEmitArgs(),
+ coordsExpr ? coordsExpr->description()
+ : "");
+ return DSLExpression(std::make_unique<SkSL::CodeStringExpression>(code.c_str(),
+ DSLWriter::Context().fHalf4_Type.get()));
+}
+
+void Start(GrGLSLFragmentProcessor* currentProcessor,
+ GrGLSLFragmentProcessor::EmitArgs* args) {
+ DSLWriter::Push(currentProcessor, args);
+ SkSL::IRGenerator& ir = DSLWriter::IRGenerator();
+ ir.symbolTable()->add(std::make_unique<SkSL::Variable>(/*offset=*/-1,
+ DSLWriter::Modifiers(SkSL::Modifiers()),
+ args->fSampleCoord,
+ DSLWriter::Context().fFloat2_Type.get(),
+ /*builtin=*/false,
+ SkSL::Variable::Storage::kLocal));
+ ir.symbolTable()->add(std::make_unique<SkSL::Variable>(/*offset=*/-1,
+ DSLWriter::Modifiers(SkSL::Modifiers()),
+ args->fInputColor,
+ DSLWriter::Context().fHalf4_Type.get(),
+ /*builtin=*/false,
+ SkSL::Variable::Storage::kLocal));
+ ir.symbolTable()->add(std::make_unique<SkSL::Variable>(/*offset=*/-1,
+ DSLWriter::Modifiers(SkSL::Modifiers()),
+ args->fOutputColor,
+ DSLWriter::Context().fHalf4_Type.get(),
+ /*builtin=*/false,
+ SkSL::Variable::Storage::kLocal));
+}
+
+void End() {
+ DSLWriter::Pop();
+}
+#endif // SK_SUPPORT_GPU && !defined(SKSL_STANDALONE)
+
+} // namespace dsl
+
+} // namespace SkSL
diff --git a/src/sksl/dsl/DSL_core.h b/src/sksl/dsl/DSL_core.h
new file mode 100644
index 0000000..9aa459b
--- /dev/null
+++ b/src/sksl/dsl/DSL_core.h
@@ -0,0 +1,127 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#ifndef SKSL_DSL_CORE
+#define SKSL_DSL_CORE
+
+#include "src/sksl/dsl/DSLExpression.h"
+#include "src/sksl/dsl/DSLFunction.h"
+#include "src/sksl/dsl/DSLStatement.h"
+#include "src/sksl/dsl/DSLType.h"
+#include "src/sksl/dsl/DSLVar.h"
+
+namespace SkSL {
+
+namespace dsl {
+
+/**
+ * Represents the fragment color, equivalent to gl_FragColor.
+ */
+DSLVar sk_FragColor();
+
+/**
+ * Represents the fragment coordinates, equivalent to gl_FragCoord.
+ */
+DSLVar sk_FragCoord();
+
+/**
+ * (Fragment processors only) Represents args.fSampleCoords.
+ */
+DSLVar sk_SampleCoord();
+
+/**
+ * (Fragment processors only) Represents args.fOutputColor.
+ */
+DSLVar sk_OutColor();
+
+/**
+ * (Fragment processors only) Represents args.fInputColor.
+ */
+DSLVar sk_InColor();
+
+/**
+ * Class which is notified in the event of an error.
+ */
+class ErrorHandler {
+public:
+ virtual ~ErrorHandler() {}
+
+ virtual void handleError(const char* msg) = 0;
+};
+
+/**
+ * Installs an ErrorHandler which will be notified of any errors that occur during DSL calls. If no
+ * ErrorHandler is installed, any errors will be fatal.
+ */
+void SetErrorHandler(ErrorHandler* errorHandler);
+
+/**
+ * Creates a variable declaration statement with an initial value.
+ */
+DSLStatement Declare(DSLVar& var, DSLExpression initialValue = DSLExpression());
+
+/**
+ * do stmt; while (test);
+ */
+DSLStatement Do(DSLStatement stmt, DSLExpression test);
+
+/**
+ * for (initializer; test; next) stmt;
+ */
+DSLStatement For(DSLStatement initializer, DSLExpression test, DSLExpression next,
+ DSLStatement stmt);
+
+/**
+ * if (test) ifTrue; [else ifFalse;]
+ */
+DSLStatement If(DSLExpression test, DSLStatement ifTrue, DSLStatement ifFalse = DSLStatement());
+
+/**
+ * test ? ifTrue : ifFalse
+ */
+DSLExpression Ternary(DSLExpression test, DSLExpression ifTrue, DSLExpression ifFalse);
+
+/**
+ * while (test) stmt;
+ */
+DSLStatement While(DSLExpression test, DSLStatement stmt);
+
+/**
+ * Returns x rounded towards positive infinity.
+ */
+DSLExpression ceil(DSLExpression x);
+
+/**
+ * Returns x clamped to between min and max.
+ */
+DSLExpression clamp(DSLExpression x, DSLExpression min, DSLExpression max);
+
+/**
+ * Returns the dot product of x and y.
+ */
+DSLExpression dot(DSLExpression x, DSLExpression y);
+
+/**
+ * Returns x rounded towards negative infinity.
+ */
+DSLExpression floor(DSLExpression x);
+
+/**
+ * Returns x clamped to the range [0, 1].
+ */
+DSLExpression saturate(DSLExpression x);
+
+/**
+ * Returns x converted from premultipled to unpremultiplied alpha.
+ */
+DSLExpression unpremul(DSLExpression x);
+
+} // namespace dsl
+
+} // namespace SkSL
+
+#endif
diff --git a/src/sksl/dsl/priv/DSLWriter.cpp b/src/sksl/dsl/priv/DSLWriter.cpp
new file mode 100644
index 0000000..714324d
--- /dev/null
+++ b/src/sksl/dsl/priv/DSLWriter.cpp
@@ -0,0 +1,145 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#include "src/sksl/dsl/priv/DSLWriter.h"
+
+#if !defined(SKSL_STANDALONE) && SK_SUPPORT_GPU
+#include "src/gpu/mock/GrMockCaps.h"
+#endif // !defined(SKSL_STANDALONE) && SK_SUPPORT_GPU
+#include "src/sksl/SkSLCompiler.h"
+#include "src/sksl/SkSLIRGenerator.h"
+#include "src/sksl/dsl/DSL_core.h"
+
+#if !SKSL_USE_THREAD_LOCAL
+#include <pthread.h>
+#endif // !SKSL_USE_THREAD_LOCAL
+
+namespace SkSL {
+
+namespace dsl {
+
+DSLWriter::DSLWriter(std::unique_ptr<SkSL::Compiler> compiler)
+ : fCompiler(std::move(compiler)) {
+ SkSL::ParsedModule module = fCompiler->moduleForProgramKind(SkSL::Program::kFragment_Kind);
+ SkSL::IRGenerator& ir = *fCompiler->fIRGenerator;
+ ir.fSymbolTable = module.fSymbols;
+ ir.pushSymbolTable();
+}
+
+SkSL::IRGenerator& DSLWriter::IRGenerator() {
+ return *Compiler().fIRGenerator;
+}
+
+const SkSL::Context& DSLWriter::Context() {
+ return IRGenerator().fContext;
+}
+
+const std::shared_ptr<SkSL::SymbolTable>& DSLWriter::SymbolTable() {
+ return IRGenerator().fSymbolTable;
+}
+
+const SkSL::Modifiers* DSLWriter::Modifiers(SkSL::Modifiers modifiers) {
+ return IRGenerator().fModifiers->addToPool(modifiers);
+}
+
+SkSL::StringFragment DSLWriter::Name(const char* name) {
+ if (ManglingEnabled()) {
+ const SkSL::String* s = SymbolTable()->takeOwnershipOfString(std::make_unique<SkSL::String>(
+ name +
+ SkSL::String("_") +
+ SkSL::to_string(++Instance().fNameCount)));
+ return SkSL::StringFragment(s->c_str(), s->length());
+ }
+ return SkSL::StringFragment(name);
+}
+
+#if !defined(SKSL_STANDALONE) && SK_SUPPORT_GPU
+void DSLWriter::Push(GrGLSLFragmentProcessor* processor,
+ GrGLSLFragmentProcessor::EmitArgs* emitArgs) {
+ Instance().fStack.push({processor, emitArgs});
+ IRGenerator().pushSymbolTable();
+}
+
+void DSLWriter::Pop() {
+ DSLWriter& instance = Instance();
+ SkASSERT(!instance.fStack.empty());
+ instance.fStack.pop();
+ IRGenerator().popSymbolTable();
+}
+#endif // !defined(SKSL_STANDALONE) && SK_SUPPORT_GPU
+
+void DSLWriter::ReportError(const char* msg) {
+ if (Instance().fErrorHandler) {
+ Instance().fErrorHandler->handleError(msg);
+ } else {
+ SK_ABORT("%sNo SkSL DSL error handler configured, treating this as a fatal error\n", msg);
+ }
+}
+
+#if !defined(SKSL_STANDALONE) && SK_SUPPORT_GPU
+
+std::unique_ptr<DSLWriter> DSLWriter::CreateInstance() {
+ static SkSL::Program::Settings settings;
+ static GrMockCaps caps((GrContextOptions()), GrMockOptions());
+ auto compiler = std::make_unique<SkSL::Compiler>(caps.shaderCaps());
+ compiler->fInliner.reset(compiler->fIRGenerator->fModifiers.get(), &settings);
+ compiler->fIRGenerator->fKind = SkSL::Program::kFragment_Kind;
+ compiler->fIRGenerator->fFile = std::make_unique<SkSL::ASTFile>();
+ compiler->fIRGenerator->fSettings = &settings;
+ return std::unique_ptr<DSLWriter>(new DSLWriter(std::move(compiler)));
+}
+
+#endif // !defined(SKSL_STANDALONE) && SK_SUPPORT_GPU
+
+#if !SK_SUPPORT_GPU || defined(SKSL_STANDALONE)
+
+DSLWriter& DSLWriter::Instance() {
+ SkUNREACHABLE;
+}
+
+#elif SKSL_USE_THREAD_LOCAL
+
+DSLWriter& DSLWriter::Instance() {
+ thread_local static std::unique_ptr<DSLWriter> instance;
+ if (!instance) {
+ instance = CreateInstance();
+ }
+ return *instance;
+}
+
+#else
+
+static void destroy_dslwriter(void* dslWriter) {
+ delete static_cast<DSLWriter*>(dslWriter);
+}
+
+static pthread_key_t get_pthread_key() {
+ static pthread_key_t sKey = []{
+ pthread_key_t key;
+ int result = pthread_key_create(&key, destroy_dslwriter);
+ if (result != 0) {
+ SK_ABORT("pthread_key_create failure: %d", result);
+ }
+ return key;
+ }();
+ return sKey;
+}
+
+DSLWriter& DSLWriter::Instance() {
+ DSLWriter* instance = static_cast<DSLWriter*>(pthread_getspecific(get_pthread_key()));
+ if (!instance) {
+ instance = CreateInstance().release();
+ pthread_setspecific(get_pthread_key(), instance);
+ }
+ return *instance;
+}
+
+#endif
+
+} // namespace dsl
+
+} // namespace SkSL
diff --git a/src/sksl/dsl/priv/DSLWriter.h b/src/sksl/dsl/priv/DSLWriter.h
new file mode 100644
index 0000000..c6888bd
--- /dev/null
+++ b/src/sksl/dsl/priv/DSLWriter.h
@@ -0,0 +1,173 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#ifndef SKSL_DSLWRITER
+#define SKSL_DSLWRITER
+
+#include "src/sksl/SkSLModifiersPool.h"
+#include "src/sksl/dsl/DSLExpression.h"
+#include "src/sksl/ir/SkSLExpressionStatement.h"
+#include "src/sksl/ir/SkSLProgram.h"
+#include "src/sksl/ir/SkSLStatement.h"
+#if !defined(SKSL_STANDALONE) && SK_SUPPORT_GPU
+#include "src/gpu/glsl/GrGLSLFragmentProcessor.h"
+#endif // !defined(SKSL_STANDALONE) && SK_SUPPORT_GPU
+
+#include <stack>
+
+namespace SkSL {
+
+class Compiler;
+class Context;
+class IRGenerator;
+class SymbolTable;
+class Type;
+
+namespace dsl {
+
+class ErrorHandler;
+
+/**
+ * Thread-safe class that tracks per-thread state associated with DSL output. This class is for
+ * internal use only.
+ */
+class DSLWriter {
+public:
+ /**
+ * Returns the Compiler used by DSL operations in the current thread.
+ */
+ static SkSL::Compiler& Compiler() {
+ return *Instance().fCompiler;
+ }
+
+ /**
+ * Returns the IRGenerator used by DSL operations in the current thread.
+ */
+ static SkSL::IRGenerator& IRGenerator();
+
+ /**
+ * Returns the Context used by DSL operations in the current thread.
+ */
+ static const SkSL::Context& Context();
+
+ /**
+ * Returns the SymbolTable of the current thread's IRGenerator.
+ */
+ static const std::shared_ptr<SkSL::SymbolTable>& SymbolTable();
+
+ /**
+ * Returns the Compiler used by DSL operations in the current thread.
+ */
+ static const SkSL::Modifiers* Modifiers(SkSL::Modifiers modifiers);
+
+ /**
+ * Returns the (possibly mangled) final name that should be used for an entity with the given
+ * raw name.
+ */
+ static SkSL::StringFragment Name(const char* name);
+
+ /**
+ * Returns the collection to which DSL program elements in this thread should be appended.
+ */
+ static std::vector<std::unique_ptr<SkSL::ProgramElement>>& ProgramElements() {
+ return Instance().fProgramElements;
+ }
+
+#if !defined(SKSL_STANDALONE) && SK_SUPPORT_GPU
+ /**
+ * Returns the fragment processor for which DSL output is being generated for the current
+ * thread.
+ */
+ static GrGLSLFragmentProcessor* CurrentProcessor() {
+ SkASSERTF(!Instance().fStack.empty(), "This feature requires a FragmentProcessor");
+ return Instance().fStack.top().fProcessor;
+ }
+
+ /**
+ * Returns the EmitArgs for fragment processor output in the current thread.
+ */
+ static GrGLSLFragmentProcessor::EmitArgs* CurrentEmitArgs() {
+ SkASSERTF(!Instance().fStack.empty(), "This feature requires a FragmentProcessor");
+ return Instance().fStack.top().fEmitArgs;
+ }
+
+ /**
+ * Pushes a new processor / emitArgs pair for the current thread.
+ */
+ static void Push(GrGLSLFragmentProcessor* processor,
+ GrGLSLFragmentProcessor::EmitArgs* emitArgs);
+
+ /**
+ * Pops the processor / emitArgs pair associated with the current thread.
+ */
+ static void Pop();
+#endif // !defined(SKSL_STANDALONE) && SK_SUPPORT_GPU
+
+ /**
+ * Sets the ErrorHandler associated with the current thread. This object will be notified when
+ * any DSL errors occur. With a null ErrorHandler (the default), any errors will be dumped to
+ * stderr and a fatal exception will be generated.
+ */
+ static void SetErrorHandler(ErrorHandler* errorHandler) {
+ Instance().fErrorHandler = errorHandler;
+ }
+
+ /**
+ * Notifies the current ErrorHandler that a DSL error has occurred. With a null ErrorHandler
+ * (the default), any errors will be dumped to stderr and a fatal exception will be generated.
+ */
+ static void ReportError(const char* msg);
+
+ /**
+ * Readies the DSLWriter to begin outputting a new top-level FragmentProcessor.
+ */
+ static void Reset() {
+ Instance().fNameCount = 0;
+ }
+
+ /**
+ * Returns whether name mangling is enabled.
+ */
+ static bool ManglingEnabled() {
+ return Instance().fMangle;
+ }
+
+ /**
+ * Enables or disables name mangling. Mangling should always be enabling except for tests which
+ * need to guarantee consistent output.
+ */
+ static void SetManglingEnabled(bool mangle) {
+ Instance().fMangle = mangle;
+ }
+
+ static DSLWriter& Instance();
+
+private:
+ DSLWriter(std::unique_ptr<SkSL::Compiler> compiler);
+
+ static std::unique_ptr<DSLWriter> CreateInstance();
+
+ SkSL::Program::Settings fSettings;
+ std::unique_ptr<SkSL::Compiler> fCompiler;
+ int fNameCount = 0;
+ std::vector<std::unique_ptr<SkSL::ProgramElement>> fProgramElements;
+ ErrorHandler* fErrorHandler = nullptr;
+ bool fMangle = true;
+#if !defined(SKSL_STANDALONE) && SK_SUPPORT_GPU
+ struct StackFrame {
+ GrGLSLFragmentProcessor* fProcessor;
+ GrGLSLFragmentProcessor::EmitArgs* fEmitArgs;
+ };
+ std::stack<StackFrame> fStack;
+#endif // !defined(SKSL_STANDALONE) && SK_SUPPORT_GPU
+};
+
+} // namespace dsl
+
+} // namespace SkSL
+
+#endif
diff --git a/src/sksl/ir/SkSLCodeStringExpression.h b/src/sksl/ir/SkSLCodeStringExpression.h
new file mode 100644
index 0000000..d977027
--- /dev/null
+++ b/src/sksl/ir/SkSLCodeStringExpression.h
@@ -0,0 +1,49 @@
+/*
+ * Copyright 2020 Google LLC.
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#ifndef SKSL_CODESTRINGEXPRESSION
+#define SKSL_CODESTRINGEXPRESSION
+
+#include "src/sksl/ir/SkSLExpression.h"
+
+namespace SkSL {
+
+/**
+ * Represents a literal string of SkSL code. This is only valid within SkSL DSL code, and is
+ * intended as a temporary measure to support a couple of spots within Skia that are currently
+ * generating raw strings of code. These will eventually transition to producing Expressions,
+ * allowing this class to be deleted.
+ */
+class CodeStringExpression final : public Expression {
+public:
+ static constexpr Kind kExpressionKind = Kind::kCodeString;
+
+ CodeStringExpression(String code, const Type* type)
+ : INHERITED(/*offset=*/-1, kExpressionKind, type)
+ , fCode(std::move(code)) {}
+
+ bool hasProperty(Property property) const override {
+ return false;
+ }
+
+ std::unique_ptr<Expression> clone() const override {
+ return std::make_unique<CodeStringExpression>(fCode, &this->type());
+ }
+
+ String description() const override {
+ return fCode;
+ }
+
+private:
+ String fCode;
+
+ using INHERITED = Expression;
+};
+
+} // namespace SkSL
+
+#endif
diff --git a/src/sksl/ir/SkSLExpression.h b/src/sksl/ir/SkSLExpression.h
index 77b90e9..aaedc02 100644
--- a/src/sksl/ir/SkSLExpression.h
+++ b/src/sksl/ir/SkSLExpression.h
@@ -30,6 +30,7 @@
enum class Kind {
kBinary = (int) Statement::Kind::kLast + 1,
kBoolLiteral,
+ kCodeString,
kConstructor,
kDefined,
kExternalFunctionCall,
diff --git a/src/sksl/lex/LexUtil.h b/src/sksl/lex/LexUtil.h
index 338b864..a30652e 100644
--- a/src/sksl/lex/LexUtil.h
+++ b/src/sksl/lex/LexUtil.h
@@ -14,5 +14,6 @@
#define ABORT(...) (fprintf(stderr, __VA_ARGS__), abort())
#define SkASSERT(x) (void)((x) || (ABORT("failed SkASSERT(%s): %s:%d\n", #x, __FILE__, __LINE__), 0))
+#define SkASSERTF(x, msg, ...)
#endif
diff --git a/tests/SkSLDSLTest.cpp b/tests/SkSLDSLTest.cpp
new file mode 100644
index 0000000..930f24e
--- /dev/null
+++ b/tests/SkSLDSLTest.cpp
@@ -0,0 +1,898 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#include "src/sksl/SkSLIRGenerator.h"
+#include "src/sksl/dsl/DSL.h"
+
+#include "tests/Test.h"
+
+using namespace SkSL::dsl;
+
+class AutoDisableMangle {
+public:
+ AutoDisableMangle() {
+ fOldState = DSLWriter::ManglingEnabled();
+ DSLWriter::SetManglingEnabled(false);
+ DSLWriter::IRGenerator().pushSymbolTable();
+ }
+
+ ~AutoDisableMangle() {
+ DSLWriter::SetManglingEnabled(fOldState);
+ DSLWriter::IRGenerator().popSymbolTable();
+ }
+
+private:
+ bool fOldState;
+};
+
+class ExpectError : public ErrorHandler {
+public:
+ ExpectError(skiatest::Reporter* reporter, const char* msg)
+ : fMsg(msg)
+ , fReporter(reporter) {
+ SetErrorHandler(this);
+ }
+
+ ~ExpectError() override {
+ REPORTER_ASSERT(fReporter, !fMsg);
+ SetErrorHandler(nullptr);
+ }
+
+ void handleError(const char* msg) override {
+ REPORTER_ASSERT(fReporter, !strcmp(msg, fMsg),
+ "Error mismatch: expected:\n%sbut received:\n%s", fMsg, msg);
+ fMsg = nullptr;
+ }
+
+private:
+ const char* fMsg;
+ skiatest::Reporter* fReporter;
+};
+
+DEF_TEST(DSLPlus, r) {
+ AutoDisableMangle disableMangle;
+ Var a(kFloat, "a"), b(kFloat, "b");
+ Expression e1 = a + b;
+ REPORTER_ASSERT(r, e1.release()->description() == "(a + b)");
+
+ Expression e2 = a + 1;
+ REPORTER_ASSERT(r, e2.release()->description() == "(a + 1.0)");
+
+ Expression e3 = 0.5 + a + -99;
+ REPORTER_ASSERT(r, e3.release()->description() == "((0.5 + a) + -99.0)");
+
+ Expression e4 = a += b + 1;
+ REPORTER_ASSERT(r, e4.release()->description() == "(a += (b + 1.0))");
+
+ {
+ ExpectError error(r, "error: type mismatch: '+' cannot operate on 'bool2', 'float'\n");
+ (Bool2(true) + a).release();
+ }
+
+ {
+ ExpectError error(r, "error: type mismatch: '+=' cannot operate on 'float', 'bool2'\n");
+ (a += Bool2(true)).release();
+ }
+}
+
+DEF_TEST(DSLMinus, r) {
+ AutoDisableMangle disableMangle;
+ Var a(kInt, "a"), b(kInt, "b");
+ Expression e1 = a - b;
+ REPORTER_ASSERT(r, e1.release()->description() == "(a - b)");
+
+ Expression e2 = a - 1;
+ REPORTER_ASSERT(r, e2.release()->description() == "(a - 1)");
+
+ Expression e3 = 2 - a - b;
+ REPORTER_ASSERT(r, e3.release()->description() == "((2 - a) - b)");
+
+ Expression e4 = a -= b + 1;
+ REPORTER_ASSERT(r, e4.release()->description() == "(a -= (b + 1))");
+
+ {
+ ExpectError error(r, "error: type mismatch: '-' cannot operate on 'bool2', 'int'\n");
+ (Bool2(true) - a).release();
+ }
+
+ {
+ ExpectError error(r, "error: type mismatch: '-=' cannot operate on 'int', 'bool2'\n");
+ (a -= Bool2(true)).release();
+ }
+}
+
+DEF_TEST(DSLMultiply, r) {
+ AutoDisableMangle disableMangle;
+ Var a(kFloat, "a"), b(kFloat, "b");
+ Expression e1 = a * b;
+ REPORTER_ASSERT(r, e1.release()->description() == "(a * b)");
+
+ Expression e2 = a * 1;
+ REPORTER_ASSERT(r, e2.release()->description() == "(a * 1.0)");
+
+ Expression e3 = 0.5 * a * -99;
+ REPORTER_ASSERT(r, e3.release()->description() == "((0.5 * a) * -99.0)");
+
+ Expression e4 = a *= b + 1;
+ REPORTER_ASSERT(r, e4.release()->description() == "(a *= (b + 1.0))");
+
+ {
+ ExpectError error(r, "error: type mismatch: '*' cannot operate on 'bool2', 'float'\n");
+ (Bool2(true) * a).release();
+ }
+
+ {
+ ExpectError error(r, "error: type mismatch: '*=' cannot operate on 'float', 'bool2'\n");
+ (a *= Bool2(true)).release();
+ }
+}
+
+DEF_TEST(DSLDivide, r) {
+ AutoDisableMangle disableMangle;
+ Var a(kFloat, "a"), b(kFloat, "b");
+ Expression e1 = a / b;
+ REPORTER_ASSERT(r, e1.release()->description() == "(a / b)");
+
+ Expression e2 = a / 1;
+ REPORTER_ASSERT(r, e2.release()->description() == "(a / 1.0)");
+
+ Expression e3 = 0.5 / a / -99;
+ REPORTER_ASSERT(r, e3.release()->description() == "((0.5 / a) / -99.0)");
+
+ Expression e4 = b / (a - 1);
+ REPORTER_ASSERT(r, e4.release()->description() == "(b / (a - 1.0))");
+
+ Expression e5 = a /= b + 1;
+ REPORTER_ASSERT(r, e5.release()->description() == "(a /= (b + 1.0))");
+
+ {
+ ExpectError error(r, "error: type mismatch: '/' cannot operate on 'bool2', 'float'\n");
+ (Bool2(true) / a).release();
+ }
+
+ {
+ ExpectError error(r, "error: type mismatch: '/=' cannot operate on 'float', 'bool2'\n");
+ (a /= Bool2(true)).release();
+ }
+}
+
+DEF_TEST(DSLMod, r) {
+ AutoDisableMangle disableMangle;
+ Var a(kInt, "a"), b(kInt, "b");
+ Expression e1 = a % b;
+ REPORTER_ASSERT(r, e1.release()->description() == "(a % b)");
+
+ Expression e2 = a % 2;
+ REPORTER_ASSERT(r, e2.release()->description() == "(a % 2)");
+
+ Expression e3 = 10 % a % -99;
+ REPORTER_ASSERT(r, e3.release()->description() == "((10 % a) % -99)");
+
+ Expression e4 = a %= b + 1;
+ REPORTER_ASSERT(r, e4.release()->description() == "(a %= (b + 1))");
+
+ {
+ ExpectError error(r, "error: type mismatch: '%' cannot operate on 'bool2', 'int'\n");
+ (Bool2(true) % a).release();
+ }
+
+ {
+ ExpectError error(r, "error: type mismatch: '%=' cannot operate on 'int', 'bool2'\n");
+ (a %= Bool2(true)).release();
+ }
+}
+
+DEF_TEST(DSLShl, r) {
+ AutoDisableMangle disableMangle;
+ Var a(kInt, "a"), b(kInt, "b");
+ Expression e1 = a << b;
+ REPORTER_ASSERT(r, e1.release()->description() == "(a << b)");
+
+ Expression e2 = a << 1;
+ REPORTER_ASSERT(r, e2.release()->description() == "(a << 1)");
+
+ Expression e3 = 1 << a << 2;
+ REPORTER_ASSERT(r, e3.release()->description() == "((1 << a) << 2)");
+
+ Expression e4 = a <<= b + 1;
+ REPORTER_ASSERT(r, e4.release()->description() == "(a <<= (b + 1))");
+
+ {
+ ExpectError error(r, "error: type mismatch: '<<' cannot operate on 'bool2', 'int'\n");
+ (Bool2(true) << a).release();
+ }
+
+ {
+ ExpectError error(r, "error: type mismatch: '<<=' cannot operate on 'int', 'bool2'\n");
+ (a <<= Bool2(true)).release();
+ }
+}
+
+DEF_TEST(DSLShr, r) {
+ AutoDisableMangle disableMangle;
+ Var a(kInt, "a"), b(kInt, "b");
+ Expression e1 = a >> b;
+ REPORTER_ASSERT(r, e1.release()->description() == "(a >> b)");
+
+ Expression e2 = a >> 1;
+ REPORTER_ASSERT(r, e2.release()->description() == "(a >> 1)");
+
+ Expression e3 = 1 >> a >> 2;
+ REPORTER_ASSERT(r, e3.release()->description() == "((1 >> a) >> 2)");
+
+ Expression e4 = a >>= b + 1;
+ REPORTER_ASSERT(r, e4.release()->description() == "(a >>= (b + 1))");
+
+ {
+ ExpectError error(r, "error: type mismatch: '>>' cannot operate on 'bool2', 'int'\n");
+ (Bool2(true) >> a).release();
+ }
+
+ {
+ ExpectError error(r, "error: type mismatch: '>>=' cannot operate on 'int', 'bool2'\n");
+ (a >>= Bool2(true)).release();
+ }
+}
+
+DEF_TEST(DSLBitwiseAnd, r) {
+ AutoDisableMangle disableMangle;
+ Var a(kInt, "a"), b(kInt, "b");
+ Expression e1 = a & b;
+ REPORTER_ASSERT(r, e1.release()->description() == "(a & b)");
+
+ Expression e2 = a & 1;
+ REPORTER_ASSERT(r, e2.release()->description() == "(a & 1)");
+
+ Expression e3 = 1 & a & 2;
+ REPORTER_ASSERT(r, e3.release()->description() == "((1 & a) & 2)");
+
+ Expression e4 = a &= b + 1;
+ REPORTER_ASSERT(r, e4.release()->description() == "(a &= (b + 1))");
+
+ {
+ ExpectError error(r, "error: type mismatch: '&' cannot operate on 'bool2', 'int'\n");
+ (Bool2(true) & a).release();
+ }
+
+ {
+ ExpectError error(r, "error: type mismatch: '&=' cannot operate on 'int', 'bool2'\n");
+ (a &= Bool2(true)).release();
+ }
+}
+
+DEF_TEST(DSLBitwiseOr, r) {
+ AutoDisableMangle disableMangle;
+ Var a(kInt, "a"), b(kInt, "b");
+ Expression e1 = a | b;
+ REPORTER_ASSERT(r, e1.release()->description() == "(a | b)");
+
+ Expression e2 = a | 1;
+ REPORTER_ASSERT(r, e2.release()->description() == "(a | 1)");
+
+ Expression e3 = 1 | a | 2;
+ REPORTER_ASSERT(r, e3.release()->description() == "((1 | a) | 2)");
+
+ Expression e4 = a |= b + 1;
+ REPORTER_ASSERT(r, e4.release()->description() == "(a |= (b + 1))");
+
+ {
+ ExpectError error(r, "error: type mismatch: '|' cannot operate on 'bool2', 'int'\n");
+ (Bool2(true) | a).release();
+ }
+
+ {
+ ExpectError error(r, "error: type mismatch: '|=' cannot operate on 'int', 'bool2'\n");
+ (a |= Bool2(true)).release();
+ }
+}
+
+DEF_TEST(DSLBitwiseXor, r) {
+ AutoDisableMangle disableMangle;
+ Var a(kInt, "a"), b(kInt, "b");
+ Expression e1 = a ^ b;
+ REPORTER_ASSERT(r, e1.release()->description() == "(a ^ b)");
+
+ Expression e2 = a ^ 1;
+ REPORTER_ASSERT(r, e2.release()->description() == "(a ^ 1)");
+
+ Expression e3 = 1 ^ a ^ 2;
+ REPORTER_ASSERT(r, e3.release()->description() == "((1 ^ a) ^ 2)");
+
+ Expression e4 = a ^= b + 1;
+ REPORTER_ASSERT(r, e4.release()->description() == "(a ^= (b + 1))");
+
+ {
+ ExpectError error(r, "error: type mismatch: '^' cannot operate on 'bool2', 'int'\n");
+ (Bool2(true) ^ a).release();
+ }
+
+ {
+ ExpectError error(r, "error: type mismatch: '^=' cannot operate on 'int', 'bool2'\n");
+ (a ^= Bool2(true)).release();
+ }
+}
+
+DEF_TEST(DSLLogicalAnd, r) {
+ AutoDisableMangle disableMangle;
+ Var a(kBool, "a"), b(kBool, "b");
+ Expression e1 = a && b;
+ REPORTER_ASSERT(r, e1.release()->description() == "(a && b)");
+
+ Expression e2 = a && true && b;
+ REPORTER_ASSERT(r, e2.release()->description() == "(a && b)");
+
+ Expression e3 = a && false && b;
+ REPORTER_ASSERT(r, e3.release()->description() == "false");
+
+ {
+ ExpectError error(r, "error: expected 'bool', but found 'int'\n");
+ (a && 5).release();
+ }
+}
+
+DEF_TEST(DSLLogicalOr, r) {
+ AutoDisableMangle disableMangle;
+ Var a(kBool, "a"), b(kBool, "b");
+ Expression e1 = a || b;
+ REPORTER_ASSERT(r, e1.release()->description() == "(a || b)");
+
+ Expression e2 = a || true || b;
+ REPORTER_ASSERT(r, e2.release()->description() == "true");
+
+ Expression e3 = a || false || b;
+ REPORTER_ASSERT(r, e3.release()->description() == "(a || b)");
+
+ {
+ ExpectError error(r, "error: expected 'bool', but found 'int'\n");
+ (a || 5).release();
+ }
+}
+
+DEF_TEST(DSLComma, r) {
+ AutoDisableMangle disableMangle;
+ Var a(kInt, "a"), b(kInt, "b");
+ Expression e1 = (a += b, b);
+ REPORTER_ASSERT(r, e1.release()->description() == "((a += b) , b)");
+
+ Expression e2 = (a += b, b += b, Int2(a));
+ REPORTER_ASSERT(r, e2.release()->description() == "(((a += b) , (b += b)) , int2(a))");
+}
+
+DEF_TEST(DSLEqual, r) {
+ AutoDisableMangle disableMangle;
+ Var a(kInt, "a"), b(kInt, "b");
+ Expression e1 = a == b;
+ REPORTER_ASSERT(r, e1.release()->description() == "(a == b)");
+
+ Expression e2 = a == 5;
+ REPORTER_ASSERT(r, e2.release()->description() == "(a == 5)");
+
+ {
+ ExpectError error(r, "error: type mismatch: '==' cannot operate on 'int', 'bool2'\n");
+ (a == Bool2(true)).release();
+ }
+}
+
+DEF_TEST(DSLNotEqual, r) {
+ AutoDisableMangle disableMangle;
+ Var a(kInt, "a"), b(kInt, "b");
+ Expression e1 = a != b;
+ REPORTER_ASSERT(r, e1.release()->description() == "(a != b)");
+
+ Expression e2 = a != 5;
+ REPORTER_ASSERT(r, e2.release()->description() == "(a != 5)");
+
+ {
+ ExpectError error(r, "error: type mismatch: '!=' cannot operate on 'int', 'bool2'\n");
+ (a != Bool2(true)).release();
+ }
+}
+
+DEF_TEST(DSLGreaterThan, r) {
+ AutoDisableMangle disableMangle;
+ Var a(kInt, "a"), b(kInt, "b");
+ Expression e1 = a > b;
+ REPORTER_ASSERT(r, e1.release()->description() == "(a > b)");
+
+ Expression e2 = a > 5;
+ REPORTER_ASSERT(r, e2.release()->description() == "(a > 5)");
+
+ {
+ ExpectError error(r, "error: type mismatch: '>' cannot operate on 'int', 'bool2'\n");
+ (a > Bool2(true)).release();
+ }
+}
+
+DEF_TEST(DSLGreaterThanOrEqual, r) {
+ AutoDisableMangle disableMangle;
+ Var a(kInt, "a"), b(kInt, "b");
+ Expression e1 = a >= b;
+ REPORTER_ASSERT(r, e1.release()->description() == "(a >= b)");
+
+ Expression e2 = a >= 5;
+ REPORTER_ASSERT(r, e2.release()->description() == "(a >= 5)");
+
+ {
+ ExpectError error(r, "error: type mismatch: '>=' cannot operate on 'int', 'bool2'\n");
+ (a >= Bool2(true)).release();
+ }
+}
+
+DEF_TEST(DSLLessThan, r) {
+ AutoDisableMangle disableMangle;
+ Var a(kInt, "a"), b(kInt, "b");
+ Expression e1 = a < b;
+ REPORTER_ASSERT(r, e1.release()->description() == "(a < b)");
+
+ Expression e2 = a < 5;
+ REPORTER_ASSERT(r, e2.release()->description() == "(a < 5)");
+
+ {
+ ExpectError error(r, "error: type mismatch: '<' cannot operate on 'int', 'bool2'\n");
+ (a < Bool2(true)).release();
+ }
+}
+
+DEF_TEST(DSLLessThanOrEqual, r) {
+ AutoDisableMangle disableMangle;
+ Var a(kInt, "a"), b(kInt, "b");
+ Expression e1 = a <= b;
+ REPORTER_ASSERT(r, e1.release()->description() == "(a <= b)");
+
+ Expression e2 = a <= 5;
+ REPORTER_ASSERT(r, e2.release()->description() == "(a <= 5)");
+
+ {
+ ExpectError error(r, "error: type mismatch: '<=' cannot operate on 'int', 'bool2'\n");
+ (a <= Bool2(true)).release();
+ }
+}
+
+DEF_TEST(DSLLogicalNot, r) {
+ AutoDisableMangle disableMangle;
+ Var a(kInt, "a"), b(kInt, "b");
+ Expression e1 = !(a <= b);
+ REPORTER_ASSERT(r, e1.release()->description() == "!(a <= b)");
+
+ {
+ ExpectError error(r, "error: '!' cannot operate on 'int'\n");
+ (!a).release();
+ }
+}
+
+DEF_TEST(DSLBitwiseNot, r) {
+ AutoDisableMangle disableMangle;
+ Var a(kInt, "a"), b(kBool, "b");
+ Expression e1 = ~a;
+ REPORTER_ASSERT(r, e1.release()->description() == "~a");
+
+ {
+ ExpectError error(r, "error: '~' cannot operate on 'bool'\n");
+ (~b).release();
+ }
+}
+
+DEF_TEST(DSLIncrement, r) {
+ AutoDisableMangle disableMangle;
+ Var a(kInt, "a"), b(kBool, "b");
+ Expression e1 = ++a;
+ REPORTER_ASSERT(r, e1.release()->description() == "++a");
+
+ Expression e2 = a++;
+ REPORTER_ASSERT(r, e2.release()->description() == "a++");
+
+ {
+ ExpectError error(r, "error: '++' cannot operate on 'bool'\n");
+ (++b).release();
+ }
+
+ {
+ ExpectError error(r, "error: '++' cannot operate on 'bool'\n");
+ (b++).release();
+ }
+
+ {
+ ExpectError error(r, "error: cannot assign to this expression\n");
+ (++(a + 1)).release();
+ }
+
+ {
+ ExpectError error(r, "error: cannot assign to this expression\n");
+ ((a + 1)++).release();
+ }
+}
+
+DEF_TEST(DSLDecrement, r) {
+ AutoDisableMangle disableMangle;
+ Var a(kInt, "a"), b(kBool, "b");
+ Expression e1 = --a;
+ REPORTER_ASSERT(r, e1.release()->description() == "--a");
+
+ Expression e2 = a--;
+ REPORTER_ASSERT(r, e2.release()->description() == "a--");
+
+ {
+ ExpectError error(r, "error: '--' cannot operate on 'bool'\n");
+ (--b).release();
+ }
+
+ {
+ ExpectError error(r, "error: '--' cannot operate on 'bool'\n");
+ (b--).release();
+ }
+
+ {
+ ExpectError error(r, "error: cannot assign to this expression\n");
+ (--(a + 1)).release();
+ }
+
+ {
+ ExpectError error(r, "error: cannot assign to this expression\n");
+ ((a + 1)--).release();
+ }
+}
+
+DEF_TEST(DSLFloat, r) {
+ AutoDisableMangle disableMangle;
+ Expression e1 = Float(0);
+ REPORTER_ASSERT(r, e1.release()->description() == "0.0");
+
+ Expression e2 = Float2(0);
+ REPORTER_ASSERT(r, e2.release()->description() == "float2(0.0)");
+
+ Expression e3 = Float2(0, 1);
+ REPORTER_ASSERT(r, e3.release()->description() == "float2(0.0, 1.0)");
+
+ Expression e4 = Float3(0);
+ REPORTER_ASSERT(r, e4.release()->description() == "float3(0.0)");
+
+ Expression e5 = Float3(Float2(0, 1), 2);
+ REPORTER_ASSERT(r, e5.release()->description() == "float3(float2(0.0, 1.0), 2.0)");
+
+ Expression e6 = Float3(0, 1, 2);
+ REPORTER_ASSERT(r, e6.release()->description() == "float3(0.0, 1.0, 2.0)");
+
+ Expression e7 = Float4(0);
+ REPORTER_ASSERT(r, e7.release()->description() == "float4(0.0)");
+
+ Expression e8 = Float4(Float2(0, 1), Float2(2, 3));
+ REPORTER_ASSERT(r, e8.release()->description() == "float4(float2(0.0, 1.0), float2(2.0, 3.0))");
+
+ Expression e9 = Float4(0, 1, Float2(2, 3));
+ REPORTER_ASSERT(r, e9.release()->description() == "float4(0.0, 1.0, float2(2.0, 3.0))");
+
+ Expression e10 = Float4(0, 1, 2, 3);
+ REPORTER_ASSERT(r, e10.release()->description() == "float4(0.0, 1.0, 2.0, 3.0)");
+
+ {
+ ExpectError error(r, "error: invalid arguments to 'float2' constructor (expected 2 scalars,"
+ " but found 4)\n");
+ Float2(Float4(1)).release();
+ }
+
+ {
+ ExpectError error(r, "error: invalid arguments to 'float4' constructor (expected 4 scalars,"
+ " but found 3)\n");
+ Float4(Float3(1)).release();
+ }
+}
+
+DEF_TEST(DSLHalf, r) {
+ AutoDisableMangle disableMangle;
+ Expression e1 = Half(0);
+ REPORTER_ASSERT(r, e1.release()->description() == "0.0");
+
+ Expression e2 = Half2(0);
+ REPORTER_ASSERT(r, e2.release()->description() == "half2(0.0)");
+
+ Expression e3 = Half2(0, 1);
+ REPORTER_ASSERT(r, e3.release()->description() == "half2(0.0, 1.0)");
+
+ Expression e4 = Half3(0);
+ REPORTER_ASSERT(r, e4.release()->description() == "half3(0.0)");
+
+ Expression e5 = Half3(Half2(0, 1), 2);
+ REPORTER_ASSERT(r, e5.release()->description() == "half3(half2(0.0, 1.0), 2.0)");
+
+ Expression e6 = Half3(0, 1, 2);
+ REPORTER_ASSERT(r, e6.release()->description() == "half3(0.0, 1.0, 2.0)");
+
+ Expression e7 = Half4(0);
+ REPORTER_ASSERT(r, e7.release()->description() == "half4(0.0)");
+
+ Expression e8 = Half4(Half2(0, 1), Half2(2, 3));
+ REPORTER_ASSERT(r, e8.release()->description() == "half4(half2(0.0, 1.0), half2(2.0, 3.0))");
+
+ Expression e9 = Half4(0, 1, Half2(2, 3));
+ REPORTER_ASSERT(r, e9.release()->description() == "half4(0.0, 1.0, half2(2.0, 3.0))");
+
+ Expression e10 = Half4(0, 1, 2, 3);
+ REPORTER_ASSERT(r, e10.release()->description() == "half4(0.0, 1.0, 2.0, 3.0)");
+
+ {
+ ExpectError error(r, "error: invalid arguments to 'half2' constructor (expected 2 scalars,"
+ " but found 4)\n");
+ Half2(Half4(1)).release();
+ }
+
+ {
+ ExpectError error(r, "error: invalid arguments to 'half4' constructor (expected 4 scalars,"
+ " but found 3)\n");
+ Half4(Half3(1)).release();
+ }
+}
+
+DEF_TEST(DSLInt, r) {
+ AutoDisableMangle disableMangle;
+ Expression e1 = Int(0);
+ REPORTER_ASSERT(r, e1.release()->description() == "0");
+
+ Expression e2 = Int2(0);
+ REPORTER_ASSERT(r, e2.release()->description() == "int2(0)");
+
+ Expression e3 = Int2(0, 1);
+ REPORTER_ASSERT(r, e3.release()->description() == "int2(0, 1)");
+
+ Expression e4 = Int3(0);
+ REPORTER_ASSERT(r, e4.release()->description() == "int3(0)");
+
+ Expression e5 = Int3(Int2(0, 1), 2);
+ REPORTER_ASSERT(r, e5.release()->description() == "int3(int2(0, 1), 2)");
+
+ Expression e6 = Int3(0, 1, 2);
+ REPORTER_ASSERT(r, e6.release()->description() == "int3(0, 1, 2)");
+
+ Expression e7 = Int4(0);
+ REPORTER_ASSERT(r, e7.release()->description() == "int4(0)");
+
+ Expression e8 = Int4(Int2(0, 1), Int2(2, 3));
+ REPORTER_ASSERT(r, e8.release()->description() == "int4(int2(0, 1), int2(2, 3))");
+
+ Expression e9 = Int4(0, 1, Int2(2, 3));
+ REPORTER_ASSERT(r, e9.release()->description() == "int4(0, 1, int2(2, 3))");
+
+ Expression e10 = Int4(0, 1, 2, 3);
+ REPORTER_ASSERT(r, e10.release()->description() == "int4(0, 1, 2, 3)");
+
+ {
+ ExpectError error(r, "error: invalid arguments to 'int2' constructor (expected 2 scalars,"
+ " but found 4)\n");
+ Int2(Int4(1)).release();
+ }
+
+ {
+ ExpectError error(r, "error: invalid arguments to 'int4' constructor (expected 4 scalars,"
+ " but found 3)\n");
+ Int4(Int3(1)).release();
+ }
+}
+
+DEF_TEST(DSLShort, r) {
+ AutoDisableMangle disableMangle;
+ Expression e1 = Short(0);
+ REPORTER_ASSERT(r, e1.release()->description() == "short(0)");
+
+ Expression e2 = Short2(0);
+ REPORTER_ASSERT(r, e2.release()->description() == "short2(short(0))");
+
+ Expression e3 = Short2(0, 1);
+ REPORTER_ASSERT(r, e3.release()->description() == "short2(short(0), short(1))");
+
+ Expression e4 = Short3(0);
+ REPORTER_ASSERT(r, e4.release()->description() == "short3(short(0))");
+
+ Expression e5 = Short3(Short2(0, 1), 2);
+ REPORTER_ASSERT(r, e5.release()->description() == "short3(short2(short(0), short(1)), "
+ "short(2))");
+
+ Expression e6 = Short3(0, 1, 2);
+ REPORTER_ASSERT(r, e6.release()->description() == "short3(short(0), short(1), short(2))");
+
+ Expression e7 = Short4(0);
+ REPORTER_ASSERT(r, e7.release()->description() == "short4(short(0))");
+
+ Expression e8 = Short4(Short2(0, 1), Short2(2, 3));
+ REPORTER_ASSERT(r, e8.release()->description() == "short4(short2(short(0), short(1)), "
+ "short2(short(2), short(3)))");
+
+ Expression e9 = Short4(0, 1, Short2(2, 3));
+ REPORTER_ASSERT(r, e9.release()->description() == "short4(short(0), short(1), short2(short(2), "
+ "short(3)))");
+
+ Expression e10 = Short4(0, 1, 2, 3);
+ REPORTER_ASSERT(r, e10.release()->description() == "short4(short(0), short(1), short(2), "
+ "short(3))");
+
+ {
+ ExpectError error(r, "error: invalid arguments to 'short2' constructor (expected 2 scalars,"
+ " but found 4)\n");
+ Short2(Short4(1)).release();
+ }
+
+ {
+ ExpectError error(r, "error: invalid arguments to 'short4' constructor (expected 4 scalars,"
+ " but found 3)\n");
+ Short4(Short3(1)).release();
+ }
+}
+
+DEF_TEST(DSLBool, r) {
+ AutoDisableMangle disableMangle;
+ Expression e1 = Bool2(false);
+ REPORTER_ASSERT(r, e1.release()->description() == "bool2(false)");
+
+ Expression e2 = Bool2(false, true);
+ REPORTER_ASSERT(r, e2.release()->description() == "bool2(false, true)");
+
+ Expression e3 = Bool3(false);
+ REPORTER_ASSERT(r, e3.release()->description() == "bool3(false)");
+
+ Expression e4 = Bool3(Bool2(false, true), false);
+ REPORTER_ASSERT(r, e4.release()->description() == "bool3(bool2(false, true), false)");
+
+ Expression e5 = Bool3(false, true, false);
+ REPORTER_ASSERT(r, e5.release()->description() == "bool3(false, true, false)");
+
+ Expression e6 = Bool4(false);
+ REPORTER_ASSERT(r, e6.release()->description() == "bool4(false)");
+
+ Expression e7 = Bool4(Bool2(false, true), Bool2(false, true));
+ REPORTER_ASSERT(r, e7.release()->description() == "bool4(bool2(false, true), "
+ "bool2(false, true))");
+
+ Expression e8 = Bool4(false, true, Bool2(false, true));
+ REPORTER_ASSERT(r, e8.release()->description() == "bool4(false, true, bool2(false, true))");
+
+ Expression e9 = Bool4(false, true, false, true);
+ REPORTER_ASSERT(r, e9.release()->description() == "bool4(false, true, false, true)");
+
+ {
+ ExpectError error(r, "error: invalid arguments to 'bool2' constructor (expected 2 scalars,"
+ " but found 4)\n");
+ Bool2(Bool4(true)).release();
+ }
+
+ {
+ ExpectError error(r, "error: invalid arguments to 'bool4' constructor (expected 4 scalars,"
+ " but found 3)\n");
+ Bool4(Bool3(true)).release();
+ }
+}
+
+DEF_TEST(DSLBlock, r) {
+ AutoDisableMangle disableMangle;
+ Statement x = Block();
+ REPORTER_ASSERT(r, x.release()->description() == "{\n}\n");
+ Var a(kInt, "a"), b(kInt, "b");
+ Statement y = Block(Declare(a, 1), Declare(b, 2), a = b);
+ REPORTER_ASSERT(r, y.release()->description() == "{\nint a = 1;\nint b = 2;\n(a = b);\n}\n");
+}
+
+DEF_TEST(DSLDeclare, r) {
+ AutoDisableMangle disableMangle;
+ Var a(kHalf4, "a"), b(kHalf4, "b");
+ Statement x = Declare(a);
+ REPORTER_ASSERT(r, x.release()->description() == "half4 a;");
+ Statement y = Declare(b, Half4(1));
+ REPORTER_ASSERT(r, y.release()->description() == "half4 b = half4(1.0);");
+
+ {
+ Var c(kHalf4, "c");
+ ExpectError error(r, "error: expected 'half4', but found 'int'\n");
+ Declare(c, 1).release();
+ }
+}
+
+DEF_TEST(DSLDo, r) {
+ AutoDisableMangle disableMangle;
+ Statement x = Do(Block(), true);
+ REPORTER_ASSERT(r, x.release()->description() == "do {\n}\n while (true);");
+
+ Var a(kFloat, "a"), b(kFloat, "b");
+ Statement y = Do(Block(a++, --b), a != b);
+ REPORTER_ASSERT(r, y.release()->description() == "do {\na++;\n--b;\n}\n while ((a != b));");
+
+ {
+ ExpectError error(r, "error: expected 'bool', but found 'int'\n");
+ Do(Block(), 7).release();
+ }
+}
+
+DEF_TEST(DSLFor, r) {
+ AutoDisableMangle disableMangle;
+ Statement x = For(Statement(), Expression(), Expression(), Block());
+ REPORTER_ASSERT(r, x.release()->description() == "for (; ; ) {\n}\n");
+
+ Var i(kInt, "i");
+ Statement y = For(Declare(i, 0), i < 10, ++i, i += 5);
+ REPORTER_ASSERT(r, y.release()->description() == "for (int i = 0; (i < 10); ++i) (i += 5);");
+
+ {
+ ExpectError error(r, "error: expected 'bool', but found 'int'\n");
+ For(i = 0, i + 10, ++i, i += 5).release();
+ }
+}
+
+DEF_TEST(DSLFunction, r) {
+ AutoDisableMangle disableMangle;
+ DSLWriter::ProgramElements().clear();
+ Var coords(kHalf2, "coords");
+ DSLFunction(kVoid, "main", coords).define(
+ sk_FragColor() = Half4(coords, 0, 1)
+ );
+ REPORTER_ASSERT(r, DSLWriter::ProgramElements().size() == 1);
+ REPORTER_ASSERT(r, DSLWriter::ProgramElements()[0]->description() ==
+R"(void main(half2 coords) {
+(sk_FragColor = half4(coords, 0.0, 1.0));
+}
+)");
+}
+
+DEF_TEST(DSLIf, r) {
+ AutoDisableMangle disableMangle;
+ Var a(kFloat, "a"), b(kFloat, "b");
+ Statement x = If(a > b, a -= b);
+ REPORTER_ASSERT(r, x.release()->description() == "if ((a > b)) (a -= b);");
+
+ Statement y = If(a > b, a -= b, b -= a);
+ REPORTER_ASSERT(r, y.release()->description() == "if ((a > b)) (a -= b); else (b -= a);");
+
+ {
+ ExpectError error(r, "error: expected 'bool', but found 'float'\n");
+ If(a + b, a -= b).release();
+ }
+}
+
+DEF_TEST(DSLTernary, r) {
+ AutoDisableMangle disableMangle;
+ Var a(kInt, "a");
+ Expression x = Ternary(a > 0, 1, -1);
+ REPORTER_ASSERT(r, x.release()->description() == "((a > 0) ? 1 : -1)");
+
+ {
+ ExpectError error(r, "error: expected 'bool', but found 'int'\n");
+ Ternary(a, 1, -1).release();
+ }
+
+ {
+ ExpectError error(r, "error: ternary operator result mismatch: 'float2', 'float3'\n");
+ Ternary(a > 0, Float2(1), Float3(1)).release();
+ }
+}
+
+DEF_TEST(DSLWhile, r) {
+ AutoDisableMangle disableMangle;
+ Statement x = While(true, Block());
+ REPORTER_ASSERT(r, x.release()->description() == "while (true) {\n}\n");
+
+ Var a(kFloat, "a"), b(kFloat, "b");
+ Statement y = While(a != b, Block(a++, --b));
+ REPORTER_ASSERT(r, y.release()->description() == "while ((a != b)) {\na++;\n--b;\n}\n");
+
+ {
+ ExpectError error(r, "error: expected 'bool', but found 'int'\n");
+ While(7, Block()).release();
+ }
+}
+
+DEF_TEST(DSLBuiltins, r) {
+ AutoDisableMangle disableMangle;
+ Var a(kHalf4, "a"), b(kHalf4, "b");
+ REPORTER_ASSERT(r, ceil(a).release()->description() == "ceil(a)");
+ REPORTER_ASSERT(r, clamp(a, 0, 1).release()->description() == "clamp(a, 0.0, 1.0)");
+ REPORTER_ASSERT(r, dot(a, b).release()->description() == "dot(a, b)");
+ REPORTER_ASSERT(r, floor(a).release()->description() == "floor(a)");
+ REPORTER_ASSERT(r, saturate(a).release()->description() == "saturate(a)");
+ REPORTER_ASSERT(r, unpremul(a).release()->description() == "unpremul(a)");
+
+ // these calls all go through the normal channels, so it ought to be sufficient to prove that
+ // one of them reports errors correctly
+ {
+ ExpectError error(r, "error: no match for ceil(bool)\n");
+ ceil(a == b).release();
+ }
+}
diff --git a/tests/sksl/errors/golden/BinaryTypeMismatch.glsl b/tests/sksl/errors/golden/BinaryTypeMismatch.glsl
index 2d4cad1..774a5fd 100644
--- a/tests/sksl/errors/golden/BinaryTypeMismatch.glsl
+++ b/tests/sksl/errors/golden/BinaryTypeMismatch.glsl
@@ -1,10 +1,11 @@
### Compilation failed:
error: 1: type mismatch: '*' cannot operate on 'int', 'bool'
-error: 2: type mismatch: '||' cannot operate on 'int', 'float'
+error: 2: expected 'bool', but found 'int'
+error: 2: expected 'bool', but found 'float'
error: 3: type mismatch: '==' cannot operate on 'float2', 'int'
error: 4: type mismatch: '!=' cannot operate on 'float2', 'int'
error: 6: type mismatch: '<' cannot operate on 'float2', 'float2'
error: 7: type mismatch: '<' cannot operate on 'float2', 'float'
error: 8: type mismatch: '<' cannot operate on 'float', 'float2'
-7 errors
+8 errors