Represent splat constructors with a dedicated ConstructorSplat class.
Change-Id: Ic9c3d688b571591d057ab6a4e998f1f9712a1b58
Bug: skia:11032
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/392117
Commit-Queue: Brian Osman <brianosman@google.com>
Auto-Submit: John Stiles <johnstiles@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
diff --git a/src/sksl/SkSLAnalysis.cpp b/src/sksl/SkSLAnalysis.cpp
index d624cdf..bb8c0e0 100644
--- a/src/sksl/SkSLAnalysis.cpp
+++ b/src/sksl/SkSLAnalysis.cpp
@@ -718,10 +718,10 @@
IsTrivialExpression(*expr.as<Swizzle>().base())) ||
(expr.is<FieldAccess>() &&
IsTrivialExpression(*expr.as<FieldAccess>().base())) ||
- (expr.is<Constructor>() &&
- expr.as<Constructor>().arguments().size() == 1 &&
- IsTrivialExpression(*expr.as<Constructor>().arguments().front())) ||
- (expr.is<Constructor>() &&
+ (expr.isAnyConstructor() &&
+ expr.asAnyConstructor().argumentSpan().size() == 1 &&
+ IsTrivialExpression(*expr.asAnyConstructor().argumentSpan().front())) ||
+ (expr.isAnyConstructor() &&
expr.isConstantOrUniform()) ||
(expr.is<IndexExpression>() &&
expr.as<IndexExpression>().index()->is<IntLiteral>() &&
@@ -749,7 +749,8 @@
case Expression::Kind::kConstructor:
case Expression::Kind::kConstructorArray:
- case Expression::Kind::kConstructorDiagonalMatrix: {
+ case Expression::Kind::kConstructorDiagonalMatrix:
+ case Expression::Kind::kConstructorSplat: {
const AnyConstructor& leftCtor = left.asAnyConstructor();
const AnyConstructor& rightCtor = right.asAnyConstructor();
const auto leftSpan = leftCtor.argumentSpan();
@@ -1019,6 +1020,7 @@
case Expression::Kind::kConstructor:
case Expression::Kind::kConstructorArray:
case Expression::Kind::kConstructorDiagonalMatrix:
+ case Expression::Kind::kConstructorSplat:
case Expression::Kind::kFieldAccess:
case Expression::Kind::kIndex:
case Expression::Kind::kPrefix:
@@ -1142,7 +1144,8 @@
}
case Expression::Kind::kConstructor:
case Expression::Kind::kConstructorArray:
- case Expression::Kind::kConstructorDiagonalMatrix: {
+ case Expression::Kind::kConstructorDiagonalMatrix:
+ case Expression::Kind::kConstructorSplat: {
auto& c = e.asAnyConstructor();
for (auto& arg : c.argumentSpan()) {
if (this->visitExpressionPtr(arg)) { return true; }
diff --git a/src/sksl/SkSLConstantFolder.cpp b/src/sksl/SkSLConstantFolder.cpp
index 7105065..440b4cc 100644
--- a/src/sksl/SkSLConstantFolder.cpp
+++ b/src/sksl/SkSLConstantFolder.cpp
@@ -14,6 +14,7 @@
#include "src/sksl/ir/SkSLBinaryExpression.h"
#include "src/sksl/ir/SkSLBoolLiteral.h"
#include "src/sksl/ir/SkSLConstructor.h"
+#include "src/sksl/ir/SkSLConstructorSplat.h"
#include "src/sksl/ir/SkSLExpression.h"
#include "src/sksl/ir/SkSLFloatLiteral.h"
#include "src/sksl/ir/SkSLIntLiteral.h"
@@ -126,14 +127,12 @@
return ctor;
}
-static Constructor splat_scalar(const Expression& scalar, const Type& type) {
+static ConstructorSplat splat_scalar(const Expression& scalar, const Type& type) {
SkASSERT(type.isVector());
SkASSERT(type.componentType() == scalar.type());
- // Use a Constructor to splat the scalar expression across a vector.
- ExpressionArray arg;
- arg.push_back(scalar.clone());
- return Constructor{scalar.fOffset, type, std::move(arg)};
+ // Use a constructor to splat the scalar expression across a vector.
+ return ConstructorSplat{scalar.fOffset, type, scalar.clone()};
}
bool ConstantFolder::GetConstantInt(const Expression& value, SKSL_INT* out) {
@@ -161,8 +160,8 @@
}
static bool contains_constant_zero(const Expression& expr) {
- if (expr.is<Constructor>()) {
- for (const auto& arg : expr.as<Constructor>().arguments()) {
+ if (expr.isAnyConstructor()) {
+ for (const auto& arg : expr.asAnyConstructor().argumentSpan()) {
if (contains_constant_zero(*arg)) {
return true;
}
@@ -176,8 +175,8 @@
// This check only supports scalars and vectors (and in particular, not matrices).
SkASSERT(expr.type().isScalar() || expr.type().isVector());
- if (expr.is<Constructor>()) {
- for (const auto& arg : expr.as<Constructor>().arguments()) {
+ if (expr.isAnyConstructor()) {
+ for (const auto& arg : expr.asAnyConstructor().argumentSpan()) {
if (!is_constant_value(*arg, value)) {
return false;
}
diff --git a/src/sksl/SkSLDehydrator.cpp b/src/sksl/SkSLDehydrator.cpp
index e62e27b..0a79483 100644
--- a/src/sksl/SkSLDehydrator.cpp
+++ b/src/sksl/SkSLDehydrator.cpp
@@ -18,6 +18,7 @@
#include "src/sksl/ir/SkSLConstructor.h"
#include "src/sksl/ir/SkSLConstructorArray.h"
#include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
+#include "src/sksl/ir/SkSLConstructorSplat.h"
#include "src/sksl/ir/SkSLContinueStatement.h"
#include "src/sksl/ir/SkSLDiscardStatement.h"
#include "src/sksl/ir/SkSLDoStatement.h"
@@ -301,6 +302,12 @@
this->writeExpressionSpan(e->as<ConstructorDiagonalMatrix>().argumentSpan());
break;
+ case Expression::Kind::kConstructorSplat:
+ this->writeCommand(Rehydrator::kConstructorSplat_Command);
+ this->write(e->type());
+ this->writeExpressionSpan(e->as<ConstructorSplat>().argumentSpan());
+ break;
+
case Expression::Kind::kExternalFunctionCall:
case Expression::Kind::kExternalFunctionReference:
SkDEBUGFAIL("unimplemented--not expected to be used from within an include file");
diff --git a/src/sksl/SkSLGLSLCodeGenerator.cpp b/src/sksl/SkSLGLSLCodeGenerator.cpp
index ab8a897..b475575 100644
--- a/src/sksl/SkSLGLSLCodeGenerator.cpp
+++ b/src/sksl/SkSLGLSLCodeGenerator.cpp
@@ -203,6 +203,7 @@
break;
case Expression::Kind::kConstructorArray:
case Expression::Kind::kConstructorDiagonalMatrix:
+ case Expression::Kind::kConstructorSplat:
this->writeAnyConstructor(expr.asAnyConstructor(), parentPrecedence);
break;
case Expression::Kind::kIntLiteral:
diff --git a/src/sksl/SkSLInliner.cpp b/src/sksl/SkSLInliner.cpp
index 422013b..2d61939 100644
--- a/src/sksl/SkSLInliner.cpp
+++ b/src/sksl/SkSLInliner.cpp
@@ -19,6 +19,7 @@
#include "src/sksl/ir/SkSLConstructor.h"
#include "src/sksl/ir/SkSLConstructorArray.h"
#include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
+#include "src/sksl/ir/SkSLConstructorSplat.h"
#include "src/sksl/ir/SkSLContinueStatement.h"
#include "src/sksl/ir/SkSLDiscardStatement.h"
#include "src/sksl/ir/SkSLDoStatement.h"
@@ -326,6 +327,12 @@
*ctor.type().clone(symbolTableForExpression),
expr(ctor.argument()));
}
+ case Expression::Kind::kConstructorSplat: {
+ const ConstructorSplat& ctor = expression.as<ConstructorSplat>();
+ return ConstructorSplat::Make(*fContext, offset,
+ *ctor.type().clone(symbolTableForExpression),
+ expr(ctor.argument()));
+ }
case Expression::Kind::kExternalFunctionCall: {
const ExternalFunctionCall& externalCall = expression.as<ExternalFunctionCall>();
return std::make_unique<ExternalFunctionCall>(offset, &externalCall.function(),
@@ -922,7 +929,8 @@
}
case Expression::Kind::kConstructor:
case Expression::Kind::kConstructorArray:
- case Expression::Kind::kConstructorDiagonalMatrix: {
+ case Expression::Kind::kConstructorDiagonalMatrix:
+ case Expression::Kind::kConstructorSplat: {
AnyConstructor& constructorExpr = (*expr)->asAnyConstructor();
for (std::unique_ptr<Expression>& arg : constructorExpr.argumentSpan()) {
this->visitExpression(&arg);
diff --git a/src/sksl/SkSLMetalCodeGenerator.cpp b/src/sksl/SkSLMetalCodeGenerator.cpp
index 60148cd..54bcae3 100644
--- a/src/sksl/SkSLMetalCodeGenerator.cpp
+++ b/src/sksl/SkSLMetalCodeGenerator.cpp
@@ -12,6 +12,7 @@
#include "src/sksl/SkSLMemoryLayout.h"
#include "src/sksl/ir/SkSLConstructorArray.h"
#include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
+#include "src/sksl/ir/SkSLConstructorSplat.h"
#include "src/sksl/ir/SkSLExpressionStatement.h"
#include "src/sksl/ir/SkSLExtension.h"
#include "src/sksl/ir/SkSLIndexExpression.h"
@@ -180,6 +181,9 @@
this->writeSingleArgumentConstructor(expr.as<ConstructorDiagonalMatrix>(),
parentPrecedence);
break;
+ case Expression::Kind::kConstructorSplat:
+ this->writeSingleArgumentConstructor(expr.as<ConstructorSplat>(), parentPrecedence);
+ break;
case Expression::Kind::kIntLiteral:
this->writeIntLiteral(expr.as<IntLiteral>());
break;
@@ -2260,7 +2264,8 @@
}
case Expression::Kind::kConstructor:
case Expression::Kind::kConstructorArray:
- case Expression::Kind::kConstructorDiagonalMatrix: {
+ case Expression::Kind::kConstructorDiagonalMatrix:
+ case Expression::Kind::kConstructorSplat: {
const AnyConstructor& c = e->asAnyConstructor();
Requirements result = kNo_Requirements;
for (const auto& arg : c.argumentSpan()) {
diff --git a/src/sksl/SkSLPipelineStageCodeGenerator.cpp b/src/sksl/SkSLPipelineStageCodeGenerator.cpp
index 93a7bde..2a50909 100644
--- a/src/sksl/SkSLPipelineStageCodeGenerator.cpp
+++ b/src/sksl/SkSLPipelineStageCodeGenerator.cpp
@@ -412,6 +412,7 @@
case Expression::Kind::kConstructor:
case Expression::Kind::kConstructorArray:
case Expression::Kind::kConstructorDiagonalMatrix:
+ case Expression::Kind::kConstructorSplat:
this->writeAnyConstructor(expr.asAnyConstructor(), parentPrecedence);
break;
case Expression::Kind::kFieldAccess:
diff --git a/src/sksl/SkSLRehydrator.cpp b/src/sksl/SkSLRehydrator.cpp
index 57eb9ea..e3eb784 100644
--- a/src/sksl/SkSLRehydrator.cpp
+++ b/src/sksl/SkSLRehydrator.cpp
@@ -18,6 +18,7 @@
#include "src/sksl/ir/SkSLConstructor.h"
#include "src/sksl/ir/SkSLConstructorArray.h"
#include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
+#include "src/sksl/ir/SkSLConstructorSplat.h"
#include "src/sksl/ir/SkSLContinueStatement.h"
#include "src/sksl/ir/SkSLDiscardStatement.h"
#include "src/sksl/ir/SkSLDoStatement.h"
@@ -470,6 +471,12 @@
return ConstructorDiagonalMatrix::Make(fContext, /*offset=*/-1, *type,
std::move(args[0]));
}
+ case Rehydrator::kConstructorSplat_Command: {
+ const Type* type = this->type();
+ ExpressionArray args = this->expressionArray();
+ SkASSERT(args.size() == 1);
+ return ConstructorSplat::Make(fContext, /*offset=*/-1, *type, std::move(args[0]));
+ }
case Rehydrator::kFieldAccess_Command: {
std::unique_ptr<Expression> base = this->expression();
int index = this->readU8();
diff --git a/src/sksl/SkSLSPIRVCodeGenerator.cpp b/src/sksl/SkSLSPIRVCodeGenerator.cpp
index 7bdafc0..128d3d5 100644
--- a/src/sksl/SkSLSPIRVCodeGenerator.cpp
+++ b/src/sksl/SkSLSPIRVCodeGenerator.cpp
@@ -715,6 +715,8 @@
return this->writeArrayConstructor(expr.as<ConstructorArray>(), out);
case Expression::Kind::kConstructorDiagonalMatrix:
return this->writeConstructorDiagonalMatrix(expr.as<ConstructorDiagonalMatrix>(), out);
+ case Expression::Kind::kConstructorSplat:
+ return this->writeConstructorSplat(expr.as<ConstructorSplat>(), out);
case Expression::Kind::kIntLiteral:
return this->writeIntLiteral(expr.as<IntLiteral>());
case Expression::Kind::kFieldAccess:
@@ -1172,7 +1174,7 @@
return result;
}
-SpvId SPIRVCodeGenerator::writeConstantVector(const Constructor& c) {
+SpvId SPIRVCodeGenerator::writeConstantVector(const AnyConstructor& c) {
const Type& type = c.type();
SkASSERT(type.isVector() && c.isCompileTimeConstant());
@@ -1625,26 +1627,39 @@
arguments.push_back(this->writeExpression(*c.arguments()[i], out));
}
}
+ SkASSERT((int)arguments.size() == type.columns());
+
SpvId result = this->nextId(&type);
- if (arguments.size() == 1 && c.arguments()[0]->type().isScalar()) {
- this->writeOpCode(SpvOpCompositeConstruct, 3 + type.columns(), out);
- this->writeWord(this->getType(type), out);
- this->writeWord(result, out);
- for (int i = 0; i < type.columns(); i++) {
- this->writeWord(arguments[0], out);
- }
- } else {
- SkASSERT(arguments.size() > 1);
- this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) arguments.size(), out);
- this->writeWord(this->getType(type), out);
- this->writeWord(result, out);
- for (SpvId id : arguments) {
- this->writeWord(id, out);
- }
+ this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) arguments.size(), out);
+ this->writeWord(this->getType(type), out);
+ this->writeWord(result, out);
+ for (SpvId id : arguments) {
+ this->writeWord(id, out);
}
return result;
}
+SpvId SPIRVCodeGenerator::writeConstructorSplat(const ConstructorSplat& c, OutputStream& out) {
+ // Use writeConstantVector to deduplicate constant splats.
+ if (c.isCompileTimeConstant()) {
+ return this->writeConstantVector(c);
+ }
+
+ // Write the splat argument.
+ SpvId argument = this->writeExpression(*c.argument(), out);
+
+ // Generate a OpCompositeConstruct which repeats the argument N times.
+ SpvId result = this->nextId(&c.type());
+ this->writeOpCode(SpvOpCompositeConstruct, 3 + c.type().columns(), out);
+ this->writeWord(this->getType(c.type()), out);
+ this->writeWord(result, out);
+ for (int i = 0; i < c.type().columns(); i++) {
+ this->writeWord(argument, out);
+ }
+ return result;
+}
+
+
SpvId SPIRVCodeGenerator::writeArrayConstructor(const ConstructorArray& c, OutputStream& out) {
const Type& type = c.type();
SkASSERT(type.isArray());
diff --git a/src/sksl/SkSLSPIRVCodeGenerator.h b/src/sksl/SkSLSPIRVCodeGenerator.h
index 0789b2d..9552957 100644
--- a/src/sksl/SkSLSPIRVCodeGenerator.h
+++ b/src/sksl/SkSLSPIRVCodeGenerator.h
@@ -24,6 +24,7 @@
#include "src/sksl/ir/SkSLConstructor.h"
#include "src/sksl/ir/SkSLConstructorArray.h"
#include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
+#include "src/sksl/ir/SkSLConstructorSplat.h"
#include "src/sksl/ir/SkSLDoStatement.h"
#include "src/sksl/ir/SkSLFieldAccess.h"
#include "src/sksl/ir/SkSLFloatLiteral.h"
@@ -241,7 +242,7 @@
SpvId writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind, OutputStream& out);
- SpvId writeConstantVector(const Constructor& c);
+ SpvId writeConstantVector(const AnyConstructor& c);
SpvId writeFloatConstructor(const Constructor& c, OutputStream& out);
@@ -291,6 +292,8 @@
SpvId writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c, OutputStream& out);
+ SpvId writeConstructorSplat(const ConstructorSplat& c, OutputStream& out);
+
SpvId writeFieldAccess(const FieldAccess& f, OutputStream& out);
SpvId writeSwizzle(const Swizzle& swizzle, OutputStream& out);
diff --git a/src/sksl/SkSLVMGenerator.cpp b/src/sksl/SkSLVMGenerator.cpp
index 8348f58..3bd5331 100644
--- a/src/sksl/SkSLVMGenerator.cpp
+++ b/src/sksl/SkSLVMGenerator.cpp
@@ -20,6 +20,7 @@
#include "src/sksl/ir/SkSLConstructor.h"
#include "src/sksl/ir/SkSLConstructorArray.h"
#include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
+#include "src/sksl/ir/SkSLConstructorSplat.h"
#include "src/sksl/ir/SkSLContinueStatement.h"
#include "src/sksl/ir/SkSLDoStatement.h"
#include "src/sksl/ir/SkSLExpressionStatement.h"
@@ -247,6 +248,7 @@
Value writeConstructor(const Constructor& c);
Value writeMultiArgumentConstructor(const MultiArgumentConstructor& c);
Value writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c);
+ Value writeConstructorSplat(const ConstructorSplat& c);
Value writeFunctionCall(const FunctionCall& c);
Value writeExternalFunctionCall(const ExternalFunctionCall& c);
Value writeFieldAccess(const FieldAccess& expr);
@@ -768,19 +770,24 @@
return dst;
}
- // We can splat scalars to all components of a vector
- if (dstType.isVector() && srcType.isScalar()) {
- Value dst(dstType.columns());
- for (int i = 0; i < dstType.columns(); ++i) {
- dst[i] = src[0];
- }
- return dst;
- }
-
SkDEBUGFAIL("Invalid constructor");
return {};
}
+Value SkVMGenerator::writeConstructorSplat(const ConstructorSplat& c) {
+ SkASSERT(c.type().isVector());
+ SkASSERT(c.argument()->type().isScalar());
+ int columns = c.type().columns();
+
+ // Splat the argument across all components of a vector.
+ Value src = this->writeExpression(*c.argument());
+ Value dst(columns);
+ for (int i = 0; i < columns; ++i) {
+ dst[i] = src[0];
+ }
+ return dst;
+}
+
Value SkVMGenerator::writeConstructorDiagonalMatrix(const ConstructorDiagonalMatrix& c) {
const Type& dstType = c.type();
SkASSERT(dstType.isMatrix());
@@ -1457,6 +1464,8 @@
return this->writeMultiArgumentConstructor(e.as<ConstructorArray>());
case Expression::Kind::kConstructorDiagonalMatrix:
return this->writeConstructorDiagonalMatrix(e.as<ConstructorDiagonalMatrix>());
+ case Expression::Kind::kConstructorSplat:
+ return this->writeConstructorSplat(e.as<ConstructorSplat>());
case Expression::Kind::kFieldAccess:
return this->writeFieldAccess(e.as<FieldAccess>());
case Expression::Kind::kIndex:
diff --git a/src/sksl/generated/sksl_gpu.dehydrated.sksl b/src/sksl/generated/sksl_gpu.dehydrated.sksl
index a2b6c7f..f6ff911 100644
--- a/src/sksl/generated/sksl_gpu.dehydrated.sksl
+++ b/src/sksl/generated/sksl_gpu.dehydrated.sksl
@@ -3899,7 +3899,7 @@
2,
45,0,0,0,0,1,
37,
-6,
+9,
43,15,2,1,
22,
43,176,0,0,0,0,0,1,0,
@@ -3948,11 +3948,11 @@
47,
1,
52,167,3,0,66,
-6,
+9,
43,15,2,1,
22,
43,176,0,0,0,0,0,
-6,
+9,
43,15,2,1,
22,
43,176,0,0,0,0,0,
@@ -5115,7 +5115,7 @@
2,
45,0,0,0,0,1,
37,
-6,
+9,
43,172,1,1,
22,
43,176,0,0,0,0,0,1,1,1,211,3,
@@ -5823,7 +5823,7 @@
52,18,4,0,
53,
37,
-6,
+9,
43,15,2,1,
22,
43,176,0,0,0,0,0,1,29,154,3,157,3,160,3,163,3,166,3,169,3,172,3,175,3,178,3,181,3,184,3,187,3,190,3,193,3,196,3,202,3,205,3,208,3,221,3,227,3,230,3,236,3,239,3,242,3,245,3,6,4,9,4,12,4,15,4,
diff --git a/src/sksl/ir/SkSLConstructor.cpp b/src/sksl/ir/SkSLConstructor.cpp
index a704651..d16fcaf 100644
--- a/src/sksl/ir/SkSLConstructor.cpp
+++ b/src/sksl/ir/SkSLConstructor.cpp
@@ -10,6 +10,7 @@
#include "src/sksl/ir/SkSLBoolLiteral.h"
#include "src/sksl/ir/SkSLConstructorArray.h"
#include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
+#include "src/sksl/ir/SkSLConstructorSplat.h"
#include "src/sksl/ir/SkSLFloatLiteral.h"
#include "src/sksl/ir/SkSLIntLiteral.h"
#include "src/sksl/ir/SkSLPrefixExpression.h"
@@ -87,14 +88,10 @@
std::move(args));
SkASSERT(typecast);
- if (type.isMatrix()) {
- // Matrix-from-scalar creates a diagonal matrix.
- return ConstructorDiagonalMatrix::Make(context, offset, type, std::move(typecast));
- }
-
- ExpressionArray typecastArgs;
- typecastArgs.push_back(std::move(typecast));
- return std::make_unique<Constructor>(offset, type, std::move(typecastArgs));
+ // Matrix-from-scalar creates a diagonal matrix; vector-from-scalar creates a splat.
+ return type.isMatrix()
+ ? ConstructorDiagonalMatrix::Make(context, offset, type, std::move(typecast))
+ : ConstructorSplat::Make(context, offset, type, std::move(typecast));
}
int expected = type.rows() * type.columns();
@@ -229,6 +226,9 @@
if (other.is<ConstructorDiagonalMatrix>()) {
return other.compareConstant(*this);
}
+ if (other.is<ConstructorSplat>()) {
+ return other.compareConstant(*this);
+ }
if (!other.is<Constructor>()) {
return ComparisonResult::kUnknown;
}
@@ -373,8 +373,7 @@
return col == row ? this->getConstantValue<SKSL_FLOAT>(*this->arguments()[0]) : 0.0;
}
if (argType.isMatrix()) {
- SkASSERT(this->arguments()[0]->is<Constructor>() ||
- this->arguments()[0]->is<ConstructorDiagonalMatrix>());
+ SkASSERT(this->arguments()[0]->isAnyConstructor());
// single matrix argument. make sure we're within the argument's bounds.
if (col < argType.columns() && row < argType.rows()) {
// within bounds, defer to argument
diff --git a/src/sksl/ir/SkSLConstructorSplat.cpp b/src/sksl/ir/SkSLConstructorSplat.cpp
new file mode 100644
index 0000000..f0306bf
--- /dev/null
+++ b/src/sksl/ir/SkSLConstructorSplat.cpp
@@ -0,0 +1,46 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#include "src/sksl/ir/SkSLConstructorSplat.h"
+
+namespace SkSL {
+
+std::unique_ptr<Expression> ConstructorSplat::Make(const Context& context,
+ int offset,
+ const Type& type,
+ std::unique_ptr<Expression> arg) {
+ SkASSERT(type.isVector());
+ SkASSERT(arg->type() == type.componentType());
+ return std::make_unique<ConstructorSplat>(offset, type, std::move(arg));
+}
+
+Expression::ComparisonResult ConstructorSplat::compareConstant(const Expression& other) const {
+ SkASSERT(this->type() == other.type());
+ if (!other.isAnyConstructor()) {
+ return ComparisonResult::kUnknown;
+ }
+
+ return this->compareConstantConstructor(other.asAnyConstructor());
+}
+
+Expression::ComparisonResult ConstructorSplat::compareConstantConstructor(
+ const AnyConstructor& other) const {
+ ComparisonResult check = ComparisonResult::kEqual;
+ for (const std::unique_ptr<Expression>& expr : other.argumentSpan()) {
+ // We need to recurse to handle nested constructors like `half4(1) == half4(half2(1), 1, 1)`
+ check = expr->isAnyConstructor()
+ ? this->compareConstantConstructor(expr->asAnyConstructor())
+ : argument()->compareConstant(*expr);
+ if (check != ComparisonResult::kEqual) {
+ break;
+ }
+ }
+
+ return check;
+}
+
+} // namespace SkSL
diff --git a/src/sksl/ir/SkSLConstructorSplat.h b/src/sksl/ir/SkSLConstructorSplat.h
new file mode 100644
index 0000000..a9da4fc
--- /dev/null
+++ b/src/sksl/ir/SkSLConstructorSplat.h
@@ -0,0 +1,62 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#ifndef SKSL_CONSTRUCTOR_SPLAT
+#define SKSL_CONSTRUCTOR_SPLAT
+
+#include "src/sksl/SkSLContext.h"
+#include "src/sksl/ir/SkSLConstructor.h"
+#include "src/sksl/ir/SkSLExpression.h"
+
+#include <memory>
+
+namespace SkSL {
+
+/**
+ * Represents the construction of a vector splat, such as `half3(n)`.
+ *
+ * These always contain exactly 1 scalar.
+ */
+class ConstructorSplat final : public SingleArgumentConstructor {
+public:
+ static constexpr Kind kExpressionKind = Kind::kConstructorSplat;
+
+ ConstructorSplat(int offset, const Type& type, std::unique_ptr<Expression> arg)
+ : INHERITED(offset, kExpressionKind, &type, std::move(arg)) {}
+
+ static std::unique_ptr<Expression> Make(const Context& context,
+ int offset,
+ const Type& type,
+ std::unique_ptr<Expression> arg);
+
+ std::unique_ptr<Expression> clone() const override {
+ return std::make_unique<ConstructorSplat>(fOffset, this->type(), argument()->clone());
+ }
+
+ ComparisonResult compareConstant(const Expression& other) const override;
+
+ SKSL_FLOAT getFVecComponent(int) const override {
+ return this->argument()->getConstantFloat();
+ }
+
+ SKSL_INT getIVecComponent(int) const override {
+ return this->argument()->getConstantInt();
+ }
+
+ bool getBVecComponent(int) const override {
+ return this->argument()->getConstantBool();
+ }
+
+private:
+ Expression::ComparisonResult compareConstantConstructor(const AnyConstructor& other) const;
+
+ using INHERITED = SingleArgumentConstructor;
+};
+
+} // namespace SkSL
+
+#endif
diff --git a/src/sksl/ir/SkSLExpression.h b/src/sksl/ir/SkSLExpression.h
index 300224a..71eeacf 100644
--- a/src/sksl/ir/SkSLExpression.h
+++ b/src/sksl/ir/SkSLExpression.h
@@ -33,6 +33,7 @@
kConstructor,
kConstructorArray,
kConstructorDiagonalMatrix,
+ kConstructorSplat,
kDefined,
kExternalFunctionCall,
kExternalFunctionReference,
@@ -83,9 +84,8 @@
}
bool isAnyConstructor() const {
- static_assert((int)Kind::kConstructorDiagonalMatrix + 1 == (int)Kind::kDefined);
- return this->kind() >= Kind::kConstructor &&
- this->kind() <= Kind::kConstructorDiagonalMatrix;
+ static_assert((int)Kind::kConstructorSplat + 1 == (int)Kind::kDefined);
+ return this->kind() >= Kind::kConstructor && this->kind() <= Kind::kConstructorSplat;
}
/**
diff --git a/src/sksl/ir/SkSLPrefixExpression.cpp b/src/sksl/ir/SkSLPrefixExpression.cpp
index 0fe9502..0853e4d 100644
--- a/src/sksl/ir/SkSLPrefixExpression.cpp
+++ b/src/sksl/ir/SkSLPrefixExpression.cpp
@@ -12,6 +12,7 @@
#include "src/sksl/ir/SkSLConstructor.h"
#include "src/sksl/ir/SkSLConstructorArray.h"
#include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
+#include "src/sksl/ir/SkSLConstructorSplat.h"
#include "src/sksl/ir/SkSLFloatLiteral.h"
#include "src/sksl/ir/SkSLIntLiteral.h"
@@ -65,6 +66,15 @@
}
break;
+ case Expression::Kind::kConstructorSplat:
+ // Convert `-vector(literal)` into `vector(-literal)`.
+ if (context.fConfig->fSettings.fOptimize && value->isCompileTimeConstant()) {
+ ConstructorSplat& ctor = operand->as<ConstructorSplat>();
+ return ConstructorSplat::Make(context, ctor.fOffset, ctor.type(),
+ negate_operand(context, std::move(ctor.argument())));
+ }
+ break;
+
case Expression::Kind::kConstructor:
// To be consistent with prior behavior, the conversion of a negated constructor into a
// constructor of negative values is only performed when optimization is on.
diff --git a/src/sksl/ir/SkSLSwizzle.cpp b/src/sksl/ir/SkSLSwizzle.cpp
index 655a4c3..9b8bd8d 100644
--- a/src/sksl/ir/SkSLSwizzle.cpp
+++ b/src/sksl/ir/SkSLSwizzle.cpp
@@ -6,6 +6,7 @@
*/
#include "src/sksl/ir/SkSLConstructor.h"
+#include "src/sksl/ir/SkSLConstructorSplat.h"
#include "src/sksl/ir/SkSLSwizzle.h"
namespace SkSL {
@@ -189,6 +190,21 @@
return Swizzle::Make(context, std::move(base.base()), combined);
}
+ // `half4(scalar).zyy` can be optimized to `half3(scalar)`, and `half3(scalar).y` can be
+ // optimized to just `scalar`. The swizzle components don't actually matter, as every field
+ // in a splat constructor holds the same value.
+ if (expr->is<ConstructorSplat>()) {
+ ConstructorSplat& splat = expr->as<ConstructorSplat>();
+ ExpressionArray ctorArgs;
+ ctorArgs.push_back(std::move(splat.argument()));
+ auto ctor = Constructor::Convert(
+ context, splat.fOffset,
+ splat.type().componentType().toCompound(context, components.size(), /*rows=*/1),
+ std::move(ctorArgs));
+ SkASSERT(ctor);
+ return ctor;
+ }
+
// Optimize swizzles of constructors.
if (expr->is<Constructor>()) {
Constructor& base = expr->as<Constructor>();
@@ -196,17 +212,6 @@
const Type& componentType = exprType.componentType();
int swizzleSize = components.size();
- // `half4(scalar).zyy` can be optimized to `half3(scalar)`. The swizzle components don't
- // actually matter since all fields are the same.
- if (base.arguments().size() == 1 && base.arguments().front()->type().isScalar()) {
- auto ctor = Constructor::Convert(
- context, base.fOffset,
- componentType.toCompound(context, swizzleSize, /*rows=*/1),
- std::move(base.arguments()));
- SkASSERT(ctor);
- return ctor;
- }
-
// Swizzles can duplicate some elements and discard others, e.g.
// `half4(1, 2, 3, 4).xxz` --> `half3(1, 1, 3)`. However, there are constraints:
// - Expressions with side effects need to occur exactly once, even if they