Add support for matrix + scalar to Metal codegen.
The Metal code generator will now detect matrix-op-scalar expressions
and splat the scalar across a matrix. This allows a scalar to be added
to, or subtracted from, a matrix. (It does not fix division because
Metal also does not natively support componentwise division on
matrices.)
Change-Id: I7d5b0c5bd35393475c524e34cad789bf4f72a103
Bug: skia:11125
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/407616
Auto-Submit: John Stiles <johnstiles@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
Commit-Queue: John Stiles <johnstiles@google.com>
diff --git a/src/sksl/codegen/SkSLMetalCodeGenerator.cpp b/src/sksl/codegen/SkSLMetalCodeGenerator.cpp
index 87204f4..a08133e 100644
--- a/src/sksl/codegen/SkSLMetalCodeGenerator.cpp
+++ b/src/sksl/codegen/SkSLMetalCodeGenerator.cpp
@@ -1339,6 +1339,27 @@
}
}
+void MetalCodeGenerator::writeNumberAsMatrix(const Expression& expr, const Type& matrixType) {
+ SkASSERT(expr.type().isNumber());
+ SkASSERT(matrixType.isMatrix());
+
+ // Componentwise multiply the scalar against a matrix of the desired size which contains all 1s.
+ this->write("(");
+ this->writeType(matrixType);
+ this->write("(");
+
+ const char* separator = "";
+ for (int index = matrixType.slotCount(); index--;) {
+ this->write(separator);
+ this->write("1.0");
+ separator = ", ";
+ }
+
+ this->write(") * ");
+ this->writeExpression(expr, Precedence::kMultiplicative);
+ this->write(")");
+}
+
void MetalCodeGenerator::writeBinaryExpression(const BinaryExpression& b,
Precedence parentPrecedence) {
const Expression& left = *b.left();
@@ -1372,7 +1393,14 @@
if (leftType.isMatrix() && rightType.isMatrix() && op.kind() == Token::Kind::TK_STAREQ) {
this->writeMatrixTimesEqualHelper(leftType, rightType, b.type());
}
- this->writeExpression(left, precedence);
+
+ bool needMatrixSplatOnScalar = rightType.isMatrix() && leftType.isScalar() &&
+ op.removeAssignment().kind() != Token::Kind::TK_STAR;
+ if (needMatrixSplatOnScalar) {
+ this->writeNumberAsMatrix(left, rightType);
+ } else {
+ this->writeExpression(left, precedence);
+ }
if (op.kind() != Token::Kind::TK_EQ && op.isAssignment() &&
left.kind() == Expression::Kind::kSwizzle && !left.hasSideEffects()) {
// This doesn't compile in Metal:
@@ -1391,7 +1419,14 @@
} else {
this->write(String(" ") + OperatorName(op) + " ");
}
- this->writeExpression(right, precedence);
+
+ needMatrixSplatOnScalar = leftType.isMatrix() && rightType.isScalar() &&
+ op.removeAssignment().kind() != Token::Kind::TK_STAR;
+ if (needMatrixSplatOnScalar) {
+ this->writeNumberAsMatrix(right, leftType);
+ } else {
+ this->writeExpression(right, precedence);
+ }
if (needParens) {
this->write(")");
}