Add GetConstantValueForVariable to constant-folder.

This can be used in places which simplify an expression at IR generation
time. e.g. `const bool x = true; @if (x) {...}` could be evaluated at
compile time by calling GetConstantValueForVariable before checking for
a BoolLiteral and flattening the if. It is also used at the top of
ConstantFolder::Simplify to reduce either side of the binary
expression being simplified, in case it's a const variable.

We don't use a lot of const variables in our tests, but this does
improve the code generation in ConstVariableComparison.sksl
when constant-propagation is disabled: http://screen/gnWPhQG8Jc5cER9

Change-Id: I26309769cd16a6833b74b11a115b87c3dc312514
Bug: skia:11343
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/378017
Commit-Queue: John Stiles <johnstiles@google.com>
Auto-Submit: John Stiles <johnstiles@google.com>
Reviewed-by: Brian Osman <brianosman@google.com>
diff --git a/src/sksl/SkSLConstantFolder.cpp b/src/sksl/SkSLConstantFolder.cpp
index 355b53c..57a5194 100644
--- a/src/sksl/SkSLConstantFolder.cpp
+++ b/src/sksl/SkSLConstantFolder.cpp
@@ -123,33 +123,21 @@
 }
 
 bool ConstantFolder::GetConstantInt(const Expression& value, SKSL_INT* out) {
-    switch (value.kind()) {
-        case Expression::Kind::kIntLiteral:
-            *out = value.as<IntLiteral>().value();
-            return true;
-        case Expression::Kind::kVariableReference: {
-            const Variable& var = *value.as<VariableReference>().variable();
-            return (var.modifiers().fFlags & Modifiers::kConst_Flag) &&
-                   var.initialValue() && GetConstantInt(*var.initialValue(), out);
-        }
-        default:
-            return false;
+    const Expression* expr = GetConstantValueForVariable(value);
+    if (!expr->is<IntLiteral>()) {
+        return false;
     }
+    *out = expr->as<IntLiteral>().value();
+    return true;
 }
 
 bool ConstantFolder::GetConstantFloat(const Expression& value, SKSL_FLOAT* out) {
-    switch (value.kind()) {
-        case Expression::Kind::kFloatLiteral:
-            *out = value.as<FloatLiteral>().value();
-            return true;
-        case Expression::Kind::kVariableReference: {
-            const Variable& var = *value.as<VariableReference>().variable();
-            return (var.modifiers().fFlags & Modifiers::kConst_Flag) &&
-                   var.initialValue() && GetConstantFloat(*var.initialValue(), out);
-        }
-        default:
-            return false;
+    const Expression* expr = GetConstantValueForVariable(value);
+    if (!expr->is<FloatLiteral>()) {
+        return false;
     }
+    *out = expr->as<FloatLiteral>().value();
+    return true;
 }
 
 static bool contains_constant_zero(const Expression& expr) {
@@ -161,10 +149,9 @@
         }
         return false;
     }
-    SKSL_INT intValue;
-    SKSL_FLOAT floatValue;
-    return (ConstantFolder::GetConstantInt(expr, &intValue) && intValue == 0) ||
-           (ConstantFolder::GetConstantFloat(expr, &floatValue) && floatValue == 0.0f);
+    const Expression* value = ConstantFolder::GetConstantValueForVariable(expr);
+    return (value->is<IntLiteral>()   && value->as<IntLiteral>().value()   == 0.0) ||
+           (value->is<FloatLiteral>() && value->as<FloatLiteral>().value() == 0.0);
 }
 
 bool ConstantFolder::ErrorOnDivideByZero(const Context& context, int offset, Operator op,
@@ -184,28 +171,58 @@
     }
 }
 
+const Expression* ConstantFolder::GetConstantValueForVariable(const Expression& inExpr) {
+    for (const Expression* expr = &inExpr;;) {
+        if (!expr->is<VariableReference>()) {
+            break;
+        }
+        const VariableReference& varRef = expr->as<VariableReference>();
+        if (varRef.refKind() != VariableRefKind::kRead) {
+            break;
+        }
+        const Variable& var = *varRef.variable();
+        if (!(var.modifiers().fFlags & Modifiers::kConst_Flag)) {
+            break;
+        }
+        expr = var.initialValue();
+        SkASSERT(expr);
+        if (expr->isCompileTimeConstant()) {
+            return expr;
+        }
+        if (!expr->is<VariableReference>()) {
+            break;
+        }
+    }
+    // We didn't find a compile-time constant at the end. Return the expression as-is.
+    return &inExpr;
+}
+
 std::unique_ptr<Expression> ConstantFolder::Simplify(const Context& context,
                                                      int offset,
-                                                     const Expression& left,
+                                                     const Expression& leftExpr,
                                                      Operator op,
-                                                     const Expression& right) {
+                                                     const Expression& rightExpr) {
+    // Replace constant variables with trivial initial-values.
+    const Expression* left = GetConstantValueForVariable(leftExpr);
+    const Expression* right = GetConstantValueForVariable(rightExpr);
+
     // If this is the comma operator, the left side is evaluated but not otherwise used in any way.
     // So if the left side has no side effects, it can just be eliminated entirely.
-    if (op.kind() == Token::Kind::TK_COMMA && !left.hasSideEffects()) {
-        return right.clone();
+    if (op.kind() == Token::Kind::TK_COMMA && !left->hasSideEffects()) {
+        return right->clone();
     }
 
     // If this is the assignment operator, and both sides are the same trivial expression, this is
     // self-assignment (i.e., `var = var`) and can be reduced to just a variable reference (`var`).
     // This can happen when other parts of the assignment are optimized away.
-    if (op.kind() == Token::Kind::TK_EQ && Analysis::IsSelfAssignment(left, right)) {
-        return right.clone();
+    if (op.kind() == Token::Kind::TK_EQ && Analysis::IsSelfAssignment(*left, *right)) {
+        return right->clone();
     }
 
     // Simplify the expression when both sides are constant Boolean literals.
-    if (left.is<BoolLiteral>() && right.is<BoolLiteral>()) {
-        bool leftVal  = left.as<BoolLiteral>().value();
-        bool rightVal = right.as<BoolLiteral>().value();
+    if (left->is<BoolLiteral>() && right->is<BoolLiteral>()) {
+        bool leftVal  = left->as<BoolLiteral>().value();
+        bool rightVal = right->as<BoolLiteral>().value();
         bool result;
         switch (op.kind()) {
             case Token::Kind::TK_LOGICALAND: result = leftVal && rightVal; break;
@@ -219,28 +236,28 @@
     }
 
     // If the left side is a Boolean literal, apply short-circuit optimizations.
-    if (left.is<BoolLiteral>()) {
-        return short_circuit_boolean(left, op, right);
+    if (left->is<BoolLiteral>()) {
+        return short_circuit_boolean(*left, op, *right);
     }
 
     // If the right side is a Boolean literal...
-    if (right.is<BoolLiteral>()) {
+    if (right->is<BoolLiteral>()) {
         // ... and the left side has no side effects...
-        if (!left.hasSideEffects()) {
+        if (!left->hasSideEffects()) {
             // We can reverse the expressions and short-circuit optimizations are still valid.
-            return short_circuit_boolean(right, op, left);
+            return short_circuit_boolean(*right, op, *left);
         }
 
         // We can't use short-circuiting, but we can still optimize away no-op Boolean expressions.
-        return eliminate_no_op_boolean(left, op, right);
+        return eliminate_no_op_boolean(*left, op, *right);
     }
 
-    if (ErrorOnDivideByZero(context, offset, op, right)) {
+    if (ErrorOnDivideByZero(context, offset, op, *right)) {
         return nullptr;
     }
 
     // Other than the short-circuit cases above, constant folding requires both sides to be constant
-    if (!left.isCompileTimeConstant() || !right.isCompileTimeConstant()) {
+    if (!left->isCompileTimeConstant() || !right->isCompileTimeConstant()) {
         return nullptr;
     }
 
@@ -253,9 +270,9 @@
     #define URESULT(t, op) std::make_unique<t ## Literal>(context, offset,       \
                                                           (SKSL_UINT) leftVal op \
                                                           (SKSL_UINT) rightVal)
-    if (left.is<IntLiteral>() && right.is<IntLiteral>()) {
-        SKSL_INT leftVal  = left.as<IntLiteral>().value();
-        SKSL_INT rightVal = right.as<IntLiteral>().value();
+    if (left->is<IntLiteral>() && right->is<IntLiteral>()) {
+        SKSL_INT leftVal  = left->as<IntLiteral>().value();
+        SKSL_INT rightVal = right->as<IntLiteral>().value();
         switch (op.kind()) {
             case Token::Kind::TK_PLUS:       return URESULT(Int, +);
             case Token::Kind::TK_MINUS:      return URESULT(Int, -);
@@ -302,9 +319,9 @@
     }
 
     // Perform constant folding on pairs of floating-point literals.
-    if (left.is<FloatLiteral>() && right.is<FloatLiteral>()) {
-        SKSL_FLOAT leftVal  = left.as<FloatLiteral>().value();
-        SKSL_FLOAT rightVal = right.as<FloatLiteral>().value();
+    if (left->is<FloatLiteral>() && right->is<FloatLiteral>()) {
+        SKSL_FLOAT leftVal  = left->as<FloatLiteral>().value();
+        SKSL_FLOAT rightVal = right->as<FloatLiteral>().value();
         switch (op.kind()) {
             case Token::Kind::TK_PLUS:  return RESULT(Float, +);
             case Token::Kind::TK_MINUS: return RESULT(Float, -);
@@ -321,14 +338,14 @@
     }
 
     // Perform constant folding on pairs of vectors.
-    const Type& leftType = left.type();
-    const Type& rightType = right.type();
+    const Type& leftType = left->type();
+    const Type& rightType = right->type();
     if (leftType.isVector() && leftType == rightType) {
         if (leftType.componentType().isFloat()) {
-            return simplify_vector<SKSL_FLOAT>(context, left, op, right);
+            return simplify_vector<SKSL_FLOAT>(context, *left, op, *right);
         }
         if (leftType.componentType().isInteger()) {
-            return simplify_vector<SKSL_INT, SKSL_UINT>(context, left, op, right);
+            return simplify_vector<SKSL_INT, SKSL_UINT>(context, *left, op, *right);
         }
         return nullptr;
     }
@@ -336,11 +353,12 @@
     // Perform constant folding on vectors against scalars, e.g.: half4(2) + 2
     if (leftType.isVector() && leftType.componentType() == rightType) {
         if (rightType.isFloat()) {
-            return simplify_vector<SKSL_FLOAT>(context, left, op, splat_scalar(right, left.type()));
+            return simplify_vector<SKSL_FLOAT>(context, *left, op,
+                                               splat_scalar(*right, left->type()));
         }
         if (rightType.isInteger()) {
-            return simplify_vector<SKSL_INT, SKSL_UINT>(context, left, op,
-                                                        splat_scalar(right, left.type()));
+            return simplify_vector<SKSL_INT, SKSL_UINT>(context, *left, op,
+                                                        splat_scalar(*right, left->type()));
         }
         return nullptr;
     }
@@ -348,12 +366,12 @@
     // Perform constant folding on scalars against vectors, e.g.: 2 + half4(2)
     if (rightType.isVector() && rightType.componentType() == leftType) {
         if (leftType.isFloat()) {
-            return simplify_vector<SKSL_FLOAT>(context, splat_scalar(left, right.type()), op,
-                                               right);
+            return simplify_vector<SKSL_FLOAT>(context, splat_scalar(*left, right->type()), op,
+                                               *right);
         }
         if (leftType.isInteger()) {
-            return simplify_vector<SKSL_INT, SKSL_UINT>(context, splat_scalar(left, right.type()),
-                                                        op, right);
+            return simplify_vector<SKSL_INT, SKSL_UINT>(context, splat_scalar(*left, right->type()),
+                                                        op, *right);
         }
         return nullptr;
     }
@@ -372,7 +390,7 @@
                 return nullptr;
         }
 
-        switch (left.compareConstant(right)) {
+        switch (left->compareConstant(*right)) {
             case Expression::ComparisonResult::kNotEqual:
                 equality = !equality;
                 [[fallthrough]];