SkSL now supports functions defined in sksl_gpu.inc

Change-Id: Ib29f41f6e71b176fec1ead26259ad1945a41e634
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/254677
Reviewed-by: Brian Osman <brianosman@google.com>
Commit-Queue: Ethan Nicholas <ethannicholas@google.com>
diff --git a/src/sksl/SkSLByteCodeGenerator.cpp b/src/sksl/SkSLByteCodeGenerator.cpp
index 560a3a7..e7c214b 100644
--- a/src/sksl/SkSLByteCodeGenerator.cpp
+++ b/src/sksl/SkSLByteCodeGenerator.cpp
@@ -1025,8 +1025,16 @@
 }
 
 void ByteCodeGenerator::writeFunctionCall(const FunctionCall& f) {
-    // Builtins have simple signatures...
-    if (f.fFunction.fBuiltin) {
+    // Find the index of the function we're calling. We explicitly do not allow calls to functions
+    // before they're defined. This is an easy-to-understand rule that prevents recursion.
+    int idx = -1;
+    for (size_t i = 0; i < fFunctions.size(); ++i) {
+        if (f.fFunction.matches(fFunctions[i]->fDeclaration)) {
+            idx = i;
+            break;
+        }
+    }
+    if (idx == -1) {
         for (const auto& arg : f.fArguments) {
             this->writeExpression(*arg);
         }
@@ -1034,18 +1042,11 @@
         return;
     }
 
-    // Find the index of the function we're calling. We explicitly do not allow calls to functions
-    // before they're defined. This is an easy-to-understand rule that prevents recursion.
-    size_t idx;
-    for (idx = 0; idx < fFunctions.size(); ++idx) {
-        if (f.fFunction.matches(fFunctions[idx]->fDeclaration)) {
-            break;
-        }
-    }
+
     if (idx > 255) {
         fErrors.error(f.fOffset, "Function count limit exceeded");
         return;
-    } else if (idx >= fFunctions.size()) {
+    } else if (idx >= (int) fFunctions.size()) {
         fErrors.error(f.fOffset, "Call to undefined function");
         return;
     }
diff --git a/src/sksl/SkSLCompiler.cpp b/src/sksl/SkSLCompiler.cpp
index 7aa6b9f..bb8ac60 100644
--- a/src/sksl/SkSLCompiler.cpp
+++ b/src/sksl/SkSLCompiler.cpp
@@ -67,6 +67,32 @@
 
 namespace SkSL {
 
+static void grab_intrinsics(std::vector<std::unique_ptr<ProgramElement>>* src,
+               std::map<StringFragment, std::pair<std::unique_ptr<ProgramElement>, bool>>* target) {
+    for (auto& element : *src) {
+        switch (element->fKind) {
+            case ProgramElement::kFunction_Kind: {
+                FunctionDefinition& f = (FunctionDefinition&) *element;
+                StringFragment name = f.fDeclaration.fName;
+                SkASSERT(target->find(name) == target->end());
+                (*target)[name] = std::make_pair(std::move(element), false);
+                break;
+            }
+            case ProgramElement::kEnum_Kind: {
+                Enum& e = (Enum&) *element;
+                StringFragment name = e.fTypeName;
+                SkASSERT(target->find(name) == target->end());
+                (*target)[name] = std::make_pair(std::move(element), false);
+                break;
+            }
+            default:
+                printf("unsupported include file element\n");
+                SkASSERT(false);
+        }
+    }
+}
+
+
 Compiler::Compiler(Flags flags)
 : fFlags(flags)
 , fContext(new Context())
@@ -223,9 +249,14 @@
                                     *fContext->fSkArgs_Type, Variable::kGlobal_Storage);
     fIRGenerator->fSymbolTable->add(skArgsName, std::unique_ptr<Symbol>(skArgs));
 
-    std::vector<std::unique_ptr<ProgramElement>> ignored;
+    fIRGenerator->fIntrinsics = &fGPUIntrinsics;
+    std::vector<std::unique_ptr<ProgramElement>> gpuIntrinsics;
     this->processIncludeFile(Program::kFragment_Kind, SKSL_GPU_INCLUDE, strlen(SKSL_GPU_INCLUDE),
-                             symbols, &ignored, &fGpuSymbolTable);
+                             symbols, &gpuIntrinsics, &fGpuSymbolTable);
+    grab_intrinsics(&gpuIntrinsics, &fGPUIntrinsics);
+    // need to hang on to the source so that FunctionDefinition.fSource pointers in this file
+    // remain valid
+    fGpuIncludeSource = std::move(fIRGenerator->fFile);
     this->processIncludeFile(Program::kVertex_Kind, SKSL_VERT_INCLUDE, strlen(SKSL_VERT_INCLUDE),
                              fGpuSymbolTable, &fVertexInclude, &fVertexSymbolTable);
     this->processIncludeFile(Program::kFragment_Kind, SKSL_FRAG_INCLUDE, strlen(SKSL_FRAG_INCLUDE),
@@ -235,9 +266,11 @@
     this->processIncludeFile(Program::kPipelineStage_Kind, SKSL_PIPELINE_INCLUDE,
                              strlen(SKSL_PIPELINE_INCLUDE), fGpuSymbolTable, &fPipelineInclude,
                              &fPipelineSymbolTable);
+    std::vector<std::unique_ptr<ProgramElement>> interpIntrinsics;
     this->processIncludeFile(Program::kGeneric_Kind, SKSL_INTERP_INCLUDE,
                              strlen(SKSL_INTERP_INCLUDE), symbols, &fInterpreterInclude,
                              &fInterpreterSymbolTable);
+    grab_intrinsics(&interpIntrinsics, &fInterpreterIntrinsics);
 }
 
 Compiler::~Compiler() {
@@ -1290,7 +1323,8 @@
     // check for missing return
     if (f.fDeclaration.fReturnType != *fContext->fVoid_Type) {
         if (cfg.fBlocks[cfg.fExit].fEntrances.size()) {
-            this->error(f.fOffset, String("function can exit without returning a value"));
+            this->error(f.fOffset, String("function '" + String(f.fDeclaration.fName) +
+                                          "' can exit without returning a value"));
         }
     }
 }
@@ -1313,22 +1347,26 @@
         case Program::kVertex_Kind:
             inherited = &fVertexInclude;
             fIRGenerator->fSymbolTable = fVertexSymbolTable;
+            fIRGenerator->fIntrinsics = &fGPUIntrinsics;
             fIRGenerator->start(&settings, inherited);
             break;
         case Program::kFragment_Kind:
             inherited = &fFragmentInclude;
             fIRGenerator->fSymbolTable = fFragmentSymbolTable;
+            fIRGenerator->fIntrinsics = &fGPUIntrinsics;
             fIRGenerator->start(&settings, inherited);
             break;
         case Program::kGeometry_Kind:
             inherited = &fGeometryInclude;
             fIRGenerator->fSymbolTable = fGeometrySymbolTable;
+            fIRGenerator->fIntrinsics = &fGPUIntrinsics;
             fIRGenerator->start(&settings, inherited);
             break;
         case Program::kFragmentProcessor_Kind:
             inherited = nullptr;
             fIRGenerator->fSymbolTable = fGpuSymbolTable;
             fIRGenerator->start(&settings, nullptr);
+            fIRGenerator->fIntrinsics = &fGPUIntrinsics;
             fIRGenerator->convertProgram(kind, SKSL_FP_INCLUDE, strlen(SKSL_FP_INCLUDE), *fTypes,
                                          &elements);
             fIRGenerator->fSymbolTable->markAllFunctionsBuiltin();
@@ -1336,11 +1374,13 @@
         case Program::kPipelineStage_Kind:
             inherited = &fPipelineInclude;
             fIRGenerator->fSymbolTable = fPipelineSymbolTable;
+            fIRGenerator->fIntrinsics = &fGPUIntrinsics;
             fIRGenerator->start(&settings, inherited);
             break;
         case Program::kGeneric_Kind:
             inherited = &fInterpreterInclude;
             fIRGenerator->fSymbolTable = fInterpreterSymbolTable;
+            fIRGenerator->fIntrinsics = &fInterpreterIntrinsics;
             fIRGenerator->start(&settings, inherited);
             break;
     }
diff --git a/src/sksl/SkSLCompiler.h b/src/sksl/SkSLCompiler.h
index f2dd756..a4b564e 100644
--- a/src/sksl/SkSLCompiler.h
+++ b/src/sksl/SkSLCompiler.h
@@ -8,9 +8,11 @@
 #ifndef SKSL_COMPILER
 #define SKSL_COMPILER
 
+#include <map>
 #include <set>
 #include <unordered_set>
 #include <vector>
+#include "src/sksl/SkSLASTFile.h"
 #include "src/sksl/SkSLCFGGenerator.h"
 #include "src/sksl/SkSLContext.h"
 #include "src/sksl/SkSLErrorReporter.h"
@@ -214,6 +216,9 @@
 
     Position position(int offset);
 
+    std::map<StringFragment, std::pair<std::unique_ptr<ProgramElement>, bool>> fGPUIntrinsics;
+    std::map<StringFragment, std::pair<std::unique_ptr<ProgramElement>, bool>> fInterpreterIntrinsics;
+    std::unique_ptr<ASTFile> fGpuIncludeSource;
     std::shared_ptr<SymbolTable> fGpuSymbolTable;
     std::vector<std::unique_ptr<ProgramElement>> fVertexInclude;
     std::shared_ptr<SymbolTable> fVertexSymbolTable;
diff --git a/src/sksl/SkSLGLSLCodeGenerator.cpp b/src/sksl/SkSLGLSLCodeGenerator.cpp
index ffa85bd..c555b8a 100644
--- a/src/sksl/SkSLGLSLCodeGenerator.cpp
+++ b/src/sksl/SkSLGLSLCodeGenerator.cpp
@@ -153,6 +153,8 @@
             }
             break;
         }
+        case Type::kEnum_Kind:
+            return "int";
         default:
             return type.name();
     }
diff --git a/src/sksl/SkSLIRGenerator.cpp b/src/sksl/SkSLIRGenerator.cpp
index 2e78b99..5157235 100644
--- a/src/sksl/SkSLIRGenerator.cpp
+++ b/src/sksl/SkSLIRGenerator.cpp
@@ -169,6 +169,10 @@
             }
         }
     }
+    SkASSERT(fIntrinsics);
+    for (auto& pair : *fIntrinsics) {
+        pair.second.second = false;
+    }
 }
 
 std::unique_ptr<Extension> IRGenerator::convertExtension(int offset, StringFragment name) {
@@ -833,7 +837,7 @@
                             return;
                         }
                     }
-                    if (other->fDefined) {
+                    if (other->fDefined && !other->fBuiltin) {
                         fErrors.error(f.fOffset, "duplicate definition of " +
                                                  other->description());
                     }
@@ -893,8 +897,10 @@
         if (Program::kVertex_Kind == fKind && fd.fName == "main" && fRTAdjust) {
             body->fStatements.insert(body->fStatements.end(), this->getNormalizeSkPositionCode());
         }
-        fProgramElements->push_back(std::unique_ptr<FunctionDefinition>(
-                                        new FunctionDefinition(f.fOffset, *decl, std::move(body))));
+        std::unique_ptr<FunctionDefinition> result(new FunctionDefinition(f.fOffset, *decl,
+                                                                          std::move(body)));
+        result->fSource = &f;
+        fProgramElements->push_back(std::move(result));
     }
 }
 
@@ -1748,6 +1754,16 @@
 std::unique_ptr<Expression> IRGenerator::call(int offset,
                                               const FunctionDeclaration& function,
                                               std::vector<std::unique_ptr<Expression>> arguments) {
+    if (function.fBuiltin) {
+        auto found = fIntrinsics->find(function.fName);
+        if (found != fIntrinsics->end() && !found->second.second) {
+            found->second.second = true;
+            const FunctionDeclaration* old = fCurrentFunction;
+            fCurrentFunction = nullptr;
+            this->convertFunction(*((FunctionDefinition&) *found->second.first).fSource);
+            fCurrentFunction = old;
+        }
+    }
     if (function.fParameters.size() != arguments.size()) {
         String msg = "call to '" + function.fName + "' expected " +
                                  to_string((uint64_t) function.fParameters.size()) +
@@ -2246,10 +2262,23 @@
             fSymbolTable = ((Enum&) *e).fSymbols;
             result = convertIdentifier(ASTNode(&fFile->fNodes, offset, ASTNode::Kind::kIdentifier,
                                                field));
+            SkASSERT(result->fKind == Expression::kVariableReference_Kind);
+            const Variable& v = ((VariableReference&) *result).fVariable;
+            SkASSERT(v.fInitialValue);
+            SkASSERT(v.fInitialValue->fKind == Expression::kIntLiteral_Kind);
+            result.reset(new IntLiteral(offset, ((IntLiteral&) *v.fInitialValue).fValue, &type));
             fSymbolTable = old;
+            break;
         }
     }
     if (!result) {
+        auto found = fIntrinsics->find(type.fName);
+        if (found != fIntrinsics->end()) {
+            SkASSERT(!found->second.second);
+            found->second.second = true;
+            fProgramElements->push_back(found->second.first->clone());
+            return this->convertTypeField(offset, type, field);
+        }
         fErrors.error(offset, "type '" + type.fName + "' does not have a field named '" + field +
                               "'");
     }
diff --git a/src/sksl/SkSLIRGenerator.h b/src/sksl/SkSLIRGenerator.h
index 6ca6248..a088444 100644
--- a/src/sksl/SkSLIRGenerator.h
+++ b/src/sksl/SkSLIRGenerator.h
@@ -8,6 +8,8 @@
 #ifndef SKSL_IRGENERATOR
 #define SKSL_IRGENERATOR
 
+#include <map>
+
 #include "src/sksl/SkSLASTFile.h"
 #include "src/sksl/SkSLASTNode.h"
 #include "src/sksl/SkSLErrorReporter.h"
@@ -155,6 +157,9 @@
     std::unordered_map<String, Program::Settings::Value> fCapsMap;
     std::shared_ptr<SymbolTable> fRootSymbolTable;
     std::shared_ptr<SymbolTable> fSymbolTable;
+    // Symbols which have definitions in the include files. The bool tells us whether this
+    // intrinsic has been included already.
+    std::map<StringFragment, std::pair<std::unique_ptr<ProgramElement>, bool>>* fIntrinsics = nullptr;
     // holds extra temp variable declarations needed for the current function
     std::vector<std::unique_ptr<Statement>> fExtraVars;
     int fLoopLevel;
diff --git a/src/sksl/ir/SkSLFunctionDefinition.h b/src/sksl/ir/SkSLFunctionDefinition.h
index 7344373..511a0f8 100644
--- a/src/sksl/ir/SkSLFunctionDefinition.h
+++ b/src/sksl/ir/SkSLFunctionDefinition.h
@@ -14,6 +14,8 @@
 
 namespace SkSL {
 
+struct ASTNode;
+
 /**
  * A function definition (a declaration plus an associated block of code).
  */
@@ -35,6 +37,11 @@
 
     const FunctionDeclaration& fDeclaration;
     std::unique_ptr<Statement> fBody;
+    // This pointer may be null, and even when non-null is not guaranteed to remain valid for the
+    // entire lifespan of this object. The parse tree's lifespan is normally controlled by
+    // IRGenerator, so the IRGenerator being destroyed or being used to compile another file will
+    // invalidate this pointer.
+    const ASTNode* fSource = nullptr;
 
     typedef ProgramElement INHERITED;
 };
diff --git a/src/sksl/ir/SkSLIntLiteral.h b/src/sksl/ir/SkSLIntLiteral.h
index a95875c..fa45978 100644
--- a/src/sksl/ir/SkSLIntLiteral.h
+++ b/src/sksl/ir/SkSLIntLiteral.h
@@ -45,7 +45,8 @@
     }
 
     int coercionCost(const Type& target) const override {
-        if (target.isSigned() || target.isUnsigned() || target.isFloat()) {
+        if (target.isSigned() || target.isUnsigned() || target.isFloat() ||
+            target.kind() == Type::kEnum_Kind) {
             return 0;
         }
         return INHERITED::coercionCost(target);
diff --git a/src/sksl/ir/SkSLSymbolTable.cpp b/src/sksl/ir/SkSLSymbolTable.cpp
index 08bd6c2..ed2cb4d 100644
--- a/src/sksl/ir/SkSLSymbolTable.cpp
+++ b/src/sksl/ir/SkSLSymbolTable.cpp
@@ -110,9 +110,7 @@
     for (const auto& pair : fSymbols) {
         switch (pair.second->fKind) {
             case Symbol::kFunctionDeclaration_Kind:
-                if (!((FunctionDeclaration&)*pair.second).fDefined) {
-                    ((FunctionDeclaration&)*pair.second).fBuiltin = true;
-                }
+                ((FunctionDeclaration&)*pair.second).fBuiltin = true;
                 break;
             case Symbol::kUnresolvedFunction_Kind:
                 for (auto& f : ((UnresolvedFunction&) *pair.second).fFunctions) {