Reorganize how SkSL includes are parsed and stored

This binds together the IntrinsicMap and SymbolTable for each include to
a single entity, with helper functions that create and return them. Used
a little bit of macro trickery to move all of the standalone/runtime
logic into loadIncludeFile, which drastically reduces boilerplate.

Change-Id: Ic70c0d67967cc614daeab5c50412ab69dcdf2fea
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/324124
Commit-Queue: Brian Osman <brianosman@google.com>
Reviewed-by: John Stiles <johnstiles@google.com>
diff --git a/src/sksl/SkSLCompiler.cpp b/src/sksl/SkSLCompiler.cpp
index 7f48645..4b9f11b 100644
--- a/src/sksl/SkSLCompiler.cpp
+++ b/src/sksl/SkSLCompiler.cpp
@@ -46,8 +46,15 @@
 #include "spirv-tools/libspirv.hpp"
 #endif
 
-#if !SKSL_STANDALONE
+#if defined(SKSL_STANDALONE)
 
+// In standalone mode, we load the textual sksl source files. GN generates or copies these files
+// to the skslc executable directory. The "data" in this mode is just the filename.
+#define MODULE_DATA(name) MakeModulePath("sksl_" #name ".sksl")
+
+#else
+
+// At runtime, we load the dehydrated sksl data files. The data is a (pointer, size) pair.
 #include "src/sksl/generated/sksl_fp.dehydrated.sksl"
 #include "src/sksl/generated/sksl_frag.dehydrated.sksl"
 #include "src/sksl/generated/sksl_geom.dehydrated.sksl"
@@ -56,66 +63,13 @@
 #include "src/sksl/generated/sksl_pipeline.dehydrated.sksl"
 #include "src/sksl/generated/sksl_vert.dehydrated.sksl"
 
-#else
-
-// GN generates or copies all of these files to the skslc executable directory
-static const char SKSL_GPU_INCLUDE[]      = "sksl_gpu.sksl";
-static const char SKSL_INTERP_INCLUDE[]   = "sksl_interp.sksl";
-static const char SKSL_VERT_INCLUDE[]     = "sksl_vert.sksl";
-static const char SKSL_FRAG_INCLUDE[]     = "sksl_frag.sksl";
-static const char SKSL_GEOM_INCLUDE[]     = "sksl_geom.sksl";
-static const char SKSL_FP_INCLUDE[]       = "sksl_fp.sksl";
-static const char SKSL_PIPELINE_INCLUDE[] = "sksl_pipeline.sksl";
+#define MODULE_DATA(name) MakeModuleData(SKSL_INCLUDE_sksl_##name,\
+                                         SKSL_INCLUDE_sksl_##name##_LENGTH)
 
 #endif
 
 namespace SkSL {
 
-static void grab_intrinsics(std::vector<std::unique_ptr<ProgramElement>>* src,
-                            IRIntrinsicMap* target) {
-    for (std::unique_ptr<ProgramElement>& element : *src) {
-        switch (element->kind()) {
-            case ProgramElement::Kind::kFunction: {
-                const FunctionDefinition& f = element->as<FunctionDefinition>();
-                SkASSERT(f.fDeclaration.isBuiltin());
-                target->insertOrDie(f.fDeclaration.description(), std::move(element));
-                break;
-            }
-            case ProgramElement::Kind::kEnum: {
-                const Enum& e = element->as<Enum>();
-                SkASSERT(e.isBuiltin());
-                target->insertOrDie(e.typeName(), std::move(element));
-                break;
-            }
-            case ProgramElement::Kind::kGlobalVar: {
-                const Variable* var = element->as<GlobalVarDeclaration>().fDecl->fVar;
-                SkASSERT(var->isBuiltin());
-                target->insertOrDie(var->name(), std::move(element));
-                break;
-            }
-            case ProgramElement::Kind::kInterfaceBlock: {
-                const Variable* var = element->as<InterfaceBlock>().fVariable;
-                SkASSERT(var->isBuiltin());
-                target->insertOrDie(var->name(), std::move(element));
-                break;
-            }
-            default:
-                printf("Unsupported element: %s\n", element->description().c_str());
-                SkASSERT(false);
-                break;
-        }
-    }
-}
-
-static void reset_call_counts(std::vector<std::unique_ptr<ProgramElement>>* src) {
-    for (std::unique_ptr<ProgramElement>& element : *src) {
-        if (element->is<FunctionDefinition>()) {
-            const FunctionDeclaration& fnDecl = element->as<FunctionDefinition>().fDeclaration;
-            fnDecl.callCount() = 0;
-        }
-    }
-}
-
 Compiler::Compiler(Flags flags)
 : fFlags(flags)
 , fContext(std::make_shared<Context>())
@@ -256,186 +210,138 @@
                                                      skCapsName, fContext->fSkCaps_Type.get(),
                                                      /*builtin=*/false, Variable::kGlobal_Storage));
 
-    fIRGenerator->fIntrinsics = nullptr;
-    std::vector<std::unique_ptr<ProgramElement>> gpuElements;
-    std::vector<std::unique_ptr<ProgramElement>> vertElements;
-    std::vector<std::unique_ptr<ProgramElement>> fragElements;
-#if SKSL_STANDALONE
-    this->processIncludeFile(Program::kFragment_Kind, SKSL_GPU_INCLUDE, fRootSymbolTable,
-                             &gpuElements, &fGpuSymbolTable);
-    this->processIncludeFile(Program::kVertex_Kind, SKSL_VERT_INCLUDE, fGpuSymbolTable,
-                             &vertElements, &fVertexSymbolTable);
-    this->processIncludeFile(Program::kFragment_Kind, SKSL_FRAG_INCLUDE, fGpuSymbolTable,
-                             &fragElements, &fFragmentSymbolTable);
-#else
-    {
-        Rehydrator rehydrator(&fIRGenerator->fContext, fIRGenerator->fModifiers.get(),
-                              fRootSymbolTable, this, SKSL_INCLUDE_sksl_gpu,
-                              SKSL_INCLUDE_sksl_gpu_LENGTH);
-        fGpuSymbolTable = rehydrator.symbolTable();
-        gpuElements = rehydrator.elements();
-        fModifiers.push_back(fIRGenerator->releaseModifiers());
-    }
-    {
-        Rehydrator rehydrator(&fIRGenerator->fContext, fIRGenerator->fModifiers.get(),
-                              fGpuSymbolTable, this, SKSL_INCLUDE_sksl_vert,
-                              SKSL_INCLUDE_sksl_vert_LENGTH);
-        fVertexSymbolTable = rehydrator.symbolTable();
-        vertElements = rehydrator.elements();
-        fModifiers.push_back(fIRGenerator->releaseModifiers());
-    }
-    {
-        Rehydrator rehydrator(&fIRGenerator->fContext, fIRGenerator->fModifiers.get(),
-                              fGpuSymbolTable, this, SKSL_INCLUDE_sksl_frag,
-                              SKSL_INCLUDE_sksl_frag_LENGTH);
-        fFragmentSymbolTable = rehydrator.symbolTable();
-        fragElements = rehydrator.elements();
-        fModifiers.push_back(fIRGenerator->releaseModifiers());
-    }
-#endif
-    // Call counts are used to track dead-stripping and inlinability within the program being
-    // currently compiled, and always should start at zero for a new program. Zero out any call
-    // counts that were registered during the assembly of the intrinsics/include data. (If we
-    // actually use calls from inside the intrinsics, we will clone them into the program and they
-    // will get new call counts.)
-    reset_call_counts(&gpuElements);
-    reset_call_counts(&vertElements);
-    reset_call_counts(&fragElements);
+    fRootModule = {fRootSymbolTable, /*fIntrinsics=*/nullptr};
 
-    fGPUIntrinsics = std::make_unique<IRIntrinsicMap>(/*parent=*/nullptr);
-    grab_intrinsics(&gpuElements, fGPUIntrinsics.get());
-
-    fVertexIntrinsics = std::make_unique<IRIntrinsicMap>(fGPUIntrinsics.get());
-    grab_intrinsics(&vertElements, fVertexIntrinsics.get());
-
-    fFragmentIntrinsics = std::make_unique<IRIntrinsicMap>(fGPUIntrinsics.get());
-    grab_intrinsics(&fragElements, fFragmentIntrinsics.get());
+    fGPUModule = this->parseModule(Program::kFragment_Kind, MODULE_DATA(gpu), fRootModule);
+    fVertexModule = this->parseModule(Program::kVertex_Kind, MODULE_DATA(vert), fGPUModule);
+    fFragmentModule = this->parseModule(Program::kFragment_Kind, MODULE_DATA(frag), fGPUModule);
 }
 
 Compiler::~Compiler() {}
 
 void Compiler::loadGeometryIntrinsics() {
-    if (fGeometrySymbolTable) {
-        return;
+    if (!fGeometryModule.fSymbols) {
+        fGeometryModule = this->parseModule(Program::kGeometry_Kind, MODULE_DATA(geom), fGPUModule);
     }
-    fGeometryIntrinsics = std::make_unique<IRIntrinsicMap>(fGPUIntrinsics.get());
-    std::vector<std::unique_ptr<ProgramElement>> geomElements;
-    #if !SKSL_STANDALONE
-        {
-            Rehydrator rehydrator(&fIRGenerator->fContext, fIRGenerator->fModifiers.get(),
-                                  fGpuSymbolTable, this, SKSL_INCLUDE_sksl_geom,
-                                  SKSL_INCLUDE_sksl_geom_LENGTH);
-            fGeometrySymbolTable = rehydrator.symbolTable();
-            geomElements = rehydrator.elements();
-            fModifiers.push_back(fIRGenerator->releaseModifiers());
-        }
-    #else
-        this->processIncludeFile(Program::kGeometry_Kind, SKSL_GEOM_INCLUDE, fGpuSymbolTable,
-                                 &geomElements, &fGeometrySymbolTable);
-    #endif
-    grab_intrinsics(&geomElements, fGeometryIntrinsics.get());
 }
 
 void Compiler::loadFPIntrinsics() {
-    if (fFPSymbolTable) {
-        return;
+    if (!fFPModule.fSymbols) {
+        fFPModule =
+                this->parseModule(Program::kFragmentProcessor_Kind, MODULE_DATA(fp), fGPUModule);
     }
-    fFPIntrinsics = std::make_unique<IRIntrinsicMap>(fGPUIntrinsics.get());
-    std::vector<std::unique_ptr<ProgramElement>> fpElements;
-    #if !SKSL_STANDALONE
-        {
-            Rehydrator rehydrator(&fIRGenerator->fContext, fIRGenerator->fModifiers.get(),
-                                  fGpuSymbolTable, this, SKSL_INCLUDE_sksl_fp,
-                                  SKSL_INCLUDE_sksl_fp_LENGTH);
-            fFPSymbolTable = rehydrator.symbolTable();
-            fpElements = rehydrator.elements();
-        }
-    #else
-        this->processIncludeFile(Program::kFragmentProcessor_Kind, SKSL_FP_INCLUDE, fGpuSymbolTable,
-                                 &fpElements, &fFPSymbolTable);
-    #endif
-    grab_intrinsics(&fpElements, fFPIntrinsics.get());
 }
 
 void Compiler::loadPipelineIntrinsics() {
-    if (fPipelineSymbolTable) {
-        return;
+    if (!fPipelineModule.fSymbols) {
+        fPipelineModule =
+                this->parseModule(Program::kPipelineStage_Kind, MODULE_DATA(pipeline), fGPUModule);
     }
-    fPipelineIntrinsics = std::make_unique<IRIntrinsicMap>(fGPUIntrinsics.get());
-    std::vector<std::unique_ptr<ProgramElement>> pipelineIntrinics;
-    #if !SKSL_STANDALONE
-        {
-            Rehydrator rehydrator(&fIRGenerator->fContext, fIRGenerator->fModifiers.get(),
-                                  fGpuSymbolTable, this, SKSL_INCLUDE_sksl_pipeline,
-                                  SKSL_INCLUDE_sksl_pipeline_LENGTH);
-            fPipelineSymbolTable = rehydrator.symbolTable();
-            pipelineIntrinics = rehydrator.elements();
-            fModifiers.push_back(fIRGenerator->releaseModifiers());
-        }
-    #else
-        this->processIncludeFile(Program::kPipelineStage_Kind, SKSL_PIPELINE_INCLUDE,
-                                 fGpuSymbolTable, &pipelineIntrinics, &fPipelineSymbolTable);
-    #endif
-    grab_intrinsics(&pipelineIntrinics, fPipelineIntrinsics.get());
 }
 
 void Compiler::loadInterpreterIntrinsics() {
-    if (fInterpreterSymbolTable) {
-        return;
+    if (!fInterpreterModule.fSymbols) {
+        fInterpreterModule =
+                this->parseModule(Program::kGeneric_Kind, MODULE_DATA(interp), fRootModule);
     }
-    fInterpreterIntrinsics = std::make_unique<IRIntrinsicMap>(/*parent=*/nullptr);
-    std::vector<std::unique_ptr<ProgramElement>> interpElements;
-    #if !SKSL_STANDALONE
-        {
-            Rehydrator rehydrator(&fIRGenerator->fContext, fIRGenerator->fModifiers.get(),
-                                  fRootSymbolTable, this, SKSL_INCLUDE_sksl_interp,
-                                  SKSL_INCLUDE_sksl_interp_LENGTH);
-            fInterpreterSymbolTable = rehydrator.symbolTable();
-            interpElements = rehydrator.elements();
-            fModifiers.push_back(fIRGenerator->releaseModifiers());
-        }
-    #else
-        this->processIncludeFile(Program::kGeneric_Kind, SKSL_INTERP_INCLUDE,
-                                 fIRGenerator->fSymbolTable, &interpElements,
-                                 &fInterpreterSymbolTable);
-    #endif
-    grab_intrinsics(&interpElements, fInterpreterIntrinsics.get());
 }
 
-void Compiler::processIncludeFile(Program::Kind kind, const char* path,
-                                  std::shared_ptr<SymbolTable> base,
-                                  std::vector<std::unique_ptr<ProgramElement>>* outElements,
-                                  std::shared_ptr<SymbolTable>* outSymbolTable) {
-    std::ifstream in(path);
+LoadedModule Compiler::loadModule(Program::Kind kind,
+                                  ModuleData data,
+                                  std::shared_ptr<SymbolTable> base) {
+    LoadedModule module;
+    if (!base) {
+        base = fRootSymbolTable;
+    }
+
+#if defined(SKSL_STANDALONE)
+    SkASSERT(data.fPath);
+    std::ifstream in(data.fPath);
     std::unique_ptr<String> text = std::make_unique<String>(std::istreambuf_iterator<char>(in),
                                                             std::istreambuf_iterator<char>());
     if (in.rdstate()) {
-        printf("error reading %s\n", path);
+        printf("error reading %s\n", data.fPath);
         abort();
     }
     const String* source = fRootSymbolTable->takeOwnershipOfString(std::move(text));
     fSource = source;
     Program::Settings settings;
-#if !defined(SKSL_STANDALONE) & SK_SUPPORT_GPU
-    GrContextOptions opts;
-    GrShaderCaps caps(opts);
-    settings.fCaps = &caps;
-#endif
     SkASSERT(fIRGenerator->fCanInline);
     fIRGenerator->fCanInline = false;
-    fIRGenerator->start(&settings, base ? base : fRootSymbolTable, true);
-    fIRGenerator->convertProgram(kind, source->c_str(), source->length(), outElements);
+    fIRGenerator->start(&settings, {base, /*fIntrinsics=*/nullptr}, /*builtin=*/true);
+    fIRGenerator->convertProgram(kind, source->c_str(), source->length(), &module.fElements);
     fIRGenerator->fCanInline = true;
     if (this->fErrorCount) {
         printf("Unexpected errors: %s\n", this->fErrorText.c_str());
-        SkDEBUGFAILF("%s %s\n", path, this->fErrorText.c_str());
+        SkDEBUGFAILF("%s %s\n", data.fPath, this->fErrorText.c_str());
     }
-    *outSymbolTable = fIRGenerator->fSymbolTable;
-#ifdef SK_DEBUG
+    module.fSymbols = fIRGenerator->fSymbolTable;
     fSource = nullptr;
-#endif
     fModifiers.push_back(fIRGenerator->releaseModifiers());
     fIRGenerator->finish();
+#else
+    SkASSERT(data.fData && (data.fSize != 0));
+    Rehydrator rehydrator(fContext.get(), fIRGenerator->fModifiers.get(), base, this,
+                          data.fData, data.fSize);
+    module = { rehydrator.symbolTable(), rehydrator.elements() };
+    fModifiers.push_back(fIRGenerator->releaseModifiers());
+#endif
+
+    return module;
+}
+
+ParsedModule Compiler::parseModule(Program::Kind kind, ModuleData data, const ParsedModule& base) {
+    auto [symbols, elements] = this->loadModule(kind, data, base.fSymbols);
+
+    // For modules that just declare (but don't define) intrinsic functions, there will be no new
+    // program elements. In that case, we can share our parent's intrinsic map:
+    if (elements.empty()) {
+        return {symbols, base.fIntrinsics};
+    }
+
+    auto intrinsics = std::make_shared<IRIntrinsicMap>(base.fIntrinsics.get());
+
+    // Now, transfer all of the program elements to an intrinsic map. This maps certain types of
+    // global objects to the declaring ProgramElement.
+    for (std::unique_ptr<ProgramElement>& element : elements) {
+        switch (element->kind()) {
+            case ProgramElement::Kind::kFunction: {
+                const FunctionDefinition& f = element->as<FunctionDefinition>();
+                SkASSERT(f.fDeclaration.isBuiltin());
+                // Call counts are used to track dead-stripping and inlinability within the program
+                // being compiled, and should start at zero for a new program. Zero out any call
+                // counts internal to the include data. (If we actually use calls from inside the
+                // intrinsics, we will clone them into the program with new call counts.)
+                f.fDeclaration.callCount() = 0;
+                intrinsics->insertOrDie(f.fDeclaration.description(), std::move(element));
+                break;
+            }
+            case ProgramElement::Kind::kEnum: {
+                const Enum& e = element->as<Enum>();
+                SkASSERT(e.isBuiltin());
+                intrinsics->insertOrDie(e.typeName(), std::move(element));
+                break;
+            }
+            case ProgramElement::Kind::kGlobalVar: {
+                const Variable* var = element->as<GlobalVarDeclaration>().fDecl->fVar;
+                SkASSERT(var->isBuiltin());
+                intrinsics->insertOrDie(var->name(), std::move(element));
+                break;
+            }
+            case ProgramElement::Kind::kInterfaceBlock: {
+                const Variable* var = element->as<InterfaceBlock>().fVariable;
+                SkASSERT(var->isBuiltin());
+                intrinsics->insertOrDie(var->name(), std::move(element));
+                break;
+            }
+            default:
+                printf("Unsupported element: %s\n", element->description().c_str());
+                SkASSERT(false);
+                break;
+        }
+    }
+
+    return {symbols, std::move(intrinsics)};
 }
 
 // add the definition created by assigning to the lvalue to the definition set
@@ -1590,36 +1496,30 @@
 
     fErrorText = "";
     fErrorCount = 0;
-    fInliner.reset(&fIRGenerator->fContext, fIRGenerator->fModifiers.get(), &settings);
+    fInliner.reset(fContext.get(), fIRGenerator->fModifiers.get(), &settings);
     std::vector<std::unique_ptr<ProgramElement>> elements;
     switch (kind) {
         case Program::kVertex_Kind:
-            fIRGenerator->fIntrinsics = fVertexIntrinsics.get();
-            fIRGenerator->start(&settings, fVertexSymbolTable);
+            fIRGenerator->start(&settings, fVertexModule);
             break;
         case Program::kFragment_Kind:
-            fIRGenerator->fIntrinsics = fFragmentIntrinsics.get();
-            fIRGenerator->start(&settings, fFragmentSymbolTable);
+            fIRGenerator->start(&settings, fFragmentModule);
             break;
         case Program::kGeometry_Kind:
             this->loadGeometryIntrinsics();
-            fIRGenerator->fIntrinsics = fGeometryIntrinsics.get();
-            fIRGenerator->start(&settings, fGeometrySymbolTable);
+            fIRGenerator->start(&settings, fGeometryModule);
             break;
         case Program::kFragmentProcessor_Kind:
             this->loadFPIntrinsics();
-            fIRGenerator->fIntrinsics = fFPIntrinsics.get();
-            fIRGenerator->start(&settings, fFPSymbolTable);
+            fIRGenerator->start(&settings, fFPModule);
             break;
         case Program::kPipelineStage_Kind:
             this->loadPipelineIntrinsics();
-            fIRGenerator->fIntrinsics = fPipelineIntrinsics.get();
-            fIRGenerator->start(&settings, fPipelineSymbolTable);
+            fIRGenerator->start(&settings, fPipelineModule);
             break;
         case Program::kGeneric_Kind:
             this->loadInterpreterIntrinsics();
-            fIRGenerator->fIntrinsics = fInterpreterIntrinsics.get();
-            fIRGenerator->start(&settings, fInterpreterSymbolTable);
+            fIRGenerator->start(&settings, fInterpreterModule);
             break;
     }
     if (externalValues) {
@@ -1717,8 +1617,7 @@
 #ifdef SK_ENABLE_SPIRV_VALIDATION
     StringStream buffer;
     fSource = program.fSource.get();
-    SPIRVCodeGenerator cg(&fIRGenerator->fContext, fIRGenerator->fModifiers.get(), &program, this,
-                          &buffer);
+    SPIRVCodeGenerator cg(fContext.get(), fIRGenerator->fModifiers.get(), &program, this, &buffer);
     bool result = cg.generateCode();
     fSource = nullptr;
     if (result) {
@@ -1736,8 +1635,7 @@
     }
 #else
     fSource = program.fSource.get();
-    SPIRVCodeGenerator cg(&fIRGenerator->fContext, fIRGenerator->fModifiers.get(), &program, this,
-                          &out);
+    SPIRVCodeGenerator cg(fContext.get(), fIRGenerator->fModifiers.get(), &program, this, &out);
     bool result = cg.generateCode();
     fSource = nullptr;
 #endif