Migrate program finalization logic out of IRGenerator.

Most of the logic in IRGenerator::finish has moved to
Compiler::finalize. The @if/@switch pass has been combined with the pass
that verifies no dangling FunctionReference/TypeReference expressions,
saving one walk through the IR tree. Most program-finalization logic now
exists in Compiler and Analysis.

This change reorders our error generation logic slightly, and manages to
squeeze a few extra (valid) errors out of one of our fuzzer-generated
tests, but is not really intended to affect results in any significant
way.

Change-Id: I461de7c31f3980dedf74424e7826c032b1f40fd2
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/444757
Commit-Queue: John Stiles <johnstiles@google.com>
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
Auto-Submit: John Stiles <johnstiles@google.com>
Reviewed-by: Ethan Nicholas <ethannicholas@google.com>
diff --git a/src/sksl/SkSLAnalysis.cpp b/src/sksl/SkSLAnalysis.cpp
index ad755f6..c5db6ad 100644
--- a/src/sksl/SkSLAnalysis.cpp
+++ b/src/sksl/SkSLAnalysis.cpp
@@ -1283,54 +1283,70 @@
     return !visitor.fFoundReturn;
 }
 
-void Analysis::VerifyStaticTests(const Program& program) {
-    class StaticTestVerifier : public ProgramVisitor {
+void Analysis::VerifyStaticTestsAndExpressions(const Program& program) {
+    class TestsAndExpressions : public ProgramVisitor {
     public:
-        StaticTestVerifier(ErrorReporter* r) : fReporter(r) {}
+        TestsAndExpressions(const Context& ctx) : fContext(ctx) {}
 
         using ProgramVisitor::visitProgramElement;
 
         bool visitStatement(const Statement& stmt) override {
-            switch (stmt.kind()) {
-                case Statement::Kind::kIf:
-                    if (stmt.as<IfStatement>().isStatic()) {
-                        fReporter->error(stmt.fOffset, "static if has non-static test");
-                    }
-                    break;
+            if (!fContext.fConfig->fSettings.fPermitInvalidStaticTests) {
+                switch (stmt.kind()) {
+                    case Statement::Kind::kIf:
+                        if (stmt.as<IfStatement>().isStatic()) {
+                            fContext.fErrors->error(stmt.fOffset, "static if has non-static test");
+                        }
+                        break;
 
-                case Statement::Kind::kSwitch:
-                    if (stmt.as<SwitchStatement>().isStatic()) {
-                        fReporter->error(stmt.fOffset, "static switch has non-static test");
-                    }
-                    break;
+                    case Statement::Kind::kSwitch:
+                        if (stmt.as<SwitchStatement>().isStatic()) {
+                            fContext.fErrors->error(stmt.fOffset,
+                                                    "static switch has non-static test");
+                        }
+                        break;
 
-                default:
-                    break;
+                    default:
+                        break;
+                }
             }
             return INHERITED::visitStatement(stmt);
         }
 
-        bool visitExpression(const Expression&) override {
-            // We aren't looking for anything inside an Expression, so skip them entirely.
-            return false;
+        bool visitExpression(const Expression& expr) override {
+            switch (expr.kind()) {
+                case Expression::Kind::kFunctionCall: {
+                    const FunctionDeclaration& decl = expr.as<FunctionCall>().function();
+                    if (!decl.isBuiltin() && !decl.definition()) {
+                        fContext.fErrors->error(expr.fOffset, "function '" + decl.description() +
+                                                              "' is not defined");
+                    }
+                    break;
+                }
+                case Expression::Kind::kExternalFunctionReference:
+                case Expression::Kind::kFunctionReference:
+                case Expression::Kind::kTypeReference:
+                    SkDEBUGFAIL("invalid reference-expr, should have been reported by coerce()");
+                    fContext.fErrors->error(expr.fOffset, "invalid expression");
+                    break;
+                default:
+                    if (expr.type() == *fContext.fTypes.fInvalid) {
+                        fContext.fErrors->error(expr.fOffset, "invalid expression");
+                    }
+                    break;
+            }
+            return INHERITED::visitExpression(expr);
         }
 
     private:
         using INHERITED = ProgramVisitor;
-        ErrorReporter* fReporter;
+        const Context& fContext;
     };
 
-    // If invalid static tests are permitted, we don't need to check anything.
-    if (program.fContext->fConfig->fSettings.fPermitInvalidStaticTests) {
-        return;
-    }
-
     // Check all of the program's owned elements. (Built-in elements are assumed to be valid.)
-    StaticTestVerifier visitor{program.fContext->fErrors};
+    TestsAndExpressions visitor{*program.fContext};
     for (const std::unique_ptr<ProgramElement>& element : program.ownedElements()) {
-        if (element->is<FunctionDefinition>()) {
-            visitor.visitProgramElement(*element);
-        }
+        visitor.visitProgramElement(*element);
     }
 }
 
diff --git a/src/sksl/SkSLAnalysis.h b/src/sksl/SkSLAnalysis.h
index 35e2e30..30c115c 100644
--- a/src/sksl/SkSLAnalysis.h
+++ b/src/sksl/SkSLAnalysis.h
@@ -158,9 +158,9 @@
     static bool CanExitWithoutReturningValue(const FunctionDeclaration& funcDecl,
                                              const Statement& body);
 
-    // Reports leftover @if and @switch statements in a program as errors. These should have been
-    // optimized away during compilation, as their tests should be constant-evaluatable.
-    static void VerifyStaticTests(const Program& program);
+    // Searches for @if/@switch statements that didn't optimize away, or dangling
+    // FunctionReference or TypeReference expressions, and reports them as errors.
+    static void VerifyStaticTestsAndExpressions(const Program& program);
 };
 
 /**
diff --git a/src/sksl/SkSLCompiler.cpp b/src/sksl/SkSLCompiler.cpp
index dbb7d5c..f514d11 100644
--- a/src/sksl/SkSLCompiler.cpp
+++ b/src/sksl/SkSLCompiler.cpp
@@ -473,7 +473,7 @@
                                              ir.fInputs);
     this->errorReporter().reportPendingErrors(PositionInfo());
     bool success = false;
-    if (this->errorCount()) {
+    if (!this->finalize(*program)) {
         // Do not return programs that failed to compile.
     } else if (!this->optimize(*program)) {
         // Do not return programs that failed to optimize.
@@ -804,8 +804,25 @@
         this->removeDeadGlobalVariables(program, usage);
     }
 
-    if (this->errorCount() == 0) {
-        Analysis::VerifyStaticTests(program);
+    return this->errorCount() == 0;
+}
+
+bool Compiler::finalize(Program& program) {
+    // Do a pass looking for @if/@switch statements that didn't optimize away, or dangling
+    // FunctionReference or TypeReference expressions. Report these as errors.
+    Analysis::VerifyStaticTestsAndExpressions(program);
+
+    // If we're in ES2 mode (runtime effects), do a pass to enforce Appendix A, Section 5 of the
+    // GLSL ES 1.00 spec -- Indexing. Don't bother if we've already found errors - this logic
+    // assumes that all loops meet the criteria of Section 4, and if they don't, could crash.
+    if (fContext->fConfig->strictES2Mode() && this->errorCount() == 0) {
+        for (const auto& pe : program.ownedElements()) {
+            Analysis::ValidateIndexingForES2(*pe, this->errorReporter());
+        }
+    }
+
+    if (fContext->fConfig->strictES2Mode()) {
+        Analysis::DetectStaticRecursion(SkMakeSpan(program.ownedElements()), this->errorReporter());
     }
 
     return this->errorCount() == 0;
diff --git a/src/sksl/SkSLCompiler.h b/src/sksl/SkSLCompiler.h
index dbe1621..107926b 100644
--- a/src/sksl/SkSLCompiler.h
+++ b/src/sksl/SkSLCompiler.h
@@ -228,6 +228,9 @@
     /** Optimize every function in the program. */
     bool optimize(Program& program);
 
+    /** Performs final checks to confirm that a fully-assembled/optimized is valid. */
+    bool finalize(Program& program);
+
     /** Optimize the module. */
     bool optimize(LoadedModule& module);
 
diff --git a/src/sksl/SkSLIRGenerator.cpp b/src/sksl/SkSLIRGenerator.cpp
index 42819d3..a41e6a7 100644
--- a/src/sksl/SkSLIRGenerator.cpp
+++ b/src/sksl/SkSLIRGenerator.cpp
@@ -1226,30 +1226,6 @@
     return PostfixExpression::Convert(fContext, std::move(base), expression.getOperator());
 }
 
-void IRGenerator::checkValid(const Expression& expr) {
-    switch (expr.kind()) {
-        case Expression::Kind::kFunctionCall: {
-            const FunctionDeclaration& decl = expr.as<FunctionCall>().function();
-            if (!decl.isBuiltin() && !decl.definition()) {
-                this->errorReporter().error(expr.fOffset,
-                                            "function '" + decl.description() + "' is not defined");
-            }
-            break;
-        }
-        case Expression::Kind::kExternalFunctionReference:
-        case Expression::Kind::kFunctionReference:
-        case Expression::Kind::kTypeReference:
-            SkDEBUGFAIL("invalid reference-expression, should have been reported by coerce()");
-            this->errorReporter().error(expr.fOffset, "invalid expression");
-            break;
-        default:
-            if (expr.type() == *fContext.fTypes.fInvalid) {
-                this->errorReporter().error(expr.fOffset, "invalid expression");
-            }
-            break;
-    }
-}
-
 void IRGenerator::findAndDeclareBuiltinVariables() {
     class BuiltinVariableScanner : public ProgramVisitor {
     public:
@@ -1394,37 +1370,6 @@
         this->findAndDeclareBuiltinVariables();
     }
 
-    // Do a pass looking for dangling FunctionReference or TypeReference expressions
-    class FindIllegalExpressions : public ProgramVisitor {
-    public:
-        FindIllegalExpressions(IRGenerator* generator) : fGenerator(generator) {}
-
-        bool visitExpression(const Expression& e) override {
-            fGenerator->checkValid(e);
-            return INHERITED::visitExpression(e);
-        }
-
-        IRGenerator* fGenerator;
-        using INHERITED = ProgramVisitor;
-        using INHERITED::visitProgramElement;
-    };
-    for (const auto& pe : *fProgramElements) {
-        FindIllegalExpressions{this}.visitProgramElement(*pe);
-    }
-
-    // If we're in ES2 mode (runtime effects), do a pass to enforce Appendix A, Section 5 of the
-    // GLSL ES 1.00 spec -- Indexing. Don't bother if we've already found errors - this logic
-    // assumes that all loops meet the criteria of Section 4, and if they don't, could crash.
-    if (this->strictES2Mode() && this->errorReporter().errorCount() == 0) {
-        for (const auto& pe : *fProgramElements) {
-            Analysis::ValidateIndexingForES2(*pe, this->errorReporter());
-        }
-    }
-
-    if (this->strictES2Mode()) {
-        Analysis::DetectStaticRecursion(SkMakeSpan(*fProgramElements), this->errorReporter());
-    }
-
     return IRBundle{std::move(*fProgramElements),
                     std::move(*fSharedElements),
                     std::move(fSymbolTable),
diff --git a/src/sksl/SkSLIRGenerator.h b/src/sksl/SkSLIRGenerator.h
index 0a23d40..10e1d0f 100644
--- a/src/sksl/SkSLIRGenerator.h
+++ b/src/sksl/SkSLIRGenerator.h
@@ -229,7 +229,6 @@
     /** Appends sk_Position fixup to the bottom of main() if this is a vertex program. */
     void appendRTAdjustFixupToVertexMain(const FunctionDeclaration& decl, Block* body);
 
-    void checkValid(const Expression& expr);
     bool setRefKind(Expression& expr, VariableReference::RefKind kind);
     void copyIntrinsicIfNeeded(const FunctionDeclaration& function);
     void findAndDeclareBuiltinVariables();
diff --git a/src/sksl/dsl/DSLCore.cpp b/src/sksl/dsl/DSLCore.cpp
index 0e33e73..199fc32 100644
--- a/src/sksl/dsl/DSLCore.cpp
+++ b/src/sksl/dsl/DSLCore.cpp
@@ -77,16 +77,17 @@
                                                       std::move(instance.fPool),
                                                       bundle.fInputs);
         bool success = false;
-        if (DSLWriter::Context().fErrors->errorCount()) {
-            DSLWriter::ReportErrors(PositionInfo());
+        if (!DSLWriter::Compiler().finalize(*result)) {
             // Do not return programs that failed to compile.
         } else if (!DSLWriter::Compiler().optimize(*result)) {
-            DSLWriter::ReportErrors(PositionInfo());
             // Do not return programs that failed to optimize.
         } else {
             // We have a successful program!
             success = true;
         }
+        if (!success) {
+            DSLWriter::ReportErrors(PositionInfo());
+        }
         if (pool) {
             pool->detachFromThread();
         }
diff --git a/tests/sksl/runtime_errors/Ossfuzz36655.skvm b/tests/sksl/runtime_errors/Ossfuzz36655.skvm
index af3e00d..fba9a1a 100644
--- a/tests/sksl/runtime_errors/Ossfuzz36655.skvm
+++ b/tests/sksl/runtime_errors/Ossfuzz36655.skvm
@@ -1,6 +1,11 @@
 ### Compilation failed:
 
+error: 20: static if has non-static test
+error: 24: static if has non-static test
+error: 31: static if has non-static test
+error: 38: static if has non-static test
+error: 44: static if has non-static test
 error: 13: potential recursion (function call cycle) not allowed:
 	void X()
 	void X()
-1 error
+6 errors