Implement operator== and != for Metal structs and arrays.
GLSL/SkSL assumes that == and != on struct/array types should work.
We need to emit equality and inequality operators whenever we find code
that compares a struct or array.
Structs and arrays can be arbitrarily nested, and either type can
contain a matrix. All of these things need custom equality operators in
Metal. Therefore, we need to recursively generate comparison operators
when any of these types are encountered.
For arrays we get lucky, and we can cover all possible array types and
sizes with a single templated operator== method. Structs and matrices
have no such luck, and are generated separately on a per-type basis.
For each of these types, operator== is implemented as an equality check
on each field, and operator!= is implemented in terms of operator==.
Equality and inequality are always emitted together. (Previously, matrix
equality and inequality were emitted and implemented independently, but
this is no longer the case.)
Change-Id: I69ee01c0a390d7db6bcb2253ed6336ab20cc4d1d
Bug: skia:11908, skia:11924
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/402016
Auto-Submit: John Stiles <johnstiles@google.com>
Commit-Queue: Brian Osman <brianosman@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
diff --git a/src/sksl/codegen/SkSLMetalCodeGenerator.cpp b/src/sksl/codegen/SkSLMetalCodeGenerator.cpp
index 283e2a3..31e55d3 100644
--- a/src/sksl/codegen/SkSLMetalCodeGenerator.cpp
+++ b/src/sksl/codegen/SkSLMetalCodeGenerator.cpp
@@ -1251,7 +1251,7 @@
void MetalCodeGenerator::writeMatrixTimesEqualHelper(const Type& left, const Type& right,
const Type& result) {
- String key = "TimesEqual" + this->typeName(left) + ":" + this->typeName(right);
+ String key = "TimesEqual " + this->typeName(left) + ":" + this->typeName(right);
auto [iter, wasInserted] = fHelpers.insert(key);
if (wasInserted) {
@@ -1264,47 +1264,114 @@
}
}
-void MetalCodeGenerator::writeMatrixEqualityHelper(const Type& left, const Type& right) {
- SkASSERTF(left.rows() == right.rows() && left.columns() == right.columns(), "left=%s, right=%s",
- left.description().c_str(), right.description().c_str());
+void MetalCodeGenerator::writeMatrixEqualityHelpers(const Type& left, const Type& right) {
+ SkASSERT(left.isMatrix());
+ SkASSERT(right.isMatrix());
+ SkASSERT(left.rows() == right.rows());
+ SkASSERT(left.columns() == right.columns());
- String key = "Equality" + this->typeName(left) + ":" + this->typeName(right);
+ String key = "MatrixEquality " + this->typeName(left) + ":" + this->typeName(right);
auto [iter, wasInserted] = fHelpers.insert(key);
if (wasInserted) {
fExtraFunctions.printf(
"thread bool operator==(const %s left, const %s right) {\n"
- " return",
+ " return ",
this->typeName(left).c_str(), this->typeName(right).c_str());
+ const char* separator = "";
for (int index=0; index<left.columns(); ++index) {
- fExtraFunctions.printf("%s all(left[%d] == right[%d])",
- index == 0 ? "" : " &&", index, index);
+ fExtraFunctions.printf("%sall(left[%d] == right[%d])", separator, index, index);
+ separator = " &&\n ";
}
- fExtraFunctions.printf(";\n"
- "}\n");
+
+ fExtraFunctions.printf(
+ ";\n"
+ "}\n"
+ "thread bool operator!=(const %s left, const %s right) {\n"
+ " return !(left == right);\n"
+ "}\n",
+ this->typeName(left).c_str(), this->typeName(right).c_str());
}
}
-void MetalCodeGenerator::writeMatrixInequalityHelper(const Type& left, const Type& right) {
- SkASSERTF(left.rows() == right.rows() && left.columns() == right.columns(), "left=%s, right=%s",
- left.description().c_str(), right.description().c_str());
+void MetalCodeGenerator::writeArrayEqualityHelpers(const Type& type) {
+ SkASSERT(type.isArray());
- String key = "Inequality" + this->typeName(left) + ":" + this->typeName(right);
+ // If the array's component type needs a helper as well, we need to emit that one first.
+ this->writeEqualityHelpers(type.componentType(), type.componentType());
+
+ auto [iter, wasInserted] = fHelpers.insert("ArrayEquality []");
+ if (wasInserted) {
+ fExtraFunctions.writeText(R"(
+template <typename T, size_t N>
+bool operator==(thread const array<T, N>& left, thread const array<T, N>& right) {
+ for (size_t index = 0; index < N; ++index) {
+ if (!(left[index] == right[index])) {
+ return false;
+ }
+ }
+ return true;
+}
+
+template <typename T, size_t N>
+bool operator!=(thread const array<T, N>& left, thread const array<T, N>& right) {
+ return !(left == right);
+}
+)");
+ }
+}
+
+void MetalCodeGenerator::writeStructEqualityHelpers(const Type& type) {
+ SkASSERT(type.isStruct());
+ String key = "StructEquality " + this->typeName(type);
auto [iter, wasInserted] = fHelpers.insert(key);
if (wasInserted) {
- fExtraFunctions.printf(
- "thread bool operator!=(const %s left, const %s right) {\n"
- " return",
- this->typeName(left).c_str(), this->typeName(right).c_str());
-
- for (int index=0; index<left.columns(); ++index) {
- fExtraFunctions.printf("%s any(left[%d] != right[%d])",
- index == 0 ? "" : " ||", index, index);
+ // If one of the struct's fields needs a helper as well, we need to emit that one first.
+ for (const Type::Field& field : type.fields()) {
+ this->writeEqualityHelpers(*field.fType, *field.fType);
}
- fExtraFunctions.printf(";\n"
- "}\n");
+
+ // Write operator== and operator!= for this struct, since those are assumed to exist in SkSL
+ // and GLSL but do not exist by default in Metal.
+ fExtraFunctions.printf(
+ "thread bool operator==(thread const %s& left, thread const %s& right) {\n"
+ " return ",
+ this->typeName(type).c_str(),
+ this->typeName(type).c_str());
+
+ const char* separator = "";
+ for (const Type::Field& field : type.fields()) {
+ fExtraFunctions.printf("%s(left.%.*s == right.%.*s)",
+ separator,
+ (int)field.fName.size(), field.fName.data(),
+ (int)field.fName.size(), field.fName.data());
+ separator = " &&\n ";
+ }
+ fExtraFunctions.printf(
+ ";\n"
+ "}\n"
+ "thread bool operator!=(thread const %s& left, thread const %s& right) {\n"
+ " return !(left == right);\n"
+ "}\n",
+ this->typeName(type).c_str(),
+ this->typeName(type).c_str());
+ }
+}
+
+void MetalCodeGenerator::writeEqualityHelpers(const Type& leftType, const Type& rightType) {
+ if (leftType.isArray() && rightType.isArray()) {
+ this->writeArrayEqualityHelpers(leftType);
+ return;
+ }
+ if (leftType.isStruct() && rightType.isStruct()) {
+ this->writeStructEqualityHelpers(leftType);
+ return;
+ }
+ if (leftType.isMatrix() && rightType.isMatrix()) {
+ this->writeMatrixEqualityHelpers(leftType, rightType);
+ return;
}
}
@@ -1319,12 +1386,14 @@
bool needParens = precedence >= parentPrecedence;
switch (op.kind()) {
case Token::Kind::TK_EQEQ:
+ this->writeEqualityHelpers(leftType, rightType);
if (leftType.isVector()) {
this->write("all");
needParens = true;
}
break;
case Token::Kind::TK_NEQ:
+ this->writeEqualityHelpers(leftType, rightType);
if (leftType.isVector()) {
this->write("any");
needParens = true;
@@ -1336,14 +1405,8 @@
if (needParens) {
this->write("(");
}
- if (leftType.isMatrix() && rightType.isMatrix()) {
- if (op.kind() == Token::Kind::TK_STAREQ) {
- this->writeMatrixTimesEqualHelper(leftType, rightType, b.type());
- } else if (op.kind() == Token::Kind::TK_EQEQ) {
- this->writeMatrixEqualityHelper(leftType, rightType);
- } else if (op.kind() == Token::Kind::TK_NEQ) {
- this->writeMatrixInequalityHelper(leftType, rightType);
- }
+ if (leftType.isMatrix() && rightType.isMatrix() && op.kind() == Token::Kind::TK_STAREQ) {
+ this->writeMatrixTimesEqualHelper(leftType, rightType, b.type());
}
this->writeExpression(left, precedence);
if (op.kind() != Token::Kind::TK_EQ && op.isAssignment() &&
@@ -1720,7 +1783,7 @@
const InterfaceBlock* parentIntf) {
MemoryLayout memoryLayout(MemoryLayout::kMetal_Standard);
int currentOffset = 0;
- for (const auto& field: fields) {
+ for (const Type::Field& field : fields) {
int fieldOffset = field.fModifiers.fLayout.fOffset;
const Type* fieldType = field.fType;
if (!MemoryLayout::LayoutIsSupported(*fieldType)) {