Allow casting to lower precision types in runtime effects
Bug: skia:10679
Change-Id: If464c48b7c31d0d8440d1231d1983829d54ce598
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/315281
Commit-Queue: Brian Osman <brianosman@google.com>
Reviewed-by: John Stiles <johnstiles@google.com>
diff --git a/src/core/SkRuntimeEffect.cpp b/src/core/SkRuntimeEffect.cpp
index 9c1ae85..e23e3f9 100644
--- a/src/core/SkRuntimeEffect.cpp
+++ b/src/core/SkRuntimeEffect.cpp
@@ -119,6 +119,7 @@
SkSL::SharedCompiler compiler;
SkSL::Program::Settings settings;
settings.fInlineThreshold = compiler.getInlineThreshold();
+ settings.fAllowNarrowingConversions = true;
auto program = compiler->convertProgram(SkSL::Program::kPipelineStage_Kind,
SkSL::String(sksl.c_str(), sksl.size()),
settings);
@@ -306,6 +307,7 @@
SkSL::Program::Settings settings;
settings.fCaps = shaderCaps;
settings.fInlineThreshold = compiler.getInlineThreshold();
+ settings.fAllowNarrowingConversions = true;
auto program = compiler->convertProgram(SkSL::Program::kPipelineStage_Kind,
SkSL::String(fSkSL.c_str(), fSkSL.size()),
diff --git a/src/sksl/SkSLIRGenerator.cpp b/src/sksl/SkSLIRGenerator.cpp
index d2dec9b..7b38b5e 100644
--- a/src/sksl/SkSLIRGenerator.cpp
+++ b/src/sksl/SkSLIRGenerator.cpp
@@ -1489,7 +1489,7 @@
if (expr->type() == *fContext.fInvalid_Type) {
return nullptr;
}
- if (expr->coercionCost(type) == INT_MAX) {
+ if (!expr->coercionCost(type).isPossible(fSettings->fAllowNarrowingConversions)) {
fErrors.error(expr->fOffset, "expected '" + type.displayName() + "', but found '" +
expr->type().displayName() + "'");
return nullptr;
@@ -1537,6 +1537,7 @@
* legal, false otherwise. If false, the values of the out parameters are undefined.
*/
static bool determine_binary_type(const Context& context,
+ bool allowNarrowing,
Token::Kind op,
const Type& left,
const Type& right,
@@ -1552,22 +1553,28 @@
*outLeftType = &left;
*outRightType = &left;
*outResultType = &left;
- return right.canCoerceTo(left);
+ return right.canCoerceTo(left, allowNarrowing);
case Token::Kind::TK_EQEQ: // fall through
- case Token::Kind::TK_NEQ:
- if (right.canCoerceTo(left)) {
- *outLeftType = &left;
- *outRightType = &left;
- *outResultType = context.fBool_Type.get();
- return true;
- }
- if (left.canCoerceTo(right)) {
- *outLeftType = &right;
- *outRightType = &right;
- *outResultType = context.fBool_Type.get();
- return true;
+ case Token::Kind::TK_NEQ: {
+ CoercionCost rightToLeft = right.coercionCost(left),
+ leftToRight = left.coercionCost(right);
+ if (rightToLeft < leftToRight) {
+ if (rightToLeft.isPossible(allowNarrowing)) {
+ *outLeftType = &left;
+ *outRightType = &left;
+ *outResultType = context.fBool_Type.get();
+ return true;
+ }
+ } else {
+ if (leftToRight.isPossible(allowNarrowing)) {
+ *outLeftType = &right;
+ *outRightType = &right;
+ *outResultType = context.fBool_Type.get();
+ return true;
+ }
}
return false;
+ }
case Token::Kind::TK_LT: // fall through
case Token::Kind::TK_GT: // fall through
case Token::Kind::TK_LTEQ: // fall through
@@ -1583,15 +1590,15 @@
*outLeftType = context.fBool_Type.get();
*outRightType = context.fBool_Type.get();
*outResultType = context.fBool_Type.get();
- return left.canCoerceTo(*context.fBool_Type) &&
- right.canCoerceTo(*context.fBool_Type);
+ return left.canCoerceTo(*context.fBool_Type, allowNarrowing) &&
+ right.canCoerceTo(*context.fBool_Type, allowNarrowing);
case Token::Kind::TK_STAREQ: // fall through
case Token::Kind::TK_STAR:
if (is_matrix_multiply(left, right)) {
// determine final component type
- if (determine_binary_type(context, Token::Kind::TK_STAR, left.componentType(),
- right.componentType(), outLeftType, outRightType,
- outResultType)) {
+ if (determine_binary_type(context, allowNarrowing, Token::Kind::TK_STAR,
+ left.componentType(), right.componentType(),
+ outLeftType, outRightType, outResultType)) {
*outLeftType = &(*outResultType)->toCompound(context, left.columns(),
left.rows());
*outRightType = &(*outResultType)->toCompound(context, right.columns(),
@@ -1657,8 +1664,8 @@
if (leftIsVectorOrMatrix && validMatrixOrVectorOp &&
right.typeKind() == Type::TypeKind::kScalar) {
- if (determine_binary_type(context, op, left.componentType(), right, outLeftType,
- outRightType, outResultType)) {
+ if (determine_binary_type(context, allowNarrowing, op, left.componentType(), right,
+ outLeftType, outRightType, outResultType)) {
*outLeftType = &(*outLeftType)->toCompound(context, left.columns(), left.rows());
if (!isLogical) {
*outResultType =
@@ -1671,8 +1678,8 @@
if (!isAssignment && rightIsVectorOrMatrix && validMatrixOrVectorOp &&
left.typeKind() == Type::TypeKind::kScalar) {
- if (determine_binary_type(context, op, left, right.componentType(), outLeftType,
- outRightType, outResultType)) {
+ if (determine_binary_type(context, allowNarrowing, op, left, right.componentType(),
+ outLeftType, outRightType, outResultType)) {
*outRightType = &(*outRightType)->toCompound(context, right.columns(), right.rows());
if (!isLogical) {
*outResultType =
@@ -1683,18 +1690,19 @@
return false;
}
- int rightToLeftCost = right.coercionCost(left);
- int leftToRightCost = isAssignment ? INT_MAX : left.coercionCost(right);
+ CoercionCost rightToLeftCost = right.coercionCost(left);
+ CoercionCost leftToRightCost = isAssignment ? CoercionCost::Impossible()
+ : left.coercionCost(right);
if ((left.typeKind() == Type::TypeKind::kScalar &&
right.typeKind() == Type::TypeKind::kScalar) ||
(leftIsVectorOrMatrix && validMatrixOrVectorOp)) {
- if (rightToLeftCost < leftToRightCost) {
- // Right-to-Left conversion is cheaper (and therefore possible)
+ if (rightToLeftCost.isPossible(allowNarrowing) && rightToLeftCost < leftToRightCost) {
+ // Right-to-Left conversion is possible and cheaper
*outLeftType = &left;
*outRightType = &left;
*outResultType = &left;
- } else if (leftToRightCost != INT_MAX) {
+ } else if (leftToRightCost.isPossible(allowNarrowing)) {
// Left-to-Right conversion is possible (and at least as cheap as Right-to-Left)
*outLeftType = &right;
*outRightType = &right;
@@ -1942,8 +1950,8 @@
} else {
rawRightType = &right->type();
}
- if (!determine_binary_type(fContext, op, *rawLeftType, *rawRightType,
- &leftType, &rightType, &resultType)) {
+ if (!determine_binary_type(fContext, fSettings->fAllowNarrowingConversions, op,
+ *rawLeftType, *rawRightType, &leftType, &rightType, &resultType)) {
fErrors.error(expression.fOffset, String("type mismatch: '") +
Compiler::OperatorName(expression.getToken().fKind) +
"' cannot operate on '" + left->type().displayName() +
@@ -1994,8 +2002,10 @@
const Type* trueType;
const Type* falseType;
const Type* resultType;
- if (!determine_binary_type(fContext, Token::Kind::TK_EQEQ, ifTrue->type(), ifFalse->type(),
- &trueType, &falseType, &resultType) || trueType != falseType) {
+ if (!determine_binary_type(fContext, fSettings->fAllowNarrowingConversions,
+ Token::Kind::TK_EQEQ, ifTrue->type(), ifFalse->type(),
+ &trueType, &falseType, &resultType) ||
+ trueType != falseType) {
fErrors.error(node.fOffset, "ternary operator result mismatch: '" +
ifTrue->type().displayName() + "', '" +
ifFalse->type().displayName() + "'");
@@ -2109,27 +2119,22 @@
/**
* Determines the cost of coercing the arguments of a function to the required types. Cost has no
- * particular meaning other than "lower costs are preferred". Returns INT_MAX if the call is not
- * valid.
+ * particular meaning other than "lower costs are preferred". Returns CoercionCost::Impossible() if
+ * the call is not valid.
*/
-int IRGenerator::callCost(const FunctionDeclaration& function,
- const std::vector<std::unique_ptr<Expression>>& arguments) {
+CoercionCost IRGenerator::callCost(const FunctionDeclaration& function,
+ const std::vector<std::unique_ptr<Expression>>& arguments) {
if (function.fParameters.size() != arguments.size()) {
- return INT_MAX;
+ return CoercionCost::Impossible();
}
- int total = 0;
std::vector<const Type*> types;
const Type* ignored;
if (!function.determineFinalTypes(arguments, &types, &ignored)) {
- return INT_MAX;
+ return CoercionCost::Impossible();
}
+ CoercionCost total = CoercionCost::Free();
for (size_t i = 0; i < arguments.size(); i++) {
- int cost = arguments[i]->coercionCost(*types[i]);
- if (cost != INT_MAX) {
- total += cost;
- } else {
- return INT_MAX;
- }
+ total = total + arguments[i]->coercionCost(*types[i]);
}
return total;
}
@@ -2169,11 +2174,11 @@
}
case Expression::Kind::kFunctionReference: {
const FunctionReference& ref = functionValue->as<FunctionReference>();
- int bestCost = INT_MAX;
+ CoercionCost bestCost = CoercionCost::Impossible();
const FunctionDeclaration* best = nullptr;
if (ref.fFunctions.size() > 1) {
for (const auto& f : ref.fFunctions) {
- int cost = this->callCost(*f, arguments);
+ CoercionCost cost = this->callCost(*f, arguments);
if (cost < bestCost) {
bestCost = cost;
best = f;
diff --git a/src/sksl/SkSLIRGenerator.h b/src/sksl/SkSLIRGenerator.h
index a3011a5..ee3fa5b 100644
--- a/src/sksl/SkSLIRGenerator.h
+++ b/src/sksl/SkSLIRGenerator.h
@@ -103,11 +103,11 @@
std::unique_ptr<Expression> call(int offset,
const FunctionDeclaration& function,
std::vector<std::unique_ptr<Expression>> arguments);
- int callCost(const FunctionDeclaration& function,
- const std::vector<std::unique_ptr<Expression>>& arguments);
+ CoercionCost callCost(const FunctionDeclaration& function,
+ const std::vector<std::unique_ptr<Expression>>& arguments);
std::unique_ptr<Expression> call(int offset, std::unique_ptr<Expression> function,
std::vector<std::unique_ptr<Expression>> arguments);
- int coercionCost(const Expression& expr, const Type& type);
+ CoercionCost coercionCost(const Expression& expr, const Type& type);
std::unique_ptr<Expression> coerce(std::unique_ptr<Expression> expr, const Type& type);
std::unique_ptr<Block> convertBlock(const ASTNode& block);
std::unique_ptr<Statement> convertBreak(const ASTNode& b);
diff --git a/src/sksl/ir/SkSLExpression.h b/src/sksl/ir/SkSLExpression.h
index 4e6c17c..7beb341 100644
--- a/src/sksl/ir/SkSLExpression.h
+++ b/src/sksl/ir/SkSLExpression.h
@@ -154,7 +154,7 @@
return nullptr;
}
- virtual int coercionCost(const Type& target) const {
+ virtual CoercionCost coercionCost(const Type& target) const {
return this->type().coercionCost(target);
}
diff --git a/src/sksl/ir/SkSLFloatLiteral.h b/src/sksl/ir/SkSLFloatLiteral.h
index 0bbdc14..1a0ca30 100644
--- a/src/sksl/ir/SkSLFloatLiteral.h
+++ b/src/sksl/ir/SkSLFloatLiteral.h
@@ -39,9 +39,9 @@
return true;
}
- int coercionCost(const Type& target) const override {
+ CoercionCost coercionCost(const Type& target) const override {
if (target.isFloat()) {
- return 0;
+ return CoercionCost::Free();
}
return INHERITED::coercionCost(target);
}
diff --git a/src/sksl/ir/SkSLFunctionDeclaration.h b/src/sksl/ir/SkSLFunctionDeclaration.h
index 2635d9a..fb7afc7 100644
--- a/src/sksl/ir/SkSLFunctionDeclaration.h
+++ b/src/sksl/ir/SkSLFunctionDeclaration.h
@@ -74,6 +74,10 @@
* does not guarantee that the function can be successfully called with those arguments, merely
* indicates that an attempt should be made. If false is returned, the state of
* outParameterTypes and outReturnType are undefined.
+ *
+ * This always assumes narrowing conversions are *allowed*. The calling code needs to verify
+ * that each argument can actually be coerced to the final parameter type, respecting the
+ * narrowing-conversions flag. This is handled in callCost(), or in convertCall() (via coerce).
*/
bool determineFinalTypes(const std::vector<std::unique_ptr<Expression>>& arguments,
std::vector<const Type*>* outParameterTypes,
@@ -86,7 +90,7 @@
std::vector<const Type*> types = parameterType.coercibleTypes();
if (genericIndex == -1) {
for (size_t j = 0; j < types.size(); j++) {
- if (arguments[i]->type().canCoerceTo(*types[j])) {
+ if (arguments[i]->type().canCoerceTo(*types[j], /*allowNarrowing=*/true)) {
genericIndex = j;
break;
}
diff --git a/src/sksl/ir/SkSLIntLiteral.h b/src/sksl/ir/SkSLIntLiteral.h
index 1ec16d5..27e3e0c 100644
--- a/src/sksl/ir/SkSLIntLiteral.h
+++ b/src/sksl/ir/SkSLIntLiteral.h
@@ -45,10 +45,10 @@
return fValue == other.as<IntLiteral>().fValue;
}
- int coercionCost(const Type& target) const override {
+ CoercionCost coercionCost(const Type& target) const override {
if (target.isSigned() || target.isUnsigned() || target.isFloat() ||
target.typeKind() == Type::TypeKind::kEnum) {
- return 0;
+ return CoercionCost::Free();
}
return INHERITED::coercionCost(target);
}
diff --git a/src/sksl/ir/SkSLProgram.h b/src/sksl/ir/SkSLProgram.h
index b1520ef..6ae2d73 100644
--- a/src/sksl/ir/SkSLProgram.h
+++ b/src/sksl/ir/SkSLProgram.h
@@ -122,6 +122,9 @@
int fInlineThreshold = 50;
// true to enable optimization passes
bool fOptimize = true;
+ // If true, implicit conversions to lower precision numeric types are allowed
+ // (eg, float to half)
+ bool fAllowNarrowingConversions = false;
};
struct Inputs {
diff --git a/src/sksl/ir/SkSLType.cpp b/src/sksl/ir/SkSLType.cpp
index bd2c497..81cfdbf 100644
--- a/src/sksl/ir/SkSLType.cpp
+++ b/src/sksl/ir/SkSLType.cpp
@@ -10,41 +10,45 @@
namespace SkSL {
-int Type::coercionCost(const Type& other) const {
+CoercionCost Type::coercionCost(const Type& other) const {
if (*this == other) {
- return 0;
+ return CoercionCost::Free();
}
if (this->typeKind() == TypeKind::kNullable && other.typeKind() != TypeKind::kNullable) {
- int result = this->componentType().coercionCost(other);
- if (result != INT_MAX) {
- ++result;
+ CoercionCost result = this->componentType().coercionCost(other);
+ if (result.isPossible(/*allowNarrowing=*/true)) {
+ ++result.fNormalCost;
}
return result;
}
if (this->fName == "null" && other.typeKind() == TypeKind::kNullable) {
- return 0;
+ return CoercionCost::Free();
}
if (this->typeKind() == TypeKind::kVector && other.typeKind() == TypeKind::kVector) {
if (this->columns() == other.columns()) {
return this->componentType().coercionCost(other.componentType());
}
- return INT_MAX;
+ return CoercionCost::Impossible();
}
if (this->typeKind() == TypeKind::kMatrix) {
if (this->columns() == other.columns() && this->rows() == other.rows()) {
return this->componentType().coercionCost(other.componentType());
}
- return INT_MAX;
+ return CoercionCost::Impossible();
}
- if (this->isNumber() && other.isNumber() && other.priority() > this->priority()) {
- return other.priority() - this->priority();
+ if (this->isNumber() && other.isNumber()) {
+ if (other.priority() >= this->priority()) {
+ return CoercionCost::Normal(other.priority() - this->priority());
+ } else {
+ return CoercionCost::Narrowing(this->priority() - other.priority());
+ }
}
for (size_t i = 0; i < fCoercibleTypes.size(); i++) {
if (*fCoercibleTypes[i] == other) {
- return (int) i + 1;
+ return CoercionCost::Normal((int) i + 1);
}
}
- return INT_MAX;
+ return CoercionCost::Impossible();
}
const Type& Type::toCompound(const Context& context, int columns, int rows) const {
diff --git a/src/sksl/ir/SkSLType.h b/src/sksl/ir/SkSLType.h
index 88faf78..c5f99f7 100644
--- a/src/sksl/ir/SkSLType.h
+++ b/src/sksl/ir/SkSLType.h
@@ -13,6 +13,7 @@
#include "src/sksl/ir/SkSLModifiers.h"
#include "src/sksl/ir/SkSLSymbol.h"
#include "src/sksl/spirv.h"
+#include <algorithm>
#include <climits>
#include <vector>
#include <memory>
@@ -21,6 +22,34 @@
class Context;
+struct CoercionCost {
+ static CoercionCost Free() { return { 0, 0, false }; }
+ static CoercionCost Normal(int cost) { return { cost, 0, false }; }
+ static CoercionCost Narrowing(int cost) { return { 0, cost, false }; }
+ static CoercionCost Impossible() { return { 0, 0, true }; }
+
+ bool isPossible(bool allowNarrowing) const {
+ return !fImpossible && (fNarrowingCost == 0 || allowNarrowing);
+ }
+
+ // Addition of two costs. Saturates at Impossible().
+ CoercionCost operator+(CoercionCost rhs) const {
+ if (fImpossible || rhs.fImpossible) {
+ return Impossible();
+ }
+ return { fNormalCost + rhs.fNormalCost, fNarrowingCost + rhs.fNarrowingCost, false };
+ }
+
+ bool operator<(CoercionCost rhs) const {
+ return std::tie( fImpossible, fNarrowingCost, fNormalCost) <
+ std::tie(rhs.fImpossible, rhs.fNarrowingCost, rhs.fNormalCost);
+ }
+
+ int fNormalCost;
+ int fNarrowingCost;
+ bool fImpossible;
+};
+
/**
* Represents a type, such as int or float4.
*/
@@ -310,8 +339,8 @@
* Returns true if an instance of this type can be freely coerced (implicitly converted) to
* another type.
*/
- bool canCoerceTo(const Type& other) const {
- return coercionCost(other) != INT_MAX;
+ bool canCoerceTo(const Type& other, bool allowNarrowing) const {
+ return this->coercionCost(other).isPossible(allowNarrowing);
}
/**
@@ -319,7 +348,7 @@
* is a number with no particular meaning other than that lower costs are preferable to higher
* costs. Returns INT_MAX if the coercion is not possible.
*/
- int coercionCost(const Type& other) const;
+ CoercionCost coercionCost(const Type& other) const;
/**
* For matrices and vectors, returns the type of individual cells (e.g. mat2 has a component
diff --git a/tests/SkRuntimeEffectTest.cpp b/tests/SkRuntimeEffectTest.cpp
index 2ddf507..0f91c0b 100644
--- a/tests/SkRuntimeEffectTest.cpp
+++ b/tests/SkRuntimeEffectTest.cpp
@@ -218,6 +218,10 @@
effect.test(0xFF000000, 0xFF00007F, 0xFF007F00, 0xFF007F7F,
[](SkCanvas* canvas, SkPaint*) { canvas->rotate(45.0f); });
+ // Runtime effects should use relaxed precision rules by default
+ effect.build("", "return float4(p - 0.5, 0, 1);");
+ effect.test(0xFF000000, 0xFF0000FF, 0xFF00FF00, 0xFF00FFFF);
+
//
// Sampling children
//