Eliminate inliner temporary variables for top-level-exit functions.

When we determine that a function only contains a single return
statement and it is at the top level (i.e. not inside any scopes),
there is no need to create a temporary variable and store the
result expression into a variable. Instead, we can directly replace
the function-call expression with the return-statement's expression.

Unlike my previous solution, this does not require variable
declarations to be rewritten. The no-scopes limitation makes it
slightly less effective in theory, but in practice we still get
almost all of the benefit. The no-scope limitation bites us on
structures like

@if (true) {
    return x;
} else {
    return y;
}

Which will optimize away the if, but leave the scope:

{
    return x;
}

However, this is not a big deal; the biggest wins are single-line
helper functions like `guarded_divide` and `unpremul` which retain
the full benefit.

Change-Id: I7fbb725e65db021b9795c04c816819669815578f
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/345167
Commit-Queue: John Stiles <johnstiles@google.com>
Auto-Submit: John Stiles <johnstiles@google.com>
Reviewed-by: Ethan Nicholas <ethannicholas@google.com>
diff --git a/src/sksl/SkSLInliner.cpp b/src/sksl/SkSLInliner.cpp
index c63707e..2f7f621 100644
--- a/src/sksl/SkSLInliner.cpp
+++ b/src/sksl/SkSLInliner.cpp
@@ -57,32 +57,6 @@
 
 static constexpr int kInlinedStatementLimit = 2500;
 
-static bool contains_returns_above_limit(const FunctionDefinition& funcDef, int limit) {
-    class CountReturnsWithLimit : public ProgramVisitor {
-    public:
-        CountReturnsWithLimit(const FunctionDefinition& funcDef, int limit) : fLimit(limit) {
-            this->visitProgramElement(funcDef);
-        }
-
-        bool visitStatement(const Statement& stmt) override {
-            switch (stmt.kind()) {
-                case Statement::Kind::kReturn:
-                    ++fNumReturns;
-                    return (fNumReturns > fLimit) || INHERITED::visitStatement(stmt);
-
-                default:
-                    return INHERITED::visitStatement(stmt);
-            }
-        }
-
-        int fNumReturns = 0;
-        int fLimit = 0;
-        using INHERITED = ProgramVisitor;
-    };
-
-    return CountReturnsWithLimit{funcDef, limit}.fNumReturns > limit;
-}
-
 static int count_returns_at_end_of_control_flow(const FunctionDefinition& funcDef) {
     class CountReturnsAtEndOfControlFlow : public ProgramVisitor {
     public:
@@ -155,11 +129,6 @@
     return CountReturnsInBreakableConstructs{funcDef}.fNumReturns;
 }
 
-static bool has_early_return(const FunctionDefinition& funcDef) {
-    int returnsAtEndOfControlFlow = count_returns_at_end_of_control_flow(funcDef);
-    return contains_returns_above_limit(funcDef, returnsAtEndOfControlFlow);
-}
-
 static bool contains_recursive_call(const FunctionDeclaration& funcDecl) {
     class ContainsRecursiveCall : public ProgramVisitor {
     public:
@@ -244,8 +213,53 @@
     return clone;
 }
 
+class CountReturnsWithLimit : public ProgramVisitor {
+public:
+    CountReturnsWithLimit(const FunctionDefinition& funcDef, int limit) : fLimit(limit) {
+        this->visitProgramElement(funcDef);
+    }
+
+    bool visitStatement(const Statement& stmt) override {
+        switch (stmt.kind()) {
+            case Statement::Kind::kReturn: {
+                ++fNumReturns;
+                fDeepestReturn = std::max(fDeepestReturn, fScopedBlockDepth);
+                return (fNumReturns >= fLimit) || INHERITED::visitStatement(stmt);
+            }
+            case Statement::Kind::kBlock: {
+                int depthIncrement = stmt.as<Block>().isScope() ? 1 : 0;
+                fScopedBlockDepth += depthIncrement;
+                bool result = INHERITED::visitStatement(stmt);
+                fScopedBlockDepth -= depthIncrement;
+                return result;
+            }
+            default:
+                return INHERITED::visitStatement(stmt);
+        }
+    }
+
+    int fNumReturns = 0;
+    int fDeepestReturn = 0;
+    int fLimit = 0;
+    int fScopedBlockDepth = 0;
+    using INHERITED = ProgramVisitor;
+};
+
 }  // namespace
 
+Inliner::ReturnComplexity Inliner::GetReturnComplexity(const FunctionDefinition& funcDef) {
+    int returnsAtEndOfControlFlow = count_returns_at_end_of_control_flow(funcDef);
+    CountReturnsWithLimit counter{funcDef, returnsAtEndOfControlFlow + 1};
+
+    if (counter.fNumReturns > returnsAtEndOfControlFlow) {
+        return ReturnComplexity::kEarlyReturns;
+    }
+    if (counter.fNumReturns > 1 || counter.fDeepestReturn > 1) {
+        return ReturnComplexity::kScopedReturns;
+    }
+    return ReturnComplexity::kSingleTopLevelReturn;
+}
+
 void Inliner::ensureScopedBlocks(Statement* inlinedBody, Statement* parentStmt) {
     // No changes necessary if this statement isn't actually a block.
     if (!inlinedBody || !inlinedBody->is<Block>()) {
@@ -430,14 +444,14 @@
 std::unique_ptr<Statement> Inliner::inlineStatement(int offset,
                                                     VariableRewriteMap* varMap,
                                                     SymbolTable* symbolTableForStatement,
-                                                    const Expression* resultExpr,
-                                                    bool haveEarlyReturns,
+                                                    std::unique_ptr<Expression>* resultExpr,
+                                                    ReturnComplexity returnComplexity,
                                                     const Statement& statement,
                                                     bool isBuiltinCode) {
     auto stmt = [&](const std::unique_ptr<Statement>& s) -> std::unique_ptr<Statement> {
         if (s) {
             return this->inlineStatement(offset, varMap, symbolTableForStatement, resultExpr,
-                                         haveEarlyReturns, *s, isBuiltinCode);
+                                         returnComplexity, *s, isBuiltinCode);
         }
         return nullptr;
     };
@@ -506,33 +520,52 @@
             return statement.clone();
         case Statement::Kind::kReturn: {
             const ReturnStatement& r = statement.as<ReturnStatement>();
-            if (r.expression()) {
-                SkASSERT(resultExpr);
-                auto assignment =
-                        std::make_unique<ExpressionStatement>(std::make_unique<BinaryExpression>(
-                                offset,
-                                clone_with_ref_kind(*resultExpr,
-                                                    VariableReference::RefKind::kWrite),
-                                Token::Kind::TK_EQ,
-                                expr(r.expression()),
-                                &resultExpr->type()));
-                if (haveEarlyReturns) {
-                    StatementArray block;
-                    block.reserve_back(2);
-                    block.push_back(std::move(assignment));
-                    block.push_back(std::make_unique<ContinueStatement>(offset));
-                    return std::make_unique<Block>(offset, std::move(block),
-                                                   /*symbols=*/nullptr, /*isScope=*/true);
-                } else {
-                    return std::move(assignment);
-                }
-            } else {
-                if (haveEarlyReturns) {
+            if (!r.expression()) {
+                if (returnComplexity >= ReturnComplexity::kEarlyReturns) {
+                    // This function doesn't return a value, but has early returns, so we've wrapped
+                    // it in a for loop. Use a continue to jump to the end of the loop and "leave"
+                    // the function.
                     return std::make_unique<ContinueStatement>(offset);
                 } else {
+                    // This function doesn't exit early or return a value. A return statement at the
+                    // end is a no-op and can be treated as such.
                     return std::make_unique<Nop>();
                 }
             }
+
+            // For a function that only contains a single top-level return, we don't need to store
+            // the result in a variable at all. Just move the return value right into the result
+            // expression.
+            SkASSERT(resultExpr);
+            SkASSERT(*resultExpr);
+            if (returnComplexity <= ReturnComplexity::kSingleTopLevelReturn) {
+                *resultExpr = expr(r.expression());
+                return std::make_unique<Nop>();
+            }
+
+            // For more complex functions, assign their result into a variable.
+            auto assignment =
+                    std::make_unique<ExpressionStatement>(std::make_unique<BinaryExpression>(
+                            offset,
+                            clone_with_ref_kind(**resultExpr, VariableReference::RefKind::kWrite),
+                            Token::Kind::TK_EQ,
+                            expr(r.expression()),
+                            &resultExpr->get()->type()));
+
+            // Early returns are wrapped in a for loop; we need to synthesize a continue statement
+            // to "leave" the function.
+            if (returnComplexity >= ReturnComplexity::kEarlyReturns) {
+                StatementArray block;
+                block.reserve_back(2);
+                block.push_back(std::move(assignment));
+                block.push_back(std::make_unique<ContinueStatement>(offset));
+                return std::make_unique<Block>(offset, std::move(block), /*symbols=*/nullptr,
+                                               /*isScope=*/true);
+            }
+            // Functions without early returns aren't wrapped in a for loop and don't need to worry
+            // about breaking out of the control flow.
+            return std::move(assignment);
+
         }
         case Statement::Kind::kSwitch: {
             const SwitchStatement& ss = statement.as<SwitchStatement>();
@@ -642,7 +675,8 @@
     ExpressionArray& arguments = call->arguments();
     const int offset = call->fOffset;
     const FunctionDefinition& function = *call->function().definition();
-    const bool hasEarlyReturn = has_early_return(function);
+    const ReturnComplexity returnComplexity = GetReturnComplexity(function);
+    bool hasEarlyReturn = (returnComplexity >= ReturnComplexity::kEarlyReturns);
 
     InlinedCall inlinedCall;
     inlinedCall.fInlinedBody = std::make_unique<Block>(offset, StatementArray{},
@@ -751,7 +785,7 @@
     inlineStatements->reserve_back(body.children().size() + argsToCopyBack.size());
     for (const std::unique_ptr<Statement>& stmt : body.children()) {
         inlineStatements->push_back(this->inlineStatement(offset, &varMap, symbolTable.get(),
-                                                          resultExpr.get(), hasEarlyReturn, *stmt,
+                                                          &resultExpr, returnComplexity, *stmt,
                                                           caller->isBuiltin()));
     }
 
@@ -770,8 +804,6 @@
 
     if (resultExpr != nullptr) {
         // Return our result variable as our replacement expression.
-        SkASSERT(resultExpr->as<VariableReference>().refKind() ==
-                 VariableReference::RefKind::kRead);
         inlinedCall.fReplacementExpr = std::move(resultExpr);
     } else {
         // It's a void function, so it doesn't actually result in anything, but we have to return
@@ -805,9 +837,6 @@
     // We don't have any mechanism to simulate early returns within a breakable construct
     // (switch/for/do/while), so we can't inline if there's a return inside one.
     bool hasReturnInBreakableConstruct = (count_returns_in_breakable_constructs(*functionDef) > 0);
-
-    // If we detected returns in breakable constructs, we should also detect an early return.
-    SkASSERT(!hasReturnInBreakableConstruct || has_early_return(*functionDef));
     return !hasReturnInBreakableConstruct;
 }