Move constant folding to a separate file.
This doesn't change any logic, just makes the IR generator a few
hundred lines shorter.
Change-Id: I92010191ee9283c33499c819d65fc85913f25824
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/352121
Commit-Queue: John Stiles <johnstiles@google.com>
Auto-Submit: John Stiles <johnstiles@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
diff --git a/src/sksl/SkSLConstantFolder.cpp b/src/sksl/SkSLConstantFolder.cpp
new file mode 100644
index 0000000..94aa5d0
--- /dev/null
+++ b/src/sksl/SkSLConstantFolder.cpp
@@ -0,0 +1,282 @@
+/*
+ * Copyright 2020 Google LLC
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#include "src/sksl/SkSLConstantFolder.h"
+
+#include <limits>
+
+#include "src/sksl/SkSLContext.h"
+#include "src/sksl/SkSLErrorReporter.h"
+#include "src/sksl/ir/SkSLBinaryExpression.h"
+#include "src/sksl/ir/SkSLBoolLiteral.h"
+#include "src/sksl/ir/SkSLConstructor.h"
+#include "src/sksl/ir/SkSLExpression.h"
+#include "src/sksl/ir/SkSLFloatLiteral.h"
+#include "src/sksl/ir/SkSLIntLiteral.h"
+#include "src/sksl/ir/SkSLType.h"
+#include "src/sksl/ir/SkSLVariable.h"
+#include "src/sksl/ir/SkSLVariableReference.h"
+
+namespace SkSL {
+
+static std::unique_ptr<Expression> short_circuit_boolean(const Expression& left,
+ Token::Kind op,
+ const Expression& right) {
+ SkASSERT(left.is<BoolLiteral>());
+ bool leftVal = left.as<BoolLiteral>().value();
+
+ if (op == Token::Kind::TK_LOGICALAND) {
+ // (true && expr) -> (expr) and (false && expr) -> (false)
+ return leftVal ? right.clone()
+ : std::make_unique<BoolLiteral>(left.fOffset, /*value=*/false, &left.type());
+ }
+ if (op == Token::Kind::TK_LOGICALOR) {
+ // (true || expr) -> (true) and (false || expr) -> (expr)
+ return leftVal ? std::make_unique<BoolLiteral>(left.fOffset, /*value=*/true, &left.type())
+ : right.clone();
+ }
+ if (op == Token::Kind::TK_LOGICALXOR && !leftVal) {
+ // (false ^^ expr) -> (expr)
+ return right.clone();
+ }
+
+ return nullptr;
+}
+
+template <typename T>
+static std::unique_ptr<Expression> simplify_vector(const Context& context,
+ ErrorReporter& errors,
+ const Expression& left,
+ Token::Kind op,
+ const Expression& right) {
+ SkASSERT(left.type() == right.type());
+ const Type& type = left.type();
+
+ // Handle boolean operations: == !=
+ if (op == Token::Kind::TK_EQEQ || op == Token::Kind::TK_NEQ) {
+ bool equality = (op == Token::Kind::TK_EQEQ);
+
+ switch (left.compareConstant(right)) {
+ case Expression::ComparisonResult::kNotEqual:
+ equality = !equality;
+ [[fallthrough]];
+
+ case Expression::ComparisonResult::kEqual:
+ return std::make_unique<BoolLiteral>(context, left.fOffset, equality);
+
+ case Expression::ComparisonResult::kUnknown:
+ return nullptr;
+ }
+ }
+
+ // Handle floating-point arithmetic: + - * /
+ const auto vectorComponentwiseFold = [&](auto foldFn) -> std::unique_ptr<Constructor> {
+ const Type& componentType = type.componentType();
+ ExpressionArray args;
+ args.reserve_back(type.columns());
+ for (int i = 0; i < type.columns(); i++) {
+ T value = foldFn(left.getVecComponent<T>(i), right.getVecComponent<T>(i));
+ args.push_back(std::make_unique<Literal<T>>(left.fOffset, value, &componentType));
+ }
+ return std::make_unique<Constructor>(left.fOffset, &type, std::move(args));
+ };
+
+ const auto isVectorDivisionByZero = [&]() -> bool {
+ for (int i = 0; i < type.columns(); i++) {
+ if (right.getVecComponent<T>(i) == 0) {
+ return true;
+ }
+ }
+ return false;
+ };
+
+ switch (op) {
+ case Token::Kind::TK_PLUS: return vectorComponentwiseFold([](T a, T b) { return a + b; });
+ case Token::Kind::TK_MINUS: return vectorComponentwiseFold([](T a, T b) { return a - b; });
+ case Token::Kind::TK_STAR: return vectorComponentwiseFold([](T a, T b) { return a * b; });
+ case Token::Kind::TK_SLASH: {
+ if (isVectorDivisionByZero()) {
+ errors.error(right.fOffset, "division by zero");
+ return nullptr;
+ }
+ return vectorComponentwiseFold([](T a, T b) { return a / b; });
+ }
+ default:
+ return nullptr;
+ }
+}
+
+std::unique_ptr<Expression> ConstantFolder::Simplify(const Context& context,
+ ErrorReporter& errors,
+ const Expression& left,
+ Token::Kind op,
+ const Expression& right) {
+ // If the left side is a constant boolean literal, the right side does not need to be constant
+ // for short-circuit optimizations to allow the constant to be folded.
+ if (left.is<BoolLiteral>() && !right.isCompileTimeConstant()) {
+ return short_circuit_boolean(left, op, right);
+ }
+
+ if (right.is<BoolLiteral>() && !left.isCompileTimeConstant()) {
+ // There aren't side effects in SkSL within expressions, so (left OP right) is equivalent to
+ // (right OP left) for short-circuit optimizations
+ // TODO: (true || (a=b)) seems to disqualify the above statement. Test this.
+ return short_circuit_boolean(right, op, left);
+ }
+
+ // Other than the short-circuit cases above, constant folding requires both sides to be constant
+ if (!left.isCompileTimeConstant() || !right.isCompileTimeConstant()) {
+ return nullptr;
+ }
+
+ // Perform constant folding on pairs of Booleans.
+ if (left.is<BoolLiteral>() && right.is<BoolLiteral>()) {
+ bool leftVal = left.as<BoolLiteral>().value();
+ bool rightVal = right.as<BoolLiteral>().value();
+ bool result;
+ switch (op) {
+ case Token::Kind::TK_LOGICALAND: result = leftVal && rightVal; break;
+ case Token::Kind::TK_LOGICALOR: result = leftVal || rightVal; break;
+ case Token::Kind::TK_LOGICALXOR: result = leftVal ^ rightVal; break;
+ default: return nullptr;
+ }
+ return std::make_unique<BoolLiteral>(context, left.fOffset, result);
+ }
+
+ // Note that we expressly do not worry about precision and overflow here -- we use the maximum
+ // precision to calculate the results and hope the result makes sense.
+ // TODO: detect and handle integer overflow properly.
+ #define RESULT(t, op) std::make_unique<t ## Literal>(context, left.fOffset, \
+ leftVal op rightVal)
+ #define URESULT(t, op) std::make_unique<t ## Literal>(context, left.fOffset, \
+ (uint64_t) leftVal op \
+ (uint64_t) rightVal)
+ if (left.is<IntLiteral>() && right.is<IntLiteral>()) {
+ SKSL_INT leftVal = left.as<IntLiteral>().value();
+ SKSL_INT rightVal = right.as<IntLiteral>().value();
+ switch (op) {
+ case Token::Kind::TK_PLUS: return URESULT(Int, +);
+ case Token::Kind::TK_MINUS: return URESULT(Int, -);
+ case Token::Kind::TK_STAR: return URESULT(Int, *);
+ case Token::Kind::TK_SLASH:
+ if (leftVal == std::numeric_limits<SKSL_INT>::min() && rightVal == -1) {
+ errors.error(right.fOffset, "arithmetic overflow");
+ return nullptr;
+ }
+ if (!rightVal) {
+ errors.error(right.fOffset, "division by zero");
+ return nullptr;
+ }
+ return RESULT(Int, /);
+ case Token::Kind::TK_PERCENT:
+ if (leftVal == std::numeric_limits<SKSL_INT>::min() && rightVal == -1) {
+ errors.error(right.fOffset, "arithmetic overflow");
+ return nullptr;
+ }
+ if (!rightVal) {
+ errors.error(right.fOffset, "division by zero");
+ return nullptr;
+ }
+ return RESULT(Int, %);
+ case Token::Kind::TK_BITWISEAND: return RESULT(Int, &);
+ case Token::Kind::TK_BITWISEOR: return RESULT(Int, |);
+ case Token::Kind::TK_BITWISEXOR: return RESULT(Int, ^);
+ case Token::Kind::TK_EQEQ: return RESULT(Bool, ==);
+ case Token::Kind::TK_NEQ: return RESULT(Bool, !=);
+ case Token::Kind::TK_GT: return RESULT(Bool, >);
+ case Token::Kind::TK_GTEQ: return RESULT(Bool, >=);
+ case Token::Kind::TK_LT: return RESULT(Bool, <);
+ case Token::Kind::TK_LTEQ: return RESULT(Bool, <=);
+ case Token::Kind::TK_SHL:
+ if (rightVal >= 0 && rightVal <= 31) {
+ return RESULT(Int, <<);
+ }
+ errors.error(right.fOffset, "shift value out of range");
+ return nullptr;
+ case Token::Kind::TK_SHR:
+ if (rightVal >= 0 && rightVal <= 31) {
+ return RESULT(Int, >>);
+ }
+ errors.error(right.fOffset, "shift value out of range");
+ return nullptr;
+
+ default:
+ return nullptr;
+ }
+ }
+
+ // Perform constant folding on pairs of floating-point literals.
+ if (left.is<FloatLiteral>() && right.is<FloatLiteral>()) {
+ SKSL_FLOAT leftVal = left.as<FloatLiteral>().value();
+ SKSL_FLOAT rightVal = right.as<FloatLiteral>().value();
+ switch (op) {
+ case Token::Kind::TK_PLUS: return RESULT(Float, +);
+ case Token::Kind::TK_MINUS: return RESULT(Float, -);
+ case Token::Kind::TK_STAR: return RESULT(Float, *);
+ case Token::Kind::TK_SLASH:
+ if (rightVal) {
+ return RESULT(Float, /);
+ }
+ errors.error(right.fOffset, "division by zero");
+ return nullptr;
+ case Token::Kind::TK_EQEQ: return RESULT(Bool, ==);
+ case Token::Kind::TK_NEQ: return RESULT(Bool, !=);
+ case Token::Kind::TK_GT: return RESULT(Bool, >);
+ case Token::Kind::TK_GTEQ: return RESULT(Bool, >=);
+ case Token::Kind::TK_LT: return RESULT(Bool, <);
+ case Token::Kind::TK_LTEQ: return RESULT(Bool, <=);
+ default: return nullptr;
+ }
+ }
+
+ // Perform constant folding on pairs of vectors.
+ const Type& leftType = left.type();
+ const Type& rightType = right.type();
+ if (leftType.isVector() && leftType == rightType) {
+ if (leftType.componentType().isFloat()) {
+ return simplify_vector<SKSL_FLOAT>(context, errors, left, op, right);
+ }
+ if (leftType.componentType().isInteger()) {
+ return simplify_vector<SKSL_INT>(context, errors, left, op, right);
+ }
+ return nullptr;
+ }
+
+ // Perform constant folding on pairs of matrices.
+ if (leftType.isMatrix() && rightType.isMatrix()) {
+ bool equality;
+ switch (op) {
+ case Token::Kind::TK_EQEQ:
+ equality = true;
+ break;
+ case Token::Kind::TK_NEQ:
+ equality = false;
+ break;
+ default:
+ return nullptr;
+ }
+
+ switch (left.compareConstant(right)) {
+ case Expression::ComparisonResult::kNotEqual:
+ equality = !equality;
+ [[fallthrough]];
+
+ case Expression::ComparisonResult::kEqual:
+ return std::make_unique<BoolLiteral>(context, left.fOffset, equality);
+
+ case Expression::ComparisonResult::kUnknown:
+ return nullptr;
+ }
+ }
+
+ // We aren't able to constant-fold.
+ #undef RESULT
+ #undef URESULT
+ return nullptr;
+}
+
+} // namespace SkSL