SkSL Metal backend can now handle CCPR
Bug: skia:
Change-Id: I796a40db46174b405495af8234c5b8d7920a46d6
Reviewed-on: https://skia-review.googlesource.com/c/189985
Reviewed-by: Jim Van Verth <jvanverth@google.com>
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
diff --git a/src/sksl/SkSLMetalCodeGenerator.cpp b/src/sksl/SkSLMetalCodeGenerator.cpp
index 928fa23..da206b4 100644
--- a/src/sksl/SkSLMetalCodeGenerator.cpp
+++ b/src/sksl/SkSLMetalCodeGenerator.cpp
@@ -26,6 +26,8 @@
#define SPECIAL(x) std::make_pair(kSpecial_IntrinsicKind, k ## x ## _SpecialIntrinsic)
fIntrinsicMap[String("texture")] = SPECIAL(Texture);
fIntrinsicMap[String("mod")] = SPECIAL(Mod);
+ fIntrinsicMap[String("equal")] = METAL(Equal);
+ fIntrinsicMap[String("notEqual")] = METAL(NotEqual);
fIntrinsicMap[String("lessThan")] = METAL(LessThan);
fIntrinsicMap[String("lessThanEqual")] = METAL(LessThanEqual);
fIntrinsicMap[String("greaterThan")] = METAL(GreaterThan);
@@ -172,6 +174,12 @@
case kMetal_IntrinsicKind:
this->writeExpression(*c.fArguments[0], kSequence_Precedence);
switch ((MetalIntrinsic) intrinsicId) {
+ case kEqual_MetalIntrinsic:
+ this->write(" == ");
+ break;
+ case kNotEqual_MetalIntrinsic:
+ this->write(" != ");
+ break;
case kLessThan_MetalIntrinsic:
this->write(" < ");
break;
@@ -248,18 +256,82 @@
}
void MetalCodeGenerator::writeInverseHack(const Expression& mat) {
- String name = "ERROR_MatrixInverseNotImplementedFor_" + mat.fType.name();
- if (mat.fType == *fContext.fFloat2x2_Type) {
- name = "_inverse2";
+ String typeName = mat.fType.name();
+ String name = typeName + "_inverse";
+ if (mat.fType == *fContext.fFloat2x2_Type || mat.fType == *fContext.fHalf2x2_Type) {
if (fWrittenIntrinsics.find(name) == fWrittenIntrinsics.end()) {
fWrittenIntrinsics.insert(name);
fExtraFunctions.writeText((
- "float2x2 " + name + "(float2x2 m) {"
+ typeName + " " + name + "(" + typeName + " m) {"
" return float2x2(m[1][1], -m[0][1], -m[1][0], m[0][0]) * (1/determinant(m));"
"}"
).c_str());
}
}
+ else if (mat.fType == *fContext.fFloat3x3_Type || mat.fType == *fContext.fHalf3x3_Type) {
+ if (fWrittenIntrinsics.find(name) == fWrittenIntrinsics.end()) {
+ fWrittenIntrinsics.insert(name);
+ fExtraFunctions.writeText((
+ typeName + " " + name + "(" + typeName + " m) {"
+ " float a00 = m[0][0], a01 = m[0][1], a02 = m[0][2];"
+ " float a10 = m[1][0], a11 = m[1][1], a12 = m[1][2];"
+ " float a20 = m[2][0], a21 = m[2][1], a22 = m[2][2];"
+ " float b01 = a22 * a11 - a12 * a21;"
+ " float b11 = -a22 * a10 + a12 * a20;"
+ " float b21 = a21 * a10 - a11 * a20;"
+ " float det = a00 * b01 + a01 * b11 + a02 * b21;"
+ " return " + typeName +
+ " (b01, (-a22 * a01 + a02 * a21), (a12 * a01 - a02 * a11),"
+ " b11, (a22 * a00 - a02 * a20), (-a12 * a00 + a02 * a10),"
+ " b21, (-a21 * a00 + a01 * a20), (a11 * a00 - a01 * a10)) * "
+ " (1/det);"
+ "}"
+ ).c_str());
+ }
+ }
+ else if (mat.fType == *fContext.fFloat4x4_Type || mat.fType == *fContext.fHalf4x4_Type) {
+ if (fWrittenIntrinsics.find(name) == fWrittenIntrinsics.end()) {
+ fWrittenIntrinsics.insert(name);
+ fExtraFunctions.writeText((
+ typeName + " " + name + "(" + typeName + " m) {"
+ " float a00 = m[0][0], a01 = m[0][1], a02 = m[0][2], a03 = m[0][3];"
+ " float a10 = m[1][0], a11 = m[1][1], a12 = m[1][2], a13 = m[1][3];"
+ " float a20 = m[2][0], a21 = m[2][1], a22 = m[2][2], a23 = m[2][3];"
+ " float a30 = m[3][0], a31 = m[3][1], a32 = m[3][2], a33 = m[3][3];"
+ " float b00 = a00 * a11 - a01 * a10;"
+ " float b01 = a00 * a12 - a02 * a10;"
+ " float b02 = a00 * a13 - a03 * a10;"
+ " float b03 = a01 * a12 - a02 * a11;"
+ " float b04 = a01 * a13 - a03 * a11;"
+ " float b05 = a02 * a13 - a03 * a12;"
+ " float b06 = a20 * a31 - a21 * a30;"
+ " float b07 = a20 * a32 - a22 * a30;"
+ " float b08 = a20 * a33 - a23 * a30;"
+ " float b09 = a21 * a32 - a22 * a31;"
+ " float b10 = a21 * a33 - a23 * a31;"
+ " float b11 = a22 * a33 - a23 * a32;"
+ " float det = b00 * b11 - b01 * b10 + b02 * b09 + b03 * b08 - "
+ " b04 * b07 + b05 * b06;"
+ " return " + typeName + "(a11 * b11 - a12 * b10 + a13 * b09,"
+ " a02 * b10 - a01 * b11 - a03 * b09,"
+ " a31 * b05 - a32 * b04 + a33 * b03,"
+ " a22 * b04 - a21 * b05 - a23 * b03,"
+ " a12 * b08 - a10 * b11 - a13 * b07,"
+ " a00 * b11 - a02 * b08 + a03 * b07,"
+ " a32 * b02 - a30 * b05 - a33 * b01,"
+ " a20 * b05 - a22 * b02 + a23 * b01,"
+ " a10 * b10 - a11 * b08 + a13 * b06,"
+ " a01 * b08 - a00 * b10 - a03 * b06,"
+ " a30 * b04 - a31 * b02 + a33 * b00,"
+ " a21 * b02 - a20 * b04 - a23 * b00,"
+ " a11 * b07 - a10 * b09 - a12 * b06,"
+ " a00 * b09 - a01 * b07 + a02 * b06,"
+ " a31 * b01 - a30 * b03 - a32 * b00,"
+ " a20 * b03 - a21 * b01 + a22 * b00) / det;"
+ "}"
+ ).c_str());
+ }
+ }
this->write(name);
}
@@ -300,8 +372,8 @@
// of type 'arg'.
String MetalCodeGenerator::getMatrixConstructHelper(const Type& matrix, const Type& arg) {
String key = matrix.name() + arg.name();
- auto found = fMatrixConstructHelpers.find(key);
- if (found != fMatrixConstructHelpers.end()) {
+ auto found = fHelpers.find(key);
+ if (found != fHelpers.end()) {
return found->second;
}
String name;
@@ -331,8 +403,34 @@
fExtraFunctions.writeText(")");
}
fExtraFunctions.writeText(");\n}\n");
- }
- else if (matrix.rows() == 2 && matrix.columns() == 2) {
+ } else if (arg.kind() == Type::kMatrix_Kind) {
+ // creating a matrix from another matrix
+ int argColumns = arg.columns();
+ int argRows = arg.rows();
+ name = "float" + to_string(columns) + "x" + to_string(rows) + "_from_float" +
+ to_string(argColumns) + "x" + to_string(argRows);
+ fExtraFunctions.printf("float%dx%d %s(float%dx%d m) {\n",
+ columns, rows, name.c_str(), argColumns, argRows);
+ fExtraFunctions.printf(" return float%dx%d(", columns, rows);
+ for (int i = 0; i < columns; ++i) {
+ if (i > 0) {
+ fExtraFunctions.writeText(", ");
+ }
+ fExtraFunctions.printf("float%d(", rows);
+ for (int j = 0; j < rows; ++j) {
+ if (j > 0) {
+ fExtraFunctions.writeText(", ");
+ }
+ if (i < argColumns && j < argRows) {
+ fExtraFunctions.printf("m[%d][%d]", i, j);
+ } else {
+ fExtraFunctions.writeText("0");
+ }
+ }
+ fExtraFunctions.writeText(")");
+ }
+ fExtraFunctions.writeText(");\n}\n");
+ } else if (matrix.rows() == 2 && matrix.columns() == 2 && arg == *fContext.fFloat4_Type) {
// float2x2(float4) doesn't work, need to split it into float2x2(float2, float2)
name = "float2x2_from_float4";
fExtraFunctions.printf(
@@ -341,12 +439,11 @@
"}\n",
name.c_str()
);
- }
- else {
+ } else {
SkASSERT(false);
name = "<error>";
}
- fMatrixConstructHelpers[key] = name;
+ fHelpers[key] = name;
return name;
}
@@ -380,15 +477,14 @@
for (const auto& arg : c.fArguments) {
this->write(separator);
separator = ", ";
- if (Type::kMatrix_Kind == c.fType.kind() && Type::kScalar_Kind == arg->fType.kind()) {
- // float2x2(float, float, float, float) doesn't work in Metal 1, so we need to merge
- // to float2x2(float2, float2).
+ if (Type::kMatrix_Kind == c.fType.kind() && arg->fType.columns() != c.fType.rows()) {
+ // merge scalars and smaller vectors together
if (!scalarCount) {
this->writeType(c.fType.componentType());
this->write(to_string(c.fType.rows()));
this->write("(");
}
- ++scalarCount;
+ scalarCount += arg->fType.columns();
}
this->writeExpression(*arg, kSequence_Precedence);
if (scalarCount && scalarCount == c.fType.rows()) {
@@ -527,10 +623,39 @@
}
}
+void MetalCodeGenerator::writeMatrixTimesEqualHelper(const Type& left, const Type& right,
+ const Type& result) {
+ String key = "TimesEqual" + left.name() + right.name();
+ if (fHelpers.find(key) == fHelpers.end()) {
+ fExtraFunctions.printf("%s operator*=(thread %s& left, thread const %s& right) {\n"
+ " left = left * right;\n"
+ " return left;\n"
+ "}", result.name().c_str(), left.name().c_str(),
+ right.name().c_str());
+ }
+}
+
void MetalCodeGenerator::writeBinaryExpression(const BinaryExpression& b,
Precedence parentPrecedence) {
Precedence precedence = GetBinaryPrecedence(b.fOperator);
- if (precedence >= parentPrecedence) {
+ bool needParens = precedence >= parentPrecedence;
+ switch (b.fOperator) {
+ case Token::EQEQ:
+ if (b.fLeft->fType.kind() == Type::kVector_Kind) {
+ this->write("all");
+ needParens = true;
+ }
+ break;
+ case Token::NEQ:
+ if (b.fLeft->fType.kind() == Type::kVector_Kind) {
+ this->write("!all");
+ needParens = true;
+ }
+ break;
+ default:
+ break;
+ }
+ if (needParens) {
this->write("(");
}
if (Compiler::IsAssignment(b.fOperator) &&
@@ -541,6 +666,10 @@
// dereference it here.
this->write("*");
}
+ if (b.fOperator == Token::STAREQ && b.fLeft->fType.kind() == Type::kMatrix_Kind &&
+ b.fRight->fType.kind() == Type::kMatrix_Kind) {
+ this->writeMatrixTimesEqualHelper(b.fLeft->fType, b.fRight->fType, b.fType);
+ }
this->writeExpression(*b.fLeft, precedence);
if (b.fOperator != Token::EQ && Compiler::IsAssignment(b.fOperator) &&
Expression::kSwizzle_Kind == b.fLeft->fKind && !b.fLeft->hasSideEffects()) {
@@ -561,7 +690,7 @@
this->write(String(" ") + Compiler::OperatorName(b.fOperator) + " ");
}
this->writeExpression(*b.fRight, precedence);
- if (precedence >= parentPrecedence) {
+ if (needParens) {
this->write(")");
}
}