Reland "Implement operator== and != for Metal structs and arrays."

This is a reland of 830c69ca66d339067cdc06775f59443a06fc15a2

Original change's description:
> Implement operator== and != for Metal structs and arrays.
>
> GLSL/SkSL assumes that == and != on struct/array types should work.
> We need to emit equality and inequality operators whenever we find code
> that compares a struct or array.
>
> Structs and arrays can be arbitrarily nested, and either type can
> contain a matrix. All of these things need custom equality operators in
> Metal. Therefore, we need to recursively generate comparison operators
> when any of these types are encountered.
>
> For arrays we get lucky, and we can cover all possible array types and
> sizes with a single templated operator== method. Structs and matrices
> have no such luck, and are generated separately on a per-type basis.
>
> For each of these types, operator== is implemented as an equality check
> on each field, and operator!= is implemented in terms of operator==.
> Equality and inequality are always emitted together. (Previously, matrix
> equality and inequality were emitted and implemented independently, but
> this is no longer the case.)
>
> Change-Id: I69ee01c0a390d7db6bcb2253ed6336ab20cc4d1d
> Bug: skia:11908, skia:11924
> Reviewed-on: https://skia-review.googlesource.com/c/skia/+/402016
> Auto-Submit: John Stiles <johnstiles@google.com>
> Commit-Queue: Brian Osman <brianosman@google.com>
> Reviewed-by: Brian Osman <brianosman@google.com>

Bug: skia:11908, skia:11924, skia:11929
Change-Id: I6336b6125e9774c1ca73e3d497e3466f11f6f25f
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/402559
Commit-Queue: John Stiles <johnstiles@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
diff --git a/src/sksl/codegen/SkSLMetalCodeGenerator.cpp b/src/sksl/codegen/SkSLMetalCodeGenerator.cpp
index 283e2a3..31e55d3 100644
--- a/src/sksl/codegen/SkSLMetalCodeGenerator.cpp
+++ b/src/sksl/codegen/SkSLMetalCodeGenerator.cpp
@@ -1251,7 +1251,7 @@
 
 void MetalCodeGenerator::writeMatrixTimesEqualHelper(const Type& left, const Type& right,
                                                      const Type& result) {
-    String key = "TimesEqual" + this->typeName(left) + ":" + this->typeName(right);
+    String key = "TimesEqual " + this->typeName(left) + ":" + this->typeName(right);
 
     auto [iter, wasInserted] = fHelpers.insert(key);
     if (wasInserted) {
@@ -1264,47 +1264,114 @@
     }
 }
 
-void MetalCodeGenerator::writeMatrixEqualityHelper(const Type& left, const Type& right) {
-    SkASSERTF(left.rows() == right.rows() && left.columns() == right.columns(), "left=%s, right=%s",
-              left.description().c_str(), right.description().c_str());
+void MetalCodeGenerator::writeMatrixEqualityHelpers(const Type& left, const Type& right) {
+    SkASSERT(left.isMatrix());
+    SkASSERT(right.isMatrix());
+    SkASSERT(left.rows() == right.rows());
+    SkASSERT(left.columns() == right.columns());
 
-    String key = "Equality" + this->typeName(left) + ":" + this->typeName(right);
+    String key = "MatrixEquality " + this->typeName(left) + ":" + this->typeName(right);
 
     auto [iter, wasInserted] = fHelpers.insert(key);
     if (wasInserted) {
         fExtraFunctions.printf(
                 "thread bool operator==(const %s left, const %s right) {\n"
-                "    return",
+                "    return ",
                 this->typeName(left).c_str(), this->typeName(right).c_str());
 
+        const char* separator = "";
         for (int index=0; index<left.columns(); ++index) {
-            fExtraFunctions.printf("%s all(left[%d] == right[%d])",
-                                   index == 0 ? "" : " &&", index, index);
+            fExtraFunctions.printf("%sall(left[%d] == right[%d])", separator, index, index);
+            separator = " &&\n           ";
         }
-        fExtraFunctions.printf(";\n"
-                               "}\n");
+
+        fExtraFunctions.printf(
+                ";\n"
+                "}\n"
+                "thread bool operator!=(const %s left, const %s right) {\n"
+                "    return !(left == right);\n"
+                "}\n",
+                this->typeName(left).c_str(), this->typeName(right).c_str());
     }
 }
 
-void MetalCodeGenerator::writeMatrixInequalityHelper(const Type& left, const Type& right) {
-    SkASSERTF(left.rows() == right.rows() && left.columns() == right.columns(), "left=%s, right=%s",
-              left.description().c_str(), right.description().c_str());
+void MetalCodeGenerator::writeArrayEqualityHelpers(const Type& type) {
+    SkASSERT(type.isArray());
 
-    String key = "Inequality" + this->typeName(left) + ":" + this->typeName(right);
+    // If the array's component type needs a helper as well, we need to emit that one first.
+    this->writeEqualityHelpers(type.componentType(), type.componentType());
+
+    auto [iter, wasInserted] = fHelpers.insert("ArrayEquality []");
+    if (wasInserted) {
+        fExtraFunctions.writeText(R"(
+template <typename T, size_t N>
+bool operator==(thread const array<T, N>& left, thread const array<T, N>& right) {
+    for (size_t index = 0; index < N; ++index) {
+        if (!(left[index] == right[index])) {
+            return false;
+        }
+    }
+    return true;
+}
+
+template <typename T, size_t N>
+bool operator!=(thread const array<T, N>& left, thread const array<T, N>& right) {
+    return !(left == right);
+}
+)");
+    }
+}
+
+void MetalCodeGenerator::writeStructEqualityHelpers(const Type& type) {
+    SkASSERT(type.isStruct());
+    String key = "StructEquality " + this->typeName(type);
 
     auto [iter, wasInserted] = fHelpers.insert(key);
     if (wasInserted) {
-        fExtraFunctions.printf(
-                "thread bool operator!=(const %s left, const %s right) {\n"
-                "    return",
-                this->typeName(left).c_str(), this->typeName(right).c_str());
-
-        for (int index=0; index<left.columns(); ++index) {
-            fExtraFunctions.printf("%s any(left[%d] != right[%d])",
-                                   index == 0 ? "" : " ||", index, index);
+        // If one of the struct's fields needs a helper as well, we need to emit that one first.
+        for (const Type::Field& field : type.fields()) {
+            this->writeEqualityHelpers(*field.fType, *field.fType);
         }
-        fExtraFunctions.printf(";\n"
-                               "}\n");
+
+        // Write operator== and operator!= for this struct, since those are assumed to exist in SkSL
+        // and GLSL but do not exist by default in Metal.
+        fExtraFunctions.printf(
+                "thread bool operator==(thread const %s& left, thread const %s& right) {\n"
+                "    return ",
+                this->typeName(type).c_str(),
+                this->typeName(type).c_str());
+
+        const char* separator = "";
+        for (const Type::Field& field : type.fields()) {
+            fExtraFunctions.printf("%s(left.%.*s == right.%.*s)",
+                                   separator,
+                                   (int)field.fName.size(), field.fName.data(),
+                                   (int)field.fName.size(), field.fName.data());
+            separator = " &&\n           ";
+        }
+        fExtraFunctions.printf(
+                ";\n"
+                "}\n"
+                "thread bool operator!=(thread const %s& left, thread const %s& right) {\n"
+                "    return !(left == right);\n"
+                "}\n",
+                this->typeName(type).c_str(),
+                this->typeName(type).c_str());
+    }
+}
+
+void MetalCodeGenerator::writeEqualityHelpers(const Type& leftType, const Type& rightType) {
+    if (leftType.isArray() && rightType.isArray()) {
+        this->writeArrayEqualityHelpers(leftType);
+        return;
+    }
+    if (leftType.isStruct() && rightType.isStruct()) {
+        this->writeStructEqualityHelpers(leftType);
+        return;
+    }
+    if (leftType.isMatrix() && rightType.isMatrix()) {
+        this->writeMatrixEqualityHelpers(leftType, rightType);
+        return;
     }
 }
 
@@ -1319,12 +1386,14 @@
     bool needParens = precedence >= parentPrecedence;
     switch (op.kind()) {
         case Token::Kind::TK_EQEQ:
+            this->writeEqualityHelpers(leftType, rightType);
             if (leftType.isVector()) {
                 this->write("all");
                 needParens = true;
             }
             break;
         case Token::Kind::TK_NEQ:
+            this->writeEqualityHelpers(leftType, rightType);
             if (leftType.isVector()) {
                 this->write("any");
                 needParens = true;
@@ -1336,14 +1405,8 @@
     if (needParens) {
         this->write("(");
     }
-    if (leftType.isMatrix() && rightType.isMatrix()) {
-        if (op.kind() == Token::Kind::TK_STAREQ) {
-            this->writeMatrixTimesEqualHelper(leftType, rightType, b.type());
-        } else if (op.kind() == Token::Kind::TK_EQEQ) {
-            this->writeMatrixEqualityHelper(leftType, rightType);
-        } else if (op.kind() == Token::Kind::TK_NEQ) {
-            this->writeMatrixInequalityHelper(leftType, rightType);
-        }
+    if (leftType.isMatrix() && rightType.isMatrix() && op.kind() == Token::Kind::TK_STAREQ) {
+        this->writeMatrixTimesEqualHelper(leftType, rightType, b.type());
     }
     this->writeExpression(left, precedence);
     if (op.kind() != Token::Kind::TK_EQ && op.isAssignment() &&
@@ -1720,7 +1783,7 @@
                                      const InterfaceBlock* parentIntf) {
     MemoryLayout memoryLayout(MemoryLayout::kMetal_Standard);
     int currentOffset = 0;
-    for (const auto& field: fields) {
+    for (const Type::Field& field : fields) {
         int fieldOffset = field.fModifiers.fLayout.fOffset;
         const Type* fieldType = field.fType;
         if (!MemoryLayout::LayoutIsSupported(*fieldType)) {
diff --git a/src/sksl/codegen/SkSLMetalCodeGenerator.h b/src/sksl/codegen/SkSLMetalCodeGenerator.h
index 8f9d524..96b0ca9 100644
--- a/src/sksl/codegen/SkSLMetalCodeGenerator.h
+++ b/src/sksl/codegen/SkSLMetalCodeGenerator.h
@@ -217,9 +217,13 @@
 
     void writeMatrixTimesEqualHelper(const Type& left, const Type& right, const Type& result);
 
-    void writeMatrixEqualityHelper(const Type& left, const Type& right);
+    void writeMatrixEqualityHelpers(const Type& left, const Type& right);
 
-    void writeMatrixInequalityHelper(const Type& left, const Type& right);
+    void writeArrayEqualityHelpers(const Type& type);
+
+    void writeStructEqualityHelpers(const Type& type);
+
+    void writeEqualityHelpers(const Type& leftType, const Type& rightType);
 
     void writeArgumentList(const ExpressionArray& arguments);