Reland "Implement operator== and != for Metal structs and arrays."
This is a reland of 830c69ca66d339067cdc06775f59443a06fc15a2
Original change's description:
> Implement operator== and != for Metal structs and arrays.
>
> GLSL/SkSL assumes that == and != on struct/array types should work.
> We need to emit equality and inequality operators whenever we find code
> that compares a struct or array.
>
> Structs and arrays can be arbitrarily nested, and either type can
> contain a matrix. All of these things need custom equality operators in
> Metal. Therefore, we need to recursively generate comparison operators
> when any of these types are encountered.
>
> For arrays we get lucky, and we can cover all possible array types and
> sizes with a single templated operator== method. Structs and matrices
> have no such luck, and are generated separately on a per-type basis.
>
> For each of these types, operator== is implemented as an equality check
> on each field, and operator!= is implemented in terms of operator==.
> Equality and inequality are always emitted together. (Previously, matrix
> equality and inequality were emitted and implemented independently, but
> this is no longer the case.)
>
> Change-Id: I69ee01c0a390d7db6bcb2253ed6336ab20cc4d1d
> Bug: skia:11908, skia:11924
> Reviewed-on: https://skia-review.googlesource.com/c/skia/+/402016
> Auto-Submit: John Stiles <johnstiles@google.com>
> Commit-Queue: Brian Osman <brianosman@google.com>
> Reviewed-by: Brian Osman <brianosman@google.com>
Bug: skia:11908, skia:11924, skia:11929
Change-Id: I6336b6125e9774c1ca73e3d497e3466f11f6f25f
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/402559
Commit-Queue: John Stiles <johnstiles@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
diff --git a/src/sksl/codegen/SkSLMetalCodeGenerator.cpp b/src/sksl/codegen/SkSLMetalCodeGenerator.cpp
index 283e2a3..31e55d3 100644
--- a/src/sksl/codegen/SkSLMetalCodeGenerator.cpp
+++ b/src/sksl/codegen/SkSLMetalCodeGenerator.cpp
@@ -1251,7 +1251,7 @@
void MetalCodeGenerator::writeMatrixTimesEqualHelper(const Type& left, const Type& right,
const Type& result) {
- String key = "TimesEqual" + this->typeName(left) + ":" + this->typeName(right);
+ String key = "TimesEqual " + this->typeName(left) + ":" + this->typeName(right);
auto [iter, wasInserted] = fHelpers.insert(key);
if (wasInserted) {
@@ -1264,47 +1264,114 @@
}
}
-void MetalCodeGenerator::writeMatrixEqualityHelper(const Type& left, const Type& right) {
- SkASSERTF(left.rows() == right.rows() && left.columns() == right.columns(), "left=%s, right=%s",
- left.description().c_str(), right.description().c_str());
+void MetalCodeGenerator::writeMatrixEqualityHelpers(const Type& left, const Type& right) {
+ SkASSERT(left.isMatrix());
+ SkASSERT(right.isMatrix());
+ SkASSERT(left.rows() == right.rows());
+ SkASSERT(left.columns() == right.columns());
- String key = "Equality" + this->typeName(left) + ":" + this->typeName(right);
+ String key = "MatrixEquality " + this->typeName(left) + ":" + this->typeName(right);
auto [iter, wasInserted] = fHelpers.insert(key);
if (wasInserted) {
fExtraFunctions.printf(
"thread bool operator==(const %s left, const %s right) {\n"
- " return",
+ " return ",
this->typeName(left).c_str(), this->typeName(right).c_str());
+ const char* separator = "";
for (int index=0; index<left.columns(); ++index) {
- fExtraFunctions.printf("%s all(left[%d] == right[%d])",
- index == 0 ? "" : " &&", index, index);
+ fExtraFunctions.printf("%sall(left[%d] == right[%d])", separator, index, index);
+ separator = " &&\n ";
}
- fExtraFunctions.printf(";\n"
- "}\n");
+
+ fExtraFunctions.printf(
+ ";\n"
+ "}\n"
+ "thread bool operator!=(const %s left, const %s right) {\n"
+ " return !(left == right);\n"
+ "}\n",
+ this->typeName(left).c_str(), this->typeName(right).c_str());
}
}
-void MetalCodeGenerator::writeMatrixInequalityHelper(const Type& left, const Type& right) {
- SkASSERTF(left.rows() == right.rows() && left.columns() == right.columns(), "left=%s, right=%s",
- left.description().c_str(), right.description().c_str());
+void MetalCodeGenerator::writeArrayEqualityHelpers(const Type& type) {
+ SkASSERT(type.isArray());
- String key = "Inequality" + this->typeName(left) + ":" + this->typeName(right);
+ // If the array's component type needs a helper as well, we need to emit that one first.
+ this->writeEqualityHelpers(type.componentType(), type.componentType());
+
+ auto [iter, wasInserted] = fHelpers.insert("ArrayEquality []");
+ if (wasInserted) {
+ fExtraFunctions.writeText(R"(
+template <typename T, size_t N>
+bool operator==(thread const array<T, N>& left, thread const array<T, N>& right) {
+ for (size_t index = 0; index < N; ++index) {
+ if (!(left[index] == right[index])) {
+ return false;
+ }
+ }
+ return true;
+}
+
+template <typename T, size_t N>
+bool operator!=(thread const array<T, N>& left, thread const array<T, N>& right) {
+ return !(left == right);
+}
+)");
+ }
+}
+
+void MetalCodeGenerator::writeStructEqualityHelpers(const Type& type) {
+ SkASSERT(type.isStruct());
+ String key = "StructEquality " + this->typeName(type);
auto [iter, wasInserted] = fHelpers.insert(key);
if (wasInserted) {
- fExtraFunctions.printf(
- "thread bool operator!=(const %s left, const %s right) {\n"
- " return",
- this->typeName(left).c_str(), this->typeName(right).c_str());
-
- for (int index=0; index<left.columns(); ++index) {
- fExtraFunctions.printf("%s any(left[%d] != right[%d])",
- index == 0 ? "" : " ||", index, index);
+ // If one of the struct's fields needs a helper as well, we need to emit that one first.
+ for (const Type::Field& field : type.fields()) {
+ this->writeEqualityHelpers(*field.fType, *field.fType);
}
- fExtraFunctions.printf(";\n"
- "}\n");
+
+ // Write operator== and operator!= for this struct, since those are assumed to exist in SkSL
+ // and GLSL but do not exist by default in Metal.
+ fExtraFunctions.printf(
+ "thread bool operator==(thread const %s& left, thread const %s& right) {\n"
+ " return ",
+ this->typeName(type).c_str(),
+ this->typeName(type).c_str());
+
+ const char* separator = "";
+ for (const Type::Field& field : type.fields()) {
+ fExtraFunctions.printf("%s(left.%.*s == right.%.*s)",
+ separator,
+ (int)field.fName.size(), field.fName.data(),
+ (int)field.fName.size(), field.fName.data());
+ separator = " &&\n ";
+ }
+ fExtraFunctions.printf(
+ ";\n"
+ "}\n"
+ "thread bool operator!=(thread const %s& left, thread const %s& right) {\n"
+ " return !(left == right);\n"
+ "}\n",
+ this->typeName(type).c_str(),
+ this->typeName(type).c_str());
+ }
+}
+
+void MetalCodeGenerator::writeEqualityHelpers(const Type& leftType, const Type& rightType) {
+ if (leftType.isArray() && rightType.isArray()) {
+ this->writeArrayEqualityHelpers(leftType);
+ return;
+ }
+ if (leftType.isStruct() && rightType.isStruct()) {
+ this->writeStructEqualityHelpers(leftType);
+ return;
+ }
+ if (leftType.isMatrix() && rightType.isMatrix()) {
+ this->writeMatrixEqualityHelpers(leftType, rightType);
+ return;
}
}
@@ -1319,12 +1386,14 @@
bool needParens = precedence >= parentPrecedence;
switch (op.kind()) {
case Token::Kind::TK_EQEQ:
+ this->writeEqualityHelpers(leftType, rightType);
if (leftType.isVector()) {
this->write("all");
needParens = true;
}
break;
case Token::Kind::TK_NEQ:
+ this->writeEqualityHelpers(leftType, rightType);
if (leftType.isVector()) {
this->write("any");
needParens = true;
@@ -1336,14 +1405,8 @@
if (needParens) {
this->write("(");
}
- if (leftType.isMatrix() && rightType.isMatrix()) {
- if (op.kind() == Token::Kind::TK_STAREQ) {
- this->writeMatrixTimesEqualHelper(leftType, rightType, b.type());
- } else if (op.kind() == Token::Kind::TK_EQEQ) {
- this->writeMatrixEqualityHelper(leftType, rightType);
- } else if (op.kind() == Token::Kind::TK_NEQ) {
- this->writeMatrixInequalityHelper(leftType, rightType);
- }
+ if (leftType.isMatrix() && rightType.isMatrix() && op.kind() == Token::Kind::TK_STAREQ) {
+ this->writeMatrixTimesEqualHelper(leftType, rightType, b.type());
}
this->writeExpression(left, precedence);
if (op.kind() != Token::Kind::TK_EQ && op.isAssignment() &&
@@ -1720,7 +1783,7 @@
const InterfaceBlock* parentIntf) {
MemoryLayout memoryLayout(MemoryLayout::kMetal_Standard);
int currentOffset = 0;
- for (const auto& field: fields) {
+ for (const Type::Field& field : fields) {
int fieldOffset = field.fModifiers.fLayout.fOffset;
const Type* fieldType = field.fType;
if (!MemoryLayout::LayoutIsSupported(*fieldType)) {
diff --git a/src/sksl/codegen/SkSLMetalCodeGenerator.h b/src/sksl/codegen/SkSLMetalCodeGenerator.h
index 8f9d524..96b0ca9 100644
--- a/src/sksl/codegen/SkSLMetalCodeGenerator.h
+++ b/src/sksl/codegen/SkSLMetalCodeGenerator.h
@@ -217,9 +217,13 @@
void writeMatrixTimesEqualHelper(const Type& left, const Type& right, const Type& result);
- void writeMatrixEqualityHelper(const Type& left, const Type& right);
+ void writeMatrixEqualityHelpers(const Type& left, const Type& right);
- void writeMatrixInequalityHelper(const Type& left, const Type& right);
+ void writeArrayEqualityHelpers(const Type& type);
+
+ void writeStructEqualityHelpers(const Type& type);
+
+ void writeEqualityHelpers(const Type& leftType, const Type& rightType);
void writeArgumentList(const ExpressionArray& arguments);