blob: 5b78893d8b7d4ddbea6981af3241f811aff793a7 [file] [log] [blame]
John Stilesdc8ec312021-01-11 11:05:21 -05001/*
2 * Copyright 2020 Google LLC
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8#include "src/sksl/SkSLConstantFolder.h"
9
10#include <limits>
11
12#include "src/sksl/SkSLContext.h"
13#include "src/sksl/SkSLErrorReporter.h"
14#include "src/sksl/ir/SkSLBinaryExpression.h"
15#include "src/sksl/ir/SkSLBoolLiteral.h"
16#include "src/sksl/ir/SkSLConstructor.h"
17#include "src/sksl/ir/SkSLExpression.h"
18#include "src/sksl/ir/SkSLFloatLiteral.h"
19#include "src/sksl/ir/SkSLIntLiteral.h"
20#include "src/sksl/ir/SkSLType.h"
21#include "src/sksl/ir/SkSLVariable.h"
22#include "src/sksl/ir/SkSLVariableReference.h"
23
24namespace SkSL {
25
26static std::unique_ptr<Expression> short_circuit_boolean(const Expression& left,
27 Token::Kind op,
28 const Expression& right) {
29 SkASSERT(left.is<BoolLiteral>());
30 bool leftVal = left.as<BoolLiteral>().value();
31
32 if (op == Token::Kind::TK_LOGICALAND) {
33 // (true && expr) -> (expr) and (false && expr) -> (false)
34 return leftVal ? right.clone()
35 : std::make_unique<BoolLiteral>(left.fOffset, /*value=*/false, &left.type());
36 }
37 if (op == Token::Kind::TK_LOGICALOR) {
38 // (true || expr) -> (true) and (false || expr) -> (expr)
39 return leftVal ? std::make_unique<BoolLiteral>(left.fOffset, /*value=*/true, &left.type())
40 : right.clone();
41 }
John Stiles81bfabe2021-01-20 12:05:03 -050042 if (op == Token::Kind::TK_EQEQ && leftVal) {
43 // (true == expr) -> (expr)
44 return right.clone();
45 }
46 if (op == Token::Kind::TK_NEQ && !leftVal) {
47 // (false != expr) -> (expr)
48 return right.clone();
49 }
John Stilesdc8ec312021-01-11 11:05:21 -050050 if (op == Token::Kind::TK_LOGICALXOR && !leftVal) {
51 // (false ^^ expr) -> (expr)
52 return right.clone();
53 }
54
55 return nullptr;
56}
57
58template <typename T>
59static std::unique_ptr<Expression> simplify_vector(const Context& context,
John Stilesdc8ec312021-01-11 11:05:21 -050060 const Expression& left,
61 Token::Kind op,
62 const Expression& right) {
John Stiles508eba72021-01-11 13:07:47 -050063 SkASSERT(left.type().isVector());
John Stilesdc8ec312021-01-11 11:05:21 -050064 SkASSERT(left.type() == right.type());
65 const Type& type = left.type();
66
67 // Handle boolean operations: == !=
68 if (op == Token::Kind::TK_EQEQ || op == Token::Kind::TK_NEQ) {
69 bool equality = (op == Token::Kind::TK_EQEQ);
70
71 switch (left.compareConstant(right)) {
72 case Expression::ComparisonResult::kNotEqual:
73 equality = !equality;
74 [[fallthrough]];
75
76 case Expression::ComparisonResult::kEqual:
77 return std::make_unique<BoolLiteral>(context, left.fOffset, equality);
78
79 case Expression::ComparisonResult::kUnknown:
80 return nullptr;
81 }
82 }
83
84 // Handle floating-point arithmetic: + - * /
85 const auto vectorComponentwiseFold = [&](auto foldFn) -> std::unique_ptr<Constructor> {
86 const Type& componentType = type.componentType();
87 ExpressionArray args;
88 args.reserve_back(type.columns());
89 for (int i = 0; i < type.columns(); i++) {
90 T value = foldFn(left.getVecComponent<T>(i), right.getVecComponent<T>(i));
91 args.push_back(std::make_unique<Literal<T>>(left.fOffset, value, &componentType));
92 }
93 return std::make_unique<Constructor>(left.fOffset, &type, std::move(args));
94 };
95
96 const auto isVectorDivisionByZero = [&]() -> bool {
97 for (int i = 0; i < type.columns(); i++) {
98 if (right.getVecComponent<T>(i) == 0) {
99 return true;
100 }
101 }
102 return false;
103 };
104
105 switch (op) {
106 case Token::Kind::TK_PLUS: return vectorComponentwiseFold([](T a, T b) { return a + b; });
107 case Token::Kind::TK_MINUS: return vectorComponentwiseFold([](T a, T b) { return a - b; });
108 case Token::Kind::TK_STAR: return vectorComponentwiseFold([](T a, T b) { return a * b; });
109 case Token::Kind::TK_SLASH: {
110 if (isVectorDivisionByZero()) {
John Stilesb30151e2021-01-11 16:13:08 -0500111 context.fErrors.error(right.fOffset, "division by zero");
John Stilesdc8ec312021-01-11 11:05:21 -0500112 return nullptr;
113 }
114 return vectorComponentwiseFold([](T a, T b) { return a / b; });
115 }
116 default:
117 return nullptr;
118 }
119}
120
John Stiles508eba72021-01-11 13:07:47 -0500121static Constructor splat_scalar(const Expression& scalar, const Type& type) {
122 SkASSERT(type.isVector());
123 SkASSERT(type.componentType() == scalar.type());
124
125 // Use a Constructor to splat the scalar expression across a vector.
126 ExpressionArray arg;
127 arg.push_back(scalar.clone());
128 return Constructor{scalar.fOffset, &type, std::move(arg)};
129}
130
John Stilesdc8ec312021-01-11 11:05:21 -0500131std::unique_ptr<Expression> ConstantFolder::Simplify(const Context& context,
John Stilesdc8ec312021-01-11 11:05:21 -0500132 const Expression& left,
133 Token::Kind op,
134 const Expression& right) {
135 // If the left side is a constant boolean literal, the right side does not need to be constant
136 // for short-circuit optimizations to allow the constant to be folded.
137 if (left.is<BoolLiteral>() && !right.isCompileTimeConstant()) {
138 return short_circuit_boolean(left, op, right);
139 }
140
141 if (right.is<BoolLiteral>() && !left.isCompileTimeConstant()) {
142 // There aren't side effects in SkSL within expressions, so (left OP right) is equivalent to
143 // (right OP left) for short-circuit optimizations
144 // TODO: (true || (a=b)) seems to disqualify the above statement. Test this.
145 return short_circuit_boolean(right, op, left);
146 }
147
148 // Other than the short-circuit cases above, constant folding requires both sides to be constant
149 if (!left.isCompileTimeConstant() || !right.isCompileTimeConstant()) {
150 return nullptr;
151 }
152
153 // Perform constant folding on pairs of Booleans.
154 if (left.is<BoolLiteral>() && right.is<BoolLiteral>()) {
155 bool leftVal = left.as<BoolLiteral>().value();
156 bool rightVal = right.as<BoolLiteral>().value();
157 bool result;
158 switch (op) {
159 case Token::Kind::TK_LOGICALAND: result = leftVal && rightVal; break;
160 case Token::Kind::TK_LOGICALOR: result = leftVal || rightVal; break;
161 case Token::Kind::TK_LOGICALXOR: result = leftVal ^ rightVal; break;
John Stiles26fdcbb2021-01-19 19:00:31 -0500162 case Token::Kind::TK_EQEQ: result = leftVal == rightVal; break;
163 case Token::Kind::TK_NEQ: result = leftVal != rightVal; break;
John Stilesdc8ec312021-01-11 11:05:21 -0500164 default: return nullptr;
165 }
166 return std::make_unique<BoolLiteral>(context, left.fOffset, result);
167 }
168
169 // Note that we expressly do not worry about precision and overflow here -- we use the maximum
170 // precision to calculate the results and hope the result makes sense.
171 // TODO: detect and handle integer overflow properly.
172 #define RESULT(t, op) std::make_unique<t ## Literal>(context, left.fOffset, \
173 leftVal op rightVal)
174 #define URESULT(t, op) std::make_unique<t ## Literal>(context, left.fOffset, \
175 (uint64_t) leftVal op \
176 (uint64_t) rightVal)
177 if (left.is<IntLiteral>() && right.is<IntLiteral>()) {
178 SKSL_INT leftVal = left.as<IntLiteral>().value();
179 SKSL_INT rightVal = right.as<IntLiteral>().value();
180 switch (op) {
181 case Token::Kind::TK_PLUS: return URESULT(Int, +);
182 case Token::Kind::TK_MINUS: return URESULT(Int, -);
183 case Token::Kind::TK_STAR: return URESULT(Int, *);
184 case Token::Kind::TK_SLASH:
185 if (leftVal == std::numeric_limits<SKSL_INT>::min() && rightVal == -1) {
John Stilesb30151e2021-01-11 16:13:08 -0500186 context.fErrors.error(right.fOffset, "arithmetic overflow");
John Stilesdc8ec312021-01-11 11:05:21 -0500187 return nullptr;
188 }
189 if (!rightVal) {
John Stilesb30151e2021-01-11 16:13:08 -0500190 context.fErrors.error(right.fOffset, "division by zero");
John Stilesdc8ec312021-01-11 11:05:21 -0500191 return nullptr;
192 }
193 return RESULT(Int, /);
194 case Token::Kind::TK_PERCENT:
195 if (leftVal == std::numeric_limits<SKSL_INT>::min() && rightVal == -1) {
John Stilesb30151e2021-01-11 16:13:08 -0500196 context.fErrors.error(right.fOffset, "arithmetic overflow");
John Stilesdc8ec312021-01-11 11:05:21 -0500197 return nullptr;
198 }
199 if (!rightVal) {
John Stilesb30151e2021-01-11 16:13:08 -0500200 context.fErrors.error(right.fOffset, "division by zero");
John Stilesdc8ec312021-01-11 11:05:21 -0500201 return nullptr;
202 }
203 return RESULT(Int, %);
204 case Token::Kind::TK_BITWISEAND: return RESULT(Int, &);
205 case Token::Kind::TK_BITWISEOR: return RESULT(Int, |);
206 case Token::Kind::TK_BITWISEXOR: return RESULT(Int, ^);
207 case Token::Kind::TK_EQEQ: return RESULT(Bool, ==);
208 case Token::Kind::TK_NEQ: return RESULT(Bool, !=);
209 case Token::Kind::TK_GT: return RESULT(Bool, >);
210 case Token::Kind::TK_GTEQ: return RESULT(Bool, >=);
211 case Token::Kind::TK_LT: return RESULT(Bool, <);
212 case Token::Kind::TK_LTEQ: return RESULT(Bool, <=);
213 case Token::Kind::TK_SHL:
214 if (rightVal >= 0 && rightVal <= 31) {
215 return RESULT(Int, <<);
216 }
John Stilesb30151e2021-01-11 16:13:08 -0500217 context.fErrors.error(right.fOffset, "shift value out of range");
John Stilesdc8ec312021-01-11 11:05:21 -0500218 return nullptr;
219 case Token::Kind::TK_SHR:
220 if (rightVal >= 0 && rightVal <= 31) {
221 return RESULT(Int, >>);
222 }
John Stilesb30151e2021-01-11 16:13:08 -0500223 context.fErrors.error(right.fOffset, "shift value out of range");
John Stilesdc8ec312021-01-11 11:05:21 -0500224 return nullptr;
225
226 default:
227 return nullptr;
228 }
229 }
230
231 // Perform constant folding on pairs of floating-point literals.
232 if (left.is<FloatLiteral>() && right.is<FloatLiteral>()) {
233 SKSL_FLOAT leftVal = left.as<FloatLiteral>().value();
234 SKSL_FLOAT rightVal = right.as<FloatLiteral>().value();
235 switch (op) {
236 case Token::Kind::TK_PLUS: return RESULT(Float, +);
237 case Token::Kind::TK_MINUS: return RESULT(Float, -);
238 case Token::Kind::TK_STAR: return RESULT(Float, *);
239 case Token::Kind::TK_SLASH:
240 if (rightVal) {
241 return RESULT(Float, /);
242 }
John Stilesb30151e2021-01-11 16:13:08 -0500243 context.fErrors.error(right.fOffset, "division by zero");
John Stilesdc8ec312021-01-11 11:05:21 -0500244 return nullptr;
245 case Token::Kind::TK_EQEQ: return RESULT(Bool, ==);
246 case Token::Kind::TK_NEQ: return RESULT(Bool, !=);
247 case Token::Kind::TK_GT: return RESULT(Bool, >);
248 case Token::Kind::TK_GTEQ: return RESULT(Bool, >=);
249 case Token::Kind::TK_LT: return RESULT(Bool, <);
250 case Token::Kind::TK_LTEQ: return RESULT(Bool, <=);
251 default: return nullptr;
252 }
253 }
254
255 // Perform constant folding on pairs of vectors.
256 const Type& leftType = left.type();
257 const Type& rightType = right.type();
258 if (leftType.isVector() && leftType == rightType) {
259 if (leftType.componentType().isFloat()) {
John Stilesb30151e2021-01-11 16:13:08 -0500260 return simplify_vector<SKSL_FLOAT>(context, left, op, right);
John Stilesdc8ec312021-01-11 11:05:21 -0500261 }
262 if (leftType.componentType().isInteger()) {
John Stilesb30151e2021-01-11 16:13:08 -0500263 return simplify_vector<SKSL_INT>(context, left, op, right);
John Stilesdc8ec312021-01-11 11:05:21 -0500264 }
265 return nullptr;
266 }
267
John Stiles508eba72021-01-11 13:07:47 -0500268 // Perform constant folding on vectors against scalars, e.g.: half4(2) + 2
269 if (leftType.isVector() && leftType.componentType() == rightType) {
270 if (rightType.isFloat()) {
John Stilesb30151e2021-01-11 16:13:08 -0500271 return simplify_vector<SKSL_FLOAT>(context, left, op, splat_scalar(right, left.type()));
John Stiles508eba72021-01-11 13:07:47 -0500272 }
273 if (rightType.isInteger()) {
John Stilesb30151e2021-01-11 16:13:08 -0500274 return simplify_vector<SKSL_INT>(context, left, op, splat_scalar(right, left.type()));
John Stiles508eba72021-01-11 13:07:47 -0500275 }
276 return nullptr;
277 }
278
279 // Perform constant folding on scalars against vectors, e.g.: 2 + half4(2)
280 if (rightType.isVector() && rightType.componentType() == leftType) {
281 if (leftType.isFloat()) {
John Stilesb30151e2021-01-11 16:13:08 -0500282 return simplify_vector<SKSL_FLOAT>(context, splat_scalar(left, right.type()), op,
283 right);
John Stiles508eba72021-01-11 13:07:47 -0500284 }
285 if (leftType.isInteger()) {
John Stilesb30151e2021-01-11 16:13:08 -0500286 return simplify_vector<SKSL_INT>(context, splat_scalar(left, right.type()), op, right);
John Stiles508eba72021-01-11 13:07:47 -0500287 }
288 return nullptr;
289 }
290
John Stilesdc8ec312021-01-11 11:05:21 -0500291 // Perform constant folding on pairs of matrices.
292 if (leftType.isMatrix() && rightType.isMatrix()) {
293 bool equality;
294 switch (op) {
295 case Token::Kind::TK_EQEQ:
296 equality = true;
297 break;
298 case Token::Kind::TK_NEQ:
299 equality = false;
300 break;
301 default:
302 return nullptr;
303 }
304
305 switch (left.compareConstant(right)) {
306 case Expression::ComparisonResult::kNotEqual:
307 equality = !equality;
308 [[fallthrough]];
309
310 case Expression::ComparisonResult::kEqual:
311 return std::make_unique<BoolLiteral>(context, left.fOffset, equality);
312
313 case Expression::ComparisonResult::kUnknown:
314 return nullptr;
315 }
316 }
317
318 // We aren't able to constant-fold.
319 #undef RESULT
320 #undef URESULT
321 return nullptr;
322}
323
324} // namespace SkSL