blob: e85a7c2527f1babc8d088291af0f268729a1ba89 [file] [log] [blame]
John Stilesdc8ec312021-01-11 11:05:21 -05001/*
2 * Copyright 2020 Google LLC
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8#include "src/sksl/SkSLConstantFolder.h"
9
10#include <limits>
11
12#include "src/sksl/SkSLContext.h"
13#include "src/sksl/SkSLErrorReporter.h"
14#include "src/sksl/ir/SkSLBinaryExpression.h"
15#include "src/sksl/ir/SkSLBoolLiteral.h"
16#include "src/sksl/ir/SkSLConstructor.h"
17#include "src/sksl/ir/SkSLExpression.h"
18#include "src/sksl/ir/SkSLFloatLiteral.h"
19#include "src/sksl/ir/SkSLIntLiteral.h"
20#include "src/sksl/ir/SkSLType.h"
21#include "src/sksl/ir/SkSLVariable.h"
22#include "src/sksl/ir/SkSLVariableReference.h"
23
24namespace SkSL {
25
26static std::unique_ptr<Expression> short_circuit_boolean(const Expression& left,
27 Token::Kind op,
28 const Expression& right) {
29 SkASSERT(left.is<BoolLiteral>());
30 bool leftVal = left.as<BoolLiteral>().value();
31
32 if (op == Token::Kind::TK_LOGICALAND) {
33 // (true && expr) -> (expr) and (false && expr) -> (false)
34 return leftVal ? right.clone()
35 : std::make_unique<BoolLiteral>(left.fOffset, /*value=*/false, &left.type());
36 }
37 if (op == Token::Kind::TK_LOGICALOR) {
38 // (true || expr) -> (true) and (false || expr) -> (expr)
39 return leftVal ? std::make_unique<BoolLiteral>(left.fOffset, /*value=*/true, &left.type())
40 : right.clone();
41 }
42 if (op == Token::Kind::TK_LOGICALXOR && !leftVal) {
43 // (false ^^ expr) -> (expr)
44 return right.clone();
45 }
46
47 return nullptr;
48}
49
50template <typename T>
51static std::unique_ptr<Expression> simplify_vector(const Context& context,
52 ErrorReporter& errors,
53 const Expression& left,
54 Token::Kind op,
55 const Expression& right) {
John Stiles508eba72021-01-11 13:07:47 -050056 SkASSERT(left.type().isVector());
John Stilesdc8ec312021-01-11 11:05:21 -050057 SkASSERT(left.type() == right.type());
58 const Type& type = left.type();
59
60 // Handle boolean operations: == !=
61 if (op == Token::Kind::TK_EQEQ || op == Token::Kind::TK_NEQ) {
62 bool equality = (op == Token::Kind::TK_EQEQ);
63
64 switch (left.compareConstant(right)) {
65 case Expression::ComparisonResult::kNotEqual:
66 equality = !equality;
67 [[fallthrough]];
68
69 case Expression::ComparisonResult::kEqual:
70 return std::make_unique<BoolLiteral>(context, left.fOffset, equality);
71
72 case Expression::ComparisonResult::kUnknown:
73 return nullptr;
74 }
75 }
76
77 // Handle floating-point arithmetic: + - * /
78 const auto vectorComponentwiseFold = [&](auto foldFn) -> std::unique_ptr<Constructor> {
79 const Type& componentType = type.componentType();
80 ExpressionArray args;
81 args.reserve_back(type.columns());
82 for (int i = 0; i < type.columns(); i++) {
83 T value = foldFn(left.getVecComponent<T>(i), right.getVecComponent<T>(i));
84 args.push_back(std::make_unique<Literal<T>>(left.fOffset, value, &componentType));
85 }
86 return std::make_unique<Constructor>(left.fOffset, &type, std::move(args));
87 };
88
89 const auto isVectorDivisionByZero = [&]() -> bool {
90 for (int i = 0; i < type.columns(); i++) {
91 if (right.getVecComponent<T>(i) == 0) {
92 return true;
93 }
94 }
95 return false;
96 };
97
98 switch (op) {
99 case Token::Kind::TK_PLUS: return vectorComponentwiseFold([](T a, T b) { return a + b; });
100 case Token::Kind::TK_MINUS: return vectorComponentwiseFold([](T a, T b) { return a - b; });
101 case Token::Kind::TK_STAR: return vectorComponentwiseFold([](T a, T b) { return a * b; });
102 case Token::Kind::TK_SLASH: {
103 if (isVectorDivisionByZero()) {
104 errors.error(right.fOffset, "division by zero");
105 return nullptr;
106 }
107 return vectorComponentwiseFold([](T a, T b) { return a / b; });
108 }
109 default:
110 return nullptr;
111 }
112}
113
John Stiles508eba72021-01-11 13:07:47 -0500114static Constructor splat_scalar(const Expression& scalar, const Type& type) {
115 SkASSERT(type.isVector());
116 SkASSERT(type.componentType() == scalar.type());
117
118 // Use a Constructor to splat the scalar expression across a vector.
119 ExpressionArray arg;
120 arg.push_back(scalar.clone());
121 return Constructor{scalar.fOffset, &type, std::move(arg)};
122}
123
John Stilesdc8ec312021-01-11 11:05:21 -0500124std::unique_ptr<Expression> ConstantFolder::Simplify(const Context& context,
125 ErrorReporter& errors,
126 const Expression& left,
127 Token::Kind op,
128 const Expression& right) {
129 // If the left side is a constant boolean literal, the right side does not need to be constant
130 // for short-circuit optimizations to allow the constant to be folded.
131 if (left.is<BoolLiteral>() && !right.isCompileTimeConstant()) {
132 return short_circuit_boolean(left, op, right);
133 }
134
135 if (right.is<BoolLiteral>() && !left.isCompileTimeConstant()) {
136 // There aren't side effects in SkSL within expressions, so (left OP right) is equivalent to
137 // (right OP left) for short-circuit optimizations
138 // TODO: (true || (a=b)) seems to disqualify the above statement. Test this.
139 return short_circuit_boolean(right, op, left);
140 }
141
142 // Other than the short-circuit cases above, constant folding requires both sides to be constant
143 if (!left.isCompileTimeConstant() || !right.isCompileTimeConstant()) {
144 return nullptr;
145 }
146
147 // Perform constant folding on pairs of Booleans.
148 if (left.is<BoolLiteral>() && right.is<BoolLiteral>()) {
149 bool leftVal = left.as<BoolLiteral>().value();
150 bool rightVal = right.as<BoolLiteral>().value();
151 bool result;
152 switch (op) {
153 case Token::Kind::TK_LOGICALAND: result = leftVal && rightVal; break;
154 case Token::Kind::TK_LOGICALOR: result = leftVal || rightVal; break;
155 case Token::Kind::TK_LOGICALXOR: result = leftVal ^ rightVal; break;
156 default: return nullptr;
157 }
158 return std::make_unique<BoolLiteral>(context, left.fOffset, result);
159 }
160
161 // Note that we expressly do not worry about precision and overflow here -- we use the maximum
162 // precision to calculate the results and hope the result makes sense.
163 // TODO: detect and handle integer overflow properly.
164 #define RESULT(t, op) std::make_unique<t ## Literal>(context, left.fOffset, \
165 leftVal op rightVal)
166 #define URESULT(t, op) std::make_unique<t ## Literal>(context, left.fOffset, \
167 (uint64_t) leftVal op \
168 (uint64_t) rightVal)
169 if (left.is<IntLiteral>() && right.is<IntLiteral>()) {
170 SKSL_INT leftVal = left.as<IntLiteral>().value();
171 SKSL_INT rightVal = right.as<IntLiteral>().value();
172 switch (op) {
173 case Token::Kind::TK_PLUS: return URESULT(Int, +);
174 case Token::Kind::TK_MINUS: return URESULT(Int, -);
175 case Token::Kind::TK_STAR: return URESULT(Int, *);
176 case Token::Kind::TK_SLASH:
177 if (leftVal == std::numeric_limits<SKSL_INT>::min() && rightVal == -1) {
178 errors.error(right.fOffset, "arithmetic overflow");
179 return nullptr;
180 }
181 if (!rightVal) {
182 errors.error(right.fOffset, "division by zero");
183 return nullptr;
184 }
185 return RESULT(Int, /);
186 case Token::Kind::TK_PERCENT:
187 if (leftVal == std::numeric_limits<SKSL_INT>::min() && rightVal == -1) {
188 errors.error(right.fOffset, "arithmetic overflow");
189 return nullptr;
190 }
191 if (!rightVal) {
192 errors.error(right.fOffset, "division by zero");
193 return nullptr;
194 }
195 return RESULT(Int, %);
196 case Token::Kind::TK_BITWISEAND: return RESULT(Int, &);
197 case Token::Kind::TK_BITWISEOR: return RESULT(Int, |);
198 case Token::Kind::TK_BITWISEXOR: return RESULT(Int, ^);
199 case Token::Kind::TK_EQEQ: return RESULT(Bool, ==);
200 case Token::Kind::TK_NEQ: return RESULT(Bool, !=);
201 case Token::Kind::TK_GT: return RESULT(Bool, >);
202 case Token::Kind::TK_GTEQ: return RESULT(Bool, >=);
203 case Token::Kind::TK_LT: return RESULT(Bool, <);
204 case Token::Kind::TK_LTEQ: return RESULT(Bool, <=);
205 case Token::Kind::TK_SHL:
206 if (rightVal >= 0 && rightVal <= 31) {
207 return RESULT(Int, <<);
208 }
209 errors.error(right.fOffset, "shift value out of range");
210 return nullptr;
211 case Token::Kind::TK_SHR:
212 if (rightVal >= 0 && rightVal <= 31) {
213 return RESULT(Int, >>);
214 }
215 errors.error(right.fOffset, "shift value out of range");
216 return nullptr;
217
218 default:
219 return nullptr;
220 }
221 }
222
223 // Perform constant folding on pairs of floating-point literals.
224 if (left.is<FloatLiteral>() && right.is<FloatLiteral>()) {
225 SKSL_FLOAT leftVal = left.as<FloatLiteral>().value();
226 SKSL_FLOAT rightVal = right.as<FloatLiteral>().value();
227 switch (op) {
228 case Token::Kind::TK_PLUS: return RESULT(Float, +);
229 case Token::Kind::TK_MINUS: return RESULT(Float, -);
230 case Token::Kind::TK_STAR: return RESULT(Float, *);
231 case Token::Kind::TK_SLASH:
232 if (rightVal) {
233 return RESULT(Float, /);
234 }
235 errors.error(right.fOffset, "division by zero");
236 return nullptr;
237 case Token::Kind::TK_EQEQ: return RESULT(Bool, ==);
238 case Token::Kind::TK_NEQ: return RESULT(Bool, !=);
239 case Token::Kind::TK_GT: return RESULT(Bool, >);
240 case Token::Kind::TK_GTEQ: return RESULT(Bool, >=);
241 case Token::Kind::TK_LT: return RESULT(Bool, <);
242 case Token::Kind::TK_LTEQ: return RESULT(Bool, <=);
243 default: return nullptr;
244 }
245 }
246
247 // Perform constant folding on pairs of vectors.
248 const Type& leftType = left.type();
249 const Type& rightType = right.type();
250 if (leftType.isVector() && leftType == rightType) {
251 if (leftType.componentType().isFloat()) {
252 return simplify_vector<SKSL_FLOAT>(context, errors, left, op, right);
253 }
254 if (leftType.componentType().isInteger()) {
255 return simplify_vector<SKSL_INT>(context, errors, left, op, right);
256 }
257 return nullptr;
258 }
259
John Stiles508eba72021-01-11 13:07:47 -0500260 // Perform constant folding on vectors against scalars, e.g.: half4(2) + 2
261 if (leftType.isVector() && leftType.componentType() == rightType) {
262 if (rightType.isFloat()) {
263 return simplify_vector<SKSL_FLOAT>(context, errors,
264 left, op, splat_scalar(right, left.type()));
265 }
266 if (rightType.isInteger()) {
267 return simplify_vector<SKSL_INT>(context, errors,
268 left, op, splat_scalar(right, left.type()));
269 }
270 return nullptr;
271 }
272
273 // Perform constant folding on scalars against vectors, e.g.: 2 + half4(2)
274 if (rightType.isVector() && rightType.componentType() == leftType) {
275 if (leftType.isFloat()) {
276 return simplify_vector<SKSL_FLOAT>(context, errors,
277 splat_scalar(left, right.type()), op, right);
278 }
279 if (leftType.isInteger()) {
280 return simplify_vector<SKSL_INT>(context, errors,
281 splat_scalar(left, right.type()), op, right);
282 }
283 return nullptr;
284 }
285
John Stilesdc8ec312021-01-11 11:05:21 -0500286 // Perform constant folding on pairs of matrices.
287 if (leftType.isMatrix() && rightType.isMatrix()) {
288 bool equality;
289 switch (op) {
290 case Token::Kind::TK_EQEQ:
291 equality = true;
292 break;
293 case Token::Kind::TK_NEQ:
294 equality = false;
295 break;
296 default:
297 return nullptr;
298 }
299
300 switch (left.compareConstant(right)) {
301 case Expression::ComparisonResult::kNotEqual:
302 equality = !equality;
303 [[fallthrough]];
304
305 case Expression::ComparisonResult::kEqual:
306 return std::make_unique<BoolLiteral>(context, left.fOffset, equality);
307
308 case Expression::ComparisonResult::kUnknown:
309 return nullptr;
310 }
311 }
312
313 // We aren't able to constant-fold.
314 #undef RESULT
315 #undef URESULT
316 return nullptr;
317}
318
319} // namespace SkSL