Migrate function-body finalization out of IRGenerator.

This is a first step towards replacing `finalizeFunction` with a
`FunctionDefinition::Convert` method living outside of the IRGenerator.

Previously this code would assert that we had no early returns from a
vertex-program main() method; this has been turned into an error.
(The original assertion was also tied to fRTFlip, because the *problem*
with early-returns in main is tied to the lack of RTFlip fixups, but
we fundamentally don't allow early returns, so it makes more sense to
just universally disallow it.)

Change-Id: Iba0742f7ef3cbc83995ea130fec1eb1ef2556c44
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/442691
Auto-Submit: John Stiles <johnstiles@google.com>
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
Reviewed-by: Ethan Nicholas <ethannicholas@google.com>
diff --git a/gn/sksl.gni b/gn/sksl.gni
index 0c394b7..0a59ba5 100644
--- a/gn/sksl.gni
+++ b/gn/sksl.gni
@@ -145,6 +145,7 @@
   "$_src/sksl/ir/SkSLFunctionCall.h",
   "$_src/sksl/ir/SkSLFunctionDeclaration.cpp",
   "$_src/sksl/ir/SkSLFunctionDeclaration.h",
+  "$_src/sksl/ir/SkSLFunctionDefinition.cpp",
   "$_src/sksl/ir/SkSLFunctionDefinition.h",
   "$_src/sksl/ir/SkSLFunctionPrototype.h",
   "$_src/sksl/ir/SkSLFunctionReference.h",
diff --git a/gn/sksl_tests.gni b/gn/sksl_tests.gni
index 1309c0f..9f1a1c8 100644
--- a/gn/sksl_tests.gni
+++ b/gn/sksl_tests.gni
@@ -151,6 +151,7 @@
   "/sksl/errors/UnscopedVariableInWhile.sksl",
   "/sksl/errors/UsingInvalidValue.sksl",
   "/sksl/errors/VectorSlice.sksl",
+  "/sksl/errors/VertexEarlyReturn.vert",
   "/sksl/errors/WhileTypeMismatch.sksl",
 ]
 
@@ -414,7 +415,6 @@
   "/sksl/shared/VectorConstructors.sksl",
   "/sksl/shared/VectorScalarMath.sksl",
   "/sksl/shared/VectorToMatrixCast.sksl",
-  "/sksl/shared/VertexEarlyReturn.vert",
   "/sksl/shared/VertexID.vert",
   "/sksl/shared/WhileLoopControlFlow.sksl",
 ]
diff --git a/resources/sksl/shared/VertexEarlyReturn.vert b/resources/sksl/errors/VertexEarlyReturn.vert
similarity index 100%
rename from resources/sksl/shared/VertexEarlyReturn.vert
rename to resources/sksl/errors/VertexEarlyReturn.vert
diff --git a/src/sksl/SkSLIRGenerator.cpp b/src/sksl/SkSLIRGenerator.cpp
index 5206457..1c8bc24 100644
--- a/src/sksl/SkSLIRGenerator.cpp
+++ b/src/sksl/SkSLIRGenerator.cpp
@@ -764,113 +764,6 @@
 std::unique_ptr<Block> IRGenerator::finalizeFunction(const FunctionDeclaration& funcDecl,
                                                      std::unique_ptr<Block> body,
                                                      IntrinsicSet* referencedIntrinsics) {
-    class Finalizer : public ProgramWriter {
-    public:
-        Finalizer(IRGenerator* irGenerator, const FunctionDeclaration* function,
-                  IntrinsicSet* referencedIntrinsics)
-            : fIRGenerator(irGenerator)
-            , fFunction(function)
-            , fReferencedIntrinsics(referencedIntrinsics) {}
-
-        ~Finalizer() override {
-            SkASSERT(!fBreakableLevel);
-            SkASSERT(!fContinuableLevel);
-        }
-
-        bool functionReturnsValue() const {
-            return !fFunction->returnType().isVoid();
-        }
-
-        bool visitExpression(Expression& expr) override {
-            if (expr.is<FunctionCall>()) {
-                const FunctionDeclaration& func = expr.as<FunctionCall>().function();
-                if (func.isBuiltin() && func.definition()) {
-                    fReferencedIntrinsics->insert(&func);
-                }
-            }
-            return INHERITED::visitExpression(expr);
-        }
-
-        bool visitStatement(Statement& stmt) override {
-            switch (stmt.kind()) {
-                case Statement::Kind::kReturn: {
-                    // early returns from a vertex main function will bypass the sk_Position
-                    // normalization, so SkASSERT that we aren't doing that. It is of course
-                    // possible to fix this by adding a normalization before each return, but it
-                    // will probably never actually be necessary.
-                    SkASSERT(fIRGenerator->programKind() != ProgramKind::kVertex ||
-                             !fIRGenerator->fRTAdjust ||
-                             !fFunction->isMain());
-
-                    // Verify that the return statement matches the function's return type.
-                    ReturnStatement& returnStmt = stmt.as<ReturnStatement>();
-                    const Type& returnType = fFunction->returnType();
-                    if (returnStmt.expression()) {
-                        if (this->functionReturnsValue()) {
-                            // Coerce return expression to the function's return type.
-                            returnStmt.setExpression(fIRGenerator->coerce(
-                                    std::move(returnStmt.expression()), returnType));
-                        } else {
-                            // Returning something from a function with a void return type.
-                            returnStmt.setExpression(nullptr);
-                            fIRGenerator->errorReporter().error(returnStmt.fOffset,
-                                                     "may not return a value from a void function");
-                        }
-                    } else {
-                        if (this->functionReturnsValue()) {
-                            // Returning nothing from a function with a non-void return type.
-                            fIRGenerator->errorReporter().error(returnStmt.fOffset,
-                                  "expected function to return '" + returnType.displayName() + "'");
-                        }
-                    }
-                    break;
-                }
-                case Statement::Kind::kDo:
-                case Statement::Kind::kFor: {
-                    ++fBreakableLevel;
-                    ++fContinuableLevel;
-                    bool result = INHERITED::visitStatement(stmt);
-                    --fContinuableLevel;
-                    --fBreakableLevel;
-                    return result;
-                }
-                case Statement::Kind::kSwitch: {
-                    ++fBreakableLevel;
-                    bool result = INHERITED::visitStatement(stmt);
-                    --fBreakableLevel;
-                    return result;
-                }
-                case Statement::Kind::kBreak:
-                    if (!fBreakableLevel) {
-                        fIRGenerator->errorReporter().error(stmt.fOffset,
-                                                 "break statement must be inside a loop or switch");
-                    }
-                    break;
-                case Statement::Kind::kContinue:
-                    if (!fContinuableLevel) {
-                        fIRGenerator->errorReporter().error(stmt.fOffset,
-                                                        "continue statement must be inside a loop");
-                    }
-                    break;
-                default:
-                    break;
-            }
-            return INHERITED::visitStatement(stmt);
-        }
-
-    private:
-        IRGenerator* fIRGenerator;
-        const FunctionDeclaration* fFunction;
-        // which intrinsics have we encountered in this function
-        IntrinsicSet* fReferencedIntrinsics;
-        // how deeply nested we are in breakable constructs (for, do, switch).
-        int fBreakableLevel = 0;
-        // how deeply nested we are in continuable constructs (for, do).
-        int fContinuableLevel = 0;
-
-        using INHERITED = ProgramWriter;
-    };
-
     bool isMain = funcDecl.isMain();
     bool needInvocationIDWorkaround = fInvocations != -1 && isMain &&
                                       !this->caps().gsInvocationsSupport();
@@ -881,13 +774,7 @@
         body->children().push_back(this->getNormalizeSkPositionCode());
     }
 
-    Finalizer finalizer{this, &funcDecl, referencedIntrinsics};
-    finalizer.visitStatement(*body);
-
-    if (Analysis::CanExitWithoutReturningValue(funcDecl, *body)) {
-        this->errorReporter().error(funcDecl.fOffset, "function '" + funcDecl.name() +
-                                                      "' can exit without returning a value");
-    }
+    FunctionDefinition::FinalizeFunctionBody(fContext, funcDecl, body.get(), referencedIntrinsics);
     return body;
 }
 
diff --git a/src/sksl/ir/SkSLFunctionDefinition.cpp b/src/sksl/ir/SkSLFunctionDefinition.cpp
new file mode 100644
index 0000000..4fdaa28
--- /dev/null
+++ b/src/sksl/ir/SkSLFunctionDefinition.cpp
@@ -0,0 +1,137 @@
+/*
+ * Copyright 2021 Google LLC
+ *
+ * Use of this source code is governed by a BSD-style license that can be
+ * found in the LICENSE file.
+ */
+
+#include "src/sksl/SkSLAnalysis.h"
+#include "src/sksl/SkSLContext.h"
+#include "src/sksl/SkSLProgramSettings.h"
+#include "src/sksl/ir/SkSLFunctionCall.h"
+#include "src/sksl/ir/SkSLFunctionDefinition.h"
+#include "src/sksl/ir/SkSLReturnStatement.h"
+
+namespace SkSL {
+
+void FunctionDefinition::FinalizeFunctionBody(const Context& context,
+                                              const FunctionDeclaration& function,
+                                              Statement* body,
+                                              IntrinsicSet* referencedIntrinsics) {
+    class Finalizer : public ProgramWriter {
+    public:
+        Finalizer(const Context& context, const FunctionDeclaration& function,
+                  IntrinsicSet* referencedIntrinsics)
+            : fContext(context)
+            , fFunction(function)
+            , fReferencedIntrinsics(referencedIntrinsics) {}
+
+        ~Finalizer() override {
+            SkASSERT(!fBreakableLevel);
+            SkASSERT(!fContinuableLevel);
+        }
+
+        bool functionReturnsValue() const {
+            return !fFunction.returnType().isVoid();
+        }
+
+        bool visitExpression(Expression& expr) override {
+            if (expr.is<FunctionCall>()) {
+                const FunctionDeclaration& func = expr.as<FunctionCall>().function();
+                if (func.isBuiltin() && func.definition()) {
+                    fReferencedIntrinsics->insert(&func);
+                }
+            }
+            return INHERITED::visitExpression(expr);
+        }
+
+        bool visitStatement(Statement& stmt) override {
+            switch (stmt.kind()) {
+                case Statement::Kind::kReturn: {
+                    // Early returns from a vertex main() function will bypass sk_Position
+                    // normalization, so SkASSERT that we aren't doing that. If this becomes an
+                    // issue, we can add normalization before each return statement.
+                    if (fContext.fConfig->fKind == ProgramKind::kVertex && fFunction.isMain()) {
+                        fContext.fErrors->error(
+                                stmt.fOffset,
+                                "early returns from vertex programs are not supported");
+                    }
+
+                    // Verify that the return statement matches the function's return type.
+                    ReturnStatement& returnStmt = stmt.as<ReturnStatement>();
+                    if (returnStmt.expression()) {
+                        if (this->functionReturnsValue()) {
+                            // Coerce return expression to the function's return type.
+                            returnStmt.setExpression(fFunction.returnType().coerceExpression(
+                                    std::move(returnStmt.expression()), fContext));
+                        } else {
+                            // Returning something from a function with a void return type.
+                            returnStmt.setExpression(nullptr);
+                            fContext.fErrors->error(returnStmt.fOffset,
+                                                    "may not return a value from a void function");
+                        }
+                    } else {
+                        if (this->functionReturnsValue()) {
+                            // Returning nothing from a function with a non-void return type.
+                            fContext.fErrors->error(returnStmt.fOffset,
+                                                    "expected function to return '" +
+                                                    fFunction.returnType().displayName() + "'");
+                        }
+                    }
+                    break;
+                }
+                case Statement::Kind::kDo:
+                case Statement::Kind::kFor: {
+                    ++fBreakableLevel;
+                    ++fContinuableLevel;
+                    bool result = INHERITED::visitStatement(stmt);
+                    --fContinuableLevel;
+                    --fBreakableLevel;
+                    return result;
+                }
+                case Statement::Kind::kSwitch: {
+                    ++fBreakableLevel;
+                    bool result = INHERITED::visitStatement(stmt);
+                    --fBreakableLevel;
+                    return result;
+                }
+                case Statement::Kind::kBreak:
+                    if (!fBreakableLevel) {
+                        fContext.fErrors->error(stmt.fOffset,
+                                                "break statement must be inside a loop or switch");
+                    }
+                    break;
+                case Statement::Kind::kContinue:
+                    if (!fContinuableLevel) {
+                        fContext.fErrors->error(stmt.fOffset,
+                                                "continue statement must be inside a loop");
+                    }
+                    break;
+                default:
+                    break;
+            }
+            return INHERITED::visitStatement(stmt);
+        }
+
+    private:
+        const Context& fContext;
+        const FunctionDeclaration& fFunction;
+        // which intrinsics have we encountered in this function
+        IntrinsicSet* fReferencedIntrinsics;
+        // how deeply nested we are in breakable constructs (for, do, switch).
+        int fBreakableLevel = 0;
+        // how deeply nested we are in continuable constructs (for, do).
+        int fContinuableLevel = 0;
+
+        using INHERITED = ProgramWriter;
+    };
+
+    Finalizer(context, function, referencedIntrinsics).visitStatement(*body);
+
+    if (Analysis::CanExitWithoutReturningValue(function, *body)) {
+        context.fErrors->error(function.fOffset, "function '" + function.name() +
+                                                 "' can exit without returning a value");
+    }
+}
+
+}  // namespace SkSL
diff --git a/src/sksl/ir/SkSLFunctionDefinition.h b/src/sksl/ir/SkSLFunctionDefinition.h
index c9c554d..7154093 100644
--- a/src/sksl/ir/SkSLFunctionDefinition.h
+++ b/src/sksl/ir/SkSLFunctionDefinition.h
@@ -34,6 +34,18 @@
         , fReferencedIntrinsics(std::move(referencedIntrinsics))
         , fSource(nullptr) {}
 
+    /**
+     * Coerces `return` statements to the return type of the function, and reports errors in the
+     * function that can't be detected at the individual statement level:
+     *     - `break` and `continue` statements must be in reasonable places.
+     *     - non-void functions are required to return a value on all paths.
+     *     - vertex main() functions don't allow early returns.
+     */
+    static void FinalizeFunctionBody(const Context& context,
+                                     const FunctionDeclaration& function,
+                                     Statement* body,
+                                     IntrinsicSet* referencedIntrinsics);
+
     const FunctionDeclaration& declaration() const {
         return *fDeclaration;
     }
diff --git a/tests/sksl/errors/VertexEarlyReturn.glsl b/tests/sksl/errors/VertexEarlyReturn.glsl
new file mode 100644
index 0000000..8023605
--- /dev/null
+++ b/tests/sksl/errors/VertexEarlyReturn.glsl
@@ -0,0 +1,4 @@
+### Compilation failed:
+
+error: 5: early returns from vertex programs are not supported
+1 error
diff --git a/tests/sksl/shared/VertexEarlyReturn.asm.vert b/tests/sksl/shared/VertexEarlyReturn.asm.vert
deleted file mode 100644
index 0f6a758..0000000
--- a/tests/sksl/shared/VertexEarlyReturn.asm.vert
+++ /dev/null
@@ -1,58 +0,0 @@
-OpCapability Shader
-%1 = OpExtInstImport "GLSL.std.450"
-OpMemoryModel Logical GLSL450
-OpEntryPoint Vertex %main "main" %3
-OpName %sk_PerVertex "sk_PerVertex"
-OpMemberName %sk_PerVertex 0 "sk_Position"
-OpMemberName %sk_PerVertex 1 "sk_PointSize"
-OpName %_UniformBuffer "_UniformBuffer"
-OpMemberName %_UniformBuffer 0 "zoom"
-OpName %main "main"
-OpMemberDecorate %sk_PerVertex 0 BuiltIn Position
-OpMemberDecorate %sk_PerVertex 1 BuiltIn PointSize
-OpDecorate %sk_PerVertex Block
-OpMemberDecorate %_UniformBuffer 0 DescriptorSet 0
-OpMemberDecorate %_UniformBuffer 0 Offset 0
-OpMemberDecorate %_UniformBuffer 0 RelaxedPrecision
-OpDecorate %_UniformBuffer Block
-OpDecorate %8 Binding 0
-OpDecorate %8 DescriptorSet 0
-OpDecorate %22 RelaxedPrecision
-OpDecorate %30 RelaxedPrecision
-%float = OpTypeFloat 32
-%v4float = OpTypeVector %float 4
-%sk_PerVertex = OpTypeStruct %v4float %float
-%_ptr_Output_sk_PerVertex = OpTypePointer Output %sk_PerVertex
-%3 = OpVariable %_ptr_Output_sk_PerVertex Output
-%_UniformBuffer = OpTypeStruct %float
-%_ptr_Uniform__UniformBuffer = OpTypePointer Uniform %_UniformBuffer
-%8 = OpVariable %_ptr_Uniform__UniformBuffer Uniform
-%void = OpTypeVoid
-%12 = OpTypeFunction %void
-%float_1 = OpConstant %float 1
-%15 = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1
-%int = OpTypeInt 32 1
-%int_0 = OpConstant %int 0
-%_ptr_Output_v4float = OpTypePointer Output %v4float
-%_ptr_Uniform_float = OpTypePointer Uniform %float
-%bool = OpTypeBool
-%main = OpFunction %void None %12
-%13 = OpLabel
-%18 = OpAccessChain %_ptr_Output_v4float %3 %int_0
-OpStore %18 %15
-%20 = OpAccessChain %_ptr_Uniform_float %8 %int_0
-%22 = OpLoad %float %20
-%23 = OpFOrdEqual %bool %22 %float_1
-OpSelectionMerge %26 None
-OpBranchConditional %23 %25 %26
-%25 = OpLabel
-OpReturn
-%26 = OpLabel
-%27 = OpAccessChain %_ptr_Output_v4float %3 %int_0
-%28 = OpLoad %v4float %27
-%29 = OpAccessChain %_ptr_Uniform_float %8 %int_0
-%30 = OpLoad %float %29
-%31 = OpVectorTimesScalar %v4float %28 %30
-OpStore %27 %31
-OpReturn
-OpFunctionEnd
diff --git a/tests/sksl/shared/VertexEarlyReturn.glsl b/tests/sksl/shared/VertexEarlyReturn.glsl
deleted file mode 100644
index 31cb675..0000000
--- a/tests/sksl/shared/VertexEarlyReturn.glsl
+++ /dev/null
@@ -1,7 +0,0 @@
-
-layout (set = 0) uniform float zoom;
-void main() {
-    gl_Position = vec4(1.0);
-    if (zoom == 1.0) return;
-    gl_Position *= zoom;
-}
diff --git a/tests/sksl/shared/VertexEarlyReturn.metal b/tests/sksl/shared/VertexEarlyReturn.metal
deleted file mode 100644
index edef044..0000000
--- a/tests/sksl/shared/VertexEarlyReturn.metal
+++ /dev/null
@@ -1,20 +0,0 @@
-#include <metal_stdlib>
-#include <simd/simd.h>
-using namespace metal;
-struct Uniforms {
-    float zoom;
-};
-struct Inputs {
-};
-struct Outputs {
-    float4 sk_Position [[position]];
-    float sk_PointSize [[point_size]];
-};
-vertex Outputs vertexMain(Inputs _in [[stage_in]], constant Uniforms& _uniforms [[buffer(0)]], uint sk_VertexID [[vertex_id]], uint sk_InstanceID [[instance_id]]) {
-    Outputs _out;
-    (void)_out;
-    _out.sk_Position = float4(1.0);
-    if (_uniforms.zoom == 1.0) return _out;
-    _out.sk_Position *= _uniforms.zoom;
-    return _out;
-}