Implement constant folding for vector*scalar ops.
Change-Id: I96b547de4fe4b73096fb26d0ef21a4e7555ca06a
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/352238
Commit-Queue: John Stiles <johnstiles@google.com>
Auto-Submit: John Stiles <johnstiles@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
diff --git a/src/sksl/SkSLConstantFolder.cpp b/src/sksl/SkSLConstantFolder.cpp
index 94aa5d0..e85a7c2 100644
--- a/src/sksl/SkSLConstantFolder.cpp
+++ b/src/sksl/SkSLConstantFolder.cpp
@@ -53,6 +53,7 @@
const Expression& left,
Token::Kind op,
const Expression& right) {
+ SkASSERT(left.type().isVector());
SkASSERT(left.type() == right.type());
const Type& type = left.type();
@@ -110,6 +111,16 @@
}
}
+static Constructor splat_scalar(const Expression& scalar, const Type& type) {
+ SkASSERT(type.isVector());
+ SkASSERT(type.componentType() == scalar.type());
+
+ // Use a Constructor to splat the scalar expression across a vector.
+ ExpressionArray arg;
+ arg.push_back(scalar.clone());
+ return Constructor{scalar.fOffset, &type, std::move(arg)};
+}
+
std::unique_ptr<Expression> ConstantFolder::Simplify(const Context& context,
ErrorReporter& errors,
const Expression& left,
@@ -246,6 +257,32 @@
return nullptr;
}
+ // Perform constant folding on vectors against scalars, e.g.: half4(2) + 2
+ if (leftType.isVector() && leftType.componentType() == rightType) {
+ if (rightType.isFloat()) {
+ return simplify_vector<SKSL_FLOAT>(context, errors,
+ left, op, splat_scalar(right, left.type()));
+ }
+ if (rightType.isInteger()) {
+ return simplify_vector<SKSL_INT>(context, errors,
+ left, op, splat_scalar(right, left.type()));
+ }
+ return nullptr;
+ }
+
+ // Perform constant folding on scalars against vectors, e.g.: 2 + half4(2)
+ if (rightType.isVector() && rightType.componentType() == leftType) {
+ if (leftType.isFloat()) {
+ return simplify_vector<SKSL_FLOAT>(context, errors,
+ splat_scalar(left, right.type()), op, right);
+ }
+ if (leftType.isInteger()) {
+ return simplify_vector<SKSL_INT>(context, errors,
+ splat_scalar(left, right.type()), op, right);
+ }
+ return nullptr;
+ }
+
// Perform constant folding on pairs of matrices.
if (leftType.isMatrix() && rightType.isMatrix()) {
bool equality;