Use swizzles in Metal matrix helper functions.
Not a big deal necessarily, but considering using this logic in GLSL as
well, and I'm less confident that your average GLSL ES driver will
optimize away the separate array loads.
Change-Id: I6a9f0d18c0fac138f64ad6426670f615e17f3492
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/449099
Commit-Queue: John Stiles <johnstiles@google.com>
Commit-Queue: Brian Osman <brianosman@google.com>
Auto-Submit: John Stiles <johnstiles@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
diff --git a/src/sksl/codegen/SkSLMetalCodeGenerator.cpp b/src/sksl/codegen/SkSLMetalCodeGenerator.cpp
index bb34f30..c0cd482 100644
--- a/src/sksl/codegen/SkSLMetalCodeGenerator.cpp
+++ b/src/sksl/codegen/SkSLMetalCodeGenerator.cpp
@@ -947,13 +947,14 @@
int argPosition = 0;
auto args = ctor.argumentSpan();
+ static constexpr char kSwizzle[] = "xyzw";
const char* separator = "";
for (int c = 0; c < columns; ++c) {
fExtraFunctions.printf("%s%s%d(", separator, matrixType.c_str(), rows);
separator = "), ";
const char* columnSeparator = "";
- for (int r = 0; r < rows; ++r) {
+ for (int r = 0; r < rows;) {
fExtraFunctions.writeText(columnSeparator);
columnSeparator = ", ";
@@ -962,16 +963,26 @@
switch (argType.typeKind()) {
case Type::TypeKind::kScalar: {
fExtraFunctions.printf("x%zu", argIndex);
+ ++r;
+ ++argPosition;
break;
}
case Type::TypeKind::kVector: {
- fExtraFunctions.printf("x%zu[%d]", argIndex, argPosition);
+ fExtraFunctions.printf("x%zu.", argIndex);
+ do {
+ fExtraFunctions.write8(kSwizzle[argPosition]);
+ ++r;
+ ++argPosition;
+ } while (r < rows && argPosition < argType.columns());
break;
}
case Type::TypeKind::kMatrix: {
- fExtraFunctions.printf("x%zu[%d][%d]", argIndex,
- argPosition / argType.rows(),
- argPosition % argType.rows());
+ fExtraFunctions.printf("x%zu[%d].", argIndex, argPosition / argType.rows());
+ do {
+ fExtraFunctions.write8(kSwizzle[argPosition]);
+ ++r;
+ ++argPosition;
+ } while (r < rows && (argPosition % argType.rows()) != 0);
break;
}
default: {
@@ -981,7 +992,6 @@
}
}
- ++argPosition;
if (argPosition >= argType.columns() * argType.rows()) {
++argIndex;
argPosition = 0;