blob: 5ed77093f3d402189acf9cfed4e55d9fc75c2dad [file] [log] [blame]
John Stilese2aec432021-03-01 09:27:48 -05001/*
2 * Copyright 2021 Google LLC
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8#include "src/sksl/SkSLAnalysis.h"
9#include "src/sksl/SkSLConstantFolder.h"
10#include "src/sksl/ir/SkSLBinaryExpression.h"
John Stiles9e13fe82021-03-23 18:45:25 -040011#include "src/sksl/ir/SkSLBoolLiteral.h"
12#include "src/sksl/ir/SkSLIndexExpression.h"
13#include "src/sksl/ir/SkSLSetting.h"
14#include "src/sksl/ir/SkSLSwizzle.h"
15#include "src/sksl/ir/SkSLTernaryExpression.h"
John Stilese2aec432021-03-01 09:27:48 -050016#include "src/sksl/ir/SkSLType.h"
17
18namespace SkSL {
19
John Stiles70cd5982021-03-25 11:09:53 -040020static bool is_low_precision_matrix_vector_multiply(const Expression& left,
21 const Operator& op,
22 const Expression& right,
23 const Type& resultType) {
24 return !resultType.highPrecision() &&
25 op.kind() == Token::Kind::TK_STAR &&
26 left.type().isMatrix() &&
27 right.type().isVector() &&
28 left.type().rows() == right.type().columns() &&
29 Analysis::IsTrivialExpression(left) &&
30 Analysis::IsTrivialExpression(right);
31}
32
33static std::unique_ptr<Expression> rewrite_matrix_vector_multiply(const Context& context,
34 const Expression& left,
35 const Operator& op,
36 const Expression& right,
37 const Type& resultType) {
38 // Rewrite m33 * v3 as (m[0] * v[0] + m[1] * v[1] + m[2] * v[2])
39 std::unique_ptr<Expression> sum;
40 for (int n = 0; n < left.type().rows(); ++n) {
41 // Get mat[N] with an index expression.
42 std::unique_ptr<Expression> matN = IndexExpression::Make(
43 context, left.clone(), IntLiteral::Make(context, left.fOffset, n));
44 // Get vec[N] with a swizzle expression.
45 std::unique_ptr<Expression> vecN = Swizzle::Make(
46 context, right.clone(), ComponentArray{(SkSL::SwizzleComponent::Type)n});
47 // Multiply them together.
48 const Type* matNType = &matN->type();
49 std::unique_ptr<Expression> product =
50 BinaryExpression::Make(context, std::move(matN), op, std::move(vecN), matNType);
51 // Sum all the components together.
52 if (!sum) {
53 sum = std::move(product);
54 } else {
55 sum = BinaryExpression::Make(context,
56 std::move(sum),
57 Operator(Token::Kind::TK_PLUS),
58 std::move(product),
59 matNType);
60 }
61 }
62
63 return sum;
64}
65
John Stiles23521a82021-03-02 17:02:51 -050066std::unique_ptr<Expression> BinaryExpression::Convert(const Context& context,
67 std::unique_ptr<Expression> left,
68 Operator op,
69 std::unique_ptr<Expression> right) {
John Stilese2aec432021-03-01 09:27:48 -050070 if (!left || !right) {
71 return nullptr;
72 }
John Stiles23521a82021-03-02 17:02:51 -050073 const int offset = left->fOffset;
John Stilese2aec432021-03-01 09:27:48 -050074
75 const Type* rawLeftType;
76 if (left->is<IntLiteral>() && right->type().isInteger()) {
77 rawLeftType = &right->type();
78 } else {
79 rawLeftType = &left->type();
80 }
81
82 const Type* rawRightType;
83 if (right->is<IntLiteral>() && left->type().isInteger()) {
84 rawRightType = &left->type();
85 } else {
86 rawRightType = &right->type();
87 }
88
John Stilese2aec432021-03-01 09:27:48 -050089 bool isAssignment = op.isAssignment();
90 if (isAssignment &&
91 !Analysis::MakeAssignmentExpr(left.get(),
92 op.kind() != Token::Kind::TK_EQ
93 ? VariableReference::RefKind::kReadWrite
94 : VariableReference::RefKind::kWrite,
95 &context.fErrors)) {
96 return nullptr;
97 }
98
99 const Type* leftType;
100 const Type* rightType;
101 const Type* resultType;
102 if (!op.determineBinaryType(context, *rawLeftType, *rawRightType,
103 &leftType, &rightType, &resultType)) {
104 context.fErrors.error(offset, String("type mismatch: '") + op.operatorName() +
105 "' cannot operate on '" + left->type().displayName() +
106 "', '" + right->type().displayName() + "'");
107 return nullptr;
108 }
John Stiles23521a82021-03-02 17:02:51 -0500109
John Stilese2aec432021-03-01 09:27:48 -0500110 if (isAssignment && leftType->componentType().isOpaque()) {
111 context.fErrors.error(offset, "assignments to opaque type '" + left->type().displayName() +
112 "' are not permitted");
John Stilese2aec432021-03-01 09:27:48 -0500113 return nullptr;
114 }
John Stiles23521a82021-03-02 17:02:51 -0500115 if (context.fConfig->strictES2Mode()) {
116 if (!op.isAllowedInStrictES2Mode()) {
117 context.fErrors.error(offset, String("operator '") + op.operatorName() +
118 "' is not allowed");
119 return nullptr;
120 }
121 if (leftType->isOrContainsArray()) {
122 // Most operators are already rejected on arrays, but GLSL ES 1.0 is very explicit that
123 // the *only* operator allowed on arrays is subscripting (and the rules against
124 // assignment, comparison, and even sequence apply to structs containing arrays as well)
125 context.fErrors.error(offset, String("operator '") + op.operatorName() + "' can not "
126 "operate on arrays (or structs containing arrays)");
127 return nullptr;
128 }
129 }
John Stilese2aec432021-03-01 09:27:48 -0500130
131 left = leftType->coerceExpression(std::move(left), context);
132 right = rightType->coerceExpression(std::move(right), context);
133 if (!left || !right) {
134 return nullptr;
135 }
136
John Stiles23521a82021-03-02 17:02:51 -0500137 return BinaryExpression::Make(context, std::move(left), op, std::move(right), resultType);
138}
139
140std::unique_ptr<Expression> BinaryExpression::Make(const Context& context,
141 std::unique_ptr<Expression> left,
142 Operator op,
143 std::unique_ptr<Expression> right) {
144 // Determine the result type of the binary expression.
145 const Type* leftType;
146 const Type* rightType;
147 const Type* resultType;
148 SkAssertResult(op.determineBinaryType(context, left->type(), right->type(),
149 &leftType, &rightType, &resultType));
150
151 return BinaryExpression::Make(context, std::move(left), op, std::move(right), resultType);
152}
153
154std::unique_ptr<Expression> BinaryExpression::Make(const Context& context,
155 std::unique_ptr<Expression> left,
156 Operator op,
157 std::unique_ptr<Expression> right,
158 const Type* resultType) {
159 // We should have detected non-ES2 compliant behavior in Convert.
160 SkASSERT(!context.fConfig->strictES2Mode() || op.isAllowedInStrictES2Mode());
161 SkASSERT(!context.fConfig->strictES2Mode() || !left->type().isOrContainsArray());
162
163 // We should have detected non-assignable assignment expressions in Convert.
164 SkASSERT(!op.isAssignment() || Analysis::IsAssignable(*left));
165 SkASSERT(!op.isAssignment() || !left->type().componentType().isOpaque());
166
167 // If we can detect division-by-zero, we should synthesize an error, but our caller is still
168 // expecting to receive a binary expression back; don't return nullptr.
169 const int offset = left->fOffset;
John Stilese2aec432021-03-01 09:27:48 -0500170 if (!ConstantFolder::ErrorOnDivideByZero(context, offset, op, *right)) {
John Stiles8f440b42021-03-05 16:48:56 -0500171 std::unique_ptr<Expression> result = ConstantFolder::Simplify(context, offset, *left,
172 op, *right, *resultType);
John Stiles23521a82021-03-02 17:02:51 -0500173 if (result) {
174 return result;
175 }
John Stilese2aec432021-03-01 09:27:48 -0500176 }
John Stiles23521a82021-03-02 17:02:51 -0500177
John Stiles70cd5982021-03-25 11:09:53 -0400178 if (context.fConfig->fSettings.fOptimize) {
179 // When sk_Caps.rewriteMatrixVectorMultiply is set, we rewrite medium-precision
180 // matrix * vector multiplication as:
181 // (sk_Caps.rewriteMatrixVectorMultiply ? (mat[0]*vec[0] + ... + mat[N]*vec[N])
182 // : mat * vec)
183 if (is_low_precision_matrix_vector_multiply(*left, op, *right, *resultType)) {
184 // Look up `sk_Caps.rewriteMatrixVectorMultiply`.
185 auto caps = Setting::Convert(context, left->fOffset, "rewriteMatrixVectorMultiply");
John Stiles9e13fe82021-03-23 18:45:25 -0400186
John Stiles70cd5982021-03-25 11:09:53 -0400187 bool capsBitIsTrue = caps->is<BoolLiteral>() && caps->as<BoolLiteral>().value();
188 if (capsBitIsTrue || !caps->is<BoolLiteral>()) {
189 // Rewrite the multiplication as a sum of vector-scalar products.
190 std::unique_ptr<Expression> rewrite =
191 rewrite_matrix_vector_multiply(context, *left, op, *right, *resultType);
192
193 // If we know the caps bit is true, return the rewritten expression directly.
194 if (capsBitIsTrue) {
195 return rewrite;
John Stiles9e13fe82021-03-23 18:45:25 -0400196 }
John Stiles9e13fe82021-03-23 18:45:25 -0400197
John Stiles70cd5982021-03-25 11:09:53 -0400198 // Return a ternary expression:
199 // sk_Caps.rewriteMatrixVectorMultiply ? (rewrite) : (mat * vec)
200 return TernaryExpression::Make(
201 context,
202 std::move(caps),
203 std::move(rewrite),
204 std::make_unique<BinaryExpression>(offset, std::move(left), op,
205 std::move(right), resultType));
John Stiles9e13fe82021-03-23 18:45:25 -0400206 }
John Stiles9e13fe82021-03-23 18:45:25 -0400207 }
208 }
209
210 return std::make_unique<BinaryExpression>(offset, std::move(left), op,
211 std::move(right), resultType);
John Stilese2aec432021-03-01 09:27:48 -0500212}
213
214bool BinaryExpression::CheckRef(const Expression& expr) {
215 switch (expr.kind()) {
216 case Expression::Kind::kFieldAccess:
217 return CheckRef(*expr.as<FieldAccess>().base());
218
219 case Expression::Kind::kIndex:
220 return CheckRef(*expr.as<IndexExpression>().base());
221
222 case Expression::Kind::kSwizzle:
223 return CheckRef(*expr.as<Swizzle>().base());
224
225 case Expression::Kind::kTernary: {
226 const TernaryExpression& t = expr.as<TernaryExpression>();
227 return CheckRef(*t.ifTrue()) && CheckRef(*t.ifFalse());
228 }
229 case Expression::Kind::kVariableReference: {
230 const VariableReference& ref = expr.as<VariableReference>();
231 return ref.refKind() == VariableRefKind::kWrite ||
232 ref.refKind() == VariableRefKind::kReadWrite;
233 }
234 default:
235 return false;
236 }
237}
238
239std::unique_ptr<Expression> BinaryExpression::clone() const {
240 return std::make_unique<BinaryExpression>(fOffset,
241 this->left()->clone(),
242 this->getOperator(),
243 this->right()->clone(),
244 &this->type());
245}
246
247String BinaryExpression::description() const {
248 return "(" + this->left()->description() +
249 " " + this->getOperator().operatorName() +
250 " " + this->right()->description() + ")";
251}
252
253} // namespace SkSL