Allow casting to lower precision types in runtime effects

Bug: skia:10679
Change-Id: If464c48b7c31d0d8440d1231d1983829d54ce598
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/315281
Commit-Queue: Brian Osman <brianosman@google.com>
Reviewed-by: John Stiles <johnstiles@google.com>
diff --git a/src/core/SkRuntimeEffect.cpp b/src/core/SkRuntimeEffect.cpp
index 9c1ae85..e23e3f9 100644
--- a/src/core/SkRuntimeEffect.cpp
+++ b/src/core/SkRuntimeEffect.cpp
@@ -119,6 +119,7 @@
     SkSL::SharedCompiler compiler;
     SkSL::Program::Settings settings;
     settings.fInlineThreshold = compiler.getInlineThreshold();
+    settings.fAllowNarrowingConversions = true;
     auto program = compiler->convertProgram(SkSL::Program::kPipelineStage_Kind,
                                             SkSL::String(sksl.c_str(), sksl.size()),
                                             settings);
@@ -306,6 +307,7 @@
     SkSL::Program::Settings settings;
     settings.fCaps = shaderCaps;
     settings.fInlineThreshold = compiler.getInlineThreshold();
+    settings.fAllowNarrowingConversions = true;
 
     auto program = compiler->convertProgram(SkSL::Program::kPipelineStage_Kind,
                                             SkSL::String(fSkSL.c_str(), fSkSL.size()),
diff --git a/src/sksl/SkSLIRGenerator.cpp b/src/sksl/SkSLIRGenerator.cpp
index d2dec9b..7b38b5e 100644
--- a/src/sksl/SkSLIRGenerator.cpp
+++ b/src/sksl/SkSLIRGenerator.cpp
@@ -1489,7 +1489,7 @@
     if (expr->type() == *fContext.fInvalid_Type) {
         return nullptr;
     }
-    if (expr->coercionCost(type) == INT_MAX) {
+    if (!expr->coercionCost(type).isPossible(fSettings->fAllowNarrowingConversions)) {
         fErrors.error(expr->fOffset, "expected '" + type.displayName() + "', but found '" +
                                      expr->type().displayName() + "'");
         return nullptr;
@@ -1537,6 +1537,7 @@
  * legal, false otherwise. If false, the values of the out parameters are undefined.
  */
 static bool determine_binary_type(const Context& context,
+                                  bool allowNarrowing,
                                   Token::Kind op,
                                   const Type& left,
                                   const Type& right,
@@ -1552,22 +1553,28 @@
             *outLeftType = &left;
             *outRightType = &left;
             *outResultType = &left;
-            return right.canCoerceTo(left);
+            return right.canCoerceTo(left, allowNarrowing);
         case Token::Kind::TK_EQEQ: // fall through
-        case Token::Kind::TK_NEQ:
-            if (right.canCoerceTo(left)) {
-                *outLeftType = &left;
-                *outRightType = &left;
-                *outResultType = context.fBool_Type.get();
-                return true;
-            }
-            if (left.canCoerceTo(right)) {
-                *outLeftType = &right;
-                *outRightType = &right;
-                *outResultType = context.fBool_Type.get();
-                return true;
+        case Token::Kind::TK_NEQ: {
+            CoercionCost rightToLeft = right.coercionCost(left),
+                         leftToRight = left.coercionCost(right);
+            if (rightToLeft < leftToRight) {
+                if (rightToLeft.isPossible(allowNarrowing)) {
+                    *outLeftType = &left;
+                    *outRightType = &left;
+                    *outResultType = context.fBool_Type.get();
+                    return true;
+                }
+            } else {
+                if (leftToRight.isPossible(allowNarrowing)) {
+                    *outLeftType = &right;
+                    *outRightType = &right;
+                    *outResultType = context.fBool_Type.get();
+                    return true;
+                }
             }
             return false;
+        }
         case Token::Kind::TK_LT:   // fall through
         case Token::Kind::TK_GT:   // fall through
         case Token::Kind::TK_LTEQ: // fall through
@@ -1583,15 +1590,15 @@
             *outLeftType = context.fBool_Type.get();
             *outRightType = context.fBool_Type.get();
             *outResultType = context.fBool_Type.get();
-            return left.canCoerceTo(*context.fBool_Type) &&
-                   right.canCoerceTo(*context.fBool_Type);
+            return left.canCoerceTo(*context.fBool_Type, allowNarrowing) &&
+                   right.canCoerceTo(*context.fBool_Type, allowNarrowing);
         case Token::Kind::TK_STAREQ: // fall through
         case Token::Kind::TK_STAR:
             if (is_matrix_multiply(left, right)) {
                 // determine final component type
-                if (determine_binary_type(context, Token::Kind::TK_STAR, left.componentType(),
-                                          right.componentType(), outLeftType, outRightType,
-                                          outResultType)) {
+                if (determine_binary_type(context, allowNarrowing, Token::Kind::TK_STAR,
+                                          left.componentType(), right.componentType(),
+                                          outLeftType, outRightType, outResultType)) {
                     *outLeftType = &(*outResultType)->toCompound(context, left.columns(),
                                                                  left.rows());
                     *outRightType = &(*outResultType)->toCompound(context, right.columns(),
@@ -1657,8 +1664,8 @@
 
     if (leftIsVectorOrMatrix && validMatrixOrVectorOp &&
         right.typeKind() == Type::TypeKind::kScalar) {
-        if (determine_binary_type(context, op, left.componentType(), right, outLeftType,
-                                  outRightType, outResultType)) {
+        if (determine_binary_type(context, allowNarrowing, op, left.componentType(), right,
+                                  outLeftType, outRightType, outResultType)) {
             *outLeftType = &(*outLeftType)->toCompound(context, left.columns(), left.rows());
             if (!isLogical) {
                 *outResultType =
@@ -1671,8 +1678,8 @@
 
     if (!isAssignment && rightIsVectorOrMatrix && validMatrixOrVectorOp &&
         left.typeKind() == Type::TypeKind::kScalar) {
-        if (determine_binary_type(context, op, left, right.componentType(), outLeftType,
-                                  outRightType, outResultType)) {
+        if (determine_binary_type(context, allowNarrowing, op, left, right.componentType(),
+                                  outLeftType, outRightType, outResultType)) {
             *outRightType = &(*outRightType)->toCompound(context, right.columns(), right.rows());
             if (!isLogical) {
                 *outResultType =
@@ -1683,18 +1690,19 @@
         return false;
     }
 
-    int rightToLeftCost = right.coercionCost(left);
-    int leftToRightCost = isAssignment ? INT_MAX : left.coercionCost(right);
+    CoercionCost rightToLeftCost = right.coercionCost(left);
+    CoercionCost leftToRightCost = isAssignment ? CoercionCost::Impossible()
+                                                : left.coercionCost(right);
 
     if ((left.typeKind() == Type::TypeKind::kScalar &&
          right.typeKind() == Type::TypeKind::kScalar) ||
         (leftIsVectorOrMatrix && validMatrixOrVectorOp)) {
-        if (rightToLeftCost < leftToRightCost) {
-            // Right-to-Left conversion is cheaper (and therefore possible)
+        if (rightToLeftCost.isPossible(allowNarrowing) && rightToLeftCost < leftToRightCost) {
+            // Right-to-Left conversion is possible and cheaper
             *outLeftType = &left;
             *outRightType = &left;
             *outResultType = &left;
-        } else if (leftToRightCost != INT_MAX) {
+        } else if (leftToRightCost.isPossible(allowNarrowing)) {
             // Left-to-Right conversion is possible (and at least as cheap as Right-to-Left)
             *outLeftType = &right;
             *outRightType = &right;
@@ -1942,8 +1950,8 @@
     } else {
         rawRightType = &right->type();
     }
-    if (!determine_binary_type(fContext, op, *rawLeftType, *rawRightType,
-                               &leftType, &rightType, &resultType)) {
+    if (!determine_binary_type(fContext, fSettings->fAllowNarrowingConversions, op,
+                               *rawLeftType, *rawRightType, &leftType, &rightType, &resultType)) {
         fErrors.error(expression.fOffset, String("type mismatch: '") +
                                           Compiler::OperatorName(expression.getToken().fKind) +
                                           "' cannot operate on '" + left->type().displayName() +
@@ -1994,8 +2002,10 @@
     const Type* trueType;
     const Type* falseType;
     const Type* resultType;
-    if (!determine_binary_type(fContext, Token::Kind::TK_EQEQ, ifTrue->type(), ifFalse->type(),
-                               &trueType, &falseType, &resultType) || trueType != falseType) {
+    if (!determine_binary_type(fContext, fSettings->fAllowNarrowingConversions,
+                               Token::Kind::TK_EQEQ, ifTrue->type(), ifFalse->type(),
+                               &trueType, &falseType, &resultType) ||
+        trueType != falseType) {
         fErrors.error(node.fOffset, "ternary operator result mismatch: '" +
                                     ifTrue->type().displayName() + "', '" +
                                     ifFalse->type().displayName() + "'");
@@ -2109,27 +2119,22 @@
 
 /**
  * Determines the cost of coercing the arguments of a function to the required types. Cost has no
- * particular meaning other than "lower costs are preferred". Returns INT_MAX if the call is not
- * valid.
+ * particular meaning other than "lower costs are preferred". Returns CoercionCost::Impossible() if
+ * the call is not valid.
  */
-int IRGenerator::callCost(const FunctionDeclaration& function,
-             const std::vector<std::unique_ptr<Expression>>& arguments) {
+CoercionCost IRGenerator::callCost(const FunctionDeclaration& function,
+                                   const std::vector<std::unique_ptr<Expression>>& arguments) {
     if (function.fParameters.size() != arguments.size()) {
-        return INT_MAX;
+        return CoercionCost::Impossible();
     }
-    int total = 0;
     std::vector<const Type*> types;
     const Type* ignored;
     if (!function.determineFinalTypes(arguments, &types, &ignored)) {
-        return INT_MAX;
+        return CoercionCost::Impossible();
     }
+    CoercionCost total = CoercionCost::Free();
     for (size_t i = 0; i < arguments.size(); i++) {
-        int cost = arguments[i]->coercionCost(*types[i]);
-        if (cost != INT_MAX) {
-            total += cost;
-        } else {
-            return INT_MAX;
-        }
+        total = total + arguments[i]->coercionCost(*types[i]);
     }
     return total;
 }
@@ -2169,11 +2174,11 @@
         }
         case Expression::Kind::kFunctionReference: {
             const FunctionReference& ref = functionValue->as<FunctionReference>();
-            int bestCost = INT_MAX;
+            CoercionCost bestCost = CoercionCost::Impossible();
             const FunctionDeclaration* best = nullptr;
             if (ref.fFunctions.size() > 1) {
                 for (const auto& f : ref.fFunctions) {
-                    int cost = this->callCost(*f, arguments);
+                    CoercionCost cost = this->callCost(*f, arguments);
                     if (cost < bestCost) {
                         bestCost = cost;
                         best = f;
diff --git a/src/sksl/SkSLIRGenerator.h b/src/sksl/SkSLIRGenerator.h
index a3011a5..ee3fa5b 100644
--- a/src/sksl/SkSLIRGenerator.h
+++ b/src/sksl/SkSLIRGenerator.h
@@ -103,11 +103,11 @@
     std::unique_ptr<Expression> call(int offset,
                                      const FunctionDeclaration& function,
                                      std::vector<std::unique_ptr<Expression>> arguments);
-    int callCost(const FunctionDeclaration& function,
-                 const std::vector<std::unique_ptr<Expression>>& arguments);
+    CoercionCost callCost(const FunctionDeclaration& function,
+                          const std::vector<std::unique_ptr<Expression>>& arguments);
     std::unique_ptr<Expression> call(int offset, std::unique_ptr<Expression> function,
                                      std::vector<std::unique_ptr<Expression>> arguments);
-    int coercionCost(const Expression& expr, const Type& type);
+    CoercionCost coercionCost(const Expression& expr, const Type& type);
     std::unique_ptr<Expression> coerce(std::unique_ptr<Expression> expr, const Type& type);
     std::unique_ptr<Block> convertBlock(const ASTNode& block);
     std::unique_ptr<Statement> convertBreak(const ASTNode& b);
diff --git a/src/sksl/ir/SkSLExpression.h b/src/sksl/ir/SkSLExpression.h
index 4e6c17c..7beb341 100644
--- a/src/sksl/ir/SkSLExpression.h
+++ b/src/sksl/ir/SkSLExpression.h
@@ -154,7 +154,7 @@
         return nullptr;
     }
 
-    virtual int coercionCost(const Type& target) const {
+    virtual CoercionCost coercionCost(const Type& target) const {
         return this->type().coercionCost(target);
     }
 
diff --git a/src/sksl/ir/SkSLFloatLiteral.h b/src/sksl/ir/SkSLFloatLiteral.h
index 0bbdc14..1a0ca30 100644
--- a/src/sksl/ir/SkSLFloatLiteral.h
+++ b/src/sksl/ir/SkSLFloatLiteral.h
@@ -39,9 +39,9 @@
         return true;
     }
 
-    int coercionCost(const Type& target) const override {
+    CoercionCost coercionCost(const Type& target) const override {
         if (target.isFloat()) {
-            return 0;
+            return CoercionCost::Free();
         }
         return INHERITED::coercionCost(target);
     }
diff --git a/src/sksl/ir/SkSLFunctionDeclaration.h b/src/sksl/ir/SkSLFunctionDeclaration.h
index 2635d9a..fb7afc7 100644
--- a/src/sksl/ir/SkSLFunctionDeclaration.h
+++ b/src/sksl/ir/SkSLFunctionDeclaration.h
@@ -74,6 +74,10 @@
      * does not guarantee that the function can be successfully called with those arguments, merely
      * indicates that an attempt should be made. If false is returned, the state of
      * outParameterTypes and outReturnType are undefined.
+     *
+     * This always assumes narrowing conversions are *allowed*. The calling code needs to verify
+     * that each argument can actually be coerced to the final parameter type, respecting the
+     * narrowing-conversions flag. This is handled in callCost(), or in convertCall() (via coerce).
      */
     bool determineFinalTypes(const std::vector<std::unique_ptr<Expression>>& arguments,
                              std::vector<const Type*>* outParameterTypes,
@@ -86,7 +90,7 @@
                 std::vector<const Type*> types = parameterType.coercibleTypes();
                 if (genericIndex == -1) {
                     for (size_t j = 0; j < types.size(); j++) {
-                        if (arguments[i]->type().canCoerceTo(*types[j])) {
+                        if (arguments[i]->type().canCoerceTo(*types[j], /*allowNarrowing=*/true)) {
                             genericIndex = j;
                             break;
                         }
diff --git a/src/sksl/ir/SkSLIntLiteral.h b/src/sksl/ir/SkSLIntLiteral.h
index 1ec16d5..27e3e0c 100644
--- a/src/sksl/ir/SkSLIntLiteral.h
+++ b/src/sksl/ir/SkSLIntLiteral.h
@@ -45,10 +45,10 @@
         return fValue == other.as<IntLiteral>().fValue;
     }
 
-    int coercionCost(const Type& target) const override {
+    CoercionCost coercionCost(const Type& target) const override {
         if (target.isSigned() || target.isUnsigned() || target.isFloat() ||
             target.typeKind() == Type::TypeKind::kEnum) {
-            return 0;
+            return CoercionCost::Free();
         }
         return INHERITED::coercionCost(target);
     }
diff --git a/src/sksl/ir/SkSLProgram.h b/src/sksl/ir/SkSLProgram.h
index b1520ef..6ae2d73 100644
--- a/src/sksl/ir/SkSLProgram.h
+++ b/src/sksl/ir/SkSLProgram.h
@@ -122,6 +122,9 @@
         int fInlineThreshold = 50;
         // true to enable optimization passes
         bool fOptimize = true;
+        // If true, implicit conversions to lower precision numeric types are allowed
+        // (eg, float to half)
+        bool fAllowNarrowingConversions = false;
     };
 
     struct Inputs {
diff --git a/src/sksl/ir/SkSLType.cpp b/src/sksl/ir/SkSLType.cpp
index bd2c497..81cfdbf 100644
--- a/src/sksl/ir/SkSLType.cpp
+++ b/src/sksl/ir/SkSLType.cpp
@@ -10,41 +10,45 @@
 
 namespace SkSL {
 
-int Type::coercionCost(const Type& other) const {
+CoercionCost Type::coercionCost(const Type& other) const {
     if (*this == other) {
-        return 0;
+        return CoercionCost::Free();
     }
     if (this->typeKind() == TypeKind::kNullable && other.typeKind() != TypeKind::kNullable) {
-        int result = this->componentType().coercionCost(other);
-        if (result != INT_MAX) {
-            ++result;
+        CoercionCost result = this->componentType().coercionCost(other);
+        if (result.isPossible(/*allowNarrowing=*/true)) {
+            ++result.fNormalCost;
         }
         return result;
     }
     if (this->fName == "null" && other.typeKind() == TypeKind::kNullable) {
-        return 0;
+        return CoercionCost::Free();
     }
     if (this->typeKind() == TypeKind::kVector && other.typeKind() == TypeKind::kVector) {
         if (this->columns() == other.columns()) {
             return this->componentType().coercionCost(other.componentType());
         }
-        return INT_MAX;
+        return CoercionCost::Impossible();
     }
     if (this->typeKind() == TypeKind::kMatrix) {
         if (this->columns() == other.columns() && this->rows() == other.rows()) {
             return this->componentType().coercionCost(other.componentType());
         }
-        return INT_MAX;
+        return CoercionCost::Impossible();
     }
-    if (this->isNumber() && other.isNumber() && other.priority() > this->priority()) {
-        return other.priority() - this->priority();
+    if (this->isNumber() && other.isNumber()) {
+        if (other.priority() >= this->priority()) {
+            return CoercionCost::Normal(other.priority() - this->priority());
+        } else {
+            return CoercionCost::Narrowing(this->priority() - other.priority());
+        }
     }
     for (size_t i = 0; i < fCoercibleTypes.size(); i++) {
         if (*fCoercibleTypes[i] == other) {
-            return (int) i + 1;
+            return CoercionCost::Normal((int) i + 1);
         }
     }
-    return INT_MAX;
+    return CoercionCost::Impossible();
 }
 
 const Type& Type::toCompound(const Context& context, int columns, int rows) const {
diff --git a/src/sksl/ir/SkSLType.h b/src/sksl/ir/SkSLType.h
index 88faf78..c5f99f7 100644
--- a/src/sksl/ir/SkSLType.h
+++ b/src/sksl/ir/SkSLType.h
@@ -13,6 +13,7 @@
 #include "src/sksl/ir/SkSLModifiers.h"
 #include "src/sksl/ir/SkSLSymbol.h"
 #include "src/sksl/spirv.h"
+#include <algorithm>
 #include <climits>
 #include <vector>
 #include <memory>
@@ -21,6 +22,34 @@
 
 class Context;
 
+struct CoercionCost {
+    static CoercionCost Free()              { return {    0,    0, false }; }
+    static CoercionCost Normal(int cost)    { return { cost,    0, false }; }
+    static CoercionCost Narrowing(int cost) { return {    0, cost, false }; }
+    static CoercionCost Impossible()        { return {    0,    0,  true }; }
+
+    bool isPossible(bool allowNarrowing) const {
+        return !fImpossible && (fNarrowingCost == 0 || allowNarrowing);
+    }
+
+    // Addition of two costs. Saturates at Impossible().
+    CoercionCost operator+(CoercionCost rhs) const {
+        if (fImpossible || rhs.fImpossible) {
+            return Impossible();
+        }
+        return { fNormalCost + rhs.fNormalCost, fNarrowingCost + rhs.fNarrowingCost, false };
+    }
+
+    bool operator<(CoercionCost rhs) const {
+        return std::tie(    fImpossible,     fNarrowingCost,     fNormalCost) <
+               std::tie(rhs.fImpossible, rhs.fNarrowingCost, rhs.fNormalCost);
+    }
+
+    int  fNormalCost;
+    int  fNarrowingCost;
+    bool fImpossible;
+};
+
 /**
  * Represents a type, such as int or float4.
  */
@@ -310,8 +339,8 @@
      * Returns true if an instance of this type can be freely coerced (implicitly converted) to
      * another type.
      */
-    bool canCoerceTo(const Type& other) const {
-        return coercionCost(other) != INT_MAX;
+    bool canCoerceTo(const Type& other, bool allowNarrowing) const {
+        return this->coercionCost(other).isPossible(allowNarrowing);
     }
 
     /**
@@ -319,7 +348,7 @@
      * is a number with no particular meaning other than that lower costs are preferable to higher
      * costs. Returns INT_MAX if the coercion is not possible.
      */
-    int coercionCost(const Type& other) const;
+    CoercionCost coercionCost(const Type& other) const;
 
     /**
      * For matrices and vectors, returns the type of individual cells (e.g. mat2 has a component
diff --git a/tests/SkRuntimeEffectTest.cpp b/tests/SkRuntimeEffectTest.cpp
index 2ddf507..0f91c0b 100644
--- a/tests/SkRuntimeEffectTest.cpp
+++ b/tests/SkRuntimeEffectTest.cpp
@@ -218,6 +218,10 @@
     effect.test(0xFF000000, 0xFF00007F, 0xFF007F00, 0xFF007F7F,
                 [](SkCanvas* canvas, SkPaint*) { canvas->rotate(45.0f); });
 
+    // Runtime effects should use relaxed precision rules by default
+    effect.build("", "return float4(p - 0.5, 0, 1);");
+    effect.test(0xFF000000, 0xFF0000FF, 0xFF00FF00, 0xFF00FFFF);
+
     //
     // Sampling children
     //