Implement matrix casting in Metal.

Unlike GLSL/SkSL, Metal does not natively support casting an array from
one size to another; we need to synthesize a helper function which will
assemble a new matrix from the values in the old matrix and the
identity.

Previously, our matrix-conversion helpers understood how to glom
together an arbitrary collection of scalars/vectors/matrices into a
matrix containing a matching number of scalars, but it would fail when
given a matrix of unequal size.

Change-Id: I35eb161ed7c17b982b00ecceb7b525cbfb8f3bcb
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/308190
Commit-Queue: John Stiles <johnstiles@google.com>
Commit-Queue: Brian Osman <brianosman@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
Auto-Submit: John Stiles <johnstiles@google.com>
diff --git a/src/sksl/SkSLMetalCodeGenerator.cpp b/src/sksl/SkSLMetalCodeGenerator.cpp
index b090962..7aadc1a 100644
--- a/src/sksl/SkSLMetalCodeGenerator.cpp
+++ b/src/sksl/SkSLMetalCodeGenerator.cpp
@@ -388,6 +388,105 @@
     }
 }
 
+// Assembles a matrix of type floatRxC by resizing another matrix named `x0`.
+// Cells that don't exist in the source matrix will be populated with identity-matrix values.
+void MetalCodeGenerator::assembleMatrixFromMatrix(const Type& sourceMatrix, int rows, int columns) {
+    SkASSERT(rows <= 4);
+    SkASSERT(columns <= 4);
+
+    const char* columnSeparator = "";
+    for (int c = 0; c < columns; ++c) {
+        fExtraFunctions.printf("%sfloat%d(", columnSeparator, rows);
+        columnSeparator = "), ";
+
+        // Determine how many values to take from the source matrix for this row.
+        int swizzleLength = 0;
+        if (c < sourceMatrix.columns()) {
+            swizzleLength = std::min<>(rows, sourceMatrix.rows());
+        }
+
+        // Emit all the values from the source matrix row.
+        bool firstItem;
+        switch (swizzleLength) {
+            case 0:  firstItem = true;                                            break;
+            case 1:  firstItem = false; fExtraFunctions.printf("x0[%d].x", c);    break;
+            case 2:  firstItem = false; fExtraFunctions.printf("x0[%d].xy", c);   break;
+            case 3:  firstItem = false; fExtraFunctions.printf("x0[%d].xyz", c);  break;
+            case 4:  firstItem = false; fExtraFunctions.printf("x0[%d].xyzw", c); break;
+            default: SkUNREACHABLE;
+        }
+
+        // Emit the placeholder identity-matrix cells.
+        for (int r = swizzleLength; r < rows; ++r) {
+            fExtraFunctions.printf("%s%s", firstItem ? "" : ", ", (r == c) ? "1.0" : "0.0");
+            firstItem = false;
+        }
+    }
+
+    fExtraFunctions.writeText(")");
+}
+
+// Assembles a matrix of type floatRxC by concatenating an arbitrary mix of values, named `x0`,
+// `x1`, etc. An error is written if the expression list don't contain exactly R*C scalars.
+void MetalCodeGenerator::assembleMatrixFromExpressions(
+        const std::vector<std::unique_ptr<Expression>>& args, int rows, int columns) {
+    size_t argIndex = 0;
+    int argPosition = 0;
+
+    const char* columnSeparator = "";
+    for (int c = 0; c < columns; ++c) {
+        fExtraFunctions.printf("%sfloat%d(", columnSeparator, rows);
+        columnSeparator = "), ";
+
+        const char* rowSeparator = "";
+        for (int r = 0; r < rows; ++r) {
+            fExtraFunctions.writeText(rowSeparator);
+            rowSeparator = ", ";
+
+            if (argIndex < args.size()) {
+                const Type& argType = args[argIndex]->fType;
+                switch (argType.kind()) {
+                    case Type::kScalar_Kind: {
+                        fExtraFunctions.printf("x%zu", argIndex);
+                        break;
+                    }
+                    case Type::kVector_Kind: {
+                        fExtraFunctions.printf("x%zu[%d]", argIndex, argPosition);
+                        break;
+                    }
+                    case Type::kMatrix_Kind: {
+                        fExtraFunctions.printf("x%zu[%d][%d]", argIndex,
+                                               argPosition / argType.rows(),
+                                               argPosition % argType.rows());
+                        break;
+                    }
+                    default: {
+                        SkDEBUGFAIL("incorrect type of argument for matrix constructor");
+                        fExtraFunctions.writeText("<error>");
+                        break;
+                    }
+                }
+
+                ++argPosition;
+                if (argPosition >= argType.columns() * argType.rows()) {
+                    ++argIndex;
+                    argPosition = 0;
+                }
+            } else {
+                SkDEBUGFAIL("not enough arguments for matrix constructor");
+                fExtraFunctions.writeText("<error>");
+            }
+        }
+    }
+
+    if (argPosition != 0 || argIndex != args.size()) {
+        SkDEBUGFAIL("incorrect number of arguments for matrix constructor");
+        fExtraFunctions.writeText(", <error>");
+    }
+
+    fExtraFunctions.writeText(")");
+}
+
 // Generates a constructor for 'matrix' which reorganizes the input arguments into the proper shape.
 // Keeps track of previously generated constructors so that we won't generate more than one
 // constructor for any given permutation of input argument types. Returns the name of the
@@ -418,7 +517,7 @@
 
     size_t argIndex = 0;
     const char* argSeparator = "";
-    for (const std::unique_ptr<Expression>& expr : c.fArguments) {
+    for (const std::unique_ptr<Expression>& expr : args) {
         fExtraFunctions.printf("%s%s x%zu", argSeparator,
                                expr->fType.displayName().c_str(), argIndex++);
         argSeparator = ", ";
@@ -426,61 +525,13 @@
 
     fExtraFunctions.printf(") {\n    return float%dx%d(", columns, rows);
 
-    argIndex = 0;
-    int argPosition = 0;
-
-    const char* columnSeparator = "";
-    for (int c = 0; c < columns; ++c) {
-        fExtraFunctions.printf("%sfloat%d(", columnSeparator, rows);
-        columnSeparator = "), ";
-
-        const char* rowSeparator = "";
-        for (int r = 0; r < rows; ++r) {
-            fExtraFunctions.printf("%s", rowSeparator);
-            rowSeparator = ", ";
-
-            if (argIndex < args.size()) {
-                const Type& argType = args[argIndex]->fType;
-                switch (argType.kind()) {
-                    case Type::kScalar_Kind: {
-                        fExtraFunctions.printf("x%zu", argIndex);
-                        break;
-                    }
-                    case Type::kVector_Kind: {
-                        fExtraFunctions.printf("x%zu[%d]", argIndex, argPosition);
-                        break;
-                    }
-                    case Type::kMatrix_Kind: {
-                        fExtraFunctions.printf("x%zu[%d][%d]", argIndex,
-                                               argPosition / argType.rows(),
-                                               argPosition % argType.rows());
-                        break;
-                    }
-                    default: {
-                        SkDEBUGFAIL("incorrect type of argument for matrix constructor");
-                        fExtraFunctions.printf("<error>");
-                        break;
-                    }
-                }
-
-                ++argPosition;
-                if (argPosition >= argType.columns() * argType.rows()) {
-                    ++argIndex;
-                    argPosition = 0;
-                }
-            } else {
-                SkDEBUGFAIL("not enough arguments for matrix constructor");
-                fExtraFunctions.printf("<error>");
-            }
-        }
+    if (args.size() == 1 && args.front()->fType.kind() == Type::kMatrix_Kind) {
+        this->assembleMatrixFromMatrix(args.front()->fType, rows, columns);
+    } else {
+        this->assembleMatrixFromExpressions(args, rows, columns);
     }
 
-    if (argPosition != 0 || argIndex != args.size()) {
-        SkDEBUGFAIL("incorrect number of arguments for matrix constructor");
-        name = "<error>";
-    }
-
-    fExtraFunctions.printf("));\n}\n");
+    fExtraFunctions.writeText(");\n}\n");
     return name;
 }