Create a diagonal-matrix Constructor class.
This constructor takes a single argument and splats it diagonally across
an otherwise-zero matrix. These are also sometimes referred to as a
uniform-scale matrix.
Change-Id: I1ed8140f55f5cad4029015807b220d6475401daa
Bug: skia:11032
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/390716
Auto-Submit: John Stiles <johnstiles@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
Commit-Queue: John Stiles <johnstiles@google.com>
diff --git a/src/sksl/SkSLAnalysis.cpp b/src/sksl/SkSLAnalysis.cpp
index e6f505a..47be736 100644
--- a/src/sksl/SkSLAnalysis.cpp
+++ b/src/sksl/SkSLAnalysis.cpp
@@ -41,6 +41,7 @@
#include "src/sksl/ir/SkSLBinaryExpression.h"
#include "src/sksl/ir/SkSLBoolLiteral.h"
#include "src/sksl/ir/SkSLConstructor.h"
+#include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
#include "src/sksl/ir/SkSLExternalFunctionCall.h"
#include "src/sksl/ir/SkSLExternalFunctionReference.h"
#include "src/sksl/ir/SkSLFieldAccess.h"
@@ -99,8 +100,9 @@
} else if (lastArg->type() == *fContext.fTypes.fFloat3x3) {
// Determine the type of matrix for this call site
if (lastArg->isConstantOrUniform()) {
- if (lastArg->kind() == Expression::Kind::kVariableReference ||
- lastArg->kind() == Expression::Kind::kConstructor) {
+ if (lastArg->is<VariableReference>() ||
+ lastArg->is<Constructor>() ||
+ lastArg->is<ConstructorDiagonalMatrix>()) {
// FIXME if this is a constant, we should parse the float3x3 constructor
// and determine if the resulting matrix introduces perspective.
fUsage.merge(SampleUsage::UniformMatrix(lastArg->description()));
@@ -759,6 +761,11 @@
}
return true;
}
+ case Expression::Kind::kConstructorDiagonalMatrix: {
+ const ConstructorDiagonalMatrix& leftCtor = left.as<ConstructorDiagonalMatrix>();
+ const ConstructorDiagonalMatrix& rightCtor = right.as<ConstructorDiagonalMatrix>();
+ return IsSameExpressionTree(*leftCtor.argument(), *rightCtor.argument());
+ }
case Expression::Kind::kFieldAccess:
return left.as<FieldAccess>().fieldIndex() == right.as<FieldAccess>().fieldIndex() &&
IsSameExpressionTree(*left.as<FieldAccess>().base(),
@@ -1012,6 +1019,7 @@
// ... expressions composed of both of the above
case Expression::Kind::kBinary:
case Expression::Kind::kConstructor:
+ case Expression::Kind::kConstructorDiagonalMatrix:
case Expression::Kind::kFieldAccess:
case Expression::Kind::kIndex:
case Expression::Kind::kPrefix:
@@ -1140,6 +1148,9 @@
}
return false;
}
+ case Expression::Kind::kConstructorDiagonalMatrix: {
+ return this->visitExpressionPtr(e.template as<ConstructorDiagonalMatrix>().argument());
+ }
case Expression::Kind::kExternalFunctionCall: {
auto& c = e.template as<ExternalFunctionCall>();
for (auto& arg : c.arguments()) {
diff --git a/src/sksl/SkSLDehydrator.cpp b/src/sksl/SkSLDehydrator.cpp
index 88b9c0e..51d6b46 100644
--- a/src/sksl/SkSLDehydrator.cpp
+++ b/src/sksl/SkSLDehydrator.cpp
@@ -16,6 +16,7 @@
#include "src/sksl/ir/SkSLBinaryExpression.h"
#include "src/sksl/ir/SkSLBreakStatement.h"
#include "src/sksl/ir/SkSLConstructor.h"
+#include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
#include "src/sksl/ir/SkSLContinueStatement.h"
#include "src/sksl/ir/SkSLDiscardStatement.h"
#include "src/sksl/ir/SkSLDoStatement.h"
@@ -283,6 +284,13 @@
}
break;
}
+ case Expression::Kind::kConstructorDiagonalMatrix: {
+ const ConstructorDiagonalMatrix& c = e->as<ConstructorDiagonalMatrix>();
+ this->writeCommand(Rehydrator::kConstructorDiagonalMatrix_Command);
+ this->write(c.type());
+ this->write(c.argument().get());
+ break;
+ }
case Expression::Kind::kExternalFunctionCall:
case Expression::Kind::kExternalFunctionReference:
SkDEBUGFAIL("unimplemented--not expected to be used from within an include file");
diff --git a/src/sksl/SkSLGLSLCodeGenerator.cpp b/src/sksl/SkSLGLSLCodeGenerator.cpp
index 68102d2..69d29cd 100644
--- a/src/sksl/SkSLGLSLCodeGenerator.cpp
+++ b/src/sksl/SkSLGLSLCodeGenerator.cpp
@@ -199,6 +199,10 @@
case Expression::Kind::kConstructor:
this->writeConstructor(expr.as<Constructor>(), parentPrecedence);
break;
+ case Expression::Kind::kConstructorDiagonalMatrix:
+ this->writeConstructorDiagonalMatrix(expr.as<ConstructorDiagonalMatrix>(),
+ parentPrecedence);
+ break;
case Expression::Kind::kIntLiteral:
this->writeIntLiteral(expr.as<IntLiteral>());
break;
@@ -740,6 +744,14 @@
this->write(")");
}
+void GLSLCodeGenerator::writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c,
+ Precedence parentPrecedence) {
+ this->writeType(c.type());
+ this->write("(");
+ this->writeExpression(*c.argument(), Precedence::kSequence);
+ this->write(")");
+}
+
void GLSLCodeGenerator::writeFragCoord() {
if (!this->caps().canUseFragCoord()) {
if (!fSetupFragCoordWorkaround) {
diff --git a/src/sksl/SkSLGLSLCodeGenerator.h b/src/sksl/SkSLGLSLCodeGenerator.h
index 3c81c39..33638e1 100644
--- a/src/sksl/SkSLGLSLCodeGenerator.h
+++ b/src/sksl/SkSLGLSLCodeGenerator.h
@@ -21,6 +21,7 @@
#include "src/sksl/ir/SkSLBinaryExpression.h"
#include "src/sksl/ir/SkSLBoolLiteral.h"
#include "src/sksl/ir/SkSLConstructor.h"
+#include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
#include "src/sksl/ir/SkSLDoStatement.h"
#include "src/sksl/ir/SkSLExtension.h"
#include "src/sksl/ir/SkSLFieldAccess.h"
@@ -136,6 +137,9 @@
void writeConstructor(const Constructor& c, Precedence parentPrecedence);
+ void writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c,
+ Precedence parentPrecedence);
+
virtual void writeFieldAccess(const FieldAccess& f);
virtual void writeSwizzle(const Swizzle& swizzle);
diff --git a/src/sksl/SkSLInliner.cpp b/src/sksl/SkSLInliner.cpp
index 41c7972..7915776 100644
--- a/src/sksl/SkSLInliner.cpp
+++ b/src/sksl/SkSLInliner.cpp
@@ -17,6 +17,7 @@
#include "src/sksl/ir/SkSLBoolLiteral.h"
#include "src/sksl/ir/SkSLBreakStatement.h"
#include "src/sksl/ir/SkSLConstructor.h"
+#include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
#include "src/sksl/ir/SkSLContinueStatement.h"
#include "src/sksl/ir/SkSLDiscardStatement.h"
#include "src/sksl/ir/SkSLDoStatement.h"
@@ -312,6 +313,12 @@
SkASSERT(inlinedCtor);
return inlinedCtor;
}
+ case Expression::Kind::kConstructorDiagonalMatrix: {
+ const ConstructorDiagonalMatrix& ctor = expression.as<ConstructorDiagonalMatrix>();
+ return ConstructorDiagonalMatrix::Make(*fContext, offset,
+ *ctor.type().clone(symbolTableForExpression),
+ expr(ctor.argument()));
+ }
case Expression::Kind::kExternalFunctionCall: {
const ExternalFunctionCall& externalCall = expression.as<ExternalFunctionCall>();
return std::make_unique<ExternalFunctionCall>(offset, &externalCall.function(),
@@ -913,6 +920,11 @@
}
break;
}
+ case Expression::Kind::kConstructorDiagonalMatrix: {
+ ConstructorDiagonalMatrix& ctorExpr = (*expr)->as<ConstructorDiagonalMatrix>();
+ this->visitExpression(&ctorExpr.argument());
+ break;
+ }
case Expression::Kind::kExternalFunctionCall: {
ExternalFunctionCall& funcCallExpr = (*expr)->as<ExternalFunctionCall>();
for (std::unique_ptr<Expression>& arg : funcCallExpr.arguments()) {
diff --git a/src/sksl/SkSLMetalCodeGenerator.cpp b/src/sksl/SkSLMetalCodeGenerator.cpp
index 0fddebe..ee87d9b 100644
--- a/src/sksl/SkSLMetalCodeGenerator.cpp
+++ b/src/sksl/SkSLMetalCodeGenerator.cpp
@@ -171,6 +171,10 @@
case Expression::Kind::kConstructor:
this->writeConstructor(expr.as<Constructor>(), parentPrecedence);
break;
+ case Expression::Kind::kConstructorDiagonalMatrix:
+ this->writeConstructorDiagonalMatrix(expr.as<ConstructorDiagonalMatrix>(),
+ parentPrecedence);
+ break;
case Expression::Kind::kIntLiteral:
this->writeIntLiteral(expr.as<IntLiteral>());
break;
@@ -1104,6 +1108,14 @@
this->write(constructorType.isArray() ? "}" : ")");
}
+void MetalCodeGenerator::writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c,
+ Precedence parentPrecedence) {
+ this->writeType(c.type());
+ this->write("(");
+ this->writeExpression(*c.argument(), Precedence::kSequence);
+ this->write(")");
+}
+
void MetalCodeGenerator::writeFragCoord() {
if (fRTHeightName.length()) {
this->write("float4(_fragCoord.x, ");
@@ -2234,6 +2246,9 @@
}
return result;
}
+ case Expression::Kind::kConstructorDiagonalMatrix: {
+ return this->requirements(e->as<ConstructorDiagonalMatrix>().argument().get());
+ }
case Expression::Kind::kFieldAccess: {
const FieldAccess& f = e->as<FieldAccess>();
if (FieldAccess::OwnerKind::kAnonymousInterfaceBlock == f.ownerKind()) {
diff --git a/src/sksl/SkSLMetalCodeGenerator.h b/src/sksl/SkSLMetalCodeGenerator.h
index 6b38c03..9bd500f 100644
--- a/src/sksl/SkSLMetalCodeGenerator.h
+++ b/src/sksl/SkSLMetalCodeGenerator.h
@@ -22,6 +22,7 @@
#include "src/sksl/ir/SkSLBinaryExpression.h"
#include "src/sksl/ir/SkSLBoolLiteral.h"
#include "src/sksl/ir/SkSLConstructor.h"
+#include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
#include "src/sksl/ir/SkSLDoStatement.h"
#include "src/sksl/ir/SkSLExtension.h"
#include "src/sksl/ir/SkSLFieldAccess.h"
@@ -229,6 +230,9 @@
void writeConstructor(const Constructor& c, Precedence parentPrecedence);
+ void writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c,
+ Precedence parentPrecedence);
+
void writeFieldAccess(const FieldAccess& f);
void writeSwizzle(const Swizzle& swizzle);
diff --git a/src/sksl/SkSLPipelineStageCodeGenerator.cpp b/src/sksl/SkSLPipelineStageCodeGenerator.cpp
index 187ba20..c67d634 100644
--- a/src/sksl/SkSLPipelineStageCodeGenerator.cpp
+++ b/src/sksl/SkSLPipelineStageCodeGenerator.cpp
@@ -14,6 +14,7 @@
#include "src/sksl/SkSLStringStream.h"
#include "src/sksl/ir/SkSLBinaryExpression.h"
#include "src/sksl/ir/SkSLConstructor.h"
+#include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
#include "src/sksl/ir/SkSLExpressionStatement.h"
#include "src/sksl/ir/SkSLFieldAccess.h"
#include "src/sksl/ir/SkSLForStatement.h"
@@ -74,6 +75,8 @@
void writeExpression(const Expression& expr, Precedence parentPrecedence);
void writeFunctionCall(const FunctionCall& c);
void writeConstructor(const Constructor& c, Precedence parentPrecedence);
+ void writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c,
+ Precedence parentPrecedence);
void writeFieldAccess(const FieldAccess& f);
void writeSwizzle(const Swizzle& swizzle);
void writeBinaryExpression(const BinaryExpression& b, Precedence parentPrecedence);
@@ -411,6 +414,10 @@
case Expression::Kind::kConstructor:
this->writeConstructor(expr.as<Constructor>(), parentPrecedence);
break;
+ case Expression::Kind::kConstructorDiagonalMatrix:
+ this->writeConstructorDiagonalMatrix(expr.as<ConstructorDiagonalMatrix>(),
+ parentPrecedence);
+ break;
case Expression::Kind::kFieldAccess:
this->writeFieldAccess(expr.as<FieldAccess>());
break;
@@ -455,6 +462,14 @@
this->write(")");
}
+void PipelineStageCodeGenerator::writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c,
+ Precedence parentPrecedence) {
+ this->writeType(c.type());
+ this->write("(");
+ this->writeExpression(*c.argument(), Precedence::kSequence);
+ this->write(")");
+}
+
void PipelineStageCodeGenerator::writeIndexExpression(const IndexExpression& expr) {
this->writeExpression(*expr.base(), Precedence::kPostfix);
this->write("[");
diff --git a/src/sksl/SkSLRehydrator.cpp b/src/sksl/SkSLRehydrator.cpp
index e289792..a8cc648 100644
--- a/src/sksl/SkSLRehydrator.cpp
+++ b/src/sksl/SkSLRehydrator.cpp
@@ -15,6 +15,8 @@
#include "include/private/SkSLStatement.h"
#include "src/sksl/ir/SkSLBinaryExpression.h"
#include "src/sksl/ir/SkSLBreakStatement.h"
+#include "src/sksl/ir/SkSLConstructor.h"
+#include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
#include "src/sksl/ir/SkSLContinueStatement.h"
#include "src/sksl/ir/SkSLDiscardStatement.h"
#include "src/sksl/ir/SkSLDoStatement.h"
@@ -451,6 +453,11 @@
SkASSERT(ctor);
return ctor;
}
+ case Rehydrator::kConstructorDiagonalMatrix_Command: {
+ const Type* type = this->type();
+ return ConstructorDiagonalMatrix::Make(fContext, /*offset=*/-1, *type,
+ this->expression());
+ }
case Rehydrator::kFieldAccess_Command: {
std::unique_ptr<Expression> base = this->expression();
int index = this->readU8();
diff --git a/src/sksl/SkSLSPIRVCodeGenerator.cpp b/src/sksl/SkSLSPIRVCodeGenerator.cpp
index daab463..9bedcee 100644
--- a/src/sksl/SkSLSPIRVCodeGenerator.cpp
+++ b/src/sksl/SkSLSPIRVCodeGenerator.cpp
@@ -711,6 +711,8 @@
return this->writeBoolLiteral(expr.as<BoolLiteral>());
case Expression::Kind::kConstructor:
return this->writeConstructor(expr.as<Constructor>(), out);
+ case Expression::Kind::kConstructorDiagonalMatrix:
+ return this->writeConstructorDiagonalMatrix(expr.as<ConstructorDiagonalMatrix>(), out);
case Expression::Kind::kIntLiteral:
return this->writeIntLiteral(expr.as<IntLiteral>());
case Expression::Kind::kFieldAccess:
@@ -1688,6 +1690,21 @@
}
}
+SpvId SPIRVCodeGenerator::writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c,
+ OutputStream& out) {
+ const Type& type = c.type();
+ SkASSERT(type.isMatrix());
+ SkASSERT(c.argument()->type().isScalar());
+
+ // Write out the scalar argument.
+ SpvId argument = this->writeExpression(*c.argument(), out);
+
+ // Build the diagonal matrix.
+ SpvId result = this->nextId(&type);
+ this->writeUniformScaleMatrix(result, argument, type, out);
+ return result;
+}
+
static SpvStorageClass_ get_storage_class(const Variable& var,
SpvStorageClass_ fallbackStorageClass) {
const Modifiers& modifiers = var.modifiers();
diff --git a/src/sksl/SkSLSPIRVCodeGenerator.h b/src/sksl/SkSLSPIRVCodeGenerator.h
index 7b032c4..88fe274 100644
--- a/src/sksl/SkSLSPIRVCodeGenerator.h
+++ b/src/sksl/SkSLSPIRVCodeGenerator.h
@@ -22,6 +22,7 @@
#include "src/sksl/ir/SkSLBinaryExpression.h"
#include "src/sksl/ir/SkSLBoolLiteral.h"
#include "src/sksl/ir/SkSLConstructor.h"
+#include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
#include "src/sksl/ir/SkSLDoStatement.h"
#include "src/sksl/ir/SkSLFieldAccess.h"
#include "src/sksl/ir/SkSLFloatLiteral.h"
@@ -287,6 +288,8 @@
SpvId writeConstructor(const Constructor& c, OutputStream& out);
+ SpvId writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c, OutputStream& out);
+
SpvId writeFieldAccess(const FieldAccess& f, OutputStream& out);
SpvId writeSwizzle(const Swizzle& swizzle, OutputStream& out);
diff --git a/src/sksl/SkSLVMGenerator.cpp b/src/sksl/SkSLVMGenerator.cpp
index 250e123..3130ec7 100644
--- a/src/sksl/SkSLVMGenerator.cpp
+++ b/src/sksl/SkSLVMGenerator.cpp
@@ -18,6 +18,7 @@
#include "src/sksl/ir/SkSLBoolLiteral.h"
#include "src/sksl/ir/SkSLBreakStatement.h"
#include "src/sksl/ir/SkSLConstructor.h"
+#include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
#include "src/sksl/ir/SkSLContinueStatement.h"
#include "src/sksl/ir/SkSLDoStatement.h"
#include "src/sksl/ir/SkSLExpressionStatement.h"
@@ -243,6 +244,7 @@
Value writeExpression(const Expression& expr);
Value writeBinaryExpression(const BinaryExpression& b);
Value writeConstructor(const Constructor& c);
+ Value writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c);
Value writeFunctionCall(const FunctionCall& c);
Value writeExternalFunctionCall(const ExternalFunctionCall& c);
Value writeFieldAccess(const FieldAccess& expr);
@@ -770,6 +772,26 @@
return {};
}
+Value SkVMGenerator::writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c) {
+ const Type& dstType = c.type();
+ SkASSERT(dstType.isMatrix());
+ SkASSERT(c.argument()->type() == dstType.componentType());
+
+ Value src = this->writeExpression(*c.argument());
+ Value dst(dstType.rows() * dstType.columns());
+ size_t dstIndex = 0;
+
+ // Matrix-from-scalar builds a diagonal scale matrix
+ for (int c = 0; c < dstType.columns(); ++c) {
+ for (int r = 0; r < dstType.rows(); ++r) {
+ dst[dstIndex++] = (c == r ? f32(src) : fBuilder->splat(0.0f));
+ }
+ }
+
+ SkASSERT(dstIndex == dst.slots());
+ return dst;
+}
+
size_t SkVMGenerator::fieldSlotOffset(const FieldAccess& expr) {
size_t offset = 0;
for (int i = 0; i < expr.fieldIndex(); ++i) {
@@ -1422,6 +1444,8 @@
return fBuilder->splat(e.as<BoolLiteral>().value() ? ~0 : 0);
case Expression::Kind::kConstructor:
return this->writeConstructor(e.as<Constructor>());
+ case Expression::Kind::kConstructorDiagonalMatrix:
+ return this->writeConstructorDiagonalMatrix(e.as<ConstructorDiagonalMatrix>());
case Expression::Kind::kFieldAccess:
return this->writeFieldAccess(e.as<FieldAccess>());
case Expression::Kind::kIndex:
diff --git a/src/sksl/ir/SkSLConstructor.cpp b/src/sksl/ir/SkSLConstructor.cpp
index 55e0218..e56a7cb 100644
--- a/src/sksl/ir/SkSLConstructor.cpp
+++ b/src/sksl/ir/SkSLConstructor.cpp
@@ -8,6 +8,7 @@
#include "src/sksl/ir/SkSLConstructor.h"
#include "src/sksl/ir/SkSLBoolLiteral.h"
+#include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
#include "src/sksl/ir/SkSLFloatLiteral.h"
#include "src/sksl/ir/SkSLIntLiteral.h"
#include "src/sksl/ir/SkSLPrefixExpression.h"
@@ -80,11 +81,19 @@
// A constructor containing a single scalar is a splat (for vectors) or diagonal matrix (for
// matrices). In either event, it's legal regardless of the scalar's type. Synthesize an
// explicit conversion to the proper type (this is a no-op if it's unnecessary).
- ExpressionArray castArgs;
- castArgs.push_back(Constructor::Convert(context, offset, type.componentType(),
- std::move(args)));
- SkASSERT(castArgs.front());
- return std::make_unique<Constructor>(offset, type, std::move(castArgs));
+ std::unique_ptr<Expression> typecast = Constructor::Convert(context, offset,
+ type.componentType(),
+ std::move(args));
+ SkASSERT(typecast);
+
+ if (type.isMatrix()) {
+ // Matrix-from-scalar creates a diagonal matrix.
+ return ConstructorDiagonalMatrix::Make(context, offset, type, std::move(typecast));
+ }
+
+ ExpressionArray typecastArgs;
+ typecastArgs.push_back(std::move(typecast));
+ return std::make_unique<Constructor>(offset, type, std::move(typecastArgs));
}
int expected = type.rows() * type.columns();
@@ -249,6 +258,9 @@
}
Expression::ComparisonResult Constructor::compareConstant(const Expression& other) const {
+ if (other.is<ConstructorDiagonalMatrix>()) {
+ return other.compareConstant(*this);
+ }
if (!other.is<Constructor>()) {
return ComparisonResult::kUnknown;
}
@@ -404,11 +416,12 @@
return col == row ? this->getConstantValue<SKSL_FLOAT>(*this->arguments()[0]) : 0.0;
}
if (argType.isMatrix()) {
- SkASSERT(this->arguments()[0]->is<Constructor>());
+ SkASSERT(this->arguments()[0]->is<Constructor>() ||
+ this->arguments()[0]->is<ConstructorDiagonalMatrix>());
// single matrix argument. make sure we're within the argument's bounds.
if (col < argType.columns() && row < argType.rows()) {
// within bounds, defer to argument
- return this->arguments()[0]->as<Constructor>().getMatComponent(col, row);
+ return this->arguments()[0]->getMatComponent(col, row);
}
// out of bounds
return 0.0;
diff --git a/src/sksl/ir/SkSLConstructorDiagonalMatrix.cpp b/src/sksl/ir/SkSLConstructorDiagonalMatrix.cpp
new file mode 100644
index 0000000..4645fc4
--- /dev/null
+++ b/src/sksl/ir/SkSLConstructorDiagonalMatrix.cpp
@@ -0,0 +1,59 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
+
+#include "src/sksl/ir/SkSLConstructor.h"
+#include "src/sksl/ir/SkSLType.h"
+
+namespace SkSL {
+
+std::unique_ptr<Expression> ConstructorDiagonalMatrix::Make(const Context& context,
+ int offset,
+ const Type& type,
+ std::unique_ptr<Expression> arg) {
+ SkASSERT(type.isMatrix());
+ SkASSERT(arg->type() == type.componentType());
+ return std::make_unique<ConstructorDiagonalMatrix>(offset, type, std::move(arg));
+}
+
+Expression::ComparisonResult ConstructorDiagonalMatrix::compareConstant(
+ const Expression& other) const {
+ SkASSERT(this->type() == other.type());
+
+ // We know that these are the only SkSL constant expressions which can hold a matrix.
+ if (!other.is<Constructor>() && !other.is<ConstructorDiagonalMatrix>()) {
+ return ComparisonResult::kNotEqual;
+ }
+
+ // The other constructor might not be DiagonalMatrix-based, so we check each cell individually.
+ for (int col = 0; col < this->type().columns(); col++) {
+ for (int row = 0; row < this->type().rows(); row++) {
+ if (this->getMatComponent(col, row) != other.getMatComponent(col, row)) {
+ return ComparisonResult::kNotEqual;
+ }
+ }
+ }
+
+ return ComparisonResult::kEqual;
+}
+
+SKSL_FLOAT ConstructorDiagonalMatrix::getMatComponent(int col, int row) const {
+ SkASSERT(this->isCompileTimeConstant());
+ SkASSERT(col >= 0);
+ SkASSERT(row >= 0);
+ SkASSERT(col < this->type().columns());
+ SkASSERT(row < this->type().rows());
+
+ // Our matrix is of the form:
+ // |x 0 0|
+ // |0 x 0|
+ // |0 0 x|
+ return (col == row) ? this->argument()->getConstantFloat() : 0.0;
+}
+
+} // namespace SkSL
diff --git a/src/sksl/ir/SkSLConstructorDiagonalMatrix.h b/src/sksl/ir/SkSLConstructorDiagonalMatrix.h
new file mode 100644
index 0000000..ae76848
--- /dev/null
+++ b/src/sksl/ir/SkSLConstructorDiagonalMatrix.h
@@ -0,0 +1,82 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#ifndef SKSL_CONSTRUCTOR_DIAGONAL_MATRIX
+#define SKSL_CONSTRUCTOR_DIAGONAL_MATRIX
+
+#include "include/private/SkSLDefines.h"
+#include "src/sksl/SkSLContext.h"
+#include "src/sksl/ir/SkSLExpression.h"
+
+#include <memory>
+
+namespace SkSL {
+
+/**
+ * Represents the construction of a diagonal matrix, such as `half3x3(n)`.
+ *
+ * These always contain exactly 1 scalar.
+ */
+class ConstructorDiagonalMatrix final : public Expression {
+public:
+ static constexpr Kind kExpressionKind = Kind::kConstructorDiagonalMatrix;
+
+ ConstructorDiagonalMatrix(int offset, const Type& type, std::unique_ptr<Expression> arg)
+ : INHERITED(offset, kExpressionKind, &type)
+ , fArgument(std::move(arg)) {}
+
+ static std::unique_ptr<Expression> Make(const Context& context,
+ int offset,
+ const Type& type,
+ std::unique_ptr<Expression> arg);
+
+ std::unique_ptr<Expression>& argument() {
+ return fArgument;
+ }
+
+ const std::unique_ptr<Expression>& argument() const {
+ return fArgument;
+ }
+
+ bool hasProperty(Property property) const override {
+ return argument()->hasProperty(property);
+ }
+
+ std::unique_ptr<Expression> clone() const override {
+ return std::make_unique<ConstructorDiagonalMatrix>(fOffset, this->type(),
+ argument()->clone());
+ }
+
+ String description() const override {
+ return this->type().description() + "(" + argument()->description() + ")";
+ }
+
+ const Type& componentType() const {
+ return this->argument()->type();
+ }
+
+ bool isCompileTimeConstant() const override {
+ return argument()->isCompileTimeConstant();
+ }
+
+ bool isConstantOrUniform() const override {
+ return argument()->isConstantOrUniform();
+ }
+
+ ComparisonResult compareConstant(const Expression& other) const override;
+
+ SKSL_FLOAT getMatComponent(int col, int row) const override;
+
+private:
+ std::unique_ptr<Expression> fArgument;
+
+ using INHERITED = Expression;
+};
+
+} // namespace SkSL
+
+#endif
diff --git a/src/sksl/ir/SkSLExpression.h b/src/sksl/ir/SkSLExpression.h
index cf9cdb0..af9c774 100644
--- a/src/sksl/ir/SkSLExpression.h
+++ b/src/sksl/ir/SkSLExpression.h
@@ -30,6 +30,7 @@
kBoolLiteral,
kCodeString,
kConstructor,
+ kConstructorDiagonalMatrix,
kDefined,
kExternalFunctionCall,
kExternalFunctionReference,
diff --git a/src/sksl/ir/SkSLPrefixExpression.cpp b/src/sksl/ir/SkSLPrefixExpression.cpp
index a06fccf..becba3c 100644
--- a/src/sksl/ir/SkSLPrefixExpression.cpp
+++ b/src/sksl/ir/SkSLPrefixExpression.cpp
@@ -10,6 +10,7 @@
#include "src/sksl/SkSLConstantFolder.h"
#include "src/sksl/ir/SkSLBoolLiteral.h"
#include "src/sksl/ir/SkSLConstructor.h"
+#include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
#include "src/sksl/ir/SkSLFloatLiteral.h"
#include "src/sksl/ir/SkSLIntLiteral.h"
@@ -41,6 +42,16 @@
}
break;
+ case Expression::Kind::kConstructorDiagonalMatrix:
+ // Convert `-matrix(literal)` into `matrix(-literal)`.
+ if (context.fConfig->fSettings.fOptimize && value->isCompileTimeConstant()) {
+ ConstructorDiagonalMatrix& ctor = operand->as<ConstructorDiagonalMatrix>();
+ return ConstructorDiagonalMatrix::Make(
+ context, ctor.fOffset, ctor.type(),
+ negate_operand(context, std::move(ctor.argument())));
+ }
+ break;
+
case Expression::Kind::kConstructor:
// To be consistent with prior behavior, the conversion of a negated constructor into a
// constructor of negative values is only performed when optimization is on.
@@ -166,11 +177,10 @@
break;
default:
- SK_ABORT("unsupported prefix operator\n");
+ SK_ABORT("unsupported prefix operator");
}
return PrefixExpression::Make(context, op, std::move(base));
-
}
std::unique_ptr<Expression> PrefixExpression::Make(const Context& context, Operator op,