Honor component type in Metal matrix helper functions.
Right now, Metal forces types to full precision. The matrix helper
functions previously baked in that assumption by hard-coding "floatX".
Now, they honor the component type; if this->typeName() started
returning "half", our helper functions would be named with "halfX". This
would allow half-precision and full-precision helpers to coexist.
Change-Id: I1679e6e76d2cf3c27fd69c42a92fb24bff6b69ec
Bug: skia:12339
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/439396
Commit-Queue: Brian Osman <brianosman@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
Auto-Submit: John Stiles <johnstiles@google.com>
diff --git a/src/sksl/codegen/SkSLMetalCodeGenerator.cpp b/src/sksl/codegen/SkSLMetalCodeGenerator.cpp
index 6bc4698..14348ee 100644
--- a/src/sksl/codegen/SkSLMetalCodeGenerator.cpp
+++ b/src/sksl/codegen/SkSLMetalCodeGenerator.cpp
@@ -825,10 +825,12 @@
SkASSERT(rows <= 4);
SkASSERT(columns <= 4);
- const char* columnSeparator = "";
+ std::string matrixType = this->typeName(sourceMatrix.componentType());
+
+ const char* separator = "";
for (int c = 0; c < columns; ++c) {
- fExtraFunctions.printf("%sfloat%d(", columnSeparator, rows);
- columnSeparator = "), ";
+ fExtraFunctions.printf("%s%s%d(", separator, matrixType.c_str(), rows);
+ separator = "), ";
// Determine how many values to take from the source matrix for this row.
int swizzleLength = 0;
@@ -861,14 +863,18 @@
// `x1`, etc. An error is written if the expression list don't contain exactly C*R scalars.
void MetalCodeGenerator::assembleMatrixFromExpressions(const AnyConstructor& ctor,
int columns, int rows) {
+ SkASSERT(rows <= 4);
+ SkASSERT(columns <= 4);
+
+ std::string matrixType = this->typeName(ctor.type().componentType());
size_t argIndex = 0;
int argPosition = 0;
auto args = ctor.argumentSpan();
- const char* rowSeparator = "";
+ const char* separator = "";
for (int r = 0; r < rows; ++r) {
- fExtraFunctions.printf("%sfloat%d(", rowSeparator, rows);
- rowSeparator = "), ";
+ fExtraFunctions.printf("%s%s%d(", separator, matrixType.c_str(), rows);
+ separator = "), ";
const char* columnSeparator = "";
for (int c = 0; c < columns; ++c) {
@@ -924,14 +930,15 @@
// constructor for any given permutation of input argument types. Returns the name of the
// generated constructor method.
String MetalCodeGenerator::getMatrixConstructHelper(const AnyConstructor& c) {
- const Type& matrix = c.type();
- int columns = matrix.columns();
- int rows = matrix.rows();
+ const Type& type = c.type();
+ int columns = type.columns();
+ int rows = type.rows();
auto args = c.argumentSpan();
+ String typeName = this->typeName(type);
// Create the helper-method name and use it as our lookup key.
String name;
- name.appendf("float%dx%d_from", columns, rows);
+ name.appendf("%s_from", typeName.c_str());
for (const std::unique_ptr<Expression>& expr : args) {
name.appendf("_%s", this->typeName(expr->type()).c_str());
}
@@ -945,7 +952,7 @@
// Unlike GLSL, Metal requires that matrices are initialized with exactly R vectors of C
// components apiece. (In Metal 2.0, you can also supply R*C scalars, but you still cannot
// supply a mixture of scalars and vectors.)
- fExtraFunctions.printf("float%dx%d %s(", columns, rows, name.c_str());
+ fExtraFunctions.printf("%s %s(", typeName.c_str(), name.c_str());
size_t argIndex = 0;
const char* argSeparator = "";
@@ -955,7 +962,7 @@
argSeparator = ", ";
}
- fExtraFunctions.printf(") {\n return float%dx%d(", columns, rows);
+ fExtraFunctions.printf(") {\n return %s(", typeName.c_str());
if (args.size() == 1 && args.front()->type().isMatrix()) {
this->assembleMatrixFromMatrix(args.front()->type(), rows, columns);
@@ -1042,18 +1049,24 @@
}
}
-void MetalCodeGenerator::writeVectorFromMat2x2ConstructorHelper() {
- static constexpr char kCode[] =
-R"(float4 float4_from_float2x2(float2x2 x) {
- return float4(x[0].xy, x[1].xy);
-}
-)";
+String MetalCodeGenerator::getVectorFromMat2x2ConstructorHelper(const Type& matrixType) {
+ SkASSERT(matrixType.isMatrix());
+ SkASSERT(matrixType.rows() == 2);
+ SkASSERT(matrixType.columns() == 2);
- String name = "matrixCompMult";
- if (fHelpers.find("float4_from_float2x2") == fHelpers.end()) {
- fHelpers.insert("float4_from_float2x2");
- fExtraFunctions.writeText(kCode);
+ String baseType = this->typeName(matrixType.componentType());
+ String name = String::printf("%s4_from_%s2x2", baseType.c_str(), baseType.c_str());
+ if (fHelpers.find(name) == fHelpers.end()) {
+ fHelpers.insert(name);
+
+ fExtraFunctions.printf(R"(
+%s4 %s(%s2x2 x) {
+ return %s4(x[0].xy, x[1].xy);
+}
+)", baseType.c_str(), name.c_str(), baseType.c_str(), baseType.c_str());
}
+
+ return name;
}
void MetalCodeGenerator::writeConstructorCompoundVector(const ConstructorCompound& c,
@@ -1065,10 +1078,8 @@
if (c.type().columns() == 4 && c.argumentSpan().size() == 1) {
const Expression& expr = *c.argumentSpan().front();
if (expr.type().isMatrix()) {
- SkASSERT(expr.type().rows() == 2);
- SkASSERT(expr.type().columns() == 2);
- this->writeVectorFromMat2x2ConstructorHelper();
- this->write("float4_from_float2x2(");
+ this->write(this->getVectorFromMat2x2ConstructorHelper(expr.type()));
+ this->write("(");
this->writeExpression(expr, Precedence::kSequence);
this->write(")");
return;
diff --git a/src/sksl/codegen/SkSLMetalCodeGenerator.h b/src/sksl/codegen/SkSLMetalCodeGenerator.h
index 088f642..6a459d5 100644
--- a/src/sksl/codegen/SkSLMetalCodeGenerator.h
+++ b/src/sksl/codegen/SkSLMetalCodeGenerator.h
@@ -178,7 +178,7 @@
void writeMatrixEqualityHelpers(const Type& left, const Type& right);
- void writeVectorFromMat2x2ConstructorHelper();
+ String getVectorFromMat2x2ConstructorHelper(const Type& matrixType);
void writeArrayEqualityHelpers(const Type& type);