blob: a4cfe7ae6c876991bdc821197a2118e93c6c8c10 [file] [log] [blame]
John Stilesdc8ec312021-01-11 11:05:21 -05001/*
2 * Copyright 2020 Google LLC
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8#include "src/sksl/SkSLConstantFolder.h"
9
10#include <limits>
11
12#include "src/sksl/SkSLContext.h"
13#include "src/sksl/SkSLErrorReporter.h"
14#include "src/sksl/ir/SkSLBinaryExpression.h"
15#include "src/sksl/ir/SkSLBoolLiteral.h"
16#include "src/sksl/ir/SkSLConstructor.h"
17#include "src/sksl/ir/SkSLExpression.h"
18#include "src/sksl/ir/SkSLFloatLiteral.h"
19#include "src/sksl/ir/SkSLIntLiteral.h"
20#include "src/sksl/ir/SkSLType.h"
21#include "src/sksl/ir/SkSLVariable.h"
22#include "src/sksl/ir/SkSLVariableReference.h"
23
24namespace SkSL {
25
26static std::unique_ptr<Expression> short_circuit_boolean(const Expression& left,
27 Token::Kind op,
28 const Expression& right) {
29 SkASSERT(left.is<BoolLiteral>());
30 bool leftVal = left.as<BoolLiteral>().value();
31
32 if (op == Token::Kind::TK_LOGICALAND) {
33 // (true && expr) -> (expr) and (false && expr) -> (false)
34 return leftVal ? right.clone()
35 : std::make_unique<BoolLiteral>(left.fOffset, /*value=*/false, &left.type());
36 }
37 if (op == Token::Kind::TK_LOGICALOR) {
38 // (true || expr) -> (true) and (false || expr) -> (expr)
39 return leftVal ? std::make_unique<BoolLiteral>(left.fOffset, /*value=*/true, &left.type())
40 : right.clone();
41 }
42 if (op == Token::Kind::TK_LOGICALXOR && !leftVal) {
43 // (false ^^ expr) -> (expr)
44 return right.clone();
45 }
46
47 return nullptr;
48}
49
50template <typename T>
51static std::unique_ptr<Expression> simplify_vector(const Context& context,
John Stilesdc8ec312021-01-11 11:05:21 -050052 const Expression& left,
53 Token::Kind op,
54 const Expression& right) {
John Stiles508eba72021-01-11 13:07:47 -050055 SkASSERT(left.type().isVector());
John Stilesdc8ec312021-01-11 11:05:21 -050056 SkASSERT(left.type() == right.type());
57 const Type& type = left.type();
58
59 // Handle boolean operations: == !=
60 if (op == Token::Kind::TK_EQEQ || op == Token::Kind::TK_NEQ) {
61 bool equality = (op == Token::Kind::TK_EQEQ);
62
63 switch (left.compareConstant(right)) {
64 case Expression::ComparisonResult::kNotEqual:
65 equality = !equality;
66 [[fallthrough]];
67
68 case Expression::ComparisonResult::kEqual:
69 return std::make_unique<BoolLiteral>(context, left.fOffset, equality);
70
71 case Expression::ComparisonResult::kUnknown:
72 return nullptr;
73 }
74 }
75
76 // Handle floating-point arithmetic: + - * /
77 const auto vectorComponentwiseFold = [&](auto foldFn) -> std::unique_ptr<Constructor> {
78 const Type& componentType = type.componentType();
79 ExpressionArray args;
80 args.reserve_back(type.columns());
81 for (int i = 0; i < type.columns(); i++) {
82 T value = foldFn(left.getVecComponent<T>(i), right.getVecComponent<T>(i));
83 args.push_back(std::make_unique<Literal<T>>(left.fOffset, value, &componentType));
84 }
85 return std::make_unique<Constructor>(left.fOffset, &type, std::move(args));
86 };
87
88 const auto isVectorDivisionByZero = [&]() -> bool {
89 for (int i = 0; i < type.columns(); i++) {
90 if (right.getVecComponent<T>(i) == 0) {
91 return true;
92 }
93 }
94 return false;
95 };
96
97 switch (op) {
98 case Token::Kind::TK_PLUS: return vectorComponentwiseFold([](T a, T b) { return a + b; });
99 case Token::Kind::TK_MINUS: return vectorComponentwiseFold([](T a, T b) { return a - b; });
100 case Token::Kind::TK_STAR: return vectorComponentwiseFold([](T a, T b) { return a * b; });
101 case Token::Kind::TK_SLASH: {
102 if (isVectorDivisionByZero()) {
John Stilesb30151e2021-01-11 16:13:08 -0500103 context.fErrors.error(right.fOffset, "division by zero");
John Stilesdc8ec312021-01-11 11:05:21 -0500104 return nullptr;
105 }
106 return vectorComponentwiseFold([](T a, T b) { return a / b; });
107 }
108 default:
109 return nullptr;
110 }
111}
112
John Stiles508eba72021-01-11 13:07:47 -0500113static Constructor splat_scalar(const Expression& scalar, const Type& type) {
114 SkASSERT(type.isVector());
115 SkASSERT(type.componentType() == scalar.type());
116
117 // Use a Constructor to splat the scalar expression across a vector.
118 ExpressionArray arg;
119 arg.push_back(scalar.clone());
120 return Constructor{scalar.fOffset, &type, std::move(arg)};
121}
122
John Stilesdc8ec312021-01-11 11:05:21 -0500123std::unique_ptr<Expression> ConstantFolder::Simplify(const Context& context,
John Stilesdc8ec312021-01-11 11:05:21 -0500124 const Expression& left,
125 Token::Kind op,
126 const Expression& right) {
127 // If the left side is a constant boolean literal, the right side does not need to be constant
128 // for short-circuit optimizations to allow the constant to be folded.
129 if (left.is<BoolLiteral>() && !right.isCompileTimeConstant()) {
130 return short_circuit_boolean(left, op, right);
131 }
132
133 if (right.is<BoolLiteral>() && !left.isCompileTimeConstant()) {
134 // There aren't side effects in SkSL within expressions, so (left OP right) is equivalent to
135 // (right OP left) for short-circuit optimizations
136 // TODO: (true || (a=b)) seems to disqualify the above statement. Test this.
137 return short_circuit_boolean(right, op, left);
138 }
139
140 // Other than the short-circuit cases above, constant folding requires both sides to be constant
141 if (!left.isCompileTimeConstant() || !right.isCompileTimeConstant()) {
142 return nullptr;
143 }
144
145 // Perform constant folding on pairs of Booleans.
146 if (left.is<BoolLiteral>() && right.is<BoolLiteral>()) {
147 bool leftVal = left.as<BoolLiteral>().value();
148 bool rightVal = right.as<BoolLiteral>().value();
149 bool result;
150 switch (op) {
151 case Token::Kind::TK_LOGICALAND: result = leftVal && rightVal; break;
152 case Token::Kind::TK_LOGICALOR: result = leftVal || rightVal; break;
153 case Token::Kind::TK_LOGICALXOR: result = leftVal ^ rightVal; break;
154 default: return nullptr;
155 }
156 return std::make_unique<BoolLiteral>(context, left.fOffset, result);
157 }
158
159 // Note that we expressly do not worry about precision and overflow here -- we use the maximum
160 // precision to calculate the results and hope the result makes sense.
161 // TODO: detect and handle integer overflow properly.
162 #define RESULT(t, op) std::make_unique<t ## Literal>(context, left.fOffset, \
163 leftVal op rightVal)
164 #define URESULT(t, op) std::make_unique<t ## Literal>(context, left.fOffset, \
165 (uint64_t) leftVal op \
166 (uint64_t) rightVal)
167 if (left.is<IntLiteral>() && right.is<IntLiteral>()) {
168 SKSL_INT leftVal = left.as<IntLiteral>().value();
169 SKSL_INT rightVal = right.as<IntLiteral>().value();
170 switch (op) {
171 case Token::Kind::TK_PLUS: return URESULT(Int, +);
172 case Token::Kind::TK_MINUS: return URESULT(Int, -);
173 case Token::Kind::TK_STAR: return URESULT(Int, *);
174 case Token::Kind::TK_SLASH:
175 if (leftVal == std::numeric_limits<SKSL_INT>::min() && rightVal == -1) {
John Stilesb30151e2021-01-11 16:13:08 -0500176 context.fErrors.error(right.fOffset, "arithmetic overflow");
John Stilesdc8ec312021-01-11 11:05:21 -0500177 return nullptr;
178 }
179 if (!rightVal) {
John Stilesb30151e2021-01-11 16:13:08 -0500180 context.fErrors.error(right.fOffset, "division by zero");
John Stilesdc8ec312021-01-11 11:05:21 -0500181 return nullptr;
182 }
183 return RESULT(Int, /);
184 case Token::Kind::TK_PERCENT:
185 if (leftVal == std::numeric_limits<SKSL_INT>::min() && rightVal == -1) {
John Stilesb30151e2021-01-11 16:13:08 -0500186 context.fErrors.error(right.fOffset, "arithmetic overflow");
John Stilesdc8ec312021-01-11 11:05:21 -0500187 return nullptr;
188 }
189 if (!rightVal) {
John Stilesb30151e2021-01-11 16:13:08 -0500190 context.fErrors.error(right.fOffset, "division by zero");
John Stilesdc8ec312021-01-11 11:05:21 -0500191 return nullptr;
192 }
193 return RESULT(Int, %);
194 case Token::Kind::TK_BITWISEAND: return RESULT(Int, &);
195 case Token::Kind::TK_BITWISEOR: return RESULT(Int, |);
196 case Token::Kind::TK_BITWISEXOR: return RESULT(Int, ^);
197 case Token::Kind::TK_EQEQ: return RESULT(Bool, ==);
198 case Token::Kind::TK_NEQ: return RESULT(Bool, !=);
199 case Token::Kind::TK_GT: return RESULT(Bool, >);
200 case Token::Kind::TK_GTEQ: return RESULT(Bool, >=);
201 case Token::Kind::TK_LT: return RESULT(Bool, <);
202 case Token::Kind::TK_LTEQ: return RESULT(Bool, <=);
203 case Token::Kind::TK_SHL:
204 if (rightVal >= 0 && rightVal <= 31) {
205 return RESULT(Int, <<);
206 }
John Stilesb30151e2021-01-11 16:13:08 -0500207 context.fErrors.error(right.fOffset, "shift value out of range");
John Stilesdc8ec312021-01-11 11:05:21 -0500208 return nullptr;
209 case Token::Kind::TK_SHR:
210 if (rightVal >= 0 && rightVal <= 31) {
211 return RESULT(Int, >>);
212 }
John Stilesb30151e2021-01-11 16:13:08 -0500213 context.fErrors.error(right.fOffset, "shift value out of range");
John Stilesdc8ec312021-01-11 11:05:21 -0500214 return nullptr;
215
216 default:
217 return nullptr;
218 }
219 }
220
221 // Perform constant folding on pairs of floating-point literals.
222 if (left.is<FloatLiteral>() && right.is<FloatLiteral>()) {
223 SKSL_FLOAT leftVal = left.as<FloatLiteral>().value();
224 SKSL_FLOAT rightVal = right.as<FloatLiteral>().value();
225 switch (op) {
226 case Token::Kind::TK_PLUS: return RESULT(Float, +);
227 case Token::Kind::TK_MINUS: return RESULT(Float, -);
228 case Token::Kind::TK_STAR: return RESULT(Float, *);
229 case Token::Kind::TK_SLASH:
230 if (rightVal) {
231 return RESULT(Float, /);
232 }
John Stilesb30151e2021-01-11 16:13:08 -0500233 context.fErrors.error(right.fOffset, "division by zero");
John Stilesdc8ec312021-01-11 11:05:21 -0500234 return nullptr;
235 case Token::Kind::TK_EQEQ: return RESULT(Bool, ==);
236 case Token::Kind::TK_NEQ: return RESULT(Bool, !=);
237 case Token::Kind::TK_GT: return RESULT(Bool, >);
238 case Token::Kind::TK_GTEQ: return RESULT(Bool, >=);
239 case Token::Kind::TK_LT: return RESULT(Bool, <);
240 case Token::Kind::TK_LTEQ: return RESULT(Bool, <=);
241 default: return nullptr;
242 }
243 }
244
245 // Perform constant folding on pairs of vectors.
246 const Type& leftType = left.type();
247 const Type& rightType = right.type();
248 if (leftType.isVector() && leftType == rightType) {
249 if (leftType.componentType().isFloat()) {
John Stilesb30151e2021-01-11 16:13:08 -0500250 return simplify_vector<SKSL_FLOAT>(context, left, op, right);
John Stilesdc8ec312021-01-11 11:05:21 -0500251 }
252 if (leftType.componentType().isInteger()) {
John Stilesb30151e2021-01-11 16:13:08 -0500253 return simplify_vector<SKSL_INT>(context, left, op, right);
John Stilesdc8ec312021-01-11 11:05:21 -0500254 }
255 return nullptr;
256 }
257
John Stiles508eba72021-01-11 13:07:47 -0500258 // Perform constant folding on vectors against scalars, e.g.: half4(2) + 2
259 if (leftType.isVector() && leftType.componentType() == rightType) {
260 if (rightType.isFloat()) {
John Stilesb30151e2021-01-11 16:13:08 -0500261 return simplify_vector<SKSL_FLOAT>(context, left, op, splat_scalar(right, left.type()));
John Stiles508eba72021-01-11 13:07:47 -0500262 }
263 if (rightType.isInteger()) {
John Stilesb30151e2021-01-11 16:13:08 -0500264 return simplify_vector<SKSL_INT>(context, left, op, splat_scalar(right, left.type()));
John Stiles508eba72021-01-11 13:07:47 -0500265 }
266 return nullptr;
267 }
268
269 // Perform constant folding on scalars against vectors, e.g.: 2 + half4(2)
270 if (rightType.isVector() && rightType.componentType() == leftType) {
271 if (leftType.isFloat()) {
John Stilesb30151e2021-01-11 16:13:08 -0500272 return simplify_vector<SKSL_FLOAT>(context, splat_scalar(left, right.type()), op,
273 right);
John Stiles508eba72021-01-11 13:07:47 -0500274 }
275 if (leftType.isInteger()) {
John Stilesb30151e2021-01-11 16:13:08 -0500276 return simplify_vector<SKSL_INT>(context, splat_scalar(left, right.type()), op, right);
John Stiles508eba72021-01-11 13:07:47 -0500277 }
278 return nullptr;
279 }
280
John Stilesdc8ec312021-01-11 11:05:21 -0500281 // Perform constant folding on pairs of matrices.
282 if (leftType.isMatrix() && rightType.isMatrix()) {
283 bool equality;
284 switch (op) {
285 case Token::Kind::TK_EQEQ:
286 equality = true;
287 break;
288 case Token::Kind::TK_NEQ:
289 equality = false;
290 break;
291 default:
292 return nullptr;
293 }
294
295 switch (left.compareConstant(right)) {
296 case Expression::ComparisonResult::kNotEqual:
297 equality = !equality;
298 [[fallthrough]];
299
300 case Expression::ComparisonResult::kEqual:
301 return std::make_unique<BoolLiteral>(context, left.fOffset, equality);
302
303 case Expression::ComparisonResult::kUnknown:
304 return nullptr;
305 }
306 }
307
308 // We aren't able to constant-fold.
309 #undef RESULT
310 #undef URESULT
311 return nullptr;
312}
313
314} // namespace SkSL