SkSL Metal backend can now handle CCPR

Bug: skia:
Change-Id: I796a40db46174b405495af8234c5b8d7920a46d6
Reviewed-on: https://skia-review.googlesource.com/c/189985
Reviewed-by: Jim Van Verth <jvanverth@google.com>
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
diff --git a/src/sksl/SkSLMetalCodeGenerator.cpp b/src/sksl/SkSLMetalCodeGenerator.cpp
index 928fa23..da206b4 100644
--- a/src/sksl/SkSLMetalCodeGenerator.cpp
+++ b/src/sksl/SkSLMetalCodeGenerator.cpp
@@ -26,6 +26,8 @@
 #define SPECIAL(x) std::make_pair(kSpecial_IntrinsicKind, k ## x ## _SpecialIntrinsic)
     fIntrinsicMap[String("texture")]            = SPECIAL(Texture);
     fIntrinsicMap[String("mod")]                = SPECIAL(Mod);
+    fIntrinsicMap[String("equal")]              = METAL(Equal);
+    fIntrinsicMap[String("notEqual")]           = METAL(NotEqual);
     fIntrinsicMap[String("lessThan")]           = METAL(LessThan);
     fIntrinsicMap[String("lessThanEqual")]      = METAL(LessThanEqual);
     fIntrinsicMap[String("greaterThan")]        = METAL(GreaterThan);
@@ -172,6 +174,12 @@
         case kMetal_IntrinsicKind:
             this->writeExpression(*c.fArguments[0], kSequence_Precedence);
             switch ((MetalIntrinsic) intrinsicId) {
+                case kEqual_MetalIntrinsic:
+                    this->write(" == ");
+                    break;
+                case kNotEqual_MetalIntrinsic:
+                    this->write(" != ");
+                    break;
                 case kLessThan_MetalIntrinsic:
                     this->write(" < ");
                     break;
@@ -248,18 +256,82 @@
 }
 
 void MetalCodeGenerator::writeInverseHack(const Expression& mat) {
-    String name = "ERROR_MatrixInverseNotImplementedFor_" + mat.fType.name();
-    if (mat.fType == *fContext.fFloat2x2_Type) {
-        name = "_inverse2";
+    String typeName = mat.fType.name();
+    String name = typeName + "_inverse";
+    if (mat.fType == *fContext.fFloat2x2_Type || mat.fType == *fContext.fHalf2x2_Type) {
         if (fWrittenIntrinsics.find(name) == fWrittenIntrinsics.end()) {
             fWrittenIntrinsics.insert(name);
             fExtraFunctions.writeText((
-                "float2x2 " + name + "(float2x2 m) {"
+                typeName + " " + name + "(" + typeName + " m) {"
                 "    return float2x2(m[1][1], -m[0][1], -m[1][0], m[0][0]) * (1/determinant(m));"
                 "}"
             ).c_str());
         }
     }
+    else if (mat.fType == *fContext.fFloat3x3_Type || mat.fType == *fContext.fHalf3x3_Type) {
+        if (fWrittenIntrinsics.find(name) == fWrittenIntrinsics.end()) {
+            fWrittenIntrinsics.insert(name);
+            fExtraFunctions.writeText((
+                typeName + " " +  name + "(" + typeName + " m) {"
+                "    float a00 = m[0][0], a01 = m[0][1], a02 = m[0][2];"
+                "    float a10 = m[1][0], a11 = m[1][1], a12 = m[1][2];"
+                "    float a20 = m[2][0], a21 = m[2][1], a22 = m[2][2];"
+                "    float b01 = a22 * a11 - a12 * a21;"
+                "    float b11 = -a22 * a10 + a12 * a20;"
+                "    float b21 = a21 * a10 - a11 * a20;"
+                "    float det = a00 * b01 + a01 * b11 + a02 * b21;"
+                "    return " + typeName +
+                "                   (b01, (-a22 * a01 + a02 * a21), (a12 * a01 - a02 * a11),"
+                "                    b11, (a22 * a00 - a02 * a20), (-a12 * a00 + a02 * a10),"
+                "                    b21, (-a21 * a00 + a01 * a20), (a11 * a00 - a01 * a10)) * "
+                "                   (1/det);"
+                "}"
+            ).c_str());
+        }
+    }
+    else if (mat.fType == *fContext.fFloat4x4_Type || mat.fType == *fContext.fHalf4x4_Type) {
+        if (fWrittenIntrinsics.find(name) == fWrittenIntrinsics.end()) {
+            fWrittenIntrinsics.insert(name);
+            fExtraFunctions.writeText((
+                typeName + " " +  name + "(" + typeName + " m) {"
+                "    float a00 = m[0][0], a01 = m[0][1], a02 = m[0][2], a03 = m[0][3];"
+                "    float a10 = m[1][0], a11 = m[1][1], a12 = m[1][2], a13 = m[1][3];"
+                "    float a20 = m[2][0], a21 = m[2][1], a22 = m[2][2], a23 = m[2][3];"
+                "    float a30 = m[3][0], a31 = m[3][1], a32 = m[3][2], a33 = m[3][3];"
+                "    float b00 = a00 * a11 - a01 * a10;"
+                "    float b01 = a00 * a12 - a02 * a10;"
+                "    float b02 = a00 * a13 - a03 * a10;"
+                "    float b03 = a01 * a12 - a02 * a11;"
+                "    float b04 = a01 * a13 - a03 * a11;"
+                "    float b05 = a02 * a13 - a03 * a12;"
+                "    float b06 = a20 * a31 - a21 * a30;"
+                "    float b07 = a20 * a32 - a22 * a30;"
+                "    float b08 = a20 * a33 - a23 * a30;"
+                "    float b09 = a21 * a32 - a22 * a31;"
+                "    float b10 = a21 * a33 - a23 * a31;"
+                "    float b11 = a22 * a33 - a23 * a32;"
+                "    float det = b00 * b11 - b01 * b10 + b02 * b09 + b03 * b08 - "
+                "                b04 * b07 + b05 * b06;"
+                "    return " + typeName + "(a11 * b11 - a12 * b10 + a13 * b09,"
+                "                            a02 * b10 - a01 * b11 - a03 * b09,"
+                "                            a31 * b05 - a32 * b04 + a33 * b03,"
+                "                            a22 * b04 - a21 * b05 - a23 * b03,"
+                "                            a12 * b08 - a10 * b11 - a13 * b07,"
+                "                            a00 * b11 - a02 * b08 + a03 * b07,"
+                "                            a32 * b02 - a30 * b05 - a33 * b01,"
+                "                            a20 * b05 - a22 * b02 + a23 * b01,"
+                "                            a10 * b10 - a11 * b08 + a13 * b06,"
+                "                            a01 * b08 - a00 * b10 - a03 * b06,"
+                "                            a30 * b04 - a31 * b02 + a33 * b00,"
+                "                            a21 * b02 - a20 * b04 - a23 * b00,"
+                "                            a11 * b07 - a10 * b09 - a12 * b06,"
+                "                            a00 * b09 - a01 * b07 + a02 * b06,"
+                "                            a31 * b01 - a30 * b03 - a32 * b00,"
+                "                            a20 * b03 - a21 * b01 + a22 * b00) / det;"
+                "}"
+            ).c_str());
+        }
+    }
     this->write(name);
 }
 
@@ -300,8 +372,8 @@
 // of type 'arg'.
 String MetalCodeGenerator::getMatrixConstructHelper(const Type& matrix, const Type& arg) {
     String key = matrix.name() + arg.name();
-    auto found = fMatrixConstructHelpers.find(key);
-    if (found != fMatrixConstructHelpers.end()) {
+    auto found = fHelpers.find(key);
+    if (found != fHelpers.end()) {
         return found->second;
     }
     String name;
@@ -331,8 +403,34 @@
             fExtraFunctions.writeText(")");
         }
         fExtraFunctions.writeText(");\n}\n");
-    }
-    else if (matrix.rows() == 2 && matrix.columns() == 2) {
+    } else if (arg.kind() == Type::kMatrix_Kind) {
+        // creating a matrix from another matrix
+        int argColumns = arg.columns();
+        int argRows = arg.rows();
+        name = "float" + to_string(columns) + "x" + to_string(rows) + "_from_float" +
+               to_string(argColumns) + "x" + to_string(argRows);
+        fExtraFunctions.printf("float%dx%d %s(float%dx%d m) {\n",
+                               columns, rows, name.c_str(), argColumns, argRows);
+        fExtraFunctions.printf("    return float%dx%d(", columns, rows);
+        for (int i = 0; i < columns; ++i) {
+            if (i > 0) {
+                fExtraFunctions.writeText(", ");
+            }
+            fExtraFunctions.printf("float%d(", rows);
+            for (int j = 0; j < rows; ++j) {
+                if (j > 0) {
+                    fExtraFunctions.writeText(", ");
+                }
+                if (i < argColumns && j < argRows) {
+                    fExtraFunctions.printf("m[%d][%d]", i, j);
+                } else {
+                    fExtraFunctions.writeText("0");
+                }
+            }
+            fExtraFunctions.writeText(")");
+        }
+        fExtraFunctions.writeText(");\n}\n");
+    } else if (matrix.rows() == 2 && matrix.columns() == 2 && arg == *fContext.fFloat4_Type) {
         // float2x2(float4) doesn't work, need to split it into float2x2(float2, float2)
         name = "float2x2_from_float4";
         fExtraFunctions.printf(
@@ -341,12 +439,11 @@
             "}\n",
             name.c_str()
         );
-    }
-    else {
+    } else {
         SkASSERT(false);
         name = "<error>";
     }
-    fMatrixConstructHelpers[key] = name;
+    fHelpers[key] = name;
     return name;
 }
 
@@ -380,15 +477,14 @@
         for (const auto& arg : c.fArguments) {
             this->write(separator);
             separator = ", ";
-            if (Type::kMatrix_Kind == c.fType.kind() && Type::kScalar_Kind == arg->fType.kind()) {
-                // float2x2(float, float, float, float) doesn't work in Metal 1, so we need to merge
-                // to float2x2(float2, float2).
+            if (Type::kMatrix_Kind == c.fType.kind() && arg->fType.columns() != c.fType.rows()) {
+                // merge scalars and smaller vectors together
                 if (!scalarCount) {
                     this->writeType(c.fType.componentType());
                     this->write(to_string(c.fType.rows()));
                     this->write("(");
                 }
-                ++scalarCount;
+                scalarCount += arg->fType.columns();
             }
             this->writeExpression(*arg, kSequence_Precedence);
             if (scalarCount && scalarCount == c.fType.rows()) {
@@ -527,10 +623,39 @@
     }
 }
 
+void MetalCodeGenerator::writeMatrixTimesEqualHelper(const Type& left, const Type& right,
+                                                     const Type& result) {
+    String key = "TimesEqual" + left.name() + right.name();
+    if (fHelpers.find(key) == fHelpers.end()) {
+        fExtraFunctions.printf("%s operator*=(thread %s& left, thread const %s& right) {\n"
+                               "    left = left * right;\n"
+                               "    return left;\n"
+                               "}", result.name().c_str(), left.name().c_str(),
+                                    right.name().c_str());
+    }
+}
+
 void MetalCodeGenerator::writeBinaryExpression(const BinaryExpression& b,
                                                Precedence parentPrecedence) {
     Precedence precedence = GetBinaryPrecedence(b.fOperator);
-    if (precedence >= parentPrecedence) {
+    bool needParens = precedence >= parentPrecedence;
+    switch (b.fOperator) {
+        case Token::EQEQ:
+            if (b.fLeft->fType.kind() == Type::kVector_Kind) {
+                this->write("all");
+                needParens = true;
+            }
+            break;
+        case Token::NEQ:
+            if (b.fLeft->fType.kind() == Type::kVector_Kind) {
+                this->write("!all");
+                needParens = true;
+            }
+            break;
+        default:
+            break;
+    }
+    if (needParens) {
         this->write("(");
     }
     if (Compiler::IsAssignment(b.fOperator) &&
@@ -541,6 +666,10 @@
         // dereference it here.
         this->write("*");
     }
+    if (b.fOperator == Token::STAREQ && b.fLeft->fType.kind() == Type::kMatrix_Kind &&
+        b.fRight->fType.kind() == Type::kMatrix_Kind) {
+        this->writeMatrixTimesEqualHelper(b.fLeft->fType, b.fRight->fType, b.fType);
+    }
     this->writeExpression(*b.fLeft, precedence);
     if (b.fOperator != Token::EQ && Compiler::IsAssignment(b.fOperator) &&
         Expression::kSwizzle_Kind == b.fLeft->fKind && !b.fLeft->hasSideEffects()) {
@@ -561,7 +690,7 @@
         this->write(String(" ") + Compiler::OperatorName(b.fOperator) + " ");
     }
     this->writeExpression(*b.fRight, precedence);
-    if (precedence >= parentPrecedence) {
+    if (needParens) {
         this->write(")");
     }
 }